Skip to content

Commit

Permalink
Expose config for rust api, tweak modekind
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Oct 12, 2024
1 parent 9dfbab1 commit d50ea88
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 16 deletions.
16 changes: 16 additions & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use candle_core::Device;
use cublaslt::setup_cublas_lt_wrapper;
use engine::Engine;
pub use engine::{EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP};
Expand Down Expand Up @@ -105,6 +106,11 @@ pub use utils::paged_attn_supported;
pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
static ENGINE_ID: AtomicUsize = AtomicUsize::new(0);

pub struct MistralRsConfig {
pub kind: ModelKind,
pub device: Device,
}

/// The MistralRs struct handles sending requests to the engine.
/// It is the core multi-threaded component of mistral.rs, and uses `mspc`
/// `Sender` and `Receiver` primitives to send and receive requests to the
Expand All @@ -119,6 +125,7 @@ pub struct MistralRs {
engine_handler: RwLock<JoinHandle<()>>,
engine_id: usize,
category: ModelCategory,
config: MistralRsConfig,
}

#[derive(Clone)]
Expand Down Expand Up @@ -322,6 +329,10 @@ impl MistralRs {
let sender = RwLock::new(tx);
let id = pipeline.try_lock().unwrap().name();

let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
let device = pipeline.try_lock().unwrap().device();
let config = MistralRsConfig { kind, device };

let engine_handler = thread::spawn(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async move {
Expand Down Expand Up @@ -355,6 +366,7 @@ impl MistralRs {
reboot_state,
engine_handler: RwLock::new(engine_handler),
category,
config,
})
}

Expand Down Expand Up @@ -481,4 +493,8 @@ impl MistralRs {
.expect("Unable to write data");
}
}

pub fn config(&self) -> &MistralRsConfig {
&self.config
}
}
6 changes: 3 additions & 3 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl GGMLLoaderBuilder {
quantized_model_id: String,
quantized_filename: String,
) -> Self {
let kind = ModelKind::Quantized {
let kind = ModelKind::GgufQuantized {
quant: QuantizationKind::Ggml,
};

Expand Down Expand Up @@ -339,8 +339,8 @@ impl Loader for GGMLLoader {
// Config into model:
// NOTE: No architecture to infer like GGUF, Llama model is implicitly matched
let model = match self.kind {
ModelKind::Quantized { .. } => Model::Llama(QLlama::try_from(model_config)?),
ModelKind::AdapterQuantized { .. } => {
ModelKind::GgufQuantized { .. } => Model::Llama(QLlama::try_from(model_config)?),
ModelKind::GgufAdapter { .. } => {
Model::XLoraLlama(XLoraQLlama::try_from(model_config)?)
}
_ => unreachable!(),
Expand Down
8 changes: 4 additions & 4 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl GGUFLoaderBuilder {
quantized_filenames: Vec<String>,
config: GGUFSpecificConfig,
) -> Self {
let kind = ModelKind::Quantized {
let kind = ModelKind::GgufQuantized {
quant: QuantizationKind::Gguf,
};

Expand Down Expand Up @@ -394,7 +394,7 @@ impl Loader for GGUFLoader {
let has_adapter = self.kind.is_adapted();
let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());

let paged_attn_config = if matches!(self.kind, ModelKind::AdapterQuantized { .. }) {
let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
warn!("Adapter models do not currently support PagedAttention, running without");
None
} else {
Expand Down Expand Up @@ -431,7 +431,7 @@ impl Loader for GGUFLoader {

// Config into model:
let model = match self.kind {
ModelKind::Quantized { .. } => match arch {
ModelKind::GgufQuantized { .. } => match arch {
GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
Expand All @@ -440,7 +440,7 @@ impl Loader for GGUFLoader {
}
a => bail!("Unsupported architecture `{a:?}` for GGUF"),
},
ModelKind::AdapterQuantized { adapter, .. } => match arch {
ModelKind::GgufAdapter { adapter, .. } => match arch {
GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
a => bail!(
Expand Down
18 changes: 9 additions & 9 deletions mistralrs-core/src/pipeline/loaders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,17 @@ impl fmt::Display for TokenSource {
#[derive(Clone, Default, derive_more::From, strum::Display)]
pub enum ModelKind {
#[default]
#[strum(to_string = "normal (no quant, no adapters)")]
#[strum(to_string = "normal (no adapters)")]
Normal,

#[strum(to_string = "quantized from {quant} (no adapters)")]
Quantized { quant: QuantizationKind },
#[strum(to_string = "gguf quantized from {quant} (no adapters)")]
GgufQuantized { quant: QuantizationKind },

#[strum(to_string = "{adapter}, (no quant)")]
#[strum(to_string = "{adapter}")]
Adapter { adapter: AdapterKind },

#[strum(to_string = "{adapter}, quantized from {quant}")]
AdapterQuantized {
#[strum(to_string = "{adapter}, gguf quantized from {quant}")]
GgufAdapter {
adapter: AdapterKind,
quant: QuantizationKind,
},
Expand Down Expand Up @@ -311,7 +311,7 @@ impl ModelKind {

match self {
Normal | Adapter { .. } => vec![None],
Quantized { quant } | AdapterQuantized { quant, .. } => vec![Some(*quant)],
GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
Speculative { target, draft } => {
let t = *target.clone();
let d = *draft.clone();
Expand All @@ -335,8 +335,8 @@ impl ModelKind {
use ModelKind::*;

match self {
Normal | Quantized { .. } => vec![None],
Adapter { adapter } | AdapterQuantized { adapter, .. } => vec![Some(*adapter)],
Normal | GgufQuantized { .. } => vec![None],
Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
Speculative { target, draft } => {
let t = *target.clone();
let d = *draft.clone();
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/utils/memory_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const KB_TO_BYTES: usize = 1024;
pub struct MemoryUsage;

impl MemoryUsage {
/// Amount of available memory in bytes.
pub fn get_memory_available(&self, device: &Device) -> Result<usize> {
match device {
Device::Cpu => {
Expand All @@ -30,6 +31,7 @@ impl MemoryUsage {
}
}

/// Amount of total memory in bytes.
pub fn get_total_memory(&self, device: &Device) -> Result<usize> {
match device {
Device::Cpu => {
Expand Down

0 comments on commit d50ea88

Please sign in to comment.