Skip to content

Commit

Permalink
Merge pull request #7 from NillionNetwork/feature/nada-algebra-0.3.1
Browse files Browse the repository at this point in the history
Feature/nada algebra 0.3.1
  • Loading branch information
mathias-nillion authored Jun 5, 2024
2 parents ba8dc6b + 134099c commit 33c245e
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
14 changes: 7 additions & 7 deletions nada_ai/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Union
import nada_algebra as na
from nada_ai.nn.module import Module
from nada_dsl import Integer, NadaType, SecretInteger, SecretBoolean, PublicBoolean
from nada_dsl import Integer, NadaType, SecretBoolean, PublicBoolean


class ReLU(Module):
Expand All @@ -19,7 +19,7 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
Returns:
na.NadaArray: Module output.
"""
if x.dtype in (na.Rational, na.SecretRational):
if x.is_rational:
mask = x.apply(self._rational_relu)
else:
mask = x.apply(self._relu)
Expand All @@ -29,29 +29,29 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
@staticmethod
def _rational_relu(
value: Union[na.Rational, na.SecretRational]
) -> na.SecretRational:
) -> Union[na.Rational, na.SecretRational]:
"""
Element-wise ReLU logic for rational values.
Args:
value (Union[na.Rational, na.SecretRational]): Input rational.
Returns:
na.SecretRational: ReLU output rational.
Union[na.Rational, na.SecretRational]: ReLU output rational.
"""
above_zero: Union[PublicBoolean, SecretBoolean] = value > na.rational(0)
above_zero: Union[na.PublicBoolean, na.SecretBoolean] = value > na.rational(0)
return above_zero.if_else(na.rational(1), na.rational(0))

@staticmethod
def _relu(value: NadaType) -> SecretInteger:
def _relu(value: NadaType) -> NadaType:
"""
Element-wise ReLU logic for NadaType values.
Args:
value (NadaType): Input nada value.
Returns:
SecretInteger: Output nada value.
NadaType: Output nada value.
"""
above_zero: Union[PublicBoolean, SecretBoolean] = value > Integer(0)
return above_zero.if_else(Integer(1), Integer(0))
4 changes: 2 additions & 2 deletions nada_ai/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
unbatched = True

batch_size, channels, input_height, input_width = x.shape
dtype = x.dtype
is_rational = x.is_rational

if any(pad > 0 for pad in self.padding):
# TODO: avoid side-step to NumPy
Expand Down Expand Up @@ -246,7 +246,7 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:

pool_region = padded_input[b, c, start_h:end_h, start_w:end_w]

if dtype in (na.Rational, na.SecretRational):
if is_rational:
pool_size = na.rational(pool_region.size)
else:
pool_size = Integer(pool_region.size)
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ python = "^3.10"
numpy = "^1.26.4"
nada-dsl = "^0.2.1"
py-nillion-client = "^0.2.1"
nada-algebra = "^0.3.0"
nada-algebra = "^0.3.1"


[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit 33c245e

Please sign in to comment.