Skip to content

Commit

Permalink
Properly handle state_dir
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Levick <[email protected]>
  • Loading branch information
rylev committed Aug 26, 2024
1 parent 8050ec9 commit 7c4b268
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 26 deletions.
19 changes: 13 additions & 6 deletions crates/factor-llm/src/spin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,25 @@ mod local {

/// The default engine creator for the LLM factor when used in the Spin CLI.
pub fn default_engine_creator(
state_dir: PathBuf,
state_dir: Option<PathBuf>,
use_gpu: bool,
) -> impl LlmEngineCreator + 'static {
) -> anyhow::Result<impl LlmEngineCreator + 'static> {
#[cfg(feature = "llm")]
let engine = spin_llm_local::LocalLlmEngine::new(state_dir.join("ai-models"), use_gpu);
let engine = {
use anyhow::Context as _;
let models_dir_parent = match state_dir {
Some(ref dir) => dir.clone(),
None => std::env::current_dir().context("failed to get current working directory")?,
};
spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"), use_gpu)
};
#[cfg(not(feature = "llm"))]
let engine = {
let _ = (state_dir, use_gpu);
noop::NoopLlmEngine
};
let engine = Arc::new(Mutex::new(engine)) as Arc<Mutex<dyn LlmEngine>>;
move || engine.clone()
Ok(move || engine.clone())
}

