Skip to content

Commit

Permalink
Merge pull request #2748 from fermyon/llm-features
Browse files Browse the repository at this point in the history
Put local llm behind feature flags like they used to be.
  • Loading branch information
rylev authored Aug 23, 2024
2 parents 6db2872 + 85b55a3 commit 752939c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 34 deletions.
7 changes: 3 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,9 @@ wit-component = "0.19.0"
# TODO(factors): default = ["llm"]
all-tests = ["extern-dependencies-tests"]
extern-dependencies-tests = []
# TODO(factors):
# llm = ["spin-trigger-http/llm"]
# llm-metal = ["llm", "spin-trigger-http/llm-metal"]
# llm-cublas = ["llm", "spin-trigger-http/llm-cublas"]
llm = ["spin-trigger-http2/llm"]
llm-metal = ["llm", "spin-trigger-http2/llm-metal"]
llm-cublas = ["llm", "spin-trigger-http2/llm-cublas"]

[workspace]
members = [
Expand Down
9 changes: 7 additions & 2 deletions crates/factor-llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@ homepage.workspace = true
repository.workspace = true
rust-version.workspace = true

[features]
llm = ["spin-llm-local"]
llm-metal = ["llm", "spin-llm-local/metal"]
llm-cublas = ["llm", "spin-llm-local/cublas"]

[dependencies]
anyhow = "1.0"
async-trait = "0.1"
serde = "1.0"
spin-factors = { path = "../factors" }
spin-llm-local = { path = "../llm-local" }
spin-llm-local = { path = "../llm-local", optional = true }
spin-llm-remote-http = { path = "../llm-remote-http" }
spin-locked-app = { path = "../locked-app" }
spin-world = { path = "../world" }
tracing = { workspace = true }
tokio = { version = "1", features = ["sync"] }
toml = "0.8"
url = "2"
url = { version = "2", features = ["serde"] }

[dev-dependencies]
spin-factors-test = { path = "../factors-test" }
Expand Down
102 changes: 74 additions & 28 deletions crates/factor-llm/src/spin.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::path::PathBuf;
use std::sync::Arc;

pub use spin_llm_local::LocalLlmEngine;

use spin_llm_remote_http::RemoteHttpLlmEngine;
use spin_world::async_trait;
use spin_world::v1::llm::{self as v1};
Expand All @@ -12,26 +10,48 @@ use url::Url;

use crate::{LlmEngine, LlmEngineCreator, RuntimeConfig};

#[async_trait]
impl LlmEngine for LocalLlmEngine {
async fn infer(
&mut self,
model: v1::InferencingModel,
prompt: String,
params: v2::InferencingParams,
) -> Result<v2::InferencingResult, v2::Error> {
self.infer(model, prompt, params).await
}
#[cfg(feature = "llm")]
mod local {
use super::*;
pub use spin_llm_local::LocalLlmEngine;

async fn generate_embeddings(
&mut self,
model: v2::EmbeddingModel,
data: Vec<String>,
) -> Result<v2::EmbeddingsResult, v2::Error> {
self.generate_embeddings(model, data).await
#[async_trait]
impl LlmEngine for LocalLlmEngine {
async fn infer(
&mut self,
model: v2::InferencingModel,
prompt: String,
params: v2::InferencingParams,
) -> Result<v2::InferencingResult, v2::Error> {
self.infer(model, prompt, params).await
}

async fn generate_embeddings(
&mut self,
model: v2::EmbeddingModel,
data: Vec<String>,
) -> Result<v2::EmbeddingsResult, v2::Error> {
self.generate_embeddings(model, data).await
}
}
}

/// The default engine creator for the LLM factor when used in the Spin CLI.
pub fn default_engine_creator(
state_dir: PathBuf,
use_gpu: bool,
) -> impl LlmEngineCreator + 'static {
#[cfg(feature = "llm")]
let engine = spin_llm_local::LocalLlmEngine::new(state_dir.join("ai-models"), use_gpu);
#[cfg(not(feature = "llm"))]
let engine = {
let _ = (state_dir, use_gpu);
noop::NoopLlmEngine
};
let engine = Arc::new(Mutex::new(engine)) as Arc<Mutex<dyn LlmEngine>>;
move || engine.clone()
}

#[async_trait]
impl LlmEngine for RemoteHttpLlmEngine {
async fn infer(
Expand Down Expand Up @@ -77,6 +97,12 @@ pub enum LlmCompute {
impl LlmCompute {
fn into_engine(self, state_dir: PathBuf, use_gpu: bool) -> Arc<Mutex<dyn LlmEngine>> {
match self {
#[cfg(not(feature = "llm"))]
LlmCompute::Spin => {
let _ = (state_dir, use_gpu);
Arc::new(Mutex::new(noop::NoopLlmEngine))
}
#[cfg(feature = "llm")]
LlmCompute::Spin => default_engine_creator(state_dir, use_gpu).create(),
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
config.url,
Expand All @@ -92,15 +118,35 @@ pub struct RemoteHttpCompute {
auth_token: String,
}

/// The default engine creator for the LLM factor when used in the Spin CLI.
pub fn default_engine_creator(
state_dir: PathBuf,
use_gpu: bool,
) -> impl LlmEngineCreator + 'static {
move || {
Arc::new(Mutex::new(LocalLlmEngine::new(
state_dir.join("ai-models"),
use_gpu,
))) as _
/// A noop engine used when the local engine feature is disabled.
#[cfg(not(feature = "llm"))]
mod noop {
use super::*;

#[derive(Clone, Copy)]
pub(super) struct NoopLlmEngine;

#[async_trait]
impl LlmEngine for NoopLlmEngine {
async fn infer(
&mut self,
_model: v2::InferencingModel,
_prompt: String,
_params: v2::InferencingParams,
) -> Result<v2::InferencingResult, v2::Error> {
Err(v2::Error::RuntimeError(
"Local LLM operations are not supported in this version of Spin.".into(),
))
}

async fn generate_embeddings(
&mut self,
_model: v2::EmbeddingModel,
_data: Vec<String>,
) -> Result<v2::EmbeddingsResult, v2::Error> {
Err(v2::Error::RuntimeError(
"Local LLM operations are not supported in this version of Spin.".into(),
))
}
}
}
5 changes: 5 additions & 0 deletions crates/trigger-http2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ edition = { workspace = true }
[lib]
doctest = false

[features]
llm = ["spin-trigger2/llm"]
llm-metal = ["spin-trigger2/llm-metal"]
llm-cublas = ["spin-trigger2/llm-cublas"]

[dependencies]
anyhow = "1.0"
async-trait = "0.1"
Expand Down
5 changes: 5 additions & 0 deletions crates/trigger2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ homepage.workspace = true
repository.workspace = true
rust-version.workspace = true

[features]
llm = ["spin-factor-llm/llm"]
llm-metal = ["spin-factor-llm/llm-metal"]
llm-cublas = ["spin-factor-llm/llm-cublas"]

[dependencies]
anyhow = "1"
clap = { version = "3.1.18", features = ["derive", "env"] }
Expand Down

0 comments on commit 752939c

Please sign in to comment.