Skip to content

Commit

Permalink
Fixed bug in batch assignment, refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
st4rl3ss committed Feb 7, 2024
1 parent 09eaddc commit 3a33bc4
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 101 deletions.
5 changes: 1 addition & 4 deletions neurodamus/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from .utils import compat
from .utils.logging import log_stage, log_verbose, log_all
from .utils.memory import DryRunStats, trim_memory, pool_shrink, free_event_queues, print_mem_usage
from .utils.memory import print_allocation_stats, export_allocation_stats, distribute_cells
from .utils.timeit import TimerManager, timeit
from .core.coreneuron_configuration import CoreConfig
# Internal Plugins
Expand Down Expand Up @@ -1966,9 +1965,7 @@ def run(self):
self._dry_run_stats.display_node_suggestions()
ranks = int(SimConfig.num_target_ranks)
self._dry_run_stats.collect_all_mpi()
allocation, total_memory_per_rank = distribute_cells(self._dry_run_stats, ranks)
print_allocation_stats(allocation, total_memory_per_rank)
export_allocation_stats(allocation, "allocation.bin")
self._dry_run_stats.distribute_cells(ranks)
return
if not SimConfig.simulate_model:
self.sim_init()
Expand Down
189 changes: 98 additions & 91 deletions neurodamus/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import psutil
import multiprocessing
import heapq
import pickle

from ..core import MPI, NeurodamusCore as Nd, run_only_rank0
from .compat import Vector
from collections import defaultdict

import numpy as np
Expand Down Expand Up @@ -162,83 +164,6 @@ def pretty_printing_memory_mb(memory_mb):
return "%.2lf PB" % (memory_mb / 1024 ** 3)


@run_only_rank0
def distribute_cells(dry_run_stats, num_ranks, batch_size=10) -> (dict, dict):
"""
Distributes cells across ranks based on their memory load.
This function uses a greedy algorithm to distribute cells across ranks such that
the total memory load is balanced. Cells with higher memory load are distributed first.
Args:
dry_run_stats (DryRunStats): A DryRunStats object.
num_ranks (int): The number of ranks.
Returns:
rank_allocation (dict): A dictionary where keys are rank IDs and
values are lists of cell IDs assigned to each rank.
rank_memory (dict): A dictionary where keys are rank IDs
and values are the total memory load on each rank.
"""
logging.debug("Distributing cells across %d ranks", num_ranks)

# Check inputs
dry_run_stats.validate_inputs_distribute(num_ranks, batch_size)

# Multiply the average number of synapses per cell by 2.0
# This is done since the biggest memory load for a synapse is 2.0 kB and at this point in the
# code we have lost the information on whether they are excitatory or inhibitory
# so we just take the biggest value to be safe. (the difference between the two is minimal)
average_syns_mem_per_cell = {k: v * 2.0 for k, v in dry_run_stats.average_syns_per_cell.items()}

# Prepare a list of tuples (cell_id, memory_load)
# We sum the memory load of the cell type and the average number of synapses per cell
def generate_cells(metype_gids):
for cell_type, gids in metype_gids.items():
memory_usage = (dry_run_stats.metype_memory[cell_type] +
average_syns_mem_per_cell[cell_type])
for gid in gids:
yield gid, memory_usage

# Initialize structures
ranks = [(0, i) for i in range(num_ranks)] # (total_memory, rank_id)
heapq.heapify(ranks)
rank_allocation = defaultdict(dict)
rank_memory = defaultdict(dict)

def assign_cells_to_rank(batch_memory, pop):
total_memory, rank_id = heapq.heappop(ranks)
logging.debug("Assigning batch to rank %d", rank_id)
if rank_id not in rank_allocation[pop]:
rank_allocation[pop][rank_id] = []
rank_allocation[pop][rank_id].append(cell_id)
# Update the total memory load of the rank
total_memory += batch_memory
rank_memory[pop][rank_id] = total_memory
# Update total memory and re-add to the heap
heapq.heappush(ranks, (total_memory, rank_id))

# Start distributing cells across ranks
for pop, metype_gids in dry_run_stats.metype_gids.items():
logging.info("Distributing cells of population %s", pop)
batch = []
batch_memory = 0

for cell_id, memory in generate_cells(metype_gids):
batch.append(cell_id)
batch_memory += memory
if len(batch) == batch_size:
assign_cells_to_rank(batch_memory, pop)
batch = []
batch_memory = 0

# Assign any remaining cells in the last, potentially incomplete batch
if batch:
assign_cells_to_rank(batch_memory, pop)

return rank_allocation, rank_memory


