Skip to content

Commit

Permalink
Roll custom secrecy impl with sqlx support. Add create_user service
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Mar 25, 2024
1 parent 38c35a2 commit 551da94
Show file tree
Hide file tree
Showing 13 changed files with 209 additions and 51 deletions.
29 changes: 18 additions & 11 deletions server/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 2 additions & 6 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,12 @@ oauth2 = "4.4.2"
once_cell = "1.19.0"
password-auth = "1.0.0"
reqwest = { version = "0.12.1", features = ["json"] }
secrecy = { version = "0.8.0", features = ["serde"] }
serde = "1.0.197"
serde_json = "1.0.114"
sqlx = { version = "0.7.4", features = ["postgres", "runtime-tokio", "tls-rustls"] }
sqlx = { version = "0.7.4", features = ["postgres", "runtime-tokio", "tls-rustls", "migrate", "uuid", "time"] }
thiserror = "1.0.58"
tokio = { version = "1.36.0", features = ["full"] }
tower-http = { version = "0.5.2", features = ["trace", "cors"] }
tracing = "0.1.40"
tower = "0.4.13"

# Compile-time verifying queries does not need optimizations
[profile.dev.package.sqlx-macros]
opt-level = 3
uuid = { version = "1.8.0", features = ["serde"] }
1 change: 0 additions & 1 deletion server/config/dev.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

db = "postgresql://postgres:postgres@localhost/curieo"

[log]
Expand Down
2 changes: 1 addition & 1 deletion server/src/health_check/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub use routes::*;

mod handlers;
mod routes;
mod selectors;
4 changes: 2 additions & 2 deletions server/src/health_check/routes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use axum::routing::get;
use axum::Router;
use axum::routing::get;

use crate::health_check::handlers::health_check;
use crate::health_check::selectors::health_check;
use crate::startup::AppState;

pub fn routes() -> Router<AppState> {
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use startup::Application;
mod err;
mod health_check;
pub mod routing;
pub mod secrets;
pub mod settings;
pub mod startup;
pub mod users;
Expand Down
84 changes: 84 additions & 0 deletions server/src/secrets.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::error::Error;
use std::fmt::{Debug, Display};

use serde::{Deserialize, Deserializer};
use sqlx::{Decode, Postgres};
use sqlx::database::HasValueRef;

/// A wrapper around a value that should be kept secret
/// when displayed. This is useful for fields like passwords
/// and access tokens. The value is redacted when displayed
/// or debugged.
#[derive(Default, Clone)]
pub struct Secret<T>(pub T)
where
T: Default + Clone;

impl<T> Display for Secret<T>
where
T: Default + Clone + Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[redacted]")
}
}

impl<T> Debug for Secret<T>
where
T: Default + Clone + Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[redacted]")
}
}

impl<T> sqlx::Type<Postgres> for Secret<T>
where
T: Default + Clone + sqlx::Type<Postgres>,
{
fn type_info() -> sqlx::postgres::PgTypeInfo {
<T as sqlx::Type<Postgres>>::type_info()
}
}

impl<T> sqlx::Decode<'_, Postgres> for Secret<T>
where
for<'a> T: sqlx::Type<Postgres> + sqlx::Decode<'a, Postgres> + Default + Clone,
{
fn decode(
value: <Postgres as HasValueRef<'_>>::ValueRef,
) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
let value = <T as Decode<Postgres>>::decode(value)?;
Ok(Secret(value))
}
}

impl<'de, T> Deserialize<'de> for Secret<T>
where
T: Deserialize<'de> + Default + Clone + Debug,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
T::deserialize(deserializer).map(Secret)
}
}

impl<T> AsRef<T> for Secret<T>
where
T: Default + Clone,
{
fn as_ref(&self) -> &T {
&self.0
}
}

impl<T> From<T> for Secret<T>
where
T: Default + Clone,
{
fn from(s: T) -> Self {
Self(s)
}
}
5 changes: 4 additions & 1 deletion server/src/users/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
mod models;
pub use models::*;

mod models;
mod selectors;
mod services;
63 changes: 34 additions & 29 deletions server/src/users/models.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,44 @@
use std::fmt::Debug;

