## Tuesday, June 16, 2009

### Equality & Scala

I'm just through reading three different sources on equality in less than a week. They all said pretty much the same thing, with a minor variation here or there. It got me thinking about it, and I have some thoughts to share. For this discussion, I'll assume you are familiar with how to do equality correctly. My examples follow the model given in this article.

class Point (val x : Int, val y : Int) {
override def toString = "(%d, %d)" format (x, y)
}

class Point3D (x : Int, y : Int, val z : Int) extends Point(x,y) {
override def toString = "(%d, %d, %d)" format (x, y, z)
}

Now, a proper equality method in these classes would look like the following:

class Point (val x : Int, val y : Int) {
override def toString = "(%d, %d)" format (x, y)
override def hashCode = 41 * (41 + x) + y
override def equals(other : Any) : Boolean = other match {
case that : Point => (
that.canEqual(this)
&& this.x == that.x
&& this.y == that.y
)
case _ => false
}
def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]
}

class Point3D (x : Int, y : Int, val z : Int) extends Point(x,y) {
override def toString = "(%d, %d, %d)" format (x, y, z)
override def hashCode = 41 * (41 * (41 + x) + y) + z
override def equals(other : Any) : Boolean = other match {
case that : Point3D => (
that.canEqual(this)
&& super.equals(that)
&& this.z == that.z
)
case _ => false
}
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]
}

Now, Scala has, starting with version 2.8, a Hashable trait, with which we can simplify things:

class Point (val x : Int, val y : Int) extends scala.util.Hashable {
override def toString = "(%d, %d)" format (x, y)
override def hashValues = List(x, y)
override def equals(other : Any) : Boolean = other match {
case that : Point => (
that.canEqual(this)
&& this.x == that.x
&& this.y == that.y
)
case _ => false
}
def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]
}

class Point3D (x : Int, y : Int, val z : Int) extends Point(x,y) {
override def toString = "(%d, %d, %d)" format (x, y, z)
override def hashValues = List(x, y, z)
override def equals(other : Any) : Boolean = other match {
case that : Point3D => (
that.canEqual(this)
&& super.equals(that)
&& this.z == that.z
)
case _ => false
}
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]
}

While it doesn't seem to have gained us anything, it might for larger objects, and, at any rate, it removes the "magic" of a hash code, and let someone else worry how to do it.

Still, there's a lot of stuff in there just to get equality right, and these are pretty simple classes. What we see is a programming pattern, but one so common and so important that, in my opinion, it merits special attention from the language itself.

Barring that, let's see what we can do programmatically about it. I'll start with a helper function, and how it would be used:

def testEquality(one : AnyRef, other : AnyRef, elementsOne : Seq[Any], elementsOther : Seq[Any]) : Boolean = {
val classOfOne = one.getClass
val classOfOther = other.getClass

def sameClassOrSubclass: Boolean = classOfOne.isAssignableFrom(classOfOther)

def superEquals : Boolean = try {
val superEqualsMethod = classOfOne.getSuperclass.getMethod("equals", classOf[Any])
if (superEqualsMethod.getDeclaringClass != classOf[Any])
superEqualsMethod.invoke(one, other) match {
case flag : java.lang.Boolean => flag.booleanValue // Translate boxed boolean into boolean
case _ => error("Method equals on the parent class of object " + one + " does not return a boolean")
}
else true
} catch {
case _ => true
}

def canEqual : Boolean = try {
val canEqualMethod = classOfOther.getMethod("canEqual", classOf[Any])
canEqualMethod.invoke(other, one) match {
case flag : java.lang.Boolean => flag.booleanValue // Translate boxed boolean into boolean
case _ => error("Method canEqual on object " + other + " does not return a boolean")
}
} catch {
case _ => true
}

def elementsEquals : Boolean = {
elementsOne.zip(elementsOther).foldLeft(true) {
(equals, pair) => equals && pair._1 == pair._2
} && elementsOne.size == elementsOther.size
}

(sameClassOrSubclass(classOfOne, classOfOther)
&& canEqual
&& superEquals
&& elementsEquals
)
}

class Point (val x : Int, val y : Int) extends scala.util.Hashable {
override def toString = "(%d, %d)" format (x, y)

// Hashable
override def hashValues = List(x, y)

// Equality
override def equals(other : Any) : Boolean = other match {
case that : Point => testEquality(this, that, List(x, y), List(that.x, that.y))
case _ => false
}
def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]
}

class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) {
override def toString = "(%d, %d, %d)" format (x, y, z)

// Hashable
override def hashValues = List(x, y)

// Equality
override def equals(other : Any) : Boolean = other match {
case that : Point3D => testEquality(this, that, List(x, y, z), List(that.x, that.y, that.z))
case _ => false
}
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]
}

