Skip to content

Commit

Permalink
[Pytorch] add basic median support (#2701)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 authored Jul 11, 2023
1 parent 814a269 commit fc37be9
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1482,13 +1482,19 @@ public PtNDArray percentile(Number percentile, int[] axes) {
/** {@inheritDoc} */
@Override
public PtNDArray median() {
throw new UnsupportedOperationException("Not implemented");
return median(new int[] {-1});
}

/** {@inheritDoc} */
@Override
public PtNDArray median(int[] axes) {
throw new UnsupportedOperationException("Not implemented");
if (axes.length != 1) {
throw new UnsupportedOperationException(
"Not supporting zero or multi-dimension median");
}
NDList result = JniUtils.median(this, axes[0], false);
result.get(1).close();
return (PtNDArray) result.get(0);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,13 @@ public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) {
PyTorchLibrary.LIB.torchMin(ndArray.getHandle(), dim, keepDim));
}

public static NDList median(PtNDArray ndArray, long dim, boolean keepDim) {
long[] handles = PyTorchLibrary.LIB.torchMedian(ndArray.getHandle(), dim, keepDim);
return new NDList(
new PtNDArray(ndArray.getManager(), handles[0]),
new PtNDArray(ndArray.getManager(), handles[1]));
}

public static PtNDArray mean(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ native void torchIndexPut(

native long torchMinimum(long self, long other);

native long[] torchMedian(long self, long dim, boolean keepDim);

native long torchMin(long handle);

native long torchMin(long handle, long dim, boolean keepDim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,20 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMinimum(
API_END_RETURN()
}

JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMedian(
JNIEnv* env, jobject jthis, jlong jself, jlong jdim, jboolean keep_dim) {
API_BEGIN()
const auto* self_ptr = reinterpret_cast<torch::Tensor*>(jself);
const auto result = self_ptr->median(jdim, keep_dim);
const auto* value_ptr = new torch::Tensor(std::get<0>(result));
const auto* indices_ptr = new torch::Tensor(std::get<1>(result));
std::vector<uintptr_t> vect;
vect.push_back(reinterpret_cast<uintptr_t>(value_ptr));
vect.push_back(reinterpret_cast<uintptr_t>(indices_ptr));
return djl::utils::jni::GetLongArrayFromVec(env, vect);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAbs(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,16 @@ public void testMinimumNDArray() {
}
}

@Test
public void testMedian() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray array1 = manager.create(new float[] {1, 3, 2, 5, 4});
Assert.assertEquals(array1.median(), manager.create(3.0f));
array1 = manager.create(new float[] {1, 3, 2, 5, 4, 8}, new Shape(2, 3));
Assert.assertEquals(array1.median(new int[] {1}), manager.create(new float[] {2, 5}));
}
}

@Test
public void testWhere() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down

0 comments on commit fc37be9

Please sign in to comment.