From d7fc65c7cb2698dda2142d4b52bfde3ab0f1df13 Mon Sep 17 00:00:00 2001 From: Oscar Beaumont Date: Wed, 17 Jul 2024 22:39:11 +0800 Subject: [PATCH] fix invalidation example? --- examples/axum/src/api/invalidation.rs | 32 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/examples/axum/src/api/invalidation.rs b/examples/axum/src/api/invalidation.rs index 1b7eac93..56ea3f12 100644 --- a/examples/axum/src/api/invalidation.rs +++ b/examples/axum/src/api/invalidation.rs @@ -4,11 +4,11 @@ use std::{ collections::HashMap, - future::Future, sync::{Arc, Mutex, PoisonError}, }; use async_stream::stream; +use futures::Stream; use rspc::middleware::Middleware; use serde::{Deserialize, Serialize}; use specta::Type; @@ -47,11 +47,12 @@ pub fn mount() -> Router { .procedure("get", { ::builder() // TODO: Why does `TCtx` need a hardcoded type??? - .with(invalidation(|ctx: Context, key, _result| async move { - ctx.invalidation - .tx - .send(InvalidateEvent::InvalidateKey(key)) - .unwrap(); + .with(invalidation(|ctx: Context, key, event| { + if let InvalidateEvent::InvalidateKey(k) = event { + k == key + } else { + false + } })) .mutation(|ctx, key: String| async move { let value = ctx @@ -71,7 +72,13 @@ pub fn mount() -> Router { .keys .lock() .unwrap_or_else(PoisonError::into_inner) - .insert(input.key, input.value); + .insert(input.key.clone(), input.value); + + // This will trigger invalidation + ctx.invalidation + .tx + .send(InvalidateEvent::InvalidateKey(input.key)) + .unwrap(); Ok(()) }) @@ -82,6 +89,8 @@ pub fn mount() -> Router { Ok(stream! { let mut tx = ctx.invalidation.tx.subscribe(); while let Ok(msg) = tx.recv().await { + // TODO: Run all the invalidation closures currently on `TCtx` to map the msg into a list of query keys to invalidate. + yield Ok(msg); } }) @@ -89,24 +98,25 @@ pub fn mount() -> Router { }) } -fn invalidation( - handler: impl Fn(TCtx, TInput, &Result) -> F + Send + Sync + 'static, +fn invalidation( + handler: impl Fn(TCtx, TInput, InvalidateEvent) -> bool + Send + Sync + 'static, ) -> Middleware where TError: Send + 'static, TCtx: Clone + Send + 'static, TInput: Clone + Send + 'static, TResult: Send + 'static, - F: Future + Send + 'static, { let handler = Arc::new(handler); Middleware::new(move |ctx: TCtx, input: TInput, next| { let handler = handler.clone(); async move { + // TODO: Register this with `TCtx` let ctx2 = ctx.clone(); let input2 = input.clone(); let result = next.exec(ctx, input).await; - handler(ctx2, input2, &result).await; + + // TODO: Unregister this with `TCtx` result } })