Skip to content

Commit

Permalink
Begin work of fixing tests, formatting, linting, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Jun 25, 2024
1 parent a404292 commit 1a8985b
Show file tree
Hide file tree
Showing 15 changed files with 190 additions and 205 deletions.
5 changes: 5 additions & 0 deletions server/config/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@ top_p = 0.7
[openai]
api_url = "https://api.openai.com/v1/chat/completions"
model = "gpt-4o"
api_key = "<openai-api-key>"

[query_rephraser]
model = "mistralai/Mistral-7B-Instruct-v0.2"
max_tokens = 100
api_key = "<query-rephraser-api-key>"

[llm]
toxicity_auth_token = "<toxicity-auth-token>"
toxicity_threshold = 0.75

[pubmed]
url_prefix = "https://pubmed.ncbi.nlm.nih.gov"

[brave]
subscription_key = "<subscription-key>"
goggles_id = "<goggles-id>"
url = "https://api.search.brave.com/res/v1/web/search"
count = 10
result_filter = "query,web"
Expand Down
2 changes: 1 addition & 1 deletion server/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ pub mod models;
pub mod oauth2;
pub mod routes;
pub mod services;
pub mod utils;
pub(crate) mod sessions;
pub mod utils;
3 changes: 2 additions & 1 deletion server/src/auth/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ async fn oauth_authenticate(
.map_err(BackendError::Reqwest)?;

// Persist user in our database, so we can use `get_user`.
let user = sqlx::query_as!(User,
let user = sqlx::query_as!(
User,
"
insert into users (username, access_token)
values ($1, $2)
Expand Down
15 changes: 11 additions & 4 deletions server/src/auth/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use color_eyre::eyre::eyre;
use sqlx::PgPool;

#[tracing::instrument(level = "debug", ret, err)]
pub async fn register(pool: PgPool, request: models::RegisterUserRequest) -> crate::Result<UserRecord> {
pub async fn register(
pool: PgPool,
request: models::RegisterUserRequest,
) -> crate::Result<UserRecord> {
if let Some(password) = request.password {
let password_hash = utils::hash_password(password).await?;
let user = sqlx::query_as!(
Expand Down Expand Up @@ -49,12 +52,16 @@ pub async fn register(pool: PgPool, request: models::RegisterUserRequest) -> cra
}

pub async fn is_email_whitelisted(pool: &PgPool, email: &String) -> crate::Result<bool> {
let whitelisted_email = sqlx::query_as!(models::WhitelistedEmail, "SELECT * FROM whitelisted_emails WHERE email = $1", email)
let whitelisted_email = sqlx::query_as!(
models::WhitelistedEmail,
"SELECT * FROM whitelisted_emails WHERE email = $1",
email
)
.fetch_one(pool)
.await;

match whitelisted_email {
Ok(whitelisted_email) => Ok(whitelisted_email.approved),
_ => Ok(false)
_ => Ok(false),
}
}
}
21 changes: 13 additions & 8 deletions server/src/llms/toxicity.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use crate::llms::LLMSettings;
use color_eyre::eyre::eyre;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};

#[derive(Debug, Serialize, Deserialize)]
pub struct ToxicityInput {
pub inputs: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct ToxicityAPIResponse (pub Vec<ToxicityScore>);
struct ToxicityAPIResponse(pub Vec<ToxicityScore>);

#[derive(Debug, Serialize, Deserialize)]
struct ToxicityScore {
Expand All @@ -25,8 +25,10 @@ pub async fn predict_toxicity(
) -> crate::Result<bool> {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_bytes(b"Authorization").map_err(|e| eyre!("Failed to create header: {e}"))?,
HeaderValue::from_str(&llm_settings.toxicity_auth_token.expose()).map_err(|e| eyre!("Failed to create header: {e}"))?,
HeaderName::from_bytes(b"Authorization")
.map_err(|e| eyre!("Failed to create header: {e}"))?,
HeaderValue::from_str(&llm_settings.toxicity_auth_token.expose())
.map_err(|e| eyre!("Failed to create header: {e}"))?,
);
let client = Client::new();

Expand All @@ -43,9 +45,12 @@ pub async fn predict_toxicity(
.await
.map_err(|e| eyre!("Failed to parse toxicity response: {e}"))?;

let toxicity_score = toxicity_api_response.into_iter().find(|x| x.label == String::from("toxic")).unwrap_or(ToxicityScore {
score: 0.0,
label: String::from(""),
});
let toxicity_score = toxicity_api_response
.into_iter()
.find(|x| x.label == String::from("toxic"))
.unwrap_or(ToxicityScore {
score: 0.0,
label: String::from(""),
});
Ok(toxicity_score.score > llm_settings.toxicity_threshold)
}
94 changes: 48 additions & 46 deletions server/src/rag/brave_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,54 @@ pub struct BraveAPIConfig {
pub headers: HeaderMap<HeaderValue>,
}

impl From<BraveSettings> for BraveAPIConfig {
fn from(brave_settings: BraveSettings) -> Self {
let queries = vec![
(String::from("count"), brave_settings.count.to_string()),
(
String::from("goggles_id"),
brave_settings.goggles_id.clone(),
),
(
String::from("result_filter"),
brave_settings.result_filter.clone(),
),
(
String::from("search_lang"),
brave_settings.search_lang.clone(),
),
(
String::from("extra_snippets"),
brave_settings.extra_snippets.to_string(),
),
(
String::from("safesearch"),
brave_settings.safesearch.clone(),
),
];

let headers = HeaderMap::from_iter(
vec![
("Accept", "application/json"),
("Accept-Encoding", "gzip"),
(
"X-Subscription-Token",
brave_settings.subscription_key.expose(),
),
]
.into_iter()
.map(|(k, v)| {
(
HeaderName::from_bytes(k.as_bytes()).unwrap(),
HeaderValue::from_str(v).unwrap(),
)
}),
);

BraveAPIConfig { queries, headers }
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BraveWebSearchResult {
pub title: String,
Expand All @@ -46,52 +94,6 @@ struct BraveAPIResponse {
pub web: BraveWebAPIResponse,
}

pub fn prepare_brave_api_config(brave_settings: &BraveSettings) -> BraveAPIConfig {
let queries = vec![
(String::from("count"), brave_settings.count.to_string()),
(
String::from("goggles_id"),
brave_settings.goggles_id.clone(),
),
(
String::from("result_filter"),
brave_settings.result_filter.clone(),
),
(
String::from("search_lang"),
brave_settings.search_lang.clone(),
),
(
String::from("extra_snippets"),
brave_settings.extra_snippets.to_string(),
),
(
String::from("safesearch"),
brave_settings.safesearch.clone(),
),
];

let headers = HeaderMap::from_iter(
vec![
("Accept", "application/json"),
("Accept-Encoding", "gzip"),
(
"X-Subscription-Token",
brave_settings.subscription_key.expose(),
),
]
.into_iter()
.map(|(k, v)| {
(
HeaderName::from_bytes(k.as_bytes()).unwrap(),
HeaderValue::from_str(v).unwrap(),
)
}),
);

BraveAPIConfig { queries, headers }
}

#[tracing::instrument(level = "debug", ret, err)]
pub async fn web_search(
brave_settings: &BraveSettings,
Expand Down
6 changes: 3 additions & 3 deletions server/src/rag/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub async fn search(
brave_api_config: &brave_search::BraveAPIConfig,
cache: &CachePool,
agency_service: &mut AgencyServiceClient<Channel>,
search_query: &String,
search_query: &str,
) -> crate::Result<rag::SearchResponse> {
if let Some(response) = cache.get(&search_query).await {
return Ok(response);
Expand Down Expand Up @@ -47,7 +47,7 @@ pub async fn search(
let compressed_results = prompt_compression::compress(
&settings.llm,
prompt_compression::PromptCompressionInput {
query: search_query.clone(),
query: search_query.to_string(),
target_token: 300,
context_texts_list: retrieved_results.iter().map(|r| r.text.clone()).collect(),
},
Expand All @@ -67,7 +67,7 @@ pub async fn search(
async fn retrieve_result_from_agency(
settings: &Settings,
agency_service: &mut AgencyServiceClient<Channel>,
search_query: &String,
search_query: &str,
) -> crate::Result<Vec<rag::RetrievedResult>> {
let agency_service = Arc::new(agency_service.clone());
let query_embeddings =
Expand Down
2 changes: 1 addition & 1 deletion server/src/search/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub async fn insert_new_search(
pool: &PgPool,
user_id: &Uuid,
search_query_request: &api_models::SearchQueryRequest,
rephrased_query: &String,
rephrased_query: &str,
) -> crate::Result<data_models::Search> {
let thread = match search_query_request.thread_id {
Some(thread_id) => {
Expand Down
2 changes: 1 addition & 1 deletion server/src/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl AppState {
cache: CachePool::new(&settings.cache).await?,
agency_service: agency_service_connect(settings.agency_api.expose()).await?,
oauth2_clients: settings.oauth2_clients.clone(),
brave_config: brave_search::prepare_brave_api_config(&settings.brave),
brave_config: settings.brave.clone().into(),
settings,
openai_stream_regex: Regex::new(r#"\"content\":\"(.*?)\"}"#)
.map_err(|e| eyre!("Failed to compile OpenAI stream regex: {}", e))?,
Expand Down
39 changes: 20 additions & 19 deletions server/src/users/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ use std::fmt::Debug;

#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum UserGroup {
Alpha,
Beta,
Public,
Alpha,
Beta,
Public,
} // Move Public to top before public release

impl From<i32> for UserGroup {
fn from(value: i32) -> Self {
match value {
0 => UserGroup::Alpha,
1 => UserGroup::Beta,
_ => UserGroup::Public,
}
}
fn from(value: i32) -> Self {
match value {
0 => UserGroup::Alpha,
1 => UserGroup::Beta,
_ => UserGroup::Public,
}
}
}

#[derive(sqlx::FromRow, Serialize, Clone, Debug)]
Expand Down Expand Up @@ -117,7 +117,6 @@ pub struct UpdatePasswordRequest {
pub new_password: Secret<String>,
}


#[derive(Serialize, Deserialize, Debug)]
pub struct UpdateProfileRequest {
pub username: Option<String>,
Expand All @@ -127,14 +126,16 @@ pub struct UpdateProfileRequest {
pub company: Option<String>,
}


impl UpdateProfileRequest {
pub fn has_any_value(&self) -> bool {
[self.username.is_some(),
self.email.is_some(),
self.fullname.is_some(),
self.title.is_some(),
self.company.is_some()
].iter().any(|&x| x)
[
self.username.is_some(),
self.email.is_some(),
self.fullname.is_some(),
self.title.is_some(),
self.company.is_some(),
]
.iter()
.any(|&x| x)
}
}
}
Loading

0 comments on commit 1a8985b

Please sign in to comment.