From 123e2a8666f205f816948b58f24461d78088ba68 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 23 Oct 2024 00:33:43 +0900 Subject: [PATCH] fix: cumsum add_constant bug fix (add dtype for np zeros) --- py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index b58435b489..3274d78c2b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -370,7 +370,7 @@ def cumsum( ) else: new_dims = tuple(data.shape) - zeros = np.zeros(new_dims) + zeros = np.zeros(new_dims, dtype=np.float32) zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value") running_sum = loop.add_recurrence(zero_trttensor)