Skip to content

Commit

Permalink
Add t_step parameter to frame reports (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
GianlucaFicarelli authored Jun 9, 2023
1 parent 2a510ac commit ac75933
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Improvements
- Clarification for partial circuit configs
- Publish version as ``bluepysnap.__version__``
- Support lazy loading of nodes attributes.
- Add t_step parameter to frame reports.
- Add python 3.11 tests.
- Drop python 3.7 support.

Expand Down
15 changes: 13 additions & 2 deletions bluepysnap/frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,31 @@ def _wrap_columns(columns):
"""Allows to change the columns names if needed."""
return columns

def get(self, group=None, t_start=None, t_stop=None):
def get(self, group=None, t_start=None, t_stop=None, t_step=None):
"""Fetch data from the report.
Args:
group (None/int/list/np.array/dict): Get frames filtered by :ref:`Group Concept`.
t_start (float): Include only frames occurring at or after this time.
t_stop (float): Include only frames occurring at or before this time.
t_step (float): Optional time step, useful to reduce the number of samples.
It should be a multiple of the report time step dt, and it's equal to dt by default.
If the given t_step isn't an exact multiple, it's rounded to the closer multiple.
Only the samples at t = t0 + k * t_step, for k = 0, 1... are returned,
where t0 is the first sample time >= t_start.
Returns:
pandas.DataFrame: frame as columns indexed by timestamps.
"""
t_stride = round(t_step / self.frame_report.dt) if t_step is not None else 1
if t_stride < 1:
msg = f"Invalid {t_step=}. It should be None or a multiple of {self.frame_report.dt}."
raise BluepySnapError(msg)
ids = self._resolve(group).tolist()
try:
view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop)
view = self._frame_population.get(
node_ids=ids, tstart=t_start, tstop=t_stop, tstride=t_stride
)
except SonataError as e:
raise BluepySnapError(e) from e

Expand Down
104 changes: 70 additions & 34 deletions tests/test_frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,65 +220,101 @@ def test_nodes_invalid_population(self):
with pytest.raises(BluepySnapError):
test_obj.nodes

def test_get(self):
pdt.assert_frame_equal(self.test_obj.get(), self.df)
pdt.assert_frame_equal(self.test_obj.get([]), pd.DataFrame())
pdt.assert_frame_equal(self.test_obj.get(np.array([])), pd.DataFrame())
pdt.assert_frame_equal(self.test_obj.get(()), pd.DataFrame())

pdt.assert_frame_equal(self.test_obj.get(2), self.df.loc[:, [2]])
pdt.assert_frame_equal(self.test_obj.get(CircuitNodeId("default", 2)), self.df.loc[:, [2]])
@pytest.mark.parametrize("t_step", [None, 0.02, 0.04, 0.0401, 0.0399, 0.05, 200000])
def test_get(self, t_step):
def _assert_frame_equal(df1, df2):
# compare df1 and df2, after filtering df2 according to t_stride
df2 = df2.iloc[::t_stride]
pdt.assert_frame_equal(df1, df2)

# calculate the expected t_stride, depending on t_step and dt (varying across tests)
t_stride = round(t_step / self.test_obj.frame_report.dt) if t_step is not None else 1

_assert_frame_equal(self.test_obj.get(t_step=t_step), self.df)
_assert_frame_equal(self.test_obj.get([], t_step=t_step), pd.DataFrame())
_assert_frame_equal(self.test_obj.get(np.array([]), t_step=t_step), pd.DataFrame())
_assert_frame_equal(self.test_obj.get((), t_step=t_step), pd.DataFrame())

_assert_frame_equal(self.test_obj.get(2, t_step=t_step), self.df.loc[:, [2]])
_assert_frame_equal(
self.test_obj.get(CircuitNodeId("default", 2), t_step=t_step), self.df.loc[:, [2]]
)

# not from this population
pdt.assert_frame_equal(self.test_obj.get(CircuitNodeId("default2", 2)), pd.DataFrame())
_assert_frame_equal(
self.test_obj.get(CircuitNodeId("default2", 2), t_step=t_step), pd.DataFrame()
)

pdt.assert_frame_equal(self.test_obj.get([2, 0]), self.df.loc[:, [0, 2]])
_assert_frame_equal(self.test_obj.get([2, 0], t_step=t_step), self.df.loc[:, [0, 2]])

pdt.assert_frame_equal(self.test_obj.get([0, 2]), self.df.loc[:, [0, 2]])
_assert_frame_equal(self.test_obj.get([0, 2], t_step=t_step), self.df.loc[:, [0, 2]])

