Skip to content

Commit

Permalink
feat: improve tokenizer config (#21)
Browse files Browse the repository at this point in the history
* feat: improve tokenizer config

* fix: add untagged decorator to `TokenizerConfig`

* feat: bump version to `0.2.0`
  • Loading branch information
McPatate authored Sep 21, 2023
1 parent eeb443f commit 787f2a1
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/llm-ls/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "llm-ls"
version = "0.1.1"
version = "0.2.0"
edition = "2021"

[[bin]]
Expand Down
127 changes: 75 additions & 52 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ use tracing_subscriber::EnvFilter;
const NAME: &str = "llm-ls";
const VERSION: &str = env!("CARGO_PKG_VERSION");

#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
enum TokenizerConfig {
Local { path: PathBuf },
HuggingFace { repository: String },
Download { url: String, to: PathBuf },
}

#[derive(Clone, Debug, Deserialize, Serialize)]
struct RequestParams {
max_new_tokens: u32,
Expand Down Expand Up @@ -120,36 +128,31 @@ struct Completion {
generated_text: String,
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
enum IDE {
enum Ide {
Neovim,
VSCode,
JetBrains,
Emacs,
Jupyter,
Sublime,
VisualStudio,
#[default]
Unknown,
}

impl Default for IDE {
fn default() -> Self {
IDE::Unknown
}
}

impl Display for IDE {
impl Display for Ide {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.serialize(f)
}
}

fn parse_ide<'de, D>(d: D) -> std::result::Result<IDE, D::Error>
fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown))
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
}

#[derive(Debug, Deserialize, Serialize)]
Expand All @@ -159,12 +162,12 @@ struct CompletionParams {
request_params: RequestParams,
#[serde(default)]
#[serde(deserialize_with = "parse_ide")]
ide: IDE,
ide: Ide,
fim: FimParams,
api_token: Option<String>,
model: String,
tokens_to_clear: Vec<String>,
tokenizer_path: Option<String>,
tokenizer_config: Option<TokenizerConfig>,
context_window: usize,
tls_skip_verify_insecure: bool,
}
Expand All @@ -183,7 +186,7 @@ fn build_prompt(
pos: Position,
text: &Rope,
fim: &FimParams,
tokenizer: Arc<Tokenizer>,
tokenizer: Option<Arc<Tokenizer>>,
context_window: usize,
) -> Result<String> {
let t = Instant::now();
Expand All @@ -206,10 +209,14 @@ fn build_prompt(
while before_line.is_some() || after_line.is_some() {
if let Some(before_line) = before_line {
let before_line = before_line.to_string();
let tokens = tokenizer
.encode(before_line.clone(), false)
.map_err(internal_error)?
.len();
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(before_line.clone(), false)
.map_err(internal_error)?
.len()
} else {
before_line.len()
};
if tokens > token_count {
break;
}
Expand All @@ -218,10 +225,14 @@ fn build_prompt(
}
if let Some(after_line) = after_line {
let after_line = after_line.to_string();
let tokens = tokenizer
.encode(after_line.clone(), false)
.map_err(internal_error)?
.len();
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(after_line.clone(), false)
.map_err(internal_error)?
.len()
} else {
after_line.len()
};
if tokens > token_count {
break;
}
Expand Down Expand Up @@ -253,10 +264,14 @@ fn build_prompt(
first = false;
}
let line = line.to_string();
let tokens = tokenizer
.encode(line.clone(), false)
.map_err(internal_error)?
.len();
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(line.clone(), false)
.map_err(internal_error)?
.len()
} else {
line.len()
};
if tokens > token_count {
break;
}
Expand All @@ -272,7 +287,7 @@ fn build_prompt(

async fn request_completion(
http_client: &reqwest::Client,
ide: IDE,
ide: Ide,
model: &str,
request_params: RequestParams,
api_token: Option<&String>,
Expand Down Expand Up @@ -311,9 +326,10 @@ fn parse_generations(generations: Vec<Generation>, tokens_to_clear: &[String]) -

async fn download_tokenizer_file(
http_client: &reqwest::Client,
model: &str,
url: &str,
api_token: Option<&String>,
to: impl AsRef<Path>,
ide: Ide,
) -> Result<()> {
if to.as_ref().exists() {
return Ok(());
Expand All @@ -325,13 +341,9 @@ async fn download_tokenizer_file(
)
.await
.map_err(internal_error)?;
let mut req = http_client.get(format!(
"https://huggingface.co/{model}/resolve/main/tokenizer.json"
));
if let Some(api_token) = api_token {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
let res = req
let res = http_client
.get(url)
.headers(build_headers(api_token, ide)?)
.send()
.await
.map_err(internal_error)?
Expand All @@ -352,27 +364,37 @@ async fn download_tokenizer_file(
async fn get_tokenizer(
model: &str,
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
tokenizer_path: Option<&String>,
tokenizer_config: Option<TokenizerConfig>,
http_client: &reqwest::Client,
cache_dir: impl AsRef<Path>,
api_token: Option<&String>,
) -> Result<Arc<Tokenizer>> {
ide: Ide,
) -> Result<Option<Arc<Tokenizer>>> {
if let Some(tokenizer) = tokenizer_map.get(model) {
return Ok(tokenizer.clone());
return Ok(Some(tokenizer.clone()));
}
let tokenizer = if model.starts_with("http://") || model.starts_with("https://") {
match tokenizer_path {
Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?),
None => return Err(internal_error("`tokenizer_path` is null")),
}
if let Some(config) = tokenizer_config {
let tokenizer = match config {
TokenizerConfig::Local { path } => {
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
}
TokenizerConfig::HuggingFace { repository } => {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
let url =
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
}
TokenizerConfig::Download { url, to } => {
download_tokenizer_file(http_client, &url, api_token, &to, ide).await?;
Arc::new(Tokenizer::from_file(to).map_err(internal_error)?)
}
};
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(Some(tokenizer))
} else {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
download_tokenizer_file(http_client, model, api_token, &path).await?;
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
};

tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(tokenizer)
Ok(None)
}
}

fn build_url(model: &str) -> String {
Expand All @@ -394,10 +416,11 @@ impl Backend {
let tokenizer = get_tokenizer(
&params.model,
&mut *self.tokenizer_map.write().await,
params.tokenizer_path.as_ref(),
params.tokenizer_config,
&self.http_client,
&self.cache_dir,
params.api_token.as_ref(),
params.ide,
)
.await?;
let prompt = build_prompt(
Expand Down Expand Up @@ -508,7 +531,7 @@ impl LanguageServer for Backend {
}
}

fn build_headers(api_token: Option<&String>, ide: IDE) -> Result<HeaderMap> {
fn build_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert(
Expand Down
18 changes: 16 additions & 2 deletions crates/mock_server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
use serde::{Deserialize, Serialize};
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tokio::{
sync::Mutex,
time::{sleep, Duration},
};

#[derive(Clone)]
struct AppState {
Expand Down Expand Up @@ -41,6 +44,16 @@ async fn log_headers(headers: HeaderMap, state: State<AppState>) -> Json<Generat
})
}

async fn wait(state: State<AppState>) -> Json<GeneratedText> {
let mut lock = state.counter.lock().await;
*lock += 1;
sleep(Duration::from_millis(200)).await;
println!("waited for req {}", lock);
Json(GeneratedText {
generated_text: "dummy".to_owned(),
})
}

#[tokio::main]
async fn main() {
let app_state = AppState {
Expand All @@ -50,11 +63,12 @@ async fn main() {
.route("/", post(default))
.route("/tgi", post(tgi))
.route("/headers", post(log_headers))
.route("/wait", post(wait))
.with_state(app_state);
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
.parse()
.expect("string to parse to socket addr");
println!("starting server {}:{}", addr.ip().to_string(), addr.port(),);
println!("starting server {}:{}", addr.ip(), addr.port(),);

axum::Server::bind(&addr)
.serve(app.into_make_service())
Expand Down

0 comments on commit 787f2a1

Please sign in to comment.