From 5b021b1ce48475c8b335c3c10368e4c30fd4c5cb Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Fri, 18 Oct 2024 13:15:18 +0200 Subject: [PATCH] Redesign metal op for a memory pool mechanism --- core/src/model/memory.rs | 202 ------------------------- core/src/model/mod.rs | 2 +- metal/src/context.rs | 52 +++---- metal/src/fact.rs | 8 + metal/src/lib.rs | 6 +- metal/src/memory/mod.rs | 5 + metal/src/memory/pool.rs | 138 +++++++++++++++++ metal/src/memory/schema.rs | 266 +++++++++++++++++++++++++++++++++ metal/src/ops/binary.rs | 44 +++--- metal/src/ops/broadcast.rs | 31 ++-- metal/src/ops/cast.rs | 29 ++-- metal/src/ops/change_axes.rs | 41 +++-- metal/src/ops/concat.rs | 46 +++--- metal/src/ops/element_wise.rs | 33 ++-- metal/src/ops/gemm.rs | 54 ++++--- metal/src/ops/mod.rs | 60 ++++++++ metal/src/ops/new_gelu.rs | 41 ++--- metal/src/ops/rms_norm.rs | 33 ++-- metal/src/ops/silu.rs | 33 ++-- metal/src/ops/slice.rs | 47 +++--- metal/src/ops/softmax.rs | 33 ++-- metal/src/plan.rs | 81 ++++++++++ metal/src/tensor/arena.rs | 117 --------------- metal/src/tensor/arena_view.rs | 48 +++++- metal/src/tensor/mod.rs | 28 ++-- metal/src/transform.rs | 7 + metal/src/utils.rs | 20 +++ 27 files changed, 911 insertions(+), 594 deletions(-) create mode 100644 metal/src/memory/mod.rs create mode 100644 metal/src/memory/pool.rs create mode 100644 metal/src/memory/schema.rs create mode 100644 metal/src/plan.rs delete mode 100644 metal/src/tensor/arena.rs diff --git a/core/src/model/memory.rs b/core/src/model/memory.rs index 78a010d2bf..099c6e2a9b 100644 --- a/core/src/model/memory.rs +++ b/core/src/model/memory.rs @@ -1,4 +1,3 @@ -use std::fmt; use super::*; use crate::prelude::*; use std::collections::HashSet; @@ -63,204 +62,3 @@ where } Ok(mem_by_steps) } - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ScopedNodeMemory { - pub node: usize, - pub scope: Scope, - pub mem_size: usize, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Scope { - pub start: usize, - pub end: usize, -} - -impl Scope { - fn is_disjoint(&self, other: &Scope) -> bool { - self.start >= other.end || other.start >= self.end - } - - pub fn is_alive_at_step(&self, step: usize) -> bool { - self.start <= step && step < self.end - } - - pub fn len(&self) -> usize { - self.end - self.start - } -} - -pub fn eval_scoped_node_memories( - model: &Graph, - order: &[usize], - flushable: Flushable) -> TractResult> -where - F: Fact + Clone + 'static, - O: Debug + Display + AsRef + AsMut + Clone + 'static, - Flushable: Fn(&Node) -> bool, -{ - - let outputs = model.output_outlets()?.to_vec(); - let flush_lists = super::order::build_flush_list(model, order, &outputs, &flushable); - let mut scoped_nodes = tvec![]; - - for (step, n) in order.iter().enumerate() { - let scope_start = step; - let scope_end = flush_lists.iter().enumerate().find(|(_step, flush_list)| flush_list.contains(n)) - .map(|it| usize::min(it.0 + 1, order.len())); - - let Some(scope_end) = scope_end else { continue; }; - - let out_facts = model - .node_output_facts(*n)? - .into_iter() - .map(|it| it.to_typed_fact()) - .collect::>>()?; - - - scoped_nodes.push(ScopedNodeMemory { - node: *n, - scope: Scope { - start: scope_start, - end: scope_end - }, - mem_size: out_facts.iter().map(|it| it.mem_size()).sum::().to_usize()?, - }) - } - - Ok(scoped_nodes) -} - - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct MemoryPlan { - pub by_partition: Vec>, - pub by_steps: Vec>>, -} - -impl MemoryPlan { - - pub fn size_by_partition(&self) -> Vec { - self.by_partition.iter().map(|it| it.iter().map(|it| it.mem_size).max().unwrap_or(0)).collect() - } - - pub fn memory_size(&self) -> usize { - self.by_partition.iter().map(|it| it.iter().map(|it| it.mem_size).max().unwrap_or(0)).sum() - } -} - -impl fmt::Display for MemoryPlan { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - for (step, mem_step) in self.by_steps.iter().enumerate() { - writeln!( - fmt, - "step: {:5} => |{}|", - step, - mem_step.iter() - .map(|n| -> String { n.as_ref().map(|it| format!("{:^7}", it.node)).unwrap_or(format!("{:^7}", "*"))}) - .collect::>() - .join("|") - )?; - - } - writeln!(fmt, "memory_size: {}", self.memory_size())?; - Ok(()) - } -} - - -pub fn eval_memory_plan( - model: &Graph, - order: &[usize], - flushable: Flushable) -> TractResult -where - F: Fact + Clone + 'static, - O: Debug + Display + AsRef + AsMut + Clone + 'static, - Flushable: Fn(&Node) -> bool, -{ - - let mut scoped_node_memories = eval_scoped_node_memories(&model, &order, flushable)?; - scoped_node_memories - .sort_by(|lhs, rhs| { - lhs.scope.end.cmp(&rhs.scope.end).reverse() - .then(lhs.scope.len().cmp(&rhs.scope.len()).reverse()) - .then(lhs.mem_size.cmp(&rhs.mem_size).reverse()) - }); - - let mut partitions: Vec> = vec![]; - for node_mem in scoped_node_memories { - // Find partitions where node scope is disjoint from existing. - let mut available = partitions - .iter_mut() - .filter(|it| it.iter().all(|n| n.scope.is_disjoint(&node_mem.scope))) - .collect::>(); - - available.sort_by_key(|n| n.iter().map(|it| it.mem_size as isize * -1).sum::()); - - match available.first_mut() { - Some(available) => { - available.push(node_mem); - }, - None => { - partitions.push(vec![node_mem]) - }, - } - } - - let by_steps: Vec>> = (0..order.len()) - .map(|step| { - let mem_step: Vec<_> = partitions.iter() - .map(|p| { - p.iter().find(|it| it.scope.is_alive_at_step(step)).cloned() - }) - .collect(); - ensure!(mem_step.len() <= partitions.len()); - Ok(mem_step) - }) - .collect::>>()?; - - - - Ok(MemoryPlan { by_partition: partitions, by_steps }) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::ops::konst::Const; - use crate::internal::*; - use crate::ops::array::Gather; - use crate::ops::math; - - #[test] - fn test_node_scope() -> TractResult<()> { - let mut model = TypedModel::default(); - let b = model.add_const("b", tensor1(&[0i64; 1000]))?; // 0 - let d = model.add_const("d", tensor1(&[0i64; 100]))?; // 1 - let a = model.add_source("a", i32::fact([10]))?; // 2 - let c = model.wire_node("c", Gather::new(0), &[a, b])?[0]; // 3 - let e = model.wire_node("e", Gather::new(0), &[c, d])?[0]; // 4 - model.set_output_outlets(&[e]).unwrap(); - - let order = model.eval_order()?; - let scoped_node_memory = eval_scoped_node_memories(&model, &order, |n| !n.op_is::())?; - let plan = eval_memory_plan(&model, &order, |n| !n.op_is::())?; - - - assert_eq!(order, &[2, 0, 3, 1, 4]); - - - eprintln!("{model}"); - - eprintln!("{:?}", order); - eprintln!("{:#?}", scoped_node_memory); - - eprintln!("{plan}"); - - // assert!(model.eval_order_opt_ram()?[2..] == [c.node, d.node, e.node]); - Ok(()) - } -} - - diff --git a/core/src/model/mod.rs b/core/src/model/mod.rs index 2612d063e4..6b88aae589 100644 --- a/core/src/model/mod.rs +++ b/core/src/model/mod.rs @@ -38,9 +38,9 @@ use std::str; mod fact; mod graph; +pub mod memory; mod node; pub mod order; -pub mod memory; mod patch; mod rewriter; pub mod translator; diff --git a/metal/src/context.rs b/metal/src/context.rs index d0f9266726..aa491397f2 100644 --- a/metal/src/context.rs +++ b/metal/src/context.rs @@ -1,17 +1,13 @@ use crate::func_constants::ConstantValues; use crate::kernels::matmul::mps; pub use crate::kernels::{LibraryContent, LibraryName}; -use crate::tensor::MetalArena; -pub use crate::tensor::MetalTensor; +use crate::{MetalMemoryPool, MetalTensor}; use core::cell::Ref; -use metal::Buffer; -use metal::MTLResourceOptions; -use metal::NSUInteger; +use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::cell::RefCell; use std::path::Path; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::sync::{OnceLock, RwLock}; +use std::sync::{Arc, OnceLock, RwLock}; use anyhow::{anyhow, Context, Result}; use metal::{ @@ -198,7 +194,7 @@ pub struct MetalContext { command_buffer: RefCell, command_buffer_used: RefCell, command_buffer_id: AtomicUsize, - arena: RefCell>, + mem_pool: RefCell>, retained_tensors: RefCell>, } @@ -220,44 +216,37 @@ impl MetalContext { command_buffer_used: RefCell::new(0), command_buffer_id: AtomicUsize::new(0), retained_tensors: RefCell::new(vec![]), - arena: RefCell::new(None), + mem_pool: RefCell::new(None), shared, } } - /// Execute callback inside a MetalArena for MetalTensor allocation used during - /// Metal kernel execution. When the arena is full, MetalTensor are allocated in the heap and available - /// to kernels. - pub fn execute_in_arena( + /// Execute callback inside a MetalMemoryPool for MetalTensor allocation used during + /// Metal kernel execution. + pub fn execute_in_mem_pool( &self, - arena: MetalArena, + mem_pool: MetalMemoryPool, exe: impl FnOnce() -> Result, - ) -> Result<(MetalArena, T)> { + ) -> Result<(MetalMemoryPool, T)> { anyhow::ensure!( - self.arena.borrow().is_none(), - "Cannot execute inside an arena because an MetalArena is already in use" + self.mem_pool.borrow().is_none(), + "Cannot execute inside a memory pool because a MetalMemoryPool is already in use" ); - *self.arena.borrow_mut() = Some(arena); + *self.mem_pool.borrow_mut() = Some(mem_pool); let res = (exe)()?; - let arena = - self.arena.borrow_mut().take().ok_or_else(|| { - anyhow!("Unexpected None arena while executing inside a metal arena") - })?; - log::debug!("MetalArena: {:.3} %", arena.used_capacity() as f32 / arena.capacity() as f32); - arena.try_reset(); - log::debug!( - "MetalArena after reset: {:.3} %", - arena.used_capacity() as f32 / arena.capacity() as f32 - ); - Ok((arena, res)) + let mem_pool = self.mem_pool.borrow_mut().take().ok_or_else(|| { + anyhow!("Unexpected None memory pool while executing inside a metal memory pool") + })?; + mem_pool.reset(); + Ok((mem_pool, res)) } pub fn device(&self) -> &Device { &self.shared.device } - pub fn memory_arena(&self) -> Ref<'_, Option> { - self.arena.borrow() + pub fn memory_pool(&self) -> Ref<'_, Option> { + self.mem_pool.borrow() } pub fn shared_context(&self) -> &SharedMetalContext { @@ -320,6 +309,7 @@ impl MetalContext { self.retained_tensors.borrow_mut().clear(); *command_buffer = self.command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); *command_buffer_used = 0; self.command_buffer_id.fetch_add(1, Ordering::Relaxed); Ok(()) diff --git a/metal/src/fact.rs b/metal/src/fact.rs index ee98b5cec3..cf31017418 100644 --- a/metal/src/fact.rs +++ b/metal/src/fact.rs @@ -27,6 +27,14 @@ impl MetalFact { Self { kind: MetalFactKind::Temporary, ..self } } + pub fn is_temporary(&self) -> bool { + matches!(self.kind, MetalFactKind::Temporary) + } + + pub fn is_shared(&self) -> bool { + matches!(self.kind, MetalFactKind::Shared) + } + pub fn into_typed_fact(self) -> TypedFact { self.fact } diff --git a/metal/src/lib.rs b/metal/src/lib.rs index 0c204fe4c1..1a5e35ab0e 100644 --- a/metal/src/lib.rs +++ b/metal/src/lib.rs @@ -6,7 +6,9 @@ pub mod encoder; pub mod fact; pub mod func_constants; pub mod kernels; +pub mod memory; pub mod ops; +pub mod plan; pub mod rewrite_rules; pub mod tensor; pub mod transform; @@ -15,7 +17,9 @@ pub mod utils; pub use crate::context::{MetalContext, METAL_CONTEXT}; use crate::func_constants::{ConstantValues, Value}; pub use crate::kernels::{matmul::MetalGemmImplKind, LibraryContent, LibraryName}; -pub use crate::tensor::{MetalArena, MetalTensor, MetalTensorExt}; +pub use crate::memory::MetalMemoryPool; +pub use crate::plan::MetalPlanState; +pub use crate::tensor::{MetalTensor, MetalTensorExt}; pub use crate::transform::MetalTransform; use anyhow::Result; pub use fact::MetalFact; diff --git a/metal/src/memory/mod.rs b/metal/src/memory/mod.rs new file mode 100644 index 0000000000..fafbd6e408 --- /dev/null +++ b/metal/src/memory/mod.rs @@ -0,0 +1,5 @@ +mod pool; +mod schema; + +pub use pool::*; +pub use schema::*; diff --git a/metal/src/memory/pool.rs b/metal/src/memory/pool.rs new file mode 100644 index 0000000000..1babee4eab --- /dev/null +++ b/metal/src/memory/pool.rs @@ -0,0 +1,138 @@ +use crate::memory::MetalResolvedMemSchema; +use crate::tensor::{MetalArenaStorage, MetalArenaView}; +use crate::{IntoMetal, MetalContext, MetalTensor}; +use anyhow::Result; +use std::cell::RefCell; +use std::collections::HashSet; +use tract_core::internal::*; + +#[derive(Debug)] +pub struct MetalMemoryPool { + storage: Arc, + alignment: usize, + resolved_schema: MetalResolvedMemSchema, + node_seen: RefCell>, +} + +impl MetalMemoryPool { + pub fn from_schema( + context: &MetalContext, + resolved_schema: MetalResolvedMemSchema, + ) -> Result { + let alignment = std::mem::size_of::(); + let storage = Arc::new(MetalArenaStorage::with_capacity( + context, + resolved_schema.memory_size, + alignment, + )?); + + Ok(Self { storage, alignment, resolved_schema, node_seen: RefCell::new(HashSet::new()) }) + } + + pub fn tensor_for_node( + &self, + node_id: usize, + dt: DatumType, + shape: &[usize], + ) -> Result { + // ensure!(!self.node_seen.borrow().contains(&node_id), "Tensor for node {:?} was already requested. Maybe the memory pool was not reset properly.", node_id); + let alignment = dt.alignment(); + (self.alignment % alignment != 0) + .then(|| self.resolved_schema.offsets_by_node[node_id]) + .map(|offset| { + // self.node_seen.borrow_mut().insert(node_id); + Ok(MetalArenaView { + arena: Arc::clone(&self.storage), + dt, + shape: shape.into(), + strides: Tensor::natural_strides(shape), + offset_bytes: offset, + } + .into()) + }) + .unwrap_or_else(|| unsafe { Tensor::uninitialized_dt(dt, shape)?.into_metal() }) + } + + pub fn reset(&self) { + self.node_seen.borrow_mut().clear(); + } +} + +// #[derive(Debug)] +// pub struct MetalArena { +// storage: Arc, +// cursor: AtomicUsize, +// capacity: usize, +// alignment: usize, +// } + +// impl MetalArena { +// pub fn with_capacity(context: &MetalContext, capacity: usize) -> TractResult { +// let alignment = std::mem::size_of::(); +// let tensor = unsafe { +// Tensor::uninitialized_aligned_dt(DatumType::U8, &[capacity], alignment).with_context( +// || anyhow!("Error while allocating a tensor of {:?} bytes", capacity), +// )? +// }; +// let buffer = context.device().new_buffer_with_bytes_no_copy( +// tensor.as_bytes().as_ptr() as *const core::ffi::c_void, +// capacity as _, +// MTLResourceOptions::StorageModeShared, +// None, +// ); +// Ok(Self { +// storage: Arc::new(MetalArenaStorage { tensor, metal: buffer }), +// cursor: AtomicUsize::new(0), +// capacity, +// alignment, +// }) +// } + +// pub fn free_capacity(&self) -> usize { +// let cursor = self.cursor.load(Ordering::SeqCst); +// self.capacity - cursor +// } + +// pub fn capacity(&self) -> usize { +// self.capacity +// } + +// pub fn used_capacity(&self) -> usize { +// self.cursor.load(Ordering::SeqCst) +// } + +// pub fn view_uninitialized_dt(&self, dt: DatumType, shape: &[usize]) -> Option { +// // Check if we can reset the cursor of the arena for next +// // view. +// self.try_reset(); + +// let alignment = dt.alignment(); +// if self.alignment % alignment != 0 { +// return None; +// } +// let size = dt.size_of() * shape.iter().product::(); + +// let cursor = self.cursor.load(Ordering::SeqCst); + +// let start = if cursor % alignment != 0 { +// cursor + (alignment - cursor % alignment) +// } else { +// cursor +// }; + +// let end = start + size; +// if self.capacity < end { +// return None; +// } + +// self.cursor.store(end, Ordering::SeqCst); + +// Some(MetalArenaView { +// arena: Arc::clone(&self.storage), +// dt, +// shape: shape.into(), +// strides: Tensor::natural_strides(shape), +// offset_bytes: start, +// }) +// } +// } diff --git a/metal/src/memory/schema.rs b/metal/src/memory/schema.rs new file mode 100644 index 0000000000..f89242ee5a --- /dev/null +++ b/metal/src/memory/schema.rs @@ -0,0 +1,266 @@ +use crate::fact::MetalTypedFactExt; +use std::fmt; +use std::fmt::Debug; +use tract_core::internal::*; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ScopedNodeMemory { + pub node: usize, + pub scope: Scope, + pub mem_size: TDim, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Scope { + pub start: usize, + pub end: usize, +} + +impl Scope { + pub fn is_disjoint(&self, other: &Scope) -> bool { + self.start >= other.end || other.start >= self.end + } + + pub fn is_alive_at_step(&self, step: usize) -> bool { + self.start <= step && step < self.end + } + + pub fn len(&self) -> usize { + self.end - self.start + } +} + +pub fn eval_metal_scope_node_mem( + model: &TypedModel, + order: &[usize], +) -> TractResult> { + let outputs = model.output_outlets()?.to_vec(); + let flush_lists = order::build_flush_list(model, order, &outputs, |node| { + let Ok(facts) = model.node_output_facts(node.id) else { return false }; + + facts.iter().any(|it| it.to_metal_fact().map(|it| it.is_temporary()).unwrap_or(false)) + }); + let mut scoped_nodes = tvec![]; + + for (step, n) in order.iter().enumerate() { + let scope_start = step; + let scope_end = flush_lists + .iter() + .enumerate() + .find(|(_step, flush_list)| flush_list.contains(n)) + .map(|it| usize::min(it.0 + 1, order.len())); + + let Some(scope_end) = scope_end else { + continue; + }; + + let out_metal_tmp_facts = model + .node_output_facts(*n)? + .into_iter() + .flat_map(|it| it.to_metal_fact().ok()) + .filter(|it| it.is_temporary()) + .collect::>(); + + if out_metal_tmp_facts.is_empty() { + continue; + } + + scoped_nodes.push(ScopedNodeMemory { + node: *n, + scope: Scope { start: scope_start, end: scope_end }, + mem_size: out_metal_tmp_facts.iter().map(|it| it.mem_size()).sum::(), + }) + } + + Ok(scoped_nodes) +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Partition { + pub nodes: Vec, +} + +impl Partition { + pub fn eval_size_to_i64(&self, symbols: &SymbolValues) -> TractResult { + Ok(self + .nodes + .iter() + .map(|it| it.mem_size.eval_to_i64(symbols)) + .collect::>>()? + .into_iter() + .max() + .unwrap_or(0)) + } + + pub fn size(&self) -> TDim { + TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect()) + } + + pub fn is_disjoint(&self, scope: &Scope) -> bool { + self.nodes.iter().all(|n| n.scope.is_disjoint(scope)) + } + + pub fn find_node_alive_at_step(&self, step: usize) -> Option<&ScopedNodeMemory> { + self.nodes.iter().find(|it| it.scope.is_alive_at_step(step)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MetalResolvedMemSchema { + pub offsets_by_node: Vec, + pub memory_size: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MetalMemSchema { + pub by_partition: Vec, + pub by_steps: Vec>>, +} + +impl MetalMemSchema { + pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult> { + self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect() + } + + pub fn size_by_partition(&self) -> Vec { + self.by_partition.iter().map(|it| it.size()).collect() + } + + pub fn memory_size(&self) -> TDim { + self.by_partition.iter().map(|it| it.size()).sum() + } + + pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult { + self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum() + } + + pub fn compute_offset_by_node( + &self, + num_nodes: usize, + symbols: &SymbolValues, + ) -> TractResult> { + let mut cursor = 0; + let mut offset_by_node = vec![0; num_nodes]; + + for partition in self.by_partition.iter() { + for node_mem in partition.nodes.iter() { + offset_by_node[node_mem.node] = cursor; + } + cursor += partition.eval_size_to_i64(symbols)? as usize; + } + + Ok(offset_by_node) + } + + pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult { + Ok(self + .by_steps + .iter() + .map(|active_nodes| { + active_nodes + .iter() + .flat_map(|it| it) + .map(|it| it.mem_size.clone()) + .sum::() + .eval_to_i64(symbols) + }) + .collect::>>()? + .into_iter() + .max() + .unwrap_or(0)) + } + + pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult { + let memory_size = self.eval_memory_size(symbols)? as f32; + let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32; + Ok(peak_memory_size / memory_size) + } +} + +impl fmt::Display for MetalMemSchema { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + for (step, mem_step) in self.by_steps.iter().enumerate() { + writeln!( + fmt, + "step: {:5} => |{}|", + step, + mem_step + .iter() + .map(|n| -> String { + n.as_ref() + .map(|it| format!("{:^7}", it.node)) + .unwrap_or(format!("{:^7}", "*")) + }) + .collect::>() + .join("|") + )?; + } + writeln!(fmt, "memory_size: {}", self.memory_size())?; + Ok(()) + } +} + +impl MetalMemSchema { + pub fn resolve(&self, num_nodes: usize, symbols: &SymbolValues) -> TractResult { + Ok(MetalResolvedMemSchema { + offsets_by_node: self.compute_offset_by_node(num_nodes, symbols)?, + memory_size: self.eval_memory_size(symbols)?.try_into()?, + }) + } + + pub fn build( + model: &TypedModel, + order: &[usize], + hint: &SymbolValues, + ) -> TractResult { + let mut scoped_nodes_mem = eval_metal_scope_node_mem(&model, &order)?; + + let hinted_mem_size = scoped_nodes_mem + .iter() + .map(|node_mem| Ok((node_mem.node, node_mem.mem_size.eval_to_i64(hint)?))) + .collect::>>()?; + + scoped_nodes_mem.sort_by(|lhs, rhs| { + let lhs_hint_mem_size = hinted_mem_size.get(&lhs.node); + let rhs_hint_mem_size = hinted_mem_size.get(&rhs.node); + + lhs.scope + .end + .cmp(&rhs.scope.end) + .reverse() + .then(lhs.scope.len().cmp(&rhs.scope.len()).reverse()) + .then(lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse()) + }); + + let mut partitions: Vec = vec![]; + for node_mem in scoped_nodes_mem { + // Find partitions where node scope is disjoint from existing. + let mut available = partitions + .iter_mut() + .filter(|it| it.is_disjoint(&node_mem.scope)) + .collect::>(); + + available.sort_by_cached_key(|n| { + n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.node)).sum::() * -1 + }); + + match available.first_mut() { + Some(available) => { + available.nodes.push(node_mem); + } + None => partitions.push(Partition { nodes: vec![node_mem] }), + } + } + + let by_steps: Vec>> = (0..order.len()) + .map(|step| { + let mem_step: Vec<_> = + partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect(); + ensure!(mem_step.len() <= partitions.len()); + Ok(mem_step) + }) + .collect::>>()?; + + Ok(MetalMemSchema { by_partition: partitions, by_steps }) + } +} diff --git a/metal/src/ops/binary.rs b/metal/src/ops/binary.rs index 35c10f3c45..7250d0ed06 100644 --- a/metal/src/ops/binary.rs +++ b/metal/src/ops/binary.rs @@ -1,5 +1,6 @@ pub use crate::kernels::BinOps; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use tract_core::internal::*; #[derive(Debug, Clone)] @@ -40,29 +41,28 @@ impl Op for MetalBinOp { op_as_typed_op!(); } -impl EvalOp for MetalBinOp { - fn is_stateless(&self) -> bool { - true - } - - fn eval(&self, inputs: TVec) -> TractResult> { - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let (opaque_a, opaque_b) = args_2!(inputs); - let a = opaque_a.to_metal_tensor()?; - let b = opaque_b.to_metal_tensor()?; +crate::impl_eval_op_for_metal_op!(MetalBinOp); - let out_shape = self.0.output_shape(a.shape(), b.shape())?; - let out_dt = self.0.output_datum_type(a.datum_type(), b.datum_type())?; - let output = unsafe { MetalTensor::uninitialized_dt(out_dt, &out_shape)? }; - self.0 - .dispatch_eval(context, a, b, &output) - .with_context(|| "Error while dispatching eval for Metal Bin Op")?; +impl MetalEvalOp for MetalBinOp { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let (opaque_a, opaque_b) = args_2!(inputs); + let a = opaque_a.to_metal_tensor()?; + let b = opaque_b.to_metal_tensor()?; + let out_shape = self.0.output_shape(a.shape(), b.shape())?; + let out_dt = self.0.output_datum_type(a.datum_type(), b.datum_type())?; + let output = crate::ops::make_tensor_for_node(context, node_id, out_dt, &out_shape)?; + self.0 + .dispatch_eval(context, a, b, &output) + .with_context(|| "Error while dispatching eval for Metal Bin Op")?; - ensure!(a.rank() == b.rank()); - Ok(tvec!(output.into_opaque_tensor().into_tvalue())) - }) - }) + ensure!(a.rank() == b.rank()); + Ok(tvec!(output.into_opaque_tensor().into_tvalue())) } } diff --git a/metal/src/ops/broadcast.rs b/metal/src/ops/broadcast.rs index 8ff42eb416..79a9438a52 100644 --- a/metal/src/ops/broadcast.rs +++ b/metal/src/ops/broadcast.rs @@ -1,5 +1,5 @@ -use crate::tensor::MetalTensorExt; -use crate::{kernels, MetalTensor}; +use crate::ops::MetalEvalOp; +use crate::{kernels, MetalContext, MetalTensorExt}; use derive_new::new; use std::fmt::Debug; use tract_core::internal::*; @@ -17,26 +17,23 @@ impl Op for MetalMultiBroadcastTo { op_as_typed_op!(); } -impl EvalOp for MetalMultiBroadcastTo { - fn is_stateless(&self) -> bool { - true - } +crate::impl_eval_op_for_metal_op!(MetalMultiBroadcastTo); - fn eval_with_session( +impl MetalEvalOp for MetalMultiBroadcastTo { + fn metal_eval( &self, - session: &SessionState, + context: &MetalContext, + node_id: usize, + session: &mut SessionState, inputs: TVec, ) -> TractResult> { + let opaque = args_1!(inputs); let shape = self.shape.eval_to_usize(&session.resolved_symbols)?; - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let opaque = args_1!(inputs); - let input = opaque.to_metal_tensor()?; - let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), &shape)? }; - kernels::array::MultiBroadcast.dispatch_eval(context, input, 0, &output)?; - Ok(tvec![output.into_opaque_tensor().into_tvalue()]) - }) - }) + let input = opaque.to_metal_tensor()?; + let output = + crate::ops::make_tensor_for_node(context, node_id, input.datum_type(), &shape)?; + kernels::array::MultiBroadcast.dispatch_eval(context, input, 0, &output)?; + Ok(tvec![output.into_opaque_tensor().into_tvalue()]) } } diff --git a/metal/src/ops/cast.rs b/metal/src/ops/cast.rs index 34b0b52026..372e893f33 100644 --- a/metal/src/ops/cast.rs +++ b/metal/src/ops/cast.rs @@ -1,5 +1,6 @@ use crate::kernels; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use tract_core::internal::*; #[derive(Debug, Clone, Hash, PartialEq, Eq)] @@ -26,25 +27,25 @@ impl Op for MetalCast { impl_op_same_as!(); } -impl EvalOp for MetalCast { - fn is_stateless(&self) -> bool { - true - } +crate::impl_eval_op_for_metal_op!(MetalCast); - fn eval(&self, inputs: TVec) -> TractResult> { +impl MetalEvalOp for MetalCast { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { let opaque = args_1!(inputs); let input = opaque.to_metal_tensor()?; if input.datum_type() == self.to { Ok(tvec!(opaque)) } else { - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let output = unsafe { MetalTensor::uninitialized_dt(self.to, input.shape())? }; - kernels::array::Cast.dispatch_eval(context, input, &output)?; - - Ok(tvec![output.into_opaque_tensor().into_tvalue()]) - }) - }) + let output = + crate::ops::make_tensor_for_node(context, node_id, self.to, input.shape())?; + kernels::array::Cast.dispatch_eval(context, input, &output)?; + Ok(tvec![output.into_opaque_tensor().into_tvalue()]) } } } diff --git a/metal/src/ops/change_axes.rs b/metal/src/ops/change_axes.rs index 19d1111374..219069092e 100644 --- a/metal/src/ops/change_axes.rs +++ b/metal/src/ops/change_axes.rs @@ -1,5 +1,6 @@ use crate::kernels::array::PermuteAxes; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use std::fmt::Debug; use tract_core::internal::*; use tract_itertools::Itertools; @@ -47,35 +48,31 @@ impl Op for MetalAxisOp { op_as_typed_op!(); } -impl EvalOp for MetalAxisOp { - fn is_stateless(&self) -> bool { - true - } +crate::impl_eval_op_for_metal_op!(MetalAxisOp); - fn eval_with_session( +impl MetalEvalOp for MetalAxisOp { + fn metal_eval( &self, - session: &SessionState, + context: &MetalContext, + node_id: usize, + session: &mut SessionState, inputs: TVec, ) -> TractResult> { let opaque = args_1!(inputs).into_tensor(); let input = opaque.to_metal_tensor()?; let new_shape = match &self.0 { AxisOp::Move(from, to) => { - return objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| -> TractResult<_> { - let mut permutation: Vec = (0..input.rank()).collect(); - permutation.remove(*from); - permutation.insert(*to, *from); - let output = unsafe { - MetalTensor::uninitialized_dt( - input.datum_type(), - &PermuteAxes::output_shape(input.shape(), &permutation)?, - )? - }; - PermuteAxes.dispatch_eval(context, input, &permutation, &output)?; - Ok(tvec!(output.into_opaque_tensor().into_tvalue())) - }) - }); + let mut permutation: Vec = (0..input.rank()).collect(); + permutation.remove(*from); + permutation.insert(*to, *from); + let output = crate::ops::make_tensor_for_node( + context, + node_id, + input.datum_type(), + &PermuteAxes::output_shape(input.shape(), &permutation)?, + )?; + PermuteAxes.dispatch_eval(context, input, &permutation, &output)?; + return Ok(tvec!(output.into_opaque_tensor().into_tvalue())); } AxisOp::Reshape(skip, from, to) => { let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect(); diff --git a/metal/src/ops/concat.rs b/metal/src/ops/concat.rs index 0bccd50b40..0e7a9e575c 100644 --- a/metal/src/ops/concat.rs +++ b/metal/src/ops/concat.rs @@ -1,5 +1,6 @@ use crate::kernels::array::Concat; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use derive_new::new; use tract_core::internal::*; use tract_core::ops::array::TypedConcat; @@ -41,29 +42,32 @@ impl Op for MetalConcat { op_as_typed_op!(); } -impl EvalOp for MetalConcat { - fn is_stateless(&self) -> bool { - true - } +crate::impl_eval_op_for_metal_op!(MetalConcat); - fn eval(&self, opaque_inputs: TVec) -> TractResult> { - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let inputs = opaque_inputs - .iter() - .map(|it| it.to_metal_tensor()) - .collect::>>()?; +impl MetalEvalOp for MetalConcat { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + opaque_inputs: TVec, + ) -> TractResult> { + let inputs = opaque_inputs + .iter() + .map(|it| it.to_metal_tensor()) + .collect::>>()?; - let mut output_shape = inputs[0].shape().to_vec(); - output_shape[self.axis()] = inputs.iter().map(|it| it.shape()[self.axis()]).sum(); - let output = unsafe { - MetalTensor::uninitialized_dt(inputs[0].datum_type(), &output_shape)? - }; - self.kernel.dispatch_eval(context, &inputs, &output)?; + let mut output_shape = inputs[0].shape().to_vec(); + output_shape[self.axis()] = inputs.iter().map(|it| it.shape()[self.axis()]).sum(); + let output = crate::ops::make_tensor_for_node( + context, + node_id, + inputs[0].datum_type(), + &output_shape, + )?; + self.kernel.dispatch_eval(context, &inputs, &output)?; - Ok(tvec!(output.into_opaque_tensor().into_tvalue())) - }) - }) + Ok(tvec!(output.into_opaque_tensor().into_tvalue())) } } diff --git a/metal/src/ops/element_wise.rs b/metal/src/ops/element_wise.rs index 1cc26858e0..5330318a4d 100644 --- a/metal/src/ops/element_wise.rs +++ b/metal/src/ops/element_wise.rs @@ -1,5 +1,6 @@ pub use crate::kernels::ElementWiseOps; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use tract_core::internal::*; #[derive(Debug, Clone)] @@ -22,21 +23,21 @@ impl Op for MetalElementWiseOp { op_as_typed_op!(); } -impl EvalOp for MetalElementWiseOp { - fn is_stateless(&self) -> bool { - true - } - - fn eval(&self, inputs: TVec) -> TractResult> { - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let opaque_a = args_1!(inputs); - let a = opaque_a.to_metal_tensor()?; - let output = unsafe { MetalTensor::uninitialized_dt(a.datum_type(), a.shape())? }; - self.0.dispatch_eval(context, a, &output)?; - Ok(tvec![output.into_opaque_tensor().into_tvalue()]) - }) - }) +crate::impl_eval_op_for_metal_op!(MetalElementWiseOp); + +impl MetalEvalOp for MetalElementWiseOp { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let opaque_a = args_1!(inputs); + let a = opaque_a.to_metal_tensor()?; + let output = crate::ops::make_tensor_for_node(context, node_id, a.datum_type(), a.shape())?; + self.0.dispatch_eval(context, a, &output)?; + Ok(tvec![output.into_opaque_tensor().into_tvalue()]) } } diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index 4c640efc20..ea005e29d5 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -1,6 +1,7 @@ use crate::kernels::matmul::{GemmImpl, GemmKernel}; +use crate::ops::MetalEvalOp; -use crate::{MetalTensor, MetalTensorExt}; +use crate::{MetalContext, MetalTensorExt}; use anyhow::{bail, ensure}; use tract_core::internal::*; @@ -60,28 +61,41 @@ impl MetalGemm { } } -impl EvalOp for MetalGemm { +impl MetalEvalOp for MetalGemm { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let (a_opaque, b_opaque) = args_2!(inputs); + let a = a_opaque + .to_metal_tensor() + .with_context(|| anyhow!("A tensor is not a metal tensor: {:?}", a_opaque))?; + let b = b_opaque + .to_metal_tensor() + .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; + let c_dt = a.datum_type(); + let c_shape = self.kernel.output_shape(a.shape(), b.shape()); + let c = crate::ops::make_tensor_for_node(context, node_id, c_dt, &c_shape)?; + self.kernel.dispatch_eval(context, a, b, &c)?; + Ok(tvec![c.into_opaque_tensor().into_tvalue()]) + } +} + +impl EvalOp for MetalGemm { fn is_stateless(&self) -> bool { - true + false } - fn eval(&self, inputs: TVec) -> TractResult> { - let (a_opaque, b_opaque) = args_2!(inputs); - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let a = a_opaque - .to_metal_tensor() - .with_context(|| anyhow!("A tensor is not a metal tensor: {:?}", a_opaque))?; - let b = b_opaque - .to_metal_tensor() - .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; - let c_dt = a.datum_type(); - let c_shape = self.kernel.output_shape(a.shape(), b.shape()); - let c = unsafe { MetalTensor::uninitialized_dt(c_dt, &c_shape)? }; - self.kernel.dispatch_eval(context, a, b, &c)?; - Ok(tvec![c.into_opaque_tensor().into_tvalue()]) - }) - }) + #[allow(unused_variables)] + fn state( + &self, + session: &mut tract_core::internal::SessionState, + node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(crate::ops::MetalOpState::new(node_id, self.clone())))) } } diff --git a/metal/src/ops/mod.rs b/metal/src/ops/mod.rs index 77a7d3fe89..bc6e1f9c91 100644 --- a/metal/src/ops/mod.rs +++ b/metal/src/ops/mod.rs @@ -29,3 +29,63 @@ pub use silu::MetalSilu; pub use slice::MetalSlice; pub use softmax::MetalSoftmax; pub use sync::{MetalSync, MetalSyncKind}; + +use crate::{MetalContext, MetalTensor}; +use derive_new::new; +use tract_core::internal::*; +use tract_core::ops::OpStateFreeze; + +pub trait MetalEvalOp: EvalOp + Op + Clone { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + session: &mut SessionState, + inputs: TVec, + ) -> TractResult>; +} + +#[derive(Debug, Clone, new)] +pub struct MetalOpState { + node_id: usize, + op: O, +} + +impl OpStateFreeze for MetalOpState { + fn freeze(&self) -> Box<(dyn FrozenOpState + 'static)> { + Box::new(self.clone()) + } +} + +impl FrozenOpState for MetalOpState { + fn unfreeze(&self) -> Box { + Box::new(self.clone()) + } +} + +impl OpState for MetalOpState { + fn eval( + &mut self, + session: &mut SessionState, + _op: &dyn Op, + inputs: TVec, + ) -> TractResult> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT + .with_borrow(|context| self.op.metal_eval(context, self.node_id, session, inputs)) + }) + } +} + +pub fn make_tensor_for_node( + context: &MetalContext, + node_id: usize, + dt: DatumType, + shape: &[usize], +) -> TractResult { + context + .memory_pool() + .as_ref() + .map(|mem| mem.tensor_for_node(node_id, dt, &shape)) + .unwrap_or_else(|| unsafe { MetalTensor::uninitialized_dt(dt, &shape) }) +} diff --git a/metal/src/ops/new_gelu.rs b/metal/src/ops/new_gelu.rs index 2f88b14e45..f73547cd3a 100644 --- a/metal/src/ops/new_gelu.rs +++ b/metal/src/ops/new_gelu.rs @@ -1,5 +1,6 @@ use crate::kernels::nn::NewGelu; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use derive_new::new; use tract_core::internal::*; @@ -14,24 +15,26 @@ impl Op for MetalNewGelu { op_as_typed_op!(); } -impl EvalOp for MetalNewGelu { - fn is_stateless(&self) -> bool { - true - } - - fn eval(&self, inputs: TVec) -> TractResult> { - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let input = args_1!(inputs); - let input_metal = input.to_metal_tensor()?; - let output = unsafe { - MetalTensor::uninitialized_dt(input_metal.datum_type(), input_metal.shape())? - }; - NewGelu::accurate().dispatch_eval(context, input_metal, &output)?; - - Ok(tvec!(output.into_opaque_tensor().into_tvalue())) - }) - }) +crate::impl_eval_op_for_metal_op!(MetalNewGelu); + +impl MetalEvalOp for MetalNewGelu { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let input = args_1!(inputs); + let input_metal = input.to_metal_tensor()?; + let output = crate::ops::make_tensor_for_node( + context, + node_id, + input_metal.datum_type(), + input_metal.shape(), + )?; + NewGelu::accurate().dispatch_eval(context, input_metal, &output)?; + Ok(tvec!(output.into_opaque_tensor().into_tvalue())) } } diff --git a/metal/src/ops/rms_norm.rs b/metal/src/ops/rms_norm.rs index 68c03b5590..ba63ef81ca 100644 --- a/metal/src/ops/rms_norm.rs +++ b/metal/src/ops/rms_norm.rs @@ -1,5 +1,6 @@ use crate::kernels::nn::RmsNorm; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use derive_new::new; use std::sync::Arc; use tract_core::internal::*; @@ -20,22 +21,22 @@ impl Op for MetalRmsNorm { op_as_typed_op!(); } -impl EvalOp for MetalRmsNorm { - fn is_stateless(&self) -> bool { - true - } +crate::impl_eval_op_for_metal_op!(MetalRmsNorm); - fn eval(&self, inputs: TVec) -> TractResult> { - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let opaque = args_1!(inputs); - let input = opaque.to_metal_tensor()?; - let output = - unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? }; - RmsNorm.dispatch_eval(context, input, self.axis, &self.eps, &output)?; - Ok(tvec!(output.into_opaque_tensor().into_tvalue())) - }) - }) +impl MetalEvalOp for MetalRmsNorm { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let opaque = args_1!(inputs); + let input = opaque.to_metal_tensor()?; + let output = + crate::ops::make_tensor_for_node(context, node_id, input.datum_type(), input.shape())?; + RmsNorm.dispatch_eval(context, input, self.axis, &self.eps, &output)?; + Ok(tvec!(output.into_opaque_tensor().into_tvalue())) } } diff --git a/metal/src/ops/silu.rs b/metal/src/ops/silu.rs index 6c2e294e26..8b7f1c746c 100644 --- a/metal/src/ops/silu.rs +++ b/metal/src/ops/silu.rs @@ -1,5 +1,6 @@ use crate::kernels::nn::Silu; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use derive_new::new; use tract_core::internal::*; @@ -14,22 +15,22 @@ impl Op for MetalSilu { op_as_typed_op!(); } -impl EvalOp for MetalSilu { - fn is_stateless(&self) -> bool { - true - } +crate::impl_eval_op_for_metal_op!(MetalSilu); - fn eval(&self, inputs: TVec) -> TractResult> { - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let opaque = args_1!(inputs); - let input = opaque.to_metal_tensor()?; - let output = - unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? }; - Silu.dispatch_eval(context, input, &output)?; - Ok(tvec!(output.into_opaque_tensor().into_tvalue())) - }) - }) +impl MetalEvalOp for MetalSilu { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let opaque = args_1!(inputs); + let input = opaque.to_metal_tensor()?; + let output = + crate::ops::make_tensor_for_node(context, node_id, input.datum_type(), input.shape())?; + Silu.dispatch_eval(context, input, &output)?; + Ok(tvec!(output.into_opaque_tensor().into_tvalue())) } } diff --git a/metal/src/ops/slice.rs b/metal/src/ops/slice.rs index 206bd49593..d23ec69a2e 100644 --- a/metal/src/ops/slice.rs +++ b/metal/src/ops/slice.rs @@ -1,5 +1,6 @@ use crate::kernels; -use crate::tensor::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use tract_core::internal::*; use tract_core::ops::array::Slice; @@ -40,27 +41,26 @@ impl Op for MetalSlice { } } -impl EvalOp for MetalSlice { - fn is_stateless(&self) -> bool { - true - } +crate::impl_eval_op_for_metal_op!(MetalSlice); - fn eval_with_session( +impl MetalEvalOp for MetalSlice { + fn metal_eval( &self, - session: &SessionState, + context: &MetalContext, + node_id: usize, + session: &mut SessionState, inputs: TVec, ) -> TractResult> { - let input = args_1!(inputs); + let opaque = args_1!(inputs); + let input = opaque.to_metal_tensor()?; + let start = self.0.start.eval(&session.resolved_symbols).to_usize()?; let end = self.0.end.eval(&session.resolved_symbols).to_usize()?; let axis = self.0.axis; - let input_shape = input.as_metal_tensor().map(|it| it.shape()).unwrap_or(input.shape()); - let input_strides = - input.as_metal_tensor().map(|it| it.strides()).unwrap_or(input.strides()); - - let input_dt = - input.as_metal_tensor().map(|it| it.datum_type()).unwrap_or(input.datum_type()); + let input_shape = input.shape(); + let input_strides = input.strides(); + let input_dt = input.datum_type(); ensure!( end <= input_shape[axis] && start <= end, @@ -76,19 +76,14 @@ impl EvalOp for MetalSlice { let offset = (start * input_strides[axis] as usize) * input_dt.size_of(); - let input = input.to_metal_tensor()?; - let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), &o_shape)? }; + let output = + crate::ops::make_tensor_for_node(context, node_id, input.datum_type(), &o_shape)?; - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - // Perform slicing only if the output is not empty. - if o_shape[axis] != 0 { - kernels::array::MultiBroadcast - .dispatch_eval(context, input, offset, &output)?; - } - Ok(tvec![output.into_opaque_tensor().into_tvalue()]) - }) - }) + // Perform slicing only if the output is not empty. + if o_shape[axis] != 0 { + kernels::array::MultiBroadcast.dispatch_eval(context, input, offset, &output)?; + } + Ok(tvec![output.into_opaque_tensor().into_tvalue()]) } } diff --git a/metal/src/ops/softmax.rs b/metal/src/ops/softmax.rs index 5ef5de2c26..99245cb447 100644 --- a/metal/src/ops/softmax.rs +++ b/metal/src/ops/softmax.rs @@ -1,5 +1,6 @@ use crate::kernels::nn::Softmax; -use crate::{MetalTensor, MetalTensorExt}; +use crate::ops::MetalEvalOp; +use crate::{MetalContext, MetalTensorExt}; use std::fmt::Debug; use tract_core::internal::*; use tract_core::ops::nn as core_ops_nn; @@ -75,22 +76,22 @@ impl TypedOp for MetalSoftmax { as_op!(); } -impl EvalOp for MetalSoftmax { - fn is_stateless(&self) -> bool { - true - } +crate::impl_eval_op_for_metal_op!(MetalSoftmax); - fn eval(&self, inputs: TVec) -> TractResult> { - objc::rc::autoreleasepool(|| { - crate::METAL_CONTEXT.with_borrow(|context| { - let opaque = args_1!(inputs); - let input = opaque.to_metal_tensor()?; - let output = - unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? }; - Softmax.dispatch_eval(context, input, self.axes[0], &output)?; +impl MetalEvalOp for MetalSoftmax { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + _session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let opaque = args_1!(inputs); + let input = opaque.to_metal_tensor()?; + let output = + crate::ops::make_tensor_for_node(context, node_id, input.datum_type(), input.shape())?; + Softmax.dispatch_eval(context, input, self.axes[0], &output)?; - Ok(tvec!(output.into_opaque_tensor().into_tvalue())) - }) - }) + Ok(tvec!(output.into_opaque_tensor().into_tvalue())) } } diff --git a/metal/src/plan.rs b/metal/src/plan.rs new file mode 100644 index 0000000000..f8a2fad8f8 --- /dev/null +++ b/metal/src/plan.rs @@ -0,0 +1,81 @@ +use crate::memory::MetalMemSchema; +use crate::MetalMemoryPool; +use std::borrow::Borrow; +use tract_core::internal::*; + +pub struct MetalPlanState +where + M: Borrow>>, + P: Borrow> + Clone, +{ + pub mem_schema: MetalMemSchema, + pub state: TypedSimpleState, +} + +impl MetalPlanState +where + M: Borrow>>, + P: Borrow, M>> + Clone, +{ + pub fn new(plan: P, memory_hint: &SymbolValues) -> TractResult { + let state = TypedSimpleState::new(plan)?; + let mem_schema = MetalMemSchema::build( + state.plan().model(), + state.plan().order_without_consts(), + memory_hint, + )?; + Ok(Self { state, mem_schema }) + } + + pub fn run_plan_with_eval( + &mut self, + inputs: TVec, + eval: Eval, + ) -> TractResult> + where + Eval: for<'a, 'b, 'c> FnMut( + &'a mut SessionState, + Option<&'b mut (dyn OpState + 'static)>, + &'c TypedNode, + TVec, + ) -> Result, E>, + E: Into + Send + Sync + 'static, + { + self.state.session_state = SessionState::default(); + self.state.set_inputs(inputs)?; + + let resolved_mem_schema = + self.mem_schema.resolve(self.model().nodes().len(), &self.state.session_state.resolved_symbols)?; + + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let memory_pool = MetalMemoryPool::from_schema(context, resolved_mem_schema)?; + let (_, outputs) = context.execute_in_mem_pool(memory_pool, || { + self.state.exec_plan_with_eval(eval)?; + let outputs = self.state.outputs()?; + self.state.reset_turn()?; + Ok(outputs) + })?; + Ok(outputs) + }) + }) + } + + pub fn model(&self) -> &TypedModel { + self.state.plan().model() + } + + pub fn run(&mut self, inputs: TVec) -> TractResult> { + self.run_plan_with_eval(inputs, tract_core::plan::eval) + } + + /// Reset wires state. + pub fn reset_turn(&mut self) -> TractResult<()> { + self.state.reset_turn() + } + + /// Reset op inner state. + pub fn reset_op_states(&mut self) -> TractResult<()> { + self.state.reset_op_states() + } +} diff --git a/metal/src/tensor/arena.rs b/metal/src/tensor/arena.rs deleted file mode 100644 index b3edbb32e1..0000000000 --- a/metal/src/tensor/arena.rs +++ /dev/null @@ -1,117 +0,0 @@ -use crate::tensor::MetalArenaView; -use crate::MetalContext; -use anyhow::{anyhow, Context}; -use core::sync::atomic::AtomicUsize; -use metal::{Buffer, MTLResourceOptions}; -use std::sync::atomic::Ordering; -use tract_core::internal::*; - -#[derive(Debug)] -pub struct MetalArena { - storage: Arc, - cursor: AtomicUsize, - capacity: usize, - alignment: usize, -} - -impl MetalArena { - pub fn with_capacity(context: &MetalContext, capacity: usize) -> TractResult { - let alignment = std::mem::size_of::(); - let tensor = unsafe { - Tensor::uninitialized_aligned_dt(DatumType::U8, &[capacity], alignment).with_context( - || anyhow!("Error while allocating a tensor of {:?} bytes", capacity), - )? - }; - let buffer = context.device().new_buffer_with_bytes_no_copy( - tensor.as_bytes().as_ptr() as *const core::ffi::c_void, - capacity as _, - MTLResourceOptions::StorageModeShared, - None, - ); - Ok(Self { - storage: Arc::new(MetalArenaStorage { tensor, metal: buffer }), - cursor: AtomicUsize::new(0), - capacity, - alignment, - }) - } - - pub fn free_capacity(&self) -> usize { - let cursor = self.cursor.load(Ordering::SeqCst); - self.capacity - cursor - } - - pub fn capacity(&self) -> usize { - self.capacity - } - - pub fn used_capacity(&self) -> usize { - self.cursor.load(Ordering::SeqCst) - } - - pub fn view_uninitialized_dt(&self, dt: DatumType, shape: &[usize]) -> Option { - // Check if we can reset the cursor of the arena for next - // view. - self.try_reset(); - - let alignment = dt.alignment(); - if self.alignment % alignment != 0 { - return None; - } - let size = dt.size_of() * shape.iter().product::(); - - let cursor = self.cursor.load(Ordering::SeqCst); - - let start = if cursor % alignment != 0 { - cursor + (alignment - cursor % alignment) - } else { - cursor - }; - - let end = start + size; - if self.capacity < end { - return None; - } - - self.cursor.store(end, Ordering::SeqCst); - - Some(MetalArenaView { - arena: Arc::clone(&self.storage), - dt, - shape: shape.into(), - strides: Tensor::natural_strides(shape), - offset_bytes: start, - }) - } - - pub fn try_reset(&self) { - let cursor = self.cursor.load(Ordering::SeqCst); - if Arc::strong_count(&self.storage) == 1 && cursor != 0 { - self.cursor.store(0, Ordering::SeqCst) - } - } -} - -#[derive(Debug, Clone)] -pub struct MetalArenaStorage { - tensor: Tensor, - metal: Buffer, -} - -impl MetalArenaStorage { - /// Get underlying inner metal buffer. - pub fn metal(&self) -> &Buffer { - &self.metal - } - - pub fn tensor(&self) -> &Tensor { - &self.tensor - } -} - -impl Hash for MetalArenaStorage { - #[inline] - fn hash(&self, state: &mut H) { - self.tensor.hash(state) - } -} diff --git a/metal/src/tensor/arena_view.rs b/metal/src/tensor/arena_view.rs index 001736d277..c0608c2c13 100644 --- a/metal/src/tensor/arena_view.rs +++ b/metal/src/tensor/arena_view.rs @@ -1,10 +1,56 @@ -use crate::tensor::MetalArenaStorage; +use crate::MetalContext; use anyhow::Result; use metal::Buffer; +use metal::MTLResourceOptions; use num_traits::AsPrimitive; use std::fmt::Display; use tract_core::internal::*; +#[derive(Debug, Clone)] +pub struct MetalArenaStorage { + tensor: Tensor, + metal: Buffer, +} + +impl MetalArenaStorage { + pub fn with_capacity( + context: &MetalContext, + capacity: usize, + alignment: usize, + ) -> TractResult { + let tensor = unsafe { + Tensor::uninitialized_aligned_dt(DatumType::U8, &[capacity], alignment).with_context( + || anyhow!("Error while allocating a tensor of {:?} bytes", capacity), + )? + }; + let buffer = context.device().new_buffer_with_bytes_no_copy( + tensor.as_bytes().as_ptr() as *const core::ffi::c_void, + capacity as _, + MTLResourceOptions::StorageModeShared, + None, + ); + Ok(MetalArenaStorage { tensor, metal: buffer }) + } +} + +impl MetalArenaStorage { + /// Get underlying inner metal buffer. + pub fn metal(&self) -> &Buffer { + &self.metal + } + + pub fn tensor(&self) -> &Tensor { + &self.tensor + } +} + +impl Hash for MetalArenaStorage { + #[inline] + fn hash(&self, state: &mut H) { + self.tensor.hash(state) + } +} + #[derive(Debug, Clone, Hash)] pub struct MetalArenaView { pub(crate) arena: Arc, diff --git a/metal/src/tensor/mod.rs b/metal/src/tensor/mod.rs index d437694565..d1e6e78d87 100644 --- a/metal/src/tensor/mod.rs +++ b/metal/src/tensor/mod.rs @@ -1,8 +1,6 @@ -mod arena; mod arena_view; mod owned; -pub use arena::*; pub use arena_view::*; pub use owned::*; @@ -54,28 +52,20 @@ impl MetalTensor { }) } - // Create a metal tensor with a given shape and a slice of elements. The data is copied and aligned to size of T. - pub fn from_shape(shape: &[usize], data: &[T]) -> Result { - Tensor::from_shape(shape, data)?.into_metal() - } - /// Create an uninitialized MetalTensor pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> Result { - crate::METAL_CONTEXT - .with_borrow(|ctxt| { - ctxt.memory_arena() - .as_ref() - .and_then(|arena| arena.view_uninitialized_dt(dt, shape)) - .map(Self::ArenaView) - .map(Result::Ok) - }) - .unwrap_or_else(|| Tensor::uninitialized_dt(dt, shape)?.into_metal()) + Tensor::uninitialized_dt(dt, shape)?.into_metal() } pub unsafe fn uninitialized(shape: &[usize]) -> Result { Self::uninitialized_dt(T::datum_type(), shape) } + // Create a metal tensor with a given shape and a slice of elements. The data is copied and aligned to size of T. + pub fn from_shape(shape: &[usize], data: &[T]) -> Result { + Tensor::from_shape(shape, data)?.into_metal() + } + pub fn is_supported_dt(dt: DatumType) -> bool { Self::SUPPORTED_DT.contains(&dt) } @@ -265,6 +255,12 @@ impl From for Opaque { } } +impl From for MetalTensor { + fn from(view: MetalArenaView) -> Self { + Self::ArenaView(view) + } +} + impl OpaquePayload for MetalTensor { fn clarify_to_tensor(&self) -> Option> { Some(self.to_cpu()) diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 4f6d770b05..0feec49b79 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -160,6 +160,13 @@ impl Translate, TypedFact, Box> for Met target: &mut TypedModel, mapping: &HashMap, ) -> TractResult> { + + if let Some(op) = node.op_as::() { + dbg!(op); + dbg!(source + .node_input_facts(node.id)?); + } + let in_dts_metal_compatible = source .node_input_facts(node.id)? .iter() diff --git a/metal/src/utils.rs b/metal/src/utils.rs index 374c7c790f..7ff2776d11 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -2,6 +2,26 @@ use crate::fact::{MetalFact, MetalFactKind, MetalTypedFactExt}; use num_traits::{AsPrimitive, Zero}; use tract_core::internal::*; +#[macro_export] +macro_rules! impl_eval_op_for_metal_op { + ($op:ty) => { + impl tract_core::internal::EvalOp for $op { + fn is_stateless(&self) -> bool { + false + } + + #[allow(unused_variables)] + fn state( + &self, + session: &mut tract_core::internal::SessionState, + node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(crate::ops::MetalOpState::new(node_id, self.clone())))) + } + } + }; +} + pub fn metal_tmp_output_facts( facts: &[&TypedFact], resolve_facts: impl Fn(&[&TypedFact]) -> TractResult>,