diff --git a/src/api/mod.rs b/src/api/mod.rs index 7091a09..007f89c 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -3,13 +3,13 @@ use crate::auth0::Client; use crate::config::Config; -use dashmap::DashMap; use ::stripe::Client as StripeClient; use ::stripe::RequestStrategy::ExponentialBackoff; use anyhow::Result; use axum::{middleware, Router}; +use dashmap::DashMap; use sea_orm::{ConnectOptions, Database, DatabaseConnection}; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use tokio::{select, signal}; use tower_default_headers::DefaultHeadersLayer; @@ -33,7 +33,7 @@ mod users; pub struct ApiContext { db: DatabaseConnection, config: Config, - rate_limit: DashMap, + rate_limit: DashMap, stripe_client: StripeClient, auth0_client: Client, } diff --git a/src/api/ratelimit.rs b/src/api/ratelimit.rs index 3a6feec..d96f29b 100644 --- a/src/api/ratelimit.rs +++ b/src/api/ratelimit.rs @@ -1,4 +1,8 @@ -use std::{sync::Arc, time::Instant}; +use std::{ + net::{IpAddr, Ipv4Addr}, + sync::Arc, + time::Instant, +}; use axum::{extract::State, http::Request, middleware::Next, response::IntoResponse}; @@ -6,6 +10,10 @@ use crate::error::Error; use super::ApiContext; +const TAKE_RATE: u8 = 1; +const IP_HEADER: &str = "X-Real-IP"; +const DEFAULT_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + #[derive(Debug, Clone, Copy)] pub struct TokenBucket { capacity: u8, @@ -24,10 +32,10 @@ impl TokenBucket { } } - fn take(&mut self, tokens: u8) -> bool { + fn take(&mut self) -> bool { self.update(); - if self.available_tokens >= tokens { - self.available_tokens -= tokens; + if self.available_tokens >= TAKE_RATE { + self.available_tokens -= TAKE_RATE; true } else { false @@ -54,16 +62,17 @@ pub async fn limiter( ) -> Result { let ip = req .headers() - .get("Fly-Client-IP") + .get(IP_HEADER) .and_then(|v| v.to_str().ok()) - .unwrap_or("127.0.0.1"); - let mut bucket = ctx.rate_limit.entry(ip.to_string()).or_insert_with(|| { + .and_then(|v| v.parse().ok()) + .unwrap_or(DEFAULT_IP); + let mut bucket = ctx.rate_limit.entry(ip).or_insert_with(|| { TokenBucket::new( ctx.config.rate_limit_capacity, ctx.config.rate_limit_fill_rate, ) }); - if bucket.take(1) { + if bucket.take() { Ok(next.run(req).await) } else { Err(Error::TooManyRequests)