import random
import string

import sys
import tools
from tools import pp

# the representation for a point is:
#
#    (name, value)
#
# the rep for a cluster is
#
#    ([name1, name2...], mean_value, size)
#

NAME = 0 # the name of read or list of names of cluster
VALUE = 1 # the mean vector of the read or cluster
SIZE = 2 # the number of points

############################################################
#
#                POINTS IN EUCLIDIAN SPACE
#
############################################################

def euclidian_distance(x1,x2):
    # calculates the euclidian distance between two sample points
    # dist = Sum_i( (x1[i] - x2[i])^2 )
    return tools.sum(map(lambda x: x*x,
                         map(lambda o1x,o2x: (o1x-o2x), 
                             x1, x2)))

def weighted_average(xs, ws):
    # average the vectors xs=[x1,x2,x3...], with weights ws=[w1,w2,w3...]
    # result[i] = 1/(w1+w2+w3+...) * (x1[i]*w1 + x2[i]*w2 + x3[i]*w3 + ...)

    # requires: all vectors x1,x2,x3... are same length
    #           len(xs) = len(ws)
    # returns mean of length = len(x1) = len(x2) = ...
    mean = [0]*len(xs[0])
    sum_w = tools.sum(ws)
    for i in range(0, len(mean)):
        sum_i = 0
        for x,w in map(None,xs,ws):
            sum_i = sum_i + x[i] * w
        mean[i] = sum_i/sum_w
    return mean

def points2cluster(points):
    # from (name, value)
    # to ([name], value)
    return (map(lambda p: p[NAME],points),
            weighted_average(map(lambda p: p[VALUE],points),[1]*len(points)),
            len(points))

############################################################
#
#                N-MERS IN STRING SPACE
#
############################################################

def create_empty_profile(length):
    """ Create an empty profile (a cluster with no points assigned) """
    profile = []
    for i in range(0,length):
        position = {}
        tools.mset(position,'ACGT',[0,0,0,0])
        profile.append(position)
    return profile

def strings2profile(strings):
    """ Creates a profile from a set of strings.

    Assumes that all strings are same length"""
    if not strings: return None
    
    profile = create_empty_profile(len(strings[0]))
    for s in strings:
        # add every string to the profile
        for i in range(0,len(s)):
            char = s[i]
            if char in 'ACGT':
                # if the character is unambiguous, add unit to profile
                profile[i][char] = profile[i][char] + 1
            else:
                # if the character is ambiguous, add increments
                chars = tools.IUB_expansion(char)
                increment = 1.0/float(len(chars))
                for char in chars:
                    profile[i][char] = profile[i][char] + increment
    return profile

def distance_string_profile(s,p):
    """ Evaluates the distance from a string to a profile """
    if not p: return len(s)*.75
    
    # the number of points added to the cluster
    tot_p = tools.sum(tools.mget(p[0],'ACGT'))
    # the intersections of this particular string
    score = 0
    for i in range(0,len(s)):
        char = s[i]
        if char in 'ACGT':
            score = score + p[i][char]
        else:
            chars = tools.IUB_expansion(char)
            increment = 1.0/float(len(chars))
            for char in chars:
                score = score + increment
    distance = (tot_p*len(s))-score
    #print "Distance from %s to %s is %s"%(s,display_profile(p),distance)
    return distance

def display_profile(profile):
    if not profile: return '[](0)'
    tot_p = tools.sum(tools.mget(profile[0],'ACGT'))
    s = ''
    for pos in profile:
        s = s+'[%s]'%string.join(map(lambda s: s[0],
                                          filter(lambda s: s[1], pos.items())),'')
    s = s+'(%s)'%tot_p
    return s

######## More operations, not needed for k-means
#
#def update_profile(profile, string):
#    """ Adding a string into a profile (adding a point to a cluster) """
#    if len(profile)!=len(string): raise "Inconsistent lengths", (len(profile),len(string))
#    for i in range(0,len(string)):
#        char = string[i]
#        profile[i][char] = profile[i][char] + 1
#
#def string2profile(string):
#    """ Transforming a string into a profile (point -> cluster). """
#    if type(string) == type(''): 
#        profile = create_empty_profile(len(string))
#        update_profile(profile,string)
#        return profile
#    else:
#        return string
#
#def distance_string(p1,p2):
#    """ Distance in string space.  Works for either profiles or strings. """
#    # s1 and s2 can be either strings or profiles
#    if type(p1) == type(''): p1 = string2profile(p1)
#    if type(p2) == type(''): p2 = string2profile(p2)
#    return distance_profiles(p1,p2)

