from os import system
import time
import os
import stat
import string
import operator
import sys
import math
import cPickle
import types
import random

# category 1
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       BOOLEAN  FUNCTIONS                 ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  Adding scheme-like capabilities to if expressions
#
############################################################

def ifab(test, a, b):
    """x = ifab(test, a, b)
       WARNING:  Both 'a' and 'b' are evaluated
       C equivalent: x = test?a:b;
       Scheme equiv: (set x (if test a b))
       Python equiv: test and a or b
       None of the equivalents evaluates both arguments
    """
    if test: return a
    else: return b

def case(variable, case2value, default=None):
    if case2value.has_key(variable):
        return case2value[variable]
    else:
        if default==None:
            raise "Unexpected value", "%s not in dictionary %s"%(
                variable, case2value.keys())
        else: 
            return default

def my_not(a):
    return not a

def xor(a,b):
    return (a and not b) or (not a and b)

def filter_not(func, list):
    newlist = []
    for item in list:
        if not func(item): newlist.append(item)
    return newlist

def modify_and_return(f, a):
    # useful when you need to call a function on a modifiable
    # object, and reuse the modified object in your next operation
    f(a)
    return a

def wait_a_bit(how_long):
    list = reverse(range(0,how_long))
    for i in range(0,len(list)):
        for j in range(0,len(list)):
            map(lambda x: x==3, list)

# category 2
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       PRINTING                           ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  Recursing and exploring any datastructure (except class)
#
############################################################

def pp(d,level=-1,maxw=0,maxh=0,parsable=0):
    """ wrapper around pretty_print that prints to stdout"""
    if not parsable: 
        pretty_print(sys.stdout, d, level, maxw, maxh, '', '', '')
    else:
        import pprint
        if maxw: pp2 = pprint.PrettyPrinter(width=maxw, indent=1)#, depth=level
        else: pp2 = pprint.PrettyPrinter(indent=1)#, depth=level
        pp2.pprint(d)

def test_pp():
    pp({'one': ('two',3,[4,5,6]),
        7: (lambda x: 8*9),
        'ten': ['ele', {'ven': 12,
                        (13,14): '15'}]})

def pretty_print(f, d, level=-1, maxw=0, maxh=0, gap="", first_gap='', last_gap=''):
    # depending on the type of expression, it recurses through its elements
    # and prints with appropriate indentation

    # f   is the output file stream
    # d   is the data structure
    #
    # level is the number of allowed recursive calls, the depth at which
    #       the data structure is explored
    #       default: -1 means never stop recursing early
    # maxw  is the maximum width that will be printed from the last element
    #       of the recursion (when no further recursion is possible, or
    #       the maximal depth has been reached)
    #       default: 0 means every line will be printed in its entirety, regardless
    #                of how long it may be
    # maxh  (max height) is the maximum number of elements that will be
    #       printed from a list or a dictionary, at any level or recursion
    #       default: 0 means every list or dictionary will have all its elements
    #                printed, even if it contains thousands of elements
    #
    # gap is the gap to include before every element of a list/dic/tuple
    # first_gap is the opening gap before the opening bracket, parens or curly braces
    # first_gap is the closing gap before the closing bracket, parens or curly braces
    
    if level == 0:
        if type(d) != types.StringType: d = `d`

        if maxw and len(d) > maxw:
            final = ifab(maxw > 20, 10, maxw/2)
            f.write(first_gap+d[:maxw-final]+'...'+d[-final:]+' (%s chars)\n' % len(d))
        else: f.write(first_gap+d+'\n')
    elif type(d) == types.ListType:
        if not d:
            f.write(first_gap+"[]\n")
            return
        # recurse on lists
        f.write(first_gap+"[\n")
        h = 0
        for el in d:
            pretty_print(f, el, level-1, maxw, maxh, gap+'   ', gap+' ->', gap+'   ')
            if maxh:
                h = h+1
                if h >= maxh and maxh<len(d):
                    f.write(gap+' -> ... (%s in list)\n'%len(d))
                    break
        f.write(last_gap+"]\n")
    elif type(d) == types.TupleType:
        if not d:
            f.write(first_gap+"()\n")
            return
        # recurse on tuples
        f.write(first_gap+"(\n")
        h = 0
        for el in d:
            pretty_print(f, el,
                         level     = level-1,
                         maxw      = maxw,
                         maxh      = maxh,
                         gap       = gap+'   ',
                         first_gap = gap+' =>',
                         last_gap  = gap+'   ')
            if maxh:
                h = h+1
                if h >= maxh and maxh<len(d):
                    f.write(gap+' => ... (%s in tuple)\n'%len(d))
                    break
        f.write(last_gap+")\n")
    elif type(d) == types.DictType:
        if not d:
            f.write(first_gap+"{}\n")
            return
        # recurse on dictionaries
        f.write(first_gap+"{\n")
        keys = d.keys()
        keys.sort()
        key_strings = map(lambda k: ifab(type(k)==types.StringType, k, `k`), keys)
        maxlen = max(map(len, key_strings))
        h = 0
        for k,key_string in map(None, keys, key_strings):
            key_string = sfill(key_string,maxlen,'.')
            blank_string = ' '*len(key_string)
            pretty_print(f, d[k],
                         level     = level-1,
                         maxw      = maxw,
                         maxh      = maxh,
                         gap       = gap+'    %s'%blank_string,
                         first_gap = gap+'  %s: '%key_string,
                         last_gap  = gap+'    %s'%blank_string)
            if maxh:
                h = h+1
                if h >= maxh and maxh<len(keys):
                    remaining_keys = []
                    for k in keys[h:]:
                        if type(k) == types.TupleType:
                            remaining_keys.append(`k`)
                        else:
                            remaining_keys.append('%s'%k)
                    remaining_keys = string.join(remaining_keys,',')
                    #f.write(gap+'  %s (%s keys)\n'%(remaining_keys, len(keys)))
                    pretty_print(f, '  %s (%s keys)'%(remaining_keys, len(keys)),0,maxw,0,
                                 gap,gap,'')
                    break
            
            #gap+' '*(len(key_string)+3), '', gap+' '*(len(key_string)+5))
        f.write(last_gap+"}\n")
    elif type(d) == types.InstanceType:
        fields = dir(d)
        
        if not fields:
            f.write(first_gap+"*EmptyClass*\n")
            return
        # recurse on classes
        f.write(first_gap+"*ClassInstance %s\n"%d)
        fields.sort()
        key_strings = map(lambda k: ifab(type(k)==types.StringType, k, `k`), fields)
        maxlen = max(map(len, key_strings))
        h = 0
        for k,key_string in map(None, fields, key_strings):
            key_string = sfill(key_string,maxlen,'.')
            blank_string = ' '*len(key_string)
            pretty_print(f, eval('d.'+k),
                         level     = level-1,
                         maxw      = maxw,
                         maxh      = maxh,
                         gap       = gap+'    %s'%blank_string,
                         first_gap = gap+'  %s: '%key_string,
                         last_gap  = gap+'    %s'%blank_string)
            if maxh:
                h = h+1
                if h >= maxh and maxh<len(keys):
                    remaining_keys = []
                    for k in keys[h:]:
                        if type(k) == type(()):
                            remaining_keys.append(`k`)
                        else:
                            remaining_keys.append('%s'%k)
                    remaining_keys = string.join(remaining_keys,',')
                    #f.write(gap+'  %s (%s keys)\n'%(remaining_keys, len(keys)))
                    pretty_print(f,
                                 '  %s (%s keys)'%(remaining_keys, len(keys)),
                                 0,
                                 maxw,
                                 0,
                                 gap,
                                 gap,
                                 '')
                    break
            
            #gap+' '*(len(key_string)+3), '', gap+' '*(len(key_string)+5))
        f.write(last_gap+"*\n")
    elif type(d) == type(""):
        # simply print strings (no quotes)
        if maxw and len(d)>maxw:
            final = ifab(maxw > 20, 10, maxw/2)
            f.write(first_gap+d[:maxw-final]+'..'+d[-final:]+' (%s)\n' % len(d))
        else:
            f.write(first_gap+d+'\n')
    else:
        # string conversion of all other types
        if maxw and len(`d`)>maxw:
            final = ifab(maxw > 20, 10, maxw/2)
            f.write(first_gap+`d`[:maxw-final]+'..'+`d`[-final:]+' (%s)\n' % len(`d`))
        else:
            f.write(first_gap+`d`+'\n')

############################################################
#
#  Functions for printing
#
############################################################

def disp(x):
    # useful for using print as a function.
    # ex:  map(disp, ['hello','there','how','are','you'])
    print x

def display_list(list, join_char=', ', format='%s'):
    # joins all the elements of a string together
    #
    # bugs: it will bug if any list element is a tuple,
    #       coz the format will complain of too few %s. 
    return string.join(map(lambda el,format=format: format%el, list), join_char)

def display_bignum(n,digits=1):

    # display a big number for which you don't really want to know
    # all the digits

    # n is the number
    # digits are the number of significant decimal places to which
    # the magnitude of the number is to be known, at the appropriate unit 

    # ex: display_bignum(1300403) -> 1.3M
    # ex: display_bignum(13004) -> 130.0k
    # ex: display_bignum(134) -> 134.0
    
    sign = ifab(n>=0,'','-')
    n = int(abs(n))
    if n < 1000: return sign+`n`
    elif n < 1000000: k,units = 3,'k'
    elif n < 1000000000: k,units = 6,'M'
    elif n < float(1000000000000): k,units = 9,'G'
    elif n < float(1000000000000000): k,units = 12,'T'
    elif n < float(1000000000000000000): k,units = 15,'P'
    else: return sign+`n`
    main = `n`[:-k]

    if digits:
        try: decimal = `n`[-k:][:digits]
        except: decimal = `0`
        return sign+main+'.'+string.replace(decimal,'.','')+units
    else: 
        return sign+main+units

def dec(n):
    # adds commas between the different magnitudes of a number
    # 134 -> 134
    # 13402 -> 13,402
    # 134020 -> 134,020
    # 134020.7 -> 134,020.7

    maybe_dec = string.split(`n`,'.')
    if len(maybe_dec) == 2:
        n,decimal_part = int(maybe_dec[0]), '.'+maybe_dec[1]
    else:
        decimal_part = ''
    
    sign = ifab(n>=0,'','-')
    n = abs(n)
    if n < 1000: return sign+`n`+decimal_part
    elif n < 1000000: return sign+`n`[:-3]+','+`n`[-3:]+decimal_part
    elif n < 1000000000: return sign+`n`[:-6]+','+`n`[-6:-3]+','+`n`[-3:]+decimal_part
    else: return sign+`n`[:-9]+','+`n`[-9:-6]+','+`n`[-6:-3]+','+`n`[-3:]+decimal_part

def safe_div(num, den, default):
    if den: return num/den
    else: return default

def safe_float(num, den, format='%2.2f'):
    if num and not den: return 'Inf'
    if not num and not den: return '?'
    return format%(float(num)/float(den))

def safe_percentile(num, den, format='%2.1f%%'):
    if num and not den: return 'Inf'
    if not num and not den: return '?'
    return format%(100*float(num)/float(den))

def safe_min(list, default):
    if not list: return default
    return min(list)

def safe_max(list, default):
    if not list: return default
    return max(list)

def perc(num, den, format='%s/%s (%s)'):
    return format%(num,den,safe_percentile(num,den))

def display_within_range(i, increment, guessmax, format='%s-%s'):
    low, high = within_range(i, increment)
    low = string.zfill(low,len(`guessmax`))
    high = string.zfill(high,len(`guessmax`))
    return format%(low, high)

def within_range(i, increment):
    # if i'm splitting directories every 100, and i'm on contig 53,
    # then i'm the range (0,99).
    # 153 -> (100,199)
    low = (i/increment)*increment
    high = low+increment-1
    return low, high

def display_fraction(num, den):
    # displays a franction in the nicest possible form
    # simplifies the fraction if possible.
    # uses .75, .5 and .25 instead of 9/4 and so on
    
    if den == 1:
        #print "case1"
        return `num`
    if num == 1:
        #print "case2"
        return '1/'+`den`
    if num == 0:
        #print "case3"
        return '0'
    if den % num == 0:
        #print "case4"
        return display_fraction(1, den/num)
    for n in range(2,den+1):
        #print "Trying case 5 for n=%d"%n
        if num % n == 0 and den % n == 0:
            #print "case5 for n=%d"%n
            return display_fraction(num/n, den/n)
    if den == 2:
        #print "case6"
        return ifab(num, `num/den`+'.5', '.5')
    if den == 4:
        #print "case7"
        if num%4 == 1: return ifab(num, `num/den`+'.25', '.25')
        if num%4 == 3: return ifab(num, `num/den`+'.75', '.75')
    #print "case8"
    return `num`+'/'+`den`

def lower_bound_mod(i, k):
    # find the largest number x, which is smaller than i,
    # and still divisible by k

    return k*(i / k)

def upper_bound_mod(i, k):
    # find the largest number x, which is smaller than i,
    # and still divisible by k

    return k*(i / k + 1)

# category 3
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       TEXT ART - or drawing with TEXT    ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  Printing intervals based on x coordinates
# 
############################################################

def print_scale(scale = {},#'min':0,'max':100,'tick':10, 'width': 120,'print_scale':1},
                items = [[(20,40,'Hello'),
                          (60,80,'World!'),
                          (90,94,':o)')],
                         [(40,60,'Niiice'), (90,87,'(o:')],
                         [(0,100,'A hundred')]]):
    write_scale(sys.stdout, scale, items)

def write_scale(f, scale, items):

    if not items:
        print "Write_scale: Nothing to display"
        return

    if not scale.has_key('width'): scale['width'] = 120
    if not scale.has_key('min'): scale['autoscale'] = 1
    
    def add_range(line1, line2, start, end, text, char='=', scale=scale):
        def convert_coord(coord, scale):
            return int(float(float(coord-scale['min'])
                             * scale['width'])
                       / (scale['max']-scale['min']))
        x1,x2 = convert_coord(start, scale), convert_coord(end, scale)
        #print "start=%s -> %s end=%s -> %s" % (start,x1,end,x2)
        if x1>x2: x1,x2,up=x2,x1,0
        else: up=1
        for i in range(x1,x2):
            try: line2[i] = char
            except IndexError: pass
        if up: start_char,end_char = '|','>'
        else: start_char,end_char = '<','|'
        
        try: line2[x2] = end_char
        except IndexError: pass
        try: line2[x1] = start_char
        except IndexError: pass
        for i in range(0, len(text)):
            try: line1[x1+1+i] = text[i]
            except IndexError: pass

    # autoscale if you have to
    if scale.has_key('autoscale') and scale['autoscale']:
        scale['min'] = min(map(lambda i: min(map(lambda a: min(a[0],a[1]),i)),items))
        scale['max'] = max(map(lambda i: max(map(lambda a: max(a[0],a[1]),i)),items))
    if scale.has_key('print_scale') and scale['print_scale'] and scale.has_key('autotick') and scale['autotick']:
        if scale.has_key('numticks'): num_ticks = scale['numticks']
        else: num_ticks = 5
        tick = (scale['max']-scale['min'])/num_ticks
        scale['tick'] = int(`tick`[0]+'0'*(len(`tick`)-1))

    # Print a coordinate scale with ticks every tick
    if scale.has_key('print_scale') and scale['print_scale']: 
        line1,line2 = [' '] * (scale['width']+scale['tick']), [' '] * (scale['width']+1)
        min_tick = (scale['min']/scale['tick']+1)*scale['tick']
        if (min_tick - scale['min'])<scale['tick']/2: min_tick = min_tick+scale['tick']
        add_range(line1,line2,scale['min'],min_tick,
                  display_bignum(scale['min'],0)
                  ,'_')
        for mark in range(min_tick, scale['max'], scale['tick']):
            add_range(line1, line2, mark, mark+scale['tick'],
                      display_bignum(mark,0), '_')
        add_range(line1,line2,scale['max'],scale['max'],
                  display_bignum(scale['max'],0),
                  '_')
        f.write(string.join(line1,'')[:scale['width']+10]+'\n')
        f.write(string.join(line2,'')[:scale['width']+10]+'\n')
    # print every line
    for line in items:
        line1 = [' '] * (scale['width'] + 10) # 10 is the max text
        line2 = [' '] * (scale['width'] + 10)
        for item in line:
            if len(item)>3: dot=item[3]
            else: dot = '='
            add_range(line1, line2, item[0], item[1], item[2], dot)
        f.write(string.join(line1,'')+'\n')
        f.write(string.join(line2,'')+'\n')

############################################################

def make_annotation(seq, orfs, start_offset):

    gapless_length = len(seq)-string.count(seq,'-')

    print "Making annotation on sequence of length %s"%(len(seq)-string.count(seq,'-'))
    print "Start offset is %s and coords are %s"%(
        start_offset, map(lambda s,e: '%s-%s'%(s,e),cget(orfs,'start'),cget(orfs,'end')))

    # takes a sequence, a
    names = map(lambda o: o['gene'] or o['name'], orfs)
    #starts = vector_scalar_sub(cget(orfs,'start'),start_offset)
    #ends = vector_scalar_sub(cget(orfs,'end'),start_offset)
    starts,ends = [],[]
    for orf in orfs:
        starts.append(orf['start']-start_offset)
        ends.append(orf['end']-start_offset+ifab(orf['up'],1,-1))
    ups = cget(orfs,'up')

    for i in range(0,len(starts)):
        if starts[i] < 0: starts[i] = 0
        if ends[i] > gapless_length: ends[i] = gapless_length

    coords = coords_seq2ali(seq,flatten([starts,ends]))
    starts,ends = coords[:len(coords)/2],coords[len(coords)/2:]
    
    annotation_list = []
    for start,end,up,name in map(None,starts,ends,ups,names):

        if up: first,last = 'S->','E'
        else: first,last = 'E-','<-S'

        annotation_list.append((start,end,name,first,'=',last))

    return make_annotation_string(' '*len(seq),
                                  annotation_list)

def make_annotation_string(str, annotation_list):
    # anotation_list is [(name,start,end,'[','=',']'),
    #                    (name,start,end,...),...]
    str = map(None, str)
    for a in annotation_list:
        add_annotation_string(str,a)
    return string.join(str,'')

def add_annotation_string(str, a):
    # a is (name,start,end,'[','=',']')
    start, end, name, first_char, middle_char, last_char = a
    annotation = sfill(first_char+name[:end-start],
                       end-start, 
                       middle_char)
    if last_char: annotation = annotation[:-1]+last_char
    #print "Annotation is '%s'"%annotation
    sadd(str,annotation,start)

############################################################
#
# 2D plots
#
############################################################

def test_multi_plot():
    x = range(0,100)
    y = map(lambda x: x*x, x)
    
    multi_plot([[x,y,20,40],
                [y,x,30,20],
                [x,x,20,30],
                [y,y,10,10]],'')

def weird_plot():
    x = range(0,100)
    y = map(lambda x: x*x, x)
    z = map(lambda x: math.sin(x/4.0), x)
    z = map(lambda x: math.cos(x/4.0), x)
    z2 = map(lambda x: math.cos(x/10.0), x)
    z3 = map(lambda x: math.sin(x/10.0), x)
    lines1 = quick_plot_lines(x,y,18,18)
    lines2 = quick_plot_lines(y,x,18,18)
    lines3 = quick_plot_lines(z2+z3,reverse_list(x+x),40,40)
    lines4 = quick_plot_lines(z3+z2,x+x,40,40)
    
    lines = multi_plot_combine_lines([lines3,flatten([lines1,lines2]),lines4], ' ')
    map(lambda l: sys.stdout.write(l+'\n'), lines)

def quick_plot(xs,ys,width=60,height=30,logx=0,logy=0,minx=None,maxx=None,miny=None,maxy=None,f=sys.stdout):

    if minx!=None or maxx!=None or miny!=None or maxy!=None:
        xys = map(None,xs,ys)
        if minx!=None: xys = filter(lambda xy,minx=minx: xy[0]>=minx,xys)
        if maxx!=None: xys = filter(lambda xy,maxx=maxx: xy[0]<=maxx,xys)
        if miny!=None: xys = filter(lambda xy,miny=miny: xy[1]>=miny,xys)
        if maxy!=None: xys = filter(lambda xy,maxy=maxy: xy[1]<=maxy,xys)
        xs,ys = cget(xys,0),cget(xys,1)
    
    lines = quick_plot_lines(xs,ys,width,height,logx,logy)
    map(lambda l,f=f: f.write(l+'\n'), lines)

def multi_plot(x_y_w_hs,separator=' || ',f=sys.stdout):
    lines = multi_plot_lines(x_y_w_hs,separator)
    map(lambda l,f=f: f.write(l+'\n'), lines)

def multi_plot_lines(x_y_w_hs,separator):
    # first run quick_plot on each of the data sets
    all_lines = []
    for xs,ys,w,h in x_y_w_hs:
        lines = quick_plot_lines(xs,ys,w,h)
        all_lines.append(lines)
    return multi_plot_combine_lines(all_lines, separator)

def multi_plot_combine_lines(all_lines, separator): 
    # calculate all the widths
    formats = []
    for lines in all_lines:
        width = max(map(len,lines))
        format = '%-'+`width`+'s'
        formats.append(format)
    # then put them all together
    packed_lines = []
    for line_i in range(0,max(map(len,all_lines))):

        row = []
        for data_set in all_lines:
            
            if len(data_set) <= line_i:
                row.append('')
            else:
                row.append(data_set[line_i])
        packed_lines.append(row)

    # and now print them all out
    new_lines = []
    for line_set in packed_lines:
        to_print = []
        for col,format in map(None,line_set,formats):
            to_print.append(format%col)
        new_lines.append(string.join(to_print,separator))
    return new_lines
        
def quick_plot_lines(xs,ys,width,height,logx=0,logy=0):

    xs_orig, ys_orig = xs, ys

    if logx: xs = map(math.log, xs)
    if logy: ys = map(math.log, ys)

    # part I - constructing the empty bitmap
    rows = []
    for row in range(0,height): 
        rows.append([' ']*width)
    # part II - finding boundaries of data, and scaling factors
    minx, maxx = min(xs)-1, max(xs)+1
    miny, maxy = min(ys)-1, max(ys)+1
    scalex = (maxx-minx) / float(width-1)
    scaley = (maxy-miny) / float(height-1)

    # part III - filling in the appropriate pixels
    for x,y in map(None, xs, ys):
        # IIIa - converting (x,y) to (i,j)
        i = int((x-minx) / scalex)
        j = int((y-miny) / scaley)
        assert(i>=0 and j>=0)
        #print "(%s,%s) -> (%s,%s)"%(x,y,i,j)

        # IIIb - incrementing point count
        try: 
            rows[j][i] = chint_inc(rows[j][i])
        except:
            pass

    rows.reverse()

    # part IV - preparing the legend on the left
    legend = ['']*len(rows)
    legend[0]=`max(ys_orig)`
    legend[-1]=`min(ys_orig)`
    legend[len(legend)/2-1]='dy='
    legend[len(legend)/2]=`max(ys_orig)-min(ys_orig)`
    legwidth = max(map(len,legend))
    
    # part IV - displaying the filled in matrix
    lines = []
    lines.append(' '*legwidth+'^')
    for i in range(0,len(rows)): 
        lines.append(string.rjust(legend[i],legwidth)+'|'+string.join(rows[i],''))
    lines.append(' '*legwidth+'+%s>'%('-'*width))
    s = ' '*width
    s = sinsert(s,`minx`,0)
    s = sinsert(s,`max(xs_orig)`,width-len(`max(xs_orig)`))
    sub = 'dx='+`max(xs_orig)-min(xs_orig)`
    s = sinsert(s,sub,width/2-len(sub)/2)
    lines.append(' '*legwidth+' '+s)
    return lines

def int2chint(integer):
    # a chint is a character integer.  Representing up to 61 in a single
    # character, after 0-9, we get a-z (10-35) and then A-Z (36-61)
    #
    # this proceduce converts integers to chints
    if integer < 0:
        #raise "Only positives between 0 and 61", integer
        return '.'
    if integer > 61:
        #raise "Only positives between 0 and 61", integer
        return '#'
    #if integer == 1: return 'o'
    if integer < 10: return `integer`
    if integer < 36: return chr(ord('a')+integer-10)
    return chr(ord('A')+integer-36)

def chint2int(chint):
    # converting chints back to integers
    if chint==' ': return 0
    #if chint=='o': return 1
    if ord('A')<=ord(chint)<=ord('Z'): 
        return 36+ord(chint)-ord('A')
    if ord('a')<=ord(chint)<=ord('z'):
        return 10+ord(chint)-ord('a')
    return int(chint)

def chint_inc(chint):
    # increments a chint
    return int2chint(chint2int(chint)+1)

def test_chint2int():
    for i in range(0, 62):
        print '%2s = %s'%(i,int2chint(i))
        assert(chint2int(int2chint(i)) == i)

def roman2int(roman):

    chr_table = {'I': 1, 'II': 2, 'III': 3, 'IV': 4, 'V': 5,
                 'VI': 6, 'VII': 7, 'VIII': 8, 'IX': 9, 'X': 10,
                 'XI': 11, 'XII': 12, 'XIII': 13, 'XIV': 14,
                 'XV': 15, 'XVI': 16, 'Mito': 17}
    if roman in chr_table.keys():
        return chr_table[roman]
    else:
        return roman

def int2roman(int):

    chr_names = ['zero', 'I', 'II', 'III', 'IV', 'V',
                 'VI', 'VII', 'VIII', 'IX', 'X',
                 'XI', 'XII', 'XIII', 'XIV', 'XV',
                 'XVI']
    if int < len(chr_names):
        return chr_names[int]
    else:
        return int

def chr2int(chr):
    assert('A'<=chr<='Q')
    return ord(chr)-ord('A')+1
    
def int2chr(int):
    assert(1<=int<=16)
    return chr(ord('A')+int-1)
    

############################################################
#
# HISTOGRAMS
#
############################################################

def quick_histogram(values, num_bins = 20, width=80, height=20, logx=0,minv=None,maxv=None):

    on = '['+'o'*(width/num_bins-3)+']'
    def display_range(v1,v2,maxlen=len(on)):
        n = '%s-%s'%(v1,v2)
        if len(n) < maxlen:
            return n
        else:
            return '%s-%s'%(display_bignum(v1,0),display_bignum(v2,0))

    binvalues, binnames = bin_values(values, num_bins=num_bins, display_range=display_range,logx=logx,
                                     minv=minv,maxv=maxv)
    binvalues = map(len, binvalues)

    print_histogram(binvalues,binnames, 
    
                    scale = {'height': height,
                             'on': on,
                             'space': ' ',
                             'display_bottom': 'name', 
                             'display_top': 'value',
                             'autoscale': 1,
                             'front_gap': ''})

def bin_values(values, do_not_set=None, num_bins=None, increment=None, f=None,
               display_range=None, logx=0, logincrement=None, minv=None, maxv=None, name_tag='-'):
    # bins the values in n bins (n = num_bins)
    #
    # constructs the bin boundaries automatically, based on the min
    # and max values.  Returns a list of lists and a list of strings.
    #
    # For bin i:
    #  * binvalues[i] contains all the elements of values in that bin
    #  * binnames[i] contains the boundaries of the bin, ex: '3.7-4.2'
    #    that can be parsed with "[-]?\d+-[-]?\d+"

    if do_not_set: raise "Please explicitly name the arguments to bin_values"

    if (num_bins and (increment or logincrement)):
        raise "cannot specifiy both num_bins and increment",(num_bins,increment)
    elif (not num_bins and not increment and not logx):
        raise "Must specifiy either num_bins or increment",(num_bins,increment)
    elif (logx and not num_bins and not logincrement):
        raise "For LOG histogram, must specifiy either num_bins or logincrement",(num_bins,logincrement)

    if minv==None:
        if values: minv = min(values)
        else: minv=0
    if maxv==None:
        if values: maxv = max(values)
        else: maxv=0

    if f:
        objects = values
        values = map(f, objects)
    else:
        objects = None

    # if increment is not specified, recalculate it from the number of bins
    if logx:
        minlog = math.log(minv)
        maxlog = math.log(maxv)
        if num_bins: 
            logincrement = float((maxlog - minlog)) / float(num_bins)
            #print "Calculating LOG increment = %s from (min,max)=(%s,%s) and num_bins=%s"%(
            #    logincrement,  minlog, maxlog, num_bins)
        logrange = float_range(minlog, maxlog, logincrement)
        bins = map(safe_exp, logrange)
    else: 
        if num_bins:
            increment = float(maxv-minv) / float(num_bins)
            #print "Calculating increment = %s from (min,max)=(%s,%s) and num_bins=%s"%(
            #    increment,  minv, maxv, num_bins)
        bins = float_range(minv, maxv, increment)
    if bins and abs(bins[-1]-maxv)>.01:
        bins.append(maxv)

    if not bins:
        if display_range: names = [display_range(minv,maxv)]
        else: names = ['%s%s%s'%(min(values), name_tag, max(values))]
        return [values], names

    # first construct the names
    binnames = []
    for low,high in map(None, bins[:-1], bins[1:]):
        if display_range: binnames.append(display_range(low,high))
        else: binnames.append('%s%s%s'%(low,name_tag,high))

    # then alter the last bin, to make up for all the datapoints
    # that are exactly the maximum
    bins[-1] = bins[-1]+.0101

    # and construct the bins
    binvalues = []
    for low,high in map(None, bins[:-1], bins[1:]):
        #low,high=int(bin),int(bin+increment)
        if objects:
            v_o = map(None, values, objects)
            v_o_in = filter(lambda v_o,low=low,high=high: low<=v_o[0]<high, v_o)
            binvalues.append(map(lambda v_o: v_o[1], v_o_in))
        else: 
            binvalues.append(filter(lambda v,low=low,high=high: low<=v<high, values))

    return binvalues, binnames

def test_quick_bin():
    values = [3,1,2,1,4,5,6,1,2,3,18,-40,-32,2,500]
    the_range = [0,2,4,6,8,10,510]
    
    pp(quick_bin(values,the_range),1)
    pp(multi_count_lt(values,the_range))
    pp(multi_count_gte(values,the_range))

def multi_count_lt(values, cutoffs):
    bins = quick_bin(values, cutoffs, ignore_low_values=0)
    cum_count = [0]
    for cutoff,this_count in my_sort(bins.items()):
        cum_count.append(cum_count[-1]+len(this_count))
        res = items2dic(map(None,my_sort(bins.keys()),cum_count[:-1]))
    return mget(res,cutoffs)

def multi_count_gte(values, cutoffs):
    bins = quick_bin(values, cutoffs, ignore_low_values=0)
    cum_count = [0]
    for cutoff,this_count in my_sort_rev(bins.items()):
        cum_count.append(cum_count[-1]+len(this_count))
        res = items2dic(map(None,my_sort_rev(bins.keys()),cum_count[1:]))
    return mget(res,cutoffs)

def quick_bin(values, the_range, ignore_low_values=0):
    # input:
    #  values = [3,1,2,1,4,5,6,1,0,2,3,18,-40,-32,2,500]
    #  the_range = [0,5,10]
    # output:
    #  {'less': [-32,-40], 0: [0,0,0,1,1,2,2,3,3,3,4], 5: [5,6,10,18,500]}

    for i in range(0,len(the_range)-1):
        assert(the_range[i]<the_range[i+1])
    values = my_sort(values)

    bins = {}
    for bin in the_range:
        bins[bin] = []

    if not values: return bins

    use_less_key = 0
    mini = min(values)
    if mini<the_range[0]:
        the_range = [mini]+the_range
        use_less_key = 1
        bins[mini] = []

    i = 0
    for active_bin, cutoff in map(None,the_range,the_range[1:]+[max(values)+1]):
        #print "Active bin = %s for everything less than %s"%(active_bin, cutoff)
        while i < len(values) and values[i] < cutoff:
            #print "value[%s]=%s"%(i,values[i])
            bins[active_bin].append(values[i])
            i = i+1

    assert(my_sort(flatten(bins.values()))==values)

    if use_less_key and ignore_low_values:
        del(bins[mini])
    
    return bins
        
