Monday, January 27, 2014

Dynamic Programming and Simple Memoization

Dynamic programming was always one of my favorite topics, and so I thought I'd spend some time visualizing the transformation of an exponential time algorithm into a polynomial time algorithm. A lot of dynamic programming algorithms are fairly straight-forward translations of recursions via memoization. The real key, it turns out, is getting the recursion right.

The Viterbi Algorithm and Hidden Markov Models

A Markov model is a probabilistic model that can be imagined with the help of a state machine. Each state has some probability emitting an output. After an output, each state then has some probability of transitioning to other states. Then, you can construct a stream of data by repeatedly querying this model. Of course, you can also attempt to go in the opposite direction, too, and attempt to explain some series of data using a Markov model. In this case, you'll need to construct the most probable series of states for a given series. How can you do this?

The Viterbi algorithm provides us with a recursion relationship that you can use to figure out the the most likely series of transitions. You base this recursion on Vt,k, which is the probability of the most probable state sequence given the first t observations and ends at the state k. From the wikipedia entry:

Vt,k = Pr[ yt | k ] * maxforall x in states[ax,k Vt-1,x]

Where yt is the tth observation, and ax,k is the probability of transitioning from state x to state k. As you can see, this gives a recursion which can be solved directly.

Solving the Recursion Directly

Okay, so how does a direct solution look in the code?

  float p_observations_to_state(int t, int k){
    /*
      P_observations_to_state (V_t_k)
      :=  Pr(the most probable sequence for t observations
      which arrive at state k)
    */
    visit(t, k);
    
    float ret;
    if(t == 1){
      // base case
      ret =  p(obs(0), k) * pi(k);
    }else{
      // recursion
      float max_p_rec = -1;
      for(int x = 0; x < states(); x++){
         float p_x_to_k = transition(x, k) * p_observations_to_state(t-1, x);
         if(p_x_to_k > max_p_rec)
            max_p_rec = p_x_to_k;
      }
      
      ret = p(obs(t-1), k) * max_p_rec;
    }
    
    return ret;
  }

Great! So what's going on here? First of all, I use a bunch of function calls to actually compute the transition and emission probabilities. That's fine. What else is happening? I make a visit() call so that I can track the runtime of the algorithm. You can also see the exponential nature of the recursion more directly once it is translated into code. Each call actually generates K recursions: one for each state in the state-space of the Markov model.

What does this actually look like in practice? Let's imagine each function call of this algorithm as a pair of integers. How many times is each pair visited? Are pairs being visited many times? Or is there some fundamental reason for the exponential recursion that cannot be fixed with memoization? For Markov model with 5 states and a series of 10 observations, I'll plot the number of times each pair (t,k) is visited.

As you can see, pairs on the left edge of the graph which comprise the "base cases" of the recursion, are visited very frequently. Pairs to the right are visited less frequently, and there's a 5x shrinking going on. Of course, this isn't too surprising, given the implementation. What this does indicate, though, is that there is a lot to be gained from memoization. Why? Well, what happens is that the pairs to the right are actually generating a whole slew of calls to the left, so the memoization isn't only eliminating a bunch of recomputation in the "body" of the recursive calls, it's actually eliminating further recursive calls. Essentially, it is going to reduce the number of times the question "What is the Vt,k?" is even asked. This is an important requirement that really changes the runtime of the algorithm

Memoization and Dynamic Programming

Often times, dynamic programming is presented as the technique involving matrices, and operations updating cells, etc. That's great, but it kind of obscures what I believe to be the real underlying principle, namely, memoization. The general idea, is to look at the "call graph" of a recursive algorithm and to say, well, if I cache results, does this thin out the call graph significantly?

First let's see how memoization changes our implementation:

  float p_observations_to_state(int t, int k){
    /*
      P_observations_to_state (V_t_k)
      :=  P(the most probable sequence for t observations
      which arrive at state k)
    */
    
    visit(t, k);

    if(single.memoizing && single.memoized[t][k] != -1){
      return single.memoized[t][k];
    }
    
    float ret;
    if(t == 1){
      // base case
      ret =  p(obs(0), k) * pi(k);
    }else{
      // recursion
      float max_p_rec = -1;
      for(int x = 0; x < states(); x++){
          float p_x_to_k = transition(x, k) * p_observations_to_state(t-1, x);
          if(p_x_to_k > max_p_rec)
              max_p_rec = p_x_to_k;
      }
      
      ret = p(obs(t-1), k) * max_p_rec;
    }
    
    if(single.memoizing)
      single.memoized[t][k] = ret;
    return ret;
  }

Pretty straight forward. It's a little hack-y in that I use "-1" to indicate that a variable hasn't been cached, but I hack, so hack-y is what you get. Notice also that the visit() call occurs before the cache is consulted: this ensures that every function entry is being counted, even ones which only do a cached read. This keeps our runtime numbers honest. So how does this change the number of visits each pair receives?

Woah. As you can see, the left most pairs are visited vastly less than before. And this, of course, is why dynamic programming is so powerful. It can drastically reduce the cost of an algorithm, without really changing the reasoning required. Another thing to note is that pairs are still being visited multiple times, however, the body of the function is only being evaluated once for each pair. Why is there this discrepancy, then? Well, the pair (1, 1) is visited when any pair (2, k) is evaluated. Each of those pairs is evaluated once, and there are five such pairs, so therefore (1, 1) will be visited five times. This

Monday, January 13, 2014

Visualizing What is and What is Not the Sieve of Eratosthenes

The Classic Algorithm

