Monday, January 27, 2014

Dynamic Programming and Simple Memoization

Dynamic programming was always one of my favorite topics, and so I thought I'd spend some time visualizing the transformation of an exponential time algorithm into a polynomial time algorithm. A lot of dynamic programming algorithms are fairly straight-forward translations of recursions via memoization. The real key, it turns out, is getting the recursion right.

The Viterbi Algorithm and Hidden Markov Models

A Markov model is a probabilistic model that can be imagined with the help of a state machine. Each state has some probability emitting an output. After an output, each state then has some probability of transitioning to other states. Then, you can construct a stream of data by repeatedly querying this model. Of course, you can also attempt to go in the opposite direction, too, and attempt to explain some series of data using a Markov model. In this case, you'll need to construct the most probable series of states for a given series. How can you do this?

The Viterbi algorithm provides us with a recursion relationship that you can use to figure out the the most likely series of transitions. You base this recursion on Vt,k, which is the probability of the most probable state sequence given the first t observations and ends at the state k. From the wikipedia entry:

Vt,k = Pr[ yt | k ] * maxforall x in states[ax,k Vt-1,x]

Where yt is the tth observation, and ax,k is the probability of transitioning from state x to state k. As you can see, this gives a recursion which can be solved directly.

Solving the Recursion Directly

Okay, so how does a direct solution look in the code?

  float p_observations_to_state(int t, int k){
    /*
      P_observations_to_state (V_t_k)
      :=  Pr(the most probable sequence for t observations
      which arrive at state k)
    */
    visit(t, k);
    
    float ret;
    if(t == 1){
      // base case
      ret =  p(obs(0), k) * pi(k);
    }else{
      // recursion
      float max_p_rec = -1;
      for(int x = 0; x < states(); x++){
         float p_x_to_k = transition(x, k) * p_observations_to_state(t-1, x);
         if(p_x_to_k > max_p_rec)
            max_p_rec = p_x_to_k;
      }
      
      ret = p(obs(t-1), k) * max_p_rec;
    }
    
    return ret;
  }

Great! So what's going on here? First of all, I use a bunch of function calls to actually compute the transition and emission probabilities. That's fine. What else is happening? I make a visit() call so that I can track the runtime of the algorithm. You can also see the exponential nature of the recursion more directly once it is translated into code. Each call actually generates K recursions: one for each state in the state-space of the Markov model.

What does this actually look like in practice? Let's imagine each function call of this algorithm as a pair of integers. How many times is each pair visited? Are pairs being visited many times? Or is there some fundamental reason for the exponential recursion that cannot be fixed with memoization? For Markov model with 5 states and a series of 10 observations, I'll plot the number of times each pair (t,k) is visited.

As you can see, pairs on the left edge of the graph which comprise the "base cases" of the recursion, are visited very frequently. Pairs to the right are visited less frequently, and there's a 5x shrinking going on. Of course, this isn't too surprising, given the implementation. What this does indicate, though, is that there is a lot to be gained from memoization. Why? Well, what happens is that the pairs to the right are actually generating a whole slew of calls to the left, so the memoization isn't only eliminating a bunch of recomputation in the "body" of the recursive calls, it's actually eliminating further recursive calls. Essentially, it is going to reduce the number of times the question "What is the Vt,k?" is even asked. This is an important requirement that really changes the runtime of the algorithm

Memoization and Dynamic Programming

Often times, dynamic programming is presented as the technique involving matrices, and operations updating cells, etc. That's great, but it kind of obscures what I believe to be the real underlying principle, namely, memoization. The general idea, is to look at the "call graph" of a recursive algorithm and to say, well, if I cache results, does this thin out the call graph significantly?

First let's see how memoization changes our implementation:

  float p_observations_to_state(int t, int k){
    /*
      P_observations_to_state (V_t_k)
      :=  P(the most probable sequence for t observations
      which arrive at state k)
    */
    
    visit(t, k);

    if(single.memoizing && single.memoized[t][k] != -1){
      return single.memoized[t][k];
    }
    
    float ret;
    if(t == 1){
      // base case
      ret =  p(obs(0), k) * pi(k);
    }else{
      // recursion
      float max_p_rec = -1;
      for(int x = 0; x < states(); x++){
          float p_x_to_k = transition(x, k) * p_observations_to_state(t-1, x);
          if(p_x_to_k > max_p_rec)
              max_p_rec = p_x_to_k;
      }
      
      ret = p(obs(t-1), k) * max_p_rec;
    }
    
    if(single.memoizing)
      single.memoized[t][k] = ret;
    return ret;
  }

