Add Lua::load_file function

This commit is contained in:
LordMZTE 2022-04-10 22:44:38 +02:00
parent d607039a31
commit 50c1380c2b
No known key found for this signature in database
GPG key ID: B64802DC33A64FF6
3 changed files with 89 additions and 0 deletions

View file

@ -7,6 +7,7 @@ use std::fmt;
use std::marker::PhantomData;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe, Location};
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::{mem, ptr, str};
@ -1296,6 +1297,76 @@ impl Lua {
}
}
/// Loads a lua file using `luaL_loadfile`. Call the returned function in order to execute the
/// file.
///
/// You can pass an optional `env` parameter to set the environment the File will run in.
pub fn load_file<'lua>(
&'lua self,
path: impl AsRef<Path>,
env: Option<Value<'lua>>,
) -> Result<Function<'lua>> {
let path = path.as_ref();
// made in collaboration with stack overflow: https://stackoverflow.com/a/59224987
let mut path_cstr = Vec::new();
#[cfg(unix)]
{
use std::os::unix::ffi::OsStrExt;
path_cstr.extend(path.as_os_str().as_bytes());
path_cstr.push(0);
}
// fun and games for windows...
#[cfg(windows)]
{
use std::os::windows::ffi::OsStrExt;
buf.extend(path.as_os_str().encode_wide().chain(Some(0)).flat_map(|b| {
let b = b.to_ne_bytes();
b.get(0).map(|s| *s).into_iter().chain(b.get(1).map(|s| *s))
}));
}
#[cfg(not(feature = "luau"))]
unsafe {
let _sg = StackGuard::new(self.state);
check_stack(self.state, 1)?;
match ffi::luaL_loadfile(self.state, path_cstr.as_ptr() as _) {
ffi::LUA_OK => {
if let Some(env) = env {
self.push_value(env)?;
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
ffi::lua_setupvalue(self.state, -2, 1);
#[cfg(any(feature = "lua51", feature = "luajit"))]
ffi::lua_setfenv(self.state, -2);
}
Ok(Function(self.pop_ref()))
}
err => Err(pop_error(self.state, err)),
}
}
#[cfg(feature = "luau")]
unsafe {
// luau has no native load file function, so we read and load the file normally
let source = std::fs::read(path).map_err(Error::external)?;
self.load_chunk(
&source,
// has to be some, otherwise luau will segfault
Some(&CString::from_vec_with_nul_unchecked(path_cstr)),
env,
if source[0] < b'\n' {
Some(ChunkMode::Binary)
} else {
None
},
)
}
}
pub(crate) fn load_chunk<'lua>(
&'lua self,
source: &[u8],

1
tests/loadfile.lua Normal file
View file

@ -0,0 +1 @@
return test

17
tests/loadfile.rs Normal file
View file

@ -0,0 +1,17 @@
use mlua::{Lua, Value};
#[test]
fn test_loadfile() {
let lua = Lua::new();
let env = lua.create_table().unwrap();
env.set("test", 42u8).unwrap();
let file = lua
.load_file(file!().replace(".rs", ".lua"), Some(Value::Table(env)))
.unwrap();
let result = file.call::<_, u8>(()).unwrap();
assert_eq!(result, 42);
}