Skip to content

Commit

Permalink
Merge pull request #65 from polyphony-chat/bitfl0wer/issue64
Browse files Browse the repository at this point in the history
Issue #64
  • Loading branch information
bitfl0wer authored Oct 24, 2024
2 parents 08729f9 + bf5e74d commit 8341331
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 19 deletions.
34 changes: 22 additions & 12 deletions src/gateway/establish_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use tokio_tungstenite::{
};

use crate::gateway::ready::create_ready;
use crate::gateway::Event;
use crate::{
database::entities::Config,
errors::{Error, GatewayError},
Expand Down Expand Up @@ -155,8 +156,17 @@ async fn finish_connecting(
}
};
debug!(target: "symfonia::gateway::establish_connection::finish_connecting", "Received message");

if let Ok(heartbeat) = from_str::<GatewayHeartbeat>(&raw_message.to_string()) {
trace!("Message: {}", raw_message);
let event = match Event::try_from(raw_message.clone()) {
Ok(event) => event,
Err(e) => {
log::debug!("Message could not be deserialized to Event: {e}");
return Err(Error::Gateway(GatewayError::UnexpectedMessage(
e.to_string(),
)));
}
};
if let Event::Heartbeat(heartbeat) = event {
log::trace!(target: "symfonia::gateway::establish_connection::finish_connecting", "Received heartbeat");
match heartbeat_handler_handle {
None => {
Expand Down Expand Up @@ -185,13 +195,11 @@ async fn finish_connecting(
state.heartbeat_send.send(heartbeat);
}
}
} else if let Ok(identify) =
from_str::<GatewayPayload<GatewayIdentifyPayload>>(&raw_message.to_string())
{
} else if let Event::Identify(identify) = event {
log::trace!(target: "symfonia::gateway::establish_connection::finish_connecting", "Received identify payload");
let claims = match check_token(
&state.db,
&identify.event_data.token,
&identify.event_data.as_ref().unwrap().token,
&state.config.security.jwt_secret,
)
.await
Expand Down Expand Up @@ -243,11 +251,14 @@ async fn finish_connecting(
}
}),
},
&identify.event_data.token,
&identify.event_data.as_ref().unwrap().token,
state.sequence_number.clone(),
)
.await;
match state.session_id_send.send(identify.event_data.token) {
match state
.session_id_send
.send(identify.event_data.unwrap().token)
{
Ok(_) => (),
Err(_) => {
log::error!(target: "symfonia::gateway::establish_connection::finish_connecting", "Failed to send session_id to heartbeat handler");
Expand All @@ -261,7 +272,7 @@ async fn finish_connecting(
}
let formatted_payload = GatewayPayload::<GatewayReady> {
op_code: 0,
event_data: create_ready(claims.id, &state.db).await?,
event_data: Some(create_ready(claims.id, &state.db).await?),
sequence_number: None,
event_name: Some("READY".to_string()),
};
Expand All @@ -274,7 +285,7 @@ async fn finish_connecting(
user: gateway_user,
client: gateway_client.clone(),
});
} else if let Ok(resume) = from_str::<GatewayResume>(&raw_message.to_string()) {
} else if let Event::Resume(resume) = event {
log::trace!(target: "symfonia::gateway::establish_connection::finish_connecting", "Received resume payload");
log::warn!(target: "symfonia::gateway::establish_connection::finish_connecting", "Resuming connections is not yet implemented. Telling client to identify instead.");
state
Expand All @@ -291,8 +302,7 @@ async fn finish_connecting(
.send(())
.expect("Failed to send kill signal");
} else {
debug!(target: "symfonia::gateway::establish_connection::finish_connecting", "Message could not be decoded as resume, heartbeat or identify.");
debug!(target: "symfonia::gateway::establish_connection::finish_connecting", "Message: {}", raw_message);
debug!(target: "symfonia::gateway::establish_connection::finish_connecting", "Message could not be decoded as resume, heartbeat or identify: {}", raw_message);
return Err(GatewayError::UnexpectedMessage("Received payload other than Heartbeat, Identify or Resume before the connection was established".to_string()).into());
}
}
Expand Down
12 changes: 10 additions & 2 deletions src/gateway/gateway_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{sync::Arc, time::Duration};

use chorus::types::{GatewayHeartbeat, GatewaySendPayload, Opcode, Snowflake};
use futures::StreamExt;
use log::debug;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::{from_str, json};
Expand Down Expand Up @@ -136,6 +137,7 @@ fn unwrap_event(
}
}

