Skip to content

Commit

Permalink
fix: cumsum add_constant bug fix (add dtype for np zeros)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed Oct 22, 2024
1 parent 40193dc commit 123e2a8
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def cumsum(
)
else:
new_dims = tuple(data.shape)
zeros = np.zeros(new_dims)
zeros = np.zeros(new_dims, dtype=np.float32)
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")

running_sum = loop.add_recurrence(zero_trttensor)
Expand Down

0 comments on commit 123e2a8

Please sign in to comment.