Skip to content

Commit

Permalink
Fix chat template and F16/BF16 CUDA GEMM when RUST_BACKTRACE (#370)
Browse files Browse the repository at this point in the history
* Fix chat template deserialization

* Fix gemm cap detection when in debug

* Bump version 0.1.12 -> 0.1.13
  • Loading branch information
EricLBuehler authored Jun 2, 2024
1 parent 24b33b1 commit d904085
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ members = [
resolver = "2"

[workspace.package]
version = "0.1.12"
version = "0.1.13"
edition = "2021"
description = "Fast and easy LLM serving."
homepage = "https://github.com/EricLBuehler/mistral.rs"
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-bench/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
serde.workspace = true
serde_json.workspace = true
clap.workspace = true
mistralrs-core = { version = "0.1.12", path = "../mistralrs-core" }
mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" }
tracing.workspace = true
either.workspace = true
tokio.workspace = true
Expand Down
32 changes: 12 additions & 20 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,34 +151,26 @@ fn set_gemm_reduced_precision_f16() {
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
match a.matmul(&a) {
Ok(_) => (),
Err(e) => match e {
candle_core::Error::Cuda(e) => {
let x = e.downcast::<candle_core::cuda::cudarc::cublas::result::CublasError>();
if format!("{x:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
tracing::info!("GEMM reduced precision in BF16 not supported.");
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
}
Err(e) => {
if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
tracing::info!("GEMM reduced precision in BF16 not supported.");
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
}
_ => (),
},
}
}

let a = Tensor::zeros((2, 2), DType::F16, &Device::new_cuda(0).unwrap()).unwrap();
candle_core::cuda::set_gemm_reduced_precision_f16(true);
match a.matmul(&a) {
Ok(_) => (),
Err(e) => match e {
candle_core::Error::Cuda(e) => {
let x = e.downcast::<candle_core::cuda::cudarc::cublas::result::CublasError>();
if format!("{x:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
tracing::info!("GEMM reduced precision in F16 not supported.");
candle_core::cuda::set_gemm_reduced_precision_f16(false);
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
}
Err(e) => {
if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
tracing::info!("GEMM reduced precision in F16 not supported.");
candle_core::cuda::set_gemm_reduced_precision_f16(false);
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
}
_ => (),
},
}
}
}

Expand Down
75 changes: 43 additions & 32 deletions mistralrs-core/src/pipeline/paths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ pub(crate) fn get_chat_template(
{
panic!("Template filename {template_filename:?} must end with `.json`.");
}
let template: ChatTemplate = serde_json::from_str(
&fs::read_to_string(&template_filename).expect("Deserialization of chat template failed."),
)
.unwrap();

#[derive(Debug, serde::Deserialize)]
struct SpecifiedTemplate {
chat_template: String,
Expand All @@ -318,40 +323,46 @@ pub(crate) fn get_chat_template(
let mut deser: HashMap<String, Value> =
serde_json::from_str(&fs::read_to_string(&template_filename).unwrap()).unwrap();

match chat_template.clone() {
Some(t) => {
if t.ends_with(".json") {
info!("Loading specified loading chat template file at `{t}`.");
let templ: SpecifiedTemplate =
serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
deser.insert(
"chat_template".to_string(),
Value::String(templ.chat_template),
);
if templ.bos_token.is_some() {
deser.insert(
"bos_token".to_string(),
Value::String(templ.bos_token.unwrap()),
);
match &template.chat_template {
Some(_) => template,
None => {
match chat_template.clone() {
Some(t) => {
if t.ends_with(".json") {
info!("Loading specified loading chat template file at `{t}`.");
let templ: SpecifiedTemplate =
serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
deser.insert(
"chat_template".to_string(),
Value::String(templ.chat_template),
);
if templ.bos_token.is_some() {
deser.insert(
"bos_token".to_string(),
Value::String(templ.bos_token.unwrap()),
);
}
if templ.eos_token.is_some() {
deser.insert(
"eos_token".to_string(),
Value::String(templ.eos_token.unwrap()),
);
}
info!("Loaded chat template file.");
} else {
deser.insert("chat_template".to_string(), Value::String(t));
info!("Loaded specified literal chat template.");
}
}
if templ.eos_token.is_some() {
deser.insert(
"eos_token".to_string(),
Value::String(templ.eos_token.unwrap()),
);
None => {
info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
deser.insert("chat_template".to_string(), Value::Null);
}
info!("Loaded chat template file.");
} else {
deser.insert("chat_template".to_string(), Value::String(t));
info!("Loaded specified literal chat template.");
}

let ser = serde_json::to_string_pretty(&deser)
.expect("Serialization of modified chat template failed.");
serde_json::from_str(&ser).unwrap()
}
None => {
info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
deser.insert("chat_template".to_string(), Value::Null);
}
};
let ser = serde_json::to_string_pretty(&deser)
.expect("Serialization of modified chat template failed.");
serde_json::from_str(&ser).unwrap()
}
}
2 changes: 1 addition & 1 deletion mistralrs-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ doc = false

[dependencies]
pyo3.workspace = true
mistralrs-core = { version = "0.1.12", path = "../mistralrs-core", features = ["pyo3_macros"] }
mistralrs-core = { version = "0.1.13", path = "../mistralrs-core", features = ["pyo3_macros"] }
serde.workspace = true
serde_json.workspace = true
candle-core.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-pyo3/Cargo_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ doc = false

[dependencies]
pyo3.workspace = true
mistralrs-core = { version = "0.1.12", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] }
mistralrs-core = { version = "0.1.13", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] }
serde.workspace = true
serde_json.workspace = true
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", features=["$feature_name"] }
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-pyo3/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "maturin"

[project]
name = "mistralrs"
version = "0.1.12"
version = "0.1.13"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-pyo3/pyproject_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "maturin"

[project]
name = "$name"
version = "0.1.12"
version = "0.1.13"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ axum = { version = "0.7.4", features = ["tokio"] }
tower-http = { version = "0.5.1", features = ["cors"]}
utoipa = { version = "4.2", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "7.1.0", features = ["axum"]}
mistralrs-core = { version = "0.1.12", path = "../mistralrs-core" }
mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" }
dyn-fmt = "0.4.0"
indexmap.workspace = true
accelerate-src = { workspace = true, optional = true }
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ license.workspace = true
homepage.workspace = true

[dependencies]
mistralrs-core = { version = "0.1.12", path = "../mistralrs-core" }
mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" }
anyhow.workspace = true
tokio.workspace = true
candle-core.workspace = true
Expand Down

0 comments on commit d904085

Please sign in to comment.