diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 88e745286..0a477c661 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -682,7 +682,7 @@ def __init__(self, tensor): @classmethod def from_numpy(cls, array): tensor = torch.from_numpy(array) - return cls(array) + return cls(tensor) def shape(self) -> List[int]: return list(self.tensor.shape)