Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

less flexible types to reduce matmul overhead #1555

Merged
merged 5 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ impl Conv {
c_m_axis,
c_n_axis,
ops,
packing == 0 && self.group == 1,
)?,
&wires,
)
Expand Down
3 changes: 3 additions & 0 deletions core/src/ops/einsum/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ fn optimized_mat_mul(
let outputs =
mmms.iter().map(|(mmm, _packing)| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
let (mmms, packings): (Vec<_>, Vec<_>) = mmms.into_iter().unzip();
let trivial_packing =
mmms.len() == 1 && packings[0] == 0 && patch.outlet_fact(a)?.opaque_fact.is_none();
let opt = OptMatMul::new(
mmms,
c_fact,
Expand All @@ -411,6 +413,7 @@ fn optimized_mat_mul(
ProtoFusedSpec::AddMatMul { geo, a: 0, b: 1, packings },
ProtoFusedSpec::Store(outputs),
],
trivial_packing,
)
.context("Creating OptMatMul")?;
let output = patch.wire_node(name, opt, &[a, b])?[0];
Expand Down
83 changes: 44 additions & 39 deletions core/src/ops/matmul/optimized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl ProtoFusedSpec {
use ProtoFusedSpec::*;
match self {
AddMatMul { geo, packings: packing, .. } => {
let (a, b) = mmm.packings()[packing[scenario]];
let (a, b) = &mmm.packings()[packing[scenario]];
format!("matmul(k={}, {a:?}•{b:?})", geo.k)
}
BinScalar(_, op) => format!("scalar{op:?}"),
Expand Down Expand Up @@ -72,7 +72,7 @@ impl ProtoFusedSpec {
let b = b.as_slice::<Opaque>().unwrap()[0]
.downcast_ref::<Box<dyn MMMInputValue>>()
.unwrap();
let (a_packing, b_packing) = mmm.packings()[packings[scenario]];
let (a_packing, b_packing) = &mmm.packings()[packings[scenario]];
let pa = if a_packing.same_as(a.format()) {
AsInputValue::Borrowed(&**a)
} else if a_packing.is::<PackedFormat>()
Expand All @@ -87,7 +87,7 @@ impl ProtoFusedSpec {
AsInputValue::Owned(Box::new(PanelExtractInput {
format,
data: data.clone(),
to: a_packing.downcast_ref::<PackedFormat>().unwrap().clone()
to: a_packing.downcast_ref::<PackedFormat>().unwrap().clone(),
}))
} else {
panic!("Un-matchable input and output for weights {:?} -> {a_packing}", a);
Expand Down Expand Up @@ -142,49 +142,43 @@ impl ProtoFusedSpec {
&'t self,
inputs: &'t [TValue],
output: &mut Tensor,
mmm: &dyn MatMatMul,
_mmm: &dyn MatMatMul,
scenario: usize,
) -> FusedSpec<'t> {
let fs = match self {
ProtoFusedSpec::AddMatMul { a, b, packings, .. } => {
let a = &inputs[*a];
let b = &inputs[*b];
let a = a.as_slice::<Opaque>().unwrap()[0]
.downcast_ref::<Box<dyn MMMInputValue>>()
.unwrap();
let b = b.as_slice::<Opaque>().unwrap()[0]
.downcast_ref::<Box<dyn MMMInputValue>>()
.unwrap();
let (a_packing, b_packing) = mmm.packings()[packings[scenario]];
let pa = if a_packing.same_as(a.format()) {
AsInputValue::Borrowed(&**a)
} else if a_packing.is::<PackedFormat>()
&& a_packing.r() == a.format().r()
&& a.is::<EagerPackedInput>()
&& a.format().is::<PackedBlockQuantFormat>()
ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
debug_assert!(inputs.get(*a).is_some());
debug_assert!(inputs.get(*b).is_some());
let a = inputs.get_unchecked(*a);
let b = inputs.get_unchecked(*b);
debug_assert!(a.datum_type().is_opaque());
debug_assert!(a.len() == 1);
debug_assert!(b.datum_type().is_opaque());
debug_assert!(b.len() == 1);
let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
debug_assert!(a.is::<Box<dyn MMMInputValue>>());
debug_assert!(b.is::<Box<dyn MMMInputValue>>());
let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
#[cfg(debug_assertions)]
{
let format = PanelExtractFormat {
pbqf: a.format().downcast_ref::<PackedBlockQuantFormat>().unwrap().clone(),
};
let data = a.downcast_ref::<EagerPackedInput>().unwrap();
AsInputValue::Owned(Box::new(PanelExtractInput {
format,
data: data.clone(),
to: a_packing.downcast_ref::<PackedFormat>().unwrap().clone()
}))
} else {
panic!("Un-matchable input and output for weights {:?} -> {a_packing}", a);
};
assert!(
b_packing.same_as(b.format())
|| (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
);
let (a_packing, b_packing) = &_mmm.packings()[packings[scenario]];
debug_assert!(
a_packing.same_as(a.format())
|| (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
);
debug_assert!(
b_packing.same_as(b.format())
|| (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
);
}
FusedSpec::AddMatMul {
a: pa,
a: AsInputValue::Borrowed(&**a),
b: AsInputValue::Borrowed(&**b),
packing: packings[scenario],
}
}
},
ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
ProtoFusedSpec::BinPerRow(v, op, _) => {
Expand Down Expand Up @@ -302,6 +296,7 @@ pub struct OptMatMul {
pub mmm: Vec<Box<dyn MatMatMul>>,
pub c_m_axis: usize,
pub c_n_axis: usize,
pub trivial_packing: bool,
pub trivial_path: bool,
}

Expand Down Expand Up @@ -591,10 +586,19 @@ impl OptMatMul {
c_m_axis: usize,
c_n_axis: usize,
micro_ops: Vec<ProtoFusedSpec>,
trivial_packing: bool,
) -> TractResult<Self> {
ensure!(c_m_axis < c_fact.rank());
ensure!(c_n_axis < c_fact.rank());
let mut it = OptMatMul { mmm, c_fact, c_m_axis, c_n_axis, micro_ops, trivial_path: false };
let mut it = OptMatMul {
mmm,
c_fact,
c_m_axis,
c_n_axis,
micro_ops,
trivial_path: false,
trivial_packing,
};
it.update_trivial_path();
Ok(it)
}
Expand Down Expand Up @@ -626,6 +630,7 @@ impl OptMatMul {
.iter()
.enumerate()
.all(|(ax, dim)| ax == self.c_m_axis || ax == self.c_n_axis || dim.is_one())
&& self.trivial_packing
&& self.micro_ops.iter().all(|o| o.is_trivial())
}

Expand Down
31 changes: 22 additions & 9 deletions data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ pub fn vector_size() -> usize {

impl Tensor {
#[allow(unreachable_code, unexpected_cfgs)]
#[inline]
pub fn default_alignment(dt: DatumType, shape: &[usize]) -> usize {
if shape.len() == 0 {
dt.alignment()
Expand All @@ -168,16 +169,19 @@ impl Tensor {
}

/// Create an uninitialized tensor (dt as type paramater).
#[inline]
pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
Self::uninitialized_dt(T::datum_type(), shape)
}

/// Create an uninitialized tensor (dt as regular parameter).
#[inline]
pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
Self::uninitialized_aligned_dt(dt, shape, dt.alignment())
}

/// Create an uninitialized tensor with a given alignment (in bytes).
#[inline]
pub unsafe fn uninitialized_aligned<T: Datum>(
shape: &[usize],
alignment: usize,
Expand All @@ -194,7 +198,11 @@ impl Tensor {
let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
let data = Blob::new_for_size_and_align(bytes, alignment);
let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), data, len: 0 };
tensor.update_strides_and_len();
if tensor.shape.len() == 0 {
tensor.len = 1;
} else {
tensor.update_strides_and_len();
}
if !tensor.data.is_empty() {
if dt == String::datum_type() || dt == Blob::datum_type() {
// assumes zero-initialized string and blob are valid
Expand Down Expand Up @@ -445,12 +453,12 @@ impl Tensor {

fn update_strides_and_len(&mut self) {
self.strides.clear();
compute_natural_stride_to(&mut self.strides, &self.shape);
self.len = if self.rank() == 0 {
1
} else {
unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) }
if self.shape.len() == 0 {
self.len = 1;
return;
}
compute_natural_stride_to(&mut self.strides, &self.shape);
self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
}

/// Force the tensor shape, no consistency check.
Expand Down Expand Up @@ -1336,18 +1344,23 @@ impl Tensor {
}
}

fn from_datum<T: Datum>(it: ArrayD<T>) -> Tensor {
fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
unsafe {
let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
if let Some(slice) = it.as_slice() {
if let Some(slice) = it.as_slice_mut() {
if t.datum_type().is_copy() {
std::ptr::copy_nonoverlapping(
slice.as_ptr() as *const i8,
t.as_ptr_mut_unchecked(),
t.data.layout().size(),
);
return t;
} else {
t.as_slice_mut_unchecked::<T>()
.iter_mut()
.zip(slice.iter_mut())
.for_each(|(t, s)| *t = std::mem::take(s));
}
return t;
}
if it.strides().iter().all(|&s| s > 0) {
let mut len_and_strides: TVec<(usize, usize)> = tvec!();
Expand Down
30 changes: 16 additions & 14 deletions data/src/tensor/litteral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::datum::Datum;
use ndarray::*;
use std::sync::Arc;

pub fn arr4<A, const N: usize, const M: usize, const T: usize>(xs: &[[[[A;T];M];N]]) -> Array4<A>
pub fn arr4<A, const N: usize, const M: usize, const T: usize>(xs: &[[[[A; T]; M]; N]]) -> Array4<A>
where
A: Clone,
{
Expand All @@ -28,25 +28,28 @@ where
}

pub fn tensor0<A: Datum>(x: A) -> Tensor {
Tensor::from(arr0(x))
unsafe {
let mut tensor = Tensor::uninitialized::<A>(&[]).unwrap();
tensor.as_slice_mut_unchecked::<A>()[0] = x;
tensor
}
}

pub fn tensor1<A: Datum>(xs: &[A]) -> Tensor {
Tensor::from(arr1(xs))
}

pub fn tensor2<A: Datum, const N: usize>(xs: &[[A;N]]) -> Tensor
{
pub fn tensor2<A: Datum, const N: usize>(xs: &[[A; N]]) -> Tensor {
Tensor::from(arr2(xs))
}

pub fn tensor3<A: Datum, const N: usize, const M: usize>(xs: &[[[A;M];N]]) -> Tensor
{
pub fn tensor3<A: Datum, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> Tensor {
Tensor::from(arr3(xs))
}

pub fn tensor4<A: Datum, const N: usize, const M: usize, const T: usize>(xs: &[[[[A;T];M];N]]) -> Tensor
{
pub fn tensor4<A: Datum, const N: usize, const M: usize, const T: usize>(
xs: &[[[[A; T]; M]; N]],
) -> Tensor {
Tensor::from(arr4(xs))
}

Expand All @@ -58,17 +61,16 @@ pub fn rctensor1<A: Datum>(xs: &[A]) -> Arc<Tensor> {
Arc::new(Tensor::from(arr1(xs)))
}

pub fn rctensor2<A: Datum, const N: usize>(xs: &[[A;N]]) -> Arc<Tensor>
{
pub fn rctensor2<A: Datum, const N: usize>(xs: &[[A; N]]) -> Arc<Tensor> {
Arc::new(Tensor::from(arr2(xs)))
}

pub fn rctensor3<A: Datum, const N: usize, const M: usize>(xs: &[[[A;M];N]]) -> Arc<Tensor>
{
pub fn rctensor3<A: Datum, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> Arc<Tensor> {
Arc::new(Tensor::from(arr3(xs)))
}

pub fn rctensor4<A: Datum, const N: usize, const M: usize, const T: usize>(xs: &[[[[A;T];M];N]]) -> Arc<Tensor>
{
pub fn rctensor4<A: Datum, const N: usize, const M: usize, const T: usize>(
xs: &[[[[A; T]; M]; N]],
) -> Arc<Tensor> {
Arc::new(Tensor::from(arr4(xs)))
}
2 changes: 1 addition & 1 deletion linalg/benches/mat_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn mat_vec_mul(c: &mut Criterion) {
&(m, k),
|be, (&m, &k)| {
let mmm = tract_linalg::ops().mmm(F32, Some(m), Some(k), Some(1)).unwrap();
let packing = mmm.packings()[0];
let packing = &mmm.packings()[0];
let a = Tensor::zero::<f32>(&[m, k]).unwrap();
let pa = packing.0.prepare_tensor(&a, 1, 0).unwrap();
let b = Tensor::zero::<f32>(&[k, 1]).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion linalg/benches/mm_for_inception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn mat_mul_smmm(be: &mut criterion::Bencher, &(m, k, n): &(usize, usize, usize))
let mmm = tract_linalg::ops().mmm(F32, Some(m), Some(k), Some(n)).unwrap();
let a = Tensor::zero::<f32>(&[m, k]).unwrap();
let b = Tensor::zero::<f32>(&[k, n]).unwrap();
let packing = mmm.packings()[0];
let packing = &mmm.packings()[0];
let pa = packing.0.prepare_tensor(&a, 1, 0).unwrap();
let pb = packing.1.prepare_tensor(&b, 0, 1).unwrap();

Expand Down
2 changes: 1 addition & 1 deletion linalg/benches/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub fn mat_mat_with_mm(
) {
let a = Tensor::zero_dt(dt, &[m, k]).unwrap();
let b = Tensor::zero_dt(dt, &[k, n]).unwrap();
let packing = mmm.packings()[0];
let packing = &mmm.packings()[0];
let pa = packing.0.prepare_tensor(&a, 1, 0).unwrap();
let pb = packing.1.prepare_tensor(&b, 0, 1).unwrap();
unsafe {
Expand Down
9 changes: 5 additions & 4 deletions linalg/src/frame/mmm/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use pack::PackedFormat;
use tract_itertools::Itertools;

use super::*;
use std::borrow::Cow;
Expand All @@ -14,7 +13,8 @@ pub trait MatMatMulKer: Clone + Debug + Send + Sync + 'static {
fn mr(&self) -> usize;
fn nr(&self) -> usize;

fn packings(&self) -> Cow<[(&dyn MMMInputFormat, &dyn MMMInputFormat)]>;
#[allow(clippy::type_complexity)]
fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
fn stores(&self) -> Cow<[DatumType]>;

#[allow(unused_variables)]
Expand Down Expand Up @@ -130,8 +130,9 @@ impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<
unsafe { (self.kernel)(op) }
}

fn packings(&self) -> Cow<[(&dyn MMMInputFormat, &dyn MMMInputFormat)]> {
Cow::Owned(self.packings.iter().map(|p| (&*p.0, &*p.1)).collect_vec())
#[allow(clippy::type_complexity)]
fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
&self.packings
}

fn stores(&self) -> Cow<[DatumType]> {
Expand Down
5 changes: 3 additions & 2 deletions linalg/src/frame/mmm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any {
fn mr(&self) -> usize;
fn nr(&self) -> usize;

fn packings(&self) -> Cow<[(&dyn MMMInputFormat, &dyn MMMInputFormat)]>;
#[allow(clippy::type_complexity)]
fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];

fn internal_type(&self) -> DatumType;

Expand Down Expand Up @@ -89,7 +90,7 @@ impl<K: MatMatMulKer> MatMatMul for K {
self.nr()
}

fn packings(&self) -> Cow<[(&dyn MMMInputFormat, &dyn MMMInputFormat)]> {
fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
self.packings()
}

Expand Down
Loading
Loading