Update examples: async_http_client/async_http_server/async_tcp_server

Make them following best practices and remove `unsafe` code.
This commit is contained in:
Alex Orlenko 2021-10-12 11:55:25 +01:00
parent ed48b11e7f
commit 25a4879cde
No known key found for this signature in database
GPG key ID: 4C150C250863B96D
4 changed files with 186 additions and 159 deletions

View file

@ -84,7 +84,7 @@ required-features = ["async", "serialize", "macros"]
[[example]] [[example]]
name = "async_http_server" name = "async_http_server"
required-features = ["async", "send"] required-features = ["async", "macros"]
[[example]] [[example]]
name = "async_tcp_server" name = "async_tcp_server"

View file

@ -1,26 +1,17 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use hyper::body::{Body as HyperBody, HttpBody as _}; use hyper::body::{Body as HyperBody, HttpBody as _};
use hyper::Client as HyperClient; use hyper::Client as HyperClient;
use tokio::sync::Mutex;
use mlua::{chunk, ExternalResult, Lua, Result, UserData, UserDataMethods}; use mlua::{chunk, AnyUserData, ExternalResult, Lua, Result, UserData, UserDataMethods};
#[derive(Clone)] struct BodyReader(HyperBody);
struct BodyReader(Arc<Mutex<HyperBody>>);
impl BodyReader {
fn new(body: HyperBody) -> Self {
BodyReader(Arc::new(Mutex::new(body)))
}
}
impl UserData for BodyReader { impl UserData for BodyReader {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("read", |lua, reader, ()| async move { methods.add_async_function("read", |lua, reader: AnyUserData| async move {
let mut reader = reader.0.lock().await; let mut reader = reader.borrow_mut::<Self>()?;
if let Some(bytes) = reader.data().await { if let Some(bytes) = reader.0.data().await {
let bytes = bytes.to_lua_err()?; let bytes = bytes.to_lua_err()?;
return Some(lua.create_string(&bytes)).transpose(); return Some(lua.create_string(&bytes)).transpose();
} }
@ -50,7 +41,7 @@ async fn main() -> Result<()> {
} }
lua_resp.set("headers", headers)?; lua_resp.set("headers", headers)?;
lua_resp.set("body", BodyReader::new(resp.into_body()))?; lua_resp.set("body", BodyReader(resp.into_body()))?;
Ok(lua_resp) Ok(lua_resp)
})?; })?;
@ -58,7 +49,7 @@ async fn main() -> Result<()> {
let f = lua let f = lua
.load(chunk! { .load(chunk! {
local res = $fetch_url(...) local res = $fetch_url(...)
print(res.status) print("status: "..res.status)
for key, vals in pairs(res.headers) do for key, vals in pairs(res.headers) do
for _, val in ipairs(vals) do for _, val in ipairs(vals) do
print(key..": "..val) print(key..": "..val)

View file

@ -1,10 +1,16 @@
use std::future::Future;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use hyper::server::conn::AddrStream; use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn}; use hyper::service::Service;
use hyper::{Body, Request, Response, Server}; use hyper::{Body, Request, Response, Server};
use mlua::{Error, Function, Lua, Result, Table, UserData, UserDataMethods}; use mlua::{
chunk, Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods,
};
struct LuaRequest(SocketAddr, Request<Body>); struct LuaRequest(SocketAddr, Request<Body>);
@ -15,75 +21,106 @@ impl UserData for LuaRequest {
} }
} }
async fn run_server(handler: Function<'static>) -> Result<()> { pub struct Svc(Rc<Lua>, SocketAddr);
let make_svc = make_service_fn(|socket: &AddrStream| {
let remote_addr = socket.remote_addr();
let handler = handler.clone();
async move {
Ok::<_, Error>(service_fn(move |req: Request<Body>| {
let handler = handler.clone();
async move {
let lua_req = LuaRequest(remote_addr, req);
let lua_resp: Table = handler.call_async(lua_req).await?;
let body = lua_resp
.get::<_, Option<String>>("body")?
.unwrap_or_default();
let mut resp = Response::builder() impl Service<Request<Body>> for Svc {
.status(lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200)); type Response = Response<Body>;
type Error = LuaError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
// If handler returns an error then generate 5xx response
let lua = self.0.clone();
let lua_req = LuaRequest(self.1, req);
Box::pin(async move {
let handler: Function = lua.named_registry_value("http_handler")?;
match handler.call_async::<_, Table>(lua_req).await {
Ok(lua_resp) => {
let status = lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200);
let mut resp = Response::builder().status(status);
// Set headers
if let Some(headers) = lua_resp.get::<_, Option<Table>>("headers")? { if let Some(headers) = lua_resp.get::<_, Option<Table>>("headers")? {
for pair in headers.pairs::<String, String>() { for pair in headers.pairs::<String, LuaString>() {
let (h, v) = pair?; let (h, v) = pair?;
resp = resp.header(&h, v); resp = resp.header(&h, v.as_bytes());
} }
} }
Ok::<_, Error>(resp.body(Body::from(body)).unwrap()) let body = lua_resp
.get::<_, Option<LuaString>>("body")?
.map(|b| Body::from(b.as_bytes().to_vec()))
.unwrap_or_else(Body::empty);
Ok(resp.body(body).unwrap())
} }
})) Err(err) => {
} eprintln!("{}", err);
}); Ok(Response::builder()
.status(500)
.body(Body::from("Internal Server Error"))
.unwrap())
}
}
})
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
let lua = Rc::new(Lua::new());
// Create Lua handler function
let handler: Function = lua
.load(chunk! {
function(req)
return {
status = 200,
headers = {
["X-Req-Method"] = req:method(),
["X-Remote-Addr"] = req:remote_addr(),
},
body = "Hello from Lua!\n"
}
end
})
.eval()
.expect("cannot create Lua handler");
// Store it in the Registry
lua.set_named_registry_value("http_handler", handler)
.expect("cannot store Lua handler");
let addr = ([127, 0, 0, 1], 3000).into(); let addr = ([127, 0, 0, 1], 3000).into();
let server = Server::bind(&addr).executor(LocalExec).serve(make_svc); let server = Server::bind(&addr).executor(LocalExec).serve(MakeSvc(lua));
println!("Listening on http://{}", addr); println!("Listening on http://{}", addr);
tokio::task::LocalSet::new() // Create `LocalSet` to spawn !Send futures
.run_until(server) let local = tokio::task::LocalSet::new();
.await local.run_until(server).await.expect("cannot run server")
.map_err(Error::external)
} }
#[tokio::main] struct MakeSvc(Rc<Lua>);
async fn main() -> Result<()> {
let lua = Lua::new().into_static();
let handler: Function = lua impl Service<&AddrStream> for MakeSvc {
.load( type Response = Svc;
r#" type Error = hyper::Error;
function(req) type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
return {
status = 200,
headers = {
["X-Req-Method"] = req:method(),
["X-Remote-Addr"] = req:remote_addr(),
},
body = "Hello, World!\n"
}
end
"#,
)
.eval()?;
run_server(handler).await?; fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
// Consume the static reference and drop it. fn call(&mut self, stream: &AddrStream) -> Self::Future {
// This is safe as long as we don't hold any other references to Lua let lua = self.0.clone();
// or alive resources. let remote_addr = stream.remote_addr();
unsafe { Lua::from_static(lua) }; Box::pin(async move { Ok(Svc(lua, remote_addr)) })
Ok(()) }
} }
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]

