diff --git a/src/interfaces/grpc.rs b/src/interfaces/grpc.rs index d42cc36..8cd01ae 100644 --- a/src/interfaces/grpc.rs +++ b/src/interfaces/grpc.rs @@ -32,6 +32,22 @@ impl MeeSignService { pub fn new(state: Arc>) -> Self { MeeSignService { state } } + + async fn check_client_auth( + &self, + certs: &Option>>, + required: bool, + ) -> Result<(), Status> { + if let Some(certs) = certs { + let device_id = certs.get(0).map(cert_to_id).unwrap_or(vec![]); + if !self.state.lock().await.device_activated(&device_id) { + return Err(Status::unauthenticated("Unknown device certificate")); + } + } else if required { + return Err(Status::unauthenticated("Authentication required")); + } + Ok(()) + } } #[tonic::async_trait] @@ -43,13 +59,9 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + debug!("ServerInfoRequest"); - if let Some(certs) = request.peer_certs() { - let device_id = certs.get(0).map(cert_to_id).unwrap_or(vec![]); - if !self.state.lock().await.device_activated(&device_id) { - return Err(Status::unauthenticated("Unknown device certificate")); - } - } Ok(Response::new(msg::ServerInfo { version: crate::VERSION.unwrap_or("unknown").to_string(), })) @@ -59,6 +71,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + let request = request.into_inner(); let name = request.name; let kind = DeviceKind::try_from(request.kind).unwrap(); @@ -90,6 +104,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + let request = request.into_inner(); let group_id = request.group_id; let name = request.name; @@ -109,6 +125,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + let request = request.into_inner(); let group_id = request.group_id; let name = request.name; @@ -129,6 +147,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + let request = request.into_inner(); let task_id = Uuid::from_slice(&request.task_id).unwrap(); let device_id = request.device_id; @@ -144,9 +164,6 @@ impl MeeSign for MeeSignService { ); let state = self.state.lock().await; - if device_id.is_some() { - state.device_activated(device_id.as_ref().unwrap()); - } let task = state.get_task(&task_id).unwrap(); let request = Some(task.get_request()); @@ -158,9 +175,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { - if request.peer_certs().is_none() { - return Err(Status::unauthenticated("Authentication required")); - } + self.check_client_auth(&request.peer_certs(), true).await?; + let device_id = request .peer_certs() .and_then(|certs| certs.get(0).map(cert_to_id)) @@ -187,7 +203,6 @@ impl MeeSign for MeeSignService { ); let mut state = self.state.lock().await; - state.device_activated(&device_id); let result = state.update_task(&task_id, &device_id, &data, attempt); match result { @@ -202,6 +217,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + let request = request.into_inner(); let device_id = request.device_id; let device_str = device_id @@ -212,7 +229,6 @@ impl MeeSign for MeeSignService { let state = self.state.lock().await; let tasks = if let Some(device_id) = device_id { - state.device_activated(&device_id); state .get_device_tasks(&device_id) .iter() @@ -233,6 +249,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + let request = request.into_inner(); let device_id = request.device_id; let device_str = device_id @@ -243,7 +261,6 @@ impl MeeSign for MeeSignService { let state = self.state.lock().await; let groups = if let Some(device_id) = device_id { - state.device_activated(&device_id); state .get_device_groups(&device_id) .iter() @@ -264,6 +281,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + let request = request.into_inner(); let name = request.name; let device_ids = request.device_ids; @@ -295,8 +314,10 @@ impl MeeSign for MeeSignService { async fn get_devices( &self, - _request: Request, + request: Request, ) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + debug!("DevicesRequest"); let resp = msg::Devices { @@ -313,6 +334,8 @@ impl MeeSign for MeeSignService { } async fn log(&self, request: Request) -> Result, Status> { + self.check_client_auth(&request.peer_certs(), false).await?; + let device_id = request .peer_certs() .and_then(|certs| certs.get(0).map(cert_to_id)); @@ -324,13 +347,6 @@ impl MeeSign for MeeSignService { let message = request.into_inner().message.replace('\n', "\\n"); debug!("LogRequest device_id={} message={}", device_str, message); - if device_id.is_some() { - self.state - .lock() - .await - .device_activated(device_id.as_ref().unwrap()); - } - Ok(Response::new(msg::Resp { message: "OK".into(), })) @@ -340,9 +356,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { - if request.peer_certs().is_none() { - return Err(Status::unauthenticated("Authentication required")); - } + self.check_client_auth(&request.peer_certs(), true).await?; + let device_id = request .peer_certs() .and_then(|certs| certs.get(0).map(cert_to_id)) @@ -362,7 +377,6 @@ impl MeeSign for MeeSignService { let state = self.state.clone(); tokio::task::spawn(async move { let mut state = state.lock().await; - state.device_activated(&device_id); state.decide_task(&task_id, &device_id, accept); }); @@ -375,9 +389,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { - if request.peer_certs().is_none() { - return Err(Status::unauthenticated("Authentication required")); - } + self.check_client_auth(&request.peer_certs(), true).await?; + let device_id = request .peer_certs() .and_then(|certs| certs.get(0).map(cert_to_id)) @@ -392,7 +405,6 @@ impl MeeSign for MeeSignService { ); let mut state = self.state.lock().await; - state.device_activated(&device_id); state.acknowledge_task(&Uuid::from_slice(&task_id).unwrap(), &device_id); Ok(Response::new(msg::Resp { @@ -404,9 +416,8 @@ impl MeeSign for MeeSignService { &self, request: Request, ) -> Result, Status> { - if request.peer_certs().is_none() { - return Err(Status::unauthenticated("Authentication required")); - } + self.check_client_auth(&request.peer_certs(), true).await?; + let device_id = request .peer_certs() .and_then(|certs| certs.get(0).map(cert_to_id)) diff --git a/src/state.rs b/src/state.rs index c672964..bb14682 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use log::{debug, error, warn}; +use log::{debug, warn}; use uuid::Uuid; use crate::device::Device; @@ -293,7 +293,7 @@ impl State { device.activated(); true } else { - error!("Unknown Device ID {}", utils::hextrunc(device_id)); + debug!("Unknown Device ID {}", utils::hextrunc(device_id)); false } }