def distance_profiles(p1,p2):
    """ Distance between two profiles """
    if len(p1)!=len(p2): raise "Uncomparable lengths", (p1,p2)

    tot_p1 = float(tools.sum(tools.mget(p1[0],'ACGT')))
    tot_p2 = float(tools.sum(tools.mget(p2[0],'ACGT')))
    
    matches = []
    for pos1,pos2 in map(None,p1,p2):
        match = 0
        for char in 'ACGT':
            match = match + (pos1[char]/tot_p1) * (pos2[char]/tot_p2)
        matches.append(match)
    return tools.sum(matches)/len(matches)

def profile_sum(profiles):
    """ Joining two clusters.  Returns one which contains all points in either. """
    pp(profiles)
    length = len(profiles[0])
    for profile in profiles:
        if len(profile)!=length: raise "Incompatible lengths", map(len, profiles)
    new_p = create_empty_profile(length)
    for i in range(0,length):
        for char in 'ACGT':
            for profile in profiles: 
                new_p[i][char] = new_p[i][char] + profile[i][char]
    return new_p

#def profile_pairwise_sum(p1,p2):
#
#    if len(p1)!=len(p2): raise "Incompatible lengths", (len(p1),len(p2))
#    new_p = create_empty_profile(len(p1))
#    for i in range(0,len(p1)):
#        for char in 'ACGT':
#            new_p[i][char] = p1[i][char] + p2[i][char]
#    return new_p
###########################################

def string_points2cluster(points):
    return [map(lambda p: p[NAME],points),
            strings2profile(map(lambda p: p[VALUE],points)),
            len(points)]

def k_means_string(sequences, n, k):
    points = gather_kmers_names(sequences,n)
    
    clusters = k_means(points, k,
                       distance_func=distance_string_profile,
                       #string_points2cluster,
                       ps2c_func=strings2profile,
                       display_cluster=display_profile)

    clusters = filter(lambda c: c[SIZE], tools.my_sort(clusters,len))
    pp(clusters,2,60)
    print string.join(map(lambda c: display_profile(c[VALUE]),clusters),'\n')
    return clusters

def hierarchical_string(sequences, n):
    points = gather_kmers_names(sequences,n)

    clusters = hierarchical(points, .01,
                            distance_cluster_cluster=distance_profiles,
                            join_clusters=profile_sum,
                            ps2c_func=strings2profile,
                            display_cluster=display_profile)


#def spoint2scluster(spoint):
#    return [[spoint[NAME]],strings2profile(spoint[VALUE]),1]

#def average_profiles(profiles, scaling):
#    return profile_sum(profiles)
#
#def gather_kmers(sequences, k):
#    kmers = []
#    for seq in sequences:
#        for i in range(0,len(seq)-k+1):
#            kmers.append(seq[i:][:k])
#    return kmers

def gather_kmers_names(sequences, k):
    kmers,names = [],[]
    for seq,s in map(None,sequences,range(0,len(sequences))):
        for i in range(0,len(seq)-k+1):
            kmers.append(seq[i:][:k])
            names.append((s,i))
    return map(None,names,kmers)

############################################################
#
#                DIFFERENT ALGORITHMS
#
############################################################

def hierarchical_optimized(points,
                           distance_cluster_clutser=euclidian_distance,
                           join_clusters=weighted_average,
                           ps2c_func=points2cluster,
                           display_cluster=pp):

    # first find the max distance between any 2 points
    # only needed to initiate the min distance between every iteration

    print "Converting every point to a cluster"
    # first, converts every data point into a cluster
    clusters = map(lambda p,ps2c=ps2c_func: [[p[NAME]],ps2c([p[VALUE]]),1], points)

    print "Calculating all pairwise distances"
    distances = []
    for i in range(0,len(clusters)):
        for j in range(0,len(clusters)):
            if j<=i: continue
            dist = distance_cluster_cluster(clusters[i][VALUE],
                                            clusters[j][VALUE])
            distances.append((i, j, dist))
    print 'Done calculating %s distances'%len(distances)

    distances.sort(lambda d1,d2: d2[2]-d1[2])
    print 'Done sorting %s distances'%len(distances)

    print "Max distance is %s"%distances[0]

    if max_dist < max_join:
        # if you can't make any joins, then simply return all points,
        # each as a different cluster
        return clusters

    # transform list into dictonary so that indices are never reused
    clusters = tools.list2dic_i(clusters)
    n_clusters = len(clusters)

    # at every iteration, joins the two closest clusters
    # replaces them by the mean cluster
    while 1:
        
        print string.join(map(display_cluster,
                              map(lambda c: c[VALUE],
                                  filter(lambda c: c[SIZE]>1,clusters.values()))),'\n')
        
        # initiate to max distance, to make sure we don't miss anything
        min_dist, min_pair = max_dist, None
        for i in clusters.keys():
            for j in clusters.keys():
                if j<=i: continue
                dist = distance_cluster_cluster(clusters[i][VALUE], clusters[j][VALUE])
                if dist < min_dist:
                    min_dist = dist
                    min_pair = (i,j)

        #print "Min dist = %s for %s"%(min_dist, min_pair)

        # if the join to be made is greater then the max allowed join,
        # then stop right there
        if (not min_pair) or min_dist > max_join or len(clusters)==1: break
        
        # replace the min pair with a weighted average of their averages
        # where the weights are the number of points already contained
        # in each cluster.
        i,j = min_pair
        new_mean = join_clusters([clusters[i][VALUE],
                                  clusters[j][VALUE]])
        #new_name = (clusters[i][NAME], clusters[j][NAME])
        new_name = tools.flatten([clusters[i][NAME], clusters[j][NAME]])
        new_size = clusters[i][SIZE] + clusters[j][SIZE]
        clusters[n_clusters] = (new_name, new_mean, new_size)
        n_clusters = n_clusters + 1
        del(clusters[i])
        del(clusters[j])
    cluster_list = tools.dic2list(clusters)
    cluster_list.sort(lambda c1,c2: c2[SIZE]-c1[SIZE])
    return cluster_list