def print_histogram(values, names, scale=None):
    display_histogram(sys.stdout, values, names, scale)

def display_histogram(f, values, names, scale_in=None):

    if not values: 
        print "Empty histogram"
        return

    if not scale_in:
        scale = {'height': 6,
                 'on': '|',
                 'space': '',
                 'display_top': 'none',
                 'display_bottom': 'name',
                 'autoscale': 1,
                 'min': 0,
                 'max': 100,
                 'front_gap': ''}
    else: scale=scale_in

    # 0. Input choice dictates what to print at the top and bottom
    def values2names(values, maxlen):
        names = []
        for v in values:
            n = `v`
            if len(n) > maxlen:
                n = display_bignum(v,0)
            names.append(n)
        return names
        
    if not scale.has_key('display_top') or string.lower(scale['display_top']) == 'none':
        names_top = ['']*len(values)
    elif scale['display_top'] == 'name': names_top = names[:]
    elif scale['display_top'] == 'value': names_top = values2names(values,len(scale['on']))#map(lambda v: `v`,values)
    else: raise "scale['display_top'] not one of name,value,none: ", scale['display_top']

    if not scale.has_key('display_bottom') or string.lower(scale['display_bottom']) == 'none':
        names_bottom = ['']*len(values)
    elif scale['display_bottom'] == 'name': names_bottom = names[:]
    elif scale['display_bottom'] == 'value': names_bottom = values2names(values,len(scale['on']))#map(lambda v: `v`,values)
    else: raise "scale['display_bottom'] not one of name,value,none: ", scale['display_bottom']

    # 1. first construct that matrix of characters
    rows = []
    for h in range(0,scale['height']+2):
        row = []
        for v in values:
            row.append(' '*len(scale['on']))
        rows.append(row)
    if scale['autoscale']:
        scale['min'] = min(values)-1
        scale['max'] = max(values)
    if scale['min'] > min(values): raise "Min specified too high",(values,scale['min'])
    if scale['max'] < max(values): raise "Max specified too low",(values,scale['max'])

    # 2. then insert all the bars of the histogram
    for j in range(0,len(values)):
        v = values[j]
        # do the appropriate rescaling
        bottom = 0
        try:
            top = int((float(v-scale['min'])/float(scale['max']-scale['min']))*scale['height'])
        except ZeroDivisionError:
            top = scale['height']
        #print "%s -> Top %s"%(v,top)
        
        for i in range(bottom, top):
            rows[i][j] = scale['on']
        rows[top][j] = '_'*len(scale['on'])
        rows[top+1][j] = string.center(names_top[j],len(scale['on']))[:len(scale['on'])]

    # 3. finally output the histogram
    rows.reverse()
    for row in rows:
        f.write(scale['front_gap']+string.join(row, scale['space'])+'\n')
    f.write(scale['front_gap']+string.join(map(lambda name_bottom,scale=scale:
                                               string.center(name_bottom,
                                                             len(scale['on']))[:len(scale['on'])],
                                               names_bottom),
                                           scale['space'])+'\n')

def print_wrap_histogram(values, names, w, scale=None):
    display_wrap_histogram(sys.stdout, values, names, w, scale)

def display_wrap_histogram(f, values, names, w, scale=None):
    if scale['autoscale']:
        scale['min'] = min(values)
        scale['max'] = max(values)
        scale['autoscale'] = 0
    for i in range(0,len(values),w):
        display_histogram(f,values[i:i+w],names[i:i+w],scale)

def test_histogram():
    display_wrap_histogram(sys.stdout,
                           [100,2,300,400,200,300,1000],
                           'abcdefg',
                           7,
                           scale={'height': 10,
                                  'on': '||',
                                  'space': '',
                                  'autoscale': 1,
                                  'front_gap': '',
                                  'display_top': 'none',
                                  'display_bottom': 'name'})

############################################################
# guess a distribution center and spread
#  
#  def guess_center(values, divisions = 20):
#  
#      bins = bin_values(values, num_bins=20)
#      i = 0
#      while i<len(values):
#          if 5*len(bins[i]) <= len(values):
#              # this can erase a bin, even in the middle of the distribution
#              del(bins[i])
#          else:
#              i = i+1
#      values = flatten(bins)
#  
#      return guess_center(values, divisions)

def guess_center(values, do_not_set=None, num_bins=None, increment=None, fraction=5):
    debug = 1

    
    if len(values) == 1: return values[0]

    if debug:
        print "Guessing center of %s values, with num_bins=%s and increment=%s. Mean %s, Stdev: %s"%(
        len(values), num_bins, increment, avg(values), stdev(values))

    if do_not_set: raise "Please name arguments to guess_center"
    if num_bins and not increment:
        if debug:quick_histogram(values,num_bins=num_bins)
        bins,binnames = bin_values(values, num_bins=num_bins)
    elif increment and not num_bins:
        if debug: quick_histogram(values,num_bins = 20)#len(range(min(values),max(values),increment))+1)
        bins,binnames = bin_values(values, increment=increment)
    elif increment and num_bins:
        raise "Do not set both num_bins and increment"
    else:
        raise "Please set either num_bins or increment"
        
    # keep the bin with the max length
    #pp(map(len,bins),1)
    max_bin = argmax(map(len, bins))
    print "Max bin is %s/%s and it contains %s/%s items"%(max_bin,len(bins),len(bins[max_bin]),len(values))
    print "The other bins contain %s"%map(len,bins)
    #print "Max bin is %s has %s items"%(max_bin,len(bins[max_bin]))
    # see if the next and previous have such strong distributions
    min_i, max_i = max_bin, max_bin
    # see if we can extend to the left
    while min_i>=1:
        if fraction * len(bins[min_i-1]) >= len(bins[min_i]):
            # include one more bin to the left
            min_i = min_i - 1
            if debug: print "Extending to the left.  Now from %s/%s to %s/%s, there's %s items"%(
                min_i, len(bins), max_i, len(bins), sum(map(len, bins[min_i:max_i+1])))
        else: 
            break
    while max_i<len(bins)-1:
        if fraction * len(bins[max_i+1]) >= len(bins[max_i]):
            # include one more bin to the right
            max_i = max_i + 1
            if debug: print "Extending to the right.  Now from %s/%s to %s/%s, there's %s items"%(
                min_i, len(bins), max_i, len(bins), sum(map(len, bins[min_i:max_i+1])))
        else: 
            break
    values_in = flatten(bins[min_i:max_i+1])
    if debug: quick_histogram(values_in, num_bins=20)
    return avg(values_in)

# category 4
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       FILES and SAVING                   ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def binsave(file, data):
    # binsave cPickle wrapper
    f = open(file, 'w')
    cPickle.dump(data, f, 1)
    f.close()

def binload(file):
    # binload cPickle wrapper
    f = open(file, 'r')
    data = cPickle.load(f)
    f.close()
    return data
    
############################################################
#
# file stuff
#
############################################################

def quick_system(cmd, stdin):

    # make a temp file for the standard in
    import tempfile
    fn_in = tempfile.mktemp('.tmp')
    #print fn_in
    #fn_in = '/seq/manoli/yeast/tmp/stdin'
    open(fn_in,'w').write(stdin)
    stdout = os.popen('%s < %s'%(cmd,fn_in),'r').read()

    return stdout

def decompose_filename(filename):

    # first find the directory
    if string.find(filename,'/')==-1:
        dir = ''
        fn = filename
    elif filename[-1] == '/': 
        dir = filename
        fn = ''
    else:
        path = string.split(filename,'/')
        dir = string.join(path[:-1],'/')+'/'
        fn = path[-1]

    # then find the extension
    if string.find(fn,'.') == -1:
        root,extension = fn,''
    else:
        path = string.split(fn,'.')
        root = path[0]
        extension = '.'+string.join(path[1:],'.')

    return dir, root, extension


def common_subdir(filenames):
    if not filenames: return '', filenames
    # make sure no filenames contain consecutive //
    for filename in filenames: 
        assert(string.find(filename,'//')==-1)
    # make sure all the paths are absolute
    assert(all_same(['/']+cget(filenames,0)))

    common = '/'
    for subdir in filter(None,string.split(filenames[0],'/')):
        newcommon = common+subdir+'/'
        #print "Trying %s"%newcommon
        all_common = 1
        for filename in filenames:
            if string.find(filename,newcommon)==0: continue
            else:
                all_common=0
                break
        if all_common: common = newcommon
        else: break

    #print "Longest common directory: %s"%common

    short_filenames = []
    for filename in filenames:
        assert(string.find(filename, common)==0)
        short_filenames.append(filename[len(common):])
        
    return common, short_filenames
    
def change_extension(filename, new_extension):
    #if new_extension[0] == '.': new_extension = new_extension[1:]
    extension = string.split(filename,'.')[-1]
    if not '/' in extension:
        filename = filename[:-1-len(extension)]
    return filename+new_extension

def opens(filename, mode='r'):
    # open superuser
    # if for reading, acts like open
    # if for writing, it will create every subdirectory
    # to the file, and then open it for writine

    if mode == 'r': return open(filename, mode)
    try:
        f = open(filename, mode)
    except IOError:
        mkdirs(string.join(string.split(filename,'/')[:-1],'/'))
        f = open(filename, mode)
    return f

def mkdirs(filename, num_tries=2):
    # make directory superuser
    # will make every directory on the way to the
    # given filename
    subdirs = string.split(filename,'/')
    for i in range(2,len(subdirs)+1):
        partial = string.join(subdirs[:i],'/')
        if not os.path.exists(partial):
            try: os.mkdir(partial)
            except OSError:
                if os.path.exists(partial): continue
                else: raise
    if not os.path.exists(filename):
        mkdirs(filename, num_tries-1)

def is_directory(filename):
    return stat.S_ISDIR(os.lstat(filename)[stat.ST_MODE])

def filesize(filename):
    return os.lstat(filename)[stat.ST_SIZE]

def cdate(filename):
    return time.ctime(os.lstat(filename)[stat.ST_CTIME])

def non_empty(filename):
    return os.path.exists(filename) and filesize(filename)!=0

#def file_exists(filename): 
#    return os.path.exists(filename) and filesize(filename)!=0
file_exists = non_empty

def touch(filename):
    open(filename,'w').close()

#def rm(filename):
#    os.remove(filename)

def empty_file(filename):
    if not os.path.exists(filename): return 1
    else: return filesize(filename)==0

def exists_but_empty(filename):
    return os.path.exists(filename) and filesize(filename)==0

def quote_filename(filename):
    return mreplace(filename,['(',')',' '],['\(','\(','\ '])

############################################################

def space_usage(file):
    # takes a file generated by a space usage command
    # and then sorts it by size
    lines = open(file,'r').readlines()

    for i in range(0, len(lines)):

        lines[i] = filter(None, string.split(lines[i]))
        lines[i][0] = int(lines[i][0])
        
    lines.sort(lambda l1,l2: l2[0] - l1[0])
    return lines

def display_usage(lines): 

    for line in lines:
        if string.count(line[1], '/') > 1: continue
        print '%14s %s'%(dec(line[0]), line[1])

############################################################

def split_command_file(filename, increment):
    #
    # Input:  a file containing a number of one line commands
    #         that are independent of each other (i.e. execution
    #         order doesn't matter
    # Action: creates a number of files, all with the same
    #         filename_i, and a source file: filename_sourcem
    #         which when sourced will send every subfile to lsf
    # Use: when you don't feel like submitting 30000 commands to
    #      LSF, but also, you don't feel like waiting for them
    #      to be executed one after the other. 

    # the commands to run
    commands = filter(lambda cmd: cmd[0]!='#',
                      open(filename,'r').readlines())

    print "Splitting %s.  %s commands in increments of %s.  %s simultaneous_files."%(
        filename, len(commands), increment, len(commands)/increment)
    
    # open the masterfile
    file_master = open(filename+'_sourcem','w')

    i = 0
    for command_subset in range(0,len(commands),increment):
        
        i = i + 1
        file_i = open(filename+'_'+`i`,'w')
        for command in commands[command_subset:][:increment]:
            file_i.write(command)
        file_i.close()
        os.chmod(file_i.name, 0777)

        file_master.write('bsub %s_%s &\n' % (filename, i))

    file_master.close()
    os.chmod(file_master.name, 0777)

    print "Done splitting.  Now source: %s_sourcem"%filename

    return filename

############################################################
#
# HTML files
#
############################################################

def html_eliminate_tags(s):
    # eliminates all matching HTML tags from a string
    out = []

    first_last = html_find_tags(s)
    if not first_last: return s
    last_first = list2pairs(first_last,
                            lambda a,b: (a[1],b[0]))
    last_first.insert(0,(0,first_last[0][0]))
    last_first.append((first_last[-1][1],len(s)))
    for first,last in last_first:
        out.append(s[first:last])
    return string.join(out,'')
        
def html_find_tags(s):
    # returns the start and end of every HTML tag
    tags = []
    last = 0
    while 1:
        first = string.find(s,'<',last)
        if first==-1: break
        last = string.find(s,'>',first)
        if last==-1: break

        # make sure you're not erasing everything
        # between inequality signs
        ok=0
        if s[first:first+2] == '<a ': ok=1
        if s[first:first+5] == '<img ': ok=1
        if string.find(s[first:last+1],'\n')==-1: ok=1
        if last+1-first < 8: ok=1
        if not ('a'<=s[first+1]<='z' or 'A'<=s[first+1]<='Z' or s[first+1] in '/!'): ok=0
        #if ok: print "HTML:Eliminated %s"%s[first:last+1][:10]
        #else: print  "HTML:Not a tag  %s"%s[first:last+1][:10]
        
        if ok: tags.append((first,last+1))
        else: last = first+1
    return tags

def html_generate_listing(f, dir, level=2):
    import glob
    if dir[-1]!='/': dir = dir+'/'
    files = my_sort(glob.glob(dir+string.join(['*']*level, '/')))
    common, files = common_subdir(files)
    #pp(files)

    f.write('<html>\n<head>\n<title>Directory Listing of %s</title>\n</head>\n</html>\n'%dir)
    f.write('<body>\n')

    f.write('<table border=1>\n')
    f.write('<tr><th>Directory</th><th>Filename</th><th>Type</th><th>Size</th><th>Comment</th></tr>')
    lastdir = ''
    for file in files:
        path = string.split(file,'/')
        thisdir = string.join(path[:-1], '/')
        if thisdir!=lastdir:
            f.write('<tr><td><b>%s</b></td>'%thisdir)
            lastdir=thisdir
        else: 
            f.write('<tr><td></td>')
        #f.write('<td></td>')

        
        typed = string.split(path[-1],'.')
        if is_directory(common+file):
            filename,type = path[-1],'dir'
            sizeinfo = '%s files'%len(glob.glob(common+file+'/*'))
            
        else:
            if len(typed)>1:
                filename,type = string.join(typed[:-1],'.'), typed[-1]
                # if len(filename) < len(type) and len(type)>5: filename,type = filename+'.'+type, ''
            else: filename,type = path[-1],''
            sizeinfo = '%s kb'%dec(filesize(common+file)/1000)
        
        f.write('<td><a href="%s">%s</a></td><td align=center>%s</td><td align=right>%s</td><td>None</td>'%(
            file,filename,type,sizeinfo))
        
        f.write('</tr>\n')
    
    
    f.write('</table>')
    f.write('</body>\n</html>\n')
    

# category 5
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       FASTA FILES AND FILE POINTERS      ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  Reading a portion of a file, until a separator is met
#
############################################################

def parse_tab(fn,debug=0):
    f = open(fn,'r')
    lines = filter(None,string.split(string.replace(f.read(),'\015',''),'\n'))
    f.close()
    lines = filter_diclist_not(lines,0,'#')
    return parse_tab_lines(lines,debug=debug)

def parse_tab_lines(lines,debug=0): 
    titles = string.split(lines[0],'\t')
    parsed = []
    for line in lines[1:]:
        lined = {}
        mset(lined,titles,string.split(line,'\t'))
        parsed.append(lined)

    if debug:
        f = sys.stdout
        maxlens = {}
        for key in titles:
            maxlens[key] = max([len('%s'%key)]+map(len,cget(parsed,key)))

        for title in titles:
            f.write(('%-'+`maxlens[title]`+'s|')%title)
        f.write('\n')
        for title in titles:
            f.write('-'*maxlens[title]+"|")
        f.write('\n')
        for line in parsed: 
            for key in titles: 
                f.write(('%-'+`maxlens[key]`+'s|')%line[key])
            f.write('\n')
        
    return parsed

def output_tab(diclist, order=None, f=sys.stdout, separator='\t', default='', exclude=[], only_specified=0):

    keys = my_sort(unique(flatten(map(lambda d: d.keys(), diclist))))

    if only_specified: exclude = set_subtract(keys, order)

    if not order: order = keys
    else: order = set_intersect(order,keys) + set_subtract(keys,order)

    # make sure the separator does not appear as a key
    for key in order:
        assert(string.find(key,separator)==-1)
    assert(string.find(default,separator)==-1)

    order = set_subtract(order,exclude)

    f.write(string.join(order,separator)+'\n')
    for dic in diclist:
        fields = []
        for key in order:
            if dic.has_key(key): fields.append('%s'%dic[key])
            else: fields.append(default)
        f.write(string.join(fields,separator)+'\n')

def read_until(f, separator, increment=500):
    #
    # Reads from file f until separator is reached
    # at every call, it returns the string between separators.
    # when the end of file is reached, it returns -1
    #
    # increment characters are read iteratively from the file,
    # until separator is found
    #
    # Ex:  if the file is:  BABOON BAD BABAK and separator is 'BA'
    # consecutive calls will return '', 'BOON ', 'D ', '', 'K', -1

    # so that you only have to look back once
    if increment < len(separator):
        #print "Increasing increment size to match separator"
        increment = len(separator)
    
    # when the file is done, it returns -1
    # it can return an empty string if separator appears twice
    
    current = []
    while 1:

        #print f.tell()
        last_read = f.read(increment)
        #print "Reading %s" % last_read

        # done reading file, return last portion
        if not last_read:
            if len(string.join(current,'')) > 0: 
                break
            else:
                return -1

        #print "Read: %s"%last_read
        #pp(current)

        # append last increment to currently read portion
        current.append(last_read)

        # search for motif in this portion
        found = string.find(current[-1], separator)

        # if found in this portion
        if found > -1:

            # restore f to the position before the increment
            f.seek(-len(current[-1])+found+len(separator),1)

            # chop off the portion after your marker
            current[-1] = current[-1][:found]

            break

        else:

            # now check if it falls on the boundary between two read iterations

            if len(current) == 1:
                # if it's not found in the last_read and current only contains one read iteration,
                # then it's just simply not there
                continue
            else:
                span_gap = current[-2][-len(separator):]+current[-1][:len(separator)]
                attempt2 = string.find(span_gap, separator)
                if attempt2 > -1:

                    # measure how many to chop from previous last_read
                    to_chop = len(separator)-attempt2
                    current[-2] = current[-2][:-to_chop]
                    del(current[-1])

                    # set the file pointer back accordingly
                    f.seek(-(increment-len(separator)+to_chop),1)

                    break
                    
    # return all you've gathered so far
    return string.join(current,'')

def test_read_until(separator = '|||'):
    f = open('/seq/manoli/yeast/tmp/separator.txt','r')
    while 1:
        txt = read_until(f, separator, 6)
        if txt !=-1: 
            print `txt`
        else:
            break
    f.close()

############################################################
#
#  How to read into a file, without loading it all in memory
#
############################################################

def get_separator_positions(filename, separator, increment=500): 
    # get all the start positions of the separator
    # to be used with f.seek(start), f.read(len(separator))

    f = open(filename, 'r')

    # make sure we'll find separator in increment some day
    if len(separator) > increment: increment = len(separator)

    where = []
    while 1:

        # read a line
        line = f.read(increment)

        #print "Line is %s"%line

        # check end of file
        if len(line) < len(separator):
            break

        # see if you hit the separator
        found = string.find(line, separator)

        # if found, where was it found
        if found > -1:

            # the absolute start and end of the separator
            start = f.tell() - len(line) + found
            end   = start + len(separator)

            #print "Found it at %s - %s"%(start, end)

            where.append((start,end))

            # now, restart at the end of the separator
            f.seek(end, 0)

        else:

            # backtrack a bit to see if we just missed the separator
            f.seek(-len(separator)+1, 1)

    # post process all the starts and ends
    if where: 
    
        f.seek(0,2)
        between = [(0,where[0][0])]
        for i in range(1,len(where)):
            between.append((where[i-1][1], where[i][0]))
        between.append((where[-1][1], f.tell()))

    else:

        between = [(0,f.tell())]

    f.close()

    return between

def use_separator_positions(filename, offsets):
    f = open(filename, 'r')
    for start,end in offsets:
        f.seek(start, 0)
        line = f.read(end-start)
        print line
    f.close()

def test_separator_positions(filename = '/seq/manoli/yeast/tmp/separator.txt',
                             separator = '||||'):
    pos = get_separator_positions(filename, separator)
    pp(pos)
    use_separator_positions(filename, pos)

############################################################
#
# The above, more precisely for FASTA files
#
############################################################

def make_random_fasta(filename, num_seqs, min_length, max_length, wrap=60):
    import blast, random, simulation
    f = open(filename,'w')
    for i in range(0,num_seqs):

        length = random.choice(range(min_length, max_length+1))
        
        blast.append_fasta(f,
                           'seq%s(%sbp)'%(i,length),
                           simulation.random_seq(length))

    f.close()

def make_test_fasta(filename, num_iterations, before, after, wrap=60):
    import blast, simulation
    f = open(filename,'w')
    for full_lines in range(1,num_iterations):
        base_len = 60*full_lines
        for i in range(-before,after):
            blast.append_fasta(f,
                               'seq%s_%s'%(full_lines,i),
                               simulation.random_seq(base_len+i))
    f.close()
            
    


def test_fasta_dictionary(filename):
    make_test_fasta(filename,3,1,1,10)
    #make_random_fasta(filename, 5, 80, 180, 60)

    #print open(filename,'r').read()

    f = open(filename,'r')
    dic = fasta_dictionary(f.name)

    #print "FASTA dictionary"
    #pp(dic,1)

    names,seqs = multiparse_fasta(f.name)
    seq_dic,len_dic = {},{}
    for name,seq in map(None, names, seqs):
        name = string.strip(name)
        seq_dic[name] = seq
        len_dic[name] = len(seq)
        #print "'%s'"%seq
    #print "Names, seqs: "
    #pp(seq_dic)
    #pp(len_dic)

    assert(my_sort(names) == my_sort(set_subtract(dic.keys(),['>width','>fn','>f'])))

    #print "%15s:%s, %s"%('name','true','fdic')
    for name in names:
        #print "%s: %3s, %3s%s"%(
        #    sfill(name,15),len_dic[name],dic[name][1],
        #    ifab(dic[name][1] == len_dic[name],'',' <- different'))
        assert(dic[name][1] == len_dic[name])

        
    #pp(len_dic)
    #pp(dic,1)
    for key in dic.keys():
        if key[0]=='>':
            #print "Skipping %s == %s"%(key,dic[key])
            continue
        start,length = dic[key]

        #print ">%s"%key
        for low in range(-10,length): #range(-10,length):
            for high in range(low-10, length+3): #range(low-10,length+3): 
                seq1 = get_sequence(f,dic,key,low,high,up=1)
                seq2 = seq_dic[key][low:high]
                if seq1!=seq2:
                    print 'Get sequence (%s, %s, %s)'%(
                        key, low, high)
                    print ' '*low+seq1+ifab(seq1==seq2, ' <- correct',
                                            ' <- OOPS '+seq2)

def fasta_dictionary(filename):

    debug = 1
    
    # first get all the starts of fasta entries
    positions = get_fasta_positions(filename)

    # now construct a dictionary which maps sequence names to positions
    dic = {}

    # reopen the file
    f = open(filename, 'r')

    # find the width of the file
    w = verify_fasta_width(f)

    for start,end in positions:

        # go to the start and read the title
        f.seek(start, 0)
        title = string.strip(read_until(f,'\n',60))
        #title = string.strip(f.readline())

        # make sure all entries are unique
        if dic.has_key(title):
            raise "Duplicate fasta title", title

        # now add an entry to the dictionary for the sequence start
        seq_start = f.tell()
        #if debug: 
        #    print "Now i'm at: %s"%f.read(20)
        #    f.seek(seq_start,0)

        # calculate the length of the sequence
        seq_length = end - seq_start - num_wraps_forward(end-seq_start, w)

        #if debug: 
        #    n = num_wraps_forward(end-seq_start, w)
        #    print "%s: total chars: %s, num wraps: %s, length: %s"%(
        #        title,end-seq_start,n,seq_length)
            

        # map the title to the start and end of the sequence
        dic[title] = (seq_start, seq_length)

    dic['>width'] = w
    dic['>fn'] = filename
    dic['>f'] = None

    return dic

def num_wraps_old(length, w):
    # count how many wraps will be, in a sequence of length length,
    # wrapped every width w
    if length <= 0: return 0
    return (length-1) / w

def num_wraps_forward(length, w):
    # count how many wraps there have been,
    # in a sequence of total length length (including wraps),
    # wrapped every width w
    # and what is the length of the sequence before inclusion of wraps
    if length <= 0: return 0
    return length / (w+1)

def num_wraps_reverse(length, w):
    # count how many wraps will be, in a sequence of length length,
    # wrapped every width w
    if length <= 0: return 0
    return (length-1) / w

def smart_wraps(start, end, w):
    # in a sequence wrapped every w characters,
    # it gives the number of wraps between start and end
    return num_wraps_reverse(end,w)-num_wraps_reverse(start,w)

def get_fasta_positions(filename):
    f = open(filename,'r')
    if f.read(1) != '>': raise "Not a fasta file", filename
    pos = get_separator_positions(filename, '\n>', 500)

    # adjust for the fact that the first entry does not have a carriage return before the >
    pos[0] = (1, pos[0][1])

    fix_fasta_positions(f,pos)

    return pos

def fix_fasta_positions(f, positions):
    # makes sure that all the starts are pointing to ">title"
    # and all the ends are pointing to non-whitespace characters
    
    for i in range(0,len(positions)):
        start,end = positions[i]

        f.seek(start-1,0)
        assert(f.read(1)=='>')

        repeat = 1
        while repeat: 
            f.seek(end-1,0)
            if f.read(1) in ' \n':
                end = end-1
                #print "Fixed end!"
            else: 
                repeat = 0

        positions[i] = (start,end)

def test_get_fasta_positions(filename):
    positions = get_fasta_positions(filename)
    f = open(filename,'r')
    fix_fasta_positions(f,positions)
    for start, end in positions:
        print "(%s,%s)"%(start,end)
        f.seek(start,0)
        print "Reading %s (end-start): '%s'"%(end-start,f.read(end-start))


def guess_fasta_width(f):
    where = f.tell()
    f.seek(0,0)
    #######
    # Look for the first two consecutive non-title lines
    first_line = f.readline()[:-1]
    while 1:
        second_line = f.readline()[:-1]
        if second_line[0] != '>' and first_line[0] != '>':
            w = len(first_line)
            break
        else:
            first_line = second_line
    f.seek(0,where)
    return w

def verify_fasta_width(f):

    w = guess_fasta_width(f)

    # assumes that every title only takes a single line w/o line breaks
    f.seek(0,0)

    # make sure every sequence line is w or less right before a title
    next_must_be_title = 1
    i = 0
    while 1:
        i = i+1
        line = f.readline()[:-1]
        if not line: break
        if next_must_be_title and line[0] != '>':
            # i'm expecting a title, but i don't get one
            raise "Line %i should be a title"%i
        elif len(line) == w or line[0] == '>':
            # i've gotten a sequence of length exactly w
            # or i've gotten a title, expecting it or not
            next_must_be_title = 0
        else:
            if len(line) > w: raise "Line %i is too long"%i
            if len(line) < w: next_must_be_title = 1
    f.seek(0,0)
    return w

def fix_fasta_width(filename_in, filename_out=None, newwidth=None):

    # choose a filename_out that doesn't already exist
    if not filename_out:
        replace = 1
        i = 0
        while 1:
            i = i+1
            filename_out = filename_in+`i`
            if os.path.exists(filename_out):
                continue
            else:
                break
    else:
        replace = 0


    f1 = open(filename_in,'r')
    f2 = open(filename_out,'w')

    # if the newwidth is not specified, then simply guess it
    if not newwidth: 
        newwidth = guess_fasta_width(f1)

    import blast

    while 1: 
        name,seq = incremental_fasta(f1)

        if (name and seq): blast.append_fasta(f2,name,seq,newwidth)
        else: break

    f1.close(), f2.close()

    if replace:
        os.rename(filename_out, filename_in)

def get_sequence(f, dic, title, start, end, up=1, safe=0):
    # f is an open file
    # dic is the fasta positions dictionary returned by fasta_dictionary(f.name)
    # title is the FASTA title of the sequence you're interested in (say, chromosome name)
    # start, end are the subsequence coordinates (say within a chromosome)
    # up is the strand

    #print "Reading file %s entry %s coords %s - %s"%(
    #    f.name, title, start, end)

    title = string.strip(title)

    if not dic.has_key(title):
        raise "Title not present in fasta_dictionary for %s"%f.name, title
    
    # read the dictionary for positional information
    seq_start, seq_length = dic[title]

    # check that you're not reading more than allowed by the sequence boundaries
    if start < 0: start = seq_length + start
    if end < 0: end = seq_length + end
    if start > seq_length: return ''
    if start > end: return ''
    if end > seq_length: end = seq_length

    #######
    # method 1: only read that portion of the file
    #
    #  will not work because of line feeds -> must assume consistent line feeds
    #  one could check consistency before generating fasta_dictionary
    #
    #f.seek(seq_start + start - 1)
    #seq = f.read(seq_len)

    ######
    #  method 2: read all, and only return interesting portion
    w = dic['>width']
    f.seek(seq_start+start+num_wraps_reverse(start,w))
    seq = f.read(end-start+smart_wraps(start,end,w))
    try:
        seq = string.replace(seq,'\n','')
    except MemoryError:
        seq = ''

    # reverse complement
    if not up: seq = revcom(seq)

    if safe:
        assert(sum(map(lambda char: string.count(seq,char),'ACGTNacgtn'))==len(seq))

    return seq

def fasta_dic_pos2seq(filename, dic):
    f = open(filename,'r')
    seq_dic = {}
    for key,(start,end) in dic.items():
        seq_dic[key] = string.strip(get_seq(f, start, end))
    return seq_dic

def get_seq(f, start, end):
    f.seek(start)
    return f.read(end-start)

############################################################
#
# An older version of this
#
############################################################

def get_fasta_entry(f, title, reset=1):
    # if reset is true, the pointer is reset and file searched
    # from the beginning
    if reset: 
        f.seek(1,0)
    else:
        #print 'WARNING! only searching from position %s on'%f.tell()
        pass
    while 1:
        seq = read_until(f, '\n>', 500)
        if seq == -1: break
        if seq[:len(title)] == title:
            #print "Found! "+seq
            return string.join(string.split(seq,'\n')[1:],'')
    return -1

def get_fasta_entries(filename, titles, reset=0):
    "Assumes: filename contains all titles, in order"

    f = open(filename,'r')
    entries = []
    for title in titles:
        entries.append(get_fasta_entry(f, title, reset))
    f.close()
    return entries

############################################################
#
# Parsing FASTA files
#
############################################################

def singleparse_fasta(name):
    # returns only the sequence of a fasta file
    # assumes the file contains a single sequence
    return string.join(string.split(string.replace(open(name).read(),
                                                   '\015',''),
                                    '\n')[1:],
                       '')

def multiparse_fasta(file):
    # returns the titles and sequences of a fasta file
    # file can contain multiple sequences
    f = string.replace(open(file).read(),'\015','')
    groups = string.split(f, '>')[1:]
    titles = []
    seqs = []
    for group in groups: 
        lines = string.split(group, '\n')
        titles.append(lines[0])
        seqs.append(string.join(lines[1:],''))
    return titles, seqs

