package de.ugoe.cs.autoquest.tasktrees.alignment.algorithms;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;

import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.SubstitutionMatrix;
import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.Constants;

public class SmithWatermanRepeated implements AlignmentAlgorithm {

	/**
	 * The first input
	 */
	private int[] input1;

	/**
	 * The second input String
	 */
	private int[] input2;

	/**
	 * The lengths of the input
	 */
	private int length1, length2;

	/**
	 * The score matrix. The true scores should be divided by the normalization
	 * factor.
	 */
	private MatrixEntry[][] matrix;

	private ArrayList<NumberSequence> alignment;

	private float scoreThreshold;

	/**
	 * Substitution matrix to calculate scores
	 */
	private SubstitutionMatrix submat;

	public SmithWatermanRepeated() {

	}

	@Override
	public void align(NumberSequence input1, NumberSequence input2,
			SubstitutionMatrix submat, float threshold) {

		alignment = new ArrayList<NumberSequence>();
		alignment.add(input1);
		alignment.add(input2);

		this.input1 = input1.getSequence();
		this.input2 = input2.getSequence();

		length1 = input1.size();
		length2 = input2.size();
		this.submat = submat;

		// System.out.println("Starting SmithWaterman algorithm with a "
		// + submat.getClass() + " Substitution Matrix: " +
		// submat.getClass().getCanonicalName());
		this.scoreThreshold = threshold;

		matrix = new MatrixEntry[length1 + 2][length2 + 1];

		for (int i = 0; i <= (length1 + 1); i++) {
			for (int j = 0; j <= length2; j++) {
				matrix[i][j] = new MatrixEntry();
			}
		}

		// Console.traceln(Level.INFO,"Generating DP Matrix");
		buildMatrix();
		// Console.traceln(Level.INFO,"Doing traceback");
		traceback();
	}

