From 8a0365f0af494b753f9e19ffaa245462d89d0fe2 Mon Sep 17 00:00:00 2001 From: Luca Palmieri <20745048+LukeMathWalker@users.noreply.github.com> Date: Fri, 26 Apr 2024 08:23:26 +0200 Subject: [PATCH] fix: Allow &mut references to be held by Next's state. (#280) Two issues combined: - We were running `Next`'s state constructors through the same validation of user-registered constructors, but that's inappropriate. It's OK for them to hold &mut references, since they might be needed later in the pipeline. - There was a bug in our code to determine which request scoped components should be built. The set of prebuilt ids didn't include components that might later become relevant because they are "passed through" without being touched. --- .github/workflows/docs.yml | 2 + doc_examples/tutorial_generator/Cargo.lock | 8 +- .../expectations/stderr.txt | 12 - .../expectations/app.rs | 235 ++++++++++++++++++ .../expectations/diagnostics.dot | 87 +++++++ .../next_handles_mut_references/lib.rs | 36 +++ .../test_config.toml | 4 + .../compiler/analyses/components/db/mod.rs | 21 ++ .../analyses/processing_pipeline/pipeline.rs | 24 +- 9 files changed, 400 insertions(+), 29 deletions(-) delete mode 100644 libs/pavex_cli/tests/ui_tests/middlewares/next_handles_lifetimes/expectations/stderr.txt create mode 100644 libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/expectations/app.rs create mode 100644 libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/expectations/diagnostics.dot create mode 100644 libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/lib.rs create mode 100644 libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/test_config.toml diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 847db6fb0..0488e941c 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -445,9 +445,11 @@ jobs: --base site --exclude-loopback --exclude-path="site/api_reference/pavex/http" + --exclude-path="site/api_reference/pavex/time" --exclude-path="site/api_reference/help.html" --exclude-path="site/api_reference/settings.html" --exclude="https://doc.rust-lang.org/*" + --exclude="https://stackoverflow.com/*" --exclude="https://github.com/LukeMathWalker/pavex/edit/main/*" --exclude="https://docs.rs/**/*" --exclude-path="site/api_reference/static.files" diff --git a/doc_examples/tutorial_generator/Cargo.lock b/doc_examples/tutorial_generator/Cargo.lock index 064a4f3b3..a5ab7e446 100644 --- a/doc_examples/tutorial_generator/Cargo.lock +++ b/doc_examples/tutorial_generator/Cargo.lock @@ -749,16 +749,16 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[patch.unused]] name = "pavex" -version = "0.1.25" +version = "0.1.34" [[patch.unused]] name = "pavex_bp_schema" -version = "0.1.25" +version = "0.1.34" [[patch.unused]] name = "pavex_cli_client" -version = "0.1.25" +version = "0.1.34" [[patch.unused]] name = "pavex_tracing" -version = "0.1.25" +version = "0.1.34" diff --git a/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_lifetimes/expectations/stderr.txt b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_lifetimes/expectations/stderr.txt deleted file mode 100644 index cdccce4da..000000000 --- a/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_lifetimes/expectations/stderr.txt +++ /dev/null @@ -1,12 +0,0 @@ -ERROR: - ร— Wrapping middlewares must take an instance of `pavex::middleware::Next<_>` - โ”‚ as input parameter. - โ”‚ This middleware doesn't. - โ”‚ - โ”‚ โ•ญโ”€[src/lib.rs:14:1] - โ”‚ 14 โ”‚ let mut bp = Blueprint::new(); - โ”‚ 15 โ”‚ bp.wrap(f!(crate::mw)); - โ”‚ ยท  โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€ - โ”‚ ยท โ•ฐโ”€โ”€ The wrapping middleware was registered here - โ”‚ 16 โ”‚ bp.route(GET, "/home", f!(crate::handler)); - โ”‚ โ•ฐโ”€โ”€โ”€โ”€ \ No newline at end of file diff --git a/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/expectations/app.rs b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/expectations/app.rs new file mode 100644 index 000000000..14bc24a9d --- /dev/null +++ b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/expectations/app.rs @@ -0,0 +1,235 @@ +//! Do NOT edit this code. +//! It was automatically generated by Pavex. +//! All manual edits will be lost next time the code is generated. +extern crate alloc; +struct ServerState { + router: pavex_matchit::Router, + #[allow(dead_code)] + application_state: ApplicationState, +} +pub struct ApplicationState {} +pub async fn build_application_state() -> crate::ApplicationState { + crate::ApplicationState {} +} +pub fn run( + server_builder: pavex::server::Server, + application_state: ApplicationState, +) -> pavex::server::ServerHandle { + let server_state = std::sync::Arc::new(ServerState { + router: build_router(), + application_state, + }); + server_builder.serve(route_request, server_state) +} +fn build_router() -> pavex_matchit::Router { + let mut router = pavex_matchit::Router::new(); + router.insert("/home", 0u32).unwrap(); + router +} +async fn route_request( + request: http::Request, + _connection_info: Option, + server_state: std::sync::Arc, +) -> pavex::response::Response { + let (request_head, request_body) = request.into_parts(); + #[allow(unused)] + let request_body = pavex::request::body::RawIncomingBody::from(request_body); + let request_head: pavex::request::RequestHead = request_head.into(); + let matched_route = match server_state.router.at(&request_head.target.path()) { + Ok(m) => m, + Err(_) => { + let allowed_methods: pavex::router::AllowedMethods = pavex::router::MethodAllowList::from_iter( + vec![], + ) + .into(); + return route_1::entrypoint(&allowed_methods).await; + } + }; + let route_id = matched_route.value; + #[allow(unused)] + let url_params: pavex::request::path::RawPathParams<'_, '_> = matched_route + .params + .into(); + match route_id { + 0u32 => { + match &request_head.method { + &pavex::http::Method::GET => route_0::entrypoint().await, + _ => { + let allowed_methods: pavex::router::AllowedMethods = pavex::router::MethodAllowList::from_iter([ + pavex::http::Method::GET, + ]) + .into(); + route_1::entrypoint(&allowed_methods).await + } + } + } + i => unreachable!("Unknown route id: {}", i), + } +} +pub mod route_0 { + pub async fn entrypoint() -> pavex::response::Response { + let response = wrapping_0().await; + response + } + async fn stage_1(mut s_0: app::A) -> pavex::response::Response { + let response = wrapping_1(&mut s_0).await; + let response = post_processing_0(s_0, response).await; + response + } + async fn stage_2<'a>(s_0: &'a mut app::A) -> pavex::response::Response { + let response = handler(s_0).await; + response + } + async fn wrapping_0() -> pavex::response::Response { + let v0 = app::a(); + let v1 = crate::route_0::Next0 { + s_0: v0, + next: stage_1, + }; + let v2 = pavex::middleware::Next::new(v1); + let v3 = pavex::middleware::wrap_noop(v2).await; + ::into_response(v3) + } + async fn wrapping_1(v0: &mut app::A) -> pavex::response::Response { + let v1 = crate::route_0::Next1 { + s_0: v0, + next: stage_2, + }; + let v2 = pavex::middleware::Next::new(v1); + let v3 = app::mw(v2); + ::into_response(v3) + } + async fn post_processing_0( + v0: app::A, + v1: pavex::response::Response, + ) -> pavex::response::Response { + let v2 = app::post(v0, v1); + ::into_response(v2) + } + async fn handler(v0: &mut app::A) -> pavex::response::Response { + let v1 = app::handler(v0); + ::into_response(v1) + } + struct Next0 + where + T: std::future::Future, + { + s_0: app::A, + next: fn(app::A) -> T, + } + impl std::future::IntoFuture for Next0 + where + T: std::future::Future, + { + type Output = pavex::response::Response; + type IntoFuture = T; + fn into_future(self) -> Self::IntoFuture { + (self.next)(self.s_0) + } + } + struct Next1<'a, T> + where + T: std::future::Future, + { + s_0: &'a mut app::A, + next: fn(&'a mut app::A) -> T, + } + impl<'a, T> std::future::IntoFuture for Next1<'a, T> + where + T: std::future::Future, + { + type Output = pavex::response::Response; + type IntoFuture = T; + fn into_future(self) -> Self::IntoFuture { + (self.next)(self.s_0) + } + } +} +pub mod route_1 { + pub async fn entrypoint<'a>( + s_0: &'a pavex::router::AllowedMethods, + ) -> pavex::response::Response { + let response = wrapping_0(s_0).await; + response + } + async fn stage_1<'a>( + s_0: &'a pavex::router::AllowedMethods, + ) -> pavex::response::Response { + let response = wrapping_1(s_0).await; + let response = post_processing_0(response).await; + response + } + async fn stage_2<'a>( + s_0: &'a pavex::router::AllowedMethods, + ) -> pavex::response::Response { + let response = handler(s_0).await; + response + } + async fn wrapping_0( + v0: &pavex::router::AllowedMethods, + ) -> pavex::response::Response { + let v1 = crate::route_1::Next0 { + s_0: v0, + next: stage_1, + }; + let v2 = pavex::middleware::Next::new(v1); + let v3 = pavex::middleware::wrap_noop(v2).await; + ::into_response(v3) + } + async fn wrapping_1( + v0: &pavex::router::AllowedMethods, + ) -> pavex::response::Response { + let v1 = crate::route_1::Next1 { + s_0: v0, + next: stage_2, + }; + let v2 = pavex::middleware::Next::new(v1); + let v3 = app::mw(v2); + ::into_response(v3) + } + async fn post_processing_0( + v0: pavex::response::Response, + ) -> pavex::response::Response { + let v1 = app::a(); + let v2 = app::post(v1, v0); + ::into_response(v2) + } + async fn handler(v0: &pavex::router::AllowedMethods) -> pavex::response::Response { + let v1 = pavex::router::default_fallback(v0).await; + ::into_response(v1) + } + struct Next0<'a, T> + where + T: std::future::Future, + { + s_0: &'a pavex::router::AllowedMethods, + next: fn(&'a pavex::router::AllowedMethods) -> T, + } + impl<'a, T> std::future::IntoFuture for Next0<'a, T> + where + T: std::future::Future, + { + type Output = pavex::response::Response; + type IntoFuture = T; + fn into_future(self) -> Self::IntoFuture { + (self.next)(self.s_0) + } + } + struct Next1<'a, T> + where + T: std::future::Future, + { + s_0: &'a pavex::router::AllowedMethods, + next: fn(&'a pavex::router::AllowedMethods) -> T, + } + impl<'a, T> std::future::IntoFuture for Next1<'a, T> + where + T: std::future::Future, + { + type Output = pavex::response::Response; + type IntoFuture = T; + fn into_future(self) -> Self::IntoFuture { + (self.next)(self.s_0) + } + } +} \ No newline at end of file diff --git a/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/expectations/diagnostics.dot b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/expectations/diagnostics.dot new file mode 100644 index 000000000..27200d2bc --- /dev/null +++ b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/expectations/diagnostics.dot @@ -0,0 +1,87 @@ +digraph "GET /home - 0" { + 0 [ label = "pavex::middleware::wrap_noop(pavex::middleware::Next) -> pavex::response::Response"] + 1 [ label = "pavex::middleware::Next::new(crate::route_0::Next0) -> pavex::middleware::Next"] + 2 [ label = "crate::route_0::Next0(app::A) -> crate::route_0::Next0"] + 3 [ label = "app::a() -> app::A"] + 4 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 1 -> 0 [ ] + 2 -> 1 [ ] + 3 -> 2 [ ] + 0 -> 4 [ ] +} + +digraph "GET /home - 1" { + 0 [ label = "app::mw(pavex::middleware::Next>) -> pavex::response::Response"] + 1 [ label = "pavex::middleware::Next::new(crate::route_0::Next1<'a>) -> pavex::middleware::Next>"] + 2 [ label = "crate::route_0::Next1(&'a mut app::A) -> crate::route_0::Next1<'a>"] + 4 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 5 [ label = "&mut app::A"] + 1 -> 0 [ ] + 2 -> 1 [ ] + 0 -> 4 [ ] + 5 -> 2 [ ] +} + +digraph "GET /home - 2" { + 0 [ label = "app::post(app::A, pavex::response::Response) -> pavex::response::Response"] + 1 [ label = "app::A"] + 2 [ label = "pavex::response::Response"] + 3 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 2 -> 0 [ ] + 1 -> 0 [ ] + 0 -> 3 [ ] +} + +digraph "GET /home - 3" { + 0 [ label = "app::handler(&mut app::A) -> pavex::response::Response"] + 2 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 3 [ label = "&mut app::A"] + 0 -> 2 [ ] + 3 -> 0 [ ] +} + +digraph "* /home - 0" { + 0 [ label = "pavex::middleware::wrap_noop(pavex::middleware::Next>) -> pavex::response::Response"] + 1 [ label = "pavex::middleware::Next::new(crate::route_1::Next0<'a>) -> pavex::middleware::Next>"] + 2 [ label = "crate::route_1::Next0(&'a pavex::router::AllowedMethods) -> crate::route_1::Next0<'a>"] + 4 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 5 [ label = "&pavex::router::AllowedMethods"] + 1 -> 0 [ ] + 2 -> 1 [ ] + 0 -> 4 [ ] + 5 -> 2 [ ] +} + +digraph "* /home - 1" { + 0 [ label = "app::mw(pavex::middleware::Next>) -> pavex::response::Response"] + 1 [ label = "pavex::middleware::Next::new(crate::route_1::Next1<'a>) -> pavex::middleware::Next>"] + 2 [ label = "crate::route_1::Next1(&'a pavex::router::AllowedMethods) -> crate::route_1::Next1<'a>"] + 4 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 5 [ label = "&pavex::router::AllowedMethods"] + 1 -> 0 [ ] + 2 -> 1 [ ] + 0 -> 4 [ ] + 5 -> 2 [ ] +} + +digraph "* /home - 2" { + 0 [ label = "app::post(app::A, pavex::response::Response) -> pavex::response::Response"] + 1 [ label = "app::a() -> app::A"] + 2 [ label = "pavex::response::Response"] + 3 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 2 -> 0 [ ] + 1 -> 0 [ ] + 0 -> 3 [ ] +} + +digraph "* /home - 3" { + 0 [ label = "pavex::router::default_fallback(&pavex::router::AllowedMethods) -> pavex::response::Response"] + 2 [ label = "::into_response(pavex::response::Response) -> pavex::response::Response"] + 3 [ label = "&pavex::router::AllowedMethods"] + 0 -> 2 [ ] + 3 -> 0 [ ] +} + +digraph app_state { + 0 [ label = "crate::ApplicationState() -> crate::ApplicationState"] +} \ No newline at end of file diff --git a/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/lib.rs b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/lib.rs new file mode 100644 index 000000000..193323036 --- /dev/null +++ b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/lib.rs @@ -0,0 +1,36 @@ +use std::future::IntoFuture; + +use pavex::blueprint::{constructor::Lifecycle, router::GET, Blueprint}; +use pavex::f; +use pavex::middleware::Next; +use pavex::response::Response; + +pub struct A; + +pub fn a() -> A { + A +} + +pub fn mw(_next: Next) -> Response +where + T: IntoFuture, +{ + todo!() +} + +pub fn post(_a: A, _r: Response) -> Response { + todo!() +} + +pub fn handler(_a: &mut A) -> Response { + todo!() +} + +pub fn blueprint() -> Blueprint { + let mut bp = Blueprint::new(); + bp.request_scoped(f!(crate::a)); + bp.post_process(f!(crate::post)); + bp.wrap(f!(crate::mw)); + bp.route(GET, "/home", f!(crate::handler)); + bp +} diff --git a/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/test_config.toml b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/test_config.toml new file mode 100644 index 000000000..6266b245c --- /dev/null +++ b/libs/pavex_cli/tests/ui_tests/middlewares/next_handles_mut_references/test_config.toml @@ -0,0 +1,4 @@ +description = """The state for `Next` can handle &mut references as fields.""" + +[expectations] +codegen = "pass" diff --git a/libs/pavexc/src/compiler/analyses/components/db/mod.rs b/libs/pavexc/src/compiler/analyses/components/db/mod.rs index 582db3e45..108026ff9 100644 --- a/libs/pavexc/src/compiler/analyses/components/db/mod.rs +++ b/libs/pavexc/src/compiler/analyses/components/db/mod.rs @@ -1159,6 +1159,27 @@ impl ComponentDb { Ok(self.get_or_intern(constructor_component, computation_db)) } + /// Raw access, only used for synthetic constructors (in particular, `Next`'s state + /// constructors). + pub fn get_or_intern_constructor_without_validation( + &mut self, + callable_id: ComputationId, + lifecycle: Lifecycle, + scope_id: ScopeId, + cloning_strategy: CloningStrategy, + computation_db: &mut ComputationDb, + derived_from: Option, + ) -> Result { + let constructor_component = UnregisteredComponent::SyntheticConstructor { + lifecycle, + computation_id: callable_id, + scope_id, + cloning_strategy, + derived_from, + }; + Ok(self.get_or_intern(constructor_component, computation_db)) + } + pub fn get_or_intern_wrapping_middleware( &mut self, callable: Cow<'_, Callable>, diff --git a/libs/pavexc/src/compiler/analyses/processing_pipeline/pipeline.rs b/libs/pavexc/src/compiler/analyses/processing_pipeline/pipeline.rs index 146c60a39..8d741cc59 100644 --- a/libs/pavexc/src/compiler/analyses/processing_pipeline/pipeline.rs +++ b/libs/pavexc/src/compiler/analyses/processing_pipeline/pipeline.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use ahash::{HashMap, HashMapExt}; +use ahash::{HashMap, HashMapExt, HashSet}; use guppy::graph::PackageGraph; use guppy::PackageId; use indexmap::{IndexMap, IndexSet}; @@ -54,6 +54,7 @@ pub struct Stage { pub(crate) pre_processing_ids: Vec, } +#[derive(Debug)] struct PipelineIds(Vec); impl PipelineIds { @@ -65,6 +66,7 @@ impl PipelineIds { } } +#[derive(Debug)] struct StageIds { pre_processing_ids: Vec, /// Either a wrapping middleware or a request handler. @@ -244,19 +246,16 @@ impl RequestHandlerPipeline { let call_graph = &id2call_graphs[&middleware_id]; let mut prebuilt_ids = IndexSet::new(); - let required_scope_ids = + let required_scope_ids: HashSet<_> = extract_request_scoped_compute_nodes(&call_graph.call_graph, component_db) - .collect_vec(); - for &request_scoped_id in &required_scope_ids { - let Some(&built_at) = - request_scoped2built_at_stage_index.get(&request_scoped_id) - else { - continue; - }; + .collect(); + for (request_scoped_id, &built_at) in &request_scoped2built_at_stage_index { if built_at < stage_index { - prebuilt_ids.insert(request_scoped_id); + prebuilt_ids.insert(*request_scoped_id); } else if built_at == stage_index { - assert!(component_db.is_wrapping_middleware(middleware_id)); + if required_scope_ids.contains(request_scoped_id) { + assert!(component_db.is_wrapping_middleware(middleware_id)); + } } } middleware_id2prebuilt_rs_ids.insert(middleware_id, prebuilt_ids.clone()); @@ -397,13 +396,12 @@ impl RequestHandlerPipeline { let next_state_callable_id = computation_db.get_or_intern(next_state_constructor); let next_state_scope_id = component_db.scope_id(middleware_id); let next_state_constructor_id = component_db - .get_or_intern_constructor( + .get_or_intern_constructor_without_validation( next_state_callable_id, Lifecycle::RequestScoped, next_state_scope_id, CloningStrategy::NeverClone, computation_db, - framework_item_db, None, ) .unwrap();