package de.ugoe.cs.quest.usageprofiles;

import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;

import de.ugoe.cs.quest.eventcore.Event;
import de.ugoe.cs.quest.usageprofiles.Trie.Edge;
import de.ugoe.cs.quest.usageprofiles.Trie.TrieVertex;
import edu.uci.ics.jung.graph.Tree;

/**
 * <p>
 * Implements a skeleton for stochastic processes that can calculate probabilities based on a trie.
 * The skeleton provides all functionalities of {@link IStochasticProcess} except
 * {@link IStochasticProcess#getProbability(List, Event)}.
 * </p>
 * 
 * @author Steffen Herbold
 * @version 1.0
 */
public abstract class TrieBasedModel implements IStochasticProcess {

    /**
     * <p>
     * Id for object serialization.
     * </p>
     */
    private static final long serialVersionUID = 1L;

    /**
     * <p>
     * The order of the trie, i.e., the maximum length of subsequences stored in the trie.
     * </p>
     */
    protected int trieOrder;

    /**
     * <p>
     * Trie on which the probability calculations are based.
     * </p>
     */
    protected Trie<Event> trie = null;

    /**
     * <p>
     * Random number generator used by probabilistic sequence generation methods.
     * </p>
     */
    protected final Random r;

    /**
     * <p>
     * Constructor. Creates a new TrieBasedModel that can be used for stochastic processes with a
     * Markov order less than or equal to {@code markovOrder}.
     * </p>
     * 
     * @param markovOrder
     *            Markov order of the model
     * @param r
     *            random number generator used by probabilistic methods of the class
     * @throws InvalidParameterException
     *             thrown if markovOrder is less than 0 or the random number generator r is null
     */
    public TrieBasedModel(int markovOrder, Random r) {
        super();
        if (markovOrder < 0) {
            throw new InvalidParameterException("markov order must not be less than 0");
        }
        if (r == null) {
            throw new InvalidParameterException("random number generator r must not be null");
        }
        this.trieOrder = markovOrder + 1;
        this.r = r;
    }

    /**
     * <p>
     * Trains the model by generating a trie from which probabilities are calculated. The trie is
     * newly generated based solely on the passed sequences. If an existing model should only be
     * updated, use {@link #update(Collection)} instead.
     * </p>
     * 
     * @param sequences
     *            training data
     * @throws InvalidParameterException
     *             thrown is sequences is null
     */
    public void train(Collection<List<Event>> sequences) {
        trie = null;
        update(sequences);
    }

    /**
     * <p>
     * Trains the model by updating the trie from which the probabilities are calculated. This
     * function updates an existing trie. In case no trie exists yet, a new trie is generated and
     * the function behaves like {@link #train(Collection)}.
     * </p>
     * 
     * @param sequences
     *            training data
     * @throws InvalidParameterException
     *             thrown is sequences is null
     */
    public void update(Collection<List<Event>> sequences) {
        if (sequences == null) {
            throw new InvalidParameterException("sequences must not be null");
        }
        if (trie == null) {
            trie = new Trie<Event>();
        }
        for (List<Event> sequence : sequences) {
            List<Event> currentSequence = new LinkedList<Event>(sequence); // defensive
                                                                           // copy
            currentSequence.add(0, Event.STARTEVENT);
            currentSequence.add(Event.ENDEVENT);

            trie.train(currentSequence, trieOrder);
        }
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#randomSequence()
     */
    @Override
    public List<Event> randomSequence() {
        return randomSequence(Integer.MAX_VALUE, true);
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#randomSequence()
     */
    @Override
    public List<Event> randomSequence(int maxLength, boolean validEnd) {
        List<Event> sequence = new LinkedList<Event>();
        if (trie != null) {
            boolean endFound = false;
            while (!endFound) { // outer loop for length checking
                sequence = new LinkedList<Event>();
                IncompleteMemory<Event> context = new IncompleteMemory<Event>(trieOrder - 1);
                context.add(Event.STARTEVENT);

                while (!endFound && sequence.size() <= maxLength) {
                    double randVal = r.nextDouble();
                    double probSum = 0.0;
                    List<Event> currentContext = context.getLast(trieOrder);
                    for (Event symbol : trie.getKnownSymbols()) {
                        probSum += getProbability(currentContext, symbol);
                        if (probSum >= randVal) {
                            if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT.equals(symbol)))
                            {
                                // only add the symbol the sequence if it is not
                                // START or END
                                context.add(symbol);
                                sequence.add(symbol);
                            }
                            endFound =
                                (Event.ENDEVENT.equals(symbol)) ||
                                    (!validEnd && sequence.size() == maxLength);
                            break;
                        }
                    }
                }
            }
        }
        return sequence;
    }

