Fix scoped async destruction of partially polled futures

This commit is contained in:
Alex Orlenko 2021-02-21 23:34:35 +00:00
parent 6a77b5f003
commit 2aed548747
3 changed files with 43 additions and 9 deletions

View file

@ -1787,9 +1787,11 @@ impl Lua {
})?,
)?;
// We set `poll` variable in the env table to be able to destroy upvalues
self.load(
r#"
local poll = get_poll(...)
poll = get_poll(...)
local poll = poll
while true do
ready, res = poll()
if ready then

View file

@ -23,9 +23,8 @@ use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti, Value};
#[cfg(feature = "async")]
use {
crate::types::AsyncCallback,
futures_core::future::Future,
futures_core::future::{Future, LocalBoxFuture},
futures_util::future::{self, TryFutureExt},
std::os::raw::c_char,
};
/// Constructed by the [`Lua::scope`] method, allows temporarily creating Lua userdata and
@ -420,12 +419,11 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
#[cfg(any(feature = "lua51", feature = "luajit"))]
ffi::lua_getfenv(state, -1);
// Then, get the get_poll() closure using the corresponding key
let key = "get_poll";
ffi::lua_pushlstring(state, key.as_ptr() as *const c_char, key.len());
// Second, get the `get_poll()` closure using the corresponding key
ffi::lua_pushstring(state, cstr!("get_poll"));
ffi::lua_rawget(state, -2);
// Finally, destroy all upvalues
// Destroy all upvalues
ffi::lua_getupvalue(state, -1, 1);
let ud1 = take_userdata::<AsyncCallback>(state);
ffi::lua_pushnil(state);
@ -437,8 +435,25 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
ffi::lua_setupvalue(state, -2, 2);
ffi::lua_pop(state, 1);
let mut data: Vec<Box<dyn Any>> = vec![Box::new(ud1), Box::new(ud2)];
vec![Box::new(ud1), Box::new(ud2)]
// Finally, get polled future and destroy it
ffi::lua_pushstring(state, cstr!("poll"));
if ffi::lua_rawget(state, -2) == ffi::LUA_TFUNCTION {
ffi::lua_getupvalue(state, -1, 1);
let ud3 = take_userdata::<LocalBoxFuture<Result<MultiValue>>>(state);
ffi::lua_pushnil(state);
ffi::lua_setupvalue(state, -2, 1);
data.push(Box::new(ud3));
ffi::lua_getupvalue(state, -1, 2);
let ud4 = take_userdata::<Lua>(state);
ffi::lua_pushnil(state);
ffi::lua_setupvalue(state, -2, 2);
data.push(Box::new(ud4));
}
data
}));
Ok(f)

View file

@ -22,7 +22,9 @@ use std::time::Duration;
use futures_timer::Delay;
use futures_util::stream::TryStreamExt;
use mlua::{Error, Function, Lua, Result, Table, TableExt, UserData, UserDataMethods};
use mlua::{
Error, Function, Lua, Result, Table, TableExt, Thread, UserData, UserDataMethods, Value,
};
#[tokio::test]
async fn test_async_function() -> Result<()> {
@ -332,11 +334,18 @@ async fn test_async_scope() -> Result<()> {
let _ = f.call_async::<u64, ()>(10).await?;
assert_eq!(Rc::strong_count(rc), 1);
// Create future in partialy polled state (Poll::Pending)
let g = lua.create_thread(f)?;
g.resume::<u64, ()>(10)?;
lua.globals().set("g", g)?;
assert_eq!(Rc::strong_count(rc), 2);
Ok(())
});
assert_eq!(Rc::strong_count(rc), 1);
let _ = fut.await?;
assert_eq!(Rc::strong_count(rc), 1);
match lua
.globals()
@ -351,6 +360,14 @@ async fn test_async_scope() -> Result<()> {
r => panic!("improper return for destructed function: {:?}", r),
};
match lua.globals().get::<_, Thread>("g")?.resume::<_, Value>(()) {
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
Error::CallbackDestructed => {}
e => panic!("expected `CallbackDestructed` error cause, got {:?}", e),
},
r => panic!("improper return for destructed function: {:?}", r),
};
Ok(())
}