diff --git a/Cargo.toml b/Cargo.toml index 226146294..8ecadfccf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.1.12" +version = "0.1.13" edition = "2021" description = "Fast and easy LLM serving." homepage = "https://github.com/EricLBuehler/mistral.rs" diff --git a/mistralrs-bench/Cargo.toml b/mistralrs-bench/Cargo.toml index 87407ba19..60b41bddf 100644 --- a/mistralrs-bench/Cargo.toml +++ b/mistralrs-bench/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true serde.workspace = true serde_json.workspace = true clap.workspace = true -mistralrs-core = { version = "0.1.12", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" } tracing.workspace = true either.workspace = true tokio.workspace = true diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index c3c2f7dee..8132071ad 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -151,34 +151,26 @@ fn set_gemm_reduced_precision_f16() { candle_core::cuda::set_gemm_reduced_precision_bf16(true); match a.matmul(&a) { Ok(_) => (), - Err(e) => match e { - candle_core::Error::Cuda(e) => { - let x = e.downcast::(); - if format!("{x:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { - tracing::info!("GEMM reduced precision in BF16 not supported."); - candle_core::cuda::set_gemm_reduced_precision_bf16(false); - INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed); - } + Err(e) => { + if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { + tracing::info!("GEMM reduced precision in BF16 not supported."); + candle_core::cuda::set_gemm_reduced_precision_bf16(false); + INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed); } - _ => (), - }, + } } let a = Tensor::zeros((2, 2), DType::F16, &Device::new_cuda(0).unwrap()).unwrap(); candle_core::cuda::set_gemm_reduced_precision_f16(true); match a.matmul(&a) { Ok(_) => (), - Err(e) => match e { - candle_core::Error::Cuda(e) => { - let x = e.downcast::(); - if format!("{x:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { - tracing::info!("GEMM reduced precision in F16 not supported."); - candle_core::cuda::set_gemm_reduced_precision_f16(false); - INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed); - } + Err(e) => { + if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { + tracing::info!("GEMM reduced precision in F16 not supported."); + candle_core::cuda::set_gemm_reduced_precision_f16(false); + INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed); } - _ => (), - }, + } } } diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index cc7cf64ea..76af55680 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -307,6 +307,11 @@ pub(crate) fn get_chat_template( { panic!("Template filename {template_filename:?} must end with `.json`."); } + let template: ChatTemplate = serde_json::from_str( + &fs::read_to_string(&template_filename).expect("Deserialization of chat template failed."), + ) + .unwrap(); + #[derive(Debug, serde::Deserialize)] struct SpecifiedTemplate { chat_template: String, @@ -318,40 +323,46 @@ pub(crate) fn get_chat_template( let mut deser: HashMap = serde_json::from_str(&fs::read_to_string(&template_filename).unwrap()).unwrap(); - match chat_template.clone() { - Some(t) => { - if t.ends_with(".json") { - info!("Loading specified loading chat template file at `{t}`."); - let templ: SpecifiedTemplate = - serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap(); - deser.insert( - "chat_template".to_string(), - Value::String(templ.chat_template), - ); - if templ.bos_token.is_some() { - deser.insert( - "bos_token".to_string(), - Value::String(templ.bos_token.unwrap()), - ); + match &template.chat_template { + Some(_) => template, + None => { + match chat_template.clone() { + Some(t) => { + if t.ends_with(".json") { + info!("Loading specified loading chat template file at `{t}`."); + let templ: SpecifiedTemplate = + serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap(); + deser.insert( + "chat_template".to_string(), + Value::String(templ.chat_template), + ); + if templ.bos_token.is_some() { + deser.insert( + "bos_token".to_string(), + Value::String(templ.bos_token.unwrap()), + ); + } + if templ.eos_token.is_some() { + deser.insert( + "eos_token".to_string(), + Value::String(templ.eos_token.unwrap()), + ); + } + info!("Loaded chat template file."); + } else { + deser.insert("chat_template".to_string(), Value::String(t)); + info!("Loaded specified literal chat template."); + } } - if templ.eos_token.is_some() { - deser.insert( - "eos_token".to_string(), - Value::String(templ.eos_token.unwrap()), - ); + None => { + info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages."); + deser.insert("chat_template".to_string(), Value::Null); } - info!("Loaded chat template file."); - } else { - deser.insert("chat_template".to_string(), Value::String(t)); - info!("Loaded specified literal chat template."); } + + let ser = serde_json::to_string_pretty(&deser) + .expect("Serialization of modified chat template failed."); + serde_json::from_str(&ser).unwrap() } - None => { - info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages."); - deser.insert("chat_template".to_string(), Value::Null); - } - }; - let ser = serde_json::to_string_pretty(&deser) - .expect("Serialization of modified chat template failed."); - serde_json::from_str(&ser).unwrap() + } } diff --git a/mistralrs-pyo3/Cargo.toml b/mistralrs-pyo3/Cargo.toml index 63315d921..574d7b1c7 100644 --- a/mistralrs-pyo3/Cargo.toml +++ b/mistralrs-pyo3/Cargo.toml @@ -17,7 +17,7 @@ doc = false [dependencies] pyo3.workspace = true -mistralrs-core = { version = "0.1.12", path = "../mistralrs-core", features = ["pyo3_macros"] } +mistralrs-core = { version = "0.1.13", path = "../mistralrs-core", features = ["pyo3_macros"] } serde.workspace = true serde_json.workspace = true candle-core.workspace = true diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index bede453ea..d7be7710c 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -17,7 +17,7 @@ doc = false [dependencies] pyo3.workspace = true -mistralrs-core = { version = "0.1.12", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } +mistralrs-core = { version = "0.1.13", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", features=["$feature_name"] } diff --git a/mistralrs-pyo3/pyproject.toml b/mistralrs-pyo3/pyproject.toml index 368f3ca2b..e684a4091 100644 --- a/mistralrs-pyo3/pyproject.toml +++ b/mistralrs-pyo3/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "mistralrs" -version = "0.1.12" +version = "0.1.13" requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", diff --git a/mistralrs-pyo3/pyproject_template.toml b/mistralrs-pyo3/pyproject_template.toml index 251a5cba3..3fa3ad241 100644 --- a/mistralrs-pyo3/pyproject_template.toml +++ b/mistralrs-pyo3/pyproject_template.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "$name" -version = "0.1.12" +version = "0.1.13" requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", diff --git a/mistralrs-server/Cargo.toml b/mistralrs-server/Cargo.toml index c76822e94..4355a15a2 100644 --- a/mistralrs-server/Cargo.toml +++ b/mistralrs-server/Cargo.toml @@ -22,7 +22,7 @@ axum = { version = "0.7.4", features = ["tokio"] } tower-http = { version = "0.5.1", features = ["cors"]} utoipa = { version = "4.2", features = ["axum_extras"] } utoipa-swagger-ui = { version = "7.1.0", features = ["axum"]} -mistralrs-core = { version = "0.1.12", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" } dyn-fmt = "0.4.0" indexmap.workspace = true accelerate-src = { workspace = true, optional = true } diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml index 80170da5e..e0134537f 100644 --- a/mistralrs/Cargo.toml +++ b/mistralrs/Cargo.toml @@ -12,7 +12,7 @@ license.workspace = true homepage.workspace = true [dependencies] -mistralrs-core = { version = "0.1.12", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" } anyhow.workspace = true tokio.workspace = true candle-core.workspace = true