Skip to content

Commit

Permalink
feat: send an error when a request is made with an unknown certificate
Browse files Browse the repository at this point in the history
  • Loading branch information
dufkan committed May 10, 2024
1 parent bb498e6 commit 971f1a8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 36 deletions.
79 changes: 45 additions & 34 deletions src/interfaces/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ impl MeeSignService {
pub fn new(state: Arc<Mutex<State>>) -> Self {
MeeSignService { state }
}

async fn check_client_auth(
&self,
certs: &Option<Arc<Vec<Certificate>>>,
required: bool,
) -> Result<(), Status> {
if let Some(certs) = certs {
let device_id = certs.get(0).map(cert_to_id).unwrap_or(vec![]);
if !self.state.lock().await.device_activated(&device_id) {
return Err(Status::unauthenticated("Unknown device certificate"));
}
} else if required {
return Err(Status::unauthenticated("Authentication required"));
}
Ok(())
}
}

#[tonic::async_trait]
Expand All @@ -43,13 +59,9 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::ServerInfoRequest>,
) -> Result<Response<msg::ServerInfo>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

debug!("ServerInfoRequest");
if let Some(certs) = request.peer_certs() {
let device_id = certs.get(0).map(cert_to_id).unwrap_or(vec![]);
if !self.state.lock().await.device_activated(&device_id) {
return Err(Status::unauthenticated("Unknown device certificate"));
}
}
Ok(Response::new(msg::ServerInfo {
version: crate::VERSION.unwrap_or("unknown").to_string(),
}))
Expand All @@ -59,6 +71,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::RegistrationRequest>,
) -> Result<Response<msg::RegistrationResponse>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

let request = request.into_inner();
let name = request.name;
let kind = DeviceKind::try_from(request.kind).unwrap();
Expand Down Expand Up @@ -90,6 +104,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::SignRequest>,
) -> Result<Response<msg::Task>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

let request = request.into_inner();
let group_id = request.group_id;
let name = request.name;
Expand All @@ -109,6 +125,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::DecryptRequest>,
) -> Result<Response<msg::Task>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

let request = request.into_inner();
let group_id = request.group_id;
let name = request.name;
Expand All @@ -129,6 +147,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::TaskRequest>,
) -> Result<Response<msg::Task>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

let request = request.into_inner();
let task_id = Uuid::from_slice(&request.task_id).unwrap();
let device_id = request.device_id;
Expand All @@ -144,9 +164,6 @@ impl MeeSign for MeeSignService {
);

let state = self.state.lock().await;
if device_id.is_some() {
state.device_activated(device_id.as_ref().unwrap());
}
let task = state.get_task(&task_id).unwrap();
let request = Some(task.get_request());

