Add hooks support (based on rlua v0.17 implementation)

This feature works on lua54, lua53, lua52 and lua51 only.
LuaJIT is unstable.
This commit is contained in:
Alex Orlenko 2020-05-22 00:35:03 +01:00
parent f6da437d8b
commit c3822219e0
6 changed files with 598 additions and 8 deletions

View file

@ -739,7 +739,7 @@ pub const LUA_MASKLINE: c_int = 1 << (LUA_HOOKLINE as usize);
pub const LUA_MASKCOUNT: c_int = 1 << (LUA_HOOKCOUNT as usize);
/// Type for functions to be called on debug events.
pub type lua_Hook = extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug);
pub type lua_Hook = unsafe extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug);
extern "C" {
pub fn lua_getstack(L: *mut lua_State, level: c_int, ar: *mut lua_Debug) -> c_int;
@ -754,7 +754,7 @@ extern "C" {
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
pub fn lua_upvaluejoin(L: *mut lua_State, fidx1: c_int, n1: c_int, fidx2: c_int, n2: c_int);
pub fn lua_sethook(L: *mut lua_State, func: lua_Hook, mask: c_int, count: c_int);
pub fn lua_sethook(L: *mut lua_State, func: Option<lua_Hook>, mask: c_int, count: c_int);
pub fn lua_gethook(L: *mut lua_State) -> Option<lua_Hook>;
pub fn lua_gethookmask(L: *mut lua_State) -> c_int;
pub fn lua_gethookcount(L: *mut lua_State) -> c_int;

204
src/hook.rs Normal file
View file

@ -0,0 +1,204 @@
#![cfg_attr(
not(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51")),
allow(dead_code)
)]
use std::ffi::CStr;
use std::marker::PhantomData;
use std::os::raw::{c_char, c_int};
use crate::ffi::{self, lua_Debug, lua_State};
use crate::lua::Lua;
use crate::util::callback_error;
/// Contains information about currently executing Lua code.
///
/// The `Debug` structure is provided as a parameter to the hook function set with
/// [`Lua::set_hook`]. You may call the methods on this structure to retrieve information about the
/// Lua code executing at the time that the hook function was called. Further information can be
/// found in the [Lua 5.3 documentaton][lua_doc].
///
/// Requires `feature = "lua54/lua53/lua52/lua51"`
///
/// [lua_doc]: https://www.lua.org/manual/5.3/manual.html#lua_Debug
/// [`Lua::set_hook`]: struct.Lua.html#method.set_hook
#[derive(Clone)]
pub struct Debug<'a> {
ar: *mut lua_Debug,
state: *mut lua_State,
_phantom: PhantomData<&'a ()>,
}
impl<'a> Debug<'a> {
/// Corresponds to the `n` what mask.
pub fn names(&self) -> DebugNames<'a> {
unsafe {
mlua_assert!(
ffi::lua_getinfo(self.state, cstr!("n"), self.ar) != 0,
"lua_getinfo failed with `n`"
);
DebugNames {
name: ptr_to_str((*self.ar).name),
name_what: ptr_to_str((*self.ar).namewhat),
}
}
}
/// Corresponds to the `n` what mask.
pub fn source(&self) -> DebugSource<'a> {
unsafe {
mlua_assert!(
ffi::lua_getinfo(self.state, cstr!("S"), self.ar) != 0,
"lua_getinfo failed with `S`"
);
DebugSource {
source: ptr_to_str((*self.ar).source),
short_src: ptr_to_str((*self.ar).short_src.as_ptr()),
line_defined: (*self.ar).linedefined as i32,
last_line_defined: (*self.ar).lastlinedefined as i32,
what: ptr_to_str((*self.ar).what),
}
}
}
/// Corresponds to the `l` what mask. Returns the current line.
pub fn curr_line(&self) -> i32 {
unsafe {
mlua_assert!(
ffi::lua_getinfo(self.state, cstr!("l"), self.ar) != 0,
"lua_getinfo failed with `l`"
);
(*self.ar).currentline as i32
}
}
/// Corresponds to the `t` what mask. Returns true if the hook is in a function tail call, false
/// otherwise.
pub fn is_tail_call(&self) -> bool {
unsafe {
mlua_assert!(
ffi::lua_getinfo(self.state, cstr!("t"), self.ar) != 0,
"lua_getinfo failed with `t`"
);
(*self.ar).currentline != 0
}
}
/// Corresponds to the `u` what mask.
pub fn stack(&self) -> DebugStack {
unsafe {
mlua_assert!(
ffi::lua_getinfo(self.state, cstr!("u"), self.ar) != 0,
"lua_getinfo failed with `u`"
);
DebugStack {
num_ups: (*self.ar).nups as i32,
#[cfg(any(feature = "lua52", feature = "lua53", feature = "lua54"))]
num_params: (*self.ar).nparams as i32,
#[cfg(any(feature = "lua52", feature = "lua53", feature = "lua54"))]
is_vararg: (*self.ar).isvararg != 0,
}
}
}
}
#[derive(Clone, Debug)]
pub struct DebugNames<'a> {
pub name: Option<&'a [u8]>,
pub name_what: Option<&'a [u8]>,
}
#[derive(Clone, Debug)]
pub struct DebugSource<'a> {
pub source: Option<&'a [u8]>,
pub short_src: Option<&'a [u8]>,
pub line_defined: i32,
pub last_line_defined: i32,
pub what: Option<&'a [u8]>,
}
#[derive(Copy, Clone, Debug)]
pub struct DebugStack {
pub num_ups: i32,
/// Requires `feature = "lua54/lua53/lua52"`
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", doc))]
pub num_params: i32,
/// Requires `feature = "lua54/lua53/lua52"`
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", doc))]
pub is_vararg: bool,
}
/// Determines when a hook function will be called by Lua.
///
/// Requires `feature = "lua54/lua53/lua52/lua51"`
#[derive(Clone, Copy, Debug, Default)]
pub struct HookTriggers {
/// Before a function call.
pub on_calls: bool,
/// When Lua returns from a function.
pub on_returns: bool,
/// Before executing a new line, or returning from a function call.
pub every_line: bool,
/// After a certain number of VM instructions have been executed. When set to `Some(count)`,
/// `count` is the number of VM instructions to execute before calling the hook.
///
/// # Performance
///
/// Setting this option to a low value can incur a very high overhead.
pub every_nth_instruction: Option<u32>,
}
impl HookTriggers {
// Compute the mask to pass to `lua_sethook`.
pub(crate) fn mask(&self) -> c_int {
let mut mask: c_int = 0;
if self.on_calls {
mask |= ffi::LUA_MASKCALL
}
if self.on_returns {
mask |= ffi::LUA_MASKRET
}
if self.every_line {
mask |= ffi::LUA_MASKLINE
}
if self.every_nth_instruction.is_some() {
mask |= ffi::LUA_MASKCOUNT
}
mask
}
// Returns the `count` parameter to pass to `lua_sethook`, if applicable. Otherwise, zero is
// returned.
pub(crate) fn count(&self) -> c_int {
self.every_nth_instruction.unwrap_or(0) as c_int
}
}
pub(crate) unsafe extern "C" fn hook_proc(state: *mut lua_State, ar: *mut lua_Debug) {
callback_error(state, |_| {
let debug = Debug {
ar,
state,
_phantom: PhantomData,
};
let lua = Lua::make_from_ptr(state);
let hook_cb = mlua_expect!(lua.hook_callback(), "no hook callback set in hook_proc");
#[allow(clippy::match_wild_err_arm)]
match hook_cb.try_borrow_mut() {
Ok(mut b) => (&mut *b)(&lua, debug),
Err(_) => mlua_panic!("Lua should not allow hooks to be called within another hook"),
}?;
Ok(())
});
}
unsafe fn ptr_to_str<'a>(input: *const c_char) -> Option<&'a [u8]> {
if input.is_null() {
None
} else {
Some(CStr::from_ptr(input).to_bytes())
}
}

