## Tuesday, June 16, 2009

### Equality & Scala

I'm just through reading three different sources on equality in less than a week. They all said pretty much the same thing, with a minor variation here or there. It got me thinking about it, and I have some thoughts to share. For this discussion, I'll assume you are familiar with how to do equality correctly. My examples follow the model given in this article.

`class Point (val x : Int, val y : Int) { override def toString = "(%d, %d)" format (x, y)}class Point3D (x : Int, y : Int, val z : Int) extends Point(x,y) {  override def toString = "(%d, %d, %d)" format (x, y, z)}`

Now, a proper equality method in these classes would look like the following:

`class Point (val x : Int, val y : Int) {  override def toString = "(%d, %d)" format (x, y)  override def hashCode = 41 * (41 + x) + y  override def equals(other : Any) : Boolean = other match {    case that : Point => (      that.canEqual(this)      && this.x == that.x      && this.y == that.y    )    case _ => false  }  def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]}class Point3D (x : Int, y : Int, val z : Int) extends Point(x,y) {  override def toString = "(%d, %d, %d)" format (x, y, z)  override def hashCode = 41 * (41 * (41 + x) + y) + z  override def equals(other : Any) : Boolean = other match {    case that : Point3D => (      that.canEqual(this)      && super.equals(that)      && this.z == that.z    )    case _ => false  }  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]}`

Now, Scala has, starting with version 2.8, a Hashable trait, with which we can simplify things:

`class Point (val x : Int, val y : Int) extends scala.util.Hashable {  override def toString = "(%d, %d)" format (x, y)  override def hashValues = List(x, y)  override def equals(other : Any) : Boolean = other match {    case that : Point => (      that.canEqual(this)      && this.x == that.x      && this.y == that.y    )    case _ => false  }  def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]}class Point3D (x : Int, y : Int, val z : Int) extends Point(x,y) {  override def toString = "(%d, %d, %d)" format (x, y, z)  override def hashValues = List(x, y, z)  override def equals(other : Any) : Boolean = other match {    case that : Point3D => (      that.canEqual(this)      && super.equals(that)      && this.z == that.z    )    case _ => false  }  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]}`

While it doesn't seem to have gained us anything, it might for larger objects, and, at any rate, it removes the "magic" of a hash code, and let someone else worry how to do it.

Still, there's a lot of stuff in there just to get equality right, and these are pretty simple classes. What we see is a programming pattern, but one so common and so important that, in my opinion, it merits special attention from the language itself.

Barring that, let's see what we can do programmatically about it. I'll start with a helper function, and how it would be used:

`def testEquality(one : AnyRef, other : AnyRef, elementsOne : Seq[Any], elementsOther : Seq[Any]) : Boolean = {  val classOfOne = one.getClass  val classOfOther = other.getClass  def sameClassOrSubclass: Boolean = classOfOne.isAssignableFrom(classOfOther)  def superEquals : Boolean = try {    val superEqualsMethod = classOfOne.getSuperclass.getMethod("equals", classOf[Any])    if (superEqualsMethod.getDeclaringClass != classOf[Any])      superEqualsMethod.invoke(one, other) match {        case flag : java.lang.Boolean => flag.booleanValue // Translate boxed boolean into boolean        case _ => error("Method equals on the parent class of object " + one + " does not return a boolean")      }    else true  } catch {    case _ => true  }  def canEqual : Boolean = try {    val canEqualMethod = classOfOther.getMethod("canEqual", classOf[Any])    canEqualMethod.invoke(other, one) match {      case flag : java.lang.Boolean => flag.booleanValue // Translate boxed boolean into boolean      case _ => error("Method canEqual on object " + other + " does not return a boolean")    }  } catch {    case _ => true  }  def elementsEquals : Boolean = {    elementsOne.zip(elementsOther).foldLeft(true) {      (equals, pair) => equals && pair._1 == pair._2    } && elementsOne.size == elementsOther.size  }  (sameClassOrSubclass(classOfOne, classOfOther)    && canEqual    && superEquals    && elementsEquals  )}class Point (val x : Int, val y : Int) extends scala.util.Hashable {  override def toString = "(%d, %d)" format (x, y)  // Hashable  override def hashValues = List(x, y)  // Equality  override def equals(other : Any) : Boolean = other match {    case that : Point => testEquality(this, that, List(x, y), List(that.x, that.y))    case _ => false  }  def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]}class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) {  override def toString = "(%d, %d, %d)" format (x, y, z)  // Hashable  override def hashValues = List(x, y)  // Equality  override def equals(other : Any) : Boolean = other match {    case that : Point3D => testEquality(this, that, List(x, y, z), List(that.x, that.y, that.z))    case _ => false  }  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]}`

Now, this method always calls the parent's equals, unless it's Any's equals. You might want to parametrize this. Also, it expects canEqual to be defined if needed, which might lead to bugs. Furthermore, its usage of reflection makes it slower than needed. Finally, the definition of equals is not that much simpler than what we had before.

