Wednesday, June 17, 2009

Equality & Scala 2

My efforts, yesterday, to come up with an Equatable trait that could ease a bit the knowledge and sheer drudgery of making a valid “equals” method met with unexpected difficulties. While tempted to replace the incorrect code with a correct one before the article got many hits, I decided that the problems I encountered and the mistakes I made taught a lesson by itself.

Anyway, I'd like first to make a few remarks that were missing from that post. One thing to notice here is that while this trait might be useful to some people, it is slower than a well thought-out equals definition. What I'm trying to do is see how much Scala let me make the job of creating valid equals methods both easier and safer.

It all comes down to the idea that, being equality in languages with subclassing so full of pitfalls, and the general solution to it being a well-defined pattern, those languages ought to be doing something about it at the language level. Or library level, if possible.

It does cross my mind that this problem might be much more efficiently and elegantly solved in languages which makes it possible to generate code at compile time, such as Lisp with its macros, and Ruby with access to the AST. This is one thing I miss in Scala, and while I understand the reasons for it and empathize with them, I do come up some roadblocks to a scalable language now and then.

That said, let’s analyze what happened. I tried to model my trait after the Hashable class. There are two things, though, that made my job harder than Hashable’s. First, I depended on super.equals, while hashCode doesn’t depend on its super. This becomes important as super.equals would make reference to definitions such as testSuperEquals or equateValues, and these definitions would be overridden in the subclass. Therefore, when Point3D.equals called Point.equals (its super.equals), the method Point.equals would make use of Point3D.equateValues and Point3D.testSuperEquals instead of Point.equateValues and Point.testSuperEquals.

This is difficult enough to get around, but it gets worse because, as opposed to hashCode, equals has to reference not one, but two objects. Calling a super of oneself is easy. Not so calling a super to a method on another object.

Another problem is the warning about type erasure. The case match never tests for type "This", so it is necessary to resort to reflection, to make sure we don't try to assign a superclass to a subclass. I didn't get any error on that because the hash test was returning false first. Well, fixed that, and changed the test code to produce constant hashCode.

The goal, then, is to make equals independent on any definition that might get overridden by later equals. Or, in other words, independent of any other definition related to the trait Equatable itself. Of course, it still will need to access members of the other object, and those might get overridden. That, however, is expected and shouldn’t have any influence on the equals method.

To begin with, let’s think how our equals definition would look like in the class. In the next to last definition, we had this:


override def equals(other : Any) = equalsTo[Point](other, true)

where equals was defined as

protected def equalsTo[This <: Equatable](other : Any, superEquals : => Boolean) : Boolean

This definition delegates to the class the task of calling super.equals – or passing “true” if not appropriate. We can’t do that inside equalsTo, because the equalsTo method is never overridden. “Super”, inside it, will have only one meaning. So this solution will do.

Now, how do we deal with equateValues? One obvious solution would be doing this:

protected def equalsTo[This <: Equatable](other : Any, equateValues : Seq[Any], superEquals : => Boolean) : Boolean

There are at least two reasons this isn’t going to work. First, while it solves the problem of equateValues on “this” object, as it get passed as a parameter on equals’ definition, it doesn’t solve it for “that” object! In fact, we now have no way of finding out what are the elements to be compared in the other object.

A second problem, though, might not be as obvious. We might depend on mutable data or data which isn’t computed yet at the time we define equals. Getting around that is possible, but would be much worse than defining an equals method by oneself!

So, what we’ll do is pass a function instead. A function which, given a “that” object, returns the sequence we need. Or, in other words, we want a object of this type:

(that : This) => Seq[Any]

The only problem with that is that “this” inside our trait does not have type This. We’ll need to receive a reference to ourselves, properly typed! Our definition, then, should be:

protected def equalsTo[This <: Equatable](self : This, other : Any, equateValues : This => Seq[Any], superEquals : => Boolean) : Boolean

The body of our function, then, becomes:

(other : @unchecked) match {
case that : This =>
(
that.canEqual(this)
&& superEquals
&& hashCode == that.hashCode // Can speed up or slow down
&& equateValues(self).zip(equateValues(that)).foldLeft(true) {
(equals, tuple) => equals && tuple._1 == tuple._2
} && equateValues.size == that.equateValues.size
)
case _ => false
}}
}

Now, how would our equals definition look like? Here:

override def equals(other : Any) = equalsTo[Point](this, other, that => List(that.x, that.y), true)
override def equals(other : Any) = equalsTo[Point3D](this, other, that => List(that.z), super.equals(other))

For big objects, inserting the function in the parameters might be awkward. Instead, we might prefer to assign the function to a val first. For example:

private val pointValues = (that : Point) => List(that.x, that.y)

So, here’s everything together, with a bit of further editing for performance resons:

trait Equatable extends scala.util.Hashable {
def canEqual (that : Any) : Boolean

protected def equalsTo[This <: Equatable](self : This, other : Any, equateValues : This => Seq[Any], superEquals : => Boolean) : Boolean =
(other : @unchecked) match {
case that : This if self.getClass.isAssignableFrom(that.getClass) =>
// Testing for hash code can improve or decrease performance, depending on the implementation;
// if hashCode gets implemented as a val, it will make equality faster
if (that.canEqual(this) && superEquals && hashCode == that.hashCode) {
val thisValues = equateValues(self)
val thatValues = equateValues(that)
(thisValues.zip(thatValues).foldLeft(true) { (equals, tuple) => equals && tuple._1 == tuple._2 }
&& thisValues.size == thatValues.size)
} else false
case _ => false
}
}

class Point (val x : Int, val y : Int) extends Equatable {
override def toString = "(%d, %d)" format (x, y)

// Hashable definitions
// override def hashValues = List(x, y)
override def hashValues = List(0) // We don't want the hash skipping our tests

// Equatable definitions
private val pointValues = (that : Point) => List(that.x, that.y) // one way
override def equals(other : Any) = equalsTo[Point](this, other, pointValues, true)
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]
}

class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {
override def toString = "(%d, %d, %d)" format (x, y, z)

// Hashable defintions
// override def hashValues = List(x, y, z)
override def hashValues = List(0) // We don't want the hash skipping our tests

// Equatable defintions
// private val point3DValues = (that : Point3D) => List(that.z)
override def equals(other : Any) = equalsTo[Point3D](this, other, that => List(that.z), super.equals(other))
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]
}


And the tests. I thought about doing them as assertions, but it was too silent for my taste. Anyway,

scala> val x = new Point(1, 2); val x2 = new Point(1, 2)
x: Point = (1, 2)
x2: Point = (1, 2)

scala> x == x2 // super.equals does not get called, so we do not perform reference equality
res0: Boolean = true

scala> val y = new Point(2, 1)
y: Point = (2, 1)

scala> x == y // expected false
res1: Boolean = false

scala> val z = new Point3D(1, 2, 0)
z: Point3D = (1, 2, 0)

scala> x == z // false in that canEqual this test
res2: Boolean = false

scala> z == x // false through reflection isAssignableFrom
res3: Boolean = false

scala> val z2 = new Point3D(2, 1, 0)
z2: Point3D = (2, 1, 0)

scala> z == z2 // expected false
res4: Boolean = false

scala> val z3 = new Point3D(1, 2, 0)
z3: Point3D = (1, 2, 0)

scala> z == z3 // expected true
res5: Boolean = true

scala> val z4 = new Point3D(1, 2, 1)
z4: Point3D = (1, 2, 1)

scala> z == z4 // extected false
res6: Boolean = false

scala> x == x
res7: Boolean = true

No comments:

Post a Comment