/// Process events triggered by the HTTP API.
async fn process_inbox(
mut connection: super::WebSocketConnection,
mut inbox: tokio::sync::broadcast::Receiver<Event>,
Expand All @@ -148,8 +150,14 @@ async fn process_inbox(
event = inbox.recv() => {
match event {
Ok(event) => {
todo!();
// TODO: Process event
let send_result = connection.sender.send(Message::Text(json!(event).to_string()));
match send_result {
Ok(_) => (), // TODO: Increase sequence number here
Err(_) => {
debug!("Failed to send event to WebSocket. Sending kill_send");
connection.kill_send.send(()).expect("Failed to send kill_send");
},
}
}
Err(_) => {
return;
Expand Down
16 changes: 15 additions & 1 deletion src/gateway/ready.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use chorus::types::{GatewayReady, Snowflake, UserNote};
use chorus::types::{ClientInfo, GatewayReady, ReadState, Session, Snowflake, UserNote};
use serde_json::json;
use sqlx::PgPool;

Expand Down Expand Up @@ -52,6 +52,14 @@ pub async fn create_ready(user_id: Snowflake, db: &PgPool) -> Result<GatewayRead
// session disconnect. This is a temporary solution.
let session_id = Snowflake::generate().to_string();

// TODO: This is also just temporary.
let session = Session {
activities: None,
client_info: ClientInfo::default(),
session_id: session_id.clone(),
status: "Testing symfonia".to_string(),
};

// TODO: There are a lot of missing fields here. Ideally, all of the fields should be
// populated with the correct data.
let ready = GatewayReady {
Expand All @@ -62,6 +70,12 @@ pub async fn create_ready(user_id: Snowflake, db: &PgPool) -> Result<GatewayRead
relationships,
private_channels,
notes,
sessions: Some([session].into()),
read_state: ReadState {
entries: Default::default(),
partial: false,
version: 0,
},
..Default::default()
};
log::debug!(target: "symfonia::gateway::ready::create_ready", "Created READY json payload: {:#?}", json!(ready));
Expand Down
34 changes: 32 additions & 2 deletions src/gateway/types/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,15 @@ impl TryFrom<tokio_tungstenite::tungstenite::Message> for Event {
let message_as_string = message.to_string();
// Payload type of option string is okay, since raw_gateway_payload is only used to look at
// the opcode and, if the opcode is 0 (= dispatch), the event name in the received message
let raw_gateway_payload: GatewayPayload<Option<String>> = from_str(&message_as_string)?;
let raw_gateway_payload: GatewayPayload<Option<serde_json::Value>> =
from_str(&message_as_string)?;
match Opcode::try_from(raw_gateway_payload.op_code).map_err(|_| {
Error::Gateway(GatewayError::UnexpectedOpcode(
raw_gateway_payload.op_code.into(),
))
})? {
Opcode::Heartbeat => return convert_to!(Event::Heartbeat, message_as_string),
Opcode::Identify => return convert_to!(Event::Heartbeat, message_as_string),
Opcode::Identify => return convert_to!(Event::Identify, message_as_string),
Opcode::PresenceUpdate => return convert_to!(Event::PresenceUpdate, message_as_string),
Opcode::VoiceStateUpdate => {
return convert_to!(Event::VoiceStateUpdate, message_as_string)
Expand Down Expand Up @@ -709,3 +710,32 @@ impl TryFrom<tokio_tungstenite::tungstenite::Message> for Event {
}
}
}

#[cfg(test)]
mod tests {

use serde_json::Value;

use super::*;
#[test]
fn identify_from_raw_json() {
let json = r#"{"op":2,"d":{"token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3Mjk4Njg3MzQsImlhdCI6MTcyOTc4MjMzNCwiZW1haWwiOiJkZmdkc2Znc2RmZ0Bkc2Zmc2Rmc2QuZGUiLCJpZCI6IjEyOTkwMjYwMDU1MzIzNDg0MTYifQ.3mFo83e0ehI4JWUFy631hUXPJKxjJWUSIT5laDTbzzU","capabilities":16381,"properties":{"browser":"Spacebar Web","client_build_number":0,"release_channel":"dev","browser_user_agent":"Mozilla/5.0 (X11; Linux x86_64; rv:131.0) Gecko/20100101 Firefox/131.0"},"compress":false,"presence":{"status":"online","since":1729782873344,"activities":[],"afk":false}}}"#;
let message = Message::Text(json.to_string());
let gateway_payload_string =
from_str::<GatewayPayload<Option<Value>>>(&message.to_string()).unwrap();
dbg!(gateway_payload_string);
let event = Event::try_from(message).unwrap();
dbg!(event);
}

#[test]
fn heartbeat_from_raw_json() {
let json = r#"{"op":1}"#;
let message = Message::Text(json.to_string());
let gateway_payload_string =
from_str::<GatewayPayload<Option<Value>>>(&message.to_string()).unwrap();
dbg!(gateway_payload_string);
let event = Event::try_from(message).unwrap();
dbg!(event);
}
}
19 changes: 17 additions & 2 deletions src/gateway/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ where
{
#[serde(rename = "op")]
pub op_code: u8,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "d")]
pub event_data: T,
pub event_data: Option<T>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "s")]
pub sequence_number: Option<u64>,
Expand All @@ -71,6 +72,20 @@ where
pub event_name: Option<String>,
}

impl<T: Serialize + DeserializeOwned> GatewayPayload<T> {
pub fn has_data(&self) -> bool {
self.event_data.is_some()
}

pub fn has_sequence(&self) -> bool {
self.sequence_number.is_some()
}

pub fn has_event_name(&self) -> bool {
self.event_name.is_some()
}
}

impl<'de, T: DeserializeOwned + Serialize> Deserialize<'de> for GatewayPayload<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand All @@ -83,7 +98,7 @@ impl<'de, T: DeserializeOwned + Serialize> Deserialize<'de> for GatewayPayload<T
Ok(t) => t,
Err(e) => return Err(::serde::de::Error::custom(e)),
},
None => return Err(::serde::de::Error::missing_field("d")),
None => None,
};
let sequence_number = value.get("s").cloned().map(|v| v.as_u64().unwrap());
let event_name = match value.get("t") {
Expand Down

0 comments on commit 8341331

Please sign in to comment.