So, can we do better? Ideally, one could build a trait similar to Hashable, but it turns out that is not that simple. Let's try:

`trait Equatable extends scala.util.Hashable {  protected type EquateThis <: Equatable  private def equalsFromAny : Boolean = {    this.getClass.getSuperclass.getMethod("equals", classOf[Any])    .getDeclaringClass == classOf[Any]  }  protected def equateValues : Seq[Any]  def canEqual (that : Any) : Boolean = true  abstract override def equals(other : Any) : Boolean = (other : @unchecked) match {    case that : EquateThis =>      (        that.canEqual(this)        && (equalsFromAny || super.equals(that))        && hashCode == that.hashCode // Can speed up or slow down        && equateValues.zip(that.equateValues).foldLeft(true) {          (equals, tuple) => equals && tuple._1 == tuple._2        } && equateValues.size == that.equateValues.size      )    case _ => false  }}class Point (val x : Int, val y : Int) extends Equatable {  override def toString = "(%d, %d)" format (x, y)  // Hashable definitions  override def hashValues = List(x, y)  // Equatable definitions  override type EquateThis = Point  override def equateValues = List(x, y)  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]}class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {  override def toString = "(%d, %d, %d)" format (x, y, z)  // Hashable defintions  override def hashValues = List(x, y, z)  // Equatable defintions  override type EquateThis = Point3D  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]  override def equateValues = List(z)}`

That looks more like it, but it has a few problems still. It still uses reflection, for one thing, to test for super's equals method. Also, you can't parametrize that is it is. It won't get Equatable's own equals, though, as traits compiles down to part of the class being defined, not as an ancestor to it.

Next, it has a default for canEqual, and a dangerous one at that. If the programmer forgets to override it, it will lead to trouble.

But, most importantly, it doesn't work. The class Point3D can't override Point's definition for type EquateThis. I don't understand precisely why this is the case, and I'd be glad if anyone stepped in to explain this.

Anyway, let fix these problems:

`trait Equatable extends scala.util.Hashable {  protected def testSuperEquals : Boolean  protected def equateValues : Seq[Any]  def canEqual (that : Any) : Boolean  protected def equalsTo[This <: Equatable](other : Any) : Boolean = (other : @unchecked) match {    case that : This =>      (        that.canEqual(this)        && ((!testSuperEquals) || super.equals(that))        && hashCode == that.hashCode // Can speed up or slow down        && equateValues.zip(that.equateValues).foldLeft(true) {          (equals, tuple) => equals && tuple._1 == tuple._2        } && equateValues.size == that.equateValues.size      )    case _ => false  }}class Point (val x : Int, val y : Int) extends Equatable {  override def toString = "(%d, %d)" format (x, y)  // Hashable definitions  override def hashValues = List(x, y)  // Equatable definitions  override def testSuperEquals = false  override def equateValues = List(x, y)  override def equals(other : Any) = equalsTo[Point](other)  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]}class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {  override def toString = "(%d, %d, %d)" format (x, y, z)  // Hashable defintions  override def hashValues = List(x, y, z)  // Equatable defintions  override def testSuperEquals = true  override def equals(other : Any) = equalsTo[Point3D](other)  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]  override def equateValues = List(z)}`
This finally get us where we wanted. Or as close to as I could, at least. :-)

The definition of canEqual is made abstract. That forces the first class to mix Equatable in to define it, even if to a default of "true". Next, instead of trying to figure out if equality must be called on the super, we simply require the class to tell us.

Finally, the equals method. We can't (or I couldn't) get the trait to define it, but I got pretty close. In the end, class still has to define an equals method, but that method pretty much gets reduced to a simple call to one defined in the trait, passing the object being compared to and the class expected. The class expected gets passed as an explicit type parametrization.

Neat. I didn't think I'd be able to get this much! Now, who's going to port it to Java?

Update:
Ok, I spoke too soon. This method fails here:

`scala> val x = new Point(1, 2); val x2 = new Point(1, 2)x: Point = (1, 2)x2: Point = (1, 2)scala> x == x2 // super.equals does not get called, so we do not perform reference equalityres7: Boolean = truescala> val y = new Point(2, 1)y: Point = (2, 1)scala> x == y // expected falseres8: Boolean = falsescala> val z = new Point3D(1, 2, 0)z: Point3D = (1, 2, 0)scala> x == z // false in the canEqual callres9: Boolean = falsescala> z == x // false in the type checkres10: Boolean = falsescala> val z2 = new Point3D(2, 1, 0)z2: Point3D = (2, 1, 0)scala> z == z2 // false in the super.equals callres11: Boolean = falsescala> val z3 = new Point3D(1, 2, 0)z3: Point3D = (1, 2, 0)scala> z == z3 // false in the super.super.equals call - it should have been trueres12: Boolean = false`

