Skip to content

Commit

Permalink
Add metal axis op for memory pool and fix clippy
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Oct 21, 2024
1 parent 638b171 commit 7aff5aa
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 12 deletions.
1 change: 1 addition & 0 deletions metal/src/memory/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl MetalMemoryPool {
dt: DatumType,
shape: &[usize],
) -> Result<MetalTensor> {
// unsafe { Tensor::uninitialized_dt(dt, shape)?.into_metal() }
// ensure!(!self.node_seen.borrow().contains(&node_id), "Tensor for node {:?} was already requested. Maybe the memory pool was not reset properly.", node_id);
let alignment = dt.alignment();
(self.alignment % alignment == 0)
Expand Down
10 changes: 7 additions & 3 deletions metal/src/memory/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ impl Scope {
self.start <= step && step < self.end
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}

pub fn len(&self) -> usize {
self.end - self.start
}
Expand Down Expand Up @@ -159,7 +163,7 @@ impl MetalMemSchema {
.map(|active_nodes| {
active_nodes
.iter()
.flat_map(|it| it)
.flatten()
.map(|it| it.mem_size.clone())
.sum::<TDim>()
.eval_to_i64(symbols)
Expand Down Expand Up @@ -217,7 +221,7 @@ impl MetalMemSchema {
order: &[usize],
hint: &SymbolValues,
) -> TractResult<MetalMemSchema> {
let mut scoped_nodes_mem = eval_metal_scope_node_mem(&model, &order)?;
let mut scoped_nodes_mem = eval_metal_scope_node_mem(model, order)?;

let hinted_mem_size = scoped_nodes_mem
.iter()
Expand Down Expand Up @@ -245,7 +249,7 @@ impl MetalMemSchema {
.collect::<Vec<_>>();

available.sort_by_cached_key(|n| {
n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.node)).sum::<i64>() * -1
-n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.node)).sum::<i64>()
});

match available.first_mut() {
Expand Down
22 changes: 16 additions & 6 deletions metal/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::kernels::array::PermuteAxes;
use crate::kernels::array::{Memcpy, PermuteAxes};
use crate::ops::MetalEvalOp;
use crate::{MetalContext, MetalTensorExt};
use std::fmt::Debug;
Expand Down Expand Up @@ -88,11 +88,21 @@ impl MetalEvalOp for MetalAxisOp {
}
};

if new_shape.as_slice() != input.shape() {
Ok(tvec![input.reshaped(new_shape)?.into_opaque_tensor().into_tvalue()])
} else {
Ok(tvec![opaque.into_tvalue()])
}
// TODO: avoid copy because of memory pool integration

// if new_shape.as_slice() != input.shape() {
// Ok(tvec![input.reshaped(new_shape)?.into_opaque_tensor().into_tvalue()])
// } else {
// Ok(tvec![opaque.into_tvalue()])
// }

// Perform copy because of memory pool integration

let output =
crate::ops::make_tensor_for_node(context, node_id, input.datum_type(), &new_shape)?;

Memcpy.dispatch_eval(context, input, 0, &output)?;
Ok(tvec!(output.into_opaque_tensor().into_tvalue()))
}
}

Expand Down
4 changes: 2 additions & 2 deletions metal/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ pub fn make_tensor_for_node(
context
.memory_pool()
.as_ref()
.map(|mem| mem.tensor_for_node(node_id, dt, &shape))
.unwrap_or_else(|| unsafe { MetalTensor::uninitialized_dt(dt, &shape) })
.map(|mem| mem.tensor_for_node(node_id, dt, shape))
.unwrap_or_else(|| unsafe { MetalTensor::uninitialized_dt(dt, shape) })
}
2 changes: 1 addition & 1 deletion metal/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ macro_rules! impl_eval_op_for_metal_op {
session: &mut tract_core::internal::SessionState,
node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(crate::ops::MetalOpState::new(node_id, self.clone()))))
Ok(Some(Box::new($crate::ops::MetalOpState::new(node_id, self.clone()))))
}
}
};
Expand Down

0 comments on commit 7aff5aa

Please sign in to comment.