-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add cpu function #21255
feat: add cpu function #21255
Changes from all commits
8178af7
6257d55
628b7da
d8ce0bc
265956b
577e193
b4b7165
1a298a4
7b8ae32
4d72bc0
db0c288
2f2ad38
44688cc
b86109b
32fc93f
44812e5
f568a40
d8e747e
1dcb9ed
2747f29
61d3859
bc18489
1e765d8
602cc21
6193216
311532b
bfa1299
aecd4f9
4befcf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,8 +52,14 @@ | |
_quantile_helper, | ||
) | ||
|
||
import unittest | ||
import torch | ||
from unittest.mock import patch | ||
from ivy_test import helpers | ||
from ivy_test.helpers import CLASS_TREE, handle_frontend_methodtry: | ||
|
||
try: | ||
import torch | ||
import torch | ||
except ImportError: | ||
torch = SimpleNamespace() | ||
|
||
|
@@ -7831,6 +7837,55 @@ def test_torch_index_select( | |
) | ||
|
||
|
||
#cpu | ||
@handle_frontend_method | ||
class TestTorchInstanceToCPU(unittest.TestCase):( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why have you implemented the test in this fashion, any reason? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If not, stick to the way other test functions are implemented. |
||
class_tree=CLASS_TREE, | ||
init_tree="torch.tensor", | ||
method_name="ivy.to_device", | ||
dtype_and_x=helpers.dtype_and_values( | ||
available_dtypes=helpers.get_dtypes("float"), | ||
num_arrays=1, | ||
min_value=-1e04, | ||
max_value=1e04, | ||
allow_inf=False, | ||
), | ||
) | ||
def test_torch_instance_to_cpu( | ||
self, | ||
dtype_and_x, | ||
frontend, | ||
backend_fw, | ||
frontend_method_data, | ||
init_flags, | ||
method_flags, | ||
): | ||
input_dtype, x = dtype_and_x | ||
with patch("ivy_framework.current_framework_str", return_value="torch"), \ | ||
patch("ivy_framework.current_device_str", return_value="cpu"): | ||
instance = frontend.init_all_as_kwargs_np( | ||
input_dtypes=input_dtype, data=x[0] | ||
) | ||
|
||
result = frontend.frontend_method_data( | ||
instance, method_name="to_cpu", input_dtypes=input_dtype | ||
) | ||
|
||
self.assertTrue(torch.all(result.data.cpu() == instance.data.cpu())) | ||
|
||
with patch("ivy_framework.current_framework_str", return_value="numpy"): | ||
result = frontend.frontend_method_data( | ||
instance, method_name="to_cpu", input_dtypes=input_dtype | ||
) | ||
|
||
|
||
self.assertEqual(result, instance) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
|
||
Comment on lines
+7885
to
+7887
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these lines are unnecessary. |
||
|
||
# int | ||
@handle_frontend_method( | ||
class_tree=CLASS_TREE, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use correct name as mentioned in open issues list here: #3612 or here in original functions here https://pytorch.org/docs/stable/generated/torch.Tensor.cpu.html#torch.Tensor.cpu.