-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: cumsum add_constant bug fix (add dtype for np zeros) #3258
base: main
Are you sure you want to change the base?
Conversation
@@ -370,7 +370,7 @@ def cumsum( | |||
) | |||
else: | |||
new_dims = tuple(data.shape) | |||
zeros = np.zeros(new_dims) | |||
zeros = np.zeros(new_dims, dtype=np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this dtype be dependent on input dtype or always float32 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using np.float32
for the input works fine regardless of the input type.
However, if we trace the root cause of the error, we can see that in this line, the name truncate_double
is used, but we are passing truncate_long_and_double
as an argument to torch.compile
, as shown here.
Because of this, at this point, the truncate_long_and_double
argument is not handled and then removed, which leads to an error when trying to process the default type float64
of np.zeros
.
According to this section, torch_tensorrt.dynamo.compile
prefers truncate_double
as the input but can also handle truncate_long_and_double
.
What would be the best way to fix this issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change the truncate_long_and_double
to truncate_double
in this example
"truncate_long_and_double": True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
truncate_long_and_double
is deprecated. But it looks like it is not correctly handled if user provides this argument in torch.compile
workflow. Can you add this check in
TensorRT/py/torch_tensorrt/dynamo/utils.py
Line 483 in 6d40ff1
valid_attrs = {attr.name for attr in fields(settings)} |
Description
When compiling the
roberta-base
model from Hugging Face (https://huggingface.co/FacebookAI/roberta-base), aTypeError
occurs in thecumsum
operation. For static shape input, the default datatype ofnp.zeros(new_dims)
function isnp.float64
which is not handled properly by thecreate_constant
utility function.Fixes # (issue)
Reproduction Code:
Type of change
Checklist: