Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cumsum add_constant bug fix (add dtype for np zeros) #3258

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def cumsum(
)
else:
new_dims = tuple(data.shape)
zeros = np.zeros(new_dims)
zeros = np.zeros(new_dims, dtype=np.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this dtype be dependent on input dtype or always float32 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using np.float32 for the input works fine regardless of the input type.

However, if we trace the root cause of the error, we can see that in this line, the name truncate_double is used, but we are passing truncate_long_and_double as an argument to torch.compile, as shown here.

Because of this, at this point, the truncate_long_and_double argument is not handled and then removed, which leads to an error when trying to process the default type float64 of np.zeros.

According to this section, torch_tensorrt.dynamo.compile prefers truncate_double as the input but can also handle truncate_long_and_double.

What would be the best way to fix this issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the truncate_long_and_double to truncate_double in this example

"truncate_long_and_double": True,
and that should work right ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

truncate_long_and_double is deprecated. But it looks like it is not correctly handled if user provides this argument in torch.compile workflow. Can you add this check in

valid_attrs = {attr.name for attr in fields(settings)}
? similar to https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/_compiler.py#L180-L185

zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")

running_sum = loop.add_recurrence(zero_trttensor)
Expand Down
Loading