Now, this method always calls the parent's equals, unless it's Any's equals. You might want to parametrize this. Also, it expects canEqual to be defined if needed, which might lead to bugs. Furthermore, its usage of reflection makes it slower than needed. Finally, the definition of equals is not that much simpler than what we had before.

So, can we do better? Ideally, one could build a trait similar to Hashable, but it turns out that is not that simple. Let's try:

trait Equatable extends scala.util.Hashable {
protected type EquateThis <: Equatable
private def equalsFromAny : Boolean = {
this.getClass.getSuperclass.getMethod("equals", classOf[Any])
.getDeclaringClass == classOf[Any]
}

protected def equateValues : Seq[Any]

def canEqual (that : Any) : Boolean = true

abstract override def equals(other : Any) : Boolean = (other : @unchecked) match {
case that : EquateThis =>
(
that.canEqual(this)
&& (equalsFromAny || super.equals(that))
&& hashCode == that.hashCode // Can speed up or slow down
&& equateValues.zip(that.equateValues).foldLeft(true) {
(equals, tuple) => equals && tuple._1 == tuple._2
} && equateValues.size == that.equateValues.size
)
case _ => false
}
}

class Point (val x : Int, val y : Int) extends Equatable {
override def toString = "(%d, %d)" format (x, y)

// Hashable definitions
override def hashValues = List(x, y)

// Equatable definitions
override type EquateThis = Point
override def equateValues = List(x, y)
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]
}

class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {
override def toString = "(%d, %d, %d)" format (x, y, z)

// Hashable defintions
override def hashValues = List(x, y, z)

// Equatable defintions
override type EquateThis = Point3D
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]
override def equateValues = List(z)
}

That looks more like it, but it has a few problems still. It still uses reflection, for one thing, to test for super's equals method. Also, you can't parametrize that is it is. It won't get Equatable's own equals, though, as traits compiles down to part of the class being defined, not as an ancestor to it.

Next, it has a default for canEqual, and a dangerous one at that. If the programmer forgets to override it, it will lead to trouble.

But, most importantly, it doesn't work. The class Point3D can't override Point's definition for type EquateThis. I don't understand precisely why this is the case, and I'd be glad if anyone stepped in to explain this.

Anyway, let fix these problems:

trait Equatable extends scala.util.Hashable {
protected def testSuperEquals : Boolean
protected def equateValues : Seq[Any]
def canEqual (that : Any) : Boolean

protected def equalsTo[This <: Equatable](other : Any) : Boolean = (other : @unchecked) match {
case that : This =>
(
that.canEqual(this)
&& ((!testSuperEquals) || super.equals(that))
&& hashCode == that.hashCode // Can speed up or slow down
&& equateValues.zip(that.equateValues).foldLeft(true) {
(equals, tuple) => equals && tuple._1 == tuple._2
} && equateValues.size == that.equateValues.size
)
case _ => false
}
}

class Point (val x : Int, val y : Int) extends Equatable {
override def toString = "(%d, %d)" format (x, y)

// Hashable definitions
override def hashValues = List(x, y)

// Equatable definitions
override def testSuperEquals = false
override def equateValues = List(x, y)
override def equals(other : Any) = equalsTo[Point](other)
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]
}

class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {
override def toString = "(%d, %d, %d)" format (x, y, z)

// Hashable defintions
override def hashValues = List(x, y, z)

// Equatable defintions
override def testSuperEquals = true
override def equals(other : Any) = equalsTo[Point3D](other)
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]
override def equateValues = List(z)
}
This finally get us where we wanted. Or as close to as I could, at least. :-)

The definition of canEqual is made abstract. That forces the first class to mix Equatable in to define it, even if to a default of "true". Next, instead of trying to figure out if equality must be called on the super, we simply require the class to tell us.

Finally, the equals method. We can't (or I couldn't) get the trait to define it, but I got pretty close. In the end, class still has to define an equals method, but that method pretty much gets reduced to a simple call to one defined in the trait, passing the object being compared to and the class expected. The class expected gets passed as an explicit type parametrization.

Neat. I didn't think I'd be able to get this much! Now, who's going to port it to Java?

Update:
Ok, I spoke too soon. This method fails here:

scala> val x = new Point(1, 2); val x2 = new Point(1, 2)
x: Point = (1, 2)
x2: Point = (1, 2)

scala> x == x2 // super.equals does not get called, so we do not perform reference equality
res7: Boolean = true

scala> val y = new Point(2, 1)
y: Point = (2, 1)

scala> x == y // expected false
res8: Boolean = false

scala> val z = new Point3D(1, 2, 0)
z: Point3D = (1, 2, 0)

scala> x == z // false in the canEqual call
res9: Boolean = false

scala> z == x // false in the type check
res10: Boolean = false

scala> val z2 = new Point3D(2, 1, 0)
z2: Point3D = (2, 1, 0)

scala> z == z2 // false in the super.equals call
res11: Boolean = false

scala> val z3 = new Point3D(1, 2, 0)
z3: Point3D = (1, 2, 0)

