diff --git a/server/src/auth/api_models.rs b/server/src/auth/api_models.rs index 05016ee3..738c5ab9 100644 --- a/server/src/auth/api_models.rs +++ b/server/src/auth/api_models.rs @@ -1,13 +1,36 @@ use crate::secrets::Secret; +use oauth2::{basic::BasicRequestTokenError, reqwest::AsyncHttpClientError}; use serde::de::Error; use serde::{Deserialize, Deserializer}; +use tokio::task; use validator::Validate; -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum AuthError { + #[error(transparent)] + Sqlx(#[from] sqlx::Error), + + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + + #[error(transparent)] + OAuth2(#[from] BasicRequestTokenError), + + #[error(transparent)] + TaskJoin(#[from] task::JoinError), + + #[error("Unauthorized: {0}")] Unauthorized(String), + #[error("Invalid session: {0}")] InvalidSession(String), - BackendError(String), + #[error("Other error: {0}")] + UserAlreadyExists(String), + #[error("Other error: {0}")] + InvalidData(String), + #[error("Not whitelisted: {0}")] + NotWhitelisted(String), + #[error("Other error: {0}")] + Other(String), } #[derive(Deserialize, Clone, Debug, Validate)] diff --git a/server/src/auth/models.rs b/server/src/auth/models.rs index 21c5454e..94b93622 100644 --- a/server/src/auth/models.rs +++ b/server/src/auth/models.rs @@ -1,8 +1,7 @@ use crate::auth::{oauth2::OAuth2Client, utils, AuthError}; -use crate::err::AppError; use crate::secrets::Secret; use crate::telemetry::spawn_blocking_with_tracing; -use crate::users::{User, UserError}; +use crate::users::User; use async_trait::async_trait; use axum::http::header::{AUTHORIZATION, USER_AGENT}; use axum_login::{AuthnBackend, UserId}; @@ -39,7 +38,7 @@ impl PostgresBackend { async fn password_authenticate( db: &PgPool, password_credentials: PasswordCredentials, -) -> crate::Result> { +) -> Result, AuthError> { let user = sqlx::query_as!( User, "select * from users where email = $1 and password_hash is not null", @@ -61,7 +60,7 @@ async fn oauth_authenticate( db: &PgPool, oauth2_clients: &[OAuth2Client], oauth_creds: OAuthCredentials, -) -> crate::Result> { +) -> Result, AuthError> { // Ensure the CSRF state has not been tampered with. if oauth_creds.old_state.secret() != oauth_creds.new_state.secret() { return Ok(None); @@ -111,15 +110,18 @@ async fn oauth_authenticate( impl AuthnBackend for PostgresBackend { type User = User; type Credentials = Credentials; - type Error = AppError; + type Error = AuthError; #[tracing::instrument(level = "info", skip(self), ret, err)] - async fn authenticate(&self, creds: Self::Credentials) -> crate::Result> { + async fn authenticate( + &self, + creds: Self::Credentials, + ) -> Result, AuthError> { match creds { Credentials::Password(password_cred) => { password_cred .validate() - .map_err(|e| UserError::InvalidData(format!("Invalid credentials: {}", e)))?; + .map_err(|e| AuthError::Unauthorized(format!("Invalid credentials: {}", e)))?; password_authenticate(&self.db, password_cred).await } Credentials::OAuth(oauth_creds) => { @@ -129,7 +131,7 @@ impl AuthnBackend for PostgresBackend { } #[tracing::instrument(level = "info", skip(self), ret, err)] - async fn get_user(&self, user_id: &UserId) -> crate::Result> { + async fn get_user(&self, user_id: &UserId) -> Result, AuthError> { sqlx::query_as!( Self::User, "select * from users where user_id = $1", diff --git a/server/src/auth/routes.rs b/server/src/auth/routes.rs index 1b2dab14..44f1c99e 100644 --- a/server/src/auth/routes.rs +++ b/server/src/auth/routes.rs @@ -3,7 +3,7 @@ use crate::auth::{ AuthError, AuthSession, Credentials, OAuthCredentials, PasswordCredentials, RegisterUserRequest, }; use crate::startup::AppState; -use crate::users::{UserError, UserRecord}; +use crate::users::UserRecord; use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Redirect, Response}; @@ -22,10 +22,10 @@ async fn register_handler( ) -> crate::Result { request .validate() - .map_err(|e| UserError::InvalidData(format!("Invalid data: {}", e)))?; + .map_err(|e| AuthError::InvalidData(format!("Invalid data: {}", e)))?; if !services::is_email_whitelisted(&pool, &request.email).await? { - return Err(UserError::NotWhitelisted(format!("This email is not whitelisted!")).into()); + return Err(AuthError::NotWhitelisted(format!("This email is not whitelisted!")).into()); } services::register(pool, request) .await @@ -43,13 +43,11 @@ async fn login_handler( { Ok(Some(user)) => user, Ok(None) => return Err(AuthError::Unauthorized(format!("Invalid credentials")).into()), - Err(_) => { - return Err(AuthError::BackendError(format!("Could not authenticate user")).into()) - } + Err(_) => return Err(AuthError::Other(format!("Could not authenticate user")).into()), }; if auth_session.login(&user).await.is_err() { - return Err(AuthError::BackendError(format!("Could not login user")).into()); + return Err(AuthError::Other(format!("Could not login user")).into()); } //if let Credentials::Password(_pw_creds) = creds { // if let Some(ref next) = pw_creds.next { @@ -123,13 +121,11 @@ pub async fn oauth_callback_handler( let user = match auth_session.authenticate(creds).await { Ok(Some(user)) => user, Ok(None) => return Err(AuthError::Unauthorized(format!("Invalid credentials")).into()), - Err(_) => { - return Err(AuthError::BackendError(format!("Could not authenticate user")).into()) - } + Err(_) => return Err(AuthError::Other(format!("Could not authenticate user")).into()), }; if auth_session.login(&user).await.is_err() { - return Err(AuthError::BackendError(format!("Could not login user")).into()); + return Err(AuthError::Other(format!("Could not login user")).into()); } if let Ok(Some(next)) = session.remove::(NEXT_URL_KEY).await { @@ -144,7 +140,7 @@ async fn logout_handler(mut auth_session: AuthSession) -> crate::Result<()> { auth_session .logout() .await - .map_err(|e| AuthError::BackendError(format!("Could not logout user: {}", e)))?; + .map_err(|e| AuthError::Other(format!("Could not logout user: {}", e)))?; Ok(()) } diff --git a/server/src/auth/services.rs b/server/src/auth/services.rs index 8a434e6b..a295a079 100644 --- a/server/src/auth/services.rs +++ b/server/src/auth/services.rs @@ -1,5 +1,7 @@ +use crate::auth::AuthError; use crate::auth::{api_models, models, utils}; -use crate::users::{User, UserError, UserRecord}; +use crate::err::ResultExt; +use crate::users::{User, UserRecord}; use sqlx::PgPool; #[tracing::instrument(level = "info", ret, err)] @@ -17,7 +19,21 @@ pub async fn register( password_hash.expose() ) .fetch_one(&pool) - .await?; + .await + .on_constraint("users_username_key", |_| { + AuthError::UserAlreadyExists(format!( + "User with username {} already exists", + request.email + )) + .into() + }) + .on_constraint("users_email_key", |_| { + AuthError::UserAlreadyExists(format!( + "User with email {} already exists", + request.email + )) + .into() + })?; return Ok(user.into()); } else if let Some(access_token) = request.access_token { @@ -29,18 +45,32 @@ pub async fn register( access_token.expose() ) .fetch_one(&pool) - .await?; + .await + .on_constraint("users_username_key", |_| { + AuthError::UserAlreadyExists(format!( + "User with username {} already exists", + request.email + )) + .into() + }) + .on_constraint("users_email_key", |_| { + AuthError::UserAlreadyExists(format!( + "User with email {} already exists", + request.email + )) + .into() + })?; return Ok(user.into()); } - Err(UserError::InvalidData(format!( + Err(AuthError::InvalidData(format!( "Either password or access_token must be provided to create a user" )) .into()) } -pub async fn is_email_whitelisted(pool: &PgPool, email: &String) -> crate::Result { +pub async fn is_email_whitelisted(pool: &PgPool, email: &String) -> Result { let whitelisted_email = sqlx::query_as!( models::WhitelistedEmail, "SELECT * FROM whitelisted_emails WHERE email = $1", diff --git a/server/src/auth/utils.rs b/server/src/auth/utils.rs index 7bbf54e6..b67eed14 100644 --- a/server/src/auth/utils.rs +++ b/server/src/auth/utils.rs @@ -1,3 +1,4 @@ +use crate::auth::AuthError; use crate::secrets::Secret; use crate::users::User; use password_auth::{generate_hash, verify_password}; @@ -7,7 +8,7 @@ use tokio::task::spawn_blocking; pub fn verify_user_password( user: Option, password_candidate: Secret, -) -> crate::Result> { +) -> Result, AuthError> { // password-based authentication. Compare our form input with an argon2 // password hash. // To prevent timed side-channel attacks, so we always compare the password @@ -31,7 +32,7 @@ pub fn verify_user_password( } // Prevent side-channel attacks by always verifying the password. -pub fn dummy_verify_password(pw: Secret>) -> crate::Result> { +pub fn dummy_verify_password(pw: Secret>) -> Result, AuthError> { let _ = verify_password( pw.expose_owned().as_ref(), "$argon2id$v=19$m=15000,t=2,p=1$\ @@ -42,6 +43,6 @@ pub fn dummy_verify_password(pw: Secret>) -> crate::Result) -> crate::Result> { +pub async fn hash_password(password: Secret) -> Result, AuthError> { Ok(spawn_blocking(move || Secret::new(generate_hash(password.expose().as_bytes()))).await?) } diff --git a/server/src/cache.rs b/server/src/cache.rs index 6f5a2006..99198a77 100644 --- a/server/src/cache.rs +++ b/server/src/cache.rs @@ -1,4 +1,3 @@ -use crate::err::AppError; use crate::secrets::Secret; use axum::extract::FromRef; use bb8::Pool; @@ -34,12 +33,6 @@ pub enum CacheError { Serde(#[from] serde_json::Error), } -impl From for AppError { - fn from(e: CacheError) -> Self { - AppError::Cache(e) - } -} - impl CachePool { pub async fn new(cache_settings: &CacheSettings) -> Result { tracing::debug!("Connecting to redis"); diff --git a/server/src/err.rs b/server/src/err.rs index 7be1acfe..d91d88b4 100644 --- a/server/src/err.rs +++ b/server/src/err.rs @@ -3,25 +3,53 @@ use axum::http::header::WWW_AUTHENTICATE; use axum::http::{HeaderMap, HeaderValue, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::Json; -use std::fmt::{Debug, Display}; +use sqlx::error::DatabaseError; +use std::fmt::Debug; use std::{borrow::Cow, collections::HashMap}; use tokio::task; #[derive(Debug, thiserror::Error)] pub enum AppError { - SearchError(SearchError), - AuthError(AuthError), - UserError(UserError), - Sqlx(sqlx::Error), - GenericError(color_eyre::eyre::Error), - Cache(CacheError), - Reqwest(reqwest::Error), + #[error(transparent)] + SearchError(#[from] SearchError), + #[error(transparent)] + AuthError(#[from] AuthError), + #[error(transparent)] + UserError(#[from] UserError), + + #[error(transparent)] + Sqlx(#[from] sqlx::Error), + + #[error(transparent)] + GenericError(#[from] color_eyre::eyre::Error), + + #[error(transparent)] + Cache(#[from] CacheError), + + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + + #[error(transparent)] TaskJoin(#[from] task::JoinError), } impl AppError { fn to_error_code(&self) -> String { match self { + AppError::SearchError(SearchError::Sqlx(err)) + | AppError::AuthError(AuthError::Sqlx(err)) + | AppError::UserError(UserError::Sqlx(err)) + | AppError::Sqlx(err) => match err { + sqlx::Error::RowNotFound => "resource_not_found".to_string(), + sqlx::Error::Protocol(_) => "invalid_data".to_string(), + sqlx::Error::Database(db_err) => match db_err.code().as_deref() { + Some("23505") => "unique_key_violation".to_string(), + Some("23503") => "foreign_key_violation".to_string(), + _ => "internal_server_error".to_string(), + }, + _ => "internal_server_error".to_string(), + }, + AppError::SearchError(err) => match err { SearchError::ToxicQuery(_) => format!("toxic_query"), SearchError::InvalidQuery(_) => format!("invalid_query"), @@ -29,29 +57,17 @@ impl AppError { _ => format!("internal_server_error"), }, AppError::UserError(err) => match err { - UserError::NotWhitelisted(_) => "not_whitelisted".to_string(), UserError::InvalidData(_) => "invalid_data".to_string(), UserError::InvalidPassword(_) => "invalid_password".to_string(), + _ => "internal_server_error".to_string(), }, AppError::AuthError(err) => match err { - AuthError::Unauthorized(_) => "unauthorized".to_string(), + AuthError::Unauthorized(_) | AuthError::OAuth2(_) => "unauthorized".to_string(), AuthError::InvalidSession(_) => "invalid_session".to_string(), - AuthError::BackendError(_) => "backend_error".to_string(), - }, - - AppError::Sqlx(err) => match err { - sqlx::Error::RowNotFound => "resource_not_found".to_string(), - sqlx::Error::Protocol(_) => "invalid_data".to_string(), - sqlx::Error::ColumnDecode { - index: _, - source: _, - } => "internal_server_error".to_string(), - sqlx::Error::Database(db_err) => match db_err.code().as_deref() { - Some("23505") => "unique_key_violation".to_string(), - Some("23503") => "foreign_key_violation".to_string(), - _ => "database_error".to_string(), - }, - _ => "database_error".to_string(), + AuthError::NotWhitelisted(_) => "not_whitelisted".to_string(), + AuthError::UserAlreadyExists(_) => "user_already_exists".to_string(), + AuthError::InvalidData(_) => "invalid_data".to_string(), + _ => "internal_server_error".to_string(), }, _ => "internal_server_error".to_string(), } @@ -59,6 +75,20 @@ impl AppError { fn to_status_code(&self) -> StatusCode { match self { + AppError::SearchError(SearchError::Sqlx(err)) + | AppError::AuthError(AuthError::Sqlx(err)) + | AppError::UserError(UserError::Sqlx(err)) + | AppError::Sqlx(err) => match err { + sqlx::Error::RowNotFound => StatusCode::NOT_FOUND, + sqlx::Error::Protocol(_) => StatusCode::BAD_REQUEST, + sqlx::Error::Database(db_err) => match db_err.code().as_deref() { + Some("23505") => StatusCode::CONFLICT, + Some("23503") => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + AppError::SearchError(err) => match err { SearchError::ToxicQuery(_) | SearchError::InvalidQuery(_) => { StatusCode::UNPROCESSABLE_ENTITY @@ -67,25 +97,17 @@ impl AppError { _ => StatusCode::INTERNAL_SERVER_ERROR, }, AppError::AuthError(err) => match err { - AuthError::Unauthorized(_) | AuthError::InvalidSession(_) => { - StatusCode::UNAUTHORIZED - } - AuthError::BackendError(_) => StatusCode::INTERNAL_SERVER_ERROR, + AuthError::Unauthorized(_) + | AuthError::InvalidSession(_) + | AuthError::OAuth2(_) => StatusCode::UNAUTHORIZED, + AuthError::NotWhitelisted(_) => StatusCode::FORBIDDEN, + AuthError::UserAlreadyExists(_) => StatusCode::CONFLICT, + AuthError::InvalidData(_) => StatusCode::UNPROCESSABLE_ENTITY, + _ => StatusCode::INTERNAL_SERVER_ERROR, }, AppError::UserError(err) => match err { - UserError::NotWhitelisted(_) => StatusCode::FORBIDDEN, UserError::InvalidData(_) => StatusCode::UNPROCESSABLE_ENTITY, - UserError::InvalidPassword(_) => StatusCode::UNAUTHORIZED, - }, - - AppError::Sqlx(err) => match err { - sqlx::Error::RowNotFound => StatusCode::NOT_FOUND, - sqlx::Error::Protocol(_) => StatusCode::BAD_REQUEST, - sqlx::Error::Database(db_err) => match db_err.code().as_deref() { - Some("23505") => StatusCode::CONFLICT, - Some("23503") => StatusCode::BAD_REQUEST, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }, + UserError::InvalidPassword(_) => StatusCode::BAD_REQUEST, _ => StatusCode::INTERNAL_SERVER_ERROR, }, _ => StatusCode::INTERNAL_SERVER_ERROR, @@ -93,59 +115,12 @@ impl AppError { } } -impl From for AppError { - fn from(inner: SearchError) -> Self { - AppError::SearchError(inner) - } -} -impl From for AppError { - fn from(inner: AuthError) -> Self { - AppError::AuthError(inner) - } -} -impl From for AppError { - fn from(inner: UserError) -> Self { - AppError::UserError(inner) - } -} - -impl From for AppError { - fn from(inner: color_eyre::eyre::Error) -> Self { - AppError::GenericError(inner) - } -} -impl From for AppError { - fn from(inner: reqwest::Error) -> Self { - AppError::Reqwest(inner) - } -} - -impl From for AppError { - fn from(inner: sqlx::Error) -> Self { - AppError::Sqlx(inner) - } -} impl From for AppError { fn from(inner: sqlx::migrate::MigrateError) -> Self { AppError::Sqlx(inner.into()) } } -impl Display for AppError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - AppError::SearchError(e) => f.write_fmt(format_args!("{:?}", e)), - AppError::AuthError(e) => e.fmt(f), - AppError::UserError(e) => e.fmt(f), - AppError::Sqlx(e) => write!(f, "{}", e), - AppError::GenericError(e) => write!(f, "{}", e), - AppError::Cache(e) => write!(f, "{}", e), - AppError::Reqwest(e) => write!(f, "{}", e), - AppError::TaskJoin(e) => write!(f, "{}", e), - } - } -} - #[derive(serde::Serialize, Debug)] pub struct ErrorMap { errors: HashMap, Cow<'static, str>>, @@ -201,3 +176,78 @@ impl IntoResponse for AppError { } } } + +/// A little helper trait for more easily converting database constraint errors into API errors. +/// +/// ```rust,ignore +/// let user_id = sqlx::query_scalar!( +/// r#"insert into "user" (username, email, password_hash) values ($1, $2, $3) returning user_id"#, +/// username, +/// email, +/// password_hash +/// ) +/// .fetch_one(&pool) +/// .await +/// .on_constraint("user_username_key", |_| Error::unprocessable_entity([("username", "already taken")]))?; +/// ``` +/// +/// Something like this would ideally live in a `sqlx-axum` crate if it made sense to author one, +/// however its definition is tied pretty intimately to the `Error` type, which is itself +/// tied directly to application semantics. +/// +/// To actually make this work in a generic context would make it quite a bit more complex, +/// as you'd need an intermediate error type to represent either a mapped or an unmapped error, +/// and even then it's not clear how to handle `?` in the unmapped case without more boilerplate. +pub trait ResultExt { + /// If `self` contains a SQLx database constraint error with the given name, + /// transform the error. + /// + /// Otherwise, the result is passed through unchanged. + fn on_constraint( + self, + name: &str, + f: impl FnOnce(Box) -> AppError, + ) -> Result; + + fn catch_constraints( + self, + names: &[&str], + map_err: impl FnOnce(Box) -> AppError, + ) -> Result; +} + +impl ResultExt for Result +where + E: Into, +{ + fn on_constraint( + self, + name: &str, + map_err: impl FnOnce(Box) -> AppError, + ) -> Result { + self.map_err(|e| match e.into() { + AppError::Sqlx(sqlx::Error::Database(dbe)) if dbe.constraint() == Some(name) => { + map_err(dbe) + } + e => e, + }) + } + + fn catch_constraints( + self, + names: &[&str], + map_err: impl FnOnce(Box) -> AppError, + ) -> Result { + self.map_err(|e| match e.into() { + AppError::Sqlx(sqlx::Error::Database(dbe)) => { + // TODO: Needs some extra logic to handle multiple constraints. + if names.contains(&dbe.constraint().unwrap_or_default()) { + map_err(dbe) + } else { + AppError::Sqlx(sqlx::Error::Database(dbe)) + } + } + e => e, + }) + } +} diff --git a/server/src/search/api_models.rs b/server/src/search/api_models.rs index 5c99386e..e0aaa51c 100644 --- a/server/src/search/api_models.rs +++ b/server/src/search/api_models.rs @@ -98,21 +98,34 @@ pub struct UpdateThreadRequest { #[derive(Debug, thiserror::Error)] pub enum SearchError { + #[error(transparent)] Reqwest(#[from] reqwest::Error), + + #[error(transparent)] ReqwestHeaderName(#[from] InvalidHeaderName), + + #[error(transparent)] ReqwestHeaderValue(#[from] InvalidHeaderValue), + + #[error(transparent)] Serde(#[from] SerdeError), + + #[error(transparent)] Tonic(#[from] TonicStatus), + + #[error(transparent)] + Sqlx(#[from] sqlx::Error), + + #[error("Toxic query: {0}")] ToxicQuery(String), + #[error("Agency failure: {0}")] InvalidQuery(String), + #[error("No results: {0}")] AgencyFailure(String), + #[error("No sources: {0}")] NoResults(String), + #[error("No sources: {0}")] NoSources(String), + #[error("Other error: {0}")] Other(String), } - -impl std::fmt::Display for SearchError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "SearchError: {}", self) - } -} diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs index 300cc37a..8a91ca50 100644 --- a/server/src/search/routes.rs +++ b/server/src/search/routes.rs @@ -8,7 +8,6 @@ use crate::{settings::Settings, startup::AppState}; use axum::extract::{Query, State}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::routing::{get, patch}; -use axum::{http::StatusCode, response::IntoResponse}; use axum::{Json, Router}; use futures::{stream::StreamExt, Stream}; use regex::Regex; @@ -110,11 +109,10 @@ async fn get_one_search_result_handler( State(pool): State, user: User, Query(search_by_id_request): Query, -) -> crate::Result { +) -> crate::Result> { let search_history = services::get_one_search(&pool, &user.user_id, &search_by_id_request).await?; - - Ok((StatusCode::OK, Json(search_history))) + Ok(Json(search_history)) } #[tracing::instrument(level = "info", skip_all, ret, err(Debug))] @@ -122,15 +120,14 @@ async fn get_threads_handler( State(pool): State, user: User, Query(thread_history_request): Query, -) -> crate::Result { +) -> crate::Result> { thread_history_request .validate() .map_err(|e| SearchError::InvalidQuery(format!("Invalid thread history request: {}", e)))?; let search_history = services::get_threads(&pool, &user.user_id, &thread_history_request).await?; - - Ok((StatusCode::OK, Json(search_history))) + Ok(Json(search_history)) } #[tracing::instrument(level = "info", skip_all, ret, err(Debug))] @@ -138,14 +135,13 @@ async fn get_one_thread_handler( State(pool): State, user: User, Query(get_thread_request): Query, -) -> crate::Result { +) -> crate::Result> { get_thread_request .validate() .map_err(|e| SearchError::InvalidQuery(format!("Invalid get thread request: {}", e)))?; let search_thread = services::get_one_thread(&pool, &user.user_id, &get_thread_request).await?; - - Ok((StatusCode::OK, Json(search_thread))) + Ok(Json(search_thread)) } #[tracing::instrument(level = "info", skip_all, ret, err(Debug))] @@ -153,14 +149,13 @@ async fn update_thread_handler( State(pool): State, user: User, Json(update_thread_request): Json, -) -> crate::Result { +) -> crate::Result<()> { update_thread_request .validate() .map_err(|e| SearchError::InvalidQuery(format!("Invalid update thread request: {}", e)))?; services::update_thread(&pool, &user.user_id, &update_thread_request).await?; - - Ok(StatusCode::OK) + Ok(()) } #[tracing::instrument(level = "info", skip_all, ret, err(Debug))] @@ -168,10 +163,9 @@ async fn update_search_reaction_handler( State(pool): State, user: User, Json(search_reaction_request): Json, -) -> crate::Result { +) -> crate::Result<()> { services::update_search_reaction(&pool, &user.user_id, &search_reaction_request).await?; - - Ok(StatusCode::OK) + Ok(()) } pub fn routes() -> Router { diff --git a/server/src/search/services.rs b/server/src/search/services.rs index 3bed00bb..f5df93a3 100644 --- a/server/src/search/services.rs +++ b/server/src/search/services.rs @@ -4,13 +4,15 @@ use sqlx::PgPool; use std::collections::HashSet; use uuid::Uuid; +type Result = std::result::Result; + #[tracing::instrument(level = "info", ret, err)] pub async fn insert_new_search( pool: &PgPool, user_id: &Uuid, search_query_request: &api_models::SearchQueryRequest, rephrased_query: &String, -) -> crate::Result { +) -> Result { let thread = match search_query_request.thread_id { Some(thread_id) => { sqlx::query_as!( @@ -53,7 +55,7 @@ pub async fn append_search_result( pool: &PgPool, search: &data_models::Search, result_suffix: &String, -) -> crate::Result { +) -> Result { // Only used by internal services, so no need to check if user_id is the owner of the search let search = sqlx::query_as!( data_models::Search, @@ -72,7 +74,7 @@ pub async fn add_search_sources( pool: &PgPool, search: &data_models::Search, sources: &Vec, -) -> crate::Result> { +) -> Result> { if sources.len() == 0 { return Err(SearchError::NoSources("No sources to add".to_string()).into()); } @@ -125,7 +127,7 @@ pub async fn get_one_search( pool: &PgPool, user_id: &Uuid, search_by_id_request: &api_models::SearchByIdRequest, -) -> crate::Result { +) -> Result { let search = sqlx::query_as!( data_models::Search, "select s.* from searches s \ @@ -155,7 +157,7 @@ pub async fn get_last_n_searches( pool: &PgPool, last_n: u8, thread_id: &Uuid, -) -> crate::Result> { +) -> Result> { // Only used by internal services, so no need to check if user_id is the owner of the search let searches = sqlx::query_as!( data_models::Search, @@ -176,7 +178,7 @@ pub async fn get_threads( pool: &PgPool, user_id: &Uuid, thread_history_request: &api_models::ThreadHistoryRequest, -) -> crate::Result { +) -> Result { let threads = sqlx::query_as!( data_models::Thread, "select * from threads where user_id = $1 order by created_at desc limit $2 offset $3", @@ -195,7 +197,7 @@ pub async fn get_one_thread( pool: &PgPool, user_id: &Uuid, thread_by_id_request: &api_models::GetThreadRequest, -) -> crate::Result { +) -> Result { let thread = sqlx::query_as!( data_models::Thread, "select * from threads where thread_id = $1 and user_id = $2", @@ -250,7 +252,7 @@ pub async fn update_thread( pool: &PgPool, user_id: &Uuid, update_thread_request: &api_models::UpdateThreadRequest, -) -> crate::Result { +) -> Result { let thread = sqlx::query_as!( data_models::Thread, "update threads set title = $1 where thread_id = $2 and user_id = $3 returning *", @@ -269,7 +271,7 @@ pub async fn update_search_reaction( pool: &PgPool, user_id: &Uuid, search_reaction_request: &api_models::SearchReactionRequest, -) -> crate::Result { +) -> Result { let search = sqlx::query_as!( data_models::Search, "update searches s set reaction = $1 from threads t \ diff --git a/server/src/users/api_models.rs b/server/src/users/api_models.rs index dada998e..d55dfc35 100644 --- a/server/src/users/api_models.rs +++ b/server/src/users/api_models.rs @@ -30,9 +30,15 @@ impl UpdateProfileRequest { } } -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum UserError { - NotWhitelisted(String), + #[error("Invalid data: {0}")] InvalidData(String), + #[error("Invalid password: {0}")] InvalidPassword(String), + #[error("Other error: {0}")] + Other(String), + + #[error(transparent)] + Sqlx(#[from] sqlx::Error), } diff --git a/server/src/users/routes.rs b/server/src/users/routes.rs index d2a5b34a..5028cfbc 100644 --- a/server/src/users/routes.rs +++ b/server/src/users/routes.rs @@ -2,7 +2,7 @@ use crate::auth::utils::verify_user_password; use crate::startup::AppState; use crate::users::{api_models, services, User, UserError, UserRecord}; use axum::routing::{get, patch}; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Form, Json, Router}; +use axum::{extract::State, Form, Json, Router}; use sqlx::PgPool; use validator::Validate; @@ -37,7 +37,7 @@ async fn update_password_handler( State(pool): State, user: User, Form(update_password_request): Form, -) -> crate::Result { +) -> crate::Result<()> { let user_id = user.user_id; if update_password_request.old_password.expose() @@ -52,7 +52,7 @@ async fn update_password_handler( Ok(Some(_user)) => { services::update_password(&pool, &user_id, update_password_request.new_password) .await?; - Ok((StatusCode::OK, ())) + Ok(()) } _ => Err(UserError::InvalidPassword(format!("Failed to authenticate old password")).into()), } diff --git a/server/src/users/selectors.rs b/server/src/users/selectors.rs index be6b16e6..8152d574 100644 --- a/server/src/users/selectors.rs +++ b/server/src/users/selectors.rs @@ -1,9 +1,9 @@ -use crate::users::{User, UserRecord}; +use crate::users::{User, UserError, UserRecord}; use sqlx::types::uuid; use sqlx::PgPool; #[tracing::instrument(level = "info", ret, err)] -pub async fn get_user(pool: PgPool, user_id: uuid::Uuid) -> crate::Result> { +pub async fn get_user(pool: PgPool, user_id: uuid::Uuid) -> Result, UserError> { let user = sqlx::query_as!(User, "select * from users where user_id = $1", user_id) .fetch_optional(&pool) .await?; diff --git a/server/src/users/services.rs b/server/src/users/services.rs index 4d282eed..3f5c2818 100644 --- a/server/src/users/services.rs +++ b/server/src/users/services.rs @@ -1,15 +1,17 @@ use crate::auth::utils::hash_password; use crate::secrets::Secret; -use crate::users::{api_models, models}; +use crate::users::{api_models, models, UserError}; use sqlx::PgPool; use uuid::Uuid; +type Result = std::result::Result; + #[tracing::instrument(level = "info", ret, err)] pub async fn update_profile( pool: &PgPool, user_id: &Uuid, update_profile_request: api_models::UpdateProfileRequest, -) -> crate::Result { +) -> Result { let user = sqlx::query_as!( models::User, "