View file

@ -1,122 +1,121 @@
use std::sync::Arc; use std::io;
use std::net::SocketAddr;
use std::rc::Rc;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio::task; use tokio::task;
use mlua::{chunk, Function, Lua, Result, String as LuaString, UserData, UserDataMethods}; use mlua::{
chunk, AnyUserData, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods,
};
struct LuaTcp; struct LuaTcpStream(TcpStream);
#[derive(Clone)]
struct LuaTcpListener(Arc<Mutex<TcpListener>>);
#[derive(Clone)]
struct LuaTcpStream(Arc<Mutex<TcpStream>>);
impl UserData for LuaTcp {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_function("bind", |_, addr: String| async move {
let listener = TcpListener::bind(addr).await?;
Ok(LuaTcpListener(Arc::new(Mutex::new(listener))))
});
}
}
impl UserData for LuaTcpListener {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("accept", |_, listener, ()| async move {
let (stream, _) = listener.0.lock().await.accept().await?;
Ok(LuaTcpStream(Arc::new(Mutex::new(stream))))
});
}
}
impl UserData for LuaTcpStream { impl UserData for LuaTcpStream {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("peer_addr", |_, stream, ()| async move { methods.add_method("peer_addr", |_, this, ()| {
Ok(stream.0.lock().await.peer_addr()?.to_string()) Ok(this.0.peer_addr()?.to_string())
}); });
methods.add_async_method("read", |lua, stream, size: usize| async move { methods.add_async_function(
let mut buf = vec![0; size]; "read",
let n = stream.0.lock().await.read(&mut buf).await?; |lua, (this, size): (AnyUserData, usize)| async move {
buf.truncate(n); let mut this = this.borrow_mut::<Self>()?;
lua.create_string(&buf) let mut buf = vec![0; size];
}); let n = this.0.read(&mut buf).await?;
buf.truncate(n);
lua.create_string(&buf)
},
);
methods.add_async_method("write", |_, stream, data: LuaString| async move { methods.add_async_function(
let n = stream.0.lock().await.write(&data.as_bytes()).await?; "write",
Ok(n) |_, (this, data): (AnyUserData, LuaString)| async move {
}); let mut this = this.borrow_mut::<Self>()?;
let n = this.0.write(&data.as_bytes()).await?;
Ok(n)
},
);
methods.add_async_method("close", |_, stream, ()| async move { methods.add_async_function("close", |_, this: AnyUserData| async move {
stream.0.lock().await.shutdown().await?; let mut this = this.borrow_mut::<Self>()?;
this.0.shutdown().await?;
Ok(()) Ok(())
}); });
} }
} }
async fn run_server(lua: &'static Lua) -> Result<()> { async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> {
let spawn = lua.create_function(move |_, func: Function| { let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();
task::spawn_local(async move { func.call_async::<_, ()>(()).await }); let listener = TcpListener::bind(addr).await.expect("cannot bind addr");
Ok(())
})?;
let tcp = LuaTcp; println!("Listening on {}", addr);
let server = lua let lua = Rc::new(lua);
let handler = Rc::new(handler);
loop {
let (stream, _) = match listener.accept().await {
Ok(res) => res,
Err(err) if is_transient_error(&err) => continue,
Err(err) => return Err(err),
};
let lua = lua.clone();
let handler = handler.clone();
task::spawn_local(async move {
let handler: Function = lua
.registry_value(&handler)
.expect("cannot get Lua handler");
let stream = LuaTcpStream(stream);
if let Err(err) = handler.call_async::<_, ()>(stream).await {
eprintln!("{}", err);
}
});
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
let lua = Lua::new();
// Create Lua handler function
let handler_fn = lua
.load(chunk! { .load(chunk! {
local addr = ... function(stream)
local listener = $tcp.bind(addr)
print("listening on "..addr)
local accept_new = true
while true do
local stream = listener:accept()
local peer_addr = stream:peer_addr() local peer_addr = stream:peer_addr()
print("connected from "..peer_addr) print("connected from "..peer_addr)
if not accept_new then while true do
return local data = stream:read(100)
end data = data:match("^%s*(.-)%s*$") // trim
print("["..peer_addr.."] "..data)
$spawn(function() if data == "bye" then
while true do stream:write("bye bye\n")
local data = stream:read(100) stream:close()
data = data:match("^%s*(.-)%s*$") -- trim return
print("["..peer_addr.."] "..data)
if data == "bye" then
stream:write("bye bye\n")
stream:close()
return
end
if data == "exit" then
stream:close()
accept_new = false
return
end
stream:write("echo: "..data.."\n")
end end
end) stream:write("echo: "..data.."\n")
end
end end
}) })
.into_function()?; .eval::<Function>()
.expect("cannot create Lua handler");
// Store it in the Registry
let handler = lua
.create_registry_value(handler_fn)
.expect("cannot store Lua handler");
task::LocalSet::new() task::LocalSet::new()
.run_until(server.call_async::<_, ()>("0.0.0.0:1234")) .run_until(run_server(lua, handler))
.await .await
.expect("cannot run server")
} }
#[tokio::main] fn is_transient_error(e: &io::Error) -> bool {
async fn main() { e.kind() == io::ErrorKind::ConnectionRefused
let lua = Lua::new().into_static(); || e.kind() == io::ErrorKind::ConnectionAborted
|| e.kind() == io::ErrorKind::ConnectionReset
run_server(lua).await.unwrap();
// Consume the static reference and drop it.
// This is safe as long as we don't hold any other references to Lua
// or alive resources.
unsafe { Lua::from_static(lua) };
} }