def hierarchical(points, max_join,
                 distance_cluster_cluster=euclidian_distance,
                 join_clusters=weighted_average,
                 ps2c_func=points2cluster,
                 display_cluster=pp):

    # first find the max distance between any 2 points
    # only needed to initiate the min distance between every iteration

    print "Converting every point to a cluster"
    # first, converts every data point into a cluster
    clusters = map(lambda p,ps2c=ps2c_func: [[p[NAME]],ps2c([p[VALUE]]),1], points)

    print "Calculating all pairwise distances"
    max_dist = 0
    for i in range(0,len(clusters)):
        for j in range(0,len(clusters)):
            if j<=i: continue
            dist = distance_cluster_cluster(clusters[i][VALUE],
                                            clusters[j][VALUE])
            if dist > max_dist: max_dist = dist
    print "Max distance is %s"%max_dist
    if max_dist < max_join:
        # if you can't make any joins, then simply return all points,
        # each as a different cluster
        return clusters

    # transform list into dictonary so that indices are never reused
    clusters = tools.list2dic_i(clusters)
    n_clusters = len(clusters)

    # at every iteration, joins the two closest clusters
    # replaces them by the mean cluster
    while 1:
        
        print string.join(map(display_cluster,
                              map(lambda c: c[VALUE],
                                  filter(lambda c: c[SIZE]>1,clusters.values()))),'\n')
        
        # initiate to max distance, to make sure we don't miss anything
        min_dist, min_pair = max_dist, None
        for i in clusters.keys():
            for j in clusters.keys():
                if j<=i: continue
                dist = distance_cluster_cluster(clusters[i][VALUE], clusters[j][VALUE])
                if dist < min_dist:
                    min_dist = dist
                    min_pair = (i,j)

        #print "Min dist = %s for %s"%(min_dist, min_pair)

        # if the join to be made is greater then the max allowed join,
        # then stop right there
        if (not min_pair) or min_dist > max_join or len(clusters)==1: break
        
        # replace the min pair with a weighted average of their averages
        # where the weights are the number of points already contained
        # in each cluster.
        i,j = min_pair
        new_mean = join_clusters([clusters[i][VALUE],
                                  clusters[j][VALUE]])
        #new_name = (clusters[i][NAME], clusters[j][NAME])
        new_name = tools.flatten([clusters[i][NAME], clusters[j][NAME]])
        new_size = clusters[i][SIZE] + clusters[j][SIZE]
        clusters[n_clusters] = (new_name, new_mean, new_size)
        n_clusters = n_clusters + 1
        del(clusters[i])
        del(clusters[j])
    cluster_list = tools.dic2list(clusters)
    cluster_list.sort(lambda c1,c2: c2[SIZE]-c1[SIZE])
    return cluster_list

def make_k_means_func(k):
    return lambda points,k=k: k_means(points,k)

