Fix __index and __newindex wrappers for Luau

This commit is contained in:
Alex Orlenko 2022-03-21 01:06:31 +00:00
parent eed7b1f3af
commit fa99f62a99
No known key found for this signature in database
GPG key ID: 4C150C250863B96D
3 changed files with 122 additions and 131 deletions

View file

@ -209,33 +209,24 @@ fn create_userdata(c: &mut Criterion) {
}); });
} }
fn userdata_index(c: &mut Criterion) { fn call_userdata_index(c: &mut Criterion) {
struct UserData(i64); struct UserData(i64);
impl LuaUserData for UserData { impl LuaUserData for UserData {
fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_meta_method(mlua::MetaMethod::Index, move |_, _, index: String| { methods.add_meta_method(LuaMetaMethod::Index, move |_, _, index: String| Ok(index));
Ok(index)
});
} }
} }
let lua = Lua::new(); let lua = Lua::new();
lua.globals().set("userdata", UserData(10)).unwrap(); lua.globals().set("userdata", UserData(10)).unwrap();
c.bench_function("index [table userdata] 10", |b| { c.bench_function("call [userdata index] 10", |b| {
b.iter_batched_ref( b.iter_batched_ref(
|| { || {
collect_gc_twice(&lua); collect_gc_twice(&lua);
lua.load( lua.load("function() for i = 1,10 do local v = userdata.test end end")
r#" .eval::<LuaFunction>()
function() .unwrap()
for i = 1,10 do
local v = userdata.test
end
end"#,
)
.eval::<LuaFunction>()
.unwrap()
}, },
|function| { |function| {
function.call::<_, ()>(()).unwrap(); function.call::<_, ()>(()).unwrap();
@ -319,7 +310,7 @@ criterion_group! {
call_concat_callback, call_concat_callback,
create_registry_values, create_registry_values,
create_userdata, create_userdata,
userdata_index, call_userdata_index,
call_userdata_method, call_userdata_method,
call_async_userdata_method, call_async_userdata_method,
} }

View file

@ -1,4 +1,5 @@
use std::any::{Any, TypeId}; use std::any::{Any, TypeId};
use std::ffi::CStr;
use std::fmt::Write; use std::fmt::Write;
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
@ -9,7 +10,7 @@ use once_cell::sync::Lazy;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::ffi::{self, lua_error}; use crate::ffi;
static METATABLE_CACHE: Lazy<FxHashMap<TypeId, u8>> = Lazy::new(|| { static METATABLE_CACHE: Lazy<FxHashMap<TypeId, u8>> = Lazy::new(|| {
let mut map = FxHashMap::with_capacity_and_hasher(32, Default::default()); let mut map = FxHashMap::with_capacity_and_hasher(32, Default::default());
@ -370,110 +371,117 @@ pub unsafe fn get_gc_userdata<T: Any>(state: *mut ffi::lua_State, index: c_int)
ud ud
} }
unsafe extern "C" fn isfunction_impl(state: *mut ffi::lua_State) -> c_int { unsafe extern "C" fn lua_error_impl(state: *mut ffi::lua_State) -> c_int {
// stack: var ffi::lua_error(state);
ffi::luaL_checkstack(state, 1, ptr::null()); }
unsafe extern "C" fn lua_isfunction_impl(state: *mut ffi::lua_State) -> c_int {
let t = ffi::lua_type(state, -1); let t = ffi::lua_type(state, -1);
ffi::lua_pop(state, 1); ffi::lua_pop(state, 1);
ffi::lua_pushboolean(state, if t == ffi::LUA_TFUNCTION { 1 } else { 0 }); ffi::lua_pushboolean(state, (t == ffi::LUA_TFUNCTION) as c_int);
1 1
} }
unsafe extern "C" fn error_impl(state: *mut ffi::lua_State) -> c_int { unsafe fn init_userdata_metatable_index(state: *mut ffi::lua_State) -> Result<()> {
// stack: message let index_key = &USERDATA_METATABLE_INDEX as *const u8 as *const _;
ffi::luaL_checkstack(state, 1, ptr::null()); if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, index_key) == ffi::LUA_TFUNCTION {
return Ok(());
}
ffi::lua_pop(state, 1);
lua_error(state); // Create and cache `__index` helper
} let code = cstr!(
r#"
pub unsafe fn init_userdata_metatable_index(state: *mut ffi::lua_State) -> Result<()> { local error, isfunction = ...
protect_lua!(state, 0, 1, |state| { return function (__index, field_getters, methods)
let ret = ffi::luaL_dostring( return function (self, key)
state, if field_getters ~= nil then
cstr!( local field_getter = field_getters[key]
r#" if field_getter ~= nil then
return function (isfunction, error) return field_getter(self)
return function (__index, field_getters, methods)
return function (self, key)
if field_getters ~= nil then
local field_getter = field_getters[key]
if field_getter ~= nil then
return field_getter(self)
end
end
if methods ~= nil then
local method = methods[key]
if method ~= nil then
return method
end
end
if isfunction(__index) then
return __index(self, key)
elseif __index == nil then
error('attempt to get an unknown field \'' .. key .. '\'')
else
return __index[key]
end
end end
end end
if methods ~= nil then
local method = methods[key]
if method ~= nil then
return method
end
end
if isfunction(__index) then
return __index(self, key)
elseif __index == nil then
error("attempt to get an unknown field '"..key.."'")
else
return __index[key]
end
end end
"# end
), "#
); );
let code_len = CStr::from_ptr(code).to_bytes().len();
protect_lua!(state, 0, 1, |state| {
let ret = ffi::luaL_loadbuffer(state, code, code_len, cstr!("__mlua_index"));
if ret != ffi::LUA_OK { if ret != ffi::LUA_OK {
ffi::lua_error(state); ffi::lua_error(state);
} }
ffi::lua_pushcfunction(state, isfunction_impl); ffi::lua_pushcfunction(state, lua_error_impl);
ffi::lua_pushcfunction(state, error_impl); ffi::lua_pushcfunction(state, lua_isfunction_impl);
ffi::lua_call(state, 2, 1); ffi::lua_call(state, 2, 1);
})?;
Ok(()) // Store in the registry
ffi::lua_pushvalue(state, -1);
ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, index_key);
})
} }
pub unsafe fn init_userdata_metatable_newindex(state: *mut ffi::lua_State) -> Result<()> { pub unsafe fn init_userdata_metatable_newindex(state: *mut ffi::lua_State) -> Result<()> {
protect_lua!(state, 0, 1, |state| { let newindex_key = &USERDATA_METATABLE_NEWINDEX as *const u8 as *const _;
let ret = ffi::luaL_dostring( if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, newindex_key) == ffi::LUA_TFUNCTION {
state, return Ok(());
cstr!( }
r#" ffi::lua_pop(state, 1);
return function (isfunction, error)
return function (__newindex, field_setters)
return function (self, key, value)
if field_setters ~= nil then
local field_setter = field_setters[key]
if field_setter ~= nil then
field_setter(self, value)
return
end
end
if isfunction(__newindex) then // Create and cache `__newindex` helper
__newindex(self, key, value) let code = cstr!(
elseif __newindex == nil then r#"
error('attempt to set an unknown field \'' .. key .. '\'') local error, isfunction = ...
else return function (__newindex, field_setters)
__newindex[key] = value return function (self, key, value)
end if field_setters ~= nil then
local field_setter = field_setters[key]
if field_setter ~= nil then
field_setter(self, value)
return
end end
end end
if isfunction(__newindex) then
__newindex(self, key, value)
elseif __newindex == nil then
error("attempt to set an unknown field '"..key.."'")
else
__newindex[key] = value
end
end end
"# end
), "#
); );
let code_len = CStr::from_ptr(code).to_bytes().len();
protect_lua!(state, 0, 1, |state| {
let ret = ffi::luaL_loadbuffer(state, code, code_len, cstr!("__mlua_newindex"));
if ret != ffi::LUA_OK { if ret != ffi::LUA_OK {
ffi::lua_error(state); ffi::lua_error(state);
} }
ffi::lua_pushcfunction(state, isfunction_impl); ffi::lua_pushcfunction(state, lua_error_impl);
ffi::lua_pushcfunction(state, error_impl); ffi::lua_pushcfunction(state, lua_isfunction_impl);
ffi::lua_call(state, 2, 1); ffi::lua_call(state, 2, 1);
})?;
Ok(()) // Store in the registry
ffi::lua_pushvalue(state, -1);
ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, newindex_key);
})
} }
// Populates the given table with the appropriate members to be a userdata metatable for the given type. // Populates the given table with the appropriate members to be a userdata metatable for the given type.
@ -494,6 +502,7 @@ pub unsafe fn init_userdata_metatable<T>(
ffi::lua_pushvalue(state, metatable); ffi::lua_pushvalue(state, metatable);
if field_getters.is_some() || methods.is_some() { if field_getters.is_some() || methods.is_some() {
// Push `__index` generator function
init_userdata_metatable_index(state)?; init_userdata_metatable_index(state)?;
push_string(state, "__index")?; push_string(state, "__index")?;
@ -508,7 +517,8 @@ pub unsafe fn init_userdata_metatable<T>(
} }
} }
protect_lua!(state, 4, 1, |state| ffi::lua_call(state, 3, 1))?; // Generate `__index`
protect_lua!(state, 4, 1, fn(state) ffi::lua_call(state, 3, 1))?;
} }
_ => mlua_panic!("improper __index type {}", index_type), _ => mlua_panic!("improper __index type {}", index_type),
} }
@ -517,6 +527,7 @@ pub unsafe fn init_userdata_metatable<T>(
} }
if let Some(field_setters) = field_setters { if let Some(field_setters) = field_setters {
// Push `__newindex` generator function
init_userdata_metatable_newindex(state)?; init_userdata_metatable_newindex(state)?;
push_string(state, "__newindex")?; push_string(state, "__newindex")?;
@ -524,8 +535,8 @@ pub unsafe fn init_userdata_metatable<T>(
match newindex_type { match newindex_type {
ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => {
ffi::lua_pushvalue(state, field_setters); ffi::lua_pushvalue(state, field_setters);
// Generate `__newindex`
protect_lua!(state, 3, 1, |state| ffi::lua_call(state, 2, 1))?; protect_lua!(state, 3, 1, fn(state) ffi::lua_call(state, 2, 1))?;
} }
_ => mlua_panic!("improper __newindex type {}", newindex_type), _ => mlua_panic!("improper __newindex type {}", newindex_type),
} }
@ -976,3 +987,5 @@ pub(crate) unsafe fn get_destructed_userdata_metatable(state: *mut ffi::lua_Stat
static DESTRUCTED_USERDATA_METATABLE: u8 = 0; static DESTRUCTED_USERDATA_METATABLE: u8 = 0;
static ERROR_PRINT_BUFFER_KEY: u8 = 0; static ERROR_PRINT_BUFFER_KEY: u8 = 0;
static USERDATA_METATABLE_INDEX: u8 = 0;
static USERDATA_METATABLE_NEWINDEX: u8 = 0;

View file

@ -7,14 +7,13 @@ use std::sync::{
Arc, Arc,
}; };
use std::time::Duration; use std::time::Duration;
use std::unreachable;
use futures_timer::Delay; use futures_timer::Delay;
use futures_util::stream::TryStreamExt; use futures_util::stream::TryStreamExt;
use mlua::{ use mlua::{
Error, ExternalError, Function, Lua, LuaOptions, Result, StdLib, Table, TableExt, Thread, Error, Function, Lua, LuaOptions, Result, StdLib, Table, TableExt, Thread, UserData,
ToLua, UserData, UserDataMethods, Value, UserDataMethods, Value,
}; };
#[tokio::test] #[tokio::test]
@ -306,40 +305,28 @@ async fn test_async_userdata() -> Result<()> {
Ok(format!("elapsed:{}ms", n)) Ok(format!("elapsed:{}ms", n))
}); });
#[cfg(not(feature = "lua51"))] #[cfg(not(any(feature = "lua51", feature = "luau")))]
methods.add_async_meta_method(MetaMethod::Index, |lua, data, key: String| async move {
Delay::new(Duration::from_millis(10)).await;
match key.as_str() {
"ms" => Ok(data.0.load(Ordering::Relaxed).to_lua(lua)?),
"s" => Ok(((data.0.load(Ordering::Relaxed) as f64) / 1000.0).to_lua(lua)?),
_ => Ok(Value::Nil),
}
});
#[cfg(not(feature = "lua51"))]
methods.add_async_meta_method( methods.add_async_meta_method(
MetaMethod::NewIndex, mlua::MetaMethod::Index,
|_, data, (key, value): (String, Value)| async move { |_, data, key: String| async move {
Delay::new(Duration::from_millis(10)).await; Delay::new(Duration::from_millis(10)).await;
match key.as_str() { match key.as_str() {
"ms" | "s" => { "ms" => Ok(Some(data.0.load(Ordering::Relaxed) as f64)),
let value = match value { "s" => Ok(Some((data.0.load(Ordering::Relaxed) as f64) / 1000.0)),
Value::Integer(value) => value as f64, _ => Ok(None),
Value::Number(value) => value, }
_ => Err("wrong type for value".to_lua_err())?, },
}; );
let value = match key.as_str() {
"ms" => value,
"s" => value * 1000.0,
_ => unreachable!(),
};
data.0.store(value as u64, Ordering::Relaxed);
Ok(()) #[cfg(not(any(feature = "lua51", feature = "luau")))]
} methods.add_async_meta_method(
_ => Err(format!("key '{}' not found", key).to_lua_err()), mlua::MetaMethod::NewIndex,
|_, data, (key, value): (String, f64)| async move {
Delay::new(Duration::from_millis(10)).await;
match key.as_str() {
"ms" => Ok(data.0.store(value as u64, Ordering::Relaxed)),
"s" => Ok(data.0.store((value * 1000.0) as u64, Ordering::Relaxed)),
_ => Err(Error::external(format!("key '{}' not found", key))),
} }
}, },
); );