From f1beea97fca1b364d44f183c94846556829a98d3 Mon Sep 17 00:00:00 2001 From: Dan Date: Thu, 2 May 2024 22:33:09 +0000 Subject: [PATCH] enforce irpa extension on save --- shark_turbine/aot/params.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/shark_turbine/aot/params.py b/shark_turbine/aot/params.py index 54858648..fadd6c86 100644 --- a/shark_turbine/aot/params.py +++ b/shark_turbine/aot/params.py @@ -258,7 +258,10 @@ def __init__(self): def save(self, file_path: Union[str, Path]): """Saves the archive.""" - self._index.create_archive_file(str(file_path)) + str_file_path = str(file_path) + if not str_file_path.endswith(".irpa"): + file_path = str_file_path + ".irpa" + self._index.create_archive_file(str_file_path) def add_tensor(self, name: str, tensor: torch.Tensor): """Adds an named tensor to the archive."""