From d50ea887216d7daab048b6873fa24f41a9f0f70c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 12 Oct 2024 07:46:05 -0400 Subject: [PATCH] Expose config for rust api, tweak modekind --- mistralrs-core/src/lib.rs | 16 ++++++++++++++++ mistralrs-core/src/pipeline/ggml.rs | 6 +++--- mistralrs-core/src/pipeline/gguf.rs | 8 ++++---- mistralrs-core/src/pipeline/loaders/mod.rs | 18 +++++++++--------- mistralrs-core/src/utils/memory_usage.rs | 2 ++ 5 files changed, 34 insertions(+), 16 deletions(-) diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 59281e0e6..86ae99a30 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -1,5 +1,6 @@ #![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use candle_core::Device; use cublaslt::setup_cublas_lt_wrapper; use engine::Engine; pub use engine::{EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP}; @@ -105,6 +106,11 @@ pub use utils::paged_attn_supported; pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false); static ENGINE_ID: AtomicUsize = AtomicUsize::new(0); +pub struct MistralRsConfig { + pub kind: ModelKind, + pub device: Device, +} + /// The MistralRs struct handles sending requests to the engine. /// It is the core multi-threaded component of mistral.rs, and uses `mspc` /// `Sender` and `Receiver` primitives to send and receive requests to the @@ -119,6 +125,7 @@ pub struct MistralRs { engine_handler: RwLock>, engine_id: usize, category: ModelCategory, + config: MistralRsConfig, } #[derive(Clone)] @@ -322,6 +329,10 @@ impl MistralRs { let sender = RwLock::new(tx); let id = pipeline.try_lock().unwrap().name(); + let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone(); + let device = pipeline.try_lock().unwrap().device(); + let config = MistralRsConfig { kind, device }; + let engine_handler = thread::spawn(move || { let rt = Runtime::new().unwrap(); rt.block_on(async move { @@ -355,6 +366,7 @@ impl MistralRs { reboot_state, engine_handler: RwLock::new(engine_handler), category, + config, }) } @@ -481,4 +493,8 @@ impl MistralRs { .expect("Unable to write data"); } } + + pub fn config(&self) -> &MistralRsConfig { + &self.config + } } diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 0b64fc206..33385878d 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -109,7 +109,7 @@ impl GGMLLoaderBuilder { quantized_model_id: String, quantized_filename: String, ) -> Self { - let kind = ModelKind::Quantized { + let kind = ModelKind::GgufQuantized { quant: QuantizationKind::Ggml, }; @@ -339,8 +339,8 @@ impl Loader for GGMLLoader { // Config into model: // NOTE: No architecture to infer like GGUF, Llama model is implicitly matched let model = match self.kind { - ModelKind::Quantized { .. } => Model::Llama(QLlama::try_from(model_config)?), - ModelKind::AdapterQuantized { .. } => { + ModelKind::GgufQuantized { .. } => Model::Llama(QLlama::try_from(model_config)?), + ModelKind::GgufAdapter { .. } => { Model::XLoraLlama(XLoraQLlama::try_from(model_config)?) } _ => unreachable!(), diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index febd76e99..432d1e050 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -124,7 +124,7 @@ impl GGUFLoaderBuilder { quantized_filenames: Vec, config: GGUFSpecificConfig, ) -> Self { - let kind = ModelKind::Quantized { + let kind = ModelKind::GgufQuantized { quant: QuantizationKind::Gguf, }; @@ -394,7 +394,7 @@ impl Loader for GGUFLoader { let has_adapter = self.kind.is_adapted(); let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); - let paged_attn_config = if matches!(self.kind, ModelKind::AdapterQuantized { .. }) { + let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) { warn!("Adapter models do not currently support PagedAttention, running without"); None } else { @@ -431,7 +431,7 @@ impl Loader for GGUFLoader { // Config into model: let model = match self.kind { - ModelKind::Quantized { .. } => match arch { + ModelKind::GgufQuantized { .. } => match arch { GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?), GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?), GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?), @@ -440,7 +440,7 @@ impl Loader for GGUFLoader { } a => bail!("Unsupported architecture `{a:?}` for GGUF"), }, - ModelKind::AdapterQuantized { adapter, .. } => match arch { + ModelKind::GgufAdapter { adapter, .. } => match arch { GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?), GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?), a => bail!( diff --git a/mistralrs-core/src/pipeline/loaders/mod.rs b/mistralrs-core/src/pipeline/loaders/mod.rs index 72eace924..f6c6fefd2 100644 --- a/mistralrs-core/src/pipeline/loaders/mod.rs +++ b/mistralrs-core/src/pipeline/loaders/mod.rs @@ -236,17 +236,17 @@ impl fmt::Display for TokenSource { #[derive(Clone, Default, derive_more::From, strum::Display)] pub enum ModelKind { #[default] - #[strum(to_string = "normal (no quant, no adapters)")] + #[strum(to_string = "normal (no adapters)")] Normal, - #[strum(to_string = "quantized from {quant} (no adapters)")] - Quantized { quant: QuantizationKind }, + #[strum(to_string = "gguf quantized from {quant} (no adapters)")] + GgufQuantized { quant: QuantizationKind }, - #[strum(to_string = "{adapter}, (no quant)")] + #[strum(to_string = "{adapter}")] Adapter { adapter: AdapterKind }, - #[strum(to_string = "{adapter}, quantized from {quant}")] - AdapterQuantized { + #[strum(to_string = "{adapter}, gguf quantized from {quant}")] + GgufAdapter { adapter: AdapterKind, quant: QuantizationKind, }, @@ -311,7 +311,7 @@ impl ModelKind { match self { Normal | Adapter { .. } => vec![None], - Quantized { quant } | AdapterQuantized { quant, .. } => vec![Some(*quant)], + GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)], Speculative { target, draft } => { let t = *target.clone(); let d = *draft.clone(); @@ -335,8 +335,8 @@ impl ModelKind { use ModelKind::*; match self { - Normal | Quantized { .. } => vec![None], - Adapter { adapter } | AdapterQuantized { adapter, .. } => vec![Some(*adapter)], + Normal | GgufQuantized { .. } => vec![None], + Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)], Speculative { target, draft } => { let t = *target.clone(); let d = *draft.clone(); diff --git a/mistralrs-core/src/utils/memory_usage.rs b/mistralrs-core/src/utils/memory_usage.rs index 611cd8ee2..dcf52a479 100644 --- a/mistralrs-core/src/utils/memory_usage.rs +++ b/mistralrs-core/src/utils/memory_usage.rs @@ -6,6 +6,7 @@ const KB_TO_BYTES: usize = 1024; pub struct MemoryUsage; impl MemoryUsage { + /// Amount of available memory in bytes. pub fn get_memory_available(&self, device: &Device) -> Result { match device { Device::Cpu => { @@ -30,6 +31,7 @@ impl MemoryUsage { } } + /// Amount of total memory in bytes. pub fn get_total_memory(&self, device: &Device) -> Result { match device { Device::Cpu => {