diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index f50d2557a..576e32de4 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -70,3 +70,9 @@ def get_source_path(self): def get_lib_path(self): return self.libpath + + def set_lib_path(self, libpath): + self.libpath = libpath + + def set_src_path(self, srcpath): + self.srcpath = srcpath diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 2e9078727..39fc2e785 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -302,9 +302,11 @@ def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if srcpath is not None: - self.srcpath = srcpath + assert self.lib_generator is not None, "lib_generator is not initialized" + self.lib_generator.set_src_path(srcpath) if libpath is not None: - self.libpath = libpath + assert self.lib_generator is not None, "lib_generator is not initialized" + self.lib_generator.set_lib_path(libpath) self.lib = ctypes.CDLL(libpath) self.lib.init()