View file

@ -63,6 +63,7 @@ mod conversion;
mod error;
mod ffi;
mod function;
mod hook;
mod lua;
mod multi;
mod scope;
@ -90,6 +91,15 @@ pub use crate::types::{Integer, LightUserData, Number, RegistryKey};
pub use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods};
pub use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value};
#[cfg(any(
feature = "lua54",
feature = "lua53",
feature = "lua52",
feature = "lua51",
doc
))]
pub use crate::hook::{Debug, DebugNames, DebugSource, DebugStack, HookTriggers};
#[cfg(feature = "async")]
pub use crate::thread::AsyncThread;

View file

@ -4,7 +4,7 @@ use std::collections::HashMap;
use std::ffi::CString;
use std::marker::PhantomData;
use std::os::raw::{c_char, c_int, c_void};
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, Weak};
use std::{mem, ptr, str};
use crate::error::{Error, Result};
@ -15,7 +15,9 @@ use crate::stdlib::StdLib;
use crate::string::String;
use crate::table::Table;
use crate::thread::Thread;
use crate::types::{Callback, Integer, LightUserData, LuaRef, MaybeSend, Number, RegistryKey};
use crate::types::{
Callback, HookCallback, Integer, LightUserData, LuaRef, MaybeSend, Number, RegistryKey,
};
use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods};
use crate::util::{
assert_stack, callback_error, check_stack, get_gc_userdata, get_main_state,
@ -25,6 +27,15 @@ use crate::util::{
};
use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value};
#[cfg(any(
feature = "lua54",
feature = "lua53",
feature = "lua52",
feature = "lua51",
doc
))]
use crate::hook::{hook_proc, Debug, HookTriggers};
#[cfg(feature = "async")]
use {
crate::types::AsyncCallback,
@ -58,6 +69,8 @@ struct ExtraData {
ref_stack_size: c_int,
ref_stack_max: c_int,
ref_free: Vec<c_int>,
hook_callback: Option<HookCallback>,
}
#[cfg_attr(any(feature = "lua51", feature = "luajit"), allow(dead_code))]
@ -85,6 +98,7 @@ pub enum GCMode {
pub(crate) struct AsyncPollPending;
#[cfg(feature = "async")]
pub(crate) static WAKER_REGISTRY_KEY: u8 = 0;
pub(crate) static EXTRA_REGISTRY_KEY: u8 = 0;
/// Requires `feature = "send"`
#[cfg(feature = "send")]
@ -276,6 +290,7 @@ impl Lua {
init_gc_metatable_for::<Callback>(state, None);
init_gc_metatable_for::<Lua>(state, None);
init_gc_metatable_for::<Weak<Mutex<ExtraData>>>(state, None);
#[cfg(feature = "async")]
{
init_gc_metatable_for::<AsyncCallback>(state, None);
@ -305,8 +320,24 @@ impl Lua {
ref_stack_size: ffi::LUA_MINSTACK - 1,
ref_stack_max: 0,
ref_free: Vec::new(),
hook_callback: None,
}));
mlua_expect!(
push_gc_userdata(state, Arc::downgrade(&extra)),
"Error while storing extra data",
);
mlua_expect!(
protect_lua_closure(main_state, 1, 0, |state| {
ffi::lua_rawsetp(
state,
ffi::LUA_REGISTRYINDEX,
&EXTRA_REGISTRY_KEY as *const u8 as *mut c_void,
);
}),
"Error while storing extra data"
);
mlua_debug_assert!(
ffi::lua_gettop(main_state) == main_state_top,
"stack leak during creation"
@ -387,6 +418,90 @@ impl Lua {
unsafe { self.push_value(cb.call(())?).map(|_| 1) }
}
/// Sets a 'hook' function that will periodically be called as Lua code executes.
///
/// When exactly the hook function is called depends on the contents of the `triggers`
/// parameter, see [`HookTriggers`] for more details.
///
/// The provided hook function can error, and this error will be propagated through the Lua code
/// that was executing at the time the hook was triggered. This can be used to implement a
/// limited form of execution limits by setting [`HookTriggers.every_nth_instruction`] and
/// erroring once an instruction limit has been reached.
///
/// Requires `feature = "lua54/lua53/lua52/lua51"`
///
/// # Example
///
/// Shows each line number of code being executed by the Lua interpreter.
///
/// ```
/// # #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51"))]
/// # use mlua::{Lua, HookTriggers, Result};
/// # #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51"))]
/// # fn main() -> Result<()> {
/// let lua = Lua::new();
/// lua.set_hook(HookTriggers {
/// every_line: true, ..Default::default()
/// }, |_lua, debug| {
/// println!("line {}", debug.curr_line());
/// Ok(())
/// });
///
/// lua.load(r#"
/// local x = 2 + 3
/// local y = x * 63
/// local z = string.len(x..", "..y)
/// "#).exec()
/// # }
///
/// # #[cfg(not(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51")))]
/// # fn main() {}
/// ```
///
/// [`HookTriggers`]: struct.HookTriggers.html
/// [`HookTriggers.every_nth_instruction`]: struct.HookTriggers.html#field.every_nth_instruction
#[cfg(any(
feature = "lua54",
feature = "lua53",
feature = "lua52",
feature = "lua51",
doc
))]
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
where
F: 'static + MaybeSend + FnMut(&Lua, Debug) -> Result<()>,
{
unsafe {
let mut extra = mlua_expect!(self.extra.lock(), "extra is poisoned");
extra.hook_callback = Some(Arc::new(RefCell::new(callback)));
ffi::lua_sethook(
self.main_state,
Some(hook_proc),
triggers.mask(),
triggers.count(),
);
}
}
/// Remove any hook previously set by `set_hook`. This function has no effect if a hook was not
/// previously set.
///
/// Requires `feature = "lua54/lua53/lua52/lua51"`
#[cfg(any(
feature = "lua54",
feature = "lua53",
feature = "lua52",
feature = "lua51",
doc
))]
pub fn remove_hook(&self) {
let mut extra = mlua_expect!(self.extra.lock(), "extra is poisoned");
unsafe {
extra.hook_callback = None;
ffi::lua_sethook(self.main_state, None, 0, 0);
}
}
/// Returns the amount of memory (in bytes) currently used inside this Lua state.
pub fn used_memory(&self) -> usize {
let extra = mlua_expect!(self.extra.lock(), "extra is poisoned");
@ -1544,10 +1659,7 @@ impl Lua {
let mut waker = noop_waker();
// Try to get an outer poll waker
ffi::lua_pushlightuserdata(
state,
&WAKER_REGISTRY_KEY as *const u8 as *mut c_void,
);
ffi::lua_pushlightuserdata(state, &WAKER_REGISTRY_KEY as *const u8 as *mut c_void);
ffi::lua_rawget(state, ffi::LUA_REGISTRYINDEX);
if let Some(w) = get_gc_userdata::<Waker>(state, -1).as_ref() {
waker = (*w).clone();
@ -1675,6 +1787,36 @@ impl Lua {
Ok(())
}
pub(crate) unsafe fn make_from_ptr(state: *mut ffi::lua_State) -> Self {
let _sg = StackGuard::new(state);
assert_stack(state, 3);
ffi::lua_rawgetp(
state,
ffi::LUA_REGISTRYINDEX,
&EXTRA_REGISTRY_KEY as *const u8 as *mut c_void,
);
let extra = mlua_expect!(
(*get_gc_userdata::<Weak<Mutex<ExtraData>>>(state, -1)).upgrade(),
"extra is destroyed"
);
ffi::lua_pop(state, 1);
Lua {
state,
main_state: get_main_state(state),
extra,
ephemeral: true,
safe: true, // TODO: Inherit the attribute
_no_ref_unwind_safe: PhantomData,
}
}
pub(crate) unsafe fn hook_callback(&self) -> Option<HookCallback> {
let extra = mlua_expect!(self.extra.lock(), "extra is poisoned");
extra.hook_callback.clone()
}
}
/// Returned from [`Lua::load`] and is used to finalize loading and executing Lua main chunks.

View file

@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::os::raw::{c_int, c_void};
use std::sync::{Arc, Mutex};
use std::{fmt, mem, ptr};
@ -7,6 +8,7 @@ use futures_core::future::LocalBoxFuture;
use crate::error::Result;
use crate::ffi;
use crate::hook::Debug;
use crate::lua::Lua;
use crate::util::{assert_stack, StackGuard};
use crate::value::MultiValue;
@ -27,6 +29,8 @@ pub(crate) type Callback<'lua, 'a> =
pub(crate) type AsyncCallback<'lua, 'a> =
Box<dyn Fn(&'lua Lua, MultiValue<'lua>) -> LocalBoxFuture<'lua, Result<MultiValue<'lua>>> + 'a>;
pub(crate) type HookCallback = Arc<RefCell<dyn FnMut(&Lua, Debug) -> Result<()>>>;
#[cfg(feature = "send")]
pub trait MaybeSend: Send {}
#[cfg(feature = "send")]

230
tests/hooks.rs Normal file
View file

@ -0,0 +1,230 @@
#![cfg(any(
feature = "lua54",
feature = "lua53",
feature = "lua52",
feature = "lua51"
))]
use std::cell::RefCell;
use std::ops::Deref;
use std::str;
use std::sync::{Arc, Mutex};
use mlua::{Error, HookTriggers, Lua, Result, Value};
#[test]
fn line_counts() -> Result<()> {
let output = Arc::new(Mutex::new(Vec::new()));
let hook_output = output.clone();
let lua = Lua::new();
lua.set_hook(
HookTriggers {
every_line: true,
..Default::default()
},
move |_lua, debug| {
hook_output.lock().unwrap().push(debug.curr_line());
Ok(())
},
);
lua.load(
r#"
local x = 2 + 3
local y = x * 63
local z = string.len(x..", "..y)
"#,
)
.exec()?;
let output = output.lock().unwrap();
assert_eq!(*output, vec![2, 3, 4]);
Ok(())
}
#[test]
fn function_calls() -> Result<()> {
let output = Arc::new(Mutex::new(Vec::new()));
let hook_output = output.clone();
let lua = Lua::new();
lua.set_hook(
HookTriggers {
on_calls: true,
..Default::default()
},
move |_lua, debug| {
let names = debug.names();
let source = debug.source();
let name = names.name.map(|s| str::from_utf8(s).unwrap().to_owned());
let what = source.what.map(|s| str::from_utf8(s).unwrap().to_owned());
hook_output.lock().unwrap().push((name, what));
Ok(())
},
);
lua.load(
r#"
local v = string.len("Hello World")
"#,
)
.exec()?;
let output = output.lock().unwrap();
assert_eq!(
*output,
vec![
(None, Some("main".to_string())),
(Some("len".to_string()), Some("C".to_string()))
]
);
Ok(())
}
#[test]
fn error_within_hook() {
let lua = Lua::new();
lua.set_hook(
HookTriggers {
every_line: true,
..Default::default()
},
|_lua, _debug| {
Err(Error::RuntimeError(
"Something happened in there!".to_string(),
))
},
);
let err = lua
.load("x = 1")
.exec()
.expect_err("panic didn't propagate");
match err {
Error::CallbackError { cause, .. } => match cause.deref() {
Error::RuntimeError(s) => assert_eq!(s, "Something happened in there!"),
_ => panic!("wrong callback error kind caught"),
},
_ => panic!("wrong error kind caught"),
};
}
#[test]
fn limit_execution_instructions() {
let lua = Lua::new();
let mut max_instructions = 10000;
lua.set_hook(
HookTriggers {
every_nth_instruction: Some(30),
..Default::default()
},
move |_lua, _debug| {
max_instructions -= 30;
if max_instructions < 0 {
Err(Error::RuntimeError("time's up".to_string()))
} else {
Ok(())
}
},
);
lua.globals().set("x", Value::Integer(0)).unwrap();
let _ = lua
.load(
r#"
for i = 1, 10000 do
x = x + 1
end
"#,
)
.exec()
.expect_err("instruction limit didn't occur");
}
#[test]
fn hook_removal() {
let lua = Lua::new();
lua.set_hook(
HookTriggers {
every_nth_instruction: Some(1),
..Default::default()
},
|_lua, _debug| {
Err(Error::RuntimeError(
"this hook should've been removed by this time".to_string(),
))
},
);
assert!(lua.load("local x = 1").exec().is_err());
lua.remove_hook();
assert!(lua.load("local x = 1").exec().is_ok());
}
#[test]
fn hook_swap_within_hook() {
thread_local! {
static TL_LUA: RefCell<Option<Lua>> = RefCell::new(None);
}
TL_LUA.with(|tl| {
*tl.borrow_mut() = Some(Lua::new());
});
TL_LUA.with(|tl| {
tl.borrow().as_ref().unwrap().set_hook(
HookTriggers {
every_line: true,
..Default::default()
},
move |lua, _debug| {
lua.globals().set("ok", 1i64).unwrap();
TL_LUA.with(|tl| {
tl.borrow().as_ref().unwrap().set_hook(
HookTriggers {
every_line: true,
..Default::default()
},
move |lua, _debug| {
lua.load(
r#"
if ok ~= nil then
ok = ok + 1
end
"#,
)
.exec()
.expect("exec failure within hook");
TL_LUA.with(|tl| {
tl.borrow().as_ref().unwrap().remove_hook();
});
Ok(())
},
);
});
Ok(())
},
);
});
TL_LUA.with(|tl| {
let tl = tl.borrow();
let lua = tl.as_ref().unwrap();
assert!(lua
.load(
r#"
local x = 1
x = 2
local y = 3
"#,
)
.exec()
.is_ok());
assert_eq!(lua.globals().get::<_, i64>("ok").unwrap_or(-1), 2);
});
}