Implement PartialEq trait for Value (and subtypes)

Add equals() method to compare values optionally invoking __eq.
This commit is contained in:
Alex Orlenko 2020-01-06 23:59:50 +00:00
parent 831161bfda
commit 5eec0ef56b
12 changed files with 331 additions and 37 deletions

View file

@ -160,3 +160,9 @@ impl<'lua> Function<'lua> {
}
}
}
impl<'lua> PartialEq for Function<'lua> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

View file

@ -166,6 +166,62 @@ impl<'lua> Table<'lua> {
self.get::<_, Function>(key)?.call(args)
}
/// Compares two tables for equality.
///
/// Tables are compared by reference first.
/// If they are not primitively equals, then mlua will try to invoke the `__eq` metamethod.
/// mlua will check `self` first for the metamethod, then `other` if not found.
///
/// # Examples
///
/// Compare two tables using `__eq` metamethod:
///
/// ```
/// # use mlua::{Lua, Result, Table};
/// # fn main() -> Result<()> {
/// # let lua = Lua::new();
/// let table1 = lua.create_table()?;
/// table1.set(1, "value")?;
///
/// let table2 = lua.create_table()?;
/// table2.set(2, "value")?;
///
/// let always_equals_mt = lua.create_table()?;
/// always_equals_mt.set("__eq", lua.create_function(|_, (_t1, _t2): (Table, Table)| Ok(true))?)?;
/// table2.set_metatable(Some(always_equals_mt));
///
/// assert!(table1.equals(&table1.clone())?);
/// assert!(table1.equals(&table2)?);
/// # Ok(())
/// # }
/// ```
pub fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
let other = other.as_ref();
if self == other {
return Ok(true);
}
// Compare using __eq metamethod if exists
// First, check the self for the metamethod.
// If self does not define it, then check the other table.
if let Some(mt) = self.get_metatable() {
if mt.contains_key("__eq")? {
return mt
.get::<_, Function>("__eq")?
.call((self.clone(), other.clone()));
}
}
if let Some(mt) = other.get_metatable() {
if mt.contains_key("__eq")? {
return mt
.get::<_, Function>("__eq")?
.call((self.clone(), other.clone()));
}
}
Ok(false)
}
/// Removes a key from the table, returning the value at the key
/// if the key was previously in the table.
pub fn raw_remove<K: ToLua<'lua>>(&self, key: K) -> Result<()> {
@ -368,6 +424,19 @@ impl<'lua> Table<'lua> {
}
}
impl<'lua> PartialEq for Table<'lua> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<'lua> AsRef<Table<'lua>> for Table<'lua> {
#[inline]
fn as_ref(&self) -> &Self {
self
}
}
/// An iterator over the pairs of a Lua table.
///
/// This struct is created by the [`Table::pairs`] method.

View file

@ -143,3 +143,9 @@ impl<'lua> Thread<'lua> {
}
}
}
impl<'lua> PartialEq for Thread<'lua> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

View file

@ -5,6 +5,7 @@ use std::{fmt, mem, ptr};
use crate::error::Result;
use crate::ffi;
use crate::lua::Lua;
use crate::util::{assert_stack, StackGuard};
use crate::value::MultiValue;
/// Type of Lua integer numbers.
@ -92,3 +93,16 @@ impl<'lua> Drop for LuaRef<'lua> {
self.lua.drop_ref(self)
}
}
impl<'lua> PartialEq for LuaRef<'lua> {
fn eq(&self, other: &Self) -> bool {
let lua = self.lua;
unsafe {
let _sg = StackGuard::new(lua.state);
assert_stack(lua.state, 2);
lua.push_ref(&self);
lua.push_ref(&other);
ffi::lua_rawequal(lua.state, -1, -2) == 1
}
}
}

View file

@ -2,7 +2,9 @@ use std::cell::{Ref, RefCell, RefMut};
use crate::error::{Error, Result};
use crate::ffi;
use crate::function::Function;
use crate::lua::Lua;
use crate::table::Table;
use crate::types::LuaRef;
use crate::util::{assert_stack, get_userdata, StackGuard};
use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti};
@ -398,6 +400,42 @@ impl<'lua> AnyUserData<'lua> {
V::from_lua(res, lua)
}
fn get_metatable(&self) -> Result<Table<'lua>> {
unsafe {
let lua = self.0.lua;
let _sg = StackGuard::new(lua.state);
assert_stack(lua.state, 3);
lua.push_ref(&self.0);
if ffi::lua_getmetatable(lua.state, -1) == 0 {
return Err(Error::UserDataTypeMismatch);
}
Ok(Table(lua.pop_ref()))
}
}
pub(crate) fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
let other = other.as_ref();
if self == other {
return Ok(true);
}
let mt = self.get_metatable()?;
if mt != other.get_metatable()? {
return Ok(false);
}
if mt.contains_key("__eq")? {
return mt
.get::<_, Function>("__eq")?
.call((self.clone(), other.clone()));
}
Ok(false)
}
fn inspect<'a, T, R, F>(&'a self, func: F) -> Result<R>
where
T: 'static + UserData,
@ -428,3 +466,16 @@ impl<'lua> AnyUserData<'lua> {
}
}
}
impl<'lua> PartialEq for AnyUserData<'lua> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<'lua> AsRef<AnyUserData<'lua>> for AnyUserData<'lua> {
#[inline]
fn as_ref(&self) -> &Self {
self
}
}

