From 2348884fda6637cba93fad8eec00469cf803cfa4 Mon Sep 17 00:00:00 2001 From: Mathias Leys Date: Fri, 7 Jun 2024 14:49:59 +0200 Subject: [PATCH] Rename base client to state client --- nada_ai/__init__.py | 2 +- nada_ai/client.py | 4 ++-- tests/python-tests/test_model_client.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/nada_ai/__init__.py b/nada_ai/__init__.py index 7e083c5..0cbc867 100644 --- a/nada_ai/__init__.py +++ b/nada_ai/__init__.py @@ -1 +1 @@ -from nada_ai.client import BaseClient, TorchClient, SklearnClient +from nada_ai.client import StateClient, TorchClient, SklearnClient diff --git a/nada_ai/client.py b/nada_ai/client.py index 4dced00..027c6d8 100644 --- a/nada_ai/client.py +++ b/nada_ai/client.py @@ -106,8 +106,8 @@ def __ensure_numpy(self, array_like: Any) -> np.ndarray: ) -class BaseClient(ModelClient): - """Base ModelClient for generic model states""" +class StateClient(ModelClient): + """ModelClient for generic model states""" def __init__(self, state_dict: Dict[str, Any]) -> None: """ diff --git a/tests/python-tests/test_model_client.py b/tests/python-tests/test_model_client.py index 768378d..7dd696c 100644 --- a/tests/python-tests/test_model_client.py +++ b/tests/python-tests/test_model_client.py @@ -8,7 +8,7 @@ from sklearn.linear_model import LinearRegression, LogisticRegression from torch import nn from nada_ai.client import ModelClient -from nada_ai import BaseClient, TorchClient, SklearnClient +from nada_ai import StateClient, TorchClient, SklearnClient import py_nillion_client as nillion @@ -73,10 +73,10 @@ def test_sklearn_4(self): with pytest.raises(NotImplementedError): model_client.export_state_as_secrets("test_model", nillion.SecretInteger) - def test_base_client_1(self): - base_client = BaseClient({"some_value": 1}) + def test_state_client_1(self): + state_client = StateClient({"some_value": 1}) - secrets = base_client.export_state_as_secrets("test_model", na.Rational) + secrets = state_client.export_state_as_secrets("test_model", na.Rational) assert list(secrets.keys()) == ["test_model_some_value_0"]