Anyone up for fixing it?

Update 2:
Ok, I fixed it. I resorted to delegating this task to the calling class. That means I do away with testSuperEquals, but require a second parameter to equalsTo. I make it by name, so that it doesn't get evaluated needlessly.

`trait Equatable extends scala.util.Hashable {  protected def equateValues : Seq[Any]  def canEqual (that : Any) : Boolean    protected def equalsTo[This <: Equatable](other : Any, superEquals : => Boolean) : Boolean = (other : @unchecked) match {    case that : This =>      (        that.canEqual(this)        && superEquals        && hashCode == that.hashCode // Can speed up or slow down         && equateValues.zip(that.equateValues).foldLeft(true) {          (equals, tuple) => equals && tuple._1 == tuple._2        } && equateValues.size == that.equateValues.size      )    case _ => false  }}}class Point (val x : Int, val y : Int) extends Equatable {  override def toString = "(%d, %d)" format (x, y)  // Hashable definitions  override def hashValues = List(x, y)    // Equatable definitions  override def equateValues = List(x, y)  override def equals(other : Any) = equalsTo[Point](other, true)  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]}class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {  override def toString = "(%d, %d, %d)" format (x, y, z)  // Hashable defintions  override def hashValues = List(x, y, z)  // Equatable defintions  override def equals(other : Any) = equalsTo[Point3D](other, super.equals(other))  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]  override def equateValues = List(z)}`

And the testing:

`scala> val x = new Point(1, 2); val x2 = new Point(1, 2)x: Point = (1, 2)x2: Point = (1, 2)scala> x == x2 // super.equals does not get called, so we do not perform reference equalityres105: Boolean = truescala> val y = new Point(2, 1)y: Point = (2, 1)scala> x == y // expected falseres106: Boolean = falsescala> val z = new Point3D(1, 2, 0)z: Point3D = (1, 2, 0)scala> x == z // false in the canEqual callres107: Boolean = falsescala> z == x // false in the type checkres108: Boolean = falsescala> val z2 = new Point3D(2, 1, 0)z2: Point3D = (2, 1, 0)scala> z == z2 // false in the super.equals callres109: Boolean = truescala> val z3 = new Point3D(1, 2, 0)z3: Point3D = (1, 2, 0)scala> z == z3 // expected trueres110: Boolean = true`

Update 3:
This is still seriously broken, as the very test above indicates. What is happening is that when super.equals gets called, it uses this.equateValues instead of the super's version of it. It might be easy to fix for "this", but not for "that". At this point, I'm giving up on super.equals. Let's assume the equality for all classes is defined by the equalsTo method, and require a call to super.equateValues at every subclass (that finds it necessary). Here it is

`trait Equatable extends scala.util.Hashable {  protected def equateValues : Seq[Any]  def canEqual (that : Any) : Boolean    protected def equalsTo[This <: Equatable](other : Any) : Boolean = (other : @unchecked) match {    case that : This =>      (        that.canEqual(this)        && hashCode == that.hashCode // Can speed up or slow down         && equateValues.zip(that.equateValues).foldLeft(true) {          (equals, tuple) => equals && tuple._1 == tuple._2        } && equateValues.size == that.equateValues.size      )    case _ => false  }}}class Point (val x : Int, val y : Int) extends Equatable {  override def toString = "(%d, %d)" format (x, y)  // Hashable definitions  override def hashValues = List(x, y)    // Equatable definitions  override def equateValues = List(x, y)  override def equals(other : Any) = equalsTo[Point](other)  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point]}class Point3D(x : Int,y : Int, val z : Int) extends Point(x,y) with Equatable {  override def toString = "(%d, %d, %d)" format (x, y, z)  // Hashable defintions  override def hashValues = List(x, y, z)  // Equatable defintions  override def equals(other : Any) = equalsTo[Point3D](other)  override def canEqual(other : Any) : Boolean = other.isInstanceOf[Point3D]  override def equateValues = z :: super.equateValues}`

And test:

`scala> val x = new Point(1, 2); val x2 = new Point(1, 2)x: Point = (1, 2)x2: Point = (1, 2)scala> x == x2 // super.equals does not get called, so we do not perform reference equalityres138: Boolean = truescala> val y = new Point(2, 1)y: Point = (2, 1)scala> x == y // expected falseres139: Boolean = falsescala> val z = new Point3D(1, 2, 0)z: Point3D = (1, 2, 0)scala> x == z // false in the canEqual callres140: Boolean = falsescala> z == x // false in the type checkres141: Boolean = falsescala> val z2 = new Point3D(2, 1, 0)z2: Point3D = (2, 1, 0)scala> z == z2 // expected falseres142: Boolean = falsescala> val z3 = new Point3D(1, 2, 0)z3: Point3D = (1, 2, 0)scala> z == z3 // expected trueres143: Boolean = true`