Skip to content

Commit

Permalink
Merge pull request #8 from NillionNetwork/feature/nada-algebra-0.3.2
Browse files Browse the repository at this point in the history
Use built-in nada functions in favor of numpy
  • Loading branch information
mathias-nillion authored Jun 6, 2024
2 parents 33c245e + 4aecabd commit 9e48be0
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 125 deletions.
37 changes: 12 additions & 25 deletions nada_ai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import nada_algebra as na
import nada_algebra.client as na_client
from collections import OrderedDict
from typing import Any, Dict, Iterable, Union
from typing import Any, Dict, Union

from sklearn.linear_model import (
LinearRegression,
Expand All @@ -26,18 +26,20 @@
nillion.PublicVariableInteger,
nillion.PublicVariableUnsignedInteger,
]
_Tensor = Union[np.ndarray, torch.Tensor]
_LinearModel = Union[LinearRegression, LogisticRegression, LogisticRegressionCV]


class ModelClient:
"""ML model client"""

def __init__(self, model: Any, state_dict: OrderedDict[str, np.ndarray]) -> None:
def __init__(self, model: Any, state_dict: OrderedDict[str, _Tensor]) -> None:
"""
Initialization.
Args:
model (Any): Model object to wrap around.
state_dict (OrderedDict[str, np.ndarray]): Model state.
state_dict (OrderedDict[str, _Tensor]): Model state.
"""
self.model = model
self.state_dict = state_dict
Expand All @@ -64,31 +66,17 @@ def from_sklearn(cls, model: sklearn.base.BaseEstimator) -> "ModelClient":
Args:
model (sklearn.base.BaseEstimator): Sklearn estimator object.
Raises:
NotImplementedError: Raised when unsupported Scikit-learn model is passed.
Returns:
ModelClient: Instantiated model client.
"""
if not isinstance(model, sklearn.base.BaseEstimator):
raise TypeError(
"Cannot interpret type `%s` as Sklearn model. Expected (sub)type of `sklearn.base.BaseEstimator`"
% type(model).__name__
)

if isinstance(model, LinearRegression):
state_dict = OrderedDict(
{
"coef": model.coef_,
"intercept": (
model.intercept_
if isinstance(model.intercept_, Iterable)
else np.array([model.intercept_])
),
}
)
elif isinstance(model, (LogisticRegression, LogisticRegressionCV)):
if isinstance(model, _LinearModel):
state_dict = OrderedDict(
{
"coef": model.coef_,
"intercept": model.intercept_,
"intercept": np.array(model.intercept_),
}
)
else:
Expand Down Expand Up @@ -119,9 +107,8 @@ def export_state_as_secrets(
state_secrets = {}
for state_layer_name, state_layer_weight in self.state_dict.items():
layer_name = f"{name}_{state_layer_name}"
state_secret = na_client.array(
self.__ensure_numpy(state_layer_weight), layer_name, nada_type
)
layer_state = self.__ensure_numpy(state_layer_weight)
state_secret = na_client.array(layer_state, layer_name, nada_type)
state_secrets.update(state_secret)

return state_secrets
Expand Down
20 changes: 12 additions & 8 deletions nada_ai/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from typing import Union
import nada_algebra as na
from nada_ai.nn.module import Module
from nada_dsl import Integer, NadaType, SecretBoolean, PublicBoolean
from nada_dsl import (
Integer,
NadaType,
SecretBoolean,
PublicBoolean,
SecretInteger,
PublicInteger,
)


class ReLU(Module):
Expand All @@ -19,11 +26,8 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
Returns:
na.NadaArray: Module output.
"""
if x.is_rational:
mask = x.apply(self._rational_relu)
else:
mask = x.apply(self._relu)

relu = self._rational_relu if x.is_rational else self._relu
mask = x.apply(relu)
return x * mask

