diff --git a/nada_ai/client/clients/__init__.py b/nada_ai/client/clients/__init__.py index c6b4741..934d3b4 100644 --- a/nada_ai/client/clients/__init__.py +++ b/nada_ai/client/clients/__init__.py @@ -1,5 +1,4 @@ -from .sklearn_client import SklearnClient -from .state_client import StateClient -from .torch_client import TorchClient +from .sklearn import SklearnClient +from .torch import TorchClient -__all__ = ["SklearnClient", "StateClient", "TorchClient"] +__all__ = ["SklearnClient", "TorchClient"] diff --git a/nada_ai/client/clients/sklearn_client.py b/nada_ai/client/clients/sklearn.py similarity index 100% rename from nada_ai/client/clients/sklearn_client.py rename to nada_ai/client/clients/sklearn.py diff --git a/nada_ai/client/clients/state_client.py b/nada_ai/client/clients/state_client.py deleted file mode 100644 index 24f0fad..0000000 --- a/nada_ai/client/clients/state_client.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Generic state client implementation""" - -from typing import Any, Dict -from nada_ai.client.model_client import ModelClient - -__all__ = ["StateClient"] - - -class StateClient(ModelClient): - """ModelClient for generic model states""" - - def __init__(self, state_dict: Dict[str, Any]) -> None: - """ - Client initialization. - This client accepts an arbitrary model state as input. - - Args: - state_dict (Dict[str, Any]): State dict. - """ - self.state_dict = state_dict diff --git a/nada_ai/client/clients/torch_client.py b/nada_ai/client/clients/torch.py similarity index 100% rename from nada_ai/client/clients/torch_client.py rename to nada_ai/client/clients/torch.py diff --git a/tests/python-tests/test_model_client.py b/tests/python-tests/test_model_client.py index e142a03..ce9f19b 100644 --- a/tests/python-tests/test_model_client.py +++ b/tests/python-tests/test_model_client.py @@ -7,7 +7,7 @@ import nada_algebra as na from sklearn.linear_model import LinearRegression, LogisticRegression from torch import nn -from nada_ai.client import ModelClient, StateClient, TorchClient, SklearnClient +from nada_ai.client import ModelClient, TorchClient, SklearnClient import py_nillion_client as nillion @@ -72,13 +72,6 @@ def test_sklearn_4(self): with pytest.raises(NotImplementedError): model_client.export_state_as_secrets("test_model", nillion.SecretInteger) - def test_state_client_1(self): - state_client = StateClient({"some_value": 1}) - - secrets = state_client.export_state_as_secrets("test_model", na.Rational) - - assert list(secrets.keys()) == ["test_model_some_value_0"] - def test_custom_client_1(self): class MyModelClient(ModelClient): def __init__(self) -> None: