Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BBPBGLIB-712] Estimate memory usage for synapse and connection #22

Merged
merged 6 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions _benchmarks/synstats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Small script to measure memory usage of synapse objects in NEURON.
# Usage: nrniv (or special) -python synstats.py

from neuron import h
import os


def get_mem_usage():
"""
Return memory usage information in KB.
"""
with open("/proc/self/statm") as fd:
_, data_size, _ = fd.read().split(maxsplit=2)
usage_kb = float(data_size) * os.sysconf("SC_PAGE_SIZE") / 1024

return usage_kb


# Dummy class to pass to synapse objects
class SynParams:
def __getattr__(self, item):
return 1


n_inst = 1000000

h.load_file("RNGSettings.hoc")
h.load_file("Map.hoc")

map_hoc = h.Map()
RNGset = h.RNGSettings()
RNGset.interpret(map_hoc)

pc = h.ParallelContext()

sec = h.Section()
sec.push()
params_obj = SynParams()

h.load_file("AMPANMDAHelper.hoc")
mem = get_mem_usage()
AMPA_helper = [h.AMPANMDAHelper(1, params_obj, 0.5, i, 0) for i in range(n_inst)]
netcon_ampa = [pc.gid_connect(1000, helper.synapse) for helper in AMPA_helper]
mem2 = get_mem_usage()
print('Memory usage per object ProbAMPA: %f KB' % ((mem2 - mem) / n_inst))

h.load_file("GABAABHelper.hoc")
mem = get_mem_usage()
GABAAB_helper = [h.GABAABHelper(1, params_obj, 0.5, i, 0) for i in range(n_inst)]
netcon_gabaab = [pc.gid_connect(1000, helper.synapse) for helper in GABAAB_helper]
mem2 = get_mem_usage()
print('Memory usage per object ProbGABAAB: %f KB' % ((mem2 - mem) / n_inst))

h.load_file("GluSynapseHelper.hoc")
mem = get_mem_usage()
GluSynapse_helper = [h.GluSynapseHelper(1, params_obj, 0.5, i, 0, map_hoc) for i in range(n_inst)]
netcon_glu = [pc.gid_connect(1000, helper.synapse) for helper in GluSynapse_helper]
mem2 = get_mem_usage()
print('Memory usage per object GluSynapse: %f KB' % ((mem2 - mem) / n_inst))

mem = get_mem_usage()
Gap_helper = [h.Gap(0.5) for i in range(n_inst)]
mem2 = get_mem_usage()
print('Memory usage per object Gap: %f KB' % ((mem2 - mem) / n_inst))
19 changes: 18 additions & 1 deletion neurodamus/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hashlib
import logging
import numpy
from collections import defaultdict
from collections import Counter, defaultdict
from itertools import chain
from os import path as ospath
from typing import List, Optional
Expand Down Expand Up @@ -273,6 +273,9 @@ def __init__(self, circuit_conf, target_manager, cell_manager, src_cell_manager=
self._load_offsets = False
self._src_target_filter = None # filter by src target in all_connect (E.g: GapJ)

# An internal var to enable collection of synapse statistics to a Counter
self._synapse_counter: Counter = kw.get("synapse_counter")

def __str__(self):
return "<{:s} | {:s} -> {:s}>".format(
self.__class__.__name__, str(self._src_cell_manager), str(self._cell_manager))
Expand Down Expand Up @@ -530,6 +533,11 @@ def connect_all(self, weight_factor=1, only_gids=None):
weight_factor: Factor to scale all netcon weights (default: 1)
only_gids: Create connections only for these tgids (default: Off)
"""
if self._synapse_counter is not None:
counts = self._get_conn_stats(self._src_target_filter, None)
self._synapse_counter.update(counts)
return

conn_options = {'weight_factor': weight_factor}
pop = self._cur_population

Expand Down Expand Up @@ -565,6 +573,11 @@ def connect_group(self, conn_source, conn_destination, synapse_type_restrict=Non
src_tname, dst_tname)
return

if self._synapse_counter is not None:
counts = self._get_conn_stats(src_target, dst_target)
self._synapse_counter.update(counts)
return

for sgid, tgid, syns_params, extra_params, offset in \
self._iterate_conn_params(src_target, dst_target, mod_override=mod_override):
if sgid == tgid:
Expand Down Expand Up @@ -706,6 +719,10 @@ def target_gids(gids):
pathway_repr = "Pathway {} -> {}".format(src_target.name, dst_target.name)
logging.info(" * %s. Created %d connections", pathway_repr, all_created)

def _get_conn_stats(self, _src_target, dst_target):
raw_gids = dst_target.get_local_gids(raw_gids=True) if dst_target else self._raw_gids
return self._synapse_reader.get_counts(raw_gids, group_by="syn_type_id")

# -
def get_target_connections(self, src_target_name,
dst_target_name,
Expand Down
12 changes: 9 additions & 3 deletions neurodamus/core/nodeset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def raw_gids(self):
def final_gids(self):
return numpy.add(self.raw_gids(), self._offset, dtype="uint32")

def intersection(self, _other):
def intersection(self, _other, _raw_gids=False):
return NotImplemented

def intersects(self, other):
Expand Down Expand Up @@ -244,7 +244,7 @@ def items(self, final_gid=False):
for gid in self._gidvec:
yield gid + offset_add, self._gid_info.get(gid)

def intersection(self, other):
def intersection(self, other, raw_gids=False):
"""Computes the intersection of two NodeSet's

For nodesets to intersect they must belong to the same population and
Expand All @@ -253,6 +253,8 @@ def intersection(self, other):
if self.population_name != other.population_name:
return []
intersect = numpy.intersect1d(self.raw_gids(), other.raw_gids(), assume_unique=True)
if raw_gids:
return intersect
return numpy.add(intersect, self._offset, dtype="uint32")

def clear_cell_info(self):
Expand Down Expand Up @@ -287,7 +289,7 @@ def final_gids_iter(self):
for gid in self.raw_gids_iter():
yield gid + self._offset

def intersection(self, other: _NodeSetBase, _quick_check=False):
def intersection(self, other: _NodeSetBase, raw_gids=False, _quick_check=False):
"""Computes intersection of two nodesets.
"""
# NOTE: A _quick_check param can be set to True so that we effectively only check for
Expand All @@ -306,6 +308,10 @@ def intersection(self, other: _NodeSetBase, _quick_check=False):
if _quick_check:
return intersect
if len(intersect):
if raw_gids:
# TODO: We should change the return type to be another `SelectionNodeSet`
# Like that we could still keep ranges internally and have PROPER API to get raw ids
return numpy.add(intersect, 1, dtype=intersect.dtype)
return numpy.add(intersect, self.offset + 1, dtype=intersect.dtype)
return []

Expand Down
9 changes: 9 additions & 0 deletions neurodamus/io/synapse_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,15 @@ class CustomSynapseParameters(self.Parameters):

return conn_syn_params

def get_counts(self, raw_ids, group_by):
"""
Counts synapses and groups by the given field.
"""
edge_ids = self._population.afferent_edges(raw_ids - 1)
data = self._population.get_attribute(group_by, edge_ids)
values, counts = np.unique(data, return_counts=True)
return dict(zip(values, counts))


class SynReaderNRN(SynapseReader):
""" Synapse Reader for NRN format only, using the hdf5_reader mod.
Expand Down
45 changes: 40 additions & 5 deletions neurodamus/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import subprocess
from os import path as ospath
from collections import namedtuple, defaultdict
from collections import Counter, namedtuple, defaultdict
from contextlib import contextmanager
from shutil import copyfileobj, move

Expand All @@ -33,6 +33,7 @@
from .utils import compat
from .utils.logging import log_stage, log_verbose, log_all
from .utils.memory import trim_memory, pool_shrink, free_event_queues, print_mem_usage
from .utils.memory import SynapseMemoryUsage
from .utils.timeit import TimerManager, timeit
# Internal Plugins
from . import ngv as _ngv # NOQA
Expand Down Expand Up @@ -471,19 +472,28 @@ def create_synapses(self):
"""
log_stage("LOADING CIRCUIT CONNECTIVITY")
target_manager = self._target_manager
self._create_synapse_manager(SynapseRuleManager, self._base_circuit, target_manager,
load_offsets=self._is_ngv_run)
manager_kwa = {"load_offsets": self._is_ngv_run}

if SimConfig.dry_run:
synapse_counter = Counter()
manager_kwa["synapse_counter"] = synapse_counter

if circuit := self._base_circuit:
self._create_synapse_manager(SynapseRuleManager, circuit, target_manager, **manager_kwa)

for circuit in self._extra_circuits.values():
Engine = circuit.Engine or METypeEngine
SynManagerCls = Engine.InnerConnectivityCls
self._create_synapse_manager(SynManagerCls, circuit, target_manager,
load_offsets=self._is_ngv_run)
self._create_synapse_manager(SynManagerCls, circuit, target_manager, **manager_kwa)

log_stage("Handling projections...")
for pname, projection in SimConfig.projections.items():
self._load_projections(pname, projection)

if SimConfig.dry_run:
self._collect_display_syn_counts(synapse_counter)
return

log_stage("Configuring connections...")
for conn_conf in SimConfig.connections.values():
self._process_connection_configure(conn_conf)
Expand Down Expand Up @@ -616,6 +626,30 @@ def _find_config_file(self, filepath, path_conf_entries=(), alt_filename=None):
if self._run_conf.get(path_key)]
return find_input_file(filepath, search_paths, alt_filename)

@staticmethod
def _collect_display_syn_counts(local_syn_counter):
xelist = [local_syn_counter] + [None] * (MPI.size - 1) # send to rank0
counters = MPI.py_alltoall(xelist)
inh = exc = 0

if MPI.rank == 0:
log_stage("Synapse memory estimate (per type)")
master_counter = Counter()
for c in counters:
master_counter.update(c)

for synapse_type, count in master_counter.items():
logging.debug(f" - {synapse_type}: {count}")
if synapse_type < 100:
inh += count
if synapse_type >= 100:
exc += count
logging.info(" - Estimated synapse memory usage (KB):")
in_mem = SynapseMemoryUsage.get_memory_usage(inh, "ProbGABAAB")
ex_mem = SynapseMemoryUsage.get_memory_usage(exc, "ProbAMPANMDA")
logging.info(f" - Inhibitory: {in_mem}")
logging.info(f" - Excitatory: {ex_mem}")

# -
@mpi_no_errors
@timeit(name="Enable Stimulus")
Expand Down Expand Up @@ -1672,6 +1706,7 @@ def __init__(
if SimConfig.dry_run:
self.load_targets()
self.create_cells()
self.create_synapses()
self._init_ok = True
return

Expand Down
20 changes: 12 additions & 8 deletions neurodamus/target_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .core import MPI, NeurodamusCore as Nd
from .core.configuration import ConfigurationError, SimConfig, GlobalConfig, find_input_file
from .core.nodeset import NodeSet, SelectionNodeSet
from .core.nodeset import _NodeSetBase, NodeSet, SelectionNodeSet
from .utils import compat
from .utils.logging import log_verbose

Expand Down Expand Up @@ -453,7 +453,7 @@ def get_offset(self, *_):


class NodesetTarget(_TargetInterface, _HocTargetInterface):
def __init__(self, name, nodesets: List[NodeSet], local_nodes=None, **_kw):
def __init__(self, name, nodesets: List[_NodeSetBase], local_nodes=None, **_kw):
self.name = name
self.nodesets = nodesets
self.local_nodes = local_nodes
Expand Down Expand Up @@ -514,22 +514,26 @@ def get_hoc_target(self):
def update_local_nodes(self, local_nodes):
self.local_nodes = local_nodes

def get_local_gids(self):
def get_local_gids(self, raw_gids=False):
"""Return the list of target gids in this rank (with offset)
"""
assert self.local_nodes, "Local nodes not set"

def pop_gid_intersect(local_nodes):
for n in self.nodesets:
if n.population_name == local_nodes.population_name:
return n.intersection(local_nodes)
def pop_gid_intersect(nodeset: _NodeSetBase, raw_gids=False):
for local_ns in self.local_nodes:
if local_ns.population_name == nodeset.population_name:
return nodeset.intersection(local_ns, raw_gids)
return []

if raw_gids:
assert len(self.nodesets) != 1, "Multiple populations when asking for raw gids"
return pop_gid_intersect(self.nodesets[0], raw_gids=True)

# If target is named Mosaic, basically we don't filter and use local_gids
if self.name == "Mosaic" or self.name.startswith("Mosaic#"):
gids_groups = tuple(n.final_gids() for n in self.local_nodes)
else:
gids_groups = tuple(pop_gid_intersect(nodes) for nodes in self.local_nodes)
gids_groups = tuple(pop_gid_intersect(ns) for ns in self.nodesets)

return numpy.concatenate(gids_groups)

Expand Down
17 changes: 17 additions & 0 deletions neurodamus/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,20 @@ def get_mem_usage():
usage_mb = float(data_size) * os.sysconf("SC_PAGE_SIZE") / 1024 ** 2

return usage_mb


class SynapseMemoryUsage:
''' A small class that works as a lookup table
for the memory used by each type of synapse.
The values are in KB. The values cannot be set by the user.
'''
_synapse_memory_usage = {
"ProbAMPANMDA": 1.7,
"ProbGABAAB": 2.0,
"Gap": 2.0,
"Glue": 0.5
}

@classmethod
def get_memory_usage(cls, count, synapse_type):
return count * cls._synapse_memory_usage[synapse_type]