Skip to content

Commit

Permalink
[BBPBGLIB-1044] Sonata Replay (#8)
Browse files Browse the repository at this point in the history
* implement a simple reader for Sonata spikes
* Integrate with replay manager
  • Loading branch information
ferdonline authored Aug 2, 2023
1 parent f388c35 commit 9606979
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 22 deletions.
17 changes: 11 additions & 6 deletions neurodamus/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .cell_distributor import LoadBalance, LoadBalanceMode
from .connection_manager import SynapseRuleManager, edge_node_pop_names
from .gap_junction import GapJunctionManager
from .replay import SpikeManager
from .replay import MissingSpikesPopulationError, SpikeManager
from .stimulus_manager import StimulusManager
from .modification_manager import ModificationManager
from .neuromodulation_manager import NeuroModulationManager
Expand Down Expand Up @@ -734,7 +734,6 @@ def enable_replay(self):
def _enable_replay(self, source, target, stim_conf, tshift=.0, delay=.0,
connectivity_type=None):
spike_filepath = find_input_file(stim_conf["SpikeFile"])
spike_manager = SpikeManager(spike_filepath, tshift) # Disposable
ptype_cls = EngineBase.connection_types.get(connectivity_type)
src_target = self.target_manager.get_target(source)
dst_target = self.target_manager.get_target(target)
Expand All @@ -743,6 +742,13 @@ def _enable_replay(self, source, target, stim_conf, tshift=.0, delay=.0,
pop_offsets, alias_pop = CircuitManager.read_population_offsets(read_virtual_pop=True)

for src_pop in src_target.population_names:
try:
log_verbose("Loading replay spikes for population '%s'", src_pop)
spike_manager = SpikeManager(spike_filepath, tshift, src_pop) # Disposable
except MissingSpikesPopulationError:
logging.info(" > No replay for src population: '%s'", src_pop)
continue

for dst_pop in dst_target.population_names:
src_pop_str, dst_pop_str = src_pop or "(base)", dst_pop or "(base)"

Expand All @@ -751,10 +757,9 @@ def _enable_replay(self, source, target, stim_conf, tshift=.0, delay=.0,
else pop_offsets[alias_pop[src_pop]]
else:
conn_manager = self._circuits.get_edge_manager(src_pop, dst_pop, ptype_cls)
if not conn_manager:
logging.error("No edge manager found among populations %s -> %s",
src_pop_str, dst_pop_str)
raise ConfigurationError("Unknown replay pathway. Check Source / Target")
if not conn_manager and SimConfig.cli_options.restrict_connectivity >= 1:
continue
assert conn_manager, f"Missing edge manager for {src_pop_str} -> {dst_pop_str}"
src_pop_offset = conn_manager.src_pop_offset

logging.info("=> Population pathway %s -> %s. Source offset: %d",
Expand Down
28 changes: 22 additions & 6 deletions neurodamus/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
"""
from __future__ import absolute_import
import os
import h5py
import logging
import numpy
from .utils.logging import log_verbose
from .utils.multimap import GroupedMultiMap
from .utils.timeit import timeit


class SpikeManager(object):
class SpikeManager:
""" Holds and manages gid spike time information, specially for Replay.
A SynapseReplay stim can be used for a single gid that has all the synapses instantiated.
Expand All @@ -22,7 +23,7 @@ class SpikeManager(object):
_ascii_spike_dtype = [('time', 'double'), ('gid', 'uint32')]

@timeit(name="Replay init")
def __init__(self, spike_filename, delay=0):
def __init__(self, spike_filename, delay=0, population=None):
"""Constructor for SynapseReplay.
Args:
Expand All @@ -32,10 +33,10 @@ def __init__(self, spike_filename, delay=0):
"""
self._gid_fire_events = None
# Nd.distributedSpikes = 0 # Wonder the effects of this
self.open_spike_file(spike_filename, delay)
self.open_spike_file(spike_filename, delay, population)

#
def open_spike_file(self, filename, delay):
def open_spike_file(self, filename, delay, population=None):
"""Opens a given spike file.
Args:
Expand All @@ -46,7 +47,9 @@ def open_spike_file(self, filename, delay):
# TODO: filename should be able to handle relative paths,
# using the Run.CurrentDir as an initial path
# _read_spikes_xxx shall return numpy arrays
if filename.endswith(".bin"):
if filename.endswith(".h5"):
tvec, gidvec = self._read_spikes_sonata(filename, population)
elif filename.endswith(".bin"):
tvec, gidvec = self._read_spikes_binary(filename)
else:
tvec, gidvec = self._read_spikes_ascii(filename)
Expand All @@ -56,6 +59,15 @@ def open_spike_file(self, filename, delay):

self._store_events(tvec, gidvec)

@classmethod
def _read_spikes_sonata(cls, filename, population):
spikes_file = h5py.File(filename, "r")
# File should have been validated earlier
spikes = spikes_file.get("spikes/" + population)
if spikes is None:
raise MissingSpikesPopulationError("Spikes population not found: " + population)
return spikes["timestamps"][...], spikes["node_ids"][...]

@classmethod
def _read_spikes_ascii(cls, filename):
log_verbose("Reading ascii spike file %s", filename)
Expand Down Expand Up @@ -99,7 +111,6 @@ def _read_spikes_binary(filename):
return tvec, gidvec

#
@timeit(name="BinEvents")
def _store_events(self, tvec, gidvec):
"""Stores the events in the _gid_fire_events GroupedMultiMap.
Expand Down Expand Up @@ -150,3 +161,8 @@ def dump_ascii(self, f, gid_offset=None):
numpy.savetxt(f, expanded_ds, fmt='%.3lf\t%d')

log_verbose("Replay: Written %d entries", len(expanded_ds))


class MissingSpikesPopulationError(Exception):
"""An exception triggered when a given node population is not found, we may want to handle"""
pass
2 changes: 1 addition & 1 deletion neurodamus/target_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _is_sonata_file(file_name):
def _try_open_start_target(self, circuit):
start_target_file = os.path.join(circuit.CircuitPath, "start.target")
if not os.path.isfile(start_target_file):
logging.warning("Circuit %s start.target not available! Skipping", circuit._name)
log_verbose("Circuit %s start.target not available! Skipping", circuit._name)
else:
self.parser.open(start_target_file, False)
self._has_hoc_targets = True
Expand Down
Binary file added tests/sample_data/out.h5
Binary file not shown.
70 changes: 61 additions & 9 deletions tests/scientific/test_replay.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,43 @@
import json
import numpy
import numpy.testing as npt
import os
import pytest
from pathlib import Path
from tempfile import NamedTemporaryFile

USECASE3 = Path(__file__).parent.absolute() / "usecase3"
SAMPLE_DATA_DIR = Path(__file__).parent.parent.absolute() / "sample_data"


@pytest.mark.skipif(
not os.environ.get("NEURODAMUS_NEOCORTEX_ROOT"),
reason="Test requires loading a neocortex model to run"
)
def test_replay(sonata_config):
from neurodamus import Neurodamus
from neurodamus.core.configuration import Feature
def replay_sim_config(sonata_config, replay_file):
sonata_config["inputs"] = {
"spikeReplay": {
"module": "synapse_replay",
"input_type": "spikes",
"spike_file": str(USECASE3 / "input.dat"),
"spike_file": replay_file,
"delay": 0,
"duration": 1000,
"node_set": "nodesPopA"
"node_set": "Mosaic", # no limits!
}
}

# create a tmp json file to read usecase3/no_edge_circuit_config.json
with NamedTemporaryFile("w", suffix='.json', delete=False) as config_file:
json.dump(sonata_config, config_file)

return config_file


@pytest.mark.skipif(
not os.environ.get("NEURODAMUS_NEOCORTEX_ROOT"),
reason="Test requires loading a neocortex model to run"
)
def test_replay_sim(sonata_config):
from neurodamus import Neurodamus
from neurodamus.core.configuration import Feature

config_file = replay_sim_config(sonata_config, str(USECASE3 / "input.dat"))
nd = Neurodamus(
config_file.name,
restrict_node_populations=["NodeA"],
Expand Down Expand Up @@ -62,3 +70,47 @@ def test_replay(sonata_config):
assert numpy.allclose(times, [0.75])

os.unlink(config_file.name)


# A more comprehensive example, using Sonata replay with two populations
# ======================================================================
@pytest.mark.skipif(
not os.environ.get("NEURODAMUS_NEOCORTEX_ROOT"),
reason="Test requires loading a neocortex model to run"
)
def test_replay_sonata_spikes(sonata_config):
from neurodamus import Neurodamus
from neurodamus.core.configuration import Feature

config_file = replay_sim_config(sonata_config, str(SAMPLE_DATA_DIR / "out.h5"))
nd = Neurodamus(
config_file.name,
restrict_features=[Feature.Replay],
disable_reports=True,
cleanup_atexit=False,
logging_level=2,
)

node_managers = nd.circuits.node_managers
assert set(node_managers) == set([None, "NodeA", "NodeB"])

edges_a = nd.circuits.get_edge_manager("NodeA", "NodeA")

conn_2_1 = next(edges_a.get_connections(1, 2))
time_vec = conn_2_1._replay.time_vec.as_numpy()
assert len(time_vec) == 11
npt.assert_allclose(time_vec[:8], [0.15, 3.45, 6.975, 11.05, 15.5, 20.225, 25.175, 30.3])

conn_1_2 = next(edges_a.get_connections(2, 1))
time_vec = conn_1_2._replay.time_vec.as_numpy()
assert len(time_vec) == 11
npt.assert_allclose(time_vec[:8], [0.175, 3.025, 5.7, 8.975, 13.95, 20.15, 26.125, 31.725])

# projections get replay too
edges_a = nd.circuits.get_edge_manager("NodeA", "NodeB")
conn_1_1001 = next(edges_a.get_connections(1001, 2))
time_vec = conn_1_1001._replay.time_vec.as_numpy()
assert len(time_vec) == 11
npt.assert_allclose(time_vec[:8], [0.15, 3.45, 6.975, 11.05, 15.5, 20.225, 25.175, 30.3])

os.unlink(config_file.name)
19 changes: 19 additions & 0 deletions tests/test_replay_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import numpy.testing as npt
from pathlib import Path

SAMPLE_DATA_DIR = Path(__file__).parent.absolute() / "sample_data"


@pytest.mark.forked
def test_replay_manager_sonata():
from neurodamus.replay import SpikeManager, MissingSpikesPopulationError
spikes_sonata = SAMPLE_DATA_DIR / "out.h5"

timestamps, spike_gids = SpikeManager._read_spikes_sonata(spikes_sonata, "NodeA")
npt.assert_allclose(timestamps[:8], [0.1, 0.15, 0.175, 2.275, 3.025, 3.45, 4.35, 5.7])
npt.assert_equal(spike_gids[:8], [0, 2, 1, 0, 1, 2, 0, 1])

# We do an internal assertion when the population doesnt exist. Verify it works as expected
with pytest.raises(MissingSpikesPopulationError, match="Spikes population not found"):
SpikeManager._read_spikes_sonata(spikes_sonata, "wont-exist")

0 comments on commit 9606979

Please sign in to comment.