def k_means(points, k,
            distance_func=euclidian_distance,
            ps2c_func=points2cluster,
            display_cluster=pp):

    # create a dictionary mapping point names to values
    point_dic = {}
    tools.mset(point_dic,
               map(lambda p: p[NAME], points),
               map(lambda p: p[VALUE], points))

    # pick k non-identical centers
    if k>len(points): raise "More centers than data points!",(k,len(points))
    centers = []
    while len(centers) < k: 
        new_center = random.choice(points)
        # check that the new center is not on an old one
        # otherwise algorithm will not converge
        no_good = 0
        for old_center in centers:
            # distance from point to cluster
            if distance_func(new_center[VALUE], old_center[VALUE])==0:
                no_good = 1
                break
        # if it intersects a previous center, check for another center
        if no_good: continue
        centers.append([[new_center[NAME]],
                        ps2c_func([new_center[VALUE]]),
                        1])

    # now do the iteration
    iter = 0
    while 1:

        print "*"*40

        if 1:
            iter = iter+1
            print "Iteration %s"%iter
            #pp(map(lambda c: (c[SIZE],(map(lambda d: '%s'%d, c[VALUE])),
            #                  string.join(map(lambda s: '%s'%s,c[NAME])[:20],',')), centers),2)
            #pp(centers,1)

        print string.join(map(display_cluster,map(lambda c: c[VALUE],centers)),'\n')
    
        # save the old point assignments
        old_names = map(lambda center: center[NAME], centers)

        # 
        # 1. assign each point to a center
        #

        # first empty the assignments of each center
        for center in centers:
            center[NAME] = []
            center[SIZE] = 0

        # now assign the points
        for point in points:
            # find the best center
            min_dist = distance_func(point[VALUE], centers[0][VALUE])
            closest_center = 0
            for c in range(0,len(centers)):
                dist = distance_func(point[VALUE], centers[c][VALUE])
                if dist < min_dist:
                    min_dist = dist
                    closest_center = c

            # assing the point to that center
            #print 'Assigning %s to %s with score %s'%(
            #    point[VALUE],
            #    display_cluster(centers[closest_center][VALUE]),
            #    min_dist)
            centers[closest_center][NAME].append(point[NAME])
            centers[closest_center][SIZE] = centers[closest_center][SIZE]+1

        #
        # 2. reevaluate each center
        #
        for center in centers:

            # one could turn this into EM by changing the weighting
            # of each point into the probability of that point
            # belonging to that center
            center[VALUE] = ps2c_func(tools.mget(point_dic,center[NAME]))
            
        #
        # to decide if you should stop, evaluate the spread
        # or simply measure how much each center is moving each time
        #

        new_names = map(lambda center: center[NAME], centers)

        additions    = map(lambda new,old: tools.set_subtract(new,old),new_names,old_names)
        subtractions = map(lambda new,old: tools.set_subtract(old,new),new_names,old_names)

        if iter>2: #tools.sum(distances) <= .01*len(distances) or iter > 10:
            break
        else:
            print "Before:       "+string.join(map(lambda add: '%2s'%len(add),old_names))
            print "Subtractions: "+string.join(map(lambda add: '%2s'%len(add),subtractions))
            print "Additions:    "+string.join(map(lambda add: '%2s'%len(add),additions))
            print "After:        "+string.join(map(lambda add: '%2s'%len(add),new_names))

    centers.sort(lambda c1,c2: c2[SIZE]-c1[SIZE])
    return map(lambda c: (c[NAME],c[VALUE],c[SIZE]), centers)

############################################################
#
#                NORMALIZATION
#
############################################################

def mean0_stdev1(points):
    new_points = []
    for point in points:

        mean = tools.avg(point[VALUE])
        stdev = tools.stdev(point[VALUE])

        new_value = map(lambda x,m=mean,s=stdev: (x-m)/s, point[VALUE])

        #print "Old mean: %s, old stdev: %s"%(mean,stdev)
        #print "New mean: %s, new stdev: %s"%(tools.avg(new_value),tools.stdev(new_value))
        
        new_point = (point[NAME], new_value)

        new_points.append(new_point)

    return new_points

def restore_points(clusters, old_points):
    # the clustering was done using normalized data points,
    # hence the cluster values do not really correspond to anything
    # meaningful.  However, the names of the points are correct. 
    # 
    # here, we restore the original values of the points, based on
    # their names, and we recompute the cluster value as the average
    # of the unnormalized values the names correspond to 


    # make a lookup table for the original values
    point_dic = {}
    tools.mset(point_dic,
               map(lambda p: p[NAME], old_points),
               map(lambda p: p[VALUE], old_points))

    # construct the new clusters
    new_clusters = []
    for cluster in clusters:
        restored_values = []
        for name in cluster[NAME]:
            restored_values.append(point_dic[name])

        new_cluster = (cluster[NAME],
                       weighted_average(restored_values,
                                        [1]*len(restored_values)),
                       cluster[SIZE])

        new_clusters.append(new_cluster)

def normalize_and_cluster(clustering_function, normalization_function, points):

    # normalize points
    new_points = normalization_function(points)

    # compute cluster based on normalized points
    clusters = clustering_function(new_points)

    # recompute cluster mean based on original value
    restore_points(clusters, points)

    return clusters

