// UPGMATree.java
//
// (c) 1999-2001 PAL Development Core Team
//
// This package may be distributed under the
// terms of the Lesser GNU General Public License (LGPL)

// Known bugs and limitations:
// - computational complexity O(numSeqs^3)
//   (this could be brought down to O(numSeqs^2)
//   but this needs more clever programming ...)


package de.ugoe.cs.autoquest.tasktrees.alignment.pal.tree;

import java.util.ArrayList;
import java.util.logging.Level;

import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithm;
import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithmFactory;
import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.NumberSequence;
import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.SmithWatermanRepeated;
import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.BinaryAlignmentStorage;
import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.ObjectDistanceSubstitionMatrix;
import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.UPGMAMatrix;
import de.ugoe.cs.autoquest.tasktrees.alignment.pal.misc.Identifier;
import de.ugoe.cs.util.console.Console;


/**
 * constructs a UPGMA tree from pairwise distances
 *
 * @version $Id: UPGMATree.java,v 1.9 2001/07/13 14:39:13 korbinian Exp $
 *
 * @author Korbinian Strimmer
 * @author Alexei Drummond
 */
public class UPGMAAligningTree extends SimpleTree
{
	//
	// Public stuff
	//	

	/**
	 * constructor UPGMA tree
	 * @param numberseqs 
	 *
	 * @param m distance matrix
	 */
	public UPGMAAligningTree(ArrayList<NumberSequence> numberseqs, BinaryAlignmentStorage alignments, ObjectDistanceSubstitionMatrix submat)
	{
		if (alignments.getDistanceMatrix().size() < 2)
		{
			new IllegalArgumentException("LESS THAN 2 TAXA IN DISTANCE MATRIX");
		}
	
		this.numberseqs = numberseqs;
		this.alignments = alignments;
		this.submat = submat;
		init(alignments.getDistanceMatrix());

		while (true)
		{
			findNextPair();
			newBranchLengths();
			
			if (numClusters == 2)
			{
				break;
			}
			
			newCluster();
		}
		
		finish();
		createNodeList();
	}


	//
	// Private stuff
	//
	private ArrayList<NumberSequence> numberseqs;
	private BinaryAlignmentStorage alignments;
	private ObjectDistanceSubstitionMatrix submat;
	private int numClusters;
	private int besti, abi;
	private int bestj, abj;
	private int[] alias;
	private double[][] distance;

	private double[] height;
	private int[] oc;

	private double getDist(int a, int b)
	{
		return distance[alias[a]][alias[b]];
	}
	
	private void init(UPGMAMatrix m)
	{
		numClusters = m.size();

		distance = new double[numClusters][numClusters];
		for (int i = 0; i < numClusters; i++)
		{
			for (int j = 0; j < numClusters; j++)
			{
				distance[i][j] = m.get(i,j);
			}
		}

		for (int i = 0; i < numClusters; i++)
		{
			Node tmp = NodeFactory.createNode();
			tmp.setIdentifier(new Identifier(Integer.toString(i)));
			tmp.setNumber(i);
			tmp.addSequence(numberseqs.get(i));
			getRoot().addChild(tmp);
		}
		
		alias = new int[numClusters];
		for (int i = 0; i < numClusters; i++)
		{
			alias[i] = i;
		}
				
		height = new double[numClusters];
		oc = new int[numClusters];
		for (int i = 0; i < numClusters; i++)
		{
			height[i] = 0.0;
			oc[i] = 1;
		}
	}

	private void finish()
	{
		this.getRoot().setSequences(alignSequences(this.getRoot()));
		distance = null;		
	}

	private void findNextPair()
	{
		besti = 0;
		bestj = 1;
		double dmin = getDist(0, 1);
		for (int i = 0; i < numClusters-1; i++)
		{
			for (int j = i+1; j < numClusters; j++)
			{
				if (getDist(i, j) < dmin)
				{
					dmin = getDist(i, j);
					besti = i;
					bestj = j;
				}
			}
		}
		abi = alias[besti];
		abj = alias[bestj];
		//System.out.println("Found best pair: " + abi + "/" +abj + " - "+ besti+ "/"+bestj +" with distance " + dmin);
		
	}

	private void newBranchLengths()
	{
		double dij = getDist(besti, bestj);
		
		getRoot().getChild(besti).setBranchLength(dij/2.0-height[abi]);
		getRoot().getChild(bestj).setBranchLength(dij/2.0-height[abj]);
	}