Expand All @@ -158,9 +175,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::TaskUpdate>,
) -> Result<Response<msg::Resp>, Status> {
if request.peer_certs().is_none() {
return Err(Status::unauthenticated("Authentication required"));
}
self.check_client_auth(&request.peer_certs(), true).await?;

let device_id = request
.peer_certs()
.and_then(|certs| certs.get(0).map(cert_to_id))
Expand All @@ -187,7 +203,6 @@ impl MeeSign for MeeSignService {
);

let mut state = self.state.lock().await;
state.device_activated(&device_id);
let result = state.update_task(&task_id, &device_id, &data, attempt);

match result {
Expand All @@ -202,6 +217,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::TasksRequest>,
) -> Result<Response<msg::Tasks>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

let request = request.into_inner();
let device_id = request.device_id;
let device_str = device_id
Expand All @@ -212,7 +229,6 @@ impl MeeSign for MeeSignService {

let state = self.state.lock().await;
let tasks = if let Some(device_id) = device_id {
state.device_activated(&device_id);
state
.get_device_tasks(&device_id)
.iter()
Expand All @@ -233,6 +249,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::GroupsRequest>,
) -> Result<Response<msg::Groups>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

let request = request.into_inner();
let device_id = request.device_id;
let device_str = device_id
Expand All @@ -243,7 +261,6 @@ impl MeeSign for MeeSignService {

let state = self.state.lock().await;
let groups = if let Some(device_id) = device_id {
state.device_activated(&device_id);
state
.get_device_groups(&device_id)
.iter()
Expand All @@ -264,6 +281,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::GroupRequest>,
) -> Result<Response<msg::Task>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

let request = request.into_inner();
let name = request.name;
let device_ids = request.device_ids;
Expand Down Expand Up @@ -295,8 +314,10 @@ impl MeeSign for MeeSignService {

async fn get_devices(
&self,
_request: Request<msg::DevicesRequest>,
request: Request<msg::DevicesRequest>,
) -> Result<Response<msg::Devices>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

debug!("DevicesRequest");

let resp = msg::Devices {
Expand All @@ -313,6 +334,8 @@ impl MeeSign for MeeSignService {
}

async fn log(&self, request: Request<msg::LogRequest>) -> Result<Response<msg::Resp>, Status> {
self.check_client_auth(&request.peer_certs(), false).await?;

let device_id = request
.peer_certs()
.and_then(|certs| certs.get(0).map(cert_to_id));
Expand All @@ -324,13 +347,6 @@ impl MeeSign for MeeSignService {
let message = request.into_inner().message.replace('\n', "\\n");
debug!("LogRequest device_id={} message={}", device_str, message);

if device_id.is_some() {
self.state
.lock()
.await
.device_activated(device_id.as_ref().unwrap());
}

Ok(Response::new(msg::Resp {
message: "OK".into(),
}))
Expand All @@ -340,9 +356,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::TaskDecision>,
) -> Result<Response<msg::Resp>, Status> {
if request.peer_certs().is_none() {
return Err(Status::unauthenticated("Authentication required"));
}
self.check_client_auth(&request.peer_certs(), true).await?;

let device_id = request
.peer_certs()
.and_then(|certs| certs.get(0).map(cert_to_id))
Expand All @@ -362,7 +377,6 @@ impl MeeSign for MeeSignService {
let state = self.state.clone();
tokio::task::spawn(async move {
let mut state = state.lock().await;
state.device_activated(&device_id);
state.decide_task(&task_id, &device_id, accept);
});

Expand All @@ -375,9 +389,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::TaskAcknowledgement>,
) -> Result<Response<msg::Resp>, Status> {
if request.peer_certs().is_none() {
return Err(Status::unauthenticated("Authentication required"));
}
self.check_client_auth(&request.peer_certs(), true).await?;

let device_id = request
.peer_certs()
.and_then(|certs| certs.get(0).map(cert_to_id))
Expand All @@ -392,7 +405,6 @@ impl MeeSign for MeeSignService {
);

let mut state = self.state.lock().await;
state.device_activated(&device_id);
state.acknowledge_task(&Uuid::from_slice(&task_id).unwrap(), &device_id);

Ok(Response::new(msg::Resp {
Expand All @@ -404,9 +416,8 @@ impl MeeSign for MeeSignService {
&self,
request: Request<msg::SubscribeRequest>,
) -> Result<Response<Self::SubscribeUpdatesStream>, Status> {
if request.peer_certs().is_none() {
return Err(Status::unauthenticated("Authentication required"));
}
self.check_client_auth(&request.peer_certs(), true).await?;

let device_id = request
.peer_certs()
.and_then(|certs| certs.get(0).map(cert_to_id))
Expand Down
4 changes: 2 additions & 2 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use log::{debug, error, warn};
use log::{debug, warn};
use uuid::Uuid;

use crate::device::Device;
Expand Down Expand Up @@ -293,7 +293,7 @@ impl State {
device.activated();
true
} else {
error!("Unknown Device ID {}", utils::hextrunc(device_id));
debug!("Unknown Device ID {}", utils::hextrunc(device_id));
false
}
}
Expand Down

0 comments on commit 971f1a8

Please sign in to comment.