	/**
	 * Build the score matrix using dynamic programming.
	 */
	private void buildMatrix() {
		if (submat.getGapPenalty() >= 0) {
			throw new Error("Indel score must be negative");
		}

		// it's a gap
		matrix[0][0].setScore(0);
		matrix[0][0].setPrevious(null); // starting point
		matrix[0][0].setXvalue(Constants.UNMATCHED_SYMBOL);
		matrix[0][0].setYvalue(Constants.UNMATCHED_SYMBOL);

		// the first column
		for (int j = 1; j < length2; j++) {
			matrix[0][j].setScore(0);
			// We don't need to go back to [0][0] if we reached matrix[0][x], so
			// just end here
			// matrix[0][j].setPrevious(matrix[0][j-1]);
			matrix[0][j].setPrevious(null);
		}

		for (int i = 1; i < (length1 + 2); i++) {

			// Formula for first row:
			// F(i,0) = max { F(i-1,0), F(i-1,j)-T j=1,...,m

			final double firstRowLeftScore = matrix[i - 1][0].getScore();
			// for sequences of length 1
			double tempMax;
			int maxRowIndex;
			if (length2 == 1) {
				tempMax = matrix[i - 1][0].getScore();
				maxRowIndex = 0;
			} else {
				tempMax = matrix[i - 1][1].getScore();
				maxRowIndex = 1;
				// position of the maximal score of the previous row

				for (int j = 2; j <= length2; j++) {
					if (matrix[i - 1][j].getScore() > tempMax) {
						tempMax = matrix[i - 1][j].getScore();
						maxRowIndex = j;
					}
				}

			}

			tempMax -= scoreThreshold;
			matrix[i][0].setScore(Math.max(firstRowLeftScore, tempMax));
			if (tempMax == matrix[i][0].getScore()) {
				matrix[i][0].setPrevious(matrix[i - 1][maxRowIndex]);
			}

			if (firstRowLeftScore == matrix[i][0].getScore()) {
				matrix[i][0].setPrevious(matrix[i - 1][0]);
			}

			// The last additional score is not related to a character in the
			// input sequence, it's the total score. Therefore we don't need to
			// save something for it
			// and can end here
			if (i < (length1 + 1)) {
				matrix[i][0].setXvalue(input1[i - 1]);
				matrix[i][0].setYvalue(Constants.UNMATCHED_SYMBOL);
			} else {
				return;
			}

			for (int j = 1; j <= length2; j++) {
				final double diagScore = matrix[i - 1][j - 1].getScore()
						+ similarity(i, j);
				final double upScore = matrix[i][j - 1].getScore()
						+ submat.getGapPenalty();
				final double leftScore = matrix[i - 1][j].getScore()
						+ submat.getGapPenalty();

				matrix[i][j].setScore(Math.max(
						diagScore,
						Math.max(upScore,
								Math.max(leftScore, matrix[i][0].getScore()))));

				// find the directions that give the maximum scores.
				// TODO: Multiple directions are ignored, we choose the first
				// maximum score
				// True if we had a match
				if (diagScore == matrix[i][j].getScore()) {
					matrix[i][j].setPrevious(matrix[i - 1][j - 1]);
					matrix[i][j].setXvalue(input1[i - 1]);
					matrix[i][j].setYvalue(input2[j - 1]);
				}
				// true if we took an event from sequence x and not from y
				if (leftScore == matrix[i][j].getScore()) {
					matrix[i][j].setXvalue(input1[i - 1]);
					matrix[i][j].setYvalue(Constants.GAP_SYMBOL);
					matrix[i][j].setPrevious(matrix[i - 1][j]);
				}
				// true if we took an event from sequence y and not from x
				if (upScore == matrix[i][j].getScore()) {
					matrix[i][j].setXvalue(Constants.GAP_SYMBOL);
					matrix[i][j].setYvalue(input2[j - 1]);
					matrix[i][j].setPrevious(matrix[i][j - 1]);
				}
				// true if we ended a matching region
				if (matrix[i][0].getScore() == matrix[i][j].getScore()) {
					matrix[i][j].setPrevious(matrix[i][0]);
					matrix[i][j].setXvalue(input1[i - 1]);
					matrix[i][j].setYvalue(Constants.UNMATCHED_SYMBOL);
				}
			}

			// Set the complete score cell

		}
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithm
	 * #getAlignment()
	 */
	@Override
	public ArrayList<NumberSequence> getAlignment() {
		return alignment;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithm
	 * #getAlignmentScore()
	 */
	@Override
	public double getAlignmentScore() {
		return matrix[length1 + 1][0].getScore();
	}

	@Override
	public ArrayList<Match> getMatches() {
		final ArrayList<Match> result = new ArrayList<Match>();

		// both alignment sequences should be equally long
		int i = 0;
		final int[] seq1 = alignment.get(0).getSequence();
		final int[] seq2 = alignment.get(1).getSequence();
		int start = 0;
		while (i < seq1.length) {
			if (seq2[i] != Constants.UNMATCHED_SYMBOL) {
				start = i;
				int count = 0;
				while ((i < seq2.length)
						&& (seq2[i] != Constants.UNMATCHED_SYMBOL)) {
					i++;
					count++;
				}
				// I am really missing memcpy here? How does one do this better
				// in java?
				final int[] tmp1 = new int[count];
				final int[] tmp2 = new int[count];
				for (int j = 0; j < count; j++) {
					tmp1[j] = seq1[start + j];
					tmp2[j] = seq2[start + j];
				}
				final NumberSequence tmpns1 = new NumberSequence(count);
				final NumberSequence tmpns2 = new NumberSequence(count);
				tmpns1.setSequence(tmp1);
				tmpns2.setSequence(tmp2);
				final Match tmpal = new Match();
				tmpal.setFirstSequence(tmpns1);
				tmpal.setSecondSequence(tmpns2);
				// tmpal.addOccurence(new
				// MatchOccurence(start,alignment.get(0).getId()));
				// tmpal.addOccurence(new
				// MatchOccurence(start,alignment.get(1).getId()));
				result.add(tmpal);
			}
			i++;
		}
		return result;
	}

	/**
	 * Get the maximum value in the score matrix.
	 */
	@Override
	public double getMaxScore() {
		double maxScore = 0;

		// skip the first row and column
		for (int i = 1; i <= length1; i++) {
			for (int j = 1; j <= length2; j++) {
				if (matrix[i][j].getScore() > maxScore) {
					maxScore = matrix[i][j].getScore();
				}
			}
		}

		return maxScore;
	}

	@Override
	public void printAlignment() {
		final int[] tmp1 = alignment.get(0).getSequence();
		final int[] tmp2 = alignment.get(1).getSequence();
		for (int i = 0; i < tmp1.length; i++) {
			if (tmp1[i] == Constants.GAP_SYMBOL) {
				System.out.print("  ___");
			} else if (tmp1[i] == Constants.UNMATCHED_SYMBOL) {
				System.out.print("  ...");
			} else {
				System.out.format("%5d", tmp1[i]);
			}

		}
		System.out.println();
		for (int i = 0; i < tmp2.length; i++) {
			if (tmp2[i] == Constants.GAP_SYMBOL) {
				System.out.print("  ___");
			} else if (tmp2[i] == Constants.UNMATCHED_SYMBOL) {
				System.out.print("  ...");
			} else {
				System.out.format("%5d", tmp2[i]);
			}

		}
		System.out.println();

	}

	/**
	 * print the dynmaic programming matrix
	 */
	@Override
	public void printDPMatrix() {
		System.out.print("          ");
		for (int i = 1; i <= length1; i++) {
			System.out.format("%5d", input1[i - 1]);
		}
		System.out.println();
		for (int j = 0; j <= length2; j++) {
			if (j > 0) {
				System.out.format("%5d ", input2[j - 1]);
			} else {
				System.out.print("      ");
			}
			for (int i = 0; i <= (length1 + 1); i++) {
				if ((i < (length1 + 1)) || ((i == (length1 + 1)) && (j == 0))) {
					System.out.format("%4.1f ", matrix[i][j].getScore());
				}

			}
			System.out.println();
		}
	}

	public void setAlignment(ArrayList<NumberSequence> alignment) {
		this.alignment = alignment;
	}

	/**
	 * Compute the similarity score of substitution The position of the first
	 * character is 1. A position of 0 represents a gap.
	 * 
	 * @param i
	 *            Position of the character in str1
	 * @param j
	 *            Position of the character in str2
	 * @return Cost of substitution of the character in str1 by the one in str2
	 */
	private double similarity(int i, int j) {
		return submat.getScore(input1[i - 1], input2[j - 1]);
	}

	public void traceback() {
		MatrixEntry tmp = matrix[length1 + 1][0].getPrevious();
		final LinkedList<Integer> aligned1 = new LinkedList<Integer>();
		final LinkedList<Integer> aligned2 = new LinkedList<Integer>();
		while (tmp.getPrevious() != null) {

			aligned1.add(new Integer(tmp.getXvalue()));
			aligned2.add(new Integer(tmp.getYvalue()));

			tmp = tmp.getPrevious();
		}

		// reverse order of the alignment
		final int reversed1[] = new int[aligned1.size()];
		final int reversed2[] = new int[aligned2.size()];

		int count = 0;
		for (final Iterator<Integer> it = aligned1.iterator(); it.hasNext();) {
			count++;
			reversed1[reversed1.length - count] = it.next();

		}
		count = 0;
		for (final Iterator<Integer> it = aligned2.iterator(); it.hasNext();) {
			count++;
			reversed2[reversed2.length - count] = it.next();
		}

		final NumberSequence ns1 = new NumberSequence(reversed1.length);
		final NumberSequence ns2 = new NumberSequence(reversed2.length);
		ns1.setSequence(reversed1);
		ns2.setSequence(reversed2);
		ns1.setId(alignment.get(0).getId());
		ns2.setId(alignment.get(1).getId());

		alignment.set(0, ns1);
		alignment.set(1, ns2);
	}

}
