Saturday, May 14, 2011

A Cute Algorithm

These days I read about an algorithm challenge: given two sorted arrays, find the k-th minimum element from their merge.

Well, if you do merge them, you can just get the element at index k, and the merge can be done in O(n + m), where n and m are the respective size of each array.

The solution given is O(k) and pretty simple: keep an index into each array, and increase one or other until you reach k. It can be done in O(log k), though, and, fortunately for me, my first idea on how to solve it in O(k) is more easily adaptable.

My own O(k) version is like this: point an index at the k-th element on the first array, and another at the first element of the second array. If the element on the first array is smaller than the element on the first array, return that. Otherwise, as long as the element on the first array is bigger than the element on the second array, decrease the first index and increase the second. After doing that, you'll have the elements on each array that make up the k smallest elements, the k-th being the bigger between the top one in each array.

In code, something like this:

def kMin(a1: Array[Int], a2: Array[Int], k: Int) = {
    def recurse(k2: Int): Int =
        if (a1(k - k2 - 1) < a2(k2)) recurse(k2 + 1)
        else k2

    if (a1(k - 1) < a2(0)) a1(k - 1)
    else {
      val k2 = recurse(1)
      a1(k - k2 - 1) max a2(k2 - 1)

Now, that code isn't particularly good, as there are some conditions that can break it. For instance, if the first array's size is smaller than k, you'll get an array index out of bounds exception. However, it gives the basis for explaining the O(k) algorithm.

Here we search linearly for the k smallest elements of both arrays together, but we know these arrays are sorted. So, instead of going one by one, we can use binary search instead, and turn it into O(log k)!

The concept is simple. We are looking into the k smallest elements of the two array together, so we know beforehand that the maximum number of elements we need to look into either array is k.

We'll search one array for the biggest element that is smaller than or equal to the k-th minimum, with the upper bound being the k-th element of that array, and the lower bound being 0 (meaning the k smallest elements are all in the other array).

To check if the number x-th is among the k-smallest ones, we see if that number is smaller than the (k - x)-th element on the other array. If it is, then x is among the k smallest. The intuitive explanation for that is that, if you take (k - x) elements from the one array and x elements from the other, you get exactly k elements. No element y > x in x's array will be smaller than x, since the array is sorted. And since the (k - x)-th element in the other array is also bigger than it, then no other element in the other array can be smaller either.

So, as long as we find an element that belongs in the k-th smallest, we move the lower bound. If we find an element that does not belong in the k-th smallest, we move the upper bound below it.

Once we find how many elements in one array belong in the k-smallest, we also know how many elements we must take from the other array. Pick the biggest among the biggest in each array, and you have the k-th smallest.

Here's the code below, which is much more concise than the above explanation. It finds the k-th smallest element, with k=1 being the smallest element of all. It assumes there are at least k elements overall on the arrays, though k may be bigger than the number of elements on one array. In fact, either array may be empty (but not both). One can find this code at my github repository, along with an sbt project and a Scalacheck test case.

    def kMin(a1: Array[Int], a2: Array[Int], k: Int): Int = {
      def select(k2: Int) = k2 match {
        case `k` => a2(k - 1)
        case 0   => a1(k - 1)
        case _   => a1(k - k2 - 1) max a2(k2 - 1)
      def recurse(top: Int, bottom: Int): Int = {
        if (top == bottom) select(top)
        else {
          val x = (bottom + top) / 2 max bottom + 1
          if (a1(k - x) <= a2(x - 1)) recurse (x - 1, bottom)
          else recurse(top, x)
      recurse(k min a2.size, 0 max k - a1.size)


  1. Perhaps this is stated in the original problem, but your solution is assuming that the 'k'-th smallest is based on "position" within in the array instead of the 'k'-th smallest number overall.

    Limiting the example to 1 array, there are two interpretations of 'k' smallest:


    Given your solution you would choose 1 as the k-smallest integer whereas I would argue that the k-smallest is 4.

  2. @corruptmemory, if you allow duplicate elements the elementary action "get the k^{th} element of the sorted array" requires at least O(min(k,n-k)). Since imagine you checked through all elements 1..k, and imagine they're all distinct. In that case, ar[k] is the k^{th} element if ar[0]≠ar[1] and ar[k+1] is the k^{th} element if ar[0]=ar[1].

    You can keep the same algorithm and use described O(k) method of finding the k^th element instead of just applying ar(k), though.

  3. Well, the point I was getting at is that the solution provided only works if the two sublists to be merged to not contain duplicate elements before the k-th element (of course you may not know this fact before looking for the k-th element). So a general, and I would argue correct, solution needs to be able to handle duplicate elements correctly. I do not see a solution that is O(k) for the general case, take, for example:

    val k = 3
    val m = Array(1,1,1,1,1,1,1,1,1,2,3)
    val n = Array(1,1,1,1,1,1,1,1,1)

    This particular problem will require m.length+n.length tests as best as I can tell.

  4. Correction: m.length tests, sorry.

  5. @corruptmemory

    There are two issues here: whether my interpretation of k-smallest number is correct or not, and how to adapt the algorithm to make it work according to your interpretation.

    The k-smallest number is formally known as the k-th order statistic. It is usually defined over distinct numbers, but, for a list of n numbers, the minimum is the 1st order statistic, the maximum the n-th order statistic, and, if n = m * 2, the median is the (m+1) order statistic. Given that, I feel my definition is correct -- or, at least, more common.

    Regardless, there's the question of how to adapt the algorithm to meet your own definition. First, I can guarantee that you won't get better than O(min(k, log N+M)), since you'll need to "confirm" at least k numbers to be distinct.

    However, you _can_ get that performance, since one can count the number of distinct elements in a list in O(distinct numbers). In fact, the very same blog that inspired post has that algorithm as well:

    So, one possible solution is the following. First, adapt that algorithm to just store the numbers it finds in an array. It only need to get the first k numbers -- once it does that, it can finish.

    Now apply this algorithm to each of the two sorted arrays. Finally, run my algorithm over the output of the previous step, which is guaranteed to only contain distinct numbers.

    Seems to me that the performance should be O(k), or, perhaps, O(k log k).

  6. Interesting.


    A couple of points:
    1. The algorithm on that page doesn't work. It counts incorrectly:

    "if you have, A[] = {1,1,1,2,2,2,3,3,4}
    you should print, 1=>3, 2=>3, 3=>2, 4=>1"

    The program produces:
    1=>3, 2=>1, 3=>1, 4=>1

    The author even states:
    "I came up with a approach but cannot make that into a working code"

    Now, the O(n) is on the number of comparisons, not the number of recursions of the function. Given that, I augmented the code to count the number of comparisons actually executed:

    Input: 1,1,1,2,2,2,3,3,4
    Comparisons: 9

    Input: 1,2,2,3,3,3,4,4,4,5,5,5,5,6,6,6,6,7,7,8,9
    Comparisons: 34

    Input: 1,1,2,2,2,3,4,4,4,5,5,6,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,8,9
    Comparisons: 61

    So the performance is worse than the naive approach of just walking the list. This is because the algorithm is based on guessing and the costs for bad guessing are very high.

    There may be an algorithm that can do this in better than O(n), but I haven't seen it.

    Frankly, I think that the answer the interviewer is looking for is a parallel one where this problem can be solved in O(n)/m time where m is the number of parallel partitions.

    RE: definition of k-smallest

    That's fine. That particular definition is not one I've encountered, on the other hand the one I proposed is one I encounter quite frequently. Different strokes for different folks I guess.