Update examples: async_http_client/async_http_server/async_tcp_server
Make them following best practices and remove `unsafe` code.
This commit is contained in:
parent
ed48b11e7f
commit
25a4879cde
|
@ -84,7 +84,7 @@ required-features = ["async", "serialize", "macros"]
|
|||
|
||||
[[example]]
|
||||
name = "async_http_server"
|
||||
required-features = ["async", "send"]
|
||||
required-features = ["async", "macros"]
|
||||
|
||||
[[example]]
|
||||
name = "async_tcp_server"
|
||||
|
|
|
@ -1,26 +1,17 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use hyper::body::{Body as HyperBody, HttpBody as _};
|
||||
use hyper::Client as HyperClient;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use mlua::{chunk, ExternalResult, Lua, Result, UserData, UserDataMethods};
|
||||
use mlua::{chunk, AnyUserData, ExternalResult, Lua, Result, UserData, UserDataMethods};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BodyReader(Arc<Mutex<HyperBody>>);
|
||||
|
||||
impl BodyReader {
|
||||
fn new(body: HyperBody) -> Self {
|
||||
BodyReader(Arc::new(Mutex::new(body)))
|
||||
}
|
||||
}
|
||||
struct BodyReader(HyperBody);
|
||||
|
||||
impl UserData for BodyReader {
|
||||
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||
methods.add_async_method("read", |lua, reader, ()| async move {
|
||||
let mut reader = reader.0.lock().await;
|
||||
if let Some(bytes) = reader.data().await {
|
||||
methods.add_async_function("read", |lua, reader: AnyUserData| async move {
|
||||
let mut reader = reader.borrow_mut::<Self>()?;
|
||||
if let Some(bytes) = reader.0.data().await {
|
||||
let bytes = bytes.to_lua_err()?;
|
||||
return Some(lua.create_string(&bytes)).transpose();
|
||||
}
|
||||
|
@ -50,7 +41,7 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
lua_resp.set("headers", headers)?;
|
||||
lua_resp.set("body", BodyReader::new(resp.into_body()))?;
|
||||
lua_resp.set("body", BodyReader(resp.into_body()))?;
|
||||
|
||||
Ok(lua_resp)
|
||||
})?;
|
||||
|
@ -58,7 +49,7 @@ async fn main() -> Result<()> {
|
|||
let f = lua
|
||||
.load(chunk! {
|
||||
local res = $fetch_url(...)
|
||||
print(res.status)
|
||||
print("status: "..res.status)
|
||||
for key, vals in pairs(res.headers) do
|
||||
for _, val in ipairs(vals) do
|
||||
print(key..": "..val)
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::rc::Rc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use hyper::server::conn::AddrStream;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::service::Service;
|
||||
use hyper::{Body, Request, Response, Server};
|
||||
|
||||
use mlua::{Error, Function, Lua, Result, Table, UserData, UserDataMethods};
|
||||
use mlua::{
|
||||
chunk, Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods,
|
||||
};
|
||||
|
||||
struct LuaRequest(SocketAddr, Request<Body>);
|
||||
|
||||
|
@ -15,75 +21,106 @@ impl UserData for LuaRequest {
|
|||
}
|
||||
}
|
||||
|
||||
async fn run_server(handler: Function<'static>) -> Result<()> {
|
||||
let make_svc = make_service_fn(|socket: &AddrStream| {
|
||||
let remote_addr = socket.remote_addr();
|
||||
let handler = handler.clone();
|
||||
async move {
|
||||
Ok::<_, Error>(service_fn(move |req: Request<Body>| {
|
||||
let handler = handler.clone();
|
||||
async move {
|
||||
let lua_req = LuaRequest(remote_addr, req);
|
||||
let lua_resp: Table = handler.call_async(lua_req).await?;
|
||||
let body = lua_resp
|
||||
.get::<_, Option<String>>("body")?
|
||||
.unwrap_or_default();
|
||||
pub struct Svc(Rc<Lua>, SocketAddr);
|
||||
|
||||
let mut resp = Response::builder()
|
||||
.status(lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200));
|
||||
impl Service<Request<Body>> for Svc {
|
||||
type Response = Response<Body>;
|
||||
type Error = LuaError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
// If handler returns an error then generate 5xx response
|
||||
let lua = self.0.clone();
|
||||
let lua_req = LuaRequest(self.1, req);
|
||||
Box::pin(async move {
|
||||
let handler: Function = lua.named_registry_value("http_handler")?;
|
||||
match handler.call_async::<_, Table>(lua_req).await {
|
||||
Ok(lua_resp) => {
|
||||
let status = lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200);
|
||||
let mut resp = Response::builder().status(status);
|
||||
|
||||
// Set headers
|
||||
if let Some(headers) = lua_resp.get::<_, Option<Table>>("headers")? {
|
||||
for pair in headers.pairs::<String, String>() {
|
||||
for pair in headers.pairs::<String, LuaString>() {
|
||||
let (h, v) = pair?;
|
||||
resp = resp.header(&h, v);
|
||||
resp = resp.header(&h, v.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<_, Error>(resp.body(Body::from(body)).unwrap())
|
||||
let body = lua_resp
|
||||
.get::<_, Option<LuaString>>("body")?
|
||||
.map(|b| Body::from(b.as_bytes().to_vec()))
|
||||
.unwrap_or_else(Body::empty);
|
||||
|
||||
Ok(resp.body(body).unwrap())
|
||||
}
|
||||
}))
|
||||
}
|
||||
});
|
||||
Err(err) => {
|
||||
eprintln!("{}", err);
|
||||
Ok(Response::builder()
|
||||
.status(500)
|
||||
.body(Body::from("Internal Server Error"))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
async fn main() {
|
||||
let lua = Rc::new(Lua::new());
|
||||
|
||||
// Create Lua handler function
|
||||
let handler: Function = lua
|
||||
.load(chunk! {
|
||||
function(req)
|
||||
return {
|
||||
status = 200,
|
||||
headers = {
|
||||
["X-Req-Method"] = req:method(),
|
||||
["X-Remote-Addr"] = req:remote_addr(),
|
||||
},
|
||||
body = "Hello from Lua!\n"
|
||||
}
|
||||
end
|
||||
})
|
||||
.eval()
|
||||
.expect("cannot create Lua handler");
|
||||
|
||||
// Store it in the Registry
|
||||
lua.set_named_registry_value("http_handler", handler)
|
||||
.expect("cannot store Lua handler");
|
||||
|
||||
let addr = ([127, 0, 0, 1], 3000).into();
|
||||
let server = Server::bind(&addr).executor(LocalExec).serve(make_svc);
|
||||
let server = Server::bind(&addr).executor(LocalExec).serve(MakeSvc(lua));
|
||||
|
||||
println!("Listening on http://{}", addr);
|
||||
|
||||
tokio::task::LocalSet::new()
|
||||
.run_until(server)
|
||||
.await
|
||||
.map_err(Error::external)
|
||||
// Create `LocalSet` to spawn !Send futures
|
||||
let local = tokio::task::LocalSet::new();
|
||||
local.run_until(server).await.expect("cannot run server")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let lua = Lua::new().into_static();
|
||||
struct MakeSvc(Rc<Lua>);
|
||||
|
||||
let handler: Function = lua
|
||||
.load(
|
||||
r#"
|
||||
function(req)
|
||||
return {
|
||||
status = 200,
|
||||
headers = {
|
||||
["X-Req-Method"] = req:method(),
|
||||
["X-Remote-Addr"] = req:remote_addr(),
|
||||
},
|
||||
body = "Hello, World!\n"
|
||||
}
|
||||
end
|
||||
"#,
|
||||
)
|
||||
.eval()?;
|
||||
impl Service<&AddrStream> for MakeSvc {
|
||||
type Response = Svc;
|
||||
type Error = hyper::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||
|
||||
run_server(handler).await?;
|
||||
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
// Consume the static reference and drop it.
|
||||
// This is safe as long as we don't hold any other references to Lua
|
||||
// or alive resources.
|
||||
unsafe { Lua::from_static(lua) };
|
||||
Ok(())
|
||||
fn call(&mut self, stream: &AddrStream) -> Self::Future {
|
||||
let lua = self.0.clone();
|
||||
let remote_addr = stream.remote_addr();
|
||||
Box::pin(async move { Ok(Svc(lua, remote_addr)) })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
|
|
|
@ -1,122 +1,121 @@
|
|||
use std::sync::Arc;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::rc::Rc;
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task;
|
||||
|
||||
use mlua::{chunk, Function, Lua, Result, String as LuaString, UserData, UserDataMethods};
|
||||
use mlua::{
|
||||
chunk, AnyUserData, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods,
|
||||
};
|
||||
|
||||
struct LuaTcp;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LuaTcpListener(Arc<Mutex<TcpListener>>);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LuaTcpStream(Arc<Mutex<TcpStream>>);
|
||||
|
||||
impl UserData for LuaTcp {
|
||||
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||
methods.add_async_function("bind", |_, addr: String| async move {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
Ok(LuaTcpListener(Arc::new(Mutex::new(listener))))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl UserData for LuaTcpListener {
|
||||
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||
methods.add_async_method("accept", |_, listener, ()| async move {
|
||||
let (stream, _) = listener.0.lock().await.accept().await?;
|
||||
Ok(LuaTcpStream(Arc::new(Mutex::new(stream))))
|
||||
});
|
||||
}
|
||||
}
|
||||
struct LuaTcpStream(TcpStream);
|
||||
|
||||
impl UserData for LuaTcpStream {
|
||||
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||
methods.add_async_method("peer_addr", |_, stream, ()| async move {
|
||||
Ok(stream.0.lock().await.peer_addr()?.to_string())
|
||||
methods.add_method("peer_addr", |_, this, ()| {
|
||||
Ok(this.0.peer_addr()?.to_string())
|
||||
});
|
||||
|
||||
methods.add_async_method("read", |lua, stream, size: usize| async move {
|
||||
let mut buf = vec![0; size];
|
||||
let n = stream.0.lock().await.read(&mut buf).await?;
|
||||
buf.truncate(n);
|
||||
lua.create_string(&buf)
|
||||
});
|
||||
methods.add_async_function(
|
||||
"read",
|
||||
|lua, (this, size): (AnyUserData, usize)| async move {
|
||||
let mut this = this.borrow_mut::<Self>()?;
|
||||
let mut buf = vec![0; size];
|
||||
let n = this.0.read(&mut buf).await?;
|
||||
buf.truncate(n);
|
||||
lua.create_string(&buf)
|
||||
},
|
||||
);
|
||||
|
||||
methods.add_async_method("write", |_, stream, data: LuaString| async move {
|
||||
let n = stream.0.lock().await.write(&data.as_bytes()).await?;
|
||||
Ok(n)
|
||||
});
|
||||
methods.add_async_function(
|
||||
"write",
|
||||
|_, (this, data): (AnyUserData, LuaString)| async move {
|
||||
let mut this = this.borrow_mut::<Self>()?;
|
||||
let n = this.0.write(&data.as_bytes()).await?;
|
||||
Ok(n)
|
||||
},
|
||||
);
|
||||
|
||||
methods.add_async_method("close", |_, stream, ()| async move {
|
||||
stream.0.lock().await.shutdown().await?;
|
||||
methods.add_async_function("close", |_, this: AnyUserData| async move {
|
||||
let mut this = this.borrow_mut::<Self>()?;
|
||||
this.0.shutdown().await?;
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_server(lua: &'static Lua) -> Result<()> {
|
||||
let spawn = lua.create_function(move |_, func: Function| {
|
||||
task::spawn_local(async move { func.call_async::<_, ()>(()).await });
|
||||
Ok(())
|
||||
})?;
|
||||
async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> {
|
||||
let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();
|
||||
let listener = TcpListener::bind(addr).await.expect("cannot bind addr");
|
||||
|
||||
let tcp = LuaTcp;
|
||||
println!("Listening on {}", addr);
|
||||
|
||||
let server = lua
|
||||
let lua = Rc::new(lua);
|
||||
let handler = Rc::new(handler);
|
||||
loop {
|
||||
let (stream, _) = match listener.accept().await {
|
||||
Ok(res) => res,
|
||||
Err(err) if is_transient_error(&err) => continue,
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
let lua = lua.clone();
|
||||
let handler = handler.clone();
|
||||
task::spawn_local(async move {
|
||||
let handler: Function = lua
|
||||
.registry_value(&handler)
|
||||
.expect("cannot get Lua handler");
|
||||
|
||||
let stream = LuaTcpStream(stream);
|
||||
if let Err(err) = handler.call_async::<_, ()>(stream).await {
|
||||
eprintln!("{}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
async fn main() {
|
||||
let lua = Lua::new();
|
||||
|
||||
// Create Lua handler function
|
||||
let handler_fn = lua
|
||||
.load(chunk! {
|
||||
local addr = ...
|
||||
local listener = $tcp.bind(addr)
|
||||
print("listening on "..addr)
|
||||
|
||||
local accept_new = true
|
||||
while true do
|
||||
local stream = listener:accept()
|
||||
function(stream)
|
||||
local peer_addr = stream:peer_addr()
|
||||
print("connected from "..peer_addr)
|
||||
|
||||
if not accept_new then
|
||||
return
|
||||
end
|
||||
|
||||
$spawn(function()
|
||||
while true do
|
||||
local data = stream:read(100)
|
||||
data = data:match("^%s*(.-)%s*$") -- trim
|
||||
print("["..peer_addr.."] "..data)
|
||||
if data == "bye" then
|
||||
stream:write("bye bye\n")
|
||||
stream:close()
|
||||
return
|
||||
end
|
||||
if data == "exit" then
|
||||
stream:close()
|
||||
accept_new = false
|
||||
return
|
||||
end
|
||||
stream:write("echo: "..data.."\n")
|
||||
while true do
|
||||
local data = stream:read(100)
|
||||
data = data:match("^%s*(.-)%s*$") // trim
|
||||
print("["..peer_addr.."] "..data)
|
||||
if data == "bye" then
|
||||
stream:write("bye bye\n")
|
||||
stream:close()
|
||||
return
|
||||
end
|
||||
end)
|
||||
stream:write("echo: "..data.."\n")
|
||||
end
|
||||
end
|
||||
})
|
||||
.into_function()?;
|
||||
.eval::<Function>()
|
||||
.expect("cannot create Lua handler");
|
||||
|
||||
// Store it in the Registry
|
||||
let handler = lua
|
||||
.create_registry_value(handler_fn)
|
||||
.expect("cannot store Lua handler");
|
||||
|
||||
task::LocalSet::new()
|
||||
.run_until(server.call_async::<_, ()>("0.0.0.0:1234"))
|
||||
.run_until(run_server(lua, handler))
|
||||
.await
|
||||
.expect("cannot run server")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let lua = Lua::new().into_static();
|
||||
|
||||
run_server(lua).await.unwrap();
|
||||
|
||||
// Consume the static reference and drop it.
|
||||
// This is safe as long as we don't hold any other references to Lua
|
||||
// or alive resources.
|
||||
unsafe { Lua::from_static(lua) };
|
||||
fn is_transient_error(e: &io::Error) -> bool {
|
||||
e.kind() == io::ErrorKind::ConnectionRefused
|
||||
|| e.kind() == io::ErrorKind::ConnectionAborted
|
||||
|| e.kind() == io::ErrorKind::ConnectionReset
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue