Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mathias-nillion committed Jun 11, 2024
1 parent 8e43b30 commit 1ef4145
Show file tree
Hide file tree
Showing 16 changed files with 16 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ The following are the currently available examples:
- [Complex Model](./complex_model): shows how to build more intricate model architectures using Nada AI. Contains convolutions, pooling operations, linear layers and activations
- [Time Series](./time_series): shows how to run a Facebook Prophet time series forecasting model using Nada AI

The Nada program source code is stored in `src/main.py`.
The Nada program source code is stored in `src/<EXAMPLE_NAME>.py`.

In order to follow the end-to-end example, head to `network/compute.py`. You can run it by simply running `nada build` to build the Nada program followed by `python network/compute.py`.
2 changes: 1 addition & 1 deletion examples/complex_model/nada-project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ version = "0.1.0"
authors = [""]

[[programs]]
path = "src/main.py"
path = "src/complex_model.py"
prime_size = 128
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/complex_model/tests/complex_model.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
program: main
program: complex_model
inputs:
secrets:
# We assume all values were originally floats, scaled & rounded by a factor of 2**16
Expand Down
2 changes: 1 addition & 1 deletion examples/linear_regression/nada-project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ version = "0.1.0"
authors = [""]

[[programs]]
path = "src/main.py"
path = "src/linear_regression.py"
prime_size = 128
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/linear_regression/tests/linear_regression.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
program: main
program: linear_regression
inputs:
secrets:
# We assume all values were originally floats, scaled & rounded by a factor of 2**16
Expand Down
4 changes: 2 additions & 2 deletions examples/neural_net/nada-project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "complex_model"
name = "neural_net"
version = "0.1.0"
authors = [""]

[[programs]]
path = "src/main.py"
path = "src/neural_net.py"
prime_size = 128
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/neural_net/tests/neural_net.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
program: main
program: neural_net
inputs:
secrets:
# We assume all values were originally floats, scaled & rounded by a factor of 2**16
Expand Down
2 changes: 1 addition & 1 deletion examples/time_series/nada-project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ version = "0.1.0"
authors = [""]

[[programs]]
path = "src/main.py"
path = "src/time_series.py"
prime_size = 128
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/time_series/tests/time_series.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
program: main
program: time_series
inputs:
secrets:
my_prophet_changepoints_t_8:
Expand Down
2 changes: 1 addition & 1 deletion nada_ai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(self, model: sklearn.base.BaseEstimator) -> None:
class ProphetClient(ModelClient):
"""ModelClient for Prophet models"""

def __init__(self, model) -> None:
def __init__(self, model: "prophet.forecaster.Prophet") -> None:
"""
Client initialization.
Expand Down
10 changes: 5 additions & 5 deletions nada_ai/time_series/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,8 @@ def predict(
Returns:
na.NadaArray: Forecasted values.
"""
assert len(dates) == len(
floor
), "Provided Prophet inputs must be equally sized."
assert len(floor) == len(t), "Provided Prophet inputs must be equally sized."
assert len(dates) == len(floor), "Prophet inputs must be equally sized."
assert len(floor) == len(t), "Prophet inputs must be equally sized."

dates = self.ensure_numeric_dates(dates)
trend = self.predict_trend(floor, t)
Expand All @@ -221,7 +219,9 @@ def ensure_numeric_dates(self, dates: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: Standardized dates.
"""
if np.issubdtype(dates.dtype, (np.integer, np.floating)):
if np.issubdtype(dates.dtype, np.integer) or np.issubdtype(
dates.dtype, np.floating
):
return dates
if np.issubdtype(dates.dtype, np.datetime64):
return dates.astype(np.float64)
Expand Down
10 changes: 0 additions & 10 deletions tests/python-tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def forward(x: na.NadaArray) -> na.NadaArray: ...
def test_parameters_5(self):
param = Parameter((2, 3))

assert param[0][0] == Integer(0)

alphas = na.alphas((2, 3), alpha=Integer(42))
param.load_state(alphas)

Expand All @@ -96,9 +94,6 @@ def forward(x: na.NadaArray) -> na.NadaArray: ...

mod = TestModule()

for _, param in mod.named_parameters():
assert param[0][0] == Integer(0)

alphas = na.alphas((2, 3), alpha=Integer(42))

for _, param in mod.named_parameters():
Expand All @@ -124,9 +119,6 @@ def forward(x: na.NadaArray) -> na.NadaArray: ...

mod = TestModule2()

for _, param in mod.named_parameters():
assert param[0][0] == Integer(0)

alphas = na.alphas((2, 3), alpha=Integer(42))

for _, param in mod.named_parameters():
Expand All @@ -138,8 +130,6 @@ def forward(x: na.NadaArray) -> na.NadaArray: ...
def test_parameters_8(self):
param = Parameter((2, 3))

assert param[0][0] == Integer(0)

alphas = na.alphas((3, 3), alpha=Integer(42))

with pytest.raises(MismatchedShapesException):
Expand Down

0 comments on commit 1ef4145

Please sign in to comment.