#![cfg(feature = "async")] use std::cell::Cell; use std::rc::Rc; use std::sync::{ atomic::{AtomicI64, AtomicU64, Ordering}, Arc, }; use std::time::Duration; use futures_timer::Delay; use futures_util::stream::TryStreamExt; use mlua::{ Error, Function, Lua, LuaOptions, Result, StdLib, Table, TableExt, Thread, UserData, UserDataMethods, Value, }; #[tokio::test] async fn test_async_function() -> Result<()> { let lua = Lua::new(); let f = lua .create_async_function(|_lua, (a, b, c): (i64, i64, i64)| async move { Ok((a + b) * c) })?; lua.globals().set("f", f)?; let res: i64 = lua.load("f(1, 2, 3)").eval_async().await?; assert_eq!(res, 9); Ok(()) } #[tokio::test] async fn test_async_sleep() -> Result<()> { let lua = Lua::new(); let sleep = lua.create_async_function(move |_lua, n: u64| async move { Delay::new(Duration::from_millis(n)).await; Ok(format!("elapsed:{}ms", n)) })?; lua.globals().set("sleep", sleep)?; let res: String = lua.load(r"return sleep(...)").call_async(100).await?; assert_eq!(res, "elapsed:100ms"); Ok(()) } #[tokio::test] async fn test_async_call() -> Result<()> { let lua = Lua::new(); let hello = lua.create_async_function(|_lua, name: String| async move { Delay::new(Duration::from_millis(10)).await; Ok(format!("hello, {}!", name)) })?; match hello.call::<_, ()>("alex") { Err(Error::RuntimeError(_)) => {} _ => panic!( "non-async executing async function must fail on the yield stage with RuntimeError" ), }; assert_eq!(hello.call_async::<_, String>("alex").await?, "hello, alex!"); // Executing non-async functions using async call is allowed let sum = lua.create_function(|_lua, (a, b): (i64, i64)| return Ok(a + b))?; assert_eq!(sum.call_async::<_, i64>((5, 1)).await?, 6); Ok(()) } #[tokio::test] async fn test_async_bind_call() -> Result<()> { let lua = Lua::new(); let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move { tokio::task::yield_now().await; Ok(a + b) })?; let plus_10 = sum.bind(10)?; lua.globals().set("plus_10", plus_10)?; assert_eq!(lua.load("plus_10(-1)").eval_async::().await?, 9); assert_eq!(lua.load("plus_10(1)").eval_async::().await?, 11); Ok(()) } #[tokio::test] async fn test_async_handle_yield() -> Result<()> { let lua = Lua::new(); let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move { Delay::new(Duration::from_millis(10)).await; Ok(a + b) })?; lua.globals().set("sleep_sum", sum)?; let res: String = lua .load( r#" sum = sleep_sum(6, 7) assert(sum == 13) coroutine.yield("in progress") return "done" "#, ) .call_async(()) .await?; assert_eq!(res, "done"); let min = lua .load( r#" function (a, b) coroutine.yield("ignore me") if a < b then return a else return b end end "#, ) .eval::()?; assert_eq!(min.call_async::<_, i64>((-1, 1)).await?, -1); Ok(()) } #[tokio::test] async fn test_async_multi_return_nil() -> Result<()> { let lua = Lua::new(); lua.globals().set( "func", lua.create_async_function(|_, _: ()| async { Ok((Option::::None, "error")) })?, )?; lua.load( r#" local ok, err = func() assert(err == "error") "#, ) .exec_async() .await } #[tokio::test] async fn test_async_return_async_closure() -> Result<()> { let lua = Lua::new(); let f = lua.create_async_function(|lua, a: i64| async move { Delay::new(Duration::from_millis(10)).await; let g = lua.create_async_function(move |_, b: i64| async move { Delay::new(Duration::from_millis(10)).await; return Ok(a + b); })?; Ok(g) })?; lua.globals().set("f", f)?; let res: i64 = lua .load("local g = f(1); return g(2) + g(3)") .call_async(()) .await?; assert_eq!(res, 7); Ok(()) } #[cfg(feature = "lua54")] #[tokio::test] async fn test_async_lua54_to_be_closed() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); globals.set("close_count", 0)?; let code = r#" local t = setmetatable({}, { __close = function() close_count = close_count + 1 end }) error "test" "#; let f = lua.load(code).into_function()?; // Test close using call_async let _ = f.call_async::<_, ()>(()).await; assert_eq!(globals.get::<_, usize>("close_count")?, 1); // Don't close by default when awaiting async threads let co = lua.create_thread(f.clone())?; let _ = co.clone().into_async::<_, ()>(()).await; assert_eq!(globals.get::<_, usize>("close_count")?, 1); let _ = co.reset(f); assert_eq!(globals.get::<_, usize>("close_count")?, 2); Ok(()) } #[tokio::test] async fn test_async_thread_stream() -> Result<()> { let lua = Lua::new(); let thread = lua.create_thread( lua.load( r#" function (sum) for i = 1,10 do sum = sum + i coroutine.yield(sum) end return sum end "#, ) .eval()?, )?; let mut stream = thread.into_async::<_, i64>(1); let mut sum = 0; while let Some(n) = stream.try_next().await? { sum += n; } assert_eq!(sum, 286); Ok(()) } #[tokio::test] async fn test_async_thread() -> Result<()> { let lua = Lua::new(); let cnt = Arc::new(10); // sleep 10ms let cnt2 = cnt.clone(); let f = lua.create_async_function(move |_lua, ()| { let cnt3 = cnt2.clone(); async move { Delay::new(Duration::from_millis(*cnt3.as_ref())).await; Ok("done") } })?; let res: String = lua.create_thread(f)?.into_async(()).await?; assert_eq!(res, "done"); assert_eq!(Arc::strong_count(&cnt), 2); lua.gc_collect()?; // thread_s is non-resumable and subject to garbage collection assert_eq!(Arc::strong_count(&cnt), 1); Ok(()) } #[tokio::test] async fn test_async_table() -> Result<()> { let options = LuaOptions::new().thread_cache_size(4); let lua = Lua::new_with(StdLib::ALL_SAFE, options)?; let table = lua.create_table()?; table.set("val", 10)?; let get_value = lua.create_async_function(|_, table: Table| async move { Delay::new(Duration::from_millis(10)).await; table.get::<_, i64>("val") })?; table.set("get_value", get_value)?; let set_value = lua.create_async_function(|_, (table, n): (Table, i64)| async move { Delay::new(Duration::from_millis(10)).await; table.set("val", n) })?; table.set("set_value", set_value)?; let sleep = lua.create_async_function(|_, n| async move { Delay::new(Duration::from_millis(n)).await; Ok(format!("elapsed:{}ms", n)) })?; table.set("sleep", sleep)?; assert_eq!( table .call_async_method::<_, _, i64>("get_value", ()) .await?, 10 ); table.call_async_method("set_value", 15).await?; assert_eq!( table .call_async_method::<_, _, i64>("get_value", ()) .await?, 15 ); assert_eq!( table .call_async_function::<_, _, String>("sleep", 7) .await?, "elapsed:7ms" ); Ok(()) } #[tokio::test] async fn test_async_thread_cache() -> Result<()> { let options = LuaOptions::new().thread_cache_size(4); let lua = Lua::new_with(StdLib::ALL_SAFE, options)?; let error_f = lua.create_async_function(|_, ()| async move { Delay::new(Duration::from_millis(10)).await; Err::<(), _>(Error::RuntimeError("test".to_string())) })?; let sleep = lua.create_async_function(|_, n| async move { Delay::new(Duration::from_millis(n)).await; Ok(format!("elapsed:{}ms", n)) })?; assert!(error_f.call_async::<_, ()>(()).await.is_err()); // Next call should use cached thread assert_eq!(sleep.call_async::<_, String>(3).await?, "elapsed:3ms"); Ok(()) } #[tokio::test] async fn test_async_userdata() -> Result<()> { #[derive(Clone)] struct MyUserData(Arc); impl UserData for MyUserData { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_async_method("get_value", |_, data, ()| async move { Delay::new(Duration::from_millis(10)).await; Ok(data.0.load(Ordering::Relaxed)) }); methods.add_async_method("set_value", |_, data, n| async move { Delay::new(Duration::from_millis(10)).await; data.0.store(n, Ordering::Relaxed); Ok(()) }); methods.add_async_function("sleep", |_, n| async move { Delay::new(Duration::from_millis(n)).await; Ok(format!("elapsed:{}ms", n)) }); #[cfg(not(any(feature = "lua51", feature = "luau")))] methods.add_async_meta_method(mlua::MetaMethod::Call, |_, data, ()| async move { let n = data.0.load(Ordering::Relaxed); Delay::new(Duration::from_millis(n)).await; Ok(format!("elapsed:{}ms", n)) }); #[cfg(not(any(feature = "lua51", feature = "luau")))] methods.add_async_meta_method( mlua::MetaMethod::Index, |_, data, key: String| async move { Delay::new(Duration::from_millis(10)).await; match key.as_str() { "ms" => Ok(Some(data.0.load(Ordering::Relaxed) as f64)), "s" => Ok(Some((data.0.load(Ordering::Relaxed) as f64) / 1000.0)), _ => Ok(None), } }, ); #[cfg(not(any(feature = "lua51", feature = "luau")))] methods.add_async_meta_method( mlua::MetaMethod::NewIndex, |_, data, (key, value): (String, f64)| async move { Delay::new(Duration::from_millis(10)).await; match key.as_str() { "ms" => Ok(data.0.store(value as u64, Ordering::Relaxed)), "s" => Ok(data.0.store((value * 1000.0) as u64, Ordering::Relaxed)), _ => Err(Error::external(format!("key '{}' not found", key))), } }, ); } } let lua = Lua::new(); let globals = lua.globals(); let userdata = lua.create_userdata(MyUserData(Arc::new(AtomicU64::new(11))))?; globals.set("userdata", userdata.clone())?; lua.load( r#" assert(userdata:get_value() == 11) userdata:set_value(12) assert(userdata.sleep(5) == "elapsed:5ms") assert(userdata:get_value() == 12) "#, ) .exec_async() .await?; #[cfg(not(any(feature = "lua51", feature = "luau")))] lua.load( r#" userdata:set_value(15) assert(userdata() == "elapsed:15ms") userdata.ms = 2000 assert(userdata.s == 2) userdata.s = 15 assert(userdata.ms == 15000) "#, ) .exec_async() .await?; Ok(()) } #[tokio::test] async fn test_async_thread_error() -> Result<()> { struct MyUserData; impl UserData for MyUserData { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_meta_method("__tostring", |_, _this, ()| Ok("myuserdata error")) } } let lua = Lua::new(); let result = lua .load("function x(...) error(...) end x(...)") .set_name("chunk")? .call_async::<_, ()>(MyUserData) .await; assert!( matches!(result, Err(Error::RuntimeError(cause)) if cause.contains("myuserdata error")), "improper error traceback from dead thread" ); Ok(()) } #[tokio::test] async fn test_async_scope() -> Result<()> { let ref lua = Lua::new(); let ref rc = Rc::new(Cell::new(0)); let fut = lua.async_scope(|scope| async move { let f = scope.create_async_function(move |_, n: u64| { let rc2 = rc.clone(); async move { rc2.set(42); Delay::new(Duration::from_millis(n)).await; assert_eq!(Rc::strong_count(&rc2), 2); Ok(()) } })?; lua.globals().set("f", f.clone())?; assert_eq!(Rc::strong_count(rc), 1); let _ = f.call_async::(10).await?; assert_eq!(Rc::strong_count(rc), 1); // Create future in partialy polled state (Poll::Pending) let g = lua.create_thread(f)?; g.resume::(10)?; lua.globals().set("g", g)?; assert_eq!(Rc::strong_count(rc), 2); Ok(()) }); assert_eq!(Rc::strong_count(rc), 1); let _ = fut.await?; assert_eq!(Rc::strong_count(rc), 1); match lua .globals() .get::<_, Function>("f")? .call_async::<_, ()>(10) .await { Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { Error::CallbackDestructed => {} e => panic!("expected `CallbackDestructed` error cause, got {:?}", e), }, r => panic!("improper return for destructed function: {:?}", r), }; match lua.globals().get::<_, Thread>("g")?.resume::<_, Value>(()) { Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { Error::CallbackDestructed => {} e => panic!("expected `CallbackDestructed` error cause, got {:?}", e), }, r => panic!("improper return for destructed function: {:?}", r), }; Ok(()) } #[tokio::test] async fn test_async_scope_userdata() -> Result<()> { #[derive(Clone)] struct MyUserData(Arc); impl UserData for MyUserData { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_async_method("get_value", |_, data, ()| async move { Delay::new(Duration::from_millis(10)).await; Ok(data.0.load(Ordering::Relaxed)) }); methods.add_async_method("set_value", |_, data, n| async move { Delay::new(Duration::from_millis(10)).await; data.0.store(n, Ordering::Relaxed); Ok(()) }); methods.add_async_function("sleep", |_, n| async move { Delay::new(Duration::from_millis(n)).await; Ok(format!("elapsed:{}ms", n)) }); } } let ref lua = Lua::new(); let ref arc = Arc::new(AtomicI64::new(11)); lua.async_scope(|scope| async move { let ud = scope.create_userdata(MyUserData(arc.clone()))?; lua.globals().set("userdata", ud)?; lua.load( r#" assert(userdata:get_value() == 11) userdata:set_value(12) assert(userdata.sleep(5) == "elapsed:5ms") assert(userdata:get_value() == 12) "#, ) .exec_async() .await }) .await?; assert_eq!(Arc::strong_count(arc), 1); match lua.load("userdata:get_value()").exec_async().await { Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { Error::CallbackDestructed => {} e => panic!("expected `CallbackDestructed` error cause, got {:?}", e), }, r => panic!("improper return for destructed userdata: {:?}", r), }; Ok(()) }