The Sieve is a classic algorithm to create a list of prime numbers. However, this algorithm is often presented wrong, as discussed in an absolutely delightful paper by Melissa O'Neill. This is especially true when people present functional implementations of the algorithm. This post will explore how we can visualize the differences between implementations.

So what is the classic algorithm? To find all the prime numbers less than or equal to a given integer n, Wikipedia gives us the following method:

  1. Create a list of integers from 2 through n: (2, 3, 4, ..., n).
  2. Let p equal 2, the first prime number.
  3. Enumerate all multiples of p by counting to n in increments of p, and mark them as composite.
  4. Let p' equal p and let p equal the first number greater than p' in the list that is not marked composite. If there is no such number, stop. Otherwise, repeat the algorithm from step 3.

Okay, great, what does this look in actual code then? Well, I'm going to implement it using Java and Processing. I'm doing this because I think it'll be way easier to visualize the algorithm this way.

void classic_seive(int n){
  boolean[] l = new boolean[n + 1];
  // l is a table denoting whether an integer is composite 
  int p = 2;
  while(p <= n){
    for(int marker = 2 * p; marker <= n; marker += p){
      assert(marker % p == 0);
      l[marker] = true;
      visit(marker);
    }
    for(p++; p <= n && l[p]; p++){
      // increment p until l[p] is false.
    }
  }
}

This is a pretty straight-forward implementation of the algorithm. The only real note-worthy addition is the visit() call. All this does is tell my visualizer to register that value in the visualization. So what does this algorithm look like? I think the best way to capture this sieve is to look at the numbers being checked over time and this is exactly what the visit() call is going to do for us.

That graph shows which numbers are being checked over time as the sieve computes all primes less than 1000. If you look at what's happening in that graph, you'll see that early on we're eliminating lots of composite numbers and that later primes are eliminating fewer and fewer (this makes sense, as you only check the multiples of the prime and the gaps between multiples keep getting larger.

The Often Presented Functional Alternative

One problem with the sieve is that the algorithm is incredibly imperative. You initialize a large list and then mutate it a whole bunch as you repeatedly iterate through it. This does not lend itself to clever or beautiful functional implementations. However, people often supply an example implementation which is really quite lovely. From SICP:

(define (sieve stream)
  (cons-stream
   (stream-car stream)
   (sieve (stream-filter
           (lambda (x)
             (not (divisible? x (stream-car stream))))
           (stream-cdr stream)))))

I love the way Scheme reads. Unfortunately, this is not a true implementation of the sieve! Examine the recursive operation. The filter is applied to all remaining elements in the stream. Of course lazy evaluation muddles this a bit, but the whole cdr is being filtered. Every number still in the sieve and which is greater than the current prime will be checked for divisibility. Therefore, every composite number m will be checked once for every prime less than it's first prime factor. In the true sieve, every composite is checked once for every prime factor of that composite number.

Let's visualize this difference by reimplementing (a non-lazy version) in Processing:

void func_seive(int n){
  ArrayList vals = new ArrayList();
  for (int t = 2; t <= n; t++){
    vals.add(t);
  }
  vals = func_seive_helper(vals);
}

ArrayList func_seive_helper(ArrayList vals){
  if(vals.size() == 0)
    return vals;
  int cur = vals.get(0);
  ArrayList filtered = new ArrayList();
  for(Integer i : vals){
    if(i == cur){
      continue;
    }else{
      visit(i);
      if(i % cur != 0){
        filtered.add(i);
      }
    }
  }
  filtered = func_seive_helper(filtered);
  filtered.add(0, cur);
  return filtered;
}

My implementation is significantly uglier. I blame this on my comparably bad coding techniques (I cannot compete with SICP) and my need to reimplement filter. To further illustrate the difference, I changed visit so that the red hue of the points change the more that a particular number is checked. I actually did this before, but because numbers are only checked at most 3 times in the true sieve, the hue shift is basically imperceptible.

As you can see, numbers are being visited way more often. For example, 11 is visited when p equals 2, 3, 5, and 7. In the true sieve, 11 is never visited!

But What About Incremental Construction?

The functional "sieve" is not without its merits. In the classic sieve algorithm, you have to commit to the size of the sieve at the start of the algorithm. In the functional sieve, if you use lazy computation or streams then you can construct arbitrarily sized sieves. That's a nice difference, especially if you don't know at the outset how many primes you need to generate.

Of course, you can achieve a similar property by some hacking on the imperative classic sieve. Instead of constructing a sieve at whatever size you want initially, you use some "starter" size, and as the sieve needs to get larger, you resize your table, and update any new entries by rechecking them with the previously computed primes. In order to do this, your table needs to be a bit more complicated. This is because you need to remember what the last multiple was for each prime. That is easy enough to do by just upgrading our boolean table to an int table.

void fix_table(int cur_max, int[] l){
  int p = 2;
  while (p <= cur_max){
    if(l[p] == 0){
      l[p] = p;
    }
    for(int marker = p + l[p]; marker <= cur_max; marker += p){
      assert(marker % p == 0);
      l[marker] = 1;
      visit(marker);
      l[p] = marker;
    }
    for(p++; p <= cur_max && l[p] == 1; p++){
      // increment p until l[p] is false.
    }
  }
}

void incremental_seive(int n){
  int[] l = new int[n + 1];
  int table_size = 100;
  int p = 2;
  fix_table(table_size, l);
  while(table_size < n){
    int prev_size = table_size;
    table_size *= 2;
    if(table_size > n){
      table_size = n;
    }
    l = Arrays.copyOf(l, table_size + 1);
    fix_table(table_size, l);
  }  
}

So what does this one look like?

Pretty cool looking, I think. Also, notice that no numbers are "turning red" in this implementation, either.