Skip to content

Commit

Permalink
Handle user join and leave
Browse files Browse the repository at this point in the history
  • Loading branch information
Enrico Marconi committed Aug 3, 2023
1 parent 1ac5f67 commit 53d4987
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 71 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ sqlx = { version = "0.6.3", default-features = false, features = [
tokio = { version = "1.29.1", features = ["full"] }
tower = "0.4.13"
tower-http = { version = "0.4.3", features = ["trace", "catch-panic"] }
tracing = "0.1.37"
tracing = { version = "0.1.37", features = ["attributes"] }
tracing-log = "0.1.3"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
uuid = { version = "1.4.1", features = ["v4", "serde"] }

Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use coodo_be::{settings::get_settings, startup::Application, telemetry};

#[tokio::main(flavor = "multi_thread")]
async fn main() -> anyhow::Result<()> {
telemetry::init();
telemetry::init_with_filter("info");

let settings = get_settings().context("Failed to parse app settings")?;
let app = Application::build(settings)
Expand Down
34 changes: 31 additions & 3 deletions src/routes/todo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ async fn create_todo_list(
Ok(Json(todo_list.id()))
}

// #[derive(Serialize, Debug)]
// struct TodoListInfo {
// name: String,
// id: Uuid,
// }

// async fn get_users_todo_lists(
// session: ReadableSession,
// State(state): State<AppState>
// ) -> Json<TodoListInfo> {

// }

async fn join_todo_list(
session: ReadableSession,
Path(todo_id): Path<Uuid>,
Expand All @@ -53,7 +66,9 @@ async fn join_todo_list(
return (StatusCode::UNAUTHORIZED, "Establish a session first").into_response();
};
if let Ok((todo_watch, command_tx)) = state.join_todo_list(todo_id, *user.id()).await {
ws.on_upgrade(move |socket| ws_handler(socket, todo_id, todo_watch, command_tx, user))
ws.on_upgrade(move |socket| {
ws_handler(socket, state, todo_id, todo_watch, command_tx, user)
})
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Expand All @@ -73,14 +88,24 @@ async fn join_todo_list(
)]
async fn ws_handler(
ws: WebSocket,
state: AppState,
todo_list_id: Uuid,
mut todo: TodoListWatcher,
command_tx: TodoCommandSender,
user: User,
) {
let (mut ws_tx, mut ws_rx) = ws.split();
if let Err(e) = send_todo_list(&mut todo, &mut ws_tx).await {
tracing::error!("{}", e);
if command_tx
.send(Command::UserJoin(user.clone()).with_issuer(*user.id()))
.await
.is_err()
{
tracing::error!(
"User {} failed to join todo list {}",
user.id(),
todo_list_id
);
return;
}

loop {
Expand All @@ -95,6 +120,9 @@ async fn ws_handler(
},
else => {
let _ = ws_tx.close().await;
let _ = command_tx.send(Command::UserLeave(user.clone()).with_issuer(*user.id())).await;
state.leave_todo_list(todo_list_id, *user.id()).await;

break;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ impl DbSettings {
}

pub fn with_db(&self) -> PgConnectOptions {
use tracing::log::LevelFilter;

let mut options = self.without_db().database(&self.name);
options.log_statements(axum_sessions::async_session::log::LevelFilter::Trace);
options.log_statements(LevelFilter::Trace);

options
}
Expand Down
14 changes: 13 additions & 1 deletion src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,19 @@ impl AppState {
.or_insert(
TodoListHandle::spawn(todo, self.db_pool.clone(), self.store_interval).await?,
)
.connect(user_id)
.get_connection(user_id)
.context("Failed to connect to given todo list")
}

pub async fn leave_todo_list(&self, todo: Uuid, user_id: Uuid) {
let mut todo_lists = self.todo_lists.write().await;
let mut empty = false;
if let Some(handle) = todo_lists.get_mut(&todo) {
handle.disconnect_user(user_id);
empty = handle.is_empty();
}
if empty {
todo_lists.remove(&todo);
}
}
}
17 changes: 7 additions & 10 deletions src/telemetry.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

pub fn init() {
init_with_filter(std::env::var("RUST_LOG").as_deref().unwrap_or("info"))
}

pub fn init_with_filter(filter: &str) {
let fmt_layer = fmt::layer().with_target(false).compact();
let filter_layer = EnvFilter::try_new(filter)
.or_else(|_| EnvFilter::try_new("info"))
.unwrap();
let filter_layer = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(filter));

tracing_subscriber::registry()
//LogTracer::init_with_filter(tracing::log::LevelFilter::Info).expect("Failed to set logger");

let subscriber = tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.init();
.with(fmt_layer);

tracing::subscriber::set_global_default(subscriber).expect("Failed to set global subscriber");
}
8 changes: 8 additions & 0 deletions src/todo/command.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;

use crate::user::User;

use super::list::{TodoList, TodoTask};

pub trait Applicable {
Expand All @@ -18,6 +20,9 @@ pub struct TodoCommand {
pub enum Command {
TaskCommand(TaskCommandMeta),
CreateTask,
UserJoin(User),
UserLeave(User),
SetListName(String),
}

impl Command {
Expand All @@ -34,6 +39,9 @@ impl Applicable for Command {
match self {
Command::TaskCommand(task_command) => task_command.apply(todo, issuer),
Command::CreateTask => todo.add_task(TodoTask::new(issuer)),
Command::UserJoin(user) => todo.add_user(user),
Command::UserLeave(user) => todo.remove_user(*user.id()),
Command::SetListName(name) => todo.rename(name),
}
}
}
Expand Down
10 changes: 9 additions & 1 deletion src/todo/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,19 @@ impl TodoListHandle {
})
}

pub fn connect(&mut self, user: Uuid) -> Option<(TodoListWatcher, TodoCommandSender)> {
pub fn get_connection(&mut self, user: Uuid) -> Option<(TodoListWatcher, TodoCommandSender)> {
self.connected_users
.insert(user)
.then(|| (self.todo_watcher.clone(), self.command_tx.clone()))
}

pub fn disconnect_user(&mut self, user: Uuid) {
self.connected_users.remove(&user);
}

pub fn is_empty(&self) -> bool {
self.connected_users.is_empty()
}
}

impl Drop for TodoListHandle {
Expand Down
21 changes: 21 additions & 0 deletions src/todo/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use chrono::{DateTime, Utc};
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
use uuid::Uuid;

use crate::user::User;

#[derive(Debug, serde::Serialize, Clone, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TodoList {
Expand All @@ -10,6 +12,7 @@ pub struct TodoList {
tasks: Vec<TodoTask>,
created_at: DateTime<Utc>,
last_updated_at: DateTime<Utc>,
connected_users: Vec<User>,
}

impl Default for TodoList {
Expand All @@ -20,6 +23,7 @@ impl Default for TodoList {
tasks: vec![],
created_at: Utc::now(),
last_updated_at: Utc::now(),
connected_users: vec![],
}
}
}
Expand Down Expand Up @@ -60,13 +64,18 @@ WHERE id = $1
tasks,
created_at: record.created_at,
last_updated_at: record.last_updated_at,
connected_users: vec![],
})
}

pub const fn id(&self) -> Uuid {
self.id
}

pub fn rename(&mut self, name: String) {
self.name = name;
}

fn update_time(&mut self) {
self.last_updated_at = Utc::now();
}
Expand Down Expand Up @@ -117,6 +126,18 @@ INSERT INTO todo_lists (id, name, created_at, last_updated_at)

Ok(())
}

pub fn connected_users(&self) -> &[User] {
&self.connected_users[..]
}

pub fn add_user(&mut self, user: User) {
self.connected_users.push(user);
}

pub fn remove_user(&mut self, id: Uuid) {
self.connected_users.retain(|user| user.id() != &id);
}
}

#[derive(Debug, Clone, FromRow, serde::Serialize, serde::Deserialize)]
Expand Down
4 changes: 2 additions & 2 deletions src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl Default for UserHandleGenerator {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct UserHandle(String);

impl AsRef<str> for UserHandle {
Expand All @@ -65,7 +65,7 @@ impl UserHandle {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct User {
id: Uuid,
handle: UserHandle,
Expand Down
45 changes: 3 additions & 42 deletions tests/api/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use core::task::Context as Ctx;
use std::{net::TcpListener, pin::Pin, task::Poll, time::Duration};
use std::{net::TcpListener, time::Duration};

use anyhow::Context;
use coodo_be::{settings::TodoHandlerSettings, telemetry, user::User};
use futures_util::{
stream::{SplitSink, SplitStream},
Future, Stream, StreamExt,
StreamExt,
};
use once_cell::sync::Lazy;
use reqwest::{cookie::Jar, Client};
use sqlx::PgPool;
use tokio::{net::TcpStream, task::JoinHandle, time::Interval};
use tokio::{net::TcpStream, task::JoinHandle};
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
use uuid::Uuid;

Expand Down Expand Up @@ -101,44 +100,6 @@ impl TestApp {
}
}

pub trait StreamExtTimed: StreamExt {
fn next_with_timeout(&mut self, timeout: Duration) -> TimedNext<'_, Self>
where
Self: Unpin,
{
TimedNext::new(self, timeout)
}
}

impl<T: ?Sized> StreamExtTimed for T where T: StreamExt {}

pub struct TimedNext<'a, S: ?Sized> {
stream: &'a mut S,
timeout: Interval,
}

impl<S: ?Sized + Unpin> Unpin for TimedNext<'_, S> {}

impl<'a, S: ?Sized + Stream + Unpin> TimedNext<'a, S> {
pub fn new(stream: &'a mut S, timeout: Duration) -> Self {
Self {
stream,
timeout: tokio::time::interval(timeout),
}
}
}

impl<S: ?Sized + Stream + Unpin> Future for TimedNext<'_, S> {
type Output = Option<S::Item>;

fn poll(mut self: Pin<&mut Self>, ctx: &mut Ctx<'_>) -> Poll<Self::Output> {
match self.timeout.poll_tick(ctx) {
Poll::Pending => self.stream.poll_next_unpin(ctx),
Poll::Ready(_) => Poll::Ready(None),
}
}
}

fn get_sid(cookie_jar: &Jar) -> anyhow::Result<String> {
use reqwest::cookie::CookieStore;

Expand Down
Loading

0 comments on commit 53d4987

Please sign in to comment.