def multiparse_qual(file):
    # file can contain multiple sequences
    f = string.replace(open(file).read(),'\015','')
    groups = string.split(f, '>')[1:]
    titles = []
    quals = []
    for group in groups:
        lines = string.split(group, '\n')
        titles.append(lines[0])
        qual = []
        for line in lines[1:]: 
            qual.extend(map(int,string.split(line)))
        quals.append(qual)
    return titles, quals


def incremental_fasta(f):
    """ every time it's called, this function reads
     one sequence from a fasta file and positions the
     cursor to the beginning of the new sequence.

     End of file:  return None
     Assumes:  positioned at the beginning of the next fasta seq
     """

    # first parse the title
    #print "Starting at %s"%f.tell()
    line = f.readline()
    #print "Read title: Now at %s"%f.tell()
    if not line: return '',''
    if line[0]!='>': raise "First line is not a valid FASTA title", line
    else: name = string.strip(line[1:-1])
    
    # then parse the sequence
    seq = []
    while 1:
        line = f.readline()
        #print "Read sequence: now at %s"%f.tell()
        if not line:
            # done reading, end of file
            break
        if line[0] == '>':
            # done reading, next title
            # backtrack
            f.seek(-len(line),1)
            #print "Backtracking: now at %s"%f.tell()
            break
        seq.append(string.strip(line))
    return name, string.join(seq, '')

def incremental_qual(f):
    """ reads one sequence from a fasta file
     and positions the cursor to the beginning of the
     new sequence.

     End of file:  return None
     Assumes:  positioned at the beginning of the next fasta seq
     """

    line = f.readline()
    if not line: return '',[]
    if line[0]!='>': raise "First line is not a valid FASTA title", line
    else: name = string.strip(line[1:-1])
    
    # then parse the quality scores
    seq = []
    while 1:
        line = f.readline()
        if not line: break
        if line[0] == '>':
            f.seek(-len(line),1)
            break
        seq.append(string.strip(line))
    return name, map(int, string.split(string.join(seq, ' ')))

def get_fasta_lengths(file):
    titles, seq = multiparse_fasta(file)
    lengths = {}
    mset(lengths,titles,map(len, seq))
    return lengths

############################################################
#
# Comparing FASTA files
#
############################################################

def compare_with_screen(species):
    #filename = '/seq/manoli/yeast/yeastSeqs/%sPhrapData'%species
    filename = '/seq/comparative02/%s/sequal_dir/%sPHRP'%(species,species)
    return compare_fasta([filename,filename+'.screen'])

def compare_fasta(files):
    """
    Input: two or more fasta files that all have the same order in their reads,
           possibly masked with different masking programs (ex: vector, repeat)
    Output: the list of fasta entries that mismatch, and the number of characters
            that differed in each case. 
    """
    
    fs = map(open, files)
    inconsistencies = []
    i = 0
    while 1:
        i = i+1
        name_seq_list = map(incremental_fasta, fs)
        names = map(lambda ns: ns[0],name_seq_list)
        seqs = map(lambda ns: string.upper(ns[1]),name_seq_list)

        if not names[0]:
            break

        for name in names: 
            if name != names[0]:
                print "%s fasta entries considered"%i
                raise "Files out of sync", (name, names[0])
        for seq in seqs: 
            if seq != seqs[0]:
                # count the number of bases that are different
                num_mismatches = 0
                for chars in unpack(seqs):
                    if not all_same(chars): 
                        num_mismatches = num_mismatches + 1

                # display mismatching message
                print "%19s: %3s / %3s mismatches (%2s%%)"%(names[0], num_mismatches, len(seqs[0]),
                                                            100*num_mismatches/len(seqs[0]))
                #print_wrap(seqs, 120, names)

                # append
                inconsistencies.append((names[0], num_mismatches))

                break
    map(lambda f: f.close(), fs)
    
    if inconsistencies:
        print "%s fasta entries were different among %s"%(len(inconsistencies),
                                                          string.join(files, ', '))
    else:
        print "Files are identical"
    return inconsistencies

# category 6
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       STRING OPERATIONS                  ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def sfill(s, length, fill_char = '.'):
    #  Appends fill_char to the string s until it reaches length length
    #  ex:  sfill('hello',18,'.') -> hello...............
    #                                <---  18 chars  --->
    # useful for printing dictionaries in a cute way
    #    one......: 1
    #    five.....: 5
    #    seventeen: 17


    #list = map(None, s)
    #list.extend(map(None, fill_char*(length - len(list))))
    #return string.join(list, '')

    return s + fill_char*(length-len(s))


def sinsert(s, sub, start):
    # inserts substring 'sub' into string 's' from i on
    # ex: sinsert('ATGCGGATATAT','*->',3) -> ATG*->ATATAT
    # useful for adding annotations
    s = map(None, s)
    sadd(s,sub,start)
    return string.join(s, '')

def sadd(s,sub,start): 
    for i in range(0, min(len(sub),len(s)-start)):
        s[i+start] = sub[i]

def mcount(s, chars):
    # sums the counts of appearances of each char in chars
    count = 0
    for char in chars:
        count = count+string.count(s,char)
    return count

def mreplace(s, chars, targets):
    # the advantage of mreplace over translate_chars (see below)
    # is that both chars and targets can be strings, not just chars

    # ex: mreplace('mississipi',['ssi','mi','pi'],['ze','mme','']) -> mmezizi

    # note: the output depends on the order the chars are specified in
    #       moreover, the output of one translation can be changed by
    #       the next translation

    # useful for: mreplace('ACG|-.AAGCG','-.|',['','','']) -> 'ACGAAGCG'
    
    assert(len(chars)==len(targets))

    if sum(map(lambda char,s=s: string.count(s,char), chars)) == len(s):
        return ''
    
    for char,target in map(None, chars, targets):
        s = string.replace(s,char,target)
    return s

def safe_replace(s,old,new):
    if not s: return s
    if not new and string.count(s,old)*len(old)==len(s): return ''
    return string.replace(s,old,new)

def msplit(seq, separators):
    # identical to string.split(seq,gap),
    # only it chops at any of the characters

    # turn the separator list into a dictionary for
    # constant lookup time
    separator_d = {}
    mset(separator_d,separators,[None]*len(separators))

    starts, ends = [0], []
    inside = 1
    for i in range(0,len(seq)):

        # if i'm inside a good stretch, and i find a separator
        if inside and separator_d.has_key(seq[i]):
            inside = 0
            ends.append(i)

        # if i'm outside a good stretch, and i find a non-separator
        if not inside and not separator_d.has_key(seq[i]):
            inside = 1
            starts.append(i)

    if not inside: starts.append(len(seq))
    ends.append(len(seq))
    
    subseqs = []
    for start, end in map(None, starts, ends):
        subseqs.append(seq[start:end])
    return subseqs

def substrings_including(seq, chars):
    # returns the list of substrings of seq that contain characters in chars
    char_d = {}
    mset(char_d,chars,[None]*len(chars))

    binaries = map(char_d.has_key, seq)
    islands = count_islands(binaries)
    if not islands.has_key(1): return []
    else: islands = islands[1]
    subseqs = []
    for start,end in islands:
        subseqs.append(seq[start-1:end])
    return subseqs

def substrings_excluding(seq, chars):
    # returns the list of substrings of seq that do not contain characters in chars
    char_d = {}
    mset(char_d,chars,[None]*len(chars))

    binaries = map(char_d.has_key, seq)
    islands = count_islands(binaries)
    if not islands.has_key(0): return []
    else: islands = islands[0]
    subseqs = []
    for start,end in islands:
        subseqs.append(seq[start-1:end])
    return subseqs

def translate_chars(s,chars,targets):
    # make a dictionary and use it
    # ex: translate_chars('mississipi',['i','s'],['e','t']) -> mettettepe
    # ex: translate_chars('mississipi','is','et') -> mettettepe
    # note only char->char substitutions are possible. For multiple
    # substring -> substring substitutions, one can use mreplace
    return string.translate(s,string.maketrans(chars,targets))

def multi_find(seq, list_of_queries,i=0):
    best = -1
    for q in list_of_queries:
        index = string.find(seq, q,i)
        if index > -1:
            if best == -1: best = index
            if index < best: best = index
    return best

def multi_rfind(seq, list_of_queries):
    best = -1
    for q in list_of_queries:
        index = string.rfind(seq, q)
        if index > best: best = index
    return best

def find_all(seq, sub):
    #print "Looking for %s in %s"%(sub,seq)
    found = []
    next = string.find(seq,sub)
    while next != -1:
        found.append(next)
        next = string.find(seq,sub,next+1)
    return found

def find_all_list(list, element):
    result = []
    for i in range(0,len(list)):
        if list[i] == element: 
            result.append(i)
    return result

def find_ignoring_gaps(sequence, sub, gaps='-'):
    # The call
    # 
    #   start,end = find_ignoring_gaps(seq, sub, gaps)
    #
    # returns start,end such that:
    # 
    #   string.replace(seq[start-1:end],gaps,'')==sub
    #
    # useful for finding 'CG' in 'A--C---GA'
    #
    # Note: the non one-to-one nature of the definition
    # of this function is resolved by returning all gaps
    # possible on the left, but no gaps on the right.
    # i.e. CG -> '--C---G'  and not 'C---G'
    #      AC -> 'A--C'     and not 'A--C---'

    #print "Looking for %s in %s"%(sub, sequence)

    no_gaps = string.replace(sequence, gaps, '')
    locs = find_all(no_gaps, sub)
    if len(locs) == 0:
        raise "NotFound", sub
    elif len(locs) != 1:
        #print string.join(map(lambda i: '%s'%i, locs))
        raise "MultipleInstances", sub
    else:

        # we know sub only appears once in the ungapped sequence
        uninterrupted = string.find(sequence, sub)
        if uninterrupted != -1:

            # optimization, check if you can find it uninterrupted.
            # will run much faster, and called for most sequences
            start, end = uninterrupted+1, uninterrupted+len(sub)

        else: 

            # well, it wasn't that easy, the sequence is actually interrupted
            # by at least one gap -> search the hard way
            pre_length = locs[0]+1
            
            # now i know that the subsequence starts after pre_length,
            # non-gap characters, so count how many total chars i does it
            # take before we see pre_length chars        
            num_chars,i = 0,0
            while 1:
                if sequence[i] != gaps:
                    num_chars = num_chars+1
                    if num_chars == pre_length:
                        break
                i = i+1
            start = i+1

            # now, from that point, count how many i does it take,
            # so that the number of non-gap chars found (num_chars)
            # is the length of the sought string sub
            
            # note i is the i from above
            num_chars = 0
            while 1:
                if sequence[i] != gaps:
                    num_chars = num_chars+1
                    if num_chars == len(sub):
                        i = i+1
                        break
                i = i+1
            end = i

        #print "Looking for %s in %s -> (%s %s) == %s"%(
        #    sub, sequence, start, end, sequence[start-1:end])

        # verify that it's correct
        assert(string.replace(sequence[start-1:end],'-','')==sub)

        return start, end

def gap_translate_start_end(seq, start, end):
    # return gap_free_start, gap_free_end such that:
    # string.replace(seq,'-','')[start-1:end] == string.replace(seq[start-1:end])

    assert(string.count(seq[:start],'-') < start) # must start with at least one non-gap character

    gaps_before_start  = string.count(seq[:start],'-')
    gaps_in_the_middle = string.count(seq[start-1:end],'-')

    #print "Gaps_before_Start = %s\nGaps_in_the_middle = %s\n"%(
    #    gaps_before_start,gaps_in_the_middle)

    new_start = start -  gaps_before_start
    new_end   = end   - (gaps_before_start + gaps_in_the_middle)

    before = string.replace(seq,'-','')[new_start-1:new_end]
    try: 
        after = string.replace(seq[start-1:end],'-','')
    except MemoryError:
        print "Before was: %s but after failed: %s"%(before, seq[start-1:end])
        print "Start, end = (%s,%s)"%(start, end)
        after = ''

    #print "Before: %s\nAfter:  %s"%(before,after)

    assert(before == after)

    return new_start, new_end

def gap_untranslate_start_end(seq, gap_free_start, gap_free_end, safe=1):
    binary = map(lambda char: char!='-', seq)
    offsets = cumulative_sum(binary)
    start, end = offsets.index(gap_free_start)+1, offsets.index(gap_free_end)+1

    if safe: 
        after  = string.replace(seq,'-','')[gap_free_start-1:gap_free_end]
        before = string.replace(seq[start-1:end],'-','')
        
        if before != after:
            print_wrap([before, after], 100, ['before', 'after'])
            
        assert(after == before)

    return start, end

#def gap_translate_coords(seq, coords, safe=1):
def coords_seq2ali(seq, coords, offsets=None):
    if not offsets: offsets = gap_translation_offsets(seq)
    # one version: 


    if len(coords) < 10: 
        #print "Old method:  %s .index operations on list of %s"%(len(coords), len(offsets))
        new_coords = map(offsets.index, coords)

    else: 
        # another version:
    
        #print "Reversing %s offsets"%len(offsets)
        reversed_offsets = reverse_map_list(offsets,0)
        #print "Looking up %s coordinates"%len(coords)
        new_coords = mget(reversed_offsets, coords)
        #print "Done"
    
    return new_coords

def coords_starts_ends_seq2ali(starts, ends, seq):
    new    = coords_seq2ali(seq, starts+ends)

    starts = new[:len(new)/2]
    ends   = new[len(new)/2:]
    
    return starts, ends

def interval_coords_seq2ali(intervals, seq):
    starts = cget(intervals, 'start')
    ends   = cget(intervals, 'end')

    new    = coords_seq2ali(seq, starts+ends)

    starts = new[:len(new)/2]
    ends   = new[len(new)/2:]
    
    cset(intervals,'start',starts)
    cset(intervals,'end',ends)

def interval_coords_ali2seq(intervals, seq):
    starts = cget(intervals, 'start')
    ends   = cget(intervals, 'end')

    new    = coords_ali2seq(seq, starts+ends)
    
    starts = new[:len(new)/2]
    ends   = new[len(new)/2:]
    cset(intervals,'start',starts)
    cset(intervals,'end',ends)

#def gap_untranslate_coords(seq, coords, safe=1):
def coords_ali2seq(seq, coords, offsets=None):
    if not offsets: offsets = gap_translation_offsets(seq)
    new_coords = mget(offsets,coords)
    return new_coords

def coords_ali2allseqs(seqs, coords, all_offsets=None):
    if not all_offsets: all_offsets = map(gap_translation_offsets,seqs)
    packed_newcoords = []
    for seq,offsets in map(None,seqs,all_offsets):
        packed_newcoords.append(coords_ali2seq(seq,coords,offsets))
    newcoords = unpack(packed_newcoords)
    return newcoords

def gap_translation_offsets(seq):
    #print "Calculating gap translation offsets for a sequence of %s bp"%len(seq)
    binary = [1]*len(seq)
    zeros = find_all(seq,'-')
    mset(binary,zeros,[0]*len(zeros))
    # above and this line are equivalent
    #binary = map(lambda char: char!='-', seq)
    #if not binary==binary2:
    #    print binary
    #    print binary2

    offsets = cumulative_sum(binary)
    offsets.insert(0,0)
    offsets.append(offsets[-1]+1)
    return offsets

def test_gap_translate():
    import simulation
    seq = simulation.random_seq(500,'ACGT-')
    #seq = '----GGATATAG---GGAGGA--'
    #seq = 'G--GA-T-ATAG---GGAGGA'
    gap_free = string.replace(seq,'-','')

    indices = range(0,len(gap_free))

    print "Testing gap translation on: \n%s and \n%s"%(seq,gap_free)
    
    for index in indices:
        
        i2 = coords_seq2ali(seq,[index])[0]

        #print "%s%s%s Char %s of sequence is found at position %s in the alignment"%(
        #    gap_free[index],ifab(gap_free[index]==seq[i2], '==', '!='),seq[i2],index,i2)

        i3 = coords_ali2seq(seq,[i2])[0]
        #print "Position %s of alignment contains char %s of sequence\n"%(i2,i3)
        #print "%s -> %s -> %s"%(index, i2, i3)
        assert(index==i3)
        if index != i3:
            print "\n\nOOPs\n\n"

    print "Testing gap reverse translation on: \n%s and \n%s"%(seq,gap_free)
    for i2 in range(0,len(seq)):
        #print "Position %s of alignment"%i2
        i3 = coords_ali2seq(seq,[i2])[0]
        #print "Position %s of alignment contains char %s of sequence"%(i2,i3)

        i4 = coords_seq2ali(seq,[i3])[0]
        #print '%s -> %s -> %s'%(i2,i3,i4)
        #print "%s%s%s Char %s of sequence is found at position %s in the alignment\n"%(
        #    gap_free[i4],ifab(gap_free[i2]==seq[i4], '==', '!='),seq[i4], i2,i4)
        
############################################################
#
#  Flipping the case of aligned sequences (to mark exons, for example)
#
############################################################

def upcase_alignment(seqs_aligned, seqs_upcased):
    # uses the up/down case information from seqs_upcased
    # and thealignment informatino from seqs_aligned,
    # to generate a sequence that preservs the best
    # alignment, while at the same time displaying extra
    # information based on the casing (example: gene/intergenic)
    # tools.upcase_alignment(['--M---A-NO---L-IS--'],['mAnOlIs']) -> ['--m---A-nO---l-Is--']
    newseqs = []
    for seq_aligned, seq_upcased in map(None, seqs_aligned, seqs_upcased):
        newseqs.append(upcase_ali(seq_aligned, seq_upcased))
    return newseqs

def upcase_ali(seq_aligned, seq_upcased):
    
    gapless_seq = gapless(seq_aligned,'-.')
    if not string.upper(seq_upcased)==gapless_seq:
        pp({'gapless': gapless_seq, 'upcase': seq_upcased},2,60)
        pw(['upcased','gapless','consensus'],
           [seq_upcased,gapless_seq,compute_clustal_consensus([string.upper(seq_upcased),gapless_seq])])
    assert(string.upper(seq_upcased)==gapless_seq)
    # a sequence that's both aligned and upcased
    seq_both = []
    i = 0
    for char in seq_aligned:
        if char=='-': seq_both.append('-')
        elif char=='.': seq_both.append('.')
        else:
            seq_both.append(seq_upcased[i])
            i = i+1
    return string.join(seq_both,'')

def upcase_only(seq):
    upseq = string.upper(seq)
    non_down = []
    for i in range(0,len(seq)):
        if seq[i] == upseq[i]:
            non_down.append(seq[i])
    return string.join(non_down,'')

def seqs_upcase_only(seqs):
    assert(all_same(map(len,seqs)))
    maxstart,minend = 0,len(seqs[0])
    for seq in seqs:
        if '*' in seq: continue
        thisstart = min(string.find(seq,'A'),string.find(seq,'C'),string.find(seq,'G'),string.find(seq,'T'))
        maxstart = max(thisstart,maxstart)
        thisend = max(string.rfind(seq,'A'),string.rfind(seq,'C'),string.rfind(seq,'G'),string.rfind(seq,'T'))
        minend = min(thisend,minend)
    return get_subseqs(seqs,maxstart,minend)

def first_upcase(seq):
    starts = string.find(seq,'A'),string.find(seq,'C'),string.find(seq,'G'),string.find(seq,'T')
    if max(starts)==-1: return -1
    else: return min(gte(starts,0))
    
def last_upcase(seq):
    return max(string.rfind(seq,'A'),string.rfind(seq,'C'),string.rfind(seq,'G'),string.rfind(seq,'T'))
    
def first_lowcase(seq):
    starts = [string.find(seq,'a'),string.find(seq,'c'),string.find(seq,'g'),string.find(seq,'t')]
    if max(starts)==-1: return -1
    else: return min(gte(starts,0))
    
def last_lowcase(seq):
    return max(string.rfind(seq,'a'),string.rfind(seq,'c'),string.rfind(seq,'g'),string.rfind(seq,'t'))
    
def upcase_this(this, case_teller):
    assert(len(this)==len(case_teller))
    that = []
    for i in range(0,len(this)):
        if string.upper(case_teller[i])==case_teller[i]:
            that.append(string.upper(this[i]))
        elif string.lower(case_teller[i])==case_teller[i]:
            that.append(string.lower(this[i]))
        else:
            that.append(this[i])
    return string.join(that,'')
    
    

#def upcase_portion(gapseq,upseq):
#    # find the best match of upseq in seq in upcase it
#    seq = string.
#
#
#    # searching the entire sequence
#    complete = string.find_all(upseq)
#    if len(complete)==1: 
#        start,end = complete[0],complete[0]+len(upseq)
#    else:
#        # first find the beginning
#        length_to_search = len(upseq)/2
#        while 1:
#            starts = find_all(upseq)
#            
#        
#    else:
#        
#        
#        
#        
#    alistart,ali = coords_seq2ali(gapseq,[start,end])
    

############################################################
#
# Kinda like find only with sw as the search method
#
############################################################

def sw_find(seq, subsequence):
    import sw
    homologs = sw.quick_sw(subsequence, seq)
    if homologs: return homologs[0]['start']-1
    else: return -1

# category 7
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       LIST OPERATIONS                    ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def multimap(functions, list):
    for function in functions:
        list = map(function, list)
    return list

def flatten(list_of_lists):
    flat_list = []
    map(flat_list.extend, list_of_lists)
    return flat_list

def reverse_list(list):
    l2 = list[:]
    l2.reverse()
    return l2

def lappend(list, element):
    # same as:  flatten([list, [element]])
    list2 = list[:]
    list2.append(element)
    return list2

def all_same(list, compfunc=None):
    # returns 1 if all the elements of the list are identical
    # if compfunc is specified then returns 1 if all the elements x
    # of the list satisfy 1==compfunc(x,first) where first is the first element
    if not list: return 1
    if len(list)==1: return 1
    first = list[0]
    if compfunc:
        for element in list:
            if compfunc(element,first): continue
            else: return 0
    else:
        for element in list:
            if element==first: continue
            else: return 0
    return 1

def all_diff(list):
    return unique_len(list)==len(list)

def all_same_key(diclist, key):
    # returns 1 if the value indexed by key
    # is identical for every dic in diclist
    if not diclist: return 1
    if len(diclist)==1: return 1
    first = diclist[0][key]
    for element in diclist:
        if element[key] != first: return 0
    return 1

def unique(list, comp=None):
    if comp: return unique_slow(list, comp)
    else: return unique_fast(list)

def unique_len(list):
    return len(unique(list))

def unique_fast(list): 
    new_list = []
    seen_dic = {}
    for element in list:
        # check if already seen, skip
        if seen_dic.has_key(element): continue
        # add it to seen list
        seen_dic[element] = None
        # add it to unique list
        new_list.append(element)
    return new_list

def unique_slow(list, comp = (lambda x,y: x==y)):
    list = list[:]
    if not list: return list
    uniquelist = []
    for nextelement in list:
        novel = 1
        for seenelement in uniquelist: 
            if comp(nextelement,seenelement):
                novel = 0
                break
        if novel: uniquelist.append(nextelement)
    return uniquelist

def unique_hash(list, el2hash=lambda el: el, strict=0):
    # takes a function that returns a hashable key
    # for every element
    new_list = []
    seen = {}
    for element in list:
        key = el2hash(element)
        if seen.has_key(key):
            if strict: assert(seen[key]==element)
        else:
            seen[key] = element
            new_list.append(element)
    return new_list

def diclist_unique(diclist, key):
    return unique_hash(diclist,lambda x,key=key: x[key])

def eliminate_duplicates(list): 
    # assumes the elements are sorted and eliminates duplicates
    if len(list)<=1: return list[:]
    newlist = [list[0]]
    for i in range(1,len(list)):
        if newlist[-1]!=list[i]:
            newlist.append(list[i])
    return newlist

def list_find_all(list, element):
    indices = []
    for i in range(0,len(list)):
        if list[i] == element: 
            indices.append(i)
    return indices

def merge_sorted(list1,list2):
    # assumes the two lists are sorted, and
    # that elements within them are unique
    # then merges them, keeping one copy of each element
    merged,i,j, = [],0,0
    while i<len(list1) and j<len(list2):
        if list1[i] < list2[j]:
            merged.append(list1[i])
            i = i+1
        elif list1[i] > list2[j]:
            merged.append(list2[j])
            j = j+1
        else:
            merged.append(list1[i])
            i = i+1
            j = j+1
    if i<len(list1): merged.extend(list1[i:])
    elif j<len(list2): merged.extend(list2[j:])
    return merged

#def list_find(list, element) -> list.index(element)

#def join(list, element):
#    # modifies list, by appending element
#    # returns the list
#    # equivalent to flatten([x,[y]])
#    list.append(element)
#    return list

#def append_unique(list, item):
#    # appends item to list only if it doesn't yet contain it
#    # note:  O(n) has to search the entire list for item first
#    if not item in list: 
#        list.append(list)

def split_list(list, is_separator):
    # input:  list: a list of items
    #         is_separator: a function s.t. is_separator(list[i]) == 1 only if element i is a separator
    # action: separates the list at every separator
    # output: the sublists, and the separators found

    # example: 3,3,3,'a','b',3,'c','d',3
    #          -> [],[],[],['a','b'],['c','d'],[]
    #          ->   3  3  3         3         3
    # note:  always one more group than separator
    # note:  groups can be empty if between consecutive separator elements
    
    groups = []
    separators = [-1]+list_find_all(map(is_separator, list),1)+[len(list)+1]
    for i,j in map(None, separators[:-1], separators[1:]):
        groups.append(list[i+1:j])
    return groups, mget(list, separators[1:-1])

############################################################
#
#  Sorting lists and dictionaries
#
############################################################

def my_sort(list,func=None):
    if func: return sort_transform(list,func)
    else: return sort_copy(list)

def my_sort_rev(list,func=None):
    if func: return my_reverse(sort_transform(list,func))
    else: return my_reverse(sort_copy(list))

def my_reverse(list):
    newlist = list[:]
    newlist.reverse()
    return newlist

def sort_copy(list):
    # useful for sorting on the fly without
    new_list = map(None, list)
    new_list.sort()
    return new_list

def sort_transform(list, function = lambda a: a):
    augmented = map(None, map(function,list), list)
    augmented.sort()
    return cget(augmented,1)

def sort_diclist(diclist, key):
    # faster:  no lambdas
    augmented = map(None, cget(diclist,key), diclist)
    augmented.sort()
    return cget(augmented,1)

def sort_diclist_rev(diclist, key):
    # faster:  no lambdas
    augmented = map(None, cget(diclist,key), diclist)
    augmented.sort()
    augmented.reverse()
    return cget(augmented,1)

def sort_diclist_multiple(diclist, keys):
    # faster:  no lambdas
    augmented = map(None, cget_multiple(diclist,keys), diclist)
    augmented.sort()
    return cget(augmented,1)

def sort_diclist_multiple_rev(diclist, keys):
    # faster:  no lambdas
    augmented = map(None, cget_multiple(diclist,keys), diclist)
    augmented.sort()
    augmented.reverse()
    return cget(augmented,1)

def diclist_unique(diclist, key):
    seen = {}
    newlist = []
    for dic in diclist:
        if seen.has_key(dic[key]): continue
        seen[dic[key]] = None
        newlist.append(dic)
    return newlist
            
    

def shuffle(list):
    list = list[:]
    for i in range(0,len(list)):
        j = random.randrange(0,len(list))
        list[i],list[j] = list[j],list[i]
        #print "Flipping %s and %s. %s"%(i,j,list)
    return list

#def shuffle2(list):
#    NOTE:  Not a correct algorithm.  
#    return my_sort(list, lambda x,random=random: random.uniform(-.5,.5))

def test_shuffle(length=100, iters=20):
    sums = [0]*length
    for iter in range(0,iters):
        list = range(1,length+1)
        list = shuffle(list)
        for i in range(0,length):
            sums[i] = sums[i] + list[i]
    for i in range(0,length):
        sums[i] = sums[i] / float(iters)
    print "The average for each bin after %s iterations should be approximately %s"%(iters,length/2)
    print map(int,map(round,sums))
    quick_histogram(sums)

def reorder_names_seqs(names, seqs, ordered):
    ordered = ordered + set_subtract(names, ordered)
    order = map(names.index, filter(names.count, ordered))
    return mget(names, order), mget(seqs, order)

def common_names_seqs(names_seqs_list):
    # input: [(names1,seqs1), (names2,pseqs2), (names3,tseqs3)]
    # output: [names, seqs, pseqs, tseqs]

    
    dics = []
    for names, seqs in names_seqs_list:
        dic = {}
        mset(dic,names,seqs)
        dics.append(dic)
        
    common_names = set_intersect_all(cget(names_seqs_list,0))

    result = [common_names]
    for dic in dics:
        result.append(mget(dic,common_names))
    return result
        
def missing_data_consensus(names, seqs, pieces):
    # in a multiple alignment (names, seqs)
    # pieces of one of the sequences are aligned
    # with the entire length of the other sequences
    # construct a consensus for the pieces

    name2i = {}
    for name_i in range(0,len(names)):
        name2i[names[name_i]] = name_i
    seqs = clustal_endgaps2dots(seqs)

    consensus = []
    for char_i in range(0,len(seqs[0])):

        chars = []
        for seq_i in mget(name2i, pieces):
            chars.append(seqs[seq_i][char_i])

        # Gaps at the end are treated as if
        # we never saw any sequence for it. 
        chars = set_subtract(chars, '.')

        others = []
        for seq_i in range(0,len(seqs)):
            if name2i.has_key(names[seq_i]): pass
            else: others.append(seqs[seq_i][char_i])

        if len(chars) == 0:
            # none of the sequence covered that part
            consensus.append('.')
        elif len(chars) == 1:
            consensus.append(chars[0])
        elif all_same(chars) and all_same(others) and (not others or others[0]==chars[0]):
            # all the sequence pieces covering that base
            # agree (either a char or a gap.  Append it)
            consensus.append(chars[0])
        else:
            # well now, we can use the other sequences in the
            # alignment to do some sort of a majority vote
            # on this particular base, but then we'd be
            # reasoning circularly.  So we just include an N.
            # others would argue i should put a ".", but it's
            # not really missing. we know there's a base there. 

            # consensus.append(generate_profile(chars))
            # consensus.append(string.lower(generate_profile(chars)))

            #if all_same(others) and others[0] in chars:
            #    consensus.append(others[0])
            #else:
            consensus.append('N')
    return string.join(consensus,'')

############################################################
############################################################
############################################################

def my_filter(x, test):
    #print 'x    = %s'%(`x`[:40])
    #print 'test = %s'%(`test`[:40])
    #if not set_subtract(unique(test), [0,1]):
    #    print "unique(test)=%s"%(unique(test))
    assert(set_subtract(unique(test),[0,1])==[])
    x_test = map(None, x, test)
    x_test_true = filter(lambda x_test: x_test[1], x_test)
    x_true = cget(x_test_true, 0)
    return x_true    

def lt(list, cutoff):
    # return the elements of the list that are strictly less than the cutoff
    return filter(lambda el,cutoff=cutoff: el<cutoff, list)

def gt(list, cutoff):
    # return the elements of the list that are strictly greater than the cutoff
    return filter(lambda el,cutoff=cutoff: el>cutoff, list)

def lte(list, cutoff):
    # return the elements of the list that are less than or equal to the cutoff
    return filter(lambda el,cutoff=cutoff: el<=cutoff, list)

def gte(list, cutoff):
    # return the elements of the list that are greater than or equal to the cutoff
    return filter(lambda el,cutoff=cutoff: el>=cutoff, list)

def within(list, left, right):
    # return elements that are greater than or equal to left and strictly less than right
    return filter(lambda el,left=left,right=right: left<=el<right, list)

def all_gt(list,cutoff):
    for el in list:
        if not el>cutoff: return 0
    return 1

def all_gte(list,cutoff):
    for el in list:
        if not el>=cutoff: return 0
    return 1

def all_lt(list,cutoff):
    for el in list:
        if not el<cutoff: return 0
    return 1

