diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 14195d17bd3..e1f75ff4b33 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1482,13 +1482,19 @@ public PtNDArray percentile(Number percentile, int[] axes) { /** {@inheritDoc} */ @Override public PtNDArray median() { - throw new UnsupportedOperationException("Not implemented"); + return median(new int[] {-1}); } /** {@inheritDoc} */ @Override public PtNDArray median(int[] axes) { - throw new UnsupportedOperationException("Not implemented"); + if (axes.length != 1) { + throw new UnsupportedOperationException( + "Not supporting zero or multi-dimension median"); + } + NDList result = JniUtils.median(this, axes[0], false); + result.get(1).close(); + return (PtNDArray) result.get(0); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 694387f33fc..aad38ae8f0c 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -875,6 +875,13 @@ public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) { PyTorchLibrary.LIB.torchMin(ndArray.getHandle(), dim, keepDim)); } + public static NDList median(PtNDArray ndArray, long dim, boolean keepDim) { + long[] handles = PyTorchLibrary.LIB.torchMedian(ndArray.getHandle(), dim, keepDim); + return new NDList( + new PtNDArray(ndArray.getManager(), handles[0]), + new PtNDArray(ndArray.getManager(), handles[1])); + } + public static PtNDArray mean(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle())); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 65f5e8a479b..63524ff60a4 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -238,6 +238,8 @@ native void torchIndexPut( native long torchMinimum(long self, long other); + native long[] torchMedian(long self, long dim, boolean keepDim); + native long torchMin(long handle); native long torchMin(long handle, long dim, boolean keepDim); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc index c30d1e8ebbd..cfc4e97681a 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc @@ -213,6 +213,20 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMinimum( API_END_RETURN() } +JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMedian( + JNIEnv* env, jobject jthis, jlong jself, jlong jdim, jboolean keep_dim) { + API_BEGIN() + const auto* self_ptr = reinterpret_cast(jself); + const auto result = self_ptr->median(jdim, keep_dim); + const auto* value_ptr = new torch::Tensor(std::get<0>(result)); + const auto* indices_ptr = new torch::Tensor(std::get<1>(result)); + std::vector vect; + vect.push_back(reinterpret_cast(value_ptr)); + vect.push_back(reinterpret_cast(indices_ptr)); + return djl::utils::jni::GetLongArrayFromVec(env, vect); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAbs(JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementComparisonOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementComparisonOpTest.java index 7d4da81c71f..1f5b215bcf0 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementComparisonOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementComparisonOpTest.java @@ -601,6 +601,16 @@ public void testMinimumNDArray() { } } + @Test + public void testMedian() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray array1 = manager.create(new float[] {1, 3, 2, 5, 4}); + Assert.assertEquals(array1.median(), manager.create(3.0f)); + array1 = manager.create(new float[] {1, 3, 2, 5, 4, 8}, new Shape(2, 3)); + Assert.assertEquals(array1.median(new int[] {1}), manager.create(new float[] {2, 5})); + } + } + @Test public void testWhere() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {