package de.ugoe.cs.eventbench.models;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import de.ugoe.cs.eventbench.data.Event;
import de.ugoe.cs.util.StringTools;
import de.ugoe.cs.util.console.Console;
import edu.uci.ics.jung.graph.Graph;
import edu.uci.ics.jung.graph.SparseMultigraph;
import edu.uci.ics.jung.graph.util.EdgeType;

import Jama.Matrix;

public class FirstOrderMarkovModel extends HighOrderMarkovModel implements IDotCompatible {

	final static int MAX_STATDIST_ITERATIONS = 1000;
	
	public FirstOrderMarkovModel(Random r) {
		super(1, r);
	}
	
	private Matrix getTransmissionMatrix() {
		List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
		int numStates = knownSymbols.size();
		Matrix transmissionMatrix = new Matrix(numStates, numStates);
		
		for( int i=0 ; i<numStates ; i++ ) {
			Event<?> currentSymbol = knownSymbols.get(i);
			List<Event<?>> context = new ArrayList<Event<?>>();
			context.add(currentSymbol);
			for( int j=0 ; j<numStates ; j++ ) {
				Event<?> follower = knownSymbols.get(j);
				double prob = getProbability(context, follower);
				transmissionMatrix.set(i, j, prob);
			}
		}
		return transmissionMatrix;
	}
	
	public String getDotRepresentation() {
		StringBuilder stringBuilder = new StringBuilder();
		stringBuilder.append("digraph model {" + StringTools.ENDLINE);

		List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
		
		for( Event<?> symbol : knownSymbols) {
			final String thisSaneId = symbol.getShortId().replace("\"", "\\\"").replaceAll("[\r\n]","");
			stringBuilder.append(" " + symbol.hashCode() + " [label=\""+thisSaneId+"\"];" + StringTools.ENDLINE);
			List<Event<?>> context = new ArrayList<Event<?>>();
			context.add(symbol); 
			List<Event<?>> followers = trie.getFollowingSymbols(context);
			for( Event<?> follower : followers ) {
				stringBuilder.append(" "+symbol.hashCode()+" -> " + follower.hashCode() + " ");
				stringBuilder.append("[label=\"" + getProbability(context, follower) + "\"];" + StringTools.ENDLINE);
			}
		}
		stringBuilder.append('}' + StringTools.ENDLINE);
		return stringBuilder.toString();
	}
	
	public Graph<String, MarkovEdge> getGraph() {
		Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
		
		List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
		
		for( Event<?> symbol : knownSymbols) {
			String from = symbol.getShortId();
			List<Event<?>> context = new ArrayList<Event<?>>();
			context.add(symbol); 
			
			List<Event<?>> followers = trie.getFollowingSymbols(context);
			
			for( Event<?> follower : followers ) {
				String to = follower.getShortId();
				MarkovEdge prob = new MarkovEdge(getProbability(context, follower));
				graph.addEdge(prob, from, to, EdgeType.DIRECTED);
			}
		}
		return graph;
	}
	
	static public class MarkovEdge {
		double weight;
		MarkovEdge(double weight) { this.weight = weight; }
		public String toString() { return ""+weight; }
	}
	
	public double calcEntropy() {
		Matrix transmissionMatrix = getTransmissionMatrix();
		List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
		int numStates = knownSymbols.size();
		
		int startStateIndex = knownSymbols.indexOf(Event.STARTEVENT);
		int endStateIndex = knownSymbols.indexOf(Event.ENDEVENT);
		if( startStateIndex==-1 ) {
			Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
			return Double.NaN;
		}
		if( endStateIndex==-1 ) {
			Console.printerrln("Error calculating entropy. End state of markov chain not found.");
			return Double.NaN;
		}
		transmissionMatrix.set(endStateIndex, startStateIndex, 1);
		
		// Calculate stationary distribution by raising the power of the transmission matrix.
		// The rank of the matrix should fall to 1 and each two should be the vector of the
		// stationory distribution. 
		int iter = 0;
		int rank = transmissionMatrix.rank();
		Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
		while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
			stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
			rank = stationaryMatrix.rank();
			iter++;
		}
		
		if( rank!=1 ) {
			Console.traceln("rank: " + rank);
			Console.printerrln("Unable to calculate stationary distribution.");
			return Double.NaN;
		}
		
		double entropy = 0.0;
		for( int i=0 ; i<numStates ; i++ ) {
			for( int j=0 ; j<numStates ; j++ ) {
				if( transmissionMatrix.get(i,j)!=0 ) {
					double tmp = stationaryMatrix.get(i, 0);
					tmp *= transmissionMatrix.get(i, j);
					tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
					entropy -= tmp;
				}
			}
		}
		return entropy;
	}

}
