-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue #2693 Implement PtNDArrayEx.multiBoxPrior with validation #2715
Issue #2693 Implement PtNDArrayEx.multiBoxPrior with validation #2715
Conversation
engines/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/engine/MxNDArrayExTest.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/engine/PtNDArrayExTest.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Outdated
Show resolved
Hide resolved
|
||
NDManager ndManager = array.getManager().getParentManager(); | ||
|
||
Float stepX = steps.get(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float stepX = steps.get(1); | |
float stepX = steps.get(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thankyou. This has been updated.
NDManager ndManager = array.getManager().getParentManager(); | ||
|
||
Float stepX = steps.get(1); | ||
Float stepY = steps.get(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float stepY = steps.get(0); | |
float stepY = steps.get(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thankyou. This has been updated.
@@ -694,7 +694,54 @@ public NDList multiBoxPrior( | |||
List<Float> steps, | |||
List<Float> offsets, | |||
boolean clip) { | |||
throw new UnsupportedOperationException("Not implemented"); | |||
|
|||
NDManager ndManager = array.getManager().getParentManager(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NDManager ndManager = array.getManager().getParentManager(); | |
NDManager ndManager = array.getManager(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thankyou. This has been updated.
} | ||
} | ||
} | ||
NDArray ndArray = ndManager.create(out).expandDims(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't look into your algorithm, but seems the final output is different from MXNet. Can you debug why it's different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some discussion of this here:
#2693
I'd love to debug it on the mxnet side, but this is non-trivial. I'm working on a method to do so, but would appreciate some pointers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @frankfliu I found the bug and put in a fix.
For me the first element in the result is -0.09902344
in my implementation and mxnet.
From your perspective is this now returning the correct values?
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #2715 +/- ##
============================================
+ Coverage 72.08% 72.17% +0.08%
- Complexity 5126 7034 +1908
============================================
Files 473 698 +225
Lines 21970 31320 +9350
Branches 2351 3234 +883
============================================
+ Hits 15838 22606 +6768
- Misses 4925 7171 +2246
- Partials 1207 1543 +336
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for your contribution.
Description
The big goal here is to implement
PtNDArrayEx.multiBoxPrior()
(currently throwsNotImplementException
) so that you can runPikachuTraining
with Pytorch and so you can run SSD Training on M1 Macs.This is discussed in: #2693
The big idea is to prove this works by writing matching unit tests for
MxNDAarrayEx.multiBoxPrior()
and comparing the results.The implementation for this is derived from:
https://github.com/apache/mxnet/blob/master/src/operator/contrib/multibox_prior.cc
and
https://github.com/apache/mxnet/blob/a720b15b5fa011a9610dfaeabf0792443c6abec5/src/operator/contrib/multibox_prior-inl.h#L116
This implementation is based on implementing the unit test on the C++ side (see unit test in pull request there) and comparing the steps internally.