Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ISQ FP8 #832

Merged
merged 12 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "60eb251" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "60eb251" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand All @@ -49,3 +49,4 @@ rayon = "1.1.0"
url = "2.5.2"
data-url = "0.3.1"
buildstructor = "0.5.4"
float8 = "0.1.1"
1 change: 1 addition & 0 deletions docs/ISQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ To set the ISQ type for individual layers, use a model [`topology`](TOPOLOGY.md)
- Q8K (*not available on CUDA*)
- HQQ4
- HQQ8
- FP8

When using ISQ, it will automatically load ISQ-able weights into CPU memory before applying ISQ. The ISQ application process moves the weights to device memory. This process is implemented to avoid memory spikes from loading the model in full precision.

Expand Down
3 changes: 3 additions & 0 deletions docs/UQFF.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ The following quantization formats are supported in UQFF. One can, of course, be
- HQQ4
- HQQ8

- FP8:
- FP8 E4M3 (4-bit exponent, 3-bit mantissa)

## Loading a UQFF model

To load a UQFF model, one should specify the artifact path. This can be either be a path to a UQFF file locally, or a Hugging Face model ID with the format `<MODEL ID>/<FILE>`. For example, the following work:
Expand Down
14 changes: 13 additions & 1 deletion docs/UQFF/LAYOUT.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0
| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |


## HQQ quantization
| ID | Element type | Endianness |
| -------- | -------- | -------- |
Expand All @@ -51,6 +50,19 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0
| CFG round zeroes (boolean) | u8 | little endian |
| CFG channel wise (boolean) | u8 | little endian |

## FP8 layers
| ID | Element type | Endianness |
| -------- | -------- | -------- |
| HQFF version | u32 | little endian |
| ISQ type (3) | u8 | little endian |
| Whether bias data is included (boolean) | u8 | little endian |
| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
| Dequant scale W | f32 | little endian |
| Dequant scale X | f32 | little endian |
| Quant scale | f32 | little endian |
| Layer dtype | u32 | little endian |
| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |

## Standard tensors
| ID | Element type | Endianness |
| -------- | -------- | -------- |
Expand Down
5 changes: 3 additions & 2 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "60eb251", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down Expand Up @@ -78,10 +78,11 @@ regex = "1.10.6"
safetensors = "0.4.5"
serde_plain = "1.0.2"
as-any = "0.3.1"
float8.workspace = true

[features]
pyo3_macros = ["pyo3"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda", "float8/cuda"]
cudnn = ["candle-core/cudnn"]
metal = ["candle-core/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
Expand Down
12 changes: 7 additions & 5 deletions mistralrs-core/src/cublaslt/api.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
pub use candle_core::cuda_backend::cudarc::cublaslt::Activation;
use candle_core::cuda::cudarc::driver::DevicePtr;
use float8::F8E4M3;
use std::ffi::c_int;

use candle_core::backend::BackendStorage;
use candle_core::cuda_backend::WrapErr;
use candle_core::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor};
use candle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor};
use half::{bf16, f16};
use std::sync::Arc;

use candle_core::cuda_backend::cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig};
use super::matmul::{Activation, CudaBlasLT, Matmul, MatmulConfig};

#[derive(Debug, Clone)]
pub struct CublasLt(Arc<CudaBlasLT>);
Expand Down Expand Up @@ -858,11 +859,12 @@ pub fn fused_batch_matmul(
a.apply_op2(b, op)
}
}

#[cfg(test)]
mod tests {
use std::f32::consts::PI;

use super::*;
use candle_core::{DType, Device};
use candle_core::{DType, Device, IndexOp};

fn to_vec2_round(t: Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
let b = 10f32.powi(digits);
Expand Down
Loading
Loading