package de.ugoe.cs.eventbench.models; import java.security.InvalidParameterException; import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Random; import java.util.Set; import de.ugoe.cs.eventbench.data.Event; import de.ugoe.cs.eventbench.models.Trie.Edge; import de.ugoe.cs.eventbench.models.Trie.TrieVertex; import edu.uci.ics.jung.graph.Tree; public abstract class TrieBasedModel implements IStochasticProcess { /** * Id for object serialization. */ private static final long serialVersionUID = 1L; protected int trieOrder; protected Trie> trie; protected final Random r; public TrieBasedModel(int markovOrder, Random r) { super(); this.trieOrder = markovOrder+1; this.r = r; } public void train(List>> sequences) { trie = new Trie>(); for(List> sequence : sequences) { List> currentSequence = new LinkedList>(sequence); // defensive copy currentSequence.add(0, Event.STARTEVENT); currentSequence.add(Event.ENDEVENT); trie.train(currentSequence, trieOrder); } } /* (non-Javadoc) * @see de.ugoe.cs.eventbench.models.IStochasticProcess#randomSequence() */ @Override public List> randomSequence() { List> sequence = new LinkedList>(); IncompleteMemory> context = new IncompleteMemory>(trieOrder-1); context.add(Event.STARTEVENT); Event currentState = Event.STARTEVENT; boolean endFound = false; while(!endFound) { double randVal = r.nextDouble(); double probSum = 0.0; List> currentContext = context.getLast(trieOrder); for( Event symbol : trie.getKnownSymbols() ) { probSum += getProbability(currentContext, symbol); if( probSum>=randVal ) { endFound = (symbol==Event.ENDEVENT); if( !(symbol==Event.STARTEVENT || symbol==Event.ENDEVENT) ) { // only add the symbol the sequence if it is not START or END context.add(symbol); currentState = symbol; sequence.add(currentState); } break; } } } return sequence; } public String getTrieDotRepresentation() { return trie.getDotRepresentation(); } public Tree getTrieGraph() { return trie.getGraph(); } @Override public String toString() { return trie.toString(); } public int getNumStates() { return trie.getNumSymbols(); } public String[] getStateStrings() { String[] stateStrings = new String[getNumStates()]; int i=0; for( Event symbol : trie.getKnownSymbols() ) { stateStrings[i] = symbol.toString(); i++; } return stateStrings; } public Set> getEvents() { return trie.getKnownSymbols(); } public Set>> generateSequences(int length) { return generateSequences(length, false); /*Set>> sequenceSet = new LinkedHashSet>>();; if( length<1 ) { throw new InvalidParameterException("Length of generated subsequences must be at least 1."); } if( length==1 ) { for( Event event : getEvents() ) { List> subSeq = new LinkedList>(); subSeq.add(event); sequenceSet.add(subSeq); } return sequenceSet; } Set> events = getEvents(); Set>> seqsShorter = generateSequences(length-1); for( Event event : events ) { for( List> seqShorter : seqsShorter ) { Event lastEvent = event; if( getProbability(seqShorter, lastEvent)>0.0 ) { List> subSeq = new ArrayList>(seqShorter); subSeq.add(lastEvent); sequenceSet.add(subSeq); } } } return sequenceSet; */ } // if startValid, all sequences will start in Event.STARTEVENT public Set>> generateSequences(int length, boolean fromStart) { Set>> sequenceSet = new LinkedHashSet>>();; if( length<1 ) { throw new InvalidParameterException("Length of generated subsequences must be at least 1."); } if( length==1 ) { if( fromStart ) { List> subSeq = new LinkedList>(); subSeq.add(Event.STARTEVENT); sequenceSet.add(subSeq); } else { for( Event event : getEvents() ) { List> subSeq = new LinkedList>(); subSeq.add(event); sequenceSet.add(subSeq); } } return sequenceSet; } Set> events = getEvents(); Set>> seqsShorter = generateSequences(length-1, fromStart); for( Event event : events ) { for( List> seqShorter : seqsShorter ) { Event lastEvent = event; if( getProbability(seqShorter, lastEvent)>0.0 ) { List> subSeq = new ArrayList>(seqShorter); subSeq.add(lastEvent); sequenceSet.add(subSeq); } } } return sequenceSet; } // sequences from start to end public Set>> generateValidSequences(int length) { // check for min-length implicitly done by generateSequences Set>> allSequences = generateSequences(length, true); Set>> validSequences = new LinkedHashSet>>(); for( List> sequence : allSequences ) { if( sequence.size()==length && Event.ENDEVENT.equals(sequence.get(sequence.size()-1)) ) { validSequences.add(sequence); } } return validSequences; } }