@run_only_rank0
def print_allocation_stats(rank_allocation, rank_memory):
"""
Expand Down Expand Up @@ -269,27 +194,25 @@ def export_allocation_stats(rank_allocation, filename):
"""
Export allocation dictionary to serialized pickle file.
"""
import pickle
try:
with open(filename, 'wb') as f:
pickle.dump(rank_allocation, f)
except Exception as e:
logging.warning("Unable to export allocation stats: {}".format(e))
with open(filename, 'wb') as f:
pickle.dump(rank_allocation, f)


@run_only_rank0
def import_allocation_stats(filename):
"""
Import allocation dictionary from serialized pickle file.
"""
import pickle
try:
with open(filename, 'rb') as f:
rank_allocation = pickle.load(f)
return rank_allocation
except Exception as e:
logging.warning("Unable to import allocation stats: {}".format(e))
return None
with open(filename, 'rb') as f:
return pickle.load(f)


@run_only_rank0
def allocation_stats_exists(filename):
"""
Check if the allocation stats file exists.
"""
return os.path.exists(filename)


class SynapseMemoryUsage:
Expand All @@ -311,6 +234,7 @@ def get_memory_usage(cls, count, synapse_type):

class DryRunStats:
_MEMORY_USAGE_FILENAME = "cell_memory_usage.json"
_ALLOCATION_FILENAME = "allocation.bin"

def __init__(self) -> None:
self.metype_memory = {}
Expand Down Expand Up @@ -472,6 +396,89 @@ def display_node_suggestions(self):
logging.info("Please remember that it is suggested to use the same class of nodes "
"for both the dryrun and the actual simulation.")

@run_only_rank0
def distribute_cells(self, num_ranks, batch_size=10) -> (dict, dict):
"""
Distributes cells across ranks based on their memory load.
This function uses a greedy algorithm to distribute cells across ranks such that
the total memory load is balanced. Cells with higher memory load are distributed first.
Args:
dry_run_stats (DryRunStats): A DryRunStats object.
num_ranks (int): The number of ranks.
Returns:
rank_allocation (dict): A dictionary where keys are rank IDs and
values are lists of cell IDs assigned to each rank.
rank_memory (dict): A dictionary where keys are rank IDs
and values are the total memory load on each rank.
"""
logging.debug("Distributing cells across %d ranks", num_ranks)

# Check inputs
self.validate_inputs_distribute(num_ranks, batch_size)

# Multiply the average number of synapses per cell by 2.0
# This is done since the biggest memory load for a synapse is 2.0 kB and at this point in
# the code we have lost the information on whether they are excitatory or inhibitory
# so we just take the biggest value to be safe. (the difference between the two is minimal)
average_syns_mem_per_cell = {k: v * 2.0 for k, v in self.average_syns_per_cell.items()}

# Prepare a list of tuples (cell_id, memory_load)
# We sum the memory load of the cell type and the average number of synapses per cell
def generate_cells(metype_gids):
for cell_type, gids in metype_gids.items():
memory_usage = (self.metype_memory[cell_type] +
average_syns_mem_per_cell[cell_type])
for gid in gids:
yield gid, memory_usage

# Initialize structures
ranks = [(0, i) for i in range(num_ranks)] # (total_memory, rank_id)
heapq.heapify(ranks)
all_allocation = {}
all_memory = {}

def assign_cells_to_rank(rank_allocation, rank_memory, batch, batch_memory):
total_memory, rank_id = heapq.heappop(ranks)
logging.debug("Assigning batch to rank %d", rank_id)
rank_allocation[rank_id].extend(batch)
# Update the total memory load of the rank
total_memory += batch_memory
rank_memory[rank_id] = total_memory
# Update total memory and re-add to the heap
heapq.heappush(ranks, (total_memory, rank_id))

# Start distributing cells across ranks
for pop, metype_gids in self.metype_gids.items():
logging.info("Distributing cells of population %s", pop)
rank_allocation = defaultdict(Vector)
rank_memory = {}
batch = []
batch_memory = 0

for cell_id, memory in generate_cells(metype_gids):
batch.append(cell_id)
batch_memory += memory
if len(batch) == batch_size:
assign_cells_to_rank(rank_allocation, rank_memory, batch, batch_memory)
batch = []
batch_memory = 0

# Assign any remaining cells in the last, potentially incomplete batch
if batch:
assign_cells_to_rank(rank_allocation, rank_memory, batch, batch_memory)

all_allocation[pop] = rank_allocation
all_memory[pop] = rank_memory

# Print and export allocation stats
print_allocation_stats(all_allocation, all_memory)
export_allocation_stats(all_allocation, self._ALLOCATION_FILENAME)

return all_allocation, rank_memory

def validate_inputs_distribute(self, num_ranks, batch_size):
assert isinstance(num_ranks, int), "num_ranks must be an integer"
assert num_ranks > 0, "num_ranks must be a positive integer"
Expand Down
23 changes: 17 additions & 6 deletions tests/integration-e2e/test_dry_run_worflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from neurodamus.utils.memory import (distribute_cells,
export_allocation_stats,
import_allocation_stats)
from neurodamus.utils.memory import import_allocation_stats, export_allocation_stats


def convert_to_standard_types(obj):
"""Converts an object containing defaultdicts of Vectors to standard Python types."""
result = {}
for node, vectors in obj.items():
result[node] = {key: list(vector) for key, vector in vectors.items()}
return result


def test_dry_run_workflow(USECASE3):
Expand Down Expand Up @@ -29,9 +35,14 @@ def test_dry_run_workflow(USECASE3):
assert nd._dry_run_stats.suggest_nodes(0.3) > 0

# Test that the allocation works and can be saved and loaded
rank_allocation, _ = distribute_cells(nd._dry_run_stats, 2)
rank_allocation, _ = nd._dry_run_stats.distribute_cells(2)
export_allocation_stats(rank_allocation, USECASE3 / "allocation.bin")
rank_allocation = import_allocation_stats(USECASE3 / "allocation.bin")
rank_allocation_standard = convert_to_standard_types(rank_allocation)

expected_items = {
'NodeA': {0: [1, 2, 3]},
'NodeB': {1: [1, 2]}
}

expected_items = {'NodeA': {0: [3]}, 'NodeB': {1: [2]}}
assert rank_allocation == expected_items
assert rank_allocation_standard == expected_items

0 comments on commit 3a33bc4

Please sign in to comment.