############################################################
#
#  CLUSTERING in ANY SPACE
#
############################################################
        
def follow_references(references, clusters, i, debug=0):
    if debug: sys.stdout.write('Looking for %s'%i)
    while 1:
        if clusters.has_key(i): break
        i = references[i]
        if debug: sys.stdout.write('->%s'%i)
    if debug: sys.stdout.write(' Found!\n')
    return i







## def generic_clustering(points, distances, threshold, points2cluster, cluster2points, compare_clusters, join_clusters, debug=0):

##     # distances is a list:  [(i,j,dist_ij),
##     #                        (i,k,dist_ik),
##     #                        (j,k,dist_jk)]
##     # threshold is the minimum distance threshold
##     # 
##     # doesn't really matter what the points are,
##     # we simply use their indices.

##     if debug: print "Gathering all points indices"
##     assert(len(tools.unique(points))==len(points)) # make sure all points are unique
##     distances = tools.my_sort(distances,lambda d: -d[2]) # sort the distances from best score to worst
##     points_with_distances = tools.unique(tools.flatten(map(lambda d: [d[0],d[1]],distances))) # get the point names from distance table
##     assert(len(tools.set_intersect(points_with_distances,points))==len(points_with_distances)) # make sure all points mentioned are listed
##     if tools.set_subtract(points,points_with_distances): # and see which have no distances whatsoever
##         print "%s points with distances / %s points total (%2.0f%%).\nNo dist info for %s"%(
##             len(points_with_distances),len(points),100.0*len(points_with_distances)/len(points),
##             tools.set_subtract(points,points_with_distances))
##     # this step here is not needed. moreover, i'd like to output clusters in order of input points
##     #points = tools.unique(points_with_distances+points) # then list all points together

##     #print "Points are %s"%points

##     # here's a little trick on actually how we look indices up:
##     # instead of just looking at them in order, we first sort the
##     # index pairs, by the distance that separates them.
##     #
##     # then since, we always look things up in order, we'll end up
##     # joining the shortest motifs together first

##     if debug:
##         print "Making a dictionary for quick distance lookup"
        
##     distdic = {}
##     for i,j,dist in distances:
##         distdic[(min(i,j),max(i,j))] = dist

##     if debug:
##         print "Distance dictionary is: "
##         pp(distdic)
        
##     if debug:
##         print "%s points defined by %s distances"%(len(points),len(distances))
##         print "Distance distribution is: "
##         tools.quick_histogram(tools.cget(distances,2))

##     if debug: print "Making every point its own cluster"
##     references = {}
##     clusters = {}
##     used_names = {}
##     tools.mset(used_names,points,[None]*len(points))
##     maxi,p_i = -1,0
##     while p_i < len(points):
##         maxi = maxi+1
##         if used_names.has_key(maxi): continue
##         # we found an index for the cluster
##         clusters[maxi] = points2cluster([points[p_i]])
##         references[points[p_i]] = maxi
##         p_i = p_i+1
        
##     if debug:
##         print "The clusters are: "
##         pp(clusters,1)

##     #maxi = len(clusters)

##     hierarchy = {}
##     for cluster,elements in clusters.items():
##         assert(len(elements)==1)
##         hierarchy[cluster] = elements[0]

##     #print "Hierarchy is %s"%hierarchy

##     join_score = {}

##     suggestions = {}
##     vetos = {}
##     #references = {}

##     comparison_cache = {}

##     # two ways or sorting the order we visit the pairs
##     # 1) sort items
##     # 2) sort pairs

##     # 1) this is the sorting of the items:  bad
##     #    coz it leads to 1+1,2+1,3+1,4+1,5+1 joinings
##     # keep trying until no more joins are made
##     #indices = tools.my_sort(clusters.keys())
##     #pairs = []
##     #for i in range(0,len(indices)):
##     #    for j in range(i+1,len(indices)):
##     #        pairs.append((indices[i],indices[j]))
##     #pairs = tools.my_sort(pairs,max)

##     # 2) this is the sorting of the pairs
##     # this is better coz it does 1+1,1+1,2+2,1+1,4+2, etc
##     pairs = map(None, tools.cget(distances,0),tools.cget(distances,1))

##     # i will only compare points that had *any* similarity to start with

##     #boom

##     retry_all = 1
##     while retry_all:

##         if debug: print "Re-starting all the loops (%s by %s)"%(len(clusters),len(clusters))
##         if debug: print "Sizes are: %s"%tools.describe_elements(map(len,clusters.values()),lambda size: -size)

##         retry_all = 0

##         for i_tmp,j_tmp in pairs:

##             # probably joined and renamed
##             if clusters.has_key(i_tmp): i = i_tmp
##             else: i = follow_references(references, clusters, i_tmp, debug=debug)
            
##             if clusters.has_key(j_tmp): j = j_tmp
##             else: j = follow_references(references, clusters, j_tmp, debug=debug)


##             if debug: print "Testing Pair (%s,%s) now in clusters (%s,%s)"%(i_tmp,j_tmp,i,j)


##             # it is possible that we've already joined the clusters they now belong in
##             if i==j: continue
            
##             # cluster2cluster_distance

##             #if debug: 
##             #    print "Comparison cache is"
##             #    pp(comparison_cache,1)

##             # PLEASE NOTE!!  IF WE RE-USE CLUSTER NAMES, WE SHOULD FLUSH THE CACHE
##             # AT EVERY ITERATION
##             if comparison_cache.has_key((i,j)):
##                 score = comparison_cache[(i,j)]
##             else:
##                 score = compare2clusters(distdic, clusters[i], clusters[j])
##                 comparison_cache[(i,j)] = score

##             if debug: print "Avg max linkage between %s and %s is %s for %s%% tested and %s for %s tested"%(
##                 i,j,value1,perc1,value2,perc2)

##             if score > threshold:

##                 # now i know i'm joining, increment maxi
##                 maxi = maxi+1

##                 # make a new cluster out of the two joined ones
##                 newcluster = join_clusters(clusters[i],clusters[j])
##                 del(clusters[i])
##                 del(clusters[j])
##                 clusters[maxi] = newcluster

##                 hierarchy[maxi] = (hierarchy[i],hierarchy[j])
##                 join_score[maxi] = weighted_linkage
##                 del(hierarchy[i])
##                 del(hierarchy[j])

##                 references[i] = maxi
##                 references[j] = maxi

##                 #clusters[i] = newcluster
                
##                 retry_all = 1

##     keys = tools.cget(tools.my_sort(clusters.items(), lambda item: min(item[1])),0)
##     #keys = tools.my_sort(clusters.keys())
##     #keys = tools.my_sort(keys,lambda k,join_score=join_score: -(join_score.has_key(k) and join_score[k]))
##     # resorting the keys to put singletons at the end

##     #singletons = tools.lte(keys,len(points))
##     #joined     = tools.gt(keys,len(points))
##     #keys = joined+singletons

##     groups = tools.mget(clusters,keys)
##     hierarchy = tools.mget(hierarchy, keys)

##     #pp(hierarchy,3)
    
##     linkages = []
##     for group in groups: 
##         linkage = []
##         for m1 in range(0,len(group)):
##             for m2 in range(m1+1,len(group)):
##                 if distdic.has_key((group[m1],group[m2])): 
##                     linkage.append(distdic[(group[m1],group[m2])])
##                 else:
##                     linkage.append(0)
##         #linkages.append(tools.avg(linkage))
##         linkages.append(tools.reverse(tools.my_sort(linkage)))

##     #linkages, groups = tools.unpack(tools.reverse(tools.my_sort(map(None,linkages,groups))))
##     #lens, avglink, linkages, groups = tools.unpack(tools.reverse(tools.my_sort(
##     #    map(None,
##     #        map(len,groups),
##     #        map(tools.avg,linkages),
##     #        linkages,
##     #        groups))))
            
    
##     return groups, linkages, hierarchy


