Skip to content

Commit

Permalink
Update & add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mathias-nillion committed Jun 7, 2024
1 parent 93e4ed7 commit 6739dcc
Showing 1 changed file with 44 additions and 36 deletions.
80 changes: 44 additions & 36 deletions tests/python-tests/test_model_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Model client unit tests"""

from collections import OrderedDict
import pytest
import torch
import numpy as np

import nada_algebra as na
from sklearn.linear_model import LinearRegression, LogisticRegression
from torch import nn
from nada_ai import ModelClient
from nada_ai.client import ModelClient
from nada_ai import BaseClient, TorchClient, SklearnClient
import py_nillion_client as nillion


Expand All @@ -18,28 +19,28 @@ def test_sklearn_1(self):

# Exporting untrained model should not be possible
with pytest.raises(AttributeError):
ModelClient.from_sklearn(lin_reg)
SklearnClient(lin_reg)

X = np.array([[1, 2, 3], [2, 3, 4]])
y = np.array([0, 1])

lin_reg_fit = lin_reg.fit(X, y)

ModelClient.from_sklearn(lin_reg_fit)
SklearnClient(lin_reg_fit)

def test_sklearn_2(self):
log_reg = LogisticRegression(fit_intercept=False)

# Exporting untrained model should not be possible
with pytest.raises(AttributeError):
ModelClient.from_sklearn(log_reg)
SklearnClient(log_reg)

X = np.array([[1, 2, 3], [2, 3, 4]])
y = np.array([0, 1])

log_reg_fit = log_reg.fit(X, y)

ModelClient.from_sklearn(log_reg_fit)
SklearnClient(log_reg_fit)

def test_sklearn_3(self):
log_reg = LogisticRegression(fit_intercept=False)
Expand All @@ -49,16 +50,15 @@ def test_sklearn_3(self):

log_reg_fit = log_reg.fit(X, y)

model_client = ModelClient.from_sklearn(log_reg_fit)
model_client = SklearnClient(log_reg_fit)

secrets = model_client.export_state_as_secrets("test_model")
secrets = model_client.export_state_as_secrets("test_model", na.SecretRational)

assert len(secrets.keys()) == 4
assert len(secrets.keys()) == 3

assert "test_model_coef_0_0" in secrets.keys()
assert "test_model_coef_0_1" in secrets.keys()
assert "test_model_coef_0_2" in secrets.keys()
assert "test_model_intercept_0" in secrets.keys()

def test_sklearn_4(self):
log_reg = LogisticRegression(fit_intercept=False)
Expand All @@ -68,33 +68,41 @@ def test_sklearn_4(self):

log_reg_fit = log_reg.fit(X, y)

model_client = ModelClient.from_sklearn(log_reg_fit)
model_client = SklearnClient(log_reg_fit)

model_client.export_state_as_secrets("test_model", nillion.SecretInteger)
with pytest.raises(NotImplementedError):
model_client.export_state_as_secrets("test_model", nillion.SecretInteger)

def test_sklearn_5(self):
log_reg = LogisticRegression(fit_intercept=False)
def test_base_client_1(self):
base_client = BaseClient({"some_value": 1})

X = np.array([[1, 2, 3], [2, 3, 4]])
y = np.array([0, 1])
secrets = base_client.export_state_as_secrets("test_model", na.Rational)

log_reg_fit = log_reg.fit(X, y)
assert list(secrets.keys()) == ["test_model_some_value_0"]

model_client = ModelClient(
log_reg_fit,
OrderedDict(
{"coef": log_reg_fit.coef_, "intercept": log_reg_fit.intercept_}
),
)
def test_custom_client_1(self):
class MyModelClient(ModelClient):
def __init__(self) -> None:
self.state_dict = {"some_value": [1, 2, 3]}

secrets = model_client.export_state_as_secrets("test_model")
model_client = MyModelClient()

assert len(secrets.keys()) == 4
secrets = model_client.export_state_as_secrets("test_model", na.Rational)

assert "test_model_coef_0_0" in secrets.keys()
assert "test_model_coef_0_1" in secrets.keys()
assert "test_model_coef_0_2" in secrets.keys()
assert "test_model_intercept_0" in secrets.keys()
assert list(sorted(secrets.keys())) == [
"test_model_some_value_0",
"test_model_some_value_1",
"test_model_some_value_2",
]

def test_custom_client_2(self):
class MyModelClient(ModelClient):
def __init__(self) -> None:
self.some_value = {"some_value": 1}

# Invalid model client: no state_dict defined
with pytest.raises(AttributeError):
MyModelClient()

def test_torch_1(self):
class TestModule(nn.Module):
Expand All @@ -110,7 +118,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

mod = TestModule()

ModelClient.from_torch(mod)
TorchClient(mod)

def test_torch_2(self):
class TestModule(nn.Module):
Expand All @@ -126,9 +134,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

mod = TestModule()

model_client = ModelClient.from_torch(mod)
model_client = TorchClient(mod)

secrets = model_client.export_state_as_secrets("test_model")
secrets = model_client.export_state_as_secrets("test_model", na.SecretRational)

assert len(secrets) == 14

Expand Down Expand Up @@ -166,9 +174,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

mod = TestModule()

model_client = ModelClient.from_torch(mod)
model_client = TorchClient(mod)

secrets = model_client.export_state_as_secrets("test_model")
secrets = model_client.export_state_as_secrets("test_model", na.SecretRational)

assert len(secrets) == 30

Expand Down Expand Up @@ -222,9 +230,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

mod = TestModule()

model_client = ModelClient(mod, mod.state_dict())
model_client = TorchClient(mod)

secrets = model_client.export_state_as_secrets("test_model")
secrets = model_client.export_state_as_secrets("test_model", na.SecretRational)

assert len(secrets) == 14

Expand Down

0 comments on commit 6739dcc

Please sign in to comment.