scala> z == z3 // false in the super.super.equals call - it should have been true
res12: Boolean = false

Anyone up for fixing it?

Update 2:
Ok, I fixed it. I resorted to delegating this task to the calling class. That means I do away with testSuperEquals, but require a second parameter to equalsTo. I make it by name, so that it doesn't get evaluated needlessly.

trait Equatable extends scala.util.Hashable {
protected def equateValues : Seq[Any]
def canEqual (that : Any) : Boolean

protected def equalsTo[This <: Equatable](other : Any, superEquals : => Boolean) : Boolean = (other : @unchecked) match {
case that : This =>
(
that.canEqual(this)
&& superEquals
&& hashCode == that.hashCode // Can speed up or slow down
&& equateValues.zip(that.equateValues).foldLeft(true) {
(equals, tuple) => equals && tuple._1 == tuple._2
} && equateValues.size == that.equateValues.size
)
case _ => false
}}
}

class Point (val x : Int, val y : Int) extends Equatable {
override def toString = "(%d, %d)" format (x, y)

// Hashable definitions
override def hashValues = List(x, y)

// Equatable definitions
override def equateValues = List(x, y)
override def equals(other : Any) = equalsTo[Point](other, true)
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]
}

class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {
override def toString = "(%d, %d, %d)" format (x, y, z)

// Hashable defintions
override def hashValues = List(x, y, z)

// Equatable defintions
override def equals(other : Any) = equalsTo[Point3D](other, super.equals(other))
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]
override def equateValues = List(z)
}

And the testing:

scala> val x = new Point(1, 2); val x2 = new Point(1, 2)
x: Point = (1, 2)
x2: Point = (1, 2)

scala> x == x2 // super.equals does not get called, so we do not perform reference equality
res105: Boolean = true

scala> val y = new Point(2, 1)
y: Point = (2, 1)

scala> x == y // expected false
res106: Boolean = false

scala> val z = new Point3D(1, 2, 0)
z: Point3D = (1, 2, 0)

scala> x == z // false in the canEqual call
res107: Boolean = false

scala> z == x // false in the type check
res108: Boolean = false

scala> val z2 = new Point3D(2, 1, 0)
z2: Point3D = (2, 1, 0)

scala> z == z2 // false in the super.equals call
res109: Boolean = true

scala> val z3 = new Point3D(1, 2, 0)
z3: Point3D = (1, 2, 0)

scala> z == z3 // expected true
res110: Boolean = true

Update 3:
This is still seriously broken, as the very test above indicates. What is happening is that when super.equals gets called, it uses this.equateValues instead of the super's version of it. It might be easy to fix for "this", but not for "that". At this point, I'm giving up on super.equals. Let's assume the equality for all classes is defined by the equalsTo method, and require a call to super.equateValues at every subclass (that finds it necessary). Here it is

trait Equatable extends scala.util.Hashable {
protected def equateValues : Seq[Any]
def canEqual (that : Any) : Boolean

protected def equalsTo[This <: Equatable](other : Any) : Boolean = (other : @unchecked) match {
case that : This =>
(
that.canEqual(this)
&& hashCode == that.hashCode // Can speed up or slow down
&& equateValues.zip(that.equateValues).foldLeft(true) {
(equals, tuple) => equals && tuple._1 == tuple._2
} && equateValues.size == that.equateValues.size
)
case _ => false
}}
}

class Point (val x : Int, val y : Int) extends Equatable {
override def toString = "(%d, %d)" format (x, y)

// Hashable definitions
override def hashValues = List(x, y)

// Equatable definitions
override def equateValues = List(x, y)
override def equals(other : Any) = equalsTo[Point](other)
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]
}

class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {
override def toString = "(%d, %d, %d)" format (x, y, z)

// Hashable defintions
override def hashValues = List(x, y, z)

// Equatable defintions
override def equals(other : Any) = equalsTo[Point3D](other)
override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]
override def equateValues = z :: super.equateValues
}

And test:

scala> val x = new Point(1, 2); val x2 = new Point(1, 2)
x: Point = (1, 2)
x2: Point = (1, 2)

scala> x == x2 // super.equals does not get called, so we do not perform reference equality
res138: Boolean = true

scala> val y = new Point(2, 1)
y: Point = (2, 1)

scala> x == y // expected false
res139: Boolean = false

scala> val z = new Point3D(1, 2, 0)
z: Point3D = (1, 2, 0)

scala> x == z // false in the canEqual call
res140: Boolean = false

scala> z == x // false in the type check
res141: Boolean = false

scala> val z2 = new Point3D(2, 1, 0)
z2: Point3D = (2, 1, 0)

scala> z == z2 // expected false
res142: Boolean = false

scala> val z3 = new Point3D(1, 2, 0)
z3: Point3D = (1, 2, 0)

scala> z == z3 // expected true
res143: Boolean = true