diff --git a/src/expect_interface.rs b/src/expect_interface.rs index a56b2aa..dccc74e 100644 --- a/src/expect_interface.rs +++ b/src/expect_interface.rs @@ -51,7 +51,7 @@ impl<'a> ExpectGetBufferBytes<'a> { } } - pub fn returning(&mut self, buffer_data: Option<&str>) -> &mut Tester { + pub fn returning(&mut self, buffer_data: Option<&[u8]>) -> &mut Tester { self.tester .get_expect_handle() .staged @@ -150,3 +150,48 @@ impl<'a> ExpectHttpCall<'a> { self.tester } } + +pub struct ExpectGrpcCall<'a> { + tester: &'a mut Tester, + service: Option<&'a str>, + service_name: Option<&'a str>, + method_name: Option<&'a str>, + initial_metadata: Option<&'a [u8]>, + request: Option<&'a [u8]>, + timeout: Option, +} + +impl<'a> ExpectGrpcCall<'a> { + pub fn expecting( + tester: &'a mut Tester, + service: Option<&'a str>, + service_name: Option<&'a str>, + method_name: Option<&'a str>, + initial_metadata: Option<&'a [u8]>, + request: Option<&'a [u8]>, + timeout: Option, + ) -> Self { + Self { + tester, + service, + service_name, + method_name, + initial_metadata, + request, + timeout, + } + } + + pub fn returning(&mut self, token_id: Option) -> &mut Tester { + self.tester.get_expect_handle().staged.set_expect_grpc_call( + self.service, + self.service_name, + self.method_name, + self.initial_metadata, + self.request, + self.timeout, + token_id, + ); + self.tester + } +} diff --git a/src/expectations.rs b/src/expectations.rs index f2bf2dc..32b3abb 100644 --- a/src/expectations.rs +++ b/src/expectations.rs @@ -41,19 +41,20 @@ impl ExpectHandle { self.staged = Expect::new(allow_unexpected); } - pub fn assert_stage(&self) { + pub fn assert_stage(&self) -> Option { if self.staged.expect_count > 0 { - panic!( + return Some(format!( "Error: failed to consume all expectations - total remaining: {}", self.staged.expect_count - ); + )); } else if self.staged.expect_count < 0 { - panic!( + return Some(format!( "Error: expectations failed to account for all host calls by {} \n\ if this is intended, please use --allow-unexpected (-a) mode", -1 * self.staged.expect_count - ); + )); } + None } pub fn print_staged(&self) { @@ -86,6 +87,15 @@ pub struct Expect { Option, Option, )>, + grpc_call: Vec<( + Option, + Option, + Option, + Option, + Option, + Option, + Option, + )>, } impl Expect { @@ -106,6 +116,7 @@ impl Expect { add_header_map_value: vec![], send_local_response: vec![], http_call: vec![], + grpc_call: vec![], } } @@ -190,13 +201,11 @@ impl Expect { pub fn set_expect_get_buffer_bytes( &mut self, buffer_type: Option, - buffer_data: Option<&str>, + buffer_data: Option<&[u8]>, ) { self.expect_count += 1; - self.get_buffer_bytes.push(( - buffer_type, - buffer_data.map(|data| data.as_bytes().to_vec()), - )); + self.get_buffer_bytes + .push((buffer_type, buffer_data.map(|data| data.to_vec()))); } pub fn get_expect_get_buffer_bytes(&mut self, buffer_type: i32) -> Option { @@ -571,4 +580,73 @@ impl Expect { } } } + + pub fn set_expect_grpc_call( + &mut self, + service: Option<&str>, + service_name: Option<&str>, + method_name: Option<&str>, + initial_metadata: Option<&[u8]>, + request: Option<&[u8]>, + timeout: Option, + token_id: Option, + ) { + self.expect_count += 1; + self.grpc_call.push(( + service.map(ToString::to_string), + service_name.map(ToString::to_string), + method_name.map(ToString::to_string), + initial_metadata.map(|s| s.to_vec()), + request.map(|s| s.to_vec()), + timeout.map(Duration::from_millis), + token_id, + )); + } + + pub fn get_expect_grpc_call( + &mut self, + service: String, + service_name: String, + method: String, + initial_metadata: &[u8], + request: &[u8], + timeout: i32, + ) -> Option { + match self.grpc_call.len() { + 0 => { + if !self.allow_unexpected { + self.expect_count -= 1; + } + set_status(ExpectStatus::Unexpected); + None + } + _ => { + self.expect_count -= 1; + let ( + expected_service, + expected_service_name, + expected_method, + expected_initial_metadata, + expected_request, + expected_duration, + result, + ) = self.grpc_call.remove(0); + + let expected = expected_service.map(|e| e == service).unwrap_or(true) + && expected_service_name + .map(|e| e == service_name) + .unwrap_or(true) + && expected_method.map(|e| e == method).unwrap_or(true) + && expected_initial_metadata + .map(|e| e == initial_metadata) + .unwrap_or(true) + && expected_request.map(|e| e == request).unwrap_or(true) + && expected_duration + .map(|e| e.as_millis() as i32 == timeout) + .unwrap_or(true); + set_expect_status(expected); + return result; + } + } + } } diff --git a/src/host_settings.rs b/src/host_settings.rs index b315f9f..8ec6c8e 100644 --- a/src/host_settings.rs +++ b/src/host_settings.rs @@ -294,5 +294,9 @@ pub fn default_buffer_bytes() -> HashMap { BufferType::HttpCallResponseBody as i32, "default_call_response_body".as_bytes().to_vec(), ); + default_bytes.insert( + BufferType::GrpcReceiveBuffer as i32, + "default_grpc_receive_buffer".as_bytes().to_vec(), + ); default_bytes } diff --git a/src/hostcalls.rs b/src/hostcalls.rs index 3323f82..1f2534f 100644 --- a/src/hostcalls.rs +++ b/src/hostcalls.rs @@ -1390,20 +1390,67 @@ fn get_hostfunc( "proxy_grpc_call" => { Some(Func::wrap( store, - |_caller: Caller<'_, ()>, - _service_ptr: i32, - _service_size: i32, - _service_name_ptr: i32, - _service_name_size: i32, - _method_name_ptr: i32, - _method_name_size: i32, - _initial_metadata_ptr: i32, - _initial_metadata_size: i32, - _request_ptr: i32, - _request_size: i32, - _timeout_milliseconds: i32, - _token_ptr: i32| + |mut caller: Caller<'_, ()>, + service_ptr: i32, + service_size: i32, + service_name_ptr: i32, + service_name_size: i32, + method_name_ptr: i32, + method_name_size: i32, + initial_metadata_ptr: i32, + initial_metadata_size: i32, + request_ptr: i32, + request_size: i32, + timeout_milliseconds: i32, + token_ptr: i32| -> i32 { + print!("[vm->host] proxy_grpc_call({initial_metadata_ptr}, {initial_metadata_size})"); + + // Default Function: receives and displays http call from proxy-wasm module + // Expectation: asserts equal the receieved http call with the expected one + let mem = match caller.get_export("memory") { + Some(Extern::Memory(mem)) => mem, + _ => { + println!("Error: proxy_http_call cannot get export \"memory\""); + println!( + "[vm<-host] proxy_http_call(...) -> (return_token) return: {:?}", + Status::InternalFailure + ); + return Status::InternalFailure as i32; + } + }; + + let service = read_string(&caller, mem, service_ptr, service_size); + let service_name = + read_string(&caller, mem, service_name_ptr, service_name_size); + let method_name = read_string(&caller, mem, method_name_ptr, method_name_size); + let initial_metadata = + read_bytes(&caller, mem, initial_metadata_ptr, initial_metadata_size) + .unwrap(); + let request = read_bytes(&caller, mem, request_ptr, request_size).unwrap(); + + println!( + "[vm->host] proxy_grpc_call(service={service}, service_name={service_name}, method_name={method_name}, initial_metadata={initial_metadata:?}, request={request:?}, timeout={timeout_milliseconds}"); + + let token_id = match EXPECT.lock().unwrap().staged.get_expect_grpc_call( + service, + service_name, + method_name, + initial_metadata, + request, + timeout_milliseconds, + ) { + Some(expect_token) => expect_token, + None => 0, + }; + + unsafe { + let return_token_add = mem.data_mut(&mut caller).get_unchecked_mut( + token_ptr as u32 as usize..token_ptr as u32 as usize + 4, + ); + return_token_add.copy_from_slice(&token_id.to_le_bytes()); + } + // Default Function: // Expectation: println!( @@ -1412,9 +1459,10 @@ fn get_hostfunc( ); println!( "[vm<-host] proxy_grpc_call() -> (..) return: {:?}", - Status::InternalFailure + Status::Ok ); - return Status::InternalFailure as i32; + assert_ne!(get_status(), ExpectStatus::Failed); + return Status::Ok as i32; }, )) } @@ -1641,6 +1689,18 @@ fn get_hostfunc( } } +fn read_string(caller: &Caller<()>, mem: Memory, ptr: i32, size: i32) -> String { + read_bytes(caller, mem, ptr, size) + .map(String::from_utf8_lossy) + .unwrap() + .to_string() +} + +fn read_bytes<'a>(caller: &'a Caller<()>, mem: Memory, ptr: i32, size: i32) -> Option<&'a [u8]> { + mem.data(caller) + .get(ptr as usize..ptr as usize + size as usize) +} + pub mod serial_utils { type Bytes = Vec; diff --git a/src/tester.rs b/src/tester.rs index 9c74b40..eed8c0e 100644 --- a/src/tester.rs +++ b/src/tester.rs @@ -268,6 +268,26 @@ impl Tester { ExpectHttpCall::expecting(self, upstream, headers, body, trailers, timeout) } + pub fn expect_grpc_call( + &mut self, + service: Option<&'static str>, + service_name: Option<&'static str>, + method_name: Option<&'static str>, + initial_metadata: Option<&'static [u8]>, + request: Option<&'static [u8]>, + timeout: Option, + ) -> ExpectGrpcCall { + ExpectGrpcCall::expecting( + self, + service, + service_name, + method_name, + initial_metadata, + request, + timeout, + ) + } + /* ------------------------------------- High-level Expectation Setting ------------------------------------- */ pub fn set_quiet(&mut self, quiet: bool) { @@ -323,7 +343,10 @@ impl Tester { } fn assert_expect_stage(&mut self) { - self.expect.lock().unwrap().assert_stage(); + let err = self.expect.lock().unwrap().assert_stage(); + if let Some(msg) = err { + panic!("{}", msg) + } } pub fn get_settings_handle(&self) -> MutexGuard { diff --git a/src/types.rs b/src/types.rs index 7921713..72f0874 100644 --- a/src/types.rs +++ b/src/types.rs @@ -88,6 +88,9 @@ pub enum BufferType { DownstreamData = 2, UpstreamData = 3, HttpCallResponseBody = 4, + GrpcReceiveBuffer = 5, + VmConfiguration = 6, + PluginConfiguration = 7, } #[repr(u32)]