def all_lte(list,cutoff):
    for el in list:
        if not el<=cutoff: return 0
    return 1


def count_lt(list, cutoff):
    # return the elements of the list that are strictly less than the cutoff
    count = 0
    for el in list:
        if el<cutoff: count = count+1
    return count

def count_gt(list, cutoff):
    # return the elements of the list that are strictly less than the cutoff
    count = 0
    for el in list:
        if el>cutoff: count = count+1
    return count

def count_lte(list, cutoff):
    # return the elements of the list that are strictly less than the cutoff
    count = 0
    for el in list:
        if el<=cutoff: count = count+1
    return count

def count_gte(list, cutoff):
    # return the elements of the list that are strictly less than the cutoff
    count = 0
    for el in list:
        if el>=cutoff: count = count+1
    return count

def count_within(list, left, right):
    # return elements that are greater than or equal to left and strictly less than right
    count = 0
    for el in list:
        if left<=el<right: count = count+1
    return count

#def expand(initial, operation, num_iters):
#    # the opposite of reduce
#    result = [initial]
#    for element in list:
#        result.append(operation(initial, result[-1]))

def map_constant(list, func=lambda x,y: x, constant=0):
    result = []
    for item in list:
        result.append(func(item, constant))
    return result

def map_repeat(num_iters, func=lambda x,y: x, constant=0):
    result = [constant]
    for i in range(0,num_iters-1):
        result.append(func(result[-1], constant))
    return result

def map_multiple(funcs, list): # same as multimap
    for func in funcs:
        list = map(func,list)
    return list

def pick_n(list,n):
    import random
    results = []
    for i in range(0,n):
        results.append(random.choice(list))
    return results

def pick_one(dic):
    # {'A': .18, 'C': .32, 'G': .32, 'T': .18}
    # will generate A with probability .18 and so on
    items = dic.items()
    cums = cumulative_sum(cget(items,1))
    if 1: #debug: 
        #print cums
        x = random.uniform(0,cums[-1])
        bin = which_bin(cums, x, safe=1)
        #print "%s is in bin %s and char %s. items=%s"%(
        #    x,bin,items[bin][0],items)
        return items[bin+1][0]
    else:
        return items[which_bin(cums, random.uniform(0,cums[-1]), safe=1)][0]

def which_bin(bins, x, safe=0):
    # if we're interested in binning x with boundaries
    # 0, 5, 10, 15
    # then it will return which boundary it belongs in.
    # if x<0: -1
    # if 0<=x<5: 0
    # if 5<=x<10: 1
    # if 10<=x<15: 2
    # if x>=15: 3
    if x<bins[0]: return -1
    for i in range(1,len(bins)):
        if x<bins[i]: return i-1
    if safe and x==bins[-1]: return len(bins)
    return len(i)+1

def test_which_bin():
    bins = range(0,10)
    #print bins
    for iter in range(0,1000): 
        x = random.uniform(bins[0], bins[-1])
        bin = which_bin(bins,x)
        #print "%s is in bin %s"%(x,bin)

############################################################
#
#  Compute aggregate values on lists
#
############################################################

def is_increasing(list):
    for i in range(1,len(list)):
        if list[i]<list[i-1]: return 0
    return 1

def is_decreasing(list):
    for i in range(1,len(list)):
        if list[i]>list[i-1]: return 0
    return 1

def longest_colinear(list):
    longest = []
    for sublist in superset(list):
        if len(sublist) > len(longest) and (is_increasing(sublist) or
                                            is_decreasing(sublist)):
            longest=sublist
    return longest

def diclist_longest_colinear(diclist,key):
    longest = []
    for sublist in superset(diclist):
        if len(sublist) > len(longest) and (is_increasing(cget(sublist,key)) or
                                            is_decreasing(cget(sublist,key))):
            longest=sublist
    return longest

def is_positive(x):
    return x>0

def sign(x):
    if x>0: return +1
    elif x<0: return -1
    else: return 0

def sum(l):
    if not l: return 0
    return reduce(operator.add,l,0)

def min_max(l):
    return "Min: %s, Median: %s, Avg: %s, N50: %s, Max: %s"%(
        display_bignum(min(l)),
        display_bignum(median(l)),
        display_bignum(avg(l)),
        display_bignum(n50(l)),
        display_bignum(max(l)))

def n50(l,fifty=50):
    l = my_sort(l)
    l.reverse()
    middle = sum(l)*fifty/100.0
    tot = 0
    for i in range(0,len(l)):
        tot = tot+l[i]
        if tot>=middle: break
    return l[i]

def n50s(l,fiftys=[50,60,70,80,85,90,95],show_n=1):
    res = []
    for fifty in fiftys:
        n50_length = n50(l,fifty)
        this_res = 'n%s=%s'%(fifty,display_bignum(n50_length,0))
        if show_n:
            # how many supercontigs are at that length or greater
            count = len(filter(lambda i,length=n50_length: i>=length, l))
            this_res = this_res + ' in %s'%count
        res.append(this_res)
    return string.join(res,', ')

def percentiles(l, ps=[25,50,75]):
    l = tools.my_sort(l)
    result = []
    for p in ps:
        result.append(l.index(int(len(l)*p/100.0)))
    return result

def log_avg(l,strict=1):
    if strict: return math.exp(avg(map(math.log,l)))
    else: return math.exp(avg(map(math.log,filter(None,l))))

def log_variance(l):
    return math.exp(variance(map(math.log,l)))

def avg(l,precise=0):
    if not l: return 0
    if precise:
        return reduce(operator.add,l,0)/float(len(l))
    else:
        return reduce(operator.add,l,0)/len(l)

def weighted_avg(l,weights,precise=0,safe=None):
    if not l: return 0
    assert(len(l)==len(weights))
    tot,div = 0,0
    for x,w in map(None,l,weights):
        tot = tot+x*w
        div = div+w
    if safe!=None and not div: return safe
    if precise: return tot/float(div)
    else: return tot/div

def safe_weighted_avg(l,weights,default,precise=0):
    if not l: return 0
    assert(len(l)==len(weights))
    tot,div = 0,0
    for x,w in map(None,l,weights):
        tot = tot+x*w
        div = div+w
    if not div: return default
    if precise: return tot/float(div)
    else: return tot/div

#def weighted_avg(l,weights,precise=0):
#    if not l: return 0
#    if sum(weights)==0:
#        print "tools.weighted_avg:  All weights are zero!!  Making them all 1"
#        weights = [1]*len(l)
#    if precise:
#        return sum(vector_vector_mul(l,weights))/float(sum(weights))
#    else: 
#        return sum(vector_vector_mul(l,weights))/sum(weights)

def median(l):
    if not l: return None
    l = my_sort(l)
    if len(l)%2: return my_sort(l)[len(l)/2]
    else: return (l[len(l)/2]+l[len(l)/2-1])/2.0

def majority(l):
    if not l: return None
    #return my_sort(count_same(l).items(),lambda i: -i[1])[0][0]
    counts = count_same(l)
    all_max = filter_diclist(counts.items(),1,max(counts.values()))
    if len(all_max)==1: return all_max[0][0]
    else: return my_sort(cget(all_max,0))[len(all_max)/2]

def argmax(list): # def max_i
    max_i, max_value = 0, list[0]
    for i in range(1,len(list)):
        if list[i] > max_value:
            max_value = list[i]
            max_i = i
    return max_i

def argmin(list): 
    min_i, min_value = 0, list[0]
    for i in range(1,len(list)):
        if list[i] < min_value:
            min_value = list[i]
            min_i = i
    return min_i

def my_max(list, f=lambda x: x):
    # returns the element that yields the largest f(x)
    i = argmax(map(f,list))
    return list[i]

def my_min(list, f=lambda x: x):
    # returns the element that yields the largest f(x)
    i = argmin(map(f,list))
    return list[i]

def stdev(l, failfast=1):
    return math.sqrt(variance(l,failfast=failfast))

def variance(l,failfast=1):
    if (not l) or len(l)==1:
        if failfast: raise "tools.variance: Not enough samples.  Need >= 2, got %s"%len(l)
        else: return 0#'N/A'
    m = avg(l,1)
    s = 0
    for i in l:
        s = s + (i-m)*(i-m)
    return s / (len(l)-1)

def normalize_mean_shift(list, desired_mean):
    return vector_scalar_sub(list, avg(list)-desired_mean)

def normalize_sum_scale(list, desired_sum):
    return vector_scalar_mul(list, desired_sum/float(sum(list)))

def normalize_max_scale(list, desired_max):
    return vector_scalar_mul(list, desired_max/float(max(list)))

def normalize_max_min(list, desired_max, desired_min):
    
    return vector_scalar_mul(list, desired_max/float(max(list)))

def normalize_mean_scale(list, desired_mean):
    return vector_scalar_mul(list, desired_mean/float(avg(list,1)))

def normalize_mean_stdev(list, desired_mean):
    return vector_scalar_sub(list, avg(list)-desired_mean)

def normalize_sum_to(list,new_total):
    newlist = list[:]
    current_total = sum(list)
    mul = new_total / float(current_total)
    for i in range(0,len(list)):
        newlist[i] = mul*newlist[i]
    return newlist

def covariance(x,y):
    assert(len(x)==len(y))
    meanx = avg(x,1)
    meany = avg(y,1)
    #print "mean x: %s, mean y: %s"%(meanx, meany)
    # method 1:  E((X-EX)(Y-EY))
    #cov = 0
    #for i in range(0,len(x)):
    #    cov = cov + (x[i]-meanx)*(y[i]-meany)
    #cov = cov/(len(x)-1)

    # method 1:  E(XY)-EXEY
    cov2 = 0
    for i in range(0,len(x)):
        cov2 = cov2 + x[i]*y[i]
    cov2 = float(cov2)/len(x) - (meanx * meany)
    cov2 = len(x)/float(len(x)-1)*cov2

    #print "Method 1:  %s Method 2: %s"%(cov,cov2)
    
    return cov2
    
def correlation(x,y):
    # the correlation between (x[i],y[i])
    return covariance(x,y)/(stdev(x)*stdev(y))

def fit_line(x,y):
    # fit an Maximum likelihood y-on-x regression line on (x[i],y[i])
    # it then calculates the standard deviation of the noise in the fit

    varx = variance(x)
    if not varx: raise "tools.fit_line: y-on-x regression impossible when variance(x) is zero."

    slope  = covariance(x,y) / varx
    offset = avg(y) - slope * avg(x)

    noise = vector_vector_sub(y, ax_plus_b(x,slope,offset))
    
    return slope, offset, stdev(noise)

def ax_plus_b(x,a,b):
    return map(lambda x,a=a,b=b: a*x+b, x)

def ax_plus_b_noise(x,a,b,s):
    return map(lambda x,a=a,b=b,s=s,random=random: a*x+b+random.gauss(0,s), x)

def test_fit_line():

    # plot a lot of numbers
    x = map(lambda x,random=random: 100*random.random(), range(0,10000))
    a,b=2,-300
    y = ax_plus_b_noise(x,a,b,20)

    quick_plot(x,y)

    a1,b1,sigma = fit_line(x,y)
    y2 = ax_plus_b(x,a1,b1)

    deviations = vector_vector_sub(y,y2)
    quick_histogram(deviations)

    print 'a,b=(%s,%s) guess=(%s,%s) with SigmaNoise=%s'%(a,b,a1,b1,sigma)

def find_outliers(samples):
    samples = samples[:] # make a copy
    samples.sort() # sort it
    mean,std = avg(samples),stdev(samples)
    print "Less than %s: %s"%(mean-3*std,
                              len(filter(lambda s,mean=mean,std=std: s<mean-3*std,samples)))
    for i in range(-3,2,1):
        low,high = mean+i*std,mean+(i+1)*std
        samples_in = filter(lambda s,l=low,h=high: l<=s<h, samples)
        print "Range %s to %s: %s samples (%s)"%(
            low,high,len(samples_in),safe_percentile(len(samples_in),len(samples)))
    print "More than %s: %s"%(mean+3*std,
                              len(filter(lambda s,mean=mean,std=std: mean+3*std<=s,samples)))

############################################################
#
#  Primitive Operations
#
############################################################

def vector_add_noise(x,s):
    # adds a random error iid N(0,s^2)
    import random
    y = x[:]
    for i in range(0,len(y)):
        y[i] = y[i] + random.gauss(0,s)
    return y

def vector_scalar_op(v1, a, op):
    v3 = v1[:]
    for i in range(0,len(v3)):
        v3[i] = op(v1[i],a)
    return v3

def vector_scalar_add(v1, a):
    return vector_scalar_op(v1,a,operator.add)

def vector_scalar_sub(v1, a):
    return vector_scalar_op(v1,a,operator.sub)

def vector_scalar_mul(v1, a):
    return vector_scalar_op(v1,a,operator.mul)

def vector_vector_op(v1,v2,op):
    assert(len(v1) == len(v2))
    v3 = [0]*len(v1)
    for i in range(0,len(v1)):
        v3[i] = op(v1[i], v2[i])
    return v3

def vector_vector_add(v1,v2):
    return vector_vector_op(v1,v2,operator.add)
    
def vector_vector_sub(v1,v2):
    return vector_vector_op(v1,v2,operator.sub)
    
def vector_vector_mul(v1,v2):
    return vector_vector_op(v1,v2,operator.mul)
    
def vector_vector_div(v1,v2):
    return vector_vector_op(v1,v2,operator.div)

############################################################

def float_range(minv, maxv, increment):
    assert(increment>0)
    values = []
    current = minv
    while current < maxv:
        values.append(current)
        current = current + increment
    #values.append(maxv)
    return values

def float_range_rev(minv, maxv, increment):
    assert(increment<0)
    values = []
    current = minv
    while current > maxv:
        values.append(current)
        current = current + increment
    #values.append(maxv)
    return values

############################################################
# all_pairs
    
def cumulative_sum(quality):
    if not quality: return quality
    sum_q = quality[:]
    for i in range(1,len(quality)):
        sum_q[i] = sum_q[i-1]+quality[i]
    return sum_q

def cumulative_avg(quality):
    if not quality: return quality
    sum_q = quality[:]
    avg_q = quality[:]
    for i in range(1,len(quality)):
        sum_q[i] = sum_q[i-1]+quality[i]
        avg_q[i] = sum_q[i]/float(i)
    return avg_q

def list2pairs(list,dist=lambda a,b: (a,b)):
    # input: an ordered list of n positions
    # output: the distances between consecut
    dists = []
    for a,b in map(None, list[:-1], list[1:]):
        dists.append(dist(a,b))
    return dists

def list2pairs_all(list,
                   dist=lambda a,b: (a,b), 
                   cutoff = None):
    pairs = []
    for i in range(0,len(list)):
        if i%100==0: print 'i=%s/%s'%(i,len(list))
        for j in range(i+1, len(list)):
            a,b = list[i],list[j]
            if not cutoff or cutoff(a,b,i,j):
                pairs.append(dist(a,b))
                if len(pairs) % 10000==0: print "pairs=%s"%len(pairs)
    return pairs

def all_pairs_in_range(min,max):
    pairs = []
    for i in range(min,max):
        for j in range(i+1,max):
            pairs.append((i,j))
    return pairs

def all_pairs(list):
    pairs = []
    for i in range(0,len(list)):
        for j in range(i+1,len(list)):
            pairs.append((list[i],list[j]))
    return pairs

############################################################
#
#  Window-based computations
#
############################################################

def sum_n_continuous(a, n):
    # a=[1,2,3,4,5,6,7] n=3 sum_n_continuous(a,n)=[6,9,12,15,18,0,0]
    # if i in 0 <len(a)-n+1: 
    #    sum[i] = a[i] + a[i+1] + a[i+2] + ... + a[i+n]
    # else:
    #    sum[i] = 0
    # a=[1,2,3,4,5,6,7]
    # r=[=====,-> 6
    #      =====, -> 9
    #        =====, -> 12
    #          =====, -> 15
    #            =====, -> 18
    #              =====, -> 0
    #                 =====, -> 0
    s = [0] * len(a)
    for i in range(0, len(a)-n+1):
        sum = 0
        for j in range(i, i+n):
            sum = sum + a[j]
        s[i] = sum
    return s

def sum_window(quality, n):
    cum = cumulative_sum(quality)
    cum.insert(0,0)
    diff = map(operator.sub, cum[n:], cum[0:-n])
    return diff

def best_n_continuous(quality, n):
    # trim to a length of 500, maximizing total quality
    diff = sum_window(quality, n)
    if not diff: return 0,len(quality)-1
    qual_sum = max(diff)
    best = diff.index(qual_sum)
    return best,best+n-1

############################################################
#
#   Multiple synchronized lists
#
############################################################

def unpack(el):
    if not el: return []
    indices = range(0,len(el[0]))
    return map(lambda index,el=el: map(lambda elements,index=index:elements[index],
                                       el), 
               indices)

def intercalate(lists): # shuffle
    # intercalate, insert as layers
    # pick one element from each list
    # shuffle([[1,2,3],['a','b','c']]) -> [1,'a',2,'b',3,'c']

    assert(all_same(map(len,lists)))
    
    merged = []
    for i in range(0,len(lists[0])):
        merged.extend(cget(lists,i))
    return merged    

#example of unpack
# list_of_elements = map(None, ['do','re','mi'], [1,2,3], ['one','two','three'], ['a','b','c'])
# notes, nums, numbers, chars = unpack(list_of_elements)"

# unpack(([1,2,3], ['a','b','c'], ['do','re','mi'])) -> [[1, 'a', 'do'], [2, 'b', 're'], [3, 'c', 'mi']]
# unpack(([1, 'a', 'do'], [2, 'b', 're'], [3, 'c', 'mi'])) -> [[1,2,3], ['a','b','c'], ['do','re','mi']]

def sort_synchronized(tuple_of_lists, i, comp_operation = cmp):
    list_of_tuples = unpack(tuple_of_lists)
    list_of_tuples.sort(lambda a,b,op=comp_operation,i=i: op(a[i],b[i]))
    return unpack(list_of_tuples)

# to sort on the third item (index 2) use:
#a,b,c,d,e = sort_synchronized((a,b,c,d,e),2,cmp)

def filter_synchronized(tuple_of_lists, i, filter_operation = None):
    list_of_tuples = unpack(tuple_of_lists)
    if filter_operation: 
        list_of_tuples = filter(lambda x,op=filter_operation,i=i: op(x[i]), list_of_tuples)
    else:
        list_of_tuples = filter(lambda x,i=i: x[i], list_of_tuples)
    if list_of_tuples: return unpack(list_of_tuples)        
    else: return map(lambda x: [], tuple_of_lists)

############################################################
#
#  PACK And UNPACK for DICTIONARIES
#
############################################################

def dict_pack(dict, keys, newkeys):
    """ example:
    pack(d, ['num_list','notes_list','letter_list'], ['num','note','letter'])
    takes d={'num_list': [1,2,3],
             'notes_list': ['do','re','mi'],
             'letter_list': ['a','b','c']}
    and returns:
     [{'p':1, 'b': 'do', 'sw': 'a'}, 
      {'p':2, 'b': 're', 'sw': 'b'}, 
      {'p':3, 'b': 'mi', 'sw': 'c'}]
    """
    if not len(keys) == len(newkeys):
        print "ALERT! unmatched lengths of keys="+`keys`+" and newkeys="+`newkeys`
    tuple_of_lists = mget(dict, keys)
    list_of_tuples = unpack(tuple_of_lists)
    list_of_dicts = map(lambda tuple,newkeys=newkeys: items2dic(map(None, newkeys, tuple)), 
                        list_of_tuples)
    return list_of_dicts

def dict_unpack(list_of_dicts, keys, newkeys):
    list_of_tuples = map(lambda dict,newkeys=newkeys: mget(dict, newkeys),
                         list_of_dicts)
    tuple_of_lists = unpack(list_of_tuples)
    dict_of_lists = items2dic(map(None, keys, tuple_of_lists))
    return dict_of_lists

# category 8
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       DICTIONARY OPERATIONS              ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  Creating and copying dictionaries
#
############################################################

def empty_dics(n):
    # returns a list of n empty dictionaries
    # equivalent to [{}]*n only no pointer mess
    list = []
    for i in range(0,n):
        list.append({})
    return list

def empty_lists(n):
    # returns a list of n empty lists
    # equivalent to [[]]*n only no pointer mess
    list = []
    for i in range(0,n):
        list.append([])
    return list

def items2dic(items):
    # create a dictionary from an items list
    # items2dic(d.items()) returns a copy of d
    dic = {}
    for key,value in items:
        dic[key] = value
    return dic

#def dic_subset(dic, keys):
#    keys = set_intersect(keys,dic.keys())
#    newdic = items2dic(map(None,keys,tools.mget(dic,keys)))
#    return newdic

def copy_dic(dic):
    # returns a copy of a dictionary
    # same as a[:] for a list
    return items2dic(dic.items())

def map_on_dic(dic, func):
    return items2dic(map(None, dic.keys(), map(func,dic.values())))

def set_key(dictionary, key, value):
    # functional implementation of dictionary assignment
    # useful for using in map
    dictionary[key] = value

def reverse_dictionary(dic, strict=0):
    # reverses the mapping from values to keys
    rdic = {}
    for key, value in dic.items():
        if strict: assert(not rdic.has_key(value))
        rdic[value] = key
    return rdic

############################################################
#
#  Getting and setting key(s) in dictionary(ies)
#
############################################################

def mget(list,keys,strict=1):
    results = []
    if strict: 
        for key in keys: 
            results.append(list[key])
    else: 
        for key in keys:
            if list.has_key(key):
                results.append(list[key])
    return results

def mget_l(list, indices, strict=1):
    results = []
    for index in indices:
        if strict: results.append(list[index])
        elif 0 <= abs(index) < len(list): results.append(list[index])
    return results

def mget_d(dict, keys):
    return map(dict.get, keys)

def mset(dict, keys, values):
    for key,value in map(None,keys,values):
        #print "Setting "+`key`+" to "+`value`
        dict[key] = value

def cget(diclist, key, strict=1): # cross_get was: gather(diclist,key)
    # gathers the same key from a list of dictionaries
    # can also be used in lists

    # input: a list of dictionaries all of which contains key
    # output: a list of elements d[key] for each d in diclist
    if strict:
        # return map(lambda d,key=key: d[key], diclist)
        result = [None]*len(diclist)
        for i in range(0,len(diclist)):
            result[i] = diclist[i][key]
        return result
    else:
        results = []
        for dic in diclist:
            if dic and generic_has_key(dic,key):
                results.append(dic[key])
        return results

def my_has_key(dic, key, generic=0):
    if generic: return generic_has_key(dic, key)
    else: return dic.has_key(key)

def generic_has_key(dic_or_list, key):
    which_type = type(dic_or_list)
    if which_type==types.ListType or which_type==types.TupleType or which_type==types.StringType:
        if key>=0: return len(dic_or_list) > key
        else: return len(dic_or_list) >= -key
    elif which_type==types.DictType:
        return dic_or_list.has_key(key)
    else:
        raise "Unexpected type (not Dict or List)", which_type

def cget_l(lists, index, strict=1):
    results = []
    if strict: 
        for list in lists:
            results.append(list[index])
    else:
        if index >= 0: 
            for list in lists:
                if len(list)>=index+1:
                    results.append(list[index])
        else:
            for list in lists:
                if len(list)>=abs(index)+1:
                    results.append(list[index])
    return results

def cset(diclist, key, valuelist):
    for dic,value in map(None, diclist, valuelist):
        dic[key] = value

def cget_multiple(diclist, keys, strict=0):
    valuelist = []
    for dic in diclist:
        valuelist.append(mget(dic,keys,strict))
    return valuelist

def cget_deep(diclist, key_series, strict=0):
    values = diclist
    for key in key_series:
        values = cget(values, key, strict)
    return values

def dic_subset(dic, keys, strict=1):
    # def dicsel(dic, keys): -> renamed dic_subset

    # gets the subset of keys that dic contains
    # creates a new dictionary from it and returns it

    # if strict, it requires that dic has every key
    new_dic = {}
    for key in keys:
        if not strict and not dic.has_key(key): continue
        new_dic[key] = dic[key]
    return new_dic

def diclist_subset(diclist, keys, strict=1):
    # calls dicsel on diclist
    newdiclist = []
    for dic in diclist:
        newdiclist.append(dic_subset(dic,keys,strict))
    return newdiclist

def rename_key(diclist, oldkey, newkey):
    cset(diclist,newkey,cget(diclist,oldkey))
    cdel(diclist,oldkey)

def rename_keys(diclist, oldkeys, newkeys):
    for oldkey, newkey in map(None, oldkeys, newkeys):
        rename_key(diclist, oldkey, newkey)

def apply_to_key(diclist, key, function):
    cset(diclist,key,map(function,cget(diclist,key)))
    
def apply_to_keys(diclist, keys, function):
    for key in keys: apply_to_key(diclist,key,function)    

def cdel(diclist, key):
    n_del = 0
    for dic in diclist:
        if dic.has_key(key):
            n_del = n_del + 1
            del(dic[key])
    return n_del

def cdel_multiple(diclist, keys):
    return map(lambda key,diclist=diclist: cdel(diclist,key),keys)

def mdel(dic, keys):
    n_del = 0
    for key in keys:
        if dic.has_key(key):
            n_del = n_del + 1
            del(dic[key])
    return n_del

def mdel_list(list, indices):
    n_del = 0
    indices.sort()
    indices.reverse()
    for index in indices:
        if len(list) > index: 
            n_del = n_del + 1
            del(list[index])
    return n_del

def get_all_but(list, i):
    result = []
    for j in range(0,len(list)):
        if j!=i: result.append(list[j])
    return result

############################################################
#
#  Counting dictionary keys
#
############################################################

def describe_elements(list,sorter=lambda x:x):
    result = []
    for item,count in my_sort(count_same(list).items(),lambda f,sorter=sorter: sorter(f[0])):
        if count==1: 
            result.append('%s'%item)
        else:
            result.append('%s(x%s)'%(item,count))
    return string.join(result,',')

def count_same(list):
    # from a list of names, constructs a dictionary with the counts of all unique names
    # input: a list
    # output: a dictionary, where the keys are the unique items of the input list,
    #         and the values are the number of times each item appears
    d = {}
    for el in list:
        if d.has_key(el):
            d[el] = d[el] + 1
        else:
            d[el] = 1
    return d

def countdic2percdic(dic, total=None):
    if total==None: total = sum(dic.values())
    percdic = {}
    for key in dic.keys():
        percdic[key] = perc(dic[key],total)
    return percdic

def countdic2ratiodic(dic):
    total = float(sum(dic.values()))
    ratiodic = {}
    for key in dic.keys():
        ratiodic[key] = dic[key]/total
    return ratiodic

def countdic_sum(countdic1, countdic2, strict=0):
    sumdic = {}
    if strict:
        for key in unique(countdic1.keys()+countdic2.keys()):
            sumdic[key] = countdic1[key] + countdic2[key]
    else:
        for key in unique(countdic1.keys()+countdic2.keys()):
            sumdic[key] = sum([0]+cget([countdic1,countdic2],key,0))
    return sumdic
    
def count_key_instances(dictionary_list):
    # input:  a list of dictionaries
    # output: a dictionary of key counts
    
    # from a list of dictionaries, count how many keys are used
    # Ex:  from 15 dictionaries, 13 had 'name', 12 had 'length' etc
    list_of_all_keys = []
    for dictionary in dictionary_list:
        list_of_all_keys.extend(dictionary.keys())
    return count_same(list_of_all_keys)

def sum_counts(counts):
    # sum the counts returned by the above function
    ks = []
    for c in counts:
        ks.extend(c.keys())
    totals = {}
    for k in ks:
        total = 0
        for c in counts:
            if c.has_key(k): total = total + c[k]
        totals[k] = total
    return totals

def mul_counts(countdic1, countdic2):
    # multiply the counts of common keys
    products = {}
    for key in countdic1.keys():
        if countdic2.has_key(key):
            products[key] = countdic1[key] * countdic2[key]
    return products
    

############################################################
#
#  Grouping Into Dictionaries
#
############################################################

def list2map(list):
    # returns a map from value to the indices that contain
    # that value
    value2id = {}
    for i in range(0,len(list)):
        value = list[i]
        if not value2id.has_key(value): value2id[value] = []
        value2id[value].append(i)
    return value2id

def diclist2map(diclist, key):
    # returns a map that gives the index of every element
    # of the diclist that has a particular key.
    return list2map(cget(diclist,key))

def reverse_map_dic(dic,strict=1):
    reverse = {}
    for key,value in dic.items():
        if not reverse.has_key(value): reverse[value]=key
        elif strict: raise "Not one-to-one mapping.  Two keys map to same value %s"%value, (key, reverse[value])
        else: pass#print "Two keys map to same value %s"%value, (key, reverse[value])
    return reverse

def reverse_map_list(list,strict=1):
    reverse = {}
    for key,value in map(None, range(0,len(list)), list): 
        if not reverse.has_key(value): reverse[value]=key
        elif strict: raise "Not one-to-one mapping.  Two indices contain same value %s"%value, (key, reverse[value])
        else: pass#print "Two indices contain same value %s"%value, (key, reverse[value])
    return reverse

def group_diclist(diclist, key):
    # groups the dictionaries in diclist if they share common values for key key
    # the elements in the subsists accessed by grouped[key] are in the same order
    # as the elements in diclist before the grouping
    grouped = {}
    for dic in diclist:
        value = dic[key]
        if not grouped.has_key(value): grouped[value] = []
        grouped[value].append(dic)
    return grouped

def group_single(list, elmt2key = lambda name: name, elmt2value = lambda name: name):
    # groups a list based on the output of the elmt2key function
    #
    # input: list: a list of elements
    #        elmt2key a function that can be applied to every element
    # output: a dictionary indexed by the outputs of the function elmt2key
    #         where each value contains a list of all the elements that had
    #         the same output
    # dic[key] = list of all elements of list for which elmt2key(elmt)=key
    dic = {}
    for name in list:
        key = elmt2key(name)
        if not dic.has_key(key): dic[key] = []
        dic[key].append(elmt2value(name))
    return dic

def group_bipartite(links, edge2point1=lambda link: link[0], edge2point2=lambda link: link[1]):
    # joins a list of links in a biparatite graph into equivalence classes
    # where the criterion for joining two edges is that they share one of
    # their two vertices

    # the trick is to use two different name spaces,
    # by appending 'x' to one set of edge names
    # and appending 'y' to another one of these sets
    # then one can use the group_anyof function 
    
    return group_anyof(links, lambda link,f1=edge2point1,f2=edge2point2: ['x%s'%f1(link),'y%s'%f2(link)])

def group_anyof(list, elmt2keys = lambda elmt: elmt, debug=0):
    # groups a list based on the output of the call
    # elmt2keys(list) which returns all the keys by
    # which a particular element can be indexed.
    #
    # if two elements in the list share an index,
    # they belong to the same group
    #
    # maintains a two-level hash table
    # key -> element -> group
    #
    # every time a new element comes in, we have a new set of keys
    # we find all the elements that were indexed by those keys
    # using the first mapping key->element
    # then we find all the groups to which these elements belong
    # using the second mapping element->group
    #
    # now if there's no groups, we start a new group
    #     if there's only one group, we simply extend this group
    #        to include the new element seen
    #     if there's more than one group, we have to merge them
    #        together and update all the element->group mappings
    #
    # we also add an entry for element->group that points to
    # the appropriate group
    #
    # finally, we update the key->element so that all the keys
    # point to this new element.  We don't care about overriding
    # the mappings for previously indexed elements, since the
    # second level of the hash table will still point to the same
    # group, regardless of which element we go through. 
    
    key2elmt = {} # maps from every key to some element that is indexed by it 
    elmt2group = [] # maps from an element to the group that contains it
    groups = [] # the list of groups

    # only call the function once
    i2keys = map(elmt2keys, list)

    for i in range(0,len(list)):

        # the keys by which that element is indexed
        keys = i2keys[i]

        # other elements that contain such keys
        # note: strict=0, since keys may be unseen before
        other_elmts = mget(key2elmt, keys, strict=0)

        # their groups
        # note: It's strict, coz every element should have a group
        groups_merged = unique(mget(elmt2group, other_elmts,strict=1))

        if debug: print "Element %s (%s) has %s keys: %s, that join it with %s elements in %s groups"%(
            i, list[i], len(keys), keys, len(other_elmts), len(groups_merged))
        
        if len(groups_merged) == 0:
            # no elements contained any of the keys (hence no groups)
            # hence i'm starting a new group
            groups.append([i]) # i can only append, so that i don't change any names
            dad_i = len(groups)-1 # the index of the group is the length-1, since it's the last one
        elif len(groups_merged)==1:
            # only one group contained elements with common keys
            dad_i = groups_merged[0] # which group am i appending to
            groups[dad_i].append(i)
        else:
            # more than one group contained elements with common keys.
            # i must merge them
            dad_i = groups_merged[0] # pick one as the daddy, that will include all others
            dad = groups[dad_i]
            for group_index in groups_merged[1:]:
                # i'm erasing every sibbling, and giving all the elements to daddy

                group = groups[group_index]

                # 1. first of all, the elmt2group will now point to daddy directly
                mset(elmt2group, group, [dad_i]*len(group)) 
                
                # 2. then extend daddy with all the elements of each other group in the set to be merged
                dad.extend(group)

                # 3. finally empty each of the other groups
                groups[group_index] = None # better None then empty list, to fail fast, if i try to append there

            # 4. and then add our latest element into the group merge
            dad.append(i)

        # now, reset all the keys that index the current element to point to the current element
        # mset(key2elmt, keys, [i]*len(keys))
        # note:  i could also only reset the new keys
        for key in keys:
            if not key2elmt.has_key(key): key2elmt[key] = i

        # finally add another entry in our elmt2group table, that points to daddy
        elmt2group.append(dad_i) # this cannot be insert, since i'm indexing with i

        if debug: pp(map(lambda g,list=list: mget(list,g), filter(None,groups)),1)

    # throw away all the deleted groups, and sort the remaining ones by length
    groups = filter(None, groups)
    groups = my_sort(groups, lambda g: -len(g))

    # instead of lists of indices, now turn the groups into actual sets of
    # elements of the original list
    groups = map(lambda g,list=list: mget(list,g), groups)

    assert(sum(map(len,groups))==len(list))

    return groups
    

#def group_single(lst, name2key = lambda name: name):
#    distinct_names = []
#    for elmt in lst:
#        name = name2key(elmt)
#        if not name in distinct_names:
#            distinct_names.append(name)
#    # construct the empty groups
#    groups = {}
#    for name in distinct_names:
#        groups[name] = []
#        
#    for elmt in lst:
#        groups[name2key(elmt)].append(elmt)
#    return groups

def group_pairs(pairs, name2group = lambda name: name):
    """Input: a list of pairs:  (x1, x2, value)
              a mapping: x1 -> group(x1)
    Output:  a dictionary: d[group1][group2] = list of
             all values for which a pair existed such
             that map(x1) = group1 and map(x2) = group2"""
    # gather all the distinct group names
    distinct_names = []
    for pair in pairs:
        name1,name2 = map(name2group, pair[:2])
        if not name1 in distinct_names:
            distinct_names.append(name1)
        if not name2 in distinct_names:
            distinct_names.append(name2)
    n = len(distinct_names)
    distinct_names.sort()
    # construct the empty groups
    groups = {}
    for i in range(0,n):
        groups[distinct_names[i]] = {}
        for j in range(i,n):
            groups[distinct_names[i]][distinct_names[j]] = []
    # append the values in each group
    for pair in pairs:
        name1,name2 = map(name2group, pair[:2])
        value = pair[2]
        groups[min(name1,name2)][max(name1,name2)].append(value)
    return groups

############################################################
#
#  PRINTING dictionaries
#
############################################################

def print_dictionary(dic, keys, lengths, separators, options):
    write_dictionary(sys.stdout, dic, keys, lengths, separators, options)
    
def write_dictionary(f, dic, keys, lengths, separators, functions = {}):
    """ ex: display_dictionary({'a': 7, 'b': 3},
                               ['a','b'],
                               [3,-7],
                               {'a': display_bignum})"""
    for key,length,separator in map(None, keys, lengths, separators):
        format_str = '%'+`length`+'s'
        if dic.has_key(key):
            content = dic[key]
            if functions.has_key(key):
                content = apply(functions[key], [content])
        else:
            if abs(length) >=3: content = ''#'N/A'
            else: content = ''#'-'
        f.write(format_str % content)
        f.write(separator)

def print_dictionary_title(titles, lengths, separators):
    write_dictionary_title(sys.stdout, titles, lengths, separators)

def write_dictionary_title(f, titles, lengths, separators):
    line2 = ''
    for title,length,separator in map(None, titles, lengths, separators):
        format_str = '%'+`length`+'s'
        line2 = line2 + (format_str % title) + separator
    line1 = string.join(map(lambda x: ifab(x=='|','+','-'),line2[:-1]),'')+'\n'
    map(f.write, [line1, line2, line1])

# category 9
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######   DICTIONARIES AND LISTS                 ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

# switching between lists and dictionaries

def list2dic(list, func, rep_ok=None, check_unambiguous=0):
    #
    # Transform a list of elements into a dictionary, indexed by the 
    # value returned by func, when called on the elements of the list
    #
    
    dic,rejected,ambiguous = {},[],{}
    keys_elmts = map(lambda elmt,f=func: (f(elmt),elmt), list)
    for key,elmt in keys_elmts:
        # check if the key already exists
        if not dic.has_key(key): # it's a new key, never seen before
            if check_unambiguous: # i have to check rep_ok either way
                if not rep_ok: raise "Unspecified function rep_ok(key,elmt)"
                if rep_ok(key, elmt): # does it pass the rep_ok
                    dic[key] = elmt # then set the dictionary
                else:
                    rejected.append((key,elmt)) # skip it
            else:
                if not ambiguous.has_key(key): 
                    dic[key] = elmt # always set an unambiguous key
                else: # had previously tried a few keys, none of which passed test
                    if rep_ok(key, elmt): # this one passes the rep_ok test
                        dic[key] = elmt
                        del(ambiguous[key]) # the key is no longer ambiguous
        else:  # the key already exists
            if not rep_ok: # no discriminating function exists
                raise "Two entries share a key",key
            else:
                
                old_good = rep_ok(key, dic[key])
                new_good = rep_ok(key, elmt)

                if old_good and not new_good:
                    rejected.append((key,elmt))
                elif new_good and not old_good:
                    rejected.append((key,dic[key]))
                    dic[key] = elmt
                elif not old_good and not new_good:
                    # not only reject current, but also notice that
                    # you had accepted an ambiguous one
                    rejected.append((key,elmt))
                    rejected.append((key,dic[key]))
                    del(dic[key])
                    ambiguous[key] = None # mark key as one to check always
                else: 
                    # both are good
                    raise "rep_ok passes for both %s and %s for %s.  No discrimination possible."%(
                        dic[key], elmt, key), rep_ok

    # print how many ambiguous
    if rejected:
        print '%s keys were rejected'%len(rejected)
        rejected.sort()
        print string.join(map(lambda r: '%s'%r[0],rejected),', ')
    if ambiguous:
        print '%s ambiguous'%len(ambiguous)
        print string.join(map(lambda a: '%s'%a,my_sort(ambiguous.keys())),', ')
    return dic

def diclist2dicdic(diclist, key, rep_ok=None, check_unambiguous=0):
    # Transform a list of dictionaries into a dictionary of dictionaries, as indexed
    # by key
    return list2dic(diclist, lambda dic,key=key: dic[key], rep_ok, check_unambiguous)

def list2dic_i(list):
    # transforms a list x into a dictionary d
    # where d[i] = x[i] for all i in range(0,len(list))
    #
    # applications:  when you want to erase list elements
    # without ever reusing them, or when you don't want
    # all your elements to be reindexed when you're deleting one
    dic = {}
    for i in range(0,len(list)):
        dic[i] = list[i]
    return dic

def dic2list_i(dic):
    # transforms a dictionary d into a list x where x[i] = d[ith key]
    # where the keys are sorted alphabetically
    keys = dic.keys()
    keys.sort()
    list = []
    for key in keys:
        list.append(dic[key])
    return list

def mergedics(diclist):
    # combinding say ORFs and INTERs in a single dictionary
    # since they have similar representations, but different
    # key namespaces
    superdic = {}
    for dic in diclist:
        for key,value in dic.items(): 
            if superdic.has_key(key):
                raise "Two dictionaries share a key", key
            superdic[key] = value
    return superdic

def filter_diclist(diclist, key, value):
    res = []
    for dic in diclist:
        if dic[key]==value: res.append(dic)
    return res

def filter_diclist_func(diclist, key, func):
    res = []
    for dic in diclist:
        if func(dic[key]): res.append(dic)
    return res

def filter_diclist_not(diclist, key, value):
    res = []
    for dic in diclist:
        if dic[key]!=value: res.append(dic)
    return res

def filter_diclist_gt(diclist, key, value):
    res = []
    for dic in diclist:
        if dic[key]>value: res.append(dic)
    return res

def filter_diclist_gte(diclist, key, value):
    res = []
    for dic in diclist:
        if dic[key]>=value: res.append(dic)
    return res

def filter_diclist_lt(diclist, key, value):
    res = []
    for dic in diclist:
        if dic[key]<value: res.append(dic)
    return res

def filter_diclist_lte(diclist, key, value):
    res = []
    for dic in diclist:
        if dic[key]<=value: res.append(dic)
    return res

def filter_diclist_within(diclist, key, min, max):
    res = []
    for dic in diclist:
        if min<=dic[key]<max: res.append(dic)
    return res

# category 10
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       HIERARCHY OPERATIONS               ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def flatten_hierarchy(hierarchy):

    flat = mreplace(`hierarchy`,'()[]',',,,,')
    flat = filter(None,string.split(flat))

    return flat

def flatten_hierarchy(hierarchy):
    # depth first traversal of the hierarchy
    children = []

    todo = map(None,hierarchy)
    while 1:
        if not todo: return children
        type0 = type(todo[0])
        if type0 == types.ListType or type0 == types.TupleType:
            todo.extend(todo[0])
        else:
            children.append(todo[0])
        todo = todo[1:]

    return children

def map_on_hierarchy_safe(hierarchy,func):
    if type(hierarchy) in [types.ListType,types.TupleType]:
        return map_on_hierarchy(hierarchy,func)
    else:
        return func(hierarchy)

def map_on_hierarchy(hierarchy,func):
    # depth first traversal of the hierarchy

    hiertype = type(hierarchy)

    # assume it's always called on a list or a tuple
    if not hierarchy:
        if hiertype == types.TupleType: return ()
        else: return []

    results = []
    for element in hierarchy: 

        elmt_type = type(element)
        if elmt_type == types.ListType or elmt_type == types.TupleType:
            res = map_on_hierarchy(element,func)
        else:
            res = func(element)
        results.append(res)

    if hiertype==types.TupleType:
        return tuple(results)
    else:
        return results

# def map_on_hierarchy_breadth(hierarchy, func, depth=0):
#     # assume it's always called on a list or a tuple
#     if not hierarchy: return []
# 
#     results = [None]*len(hierarchy)
#     elmt_types = map(type, element)
# 
#     lists  = map(lambda t: t==types.ListType, elmt_types)
#     tuples = map(lambda t: t==types.TupleType, elmt_types)
#     recurse = map(operator.or_, lists, tuples)
#     
#     do_first = find_all(recurse,0)
#     
#     for i in do_first:
#         results[i] = func(hierarchy[i])
# 
#         if elmt_type == types.ListType:
#             res = map_on_hierarchy(element,func)
#         elif elmt_type == types.TupleType:
#             res = tuple(map_on_hierarchy(element,func))
#         else:
#             res = func(element)
#         results.append(res)
# 
#     return results

# section
############################################################
# evaluating a function on a hierarchy from the inside out

def map_on_hierarchy_breadth(hierarchy, func):
    if not hierarchy: return []

    maxdepth = get_hierarchy_depth(hierarchy)

    for depth in range(0,maxdepth+1):
        results = hierarchy_evaluate_at_level(hierarchy,func,depth)
        print "Results at depth=%s are: %s"%(depth,results)

def get_hierarchy_depth(hierarchy):
    depths = []
    for elmt in hierarchy:
        if type(elmt) in [types.ListType, types.TupleType]:
            depths.append(get_hierarchy_depth(elmt)+1)
        else:
            depths.append(0)
    print "Depth of %s is %s"%(hierarchy, max(depths))
    return max(depths)

def hierarchy_evaluate_at_level(hierarchy, func, depth, cur_depth=0):

    assert(cur_depth<=depth)
    results = []
    for elmt in hierarchy:
        if type(elmt) in [types.ListType, types.TupleType]:
            if cur_depth<depth:
                results.append(hierarchy_evaluate_at_level(elmt, func, depth, cur_depth+1))
            else:
                results.append('WAIT')
        else: 
            if cur_depth==depth:
                results.append(func(elmt))
            else:
                results.append('SKIP')
    return results
                
# section
############################################################
# evaluating a function on a hierarchy from the inside out

def hierarchy_evaluate_joins(hierarchy, func, followTuples=1, followLists=1):

    # pp("hierarchy_evaluate_joins %s on %s"%(func,hierarchy),1,90)

    hiertype = type(hierarchy)

    if not ((hiertype==types.TupleType and followTuples) or
            (hiertype==types.ListType and followLists)):
        # pp("I'm evaluating on %s"%hierarchy,1,90)
        return func(hierarchy)
    else: 
        flat = []
        for elmt in hierarchy:
            elmt_type = type(elmt)
            if ((elmt_type==types.TupleType and followTuples) or
                (elmt_type==types.ListType and followLists)):

                # pp(("I'm recursing on %s"%`elmt`),1,90)
                
                flat.append(hierarchy_evaluate_joins(elmt,func,followTuples,followLists))
            else:
                flat.append(elmt)

    if hiertype == types.TupleType: flat = tuple(flat)

    # pp("I'm returning func(%s)"%`flat`)
    return func(flat)
    
# category 11
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       INTERVAL OPERATIONS                ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def alignments_overlap(bn,tx):
    return (tx['chromosome'] == bn['chromosome'] and
            tx['start'] <= bn['end'] and
            bn['start'] <= tx['end'])

def orthologs_in_interval(orthologs, interval):
    return filter(lambda o,i=interval: alignments_overlap(o,i),orthologs)

def intervals_intersect(a,b,start='start',end='end',offset=0):
    return not (a[start]>b[end]+offset or b[start]>a[end]+offset)

def items2intervals(items):
    intervals = []
    for item in items:
        intervals.append({'start': item[0],
                          'end': item[1],
                          'label': item[2]})
    return intervals

def join_intervals(intervals,start='start',end='end'):
    # intervals is a list of
    # [{'start': 1, 'end': 2},
    #  {'start': 3, 'end': 5},
    #  {'start': 4, 'end': 7}]

    if not intervals: return intervals
    
    intervals.sort(lambda a,b,start=start,end=end: a[start]-b[start])

    non_overlapping = []
    current = {start: intervals[0][start],
               end: intervals[0][end]}
    
    for next in intervals[1:]:

        # if you're overlapping the last one, increase current_end
        if intervals_intersect(current, next, start=start, end=end):
            current[end] = max(current[end], next[end])
        else:
            non_overlapping.append(current)
            current = {start: next[start],
                       end: next[end]}

    non_overlapping.append(current)

    rep_type = type(intervals[0])
    if rep_type == types.TupleType:
        assert(start==0 and end==1)
        non_overlapping = map(lambda interval,start=start,end=end: (interval[start],interval[end]), non_overlapping)
    elif rep_type == types.ListType:
        assert(start==0 and end==1)
        non_overlapping = map(lambda interval,start=start,end=end: [interval[start],interval[end]], non_overlapping)
    elif rep_type == types.DictType:
        pass

    return non_overlapping

def group_overlapping(intervals, start='start', end='end', offset=0): 
    if not intervals: return []
    
    intervals.sort(lambda a,b,start=start: a[start]-b[start])

    non_overlapping = []
    current = {start: intervals[0][start],
               end: intervals[0][end],
               'intervals': [intervals[0]]}
    
    for next in intervals[1:]:

        # if you're overlapping the last one, increase current_end
        if intervals_intersect(current, next, start, end, offset):
            current[end] = max(current[end], next[end])
            current['intervals'].append(next)
        else:
            non_overlapping.append(current)
            current = {start: next[start],
                       end: next[end],
                       'intervals': [next]}

    non_overlapping.append(current)

    return cget(non_overlapping,'intervals')

#def interval_union(as,bs):
#    "a,b are lists of intervals containing {'start','end'}"
#    news = []
#    all = as[:]
#    all.extend(bs)
#    all.sort(lambda x,y: cmp(x['start'],y['start']))
#    i = 0
#    junctions_made = 1
#    while junctions_made:
#        junctions_made = 0
#        while i in range(0,len(all)-1):
#            print_intervals([all])
#            s = all[i]
#            t = all[i+1]
#            if s['end'] <= t['start']:
#                new = {'start': s['start'], 'end': max(t['end'],s['end'])}
#                del(all[i])
#                del(all[i+1])
#                all.insert(i,new)
#                junctions_made = 1
#            else:
#                i = i+1
#    return all
#
#def test_interval_union():
#    a = [{'start': 2, 'end': 4},
#         {'start': 5, 'end': 6}]
#    b = [{'start': 1.5, 'end': 2.5},
#         {'start': 3, 'end': 3.5},
#         {'start': 4.5, 'end': 7}]
#    print_intervals([a,b])
#    print interval_union(a,b)

def interval_union(interval_list):
    intervals = map(lambda x: (x['start'],+1), interval_list)
    map(intervals.append, map(lambda x: (x['end'],-1), interval_list))
    intervals.sort(lambda x,y: cmp(x[0],y[0]))
    tot,openings,closings = 0,[],[]
    for i in range(len(intervals)):
        if tot == 0:
            openings.append(intervals[i][0])
        tot = intervals[i][1] + tot
        if tot == 0:
            closings.append(intervals[i][0])
    return map(lambda start,end:{'start':start,'end':end}, openings, closings)    

def test_interval_cut():
    interval_dic = {'a': [{'start': 20, 'end': 40},
                          {'start': 50, 'end': 60}],
                    'b': [{'start': 15, 'end': 25},
                          {'start': 30, 'end': 35},
                          {'start': 45, 'end': 80}]}
    #pp(interval_dic.values())
    print_intervals(interval_dic.values(),'start','end')
    cut = interval_cut(interval_dic,'start','end','set')
    #pp(cut,2)
    print_intervals([cut],'start','end')
    #print_intervals([filter(lambda c: len(c['set'])==2, cut)],'start','end')
    print string.join(map(lambda i: string.join(i['set'],'+'), cut),'\t')
    boom

def interval_subtract(intervals1, intervals2, start='start', end='end'):
    all_sets = interval_cut({'keep': intervals1,'skip': intervals2}, start=start, end=end)

    result = filter(lambda set: 'keep' in set['set'] and 'skip' not in set['set'], all_sets)
    for res in result:
        assert(res['set']==['keep'])
    cdel(result,'set')
    return filter(lambda r,start=start,end=end: r[end]-r[start]+1 != 0, result)


def interval_cut(interval_dic,start='start',end='end',set='set'):

    # For the structure of interval_dic, you can see test_interval_cut
    # the keys of the dictionary are meaningful, they are used in the
    # labels for the interval_list

    # output is a flat list of intervals, each of which has a start,
    # end and a set, where set is an unordered list of the different
    # labels, coming from the keys of the input dictionary.


    # Part 1. Constructing the interval list transform the directory
    # structure into a flat list of start,end coordinates, each with a
    # label of which dictionary entry it came from
    
    interval_list = []
    for key,intervals in interval_dic.items():
        for interval in intervals: 
            interval_list.append({start: interval[start],
                                  end: interval[end],
                                  'label': key})

    #pp(interval_list)

    # 2b. determine if the coordinates are all integers, in which case
    # the end of one interval is one less than the beginning of the
    # next.
    
    coords = flatten(map(lambda i,start=start,end=end: [i[start],i[end]], interval_list))
    if map(int, coords) == coords: adjustment = 1
    else: adjustment = 0

    # Part 2. Transforming the start,end coordinates, into a list of
    # operations to perform at each coordinate position.  You either
    # add of subtract a label from the current set of active intervals
            
    intervals = map(lambda x,start=start: (x[start],'add',x['label']), interval_list)
    intervals.extend(map(lambda x,a=adjustment,end=end: (x[end]+a,'sub',x['label']), interval_list))
    intervals.sort(lambda x,y: cmp(x[0],y[0]))

    #pp(intervals)

    # Part 3.  Actually perform those add or subtract operations,
    # keeping track of the current open intervals.  Construct a list
    # of every region, and all the open intervals there.

    current_set, all_sets = [], []
    for interval in intervals:

        set_operation = case(interval[1],
                             {'add': set_union,
                              'sub': set_subtract})
        
        current_set = set_operation(current_set, [interval[2]])
        
        all_sets.append({start: interval[0],
                         set: current_set})

    #pp(all_sets,1)
    if not current_set==[]:
        print "VERY VERY BAD!! current_set = %s"%current_set
    #assert(current_set == [])

    # Part 4.  Fix each interval in your list, by also appending
    # an end position

    # 4b. go for it
    for i in range(0,len(all_sets)-1):
        all_sets[i][end] = all_sets[i+1][start]-adjustment
    #all_sets[-1]['end'] = intervals[-1][0]

    if not all_sets: return []

    # the last set was only needed to find the end of the 2nd-to-last
    del(all_sets[-1])#['end'] = intervals[-1][0]

    #pp(all_sets,1)

    return all_sets

def test_interval_union():
    a = [{'start': 2, 'end': 4},
         {'start': 5, 'end': 6}]
    b = [{'start': 1.5, 'end': 2.5},
         {'start': 3, 'end': 3.5},
         {'start': 4.5, 'end': 7}]
    all = a[:]
    all.extend(b)
    print_intervals([a,b,interval_union(all)])

def print_intervals(intervals,start='start',end='end'):
    print_scale({'min': min(map(lambda interval,start=start: min(map(lambda x,start=start: x[start], interval)),intervals)), 
                 'max': max(map(lambda interval,end=end: max(map(lambda x,end=end: x[end], interval)),intervals)), 
                 'tick': 10,
                 'width': 120,
                 'print_scale': 1},
                map(lambda interval,start=start,end=end: map(lambda x,start=start,end=end:
                                                             (x[start],x[end],`(x[start],x[end])`),
                                                             interval),
                    intervals))

def common_intervals(intervals_list,start='start',end='end'):
    if not intervals_list: return []
    # returns the extended intervals that are common to all
    common_intervals = intervals_list[0]
    for intervals in intervals_list[1:]:
        intersection = []
        keep_i, keep_j = [], []
        for i in range(0,len(common_intervals)):
            for j in range(0,len(intervals)):
                if intervals_intersect(common_intervals[i],intervals[j],start=start,end=end):
                    keep_i.append(i)
                    keep_j.append(j)
        keep_i = unique(keep_i)
        keep_j = unique(keep_j)

        common_intervals = mget(common_intervals,keep_i)+mget(intervals,keep_j)

    return join_intervals(common_intervals,start=start,end=end)
    

# category 12
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       ISLAND OPERATIONS                  ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def count_islands(list):

    # takes a list, or a string, and gather islands of identical elements.
    # it returns a dictionary counting where
    # counting = {element: [(start,end), (start,end), ...],
    #             element: [(start,end), (start,end), ...],
    #             ...}
    # counting.keys() is the list of unique elements of the input list
    # counting[element] is the list of all islands of occurence of element
    # counting[element][i] = (start,end)
    #  is such that list[start-1:end] only contains element
    if not list: return {}

    counting = {}

    i,current_char, current_start = 0,list[0], 0
    
    while i < len(list):

        if current_char == list[i]:
            i = i+1
        else:
            if not counting.has_key(current_char): counting[current_char] = []
            counting[current_char].append((current_start+1, i))
            current_char = list[i]
            current_start = i

    if not counting.has_key(current_char): counting[current_char] = []
    counting[current_char].append((current_start+1, i))

    return counting

# category 13
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       SET OPERATIONS                     ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def set_compare(set1, set2):
    set1 = unique(set1)
    set2 = unique(set2)
    union = set_union(set1, set2)
    inter = set_intersect(set1, set2)
    set1_only = set_subtract(set1, set2)
    set2_only = set_subtract(set2, set1)
    return {'Set1': (len(set1), set1),
            'Set2': (len(set2), set2),
            'Inter': (len(inter), inter),
            'Union': (len(union), union),
            'set1_only': (len(set1_only), set1_only),
            'set2_only': (len(set2_only), set2_only)}

def set_union(set1, set2):
    new_set = set1[:]
    # add all the elements of set2 that weren't in set 1
    new_set.extend(set_subtract(set2, set1))
    return new_set

def set_union_all(sets):
    return unique(flatten(sets))

def set_intersect_all(sets):
    if not sets: return []
    intersection = sets[0]
    for set in sets[1:]:
        intersection = set_intersect(set, intersection)
    return intersection

def venn_diagram(sets):
    sets = map(unique,sets)
    all = range(0,len(sets))

    results = []
    for subset in superset(all):
        complement = set_subtract(all,subset)

        within  = set_intersect_all(mget(sets,subset))
        outside = set_union_all(mget(sets,complement))

        exactly_in = set_subtract(within, outside)

        #print "In %s:  %s"%(subset, len(exactly_in))
        #print display_list(exactly_in)
        if subset: #exactly_in:
            results.append((subset,exactly_in))
    return results



def fast_set_intersect(set1,set2):
    dic1 = {}
    for elmt in set1: dic1[elmt] = None
    return filter(dic1.has_key, set2)

def smart_set_intersect(set1,set2):
    dic1 = {}
    for elmt in set1:
        if dic1.has_key(elmt): dic1[elmt] = dic1[elmt]+1
        else: dic1[elmt] = 1
    common = []
    for elmt in set2:
        if not dic1.has_key(elmt): continue
        common.append(elmt)
        if dic1[elmt] == 1: del(dic1[elmt])
        else: dic1[elmt] = dic1[elmt] - 1
    return common

def set_intersect(set1, set2, el2hash=lambda x: x):
    if not set1: return []
    if not set2: return []
    new_set = []
    set2_dic = {}

    #print "Hashes are: "
    #pp(map(el2hash,set1))
    #pp(map(el2hash,set2))
    
    for elmt in set2: set2_dic[el2hash(elmt)] = ''
    for elmt in set1:
        #if elmt in set2:
        if set2_dic.has_key(el2hash(elmt)): 
            new_set.append(elmt)
    return new_set

def set_equal(set1, set2):
    if len(set1) != len(set2): return 0
    return my_sort(set1) == my_sort(set2)

def set_subtract(set, subset, el2hash=lambda x: x):
    new_set = []
    subset_dic = {}
    for elmt in subset: subset_dic[el2hash(elmt)] = ''
    for elmt in set:
        if not subset_dic.has_key(el2hash(elmt)):
            new_set.append(elmt)
    return new_set

def set_included(set, superset):
    return len(set_subtract(set, superset))==0

def permutations(set):
    # [1,2,3]
    # -> [[1,2,3],[1,3,2],
    #     [2,1,3],[2,3,1],
    #     [3,1,2],[3,2,1]]
    if len(set)==0:
        return [[]]
    else: 
        this = set[0]
        others = set[1:]

        rests = permutations(others)

        perms = []
        for i in range(0,len(others)+1):
            for rest in rests: 
                perm = []
                perm.extend(rest[:i])
                perm.append(this)
                perm.extend(rest[i:])
                perms.append(perm)
            
        return perms

def superset_sorted(set):
    # this is just for the paranoid ones that want their elements in the order
    
    indices_super = superset(range(0,len(set)))

    indices_super = my_sort(indices_super, lambda c: (len(c),c))

    super = []
    for indices in indices_super:
        super.append(mget(set,indices))
    return super

def superset(set):
    # calculate the superset of the input set.
    # ex: [1,2,3] -> [[],[1],[2],[3],[12],[13],[23],[123]]

    if len(set) == 0:
        return [set]
    else:
        subsuper = superset(set[:-1])
        lastel = set[-1:]

        super = []
        i,j = 0,0

        next1 = subsuper[i]
        next2 = subsuper[j]+lastel

        maxlen = max(map(len,subsuper))
        while 1:
            if len(next1) <= len(next2) and next1 <= next2: 
                super.append(next1)
                i = i+1
                if i == len(subsuper): break
                next1 = subsuper[i]
            else:
                super.append(next2)
                j = j+1
                if j == len(subsuper): break
                next2 = subsuper[j]+lastel
        assert(i == len(subsuper))
        for j in range(j,len(subsuper)):
            super.append(subsuper[j]+lastel)

        #print 'Super has %s elements'%len(super)
            
    return super

def all_combinations(sets_list):
    if len(sets_list) == 0:
        return [[]]
    else:
        results = []
        
        subres = all_combinations(sets_list[1:])

        for element in sets_list[0]:
            for sub in subres: 
                results.append([element]+sub)

        return results

def equivalence_classes(identity_pairs): 
    # having a list of pairs such as (a,b), (a,c) which means a==b and a==c,
    # it returns a list of equivalence classes such as [(a,b,c)]
    # equivalence classes are the transitive closure of the join operation
    # of two sets, where you join two sets if they have an element in common
    # ex: ['ab','ac','ad','ef'] -> [['a', 'b', 'c', 'd'], ['e', 'f']]

    sets_of = {} # for each element x, the sets_of[x] are the sets that x belongs to
    sets = {} # all the unique sets ever created
    i = 0
    for pair in identity_pairs:
        a,b = pair
        if not sets_of.has_key(a): sets_of[a] = []
        if not sets_of.has_key(b): sets_of[b] = []
        # create a new set with only elements a and b in it
        if a != b: 
            sets[i] = [a,b]
            sets_of[a].append(i)
            sets_of[b].append(i)
        else:
            sets[i] = [a]
            sets_of[a].append(i)
        i = i + 1 # increment i, coz a set id never reappears
        
    #pp({'sets_of': sets_of, 'sets': sets},2)

    some_junction_was_made = 1
    while some_junction_was_made:
        some_junction_was_made = 0
        for a in sets_of.keys():
            # if a belongs to more than one set, join them
            if len(sets_of[a]) > 1:
                some_junction_was_made = 1
                # print "\nElement %s belongs to more than one set %s" % (a, sets_of[a])
                
                # create a new set
                # print "    Creating a new set %s for the junction of %s" % (i, sets_of[a])
                sets[i] = []
                
                # remove 'a' from every set it belongs to
                # print "    Removing %s from every set it belongs to" % a
                for old_set in sets_of[a]:
                    sets[old_set].remove(a)

                # print "  Adding %s to the new set %s" % (a, i)
                sets[i].append(a)
                
                # for every set that 'a' belongs to
                for old_set in sets_of[a]:
                    # first removing a from every set it belonged to
                    # print "  Considering old set %s" % (old_set)
                    # destroy the old set
                    for element in sets[old_set]:
                        # element doesn't belong it old set anymore
                        # print "    Removing %s from old set %s" % (element, old_set)
                        sets_of[element].remove(old_set)
                        # actually add the element to the new set
                        if not element in sets[i]:
                            # print "    Adding %s to new set %s" % (element, i)
                            sets_of[element].append(i)
                            sets[i].append(element)
                        else:
                            # print "    Element %s already belonged to set %s" % (element, i)
                            pass
                    # print "    Deleting old set %s" % old_set
                    del(sets[old_set])
                sets_of[a] = [i]
                sets[i].sort()
                i = i+1
                # print "Done with element %s" % a
        #if some_junction_was_made: print "some junction was made, please continue"
    else:
        #print "\nNo more junctions are possible.  I'm done"
        pass
    #pp({'sets_of': sets_of, 'sets': sets},2)
    equivalence_classes = sets.values()
    equivalence_classes.sort()
    return equivalence_classes