	private void newCluster()
	{
		// Update distances
		for (int k = 0; k < numClusters; k++)
		{
			if (k != besti && k != bestj)
			{
				int ak = alias[k];
				double updated = updatedDistance(besti,bestj,k);
				distance[ak][abi] = distance[abi][ak] = updated;
			}
		}
		distance[abi][abi] = 0.0;

		// Update UPGMA variables
		height[abi] = getDist(besti, bestj)/2.0;
		oc[abi] += oc[abj];
		
		// Index besti now represent the new cluster
		Node newNode = getRoot().joinChildren(besti, bestj);
		
		if(newNode instanceof FengDoolittleNode) {
			newNode.setSequences(alignSequences(newNode));
		}
		
		// Update alias
		for (int i = bestj; i < numClusters-1; i++)
		{
			alias[i] = alias[i+1];
		}
		
		numClusters--;
	}
	
	
	public ArrayList<NumberSequence> alignSequences(Node parent) {
		ArrayList<NumberSequence> alignment = new ArrayList<NumberSequence>();
		if(parent.getChildCount()<3) {
			
			Node node1 = parent.getChild(0);
			Node node2 = parent.getChild(1);
			
			int seqCount1 = node1.getSequences().size();
			int seqCount2 = node2.getSequences().size();
			
			/*
			for(int i = 0; i < seqCount1; i++) {
				for(int j = 0; j < seqCount2; j++) {
					node1.getSequence(i).printSequence();
					node2.getSequence(j).printSequence();
				}
			}
			*/
			
			Console.traceln(Level.INFO,"Merging node " + node1.getIdentifier() + " with " + node2.getIdentifier());
			//Console.println("SeqCount1: " + seqCount1 + " seqCount2 " + seqCount2);
			//Align 2 sequences
			if(seqCount1 == 1 && seqCount2 == 1) {
				alignment = (alignments.get(node1.getNumber(), node2.getNumber())).getAlignment();
				
			}
			//Align a sequence to a group
			else if( seqCount1 > 1 && seqCount2 == 1) {
				alignment.addAll(node1.getSequences());
				
				BinaryAlignmentStorage tempStorage = new BinaryAlignmentStorage(seqCount1,seqCount2);
				double maxScore = 0.0;
				int maxIndex = 0;
				for(int i=0;i<seqCount1;i++) {
					tempStorage.set(i, 1, AlignmentAlgorithmFactory.create(node1.getSequence(i).getSequence(), node2.getSequence(0).getSequence() , submat, 5));
					if(maxScore < tempStorage.get(i, 1).getAlignmentScore()) {
						maxScore = tempStorage.get(i, 1).getAlignmentScore();
						maxIndex = i;
					}
				}
				//if(maxScore > 0)
				
				alignment.add(tempStorage.get(maxIndex, 1).getAlignment().get(1));
			}
			//Align a sequence to a group
			else if(seqCount1 == 1 && seqCount2 > 1) {
				alignment.addAll(node2.getSequences());
				BinaryAlignmentStorage tempStorage = new BinaryAlignmentStorage(seqCount1,seqCount2);
				double maxScore = 0.0;
				int maxIndex = 0;
				for(int i=0;i<seqCount2;i++) {
					tempStorage.set(1, i, AlignmentAlgorithmFactory.create(node2.getSequence(i).getSequence(), node1.getSequence(0).getSequence() , submat, 5));
					if(maxScore < tempStorage.get(1, i).getAlignmentScore()) {
						maxScore = tempStorage.get(1, i).getAlignmentScore();
						maxIndex = i;
					}
				}
				//if(maxScore > 0)
				
				alignment.add(tempStorage.get(1,maxIndex).getAlignment().get(1));
			}
			//Align 2 groups
			else if((seqCount1 > 1) && (seqCount2 > 1)){
					BinaryAlignmentStorage tempStorage1 = new BinaryAlignmentStorage(seqCount2,1);
					BinaryAlignmentStorage tempStorage2 = new BinaryAlignmentStorage(seqCount1,1);
					double maxScore1 = 0.0;
					double maxScore2 = 0.0;
					int maxIndex1 = 0;
					int maxIndex2 = 0;
					for(int i=0;i<seqCount1;i++) {
						for(int j=0;j<seqCount2;j++) {
							tempStorage1.set(j, 0, AlignmentAlgorithmFactory.create(node1.getSequence(i).getSequence(), node2.getSequence(j).getSequence() , submat, 5));
							if(maxScore1 < tempStorage1.get(j, 0).getAlignmentScore()) {
								maxScore1 = tempStorage1.get(j, 0).getAlignmentScore();
								maxIndex1 = j;
							}
						}
						//if(maxScore1 > 0)
						alignment.add(tempStorage1.get(maxIndex1,0).getAlignment().get(0));
					}
					for(int i=0; i<seqCount2;i++) {
						for (int j=0;j<seqCount1;j++) {
							tempStorage2.set(j, 0, AlignmentAlgorithmFactory.create(node2.getSequence(i).getSequence(),node1.getSequence(j).getSequence(),submat,5));
							if(maxScore2 < tempStorage2.get(j, 0).getAlignmentScore()) {
								maxScore2 = tempStorage2.get(j, 0).getAlignmentScore();
								maxIndex2 = j;
							}
						}
						//if(maxScore2 > 0)
						alignment.add(tempStorage2.get(maxIndex2,0).getAlignment().get(0));
					}
					
			}
			else {
				Console.traceln(Level.WARNING,"No sequences to align while merging " + node1.getIdentifier() + " with " + node2.getIdentifier());
			}
		}
		else {
			Console.traceln(Level.WARNING,"More than 2 children! This should never happen, it's a binary tree.");
		}
		return alignment;
	}

	
	/**
	 * compute updated distance between the new cluster (i,j)
	 * to any other cluster k
	 */
	private double updatedDistance(int i, int j, int k)
	{
		int ai = alias[i];
		int aj = alias[j];
		
		double ocsum = (double) (oc[ai]+oc[aj]);
		double idist = getDist(k,i);
		double jdist = getDist(k,j);
		//TODO: Dirty hack to deal with infinity, insert proper solution here
		if(Double.isInfinite(idist)) {
			idist = 100;
		}
		if(Double.isInfinite(jdist)) {
			jdist = 100;
		}
		
		return 	(oc[ai]/ocsum)*idist+
			(oc[aj]/ocsum)*jdist;
	}
}


