Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Igemm experiment #43

Open
wants to merge 6 commits into
base: i32-gemm-experiment
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/dgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@ macro_rules! loop_n {
}

impl GemmKernel for Gemm {
type Elem = T;
type ElemIn = T;
type ElemOut = T;

#[inline(always)]
fn align_to() -> usize { 0 }
const MR: usize = MR;
const NR: usize = NR;

#[inline(always)]
fn mr() -> usize { MR }
#[inline(always)]
fn nr() -> usize { NR }
fn align_to() -> usize { 0 }

#[inline(always)]
fn always_masked() -> bool { true }
Expand Down
91 changes: 47 additions & 44 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use kernel::GemmKernel;
use kernel::Element;
use sgemm_kernel;
use dgemm_kernel;
use igemm_kernel;
use i8gemm_kernel;
use rawpointer::PointerExt;

/// General matrix multiplication (f32)
Expand Down Expand Up @@ -88,15 +88,15 @@ pub unsafe fn dgemm(
c, rsc, csc)
}

pub unsafe fn igemm(
pub unsafe fn i8gemm(
m: usize, k: usize, n: usize,
alpha: i32,
a: *const i32, rsa: isize, csa: isize,
b: *const i32, rsb: isize, csb: isize,
beta: i32,
c: *mut i32, rsc: isize, csc: isize)
alpha: i16,
a: *const i8, rsa: isize, csa: isize,
b: *const i8, rsb: isize, csb: isize,
beta: i16,
c: *mut i16, rsc: isize, csc: isize)
{
gemm_loop::<igemm_kernel::Gemm>(
gemm_loop::<i8gemm_kernel::Gemm>(
m, k, n,
alpha,
a, rsa, csa,
Expand All @@ -113,26 +113,26 @@ pub unsafe fn igemm(
fn ensure_kernel_params<K>()
where K: GemmKernel
{
let mr = K::mr();
let nr = K::nr();
let mr = K::MR;
let nr = K::NR;
assert!(mr > 0 && mr <= 8);
assert!(nr > 0 && nr <= 8);
assert!(mr * nr * size_of::<K::Elem>() <= 8 * 4 * 8);
assert!(mr * nr * size_of::<K::ElemOut>() <= 8 * 4 * 8);
assert!(K::align_to() <= 32);
// one row/col of the kernel is limiting the max align we can provide
let max_align = size_of::<K::Elem>() * min(mr, nr);
let max_align = size_of::<K::ElemOut>() * min(mr, nr);
assert!(K::align_to() <= max_align);
}

/// Implement matrix multiply using packed buffers and a microkernel
/// strategy, the type parameter `K` is the gemm microkernel.
unsafe fn gemm_loop<K>(
m: usize, k: usize, n: usize,
alpha: K::Elem,
a: *const K::Elem, rsa: isize, csa: isize,
b: *const K::Elem, rsb: isize, csb: isize,
beta: K::Elem,
c: *mut K::Elem, rsc: isize, csc: isize)
alpha: K::ElemOut,
a: *const K::ElemIn, rsa: isize, csa: isize,
b: *const K::ElemIn, rsb: isize, csb: isize,
beta: K::ElemOut,
c: *mut K::ElemOut, rsc: isize, csc: isize)
where K: GemmKernel
{
debug_assert!(m <= 1 || n == 0 || rsc != 0);
Expand All @@ -146,7 +146,7 @@ unsafe fn gemm_loop<K>(
let knc = K::nc();
let kkc = K::kc();
let kmc = K::mc();
ensure_kernel_params::<K>();
// ensure_kernel_params::<K>();

let (mut packing_buffer, bp_offset) = make_packing_buffer::<K>(m, k, n);
let app = packing_buffer.ptr_mut();
Expand All @@ -165,7 +165,7 @@ unsafe fn gemm_loop<K>(
let a = a.stride_offset(csa, kkc * l4);

// Pack B -> B~
pack(kc, nc, K::nr(), bpp, b, csb, rsb);
pack(kc, nc, K::NR, bpp, b, csb, rsb);

// LOOP 3: split m into mc parts
for (l3, mc) in range_chunk(m, kmc) {
Expand All @@ -174,7 +174,7 @@ unsafe fn gemm_loop<K>(
let c = c.stride_offset(rsc, kmc * l3);

// Pack A -> A~
pack(kc, mc, K::mr(), app, a, rsa, csa);
pack(kc, mc, K::MR, app, a, rsa, csa);

// First time writing to C, use user's `beta`, else accumulate
let betap = if l4 == 0 { beta } else { <_>::one() };
Expand All @@ -198,18 +198,19 @@ unsafe fn gemm_loop<K>(
/// + kc: columns of packed A / rows of packed B
/// + mc: rows of packed A
unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
alpha: K::Elem,
app: *const K::Elem, bpp: *const K::Elem,
beta: K::Elem,
c: *mut K::Elem, rsc: isize, csc: isize)
alpha: K::ElemOut,
app: *const K::ElemIn, bpp: *const K::ElemIn,
beta: K::ElemOut,
c: *mut K::ElemOut, rsc: isize, csc: isize)
where K: GemmKernel,
{
let mr = K::mr();
let nr = K::nr();
let mr = K::MR;
let nr = K::NR;
// make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment
assert!(mr * nr * size_of::<K::Elem>() <= 256 && K::align_to() <= 32);
let mut mask_buf = [0u8; 256 + 31];
let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::Elem;
// assert!(mr * nr * size_of::<K::ElemOut>() <= 256 && K::align_to() <= 32);
// let mut mask_buf = [0u8; 256 + 31];
let mut mask_buf = [0u8; 16*32*2 + 31];
let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::ElemOut;

// LOOP 2: through micropanels in packed `b`
for (l2, nr_) in range_chunk(nc, nr) {
Expand All @@ -225,7 +226,7 @@ unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
// NOTE: For the rust kernels, it performs better to simply
// always use the masked kernel function!
if K::always_masked() || nr_ < nr || mr_ < mr {
masked_kernel::<_, K>(kc, alpha, &*app, &*bpp,
masked_kernel::<_, _, K>(kc, alpha, &*app, &*bpp,
beta, &mut *c, rsc, csc,
mr_, nr_, mask_ptr);
continue;
Expand All @@ -244,7 +245,7 @@ unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
/// we have rounded up to a multiple of the kernel size).
///
/// Return packing buffer and offset to start of b
unsafe fn make_packing_buffer<K>(m: usize, k: usize, n: usize) -> (Alloc<K::Elem>, usize)
unsafe fn make_packing_buffer<K>(m: usize, k: usize, n: usize) -> (Alloc<K::ElemIn>, usize)
where K: GemmKernel,
{
// max alignment requirement is a multiple of min(MR, NR) * sizeof<Elem>
Expand All @@ -254,8 +255,8 @@ unsafe fn make_packing_buffer<K>(m: usize, k: usize, n: usize) -> (Alloc<K::Elem
let n = min(n, K::nc());
// round up k, n to multiples of mr, nr
// round up to multiple of kc
let apack_size = k * round_up_to(m, K::mr());
let bpack_size = k * round_up_to(n, K::nr());
let apack_size = k * round_up_to(m, K::MR);
let bpack_size = k * round_up_to(n, K::NR);
let nelem = apack_size + bpack_size;

dprint!("packed nelem={}, apack={}, bpack={},
Expand Down Expand Up @@ -349,27 +350,29 @@ unsafe fn pack<T>(kc: usize, mc: usize, mr: usize, pack: *mut T,
/// + rows: rows of kernel unmasked
/// + cols: cols of kernel unmasked
#[inline(never)]
unsafe fn masked_kernel<T, K>(k: usize, alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize,
unsafe fn masked_kernel<Tin, Tout, K>(k: usize, alpha: Tout,
a: *const Tin,
b: *const Tin,
beta: Tout,
c: *mut Tout, rsc: isize, csc: isize,
rows: usize, cols: usize,
mask_buf: *mut T)
where K: GemmKernel<Elem=T>, T: Element,
mask_buf: *mut Tout)
where K: GemmKernel<ElemIn=Tin, ElemOut=Tout>,
Tin: Element,
Tout: Element,
{
let mr = K::mr();
let nr = K::nr();
let mr = K::MR;
let nr = K::NR;
// use column major order for `mask_buf`
K::kernel(k, T::one(), a, b, T::zero(), mask_buf, 1, mr as isize);
K::kernel(k, Tout::one(), a, b, Tout::zero(), mask_buf, 1, mr as isize);
let mut ab = mask_buf;
for j in 0..nr {
for i in 0..mr {
if i < rows && j < cols {
let cptr = c.stride_offset(rsc, i)
.stride_offset(csc, j);
if beta.is_zero() {
*cptr = T::zero(); // initialize C
*cptr = Tout::zero(); // initialize C
} else {
(*cptr).scale_by(beta);
}
Expand Down
Loading