Skip to content

Commit

Permalink
enumify + remove a generic
Browse files Browse the repository at this point in the history
  • Loading branch information
oscartbeaumont committed Jul 17, 2023
1 parent bd89813 commit 4cb9ad8
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 97 deletions.
16 changes: 11 additions & 5 deletions src/integrations/httpz/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::pin::pin;

use futures::SinkExt;
use httpz::{
http::{Response, StatusCode},
ws::WebsocketUpgrade,
Expand Down Expand Up @@ -44,11 +47,14 @@ where
};

let cookies = req.cookies(); // TODO: Reorder args of next func so cookies goes first
WebsocketUpgrade::from_req_with_cookies(req, cookies, move |_, socket| {
ConnectionTask::<TokioRuntime, TCtx, _, _, _, _>::new(
Connection::new(ctx, executor),
socket,
)
WebsocketUpgrade::from_req_with_cookies(req, cookies, move |_, socket| async move {
let socket = socket.with(|v: String| async move {
Ok(httpz::ws::Message::Text(v)) as Result<_, httpz::Error>
});
let socket = pin!(socket);

ConnectionTask::<TokioRuntime, TCtx, _, _, _>::new(Connection::new(ctx, executor), socket)
.await;
})
.into_response()
}
21 changes: 7 additions & 14 deletions src/integrations/tauri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use tokio::sync::mpsc;

use crate::{
internal::exec::{
AsyncRuntime, Connection, ConnectionTask, Executor, IncomingMessage, OutgoingMessage,
SubscriptionMap, TokioRuntime,
AsyncRuntime, Connection, ConnectionTask, Executor, IncomingMessage, SubscriptionMap,
TokioRuntime,
},
BuiltRouter,
};
Expand Down Expand Up @@ -76,7 +76,7 @@ where
};
let ctx = (self.ctx_fn)(window.clone());
let handle = R::spawn(async move {
ConnectionTask::<R, TCtx, _, _, _, _>::new(Connection::new(ctx, executor), socket)
ConnectionTask::<R, TCtx, _, _, _>::new(Connection::new(ctx, executor), socket)
.await;
});

Expand Down Expand Up @@ -157,21 +157,16 @@ struct Socket {
window: Window,
}

