Skip to content

Commit

Permalink
[TensorRT EP] Update ORT kernel output with TRT DDS int64 output for …
Browse files Browse the repository at this point in the history
…TRT 10 (#20738)

TRT 10 now natively supports int64 tensor, so needs to updating the code
where binding the ORT kernel output with DDS int64 output.
  • Loading branch information
chilo-ms authored May 21, 2024
1 parent 8a98874 commit df01e0d
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1062,8 +1062,12 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
#if NV_TENSORRT_MAJOR >= 10
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
#else
// The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output.
CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t)
#endif
// The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output.
CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double)
default: {
Expand Down

0 comments on commit df01e0d

Please sign in to comment.