package de.ugoe.cs.eventbench.ppm;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;

import de.ugoe.cs.eventbench.data.Event;
import de.ugoe.cs.eventbench.markov.IncompleteMemory;
import de.ugoe.cs.util.console.Console;

public class PredictionByPartialMatch {
	
	private String initialSymbol = "GS";
	private String endSymbol = "GE";
	
	private int maxOrder = 3;
	
	private Trie<String> trie;
	
	private Set<String> knownSymbols;
	
	private double probEscape = 0.2d; // TODO getter/setter - steering parameter!
	
	private Random r = new Random(); // TODO should be defined in the constructor
	
	// the training is basically the generation of the trie
	public void train(List<List<Event<?>>> sequences) {
		trie = new Trie<String>();
		knownSymbols = new LinkedHashSet<String>();
		knownSymbols.add(initialSymbol);
		knownSymbols.add(endSymbol);
		
		for(List<Event<?>> sequence : sequences) {
			List<String> stringSequence = new LinkedList<String>();
			stringSequence.add(initialSymbol);
			for( Event<?> event : sequence ) {
				stringSequence.add(event.getStandardId());
			}
			stringSequence.add(endSymbol);
			
			trainStringTrie(stringSequence);
		}
	}
	
	private void trainStringTrie(List<String> sequence) {
		knownSymbols = new LinkedHashSet<String>();		
		IncompleteMemory<String> latestActions = new IncompleteMemory<String>(maxOrder);
		int i=0;
		for(String currentAction : sequence) {
			String currentId = currentAction;
			latestActions.add(currentId);
			knownSymbols.add(currentId);
			i++;
			if( i>=maxOrder ) {
				trie.add(latestActions.getLast(maxOrder));
			}
		}
		int sequenceLength = sequence.size();
		for( int j=maxOrder-1 ; j>0 ; j-- ) {
			trie.add(sequence.subList(sequenceLength-j, sequenceLength));
		}
	}
	
	public void testStuff() {
		// basically an inline unit test without assertions but manual observation
		List<String> list = new ArrayList<String>();
		list.add(initialSymbol);
		list.add("a");
		list.add("b");
		list.add("r");
		list.add("a");
		list.add("c");
		list.add("a");
		list.add("d");
		list.add("a");
		list.add("b");
		list.add("r");
		list.add("a");
		list.add(endSymbol);
		
		PredictionByPartialMatch model = new PredictionByPartialMatch();
		model.trie = new Trie<String>();
		model.trainStringTrie(list);
		model.trie.display();
		Console.println("------------------------");
		model.randomSequence();/*
		Console.println("------------------------");
		model.randomSequence();
		Console.println("------------------------");
		model.randomSequence();
		Console.println("------------------------");*/
		
		List<String> context = new ArrayList<String>();
		String symbol = "a";
		// expected: 5
		Console.traceln(""+model.trie.getCount(context, symbol));
		
		// expected: 0
		context.add("b");
		Console.traceln(""+model.trie.getCount(context, symbol));
		
		// expected: 2
		context.add("r");
		Console.traceln(""+model.trie.getCount(context, symbol));
		
		// exptected: [b, r]
		context = new ArrayList<String>();
		context.add("a");
		context.add("b");
		context.add("r");
		Console.traceln(model.trie.getContextSuffix(context).toString());
		
		// exptected: []
		context = new ArrayList<String>();
		context.add("e");
		Console.traceln(model.trie.getContextSuffix(context).toString());
		
		// exptected: {a, b, c, d, r}
		context = new ArrayList<String>();
		Console.traceln(model.trie.getFollowingSymbols(context).toString());
		
		// exptected: {b, c, d}
		context = new ArrayList<String>();
		context.add("a");
		Console.traceln(model.trie.getFollowingSymbols(context).toString());
		
		// exptected: []
		context = new ArrayList<String>();
		context.add("a");
		context.add("b");
		context.add("r");
		Console.traceln(model.trie.getFollowingSymbols(context).toString());
	}
	
	// TODO needs to be changed from String to <? extends Event>
	public List<String> randomSequence() {
		List<String> sequence = new LinkedList<String>();
		
		IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1);
		context.add(initialSymbol);
		sequence.add(initialSymbol);
		
