Skip to content

Commit

Permalink
less flexible types to reduce matmul overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 11, 2024
1 parent 9a1cb28 commit 7e6a457
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 18 deletions.
6 changes: 3 additions & 3 deletions core/src/ops/matmul/optimized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl ProtoFusedSpec {
use ProtoFusedSpec::*;
match self {
AddMatMul { geo, packings: packing, .. } => {
let (a, b) = mmm.packings()[packing[scenario]];
let (a, b) = &mmm.packings()[packing[scenario]];
format!("matmul(k={}, {a:?}•{b:?})", geo.k)
}
BinScalar(_, op) => format!("scalar{op:?}"),
Expand Down Expand Up @@ -72,7 +72,7 @@ impl ProtoFusedSpec {
let b = b.as_slice::<Opaque>().unwrap()[0]
.downcast_ref::<Box<dyn MMMInputValue>>()
.unwrap();
let (a_packing, b_packing) = mmm.packings()[packings[scenario]];
let (a_packing, b_packing) = &mmm.packings()[packings[scenario]];
let pa = if a_packing.same_as(a.format()) {
AsInputValue::Borrowed(&**a)
} else if a_packing.is::<PackedFormat>()
Expand Down Expand Up @@ -155,7 +155,7 @@ impl ProtoFusedSpec {
let b = b.as_slice::<Opaque>().unwrap()[0]
.downcast_ref::<Box<dyn MMMInputValue>>()
.unwrap();
let (a_packing, b_packing) = mmm.packings()[packings[scenario]];
let (a_packing, b_packing) = &mmm.packings()[packings[scenario]];
let pa = if a_packing.same_as(a.format()) {
AsInputValue::Borrowed(&**a)
} else if a_packing.is::<PackedFormat>()
Expand Down
2 changes: 1 addition & 1 deletion linalg/benches/mat_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn mat_vec_mul(c: &mut Criterion) {
&(m, k),
|be, (&m, &k)| {
let mmm = tract_linalg::ops().mmm(F32, Some(m), Some(k), Some(1)).unwrap();
let packing = mmm.packings()[0];
let packing = &mmm.packings()[0];
let a = Tensor::zero::<f32>(&[m, k]).unwrap();
let pa = packing.0.prepare_tensor(&a, 1, 0).unwrap();
let b = Tensor::zero::<f32>(&[k, 1]).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion linalg/benches/mm_for_inception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn mat_mul_smmm(be: &mut criterion::Bencher, &(m, k, n): &(usize, usize, usize))
let mmm = tract_linalg::ops().mmm(F32, Some(m), Some(k), Some(n)).unwrap();
let a = Tensor::zero::<f32>(&[m, k]).unwrap();
let b = Tensor::zero::<f32>(&[k, n]).unwrap();
let packing = mmm.packings()[0];
let packing = &mmm.packings()[0];
let pa = packing.0.prepare_tensor(&a, 1, 0).unwrap();
let pb = packing.1.prepare_tensor(&b, 0, 1).unwrap();

Expand Down
2 changes: 1 addition & 1 deletion linalg/benches/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub fn mat_mat_with_mm(
) {
let a = Tensor::zero_dt(dt, &[m, k]).unwrap();
let b = Tensor::zero_dt(dt, &[k, n]).unwrap();
let packing = mmm.packings()[0];
let packing = &mmm.packings()[0];
let pa = packing.0.prepare_tensor(&a, 1, 0).unwrap();
let pb = packing.1.prepare_tensor(&b, 0, 1).unwrap();
unsafe {
Expand Down
11 changes: 5 additions & 6 deletions linalg/src/frame/mmm/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use pack::PackedFormat;
use tract_itertools::Itertools;

use super::*;
use std::borrow::Cow;
Expand All @@ -14,7 +13,7 @@ pub trait MatMatMulKer: Clone + Debug + Send + Sync + 'static {
fn mr(&self) -> usize;
fn nr(&self) -> usize;

fn packings(&self) -> Cow<[(&dyn MMMInputFormat, &dyn MMMInputFormat)]>;
fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];
fn stores(&self) -> Cow<[DatumType]>;

#[allow(unused_variables)]
Expand Down Expand Up @@ -46,7 +45,7 @@ impl<const MR: usize, const NR: usize, Acc: LADatum> DynKernel<MR, NR, Acc> {
name: &str,
kernel: Kernel<Acc>,
default_packing_alignments: (usize, usize),
) -> Self {
) -> Self {
let kernel = DynKernel {
name: name.to_string(),
kernel,
Expand Down Expand Up @@ -130,11 +129,11 @@ impl<const MR: usize, const NR: usize, Acc: LADatum> MatMatMulKer for DynKernel<
unsafe { (self.kernel)(op) }
}

fn packings(&self) -> Cow<[(&dyn MMMInputFormat, &dyn MMMInputFormat)]> {
Cow::Owned(self.packings.iter().map(|p| (&*p.0, &*p.1)).collect_vec())
fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
&self.packings
}

fn stores(&self) -> Cow<[DatumType]> {
Cow::Borrowed(&self.stores)
}
}
}
4 changes: 2 additions & 2 deletions linalg/src/frame/mmm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any {
fn mr(&self) -> usize;
fn nr(&self) -> usize;

fn packings(&self) -> Cow<[(&dyn MMMInputFormat, &dyn MMMInputFormat)]>;
fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)];

fn internal_type(&self) -> DatumType;

Expand Down Expand Up @@ -89,7 +89,7 @@ impl<K: MatMatMulKer> MatMatMul for K {
self.nr()
}

fn packings(&self) -> Cow<[(&dyn MMMInputFormat, &dyn MMMInputFormat)]> {
fn packings(&self) -> &[(Box<dyn MMMInputFormat>, Box<dyn MMMInputFormat>)] {
self.packings()
}

Expand Down
6 changes: 3 additions & 3 deletions linalg/src/frame/mmm/tests/packed_packed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ impl<K: MatMatMulKer> PackedPackedProblem<K> {
}

pub fn padded_inputs(&self) -> TractResult<(Tensor, Tensor)> {
let (pack_a, pack_b) = self.ker.packings()[self.packing];
let (pack_a, pack_b) = &self.ker.packings()[self.packing];
assert!(pack_b.k_alignment() == 1);
let (m, k, n) = self.mkn();
let k_aligned = k.next_multiple_of(pack_a.k_alignment());
Expand Down Expand Up @@ -283,7 +283,7 @@ impl<K: MatMatMulKer> PackedPackedProblem<K> {

pub fn reference(&self) -> TractResult<Tensor> {
let (m, k, n) = self.mkn();
let pack_a = self.ker.packings()[self.packing].0;
let pack_a = &self.ker.packings()[self.packing].0;
let (mut a, b) = self.padded_inputs()?;
let k_aligned = k.next_multiple_of(pack_a.k_alignment());
if let Some(pbqf) = pack_a.downcast_ref::<PackedBlockQuantFormat>() {
Expand All @@ -310,7 +310,7 @@ impl<K: MatMatMulKer> PackedPackedProblem<K> {

pub fn run(&self) -> TractResult<Tensor> {
let (m, k, n) = self.mkn();
let (pack_a, pack_b) = self.ker.packings()[self.packing];
let (pack_a, pack_b) = &self.ker.packings()[self.packing];
assert!(pack_b.k_alignment() == 1);
let k_aligned = k.next_multiple_of(pack_a.k_alignment());

Expand Down
2 changes: 1 addition & 1 deletion linalg/tests/virtual_im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl ConvProblem {
let mmm = tract_linalg::ops().mmm(F32, Some(m), Some(k), Some(n)).unwrap();
let output = Tensor::zero::<f32>(&internal_output_shape)?;
let reshaped_filters = self.filters.clone().into_shape(&[k, m])?;
let (a_pack, b_pack) = mmm.packings()[0];
let (a_pack, b_pack) = &mmm.packings()[0];
let a = a_pack.prepare_tensor(&reshaped_filters, 0, 1)?;
unsafe {
let im2col: Box<dyn MMMInputValue> = if self.lazy_im2col {
Expand Down

0 comments on commit 7e6a457

Please sign in to comment.