Skip to content

Commit

Permalink
Merge pull request #12 from NillionNetwork/feature/major-refactor
Browse files Browse the repository at this point in the history
Modularize tech design
  • Loading branch information
mathias-nillion authored Jun 12, 2024
2 parents 2f57586 + e4bce1f commit a6191e5
Show file tree
Hide file tree
Showing 20 changed files with 423 additions and 391 deletions.
1 change: 0 additions & 1 deletion nada_ai/__init__.py

This file was deleted.

2 changes: 2 additions & 0 deletions nada_ai/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .model_client import ModelClient
from .clients import *
4 changes: 4 additions & 0 deletions nada_ai/client/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .sklearn import SklearnClient
from .torch import TorchClient

__all__ = ["SklearnClient", "TorchClient"]
29 changes: 29 additions & 0 deletions nada_ai/client/clients/sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Scikit-learn client implementation"""

import sklearn
from nada_ai.client.model_client import ModelClient
from nada_ai.typing import LinearModel

__all__ = ["SklearnClient"]


class SklearnClient(ModelClient):
"""ModelClient for Scikit-learn models"""

def __init__(self, model: sklearn.base.BaseEstimator) -> None:
"""
Client initialization.
Args:
model (sklearn.base.BaseEstimator): Sklearn model object to wrap around.
"""
if isinstance(model, LinearModel):
state_dict = {"coef": model.coef_}
if model.fit_intercept is True:
state_dict.update({"intercept": model.intercept_})
else:
raise NotImplementedError(
f"Instantiating ModelClient from Sklearn model type `{type(model).__name__}` is not yet implemented."
)

self.state_dict = state_dict
19 changes: 19 additions & 0 deletions nada_ai/client/clients/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""PyTorch client implementation"""

from torch import nn
from nada_ai.client.model_client import ModelClient

__all__ = ["TorchClient"]


class TorchClient(ModelClient):
"""ModelClient for PyTorch models"""

def __init__(self, model: nn.Module) -> None:
"""
Client initialization.
Args:
model (nn.Module): PyTorch model object to wrap around.
"""
self.state_dict = model.state_dict()
86 changes: 9 additions & 77 deletions nada_ai/client.py → nada_ai/client/model_client.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,13 @@
"""
This module provides functions to work with the Python Nillion Client
"""
"""Model client implementation"""

from abc import ABC, ABCMeta
import nada_algebra as na
import nada_algebra.client as na_client
from typing import Any, Dict, Sequence, Union
from typing import Any, Dict, Sequence
from nada_ai.typing import NillionType

from sklearn.linear_model import (
LinearRegression,
LogisticRegression,
LogisticRegressionCV,
)
import torch
from torch import nn
import sklearn
import numpy as np
import py_nillion_client as nillion

_NillionType = Union[
na.Rational,
na.SecretRational,
nillion.SecretInteger,
nillion.SecretUnsignedInteger,
nillion.PublicVariableInteger,
nillion.PublicVariableUnsignedInteger,
]
_LinearModel = Union[LinearRegression, LogisticRegression, LogisticRegressionCV]


class ModelClientMeta(ABCMeta):
Expand All @@ -42,9 +23,9 @@ def __call__(self, *args, **kwargs) -> object:
Returns:
object: Result object.
"""
obj = super(ModelClientMeta, self).__call__(*args, **kwargs)
obj = super().__call__(*args, **kwargs)
if not getattr(obj, "state_dict"):
raise AttributeError("required attribute `state_dict` not set")
raise AttributeError("Required attribute `state_dict` not set")
return obj


Expand All @@ -54,21 +35,21 @@ class ModelClient(ABC, metaclass=ModelClientMeta):
def export_state_as_secrets(
self,
name: str,
nada_type: _NillionType,
) -> Dict[str, _NillionType]:
nada_type: NillionType,
) -> Dict[str, NillionType]:
"""
Exports model state as a Dict of Nillion secret types.
Args:
name (str): Name to be used to store state secrets in the network.
nada_type (_NillionType): Data type to convert weights to.
nada_type (NillionType): Data type to convert weights to.
Raises:
NotImplementedError: Raised when unsupported model state type is passed.
TypeError: Raised when model state has incompatible values.
Returns:
Dict[str, _NillionType]: Dict of Nillion secret types that represents model state.
Dict[str, NillionType]: Dict of Nillion secret types that represents model state.
"""
if nada_type not in (na.Rational, na.SecretRational):
raise NotImplementedError("Exporting non-rational state is not supported")
Expand Down Expand Up @@ -104,52 +85,3 @@ def __ensure_numpy(self, array_like: Any) -> np.ndarray:
raise TypeError(
"Could not convert type `%s` to NumPy array" % type(array_like).__name__
)


class StateClient(ModelClient):
"""ModelClient for generic model states"""

def __init__(self, state_dict: Dict[str, Any]) -> None:
"""
Client initialization.
This client accepts an arbitrary model state as input.
Args:
state_dict (Dict[str, Any]): State dict.
"""
self.state_dict = state_dict


class TorchClient(ModelClient):
"""ModelClient for PyTorch models"""

def __init__(self, model: nn.Module) -> None:
"""
Client initialization.
Args:
model (nn.Module): PyTorch model object to wrap around.
"""
self.state_dict = model.state_dict()


class SklearnClient(ModelClient):
"""ModelClient for Scikit-learn models"""

def __init__(self, model: sklearn.base.BaseEstimator) -> None:
"""
Client initialization.
Args:
model (sklearn.base.BaseEstimator): Sklearn model object to wrap around.
"""
if isinstance(model, _LinearModel):
state_dict = {"coef": model.coef_}
if model.fit_intercept is True:
state_dict.update({"intercept": model.intercept_})
else:
raise NotImplementedError(
f"Instantiating ModelClient from Sklearn model type `{type(model).__name__}` is not yet implemented."
)

self.state_dict = state_dict
2 changes: 2 additions & 0 deletions nada_ai/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Custom exceptions"""

__all__ = ["MismatchedShapesException"]


class MismatchedShapesException(Exception):
"""Raised when NadaArray shapes are incompatible"""
Expand Down
2 changes: 2 additions & 0 deletions nada_ai/linear_model/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from nada_ai.nn.module import Module
from nada_ai.nn.parameter import Parameter

__all__ = ["LinearRegression"]


class LinearRegression(Module):
"""Linear regression implementation"""
Expand Down
3 changes: 1 addition & 2 deletions nada_ai/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .module import Module
from .parameter import Parameter
from .layers import *
from .activations import *
from .modules import *
Loading

0 comments on commit a6191e5

Please sign in to comment.