def avg_max_linkage(points, distances, threshold, linkage_type='max', debug=0):

    # distances is a list:  [(i,j,dist_ij),
    #                        (i,k,dist_ik),
    #                        (j,k,dist_jk)]
    # threshold is the minimum distance threshold
    # 
    # doesn't really matter what the points are,
    # we simply use their indices.

    if debug: print "Gathering all points indices"
    distances = tools.my_sort(distances,lambda d: -d[2])
    points_with_distances = tools.unique(tools.flatten(map(lambda d: [d[0],d[1]],distances)))
    assert(len(tools.set_intersect(points_with_distances,points))==len(points_with_distances))
    if tools.set_subtract(points,points_with_distances):
        print "%s points with distances / %s points total (%2.0f%%).\nNo dist info for %s"%(
            len(points_with_distances),len(points),100.0*len(points_with_distances)/len(points),
            tools.set_subtract(points,points_with_distances))
    points = tools.unique(points_with_distances+points)

    #print "Points are %s"%points

    # here's a little trick on actually how we look indices up:
    # instead of just looking at them in order, we first sort the
    # index pairs, by the distance that separates them.
    #
    # then since, we always look things up in order, we'll end up
    # joining the shortest motifs together first

    if debug:
        print "Making a dictionary for quick distance lookup"
        
    distdic = {}
    for i,j,dist in distances:
        distdic[(min(i,j),max(i,j))] = dist

    if debug:
        print "Distance dictionary is: "
        pp(distdic)
        
    if debug:
        print "%s points defined by %s distances"%(len(points),len(distances))
        print "Distance distribution is: "
        tools.quick_histogram(tools.cget(distances,2))

    if debug: print "Making every point its own cluster"
    references = {}
    clusters = {}
    used_names = {}
    tools.mset(used_names,points,[None]*len(points))
    maxi,p_i = -1,0
    while p_i < len(points):
        maxi = maxi+1
        if used_names.has_key(maxi): continue
        # we found an index for the cluster
        clusters[maxi] = [points[p_i]]
        references[points[p_i]] = maxi
        p_i = p_i+1
        
        
    if debug:
        print "The clusters are: "
        pp(clusters,1)

    #maxi = len(clusters)

    hierarchy = {}
    for cluster,elements in clusters.items():
        assert(len(elements)==1)
        hierarchy[cluster] = elements[0]

    #print "Hierarchy is %s"%hierarchy

    join_score = {}

    suggestions = {}
    vetos = {}
    #references = {}

    comparison_cache = {}

    # two ways or sorting the order we visit the pairs
    # 1) sort items
    # 2) sort pairs

    # 1) this is the sorting of the items:  bad
    #    coz it leads to 1+1,2+1,3+1,4+1,5+1 joinings
    # keep trying until no more joins are made
    #indices = tools.my_sort(clusters.keys())
    #pairs = []
    #for i in range(0,len(indices)):
    #    for j in range(i+1,len(indices)):
    #        pairs.append((indices[i],indices[j]))
    #pairs = tools.my_sort(pairs,max)

    # 2) this is the sorting of the pairs
    # this is better coz it does 1+1,1+1,2+2,1+1,4+2, etc
    pairs = map(None, tools.cget(distances,0),tools.cget(distances,1))

    #boom

    retry_all = 1
    while retry_all:

        if debug: print "Re-starting all the loops (%s by %s)"%(len(clusters),len(clusters))
        if debug: print "Sizes are: %s"%tools.describe_elements(map(len,clusters.values()),lambda size: -size)

        retry_all = 0

        for i_tmp,j_tmp in pairs:

            # probably joined and renamed
            if clusters.has_key(i_tmp): i = i_tmp
            else: i = follow_references(references, clusters, i_tmp, debug=debug)
            
            if clusters.has_key(j_tmp): j = j_tmp
            else: j = follow_references(references, clusters, j_tmp, debug=debug)


            if debug: print "Testing Pair (%s,%s) now in clusters (%s,%s)"%(i_tmp,j_tmp,i,j)


            # it is possible that we've already joined the clusters they now belong in
            if i==j: continue
            
            # cluster2cluster_distance

            #if debug: 
            #    print "Comparison cache is"
            #    pp(comparison_cache,1)

            # PLEASE NOTE!!  IF WE RE-USE CLUSTER NAMES, WE SHOULD FLUSH THE CACHE
            # AT EVERY ITERATION
            if comparison_cache.has_key((i,j)):
                value1, perc1, value2, perc2, weighted_linkage = comparison_cache[(i,j)]
            else: 
                value1, perc1 = get_avg_linkage(distdic, clusters[i], clusters[j], linkage_type)
                value2, perc2 = get_avg_linkage(distdic, clusters[j], clusters[i], linkage_type)
                weighted_linkage = tools.weighted_avg([value1,value2],[len(clusters[i]),len(clusters[j])])
                comparison_cache[(i,j)] = (value1,perc1,value2,perc2,weighted_linkage)

            if debug: print "Avg max linkage between %s and %s is %s for %s%% tested and %s for %s tested"%(
                i,j,value1,perc1,value2,perc2)

            if 0: #join_score.has_key(i) and join_score.has_key(j):
                if weighted_linkage < .5*min(join_score[j],join_score[i]):
                    #if 0: print "I veto joining %s to %s.  (%2.0f,%2.0f) -> %2.0f is too big a drop in linkage"%(
                    #    tools.display_list(clusters[i],format='m%s'),
                    #    tools.display_list(clusters[j],format='m%s'),
                    #    join_score[j],join_score[i],weighted_linkage)
                    #suggestions[(i,j)] = weighted_linkage
                    pass
                
            if value1*perc1/100.0 > threshold and value2*perc2/100.0 > threshold:

                #if join_score.has_key(i) and join_score.has_key(j):
                #    if weighted_linkage < .5*min(join_score[j],join_score[i]):
                #        print "I veto joining %s to %s.  (%2.0f,%2.0f) -> %2.0f is too big a drop in linkage"%(
                #            tools.display_list(clusters[i],format='m%s'),
                #            tools.display_list(clusters[j],format='m%s'),
                #            join_score[j],join_score[i],weighted_linkage)
                #        continue
                
                # now i know i'm joining, increment maxi
                maxi = maxi+1


                if debug:

                    if join_score.has_key(i): join_score_i = '%2.0f linkage'%join_score[i]
                    else: join_score_i = 'singleton'
                    if join_score.has_key(j): join_score_j = '%2.0f linkage'%join_score[j]
                    else: join_score_j = 'singleton'
                    
                    if debug: print "Joining clu_%s (%s items, %s) and clu_%s (%s items, %s) -> clu_%s (%s items, %2.0f linkage)"%(
                        i,len(clusters[i]),join_score_i,
                        j,len(clusters[j]),join_score_j,
                        maxi,len(clusters[i])+len(clusters[j]),weighted_linkage)
                
                # make a new cluster out of the two joined ones
                newcluster = clusters[i]+clusters[j]
                del(clusters[i])
                del(clusters[j])
                clusters[maxi] = newcluster

                hierarchy[maxi] = (hierarchy[i],hierarchy[j])
                join_score[maxi] = weighted_linkage
                del(hierarchy[i])
                del(hierarchy[j])

                references[i] = maxi
                references[j] = maxi

                #clusters[i] = newcluster
                
                retry_all = 1

    keys = tools.cget(tools.my_sort(clusters.items(), lambda item: min(item[1])),0)
    #keys = tools.my_sort(clusters.keys())
    #keys = tools.my_sort(keys,lambda k,join_score=join_score: -(join_score.has_key(k) and join_score[k]))
    # resorting the keys to put singletons at the end

    #singletons = tools.lte(keys,len(points))
    #joined     = tools.gt(keys,len(points))
    #keys = joined+singletons

    groups = tools.mget(clusters,keys)
    hierarchy = tools.mget(hierarchy, keys)

    #pp(hierarchy,3)
    
    linkages = []
    for group in groups: 
        linkage = []
        for m1 in range(0,len(group)):
            for m2 in range(m1+1,len(group)):
                if distdic.has_key((group[m1],group[m2])): 
                    linkage.append(distdic[(group[m1],group[m2])])
                else:
                    linkage.append(0)
        #linkages.append(tools.avg(linkage))
        linkages.append(tools.reverse(tools.my_sort(linkage)))

    #linkages, groups = tools.unpack(tools.reverse(tools.my_sort(map(None,linkages,groups))))
    #lens, avglink, linkages, groups = tools.unpack(tools.reverse(tools.my_sort(
    #    map(None,
    #        map(len,groups),
    #        map(tools.avg,linkages),
    #        linkages,
    #        groups))))
            
    
    return groups, linkages, hierarchy


