Skip to content

Commit

Permalink
Add support for parameters dataclass (#57)
Browse files Browse the repository at this point in the history
* Add support for parameter dataclasses

* Update test

* Update version

* Update README

* Update typing

* Allow single value parameters

* Update changelog
  • Loading branch information
BenSchZA authored Sep 3, 2022
1 parent 744b5f4 commit 74afce7
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 9 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.11.1] - 2022-09-03
### Added
- Add support for parameters dataclass
- Add support for single value parameters (i.e. not a list)

## [0.11.0] - 2022-09-03
### Changed
- Adapted core deepcopy processes to be more efficient and avoid unintended mutation of state between policy and state update functions
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ result = experiment.run()
* [x] Distributed computing and remote execution in a cluster (AWS, GCP, Kubernetes, ...) using [Ray - Fast and Simple Distributed Computing](https://ray.io/)
* [x] Hooks to easily extend the functionality - e.g. save results to HDF5 file format after completion
* [x] Model classes are iterable, so you can iterate over them step-by-step from one state to the next (useful for gradient descent, live digital twins)
* [x] Parameters can be a dataclass! This enables typing and dot notation for accessing parameters.

## Installation

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "radcad"
version = "0.11.0"
version = "0.11.1"
description = "A Python package for dynamical systems modelling & simulation, inspired by and compatible with cadCAD"
authors = ["CADLabs <[email protected]>"]
packages = [
Expand Down
2 changes: 1 addition & 1 deletion radcad/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.11.0"
__version__ = "0.11.1"

from radcad.wrappers import Context, Model, Simulation, Experiment
from radcad.engine import Engine
Expand Down
25 changes: 18 additions & 7 deletions radcad/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import logging
import pickle
import traceback
from typing import Dict, List, Tuple, Callable
from typing import Union, Dict, List, Tuple, Callable
from dataclasses import asdict, dataclass, field, is_dataclass

from radcad.utils import Dataclass


# Use "radCAD" logging instance to avoid conflict with other projects
Expand Down Expand Up @@ -169,16 +172,21 @@ def _single_run_wrapper(args):
return [], e


def generate_parameter_sweep(params: Dict[str, List[any]]):
def generate_parameter_sweep(params: Union[Dict[str, List[any]], Dataclass]):
_is_dataclass = is_dataclass(params)
_params = asdict(params) if _is_dataclass else params

param_sweep = []
max_len = 0
for value in params.values():
if len(value) > max_len:
max_len = 1
for value in _params.values():
if isinstance(value, list) and len(value) > max_len:
max_len = len(value)

for sweep_index in range(0, max_len):
param_set = {}
for (key, value) in params.items():
for (key, value) in _params.items():
if not isinstance(value, list):
value = [value]
param = (
value[sweep_index]
if sweep_index < len(value)
Expand All @@ -187,7 +195,10 @@ def generate_parameter_sweep(params: Dict[str, List[any]]):
param_set[key] = param
param_sweep.append(param_set)

return param_sweep
if _is_dataclass:
return [params.__class__(**subset) for subset in param_sweep]
else:
return param_sweep


def _add_signals(acc, a: Dict[str, any]):
Expand Down
96 changes: 96 additions & 0 deletions radcad/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import itertools
from typing import Dict
from typing_extensions import Protocol
import copy
import datetime
from dataclasses import field
from functools import partial


def flatten(nested_list):
Expand All @@ -22,3 +28,93 @@ def generate_cartesian_product_parameter_sweep(params):
cartesian_product = list(itertools.product(*params.values()))
param_sweep = {key: [x[i] for x in cartesian_product] for i, key in enumerate(params.keys())}
return param_sweep


class Dataclass(Protocol):
# Checking for this attribute is currently
# the most reliable way to ascertain that something is a dataclass
__dataclass_fields__: Dict


def _update_from_signal(
state_variable,
signal_key,
optional_update,
params,
substep,
state_history,
previous_state,
policy_input,
):
"""A private function used to generate the partial function returned by `update_from_signal(...)`."""
if not signal_key in policy_input and optional_update:
return state_variable, previous_state[state_variable]
else:
return state_variable, policy_input[signal_key]


def update_from_signal(state_variable, signal_key=None, optional_update=False):
"""
A generic State Update Function to update a State Variable directly from a Policy Signal,
useful to avoid boilerplate code.
Args:
state_variable (str): State Variable key
signal_key (str, optional): Policy Signal key. Defaults to None.
optional_update (bool, optional): If True, only update State Variable if Policy Signal key exists.
Returns:
Callable: A generic State Update Function
"""
if not signal_key:
signal_key = state_variable
return partial(_update_from_signal, state_variable, signal_key, optional_update)


def _accumulate_from_signal(
state_variable,
signal_key,
params,
substep,
state_history,
previous_state,
policy_input,
):
"""A private function used to generate the partial function returned by `accumulate_from_signal(...)`."""
return state_variable, previous_state[state_variable] + policy_input[signal_key]


def accumulate_from_signal(state_variable, signal_key=None):
"""
A generic State Update Function to accumulate a State Variable directly from a Policy Signal,
useful to avoid boilerplate code.
"""
if not signal_key:
signal_key = state_variable
return partial(_accumulate_from_signal, state_variable, signal_key)


def update_timestamp(params, substep, state_history, previous_state, policy_input):
"""
A radCAD State Update Function used to calculate and update the current timestamp
given a timestep and starting date parameter.
"""
# Parameters
dt = params["dt"]
date_start = params["date_start"]

# State Variables
timestep = previous_state["timestep"]

# Calculate current timestamp from timestep
timestamp = date_start + datetime.timedelta(days=timestep * dt)

return "timestamp", timestamp


def local_variables(_locals):
"""Return a dictionary of all local variables, useful for debugging."""
return {key: _locals[key] for key in [_key for _key in _locals.keys() if "__" not in _key]}


def default(obj):
"""Used and necessary when setting the default value of a dataclass field to a list."""
return field(default_factory=lambda: copy.copy(obj))
77 changes: 77 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import List
import pytest
from dataclasses import dataclass

import radcad.core as core
from radcad.core import generate_parameter_sweep, reduce_signals

from radcad import Model, Simulation, Experiment
from radcad.engine import flatten
from radcad.utils import default

from tests.test_cases import basic


def test_generate_parameter_sweep():
params = {
'a': [0],
Expand All @@ -31,6 +35,79 @@ def test_generate_parameter_sweep():
param_sweep = generate_parameter_sweep(params)
assert param_sweep == [{'a': 0, 'b': 0, 'c': 0}, {'a': 1, 'b': 1, 'c': 0}, {'a': 2, 'b': 1, 'c': 0}]


def test_generate_dataclass_parameter_sweep():
@dataclass
class P1:
a: List[int] = default([0])
b: List[int] = default([0])
param_sweep = generate_parameter_sweep(P1())
assert param_sweep == [P1(**{'a': 0, 'b': 0})]

@dataclass
class P2:
a: List[int] = default([0, 1, 2])
b: List[int] = default([0])
param_sweep = generate_parameter_sweep(P2())
assert param_sweep == [P2(**{'a': 0, 'b': 0}), P2(**{'a': 1, 'b': 0}), P2(**{'a': 2, 'b': 0})]

@dataclass
class P3:
a: List[int] = default([0, 1, 2])
b: List[int] = default([0, 1])
c: List[int] = default([0])
param_sweep = generate_parameter_sweep(P3())
assert param_sweep == [P3(**{'a': 0, 'b': 0, 'c': 0}), P3(**{'a': 1, 'b': 1, 'c': 0}), P3(**{'a': 2, 'b': 1, 'c': 0})]


def test_generate_single_value_dataclass_parameter_sweep():
@dataclass
class P1:
a: int = 0
b: int = 0
param_sweep = generate_parameter_sweep(P1())
assert param_sweep == [P1(**{'a': 0, 'b': 0})]

@dataclass
class P2:
a: List[int] = default([0, 1, 2])
b: int = 0
param_sweep = generate_parameter_sweep(P2())
assert param_sweep == [P2(**{'a': 0, 'b': 0}), P2(**{'a': 1, 'b': 0}), P2(**{'a': 2, 'b': 0})]

@dataclass
class P3:
a: List[int] = default([0, 1, 2])
b: List[int] = default([0, 1])
c: int = 0
param_sweep = generate_parameter_sweep(P3())
assert param_sweep == [P3(**{'a': 0, 'b': 0, 'c': 0}), P3(**{'a': 1, 'b': 1, 'c': 0}), P3(**{'a': 2, 'b': 1, 'c': 0})]


def test_generate_single_value_parameter_sweep():
params = {
'a': 0,
'b': 0
}
param_sweep = generate_parameter_sweep(params)
assert param_sweep == [{'a': 0, 'b': 0}]

params = {
'a': [0, 1, 2],
'b': 0
}
param_sweep = generate_parameter_sweep(params)
assert param_sweep == [{'a': 0, 'b': 0}, {'a': 1, 'b': 0}, {'a': 2, 'b': 0}]

params = {
'a': [0, 1, 2],
'b': [0, 1],
'c': 0
}
param_sweep = generate_parameter_sweep(params)
assert param_sweep == [{'a': 0, 'b': 0, 'c': 0}, {'a': 1, 'b': 1, 'c': 0}, {'a': 2, 'b': 1, 'c': 0}]


def test_reduce_signals():
psu = {
'policies': {
Expand Down
41 changes: 41 additions & 0 deletions tests/test_dataclass_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dataclasses import dataclass
from typing import List
from radcad import Model, Simulation, Experiment, Backend
from radcad.utils import default


# NOTE To pickle a dataclass, it must be defined in module and not function scope
@dataclass
class P1:
subset: List[int] = default([0, 1, 2])
a: List[int] = default([0, 1])


def policy(params: P1, substep, state_history, previous_state):
subset = previous_state['subset']
assert subset == params.subset
assert params.a == params.subset if subset < 2 else params.a == 1
return {}


def test_basic_state_update():
initial_state = {}

state_update_blocks = [
{
'policies': {
'p': policy
},
'variables': {}
},
]

params = P1()

TIMESTEPS = 10
RUNS = 3

model = Model(initial_state=initial_state, state_update_blocks=state_update_blocks, params=params)
simulation = Simulation(model=model, timesteps=TIMESTEPS, runs=RUNS)
experiment = Experiment(simulation)
_result = experiment.run()

0 comments on commit 74afce7

Please sign in to comment.