Skip to content

Commit

Permalink
Added intersection() method to circuit ids (#192)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Joni Herttuainen <[email protected]>
  • Loading branch information
HDictus and Joni Herttuainen authored Jun 23, 2023
1 parent 08daf28 commit c833b8e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog
Version v1.0.7
--------------

New Features
~~~~~~~~~~~~
- Added ``CircuitIds.intersection`` to take the intersection of two ``CircuitIds``.

Bug Fixes
~~~~~~~~~
- Fix CircuitIds.sample() to always return different samples.
Expand Down
12 changes: 12 additions & 0 deletions bluepysnap/circuit_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ def unique(self, inplace=False):
"""
return self._apply(lambda x: x.unique(), inplace)

def intersection(self, circuit_ids, inplace=False):
"""Take the intersection of this CircuitIds and the input.
The index of the resulting object is sorted if ``inplace=False``.
Otherwise, the orginal order of the index is kept.
Args:
circuit_ids (CircuitIds): The CircuitIds to intersect with.
inplace (bool): if set to True, do the transformation inplace.
"""
return self._apply(lambda x: x.intersection(circuit_ids.index), inplace)

def to_csv(self, filepath):
"""Save CircuitIds to csv format."""
self.index.to_frame(index=False).to_csv(filepath, index=False)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_circuit_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,23 @@ def test_append(self):
)
assert test_obj == expected

def test_intersection(self):
test_obj = self.ids_cls.from_tuples(_multi_index(), sort_index=False)
other = self.ids_cls.from_tuples([("b", 0), ("a", 3), ("a", 2)], sort_index=False)
expected = self.ids_cls.from_tuples([("a", 2), ("b", 0)])
res = test_obj.intersection(other)

# res should be sorted when inplace=False
assert res == expected
assert test_obj != expected

res = test_obj.intersection(other, inplace=True)
assert res is None

# test_obj index should not be sorted when inplace=True
assert test_obj != expected
assert all(expected.index == test_obj.index.sort_values())

def test_sample(self):
with patch("numpy.random.choice", return_value=np.array([0, 3])):
tested = self.test_obj_unsorted.sample(2, inplace=False)
Expand Down

0 comments on commit c833b8e

Please sign in to comment.