From 25a4879cdeff7f6c23f2d157507555a721e8a3a9 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Tue, 12 Oct 2021 11:55:25 +0100 Subject: [PATCH] Update examples: async_http_client/async_http_server/async_tcp_server Make them following best practices and remove `unsafe` code. --- Cargo.toml | 2 +- examples/async_http_client.rs | 23 ++--- examples/async_http_server.rs | 143 +++++++++++++++++---------- examples/async_tcp_server.rs | 177 +++++++++++++++++----------------- 4 files changed, 186 insertions(+), 159 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8644f3b..b0c22b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,7 +84,7 @@ required-features = ["async", "serialize", "macros"] [[example]] name = "async_http_server" -required-features = ["async", "send"] +required-features = ["async", "macros"] [[example]] name = "async_tcp_server" diff --git a/examples/async_http_client.rs b/examples/async_http_client.rs index 6fd9389..e791808 100644 --- a/examples/async_http_client.rs +++ b/examples/async_http_client.rs @@ -1,26 +1,17 @@ use std::collections::HashMap; -use std::sync::Arc; use hyper::body::{Body as HyperBody, HttpBody as _}; use hyper::Client as HyperClient; -use tokio::sync::Mutex; -use mlua::{chunk, ExternalResult, Lua, Result, UserData, UserDataMethods}; +use mlua::{chunk, AnyUserData, ExternalResult, Lua, Result, UserData, UserDataMethods}; -#[derive(Clone)] -struct BodyReader(Arc>); - -impl BodyReader { - fn new(body: HyperBody) -> Self { - BodyReader(Arc::new(Mutex::new(body))) - } -} +struct BodyReader(HyperBody); impl UserData for BodyReader { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("read", |lua, reader, ()| async move { - let mut reader = reader.0.lock().await; - if let Some(bytes) = reader.data().await { + methods.add_async_function("read", |lua, reader: AnyUserData| async move { + let mut reader = reader.borrow_mut::()?; + if let Some(bytes) = reader.0.data().await { let bytes = bytes.to_lua_err()?; return Some(lua.create_string(&bytes)).transpose(); } @@ -50,7 +41,7 @@ async fn main() -> Result<()> { } lua_resp.set("headers", headers)?; - lua_resp.set("body", BodyReader::new(resp.into_body()))?; + lua_resp.set("body", BodyReader(resp.into_body()))?; Ok(lua_resp) })?; @@ -58,7 +49,7 @@ async fn main() -> Result<()> { let f = lua .load(chunk! { local res = $fetch_url(...) - print(res.status) + print("status: "..res.status) for key, vals in pairs(res.headers) do for _, val in ipairs(vals) do print(key..": "..val) diff --git a/examples/async_http_server.rs b/examples/async_http_server.rs index 17d60b0..43ae7a9 100644 --- a/examples/async_http_server.rs +++ b/examples/async_http_server.rs @@ -1,10 +1,16 @@ +use std::future::Future; use std::net::SocketAddr; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; +use hyper::service::Service; use hyper::{Body, Request, Response, Server}; -use mlua::{Error, Function, Lua, Result, Table, UserData, UserDataMethods}; +use mlua::{ + chunk, Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods, +}; struct LuaRequest(SocketAddr, Request); @@ -15,75 +21,106 @@ impl UserData for LuaRequest { } } -async fn run_server(handler: Function<'static>) -> Result<()> { - let make_svc = make_service_fn(|socket: &AddrStream| { - let remote_addr = socket.remote_addr(); - let handler = handler.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request| { - let handler = handler.clone(); - async move { - let lua_req = LuaRequest(remote_addr, req); - let lua_resp: Table = handler.call_async(lua_req).await?; - let body = lua_resp - .get::<_, Option>("body")? - .unwrap_or_default(); +pub struct Svc(Rc, SocketAddr); - let mut resp = Response::builder() - .status(lua_resp.get::<_, Option>("status")?.unwrap_or(200)); +impl Service> for Svc { + type Response = Response; + type Error = LuaError; + type Future = Pin>>>; + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + // If handler returns an error then generate 5xx response + let lua = self.0.clone(); + let lua_req = LuaRequest(self.1, req); + Box::pin(async move { + let handler: Function = lua.named_registry_value("http_handler")?; + match handler.call_async::<_, Table>(lua_req).await { + Ok(lua_resp) => { + let status = lua_resp.get::<_, Option>("status")?.unwrap_or(200); + let mut resp = Response::builder().status(status); + + // Set headers if let Some(headers) = lua_resp.get::<_, Option>("headers")? { - for pair in headers.pairs::() { + for pair in headers.pairs::() { let (h, v) = pair?; - resp = resp.header(&h, v); + resp = resp.header(&h, v.as_bytes()); } } - Ok::<_, Error>(resp.body(Body::from(body)).unwrap()) + let body = lua_resp + .get::<_, Option>("body")? + .map(|b| Body::from(b.as_bytes().to_vec())) + .unwrap_or_else(Body::empty); + + Ok(resp.body(body).unwrap()) } - })) - } - }); + Err(err) => { + eprintln!("{}", err); + Ok(Response::builder() + .status(500) + .body(Body::from("Internal Server Error")) + .unwrap()) + } + } + }) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() { + let lua = Rc::new(Lua::new()); + + // Create Lua handler function + let handler: Function = lua + .load(chunk! { + function(req) + return { + status = 200, + headers = { + ["X-Req-Method"] = req:method(), + ["X-Remote-Addr"] = req:remote_addr(), + }, + body = "Hello from Lua!\n" + } + end + }) + .eval() + .expect("cannot create Lua handler"); + + // Store it in the Registry + lua.set_named_registry_value("http_handler", handler) + .expect("cannot store Lua handler"); let addr = ([127, 0, 0, 1], 3000).into(); - let server = Server::bind(&addr).executor(LocalExec).serve(make_svc); + let server = Server::bind(&addr).executor(LocalExec).serve(MakeSvc(lua)); println!("Listening on http://{}", addr); - tokio::task::LocalSet::new() - .run_until(server) - .await - .map_err(Error::external) + // Create `LocalSet` to spawn !Send futures + let local = tokio::task::LocalSet::new(); + local.run_until(server).await.expect("cannot run server") } -#[tokio::main] -async fn main() -> Result<()> { - let lua = Lua::new().into_static(); +struct MakeSvc(Rc); - let handler: Function = lua - .load( - r#" - function(req) - return { - status = 200, - headers = { - ["X-Req-Method"] = req:method(), - ["X-Remote-Addr"] = req:remote_addr(), - }, - body = "Hello, World!\n" - } - end - "#, - ) - .eval()?; +impl Service<&AddrStream> for MakeSvc { + type Response = Svc; + type Error = hyper::Error; + type Future = Pin>>>; - run_server(handler).await?; + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } - // Consume the static reference and drop it. - // This is safe as long as we don't hold any other references to Lua - // or alive resources. - unsafe { Lua::from_static(lua) }; - Ok(()) + fn call(&mut self, stream: &AddrStream) -> Self::Future { + let lua = self.0.clone(); + let remote_addr = stream.remote_addr(); + Box::pin(async move { Ok(Svc(lua, remote_addr)) }) + } } #[derive(Clone, Copy, Debug)] diff --git a/examples/async_tcp_server.rs b/examples/async_tcp_server.rs index 21c81c2..edfc114 100644 --- a/examples/async_tcp_server.rs +++ b/examples/async_tcp_server.rs @@ -1,122 +1,121 @@ -use std::sync::Arc; +use std::io; +use std::net::SocketAddr; +use std::rc::Rc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::Mutex; use tokio::task; -use mlua::{chunk, Function, Lua, Result, String as LuaString, UserData, UserDataMethods}; +use mlua::{ + chunk, AnyUserData, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods, +}; -struct LuaTcp; - -#[derive(Clone)] -struct LuaTcpListener(Arc>); - -#[derive(Clone)] -struct LuaTcpStream(Arc>); - -impl UserData for LuaTcp { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_function("bind", |_, addr: String| async move { - let listener = TcpListener::bind(addr).await?; - Ok(LuaTcpListener(Arc::new(Mutex::new(listener)))) - }); - } -} - -impl UserData for LuaTcpListener { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("accept", |_, listener, ()| async move { - let (stream, _) = listener.0.lock().await.accept().await?; - Ok(LuaTcpStream(Arc::new(Mutex::new(stream)))) - }); - } -} +struct LuaTcpStream(TcpStream); impl UserData for LuaTcpStream { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("peer_addr", |_, stream, ()| async move { - Ok(stream.0.lock().await.peer_addr()?.to_string()) + methods.add_method("peer_addr", |_, this, ()| { + Ok(this.0.peer_addr()?.to_string()) }); - methods.add_async_method("read", |lua, stream, size: usize| async move { - let mut buf = vec![0; size]; - let n = stream.0.lock().await.read(&mut buf).await?; - buf.truncate(n); - lua.create_string(&buf) - }); + methods.add_async_function( + "read", + |lua, (this, size): (AnyUserData, usize)| async move { + let mut this = this.borrow_mut::()?; + let mut buf = vec![0; size]; + let n = this.0.read(&mut buf).await?; + buf.truncate(n); + lua.create_string(&buf) + }, + ); - methods.add_async_method("write", |_, stream, data: LuaString| async move { - let n = stream.0.lock().await.write(&data.as_bytes()).await?; - Ok(n) - }); + methods.add_async_function( + "write", + |_, (this, data): (AnyUserData, LuaString)| async move { + let mut this = this.borrow_mut::()?; + let n = this.0.write(&data.as_bytes()).await?; + Ok(n) + }, + ); - methods.add_async_method("close", |_, stream, ()| async move { - stream.0.lock().await.shutdown().await?; + methods.add_async_function("close", |_, this: AnyUserData| async move { + let mut this = this.borrow_mut::()?; + this.0.shutdown().await?; Ok(()) }); } } -async fn run_server(lua: &'static Lua) -> Result<()> { - let spawn = lua.create_function(move |_, func: Function| { - task::spawn_local(async move { func.call_async::<_, ()>(()).await }); - Ok(()) - })?; +async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> { + let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); + let listener = TcpListener::bind(addr).await.expect("cannot bind addr"); - let tcp = LuaTcp; + println!("Listening on {}", addr); - let server = lua + let lua = Rc::new(lua); + let handler = Rc::new(handler); + loop { + let (stream, _) = match listener.accept().await { + Ok(res) => res, + Err(err) if is_transient_error(&err) => continue, + Err(err) => return Err(err), + }; + + let lua = lua.clone(); + let handler = handler.clone(); + task::spawn_local(async move { + let handler: Function = lua + .registry_value(&handler) + .expect("cannot get Lua handler"); + + let stream = LuaTcpStream(stream); + if let Err(err) = handler.call_async::<_, ()>(stream).await { + eprintln!("{}", err); + } + }); + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() { + let lua = Lua::new(); + + // Create Lua handler function + let handler_fn = lua .load(chunk! { - local addr = ... - local listener = $tcp.bind(addr) - print("listening on "..addr) - - local accept_new = true - while true do - local stream = listener:accept() + function(stream) local peer_addr = stream:peer_addr() print("connected from "..peer_addr) - if not accept_new then - return - end - - $spawn(function() - while true do - local data = stream:read(100) - data = data:match("^%s*(.-)%s*$") -- trim - print("["..peer_addr.."] "..data) - if data == "bye" then - stream:write("bye bye\n") - stream:close() - return - end - if data == "exit" then - stream:close() - accept_new = false - return - end - stream:write("echo: "..data.."\n") + while true do + local data = stream:read(100) + data = data:match("^%s*(.-)%s*$") // trim + print("["..peer_addr.."] "..data) + if data == "bye" then + stream:write("bye bye\n") + stream:close() + return end - end) + stream:write("echo: "..data.."\n") + end end }) - .into_function()?; + .eval::() + .expect("cannot create Lua handler"); + + // Store it in the Registry + let handler = lua + .create_registry_value(handler_fn) + .expect("cannot store Lua handler"); task::LocalSet::new() - .run_until(server.call_async::<_, ()>("0.0.0.0:1234")) + .run_until(run_server(lua, handler)) .await + .expect("cannot run server") } -#[tokio::main] -async fn main() { - let lua = Lua::new().into_static(); - - run_server(lua).await.unwrap(); - - // Consume the static reference and drop it. - // This is safe as long as we don't hold any other references to Lua - // or alive resources. - unsafe { Lua::from_static(lua) }; +fn is_transient_error(e: &io::Error) -> bool { + e.kind() == io::ErrorKind::ConnectionRefused + || e.kind() == io::ErrorKind::ConnectionAborted + || e.kind() == io::ErrorKind::ConnectionReset }