From ac75933b8dd17b6182d093899f75dd0fe34c7bea Mon Sep 17 00:00:00 2001 From: Gianluca Ficarelli <26835404+GianlucaFicarelli@users.noreply.github.com> Date: Fri, 9 Jun 2023 11:19:32 +0200 Subject: [PATCH] Add t_step parameter to frame reports (#218) --- CHANGELOG.rst | 1 + bluepysnap/frame_report.py | 15 +++++- tests/test_frame_report.py | 104 +++++++++++++++++++++++++------------ 3 files changed, 84 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b1bf2cc6..f00161e2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,7 @@ Improvements - Clarification for partial circuit configs - Publish version as ``bluepysnap.__version__`` - Support lazy loading of nodes attributes. +- Add t_step parameter to frame reports. - Add python 3.11 tests. - Drop python 3.7 support. diff --git a/bluepysnap/frame_report.py b/bluepysnap/frame_report.py index ecb78006..6cc440c0 100644 --- a/bluepysnap/frame_report.py +++ b/bluepysnap/frame_report.py @@ -76,20 +76,31 @@ def _wrap_columns(columns): """Allows to change the columns names if needed.""" return columns - def get(self, group=None, t_start=None, t_stop=None): + def get(self, group=None, t_start=None, t_stop=None, t_step=None): """Fetch data from the report. Args: group (None/int/list/np.array/dict): Get frames filtered by :ref:`Group Concept`. t_start (float): Include only frames occurring at or after this time. t_stop (float): Include only frames occurring at or before this time. + t_step (float): Optional time step, useful to reduce the number of samples. + It should be a multiple of the report time step dt, and it's equal to dt by default. + If the given t_step isn't an exact multiple, it's rounded to the closer multiple. + Only the samples at t = t0 + k * t_step, for k = 0, 1... are returned, + where t0 is the first sample time >= t_start. Returns: pandas.DataFrame: frame as columns indexed by timestamps. """ + t_stride = round(t_step / self.frame_report.dt) if t_step is not None else 1 + if t_stride < 1: + msg = f"Invalid {t_step=}. It should be None or a multiple of {self.frame_report.dt}." + raise BluepySnapError(msg) ids = self._resolve(group).tolist() try: - view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop) + view = self._frame_population.get( + node_ids=ids, tstart=t_start, tstop=t_stop, tstride=t_stride + ) except SonataError as e: raise BluepySnapError(e) from e diff --git a/tests/test_frame_report.py b/tests/test_frame_report.py index b79504f1..5127d339 100644 --- a/tests/test_frame_report.py +++ b/tests/test_frame_report.py @@ -220,65 +220,101 @@ def test_nodes_invalid_population(self): with pytest.raises(BluepySnapError): test_obj.nodes - def test_get(self): - pdt.assert_frame_equal(self.test_obj.get(), self.df) - pdt.assert_frame_equal(self.test_obj.get([]), pd.DataFrame()) - pdt.assert_frame_equal(self.test_obj.get(np.array([])), pd.DataFrame()) - pdt.assert_frame_equal(self.test_obj.get(()), pd.DataFrame()) - - pdt.assert_frame_equal(self.test_obj.get(2), self.df.loc[:, [2]]) - pdt.assert_frame_equal(self.test_obj.get(CircuitNodeId("default", 2)), self.df.loc[:, [2]]) + @pytest.mark.parametrize("t_step", [None, 0.02, 0.04, 0.0401, 0.0399, 0.05, 200000]) + def test_get(self, t_step): + def _assert_frame_equal(df1, df2): + # compare df1 and df2, after filtering df2 according to t_stride + df2 = df2.iloc[::t_stride] + pdt.assert_frame_equal(df1, df2) + + # calculate the expected t_stride, depending on t_step and dt (varying across tests) + t_stride = round(t_step / self.test_obj.frame_report.dt) if t_step is not None else 1 + + _assert_frame_equal(self.test_obj.get(t_step=t_step), self.df) + _assert_frame_equal(self.test_obj.get([], t_step=t_step), pd.DataFrame()) + _assert_frame_equal(self.test_obj.get(np.array([]), t_step=t_step), pd.DataFrame()) + _assert_frame_equal(self.test_obj.get((), t_step=t_step), pd.DataFrame()) + + _assert_frame_equal(self.test_obj.get(2, t_step=t_step), self.df.loc[:, [2]]) + _assert_frame_equal( + self.test_obj.get(CircuitNodeId("default", 2), t_step=t_step), self.df.loc[:, [2]] + ) # not from this population - pdt.assert_frame_equal(self.test_obj.get(CircuitNodeId("default2", 2)), pd.DataFrame()) + _assert_frame_equal( + self.test_obj.get(CircuitNodeId("default2", 2), t_step=t_step), pd.DataFrame() + ) - pdt.assert_frame_equal(self.test_obj.get([2, 0]), self.df.loc[:, [0, 2]]) + _assert_frame_equal(self.test_obj.get([2, 0], t_step=t_step), self.df.loc[:, [0, 2]]) - pdt.assert_frame_equal(self.test_obj.get([0, 2]), self.df.loc[:, [0, 2]]) + _assert_frame_equal(self.test_obj.get([0, 2], t_step=t_step), self.df.loc[:, [0, 2]]) - pdt.assert_frame_equal(self.test_obj.get(np.asarray([0, 2])), self.df.loc[:, [0, 2]]) + _assert_frame_equal( + self.test_obj.get(np.asarray([0, 2]), t_step=t_step), self.df.loc[:, [0, 2]] + ) - pdt.assert_frame_equal(self.test_obj.get([2], t_stop=0.5), self.df.iloc[:6].loc[:, [2]]) + _assert_frame_equal( + self.test_obj.get([2], t_stop=0.5, t_step=t_step), self.df.iloc[:6].loc[:, [2]] + ) - pdt.assert_frame_equal(self.test_obj.get([2], t_stop=0.55), self.df.iloc[:6].loc[:, [2]]) + _assert_frame_equal( + self.test_obj.get([2], t_stop=0.55, t_step=t_step), self.df.iloc[:6].loc[:, [2]] + ) - pdt.assert_frame_equal(self.test_obj.get([2], t_start=0.5), self.df.iloc[5:].loc[:, [2]]) + _assert_frame_equal( + self.test_obj.get([2], t_start=0.5, t_step=t_step), self.df.iloc[5:].loc[:, [2]] + ) - pdt.assert_frame_equal( - self.test_obj.get([2], t_start=0.5, t_stop=0.8), self.df.iloc[5:9].loc[:, [2]] + _assert_frame_equal( + self.test_obj.get([2], t_start=0.5, t_stop=0.8, t_step=t_step), + self.df.iloc[5:9].loc[:, [2]], ) - pdt.assert_frame_equal( - self.test_obj.get([2, 1], t_start=0.5, t_stop=0.8), self.df.iloc[5:9].loc[:, [1, 2]] + _assert_frame_equal( + self.test_obj.get([2, 1], t_start=0.5, t_stop=0.8, t_step=t_step), + self.df.iloc[5:9].loc[:, [1, 2]], ) - pdt.assert_frame_equal( - self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8), self.df.iloc[2:9].loc[:, [1, 2]] + _assert_frame_equal( + self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8, t_step=t_step), + self.df.iloc[2:9].loc[:, [1, 2]], ) - pdt.assert_frame_equal( - self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8), + _assert_frame_equal( + self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8, t_step=t_step), self.df.iloc[2:9].loc[:, [1, 2]], ) - pdt.assert_frame_equal(self.test_obj.get(group={Cell.MTYPE: "L2_X"}), self.df.loc[:, [0]]) + _assert_frame_equal( + self.test_obj.get(group={Cell.MTYPE: "L2_X"}, t_step=t_step), self.df.loc[:, [0]] + ) - pdt.assert_frame_equal(self.test_obj.get(group="Layer23"), self.df.loc[:, [0]]) + _assert_frame_equal(self.test_obj.get(group="Layer23", t_step=t_step), self.df.loc[:, [0]]) ids = CircuitNodeIds.from_arrays(["default", "default", "default2"], [0, 2, 1]) - pdt.assert_frame_equal(self.test_obj.get(group=ids), self.df.loc[:, [0, 2]]) + _assert_frame_equal(self.test_obj.get(group=ids, t_step=t_step), self.df.loc[:, [0, 2]]) - with pytest.raises(BluepySnapError): - self.test_obj.get(-1, t_start=0.2) + with pytest.raises( + BluepySnapError, match="All node IDs must be >= 0 and < 3 for population 'default'" + ): + self.test_obj.get(-1, t_start=0.2, t_step=t_step) - with pytest.raises(BluepySnapError): - self.test_obj.get(0, t_start=-1) + with pytest.raises(BluepySnapError, match="Times cannot be negative"): + self.test_obj.get(0, t_start=-1, t_step=t_step) - with pytest.raises(BluepySnapError): - self.test_obj.get([0, 2], t_start=15) + with pytest.raises(BluepySnapError, match="tstart is after the end of the range"): + self.test_obj.get([0, 2], t_start=15, t_step=t_step) - with pytest.raises(BluepySnapError): - self.test_obj.get(4) + with pytest.raises( + BluepySnapError, match="All node IDs must be >= 0 and < 3 for population 'default'" + ): + self.test_obj.get(4, t_step=t_step) + + @pytest.mark.parametrize("t_step", [0, -1, 0.0000001]) + def test_get_with_invalid_t_step(self, t_step): + match = f"Invalid t_step={t_step}. It should be None or a multiple of" + with pytest.raises(BluepySnapError, match=match): + self.test_obj.get(t_step=t_step) def test_get_partially_not_in_report(self): with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([0, 4])):