diff --git a/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/expectations/app.rs b/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/expectations/app.rs index 4e3f1bc3e..4f0520883 100644 --- a/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/expectations/app.rs +++ b/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/expectations/app.rs @@ -26,15 +26,17 @@ fn build_router() -> matchit::Router { let mut router = matchit::Router::new(); router.insert("/any", 0u32).unwrap(); router.insert("/connect", 1u32).unwrap(); - router.insert("/delete", 2u32).unwrap(); - router.insert("/get", 3u32).unwrap(); - router.insert("/head", 4u32).unwrap(); - router.insert("/mixed", 5u32).unwrap(); - router.insert("/options", 6u32).unwrap(); - router.insert("/patch", 7u32).unwrap(); - router.insert("/post", 8u32).unwrap(); - router.insert("/put", 9u32).unwrap(); - router.insert("/trace", 10u32).unwrap(); + router.insert("/custom", 2u32).unwrap(); + router.insert("/delete", 3u32).unwrap(); + router.insert("/get", 4u32).unwrap(); + router.insert("/head", 5u32).unwrap(); + router.insert("/mixed", 6u32).unwrap(); + router.insert("/mixed_with_custom", 7u32).unwrap(); + router.insert("/options", 8u32).unwrap(); + router.insert("/patch", 9u32).unwrap(); + router.insert("/post", 10u32).unwrap(); + router.insert("/put", 11u32).unwrap(); + router.insert("/trace", 12u32).unwrap(); router } async fn route_request( @@ -49,7 +51,7 @@ async fn route_request( Ok(m) => m, Err(_) => { let allowed_methods = pavex::router::AllowedMethods::new(vec![]); - return route_11::handler(&allowed_methods).await; + return route_13::handler(&allowed_methods).await; } }; let route_id = matched_route.value; @@ -66,44 +68,58 @@ async fn route_request( let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::CONNECT], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } 2u32 => { + match &request_head.method { + s if s.as_str() == "CUSTOM" => route_11::handler().await, + _ => { + let allowed_methods = pavex::router::AllowedMethods::new( + vec![ + pavex::http::Method::try_from("CUSTOM") + .expect("Failed to parse custom method") + ], + ); + route_13::handler(&allowed_methods).await + } + } + } + 3u32 => { match &request_head.method { &pavex::http::Method::DELETE => route_1::handler().await, _ => { let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::DELETE], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } - 3u32 => { + 4u32 => { match &request_head.method { &pavex::http::Method::GET => route_2::handler().await, _ => { let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::GET], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } - 4u32 => { + 5u32 => { match &request_head.method { &pavex::http::Method::HEAD => route_3::handler().await, _ => { let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::HEAD], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } - 5u32 => { + 6u32 => { match &request_head.method { &pavex::http::Method::PATCH | &pavex::http::Method::POST => { route_10::handler().await @@ -112,62 +128,82 @@ async fn route_request( let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::PATCH, pavex::http::Method::POST], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } - 6u32 => { + 7u32 => { + match &request_head.method { + &pavex::http::Method::GET => route_12::handler().await, + s if s.as_str() == "CUSTOM" || s.as_str() == "HEY" => { + route_12::handler().await + } + _ => { + let allowed_methods = pavex::router::AllowedMethods::new( + vec![ + pavex::http::Method::try_from("CUSTOM") + .expect("Failed to parse custom method"), + pavex::http::Method::GET, + pavex::http::Method::try_from("HEY") + .expect("Failed to parse custom method") + ], + ); + route_13::handler(&allowed_methods).await + } + } + } + 8u32 => { match &request_head.method { &pavex::http::Method::OPTIONS => route_4::handler().await, _ => { let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::OPTIONS], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } - 7u32 => { + 9u32 => { match &request_head.method { &pavex::http::Method::PATCH => route_5::handler().await, _ => { let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::PATCH], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } - 8u32 => { + 10u32 => { match &request_head.method { &pavex::http::Method::POST => route_6::handler().await, _ => { let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::POST], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } - 9u32 => { + 11u32 => { match &request_head.method { &pavex::http::Method::PUT => route_7::handler().await, _ => { let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::PUT], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } - 10u32 => { + 12u32 => { match &request_head.method { &pavex::http::Method::TRACE => route_8::handler().await, _ => { let allowed_methods = pavex::router::AllowedMethods::new( vec![pavex::http::Method::TRACE], ); - route_11::handler(&allowed_methods).await + route_13::handler(&allowed_methods).await } } } @@ -241,6 +277,18 @@ pub mod route_10 { } } pub mod route_11 { + pub async fn handler() -> pavex::response::Response { + let v0 = app::handler(); + ::into_response(v0) + } +} +pub mod route_12 { + pub async fn handler() -> pavex::response::Response { + let v0 = app::handler(); + ::into_response(v0) + } +} +pub mod route_13 { pub async fn handler( v0: &pavex::router::AllowedMethods, ) -> pavex::response::Response { diff --git a/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/expectations/diagnostics.dot b/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/expectations/diagnostics.dot index eece195d7..e07f0f9b2 100644 --- a/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/expectations/diagnostics.dot +++ b/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/expectations/diagnostics.dot @@ -18,6 +18,20 @@ digraph "* /connect - 0" { 3 -> 0 [ ] } +digraph "CUSTOM /custom - 0" { + 0 [ label = "app::handler() -> pavex::response::Response"] + 1 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 0 -> 1 [ ] +} + +digraph "* /custom - 0" { + 0 [ label = "pavex::router::default_fallback(&pavex::router::AllowedMethods) -> pavex::response::Response"] + 2 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 3 [ label = "&pavex::router::AllowedMethods"] + 0 -> 2 [ ] + 3 -> 0 [ ] +} + digraph "DELETE /delete - 0" { 0 [ label = "app::handler() -> pavex::response::Response"] 1 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] @@ -74,6 +88,20 @@ digraph "* /mixed - 0" { 3 -> 0 [ ] } +digraph "CUSTOM | GET | HEY /mixed_with_custom - 0" { + 0 [ label = "app::handler() -> pavex::response::Response"] + 1 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 0 -> 1 [ ] +} + +digraph "* /mixed_with_custom - 0" { + 0 [ label = "pavex::router::default_fallback(&pavex::router::AllowedMethods) -> pavex::response::Response"] + 2 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 3 [ label = "&pavex::router::AllowedMethods"] + 0 -> 2 [ ] + 3 -> 0 [ ] +} + digraph "OPTIONS /options - 0" { 0 [ label = "app::handler() -> pavex::response::Response"] 1 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] diff --git a/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/lib.rs b/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/lib.rs index 94d4b256b..8b0dfafa8 100644 --- a/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/lib.rs +++ b/libs/pavex_cli/tests/ui_tests/blueprint/router/http_method_routing_variants/lib.rs @@ -1,8 +1,9 @@ -use pavex::f; use pavex::blueprint::{ router::{MethodGuard, ANY, CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE}, Blueprint, }; +use pavex::f; +use pavex::http::Method; pub fn handler() -> pavex::response::Response { todo!() @@ -21,9 +22,21 @@ pub fn blueprint() -> Blueprint { bp.route(TRACE, "/trace", f!(crate::handler)); bp.route(ANY, "/any", f!(crate::handler)); bp.route( - MethodGuard::new([pavex::http::Method::PATCH, pavex::http::Method::POST]), + MethodGuard::new([Method::PATCH, Method::POST]), "/mixed", f!(crate::handler), ); + let custom_method = Method::from_bytes(b"CUSTOM").unwrap(); + let custom2_method = Method::from_bytes(b"HEY").unwrap(); + bp.route( + MethodGuard::new(vec![custom_method.clone()]), + "/custom", + f!(crate::handler), + ); + bp.route( + MethodGuard::new(vec![custom_method, custom2_method, Method::GET]), + "/mixed_with_custom", + f!(crate::handler), + ); bp } diff --git a/libs/pavexc/src/compiler/codegen.rs b/libs/pavexc/src/compiler/codegen.rs index 0c07a6b30..8cbd05879 100644 --- a/libs/pavexc/src/compiler/codegen.rs +++ b/libs/pavexc/src/compiler/codegen.rs @@ -1,11 +1,12 @@ use std::collections::{BTreeMap, BTreeSet}; -use ahash::HashMap; +use ahash::{HashMap, HashSet}; use bimap::{BiBTreeMap, BiHashMap}; use cargo_manifest::{Dependency, DependencyDetail, Edition}; use guppy::graph::{ExternalSource, PackageSource}; use guppy::PackageId; use indexmap::{IndexMap, IndexSet}; +use once_cell::sync::Lazy; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use syn::{ItemEnum, ItemFn, ItemStruct}; @@ -319,6 +320,15 @@ fn get_request_dispatcher( http: &Ident, hyper: &Ident, ) -> ItemFn { + static WELL_KNOWN_METHODS: Lazy> = Lazy::new(|| { + HashSet::from_iter( + [ + "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "CONNECT", "TRACE", + ] + .into_iter(), + ) + }); + let mut route_dispatch_table = quote! {}; let server_state_ident = format_ident!("server_state"); @@ -337,18 +347,14 @@ fn get_request_dispatcher( .methods_and_pipelines .iter() .flat_map(|(methods, _)| methods) - .map(|m| match m.as_str() { - "GET" | "POST" | "PUT" | "DELETE" | "PATCH" | "HEAD" | "OPTIONS" - | "CONNECT" | "TRACE" => { - let i = format_ident!("{}", m); - quote! { + .map(|m| if WELL_KNOWN_METHODS.contains(m.as_str()) { + let i = format_ident!("{}", m); + quote! { #pavex::http::Method::#i - } } - s => { - quote! { - #pavex::http::Method::try_from(#s).unwrap() - } + } else { + quote! { + #pavex::http::Method::try_from(#m).expect("Failed to parse custom method") } }); quote! { @@ -364,7 +370,6 @@ fn get_request_dispatcher( request_scoped_bindings, &server_state_ident, ); - let methods = methods.iter().map(|m| format_ident!("{}", m)); let invocation = if request_pipeline.needs_allowed_methods(framework_items_db) { quote! { { @@ -375,10 +380,35 @@ fn get_request_dispatcher( } else { invocation }; - sub_router_dispatch_table = quote! { - #sub_router_dispatch_table - #(&#pavex::http::Method::#methods)|* => #invocation, - } + + let (well_known_methods, custom_methods) = methods + .iter() + .partition::, _>(|m| WELL_KNOWN_METHODS.contains(m.as_str())); + + if !well_known_methods.is_empty() { + let well_known_methods = well_known_methods.into_iter().map(|m| { + let m = format_ident!("{}", m); + quote! { + #pavex::http::Method::#m + } + }); + sub_router_dispatch_table = quote! { + #sub_router_dispatch_table + #(&#well_known_methods)|* => #invocation, + }; + }; + + if !custom_methods.is_empty() { + let custom_methods = custom_methods.into_iter().map(|m| { + quote! { + s.as_str() == #m + } + }); + sub_router_dispatch_table = quote! { + #sub_router_dispatch_table + s if #(#custom_methods)||* => #invocation, + }; + }; } let matched_route_template = if sub_router.needs_matched_route(framework_items_db) { let path = route_id2path.get_by_left(route_id).unwrap();