Skip to content

Commit

Permalink
Update async examples (use send feature)
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Jul 31, 2024
1 parent a86d6ab commit 8337909
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 73 deletions.
5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ libloading = { version = "0.8", optional = true }

[dev-dependencies]
trybuild = "1.0"
futures = "0.3.5"
hyper = { version = "1.2", features = ["full"] }
hyper-util = { version = "0.1.3", features = ["full"] }
http-body-util = "0.1.1"
Expand Down Expand Up @@ -101,11 +100,11 @@ required-features = ["async", "serialize", "macros"]

[[example]]
name = "async_http_server"
required-features = ["async", "macros"]
required-features = ["async", "macros", "send"]

[[example]]
name = "async_tcp_server"
required-features = ["async", "macros"]
required-features = ["async", "macros", "send"]

[[example]]
name = "guided_tour"
Expand Down
69 changes: 20 additions & 49 deletions examples/async_http_server.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::rc::Rc;
use std::pin::Pin;

use futures::future::LocalBoxFuture;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt as _, Empty, Full};
use hyper::body::{Bytes, Incoming};
use hyper::server::conn::http1;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use hyper_util::server::conn::auto::Builder as ServerConnBuilder;
use tokio::net::TcpListener;
use tokio::task::LocalSet;

use mlua::{
chunk, Error as LuaError, Function, Lua, RegistryKey, String as LuaString, Table, UserData,
UserDataMethods,
};
use mlua::{chunk, Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods};

/// Wrapper around incoming request that implements UserData
struct LuaRequest(SocketAddr, Request<Incoming>);
Expand All @@ -32,33 +27,26 @@ impl UserData for LuaRequest {
/// Service that handles incoming requests
#[derive(Clone)]
pub struct Svc {
lua: Rc<Lua>,
handler: Rc<RegistryKey>,
handler: Function,
peer_addr: SocketAddr,
}

impl Svc {
pub fn new(lua: Rc<Lua>, handler: Rc<RegistryKey>, peer_addr: SocketAddr) -> Self {
Self {
lua,
handler,
peer_addr,
}
pub fn new(handler: Function, peer_addr: SocketAddr) -> Self {
Self { handler, peer_addr }
}
}

impl hyper::service::Service<Request<Incoming>> for Svc {
type Response = Response<BoxBody<Bytes, Infallible>>;
type Error = LuaError;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn call(&self, req: Request<Incoming>) -> Self::Future {
// If handler returns an error then generate 5xx response
let lua = self.lua.clone();
let handler_key = self.handler.clone();
let handler = self.handler.clone();
let lua_req = LuaRequest(self.peer_addr, req);
Box::pin(async move {
let handler: Function = lua.registry_value(&handler_key)?;
match handler.call_async::<_, Table>(lua_req).await {
Ok(lua_resp) => {
let status = lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200);
Expand Down Expand Up @@ -94,10 +82,10 @@ impl hyper::service::Service<Request<Incoming>> for Svc {

#[tokio::main(flavor = "current_thread")]
async fn main() {
let lua = Rc::new(Lua::new());
let lua = Lua::new();

// Create Lua handler function
let handler: RegistryKey = lua
let handler = lua
.load(chunk! {
function(req)
return {
Expand All @@ -111,15 +99,13 @@ async fn main() {
}
end
})
.eval()
.eval::<Function>()
.expect("Failed to create Lua handler");
let handler = Rc::new(handler);

let listen_addr = "127.0.0.1:3000";
let listener = TcpListener::bind(listen_addr).await.unwrap();
println!("Listening on http://{listen_addr}");

let local = LocalSet::new();
loop {
let (stream, peer_addr) = match listener.accept().await {
Ok(x) => x,
Expand All @@ -129,29 +115,14 @@ async fn main() {
}
};

let svc = Svc::new(lua.clone(), handler.clone(), peer_addr);
local
.run_until(async move {
let result = ServerConnBuilder::new(LocalExec)
.http1()
.serve_connection(TokioIo::new(stream), svc)
.await;
if let Err(err) = result {
eprintln!("Error serving connection: {err:?}");
}
})
.await;
}
}

#[derive(Clone, Copy, Debug)]
struct LocalExec;

impl<F> hyper::rt::Executor<F> for LocalExec
where
F: Future + 'static, // not requiring `Send`
{
fn execute(&self, fut: F) {
tokio::task::spawn_local(fut);
let svc = Svc::new(handler.clone(), peer_addr);
tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(TokioIo::new(stream), svc)
.await
{
eprintln!("Error serving connection: {:?}", err);
}
});
}
}
25 changes: 5 additions & 20 deletions examples/async_tcp_server.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::io;
use std::net::SocketAddr;
use std::rc::Rc;

use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::task;

use mlua::{chunk, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods};
use mlua::{chunk, Function, Lua, String as LuaString, UserData, UserDataMethods};

struct LuaTcpStream(TcpStream);

Expand All @@ -33,26 +31,21 @@ impl UserData for LuaTcpStream {
}
}

async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> {
async fn run_server(handler: Function) -> io::Result<()> {
let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();
let listener = TcpListener::bind(addr).await.expect("cannot bind addr");

println!("Listening on {}", addr);

let lua = Rc::new(lua);
let handler = Rc::new(handler);
loop {
let (stream, _) = match listener.accept().await {
Ok(res) => res,
Err(err) if is_transient_error(&err) => continue,
Err(err) => return Err(err),
};

let lua = lua.clone();
let handler = handler.clone();
task::spawn_local(async move {
let handler: Function = lua.registry_value(&handler).expect("cannot get Lua handler");

tokio::task::spawn(async move {
let stream = LuaTcpStream(stream);
if let Err(err) = handler.call_async::<_, ()>(stream).await {
eprintln!("{}", err);
Expand All @@ -66,7 +59,7 @@ async fn main() {
let lua = Lua::new();

// Create Lua handler function
let handler_fn = lua
let handler = lua
.load(chunk! {
function(stream)
local peer_addr = stream:peer_addr()
Expand All @@ -88,15 +81,7 @@ async fn main() {
.eval::<Function>()
.expect("cannot create Lua handler");

// Store it in the Registry
let handler = lua
.create_registry_value(handler_fn)
.expect("cannot store Lua handler");

task::LocalSet::new()
.run_until(run_server(lua, handler))
.await
.expect("cannot run server")
run_server(handler).await.expect("cannot run server")
}

fn is_transient_error(e: &io::Error) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ impl Thread {
///
/// ```
/// # use mlua::{Lua, Result, Thread};
/// use futures::stream::TryStreamExt;
/// use futures_util::stream::TryStreamExt;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// # let lua = Lua::new();
Expand Down

0 comments on commit 8337909

Please sign in to comment.