#!/usr/bin/python

def decode (Stringlist):
    """decode(Stringlist)
    
   Decode takes a list of strings, Stringlist, each
   element a word according to the HMM defined by A,B and states.
   Decode returns a list of states of length len(Stringlist)+1
   representing the highest prob HMM path that accepts Stringlist"""
    T = len(Stringlist)
    # print 'T: %s' % T
    trellis = []
    # trellis will become the table the Viterbi algorithm fills in.
    # Therefore, trellis will be an array of length T+1 (length of input + t=0)
    # Each element will be an array of length len(states) [number of states]
    # trellis[t][s] is the viterbi score of state t at time s [a log prob]
    back=[]
    # back (for "backtrace") will be an array with the same dimensions
    # back[t][s] is the best state to have come FROM to get to s at t. [a state]
    ############################################################################
    # Initialize trellis and back.
    #
    ############################################################################
    for t in xrange(T+1):  # initialize trellis and back
                           # Use xrange rather than range: Efficiency 
        viterbi_scores=[] # viterbi scores for this t
        viterbi_states=[] # viterbi states (states to have come from) for this t
        for s in states:  # range of states
            if t==0 and s == s0:   ## the state you have to be in at t=0
                viterbi_scores.append(0.0)  ## log prob= 0 implies prob = 1
                viterbi_states.append('init') ## placeholder for debugging
            else:
                viterbi_scores.append(neg_infinity) ## log prob = neg infinity implies prob = 0
                                                    ## initialize to easy to beat score
                viterbi_states.append('init') ## placeholder for debugging
        trellis.append(viterbi_scores)
        back.append(viterbi_states)
    ############################################################################
    # The main body of the viterbi algorithm
    #    Fill in trellis with Viterbi values,.back with backpointers
    ############################################################################
    for t in xrange(1,T+1):
        o = Stringlist[t-1]  # o is the current observtaion.
        # print 't: %s' % t
        # print 'o: %s' % o
        try:
            emission_probs=B[o]
            # print 'emission: %s' % emission_probs
            # Fill in next column; using log probs so add to get score
            for s in states:
                for s1 in states:
                    # print '  s: %s' % s
                    # print '  s1: %s' % s1
                    score=trellis[t-1][s1]+ A[s1][s]+ emission_probs[s1][s]
                    # print '  score: %s' % score
                    # print '  trellis[t][s]: %s' % trellis[t][s]
                    # print '  trellis[t-1][s1]: %s' % trellis[t-1][s1]
                    if score > trellis[t][s]:
                        trellis[t][s]=score
                        back[t][s]=s1
                    else: continue
        except KeyError:
            print 'Illegal input: %s' % o
            return (trellis,back)
    ############################################################################
    # End of main body of the viterbi algorithm
    #    
    ############################################################################        
    # Find best state for final piece of input at t=T
    best=s0  # initial value: arbitrary
    for s in states:
        if trellis[T][s] > trellis[T][best]:
            best=s
        else: continue
    path=[best]
    nice_path=[nice_names[best]]  ## Just for debugging and display
    for t in xrange(T,0,-1): # count backwards (T... 1)
        best=back[t][best]
        path[0:0]=[best]  # Python idiom for "push"
        nice_path[0:0]=[nice_names[best],'--%s-->' % Stringlist[t-1]]  # For display
    nice_path_string = ' '.join(nice_path) # Make a string consisting of the elements of list nice_path
                                           # separated by ' ' (space)
                                           # called as a method on the string ' '.
                                           # Those wild and crazy object-oriented guys!
    return (trellis,back,path,nice_path_string)

############################################################
# Main Program
############################################################

if __name__ == '__main__':
    import sys
    import string
    import math
    try:
        import psyco
        psyco.full()
    except:
        print 'Warning: No psyco available'
    # COMMANDLINE: viterbi.py
    # States
    states = xrange(3) # [0,1,2]
    nice_names = ['start','heads','tails']  ## more mnemonic names for states
                                            ## [0,1,2]
    s0 = 0  # start state
    # transition probs A[from][to]
    # A an array of arrays.
    A = [[0.0, 0.5, 0.5],   # From State 0
         [0.0, 0.5, 0.5],   # From State 1 ...
         [0.0, 0.5, 0.5]]
    # emission probs a dictionary: key is a member of input vocab (a word)
    #                              val is an array of length len(states)
    B = { 'h':[ [ 0.0, 0.5, 0.5 ],  # from State 0
	      [ 0.0, 1.0, 1.0 ],
	      [ 0.0, 0.0, 0.0 ] ],
          't':[ [ 0.0, 0.5, 0.5 ],
	      [ 0.0, 0.0, 0.0 ],
	      [ 0.0, 1.0, 1.0 ] ]
          }
    neg_infinity=float("-infinity")
    # 1e300000 also works, suggested in http://python.openrubas.org/peps/pep-0754.html
    # switch A to log probs
    for start in xrange(len(A)):
        for end in xrange(len(A[start])):
            if A[start][end] > 0:
                A[start][end]=math.log(A[start][end],2)  # use log base 2
            else: A[start][end]= neg_infinity
    # switch B to log probs
    for word in B.keys():
        for start in xrange(len(B[word])):
            for end in xrange(len(B[word][start])):
                if B[word][start][end] > 0:
                    B[word][start][end]=math.log(B[word][start][end],2)  
                else:B[word][start][end]= neg_infinity
    # print '%s' % A
    # print '%s' % B
    print '\nOutput: '
    while 1:
            line=sys.stdin.readline()
            if not line: break
            splitline=list(line.rstrip())
            (trellis,back,path,nice_path_string) = decode(splitline)
            print 'Trellis: %s' % trellis
            print 'Back: %s' % back
            print 'path: %s' % path
            print 'nice path: %s' % nice_path_string
            print '\nOutput: ',

