Update examples: async_http_client/async_http_server/async_tcp_server

Make them following best practices and remove `unsafe` code.
This commit is contained in:
Alex Orlenko 2021-10-12 11:55:25 +01:00
parent ed48b11e7f
commit 25a4879cde
No known key found for this signature in database
GPG key ID: 4C150C250863B96D
4 changed files with 186 additions and 159 deletions

View file

@ -84,7 +84,7 @@ required-features = ["async", "serialize", "macros"]
[[example]]
name = "async_http_server"
required-features = ["async", "send"]
required-features = ["async", "macros"]
[[example]]
name = "async_tcp_server"

View file

@ -1,26 +1,17 @@
use std::collections::HashMap;
use std::sync::Arc;
use hyper::body::{Body as HyperBody, HttpBody as _};
use hyper::Client as HyperClient;
use tokio::sync::Mutex;
use mlua::{chunk, ExternalResult, Lua, Result, UserData, UserDataMethods};
use mlua::{chunk, AnyUserData, ExternalResult, Lua, Result, UserData, UserDataMethods};
#[derive(Clone)]
struct BodyReader(Arc<Mutex<HyperBody>>);
impl BodyReader {
fn new(body: HyperBody) -> Self {
BodyReader(Arc::new(Mutex::new(body)))
}
}
struct BodyReader(HyperBody);
impl UserData for BodyReader {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("read", |lua, reader, ()| async move {
let mut reader = reader.0.lock().await;
if let Some(bytes) = reader.data().await {
methods.add_async_function("read", |lua, reader: AnyUserData| async move {
let mut reader = reader.borrow_mut::<Self>()?;
if let Some(bytes) = reader.0.data().await {
let bytes = bytes.to_lua_err()?;
return Some(lua.create_string(&bytes)).transpose();
}
@ -50,7 +41,7 @@ async fn main() -> Result<()> {
}
lua_resp.set("headers", headers)?;
lua_resp.set("body", BodyReader::new(resp.into_body()))?;
lua_resp.set("body", BodyReader(resp.into_body()))?;
Ok(lua_resp)
})?;
@ -58,7 +49,7 @@ async fn main() -> Result<()> {
let f = lua
.load(chunk! {
local res = $fetch_url(...)
print(res.status)
print("status: "..res.status)
for key, vals in pairs(res.headers) do
for _, val in ipairs(vals) do
print(key..": "..val)

View file

@ -1,10 +1,16 @@
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::service::Service;
use hyper::{Body, Request, Response, Server};
use mlua::{Error, Function, Lua, Result, Table, UserData, UserDataMethods};
use mlua::{
chunk, Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods,
};
struct LuaRequest(SocketAddr, Request<Body>);
@ -15,75 +21,106 @@ impl UserData for LuaRequest {
}
}
async fn run_server(handler: Function<'static>) -> Result<()> {
let make_svc = make_service_fn(|socket: &AddrStream| {
let remote_addr = socket.remote_addr();
let handler = handler.clone();
async move {
Ok::<_, Error>(service_fn(move |req: Request<Body>| {
let handler = handler.clone();
async move {
let lua_req = LuaRequest(remote_addr, req);
let lua_resp: Table = handler.call_async(lua_req).await?;
let body = lua_resp
.get::<_, Option<String>>("body")?
.unwrap_or_default();
pub struct Svc(Rc<Lua>, SocketAddr);
let mut resp = Response::builder()
.status(lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200));
impl Service<Request<Body>> for Svc {
type Response = Response<Body>;
type Error = LuaError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
// If handler returns an error then generate 5xx response
let lua = self.0.clone();
let lua_req = LuaRequest(self.1, req);
Box::pin(async move {
let handler: Function = lua.named_registry_value("http_handler")?;
match handler.call_async::<_, Table>(lua_req).await {
Ok(lua_resp) => {
let status = lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200);
let mut resp = Response::builder().status(status);
// Set headers
if let Some(headers) = lua_resp.get::<_, Option<Table>>("headers")? {
for pair in headers.pairs::<String, String>() {
for pair in headers.pairs::<String, LuaString>() {
let (h, v) = pair?;
resp = resp.header(&h, v);
resp = resp.header(&h, v.as_bytes());
}
}
Ok::<_, Error>(resp.body(Body::from(body)).unwrap())
let body = lua_resp
.get::<_, Option<LuaString>>("body")?
.map(|b| Body::from(b.as_bytes().to_vec()))
.unwrap_or_else(Body::empty);
Ok(resp.body(body).unwrap())
}
}))
}
});
Err(err) => {
eprintln!("{}", err);
Ok(Response::builder()
.status(500)
.body(Body::from("Internal Server Error"))
.unwrap())
}
}
})
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
let lua = Rc::new(Lua::new());
// Create Lua handler function
let handler: Function = lua
.load(chunk! {
function(req)
return {
status = 200,
headers = {
["X-Req-Method"] = req:method(),
["X-Remote-Addr"] = req:remote_addr(),
},
body = "Hello from Lua!\n"
}
end
})
.eval()
.expect("cannot create Lua handler");
// Store it in the Registry
lua.set_named_registry_value("http_handler", handler)
.expect("cannot store Lua handler");
let addr = ([127, 0, 0, 1], 3000).into();
let server = Server::bind(&addr).executor(LocalExec).serve(make_svc);
let server = Server::bind(&addr).executor(LocalExec).serve(MakeSvc(lua));
println!("Listening on http://{}", addr);
tokio::task::LocalSet::new()
.run_until(server)
.await
.map_err(Error::external)
// Create `LocalSet` to spawn !Send futures
let local = tokio::task::LocalSet::new();
local.run_until(server).await.expect("cannot run server")
}
#[tokio::main]
async fn main() -> Result<()> {
let lua = Lua::new().into_static();
struct MakeSvc(Rc<Lua>);
let handler: Function = lua
.load(
r#"
function(req)
return {
status = 200,
headers = {
["X-Req-Method"] = req:method(),
["X-Remote-Addr"] = req:remote_addr(),
},
body = "Hello, World!\n"
}
end
"#,
)
.eval()?;
impl Service<&AddrStream> for MakeSvc {
type Response = Svc;
type Error = hyper::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
run_server(handler).await?;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
// Consume the static reference and drop it.
// This is safe as long as we don't hold any other references to Lua
// or alive resources.
unsafe { Lua::from_static(lua) };
Ok(())
fn call(&mut self, stream: &AddrStream) -> Self::Future {
let lua = self.0.clone();
let remote_addr = stream.remote_addr();
Box::pin(async move { Ok(Svc(lua, remote_addr)) })
}
}
#[derive(Clone, Copy, Debug)]

View file

@ -1,122 +1,121 @@
use std::sync::Arc;
use std::io;
use std::net::SocketAddr;
use std::rc::Rc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio::task;
use mlua::{chunk, Function, Lua, Result, String as LuaString, UserData, UserDataMethods};
use mlua::{
chunk, AnyUserData, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods,
};
struct LuaTcp;
#[derive(Clone)]
struct LuaTcpListener(Arc<Mutex<TcpListener>>);
#[derive(Clone)]
struct LuaTcpStream(Arc<Mutex<TcpStream>>);
impl UserData for LuaTcp {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_function("bind", |_, addr: String| async move {
let listener = TcpListener::bind(addr).await?;
Ok(LuaTcpListener(Arc::new(Mutex::new(listener))))
});
}
}
impl UserData for LuaTcpListener {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("accept", |_, listener, ()| async move {
let (stream, _) = listener.0.lock().await.accept().await?;
Ok(LuaTcpStream(Arc::new(Mutex::new(stream))))
});
}
}
struct LuaTcpStream(TcpStream);
impl UserData for LuaTcpStream {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("peer_addr", |_, stream, ()| async move {
Ok(stream.0.lock().await.peer_addr()?.to_string())
methods.add_method("peer_addr", |_, this, ()| {
Ok(this.0.peer_addr()?.to_string())
});
methods.add_async_method("read", |lua, stream, size: usize| async move {
let mut buf = vec![0; size];
let n = stream.0.lock().await.read(&mut buf).await?;
buf.truncate(n);
lua.create_string(&buf)
});
methods.add_async_function(
"read",
|lua, (this, size): (AnyUserData, usize)| async move {
let mut this = this.borrow_mut::<Self>()?;
let mut buf = vec![0; size];
let n = this.0.read(&mut buf).await?;
buf.truncate(n);
lua.create_string(&buf)
},
);
methods.add_async_method("write", |_, stream, data: LuaString| async move {
let n = stream.0.lock().await.write(&data.as_bytes()).await?;
Ok(n)
});
methods.add_async_function(
"write",
|_, (this, data): (AnyUserData, LuaString)| async move {
let mut this = this.borrow_mut::<Self>()?;
let n = this.0.write(&data.as_bytes()).await?;
Ok(n)
},
);
methods.add_async_method("close", |_, stream, ()| async move {
stream.0.lock().await.shutdown().await?;
methods.add_async_function("close", |_, this: AnyUserData| async move {
let mut this = this.borrow_mut::<Self>()?;
this.0.shutdown().await?;
Ok(())
});
}
}
async fn run_server(lua: &'static Lua) -> Result<()> {
let spawn = lua.create_function(move |_, func: Function| {
task::spawn_local(async move { func.call_async::<_, ()>(()).await });
Ok(())
})?;
async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> {
let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();
let listener = TcpListener::bind(addr).await.expect("cannot bind addr");
let tcp = LuaTcp;
println!("Listening on {}", addr);
let server = lua
let lua = Rc::new(lua);
let handler = Rc::new(handler);
loop {
let (stream, _) = match listener.accept().await {
Ok(res) => res,
Err(err) if is_transient_error(&err) => continue,
Err(err) => return Err(err),
};
let lua = lua.clone();
let handler = handler.clone();
task::spawn_local(async move {
let handler: Function = lua
.registry_value(&handler)
.expect("cannot get Lua handler");
let stream = LuaTcpStream(stream);
if let Err(err) = handler.call_async::<_, ()>(stream).await {
eprintln!("{}", err);
}
});
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
let lua = Lua::new();
// Create Lua handler function
let handler_fn = lua
.load(chunk! {
local addr = ...
local listener = $tcp.bind(addr)
print("listening on "..addr)
local accept_new = true
while true do
local stream = listener:accept()
function(stream)
local peer_addr = stream:peer_addr()
print("connected from "..peer_addr)
if not accept_new then
return
end
$spawn(function()
while true do
local data = stream:read(100)
data = data:match("^%s*(.-)%s*$") -- trim
print("["..peer_addr.."] "..data)
if data == "bye" then
stream:write("bye bye\n")
stream:close()
return
end
if data == "exit" then
stream:close()
accept_new = false
return
end
stream:write("echo: "..data.."\n")
while true do
local data = stream:read(100)
data = data:match("^%s*(.-)%s*$") // trim
print("["..peer_addr.."] "..data)
if data == "bye" then
stream:write("bye bye\n")
stream:close()
return
end
end)
stream:write("echo: "..data.."\n")
end
end
})
.into_function()?;
.eval::<Function>()
.expect("cannot create Lua handler");
// Store it in the Registry
let handler = lua
.create_registry_value(handler_fn)
.expect("cannot store Lua handler");
task::LocalSet::new()
.run_until(server.call_async::<_, ()>("0.0.0.0:1234"))
.run_until(run_server(lua, handler))
.await
.expect("cannot run server")
}
#[tokio::main]
async fn main() {
let lua = Lua::new().into_static();
run_server(lua).await.unwrap();
// Consume the static reference and drop it.
// This is safe as long as we don't hold any other references to Lua
// or alive resources.
unsafe { Lua::from_static(lua) };
fn is_transient_error(e: &io::Error) -> bool {
e.kind() == io::ErrorKind::ConnectionRefused
|| e.kind() == io::ErrorKind::ConnectionAborted
|| e.kind() == io::ErrorKind::ConnectionReset
}