Skip to content

Commit

Permalink
Redesign metal op for a memory pool mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Oct 18, 2024
1 parent 8ef7a78 commit 5b021b1
Show file tree
Hide file tree
Showing 27 changed files with 911 additions and 594 deletions.
202 changes: 0 additions & 202 deletions core/src/model/memory.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::fmt;
use super::*;
use crate::prelude::*;
use std::collections::HashSet;
Expand Down Expand Up @@ -63,204 +62,3 @@ where
}
Ok(mem_by_steps)
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ScopedNodeMemory {
pub node: usize,
pub scope: Scope,
pub mem_size: usize,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Scope {
pub start: usize,
pub end: usize,
}

impl Scope {
fn is_disjoint(&self, other: &Scope) -> bool {
self.start >= other.end || other.start >= self.end
}

pub fn is_alive_at_step(&self, step: usize) -> bool {
self.start <= step && step < self.end
}

pub fn len(&self) -> usize {
self.end - self.start
}
}

pub fn eval_scoped_node_memories<F, O, Flushable>(
model: &Graph<F, O>,
order: &[usize],
flushable: Flushable) -> TractResult<TVec<ScopedNodeMemory>>
where
F: Fact + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
Flushable: Fn(&Node<F, O>) -> bool,
{

let outputs = model.output_outlets()?.to_vec();
let flush_lists = super::order::build_flush_list(model, order, &outputs, &flushable);
let mut scoped_nodes = tvec![];

for (step, n) in order.iter().enumerate() {
let scope_start = step;
let scope_end = flush_lists.iter().enumerate().find(|(_step, flush_list)| flush_list.contains(n))
.map(|it| usize::min(it.0 + 1, order.len()));

let Some(scope_end) = scope_end else { continue; };

let out_facts = model
.node_output_facts(*n)?
.into_iter()
.map(|it| it.to_typed_fact())
.collect::<TractResult<TVec<_>>>()?;


scoped_nodes.push(ScopedNodeMemory {
node: *n,
scope: Scope {
start: scope_start,
end: scope_end
},
mem_size: out_facts.iter().map(|it| it.mem_size()).sum::<TDim>().to_usize()?,
})
}

Ok(scoped_nodes)
}


#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MemoryPlan {
pub by_partition: Vec<Vec<ScopedNodeMemory>>,
pub by_steps: Vec<Vec<Option<ScopedNodeMemory>>>,
}

impl MemoryPlan {

pub fn size_by_partition(&self) -> Vec<usize> {
self.by_partition.iter().map(|it| it.iter().map(|it| it.mem_size).max().unwrap_or(0)).collect()
}

pub fn memory_size(&self) -> usize {
self.by_partition.iter().map(|it| it.iter().map(|it| it.mem_size).max().unwrap_or(0)).sum()
}
}

impl fmt::Display for MemoryPlan {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for (step, mem_step) in self.by_steps.iter().enumerate() {
writeln!(
fmt,
"step: {:5} => |{}|",
step,
mem_step.iter()
.map(|n| -> String { n.as_ref().map(|it| format!("{:^7}", it.node)).unwrap_or(format!("{:^7}", "*"))})
.collect::<Vec<String>>()
.join("|")
)?;

}
writeln!(fmt, "memory_size: {}", self.memory_size())?;
Ok(())
}
}


