diff --git a/quickwit/Cargo.lock b/quickwit/Cargo.lock index ffd9a62060a..ff7fbf56982 100644 --- a/quickwit/Cargo.lock +++ b/quickwit/Cargo.lock @@ -4476,6 +4476,7 @@ version = "0.6.1" dependencies = [ "async-trait", "dyn-clone", + "futures", "http", "hyper", "mockall", diff --git a/quickwit/quickwit-codegen/example/Cargo.toml b/quickwit/quickwit-codegen/example/Cargo.toml index 02596b24b98..f77911bbe2e 100644 --- a/quickwit/quickwit-codegen/example/Cargo.toml +++ b/quickwit/quickwit-codegen/example/Cargo.toml @@ -12,6 +12,7 @@ documentation = "https://quickwit.io/docs/" [dependencies] async-trait = { workspace = true } dyn-clone = { workspace = true } +futures = { workspace = true } http = { workspace = true } hyper = { workspace = true } prost = { workspace = true } diff --git a/quickwit/quickwit-codegen/example/src/codegen/hello.rs b/quickwit/quickwit-codegen/example/src/codegen/hello.rs index 81d7f9f3d8d..17288e9c4c4 100644 --- a/quickwit/quickwit-codegen/example/src/codegen/hello.rs +++ b/quickwit/quickwit-codegen/example/src/codegen/hello.rs @@ -41,7 +41,8 @@ pub struct PingResponse { pub message: ::prost::alloc::string::String, } /// BEGIN quickwit-codegen -type HelloStream = quickwit_common::ServiceStream; +use tower::{Layer, Service, ServiceExt}; +pub type HelloStream = quickwit_common::ServiceStream>; #[cfg_attr(any(test, feature = "testsuite"), mockall::automock)] #[async_trait::async_trait] pub trait Hello: std::fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static { @@ -55,7 +56,7 @@ pub trait Hello: std::fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static { ) -> crate::HelloResult; async fn ping( &mut self, - request: PingRequest, + request: quickwit_common::ServiceStream, ) -> crate::HelloResult>; } dyn_clone::clone_trait_object!(Hello); @@ -127,7 +128,7 @@ impl Hello for HelloClient { } async fn ping( &mut self, - request: PingRequest, + request: quickwit_common::ServiceStream, ) -> crate::HelloResult> { self.inner.ping(request).await } @@ -155,7 +156,7 @@ pub mod mock { } async fn ping( &mut self, - request: PingRequest, + request: quickwit_common::ServiceStream, ) -> crate::HelloResult> { self.inner.lock().await.ping(request).await } @@ -204,7 +205,7 @@ impl tower::Service for Box { Box::pin(fut) } } -impl tower::Service for Box { +impl tower::Service> for Box { type Response = HelloStream; type Error = crate::HelloError; type Future = BoxFuture; @@ -214,7 +215,10 @@ impl tower::Service for Box { ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } - fn call(&mut self, request: PingRequest) -> Self::Future { + fn call( + &mut self, + request: quickwit_common::ServiceStream, + ) -> Self::Future { let mut svc = self.clone(); let fut = async move { svc.ping(request).await }; Box::pin(fut) @@ -234,7 +238,7 @@ struct HelloTowerBlock { crate::HelloError, >, ping_svc: quickwit_common::tower::BoxService< - PingRequest, + quickwit_common::ServiceStream, HelloStream, crate::HelloError, >, @@ -264,7 +268,7 @@ impl Hello for HelloTowerBlock { } async fn ping( &mut self, - request: PingRequest, + request: quickwit_common::ServiceStream, ) -> crate::HelloResult> { self.ping_svc.ready().await?.call(request).await } @@ -293,7 +297,7 @@ pub struct HelloTowerBlockBuilder { ping_layer: Option< quickwit_common::tower::BoxLayer< Box, - PingRequest, + quickwit_common::ServiceStream, HelloStream, crate::HelloError, >, @@ -316,11 +320,13 @@ impl HelloTowerBlockBuilder { > + Clone + Send + Sync + 'static, >::Future: Send + 'static, L::Service: tower::Service< - PingRequest, + quickwit_common::ServiceStream, Response = HelloStream, Error = crate::HelloError, > + Clone + Send + Sync + 'static, - >::Future: Send + 'static, + , + >>::Future: Send + 'static, { self.hello_layer = Some(quickwit_common::tower::BoxLayer::new(layer.clone())); self.goodbye_layer = Some(quickwit_common::tower::BoxLayer::new(layer.clone())); @@ -357,11 +363,13 @@ impl HelloTowerBlockBuilder { where L: tower::Layer> + Send + Sync + 'static, L::Service: tower::Service< - PingRequest, + quickwit_common::ServiceStream, Response = HelloStream, Error = crate::HelloError, > + Clone + Send + Sync + 'static, - >::Future: Send + 'static, + , + >>::Future: Send + 'static, { self.ping_layer = Some(quickwit_common::tower::BoxLayer::new(layer)); self @@ -460,13 +468,12 @@ impl Clone for HelloMailbox { Self { inner } } } -use tower::{Layer, Service, ServiceExt}; impl tower::Service for HelloMailbox where A: quickwit_actors::Actor + quickwit_actors::DeferableReplyHandler> + Send + 'static, - M: std::fmt::Debug + Send + Sync + 'static, + M: std::fmt::Debug + Send + 'static, T: Send + 'static, E: std::fmt::Debug + Send + 'static, crate::HelloError: From>, @@ -494,7 +501,7 @@ where #[async_trait::async_trait] impl Hello for HelloMailbox where - A: quickwit_actors::Actor + std::fmt::Debug + Send + Sync + 'static, + A: quickwit_actors::Actor + std::fmt::Debug, HelloMailbox< A, >: tower::Service< @@ -510,7 +517,7 @@ where Future = BoxFuture, > + tower::Service< - PingRequest, + quickwit_common::ServiceStream, Response = HelloStream, Error = crate::HelloError, Future = BoxFuture, crate::HelloError>, @@ -530,7 +537,7 @@ where } async fn ping( &mut self, - request: PingRequest, + request: quickwit_common::ServiceStream, ) -> crate::HelloResult> { self.call(request).await } @@ -576,15 +583,15 @@ where } async fn ping( &mut self, - request: PingRequest, + request: quickwit_common::ServiceStream, ) -> crate::HelloResult> { self.inner .ping(request) .await .map(|response| { - let stream = response.into_inner(); - let service_stream = quickwit_common::ServiceStream::from(stream); - service_stream.map_err(|error| error.into()) + let streaming: tonic::Streaming<_> = response.into_inner(); + let stream = quickwit_common::ServiceStream::from(streaming); + stream.map_err(|error| error.into()) }) .map_err(|error| error.into()) } @@ -625,14 +632,17 @@ impl hello_grpc_server::HelloGrpc for HelloGrpcServerAdapter { .map(tonic::Response::new) .map_err(|error| error.into()) } - type PingStream = quickwit_common::ServiceStream; + type PingStream = quickwit_common::ServiceStream>; async fn ping( &self, - request: tonic::Request, + request: tonic::Request>, ) -> Result, tonic::Status> { self.inner .clone() - .ping(request.into_inner()) + .ping({ + let streaming: tonic::Streaming<_> = request.into_inner(); + quickwit_common::ServiceStream::from(streaming) + }) .await .map(|stream| tonic::Response::new(stream.map_err(|error| error.into()))) .map_err(|error| error.into()) @@ -766,7 +776,7 @@ pub mod hello_grpc_client { } pub async fn ping( &mut self, - request: impl tonic::IntoRequest, + request: impl tonic::IntoStreamingRequest, ) -> std::result::Result< tonic::Response>, tonic::Status, @@ -782,9 +792,9 @@ pub mod hello_grpc_client { })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static("/hello.Hello/Ping"); - let mut req = request.into_request(); + let mut req = request.into_streaming_request(); req.extensions_mut().insert(GrpcMethod::new("hello.Hello", "Ping")); - self.inner.server_streaming(req, path, codec).await + self.inner.streaming(req, path, codec).await } } } @@ -811,7 +821,7 @@ pub mod hello_grpc_server { + 'static; async fn ping( &self, - request: tonic::Request, + request: tonic::Request>, ) -> std::result::Result, tonic::Status>; } #[derive(Debug)] @@ -982,7 +992,7 @@ pub mod hello_grpc_server { struct PingSvc(pub Arc); impl< T: HelloGrpc, - > tonic::server::ServerStreamingService + > tonic::server::StreamingService for PingSvc { type Response = super::PingResponse; type ResponseStream = T::PingStream; @@ -992,7 +1002,7 @@ pub mod hello_grpc_server { >; fn call( &mut self, - request: tonic::Request, + request: tonic::Request>, ) -> Self::Future { let inner = Arc::clone(&self.0); let fut = async move { (*inner).ping(request).await }; @@ -1017,7 +1027,7 @@ pub mod hello_grpc_server { max_decoding_message_size, max_encoding_message_size, ); - let res = grpc.server_streaming(method, req).await; + let res = grpc.streaming(method, req).await; Ok(res) }; Box::pin(fut) diff --git a/quickwit/quickwit-codegen/example/src/hello.proto b/quickwit/quickwit-codegen/example/src/hello.proto index 97356c61b22..77b69c6a473 100644 --- a/quickwit/quickwit-codegen/example/src/hello.proto +++ b/quickwit/quickwit-codegen/example/src/hello.proto @@ -48,5 +48,5 @@ message PingResponse { service Hello { rpc Hello(HelloRequest) returns (HelloResponse); rpc Goodbye(GoodbyeRequest) returns (GoodbyeResponse); - rpc Ping(PingRequest) returns (stream PingResponse); + rpc Ping(stream PingRequest) returns (stream PingResponse); } diff --git a/quickwit/quickwit-codegen/example/src/lib.rs b/quickwit/quickwit-codegen/example/src/lib.rs index 055a94d04f5..62aa573f67b 100644 --- a/quickwit/quickwit-codegen/example/src/lib.rs +++ b/quickwit/quickwit-codegen/example/src/lib.rs @@ -28,6 +28,7 @@ use std::task::{Context, Poll}; use std::time::Duration; use async_trait::async_trait; +use futures::StreamExt; use quickwit_common::ServiceStream; use tower::{Layer, Service}; @@ -76,21 +77,35 @@ impl Layer for CounterLayer { } } -fn spawn_ping_response_stream(name: String) -> ServiceStream { +fn spawn_ping_response_stream( + mut request_stream: ServiceStream, +) -> ServiceStream> { let (ping_tx, service_stream) = ServiceStream::new_bounded(1); let future = async move { + let mut name = "".to_string(); let mut interval = tokio::time::interval(Duration::from_millis(100)); - loop { - interval.tick().await; - if ping_tx - .send(Ok(PingResponse { - message: format!("Pong, {}!", name), - })) - .await - .is_err() - { - break; + loop { + tokio::select! { + request_opt = request_stream.next() => { + match request_opt { + Some(request) => name = request.name, + _ => break, + }; + } + _ = interval.tick() => { + if name.is_empty() { + continue; + } + if name == "stop" { + break; + } + if ping_tx.send(Ok(PingResponse { + message: format!("Pong, {name}!") + })).await.is_err() { + break; + } + } } } }; @@ -103,13 +118,13 @@ struct HelloImpl; #[async_trait] impl Hello for HelloImpl { - async fn hello(&mut self, request: HelloRequest) -> crate::HelloResult { + async fn hello(&mut self, request: HelloRequest) -> HelloResult { Ok(HelloResponse { message: format!("Hello, {}!", request.name), }) } - async fn goodbye(&mut self, request: GoodbyeRequest) -> crate::HelloResult { + async fn goodbye(&mut self, request: GoodbyeRequest) -> HelloResult { Ok(GoodbyeResponse { message: format!("Goodbye, {}!", request.name), }) @@ -117,9 +132,9 @@ impl Hello for HelloImpl { async fn ping( &mut self, - request: PingRequest, - ) -> crate::HelloResult> { - Ok(spawn_ping_response_stream(request.name)) + request: ServiceStream, + ) -> HelloResult> { + Ok(spawn_ping_response_stream(request)) } } @@ -131,6 +146,7 @@ mod tests { use quickwit_actors::{Actor, ActorContext, ActorExitStatus, Handler, Universe}; use quickwit_common::tower::{BalanceChannel, Change}; + use tokio::sync::mpsc::error::TrySendError; use tokio_stream::StreamExt; use tonic::transport::{Endpoint, Server}; use tower::timeout::Timeout; @@ -139,7 +155,7 @@ mod tests { use crate::hello::hello_grpc_server::HelloGrpcServer; use crate::hello::MockHello; use crate::hello_grpc_client::HelloGrpcClient; - use crate::{CounterLayer, GoodbyeRequest, GoodbyeResponse, HelloError}; + use crate::{CounterLayer, GoodbyeRequest, GoodbyeResponse}; #[tokio::test] async fn test_hello_codegen() { @@ -171,16 +187,40 @@ mod tests { } ); - let mut pong_stream = client - .ping(PingRequest { + let (ping_stream_tx, ping_stream) = ServiceStream::new_bounded(1); + let mut pong_stream = client.ping(ping_stream).await.unwrap(); + + ping_stream_tx + .try_send(PingRequest { name: "World".to_string(), }) - .await .unwrap(); assert_eq!( pong_stream.next().await.unwrap().unwrap().message, "Pong, World!" ); + ping_stream_tx + .try_send(PingRequest { + name: "Mundo".to_string(), + }) + .unwrap(); + assert_eq!( + pong_stream.next().await.unwrap().unwrap().message, + "Pong, Mundo!" + ); + ping_stream_tx + .try_send(PingRequest { + name: "stop".to_string(), + }) + .unwrap(); + assert!(pong_stream.next().await.is_none()); + + let error = ping_stream_tx + .try_send(PingRequest { + name: "stop".to_string(), + }) + .unwrap_err(); + assert!(matches!(error, TrySendError::Closed(_))); let mut mock_hello = MockHello::new(); @@ -189,6 +229,7 @@ mod tests { message: "Hello, Mock!".to_string(), }) }); + assert_eq!( mock_hello .hello(HelloRequest { @@ -200,8 +241,11 @@ mod tests { message: "Hello, Mock!".to_string() } ); + } - let grpc_server_adapter = HelloGrpcServerAdapter::new(hello.clone()); + #[tokio::test] + async fn test_hello_codegen_grpc() { + let grpc_server_adapter = HelloGrpcServerAdapter::new(HelloImpl); let grpc_server = HelloGrpcServer::new(grpc_server_adapter); let addr: SocketAddr = "127.0.0.1:6666".parse().unwrap(); @@ -223,26 +267,45 @@ mod tests { assert_eq!( grpc_client .hello(HelloRequest { - name: "Client".to_string() + name: "gRPC client".to_string() }) .await .unwrap(), HelloResponse { - message: "Hello, Client!".to_string() + message: "Hello, gRPC client!".to_string() } ); - let mut pong_stream = grpc_client - .ping(PingRequest { - name: "Client".to_string(), + let (ping_stream_tx, ping_stream) = ServiceStream::new_bounded(1); + let mut pong_stream = grpc_client.ping(ping_stream).await.unwrap(); + + ping_stream_tx + .try_send(PingRequest { + name: "gRPC client".to_string(), }) - .await .unwrap(); assert_eq!( pong_stream.next().await.unwrap().unwrap().message, - "Pong, Client!" + "Pong, gRPC client!" ); + ping_stream_tx + .try_send(PingRequest { + name: "stop".to_string(), + }) + .unwrap(); + assert!(pong_stream.next().await.is_none()); + + let error = ping_stream_tx + .try_send(PingRequest { + name: "stop".to_string(), + }) + .unwrap_err(); + assert!(matches!(error, TrySendError::Closed(_))); + } + + #[tokio::test] + async fn test_hello_codegen_actor() { #[derive(Debug)] struct HelloActor; @@ -283,15 +346,15 @@ mod tests { } #[async_trait] - impl Handler for HelloActor { - type Reply = HelloResult>; + impl Handler> for HelloActor { + type Reply = HelloResult>; async fn handle( &mut self, - message: PingRequest, + message: ServiceStream, _ctx: &ActorContext, ) -> Result { - Ok(Ok(spawn_ping_response_stream(message.name))) + Ok(Ok(spawn_ping_response_stream(message))) } } @@ -312,11 +375,13 @@ mod tests { } ); - let mut pong_stream = actor_client - .ping(PingRequest { + let (ping_stream_tx, ping_stream) = ServiceStream::new_bounded(1); + let mut pong_stream = actor_client.ping(ping_stream).await.unwrap(); + + ping_stream_tx + .try_send(PingRequest { name: "beautiful actor".to_string(), }) - .await .unwrap(); assert_eq!( pong_stream.next().await.unwrap().unwrap().message, @@ -349,15 +414,17 @@ mod tests { } ); - let mut pong_stream = actor_client - .ping(PingRequest { - name: "Tower actor".to_string(), + let (ping_stream_tx, ping_stream) = ServiceStream::new_bounded(1); + let mut pong_stream = actor_client.ping(ping_stream).await.unwrap(); + + ping_stream_tx + .try_send(PingRequest { + name: "beautiful Tower actor".to_string(), }) - .await .unwrap(); assert_eq!( pong_stream.next().await.unwrap().unwrap().message, - "Pong, Tower actor!" + "Pong, beautiful Tower actor!" ); universe.assert_quit().await; @@ -389,12 +456,18 @@ mod tests { .await .unwrap(); - hello_tower - .ping(PingRequest { + let (ping_stream_tx, ping_stream) = ServiceStream::new_bounded(1); + let mut pong_stream = hello_tower.ping(ping_stream).await.unwrap(); + + ping_stream_tx + .try_send(PingRequest { name: "Tower".to_string(), }) - .await .unwrap(); + assert_eq!( + pong_stream.next().await.unwrap().unwrap().message, + "Pong, Tower!" + ); assert_eq!(hello_layer.counter.load(Ordering::Relaxed), 1); assert_eq!(goodbye_layer.counter.load(Ordering::Relaxed), 1); @@ -423,12 +496,18 @@ mod tests { .await .unwrap(); - hello_tower - .ping(PingRequest { + let (ping_stream_tx, ping_stream) = ServiceStream::new_bounded(1); + let mut pong_stream = hello_tower.ping(ping_stream).await.unwrap(); + + ping_stream_tx + .try_send(PingRequest { name: "Tower".to_string(), }) - .await .unwrap(); + assert_eq!( + pong_stream.next().await.unwrap().unwrap().message, + "Pong, Tower!" + ); assert_eq!(layer.counter.load(Ordering::Relaxed), 3); } diff --git a/quickwit/quickwit-codegen/src/codegen.rs b/quickwit/quickwit-codegen/src/codegen.rs index 134d1c036ae..8f541dc47a0 100644 --- a/quickwit/quickwit-codegen/src/codegen.rs +++ b/quickwit/quickwit-codegen/src/codegen.rs @@ -129,13 +129,13 @@ impl CodegenContext { let stream_type = quote::format_ident!("{}Stream", service.name); let stream_type_alias = if service.methods.iter().any(|method| method.server_streaming) { quote! { - type #stream_type = quickwit_common::ServiceStream; + pub type #stream_type = quickwit_common::ServiceStream<#result_type>; } } else { TokenStream::new() }; - let methods = parse_methods(&service.methods); + let methods = SynMethod::parse_prost_methods(&service.methods); let client_name = quote::format_ident!("{}Client", service.name); let tower_block_name = quote::format_ident!("{}TowerBlock", service.name); @@ -189,6 +189,9 @@ fn generate_all(service: &Service, result_type_path: &str, error_type_path: &str quote! { // The line below is necessary to opt out of the license header check. /// BEGIN quickwit-codegen + + use tower::{Layer, Service, ServiceExt}; + #stream_type_alias #service_trait @@ -216,31 +219,51 @@ struct SynMethod { proto_name: Ident, request_type: syn::Path, response_type: syn::Path, + client_streaming: bool, server_streaming: bool, } -fn parse_methods(methods: &[Method]) -> Vec { - let mut syn_methods = Vec::with_capacity(methods.len()); +impl SynMethod { + fn request_type(&self) -> TokenStream { + if self.client_streaming { + let request_type = &self.request_type; + quote! { quickwit_common::ServiceStream<#request_type> } + } else { + self.request_type.to_token_stream() + } + } - for method in methods { - let name = quote::format_ident!("{}", method.name); - let proto_name = quote::format_ident!("{}", method.proto_name); - let request_type = syn::parse_str::(&method.input_type).unwrap(); - let response_type = syn::parse_str::(&method.output_type).unwrap(); + fn response_type(&self, context: &CodegenContext) -> TokenStream { + if self.server_streaming { + let stream_type = &context.stream_type; + let response_type = &self.response_type; + quote! { #stream_type<#response_type> } + } else { + self.response_type.to_token_stream() + } + } - if method.client_streaming && method.server_streaming { - panic!("Client-side or bidirectional streaming RPCs are not supported."); + fn parse_prost_methods(methods: &[Method]) -> Vec { + let mut syn_methods = Vec::with_capacity(methods.len()); + + for method in methods { + let name = quote::format_ident!("{}", method.name); + let proto_name = quote::format_ident!("{}", method.proto_name); + let request_type = syn::parse_str::(&method.input_type).unwrap(); + let response_type = syn::parse_str::(&method.output_type).unwrap(); + + let syn_method = SynMethod { + name, + proto_name, + request_type, + response_type, + client_streaming: method.client_streaming, + server_streaming: method.server_streaming, + }; + syn_methods.push(syn_method); } - let syn_method = SynMethod { - name, - proto_name, - request_type, - response_type, - server_streaming: method.server_streaming, - }; - syn_methods.push(syn_method); + syn_methods } - syn_methods } fn generate_service_trait(context: &CodegenContext) -> TokenStream { @@ -268,19 +291,13 @@ fn generate_service_trait(context: &CodegenContext) -> TokenStream { fn generate_service_trait_methods(context: &CodegenContext) -> TokenStream { let result_type = &context.result_type; - let stream_type = &context.stream_type; let mut stream = TokenStream::new(); for syn_method in &context.methods { let method_name = syn_method.name.to_token_stream(); - let request_type = syn_method.request_type.to_token_stream(); - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); let method = quote! { async fn #method_name(&mut self, request: #request_type) -> #result_type<#response_type>; }; @@ -295,8 +312,8 @@ fn generate_client(context: &CodegenContext) -> TokenStream { let grpc_client_adapter_name = &context.grpc_client_adapter_name; let grpc_client_package_name = &context.grpc_client_package_name; let grpc_client_name = &context.grpc_client_name; - let client_methods = generate_client_methods(context); - let mock_methods = generate_mock_methods(context); + let client_methods = generate_client_methods(context, false); + let mock_methods = generate_client_methods(context, true); let mailbox_name = &context.mailbox_name; let tower_block_builder_name = &context.tower_block_builder_name; let mock_name = &context.mock_name; @@ -382,51 +399,28 @@ fn generate_client(context: &CodegenContext) -> TokenStream { } } -fn generate_client_methods(context: &CodegenContext) -> TokenStream { +fn generate_client_methods(context: &CodegenContext, mock: bool) -> TokenStream { let result_type = &context.result_type; - let stream_type = &context.stream_type; let mut stream = TokenStream::new(); for syn_method in &context.methods { let method_name = syn_method.name.to_token_stream(); - let request_type = syn_method.request_type.to_token_stream(); + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; - let method = quote! { - async fn #method_name(&mut self, request: #request_type) -> #result_type<#response_type> { + let body = if !mock { + quote! { self.inner.#method_name(request).await } - }; - stream.extend(method); - } - stream -} - -fn generate_mock_methods(context: &CodegenContext) -> TokenStream { - let result_type = &context.result_type; - let stream_type = &context.stream_type; - - let mut stream = TokenStream::new(); - - for syn_method in &context.methods { - let method_name = syn_method.name.to_token_stream(); - let request_type = syn_method.request_type.to_token_stream(); - - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } } else { - syn_method.response_type.to_token_stream() + quote! { + self.inner.lock().await.#method_name(request).await + } }; let method = quote! { async fn #method_name(&mut self, request: #request_type) -> #result_type<#response_type> { - self.inner.lock().await.#method_name(request).await + #body } }; stream.extend(method); @@ -437,20 +431,13 @@ fn generate_mock_methods(context: &CodegenContext) -> TokenStream { fn generate_tower_services(context: &CodegenContext) -> TokenStream { let service_name = &context.service_name; let error_type = &context.error_type; - let stream_type = &context.stream_type; let mut stream = TokenStream::new(); for syn_method in &context.methods { let method_name = syn_method.name.to_token_stream(); - let request_type = syn_method.request_type.to_token_stream(); - - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); let service = quote! { impl tower::Service<#request_type> for Box { @@ -498,20 +485,14 @@ fn generate_tower_block(context: &CodegenContext) -> TokenStream { fn generate_tower_block_attributes(context: &CodegenContext) -> TokenStream { let error_type = &context.error_type; - let stream_type = &context.stream_type; let mut stream = TokenStream::new(); for syn_method in &context.methods { let attribute_name = quote::format_ident!("{}_svc", syn_method.name); - let request_type = syn_method.request_type.to_token_stream(); + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; let attribute = quote! { #attribute_name: quickwit_common::tower::BoxService<#request_type, #response_type, #error_type>, }; @@ -549,21 +530,15 @@ fn generate_tower_block_service_impl(context: &CodegenContext) -> TokenStream { let tower_block_name = &context.tower_block_name; let result_type = &context.result_type; - let stream_type = &context.stream_type; let mut methods = TokenStream::new(); for syn_method in &context.methods { let attribute_name = quote::format_ident!("{}_svc", syn_method.name); let method_name = syn_method.name.to_token_stream(); - let request_type = syn_method.request_type.to_token_stream(); + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; let attribute = quote! { async fn #method_name(&mut self, request: #request_type) -> #result_type<#response_type> { self.#attribute_name.ready().await?.call(request).await @@ -598,20 +573,14 @@ fn generate_tower_block_builder(context: &CodegenContext) -> TokenStream { fn generate_tower_block_builder_attributes(context: &CodegenContext) -> TokenStream { let service_name = &context.service_name; let error_type = &context.error_type; - let stream_type = &context.stream_type; let mut stream = TokenStream::new(); for syn_method in &context.methods { let attribute_name = quote::format_ident!("{}_layer", syn_method.name); - let request_type = syn_method.request_type.to_token_stream(); + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; let attribute = quote! { #[allow(clippy::type_complexity)] #attribute_name: Option, #request_type, #response_type, #error_type>>, @@ -628,7 +597,6 @@ fn generate_tower_block_builder_impl(context: &CodegenContext) -> TokenStream { let tower_block_name = &context.tower_block_name; let tower_block_builder_name = &context.tower_block_builder_name; let error_type = &context.error_type; - let stream_type = &context.stream_type; let mut layer_method_bounds = TokenStream::new(); let mut layer_method_statements = TokenStream::new(); @@ -639,14 +607,8 @@ fn generate_tower_block_builder_impl(context: &CodegenContext) -> TokenStream { for (i, syn_method) in context.methods.iter().enumerate() { let layer_attribute_name = quote::format_ident!("{}_layer", syn_method.name); let svc_attribute_name = quote::format_ident!("{}_svc", syn_method.name); - let request_type = syn_method.request_type.to_token_stream(); - - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); let layer_method_bound = quote! { L::Service: tower::Service<#request_type, Response = #response_type, Error = #error_type> + Clone + Send + Sync + 'static, @@ -797,12 +759,10 @@ fn generate_tower_mailbox(context: &CodegenContext) -> TokenStream { } } - use tower::{Layer, Service, ServiceExt}; - impl tower::Service for #mailbox_name where A: quickwit_actors::Actor + quickwit_actors::DeferableReplyHandler> + Send + 'static, - M: std::fmt::Debug + Send + Sync + 'static, + M: std::fmt::Debug + Send + 'static, T: Send + 'static, E: std::fmt::Debug + Send + 'static, #error_type: From>, @@ -833,7 +793,7 @@ fn generate_tower_mailbox(context: &CodegenContext) -> TokenStream { #[async_trait::async_trait] impl #service_name for #mailbox_name where - A: quickwit_actors::Actor + std::fmt::Debug + Send + Sync + 'static, + A: quickwit_actors::Actor + std::fmt::Debug, #mailbox_name: #(#mailbox_bounds)+*, { #mailbox_methods @@ -846,21 +806,15 @@ fn generate_mailbox_bounds_and_methods( ) -> (Vec, TokenStream) { let result_type = &context.result_type; let error_type = &context.error_type; - let stream_type = &context.stream_type; let mut bounds = Vec::with_capacity(context.methods.len()); let mut methods = TokenStream::new(); for syn_method in &context.methods { let method_name = syn_method.name.to_token_stream(); - let request_type = syn_method.request_type.to_token_stream(); + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; let bound = quote! { tower::Service<#request_type, Response = #response_type, Error = #error_type, Future = BoxFuture<#response_type, #error_type>> }; @@ -912,26 +866,20 @@ fn generate_grpc_client_adapter(context: &CodegenContext) -> TokenStream { fn generate_grpc_client_adapter_methods(context: &CodegenContext) -> TokenStream { let result_type = &context.result_type; - let stream_type = &context.stream_type; let mut stream = TokenStream::new(); for syn_method in &context.methods { let method_name = syn_method.name.to_token_stream(); - let request_type = syn_method.request_type.to_token_stream(); + let request_type = syn_method.request_type(); + let response_type = syn_method.response_type(context); - let response_type = if syn_method.server_streaming { - let response_type = &syn_method.response_type; - quote! { #stream_type<#response_type> } - } else { - syn_method.response_type.to_token_stream() - }; let into_response_type = if syn_method.server_streaming { quote! { |response| { - let stream = response.into_inner(); - let service_stream = quickwit_common::ServiceStream::from(stream); - service_stream.map_err(|error| error.into()) + let streaming: tonic::Streaming<_> = response.into_inner(); + let stream = quickwit_common::ServiceStream::from(streaming); + stream.map_err(|error| error.into()) } } } else { @@ -985,8 +933,22 @@ fn generate_grpc_server_adapter_methods(context: &CodegenContext) -> TokenStream for syn_method in &context.methods { let method_name = syn_method.name.to_token_stream(); - let request_type = syn_method.request_type.to_token_stream(); - + let request_type = if syn_method.client_streaming { + let request_type = &syn_method.request_type; + quote! { tonic::Streaming<#request_type> } + } else { + syn_method.request_type.to_token_stream() + }; + let method_arg = if syn_method.client_streaming { + quote! { + { + let streaming: tonic::Streaming<_> = request.into_inner(); + quickwit_common::ServiceStream::from(streaming) + } + } + } else { + quote! { request.into_inner() } + }; let response_type = if syn_method.server_streaming { let associated_type_name = quote::format_ident!("{}Stream", syn_method.proto_name); quote! { Self::#associated_type_name } @@ -996,7 +958,7 @@ fn generate_grpc_server_adapter_methods(context: &CodegenContext) -> TokenStream let associated_type = if syn_method.server_streaming { let associated_type_name = quote::format_ident!("{}Stream", syn_method.proto_name); let response_type = &syn_method.response_type; - quote! { type #associated_type_name = quickwit_common::ServiceStream<#response_type, tonic::Status>; } + quote! { type #associated_type_name = quickwit_common::ServiceStream>; } } else { TokenStream::new() }; @@ -1013,7 +975,7 @@ fn generate_grpc_server_adapter_methods(context: &CodegenContext) -> TokenStream async fn #method_name(&self, request: tonic::Request<#request_type>) -> Result, tonic::Status> { self.inner .clone() - .#method_name(request.into_inner()) + .#method_name(#method_arg) .await .map(#into_response_type) .map_err(|error| error.into()) diff --git a/quickwit/quickwit-common/src/stream_utils.rs b/quickwit/quickwit-common/src/stream_utils.rs index f68edd42eeb..70d46f3b0ae 100644 --- a/quickwit/quickwit-common/src/stream_utils.rs +++ b/quickwit/quickwit-common/src/stream_utils.rs @@ -17,37 +17,52 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . +use std::any::TypeId; +use std::fmt; use std::pin::Pin; -use futures::{Stream, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use tokio::sync::mpsc; use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream}; +use tracing::warn; pub type BoxStream = Pin + Send + Unpin + 'static>>; /// A stream impl for code-generated services with streaming endpoints. -pub struct ServiceStream { - inner: BoxStream>, +pub struct ServiceStream { + inner: BoxStream, } -impl Unpin for ServiceStream {} +impl fmt::Debug for ServiceStream +where T: 'static +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ServiceStream<{:?}>", TypeId::of::()) + } +} -impl ServiceStream -where - T: Send + 'static, - E: Send + 'static, +impl Unpin for ServiceStream {} + +impl ServiceStream +where T: Send + 'static { - pub fn new_bounded(capacity: usize) -> (mpsc::Sender>, Self) { + pub fn new_bounded(capacity: usize) -> (mpsc::Sender, Self) { let (sender, receiver) = mpsc::channel(capacity); (sender, receiver.into()) } - pub fn new_unbounded() -> (mpsc::UnboundedSender>, Self) { + pub fn new_unbounded() -> (mpsc::UnboundedSender, Self) { let (sender, receiver) = mpsc::unbounded_channel(); (sender, receiver.into()) } +} - pub fn map_err(self, f: F) -> ServiceStream +impl ServiceStream> +where + T: Send + 'static, + E: Send + 'static, +{ + pub fn map_err(self, f: F) -> ServiceStream> where F: FnMut(E) -> U + Send + 'static, U: Send + 'static, @@ -58,8 +73,8 @@ where } } -impl Stream for ServiceStream { - type Item = Result; +impl Stream for ServiceStream { + type Item = T; fn poll_next( mut self: std::pin::Pin<&mut Self>, @@ -69,31 +84,30 @@ impl Stream for ServiceStream { } } -impl From>> for ServiceStream -where - T: Send + 'static, - E: Send + 'static, +impl From> for ServiceStream +where T: Send + 'static { - fn from(receiver: mpsc::Receiver>) -> Self { + fn from(receiver: mpsc::Receiver) -> Self { Self { inner: Box::pin(ReceiverStream::new(receiver)), } } } -impl From>> for ServiceStream -where - T: Send + 'static, - E: Send + 'static, +impl From> for ServiceStream +where T: Send + 'static { - fn from(receiver: mpsc::UnboundedReceiver>) -> Self { + fn from(receiver: mpsc::UnboundedReceiver) -> Self { Self { inner: Box::pin(UnboundedReceiverStream::new(receiver)), } } } -impl From> for ServiceStream +/// Adapts a server-side tonic::Streaming into a ServiceStream of `Result`. Once +/// an error is encountered, the stream will be closed and subsequent calls to `poll_next` will +/// return `None`. +impl From> for ServiceStream> where T: Send + 'static { fn from(streaming: tonic::Streaming) -> Self { @@ -102,3 +116,26 @@ where T: Send + 'static } } } + +/// Adapts a client-side tonic::Streaming into a ServiceStream of `T`. Once an error is encountered, +/// the stream will be closed and subsequent calls to `poll_next` will return `None`. +impl From> for ServiceStream +where T: Send + 'static +{ + fn from(streaming: tonic::Streaming) -> Self { + let ok_streaming = streaming.filter_map(|message| { + Box::pin(async move { + message + .map_err(|status| { + warn!(status=?status, "gRPC transport error."); + status + }) + .ok() + }) + }); + + Self { + inner: Box::pin(ok_streaming), + } + } +} diff --git a/quickwit/quickwit-control-plane/src/codegen/control_plane_service.rs b/quickwit/quickwit-control-plane/src/codegen/control_plane_service.rs index a4ef56d2f56..3824b963d94 100644 --- a/quickwit/quickwit-control-plane/src/codegen/control_plane_service.rs +++ b/quickwit/quickwit-control-plane/src/codegen/control_plane_service.rs @@ -7,6 +7,7 @@ pub struct NotifyIndexChangeRequest {} #[derive(Clone, PartialEq, ::prost::Message)] pub struct NotifyIndexChangeResponse {} /// BEGIN quickwit-codegen +use tower::{Layer, Service, ServiceExt}; #[cfg_attr(any(test, feature = "testsuite"), mockall::automock)] #[async_trait::async_trait] pub trait ControlPlaneService: std::fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static { @@ -281,13 +282,12 @@ impl Clone for ControlPlaneServiceMailbox { Self { inner } } } -use tower::{Layer, Service, ServiceExt}; impl tower::Service for ControlPlaneServiceMailbox where A: quickwit_actors::Actor + quickwit_actors::DeferableReplyHandler> + Send + 'static, - M: std::fmt::Debug + Send + Sync + 'static, + M: std::fmt::Debug + Send + 'static, T: Send + 'static, E: std::fmt::Debug + Send + 'static, crate::ControlPlaneError: From>, @@ -315,7 +315,7 @@ where #[async_trait::async_trait] impl ControlPlaneService for ControlPlaneServiceMailbox where - A: quickwit_actors::Actor + std::fmt::Debug + Send + Sync + 'static, + A: quickwit_actors::Actor + std::fmt::Debug, ControlPlaneServiceMailbox< A, >: tower::Service< diff --git a/quickwit/quickwit-ingest/src/codegen/ingest_service.rs b/quickwit/quickwit-ingest/src/codegen/ingest_service.rs index 6ff6df0c637..4fff2e4e56e 100644 --- a/quickwit/quickwit-ingest/src/codegen/ingest_service.rs +++ b/quickwit/quickwit-ingest/src/codegen/ingest_service.rs @@ -115,6 +115,7 @@ pub struct ListQueuesResponse { pub queues: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } /// BEGIN quickwit-codegen +use tower::{Layer, Service, ServiceExt}; #[cfg_attr(any(test, feature = "testsuite"), mockall::automock)] #[async_trait::async_trait] pub trait IngestService: std::fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static { @@ -508,13 +509,12 @@ impl Clone for IngestServiceMailbox { Self { inner } } } -use tower::{Layer, Service, ServiceExt}; impl tower::Service for IngestServiceMailbox where A: quickwit_actors::Actor + quickwit_actors::DeferableReplyHandler> + Send + 'static, - M: std::fmt::Debug + Send + Sync + 'static, + M: std::fmt::Debug + Send + 'static, T: Send + 'static, E: std::fmt::Debug + Send + 'static, crate::IngestServiceError: From>, @@ -542,7 +542,7 @@ where #[async_trait::async_trait] impl IngestService for IngestServiceMailbox where - A: quickwit_actors::Actor + std::fmt::Debug + Send + Sync + 'static, + A: quickwit_actors::Actor + std::fmt::Debug, IngestServiceMailbox< A, >: tower::Service<