Skip to content

Commit

Permalink
Improve Secret wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Mar 25, 2024
1 parent 551da94 commit e870efd
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 30 deletions.
2 changes: 1 addition & 1 deletion server/src/health_check/routes.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use axum::Router;
use axum::routing::get;
use axum::Router;

use crate::health_check::selectors::health_check;
use crate::startup::AppState;
Expand Down
25 changes: 14 additions & 11 deletions server/src/secrets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,30 @@ use std::error::Error;
use std::fmt::{Debug, Display};

use serde::{Deserialize, Deserializer};
use sqlx::{Decode, Postgres};
use sqlx::database::HasValueRef;
use sqlx::{Decode, Postgres};

/// A wrapper around a value that should be kept secret
/// when displayed. This is useful for fields like passwords
/// and access tokens. The value is redacted when displayed
/// or debugged.
#[derive(Default, Clone)]
pub struct Secret<T>(pub T)
pub struct Secret<T>(T)
where
T: Default + Clone;

impl<T> Secret<T>
where
T: Default + Clone,
{
pub fn expose(&self) -> &T {
&self.0
}
pub fn expose_owned(self) -> T {
self.0
}
}

impl<T> Display for Secret<T>
where
T: Default + Clone + Display,
Expand Down Expand Up @@ -65,15 +77,6 @@ where
}
}

impl<T> AsRef<T> for Secret<T>
where
T: Default + Clone,
{
fn as_ref(&self) -> &T {
&self.0
}
}

impl<T> From<T> for Secret<T>
where
T: Default + Clone,
Expand Down
5 changes: 3 additions & 2 deletions server/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use std::{env, fmt::Display};
use config::{Config, Environment, File};
use dotenvy::dotenv;
use once_cell::sync::Lazy;
use secrecy::SecretString;
use serde::{Deserialize, Deserializer};

use crate::secrets::Secret;

#[derive(Debug, Clone)]
pub enum LogFmt {
Json,
Expand Down Expand Up @@ -76,7 +77,7 @@ pub struct Settings {
pub log: Log,
pub host: String,
pub port: u16,
pub db: SecretString,
pub db: Secret<String>,
}

impl Settings {
Expand Down
3 changes: 1 addition & 2 deletions server/src/startup.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use axum::{extract::FromRef, routing::IntoMakeService, serve::Serve, Router};
use color_eyre::eyre::eyre;
use color_eyre::Result;
use secrecy::ExposeSecret;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use tokio::net::TcpListener;
Expand Down Expand Up @@ -53,7 +52,7 @@ pub async fn db_connect(database_url: &str) -> Result<PgPool> {

async fn run(listener: TcpListener) -> Result<Serve<IntoMakeService<Router>, Router>> {
let settings = Settings::new();
let db = db_connect(settings.db.expose_secret()).await?;
let db = db_connect(settings.db.expose()).await?;

let state = AppState { db, settings };

Expand Down
18 changes: 9 additions & 9 deletions server/src/users/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ use std::fmt::Debug;

use async_trait::async_trait;
use axum::http::header::{AUTHORIZATION, USER_AGENT};
use axum_login::{AuthnBackend, AuthUser, UserId};
use axum_login::{AuthUser, AuthnBackend, UserId};
use oauth2::{
AuthorizationCode,
basic::{BasicClient, BasicRequestTokenError},
CsrfToken,
reqwest::{async_http_client, AsyncHttpClientError}, TokenResponse, url::Url,
reqwest::{async_http_client, AsyncHttpClientError},
url::Url,
AuthorizationCode, CsrfToken, TokenResponse,
};
use password_auth::verify_password;
use serde::{Deserialize, Serialize};
use sqlx::{FromRow, PgPool};
use sqlx::types::time;
use sqlx::{FromRow, PgPool};
use tokio::task;

use crate::secrets::Secret;
Expand Down Expand Up @@ -55,11 +55,11 @@ impl AuthUser for User {
}

fn session_auth_hash(&self) -> &[u8] {
if let Some(access_token) = &self.access_token.as_ref() {
if let Some(access_token) = &self.access_token.expose() {
return access_token.as_bytes();
}

if let Some(password) = &self.password_hash.as_ref() {
if let Some(password) = &self.password_hash.expose() {
return password.as_bytes();
}

Expand Down Expand Up @@ -151,10 +151,10 @@ impl AuthnBackend for PostgresBackend {
// We're using password-based authentication: this works by comparing our form
// input with an argon2 password hash.
Ok(user.filter(|user| {
let Some(password) = user.password_hash.as_ref() else {
let Some(password) = user.password_hash.expose() else {
return false;
};
verify_password(password_cred.password.as_ref(), password.as_ref()).is_ok()
verify_password(password_cred.password.expose(), password.as_ref()).is_ok()
}))
})
.await?
Expand Down
2 changes: 1 addition & 1 deletion server/src/users/selectors.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use axum::extract::State;
use sqlx::{Acquire, PgPool};
use sqlx::types::uuid;
use sqlx::{Acquire, PgPool};

use crate::users::User;

Expand Down
4 changes: 2 additions & 2 deletions server/src/users/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::users::{CreateUser, User};
async fn create_user(pool: PgPool, create_user: CreateUser) -> color_eyre::Result<Option<User>> {
let mut conn = pool.acquire().await?;

if let Some(password_hash) = create_user.password_hash.as_ref() {
if let Some(password_hash) = create_user.password_hash.expose() {
let user = sqlx::query_as!(
User,
"INSERT INTO users (email, username, password_hash) VALUES ($1, $2, $3) RETURNING *",
Expand All @@ -19,7 +19,7 @@ async fn create_user(pool: PgPool, create_user: CreateUser) -> color_eyre::Resul
.await?;

return Ok(user);
} else if let Some(access_token) = create_user.access_token.as_ref() {
} else if let Some(access_token) = create_user.access_token.expose() {
let user = sqlx::query_as!(
User,
"INSERT INTO users (email, username, access_token) VALUES ($1, $2, $3) RETURNING *",
Expand Down
3 changes: 1 addition & 2 deletions server/tests/health_check.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use axum::body::Body;
use axum::http::{Request, StatusCode};
use secrecy::ExposeSecret;
use tower::ServiceExt;

use server::routing::router;
Expand All @@ -11,7 +10,7 @@ use server::startup::{db_connect, AppState};
async fn health_check_works() {
let settings = Settings::new();

let db = db_connect(settings.db.expose_secret())
let db = db_connect(settings.db.expose())
.await
.expect("Failed to connect to Postgres.");

Expand Down

0 comments on commit e870efd

Please sign in to comment.