def group_into_equivalence_classes(list, are_same = lambda a,b: 0, debug=0):
    # groups a list into equivalence classes, where two items
    # are equivalent
    # if are_same(a,b) or are_same(a,c)
    #    if c is equivalent to a or b

    if debug: print "Testing  %s possible pairings"%(len(list)*len(list))

    used = {}
    pairs = []
    for i in range(0,len(list)):
        for j in range(i+1,len(list)):
            if are_same(list[i], list[j]):
                pairs.append((i,j))
                if not used.has_key(i): used[i] = None
                if not used.has_key(j): used[j] = None

    if debug: print "Reducing %s actual pairs to equivalence classes"%len(pairs)

    #print "Identities found between elements: %s"%pairs
    classes = equivalence_classes(pairs)

    if debug: print "Found %s equivalence classes"%len(classes)

    result = []
    for each_class in classes:
        result.append(mget(list,each_class))

    #print "Result is........................: %s"%result
    
    for i in range(0, len(list)):
        if not used.has_key(i):
            result.append([list[i]])
    
    if debug: print "And %s signletons"%(len(result)-len(classes))

    #print "After adding singletons..........: %s"%result

    return result

def group_into_equivalence_classes2(list, are_same = lambda a,b: 0, debug=0):
    print "I must group a total of %s elements"%len(list)
    groups = []
    n = 0 # the number of comparisons
    for element1 in list:
        groups_merged = []
        
        for i in range(0,len(groups)):
            for element2 in groups[i]:
                n = n+1
                if are_same(element1,element2):
                    groups_merged.append(i)
                    break

        if len(groups_merged) == 0:
            # create a new group
            #groups.insert(0,[element1])# instead of append, prepend
            groups.append([element1])# instead of append, prepend
        elif len(groups_merged) == 1:
            # append an element to an existing group
            i = groups_merged[0]
            #groups[i].insert(0,element1) #instead of append, prepend
            groups[i].append(element1) #instead of append, prepend
        else:
            print "Merging groups %s"%groups_merged
            dad = groups[groups_merged[0]]
            for i in groups_merged[1:]:
                dad.extend(groups[i])
                groups[i] = []
        groups = my_sort(filter(None, groups),lambda g: -len(g))
        if len(groups)%100 == 0:
            print "After %s comparisons, %s groups contain %s elements"%(
                n, len(groups), sum(map(len,groups)))
    print "%s comparisons"%n
    return groups

def group_into_equivalence_classes_hash(list, element2classes = lambda a: [a]):
    # element2classes returns the list of hash keys under which element a will be indexed
    # note:  if a is not hashable, then you lose
    key2group = {}
    elements2groups = {}
    groups = []
    for i in range(0,len(list)):
        element = list[i]
        # all the keys by which element can be indexed
        keys = element2classes(element)
        # all the groups that already contain such elements
        groups_merged = mget(key2element, keys, 0)


def bipartite2tree(connections, edge2point1=lambda a: a[0], edge2point2=lambda a: a[1]):

    # first group according to the first point
    groups = group_single(connections, edge2point1)

    groups2 = group_single(connections)
    
    
    

        
            
    
    
    

            
# category 14
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       SIGNIFICANCE TESTS                 ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def sample_variance(samples):
    mean = avg(samples)
    return sum(map(lambda s,mean=mean: (s-mean)**2), samples)/(len(samples)-1)

def t_test(samples1, samples2):
    n1 = len(samples1)
    n2 = len(samples2)

    mean1 = avg(samples1)
    mean2 = avg(samples2)

    var1 = sample_variance(samples1)
    var2 = sample_variance(samples2)

    return t_test_only(n1,mean1,var1,n2,mean2,var2)

def t_test_only(n1,mean1,var1,n2,mean2,var2): 
    var = ((n1-1)*var1 + (n2-1)*var2) / (n1+n2-2)
    print "s^2 = %s"%var

    t = (mean1-mean2) / math.sqrt(var/n1 + var/n2)
    dof = n1+n2-2

    return t,dof

# category 15
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       MATH PRIMITIVES                    ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

def gaussian(x, mu, sigma):
    """
    Evaluate N(mu,sigma) at x.
    where N(mu,sigma) is a gaussian of mean mu and stddev sigma
    """

    return (  (1.0/math.sqrt(2*math.pi*sigma))
            * (math.e**(-((x-mu)**2)/(2*sigma**2))))

def make_gaussian(mu, sigma):
    """ usage: 
    N2_3 = make_gaussian(2,3)
    N2_3(4) -> guassianN(2,3) evaluated at 4
    """
    return lambda x,mu=mu,sigma=sigma: (  (1.0/math.sqrt(2*math.pi*sigma))
                                          * (math.e**(-((x-mu)**2)/(2*sigma**2))))

def make_adder(n):
    """ usage: 
    Add2 = make_adder(2)
    Add2(3) -> 5
    """
    return lambda x,n=n: x+n


def prob2score(prob):
    # 1/100 -> 20
    try: 
        return -10*float(math.log10(float(prob)))
    except:
        return -1

def prob2stdevs(prob):
    pass


#log10_2 = math.log10(2)
loge_2 = math.log(2)

def log2(x):
    # converting bases:  log_a(b) = log_c(b)/log_c(a)
    # i.e. log_2(x) = log_e(2)/log_e(x) = log_10(2)/log_10(x)
    return math.log(x)/loge_2

#def log_k(x,k):
#    # take the k-th base log
#    log

def p2bits(p):
    # return -log2(p)
    return -math.log(p)/loge_2

def profile2bits(profile):
    p = 1.0
    for char in profile:
        p = p*.25*len(IUB_expansion[char])
    return -math.log(p)/loge_2

def binomial_likelihood_ratio(ps,k,n):
    # p[0] is the null hypothesis
    # p[1] is the hypothesis being tested
    assert(len(ps)==2)
    likelihoods = []
    for p in ps:
        likelihoods.append(binomial(p,k,n))
    #i = argmax(likelihoods)
    #p = likelihoods[i] / sum(likelihoods)
    #return p
    if likelihoods[0]: return log(likelihoods[1]) / likelihoods[0]
    else:
        print "Warning: likelihood ratio set to sys.maxint.  p(H1)=%s, p(H0)=0"%(p[1])
        return sys.maxint

def binomial_log_likelihood_ratio(ps,k,n):
    return log_binomial(ps[1],k,n) - log_binomial(ps[0],k,n)

def poisson_expected(rate):
    for x in range(1,50,1):
        p = poisson(rate,x)
        print "%s\t%s\t%s"%(x,p,12000000*p)

def poisson(rate, x):
    return math.exp(-rate)*(rate**x)/factorial(x)
        
       
    

############################################################
############################################################
############################################################

def smart_binomial_sigmas(ps,ks,ns):
    assert(len(ps)==len(ks)==len(ns))
    signs = []
    for i in range(0,len(ks)):
        p,k,n = ps[i],ks[i],ns[i]
        if p==0:
            p=ps[i]=0.5/n
            #print "Fixing i=%s p,k,n=%s,%s,%s to p=%s"%(
            #    i,p,k,n,.000000001)
        elif p==1:
            p=ps[i]=(n-0.5)/n
            #print "Fixing i=%s p,k,n=%s,%s,%s to p=%s"%(
            #    i,p,k,n,0.999999999)
            #p=ps[i]=0.999999999
        assert(k<=n)
        # do i see more k's than i'd expect
        if n==0: signs.append('-')
        else: signs.append(ifab(p <= float(k)/float(n),'+','-'))

    # now compute the appropriate tail area (right or left, depending on sign)
    tails = binomial_tails(ps,ks,ns,signs)

    # and transform them into sigmas
    sigmas = ps2sigmas(tails)

    # then flip the sigmas depending on the sign
    for i in range(0,len(sigmas)): 
        if signs[i]=='-': sigmas[i] = -sigmas[i]

    return sigmas

def binomial_tails(ps,ks,ns,signs=None):
    assert(len(ps)==len(ks)==len(ns))
    if not signs: signs = ['+']*len(ps)
    assert(0<=min(ps)<=max(ps)<=1)
    for i in range(0,len(ks)): assert(ks[i]<=ns[i])
    inlines = map(lambda p,k,n,sign: '%s %s %s %s'%(n,k,p,sign), ps,ks,ns,signs)
    outlines = map(string.strip,string.split(quick_system('/home/franklin/nickp/bin/calctail',
                                                          string.join(inlines,'\n')),'\n'))[:-1]
    assert(len(inlines)==len(outlines))
    #print "%s -> %s"%(len(inlines),len(outlines))
    #pp(map(lambda a,b: '%s -> %s'%(a,b),inlines,outlines),1)
    return map(float,outlines)

def ps2sigmas(probabilities):
    inlines = string.join(map(lambda s: '%s'%s, probabilities),'\n')
    outlines = string.split(quick_system('/home/franklin/nickp/bin/calcz',inlines),'\n')[:-1]
    sigmas = map(float,map(string.strip,outlines))
    assert(len(sigmas)==len(probabilities))
    return sigmas
    
############################################################
############################################################
############################################################

def binomial_tail(p,k,n,cache=None,debug=0):
    # the log probability of seeing k or more successes in n trials
    # given the probability of success is p
    if cache == None:
        sum = 0
        for k_iter in range(k,n+1): sum = sum+safe_exp(log_binomial(p,k_iter,n))
        return sum

    sum = 0
    lookups,total = 0,0
    for k_iter in range(k,n+1):
        if debug: total = total+1
        if cache.has_key((k_iter,n)):
            add = cache[(k_iter,n)]
            lookups = lookups+1
        else:
            add = safe_exp(log_binomial(p,k_iter,n))
            cache[(k_iter,n)] = add
        sum = sum+add
    if debug and cache: print "p=%s k=%s n=%s %s / %s binomials were lookups.  Cache now has %s items"%(
        p,k,n,lookups,total,len(cache))
    return sum

def binomial_tail(p,k,n):
    # the log probability of seeing k or more successes in n trials
    # given the probability of success is p
    if k < n/2:
        sum = 0
        for k_iter in range(k,n+1):
            sum = sum+safe_exp(log_binomial(p,k_iter,n))
        #print sum
    else:
        sum = 1.0
        for k_iter in range(0,k):
            sum = sum-safe_exp(log_binomial(p,k_iter,n))
        #print sum
    return sum


def safe_log(n):
    try: return math.log(n)
    except OverflowError:
        if n==0: return -1e400
        else: return 1e400

def safe_exp(n):
    try: return math.exp(n)
    except OverflowError:
        #sys.stdout.write("x")
        #sys.stdout.flush()
        if n<0: return 0.0
        else: return 1e400

def log_binomial(p,k,n):
    # the log probability of seeing exactly k successes in n trials
    # given the probability of success is p
    return log_n_choose_k(n,k)+math.log(p)*k+math.log(1-p)*(n-k)

def binomial(p,k,n):
    # probability of seeing exactly k successes in n trials, given
    # the probability of success is p
    #return n_choose_k(n,k)*(p**k)*((1-p)**(n-k))
    return n_choose_k(n,k)*(p**k)*((1-p)**(n-k))

def n_choose_k(n,k):
    # (n k) = n! / (k! (n-k)!)
    #
    #         n*(n-1)*(n-2)*....*(n-k+1)
    #       = --------------------------
    #              k*(k-1)*...*1
    assert(k<=n)
    k = min(k, n-k)
    nominator   = range(n,n-k,-1)
    denominator = range(k,0,-1)
    
    result = 1.0
    for nom, den in map(None, nominator, denominator):
        result = (result * nom) / den
        #result = result*nom
        #print result
        #result = result/den
        #print result
        
    return result

def log_n_choose_k(n,k):
    # (n k) = n! / (k! (n-k)!)
    #
    #         n*(n-1)*(n-2)*....*(n-k+1)
    #       = --------------------------
    #              k*(k-1)*...*1
    assert(k<=n)
    k = min(k, n-k)
    nominator   = range(n,n-k,-1)
    denominator = range(k,0,-1)

    result = 0
    for nom, den in map(None, nominator, denominator):
        result = (result + math.log(nom)) - math.log(den)
    return result

def factorial(n):
    result = 1
    for i in range(n,0,-1):
        #print i
        result = result * i
    return result

def factorial_partial(n,k):
    # carries out the multiplication up to and including k
    # n*(n-1)*(n-2)*...*k
    assert(k<=n+1)

    result = 1
    for i in range(n,k-1,-1):
        #print i
        result = result*i
    return result

def test_chi_square():
    chi_square([[45,448],[57,157]])

def make_expected(rows):
    rowtotals = map(sum, rows)
    coltotals = map(sum, unpack(rows))
    grandtotal = float(sum(rowtotals))

    expected = []
    for row,rowtotal in map(None, rows,rowtotals):
        expected_row = []
        for obs, coltotal in map(None, row, coltotals):
            exp = rowtotal * coltotal / grandtotal
            expected_row.append(exp)
        expected.append(expected_row)
    return expected

def chi_square(rows, expected=None):
    # ex: rows = [[1,2,3],[1,4,5]]
    assert(all_same(map(len,rows)))

    #print "row totals: %s"%rowtotals
    #print "col totals: %s"%coltotals

    if 0 in map(sum,rows): return 0,1.0
    cols = map(lambda i,rows=rows: cget(rows,i), range(0,len(rows[0])))
    if 0 in map(sum,cols): return 0,1.0

    if not expected:
        expected = make_expected(rows)

    chisq = 0
    for obss,exps in map(None,rows,expected):
        for obs, exp in map(None, obss, exps):
            chisq = chisq + ((obs-exp)**2)/exp

    df = (len(rows)-1)*(len(rows[0])-1)

    p = chi_square_lookup(chisq,df)

    #print "Chi square(df=%s,P<=%s) ~ %s"%(df,p,chisq)
    return chisq,p

chi_square_table = {
    1: [1.64, 2.71, 3.84, 5.02, 6.64, 10.83],
    2: [3.22, 4.61, 5.99, 7.38, 9.21, 13.82],
    3: [4.64, 6.25, 7.82, 9.35, 11.34, 16.27],
    4: [5.99, 7.78, 9.49, 11.14, 13.28, 18.47],
    5: [7.29, 9.24, 11.07, 12.83, 15.09, 20.52],
    6: [8.56, 10.64, 12.59, 14.45, 16.81, 22.46],
    7: [9.80, 12.02, 14.07, 16.01, 18.48, 24.32],
    8: [11.03, 13.36, 15.51, 17.53, 20.09, 26.12],
    9: [12.24, 14.68, 16.92, 19.02, 21.67, 27.88],
    10: [13.44, 15.99, 18.31, 20.48, 23.21, 29.59],
    11: [14.63, 17.28, 19.68, 21.92, 24.72, 31.26],
    12: [15.81, 18.55, 21.03, 23.34, 26.22, 32.91],
    13: [16.98, 19.81, 22.36, 24.74, 27.69, 34.53],
    14: [18.15, 21.06, 23.68, 26.12, 29.14, 36.12],
    15: [19.31, 22.31, 25.00, 27.49, 30.58, 37.70],
    16: [20.47, 23.54, 26.30, 28.85, 32.00, 39.25],
    17: [21.61, 24.77, 27.59, 30.19, 33.41, 40.79],
    18: [22.76, 25.99, 28.87, 31.53, 34.81, 42.31],
    19: [23.90, 27.20, 30.14, 32.85, 36.19, 43.82],
    20: [25.04, 28.41, 31.41, 34.17, 37.57, 45.31],
    21: [26.17, 29.62, 32.67, 35.48, 38.93, 46.80],
    22: [27.30, 30.81, 33.92, 36.78, 40.29, 48.27],
    23: [28.43, 32.01, 35.17, 38.08, 41.64, 49.73],
    24: [29.55, 33.20, 36.42, 39.36, 42.98, 51.18],
    25: [30.68, 34.38, 37.65, 40.65, 44.31, 52.62],
    26: [31.79, 35.56, 38.89, 41.92, 45.64, 54.05],
    27: [32.91, 36.74, 40.11, 43.19, 46.96, 55.48],
    28: [34.03, 37.92, 41.34, 44.46, 48.28, 56.89],
    29: [35.14, 39.09, 42.56, 45.72, 49.59, 58.30],
    30: [36.25, 40.26, 43.77, 46.98, 50.89, 59.70]}

def chi_square_lookup(value, df):
    
    ps = [0.20, 0.10, 0.05, 0.025, 0.01, 0.001]
    row = chi_square_table[df]

    for i in range(0,len(row)):
        if row[i] >= value:
            i = i-1
            break

    #print "Table[%s] -> %s"%(row[i],ps[i])
    #print "Table[%s] -> ?"%value
    #if i<len(row)-1: print "Table[%s] -> %s"%(row[i+1],ps[i+1])

    #print "Chi sq with df=%s and value %s has P<%s"%(df,value,p)
    
    if i == -1: return 1
    else: return ps[i]



# category 16
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######       SEQUENCE OPERATIONS                ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  Sequence primitives
#
############################################################

def reverse(something):
    type_something = type(something)
    if type_something==types.StringType:
        return reverse_str(something)
    elif type_something==types.ListType:
        return reverse_list(something)
    else:
        raise "Unexpected type for reversal", type_something
        

def reverse_str(str):
    lst = map(None, str)
    lst.reverse()
    return string.join(lst, '')

comp_trans = string.maketrans('ACGTKMYRSWBVDHNacgtkmyrswbvdhn',
                              'TGCAMKRYSWVBHDNtgcamkryswvbhdn')

# reverse complement mapping
#    A -> T
#    C -> G
#    G -> C
#    T -> A
#    
#    K -> M  # keto (GT) -> amine (CA)
#    M -> K  # amine (AC) -> keto (TG)
#    
#    Y -> R  # pyrimidine (CT) -> purine (GA)
#    R -> Y  # purine (AG) -> pyrimidine (TC)
#    
#    S -> S  # strong (GC) -> strong (CG)
#    W -> W  # weak (AT) -> weak (TA)
#    
#    B -> V  # not A -> not T
#    V -> B  # not T -> not A
#    D -> H  # not C -> not G
#    H -> D  # not G -> not C
#    
#    N -> N  # any -> any
#   
#   # and the procedure which helped us find the above :o)
#   def make_full_translation():
#       for char,expansion in sort(IUB_expansion.items()):
#   
#           print '(%s->%s)'%(char,expansion)
#   
#           if char == '-': continue
#           target = IUB_code[string.join(my_sort(complement(expansion)),'')]
#           print '%s -> %s'%(char, target)

def complement(str):
    return string.translate(str, comp_trans)

def revcom(str):
    lst = map(None, str)
    lst.reverse()
    return string.translate(string.join(lst, ''), comp_trans)

def sequence_type(seq):
    chars = unique(map(None, seq))
    non_nucleotide = set_subtract(chars, IUB_code.values())
    non_amino_acid = set_subtract(chars, protein.aminoacids)

def add_gaps_to_gapless(ali,seq,gap='-'):
    # gaps are given by ali, added to seq
    assert(gapless_len(ali)==len(seq))
    seq = map(None,seq)
    for i in reverse(find_all(ali,gap)):
        #print i
        seq.insert(i,gap)
    return string.join(seq,'')

def gapless(seq,gaps=['-']):
    if gapless_len(seq,gaps)==0: return ''
    for gap in gaps: seq = string.replace(seq,gap,'')
    return seq

def gapless_portions(seqs,gaps=['-'],maxwidth=0):
    # returns the gapless portions of an alignment
    gapped_positions = []
    for seq in seqs:
        for gap in gaps: 
            gapped_positions.extend(find_all(seq,gap))
    
    gapped_positions = my_sort(unique(gapped_positions))

    portions = []
    for start,end in list2pairs([-1]+gapped_positions+[len(seqs[0])]):
        if start>=end-1: continue
        if maxwidth:
            portions.extend(seqs2windows(get_subseqs(seqs,start+1,end),maxwidth))
        else: 
            portions.append(get_subseqs(seqs,start+1,end))
    return portions

def gapless_in_species(seqs,gaps=['-'],which_species=0):
    # returns the subset of columns in the alignment that don't have a gap in seqs[which_species]
    gapped_positions = []
    for gap in gaps:
        gapped_positions.extend(find_all(seqs[which_species],gap))
    
    portions = []
    for start,end in list2pairs([-1]+gapped_positions+[len(seqs[0])]):
        if start>=end-1: continue
        portions.append(get_subseqs(seqs,start+1,end))

    newseqs = []
    for i in range(0,len(seqs)):
        newseqs.append(string.join(cget(portions,i),''))

    return newseqs

def gapless_len(seq,gaps=['-']):
    totlen = len(seq)
    for gap in gaps:
        totlen = totlen - string.count(seq,gap)
    return totlen

def append_gaps(seqs):
    seqs = seqs[:]
    maxlen = max(map(len,seqs))
    for i in range(0,len(seqs)):
        if len(seqs[i])!=maxlen:
            seqs[i] = seqs[i]+'-'*(maxlen-len(seqs[i]))
    return seqs

def join_subseqs(subseqs_list):
    assert(all_same(map(len,subseqs_list)))
    seqs = []
    for i in range(0,len(subseqs_list[0])):
        seqs.append(string.join(cget(subseqs_list,i),''))
    return seqs
    
def get_subseqs(seqs, start, end=None):
    subseqs = []
    for seq in seqs:
        if end==None: subseqs.append(seq[start:])
        else: subseqs.append(seq[start:end])
    return subseqs

def get_subseqs_with_context(seqs, start, end=None, context=5):
    pre = get_subseqs(seqs, max(start-context,0), start)
    middle = get_subseqs(seqs,start,end)
    if end+context > len(seqs): post = get_subseqs(seqs, end, end+context)
    else: post = get_subseqs(seqs, end, min(end+context,len(seqs)))

    subseqs = []
    for i in range(0,len(seqs)):
        subseqs.append('%s|%s|%s'%(string.lower(pre[i]),
                                   middle[i],
                                   string.lower(post[i])))
    return subseqs

def get_multisubseqs(seqs, start_ends):
    all_subseqs = []
    for start,end in start_ends:  
        all_subseqs.append(get_subseqs(seqs,start,end))
    subseqs = []
    for i in range(0,len(seqs)):
        subseqs.append(string.join(cget(all_subseqs,i),''))
    return subseqs

def smart_getsubseqs(seqs, start, end):
    subseqs = []
    for seq in seqs:
        subseqs.append(smart_getsubseq(seq,start,end))
    return subseqs

def smart_getsubseq(seq, start, end):
    length = end-start
    subseq = seq[start:end]
    if '-' in subseq:
        if subseq[0]!='-':
            subseq = gapless(seq[start:end+20])[:length]
            print "Dealt with gaps %s -> %s"%(seq[start:end], subseq)
        elif subseq[-1]!='-':
            subseq = gapless(seq[start-20:end])[-length:]
            print "Dealt with gaps %s -> %s"%(seq[start:end], subseq)
        else:
            print "Gaps everywhere!  Don't know what to do! %s"%seq[start-(2*length):end+(2*length)]
    return subseq


def get_inside(seqs, coords):
    # seqs is a multiple alignment of sequences
    # coords is a list of the form [{'start': 40, 'end': 100}, {'start': 150,'end': 200}, ...]
    # start and end are inclusive
    # returns all the sequences that fall inside the coords
    inside = []
    for coord in coords:
        inside.append(get_subseqs(seqs, coord['start']-1,coord['end']))
    return inside

def coords_in2out(coords, maxlen):
    newcoords = list2pairs(coords,
                           lambda c1,c2: {'start': c1['end']+1,
                                          'end': c2['start']-1})
    newcoords.insert(0,{'start': 1,
                        'end': coords[0]['start']-1})
    newcoords.append({'start': coords[-1]['end']+1,
                      'end': maxlen})
    
    return newcoords

def get_outside(seqs, coords):
    # seqs is a multiple alignment of sequences
    # coords is a list of the form [{'start': 40, 'end': 100}, {'start': 150,'end': 200}, ...]
    # start and end are inclusive
    # returns all the sequences that fall inside the coords
    outside = []
    newcoords = coords_in2out(coords, len(seqs[0]))
    
    return get_inside(seqs,newcoords)

def flatten_seqslist(seqs_list):
    seqs = [None]*len(seqs_list[0])
    for i in range(0,len(seqs_list[0])):
        seqs[i] = string.join(cget(seqs_list,i),'')
    return seqs
    
############################################################
#
# Printing aligned sequences that span multiple lines
#
############################################################

def write_wrap_label2str(label, max_label_len):
    return sfill(label,max_label_len)+': '

def clustal_wrap_label2str(label, max_label_len):
    return ('%%-%ss'%max(max_label_len+1,16))%label

def print_wrap(sequences, indentation, labels=[], ignore_lines=None):
    write_wrap(sys.stdout, sequences, indentation, labels, ignore_lines, write_nums=1)

def write_wrap(f, sequences, indentation, labels=[], ignore_lines=None, write_nums=1,label2str = write_wrap_label2str, offset=0):

    if not labels:
        labels = map(lambda s: '', sequences)
    if len(sequences)!=len(labels):
        diff = len(sequences) - len(labels)
        print "Len(seqs)=%s, Len(labels)=%s. %s extra labels"%(len(sequences), len(labels), ifab(diff>0,'Adding','Ignoring'))
        if diff>0: labels = labels+['']*diff
        else: labels = labels[:len(sequences)]
    
    max_label_len = max(map(lambda label: len(label), labels))
    max_length = max(map(len, sequences))
    iterations = range(0, max_length, indentation)
    if not iterations: iterations = [0]
    for i in iterations:
        if write_nums:
            upper_bound = i+indentation
            if upper_bound > len(sequences[0]): upper_bound = len(sequences[0])
            f.write('%d~%d:\n' % (i+offset, upper_bound+offset))
        else: f.write('\n')
        for seq,label in map(None,sequences,labels):

            subseq = seq[i:i+indentation]
            if ignore_lines and sum(map(lambda ignore,subseq=subseq: string.count(subseq,ignore),ignore_lines))==len(subseq):
                continue

            if max_label_len: f.write(label2str(label,max_label_len))
            f.write(subseq+'\n')

def pw(names, seqs, indent=90, offset=0,strict=1):
    if not strict:
        if len(names)!=len(seqs):
            print "BEWARE!!  %s seqs vs. %s names.  IGNORING EXTRA"%(
                len(seqs),len(names))
            if len(names) < len(seqs): names = names+['']*(len(seqs)-len(names))
            if len(names) > len(seqs): seqs  = seqs+['']*(len(names)-len(seqs))
        seqs = map(lambda s: '%s'%s,seqs)
        if not all_same(map(len,seqs)):
            maxlen = max(map(len,seqs))
            print "BEWARE!! seqs are not all the same length"
            seqs = map(lambda s,maxlen=maxlen: s+' '*(maxlen-len(s)), seqs)
        
    write_wrap(sys.stdout, seqs, indent, names, write_nums=1, label2str=write_wrap_label2str, offset=offset)

def ww(f, names, seqs, indent=80):
    write_wrap(f, seqs, indent, names, write_nums=1, label2str=write_wrap_label2str)

def write_clustal(f, names, seqs):
    if names[-1]=='' or names[-1]=='consensus':
        names = names[:-1]+['']
    else:
        names = names+['']
        seqs = seqs+[compute_clustal_consensus(seqs)]
    
    write_wrap(f, seqs, 60, map(lambda n: safe_replace(n,' ','_'), names), write_nums=0, label2str=clustal_wrap_label2str)

############################################################
#
#  Calculate consensus from two sequences
#
############################################################

def idperc(q,s,countgaps=0):
    assert(countgaps in [0,1])
    same,tot = 0,0
    for char1,char2 in map(None,q,s):
        if is_gap(char1) or is_gap(char2):
            tot=tot+countgaps
            continue
        if char1==char2: same = same+1
        tot = tot+1
    if not tot: return 0
    return 100.0*same/tot

def q_and_s(q0,s0):
    if q0==s0 and q0!='-' and q0!='X': return q0
    else: return ' '

def sequence_identities(q, s):
    return string.join(map(q_and_s,q,s),'')

def bar_q_and_s(q0,s0):
    if q0==s0 and q0!='-' and q0!='X': return '|'
    else: return ' '
    
def bar_sequence_identities(q, s):
    return string.join(map(bar_q_and_s,string.upper(q),string.upper(s)),'')

def star_q_and_s(q0,s0):
    if q0==s0 and q0!='-' and q0!='X': return '*'
    else: return ' '
    
def star_sequence_identities(q, s):
    return string.join(map(star_q_and_s,string.upper(q),string.upper(s)),'')

#def sequence_identities_new_but_slower(q, s):
#    result_list = map(None, q)
#    for i,q0,s0 in map(None, range(0,len(result_list)),q,s):
#        if q0!=s0: result_list[i] = ' '
#    return string.join(result_list,'')

def q_pos_s(q0,s0):
    common = []
    s_all = IUB_expansion[s0]
    for c in IUB_expansion[q0]:
        if c in s_all:
            common.append(c)
    if common: return get_IUB_code(string.join(common,''))
    else: return ' '

def sequence_positives(q, s):
    # incorporates IUB codes in comparing profiles
    return string.join(map(q_pos_s, string.upper(q), string.upper(s)),'')

def bar_q_pos_s(q0,s0):
    if q0==s0: return '|'
    s_all = IUB_expansion[s0]
    for c in IUB_expansion[q0]:
        if c in s_all:
            return '+'
    return ' '

def bar_sequence_positives(q, s):
    # incorporates IUB codes in comparing profiles
    return string.join(map(bar_q_pos_s, string.upper(q), string.upper(s)),'')

############################################################
#
# Clustal stuff
#
############################################################

clustal_NN2star = string.maketrans('ACGTacgt','********')
clustal_AA2star = string.maketrans('*ABCDEFGHIKLMNPQRSTVWXYZ+ ','************************: ')

# def compute_clustal_consensus(list_of_seqs):
#     sum = list_of_seqs[0]
#     for seq in list_of_seqs[1:]:
#         #print sum
#         sum = sequence_identities(sum, seq)
#     return string.translate(sum, clustal_NN2star)

def test_missing2pieces():

    names = ['seq1','seq2','seq3']
    seqs = ['ATTTAGATACGAGT',
            'ATTAAG...CGACT',
            'AT..AG.TACGACT']

    # -> AT TT AG A TA CGAGT
    # -> AT TA AG . .. CGACT
    # -> AT .. AG . TA CGACT

    pw(names, seqs)

    pieces = missing2pieces(names,seqs)
    for subnames,subseqs in pieces:
        print "\nNext piece"
        pw(subnames,subseqs)

def missing2pieces(allnames, allseqs):
    # see test_missing2pieces above for example
    alldots = map(lambda seq: map(lambda char: char=='.', seq), allseqs)
    names, seqs, dots = [], [], []
    for name,seq,dot in map(None,allnames,allseqs,alldots):
        if 0 in dot:
            names.append(name)
            seqs.append(seq)
            dots.append(dot)
    if not names: return [(allnames, allseqs)]
    
    islands_list = map(count_islands,dots)
    #pp(islands_list,1)
    intervals = map(lambda islands: map(lambda ab: {'start': ab[0], 'end': ab[1]},
                                        islands[0]),
                    islands_list)
    intervald = {}
    mset(intervald, names, intervals)
    cut = interval_cut(intervald)
    #pp(cut,1)

    pieces = []
    for region in cut:
        if not region['end'] >= region['start']: continue
        #print "Piece1: %(start)s-%(end)s"%region
        pieces.append((allnames, get_subseqs(allseqs,region['start']-1,region['end'])))
    return pieces


def char_perfect_or_gap(c):
    if c in 'ACGT': return '*'
    if c in 'acgt': return '*'
    else: return ' '

def consensus2idperc(consensus):
    if not consensus: return 0
    return 100*string.count(consensus,'*')/len(consensus)

def consensus2idperc_float(consensus):
    if not consensus: return 0
    return 100.0*string.count(consensus,'*')/len(consensus)

