// Copyright 2012 Georg-August-Universität Göttingen, Germany
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package de.ugoe.cs.autoquest.usageprofiles;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import de.ugoe.cs.autoquest.eventcore.Event;
import de.ugoe.cs.autoquest.usageprofiles.Trie.Edge;
import de.ugoe.cs.autoquest.usageprofiles.Trie.TrieVertex;
import edu.uci.ics.jung.graph.Tree;
/**
*
* Implements a skeleton for stochastic processes that can calculate probabilities based on a trie.
* The skeleton provides all functionalities of {@link IStochasticProcess} except
* {@link IStochasticProcess#getProbability(List, Event)}.
*
*
* @author Steffen Herbold
* @version 1.0
*/
public abstract class TrieBasedModel implements IStochasticProcess {
/**
*
* Id for object serialization.
*
*/
private static final long serialVersionUID = 1L;
/**
*
* The order of the trie, i.e., the maximum length of subsequences stored in the trie.
*
*/
protected int trieOrder;
/**
*
* Trie on which the probability calculations are based.
*
*/
protected Trie trie = null;
/**
*
* Random number generator used by probabilistic sequence generation methods.
*
*/
protected final Random r;
/**
*
* Constructor. Creates a new TrieBasedModel that can be used for stochastic processes with a
* Markov order less than or equal to {@code markovOrder}.
*
*
* @param markovOrder
* Markov order of the model
* @param r
* random number generator used by probabilistic methods of the class
* @throws IllegalArgumentException
* thrown if markovOrder is less than 0 or the random number generator r is null
*/
public TrieBasedModel(int markovOrder, Random r) {
super();
if (markovOrder < 0) {
throw new IllegalArgumentException("markov order must not be less than 0");
}
if (r == null) {
throw new IllegalArgumentException("random number generator r must not be null");
}
this.trieOrder = markovOrder + 1;
this.r = r;
}
/**
*
* Trains the model by generating a trie from which probabilities are calculated. The trie is
* newly generated based solely on the passed sequences. If an existing model should only be
* updated, use {@link #update(Collection)} instead.
*
*
* @param sequences
* training data
* @throws IllegalArgumentException
* thrown is sequences is null
*/
public void train(Collection> sequences) {
trie = null;
update(sequences);
}
/**
*
* Trains the model by updating the trie from which the probabilities are calculated. This
* function updates an existing trie. In case no trie exists yet, a new trie is generated and
* the function behaves like {@link #train(Collection)}.
*
*
* @param sequences
* training data
* @throws IllegalArgumentException
* thrown is sequences is null
*/
public void update(Collection> sequences) {
if (sequences == null) {
throw new IllegalArgumentException("sequences must not be null");
}
if (trie == null) {
trie = new Trie();
}
for (List sequence : sequences) {
List currentSequence = new LinkedList(sequence); // defensive
// copy
currentSequence.add(0, Event.STARTEVENT);
currentSequence.add(Event.ENDEVENT);
trie.train(currentSequence, trieOrder);
}
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
*/
@Override
public List randomSequence() {
return randomSequence(Integer.MAX_VALUE, true, 100);
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
*/
@Override
public List randomSequence(int maxLength, boolean validEnd, long maxIter) {
List sequence = new LinkedList();
int attempts = 0;
if (trie != null) {
boolean endFound = false;
while (!endFound && attempts <= maxIter) { // outer loop for length checking
IncompleteMemory context = new IncompleteMemory(trieOrder - 1);
context.add(Event.STARTEVENT);
while (!endFound && sequence.size() <= maxLength) {
double randVal = r.nextDouble();
double probSum = 0.0;
List currentContext = context.getLast(trieOrder);
for (Event symbol : trie.getKnownSymbols()) {
probSum += getProbability(currentContext, symbol);
if (probSum >= randVal) {
if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT.equals(symbol)))
{
// only add the symbol the sequence if it is not
// START or END
context.add(symbol);
sequence.add(symbol);
}
endFound =
(Event.ENDEVENT.equals(symbol)) ||
(!validEnd && sequence.size() == maxLength);
break;
}
}
}
if (!endFound) {
sequence = new LinkedList();
}
attempts++;
}
}
return sequence;
}
/**
*
* Returns a Dot representation of the internal trie.
*
*
* @return dot representation of the internal trie
*/
public String getTrieDotRepresentation() {
if (trie == null) {
return "";
}
else {
return trie.getDotRepresentation();
}
}
/**
*
* Returns a {@link Tree} of the internal trie that can be used for visualization.
*
*
* @return {@link Tree} depicting the internal trie
*/
public Tree getTrieGraph() {
if (trie == null) {
return null;
}
else {
return trie.getGraph();
}
}
/**
*
* The string representation of the model is {@link Trie#toString()} of {@link #trie}.
*
*
* @see java.lang.Object#toString()
*/
@Override
public String toString() {
if (trie == null) {
return "";
}
else {
return trie.toString();
}
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumStates()
*/
@Override
public int getNumSymbols() {
if (trie == null) {
return 0;
}
else {
return trie.getNumSymbols();
}
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getStateStrings()
*/
@Override
public String[] getSymbolStrings() {
if (trie == null) {
return new String[0];
}
String[] stateStrings = new String[getNumSymbols()];
int i = 0;
for (Event symbol : trie.getKnownSymbols()) {
if (symbol.toString() == null) {
stateStrings[i] = "null";
}
else {
stateStrings[i] = symbol.toString();
}
i++;
}
return stateStrings;
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getEvents()
*/
@Override
public Collection getEvents() {
if (trie == null) {
return new HashSet();
}
else {
return trie.getKnownSymbols();
}
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int)
*/
@Override
public Collection> generateSequences(int length) {
return generateSequences(length, false);
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int, boolean)
*/
@Override
public Set> generateSequences(int length, boolean fromStart) {
Set> sequenceSet = new LinkedHashSet>();
if (length < 1) {
throw new IllegalArgumentException(
"Length of generated subsequences must be at least 1.");
}
if (length == 1) {
if (fromStart) {
List subSeq = new LinkedList();
subSeq.add(Event.STARTEVENT);
sequenceSet.add(subSeq);
}
else {
for (Event event : getEvents()) {
List subSeq = new LinkedList();
subSeq.add(event);
sequenceSet.add(subSeq);
}
}
return sequenceSet;
}
Collection events = getEvents();
Collection> seqsShorter = generateSequences(length - 1, fromStart);
for (Event event : events) {
for (List seqShorter : seqsShorter) {
Event lastEvent = event;
if (getProbability(seqShorter, lastEvent) > 0.0) {
List subSeq = new ArrayList(seqShorter);
subSeq.add(lastEvent);
sequenceSet.add(subSeq);
}
}
}
return sequenceSet;
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateValidSequences (int)
*/
@Override
public Collection> generateValidSequences(int length) {
// check for min-length implicitly done by generateSequences
Collection> allSequences = generateSequences(length, true);
Collection> validSequences = new LinkedHashSet>();
for (List sequence : allSequences) {
if (sequence.size() == length &&
Event.ENDEVENT.equals(sequence.get(sequence.size() - 1)))
{
validSequences.add(sequence);
}
}
return validSequences;
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
*/
@Override
public double getProbability(List sequence) {
if (sequence == null) {
throw new IllegalArgumentException("sequence must not be null");
}
double prob = 1.0;
List context = new LinkedList();
for (Event event : sequence) {
prob *= getProbability(context, event);
context.add(event);
}
return prob;
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
*/
@Override
public double getLogSum(List sequence) {
if (sequence == null) {
throw new IllegalArgumentException("sequence must not be null");
}
double odds = 0.0;
List context = new LinkedList();
for (Event event : sequence) {
odds += Math.log(getProbability(context, event) + 1);
context.add(event);
}
return odds;
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumFOMStates()
*/
@Override
public int getNumFOMStates() {
if (trie == null) {
return 0;
}
else {
return trie.getNumLeafAncestors();
}
}
/*
* (non-Javadoc)
*
* @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumTransitions()
*/
@Override
public int getNumTransitions() {
if (trie == null) {
return 0;
}
else {
return trie.getNumLeafs();
}
}
}