Skip to content

Commit

Permalink
use IpAddr type for ratelimit map key instead of String
Browse files Browse the repository at this point in the history
  • Loading branch information
jakewmeyer committed Sep 27, 2023
1 parent 75a7b90 commit db16007
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

use crate::auth0::Client;
use crate::config::Config;
use dashmap::DashMap;
use ::stripe::Client as StripeClient;
use ::stripe::RequestStrategy::ExponentialBackoff;
use anyhow::Result;
use axum::{middleware, Router};
use dashmap::DashMap;
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tokio::{select, signal};
use tower_default_headers::DefaultHeadersLayer;
Expand All @@ -33,7 +33,7 @@ mod users;
pub struct ApiContext {
db: DatabaseConnection,
config: Config,
rate_limit: DashMap<String, TokenBucket>,
rate_limit: DashMap<IpAddr, TokenBucket>,
stripe_client: StripeClient,
auth0_client: Client,
}
Expand Down
25 changes: 17 additions & 8 deletions src/api/ratelimit.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
use std::{sync::Arc, time::Instant};
use std::{
net::{IpAddr, Ipv4Addr},
sync::Arc,
time::Instant,
};

use axum::{extract::State, http::Request, middleware::Next, response::IntoResponse};

use crate::error::Error;

use super::ApiContext;

const TAKE_RATE: u8 = 1;
const IP_HEADER: &str = "X-Real-IP";
const DEFAULT_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));

#[derive(Debug, Clone, Copy)]
pub struct TokenBucket {
capacity: u8,
Expand All @@ -24,10 +32,10 @@ impl TokenBucket {
}
}

fn take(&mut self, tokens: u8) -> bool {
fn take(&mut self) -> bool {
self.update();
if self.available_tokens >= tokens {
self.available_tokens -= tokens;
if self.available_tokens >= TAKE_RATE {
self.available_tokens -= TAKE_RATE;
true
} else {
false
Expand All @@ -54,16 +62,17 @@ pub async fn limiter<B>(
) -> Result<impl IntoResponse, Error> {
let ip = req
.headers()
.get("Fly-Client-IP")
.get(IP_HEADER)
.and_then(|v| v.to_str().ok())
.unwrap_or("127.0.0.1");
let mut bucket = ctx.rate_limit.entry(ip.to_string()).or_insert_with(|| {
.and_then(|v| v.parse().ok())
.unwrap_or(DEFAULT_IP);
let mut bucket = ctx.rate_limit.entry(ip).or_insert_with(|| {
TokenBucket::new(
ctx.config.rate_limit_capacity,
ctx.config.rate_limit_fill_rate,
)
});
if bucket.take(1) {
if bucket.take() {
Ok(next.run(req).await)
} else {
Err(Error::TooManyRequests)
Expand Down

0 comments on commit db16007

Please sign in to comment.