    /**
     * <p>
     * Returns a Dot representation of the internal trie.
     * </p>
     * 
     * @return dot representation of the internal trie
     */
    public String getTrieDotRepresentation() {
        if (trie == null) {
            return "";
        }
        else {
            return trie.getDotRepresentation();
        }
    }

    /**
     * <p>
     * Returns a {@link Tree} of the internal trie that can be used for visualization.
     * </p>
     * 
     * @return {@link Tree} depicting the internal trie
     */
    public Tree<TrieVertex, Edge> getTrieGraph() {
        if (trie == null) {
            return null;
        }
        else {
            return trie.getGraph();
        }
    }

    /**
     * <p>
     * The string representation of the model is {@link Trie#toString()} of {@link #trie}.
     * </p>
     * 
     * @see java.lang.Object#toString()
     */
    @Override
    public String toString() {
        if (trie == null) {
            return "";
        }
        else {
            return trie.toString();
        }
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumStates()
     */
    @Override
    public int getNumSymbols() {
        if (trie == null) {
            return 0;
        }
        else {
            return trie.getNumSymbols();
        }
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getStateStrings()
     */
    @Override
    public String[] getSymbolStrings() {
        if (trie == null) {
            return new String[0];
        }
        String[] stateStrings = new String[getNumSymbols()];
        int i = 0;
        for (Event symbol : trie.getKnownSymbols()) {
            if (symbol.toString() == null) {
                stateStrings[i] = "null";
            }
            else {
                stateStrings[i] = symbol.toString();
            }
            i++;
        }
        return stateStrings;
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getEvents()
     */
    @Override
    public Collection<Event> getEvents() {
        if (trie == null) {
            return new HashSet<Event>();
        }
        else {
            return trie.getKnownSymbols();
        }
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateSequences(int)
     */
    @Override
    public Collection<List<Event>> generateSequences(int length) {
        return generateSequences(length, false);
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateSequences(int, boolean)
     */
    @Override
    public Set<List<Event>> generateSequences(int length, boolean fromStart) {
        Set<List<Event>> sequenceSet = new LinkedHashSet<List<Event>>();
        if (length < 1) {
            throw new InvalidParameterException(
                                                "Length of generated subsequences must be at least 1.");
        }
        if (length == 1) {
            if (fromStart) {
                List<Event> subSeq = new LinkedList<Event>();
                subSeq.add(Event.STARTEVENT);
                sequenceSet.add(subSeq);
            }
            else {
                for (Event event : getEvents()) {
                    List<Event> subSeq = new LinkedList<Event>();
                    subSeq.add(event);
                    sequenceSet.add(subSeq);
                }
            }
            return sequenceSet;
        }
        Collection<Event> events = getEvents();
        Collection<List<Event>> seqsShorter = generateSequences(length - 1, fromStart);
        for (Event event : events) {
            for (List<Event> seqShorter : seqsShorter) {
                Event lastEvent = event;
                if (getProbability(seqShorter, lastEvent) > 0.0) {
                    List<Event> subSeq = new ArrayList<Event>(seqShorter);
                    subSeq.add(lastEvent);
                    sequenceSet.add(subSeq);
                }
            }
        }
        return sequenceSet;
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateValidSequences (int)
     */
    @Override
    public Collection<List<Event>> generateValidSequences(int length) {
        // check for min-length implicitly done by generateSequences
        Collection<List<Event>> allSequences = generateSequences(length, true);
        Collection<List<Event>> validSequences = new LinkedHashSet<List<Event>>();
        for (List<Event> sequence : allSequences) {
            if (sequence.size() == length &&
                Event.ENDEVENT.equals(sequence.get(sequence.size() - 1)))
            {
                validSequences.add(sequence);
            }
        }
        return validSequences;
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
     */
    @Override
    public double getProbability(List<Event> sequence) {
        if (sequence == null) {
            throw new InvalidParameterException("sequence must not be null");
        }
        double prob = 1.0;
        List<Event> context = new LinkedList<Event>();
        for (Event event : sequence) {
            prob *= getProbability(context, event);
            context.add(event);
        }
        return prob;
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumFOMStates()
     */
    @Override
    public int getNumFOMStates() {
        if (trie == null) {
            return 0;
        }
        else {
            return trie.getNumLeafAncestors();
        }
    }

    /*
     * (non-Javadoc)
     * 
     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumTransitions()
     */
    @Override
    public int getNumTransitions() {
        if (trie == null) {
            return 0;
        }
        else {
            return trie.getNumLeafs();
        }
    }
}
