package de.ugoe.cs.eventbench.models;

import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;

import de.ugoe.cs.eventbench.data.Event;
import de.ugoe.cs.eventbench.models.Trie.Edge;
import de.ugoe.cs.eventbench.models.Trie.TrieVertex;
import edu.uci.ics.jung.graph.Tree;

public abstract class TrieBasedModel implements IStochasticProcess {

	/**
	 * Id for object serialization.
	 */
	private static final long serialVersionUID = 1L;

	protected int trieOrder;

	protected Trie<Event<?>> trie;
	protected final Random r;

	
	public TrieBasedModel(int markovOrder, Random r) {
		super();
		this.trieOrder = markovOrder+1;
		this.r = r;
	}

	public void train(List<List<Event<?>>> sequences) {
		trie = new Trie<Event<?>>();
		
		for(List<Event<?>> sequence : sequences) {
			List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive copy
			currentSequence.add(0, Event.STARTEVENT);
			currentSequence.add(Event.ENDEVENT);
			
			trie.train(currentSequence, trieOrder);
		}
	}

	/* (non-Javadoc)
	 * @see de.ugoe.cs.eventbench.models.IStochasticProcess#randomSequence()
	 */
	@Override
	public List<? extends Event<?>> randomSequence() {
		List<Event<?>> sequence = new LinkedList<Event<?>>();
		
		IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(trieOrder-1);
		context.add(Event.STARTEVENT);
		
		Event<?> currentState = Event.STARTEVENT;
		
		boolean endFound = false;
		
		while(!endFound) {
			double randVal = r.nextDouble();
			double probSum = 0.0;
			List<Event<?>> currentContext = context.getLast(trieOrder);
			for( Event<?> symbol : trie.getKnownSymbols() ) {
				probSum += getProbability(currentContext, symbol);
				if( probSum>=randVal ) {
					endFound = (symbol==Event.ENDEVENT);
					if( !(symbol==Event.STARTEVENT || symbol==Event.ENDEVENT) ) {
						// only add the symbol the sequence if it is not START or END
						context.add(symbol);
						currentState = symbol;
						sequence.add(currentState);
					}
					break;
				}
			}
		}
		return sequence;
	}
	
	public String getTrieDotRepresentation() {
		return trie.getDotRepresentation();
	}
	
	public Tree<TrieVertex, Edge> getTrieGraph() {
		return trie.getGraph();
	}

	@Override
	public String toString() {
		return trie.toString();
	}
	
	public int getNumStates() {
		return trie.getNumSymbols();
	}
	
	public String[] getStateStrings() {
		String[] stateStrings = new String[getNumStates()];
		int i=0;
		for( Event<?> symbol : trie.getKnownSymbols() ) {
			stateStrings[i] = symbol.toString();
			i++;
		}
		return stateStrings;
	}
	
	public Set<? extends Event<?>> getEvents() {
		return trie.getKnownSymbols();
	}
	
	public Set<List<? extends Event<?>>> generateSequences(int length) {
		return generateSequences(length, false);
		/*Set<List<? extends Event<?>>> sequenceSet = new LinkedHashSet<List<? extends Event<?>>>();;
		if( length<1 ) {
			throw new InvalidParameterException("Length of generated subsequences must be at least 1.");
		}
		if( length==1 ) {
			for( Event<?> event : getEvents() ) {
				List<Event<?>> subSeq = new LinkedList<Event<?>>();
				subSeq.add(event);
				sequenceSet.add(subSeq);
			}
			return sequenceSet;
		}
		Set<? extends Event<?>> events = getEvents();
		Set<List<? extends Event<?>>> seqsShorter = generateSequences(length-1);
		for( Event<?> event : events ) {
			for( List<? extends Event<?>> seqShorter : seqsShorter ) {
				Event<?> lastEvent = event;
				if( getProbability(seqShorter, lastEvent)>0.0 ) {
					List<Event<?>> subSeq = new ArrayList<Event<?>>(seqShorter);
					subSeq.add(lastEvent);
					sequenceSet.add(subSeq);
				}
			}
		}
		return sequenceSet;
		*/
	}
	
	// if startValid, all sequences will start in Event.STARTEVENT
	public Set<List<? extends Event<?>>> generateSequences(int length, boolean fromStart) {
		Set<List<? extends Event<?>>> sequenceSet = new LinkedHashSet<List<? extends Event<?>>>();;
		if( length<1 ) {
			throw new InvalidParameterException("Length of generated subsequences must be at least 1.");
		}
		if( length==1 ) {
			if( fromStart ) {
				List<Event<?>> subSeq = new LinkedList<Event<?>>();
				subSeq.add(Event.STARTEVENT);
			} else {
				for( Event<?> event : getEvents() ) {
					List<Event<?>> subSeq = new LinkedList<Event<?>>();
					subSeq.add(event);
					sequenceSet.add(subSeq);
				}
			}
			return sequenceSet;
		}
		Set<? extends Event<?>> events = getEvents();
		Set<List<? extends Event<?>>> seqsShorter = generateSequences(length-1);
		for( Event<?> event : events ) {
			for( List<? extends Event<?>> seqShorter : seqsShorter ) {
				Event<?> lastEvent = event;
				if( getProbability(seqShorter, lastEvent)>0.0 ) {
					List<Event<?>> subSeq = new ArrayList<Event<?>>(seqShorter);
					subSeq.add(lastEvent);
					sequenceSet.add(subSeq);
				}
			}
		}
		return sequenceSet;
	}
	
	// sequences from start to end
	public Set<List<? extends Event<?>>> generateValidSequences(int length) {
		// check for min-length implicitly done by generateSequences
		Set<List<? extends Event<?>>> validSequences = generateSequences(length, true);
		for( List<? extends Event<?>> sequence : validSequences ) {
			if( sequence.size()!=length ) {
				validSequences.remove(sequence);
			} else {
				if( !Event.ENDEVENT.equals(sequence.get(sequence.size()-1)) ) {
					validSequences.remove(sequence);
				}
			}
		}
		return validSequences;
	}

}