1 | package de.ugoe.cs.eventbench.models;
|
---|
2 |
|
---|
3 | import java.security.InvalidParameterException;
|
---|
4 | import java.util.ArrayList;
|
---|
5 | import java.util.LinkedHashSet;
|
---|
6 | import java.util.LinkedList;
|
---|
7 | import java.util.List;
|
---|
8 | import java.util.Random;
|
---|
9 | import java.util.Set;
|
---|
10 |
|
---|
11 | import de.ugoe.cs.eventbench.data.Event;
|
---|
12 | import de.ugoe.cs.eventbench.models.Trie.Edge;
|
---|
13 | import de.ugoe.cs.eventbench.models.Trie.TrieVertex;
|
---|
14 | import edu.uci.ics.jung.graph.Tree;
|
---|
15 |
|
---|
16 | public abstract class TrieBasedModel implements IStochasticProcess {
|
---|
17 |
|
---|
18 | /**
|
---|
19 | * Id for object serialization.
|
---|
20 | */
|
---|
21 | private static final long serialVersionUID = 1L;
|
---|
22 |
|
---|
23 | protected int trieOrder;
|
---|
24 |
|
---|
25 | protected Trie<Event<?>> trie;
|
---|
26 | protected final Random r;
|
---|
27 |
|
---|
28 |
|
---|
29 | public TrieBasedModel(int markovOrder, Random r) {
|
---|
30 | super();
|
---|
31 | this.trieOrder = markovOrder+1;
|
---|
32 | this.r = r;
|
---|
33 | }
|
---|
34 |
|
---|
35 | public void train(List<List<Event<?>>> sequences) {
|
---|
36 | trie = new Trie<Event<?>>();
|
---|
37 |
|
---|
38 | for(List<Event<?>> sequence : sequences) {
|
---|
39 | List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive copy
|
---|
40 | currentSequence.add(0, Event.STARTEVENT);
|
---|
41 | currentSequence.add(Event.ENDEVENT);
|
---|
42 |
|
---|
43 | trie.train(currentSequence, trieOrder);
|
---|
44 | }
|
---|
45 | }
|
---|
46 |
|
---|
47 | /* (non-Javadoc)
|
---|
48 | * @see de.ugoe.cs.eventbench.models.IStochasticProcess#randomSequence()
|
---|
49 | */
|
---|
50 | @Override
|
---|
51 | public List<? extends Event<?>> randomSequence() {
|
---|
52 | List<Event<?>> sequence = new LinkedList<Event<?>>();
|
---|
53 |
|
---|
54 | IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(trieOrder-1);
|
---|
55 | context.add(Event.STARTEVENT);
|
---|
56 |
|
---|
57 | Event<?> currentState = Event.STARTEVENT;
|
---|
58 |
|
---|
59 | boolean endFound = false;
|
---|
60 |
|
---|
61 | while(!endFound) {
|
---|
62 | double randVal = r.nextDouble();
|
---|
63 | double probSum = 0.0;
|
---|
64 | List<Event<?>> currentContext = context.getLast(trieOrder);
|
---|
65 | for( Event<?> symbol : trie.getKnownSymbols() ) {
|
---|
66 | probSum += getProbability(currentContext, symbol);
|
---|
67 | if( probSum>=randVal ) {
|
---|
68 | endFound = (symbol==Event.ENDEVENT);
|
---|
69 | if( !(symbol==Event.STARTEVENT || symbol==Event.ENDEVENT) ) {
|
---|
70 | // only add the symbol the sequence if it is not START or END
|
---|
71 | context.add(symbol);
|
---|
72 | currentState = symbol;
|
---|
73 | sequence.add(currentState);
|
---|
74 | }
|
---|
75 | break;
|
---|
76 | }
|
---|
77 | }
|
---|
78 | }
|
---|
79 | return sequence;
|
---|
80 | }
|
---|
81 |
|
---|
82 | public String getTrieDotRepresentation() {
|
---|
83 | return trie.getDotRepresentation();
|
---|
84 | }
|
---|
85 |
|
---|
86 | public Tree<TrieVertex, Edge> getTrieGraph() {
|
---|
87 | return trie.getGraph();
|
---|
88 | }
|
---|
89 |
|
---|
90 | @Override
|
---|
91 | public String toString() {
|
---|
92 | return trie.toString();
|
---|
93 | }
|
---|
94 |
|
---|
95 | public int getNumStates() {
|
---|
96 | return trie.getNumSymbols();
|
---|
97 | }
|
---|
98 |
|
---|
99 | public String[] getStateStrings() {
|
---|
100 | String[] stateStrings = new String[getNumStates()];
|
---|
101 | int i=0;
|
---|
102 | for( Event<?> symbol : trie.getKnownSymbols() ) {
|
---|
103 | stateStrings[i] = symbol.toString();
|
---|
104 | i++;
|
---|
105 | }
|
---|
106 | return stateStrings;
|
---|
107 | }
|
---|
108 |
|
---|
109 | public Set<? extends Event<?>> getEvents() {
|
---|
110 | return trie.getKnownSymbols();
|
---|
111 | }
|
---|
112 |
|
---|
113 | public Set<List<? extends Event<?>>> generateSequences(int length) {
|
---|
114 | return generateSequences(length, false);
|
---|
115 | /*Set<List<? extends Event<?>>> sequenceSet = new LinkedHashSet<List<? extends Event<?>>>();;
|
---|
116 | if( length<1 ) {
|
---|
117 | throw new InvalidParameterException("Length of generated subsequences must be at least 1.");
|
---|
118 | }
|
---|
119 | if( length==1 ) {
|
---|
120 | for( Event<?> event : getEvents() ) {
|
---|
121 | List<Event<?>> subSeq = new LinkedList<Event<?>>();
|
---|
122 | subSeq.add(event);
|
---|
123 | sequenceSet.add(subSeq);
|
---|
124 | }
|
---|
125 | return sequenceSet;
|
---|
126 | }
|
---|
127 | Set<? extends Event<?>> events = getEvents();
|
---|
128 | Set<List<? extends Event<?>>> seqsShorter = generateSequences(length-1);
|
---|
129 | for( Event<?> event : events ) {
|
---|
130 | for( List<? extends Event<?>> seqShorter : seqsShorter ) {
|
---|
131 | Event<?> lastEvent = event;
|
---|
132 | if( getProbability(seqShorter, lastEvent)>0.0 ) {
|
---|
133 | List<Event<?>> subSeq = new ArrayList<Event<?>>(seqShorter);
|
---|
134 | subSeq.add(lastEvent);
|
---|
135 | sequenceSet.add(subSeq);
|
---|
136 | }
|
---|
137 | }
|
---|
138 | }
|
---|
139 | return sequenceSet;
|
---|
140 | */
|
---|
141 | }
|
---|
142 |
|
---|
143 | // if startValid, all sequences will start in Event.STARTEVENT
|
---|
144 | public Set<List<? extends Event<?>>> generateSequences(int length, boolean fromStart) {
|
---|
145 | Set<List<? extends Event<?>>> sequenceSet = new LinkedHashSet<List<? extends Event<?>>>();;
|
---|
146 | if( length<1 ) {
|
---|
147 | throw new InvalidParameterException("Length of generated subsequences must be at least 1.");
|
---|
148 | }
|
---|
149 | if( length==1 ) {
|
---|
150 | if( fromStart ) {
|
---|
151 | List<Event<?>> subSeq = new LinkedList<Event<?>>();
|
---|
152 | subSeq.add(Event.STARTEVENT);
|
---|
153 | sequenceSet.add(subSeq);
|
---|
154 | } else {
|
---|
155 | for( Event<?> event : getEvents() ) {
|
---|
156 | List<Event<?>> subSeq = new LinkedList<Event<?>>();
|
---|
157 | subSeq.add(event);
|
---|
158 | sequenceSet.add(subSeq);
|
---|
159 | }
|
---|
160 | }
|
---|
161 | return sequenceSet;
|
---|
162 | }
|
---|
163 | Set<? extends Event<?>> events = getEvents();
|
---|
164 | Set<List<? extends Event<?>>> seqsShorter = generateSequences(length-1, fromStart);
|
---|
165 | for( Event<?> event : events ) {
|
---|
166 | for( List<? extends Event<?>> seqShorter : seqsShorter ) {
|
---|
167 | Event<?> lastEvent = event;
|
---|
168 | if( getProbability(seqShorter, lastEvent)>0.0 ) {
|
---|
169 | List<Event<?>> subSeq = new ArrayList<Event<?>>(seqShorter);
|
---|
170 | subSeq.add(lastEvent);
|
---|
171 | sequenceSet.add(subSeq);
|
---|
172 | }
|
---|
173 | }
|
---|
174 | }
|
---|
175 | return sequenceSet;
|
---|
176 | }
|
---|
177 |
|
---|
178 | // sequences from start to end
|
---|
179 | public Set<List<? extends Event<?>>> generateValidSequences(int length) {
|
---|
180 | // check for min-length implicitly done by generateSequences
|
---|
181 | Set<List<? extends Event<?>>> allSequences = generateSequences(length, true);
|
---|
182 | Set<List<? extends Event<?>>> validSequences = new LinkedHashSet<List<? extends Event<?>>>();
|
---|
183 | for( List<? extends Event<?>> sequence : allSequences ) {
|
---|
184 | if( sequence.size()==length && Event.ENDEVENT.equals(sequence.get(sequence.size()-1)) ) {
|
---|
185 | validSequences.add(sequence);
|
---|
186 | }
|
---|
187 | }
|
---|
188 | return validSequences;
|
---|
189 | }
|
---|
190 |
|
---|
191 | } |
---|