def describe_idperc(consensus):
    return perc(count_stars(consensus),len(consensus))

def count_stars(consensus):
    return string.count(consensus,'*')

def compute_clustal_consensus_ignoring_gaps(seqs):
    return compute_clustal_consensus(seqs,'-.')

def compute_clustal_consensus_ignoring_endgaps(seqs,chars_to_ignore='.'):
    return compute_clustal_consensus(clustal_endgaps2dots(seqs),chars_to_ignore='.'+chars_to_ignore)

#def compute_clustal_consensus(seqs):
#    numseqs = len(seqs)
#    length = len(seqs[0])
#
#    first = seqs[0]
#    others = seqs[1:]
#    
#    consensus = [' ']*length
#    for i in range(0,length):
#        char = first[i]
#        for seq in others:
#            if char == seq[i]: continue
#            else: break
#        else:
#            consensus[i] = '*'
#    return string.join(consensus,'')

def compute_clustal_consensus(seqs,chars_to_ignore='.',result=None):
    # see also def generate_profile()

    
    # put the chars to ignore in a dictionary, to decrease lookup time
    ignoreit = {}
    mset(ignoreit,chars_to_ignore,[None]*len(chars_to_ignore))
    # allocate the necessary memory for the result in one shot
    if not result: result = ['']*len(seqs[0])
    assert(all_same(map(len,seqs+[result])))

    for i in range(0,len(seqs[0])):
        # a dictionary for how many chars we've seen
        seen = {}
        for seq in seqs:
            if ignoreit.has_key(seq[i]): continue
            seen[seq[i]] = None
        # if they're all the same, don't even wonder, match
        if len(seen)==0: result[i] = ' '
        elif len(seen) == 1:
            if seen == '-': result[i]==' '
            else: result[i] = '*'
        # if there's more than two, forget it, no match
        elif len(seen) > 2: result[i] = ' '
        # if they're exactly two though, upcase'em
        else:
            one,two = seen.keys()
            if len(unique(string.upper(one+two)))==1: result[i] = '*'
            else: result[i] = ' '
    return string.join(result,'')

def quick_clustal_consensus(seqs):
    # this is a function that runs faster and computes a case-insensitive consensus
    seqs = map(string.upper,seqs)
    assert(all_same(map(len,seqs)))
    consensus = [' ']*len(seqs[0])
    for i in range(0,len(seqs[0])):
        char = seqs[0][i]
        if char == '-': continue
        for j in range(1,len(seqs)):
            if seqs[j][i] == '-': break
            if seqs[j][i]!=char:
                break
        else:
            consensus[i]='*'
    return string.join(consensus,'')

def compute_soft_consensus(seqs):

    assert(all_same(map(len,seqs)))
    # LOWERCASE MEANS GAP TO IUB CODES!! FIX THIS
    seqs = map(string.upper, seqs)
    
    result = []
    for i in range(0,len(seqs[0])):

        IUBs = map(expand_IUB, cget(seqs,i))
        common_bits = len(set_intersect_all(IUBs))
        any_bits = len(set_union_all(IUBs))

        ratio = 100.0 * common_bits / any_bits
        if common_bits==any_bits==4:
            result.append('_')
        elif ratio==0:
            result.append(' ')
        elif ratio<50:
            result.append('.')
        elif ratio<100:
            result.append(':')
        elif ratio==100:
            result.append('*')
        else:
            raise "OOPS!"
    return string.join(result,'')

#def compute_clustal_consensus_ignoring_chrs(list_of_seqs,chars_to_ignore=['.'],result=''):
#    # computes the consensus of a list of sequences, where '.' or other specified characters
#    # mean that information is just not available
#
#    # result is a list of characters that will be used in evaluating 
#    
#    assert(all_same(map(len, list_of_seqs)))
#    if not result: result = map(None, '*'*len(list_of_seqs[0]))
#    else: result = map(None, result)
#    assert(not ' ' in result)
#    cols_of_chars = unpack(list_of_seqs)
#    for i in range(0,len(cols_of_chars)):
#        if not len(set_subtract(unique(cols_of_chars[i]), chars_to_ignore)) == 1:
#            result[i] = ' '
#    try: result = string.join(result,'')
#    except TypeError: pass # this means it was already a string
#
#    return result

def compute_evidence(seqs, chars_to_ignore=['.']):
    assert(all_same(map(len, seqs)))
    cols_of_chars = unpack(seqs)
    result = []
    for i in range(0,len(cols_of_chars)):
        result.append(len(filter(lambda c,ignores=chars_to_ignore: not c in ignores, cols_of_chars[i])))
    return result

def clustal_endgaps2dots(seqs):
    # takes a multiple alignment of sequences, and replaces the end gaps
    # by dots, to allow computation of a consensus despite missing data
    return map(endgaps2dots, seqs)

def endgaps2dots(seq,gaps='-',dots='.'):

    if string.count(seq,gaps)==len(seq): return dots*len(seq)

    is_gap = map(lambda char,gaps=gaps: char==gaps, seq)
    first_non_gap = is_gap.index(0)
    last_non_gap = reverse_list(is_gap).index(0)

    if last_non_gap!=0: middle = seq[first_non_gap:-last_non_gap]
    else: middle = seq[first_non_gap:]

    return dots*first_non_gap+middle+dots*last_non_gap

def mstrip(seq,gaps):
    if mcount(seq,gaps)==len(seq): return ''

    is_gap = map(lambda char,gaps=gaps: char in gaps, seq)
    first_non_gap = is_gap.index(0)
    last_non_gap = reverse_list(is_gap).index(0)

    if last_non_gap!=0: middle = seq[first_non_gap:-last_non_gap]
    else: middle = seq[first_non_gap:]

    return middle
    
    
        
# def endgaps2dots(seq,gaps='-',dots='.'):
#     assert(len(gaps)==len(dots)==1)
#     unique_char = ' '
#     
#     assert(string.count(seq,unique_char)==0)
# 
#     #print seq
# 
#     seq = string.replace(seq,gaps,unique_char)
#     #print seq
# 
#     # strip left
#     right_portion = string.lstrip(seq)
#     seq = dots*(len(seq) - len(right_portion)) + right_portion
#     #print seq
# 
#     # strip right
#     left_portion = string.rstrip(seq)
#     seq = left_portion + dots*(len(seq) - len(left_portion))
#     #print seq
# 
#     # put gaps back in the middle
#     seq = string.replace(seq,unique_char,gaps)
#     #print seq
# 
#     return seq

def compute_conserved_genome(seqs,nomatch='_'):
    # generates a list of all
    assert(all_same(map(len,seqs)))

    upseqs = map(string.upper,seqs)

    newseq = [nomatch]*len(seqs[0])
    for i in range(0,len(seqs[0])):
        upchars = cget(upseqs,i)
        if all_same(upchars):
            chars = cget(seqs,i)
            newseq[i] = majority(chars)
    return string.join(newseq,'')

############################################################
#
# Sliding windows and computing scores
#
############################################################

def score_seqs(seqs, method, increment=50, length=25):
    # a wrapper that scores sequences, depending on
    # different scoring schemes
    if method=='mutations':
        tree = build_tree(seqs)
        mutations = number_of_mutations(seqs,tree)['mut']
        return score_mutations(mutations, increment, length)
    elif method=='consensus':
        consensus = compute_clustal_consensus(seqs)
        return score_consensus(consensus, increment, length)
    elif method=='profile':
        profile = generate_profile(seqs)
        return score_profile(profile, increment, length)
    else:
        raise "Unknown method: ", method

# scoring profiles, consensuses
def score_profile(profile, increment=25,length=50):
    # for every window of length 50 at increments of 25
    # compute a score for the particular profile
    score = []
    for i in range(0,len(profile),increment):
        subseq = profile[i:][:length]
        score.append(string.count(subseq,'*'),
                     len(subseq))
    return score

def score_consensus(consensus, increment=25,length=50,floats=0):
    # for every window of length 50 at increments of 25
    # compute a score for the particular consensus sequence
    score = []
    for i in range(0,len(consensus),increment):
        subseq = consensus[i:][:length]
        if floats: score.append(100.0*float(string.count(subseq,'*'))/float(len(subseq)))
        else: score.append(100*string.count(subseq,'*')/len(subseq))
    return score

def score_mutations(mutations, increment=25,length=50):
    # for every window of length 50 at increments of 25
    # compute a score for the number of mutations occuring at
    # that point
    score = []
    for i in range(0,len(mutations),increment):
        score.append(avg(mutations[i:][:length]))
    # now reverse the mutations to obtain a score-like metric
    raise 'unimplemented'
    return score

############################################################
#
# PROFILES and IUB
#
############################################################

def is_gap(char):
    return char=='-' or char=='.'

IUB_code = {'A':'A',  'C':'C',  'G':'G',  'T':'T',
            'AC':'M', 'AG':'R', 'AT':'W', 
            'CG':'S', 'CT':'Y', 'GT':'K',
            'ACG':'V','ACT':'H','AGT':'D','CGT':'B',
            'ACGT':'N','-': '-'}

IUB_expansion = {'A':'A',  'C':'C',  'G':'G',  'T':'T',
                 'M':'AC', 'R':'AG', 'W':'AT', 
                 'S':'CG', 'Y':'CT', 'K':'GT',
                 'V':'ACG','H':'ACT','D':'AGT','B':'CGT',
                 'N':'ACGT', ' ':'ACGT', '-': '-',

                 'a':'A-',  'c':'C-',  'g':'G-',  't':'T-',
                 'm':'AC-', 'r':'AG-', 'w':'AT-', 
                 's':'CG-', 'y':'CT-', 'k':'GT-',
                 'v':'ACG-','h':'ACT-','d':'AGT-','b':'CGT-',
                 'n':'ACGT-', '.': 'ACGT-'}

#def IUB_matches(char,char):
#    # IUB_includes['Y']['C'] = 1 all of C  in CT
#    # IUB_includes['V']['Y'] = 0 not all of CT in ACG
#    # IUB_includes['H']['Y'] = 1 all of 
#    
#    {
#
#    }

def IUB_superset(IUB_char):
    set = IUB_expansion[IUB_char]
    all_subsets = filter(None,map(lambda s: string.join(my_sort(s),''),superset(map(None,set))))
    return mget(IUB_code,all_subsets)

def expand_IUB(chars):
    return map(None, string.join(map(lambda c: IUB_expansion[c], chars),''))

def get_IUB_code(str):
    # ignore gaps (otherwise, generate a 2nd order model)
    if '-' in str: has_gap = 1
    else: has_gap = 0
    if str == '-': return str
    str = string.replace(str,'-','')
    # if any of the chars is "N", then IUB will be N
    
    if 'N' in str: result = 'N'
    else: result = IUB_code[str]

    if has_gap: return string.lower(result)
    else: return result

def generate_profile(list_of_seqs):
    """Assert: list_of_seqs contains seqs that are all same size """
    return string.join(map(lambda chars:
                           get_IUB_code(string.join(my_sort(unique(expand_IUB(chars))),'')),
                           unpack(list_of_seqs)),'')

def IUPAC_from_chardic(char):
    # char is of the form:  {'A': 3, 'G': 2, 'C': 1, 'T': 0}

    # from transfac: 
    # A single nucleotide is shown if its frequency is greater than
    # 50% and at least twice as high as the second most frequent
    # nucleotide.  A double-degenerate code indicates that the
    # corresponding two nucleotides occur in more than 75% of the
    # underlying sequences, but each of them is present in less than
    # 50%.  Usage of triple-degenerate codes is restricted to those
    # positions where one of the nucleotides did not show up at all in
    # the sequence set and none of the afore-mentioned rules applies.
    # All other frequency distributions are represented by the letter
    # "N".

    # first count total number of chars
    total = sum(char.values())

    items = my_sort(char.items(), lambda c_n: -c_n[1])

    assert(len(items) == 4)

    if (2*items[0][1] >= total and # top char is at least 50% of time
        2*items[1][1] <= items[0][1]): # and more than twice that of second most frequent

        return items[0][0]

    if 4*(items[0][1] + items[1][1]) >= 3*total:

        return IUB_code[string.join(my_sort(cget(items[:2],0)),'')]

    if items[-1][1] == 0:
        
        return IUB_code[string.join(my_sort(cget(items[:3],0)),'')]
    
    return 'N'

ambiguity_explanation = {'M':'[ac]','R':'[ag]','W':'[at]','S':'[cg]','Y':'[ct]','K':'[gt]',
                         'm':'[ac]','r':'[ag]','w':'[at]','s':'[cg]','y':'[ct]','k':'[gt]',
                         'V':'[acg]','H':'[act]','D':'[agt]','B':'[cgt]','N':'[acgt]',
                         'v':'[acg]','h':'[act]','d':'[agt]','b':'[cgt]','n':'[acgt]'}

def explain_profile(seq):
    # progressively replace each of the ambiguous characters that may be present in
    # the sequence by the mappings above
    for key,val in ambiguity_explanation.items():
        seq = string.replace(seq, key, val)
    return seq


############################################################
#
#   Sequence Motifs
#
############################################################

profile_ambiguities = {'A': 'A', 'C': 'C', 'G':'G', 'T':'T',
                       'S':'CG','W':'AT','R':'AG','Y':'CT','M':'AC','K':'TG',                       
                       'B':'TCG','D':'ATG','H':'ATC','V':'ACG','N':'ATCG'}

def profile2seqs(profile):
    # creates all the instances of a motif that are possible,
    # expanding Y into CT etc
    list = []
    if not profile:
        return ['']
    for rest in profile2seqs(profile[:-1]):
        for first in map(None, profile_ambiguities[profile[-1]]):
            list.append(rest+first)
    return list

def find_profile_in_genome(profile,all_chrs):
    for instance in profile2seqs(profile):
        print "Looking for %s"%instance
        for chr in all_chrs: 
            print find_all(chr,instance)


def star_patterns(length):
    patterns = []
    for which_ones_on in superset(range(0,length)):
        new = [' ']*length
        for i in which_ones_on:
            new[i] = '*'
        patterns.append(string.join(new,''))
    return patterns
        
def find_profile(seq, profile):
    all_positions = []
    for instance in profile2seqs(profile):
        instance_positions = find_all(seq, instance)
        #if instance_positions: print "Looking for %s -> %s"%(instance, instance_positions)
        all_positions.append(instance_positions)
    return my_sort(flatten(all_positions))

def find_multimer(genome, multimer):
    # multimer is: {'seqs': ['CGG','CCG'],
    #               'gaps': [(11,11)]}
    # which means: 
    #  1. 'CGG'
    #  2. a gap between 11 and 11 bases
    #  3. 'CCG'

    seqs = multimer['seqs']
    mingaps = cget(multimer['gaps'],0)
    maxgaps = cget(multimer['gaps'],0)
    
    oldlist = find_profile(genome, seqs[0])
    for prev, next, mingap, maxgap in map(None, seqs[:-1], seqs[1:], mingaps, maxgaps):
        newlist = find_profile(genome, next)

        oldlist = join_lists(oldlist, newlist,
                             len(prev)+mingap,
                             len(prev)+maxgap)
    return oldlist
        
        
    
def join_lists(list1, list2, mindist, maxdist):
    # returns all the elements of list1 that also satisfy the
    # constraint that at least one element of list2 is within
    # some distance mindist<dist<maxdist
    #
    # note: one-sided test always list1[i] < list2[j]

    # note:  in the case of mindist = 3 and maxdist = 5
    # list1 = [1,2,7,8]
    # list2 = [4,5]
    # then these start,end pairs satisfy the condition
    # (1,4),(1,5),(2,4),(2,5)
    # however, we're only returning [1,2] as opposed to [1,1,2,2]
    # accoring to the definition above

    newlist = []
    i = j = 0
    while i<len(list1) and j<len(list2):
        diff = list2[j] - list1[i]
        if diff < mindist:
            j = j+1
        elif diff > maxdist:
            i = i+1
        else:
            newlist.append(list1[i])
            i = i+1
    return newlist

def trim_lists_nearby(list1, list2, mindist):
    # returns the list of elements of list1 that are
    # within mindist of some element in list2 and
    # vice-versa

    newlist1,newlist2 = [],[]
    i = j = 0
    while i<len(list1) and j<len(list2):

        diff = list2[j] - list1[i]
        if abs(diff) < mindist:
            # distance is too big, increment the smaller one
            newlist1.append(list1[i])
            newlist2.append(list2[j])

        if diff>0: i = i+1
        else: j = j+1

    return (eliminate_duplicates(newlist1),
            eliminate_duplicates(newlist2))
            
def trim_lists_nearby(list1, list2, mindist):
    # returns the list of elements of list1 that are
    # within mindist of some element in list2 and
    # vice-versa

    pairs = []
    newlist1,newlist2 = [],[]
    i = j = 0
    while i<len(list1) and j<len(list2):

        diff = list2[j] - list1[i]
        if abs(diff) < mindist:
            # distance is too big, increment the smaller one
            newlist1.append(list1[i])
            newlist2.append(list2[j])
            pairs.append((list1[i],list2[j]))

        if diff>0: i = i+1
        else: j = j+1

    return (eliminate_duplicates(newlist1),
            eliminate_duplicates(newlist2),
            pairs)
            
############################################################
#
#  Statistical significance of a match to a profile...
#
############################################################

def random_posperc(seq):
    # if a sequence is actually a profile, made of
    # many possibilities for each character, what
    # posperc identity rate would you expect with
    # a random uniform (.25) sequence ? 
    score = 0
    for base in seq:
        possibilities = IUB_expansion[base]
        score = score + .25 * len(possibilities)
    return int(100*score/len(seq))

def GCaware_posperc(seq, base_occurences={'A': 1, 'C': 1, 'G': 1, 'T': 1}):
    # same as random_posperc

    # if a sequence is actually a profile, made of
    # many possibilities for each character, what
    # posperc identity rate would you expect with
    # a random sequence of known GC content ?
    
    total = sum(base_occurences.values())
    base_freq = {}
    for base in base_occurences.keys():
        base_freq[base] = float(base_occurences[base])/float(total)
    score = 0
    for base in seq:
        possibilities = IUB_expansion[base]
        for possibility in possibilities:
            if base_freq.has_key(possibility):
                score = score + base_freq[possibility]
    return int(100*score/len(seq))

############################################################
#
#  Sequence Analysis
#
############################################################

def acgt_distribution(list_of_chars):
    acgt = {'A': 0, 'C': 0, 'G': 0, 'T': 0, '-': 0}
    for char in list_of_chars:
        if char in 'ACGT-':
            acgt[char] = acgt[char] + 1
    return acgt

def gc_content(seq):
    gc = mcount(seq, 'GCgc')
    at = mcount(seq, 'ATat')
    return 100*gc/(gc+at)

def gc_content_profile(profile):
    exp = mget(IUB_expansion, string.replace(profile,'N',''))
    expgc = map(gc_content,exp)
    return avg(expgc)

def is_cpg(seq):
    binary = map(lambda s: s in 'GCgc', seq)
    sums = sum_window(binary, 50)
    if not sums: return 0
    if max(sums) >= 25: return 1
    else: return 0

def prob_seq(seq, pGC=.5):
    # given a GC content, what is the probability
    # of getting the particular sequence
    
    assert(0<=pGC<=1)
    # the probability of obtaining sequence seq
    # given a background gc probability of .5
    ps = []
    for char in seq:
        if char in 'CG': ps.append(pGC/2)
        elif char in 'AT': ps.append((1-pGC)/2)
        else: raise "Unexpected char: ",char
    return reduce(operator.mul, ps, 1)

def prob_profile(profile,pGC):
    # given a GC content, what is the probability
    # of getting the particular profile
    assert(0<=pGC<=1)
    ps = []
    for IUB_char in string.upper(profile):
        p = []
        for char in IUB_expansion[IUB_char]:
            if char in 'CG': p.append(pGC/2)
            elif char in 'AT': p.append((1-pGC)/2)
            else: raise "Unexpected char: ",char
        ps.append(avg(p))
    return reduce(operator.mul, ps, 1)

def prob_match(seq1, seq2, pGC, debug=0):
    # given the particular GC content,
    # what is the likelihood of seeing the match
    # of seq1 to seq2 we observe, simply due
    # to the background GC content

    # the way i calculate this is by saying
    # P(match | GC content) =
    #    P(match | G)

    # prob of seeing the GC content of sequence1,
    # given the background probability of any GC
    p1 = prob_profile(seq1,pGC)
    p2 = prob_profile(seq2,pGC)
    #print "P(%s | pGC=%s)=%s"%(seq1,pGC,p1)
    #print "P(%s | pGC=%s)=%s"%(seq2,pGC,p2)

    # prob of finding a match given the two sequences
    # are what they are.  Shuffling
    seq1p = prob_each_char(seq1)
    seq2p = prob_each_char(seq2)

    p_same = 0
    for char, ratio1 in seq1p.items():
        ratio2 = seq2p[char]
        p_same = p_same + ratio1*ratio2
        #print "P(%s1)*P(%s2)=%s"%(char,char,ratio1*ratio2)
    #print "P(same)=%s"%p_same

    p = p_same*p1*p2

    if 1:
        null0 = (.25)*(.25**len(seq1))*(.25**len(seq2))
        if debug: print "P(same|seq1=%s seq2=%s GC=%s)/P(same|GC=.5)=%s"%(
            seq1,seq2,pGC,p/null0)
    
    return p

def prob_each_char(profile):
    # the probability that the profile generates a char
    # in any of the positions (also the percent char
    # of the overall profile)
    counts = {'A': 0, 'C': 0, 'G': 0, 'T': 0}
    for IUB_char in profile:
        all = IUB_expansion[IUB_char]
        ratio = 1.0/len(all)
        for char in all:
            counts[char] = counts[char]+ratio
    total = float(sum(counts.values()))
    #if not total: return counts
    ratios = {}
    for char,value in counts.items():
        ratios[char] = value/total
    return ratios
        
############################################################

def cluster_significance_unsafe(pickedpos, picked, totpos, tot):
    # i picked pickedpos + pickedneg objects.
    # pickedpos fell in class, pickedneg fell negside class.
    # knowing that the class contains totpos objects from
    # a total of totpos+totneg objects, then what is the
    # probability that i would have picked pickedpos or
    # higher objects inside my class by chance alone?

    pickedneg = picked-pickedpos
    totneg    = tot - totpos

    denominator = n_choose_k(tot,picked)
    sum = 0
    for pickedpos_iter in range(pickedpos, min(picked+1,totpos+1)):

        pickedneg_iter = picked-pickedpos_iter
        sum = sum + (n_choose_k(totpos, pickedpos_iter) *
                     n_choose_k(totneg, pickedneg_iter))
        
        # how many ways to choose k objects from tot

    return sum/denominator

def test_cluster_significance(num_trials = 20000):
    

    totpos,tot = 300,6000
    picked = 80
    
    list = shuffle(['A']*totpos+['B']*(tot-totpos))

    picked_counts = {}
    
    res = []
    for i in range(0,num_trials):

        random_points = pick_n(list,picked)
        pickedpos = random_points.count('A')

        if not picked_counts.has_key(pickedpos): picked_counts[pickedpos] = 0
        picked_counts[pickedpos] = picked_counts[pickedpos] + 1

        if i%1000==999:
            print "%s tested %s to go, %s hypers computed"%(i+1,num_trials-i-1,len(res))

            pp(picked_counts)

            cumulative = 0

            print "Tested\tValue\tCount\tHyper\t\tCumul\tCluster\t\tTails\tBinom"
            for pickedpos,count in my_sort_rev(picked_counts.items()):

                cumulative = cumulative + count

                clust = cluster_significance(pickedpos, picked, totpos, tot)
                hyper = hypergeometric(pickedpos, picked, totpos, tot)

                binom = binomial_tail(picked/float(tot),pickedpos,totpos)
                
                if pickedpos/float(picked) >= totpos/float(tot):
                    symbol='>='
                    tail = cumulative
                else:
                    symbol='<='
                    tail = i+1-cumulative+count

                print "%s\t%s%s\t%s\t%0.1f\t%s\t%s\t%0.1f\t%s\t%s\t%0.1f\t%s"%(
                    i+1,
                    symbol,pickedpos,
                    count, hyper*(i+1), ifab(count>10,'%0.0f%%'%(100.0*count/(hyper*(i+1))),'-'),
                    tail,  clust*(i+1), ifab(count>10,'%0.0f%%'%(100.0*tail/(clust*(i+1))),'-'),
                    cumulative, binom*(i+1), ifab(count>10,'%0.0f%%'%(100.0*cumulative/(binom*(i+1))),'-'))
                    
                    #ifab(observed > 10, '='*int(40.0*observed/expected),' Too few to judge'))
                
                    
def minuslog_hypergeometric(pickedpos, picked, totpos, tot):
    pickedneg,totneg = picked-pickedpos, tot-totpos
    logP = log_n_choose_k(totpos, pickedpos) + \
           log_n_choose_k(totneg, pickedneg) - \
           log_n_choose_k(tot, picked)
    return -logP


def hypergeometric(pickedpos, picked, totpos, tot):
    return math.exp(-minuslog_hypergeometric(pickedpos, picked, totpos, tot))

def cluster_significance(pickedpos, picked, totpos, tot):


    ## >I also have a scientific question. Could you give me the formula for the
    ## >hypergeometric distribution? (I would like to use it instead of the
    ## >chi-squared test.)
    ## 
    ##          n_k(picked_in, total_in)*n_k(picked_out, total_out)
    ## P(x=k) = ---------------------------------------------------
    ##                    n_k(picked_total, total)
    ## 
    ## If i pick 400 genes at random (picked_total),
    ## and 300 are in my category (picked_in), 
    ## then picked_out = picked_total - picked_in = 100. 
    ## 
    ## The category now contained a total of 600 genes (total_in),
    ## out of say 6000 genes in yeast (total).  This makes 
    ## total_out =  total - total_in = 5400 genes outside the 
    ## category. 
    ## 
    ## This gives us the probability of picking exactly 300 genes
    ## in and 100 genes out when i pick 400 genes at random, 
    ## given that the category has 600 genes out of 6000 possible. 
    ## 
    ## However, you need the probability of picking at least 300
    ## genes.  So you need the sum of the above for k=300,301,302...
    ## all the way to min(picked_total, in_total). 
    
    # i picked pickedpos + pickedneg objects.
    # pickedpos fell in class, pickedneg fell negside class.
    # knowing that the class contains totpos objects from
    # a total of totpos+totneg objects, then what is the
    # probability that i would have picked pickedpos or
    # higher objects inside my class by chance alone?

    pickedneg = picked-pickedpos
    totneg    = tot - totpos

    assert(pickedpos <= totpos)
    assert(pickedneg <= totneg)

    # if the number of positives picked is greater than the
    # number of negatives picked, simply reverse the labels
    pickedratio = float(pickedpos) / picked
    totratio    = float(totpos)    / tot

    if pickedratio < totratio:
        # what is the prob that *so few* get picked by chance
        the_range = range(max(picked-totneg,0), pickedpos+1)
    elif pickedratio >= totratio:
        # what is the prob that *that many* get picked by chance
        #the_range = range(pickedpos, min(picked+1,totpos+1), -1)
        the_range = range(min(picked,totpos), pickedpos-1, -1)

    #print "Picked_pos/picked: %s when total_pos/total %s"%(
    #    perc(pickedpos,picked),perc(totpos,tot))

    denominator = log_n_choose_k(tot,picked)
    sum = 0
    for pickedpos_iter in the_range:

        #print "P(%s/%s | %s/%s)="%(
        #    pickedpos_iter,picked,totpos,tot)

        pickedneg_iter = picked-pickedpos_iter
        inc_log = (log_n_choose_k(totpos, pickedpos_iter) +
                   log_n_choose_k(totneg, pickedneg_iter) -
                   denominator)
        inc = safe_exp(inc_log)

        #print "P(%s/%s | %s/%s)=%s"%(
        #    pickedpos_iter,picked,totpos,tot,inc)

        sum = sum + inc
                
        # how many ways to choose k objects from tot

    return sum

def birthday_paradox(total, chosen):
    assert(chosen < total)

    prob_not = 0
    for i in range(1,chosen):
        prob_not = prob_not + math.log(total-i) - math.log(total)
    return 1-math.exp(prob_not)
    
    
        
    
    



# category 17
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######    PAIRWISE AND MULTIPLE ALIGNMENTS      ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  Counting Substitutions
#
############################################################

def nn_counts(seq):
    return {'A': string.count(seq,'A'),
            'C': string.count(seq,'C'),
            'G': string.count(seq,'G'),
            'T': string.count(seq,'T'),
            '-': string.count(seq,'-')}

def compare_sequences(seq1,seq2):
    print "Comparing two sequences"
    counts2 = nn_counts(string.replace(seq2,'-',''))
    len2 = sum(counts2.values())

def summarize_seq(seq): 
    counts = nn_counts(seq)
    gapless_length = sum(mget(counts,'ACGT'))
    length = sum(counts.values())
    return "Total %s bp; GC %s; A %s, T %s, C %s, G %s, - %s"%(
        dec(gapless_length),
        safe_percentile(counts['G']+counts['C'], gapless_length),
        safe_percentile(counts['A'], length),
        safe_percentile(counts['T'], length),
        safe_percentile(counts['C'], length),
        safe_percentile(counts['G'], length),
        safe_percentile(counts['-'], length))

def substitutions_pstg(seq1,seq2):
    # returns the probabilities that:
    # p: any nucleotide will remain as itself
    # s: the nucleotide will remain a purine or pyrimidine
    # t: the nucleotide will transition purine <-> pyrimidine

    print "Seq1: %s"%summarize_seq(seq1)
    print "Seq2: %s"%summarize_seq(seq2)
    
    counts = count_substitutions(seq1,seq2)
    #percs = substitution_counts2perc(counts)
    #symmcounts = substitution_counts2symm(counts)
    #pp(percs)
    #p,s,t = substitutions2transitions(counts)
    p,s,t,g = substitutions2probs(counts)
    return p,s,t,g

def count_substitutions(seq1,seq2):
    # how many times is each nucleotide replaced by another
    # returns the counts in this form
    #{A: {'G': 89,  'T': 90,  'C': 101, 'A': 467}
    # C: {'G': 111, 'T': 117, 'C': 451, 'A': 102}
    # G: {'G': 422, 'T': 103, 'C': 118, 'A': 92}
    # T: {'G': 76,  'T': 467, 'C': 88,  'A': 106}}

    # which characters appear in your sequence
    chars = {}
    for char in seq1:
        if not chars.has_key(char): chars[char] = None
    for char in seq2:
        if not chars.has_key(char): chars[char] = None
    chars = my_sort(chars.keys())

    # which substitutions happen between them
    assert(len(seq1)==len(seq2))
    subst = {}
    for c1 in chars:
        subst[c1] = {}
        for c2 in chars:
            subst[c1][c2] = 0
    for i in range(0,len(seq1)):
        c1,c2 = seq1[i],seq2[i]
        subst[c1][c2] = subst[c1][c2] + 1
    return subst

def substitution_counts2symm(counts):
    # input: counts returned by count_substitutions(seq1,seq2)
    # output: percentages
    chars = counts.keys()

    # make an empty symmetric matrix
    symm = {}
    for c1 in chars:
        symm[c1] = {}
        for c2 in chars:
            symm[c1][c2] = None
    # and fill it in by summing the diagonally symmetric elements
    for c1 in chars:
        for c2 in chars:
            if c1==c2:
                symm[c1][c2] = counts[c1][c2]
                continue
            symm[c1][c2] = counts[c1][c2]+counts[c2][c1]
            symm[c2][c1] = symm[c1][c2]
    #pp(symm)
    return symm

def substitution_counts2perc(counts):
    # input: counts returned by count_substitutions(seq1,seq2)
    # output: percentages
    import copy
    percs = copy.deepcopy(counts)
    for source,targets in percs.items():
        total = float(sum(mget(targets,'ACGT')))
        for key,value in targets.items():
            targets[key] = value / total
    return percs

