package de.ugoe.cs.eventbench.ppm;

import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;

import de.ugoe.cs.eventbench.data.Event;
import de.ugoe.cs.eventbench.markov.IncompleteMemory;
import de.ugoe.cs.util.console.Console;

public class PredictionByPartialMatch {
	
	private String initialSymbol = "GLOBALSTARTSTATE";
	private String endSymbol = "GLOBALENDSTATE";
	
	private int maxOrder = 3;
	
	private Trie<String> trie;
	
	private Set<String> knownSymbols;
	
	// the training is basically the generation of the trie
	public void train(List<List<Event<?>>> sequences) {
		trie = new Trie<String>();
		knownSymbols = new LinkedHashSet<String>();
		knownSymbols.add(initialSymbol);
		knownSymbols.add(endSymbol);
		
		for(List<Event<?>> sequence : sequences) {
			IncompleteMemory<String> latestActions = new IncompleteMemory<String>(maxOrder); // TODO need to check if it should be maxOrder+1
			latestActions.add(initialSymbol);
			for(Event<?> currentAction : sequence) {
				String currentId = currentAction.getStandardId();
				latestActions.add(currentId);
				knownSymbols.add(currentId);
				if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder
					trie.add(latestActions.getLast(maxOrder));
				}
			}
			latestActions.add(endSymbol);
			if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder
				trie.add(latestActions.getLast(maxOrder));
			}
		}
	}
	
	public void printRandomWalk(Random r) {
		IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1);
		
		context.add(initialSymbol);
		
		String currentState = initialSymbol;
		
		Console.println(currentState);
		while(!endSymbol.equals(currentState)) {
			double randVal = r.nextDouble();
			double probSum = 0.0;
			List<String> currentContext = context.getLast(maxOrder);
			for( String symbol : knownSymbols ) {
				probSum += getProbability(currentContext, symbol);
				if( probSum>=randVal ) {
					currentContext.add(symbol);
					currentState = symbol;
					Console.println(currentState);
					break;
				}
			}
		}
	}
	
	private double getProbability(List<String> context, String symbol) {
		double result = 0.0; 
		int countContextSymbol = 0;
		List<String> contextSuffix = trie.getContextSuffix(context);
		if( contextSuffix.isEmpty() ) {
			result = 1.0d / knownSymbols.size(); 
		} else {
			countContextSymbol = trie.getCount(contextSuffix, symbol);
			List<String> followers = trie.getFollowingSymbols(contextSuffix);
			int countContextFollowers = 0;
			for( String follower : followers ) {
				countContextFollowers += trie.getCount(contextSuffix, follower);
			}
			
			if( followers.isEmpty() ) {
				throw new AssertionError("Invalid return value of getContextSuffix!");
			}
			if( countContextSymbol!=0 ) {
				result = ((double) countContextSymbol) / (followers.size()+countContextFollowers);
			} else { // escape
				double probEscape = ((double) followers.size()) / (followers.size()+countContextFollowers);
				contextSuffix.remove(0);
				double probSuffix = getProbability(contextSuffix, symbol);
				result = probEscape*probSuffix;
			}
		}

		return result;
	}
	
	@Override
	public String toString() {
		return trie.toString();
	}
}
