Skip to content

Commit

Permalink
More informative error messages when policies / SUFs have the wrong o…
Browse files Browse the repository at this point in the history
…utput (#318)

* add error msg

---------

Co-authored-by: Emanuel Lima <[email protected]>
  • Loading branch information
danlessa and emanuellima1 authored Dec 15, 2023
1 parent 546881a commit 676f0e0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
4 changes: 2 additions & 2 deletions cadCAD/configuration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, Callable, List, Tuple
from pandas.core.frame import DataFrame
from pandas.core.frame import DataFrame # type: ignore
from datetime import datetime
from collections import deque
from copy import deepcopy
import pandas as pd
import pandas as pd # type: ignore

from cadCAD.utils import key_filter
from cadCAD.configuration.utils import exo_update_per_ts, configs_as_objs
Expand Down
17 changes: 14 additions & 3 deletions cadCAD/engine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ def compose(init_reduction_funct, funct_list, val_list):
return result

col_results = get_col_results(sweep_dict, sub_step, sL, s, funcs)
key_set = list(set(list(reduce(lambda a, b: a + b, list(map(lambda x: list(x.keys()), col_results))))))
try:
reducer_arg = list(map(lambda x: list(x.keys()), col_results))
except:
raise ValueError("There is a Policy Function that has not properly returned a Dictionary")
reducer_function = lambda a, b: a + b
key_set = list(set(reduce(reducer_function, reducer_arg)))
new_dict = {k: [] for k in key_set}
for d in col_results:
for k in d.keys():
Expand Down Expand Up @@ -146,10 +151,16 @@ def transfer_missing_fields(source, destination):
for k in source:
if k not in destination:
destination[k] = source[k]
del source
del source
return destination

last_in_copy: Dict[str, Any] = transfer_missing_fields(last_in_obj, dict(generate_record(state_funcs)))
try:
new_state_vars = dict(generate_record(state_funcs))
except (ValueError, TypeError):
raise ValueError("There is a State Update Function which is not returning an proper tuple")


last_in_copy: Dict[str, Any] = transfer_missing_fields(last_in_obj, new_state_vars)
last_in_copy: Dict[str, Any] = self.apply_env_proc(sweep_dict, env_processes, last_in_copy)
last_in_copy['substep'], last_in_copy['timestep'], last_in_copy['run'] = sub_step, time_step, run

Expand Down
4 changes: 2 additions & 2 deletions testing/test_arg_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def test_sufs():
psubs = [
{
'policies': {
'p_A': lambda _1, _2, _3, _4: {}
'p_A': lambda _1, _2, _3, _4, _5: {}
},
'variables': {
'v_A': lambda _1, _2, _3, _4, _5: ('v_a', None)
'v_A': lambda _1, _2, _3, _4, _5, _6: ('a', 1)
}
}
]
Expand Down

0 comments on commit 676f0e0

Please sign in to comment.