impl futures::Sink<OutgoingMessage> for Socket {
impl futures::Sink<String> for Socket {
type Error = Infallible;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn start_send(
self: std::pin::Pin<&mut Self>,
item: OutgoingMessage,
) -> Result<(), Self::Error> {
println!("OUTGOING: {:?}", item.0); // TODO

fn start_send(self: std::pin::Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
self.window
.emit("plugin:rspc:transport:resp", item.0)
.emit("plugin:rspc:transport:resp", item)
.map_err(|err| {
#[cfg(feature = "tracing")]
tracing::error!("failed to emit JSON-RPC response: {}", err);
Expand All @@ -194,8 +189,6 @@ impl futures::Stream for Socket {
type Item = Result<IncomingMessage, Infallible>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let y = self.recv.poll_recv(cx).map(|v| v.map(Ok));
println!("INCOMING: {:#?}", y); // TODO
y
self.recv.poll_recv(cx).map(|v| v.map(Ok))
}
}
6 changes: 3 additions & 3 deletions src/internal/exec/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use super::{
// TODO: Seal this shit up tight

/// TODO
#[pin_project(project = PlzNameThisEnumProj)]
#[pin_project(project = StreamOrFutProj)]
enum StreamOrFut<TCtx: 'static> {
OwnedStream(#[pin] OwnedStream<TCtx>),
ExecRequestFut(#[pin] PinnedOption<ExecRequestFut>),
Expand All @@ -29,7 +29,7 @@ impl<TCtx: 'static> Stream for StreamOrFut<TCtx> {

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.project() {
PlzNameThisEnumProj::OwnedStream(s) => {
StreamOrFutProj::OwnedStream(s) => {
let s = s.project();

let v = ready!(s.reference.poll_next(cx));
Expand All @@ -42,7 +42,7 @@ impl<TCtx: 'static> Stream for StreamOrFut<TCtx> {
},
}))
}
PlzNameThisEnumProj::ExecRequestFut(mut s) => match s.as_mut().project() {
StreamOrFutProj::ExecRequestFut(mut s) => match s.as_mut().project() {
PinnedOptionProj::Some(ss) => ss.poll(cx).map(|v| {
s.set(PinnedOption::None);
Some(v)
Expand Down
118 changes: 43 additions & 75 deletions src/internal/exec/connection2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,29 @@ use futures::{ready, Sink, Stream};
use pin_project::pin_project;
use serde_json::Value;

use super::{AsyncRuntime, Batcher, Connection};
use super::{AsyncRuntime, Batcher, Connection, IncomingMessage};
use crate::internal::exec;

// TODO: Maybe merge this `Connection` and `Batch` abstractions into this?
// TODO: Rewrite this to named enums instead of `Option<bool>` and the like
enum PollResult {
/// The poller has done some progressed work.
/// WARNING: this does not guarantee any wakers have been registered so to uphold the `Future` invariants you can not return.
Progressed,

#[derive(Debug)]
pub(crate) enum IncomingMessage {
Msg(Result<Value, serde_json::Error>),
Close,
Skip,
}
/// The poller has queued a message to be sent.
/// WARNING: You must call `Self::poll_send` to prior to returning from the `Future::poll` method.
QueueSend,

pub(crate) struct OutgoingMessage(pub String);

#[cfg(feature = "httpz")]
impl From<httpz::ws::Message> for IncomingMessage {
fn from(value: httpz::ws::Message) -> Self {
match value {
httpz::ws::Message::Text(v) => Self::Msg(serde_json::from_str(&v)),
httpz::ws::Message::Binary(v) => Self::Msg(serde_json::from_slice(&v)),
httpz::ws::Message::Ping(_) | httpz::ws::Message::Pong(_) => Self::Skip,
httpz::ws::Message::Close(_) => Self::Close,
httpz::ws::Message::Frame(_) => {
#[cfg(debug_assertions)]
unreachable!("Reading a 'httpz::ws::Message::Frame' is impossible");

#[cfg(not(debug_assertions))]
return Self::Skip;
}
}
}
}

#[cfg(feature = "httpz")]
impl From<OutgoingMessage> for httpz::ws::Message {
fn from(value: OutgoingMessage) -> Self {
Self::Text(value.0)
}
/// The future is complete
Complete,
}

/// TODO
#[pin_project(project = ConnectionTaskProj)]
pub(crate) struct ConnectionTask<
R: AsyncRuntime,
TCtx: Clone + Send + 'static,
S: Sink<M, Error = E> + Stream<Item = Result<M2, E>> + Send + Unpin,
M: From<OutgoingMessage>,
S: Sink<String, Error = E> + Stream<Item = Result<M2, E>> + Send + Unpin,
// TODO: Remove both?
M2: Into<IncomingMessage>,
E: std::fmt::Debug + std::error::Error,
> {
Expand All @@ -67,19 +42,18 @@ pub(crate) struct ConnectionTask<

#[pin]
socket: S,
tx_queue: Option<M>,
tx_queue: Option<String>,

phantom: PhantomData<M2>,
}

impl<
R: AsyncRuntime,
TCtx: Clone + Send + 'static,
S: Sink<M, Error = E> + Stream<Item = Result<M2, E>> + Send + Unpin,
M: From<OutgoingMessage>,
S: Sink<String, Error = E> + Stream<Item = Result<M2, E>> + Send + Unpin,
M2: Into<IncomingMessage>,
E: std::fmt::Debug + std::error::Error,
> ConnectionTask<R, TCtx, S, M, M2, E>
> ConnectionTask<R, TCtx, S, M2, E>
{
pub fn new(conn: Connection<R, TCtx>, socket: S) -> Self {
Self {
Expand All @@ -92,16 +66,14 @@ impl<
}

/// Poll sending
///
/// `Poll::Ready(())` is returned no wakers have been registered. This invariant must be maintained by caller!
fn poll_send(
this: &mut ConnectionTaskProj<R, TCtx, S, M, M2, E>,
this: &mut ConnectionTaskProj<R, TCtx, S, M2, E>,
cx: &mut Context<'_>,
) -> Poll<()> {
) -> Poll<PollResult> {
// If nothing in `tx_queue`, poll the batcher to populate it
if this.tx_queue.is_none() {
match ready!(this.batch.as_mut().poll_next(cx)) {
Some(Some(json)) => *this.tx_queue = Some(OutgoingMessage(json).into()),
Some(Some(json)) => *this.tx_queue = Some(json),
Some(None) => {}
None => panic!("rspc: batcher stream ended unexpectedly"),
};
Expand All @@ -114,7 +86,7 @@ impl<
#[cfg(feature = "tracing")]
tracing::error!("Error waiting for websocket to be ready: {}", err);

return Poll::Ready(());
return PollResult::Progressed.into();
};

let item = this
Expand All @@ -135,7 +107,7 @@ impl<
tracing::error!("Error flushing message to websocket: {}", err);
}

Poll::Ready(())
PollResult::Progressed.into()
}

/// Poll receiving
Expand All @@ -144,9 +116,9 @@ impl<
/// `Poll::Ready(Some(false))` means you must `Self::poll_send`. This invariant must be maintained by caller!
/// `Poll::Ready(None)` is returned no wakers have been registered. This invariant must be maintained by caller!
fn poll_recv(
this: &mut ConnectionTaskProj<R, TCtx, S, M, M2, E>,
this: &mut ConnectionTaskProj<R, TCtx, S, M2, E>,
cx: &mut Context<'_>,
) -> Poll<Option<bool>> {
) -> Poll<PollResult> {
match ready!(this.socket.as_mut().poll_next(cx)) {
Some(Ok(msg)) => {
let res = match msg.into() {
Expand All @@ -158,21 +130,17 @@ impl<
// TODO: Terminate all subscriptions
// TODO: Tell frontend all subscriptions were terminated

return Poll::Ready(Some(true));
return PollResult::Complete.into();
}
IncomingMessage::Skip => return Poll::Ready(None),
IncomingMessage::Skip => return PollResult::Progressed.into(),
};

match res.and_then(|v| match v.is_array() {
true => serde_json::from_value::<Vec<exec::Request>>(v),
false => serde_json::from_value::<exec::Request>(v).map(|v| vec![v]),
}) {
Ok(reqs) => {
let mut a = this.conn.exec(reqs);
println!("C {a:?}");
this.batch.as_mut().append(&mut a);

return Poll::Ready(Some(false));
this.batch.as_mut().append(&mut this.conn.exec(reqs));
}
Err(_err) => {
#[cfg(feature = "tracing")]
Expand All @@ -182,13 +150,17 @@ impl<
println!("D {_err:?}");
}
}

PollResult::QueueSend
}
Some(Err(_err)) => {
#[cfg(feature = "tracing")]
tracing::debug!("Error reading from websocket connection: {:?}", _err);

// TODO: Send report of error to frontend
println!("E {_err:?}");

PollResult::QueueSend
}
None => {
#[cfg(feature = "tracing")]
Expand All @@ -197,38 +169,33 @@ impl<
// TODO: Terminate all subscriptions
// TODO: Tell frontend all subscriptions were terminated

return Poll::Ready(Some(true));
PollResult::Complete
}
}

Poll::Ready(None)
.into()
}

/// Poll active streams
///
/// `Poll::Ready(false)` is returned no wakers have been registered. This invariant must be maintained by caller!
/// `Poll::Ready(true)` means you must `Self::poll_send`. This invariant must be maintained by caller!
fn poll_streams(
this: &mut ConnectionTaskProj<R, TCtx, S, M, M2, E>,
this: &mut ConnectionTaskProj<R, TCtx, S, M2, E>,
cx: &mut Context<'_>,
) -> Poll<bool> {
) -> Poll<PollResult> {
if let Some(batch) = ready!(this.conn.as_mut().poll_next(cx)).expect("rspc unreachable") {
this.batch.as_mut().insert(batch);
return Poll::Ready(true);
return PollResult::QueueSend.into();
}

Poll::Ready(false)
PollResult::Progressed.into()
}
}

impl<
R: AsyncRuntime,
TCtx: Clone + Send + 'static,
S: Sink<M, Error = E> + Stream<Item = Result<M2, E>> + Send + Unpin,
M: From<OutgoingMessage>,
S: Sink<String, Error = E> + Stream<Item = Result<M2, E>> + Send + Unpin,
M2: Into<IncomingMessage>,
E: std::fmt::Debug + std::error::Error,
> Future for ConnectionTask<R, TCtx, S, M, M2, E>
> Future for ConnectionTask<R, TCtx, S, M2, E>
{
type Output = ();

Expand All @@ -244,17 +211,18 @@ impl<
}

match Self::poll_recv(&mut this, cx) {
Poll::Ready(Some(true)) => return Poll::Ready(()),
Poll::Ready(Some(false)) => continue,
Poll::Ready(None) => {}
Poll::Ready(PollResult::Complete) => return Poll::Ready(()),
Poll::Ready(PollResult::Progressed) => {}
Poll::Ready(PollResult::QueueSend) => continue,
Poll::Pending => {
is_pending = true;
}
}

match Self::poll_streams(&mut this, cx) {
Poll::Ready(true) => continue,
Poll::Ready(false) => {}
Poll::Ready(PollResult::Complete) => return Poll::Ready(()),
Poll::Ready(PollResult::Progressed) => {}
Poll::Ready(PollResult::QueueSend) => continue,
Poll::Pending => {
is_pending = true;
}
Expand Down
27 changes: 27 additions & 0 deletions src/internal/exec/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,33 @@ mod private {
#[serde(flatten)]
pub result: ValueOrError,
}

/// TODO
#[derive(Debug)]
pub(crate) enum IncomingMessage {
Msg(Result<Value, serde_json::Error>),
Close,
Skip,
}

#[cfg(feature = "httpz")]
impl From<httpz::ws::Message> for IncomingMessage {
fn from(value: httpz::ws::Message) -> Self {
match value {
httpz::ws::Message::Text(v) => Self::Msg(serde_json::from_str(&v)),
httpz::ws::Message::Binary(v) => Self::Msg(serde_json::from_slice(&v)),
httpz::ws::Message::Ping(_) | httpz::ws::Message::Pong(_) => Self::Skip,
httpz::ws::Message::Close(_) => Self::Close,
httpz::ws::Message::Frame(_) => {
#[cfg(debug_assertions)]
unreachable!("Reading a 'httpz::ws::Message::Frame' is impossible");

#[cfg(not(debug_assertions))]
return Self::Skip;
}
}
}
}
}

#[cfg(feature = "unstable")]
Expand Down

0 comments on commit 4cb9ad8

Please sign in to comment.