Skip to content

Commit

Permalink
OpenAPI working
Browse files Browse the repository at this point in the history
  • Loading branch information
oscartbeaumont committed Jul 17, 2024
1 parent 687c3f7 commit 5abc0e8
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 46 deletions.
3 changes: 2 additions & 1 deletion examples/axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ thiserror = "1.0.62"
async-stream = "0.3.5"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
rspc-tracing = { version = "0.0.0", path = "../../middleware/tracing" } # TODO: Remove?
rspc-tracing = { version = "0.0.0", path = "../../middleware/tracing" }
rspc-openapi = { version = "0.0.0", path = "../../middleware/openapi" }
serde = { version = "1", features = ["derive"] }
specta = { version = "=2.0.0-rc.15", features = ["derive"] } # TODO: Drop requirement on `derive`
specta-util = "0.0.2" # TODO: We need this for `TypeCollection` which is cringe
Expand Down
6 changes: 4 additions & 2 deletions examples/axum/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use specta_typescript::Typescript;
use specta_util::TypeCollection;
use thiserror::Error;

mod chat;
pub(crate) mod chat;
pub(crate) mod store;

#[derive(Debug, Error)]
pub enum Error {}

// `Clone` is only required for usage with Websockets
#[derive(Default, Clone)]
#[derive(Clone)]
pub struct Context {
pub chat: chat::Ctx,
}
Expand All @@ -40,6 +41,7 @@ pub fn mount() -> Router {
<BaseProcedure>::builder().query(|_, _: ()| async { Ok(env!("CARGO_PKG_VERSION")) })
})
.merge("chat", chat::mount())
.merge("store", store::mount())
// TODO: I dislike this API
.ext({
let mut types = TypeCollection::default();
Expand Down
6 changes: 3 additions & 3 deletions examples/axum/src/api/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ pub struct Ctx {
chat: broadcast::Sender<Message>,
}

impl Default for Ctx {
fn default() -> Self {
impl Ctx {
pub fn new(chat: broadcast::Sender<Message>) -> Self {
Self {
author: Arc::new(Mutex::new("Anonymous".into())),
chat: broadcast::channel(100).0,
chat,
}
}
}
Expand Down
25 changes: 25 additions & 0 deletions examples/axum/src/api/store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use rspc_openapi::OpenAPI;

use super::{BaseProcedure, Router};

pub fn mount() -> Router {
Router::new()
.procedure("get", {
<BaseProcedure>::builder()
.with(OpenAPI::get("/api/get").build())
.mutation(|ctx, _: ()| async move {
// TODO

Ok("Hello From rspc!")
})
})
.procedure("set", {
<BaseProcedure>::builder()
.with(OpenAPI::post("/api/set").build())
.mutation(|ctx, value: String| async move {
// TODO

Ok(())
})
})
}
17 changes: 9 additions & 8 deletions examples/axum/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::net::Ipv6Addr;

use axum::{routing::get, Router};
use tokio::sync::broadcast;
use tracing::info;

mod api;
Expand All @@ -11,18 +12,18 @@ async fn main() {

let router = api::mount().build().unwrap();

let chat_tx = broadcast::channel(100).0;
let ctx_fn = move || api::Context {
chat: api::chat::Ctx::new(chat_tx.clone()),
};

let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.nest(
"/rspc",
rspc_axum::Endpoint::new(router, || api::Context {
chat: Default::default(),
})
.with_endpoints()
.with_websocket()
.with_batching()
.build(),
);
rspc_axum::Endpoint::new(router.clone(), ctx_fn.clone()),

Check failure on line 24 in examples/axum/src/main.rs

View workflow job for this annotation

GitHub Actions / Clippy

mismatched types

error[E0308]: mismatched types --> examples/axum/src/main.rs:24:13 | 22 | .nest( | ---- arguments to this method are incorrect 23 | "/rspc", 24 | rspc_axum::Endpoint::new(router.clone(), ctx_fn.clone()), | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `Router<_>`, found `Endpoint<Context>` | = note: expected struct `axum::Router<_>` found struct `rspc_axum::Endpoint<api::Context>` note: method defined here --> /home/runner/.cargo/registry/src/index.crates.io-6f17d22bba15001f/axum-0.7.5/src/routing/mod.rs:189:12 | 189 | pub fn nest(self, path: &str, router: Router<S>) -> Self { | ^^^^
)
.nest("/", rspc_openapi::mount(router, ctx_fn));

info!("Listening on http://[::1]:3000");
let listener = tokio::net::TcpListener::bind((Ipv6Addr::UNSPECIFIED, 3000))
Expand Down
3 changes: 3 additions & 0 deletions middleware/openapi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ publish = false # TODO: Crate metadata & publish

[dependencies]
rspc = { path = "../../rspc" }
axum = { version = "0.7.5", default-features = false }
serde_json = "1.0.120"
futures = "0.3.30"

# /bin/sh RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features
[package.metadata."docs.rs"]
Expand Down
245 changes: 219 additions & 26 deletions middleware/openapi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,230 @@
html_favicon_url = "https://github.com/oscartbeaumont/rspc/raw/main/docs/public/logo.png"
)]

use std::{borrow::Cow, collections::HashMap};
use std::{borrow::Cow, collections::HashMap, hash::Hash, sync::Arc};

use rspc::middleware::Middleware;
use axum::{
body::Bytes,
extract::Query,
http::StatusCode,
response::Html,
routing::{get, post},
Json,
};
use futures::StreamExt;
use rspc::{
middleware::Middleware,
procedure::{Procedure, ProcedureInput},
BuiltRouter,
};
use serde_json::json;

#[derive(Default)]
pub struct OpenAPIState(HashMap<Cow<'static, str>, ()>);
// TODO: Properly handle inputs from query params
// TODO: Properly handle responses from query params
// TODO: Support input's coming from URL. Eg. `/todos/{id}` like tRPC-OpenAPI
// TODO: Support `application/x-www-form-urlencoded` bodies like tRPC-OpenAPI
// TODO: Probs put SwaggerUI behind a feature flag

pub struct OpenAPI {
method: &'static str,
path: Cow<'static, str>,
}

impl OpenAPI {
// TODO
// pub fn new(method: Method, path: impl Into<Cow<'static, str>>) {}

pub fn get(path: impl Into<Cow<'static, str>>) -> Self {
Self {
method: "GET",
path: path.into(),
}
}

pub fn post(path: impl Into<Cow<'static, str>>) -> Self {
Self {
method: "GET",
path: path.into(),
}
}

pub fn put(path: impl Into<Cow<'static, str>>) -> Self {
Self {
method: "GET",
path: path.into(),
}
}

pub fn patch(path: impl Into<Cow<'static, str>>) -> Self {
Self {
method: "GET",
path: path.into(),
}
}

pub fn delete(path: impl Into<Cow<'static, str>>) -> Self {
Self {
method: "GET",
path: path.into(),
}
}

// TODO: Configure other OpenAPI stuff like auth???

pub fn build<TError, TThisCtx, TThisInput, TThisResult>(
self,
) -> Middleware<TError, TThisCtx, TThisInput, TThisResult>
where
TError: 'static,
TThisCtx: Send + 'static,
TThisInput: Send + 'static,
TThisResult: Send + 'static,
{
// TODO: Can we have a middleware with only a `setup` function to avoid the extra future boxing???
Middleware::new(|ctx, input, next| async move { next.exec(ctx, input).await }).setup(
move |state, meta| {
state
.get_mut_or_init::<OpenAPIState>(Default::default)
.0
.insert((self.method, self.path), meta.name().to_string());
},
)
}
}

// TODO: Configure other OpenAPI stuff like auth
// The state that is stored into rspc.
// A map of (method, path) to procedure name.
#[derive(Default)]
struct OpenAPIState(HashMap<(&'static str, Cow<'static, str>), String>);

// TODO: Make convert this into a builder like: Endpoint::get("/todo").some_other_stuff().build()
pub fn openapi<TError, TThisCtx, TThisInput, TThisResult>(
// method: Method,
path: impl Into<Cow<'static, str>>,
) -> Middleware<TError, TThisCtx, TThisInput, TThisResult>
// TODO: Axum should be behind feature flag
// TODO: Can we decouple webserver from OpenAPI while keeping something maintainable????
pub fn mount<TCtx, S>(
router: BuiltRouter<TCtx>,
ctx_fn: impl Fn() -> TCtx + Clone + Send + Sync + 'static,
) -> axum::Router<S>
where
TError: 'static,
TThisCtx: Send + 'static,
TThisInput: Send + 'static,
TThisResult: Send + 'static,
S: Clone + Send + Sync + 'static,
TCtx: Send + 'static,
{
let path = path.into();
Middleware::new(|ctx, input, next| async move {
let _result = next.exec(ctx, input).await;
_result
})
.setup(|state, meta| {
state
.get_mut_or_init::<OpenAPIState>(Default::default)
.0
.insert(path, ());
})
let mut r = axum::Router::new();

let mut paths: HashMap<_, HashMap<_, _>> = HashMap::new();
if let Some(endpoints) = router.state.get::<OpenAPIState>() {
for ((method, path), procedure_name) in endpoints.0.iter() {
let procedure = router
.procedures
.get(&Cow::Owned(procedure_name.clone()))
.expect("unreachable: a procedure was registered that doesn't exist")
.clone();
let ctx_fn = ctx_fn.clone();

paths
.entry(path.clone())
.or_default()
.insert(method.to_lowercase(), procedure.clone());

r = r.route(
path,
match *method {
"GET" => {
// TODO: By moving `procedure` into the closure we hang onto the types for the duration of the program which is probs undesirable.
get(move |query: Query<HashMap<String, String>>| async move {
let ctx = (ctx_fn)();

handle_procedure(
ctx,
&mut serde_json::Deserializer::from_str(
query.get("input").map(|v| &**v).unwrap_or("null"),
),
procedure,
)
.await
})
}
"POST" => {
// TODO: By moving `procedure` into the closure we hang onto the types for the duration of the program which is probs undesirable.
post(move |body: Bytes| async move {
let ctx = (ctx_fn)();

handle_procedure(
ctx,
&mut serde_json::Deserializer::from_slice(&body),
procedure,
)
.await
})
}
// "PUT" => axum::routing::put,
// "PATCH" => axum::routing::patch,
// "DELETE" => axum::routing::delete,
_ => panic!("Unsupported method"),
},
);
}
}

let schema = Arc::new(json!({
"openapi": "3.0.3",
"info": {
"title": "rspc OpenAPI",
"description": "This is a demo of rspc OpenAPI",
"version": "0.0.0"
},
"paths": paths.into_iter()
.map(|(path, procedures)| {
let mut methods = HashMap::new();
for (method, procedure) in procedures {
methods.insert(method.to_string(), json!({
"operationId": procedure.ty().key.to_string(),
"responses": {
"200": {
"description": "Successful operation"
}
}
}));
}

(path, methods)
})
.collect::<HashMap<_, _>>()
})); // TODO: Maybe convert to string now cause it will be more efficient to clone

r.route(
// TODO: Allow the user to configure this URL & turn it off
"/api/docs",
get(|| async { Html(include_str!("swagger.html")) }),
)
.route(
// TODO: Allow the user to configure this URL & turn it off
"/api/openapi.json",
get(move || async move { Json((*schema).clone()) }),
)
}

// TODO: Convert into API endpoint
// Used for `GET` and `POST` endpoints
async fn handle_procedure<'de, TCtx>(
ctx: TCtx,
input: impl ProcedureInput<'de>,
procedure: Procedure<TCtx>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<String>)> {
let mut stream = procedure.exec(ctx, input).map_err(|err| {
// TODO: Error code by matching off `InternalError`
(StatusCode::INTERNAL_SERVER_ERROR, Json(err.to_string()))
})?;

// TODO: Support for streaming
while let Some(value) = stream.next().await {
// TODO: We should probs deserialize into buffer instead of value???
return match value.map(|v| v.serialize(serde_json::value::Serializer)) {
Ok(Ok(value)) => Ok(Json(value)),
Ok(Err(err)) => {
// TODO: Error code by matching off `InternalError`
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(err.to_string())))
}
Err(err) => panic!("{err:?}"), // TODO: Error handling -> How to serialize `TError`??? -> Should this be done in procedure?
};
}

Ok(Json(serde_json::Value::Null))
}
Loading

0 comments on commit 5abc0e8

Please sign in to comment.