pub fn eval_memory_plan<F, O, Flushable>(
model: &Graph<F, O>,
order: &[usize],
flushable: Flushable) -> TractResult<MemoryPlan>
where
F: Fact + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
Flushable: Fn(&Node<F, O>) -> bool,
{

let mut scoped_node_memories = eval_scoped_node_memories(&model, &order, flushable)?;
scoped_node_memories
.sort_by(|lhs, rhs| {
lhs.scope.end.cmp(&rhs.scope.end).reverse()
.then(lhs.scope.len().cmp(&rhs.scope.len()).reverse())
.then(lhs.mem_size.cmp(&rhs.mem_size).reverse())
});

let mut partitions: Vec<Vec<ScopedNodeMemory>> = vec![];
for node_mem in scoped_node_memories {
// Find partitions where node scope is disjoint from existing.
let mut available = partitions
.iter_mut()
.filter(|it| it.iter().all(|n| n.scope.is_disjoint(&node_mem.scope)))
.collect::<Vec<_>>();

available.sort_by_key(|n| n.iter().map(|it| it.mem_size as isize * -1).sum::<isize>());

match available.first_mut() {
Some(available) => {
available.push(node_mem);
},
None => {
partitions.push(vec![node_mem])
},
}
}

let by_steps: Vec<Vec<Option<ScopedNodeMemory>>> = (0..order.len())
.map(|step| {
let mem_step: Vec<_> = partitions.iter()
.map(|p| {
p.iter().find(|it| it.scope.is_alive_at_step(step)).cloned()
})
.collect();
ensure!(mem_step.len() <= partitions.len());
Ok(mem_step)
})
.collect::<TractResult<Vec<_>>>()?;



Ok(MemoryPlan { by_partition: partitions, by_steps })
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ops::konst::Const;
use crate::internal::*;
use crate::ops::array::Gather;
use crate::ops::math;

#[test]
fn test_node_scope() -> TractResult<()> {
let mut model = TypedModel::default();
let b = model.add_const("b", tensor1(&[0i64; 1000]))?; // 0
let d = model.add_const("d", tensor1(&[0i64; 100]))?; // 1
let a = model.add_source("a", i32::fact([10]))?; // 2
let c = model.wire_node("c", Gather::new(0), &[a, b])?[0]; // 3
let e = model.wire_node("e", Gather::new(0), &[c, d])?[0]; // 4
model.set_output_outlets(&[e]).unwrap();

let order = model.eval_order()?;
let scoped_node_memory = eval_scoped_node_memories(&model, &order, |n| !n.op_is::<Const>())?;
let plan = eval_memory_plan(&model, &order, |n| !n.op_is::<Const>())?;


assert_eq!(order, &[2, 0, 3, 1, 4]);


eprintln!("{model}");

eprintln!("{:?}", order);
eprintln!("{:#?}", scoped_node_memory);

eprintln!("{plan}");

// assert!(model.eval_order_opt_ram()?[2..] == [c.node, d.node, e.node]);
Ok(())
}
}


2 changes: 1 addition & 1 deletion core/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ use std::str;

mod fact;
mod graph;
pub mod memory;
mod node;
pub mod order;
pub mod memory;
mod patch;
mod rewriter;
pub mod translator;
Expand Down
52 changes: 21 additions & 31 deletions metal/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
use crate::func_constants::ConstantValues;
use crate::kernels::matmul::mps;
pub use crate::kernels::{LibraryContent, LibraryName};
use crate::tensor::MetalArena;
pub use crate::tensor::MetalTensor;
use crate::{MetalMemoryPool, MetalTensor};
use core::cell::Ref;
use metal::Buffer;
use metal::MTLResourceOptions;
use metal::NSUInteger;
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::cell::RefCell;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::{OnceLock, RwLock};
use std::sync::{Arc, OnceLock, RwLock};

use anyhow::{anyhow, Context, Result};
use metal::{
Expand Down Expand Up @@ -198,7 +194,7 @@ pub struct MetalContext {
command_buffer: RefCell<CommandBuffer>,
command_buffer_used: RefCell<usize>,
command_buffer_id: AtomicUsize,
arena: RefCell<Option<MetalArena>>,
mem_pool: RefCell<Option<MetalMemoryPool>>,
retained_tensors: RefCell<Vec<MetalTensor>>,
}

Expand All @@ -220,44 +216,37 @@ impl MetalContext {
command_buffer_used: RefCell::new(0),
command_buffer_id: AtomicUsize::new(0),
retained_tensors: RefCell::new(vec![]),
arena: RefCell::new(None),
mem_pool: RefCell::new(None),
shared,
}
}

