Skip to content

Commit

Permalink
🚧 added more error codes and restructured error sending process
Browse files Browse the repository at this point in the history
  • Loading branch information
rathijitpapon committed Jul 2, 2024
1 parent 0f34fb6 commit d136ff3
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 166 deletions.
27 changes: 25 additions & 2 deletions server/src/auth/api_models.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
use crate::secrets::Secret;
use oauth2::{basic::BasicRequestTokenError, reqwest::AsyncHttpClientError};
use serde::de::Error;
use serde::{Deserialize, Deserializer};
use tokio::task;
use validator::Validate;

#[derive(Debug)]
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error(transparent)]
Sqlx(#[from] sqlx::Error),

#[error(transparent)]
Reqwest(#[from] reqwest::Error),

#[error(transparent)]
OAuth2(#[from] BasicRequestTokenError<AsyncHttpClientError>),

#[error(transparent)]
TaskJoin(#[from] task::JoinError),

#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Invalid session: {0}")]
InvalidSession(String),
BackendError(String),
#[error("Other error: {0}")]
UserAlreadyExists(String),
#[error("Other error: {0}")]
InvalidData(String),
#[error("Not whitelisted: {0}")]
NotWhitelisted(String),
#[error("Other error: {0}")]
Other(String),
}

#[derive(Deserialize, Clone, Debug, Validate)]
Expand Down
18 changes: 10 additions & 8 deletions server/src/auth/models.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::auth::{oauth2::OAuth2Client, utils, AuthError};
use crate::err::AppError;
use crate::secrets::Secret;
use crate::telemetry::spawn_blocking_with_tracing;
use crate::users::{User, UserError};
use crate::users::User;
use async_trait::async_trait;
use axum::http::header::{AUTHORIZATION, USER_AGENT};
use axum_login::{AuthnBackend, UserId};
Expand Down Expand Up @@ -39,7 +38,7 @@ impl PostgresBackend {
async fn password_authenticate(
db: &PgPool,
password_credentials: PasswordCredentials,
) -> crate::Result<Option<User>> {
) -> Result<Option<User>, AuthError> {
let user = sqlx::query_as!(
User,
"select * from users where email = $1 and password_hash is not null",
Expand All @@ -61,7 +60,7 @@ async fn oauth_authenticate(
db: &PgPool,
oauth2_clients: &[OAuth2Client],
oauth_creds: OAuthCredentials,
) -> crate::Result<Option<User>> {
) -> Result<Option<User>, AuthError> {
// Ensure the CSRF state has not been tampered with.
if oauth_creds.old_state.secret() != oauth_creds.new_state.secret() {
return Ok(None);
Expand Down Expand Up @@ -111,15 +110,18 @@ async fn oauth_authenticate(
impl AuthnBackend for PostgresBackend {
type User = User;
type Credentials = Credentials;
type Error = AppError;
type Error = AuthError;

#[tracing::instrument(level = "info", skip(self), ret, err)]
async fn authenticate(&self, creds: Self::Credentials) -> crate::Result<Option<Self::User>> {
async fn authenticate(
&self,
creds: Self::Credentials,
) -> Result<Option<Self::User>, AuthError> {
match creds {
Credentials::Password(password_cred) => {
password_cred
.validate()
.map_err(|e| UserError::InvalidData(format!("Invalid credentials: {}", e)))?;
.map_err(|e| AuthError::Unauthorized(format!("Invalid credentials: {}", e)))?;
password_authenticate(&self.db, password_cred).await
}
Credentials::OAuth(oauth_creds) => {
Expand All @@ -129,7 +131,7 @@ impl AuthnBackend for PostgresBackend {
}

#[tracing::instrument(level = "info", skip(self), ret, err)]
async fn get_user(&self, user_id: &UserId<Self>) -> crate::Result<Option<Self::User>> {
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, AuthError> {
sqlx::query_as!(
Self::User,
"select * from users where user_id = $1",
Expand Down
20 changes: 8 additions & 12 deletions server/src/auth/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::auth::{
AuthError, AuthSession, Credentials, OAuthCredentials, PasswordCredentials, RegisterUserRequest,
};
use crate::startup::AppState;
use crate::users::{UserError, UserRecord};
use crate::users::UserRecord;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Redirect, Response};
Expand All @@ -22,10 +22,10 @@ async fn register_handler(
) -> crate::Result<impl IntoResponse> {
request
.validate()
.map_err(|e| UserError::InvalidData(format!("Invalid data: {}", e)))?;
.map_err(|e| AuthError::InvalidData(format!("Invalid data: {}", e)))?;

if !services::is_email_whitelisted(&pool, &request.email).await? {
return Err(UserError::NotWhitelisted(format!("This email is not whitelisted!")).into());
return Err(AuthError::NotWhitelisted(format!("This email is not whitelisted!")).into());
}
services::register(pool, request)
.await
Expand All @@ -43,13 +43,11 @@ async fn login_handler(
{
Ok(Some(user)) => user,
Ok(None) => return Err(AuthError::Unauthorized(format!("Invalid credentials")).into()),
Err(_) => {
return Err(AuthError::BackendError(format!("Could not authenticate user")).into())
}
Err(_) => return Err(AuthError::Other(format!("Could not authenticate user")).into()),
};

if auth_session.login(&user).await.is_err() {
return Err(AuthError::BackendError(format!("Could not login user")).into());
return Err(AuthError::Other(format!("Could not login user")).into());
}
//if let Credentials::Password(_pw_creds) = creds {
// if let Some(ref next) = pw_creds.next {
Expand Down Expand Up @@ -123,13 +121,11 @@ pub async fn oauth_callback_handler(
let user = match auth_session.authenticate(creds).await {
Ok(Some(user)) => user,
Ok(None) => return Err(AuthError::Unauthorized(format!("Invalid credentials")).into()),
Err(_) => {
return Err(AuthError::BackendError(format!("Could not authenticate user")).into())
}
Err(_) => return Err(AuthError::Other(format!("Could not authenticate user")).into()),
};

if auth_session.login(&user).await.is_err() {
return Err(AuthError::BackendError(format!("Could not login user")).into());
return Err(AuthError::Other(format!("Could not login user")).into());
}

if let Ok(Some(next)) = session.remove::<String>(NEXT_URL_KEY).await {
Expand All @@ -144,7 +140,7 @@ async fn logout_handler(mut auth_session: AuthSession) -> crate::Result<()> {
auth_session
.logout()
.await
.map_err(|e| AuthError::BackendError(format!("Could not logout user: {}", e)))?;
.map_err(|e| AuthError::Other(format!("Could not logout user: {}", e)))?;

Ok(())
}
Expand Down
40 changes: 35 additions & 5 deletions server/src/auth/services.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::auth::AuthError;
use crate::auth::{api_models, models, utils};
use crate::users::{User, UserError, UserRecord};
use crate::err::ResultExt;
use crate::users::{User, UserRecord};
use sqlx::PgPool;

#[tracing::instrument(level = "info", ret, err)]
Expand All @@ -17,7 +19,21 @@ pub async fn register(
password_hash.expose()
)
.fetch_one(&pool)
.await?;
.await
.on_constraint("users_username_key", |_| {
AuthError::UserAlreadyExists(format!(
"User with username {} already exists",
request.email
))
.into()
})
.on_constraint("users_email_key", |_| {
AuthError::UserAlreadyExists(format!(
"User with email {} already exists",
request.email
))
.into()
})?;

return Ok(user.into());
} else if let Some(access_token) = request.access_token {
Expand All @@ -29,18 +45,32 @@ pub async fn register(
access_token.expose()
)
.fetch_one(&pool)
.await?;
.await
.on_constraint("users_username_key", |_| {
AuthError::UserAlreadyExists(format!(
"User with username {} already exists",
request.email
))
.into()
})
.on_constraint("users_email_key", |_| {
AuthError::UserAlreadyExists(format!(
"User with email {} already exists",
request.email
))
.into()
})?;

return Ok(user.into());
}

Err(UserError::InvalidData(format!(
Err(AuthError::InvalidData(format!(
"Either password or access_token must be provided to create a user"
))
.into())
}

pub async fn is_email_whitelisted(pool: &PgPool, email: &String) -> crate::Result<bool> {
pub async fn is_email_whitelisted(pool: &PgPool, email: &String) -> Result<bool, AuthError> {
let whitelisted_email = sqlx::query_as!(
models::WhitelistedEmail,
"SELECT * FROM whitelisted_emails WHERE email = $1",
Expand Down
7 changes: 4 additions & 3 deletions server/src/auth/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::auth::AuthError;
use crate::secrets::Secret;
use crate::users::User;
use password_auth::{generate_hash, verify_password};
Expand All @@ -7,7 +8,7 @@ use tokio::task::spawn_blocking;
pub fn verify_user_password(
user: Option<User>,
password_candidate: Secret<String>,
) -> crate::Result<Option<User>> {
) -> Result<Option<User>, AuthError> {
// password-based authentication. Compare our form input with an argon2
// password hash.
// To prevent timed side-channel attacks, so we always compare the password
Expand All @@ -31,7 +32,7 @@ pub fn verify_user_password(
}

// Prevent side-channel attacks by always verifying the password.
pub fn dummy_verify_password(pw: Secret<impl AsRef<[u8]>>) -> crate::Result<Option<User>> {
pub fn dummy_verify_password(pw: Secret<impl AsRef<[u8]>>) -> Result<Option<User>, AuthError> {
let _ = verify_password(
pw.expose_owned().as_ref(),
"$argon2id$v=19$m=15000,t=2,p=1$\
Expand All @@ -42,6 +43,6 @@ pub fn dummy_verify_password(pw: Secret<impl AsRef<[u8]>>) -> crate::Result<Opti
Ok(None)
}

pub async fn hash_password(password: Secret<String>) -> crate::Result<Secret<String>> {
pub async fn hash_password(password: Secret<String>) -> Result<Secret<String>, AuthError> {
Ok(spawn_blocking(move || Secret::new(generate_hash(password.expose().as_bytes()))).await?)
}
7 changes: 0 additions & 7 deletions server/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::err::AppError;
use crate::secrets::Secret;
use axum::extract::FromRef;
use bb8::Pool;
Expand Down Expand Up @@ -34,12 +33,6 @@ pub enum CacheError {
Serde(#[from] serde_json::Error),
}

impl From<CacheError> for AppError {
fn from(e: CacheError) -> Self {
AppError::Cache(e)
}
}

impl CachePool {
pub async fn new(cache_settings: &CacheSettings) -> Result<Self, CacheError> {
tracing::debug!("Connecting to redis");
Expand Down
Loading

0 comments on commit d136ff3

Please sign in to comment.