// UPGMATree.java
//
// (c) 1999-2001 PAL Development Core Team
//
// This package may be distributed under the
// terms of the Lesser GNU General Public License (LGPL)

// Known bugs and limitations:
// - computational complexity O(numSeqs^3)
//   (this could be brought down to O(numSeqs^2)
//   but this needs more clever programming ...)


package de.ugoe.cs.autoquest.tasktrees.alignment.pal.tree;

import java.util.ArrayList;

import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.NumberSequence;
import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.UPGMAMatrix;
import de.ugoe.cs.autoquest.tasktrees.alignment.pal.misc.Identifier;


/**
 * constructs a UPGMA tree from pairwise distances
 *
 * @version $Id: UPGMATree.java,v 1.9 2001/07/13 14:39:13 korbinian Exp $
 *
 * @author Korbinian Strimmer
 * @author Alexei Drummond
 */
public class UPGMAAligningTree extends SimpleTree
{
	//
	// Public stuff
	//	

	/**
	 * constructor UPGMA tree
	 * @param numberseqs 
	 *
	 * @param m distance matrix
	 */
	public UPGMAAligningTree(ArrayList<NumberSequence> numberseqs, UPGMAMatrix m)
	{
		if (m.size() < 2)
		{
			new IllegalArgumentException("LESS THAN 2 TAXA IN DISTANCE MATRIX");
		}
	
		this.numberseqs = numberseqs;
		init(m);

		while (true)
		{
			findNextPair();
			newBranchLengths();
			
			if (numClusters == 2)
			{
				break;
			}
			
			newCluster();
		}
		
		finish();
		createNodeList();
	}


	//
	// Private stuff
	//
	private ArrayList<NumberSequence> numberseqs;
	private int numClusters;
	private int besti, abi;
	private int bestj, abj;
	private int[] alias;
	private double[][] distance;

	private double[] height;
	private int[] oc;

	private double getDist(int a, int b)
	{
		return distance[alias[a]][alias[b]];
	}
	
	private void init(UPGMAMatrix m)
	{
		numClusters = m.size();

		distance = new double[numClusters][numClusters];
		for (int i = 0; i < numClusters; i++)
		{
			for (int j = 0; j < numClusters; j++)
			{
				distance[i][j] = m.get(i,j);
			}
		}

		for (int i = 0; i < numClusters; i++)
		{
			Node tmp = NodeFactory.createNode();
			tmp.setIdentifier(new Identifier(Integer.toString(i)));
			tmp.addSequence(numberseqs.get(i));
			getRoot().addChild(tmp);
		}
		
		alias = new int[numClusters];
		for (int i = 0; i < numClusters; i++)
		{
			alias[i] = i;
		}
				
		height = new double[numClusters];
		oc = new int[numClusters];
		for (int i = 0; i < numClusters; i++)
		{
			height[i] = 0.0;
			oc[i] = 1;
		}
	}

	private void finish()
	{
		distance = null;		
	}

	private void findNextPair()
	{
		besti = 0;
		bestj = 1;
		double dmin = getDist(0, 1);
		for (int i = 0; i < numClusters-1; i++)
		{
			for (int j = i+1; j < numClusters; j++)
			{
				if (getDist(i, j) < dmin)
				{
					dmin = getDist(i, j);
					besti = i;
					bestj = j;
				}
			}
		}
		abi = alias[besti];
		abj = alias[bestj];
		//System.out.println("Found best pair: " + abi + "/" +abj + " - "+ besti+ "/"+bestj +" with distance " + dmin);
		
	}

	private void newBranchLengths()
	{
		double dij = getDist(besti, bestj);
		
		getRoot().getChild(besti).setBranchLength(dij/2.0-height[abi]);
		getRoot().getChild(bestj).setBranchLength(dij/2.0-height[abj]);
	}

	private void newCluster()
	{
		// Update distances
		for (int k = 0; k < numClusters; k++)
		{
			if (k != besti && k != bestj)
			{
				int ak = alias[k];
				double updated = updatedDistance(besti,bestj,k);
				distance[ak][abi] = distance[abi][ak] = updated;
			}
		}
		distance[abi][abi] = 0.0;

		// Update UPGMA variables
		height[abi] = getDist(besti, bestj)/2.0;
		oc[abi] += oc[abj];
		
		// Index besti now represent the new cluster
		getRoot().joinChildren(besti, bestj);
		
		// Update alias
		for (int i = bestj; i < numClusters-1; i++)
		{
			alias[i] = alias[i+1];
		}
		
		numClusters--;
	}

	
	/**
	 * compute updated distance between the new cluster (i,j)
	 * to any other cluster k
	 */
	private double updatedDistance(int i, int j, int k)
	{
		int ai = alias[i];
		int aj = alias[j];
		
		double ocsum = (double) (oc[ai]+oc[aj]);
		double idist = getDist(k,i);
		double jdist = getDist(k,j);
		//TODO: Dirty hack to deal with infinity, insert proper solution here
		if(Double.isInfinite(idist)) {
			idist = 100;
		}
		if(Double.isInfinite(jdist)) {
			jdist = 100;
		}
		
		return 	(oc[ai]/ocsum)*idist+
			(oc[aj]/ocsum)*jdist;
	}
}