View file

@ -2,6 +2,7 @@ use std::iter::{self, FromIterator};
use std::{slice, str, vec};
use crate::error::{Error, Result};
use crate::ffi;
use crate::function::Function;
use crate::lua::Lua;
use crate::string::String;
@ -61,6 +62,51 @@ impl<'lua> Value<'lua> {
Value::UserData(_) | Value::Error(_) => "userdata",
}
}
/// Compares two values for equality.
///
/// Equality comparisons do not convert strings to numbers or vice versa.
/// Tables, Functions, Threads, and Userdata are compared by reference:
/// two objects are considered equal only if they are the same object.
///
/// If Tables or Userdata have `__eq` metamethod then mlua will try to invoke it.
/// The first value is checked first. If that value does not define a metamethod
/// for `__eq`, then mlua will check the second value.
/// Then mlua calls the metamethod with the two values as arguments, if found.
pub fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
match (self, other.as_ref()) {
(Value::Table(a), Value::Table(b)) => a.equals(b),
(Value::UserData(a), Value::UserData(b)) => a.equals(b),
_ => Ok(self == other.as_ref()),
}
}
}
impl<'lua> PartialEq for Value<'lua> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Value::Nil, Value::Nil) => true,
(Value::Boolean(a), Value::Boolean(b)) => a == b,
(Value::LightUserData(a), Value::LightUserData(b)) => a == b,
(Value::Integer(a), Value::Integer(b)) => *a == *b,
(Value::Integer(a), Value::Number(b)) => *a as ffi::lua_Number == *b,
(Value::Number(a), Value::Integer(b)) => *a == *b as ffi::lua_Number,
(Value::Number(a), Value::Number(b)) => *a == *b,
(Value::String(a), Value::String(b)) => a == b,
(Value::Table(a), Value::Table(b)) => a == b,
(Value::Function(a), Value::Function(b)) => a == b,
(Value::Thread(a), Value::Thread(b)) => a == b,
(Value::UserData(a), Value::UserData(b)) => a == b,
_ => false,
}
}
}
impl<'lua> AsRef<Value<'lua>> for Value<'lua> {
#[inline]
fn as_ref(&self) -> &Self {
self
}
}
/// Trait for types convertible to `Value`.

View file

