diff --git a/core/src/ops/cnn/conv/conv.rs b/core/src/ops/cnn/conv/conv.rs index 5ec7dac5cf..15f88862ae 100644 --- a/core/src/ops/cnn/conv/conv.rs +++ b/core/src/ops/cnn/conv/conv.rs @@ -516,6 +516,7 @@ impl Conv { c_m_axis, c_n_axis, ops, + packing == 0 && self.group == 1, )?, &wires, ) diff --git a/core/src/ops/einsum/optimize.rs b/core/src/ops/einsum/optimize.rs index 773b470b44..c67f593afd 100644 --- a/core/src/ops/einsum/optimize.rs +++ b/core/src/ops/einsum/optimize.rs @@ -402,6 +402,8 @@ fn optimized_mat_mul( let outputs = mmms.iter().map(|(mmm, _packing)| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect(); let (mmms, packings): (Vec<_>, Vec<_>) = mmms.into_iter().unzip(); + let trivial_packing = + mmms.len() == 1 && packings[0] == 0 && patch.outlet_fact(a)?.opaque_fact.is_none(); let opt = OptMatMul::new( mmms, c_fact, @@ -411,6 +413,7 @@ fn optimized_mat_mul( ProtoFusedSpec::AddMatMul { geo, a: 0, b: 1, packings }, ProtoFusedSpec::Store(outputs), ], + trivial_packing, ) .context("Creating OptMatMul")?; let output = patch.wire_node(name, opt, &[a, b])?[0]; diff --git a/core/src/ops/matmul/optimized.rs b/core/src/ops/matmul/optimized.rs index d4d7556423..f059b52ff8 100644 --- a/core/src/ops/matmul/optimized.rs +++ b/core/src/ops/matmul/optimized.rs @@ -87,7 +87,7 @@ impl ProtoFusedSpec { AsInputValue::Owned(Box::new(PanelExtractInput { format, data: data.clone(), - to: a_packing.downcast_ref::().unwrap().clone() + to: a_packing.downcast_ref::().unwrap().clone(), })) } else { panic!("Un-matchable input and output for weights {:?} -> {a_packing}", a); @@ -142,49 +142,43 @@ impl ProtoFusedSpec { &'t self, inputs: &'t [TValue], output: &mut Tensor, - mmm: &dyn MatMatMul, + _mmm: &dyn MatMatMul, scenario: usize, ) -> FusedSpec<'t> { let fs = match self { - ProtoFusedSpec::AddMatMul { a, b, packings, .. } => { - let a = &inputs[*a]; - let b = &inputs[*b]; - let a = a.as_slice::().unwrap()[0] - .downcast_ref::>() - .unwrap(); - let b = b.as_slice::().unwrap()[0] - .downcast_ref::>() - .unwrap(); - let (a_packing, b_packing) = &mmm.packings()[packings[scenario]]; - let pa = if a_packing.same_as(a.format()) { - AsInputValue::Borrowed(&**a) - } else if a_packing.is::() - && a_packing.r() == a.format().r() - && a.is::() - && a.format().is::() + ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe { + debug_assert!(inputs.get(*a).is_some()); + debug_assert!(inputs.get(*b).is_some()); + let a = inputs.get_unchecked(*a); + let b = inputs.get_unchecked(*b); + debug_assert!(a.datum_type().is_opaque()); + debug_assert!(a.len() == 1); + debug_assert!(b.datum_type().is_opaque()); + debug_assert!(b.len() == 1); + let a = a.as_slice_unchecked::().get_unchecked(0); + let b = b.as_slice_unchecked::().get_unchecked(0); + debug_assert!(a.is::>()); + debug_assert!(b.is::>()); + let a = a.downcast_ref::>().unwrap_unchecked(); + let b = b.downcast_ref::>().unwrap_unchecked(); + #[cfg(debug_assertions)] { - let format = PanelExtractFormat { - pbqf: a.format().downcast_ref::().unwrap().clone(), - }; - let data = a.downcast_ref::().unwrap(); - AsInputValue::Owned(Box::new(PanelExtractInput { - format, - data: data.clone(), - to: a_packing.downcast_ref::().unwrap().clone() - })) - } else { - panic!("Un-matchable input and output for weights {:?} -> {a_packing}", a); - }; - assert!( - b_packing.same_as(b.format()) - || (b_packing.is::() && b_packing.r() == b.format().r()) - ); + let (a_packing, b_packing) = &_mmm.packings()[packings[scenario]]; + debug_assert!( + a_packing.same_as(a.format()) + || (a_packing.is::() && a_packing.r() == a.format().r()) + ); + debug_assert!( + b_packing.same_as(b.format()) + || (b_packing.is::() && b_packing.r() == b.format().r()) + ); + } FusedSpec::AddMatMul { - a: pa, + a: AsInputValue::Borrowed(&**a), b: AsInputValue::Borrowed(&**b), packing: packings[scenario], } - } + }, ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op), ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]), ProtoFusedSpec::BinPerRow(v, op, _) => { @@ -302,6 +296,7 @@ pub struct OptMatMul { pub mmm: Vec>, pub c_m_axis: usize, pub c_n_axis: usize, + pub trivial_packing: bool, pub trivial_path: bool, } @@ -591,10 +586,19 @@ impl OptMatMul { c_m_axis: usize, c_n_axis: usize, micro_ops: Vec, + trivial_packing: bool, ) -> TractResult { ensure!(c_m_axis < c_fact.rank()); ensure!(c_n_axis < c_fact.rank()); - let mut it = OptMatMul { mmm, c_fact, c_m_axis, c_n_axis, micro_ops, trivial_path: false }; + let mut it = OptMatMul { + mmm, + c_fact, + c_m_axis, + c_n_axis, + micro_ops, + trivial_path: false, + trivial_packing, + }; it.update_trivial_path(); Ok(it) } @@ -626,6 +630,7 @@ impl OptMatMul { .iter() .enumerate() .all(|(ax, dim)| ax == self.c_m_axis || ax == self.c_n_axis || dim.is_one()) + && self.trivial_packing && self.micro_ops.iter().all(|o| o.is_trivial()) }