Pretty straight forward. It's a little hack-y in that I use "-1" to indicate that a variable hasn't been cached, but I hack, so hack-y is what you get. Notice also that the visit() call occurs before the cache is consulted: this ensures that every function entry is being counted, even ones which only do a cached read. This keeps our runtime numbers honest. So how does this change the number of visits each pair receives?

Woah. As you can see, the left most pairs are visited vastly less than before. And this, of course, is why dynamic programming is so powerful. It can drastically reduce the cost of an algorithm, without really changing the reasoning required. Another thing to note is that pairs are still being visited multiple times, however, the body of the function is only being evaluated once for each pair. Why is there this discrepancy, then? Well, the pair (1, 1) is visited when any pair (2, k) is evaluated. Each of those pairs is evaluated once, and there are five such pairs, so therefore (1, 1) will be visited five times. This

Monday, January 13, 2014

Visualizing What is and What is Not the Sieve of Eratosthenes

The Classic Algorithm

The Sieve is a classic algorithm to create a list of prime numbers. However, this algorithm is often presented wrong, as discussed in an absolutely delightful paper by Melissa O'Neill. This is especially true when people present functional implementations of the algorithm. This post will explore how we can visualize the differences between implementations.

So what is the classic algorithm? To find all the prime numbers less than or equal to a given integer n, Wikipedia gives us the following method:

  1. Create a list of integers from 2 through n: (2, 3, 4, ..., n).
  2. Let p equal 2, the first prime number.
  3. Enumerate all multiples of p by counting to n in increments of p, and mark them as composite.
  4. Let p' equal p and let p equal the first number greater than p' in the list that is not marked composite. If there is no such number, stop. Otherwise, repeat the algorithm from step 3.

Okay, great, what does this look in actual code then? Well, I'm going to implement it using Java and Processing. I'm doing this because I think it'll be way easier to visualize the algorithm this way.

void classic_seive(int n){
  boolean[] l = new boolean[n + 1];
  // l is a table denoting whether an integer is composite 
  int p = 2;
  while(p <= n){
    for(int marker = 2 * p; marker <= n; marker += p){
      assert(marker % p == 0);
      l[marker] = true;
      visit(marker);
    }
    for(p++; p <= n && l[p]; p++){
      // increment p until l[p] is false.
    }
  }
}

This is a pretty straight-forward implementation of the algorithm. The only real note-worthy addition is the visit() call. All this does is tell my visualizer to register that value in the visualization. So what does this algorithm look like? I think the best way to capture this sieve is to look at the numbers being checked over time and this is exactly what the visit() call is going to do for us.

That graph shows which numbers are being checked over time as the sieve computes all primes less than 1000. If you look at what's happening in that graph, you'll see that early on we're eliminating lots of composite numbers and that later primes are eliminating fewer and fewer (this makes sense, as you only check the multiples of the prime and the gaps between multiples keep getting larger.

The Often Presented Functional Alternative

One problem with the sieve is that the algorithm is incredibly imperative. You initialize a large list and then mutate it a whole bunch as you repeatedly iterate through it. This does not lend itself to clever or beautiful functional implementations. However, people often supply an example implementation which is really quite lovely. From SICP:

(define (sieve stream)
  (cons-stream
   (stream-car stream)
   (sieve (stream-filter
           (lambda (x)
             (not (divisible? x (stream-car stream))))
           (stream-cdr stream)))))

I love the way Scheme reads. Unfortunately, this is not a true implementation of the sieve! Examine the recursive operation. The filter is applied to all remaining elements in the stream. Of course lazy evaluation muddles this a bit, but the whole cdr is being filtered. Every number still in the sieve and which is greater than the current prime will be checked for divisibility. Therefore, every composite number m will be checked once for every prime less than it's first prime factor. In the true sieve, every composite is checked once for every prime factor of that composite number.