@ -51,15 +51,15 @@ fn test_gc_error() {
match lua
.load(
r#"
val = nil
table = {}
setmetatable(table, {
__gc = function()
error("gcwascalled")
end
})
table = nil
collectgarbage("collect")
val = nil
table = {}
setmetatable(table, {
__gc = function()
error("gcwascalled")
end
})
table = nil
collectgarbage("collect")
"#,
)
.exec()

View file

@ -148,6 +148,40 @@ fn test_metatable() -> Result<()> {
Ok(())
}
#[test]
fn test_table_eq() -> Result<()> {
let lua = Lua::new();
let globals = lua.globals();
lua.load(
r#"
table1 = {1}
table2 = {1}
table3 = table1
table4 = {1}
setmetatable(table4, {
__eq = function(a, b) return a[1] == b[1] end
})
"#,
)
.exec()?;
let table1 = globals.get::<_, Table>("table1")?;
let table2 = globals.get::<_, Table>("table2")?;
let table3 = globals.get::<_, Table>("table3")?;
let table4 = globals.get::<_, Table>("table4")?;
assert!(table1 != table2);
assert!(!table1.equals(&table2)?);
assert!(table1 == table3);
assert!(table1.equals(&table3)?);
assert!(table1 != table4);
assert!(table1.equals(&table4)?);
Ok(())
}
#[test]
fn test_table_error() -> Result<()> {
let lua = Lua::new();

View file

@ -92,13 +92,13 @@ fn test_lua_multi() -> Result<()> {
lua.load(
r#"
function concat(arg1, arg2)
return arg1 .. arg2
end
function concat(arg1, arg2)
return arg1 .. arg2
end
function mreturn()
return 1, 2, 3, 4, 5, 6
end
function mreturn()
return 1, 2, 3, 4, 5, 6
end
"#,
)
.exec()?;

View file

@ -20,13 +20,13 @@ fn test_thread() -> Result<()> {
let thread = lua.create_thread(
lua.load(
r#"
function (s)
local sum = s
for i = 1,4 do
sum = sum + coroutine.yield(sum)
end
return sum
function (s)
local sum = s
for i = 1,4 do
sum = sum + coroutine.yield(sum)
end
return sum
end
"#,
)
.eval()?,
@ -47,11 +47,11 @@ fn test_thread() -> Result<()> {
let accumulate = lua.create_thread(
lua.load(
r#"
function (sum)
while true do
sum = sum + coroutine.yield(sum)
end
function (sum)
while true do
sum = sum + coroutine.yield(sum)
end
end
"#,
)
.eval::<Function>()?,

View file

@ -13,7 +13,7 @@ use std::sync::Arc;
use mlua::{
AnyUserData, ExternalError, Function, Lua, MetaMethod, Result, String, UserData,
UserDataMethods,
UserDataMethods, Value,
};
#[test]
@ -96,6 +96,9 @@ fn test_metamethods() -> Result<()> {
MetaMethod::Sub,
|_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 - rhs.0)),
);
methods.add_meta_function(MetaMethod::Eq, |_, (lhs, rhs): (MyUserData, MyUserData)| {
Ok(lhs.0 == rhs.0)
});
methods.add_meta_method(MetaMethod::Index, |_, data, index: String| {
if index.to_str()? == "inner" {
Ok(data.0)
@ -122,6 +125,7 @@ fn test_metamethods() -> Result<()> {
let globals = lua.globals();
globals.set("userdata1", MyUserData(7))?;
globals.set("userdata2", MyUserData(3))?;
globals.set("userdata3", MyUserData(3))?;
assert_eq!(
lua.load("userdata1 + userdata2").eval::<MyUserData>()?.0,
10
@ -151,6 +155,13 @@ fn test_metamethods() -> Result<()> {
assert_eq!(ipairs_it.call::<_, i64>(())?, 28);
assert!(lua.load("userdata2.nonexist_field").eval::<()>().is_err());
let userdata2: Value = globals.get("userdata2")?;
let userdata3: Value = globals.get("userdata3")?;
assert!(lua.load("userdata2 == userdata3").eval::<bool>()?);
assert!(userdata2 != userdata3); // because references are differ
assert!(userdata2.equals(userdata3)?);
Ok(())
}
@ -175,18 +186,18 @@ fn test_gc_userdata() -> Result<()> {
assert!(lua
.load(
r#"
local tbl = setmetatable({
userdata = userdata
}, { __gc = function(self)
-- resurrect userdata
hatch = self.userdata
end })
local tbl = setmetatable({
userdata = userdata
}, { __gc = function(self)
-- resurrect userdata
hatch = self.userdata
end })
tbl = nil
userdata = nil -- make table and userdata collectable
collectgarbage("collect")
hatch:access()
"#
tbl = nil
userdata = nil -- make table and userdata collectable
collectgarbage("collect")
hatch:access()
"#
)
.exec()
.is_err());

57
tests/value.rs Normal file
View file

@ -0,0 +1,57 @@
use mlua::{Lua, Result, Value};
#[test]
fn test_value_eq() -> Result<()> {
let lua = Lua::new();
let globals = lua.globals();
lua.load(
r#"
table1 = {1}
table2 = {1}
string1 = "hello"
string2 = "hello"
num1 = 1
num2 = 1.0
num3 = "1"
func1 = function() end
func2 = func1
func3 = function() end
thread1 = coroutine.create(function() end)
thread2 = thread1
setmetatable(table1, {
__eq = function(a, b) return a[1] == b[1] end
})
"#,
)
.exec()?;
let table1: Value = globals.get("table1")?;
let table2: Value = globals.get("table2")?;
let string1: Value = globals.get("string1")?;
let string2: Value = globals.get("string2")?;
let num1: Value = globals.get("num1")?;
let num2: Value = globals.get("num2")?;
let num3: Value = globals.get("num3")?;
let func1: Value = globals.get("func1")?;
let func2: Value = globals.get("func2")?;
let func3: Value = globals.get("func3")?;
let thread1: Value = globals.get("thread1")?;
let thread2: Value = globals.get("thread2")?;
assert!(table1 != table2);
assert!(table1.equals(table2)?);
assert!(string1 == string2);
assert!(string1.equals(string2)?);
assert!(num1 == num2);
assert!(num1.equals(num2)?);
assert!(num1 != num3);
assert!(func1 == func2);
assert!(func1 != func3);
assert!(!func1.equals(func3)?);
assert!(thread1 == thread2);
assert!(thread1.equals(thread2)?);
Ok(())
}