-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
For double sided maxwell #21264
For double sided maxwell #21264
Conversation
Thanks for contributing to Ivy! 😊👏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @stalemate1
uint32
is not supported by torch
and paddle
, kindly add those in the unsupported_dtypes
decorator.
Also, do make sure that the tests are being passing when you request a review.
Feel free to comment on the PR if you need to ask any question.
Thanks!
Hello @zaeemansari70! Today, it says more than 90 files are failing for the same function. Could you please help me understand where I might be going wrong. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @stalemate1
You are passing in a uint32
dtype which is neither supported by torch
nor paddle
.
Kindly fix this 🙂
Also do not worry about the CI tests, I'll test your PR locally before merging 👍
@handle_frontend_test( | ||
fn_tree="jax.random.double_sided_maxwell", | ||
dtype_key=helpers.dtype_and_values( | ||
available_dtypes=["uint32"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are specifying the dtype here, this should be valid
dtype and the dtypes that should be skipped should be added in the decorator above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and pass in valid
here, so that the decorator skips the dtypes which are not to be tested on. Rather than us trying to restrict those in the tests.
hi @zaeemansari70 I did try different things including using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addresses your comments 🙂
Let me know if you have some questions, thanks! 🙂
@handle_frontend_test( | ||
fn_tree="jax.random.double_sided_maxwell", | ||
dtype_key=helpers.dtype_and_values( | ||
available_dtypes=["uint32"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and pass in valid
here, so that the decorator skips the dtypes which are not to be tested on. Rather than us trying to restrict those in the tests.
…or_double_sided_maxwell # Conflicts: # ivy/functional/frontends/jax/random.py
@stalemate1 |
…or_double_sided_maxwell # Conflicts: # ivy/functional/frontends/jax/random.py
sorry @Ookamice I did not notice that. I have corrected those files now. |
@stalemate1 PS: weirdly rademacher doesn't pass the tests on my end for some reason so I'm not sure why it is merged. Perhaps this may be the cause as to why pytorch test for double_sided_maxwell not working? |
@Ookamice but when I change it back to 0.5 rademacher is producing only one element, thus it is not working with double_sided_maxwell. Hence, this time I created an array of 0.5's so that it is compatible with any shape. I hope this is okay. |
@stalemate1 PS: make sure that when multiplying by 0.5, you get a float rather than a int, im not sure how the type promotion may carry through as your original array of 1s is a int |
@Ookamice |
@Ookamice Could this be merged now? |
@stalemate1 Secondly, I had a look at maxwell and plotted your output distribution compared to jax's native implementation. I plotted the output of both and it seems like there's a descrepency that should be fixed: PS: The code used to test this is: import ivy
import jax
from ivy.functional.frontends.jax.random import _get_seed
def maxwell_current(key, shape=None, dtype="float64"):
seed = _get_seed(key)
# generate uniform random numbers between 0 and 1
z = ivy.random_uniform(seed=seed, shape=shape, dtype=dtype)
# applying inverse transform sampling
x = (z**2) * ivy.exp(-(z**2) / 2)
return x
n = 1000000
import matplotlib.pyplot as plt
x = maxwell_current((123,123), (n,))
z = jax.random.maxwell(jax.numpy.array([123,123], dtype='uint32'), shape=(n,))
num_bins = int(ivy.sqrt(n))
n0, bins0, patches0 = plt.hist(x, bins=num_bins, histtype='step', color='blue')
n, bins, patches = plt.hist(z, bins=num_bins, histtype='step', color='red')
plt.show() |
I have trouble installing and running jax. But I had changed the code for maxwell. I was just waiting to submit this PR so that I could start a new PR for maxwell. The edited code is as follows,
and it is passing all the tests |
@stalemate1 Edit: just linking the issue regarding fixing this #23337 here for convivence (merged and resolved) |
@dash96 |
close #19483