diff --git a/src/lua.rs b/src/lua.rs index d7bcb8b..b294ec9 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -7,6 +7,7 @@ use std::fmt; use std::marker::PhantomData; use std::os::raw::{c_char, c_int, c_void}; use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe, Location}; +use std::path::Path; use std::sync::{Arc, Mutex}; use std::{mem, ptr, str}; @@ -1296,6 +1297,76 @@ impl Lua { } } + /// Loads a lua file using `luaL_loadfile`. Call the returned function in order to execute the + /// file. + /// + /// You can pass an optional `env` parameter to set the environment the File will run in. + pub fn load_file<'lua>( + &'lua self, + path: impl AsRef, + env: Option>, + ) -> Result> { + let path = path.as_ref(); + + // made in collaboration with stack overflow: https://stackoverflow.com/a/59224987 + let mut path_cstr = Vec::new(); + + #[cfg(unix)] + { + use std::os::unix::ffi::OsStrExt; + path_cstr.extend(path.as_os_str().as_bytes()); + path_cstr.push(0); + } + + // fun and games for windows... + #[cfg(windows)] + { + use std::os::windows::ffi::OsStrExt; + buf.extend(path.as_os_str().encode_wide().chain(Some(0)).flat_map(|b| { + let b = b.to_ne_bytes(); + b.get(0).map(|s| *s).into_iter().chain(b.get(1).map(|s| *s)) + })); + } + + #[cfg(not(feature = "luau"))] + unsafe { + let _sg = StackGuard::new(self.state); + check_stack(self.state, 1)?; + + match ffi::luaL_loadfile(self.state, path_cstr.as_ptr() as _) { + ffi::LUA_OK => { + if let Some(env) = env { + self.push_value(env)?; + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] + ffi::lua_setupvalue(self.state, -2, 1); + #[cfg(any(feature = "lua51", feature = "luajit"))] + ffi::lua_setfenv(self.state, -2); + } + + Ok(Function(self.pop_ref())) + } + err => Err(pop_error(self.state, err)), + } + } + + #[cfg(feature = "luau")] + unsafe { + // luau has no native load file function, so we read and load the file normally + let source = std::fs::read(path).map_err(Error::external)?; + self.load_chunk( + &source, + // has to be some, otherwise luau will segfault + Some(&CString::from_vec_with_nul_unchecked(path_cstr)), + env, + if source[0] < b'\n' { + Some(ChunkMode::Binary) + } else { + None + }, + ) + } + } + pub(crate) fn load_chunk<'lua>( &'lua self, source: &[u8], diff --git a/tests/loadfile.lua b/tests/loadfile.lua new file mode 100644 index 0000000..2cedbb0 --- /dev/null +++ b/tests/loadfile.lua @@ -0,0 +1 @@ +return test diff --git a/tests/loadfile.rs b/tests/loadfile.rs new file mode 100644 index 0000000..cbddecd --- /dev/null +++ b/tests/loadfile.rs @@ -0,0 +1,17 @@ +use mlua::{Lua, Value}; + +#[test] +fn test_loadfile() { + let lua = Lua::new(); + + let env = lua.create_table().unwrap(); + env.set("test", 42u8).unwrap(); + + let file = lua + .load_file(file!().replace(".rs", ".lua"), Some(Value::Table(env))) + .unwrap(); + + let result = file.call::<_, u8>(()).unwrap(); + + assert_eq!(result, 42); +}