/// Execute callback inside a MetalArena for MetalTensor allocation used during
/// Metal kernel execution. When the arena is full, MetalTensor are allocated in the heap and available
/// to kernels.
pub fn execute_in_arena<T>(
/// Execute callback inside a MetalMemoryPool for MetalTensor allocation used during
/// Metal kernel execution.
pub fn execute_in_mem_pool<T>(
&self,
arena: MetalArena,
mem_pool: MetalMemoryPool,
exe: impl FnOnce() -> Result<T>,
) -> Result<(MetalArena, T)> {
) -> Result<(MetalMemoryPool, T)> {
anyhow::ensure!(
self.arena.borrow().is_none(),
"Cannot execute inside an arena because an MetalArena is already in use"
self.mem_pool.borrow().is_none(),
"Cannot execute inside a memory pool because a MetalMemoryPool is already in use"
);
*self.arena.borrow_mut() = Some(arena);
*self.mem_pool.borrow_mut() = Some(mem_pool);
let res = (exe)()?;
let arena =
self.arena.borrow_mut().take().ok_or_else(|| {
anyhow!("Unexpected None arena while executing inside a metal arena")
})?;
log::debug!("MetalArena: {:.3} %", arena.used_capacity() as f32 / arena.capacity() as f32);
arena.try_reset();
log::debug!(
"MetalArena after reset: {:.3} %",
arena.used_capacity() as f32 / arena.capacity() as f32
);
Ok((arena, res))
let mem_pool = self.mem_pool.borrow_mut().take().ok_or_else(|| {
anyhow!("Unexpected None memory pool while executing inside a metal memory pool")
})?;
mem_pool.reset();
Ok((mem_pool, res))
}

pub fn device(&self) -> &Device {
&self.shared.device
}

pub fn memory_arena(&self) -> Ref<'_, Option<MetalArena>> {
self.arena.borrow()
pub fn memory_pool(&self) -> Ref<'_, Option<MetalMemoryPool>> {
self.mem_pool.borrow()
}

pub fn shared_context(&self) -> &SharedMetalContext {
Expand Down Expand Up @@ -320,6 +309,7 @@ impl MetalContext {
self.retained_tensors.borrow_mut().clear();

*command_buffer = self.command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
*command_buffer_used = 0;
self.command_buffer_id.fetch_add(1, Ordering::Relaxed);
Ok(())
Expand Down
8 changes: 8 additions & 0 deletions metal/src/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ impl MetalFact {
Self { kind: MetalFactKind::Temporary, ..self }
}

pub fn is_temporary(&self) -> bool {
matches!(self.kind, MetalFactKind::Temporary)
}

pub fn is_shared(&self) -> bool {
matches!(self.kind, MetalFactKind::Shared)
}

pub fn into_typed_fact(self) -> TypedFact {
self.fact
}
Expand Down
6 changes: 5 additions & 1 deletion metal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ pub mod encoder;
pub mod fact;
pub mod func_constants;
pub mod kernels;
pub mod memory;
pub mod ops;
pub mod plan;
pub mod rewrite_rules;
pub mod tensor;
pub mod transform;
Expand All @@ -15,7 +17,9 @@ pub mod utils;
pub use crate::context::{MetalContext, METAL_CONTEXT};
use crate::func_constants::{ConstantValues, Value};
pub use crate::kernels::{matmul::MetalGemmImplKind, LibraryContent, LibraryName};
pub use crate::tensor::{MetalArena, MetalTensor, MetalTensorExt};
pub use crate::memory::MetalMemoryPool;
pub use crate::plan::MetalPlanState;
pub use crate::tensor::{MetalTensor, MetalTensorExt};
pub use crate::transform::MetalTransform;
use anyhow::Result;
pub use fact::MetalFact;
Expand Down
5 changes: 5 additions & 0 deletions metal/src/memory/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod pool;
mod schema;

pub use pool::*;
pub use schema::*;
Loading

0 comments on commit 5b021b1

Please sign in to comment.