@staticmethod
Expand All @@ -43,15 +47,15 @@ def _rational_relu(
return above_zero.if_else(na.rational(1), na.rational(0))

@staticmethod
def _relu(value: NadaType) -> NadaType:
def _relu(value: NadaType) -> Union[PublicInteger, SecretInteger]:
"""
Element-wise ReLU logic for NadaType values.
Args:
value (NadaType): Input nada value.
Returns:
NadaType: Output nada value.
Union[PublicInteger, SecretInteger]: Output nada value.
"""
above_zero: Union[PublicBoolean, SecretBoolean] = value > Integer(0)
return above_zero.if_else(Integer(1), Integer(0))
88 changes: 34 additions & 54 deletions nada_ai/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
kernel_height, kernel_width = kernel_size
self.kernel_size = kernel_size

if isinstance(padding, int):
padding = (padding, padding)
Expand All @@ -77,9 +77,7 @@ def __init__(
stride = (stride, stride)
self.stride = stride

self.weight = Parameter(
(out_channels, in_channels, kernel_height, kernel_width)
)
self.weight = Parameter((out_channels, in_channels, *kernel_size))
self.bias = Parameter(out_channels) if include_bias else None

def forward(self, x: na.NadaArray) -> na.NadaArray:
Expand All @@ -93,18 +91,17 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
na.NadaArray: Module output.
"""
unbatched = False
if len(x.shape) == 3:
if x.ndim == 3:
# Assume unbatched --> assign batch_size of 1
x = x.reshape(1, *x.shape)
unbatched = True

batch_size, _, input_rows, input_cols = x.shape
batch_size, _, input_height, input_width = x.shape
out_channels, _, kernel_rows, kernel_cols = self.weight.shape

if any(pad > 0 for pad in self.padding):
# TODO: avoid side-step to NumPy
padded_input = np.pad(
x.inner,
x = na.pad(
x,
[
(0, 0),
(0, 0),
Expand All @@ -113,48 +110,39 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
],
mode="constant",
)
padded_input = np.frompyfunc(
lambda x: Integer(x.item()) if isinstance(x, np.int64) else x, 1, 1
)(padded_input)
else:
padded_input = x.inner

output_rows = (input_rows + 2 * self.padding[0] - kernel_rows) // self.stride[
0
] + 1
output_cols = (input_cols + 2 * self.padding[1] - kernel_cols) // self.stride[
1
] + 1

output_tensor = np.zeros(
(batch_size, out_channels, output_rows, output_cols)
).astype(Integer)

out_height = (
input_height + 2 * self.padding[0] - self.kernel_size[0]
) // self.stride[0] + 1
out_width = (
input_width + 2 * self.padding[1] - self.kernel_size[1]
) // self.stride[1] + 1

output_tensor = na.zeros((batch_size, out_channels, out_height, out_width))
for b in range(batch_size):
for oc in range(out_channels):
for i in range(output_rows):
for j in range(output_cols):
for i in range(out_height):
for j in range(out_width):
start_i = i * self.stride[0]
start_j = j * self.stride[1]

receptive_field = padded_input[
receptive_field = x[
b,
:,
start_i : start_i + kernel_rows,
start_j : start_j + kernel_cols,
]
output_tensor[b, oc, i, j] = np.sum(
self.weight.inner[oc] * receptive_field
output_tensor[b, oc, i, j] = na.sum(
self.weight[oc] * receptive_field
)

if self.bias is not None:
output_tensor = output_tensor + self.bias.inner.reshape(
1, out_channels, 1, 1
)
output_tensor += self.bias.reshape(1, out_channels, 1, 1)

if unbatched:
output_tensor = output_tensor[0]

return na.NadaArray(output_tensor)
return output_tensor


class AvgPool2d(Module):
Expand Down Expand Up @@ -199,7 +187,7 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
na.NadaArray: Module output.
"""
unbatched = False
if len(x.shape) == 3:
if x.ndim == 3:
# Assume unbatched --> assign batch_size of 1
x = x.reshape(1, *x.shape)
unbatched = True
Expand All @@ -208,9 +196,8 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
is_rational = x.is_rational

if any(pad > 0 for pad in self.padding):
# TODO: avoid side-step to NumPy
padded_input = np.pad(
x.inner,
x = na.pad(
x,
(
(0, 0),
(0, 0),
Expand All @@ -219,44 +206,37 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
),
mode="constant",
)
padded_input = np.frompyfunc(
lambda x: Integer(x.item()) if isinstance(x, np.int64) else x, 1, 1
)(padded_input)
else:
padded_input = x.inner

output_height = (
out_height = (
input_height + 2 * self.padding[0] - self.kernel_size[0]
) // self.stride[0] + 1
output_width = (
out_width = (
input_width + 2 * self.padding[1] - self.kernel_size[1]
) // self.stride[1] + 1

output_array = np.zeros(
(batch_size, channels, output_height, output_width)
).astype(Integer)
output_tensor = na.zeros((batch_size, channels, out_height, out_width))
for b in range(batch_size):
for c in range(channels):
for i in range(output_height):
for j in range(output_width):
for i in range(out_height):
for j in range(out_width):
start_h = i * self.stride[0]
start_w = j * self.stride[1]
end_h = start_h + self.kernel_size[0]
end_w = start_w + self.kernel_size[1]

pool_region = padded_input[b, c, start_h:end_h, start_w:end_w]
pool_region = x[b, c, start_h:end_h, start_w:end_w]

if is_rational:
pool_size = na.rational(pool_region.size)
else:
pool_size = Integer(pool_region.size)

output_array[b, c, i, j] = np.sum(pool_region) / pool_size
output_tensor[b, c, i, j] = na.sum(pool_region) / pool_size

if unbatched:
output_array = output_array[0]
output_tensor = output_tensor[0]

return na.NadaArray(output_array)
return na.NadaArray(output_tensor)


class Flatten(Module):
Expand Down
6 changes: 2 additions & 4 deletions nada_ai/nn/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import nada_algebra as na
from nada_ai.exceptions import MismatchedShapesException
from nada_dsl import Integer

_ShapeLike = Union[int, Iterable[int]]

Expand All @@ -20,9 +19,8 @@ def __init__(self, shape: _ShapeLike) -> None:
Args:
shape (_ShapeLike, optional): Parameter array shape.
"""
zeros = np.zeros(shape, dtype=int)
zeros = np.frompyfunc(Integer, 1, 1)(zeros)
super().__init__(inner=zeros)
zeros = na.zeros(shape)
super().__init__(inner=zeros.inner)

def numel(self) -> int:
"""
Expand Down
Loading

0 comments on commit 9e48be0

Please sign in to comment.