From bac88bd712964c67a7b9d72c48d0b2a6a61c834d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 7 Mar 2024 17:08:25 -0800 Subject: [PATCH] Use `Clip` to implement aten::clamp | fix(torchlib) --- onnxscript/function_libs/torch_lib/ops/core.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index be8dd62d0..b8f3db689 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1548,18 +1548,17 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = if min is None and max is None: return clamped - # If min is greater than max torch.clamp(..., min, max) - # sets all elements in input to the value of max. - # So this order is important. if min is not None: min_clamp = op.CastLike(min, self) - clamped = op.Max(clamped, min_clamp) + else: + min_clamp = None if max is not None: max_clamp = op.CastLike(max, self) - clamped = op.Min(clamped, max_clamp) + else: + max_clamp = None - return clamped + return op.Clip(self, min_clamp, max_clamp) @torch_op("aten::clamp_max", traceable=True)