Let's visualize this difference by reimplementing (a non-lazy version) in Processing:

void func_seive(int n){
  ArrayList vals = new ArrayList();
  for (int t = 2; t <= n; t++){
    vals.add(t);
  }
  vals = func_seive_helper(vals);
}

ArrayList func_seive_helper(ArrayList vals){
  if(vals.size() == 0)
    return vals;
  int cur = vals.get(0);
  ArrayList filtered = new ArrayList();
  for(Integer i : vals){
    if(i == cur){
      continue;
    }else{
      visit(i);
      if(i % cur != 0){
        filtered.add(i);
      }
    }
  }
  filtered = func_seive_helper(filtered);
  filtered.add(0, cur);
  return filtered;
}

My implementation is significantly uglier. I blame this on my comparably bad coding techniques (I cannot compete with SICP) and my need to reimplement filter. To further illustrate the difference, I changed visit so that the red hue of the points change the more that a particular number is checked. I actually did this before, but because numbers are only checked at most 3 times in the true sieve, the hue shift is basically imperceptible.

As you can see, numbers are being visited way more often. For example, 11 is visited when p equals 2, 3, 5, and 7. In the true sieve, 11 is never visited!

But What About Incremental Construction?

The functional "sieve" is not without its merits. In the classic sieve algorithm, you have to commit to the size of the sieve at the start of the algorithm. In the functional sieve, if you use lazy computation or streams then you can construct arbitrarily sized sieves. That's a nice difference, especially if you don't know at the outset how many primes you need to generate.

Of course, you can achieve a similar property by some hacking on the imperative classic sieve. Instead of constructing a sieve at whatever size you want initially, you use some "starter" size, and as the sieve needs to get larger, you resize your table, and update any new entries by rechecking them with the previously computed primes. In order to do this, your table needs to be a bit more complicated. This is because you need to remember what the last multiple was for each prime. That is easy enough to do by just upgrading our boolean table to an int table.

void fix_table(int cur_max, int[] l){
  int p = 2;
  while (p <= cur_max){
    if(l[p] == 0){
      l[p] = p;
    }
    for(int marker = p + l[p]; marker <= cur_max; marker += p){
      assert(marker % p == 0);
      l[marker] = 1;
      visit(marker);
      l[p] = marker;
    }
    for(p++; p <= cur_max && l[p] == 1; p++){
      // increment p until l[p] is false.
    }
  }
}

void incremental_seive(int n){
  int[] l = new int[n + 1];
  int table_size = 100;
  int p = 2;
  fix_table(table_size, l);
  while(table_size < n){
    int prev_size = table_size;
    table_size *= 2;
    if(table_size > n){
      table_size = n;
    }
    l = Arrays.copyOf(l, table_size + 1);
    fix_table(table_size, l);
  }  
}

So what does this one look like?

Pretty cool looking, I think. Also, notice that no numbers are "turning red" in this implementation, either.

Monday, October 8, 2012

Custom Serialization in Pyro4

I am a pretty big fan of the Python remote object library Pyro. It has a great interface that's small and easy to use. However, there's been one big problem for me. I've been working on a project that composes applications into components running in separate processes to achieve some isolation properties. These components need to communicate with each other, so they do this through a normal Unix socket. I used Pyro to set up the RPCs easily.

The problem is that Pyro uses Python's pickle module. If you click on that link, you should see a nice warning right there at the top. Never unpickle data from an untrusted source. Suddenly my nice isolation properties have vanished! This is a problem that's looked at very nicely over at the blog Why Python Pickle is Insecure. Now, that article and numerous other places on the internet offer suggestions for how to make pickle more safe or to use safer serialization libraries.

What's the problem then? Well, Pyro doesn't really expose a great interface for replacing the serialization library. Or at least they don't advertise one. But I didn't want to give up Pyro, so I poked around in the codebase of Pyro4 for a little bit and found this in Pyro4/util.py:

class Serializer(object):
    """
    A (de)serializer that wraps a certain serialization protocol.
    Currently it only supports the standard pickle protocol.
    It can optionally compress the serialized data, and is thread safe.
    """

    def serialize(self, data, compress = False):
        ...
    def deserialize(self, data, compressed=False):
        ...
    def __eq__(self, other):
        ...
    def __ne__(self, other):
        ...
So it turns out there's this class that defines serialize() and deserialize() methods for use with Pyro. But how do we get our Proxy objects and Daemons to use a custom serializer? Looking at those two class definitions, we find that each has an instance variable for storing that instance's serializer object. So to install your own serializer looks like this:
def start_daemon():
    daemon = Pyro4.Daemon()
    daemon.serializer = customSerializer()
    daemon.register(foo, "foo_name")
    daemon.requestLoop()

def get_proxy():
    proxy = Pyro4.Proxy(uri)
    proxy._pyroSerializer = customSerializer()
    return proxy
Great! So this will work for you as long as your serializer defines the functions that the default serializer does. You only need to be able to serialize and deserialize the objects that your project uses, but remember that Pyro will try to serialize Exceptions and the like.

Monday, July 2, 2012

Disassembling Python One Line at a Time

Recently, I had a desire to inspect Python bytecodes during the course of execution. My (admittedly non-exhaustive) search yielded no good pointers, so I figured I could do this using some crazy combination of Python tracing, inspection, and disassembly. This turned out to be true, though I had to futz around with some of the supplied Python code.

If you take a look at Python's dis module, you'll find that it does lots of really neat things. However, if you just want to disassemble a single line of code (for example, the line about to be executed) and analyze than you're a little SOL. However, looking at the code, it's actually pretty simple. Below, I've got a modified version of disassemble that takes a line number as a parameter and it only prints the bytecodes associated with that line. Of course, this is heavily cribbed from the cpython source code, so nearly all of the credit rests with the hackers over there.

import inspect, dis, opcode

def foo(arg):
    result = arg + 1
    return result

def __find(l, func):
    for ix, v in enumerate(l):
        if func(v):
            return ix
    return -1

def disassemble_line(co, lineno):
    code = co.co_code
    labels = dis.findlabels(code)
    linestarts = list(dis.findlinestarts(co))
    line_offset_ix = __find(linestarts, lambda val : val[1] == lineno)
    line_offset = linestarts[line_offset_ix][0]
    n = len(code)

    if line_offset_ix + 1 < len(linestarts):
        next_offset = linestarts[line_offset_ix + 1][0]
    else:
        next_offset = n

    i = line_offset
    extended_arg = 0
    free = None
    while i < next_offset:
        c = code[i]
        op = ord(c)
        if i in labels: print '>>>',
        else: print '   ',
        print repr(i).rjust(4),
        print opcode.opname[op].ljust(20),
        
        i = i + 1
        if op >= opcode.HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
            extended_arg = 0
            i = i + 2
            if op == opcode.EXTENDED_ARG:
                extended_arg = oparg*65536L
            print repr(oparg).rjust(5),
            if op in opcode.hasconst:
                print '(' + repr(co.co_consts[oparg]) + ')',
            elif op in opcode.hasname:
                print '(' + co.co_names[oparg] + ')',
            elif op in opcode.hasjrel:
                print '(to ' + repr(i + oparg) + ')',
            elif op in opcode.haslocal:
                print '(' + co.co_varnames[oparg] + ')',
            elif op in opcode.hascompare:
                print '(' + opcode.cmp_op[oparg] + ')',
            elif op in opcode.hasfree:
                if free is None:
                    free = co.co_cellvars + co.co_freevars
                print '(' + free[oparg] + ')',
        print

if __name__ == "__main__":
    disassemble_line(foo.func_code, 4)

Wednesday, May 30, 2012

Fun with Python and Bit Sets

Bit sets? What are they? And Why?

First off, what's a bit set? Well, it's a particular way to represent sets of integers. And it's a pretty cool way, too! Let's say you expect integers in the range [0, N], well, you can represent this set as a N-bit string, where the ath bit is 1 if and only if a is in the set.

Now for an Example

For whatever reason, I recently found myself trying to solve the following problem: For n points in the plane, find a set of lines such that for any line that partitions the n points into two sets of points, there exists a line in the set that forms that same partition.

How do we go about solving this? I chose the most straight-forward approach -- exhaustive search. I took every pair of points in the set (this is n choose 2) and formed two lines that cut between these pairs. Now, as I check all the partitions formed by these two lines, I need to check whether or not I have seen a particular partition yet. How do I represent partitions and how do I check them quickly?

Bit sets.

Not only are bit sets fast and compact, but they have a wonderful quality-- a particular set can only be represented by one sequence of bytes. So to check whether or not I have seen a set before, I can just store the previously seen sets in a tree or hash set. This makes that check very fast in comparison to checking it against every set that has come before.

Can't You Use Hash Sets?

The problem with hash sets is that a particular set can be represented with different structures. For example, if you have multiple values that hash to the same bucket in your hash set, you need to store a linked list to the values. That linked list can be reordered. It could have different pointer values. For all of these reasons, checking whether or not two hash sets are equal requires that you do more complicated checks.

Using Bit Sets in Python

So how are we going to do this? First let's see if anyone has implemented this before. A quick google search leads us to the bitarray library. Oh wonderful, it looks like bitarray is more or less what we're looking for.

Let's see, aside from some baffling initialization behavior and a lack of a good __hash__ function, this is good stuff. Let's see how it stacks up. First, I'll hack together a simple testing script...

import bitarray, random, timeit, array
import sys
set_size = 100
num_sets = set_size*set_size

def test_pytuple():
    seen_before = set()
    rng = random.Random()
    
    for i in range(0, num_sets):
        b = (rng.randint(0,1) for x in range(0, set_size))
        if b not in seen_before:
            seen_before.add(b)
        
def test_bitarray():
    seen_before = set()
    rng = random.Random()
    
    for i in range(0, num_sets):
        b = bitarray.bitarray([rng.randint(0,1) 
                               for x in range(0, set_size)])
 if b not in seen_before:
            seen_before.add(b)

def test_array():
    seen_before = set()
    rng = random.Random()
    
    for i in range(0, num_sets):
        b = array.array('b' , (rng.randint(0,1) 
                               for x in range(0, set_size)))
 if b.tostring() not in seen_before:
            seen_before.add(b.tostring())
    

if __name__ == "__main__":
    if sys.argv[1] == "bitarray":
        t = timeit.Timer(test_bitarray)
        print "Bitarray average time: %f s" % (t.timeit(10)/10.0)
        if sys.argv[1] == "tuple":
            t = timeit.Timer(test_pytuple)
            print "pytuple average time: %f s" % (t.timeit(10)/10.0)
            if sys.argv[1] == "array":
                t = timeit.Timer(test_array)
                print "array average time: %f s" % (t.timeit(10)/10.0)

Here, we can see that I added some code for testing some alternatives as well... Let's see what happens:

$ python bitarrays.py bitarray
Bitarray average time: 2.327128 s

Well that's not great. What if I just use python tuples or the array?

$ python bitarrays.py tuple
pytuple average time: 0.037775 s
$ python bitarrays.py array
array average time: 2.461786 s

Tuples are Better?!

Hammer of Thor! It looks like the tuple implementation is a lot faster. Why is that happening? Is there something wrong with the way I am testing this? Something looks fishy here, particularly this line:

b = (rng.randint(0,1) for x in range(0, set_size))

Why is this line fishy? Well, the tuple implementation doesn't have to convert from a list. What if we force it to by doing something like this:

b = tuple([rng.randint(0,1) for x in range(0, set_size)])

Then it's run time shoots up to 2.292623 s! Okay, so we've discovered that really what was important in this silly test was probably the construction of the set more than the actual usage. Okay, but what about memory usage? Surely, bitarrays must have an advantage there. After checking peak memory usage using a memusg script, I found the following:

StructurePeak Memory Allocation (mb)
tuple21408
bitarray7400
array7400

Importantly, this measures allocated memory and not actually memory usage. However, we can see, as one would expect, that tuples require more memory. They represent each boolean as a full integer, while arrays and bitarrays are more compressed.

So what have we learned? Well, you can represent compact sets using an array of booleans, and it's efficient to do this. One clever way to do this is using bitarrays, but they may be a bit too clever for what you actually want to do. In Python, you may just want to use a tuple of integers and call it a day.

Tuesday, January 17, 2012

Customized Boxplots with Matplotlib

Recently, I wanted to use a box plot to describe some data I had collected. I had a program whose response latency depended on a variable k and I wanted to show some information about the distribution of latencies for each value of k that I tested. Box plots, in my opinion, are perfect for this kind of display.

My toolkit of choice for plotting or graphing data happens to be the combination of numpy, matplotlib, and scipy. I immediately found the function I was looking for: matplotlib.pyplot.boxplot. Unfortunately, I wanted the box plot to show information about the 99th percentile, and woe, this function will only draw whiskers based on the IQR.

So, I wrote my own box plot implementation, that's a little bit more generic and most importantly met my needs.

# @author: Aaron Blankstein 

from scipy.stats import scoreatpercentile

class boxplotter(object):
    def __init__(self, median, top, bottom, whisk_top=None, 
                 whisk_bottom=None):
        self.median = median
        self.top = top
        self.bott = bottom
        self.whisk_top = whisk_top
        self.whisk_bott = whisk_bottom
    def draw_on(self, ax, index, box_color = "blue", 
                median_color = "red", whisker_color = "black"):
        width = .7
        w2 = width / 2
        ax.broken_barh([(index - w2, width)],
                       (self.bott,self.top - self.bott), 
                       facecolor="white",edgecolor=box_color)
        ax.broken_barh([(index - w2, width)],
                       (self.median,0), 
                       facecolor="white", edgecolor=median_color)
        if self.whisk_top is not None:
            ax.broken_barh([(index - w2, width)],
                           (self.whisk_top,0), 
                           facecolor="white", edgecolor=whisker_color)
            ax.broken_barh([(index , 0)], 
                           (self.whisk_top, self.top-self.whisk_top),
                           edgecolor=box_color,linestyle="dashed")
        if self.whisk_bott is not None:
            ax.broken_barh([(index - w2, width)],
                           (self.whisk_bott,0), 
                           facecolor="white", edgecolor=whisker_color)
            ax.broken_barh([(index , 0)], 
                           (self.whisk_bott,self.bott-self.whisk_bott),
                           edgecolor=box_color,linestyle="dashed")

def percentile_box_plot(ax, data, indexer=None, box_top=75, 
                        box_bottom=25,whisker_top=99,whisker_bottom=1):
    if indexer is None:
        indexed_data = zip(range(1,len(data)+1), data)
    else:
        indexed_data = [(indexer(datum), datum) for datum in data]
    def get_whisk(vector, w):
        if w is None:
            return None
        return scoreatpercentile(vector, w)

    for index, x in indexed_data:
        bp = boxplotter(scoreatpercentile(x, 50),
                        scoreatpercentile(x, box_top),
                        scoreatpercentile(x, box_bottom),
                        get_whisk(x, whisker_top),
                        get_whisk(x, whisker_bottom))
        bp.draw_on(ax, index)

def example():

    from pylab import rand, ones, concatenate
    import matplotlib.pyplot as plt
    # EXAMPLE data code from: 
    # http://matplotlib.sourceforge.net/pyplots/boxplot_demo.py
    # fake up some data
    spread= rand(50) * 100
    center = ones(25) * 50
    flier_high = rand(10) * 100 + 100
    flier_low = rand(10) * -100
    data =concatenate((spread, center, flier_high, flier_low), 0)
    # fake up some more data
    spread= rand(50) * 100
    center = ones(25) * 40
    flier_high = rand(10) * 100 + 100
    flier_low = rand(10) * -100
    d2 = concatenate( (spread, center, flier_high, flier_low), 0 )
    data.shape = (-1, 1)
    d2.shape = (-1, 1)
    data = [data, d2, d2[::2,0]]

    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    ax.set_xlim(0,4)
    percentile_box_plot(ax, data)
    plt.savefig('example.png')
 
if __name__ == "__main__":
    example()


The example() method produced the lovely box plot above. If you supply None arguments to either of the whiskers, it won't draw that particular whisker. Anyways, happy plotting.