Skip to content

Commit

Permalink
Use a high level Arc for ApiContext to avoid multiple deep clones
Browse files Browse the repository at this point in the history
  • Loading branch information
jakewmeyer committed Sep 8, 2023
1 parent 42f3c65 commit 07a2239
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 48 deletions.
24 changes: 13 additions & 11 deletions src/api/accounts.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use axum::{
extract::{rejection::PathRejection, Path, State},
http::StatusCode,
Expand Down Expand Up @@ -25,7 +27,7 @@ pub struct UpdateAccount {
pub provider_id: Option<String>,
}

pub fn routes() -> Router<ApiContext> {
pub fn routes() -> Router<Arc<ApiContext>> {
Router::new()
.route("/v1/accounts", get(list_accounts_handler))
.route("/v1/accounts/:id", get(get_account_by_id_handler))
Expand All @@ -35,7 +37,7 @@ pub fn routes() -> Router<ApiContext> {
}

pub async fn list_accounts(
ctx: &ApiContext,
ctx: &Arc<ApiContext>,
page: &Pagination,
) -> Result<Vec<accounts::Model>, Error> {
let accounts = Accounts::find()
Expand All @@ -47,22 +49,22 @@ pub async fn list_accounts(
Ok(accounts)
}

pub async fn get_account_by_id(ctx: &ApiContext, id: Uuid) -> Result<accounts::Model, Error> {
pub async fn get_account_by_id(ctx: &Arc<ApiContext>, id: Uuid) -> Result<accounts::Model, Error> {
Accounts::find_by_id(id)
.one(&ctx.db)
.await?
.ok_or(Error::NotFound)
}

pub async fn create_account(ctx: &ApiContext) -> Result<accounts::Model, Error> {
pub async fn create_account(ctx: &Arc<ApiContext>) -> Result<accounts::Model, Error> {
let account = accounts::ActiveModel {
..Default::default()
};
let account = account.insert(&ctx.db).await?;
Ok(account)
}

pub async fn delete_account(ctx: &ApiContext, id: Uuid) -> Result<accounts::Model, Error> {
pub async fn delete_account(ctx: &Arc<ApiContext>, id: Uuid) -> Result<accounts::Model, Error> {
let account = Accounts::find_by_id(id).one(&ctx.db).await?;
let account = account.ok_or(Error::NotFound)?;
let mut account: accounts::ActiveModel = account.into();
Expand All @@ -72,7 +74,7 @@ pub async fn delete_account(ctx: &ApiContext, id: Uuid) -> Result<accounts::Mode
}

pub async fn list_account_users(
ctx: &ApiContext,
ctx: &Arc<ApiContext>,
id: Uuid,
page: &Pagination,
) -> Result<Vec<users::Model>, Error> {
Expand All @@ -90,7 +92,7 @@ pub async fn list_account_users(

pub async fn list_accounts_handler(
user: AuthUser,
State(ctx): State<ApiContext>,
State(ctx): State<Arc<ApiContext>>,
page: Pagination,
) -> Result<impl IntoResponse, Error> {
user.has_permission("list:account")?;
Expand All @@ -100,7 +102,7 @@ pub async fn list_accounts_handler(

pub async fn create_account_handler(
user: AuthUser,
State(ctx): State<ApiContext>,
State(ctx): State<Arc<ApiContext>>,
) -> Result<impl IntoResponse, Error> {
user.has_permission("create:account")?;
let created = create_account(&ctx).await?;
Expand All @@ -109,7 +111,7 @@ pub async fn create_account_handler(

pub async fn get_account_by_id_handler(
user: AuthUser,
State(ctx): State<ApiContext>,
State(ctx): State<Arc<ApiContext>>,
account_id: Result<Path<Uuid>, PathRejection>,
) -> Result<impl IntoResponse, Error> {
user.has_permission("retrieve:account")?;
Expand All @@ -120,7 +122,7 @@ pub async fn get_account_by_id_handler(

pub async fn delete_account_handler(
user: AuthUser,
State(ctx): State<ApiContext>,
State(ctx): State<Arc<ApiContext>>,
account_id: Result<Path<Uuid>, PathRejection>,
) -> Result<impl IntoResponse, Error> {
user.has_permission("delete:account")?;
Expand All @@ -131,7 +133,7 @@ pub async fn delete_account_handler(

pub async fn list_account_users_handler(
user: AuthUser,
State(ctx): State<ApiContext>,
State(ctx): State<Arc<ApiContext>>,
account_id: Result<Path<Uuid>, PathRejection>,
page: Pagination,
) -> Result<impl IntoResponse, Error> {
Expand Down
8 changes: 5 additions & 3 deletions src/api/auth.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use axum::{
async_trait,
extract::{FromRef, FromRequestParts},
Expand All @@ -21,7 +23,7 @@ pub struct AuthUser {

impl AuthUser {
async fn from_authorization(
ctx: &ApiContext,
ctx: &Arc<ApiContext>,
auth_header: &HeaderValue,
) -> Result<Self, Error> {
let token = auth_header
Expand Down Expand Up @@ -71,12 +73,12 @@ impl AuthUser {
#[async_trait]
impl<S> FromRequestParts<S> for AuthUser
where
ApiContext: FromRef<S>,
Arc<ApiContext>: FromRef<S>,
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let ctx = ApiContext::from_ref(state);
let ctx = Arc::from_ref(state);
if let Some(auth_header) = parts.headers.get(header::AUTHORIZATION) {
Ok(Self::from_authorization(&ctx, auth_header).await?)
} else {
Expand Down
37 changes: 21 additions & 16 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@ use ::stripe::Client as StripeClient;
use ::stripe::RequestStrategy::ExponentialBackoff;
use anyhow::Result;
use axum::{middleware, Router};
use sea_orm::{Database, DatabaseConnection};
use tower_http::timeout::TimeoutLayer;
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::signal;
use tokio::{signal, select};
use tokio::sync::Mutex;
use tower_default_headers::DefaultHeadersLayer;
use tower_http::compression::CompressionLayer;
use tower_http::cors::CorsLayer;
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing::{info, error};

use self::ratelimit::TokenBucket;

Expand All @@ -32,17 +31,16 @@ mod ratelimit;
mod stripe;
mod users;

#[derive(Clone)]
pub struct ApiContext {
db: DatabaseConnection,
config: Config,
rate_limit: Arc<Mutex<HashMap<String, TokenBucket>>>,
rate_limit: Mutex<HashMap<String, TokenBucket>>,
stripe_client: StripeClient,
auth0_client: Client,
}

/// Creates a signal handler for graceful shutdown.
async fn shutdown_signal() {
async fn shutdown_signal(ctx: Arc<ApiContext>) {
// Handle SIGINT
let ctrl_c = async {
signal::ctrl_c()
Expand All @@ -62,21 +60,27 @@ async fn shutdown_signal() {
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();

tokio::select! {
select! {
_ = ctrl_c => {},
_ = terminate => {},
}

// Any other graceful shutdow logic goes here
info!("Signal received, starting graceful shutdown...");
ctx.db.clone().close().await.unwrap_or_else(|e| {
error!("Failed to close database connection: {}", e);
});
info!("Graceful shutdown complete");
}

/// Create and serve an Axum server with pre-registered routes
/// and middleware
pub async fn serve(config: Config) -> Result<()> {
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;

let db = Database::connect(&config.database_url).await?;
let mut opts = ConnectOptions::new(&config.database_url);
opts.connect_timeout(config.database_timeout);
let db = Database::connect(opts).await?;

let stripe_client =
StripeClient::new(&config.stripe_secret_key).with_strategy(ExponentialBackoff(5));
Expand All @@ -88,20 +92,20 @@ pub async fn serve(config: Config) -> Result<()> {
);
auth0_client.load_jwk().await?;

let state = ApiContext {
let state = Arc::new(ApiContext {
config: config.clone(),
db,
rate_limit: Arc::new(Mutex::new(HashMap::new())),
rate_limit: Mutex::new(HashMap::new()),
stripe_client,
auth0_client,
};
});

let app = Router::new()
.merge(public::routes())
.merge(accounts::routes())
.merge(users::routes())
.merge(stripe::routes())
.layer(TimeoutLayer::new(Duration::from_secs(config.request_timeout)))
.layer(TimeoutLayer::new(config.request_timeout))
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.layer(middleware::from_fn_with_state(
Expand All @@ -112,12 +116,13 @@ pub async fn serve(config: Config) -> Result<()> {
.layer(PropagateRequestIdLayer::x_request_id())
.layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
.layer(DefaultHeadersLayer::new(owasp_headers::headers()))
.with_state(state);
.with_state(state.clone());

info!("Listening on {}", addr);
axum::Server::try_bind(&addr)?
.http1_header_read_timeout(config.request_timeout)
.serve(app.into_make_service())
.with_graceful_shutdown(shutdown_signal())
.with_graceful_shutdown(shutdown_signal(state.clone()))
.await?;
Ok(())
}
4 changes: 3 additions & 1 deletion src/api/public.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::sync::Arc;

use axum::{http::StatusCode, response::IntoResponse, routing::get, Router};

use super::ApiContext;

pub fn routes() -> Router<ApiContext> {
pub fn routes() -> Router<Arc<ApiContext>> {
Router::new().route("/health", get(healthcheck))
}

Expand Down
4 changes: 2 additions & 2 deletions src/api/ratelimit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Instant;
use std::{time::Instant, sync::Arc};

use axum::{extract::State, http::Request, middleware::Next, response::IntoResponse};

Expand Down Expand Up @@ -48,7 +48,7 @@ impl TokenBucket {
}

pub async fn limiter<B>(
State(ctx): State<ApiContext>,
State(ctx): State<Arc<ApiContext>>,
req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, Error> {
Expand Down
6 changes: 3 additions & 3 deletions src/api/stripe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ use crate::error::Error;

use super::ApiContext;

pub fn routes() -> Router<ApiContext> {
pub fn routes() -> Router<Arc<ApiContext>> {
Router::new().route("/v1/stripe/webhooks", post(stripe_webhook_handler))
}

// Handler for GET /v1/stripe/webhooks
async fn stripe_webhook_handler(
State(ctx): State<ApiContext>,
State(ctx): State<Arc<ApiContext>>,
headers: HeaderMap,
body: String,
) -> Result<impl IntoResponse, Error> {
let stripe_signature = headers
.get("stripe-signature")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let stripe_webhook_secret = ctx.config.stripe_webhook_secret;
let stripe_webhook_secret = &ctx.config.stripe_webhook_secret;
let event = Webhook::construct_event(&body, stripe_signature, &stripe_webhook_secret)?;
let _event = Arc::new(event);
Ok(StatusCode::OK)
Expand Down
Loading

0 comments on commit 07a2239

Please sign in to comment.