Saturday, May 21, 2011

Scala 2.9 optimizes for comprehensions way better!

Ok, I completely missed this. For comprehensions in Scala 2.9 was way better optimized with the parameter -optimize than they were before! Take this code:

class OptEx {
    def sum(l: Array[Int]) = {
        var acc = 0
        for (i <- 0 until l.length) acc += l(i)
This is the java bytecode generated with Scala 2.8.1 for the method sum:
public int sum(int[]);
   0:   new     #7; //class scala/runtime/IntRef
   3:   dup
   4:   iconst_0
   5:   invokespecial   #12; //Method scala/runtime/IntRef."":(I)V
   8:   astore_2
   9:   new     #14; //class scala/runtime/RichInt
   12:  dup
   13:  iconst_0
   14:  invokespecial   #15; //Method scala/runtime/RichInt."":(I)V
   17:  aload_1
   18:  arraylength
   19:  invokevirtual   #19; //Method scala/runtime/RichInt.until:(I)Lscala/collection/immutable/Range$ByOne;
   22:  new     #21; //class OptEx$$anonfun$sum$1
   25:  dup
   26:  aload_0
   27:  aload_1
   28:  aload_2
   29:  invokespecial   #24; //Method OptEx$$anonfun$sum$1."":(LOptEx;[ILscala/runtime/IntRef;)V
   32:  invokeinterface #30,  2; //InterfaceMethod scala/collection/immutable/Range$ByOne.foreach$mVc$sp:(Lscala/Functio
   37:  aload_2
   38:  getfield        #34; //Field scala/runtime/IntRef.elem:I
   41:  ireturn
And this is what Scala 2.9.0 does:
public int sum(int[]);
   0:   new     #7; //class scala/runtime/IntRef
   3:   dup
   4:   iconst_0
   5:   invokespecial   #12; //Method scala/runtime/IntRef."":(I)V
   8:   astore  6
   10:  new     #14; //class scala/runtime/RichInt
   13:  dup
   14:  iconst_0
   15:  invokespecial   #15; //Method scala/runtime/RichInt."":(I)V
   18:  aload_1
   19:  arraylength
   20:  istore_3
   21:  astore_2
   22:  getstatic       #21; //Field scala/collection/immutable/Range$.MODULE$:Lscala/collection/immutable/Range$;
   25:  aload_2
   26:  invokevirtual   #25; //Method scala/runtime/RichInt.self:()I
   29:  iload_3
   30:  invokevirtual   #29; //Method scala/collection/immutable/Range$.apply:(II)Lscala/collection/immutable/Range;
   33:  dup
   34:  astore  8
   36:  invokevirtual   #34; //Method scala/collection/immutable/Range.length:()I
   39:  iconst_0
   40:  if_icmple       83
   43:  aload   8
   45:  invokevirtual   #37; //Method scala/collection/immutable/Range.last:()I
   48:  istore  4
   50:  aload   8
   52:  invokevirtual   #40; //Method scala/collection/immutable/Range.start:()I
   55:  istore  9
   57:  iload   9
   59:  iload   4
   61:  if_icmpne       89
   64:  iload   9
   66:  istore  5
   68:  aload   6
   70:  aload   6
   72:  getfield        #44; //Field scala/runtime/IntRef.elem:I
   75:  aload_1
   76:  iload   5
   78:  iaload
   79:  iadd
   80:  putfield        #44; //Field scala/runtime/IntRef.elem:I
   83:  aload   6
   85:  getfield        #44; //Field scala/runtime/IntRef.elem:I
   88:  ireturn
   89:  iload   9
   91:  istore  7
   93:  aload   6
   95:  aload   6
   97:  getfield        #44; //Field scala/runtime/IntRef.elem:I
   100: aload_1
   101: iload   7
   103: iaload
   104: iadd
   105: putfield        #44; //Field scala/runtime/IntRef.elem:I
   108: iload   9
   110: aload   8
   112: invokevirtual   #47; //Method scala/collection/immutable/Range.step:()I
   115: iadd
   116: istore  9
   118: goto    57

Time to take your old benchmarks out of the closet, people!

Thursday, May 19, 2011

Regex Again

I have been thinking about regex lately. I have never felt comfortable with how Scala regex works, but I could never settle on what should be done about. Recently, I have started more and more of thinking of regex like this:

class RegexF(pattern: String) extends String => Option[Seq[String]]

or, perhaps,

class RegexPF(pattern: String) extends PartialFunction[String, Seq[String]]

In fact, RegexPF.lift would (could) yield a RegexF. It then caught my attention that RegexF.apply has the same signature as Regex.unapplySeq, which is the standard way of handling regex in Scala!

Might this be what has been bugging me about Scala's regex all along? Should we translate

val YYYYMMDD = """(\d{4})-(\d{2})-(\d{2})""".r
val MMDDYYYY = """(\d{2})/(\d{2})/(\d{4})""".r

def getYear(s: String) = s match {
    case YYYYMMDD(year, _, _) => year
    case MMDDYYYY(_, _, year) => year


val YYYYMMDD = """(\d{4})-(\d{2})-(\d{2})""".r
val MMDDYYYY = """(\d{2})/(\d{2})/(\d{4})""".r andThen (fields => fields.last +: fields.init)

def getYear(s: String) = ((YYYYMMDD orElse MMDDYYYY) andThen (_.head))(s)

I can certainly see the advantages of pattern matching, but... it doesn't compose very well. And it has some performance issues, which is a big deal for most regex usages. And being a PartialFunction would not prevent a Regex from having extractors as well.

Saturday, May 14, 2011

A Cute Algorithm

These days I read about an algorithm challenge: given two sorted arrays, find the k-th minimum element from their merge.

Well, if you do merge them, you can just get the element at index k, and the merge can be done in O(n + m), where n and m are the respective size of each array.

The solution given is O(k) and pretty simple: keep an index into each array, and increase one or other until you reach k. It can be done in O(log k), though, and, fortunately for me, my first idea on how to solve it in O(k) is more easily adaptable.

My own O(k) version is like this: point an index at the k-th element on the first array, and another at the first element of the second array. If the element on the first array is smaller than the element on the first array, return that. Otherwise, as long as the element on the first array is bigger than the element on the second array, decrease the first index and increase the second. After doing that, you'll have the elements on each array that make up the k smallest elements, the k-th being the bigger between the top one in each array.

In code, something like this:

def kMin(a1: Array[Int], a2: Array[Int], k: Int) = {
    def recurse(k2: Int): Int =
        if (a1(k - k2 - 1) < a2(k2)) recurse(k2 + 1)
        else k2

    if (a1(k - 1) < a2(0)) a1(k - 1)
    else {
      val k2 = recurse(1)
      a1(k - k2 - 1) max a2(k2 - 1)

Now, that code isn't particularly good, as there are some conditions that can break it. For instance, if the first array's size is smaller than k, you'll get an array index out of bounds exception. However, it gives the basis for explaining the O(k) algorithm.

Here we search linearly for the k smallest elements of both arrays together, but we know these arrays are sorted. So, instead of going one by one, we can use binary search instead, and turn it into O(log k)!

The concept is simple. We are looking into the k smallest elements of the two array together, so we know beforehand that the maximum number of elements we need to look into either array is k.

We'll search one array for the biggest element that is smaller than or equal to the k-th minimum, with the upper bound being the k-th element of that array, and the lower bound being 0 (meaning the k smallest elements are all in the other array).

To check if the number x-th is among the k-smallest ones, we see if that number is smaller than the (k - x)-th element on the other array. If it is, then x is among the k smallest. The intuitive explanation for that is that, if you take (k - x) elements from the one array and x elements from the other, you get exactly k elements. No element y > x in x's array will be smaller than x, since the array is sorted. And since the (k - x)-th element in the other array is also bigger than it, then no other element in the other array can be smaller either.

So, as long as we find an element that belongs in the k-th smallest, we move the lower bound. If we find an element that does not belong in the k-th smallest, we move the upper bound below it.

Once we find how many elements in one array belong in the k-smallest, we also know how many elements we must take from the other array. Pick the biggest among the biggest in each array, and you have the k-th smallest.

Here's the code below, which is much more concise than the above explanation. It finds the k-th smallest element, with k=1 being the smallest element of all. It assumes there are at least k elements overall on the arrays, though k may be bigger than the number of elements on one array. In fact, either array may be empty (but not both). One can find this code at my github repository, along with an sbt project and a Scalacheck test case.

    def kMin(a1: Array[Int], a2: Array[Int], k: Int): Int = {
      def select(k2: Int) = k2 match {
        case `k` => a2(k - 1)
        case 0   => a1(k - 1)
        case _   => a1(k - k2 - 1) max a2(k2 - 1)
      def recurse(top: Int, bottom: Int): Int = {
        if (top == bottom) select(top)
        else {
          val x = (bottom + top) / 2 max bottom + 1
          if (a1(k - x) <= a2(x - 1)) recurse (x - 1, bottom)
          else recurse(top, x)
      recurse(k min a2.size, 0 max k - a1.size)

Thursday, May 12, 2011

Scala 2.9 and Parallel collections

So, Scala 2.9.0 is out. Also, the Typesafe Stack is also out, which brings together Scala, Akka, and a few other things to get one up-and-running quickly. Much fun.

On the collection side of things, one of the first questions I saw was: do parallel collections share a common interface with standard collections. The answer is yes, they do, but not one that existed in 2.8.1.

You see, a trouble with parallel collections is that, now that they are available, people will probably be passing them around. If they could be passed to old code -- as it was briefly contemplated -- that old code could crash in mysterious ways. In fact, it happens with REPL itself.

For that reason, ALL of your code comes with a guarantee that it will only accept sequential collections. In other words, Iterable, Seq, Set, etc, they all now share a guarantee to be sequential, which means you cannot pass a parallel sequence to a method expecting Seq.

The parallel collections start with Par: ParIterable, ParSeq, ParSet and ParMap. No ParTraversable for now. These are guaranteed to be parallel. They can be found inside scala.collection.parallel, scala.collection.parallel.immutable, etc.

You can also get a parallel collection just by calling the ".par" method on it, and, similarly, the ".seq" method will return a sequential collection.

Now, if you want your code to not care whether it receives a parallel or sequential collection, you should prefix it with Gen: GenTraversable, GenIterable, GenSeq, etc. These can be either parallel or sequential.

And, now, something fun to try out:

def p[T](coll: collection.GenIterable[T]) = coll foreach println; p(1 to 20); p((1 to 20).par)