Skip to content

Commit

Permalink
Improved generator to sort on the fly
Browse files Browse the repository at this point in the history
  • Loading branch information
st4rl3ss committed Jan 30, 2024
1 parent b4cf01f commit 6e8d575
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions neurodamus/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,20 @@ def distribute_cells(dry_run_stats, num_ranks) -> (dict, dict):

# Prepare a list of tuples (cell_id, memory_load)
# We sum the memory load of the cell type and the average number of synapses per cell
# cells = [(gid, dry_run_stats.metype_memory[cell_type] +
# average_syns_mem_per_cell[cell_type])
# for cell_type, gids in dry_run_stats.metype_gids.items() for gid in gids]
# # Distribute cells with higher memory load first
# cells.sort(key=lambda x: x[1], reverse=True)
logging.debug("Generating cells...")
logging.debug("Creating generator...")

def generate_cells():
heap = []
for cell_type, gids in dry_run_stats.metype_gids.items():
for gid in gids:
yield gid, dry_run_stats.metype_memory[cell_type] + average_syns_mem_per_cell[cell_type]
memory_usage = dry_run_stats.metype_memory[cell_type] + average_syns_mem_per_cell[cell_type]
# Use negative memory usage as the priority for descending order
heapq.heappush(heap, (-memory_usage, gid))

# Yield from the heap in sorted order
while heap:
memory_usage, gid = heapq.heappop(heap)
yield gid, -memory_usage

# Initialize structures
logging.debug("Initializing structures...")
Expand All @@ -218,7 +221,7 @@ def generate_cells():

# Start distributing cells across ranks starting with the ones with higher memory load
logging.debug("Distributing cells across ranks...")
for cell_id, memory in sorted(generate_cells(), key=lambda x: x[1], reverse=True):
for cell_id, memory in generate_cells():
# Get the rank with the lowest memory load
total_memory, rank_id = heapq.heappop(ranks)
logging.debug("Assigning cell %d to rank %d", cell_id, rank_id)
Expand Down Expand Up @@ -247,9 +250,9 @@ def print_allocation_stats(rank_allocation, rank_memory):
print("Total memory per rank: ", rank_memory)
import statistics
values = list(rank_memory.values())
print("Mean: ", statistics.mean(values))
print("Median: ", statistics.median(values))
print("Stdev: ", statistics.stdev(values))
print("Mean: ", round(statistics.mean(values)))
print("Median: ", round(statistics.median(values)))
print("Stdev: ", round(statistics.stdev(values)))


class SynapseMemoryUsage:
Expand Down

0 comments on commit 6e8d575

Please sign in to comment.