pdt.assert_frame_equal(self.test_obj.get(np.asarray([0, 2])), self.df.loc[:, [0, 2]])
_assert_frame_equal(
self.test_obj.get(np.asarray([0, 2]), t_step=t_step), self.df.loc[:, [0, 2]]
)

pdt.assert_frame_equal(self.test_obj.get([2], t_stop=0.5), self.df.iloc[:6].loc[:, [2]])
_assert_frame_equal(
self.test_obj.get([2], t_stop=0.5, t_step=t_step), self.df.iloc[:6].loc[:, [2]]
)

pdt.assert_frame_equal(self.test_obj.get([2], t_stop=0.55), self.df.iloc[:6].loc[:, [2]])
_assert_frame_equal(
self.test_obj.get([2], t_stop=0.55, t_step=t_step), self.df.iloc[:6].loc[:, [2]]
)

pdt.assert_frame_equal(self.test_obj.get([2], t_start=0.5), self.df.iloc[5:].loc[:, [2]])
_assert_frame_equal(
self.test_obj.get([2], t_start=0.5, t_step=t_step), self.df.iloc[5:].loc[:, [2]]
)

pdt.assert_frame_equal(
self.test_obj.get([2], t_start=0.5, t_stop=0.8), self.df.iloc[5:9].loc[:, [2]]
_assert_frame_equal(
self.test_obj.get([2], t_start=0.5, t_stop=0.8, t_step=t_step),
self.df.iloc[5:9].loc[:, [2]],
)

pdt.assert_frame_equal(
self.test_obj.get([2, 1], t_start=0.5, t_stop=0.8), self.df.iloc[5:9].loc[:, [1, 2]]
_assert_frame_equal(
self.test_obj.get([2, 1], t_start=0.5, t_stop=0.8, t_step=t_step),
self.df.iloc[5:9].loc[:, [1, 2]],
)

pdt.assert_frame_equal(
self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8), self.df.iloc[2:9].loc[:, [1, 2]]
_assert_frame_equal(
self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8, t_step=t_step),
self.df.iloc[2:9].loc[:, [1, 2]],
)

pdt.assert_frame_equal(
self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8),
_assert_frame_equal(
self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8, t_step=t_step),
self.df.iloc[2:9].loc[:, [1, 2]],
)

pdt.assert_frame_equal(self.test_obj.get(group={Cell.MTYPE: "L2_X"}), self.df.loc[:, [0]])
_assert_frame_equal(
self.test_obj.get(group={Cell.MTYPE: "L2_X"}, t_step=t_step), self.df.loc[:, [0]]
)

pdt.assert_frame_equal(self.test_obj.get(group="Layer23"), self.df.loc[:, [0]])
_assert_frame_equal(self.test_obj.get(group="Layer23", t_step=t_step), self.df.loc[:, [0]])

ids = CircuitNodeIds.from_arrays(["default", "default", "default2"], [0, 2, 1])
pdt.assert_frame_equal(self.test_obj.get(group=ids), self.df.loc[:, [0, 2]])
_assert_frame_equal(self.test_obj.get(group=ids, t_step=t_step), self.df.loc[:, [0, 2]])

with pytest.raises(BluepySnapError):
self.test_obj.get(-1, t_start=0.2)
with pytest.raises(
BluepySnapError, match="All node IDs must be >= 0 and < 3 for population 'default'"
):
self.test_obj.get(-1, t_start=0.2, t_step=t_step)

with pytest.raises(BluepySnapError):
self.test_obj.get(0, t_start=-1)
with pytest.raises(BluepySnapError, match="Times cannot be negative"):
self.test_obj.get(0, t_start=-1, t_step=t_step)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=15)
with pytest.raises(BluepySnapError, match="tstart is after the end of the range"):
self.test_obj.get([0, 2], t_start=15, t_step=t_step)

with pytest.raises(BluepySnapError):
self.test_obj.get(4)
with pytest.raises(
BluepySnapError, match="All node IDs must be >= 0 and < 3 for population 'default'"
):
self.test_obj.get(4, t_step=t_step)

@pytest.mark.parametrize("t_step", [0, -1, 0.0000001])
def test_get_with_invalid_t_step(self, t_step):
match = f"Invalid t_step={t_step}. It should be None or a multiple of"
with pytest.raises(BluepySnapError, match=match):
self.test_obj.get(t_step=t_step)

def test_get_partially_not_in_report(self):
with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([0, 4])):
Expand Down

0 comments on commit ac75933

Please sign in to comment.