From a4d41dc6997fb22b1bd1d4a4aace0ffa1f9b66da Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 6 Oct 2024 20:52:40 -0400 Subject: [PATCH] Add cublaslt matmul --- Cargo.lock | 27 +- Cargo.toml | 5 +- mistralrs-core/Cargo.toml | 5 +- mistralrs-core/src/cublaslt/api.rs | 243 +++++++- mistralrs-core/src/cublaslt/matmul.rs | 560 ++++++++++++++++++ mistralrs-core/src/cublaslt/mod.rs | 8 +- mistralrs-core/src/diffusion_models/t5/mod.rs | 3 + mistralrs-core/src/ops.rs | 8 + mistralrs-paged-attn/Cargo.toml | 1 + mistralrs-paged-attn/src/backend/mod.rs | 1 + mistralrs-pyo3/Cargo_template.toml | 2 +- mistralrs-quant/Cargo.toml | 1 + mistralrs-quant/src/utils/ops.rs | 8 + mistralrs-quant/src/utils/uqff.rs | 10 +- 14 files changed, 864 insertions(+), 18 deletions(-) create mode 100644 mistralrs-core/src/cublaslt/matmul.rs diff --git a/Cargo.lock b/Cargo.lock index 121b6c542..12a10b160 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -393,13 +393,14 @@ checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "candle-core" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=4754afe#4754afe3b735d1c8c9bdddd5849250b411b471a7" dependencies = [ "accelerate-src", "byteorder", "candle-kernels", "candle-metal-kernels", "cudarc", + "float8", "gemm", "half", "intel-mkl-src", @@ -420,7 +421,7 @@ dependencies = [ [[package]] name = "candle-flash-attn" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=4754afe#4754afe3b735d1c8c9bdddd5849250b411b471a7" dependencies = [ "anyhow", "bindgen_cuda 0.1.5", @@ -431,7 +432,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=4754afe#4754afe3b735d1c8c9bdddd5849250b411b471a7" dependencies = [ "bindgen_cuda 0.1.5", ] @@ -439,7 +440,7 @@ dependencies = [ [[package]] name = "candle-metal-kernels" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=4754afe#4754afe3b735d1c8c9bdddd5849250b411b471a7" dependencies = [ "metal", "once_cell", @@ -450,7 +451,7 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=4754afe#4754afe3b735d1c8c9bdddd5849250b411b471a7" dependencies = [ "accelerate-src", "candle-core", @@ -1119,6 +1120,19 @@ dependencies = [ "miniz_oxide 0.8.0", ] +[[package]] +name = "float8" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c3475274d374d263c4c40c43ad854c5bdf733c7db775bbd3c1ca2ad7427978" +dependencies = [ + "cudarc", + "half", + "num-traits", + "rand", + "rand_distr", +] + [[package]] name = "flume" version = "0.11.0" @@ -2155,6 +2169,7 @@ dependencies = [ "derive_more", "dirs", "either", + "float8", "futures", "galil-seiferas", "half", @@ -2208,6 +2223,7 @@ dependencies = [ "anyhow", "bindgen_cuda 0.1.6", "candle-core", + "float8", "half", ] @@ -2243,6 +2259,7 @@ dependencies = [ "byteorder", "candle-core", "candle-nn", + "float8", "half", "lazy_static", "paste", diff --git a/Cargo.toml b/Cargo.toml index 109717ae2..9e13101dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "4754afe" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "4754afe" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } @@ -49,3 +49,4 @@ rayon = "1.1.0" url = "2.5.2" data-url = "0.3.1" buildstructor = "0.5.4" +float8 = "0.1.1" diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index dbcfab697..1b0027d79 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "4754afe", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" @@ -78,10 +78,11 @@ regex = "1.10.6" safetensors = "0.4.5" serde_plain = "1.0.2" as-any = "0.3.1" +float8.workspace = true [features] pyo3_macros = ["pyo3"] -cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"] +cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda", "float8/cuda"] cudnn = ["candle-core/cudnn"] metal = ["candle-core/metal", "candle-nn/metal"] flash-attn = ["cuda", "dep:candle-flash-attn"] diff --git a/mistralrs-core/src/cublaslt/api.rs b/mistralrs-core/src/cublaslt/api.rs index 24aca6ba2..25b2e636e 100644 --- a/mistralrs-core/src/cublaslt/api.rs +++ b/mistralrs-core/src/cublaslt/api.rs @@ -1,4 +1,4 @@ -pub use candle_core::cuda_backend::cudarc::cublaslt::Activation; +use float8::F8E4M3; use std::ffi::c_int; use candle_core::backend::BackendStorage; @@ -7,7 +7,7 @@ use candle_core::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor}; use half::{bf16, f16}; use std::sync::Arc; -use candle_core::cuda_backend::cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig}; +use super::matmul::{Activation, CudaBlasLT, Matmul, MatmulConfig}; #[derive(Debug, Clone)] pub struct CublasLt(Arc); @@ -318,6 +318,101 @@ impl CublasLTMatmul { Ok((out, out_shape)) } + + pub fn fwd_f8e4m3( + &self, + a: &candle_core::CudaStorage, + a_l: &Layout, + b: &candle_core::CudaStorage, + b_l: &Layout, + bias: Option<&candle_core::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle_core::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (m, k) = a_l.shape().dims2()?; + + let (n, b_1) = b_l.shape().dims2()?; + + if b_1 != k { + candle_core::bail!("This layer only supports TN layout"); + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle_core::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let mut out = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle_core::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle_core::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle_core::bail!("`c` has to be contiguous"), + }; + if c_l.shape().dims2()? != (n, m) { + candle_core::bail!("`c` does not have the correct shape"); + } + + c.clone() + } else { + // Allocate out tensor + unsafe { dev.alloc::(out_shape.elem_count()).w()? } + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: None, + stride_b: None, + stride_c: None, + stride_bias: None, + batch_size: None, + }; + + unsafe { + self.cublaslt + .matmul_fp8_like(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + } + + let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } } impl candle_core::CustomOp2 for CublasLTMatmul { @@ -346,7 +441,10 @@ impl candle_core::CustomOp2 for CublasLTMatmul { candle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), candle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), candle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), - dt => candle_core::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), + candle_core::DType::F8E4M3 => self.fwd_f32(a, a_l, b, b_l, None, None), + dt => candle_core::bail!( + "cublaslt-matmul is only supported for f16/bf16/f32/f8e4m3 ({dt:?})" + ), } } } @@ -744,6 +842,109 @@ impl CublasLTBatchMatmul { Ok((out, out_shape)) } + + pub fn fwd_f8e4m3( + &self, + a: &candle_core::CudaStorage, + a_l: &Layout, + b: &candle_core::CudaStorage, + b_l: &Layout, + bias: Option<&candle_core::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle_core::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (batch_size, m, k) = a_l.shape().dims3()?; + let (b_0, n, b_2) = b_l.shape().dims3()?; + + if b_2 != k { + candle_core::bail!("This layer only supports TN layout"); + } + + if b_0 != batch_size { + candle_core::bail!("`b` must have the same batch size as `a`") + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((batch_size, n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle_core::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let (mut out, stride_c) = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle_core::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle_core::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle_core::bail!("`c` has to be contiguous"), + }; + + if c_l.shape().dims3()? != (batch_size, n, m) { + candle_core::bail!("`c` does not have the correct shape"); + } + + // Set beta to 0.0 if it is not set + (c.clone(), c_l.stride()[0]) + } else { + // Allocate out tensor + ( + unsafe { dev.alloc::(out_shape.elem_count()).w()? }, + (n * m), + ) + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: Some(a_l.stride()[0] as i64), + stride_b: Some(b_l.stride()[0] as i64), + stride_c: Some(stride_c as i64), + stride_bias: None, + batch_size: Some(c_int::try_from(batch_size)?), + }; + + unsafe { + self.cublaslt + .matmul_fp8_like(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + } + + let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } } impl candle_core::CustomOp2 for CublasLTBatchMatmul { @@ -811,8 +1012,9 @@ impl candle_core::CustomOp3 for CublasLTBatchMatmul { candle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), candle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), candle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle_core::DType::F8E4M3 => self.fwd_f8e4m3(a, a_l, b, b_l, Some(bias), Some(bias_l)), dt => candle_core::bail!( - "cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})" + "cublaslt-batch-matmul-add is only supported for f16/bf16/f32/f8e4m3 ({dt:?})" ), } } @@ -937,4 +1139,37 @@ mod tests { .all(|x| x.into_iter().all(|y| y.into_iter().all(|x| *x <= range)))); Ok(()) } + + #[test] + fn test_fused_batch_matmul_f8e4m3() -> Result<()> { + let device = Device::new_cuda(0)?; + + let a = Tensor::randn(0., 1., (3, 8, 4), &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., (3, 2, 4), &device)?.to_dtype(DType::F32)?; + let c = Tensor::randn(0., 1., (3, 2, 8), &device)?.to_dtype(DType::F32)?; + let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let cublaslt = CublasLt::new(&device)?; + + let res = fused_batch_matmul( + &a.to_dtype(DType::F8E4M3)?, + &b.to_dtype(DType::F8E4M3)?, + Some(&c.to_dtype(DType::BF16)?), + None, + Some(1.0), + Some(&bias.to_dtype(DType::BF16)?), + None, + cublaslt, + )?; + let expected = (b.matmul(&a.t()?)?.add(&c)? + bias.broadcast_left((3, 2))?)?; + + let abs_diff = (res.to_dtype(DType::F32)? - expected)? + .abs()? + .to_vec3::()?; + let range = 1e-02; + assert!(abs_diff + .iter() + .all(|x| x.into_iter().all(|y| y.into_iter().all(|x| *x <= range)))); + Ok(()) + } } diff --git a/mistralrs-core/src/cublaslt/matmul.rs b/mistralrs-core/src/cublaslt/matmul.rs new file mode 100644 index 000000000..f2e55d3f3 --- /dev/null +++ b/mistralrs-core/src/cublaslt/matmul.rs @@ -0,0 +1,560 @@ +use candle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute; +use candle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys}; +use candle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream}; +use candle_core::cuda::cudarc::driver::{ + CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError, +}; +use core::ffi::c_int; +use core::mem; +use float8::F8E4M3; +use half::bf16; +use std::sync::Arc; + +/// Wrapper around [sys::cublasLtHandle_t] +/// +/// 1. Create with [CudaBlasLT::new()] +/// 2. Execute matmul kernel with matmul. f32 is supported. f16 and bf16 are supported +/// if feature `half` is activated +/// +/// Note: This maintains a instance of [`Arc`], so will prevent the device +/// from being dropped. Kernels will be launched on the device device default stream. +#[derive(Debug)] +pub struct CudaBlasLT { + handle: sys::cublasLtHandle_t, + workspace: Workspace, + device: Arc, +} + +unsafe impl Send for CudaBlasLT {} + +unsafe impl Sync for CudaBlasLT {} + +impl CudaBlasLT { + /// Creates a new cublasLt handle. + pub fn new(device: Arc) -> Result { + let handle = result::create_handle()?; + let workspace = Workspace::new(device.clone()).unwrap(); + + Ok(Self { + handle, + workspace, + device, + }) + } +} + +impl Drop for CudaBlasLT { + fn drop(&mut self) { + let handle = mem::replace(&mut self.handle, std::ptr::null_mut()); + if !handle.is_null() { + unsafe { result::destroy_handle(handle) }.unwrap(); + } + } +} + +/// User owned CublasLt workspace buffer. +/// The workspace is initialised following the Nvidia recommendations: +/// +/// 1. NVIDIA Hopper Architecture: 32 MiB +/// 2. Other: 4 MiB +#[derive(Debug, Clone)] +pub struct Workspace { + pub(crate) buffer: CudaSlice, + pub(crate) size: usize, +} + +impl Workspace { + /// Creates a CublasLt workspace buffer on the provided device + pub fn new(device: Arc) -> Result { + device.bind_to_thread()?; + + let major = + device.attribute(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?; + let workspace_size = if major >= 9 { 33_554_432 } else { 4_194_304 }; + + let buffer = unsafe { device.alloc::(workspace_size)? }; + Ok(Self { + buffer, + size: workspace_size, + }) + } +} + +/// Available activation for kernel fusing in matmul +#[derive(Debug, Clone)] +pub enum Activation { + Relu, + Gelu, +} + +/// MatrixLayout helper type +struct MatrixLayout { + handle: sys::cublasLtMatrixLayout_t, +} + +impl MatrixLayout { + fn new( + matrix_type: sys::cudaDataType, + rows: u64, + cols: u64, + ld: i64, + ) -> Result { + let handle = result::create_matrix_layout(matrix_type, rows, cols, ld)?; + Ok(Self { handle }) + } + + fn set_batch(&self, size: c_int, stride: i64) -> Result<(), CublasError> { + unsafe { + // Set batch size + set_matrix_layout_attribute( + self.handle, + sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + (&size) as *const _ as *const _, + mem::size_of::(), + )?; + // Set batch stride + set_matrix_layout_attribute( + self.handle, + sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + (&stride) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatrixLayout { + fn drop(&mut self) { + // panic on failure + unsafe { + result::destroy_matrix_layout(self.handle).expect("Unable to destroy matrix layout") + } + } +} + +enum Matrix { + A, + B, + #[allow(dead_code)] + C, +} + +/// MatmulDesc helper type +struct MatmulDesc { + handle: sys::cublasLtMatmulDesc_t, +} + +impl MatmulDesc { + fn new( + compute_type: sys::cublasComputeType_t, + scale_type: sys::cudaDataType, + ) -> Result { + let handle = result::create_matmul_desc(compute_type, scale_type)?; + Ok(Self { handle }) + } + + fn set_transpose(&self, transpose: bool, matrix: Matrix) -> Result<(), CublasError> { + // Set transpose + // 1 == T, 0 == N + let transpose = transpose as i32; + let attr = match matrix { + Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA, + Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB, + Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSC, + }; + + unsafe { + result::set_matmul_desc_attribute( + self.handle, + attr, + (&transpose) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } + + // Epilogue system can be leveraged to fuse add and activation operations + fn set_epilogue( + &self, + act: Option<&Activation>, + bias_ptr: Option<&CUdeviceptr>, + stride_bias: Option, + ) -> Result<(), CublasError> { + let epilogue = if let Some(bias_ptr) = bias_ptr { + let epilogue = act + .map(|act| match act { + // Act + bias + Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS, + Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS, + }) + // Only bias + .unwrap_or(sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS); + + // Set bias CUdeviceptr in matmul_desc + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER, + bias_ptr as *const CUdeviceptr as *const _, + mem::size_of::(), + )?; + } + + if let Some(stride_bias) = stride_bias { + // Set bias batch stride + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE, + (&stride_bias) as *const _ as *const _, + mem::size_of::(), + )?; + } + } + epilogue + } else if let Some(act) = act { + // Only Act + match act { + Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU, + Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU, + } + } else { + // No epilogue + sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT + }; + + // Set epilogue + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE, + (&epilogue) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatmulDesc { + fn drop(&mut self) { + unsafe { result::destroy_matmul_desc(self.handle).expect("Unable to destroy matmul desc") } + } +} + +/// MatmulPref helper type +struct MatmulPref { + handle: sys::cublasLtMatmulPreference_t, +} + +impl MatmulPref { + fn new() -> Result { + let handle = result::create_matmul_pref()?; + Ok(Self { handle }) + } + + fn set_workspace_size(&self, size: usize) -> Result<(), CublasError> { + unsafe { + // Set workspace size + result::set_matmul_pref_attribute( + self.handle, + sys::cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + (&size) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatmulPref { + fn drop(&mut self) { + unsafe { result::destroy_matmul_pref(self.handle).expect("Unable to destroy matmul pref") } + } +} + +/// [Matmul] super-trait +pub trait MatmulShared { + /// Returns a reference to the underlying cublasLt handle. + fn handle(&self) -> &sys::cublasLtHandle_t; + + /// Returns a reference to the underlying cublasLt workspace + fn workspace(&self) -> &Workspace; + + /// Returns a reference to the underlying stream + fn stream(&self) -> &CUstream; +} + +/// Configuration for [Matmul] +#[derive(Debug, Copy, Clone)] +pub struct MatmulConfig { + pub transa: bool, + pub transb: bool, + pub m: u64, + pub n: u64, + pub k: u64, + pub alpha: f32, + pub lda: i64, + pub ldb: i64, + pub beta: f32, + pub ldc: i64, + pub stride_a: Option, + pub stride_b: Option, + pub stride_c: Option, + pub stride_bias: Option, + pub batch_size: Option, +} + +/// Matrix matrix multiplication with elements of type `T`. +pub trait Matmul: MatmulShared { + /// Underlying CUDA Type for `T` + fn matrix_type() -> sys::cudaDataType; + + /// Underlying CUDA Compute Type for `T` + fn compute_type() -> sys::cublasComputeType_t; + + /// Matrix matrix multiplication. See + /// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul) + /// + /// # Safety + /// This is unsafe because improper arguments may lead to invalid + /// memory accesses. + unsafe fn matmul, O: DevicePtrMut>( + &self, + cfg: MatmulConfig, + a: &I, + b: &I, + c: &mut O, + bias: Option<&I>, + act: Option<&Activation>, + ) -> Result<(), CublasError> { + let (a_rows, a_cols) = if cfg.transa { + (cfg.k, cfg.m) + } else { + (cfg.m, cfg.k) + }; + let (b_rows, b_cols) = if cfg.transb { + (cfg.n, cfg.k) + } else { + (cfg.k, cfg.n) + }; + + // Creates matrix layouts + let a_layout = MatrixLayout::new(Self::matrix_type(), a_rows, a_cols, cfg.lda)?; + if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) { + a_layout.set_batch(batch_size, stride_a)?; + } + + let b_layout = MatrixLayout::new(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?; + if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) { + b_layout.set_batch(batch_size, stride_b)?; + } + + let c_layout = MatrixLayout::new(Self::matrix_type(), cfg.m, cfg.n, cfg.ldc)?; + if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) { + c_layout.set_batch(batch_size, stride_c)?; + } + + // Matmul description + let matmul_desc = MatmulDesc::new(Self::compute_type(), sys::cudaDataType_t::CUDA_R_32F)?; + + // Set transa + matmul_desc.set_transpose(cfg.transa, Matrix::A)?; + // Set transb + matmul_desc.set_transpose(cfg.transb, Matrix::B)?; + + // Epilogue system can be leveraged to fuse add and activation operations + matmul_desc.set_epilogue(act, bias.map(|b| b.device_ptr()), cfg.stride_bias)?; + + // Create matmul heuristic search preferences + let matmul_pref = MatmulPref::new()?; + + // Set workspace size + matmul_pref.set_workspace_size(self.workspace().size)?; + + // Get heuristic given Config, bias, act and workspace size + let heuristic = result::get_matmul_algo_heuristic( + *self.handle(), + matmul_desc.handle, + a_layout.handle, + b_layout.handle, + c_layout.handle, + c_layout.handle, + matmul_pref.handle, + )?; + + // Launch matmul kernel + result::matmul( + *self.handle(), + matmul_desc.handle, + (&cfg.alpha) as *const _ as *const _, + (&cfg.beta) as *const _ as *const _, + *a.device_ptr() as *const _, + a_layout.handle, + *b.device_ptr() as *const _, + b_layout.handle, + *c.device_ptr_mut() as *const _, + c_layout.handle, + *c.device_ptr_mut() as *mut _, + c_layout.handle, + (&heuristic.algo) as *const _, + *self.workspace().buffer.device_ptr() as *const CUdeviceptr as *mut _, + self.workspace().size, + *self.stream() as *mut _, + ) + } + + /// Matrix matrix multiplication. See + /// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul) + /// + /// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul + /// There are a few requirements: + /// - Compute type must be f32 (upheld) + /// - `transa && !transb` (upheld) + /// - Scale type must be (upheld) + /// - A and B must be f8e4m3, but C must be bf16 (upheld) + /// + /// # Safety + /// This is unsafe because improper arguments may lead to invalid + /// memory accesses. + unsafe fn matmul_fp8_like, O: DevicePtrMut, B: DevicePtr>( + &self, + cfg: MatmulConfig, + a: &I, + b: &I, + c: &mut O, + bias: Option<&B>, + act: Option<&Activation>, + ) -> Result<(), CublasError> { + let (a_rows, a_cols) = if cfg.transa { + (cfg.k, cfg.m) + } else { + (cfg.m, cfg.k) + }; + let (b_rows, b_cols) = if cfg.transb { + (cfg.n, cfg.k) + } else { + (cfg.k, cfg.n) + }; + + // Creates matrix layouts + let a_layout = MatrixLayout::new(Self::matrix_type(), a_rows, a_cols, cfg.lda)?; + if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) { + a_layout.set_batch(batch_size, stride_a)?; + } + + let b_layout = MatrixLayout::new(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?; + if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) { + b_layout.set_batch(batch_size, stride_b)?; + } + + let c_layout = MatrixLayout::new(Self::matrix_type(), cfg.m, cfg.n, cfg.ldc)?; + if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) { + c_layout.set_batch(batch_size, stride_c)?; + } + + // Matmul description + let matmul_desc = MatmulDesc::new(Self::compute_type(), sys::cudaDataType_t::CUDA_R_32F)?; + + // Set transa + matmul_desc.set_transpose(cfg.transa, Matrix::A)?; + // Set transb + matmul_desc.set_transpose(cfg.transb, Matrix::B)?; + + // Epilogue system can be leveraged to fuse add and activation operations + matmul_desc.set_epilogue(act, bias.map(|b| b.device_ptr()), cfg.stride_bias)?; + + // Create matmul heuristic search preferences + let matmul_pref = MatmulPref::new()?; + + // Set workspace size + matmul_pref.set_workspace_size(self.workspace().size)?; + + // Get heuristic given Config, bias, act and workspace size + let heuristic = result::get_matmul_algo_heuristic( + *self.handle(), + matmul_desc.handle, + a_layout.handle, + b_layout.handle, + c_layout.handle, + c_layout.handle, + matmul_pref.handle, + )?; + + // Launch matmul kernel + result::matmul( + *self.handle(), + matmul_desc.handle, + (&cfg.alpha) as *const _ as *const _, + (&cfg.beta) as *const _ as *const _, + *a.device_ptr() as *const _, + a_layout.handle, + *b.device_ptr() as *const _, + b_layout.handle, + *c.device_ptr_mut() as *const _, + c_layout.handle, + *c.device_ptr_mut() as *mut _, + c_layout.handle, + (&heuristic.algo) as *const _, + *self.workspace().buffer.device_ptr() as *const CUdeviceptr as *mut _, + self.workspace().size, + *self.stream() as *mut _, + ) + } +} + +impl MatmulShared for CudaBlasLT { + fn handle(&self) -> &sys::cublasLtHandle_t { + &self.handle + } + + fn workspace(&self) -> &Workspace { + &self.workspace + } + + fn stream(&self) -> &CUstream { + &self.device.cu_stream() + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_32F + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_16F + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_16BF + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_8F_E4M3 + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + } +} diff --git a/mistralrs-core/src/cublaslt/mod.rs b/mistralrs-core/src/cublaslt/mod.rs index 7657186ca..9d6046b38 100644 --- a/mistralrs-core/src/cublaslt/mod.rs +++ b/mistralrs-core/src/cublaslt/mod.rs @@ -9,9 +9,11 @@ use std::sync::{Mutex, Once}; #[cfg(feature = "cuda")] mod api; +#[cfg(feature = "cuda")] +mod matmul; #[cfg(feature = "cuda")] -use api::{fused_batch_matmul, fused_matmul, Activation, CublasLt}; +use api::{fused_batch_matmul, fused_matmul, CublasLt}; static INIT: Once = Once::new(); static mut CUBLASLT: Option = None; @@ -70,8 +72,8 @@ impl CublasLtWrapper { #[cfg(feature = "cuda")] { let inner_act = act.map(|a| match a { - CandleActivation::Relu => Activation::Relu, - CandleActivation::Gelu => Activation::Gelu, + CandleActivation::Relu => matmul::Activation::Relu, + CandleActivation::Gelu => matmul::Activation::Gelu, _ => unreachable!("Unsupported activation in cublaslt matmul"), }); let mut result = fused_batch_matmul( diff --git a/mistralrs-core/src/diffusion_models/t5/mod.rs b/mistralrs-core/src/diffusion_models/t5/mod.rs index ac66d0faa..e9980ae6b 100644 --- a/mistralrs-core/src/diffusion_models/t5/mod.rs +++ b/mistralrs-core/src/diffusion_models/t5/mod.rs @@ -5,6 +5,7 @@ use candle_core::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, VarBuilder}; +use float8::F8E4M3; use serde::Deserialize; use std::sync::Arc; @@ -626,6 +627,7 @@ impl TensorInfExtend for Tensor { DType::BF16 => Ok(sum.to_scalar::()? == half::bf16::from_f32_const(0.)), DType::F32 => Ok(sum.to_scalar::()? == 0.), DType::F64 => Ok(sum.to_scalar::()? == 0.), + DType::F8E4M3 => Ok(sum.to_scalar::()? == F8E4M3::ZERO), } } } @@ -641,6 +643,7 @@ fn clamp_for_f16(xs: &Tensor) -> Result { DType::BF16 => half::bf16::MAX.to_f64_const() - 1000., DType::F32 => f32::MAX as f64 - 1000., DType::F64 => f64::MAX - 1000., + DType::F8E4M3 => F8E4M3::MAX.to_f64() - 1000., }; if xs.is_inf()?.any()? { max -= 1000.; diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs index 0d7b5321d..020bfe04f 100644 --- a/mistralrs-core/src/ops.rs +++ b/mistralrs-core/src/ops.rs @@ -123,6 +123,7 @@ impl CustomOp2 for BitWise { CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")), CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")), + CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")), } } #[cfg(feature = "cuda")] @@ -191,6 +192,9 @@ impl CustomOp2 for BitWise { DType::F64 => { return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")); } + DType::F8E4M3 => { + return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")); + } }; let dst = match s1.dtype() { DType::U8 => { @@ -397,6 +401,7 @@ fn count_nonzero_cuda(dtype: candle_core::DType, d_in: *const c_void, n: u32) -> candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n), candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n), candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n), + candle_core::DType::F8E4M3 => todo!(), } } } @@ -438,6 +443,7 @@ fn nonzero_cuda( candle_core::DType::F64 => { ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out) } + candle_core::DType::F8E4M3 => todo!(), } } } @@ -461,6 +467,7 @@ impl CustomOp1 for NonZero { candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout), + candle_core::CpuStorage::F8E4M3(_vs) => todo!(), }; let index_len = layout.dims().len(); let result_len = result.len() / index_len; @@ -488,6 +495,7 @@ impl CustomOp1 for NonZero { candle_core::DType::F16 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::F32 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::F64 => *storage.as_cuda_slice::()?.device_ptr(), + candle_core::DType::F8E4M3 => todo!(), } as *const c_void; let n = layout.shape().elem_count(); let num_nonzero = count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?); diff --git a/mistralrs-paged-attn/Cargo.toml b/mistralrs-paged-attn/Cargo.toml index 9c65dfd58..5e16e57a0 100644 --- a/mistralrs-paged-attn/Cargo.toml +++ b/mistralrs-paged-attn/Cargo.toml @@ -14,6 +14,7 @@ homepage.workspace = true [dependencies] candle-core.workspace = true half.workspace = true +float8.workspace = true [build-dependencies] bindgen_cuda = {git = "https://github.com/guoqingbao/bindgen_cuda.git", version = "0.1.6"} diff --git a/mistralrs-paged-attn/src/backend/mod.rs b/mistralrs-paged-attn/src/backend/mod.rs index ad40a237c..579caf44d 100644 --- a/mistralrs-paged-attn/src/backend/mod.rs +++ b/mistralrs-paged-attn/src/backend/mod.rs @@ -33,6 +33,7 @@ pub fn get_or_load_func( DType::F16 => "_f16", DType::F32 => "_f32", DType::F64 => "_f64", + DType::F8E4M3 => "_f8_e4m3", }; let spec = if let Some(suffix) = suffix { spec.to_owned() + suffix diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index 899313cb9..0109690d5 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -20,7 +20,7 @@ pyo3.workspace = true mistralrs-core = { version = "0.3.1", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", features=["$feature_name"] } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "4754afe", features=["$feature_name"] } indexmap.workspace = true accelerate-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true } diff --git a/mistralrs-quant/Cargo.toml b/mistralrs-quant/Cargo.toml index cc6ba5ef0..eafc13ee7 100644 --- a/mistralrs-quant/Cargo.toml +++ b/mistralrs-quant/Cargo.toml @@ -21,6 +21,7 @@ paste = "1.0.15" tracing.workspace = true rayon.workspace = true byteorder = "1.5.0" +float8.workspace = true [features] cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda"] diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs index de59be5d0..3b38d4e69 100644 --- a/mistralrs-quant/src/utils/ops.rs +++ b/mistralrs-quant/src/utils/ops.rs @@ -70,6 +70,7 @@ impl CustomOp2 for BitWiseOr { CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or")), CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or")), + CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise-or")), } } #[cfg(feature = "cuda")] @@ -141,6 +142,9 @@ impl CustomOp2 for BitWiseOr { DType::F64 => { return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or")); } + DType::F8E4M3 => { + return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise-or")); + } }; let dst = match s1.dtype() { DType::U8 => { @@ -226,6 +230,7 @@ impl CustomOp1 for Leftshift { CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")), CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")), + CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshifr")), } } #[cfg(feature = "cuda")] @@ -269,6 +274,9 @@ impl CustomOp1 for Leftshift { DType::F64 => { return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift")); } + DType::F8E4M3 => { + return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshift")); + } }; let dst = match s1.dtype() { DType::U8 => { diff --git a/mistralrs-quant/src/utils/uqff.rs b/mistralrs-quant/src/utils/uqff.rs index 1494357aa..8c13c7b83 100644 --- a/mistralrs-quant/src/utils/uqff.rs +++ b/mistralrs-quant/src/utils/uqff.rs @@ -1,11 +1,16 @@ use byteorder::{LittleEndian, ReadBytesExt}; use candle_core::{DType, Device, Result, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; +// v0.1.0: initial release +// v0.1.1: add i16 dtype +// v0.1.2: add F8E4M3 + const HQFF_VERSION_MAJOR: u32 = 0; const HQFF_VERSION_MINOR: u32 = 1; -const HQFF_VERSION_PATCH: u32 = 1; +const HQFF_VERSION_PATCH: u32 = 2; /// Format 4 bytes, little endian: [ UNSPECIFIED ] [ MAJOR ] [ MINOR ] [ PATCH ] pub(crate) const HQFF_VERSION: u32 = @@ -54,6 +59,7 @@ pub(crate) fn serialize_tensor(buffer: &mut Vec, tensor: &Tensor) -> Result< DType::BF16 => data_to_bytes::(tensor.to_vec1()?), DType::F32 => data_to_bytes::(tensor.to_vec1()?), DType::F64 => data_to_bytes::(tensor.to_vec1()?), + DType::F8E4M3 => data_to_bytes::(tensor.to_vec1()?), }; buffer.extend(&(bias.len() as u32).to_le_bytes()); @@ -67,6 +73,7 @@ pub(crate) fn serialize_tensor(buffer: &mut Vec, tensor: &Tensor) -> Result< DType::F32 => 6, DType::F64 => 7, DType::I16 => 8, + DType::F8E4M3 => 9, }; buffer.extend(&dtype.to_le_bytes()); @@ -121,6 +128,7 @@ pub(crate) fn deserialize_tensor( DType::I16 => bytes_to_data::(&tensor_data, &dims, device), DType::U32 => bytes_to_data::(&tensor_data, &dims, device), DType::U8 => bytes_to_data::(&tensor_data, &dims, device), + DType::F8E4M3 => bytes_to_data::(&tensor_data, &dims, device), } }