use async_trait::async_trait;
use axum::http::header::{AUTHORIZATION, USER_AGENT};
use axum_login::{AuthUser, AuthnBackend, UserId};
use axum_login::{AuthnBackend, AuthUser, UserId};
use oauth2::{
AuthorizationCode,
basic::{BasicClient, BasicRequestTokenError},
reqwest::{async_http_client, AsyncHttpClientError},
url::Url,
AuthorizationCode, CsrfToken, TokenResponse,
CsrfToken,
reqwest::{async_http_client, AsyncHttpClientError}, TokenResponse, url::Url,
};
use password_auth::verify_password;
use serde::{Deserialize, Serialize};
use sqlx::{FromRow, PgPool};
use sqlx::types::time;
use tokio::task;

#[derive(sqlx::Type, Deserialize, Clone)]
#[sqlx(transparent)]
pub struct SecretString(String);
use crate::secrets::Secret;

impl AsRef<str> for SecretString {
fn as_ref(&self) -> &str {
&self.0
}
}
#[derive(sqlx::FromRow, Clone, Deserialize, Debug)]
pub struct User {
pub user_id: uuid::Uuid,
pub email: String,
pub username: String,
#[sqlx(default)]
pub password_hash: Secret<Option<String>>,
#[sqlx(default)]
pub access_token: Secret<Option<String>>,

impl std::fmt::Debug for SecretString {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.pad("[redacted]")
}
pub created_at: time::OffsetDateTime,
pub updated_at: Option<time::OffsetDateTime>,
}

#[derive(Clone, Deserialize, Debug, FromRow)]
pub struct User {
id: i64,
#[derive(sqlx::Encode, Clone, Deserialize, Debug)]
pub struct CreateUser {
pub email: String,
pub username: String,
pub password: Option<SecretString>,
pub access_token: Option<SecretString>,
#[sqlx(default)]
pub password_hash: Secret<Option<String>>,
#[sqlx(default)]
pub access_token: Secret<Option<String>>,
}

#[derive(Clone, Serialize, Deserialize, FromRow)]
Expand All @@ -43,19 +48,19 @@ pub struct UserOut {
}

impl AuthUser for User {
type Id = i64;
type Id = uuid::Uuid;

fn id(&self) -> Self::Id {
self.id
self.user_id
}

fn session_auth_hash(&self) -> &[u8] {
if let Some(access_token) = &self.access_token {
return access_token.as_ref().as_bytes();
if let Some(access_token) = &self.access_token.as_ref() {
return access_token.as_bytes();
}

if let Some(password) = &self.password {
return password.as_ref().as_bytes();
if let Some(password) = &self.password_hash.as_ref() {
return password.as_bytes();
}

&[]
Expand All @@ -71,7 +76,7 @@ pub enum Credentials {
#[derive(Debug, Clone, Deserialize)]
pub struct PasswordCreds {
pub username: String,
pub password: SecretString,
pub password: Secret<String>,
pub next: Option<String>,
}

Expand Down Expand Up @@ -146,7 +151,7 @@ impl AuthnBackend for PostgresBackend {
// We're using password-based authentication: this works by comparing our form
// input with an argon2 password hash.
Ok(user.filter(|user| {
let Some(ref password) = user.password else {
let Some(password) = user.password_hash.as_ref() else {
return false;
};
verify_password(password_cred.password.as_ref(), password.as_ref()).is_ok()
Expand Down
18 changes: 18 additions & 0 deletions server/src/users/selectors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use axum::extract::State;
use sqlx::{Acquire, PgPool};
use sqlx::types::uuid;

use crate::users::User;

#[tracing::instrument(level = "debug", ret, err)]
async fn get_user(
State(pool): State<PgPool>,
user_id: uuid::Uuid,
) -> color_eyre::Result<Option<User>> {
let mut conn = pool.acquire().await?;
let user = sqlx::query_as!(User, "SELECT * FROM users WHERE user_id = $1", user_id)
.fetch_optional(conn.acquire().await?)
.await?;

Ok(user)
}
Loading

0 comments on commit 551da94

Please sign in to comment.