def get_avg_linkage(distdic, cluster1, cluster2, linkage_type='max'):
    # returns the avg

    max_similarity = []

    # for every poitn in cluster1
    for i in cluster1:
        
        # find the closest point in cluster2
        #dists_i = tools.mget(distdic, map(lambda j,i=i: (min(i,j),max(i,j)), cluster2), 0)
        dists_i = []
        for j in cluster2:

            key = min(i,j),max(i,j)
            if distdic.has_key(key): dists_i.append(distdic[key])

            

        # and average those distances
        if dists_i:
            if linkage_type=='avg': max_similarity.append(tools.avg(dists_i))
            elif linkage_type=='max': max_similarity.append(max(dists_i))
            elif linkage_type=='min': max_similarity.append(min(dists_i))
            else: raise ValueError, linkage_type
        #if dists_i: max_similarity.append(max(dists_i))

    # then return the avg max similarity, and well as the percent for which we had values
    return tools.avg(max_similarity), 100.0*len(max_similarity)/len(cluster1)
        
def cluster_clusters(grouping1, grouping2):

    groups, linkages, hierarchies = grouping1
    ggroups, glinkages, ghierarchies = grouping2

    newgroups, newlinkages, newhierarchies = [], [], []
    for ggroup,glinkage,ghierarchy in map(None, ggroups, glinkages, ghierarchies):

        newhierarchies.append(tools.map_on_hierarchy_safe(ghierarchy, tools.list2dic_i(hierarchies).get))

        newgroups.append(tools.flatten(tools.mget(groups, ggroup)))

        #print "Linkages of the previous groups: %s"%tools.mget(linkages,ggroup)
        #print "Current linkage: %s"%glinkage
        newlinkages.append(tools.flatten(tools.mget(linkages,ggroup))+glinkage)

    return newgroups, newlinkages, newhierarchies