		String currentState = initialSymbol;
		
		Console.println(currentState);
		while(!endSymbol.equals(currentState)) {
			double randVal = r.nextDouble();
			double probSum = 0.0;
			List<String> currentContext = context.getLast(maxOrder);
			for( String symbol : knownSymbols ) {
				probSum += getProbability(currentContext, symbol);
				if( probSum>=randVal ) {
					context.add(symbol);
					currentState = symbol;
					sequence.add(currentState);
					break;
				}
			}
		}
		return sequence;
	}
	
	/*public void printRandomWalk(Random r) {
		IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1);
		
		context.add(initialSymbol);
		
		String currentState = initialSymbol;
		
		Console.println(currentState);
		while(!endSymbol.equals(currentState)) {
			double randVal = r.nextDouble();
			double probSum = 0.0;
			List<String> currentContext = context.getLast(maxOrder);
			// DEBUG //
			Console.traceln("Context: " + currentContext.toString());
			double tmpSum = 0.0d;
			for( String symbol : knownSymbols ) {
				double prob = getProbability(currentContext, symbol);
				tmpSum += prob;
				Console.traceln(symbol + ": " + prob);
			}
			Console.traceln("Sum: " + tmpSum);
			// DEBUG-END //
			for( String symbol : knownSymbols ) {
				probSum += getProbability(currentContext, symbol);
				if( probSum>=randVal-0.3 ) {
					context.add(symbol);
					currentState = symbol;
					Console.println(currentState);
					break;
				}
			}
		}
	}*/
	
	private double getProbability(List<String> context, String symbol) {
		// FIXME needs exception handling for unknown symbols
		// if the symbol is not contained in the trie, context.remove(0) will fail
		double result = 0.0d;
		double resultCurrentContex = 0.0d;
		double resultShorterContex = 0.0d;
		
		List<String> contextCopy = new LinkedList<String>(context); // defensive copy

	
		List<String> followers = trie.getFollowingSymbols(contextCopy); // \Sigma'
		int sumCountFollowers = 0; // N(s\sigma')
		for( String follower : followers ) {
			sumCountFollowers += trie.getCount(contextCopy, follower);
		}
		
		int countSymbol = trie.getCount(contextCopy, symbol); // N(s\sigma)
		if( contextCopy.size()==0 ) {
			resultCurrentContex = ((double) countSymbol) / sumCountFollowers;
		} else {
			resultCurrentContex = ((double) countSymbol / sumCountFollowers)*(1-probEscape);
			contextCopy.remove(0); 
			double probSuffix = getProbability(contextCopy, symbol);
			if( followers.size()==0 ) {
				resultShorterContex = probSuffix;
			} else {
				resultShorterContex = probEscape*probSuffix;
			}
		}
		result = resultCurrentContex+resultShorterContex;
		
		return result;
	}
	
	/*
	private double getProbability(List<String> context, String symbol) {
		double result = 0.0; 
		int countContextSymbol = 0;
		List<String> contextSuffix = trie.getContextSuffix(context);
		if( contextSuffix.isEmpty() ) {
			// unobserved context! everything is possible... assuming identical distribution
			result = 1.0d / knownSymbols.size(); // why 1.0 and not N(symbol)
		} else {
			countContextSymbol = trie.getCount(contextSuffix, symbol);
			List<String> followers = trie.getFollowingSymbols(contextSuffix);
			int countContextFollowers = 0;
			for( String follower : followers ) {
				countContextFollowers += trie.getCount(contextSuffix, follower);
			}
			
			if( followers.isEmpty() ) {
				throw new AssertionError("Invalid return value of trie.getContextSuffix()!");
			}
			if( countContextSymbol!=0 ) {
				result = ((double) countContextSymbol) / (followers.size()+countContextFollowers);
			} else { // escape
				double probEscape = ((double) followers.size()) / (followers.size()+countContextFollowers);
				contextSuffix.remove(0); 
				double probSuffix = getProbability(contextSuffix, symbol);
				result = probEscape*probSuffix;
			}
		}

		return result;
	}*/
	
	@Override
	public String toString() {
		return trie.toString();
	}
}
