Monday, January 24, 2011

Testing with Scala for Fun and Profit

Sometimes I like Scala, and sometimes I really like Scala.

So, I was writing some algorithms to compute the median of a sequence of values, just for the fun of it, adapting from pseudo-code description of an algorithm. The problem with pseudo-code is that it is pseudo-precise, meaning my algorithm was pseudo-correct, so I wanted to test it to ensure it really worked.

As it happens, the median of values has a trivial implementation that could be used to test against. So, one import and one one-liner afterwards, I had a quick way to test it from the REPL, where I could then experiment with the failing test case to understand what went wrong:

import org.scalacheck.Prop._
forAll((lst: List[Double]) => lst.nonEmpty ==> (myAlgorithm(lst) == lst.sorted.apply((lst.size - 1) / 2))).check

That uses the Scalacheck library, the best way to test algorithms in my opinion. What I'm doing with forAll is saying that for all inputs (List[Double]) of that function, the conditions must hold true. Specifically, if the input is not empty, then the result of my algorithm must be equal to the result of the trivial implementation of median. That will result in a property (class Prop).

I then tell it to check that property, which will be done by automatically generating input, with some heuristics to improve the chance of catching boundary cases, and testing the condition for each such input. If 100 tests pass, it finishes by saying:

+ OK, passed 100 tests.

Or, if it fail at some point, it will say something like this:

! Falsified after 0 passed tests.                                             
> ARG_0: List("-1.0") (orig arg: List("-1.0", "8.988465674311579E307"))

To be honest, my one-liner was slightly longer, because I was using arrays and arrays in Java do not have a useful toString method, so I had to tell Scalacheck how to print the array. Both Scalatest and Specs support integration with Scalacheck too, so this can be easily turned into part of a test suite.

Now this helped me achieve correctness, but I was also interested in how fast the code could run. I had three different algorithms (including the trivial implementation), and two of them could be further refined by abstracting over the selection of a pivot (quicksort-style). At that point, I decided to build a small framework which would help me test everything automatically.

I wanted to test each algorithm three times for different input sizes. The basic measurement algorithm was done by one of the methods in my toolkit:

import java.lang.System.currentTimeMillis

def bench[A](n: Int)(body: => A): Long = {
  val start = currentTimeMillis()
  1 to n foreach { _ => body }
  currentTimeMillis() - start

This should be familiar to anyone who ever did microbenchmarking with Scala. With that in hand, another one liner got the results I wanted:


Which worked well enough for a while, but really didn't scale as I got more algorithm variations and tested them with different settings. So, to get the results for each algorithm, I wrote this:

import scala.util.Random.nextDouble

def benchmark(algorithm: Array[Double] => Double,
              arraySizes: List[Int]): List[Iterable[Long]] = 
    for (size <- arraySizes)
    yield for (iteration <- 1 to 3)
        yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))

Which let me pass a list of sizes I wanted to test the stuff at, and run each benchmark three times, to give me a feel for the variation in the results. Next, I made a list of the algorithms I wanted tested:

val algorithms = sortingAlgorithm :: immutableAlgorithms

That's the list I started with, but it grew as I added other algorithms. As for the immutable algorithms, they were all the same method call, but passing different pivot selection as parameters. As it got a bit verbose, I decided to apply a bit of DRY. First, I made a list of my pivot selection algorithms:

val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
    "Random Pivot"      -> chooseRandomPivot,
    "Median of Medians" -> medianOfMedians,
    "Midpoint"          -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))

Next, I used that list to produce a list of median algorithms using each of these pivot selections:

val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
        yield name -> (findMedian(_: Array[Double])(pivotSelection))

With the list of algorithms in hand, it was a simple for comprehension to produce a list of results:

val benchmarkResults: List[String] = for {
    (name, algorithm) <- algorithms
    results <- benchmark(algorithm, arraySizes).transpose
} yield formattingString format (name, formatResults(results))

The transposition let me see each size on a different column, making it easier to compare the algorithms when displayed together. Anyway, once I had that ready, I could also easily use the list of algorithms to test them:

def test(algorithm: Array[Double] => Double, 
         reference: Array[Double] => Double): String = {
    def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
    val resultEqualsReference = forAll { (arr: Array[Double]) => 
        arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
    Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))

val testResults = for ((name, algorithm) <- algorithms)
    yield formattingString format (name, test(algorithm, sortingAlgorithm._2))

I could even make the test method more general, but parameterizing the type of the algorithm, and adding a parameter for a no-parameter function generating the input. This could be easily done, but it was not needed by the time I was finished.

I find the final result of a certain elegance (then again, I'm obviously biased), that's not the reason I really like Scala. What I really, really like about it, is how I could start very small, with a few easy commands on the REPL, and then use that as the basis for an increasingly more flexible framework to do the tests I wanted.

If anyone is interested in seeing the full code, it can be found here.