#[async_trait]
Expand All @@ -74,7 +81,7 @@ impl LlmEngine for RemoteHttpLlmEngine {

pub fn runtime_config_from_toml(
table: &toml::Table,
state_dir: PathBuf,
state_dir: Option<PathBuf>,
use_gpu: bool,
) -> anyhow::Result<Option<RuntimeConfig>> {
let Some(value) = table.get("llm_compute") else {
Expand All @@ -95,7 +102,7 @@ pub enum LlmCompute {
}

impl LlmCompute {
fn into_engine(self, state_dir: PathBuf, use_gpu: bool) -> Arc<Mutex<dyn LlmEngine>> {
fn into_engine(self, state_dir: Option<PathBuf>, use_gpu: bool) -> Arc<Mutex<dyn LlmEngine>> {
match self {
#[cfg(not(feature = "llm"))]
LlmCompute::Spin => {
Expand Down
46 changes: 37 additions & 9 deletions crates/runtime-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use spin_factor_sqlite::runtime_config::spin as sqlite;
use spin_factor_sqlite::SqliteFactor;
use spin_factor_variables::{spin_cli as variables, VariablesFactor};
use spin_factor_wasi::WasiFactor;
use spin_factors::runtime_config::toml::GetTomlValue as _;
use spin_factors::{
runtime_config::toml::TomlKeyTracker, FactorRuntimeConfigSource, RuntimeConfigSourceFinalizer,
};
Expand All @@ -33,6 +34,10 @@ pub struct ResolvedRuntimeConfig<T> {
pub key_value_resolver: key_value::RuntimeConfigResolver,
/// The resolver used to resolve sqlite databases from runtime configuration.
pub sqlite_resolver: sqlite::RuntimeConfigResolver,
/// The fully resolved state directory.
///
/// `None` is used for an "unset" state directory which each factor will treat differently.
pub state_dir: Option<PathBuf>,
}

impl<T> ResolvedRuntimeConfig<T>
Expand All @@ -41,6 +46,9 @@ where
for<'a> <T as TryFrom<TomlRuntimeConfigSource<'a>>>::Error: Into<anyhow::Error>,
{
/// Creates a new resolved runtime configuration from a runtime config source TOML file.
///
/// `state_dir` is the explicitly provided state directory, if any. Some("") will be treated as
/// as `None`.
pub fn from_file(
runtime_config_path: &Path,
state_dir: Option<&str>,
Expand All @@ -64,21 +72,22 @@ where
runtime_config_path.display()
)
})?;
let runtime_config: T = TomlRuntimeConfigSource::new(
let source = TomlRuntimeConfigSource::new(
&toml,
state_dir.unwrap_or(DEFAULT_STATE_DIR).into(),
state_dir,
&key_value_config_resolver,
&tls_resolver,
&sqlite_config_resolver,
use_gpu,
)
.try_into()
.map_err(Into::into)?;
);
let state_dir = source.state_dir();
let runtime_config: T = source.try_into().map_err(Into::into)?;

Ok(Self {
runtime_config,
key_value_resolver: key_value_config_resolver,
sqlite_resolver: sqlite_config_resolver,
state_dir,
})
}

Expand All @@ -102,24 +111,32 @@ where
}
Ok(())
}

/// The fully resolved state directory.
pub fn state_dir(&self) -> Option<PathBuf> {
self.state_dir.clone()
}
}

impl<T: Default> ResolvedRuntimeConfig<T> {
/// Creates a new resolved runtime configuration with default values.
pub fn default(state_dir: Option<&str>) -> Self {
let state_dir = state_dir.map(PathBuf::from);
Self {
sqlite_resolver: sqlite_config_resolver(state_dir.clone())
.expect("failed to resolve sqlite runtime config"),
key_value_resolver: key_value_config_resolver(state_dir),
key_value_resolver: key_value_config_resolver(state_dir.clone()),
runtime_config: Default::default(),
state_dir,
}
}
}

/// The TOML based runtime configuration source Spin CLI.
pub struct TomlRuntimeConfigSource<'a> {
table: TomlKeyTracker<'a>,
state_dir: PathBuf,
/// Explicitly provided state directory.
state_dir: Option<&'a str>,
key_value: &'a key_value::RuntimeConfigResolver,
tls: &'a SpinTlsRuntimeConfig,
sqlite: &'a sqlite::RuntimeConfigResolver,
Expand All @@ -129,7 +146,7 @@ pub struct TomlRuntimeConfigSource<'a> {
impl<'a> TomlRuntimeConfigSource<'a> {
pub fn new(
table: &'a toml::Table,
state_dir: PathBuf,
state_dir: Option<&'a str>,
key_value: &'a key_value::RuntimeConfigResolver,
tls: &'a SpinTlsRuntimeConfig,
sqlite: &'a sqlite::RuntimeConfigResolver,
Expand All @@ -144,6 +161,17 @@ impl<'a> TomlRuntimeConfigSource<'a> {
use_gpu,
}
}

/// Get the configured state_directory.
pub fn state_dir(&self) -> Option<PathBuf> {
let from_toml = || self.table.get("state_dir").and_then(|v| v.as_str());
// Prefer explicitly provided state directory, then take from toml.
self.state_dir
.or_else(from_toml)
// Treat "" as None.
.filter(|s| !s.is_empty())
.map(PathBuf::from)
}
}

impl FactorRuntimeConfigSource<KeyValueFactor> for TomlRuntimeConfigSource<'_> {
Expand Down Expand Up @@ -187,7 +215,7 @@ impl FactorRuntimeConfigSource<OutboundMysqlFactor> for TomlRuntimeConfigSource<

impl FactorRuntimeConfigSource<LlmFactor> for TomlRuntimeConfigSource<'_> {
fn get_runtime_config(&mut self) -> anyhow::Result<Option<spin_factor_llm::RuntimeConfig>> {
llm::runtime_config_from_toml(self.table.as_ref(), self.state_dir.clone(), self.use_gpu)
llm::runtime_config_from_toml(self.table.as_ref(), self.state_dir(), self.use_gpu)
}
}

Expand Down
7 changes: 4 additions & 3 deletions crates/trigger/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use spin_common::ui::quoted_path;
use spin_common::url::parse_file_url;
use spin_common::{arg_parser::parse_kv, sloth};
use spin_factors_executor::{ComponentLoader, FactorsExecutor};
use spin_runtime_config::{ResolvedRuntimeConfig, DEFAULT_STATE_DIR};
use spin_runtime_config::ResolvedRuntimeConfig;

use crate::factors::{TriggerFactors, TriggerFactorsRuntimeConfig};
use crate::stdio::{FollowComponents, StdioLoggingExecutorHooks};
Expand Down Expand Up @@ -321,13 +321,14 @@ impl<T: Trigger> TriggerAppBuilder<T> {
.await?;

let factors = TriggerFactors::new(
options.state_dir.unwrap_or(DEFAULT_STATE_DIR),
runtime_config.state_dir(),
self.working_dir.clone(),
options.allow_transient_write,
runtime_config.key_value_resolver,
runtime_config.sqlite_resolver,
use_gpu,
);
)
.context("failed to create factors")?;

// TODO: move these into Factor methods/constructors
// let init_data = crate::HostComponentInitData::new(
Expand Down
17 changes: 9 additions & 8 deletions crates/trigger/src/factors.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::path::PathBuf;

use anyhow::Context as _;
use spin_factor_key_value::KeyValueFactor;
use spin_factor_llm::LlmFactor;
use spin_factor_outbound_http::OutboundHttpFactor;
Expand Down Expand Up @@ -31,14 +32,14 @@ pub struct TriggerFactors {

impl TriggerFactors {
pub fn new(
state_dir: impl Into<PathBuf>,
state_dir: Option<PathBuf>,
working_dir: impl Into<PathBuf>,
allow_transient_writes: bool,
default_key_value_label_resolver: impl spin_factor_key_value::DefaultLabelResolver + 'static,
default_sqlite_label_resolver: impl spin_factor_sqlite::DefaultLabelResolver + 'static,
use_gpu: bool,
) -> Self {
Self {
) -> anyhow::Result<Self> {
Ok(Self {
wasi: wasi_factor(working_dir, allow_transient_writes),
variables: VariablesFactor::default(),
key_value: KeyValueFactor::new(default_key_value_label_resolver),
Expand All @@ -49,11 +50,11 @@ impl TriggerFactors {
mqtt: OutboundMqttFactor::new(NetworkedMqttClient::creator()),
pg: OutboundPgFactor::new(),
mysql: OutboundMysqlFactor::new(),
llm: LlmFactor::new(spin_factor_llm::spin::default_engine_creator(
state_dir.into(),
use_gpu,
)),
}
llm: LlmFactor::new(
spin_factor_llm::spin::default_engine_creator(state_dir, use_gpu)
.context("failed to configure LLM factor")?,
),
})
}
}

Expand Down

0 comments on commit 7c4b268

Please sign in to comment.