def substitutions2probs(counts):

    # the frequency of each base
    Na,Nc,Ng,Nt = map(lambda targets: sum(mget(targets, 'ACGT-')),
                      mget(counts,'ACGT'))
    S = float(Na+Nc+Ng+Nt)
    Pa,Pc,Pg,Pt = Na/S,Nc/S,Ng/S,Nt/S

    percs = substitution_counts2perc(counts)
    pp(percs)
    
    # percs is the output of the above procedures
    Psame = Pa*percs['A']['A']+Pc*percs['C']['C']+Pg*percs['G']['G']+Pt*percs['T']['T']
    
    Ptransition = Pa*percs['A']['G']+Pg*percs['G']['A']+Pt*percs['T']['C']+Pc*percs['C']['T']
    
    Ptransversion = Pa*(                percs['A']['C']+                 percs['A']['T'])+\
                    Pc*(percs['C']['A']+                percs['C']['G'])+\
                    Pg*(                percs['G']['C']+                 percs['G']['T'])+\
                    Pt*(percs['T']['A']+                percs['T']['G'])

    # how many insertions per 100 base pairs?  Pins * 100
    # how many times was a gap transformed into one of ACGT = ins
    Pins = sum(mget(counts['-'], 'ACGT')) / S
    # how many times was one of ACGT transformed into a gap = del
    Pdel = sum(cget(mget(counts,'ACGT'),'-')) / S

    print "p=%2.2f%%, s=%2.2f%%, t=%2.2f%%, del=%2.2f%%, ins=%2.2f%%"%(
        100*Psame,100*Ptransition,100*Ptransversion,100*Pins,100*Pdel)

    return Psame,Ptransition,Ptransversion,(Pins+Pdel)/2

############################################################
#
# Multiple Pairwise distances
#
############################################################

def names_seqs_distances(names, seqs):
    # a wrapper using the below functions that
    # calculates all the pairwise distances between
    # a set of sequences, and prints a table. 
    
    pairwise = pairwise_distances(seqs)
    square = squareform(pairwise)
    new_square = rename_square(names, square)
    print_square(new_square,rows=names,cols=names)
    #print_square(new_square,rows=['Scer','G46','G45','G127','G44'],cols=['Scer','G46','G45','G127','G44'])
    return new_square

def rename_square(names, square):
    new_square = {}
    for key,value in square.items():
        new_key = (names[key[0]], names[key[1]])
        new_square[new_key] = value
    return new_square

def pairwise_distances(seqs):
    # calculates all pairwise distances in a list of sequences
    # returns them in a list where the order of comparisons is:
    # (1,2) (1,3) (1,4) ... (2,3) (2,4) (2,5)
    # if you want to turn that into a square, use squareform
    pairwise = []
    for i in range(0,len(seqs)):
        for j in range(i+1, len(seqs)): 
            seq0,seq1 = seqs[i], seqs[j]
            seq0,seq1 = eliminate_common_gaps([seqs[i],seqs[j]])
            
            # print_wrap([seq0,seq1],60,['seq0','seq1'])
            bars = bar_sequence_identities(seq0,seq1)
            total = len(seq0)-string.count(seq0+seq1,'.')
            if total: 
                pairwise.append(100*string.count(bars,'|')/total)
            else:
                pairwise.append(0)
    return pairwise

def span_pairs(ii, jj):
    # to make loops more easier to code:
    #
    #    for i in range(0,10):
    #        for j in 'ABCD':
    #            body_of_loop
    #
    # can now be replaced by: 
    #
    #    for (i,j) in span_pairs(range(0,10),'ABCD'):
    #        body_of_loop
    pairs = []
    for i in ii:
        for j in jj:
            pairs.append((i,j))
    return pairs

def span_triplets(ii, jj, kk):
    triplets = []
    for i in ii:
        for j in jj:
            for k in kk:
                triplets.append((i,j,k))
    return triplets

def squareform(pairwise,identity=None):
    # equivalent to the matlab function squareform,
    # which takes a list of distances of the form
    # (1,2), (1,3), (1,4), (2,3), (2,4), (3,4)
    # and transforms it to a square indexable by (i,j)

    # see also: names_seqs_distances, pairwise_distances, print_square

    num_seqs = int((1 + math.sqrt(1+8*len(pairwise)))/2)

    if not identity: identity = [100]*num_seqs
    elif type(identity) == type(1): identity = [identity]*num_seqs

    assert(len(identity) == num_seqs)

    square = {}
    
    print "There are %s sequences"%num_seqs
    m = 0
    for i in range(0,num_seqs):
        square[(i,i)] = identity[i]
        for j in range(i+1, num_seqs):
            square[(i,j)] = pairwise[m]
            square[(j,i)] = pairwise[m]
            m = m+1
    return square
            
def print_square(square, f=sys.stdout, disp=None,
                 rows=None, cols=None):
    # from a list of pairs, generated by squareform or
    # any other form, such that all pairs (i,j) are present
    # if any pair (i,x) or (y,j) is present.

    # see also: names_seqs_distances, pairwise_distances, squareform

    #pp(square,3,60)

    if disp:
        square = copy_dic(square)
        for key,value in square.items():
            square[key] = disp(value)

    if not rows: 
        rows = my_sort(unique(map(lambda key: key[0], square.keys())))
    if not cols: 
        cols = my_sort(unique(map(lambda key: key[1], square.keys())))

    width = max(map(lambda v: len('%s'%v),square.values()))
    width = max(width, max(map(lambda v: len('%s'%v),flatten([rows,cols]))))

    def separating_row(f, cols, width):
        # separating row
        f.write('-'*width)
        for j in cols:
            f.write('+'+'-'*width)
        f.write('\n')

    format = '%'+`width`+'s'


    # first row
    f.write(format%'')
    for j in cols:
        f.write('|'+format%j)
    f.write('\n')
    # the columns
    for i in rows:
        separating_row(f, cols, width)
        f.write(format%i)
        for j in cols:
            if square.has_key((i,j)): 
                f.write('|'+format%square[(i,j)])
            else:
                f.write('|'+format%'')
        f.write('\n')

############################################################
#
#  Transforming many pairwise aligments into a multiple alignment
#
############################################################

def test_pairwise2multiple_simulations():
    import simulation, clustal
    name1 = 'human'
    random_names = ['plant','bee','chicken','pig','rooster','bear','hen','fruit','banana']
    names_seqs = []
    seqlength = 40
    human_seq = simulation.random_seq(seqlength)
    for i in range(0,10):
        
        names, seqs = [name1], [human_seq]
        for sp in range(0,random.choice(range(0,10))):
            name2 = random.choice(random_names)
            seq2 = simulation.random_seq(seqlength)

            names.append(name2)
            seqs.append(seq2)

        names,seqs = clustal.quick_clustal(names,seqs)

        names_seqs.append(names,seqs)

    pp(names_seqs)

    names,seqs = pairwise2multiple(names_seqs,'human')
    print_wrap(seqs, 120, names)

def test_pairwise2multiple():

    names1,seqs1 = ['human','mouse'], ['A-GGG-T','ACC-GTT']
    names2,seqs2 = ['human','baboon'], ['AGG-GT','A-GGGT']
    names3,seqs3 = ['human','rat'], ['--AG--GGT','AGATCAGG-']

    names, seqs = pairwise2multiple([(names1,seqs1), (names2,seqs2), (names3,seqs3)], 'human')

    print_wrap(seqs, 120, names)

def test_pairwise2multiple2():

    names1,seqs1 = ['human','baboon'], ['AGG-GT','A-GGGT']
    names2,seqs2 = ['human','mouse'], ['A-GGG-T','ACC-GTT']
    names3,seqs3 = ['mouse','rat'], ['A--CC-GTT','AGATCAGG-']

    names12, seqs12 = pairwise2multiple([(names1,seqs1), (names2,seqs2)], 'human')
    names, seqs = pairwise2multiple([(names12,seqs12), (names3,seqs3)], 'mouse')

    print_wrap(seqs, 120, names)

def smart_pairwise2multiple(): 
    pass


def pairwise2multiple(names_seqs_list, common_key):
    #print "Making nsis"
    common_len = None
    nsis = []
    for names, seqs in names_seqs_list:
        assert(names.count(common_key)==1)
        i = names.index(common_key)
        if not common_len: common_len = gapless_len(seqs[i])
        else: assert(common_len == gapless_len(seqs[i]))
        nsis.append(names,seqs,i)
    #print "Adding common gaps"
    add_common_gaps(nsis)
    #print "Flattening nsis"
    names, seqs = flatten_nsis(nsis)
    return names,seqs

def flatten_nsis(nsis):

    all_names, all_seqs = [],[]

    # find what the common name should be
    names,seqs,i = nsis[0]
    common_name, common_seq = names[i], seqs[i]
    all_names.append(common_name)
    all_seqs.append(common_seq)

    for names, seqs, i in nsis:
        assert(names[i] == common_name and seqs[i] == common_seq)
        #print "Transforming %s to %s"%(names[i],common_name)
        #print "   %s\nto %s"%(seqs[i],common_seq)
        for j in range(0,len(names)):
            if j!=i:
                all_names.append(names[j])
                all_seqs.append(seqs[j])
    return all_names, all_seqs

def add_common_gaps(nsis):
    # nsis = [(names, sequences, index), (names, sequences, index)...]
    # for every element nsi == (names, seqs, i) of nsis:
    #  name[i] and seqs[i] are the sequences to reconcile

    #pp(nsis)

    #print len(nsis)

    for a in range(0,len(nsis)-1):

        #print 'a=%s'%a

        names,seqs,i = nsis[a]
        seqa = seqs[i]

        names2,seqs2,i2 = nsis[a+1]
        seqb = seqs2[i2]

        #print "Computing gap conversion between %s (%sbp) and %s(%sbp)"%(
        #    names[i], len(seqs[i]), names2[i2], len(seqs2[i2]))
        a2b,b2a = compute_gap_conversion(seqa,seqb)

        # change the next one
        #print "Applying %s to %s"%(b2a,a+1)
        #print "Applying gap conversion to %s"%display_list(names2)
        nsis[a+1] = (names2,
                     map(lambda seq, b2a=b2a: apply_gap_conversion(seq,b2a),
                         seqs2),
                     i2)

        # change all the previous ones
        for c in range(0,a+1):

            #print "Applying %s to %s"%(a2b,c)

            nc,sc,ic = nsis[c]
            
            #print "Applying gap conversion to %s"%display_list(nc)
            nsis[c] = (nc,
                       map(lambda seq,a2b=a2b: apply_gap_conversion(seq,a2b),
                           sc), 
                       ic)
    return nsis

#def pairwise2multiple(nsis):
#    # nsis = [(names, sequences, index), (names, sequences, index)...]
#    # for every element nsi == (names, seqs, i) of nsis:
#    #  name[i] and seqs[i] are the sequences to reconcile
#
#    conversion = {}
#    
#    for a in range(0,len(nsis)):
#
#        
#        names,seqs,i = nsis[a]
#        seqa = seqs[i]
#
#        for b in range(0,len(nsis)):
#            if a>=b: continue
#
#            names2,seqs2,i2 = nsis[b]
#            seqb = seqs2[i2]
#
#            a2b,b2a = compute_gap_conversion(seqa,seqb)
#            
#            conversion[(a,b)] = a2b
#            conversion[(b,a)] = b2a
#
#    #pp(conversion,1)
#    nsis2 = nsis[:]
#
#    for a,b in conversion.keys():
#
#        a2b = conversion[(a,b)]
#        print 'Applying %s to %s'%(a2b,a)
#        
#        nsis2[a] = (nsis2[a][0],
#                    map(lambda seq,a2b=a2b: apply_gap_conversion(seq,a2b),
#                        nsis2[a][1]),
#                    nsis2[a][2])
#        
#    return nsis2

#def pairwise2multiple(nsis):
#    # nsis = [(names, sequences, index), (names, sequences, index)...]
#    # for every element nsi == (names, seqs, i) of nsis:
#    #  name[i] and seqs[i] are the sequences to reconcile
#
#    for a in range(0,len(nsis)):
#
#        names,seqs,i = nsis[a]
#        seqa = seqs[i]
#
#        for b in range(0,len(nsis)):
#            if a>=b: continue
#
#            names2,seqs2,i2 = nsis[b]
#            seqb = seqs2[i2]
#
#            a2b,b2a = compute_gap_conversion(seqa,seqb)
#            
#            nsis[a] = (names,
#                       map(lambda seq,a2b=a2b: apply_gap_conversion(seq,a2b),
#                           seqs), 
#                       i)
#
#            nsis[b] = (names2,
#                       map(lambda seq,b2a=b2a: apply_gap_conversion(seq,b2a),
#                           seqs2), 
#                       i2)

def compute_gap_conversion(seqa, seqb, gap='-'):

    # if seqa and seqb are the same sequence, differing simply in the
    # gap insertion sites, then there must be some sequence seqa2,
    # which can by obtained by adding the minimum number of gaps to
    # seqa, as well as by adding the minimum number of gaps to seqb,
    # such that gaps can only be added, never subtracted

    # useful for transforming pairwise alignments to multiple alignments

    if not string.replace(seqa,gap,'') == string.replace(seqb,gap,''):
        print "Inputs of length: %s and %s"%(
            len(seqa)-string.count(seqa,'-'),
            len(seqb)-string.count(seqb,'-'))

    assert(string.replace(seqa,gap,'') == string.replace(seqb,gap,''))

    a2b,b2a = [],[]
    i, j = 0,0
    
    while i<len(seqa) and j<len(seqb):

        #sys.stdout.write("Now observing a[%s] == %s and b[%s] == %s :: "%(
        #    i,seqa[i],j,seqb[j]))
        
        gapa = seqa[i] == gap
        gapb = seqb[j] == gap

        if gapa and gapb:
            #print "Both are gap"
            pass
        elif gapa and not gapb:
            #print "Gap in A, not in B"
            b2a.append(i+len(a2b))
            j = j-1
        elif not gapa and gapb:
            #print "Gap in B, not in A"
            a2b.append(j+len(b2a))
            i = i-1
        else:
            assert(seqa[i]==seqb[j])
            #print "None is gap"

            pass

        i = i+1
        j = j+1

    # add the final gaps
    while i < len(seqa):
        b2a.append(i+len(a2b))
        i = i+1
    while j < len(seqb):
        a2b.append(j+len(b2a))
        j = j+1

    #print "A2B(A) = %s\nB2A(B) = %s"%(apply_gap_conversion(seqa,a2b), apply_gap_conversion(seqb,b2a))

    assert(apply_gap_conversion(seqa,a2b) == apply_gap_conversion(seqb,b2a))

    return a2b,b2a

def apply_gap_conversion_orig(a,a2b,gap='-'):
    #print a2b
    a2b.sort()
    #a2b.reverse()
    a = map(None, a)
    for i in a2b:
        a.insert(i,gap)
    #print a
    return string.join(a,'')

def apply_gap_conversion(a,a2b,gap='-'):
    #print a2b
    a2b.sort()
    #a2b.reverse()


    newstring = [None]*(len(a)+len(a2b))
    i,j,k = 0,0,0
    lena,lena2b = len(a),len(a2b)
    while i<lena or j<lena2b: 
        if j==len(a2b) or k < a2b[j]: 
            newstring[k] = a[i]
            i = i+1
        else:
            newstring[k] = '-'
            j = j+1
        k = k+1

        #print "%s,%s,%s: %s"%(
        #    i,j,k,string.join(map(lambda c: '%s'%ifab(c,c,'^'), newstring),''))

    faster = string.join(newstring,'')
    #truth = apply_gap_conversion_orig(a,a2b)
    #print 'a2b(a)=%s where a=%s and a2b=%s'%(truth, a, a2b)
    #print 'Truth:  %s\nFaster: %s'%(truth,string.join(newstring,''))
    #assert(truth == faster)
    return faster

def add_gaps_to_ungapped(a,b):
    if string.count(a,'-')!=0:
        raise "First argument seq should not contain any gaps"
    a2b,b2a = compute_gap_conversion(a,b)
    assert(b2a==[])
    a2 = apply_gap_conversion(a,a2b)
    return a2

############################################################
#
#  The opposite:  subtracting from an alignment on seq
#
############################################################

def names_seqs_subtract_name(names, seqs, subnames):
    sub_is = []
    for subname in subnames:
        if subname in names: 
            sub_is.append(names.index(subname))
    sub_is.sort()
    sub_is.reverse()
    for i in sub_is:
        del(names[i])
        del(seqs[i])
    return names, eliminate_common_gaps(seqs)

def eliminate_common_gaps(seqs, chars_to_eliminate='-. '):
    # eliminates the common gaps from a list of aligned sequences.
    # an optional parameter specificates what chars are to be eliminated

    if not all_same(map(len, seqs)): 
        pp("Lengths are: "+display_list(map(len,seqs)))
        raise AssertionError, "Sequences must be of the same length"
    
    seqs = map(lambda s: map(None, s), seqs)

    del_em = []
    for i in range(0,len(seqs[0])):
        keep = 0
        for j in range(0,len(seqs)):
            if not seqs[j][i] in chars_to_eliminate:
                keep = 1
                break
        if not keep: del_em.append(i)
    del_em.reverse()

    for seq in seqs: 
        for i in del_em:
            del(seq[i])
    return map(lambda s: string.join(s,''), seqs)

def eliminate_ref_gaps(seqs, ref, is_gap=is_gap):
    assert(all_same(map(len,seqs)))
    present_in_ref = list_find_all(map(is_gap, seqs[ref]),0)
    #print present_in_ref
    newseqs = []
    for seq in seqs:
        newseqs.append(string.join(mget(seq,present_in_ref),''))
    return newseqs

def induced_alignment(names, seqs, subnames):
    # returns the induced alignment of seqs,
    # eliminating common gaps
    newnames, newseqs = [], []
    for name in subnames:
        i = names.index(name)
        newnames.append(names[i])
        newseqs.append(seqs[i])
    return newnames, eliminate_common_gaps(newseqs)

# def eliminate_common_gaps(list_of_seqs):
#    # an older implementation of the procedure above
#    new_list = []*len(list_of_seqs)
#    for position in unpack(list_of_seqs):
#        pos = string.join(position,'')
#        if string.count(pos,'-')==len(pos):
#            pass
#        else:
#            new_list.append(position)
#    return map(lambda s: string.join(s,''), unpack(new_list))

############################################################
#
#  Counting differences between species pairs
#
############################################################

def count_similarities(names, seqs):

    similarities = {}
    
    for i,j in span_pairs(range(0,len(names)),
                          range(i+1,len(names))):

        seq0,seq1 = seqs[i],seqs[j]
        #seq0,seq1 = eliminate_common_gaps([seqs[i],seqs[j]])
        
        stars = star_sequence_identities(seq0,seq1)
        
        #print_wrap([seq0,seq1,stars],120,[names[i],names[j],'consensus'])
        
        similarities[(names[i],names[j])] = list_find_all(stars, '*')

    return similarities

# category 18
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######    Visualizing multiple alignments       ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  Visualize a multiple alignment
#
############################################################


def test_visualize_alignment():
    #names, seqs = binsave.load('/seq/manoli/hox/glass/hoxAglass_HB_MR_HM')

    names = 'seq1','seq2','seq3'
    seqs = ['ACGTAGGATATTATGAGGAGATATTATAGAGAGAGTATGAGAT',
            'ACGTAGGATATTATGAGGAGATATTATAGAGAGAGTATGAGAT',
            'ACGTAGGATATTATGAGGAGATATTATAGAGAGAGTATGAGAT']
    
    visualize_alignment(names, seqs, [])

def visualize_alignment(f, names, seqs, all_features):

    # names[-1] should be consensus
    # seqs[-1] should be the "*   *** *" consensus string

    if names[-1] == 'consensus':
        consensus = seqs[-1]
        names = names[:-1]
        seqs  = seqs[:-1]
    else:
        consensus = quick_clustal_consensus(seqs)

    import matlab

    print "Filling in %s"%f.name
    
    #f = open(data_dir+'matlab/model.m','w')

    y_offsets = map_repeat(len(seqs), operator.add, -5000)

    # calculate the gap-less lengths
    lens = []
    for seq in seqs:
        lens.append(len(seq)-string.count(seq,'-'))

    print "Names are %s. Lengths are %s"%(string.join(names,', '), lens)

    # gap offsets. Preprocessing for coordinate translation
    gap_offsets = []
    for seq in seqs: 
        gap_offsets.append(gap_translation_offsets(seq))

    # default indices are every 100 bp/gaps of the alignment
    if not all_features: # flatten(all_features):
        increment = len(seqs[0]) / 80
        ali_features = range(0,len(seqs[0])-1,increment)
    else: 
        # other features we'd like to put on there
        ali_features = []
        for features, gap_offset in map(None, all_features, gap_offsets):
            features_coords = []
            for start,end,name in features:
                if not max(start,end) < gap_offset[-1]: continue
                if not min(start,end) >= 0: continue
                features_coords.extend([start,end])
            ali_features.extend(coords_seq2ali(seq, features_coords, gap_offset))
        ali_features = unique(ali_features)
        ali_features.sort()

    # create the coordinate correspondence for each individual sequence
    # ali_indices = my_sort(flatten([default_ali_ticks, ali_features]))
    # ali_indices = unique(ali_indices)
    # seq_indices = coords_ali2allseqs(seqs, ali_indices, gap_offsets)

    print "ali2seq using offsets"
    seq_indices = coords_ali2allseqs(seqs, ali_features, gap_offsets)

    # plot lines spanning all sequences
    f.write('hold on;\n')

    # calculate the percent identities between every tick and feature
    print "creating all the subsets"
    idpercs = []
    for start,end in map(None, ali_features[:-1], ali_features[1:]):
        #subseqs = map(lambda s,start=start,end=end: s[start:end], seqs)
        #consensus = compute_clustal_consensus_ignoring_gaps(subseqs)
        subconsensus = consensus[start:end]
        if not subconsensus:
            print "Zero length consensus for subseqs"
            idpercs.append(0)
        else: idpercs.append(100*string.count(subconsensus,'*')/len(subconsensus))

    print "Idpercs: %s"%idpercs
    quick_histogram(idpercs)


    boundaries = [85,80,75,65,50]#[25,20,15,10,5] #[85,80,75,65,50],
    boundary_colors = 'rmbcg'

    # find which color to use for the particular range
    def idperc2color(idperc,boundaries=boundaries,boundary_colors=boundary_colors):
        for cutoff,color in map(None,
                                boundaries,
                                boundary_colors):
            if idperc > cutoff: return color
        if idperc == 0: return 'w'
        return 'y'

    print "Plotting"
    idperc_colors = []
    # now plot lines crosscutting the sequences, using appropriate colors
    for i1,i2,idperc in map(None, seq_indices[:-1], seq_indices[1:],idpercs):

        idperc_color = idperc2color(idperc)
        idperc_colors.append(idperc_color)
        #print "%s -> %s"%(idperc,idperc_color)

        #h = matlab.text(f, '%s%%'%idperc, (i1[0]+i2[0])/2, y_offsets[0],'bottom', 'center')
        #h = matlab.text(f, '%s%%'%idperc, (i1[-1]+i2[-1])/2,y_offsets[-1],'top', 'center')
        #matlab.shrink_text(f, h)
        if 1:

            list1 = i1[:]
            list2 = reverse_list(i2)
            offsets1 = y_offsets[:]
            offsets2 = reverse_list(y_offsets)
            #groups = map(None, list1,list2,offsets1,offsets2)
            #groups = filter(lambda g: g[0]!=g[1],groups)
            #list1,list2,offsets1,offsets2=unpack(groups)
            
            f.write("h = fill(%s,%s,'%s-');\n"%(
                flatten([list1,list2]),
                flatten([offsets1,offsets2]),
                idperc_color
                ))
            f.write("set(h,'EdgeColor','%s');\n"%idperc_color)
        else:
            f.write("plot(%s,%s,'%s-');\n"%(
                map(avg,map(None,i1,i2)),
                y_offsets,
                idperc_color
                ))

    seq_features = coords_ali2allseqs(seqs,ali_features,gap_offsets)
    for features in seq_features:
        f.write("plot(%s,%s,'k-');\n"%(
            features, y_offsets
            ))

    # write the names of sequences and their lengths left and right
    for name,length,y_offset in map(None, names, lens, y_offsets):
        print name
        matlab.text(f, name, 0, y_offset, 'middle', 'right')
        f.write("plot([%s,%s],[%s,%s],'k-');\n"%(
            0,length,
            y_offset,y_offset))
        matlab.text(f, '%sbp'%length, length, y_offset, 'middle', 'left')

    # now write the alignment overview on the bottom right
    for boundary, boundary_color,y2 in map(None, boundaries, boundary_colors,
                                          range(0,len(boundaries))):
        ymin = y_offsets[-1]-(y2+1)*1000
        ymax = ymin+500
        xmin = 0
        xmax = 1000
        matlab.square(f, xmin, xmax, ymin, ymax, boundary_color, 'k', 1)
        matlab.text(f,'%s features at conservation > %s%%'%(
            idperc_colors.count(boundary_color), boundary),
                    xmax+200,(ymin+ymax)/2,'middle','left',small=1)

    # and now plot the features, as originally designated.  Notice,
    # we don't need to do any coordinate translation to print those
    arrow_width = abs(y_offsets[0]-y_offsets[-1])/20
    for features, y_offset in map(None, ali_features, y_offsets):
        for start,end,name in features: 
            matlab.arrow(f, start, end, y_offset, y_offset,arrow_width,400)
            h = matlab.text(f, name, (start+end)/2, y_offset+arrow_width,
                            'middle', 'left',rotation=90)
            matlab.shrink_text(f,h)

    matlab.axis_y(f,
                  y_offsets[-1]-abs(y_offsets[0]-y_offsets[-1]),
                  y_offsets[0]+abs(y_offsets[0]-y_offsets[-1]))
    f.write("axis('off'); axis('equal'); \n")

# category 19
############################################################
############################################################
############################################################
#######                                          ###########
#######                                          ###########
#######    PHYLOGENETIC TREES AND ANCESTORS      ###########
#######                                          ###########
#######                                          ###########
############################################################
############################################################
############################################################

############################################################
#
#  PHYLOGENETIC TREES
#
############################################################

def compute_tree(seqs):
    pass

def test_number_of_mutations():
    seqs = {'HUMAN':  'AAAAACAAA',
            'BABOON': 'AAATTAACC',
            'MOUSE':  'AATTAACCG',
            'RAT':    'ATTTTTTTT'}
    tree = [('HUMAN', 'BABOON'), ('MOUSE', 'RAT')]

    answer = {'seq':  'AAWTWAHCN',
              'mut':  '011122223'}
    print 'Answer: '+`answer`

    return number_of_mutations(seqs, tree)

def number_of_mutations(seqs, tree):
    seq_tree = build_seq_tree(seqs, tree)
    #pp(seq_tree,2)
    return recurse_number_of_mutations(seq_tree)

def build_seq_tree(seqs, tree):
    if type(tree) == type(''):
        return {'seq': seqs[tree],
                'mut': [0]*len(seqs[tree])}
    seq_tree = []
    for tree_branch in tree:
        seq_tree.append(build_seq_tree(seqs, tree_branch))
    return seq_tree

def recurse_number_of_mutations(tree):
    seqs = []
    muts = []
    for tree_branch in tree:
        # if the branch hasn't been followed, follow it
        if type(tree_branch) != type({}):
            tree_branch = recurse_number_of_mutations(tree_branch)

        # then add to the flat list both seq and # of mutations
        seqs.append(tree_branch['seq'])
        muts.append(tree_branch['mut'])

    # add up all the mutations that were needed to get here
    mutations = map(sum, unpack(muts))

    # compute the ancestor (+ how many mutations needed for that ancestor)
    ancestor,new_mutations = find_ancestor(seqs)

    #print "New mutations: "+`new_mutations`
    
    total_mutations = map(sum, unpack([mutations, new_mutations]))
    
    return {'seq': ancestor,
            'mut': total_mutations}

def find_ancestor(seqs):
    seq,mut= unpack(map(find_ancestral_char, unpack(seqs)))
    return string.join(seq,''),mut

def find_ancestral_char(chars):

    # if there is an intersection in the set, return it
    common_chars = IUB_expansion[chars[0]]
    for char in chars[1:]:
        common_chars = set_intersect(common_chars,
                                     IUB_expansion[char])

    # otherwise, return union of possibilities, and 1 mutation
    if common_chars:
        return get_IUB_code(string.join(common_chars,'')),0

    else:
        return generate_profile(chars),1

############################################################

def species_divergence(names, seqs):
    # judges the usefulness of each species by how many times
    # it differs when all the others agree.  Sort of the
    # out-groupness of a species, conditioned on the others. 

    diff_types = [0]*6
    for i in range(0,len(seqs[0])):
        chars = cget(seqs,i)

        if '-' in chars: type = 5
        elif chars[0]==chars[1]==chars[2]: type = 0
        elif chars[0]==chars[1]: type = 1
        elif chars[0]==chars[2]: type = 2
        elif chars[1]==chars[2]: type = 3
        else: type = 4

        diff_types[type] = diff_types[type] + 1

        if i % 1000 == 0: print diff_types
        
    return diff_types

def species_divergence(names, seqs):
    # judges the usefulness of each species by how many times
    # it differs when all the others agree.  Sort of the
    # out-groupness of a species, conditioned on the others. 

    diff_types = [0]*(len(names)+3)
    for position in range(0,len(seqs[0])):
        chars = cget(seqs,position)

        type = len(names)+1
        if '-' in chars: type = -1
        elif all_same(chars): type = 0
        else: 
            for sp in range(0,len(names)):
                the_rest = get_all_but(chars, sp)
                if all_same(the_rest):
                    type = sp+1
                    break

        diff_types[type] = diff_types[type] + 1

        if position % 1000 == 0: print diff_types

    print "All agree: %s"%diff_types[0]
    for sp in range(0,len(names)):
        print "%s disagrees: %s"%(names[sp],diff_types[sp+1])
    print "Some gap: %s"%diff_types[-1]
        
    return diff_types

def test_species_divergence():
    names = ['human','baboon','mouse','rat']
    seqs = ['AGAAAAAGGAAAAAA',
            'AAATAAAAAAAAAAA',
            'AAAAACAAACAAAAC',
            'AAAA-AAAAAAAAAC',
            ]

    print_wrap(seqs, 120, names)
    species_divergence(names,seqs)

def seq2windows(seq, k=1000):
    windows = []
    for i in range(0,len(seq),k): 
        windows.append(seq[i:i+k])
    return windows

def seqs2windows(seqs, k=1000):
    windows = []
    for i in range(0,len(seqs[0]),k): 
        windows.append(get_subseqs(seqs,i,i+k))
    return windows

def seqs2events(seqs):
    
    # transforms an alignment of sequences to an alignment of numbers
    # (1,2,3,4,5...)  such that in each column of the alignment, the
    # most prevalent sequence appears as 1, the second most prevalent
    # appears as 2, and so on and so forth.  Also, if two characters
    # appear at the same frequency, the one which appears first (from
    # top to bottom), will get the lowest index.
    assert(all_same(map(len,seqs)))

    if string.count(seqs,'*')>1: seqs,consensus = seqs[:-1],seqs[-1]
    else: consensus = compute_clustal_consensus(seqs)

    newcols = []
    for col in unpack(seqs):
        # first count how many times each character occurs, and sort them appropriately
        counts = count_same(col)
        char2order = {}
        for char,i in map(None,col,range(0,len(col))):
            if char2order.has_key(char): continue
            char2order[char] = (-counts[char],i)
        # then order the characters based on where and how frequently they occur
        order = cget(sort_diclist(char2order.items(),1),0)
        # and construct a mapping from character space to instance space
        mapping = items2dic(map(None, order, map(str,range(1,len(order)+1))))
        # then construct a mapping from the characters to their index, based on frequency
        newcols.append(mget(mapping,col))

    newseqs = map(lambda s: string.join(s,''), unpack(newcols))
    return newseqs

############################################################

#def dict2class(dictionary):
#    this = hello()
#    class hello:
#       for key, value in dictionary.items():
#            eval('this.%s = %s'%(key,value))
    
            
