source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/models/TrieBasedModel.java @ 176

Last change on this file since 176 was 176, checked in by sherbold, 13 years ago
  • fixed minor code smells
  • Property svn:mime-type set to text/plain
File size: 7.9 KB
Line 
1package de.ugoe.cs.eventbench.models;
2
3import java.security.InvalidParameterException;
4import java.util.ArrayList;
5import java.util.Collection;
6import java.util.LinkedHashSet;
7import java.util.LinkedList;
8import java.util.List;
9import java.util.Random;
10import java.util.Set;
11
12import de.ugoe.cs.eventbench.data.Event;
13import de.ugoe.cs.eventbench.models.Trie.Edge;
14import de.ugoe.cs.eventbench.models.Trie.TrieVertex;
15import edu.uci.ics.jung.graph.Tree;
16
17/**
18 * <p>
19 * Implements a skeleton for stochastic processes that can calculate
20 * probabilities based on a trie. The skeleton provides all functionalities of
21 * {@link IStochasticProcess} except
22 * {@link IStochasticProcess#getProbability(List, Event)}.
23 * </p>
24 *
25 * @author Steffen Herbold
26 * @version 1.0
27 */
28public abstract class TrieBasedModel implements IStochasticProcess {
29
30        /**
31         * <p>
32         * Id for object serialization.
33         * </p>
34         */
35        private static final long serialVersionUID = 1L;
36
37        /**
38         * <p>
39         * The order of the trie, i.e., the maximum length of subsequences stored in
40         * the trie.
41         * </p>
42         */
43        protected int trieOrder;
44
45        /**
46         * <p>
47         * Trie on which the probability calculations are based.
48         * </p>
49         */
50        protected Trie<Event<?>> trie;
51
52        /**
53         * <p>
54         * Random number generator used by probabilistic sequence generation
55         * methods.
56         * </p>
57         */
58        protected final Random r;
59
60        /**
61         * <p>
62         * Constructor. Creates a new TrieBasedModel that can be used for stochastic
63         * processes with a Markov order less than or equal to {@code markovOrder}.
64         * </p>
65         *
66         * @param markovOrder
67         *            Markov order of the model
68         * @param r
69         *            random number generator used by probabilistic methods of the
70         *            class
71         */
72        public TrieBasedModel(int markovOrder, Random r) {
73                super();
74                this.trieOrder = markovOrder + 1;
75                this.r = r;
76        }
77
78        /**
79         * <p>
80         * Trains the model by generating a trie from which probabilities are
81         * calculated.
82         * </p>
83         *
84         * @param sequences
85         *            training data
86         */
87        public void train(Collection<List<Event<?>>> sequences) {
88                trie = new Trie<Event<?>>();
89
90                for (List<Event<?>> sequence : sequences) {
91                        List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive
92                                                                                                                                                                        // copy
93                        currentSequence.add(0, Event.STARTEVENT);
94                        currentSequence.add(Event.ENDEVENT);
95
96                        trie.train(currentSequence, trieOrder);
97                }
98        }
99
100        /*
101         * (non-Javadoc)
102         *
103         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#randomSequence()
104         */
105        @Override
106        public List<? extends Event<?>> randomSequence() {
107                List<Event<?>> sequence = new LinkedList<Event<?>>();
108
109                IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(
110                                trieOrder - 1);
111                context.add(Event.STARTEVENT);
112
113                Event<?> currentState = Event.STARTEVENT;
114
115                boolean endFound = false;
116
117                while (!endFound) {
118                        double randVal = r.nextDouble();
119                        double probSum = 0.0;
120                        List<Event<?>> currentContext = context.getLast(trieOrder);
121                        for (Event<?> symbol : trie.getKnownSymbols()) {
122                                probSum += getProbability(currentContext, symbol);
123                                if (probSum >= randVal) {
124                                        endFound = (symbol == Event.ENDEVENT);
125                                        if (!(symbol == Event.STARTEVENT || symbol == Event.ENDEVENT)) {
126                                                // only add the symbol the sequence if it is not START
127                                                // or END
128                                                context.add(symbol);
129                                                currentState = symbol;
130                                                sequence.add(currentState);
131                                        }
132                                        break;
133                                }
134                        }
135                }
136                return sequence;
137        }
138
139        /**
140         * <p>
141         * Returns a Dot representation of the internal trie.
142         * </p>
143         *
144         * @return dot representation of the internal trie
145         */
146        public String getTrieDotRepresentation() {
147                return trie.getDotRepresentation();
148        }
149
150        /**
151         * <p>
152         * Returns a {@link Tree} of the internal trie that can be used for
153         * visualization.
154         * </p>
155         *
156         * @return {@link Tree} depicting the internal trie
157         */
158        public Tree<TrieVertex, Edge> getTrieGraph() {
159                return trie.getGraph();
160        }
161
162        /**
163         * <p>
164         * The string representation of the model is {@link Trie#toString()} of
165         * {@link #trie}.
166         * </p>
167         *
168         * @see java.lang.Object#toString()
169         */
170        @Override
171        public String toString() {
172                return trie.toString();
173        }
174
175        /*
176         * (non-Javadoc)
177         *
178         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getNumStates()
179         */
180        @Override
181        public int getNumSymbols() {
182                return trie.getNumSymbols();
183        }
184
185        /*
186         * (non-Javadoc)
187         *
188         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getStateStrings()
189         */
190        @Override
191        public String[] getSymbolStrings() {
192                String[] stateStrings = new String[getNumSymbols()];
193                int i = 0;
194                for (Event<?> symbol : trie.getKnownSymbols()) {
195                        stateStrings[i] = symbol.toString();
196                        i++;
197                }
198                return stateStrings;
199        }
200
201        /*
202         * (non-Javadoc)
203         *
204         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getEvents()
205         */
206        @Override
207        public Collection<? extends Event<?>> getEvents() {
208                return trie.getKnownSymbols();
209        }
210
211        /*
212         * (non-Javadoc)
213         *
214         * @see
215         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateSequences(int)
216         */
217        @Override
218        public Collection<List<? extends Event<?>>> generateSequences(int length) {
219                return generateSequences(length, false);
220        }
221
222        /*
223         * (non-Javadoc)
224         *
225         * @see
226         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateSequences(int,
227         * boolean)
228         */
229        @Override
230        public Set<List<? extends Event<?>>> generateSequences(int length,
231                        boolean fromStart) {
232                Set<List<? extends Event<?>>> sequenceSet = new LinkedHashSet<List<? extends Event<?>>>();
233                if (length < 1) {
234                        throw new InvalidParameterException(
235                                        "Length of generated subsequences must be at least 1.");
236                }
237                if (length == 1) {
238                        if (fromStart) {
239                                List<Event<?>> subSeq = new LinkedList<Event<?>>();
240                                subSeq.add(Event.STARTEVENT);
241                                sequenceSet.add(subSeq);
242                        } else {
243                                for (Event<?> event : getEvents()) {
244                                        List<Event<?>> subSeq = new LinkedList<Event<?>>();
245                                        subSeq.add(event);
246                                        sequenceSet.add(subSeq);
247                                }
248                        }
249                        return sequenceSet;
250                }
251                Collection<? extends Event<?>> events = getEvents();
252                Collection<List<? extends Event<?>>> seqsShorter = generateSequences(
253                                length - 1, fromStart);
254                for (Event<?> event : events) {
255                        for (List<? extends Event<?>> seqShorter : seqsShorter) {
256                                Event<?> lastEvent = event;
257                                if (getProbability(seqShorter, lastEvent) > 0.0) {
258                                        List<Event<?>> subSeq = new ArrayList<Event<?>>(seqShorter);
259                                        subSeq.add(lastEvent);
260                                        sequenceSet.add(subSeq);
261                                }
262                        }
263                }
264                return sequenceSet;
265        }
266
267        /*
268         * (non-Javadoc)
269         *
270         * @see
271         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateValidSequences
272         * (int)
273         */
274        @Override
275        public Collection<List<? extends Event<?>>> generateValidSequences(
276                        int length) {
277                // check for min-length implicitly done by generateSequences
278                Collection<List<? extends Event<?>>> allSequences = generateSequences(
279                                length, true);
280                Collection<List<? extends Event<?>>> validSequences = new LinkedHashSet<List<? extends Event<?>>>();
281                for (List<? extends Event<?>> sequence : allSequences) {
282                        if (sequence.size() == length
283                                        && Event.ENDEVENT.equals(sequence.get(sequence.size() - 1))) {
284                                validSequences.add(sequence);
285                        }
286                }
287                return validSequences;
288        }
289
290        /*
291         * (non-Javadoc)
292         *
293         * @see
294         * de.ugoe.cs.eventbench.models.IStochasticProcess#getProbability(java.util
295         * .List)
296         */
297        @Override
298        public double getProbability(List<? extends Event<?>> sequence) {
299                double prob = 1.0;
300                if (sequence != null) {
301                        List<Event<?>> context = new LinkedList<Event<?>>();
302                        for (Event<?> event : sequence) {
303                                prob *= getProbability(context, event);
304                                context.add(event);
305                        }
306                }
307                return prob;
308        }
309
310        /* (non-Javadoc)
311         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getNumFOMStates()
312         */
313        @Override
314        public int getNumFOMStates() {
315                return trie.getNumLeafAncestors();
316        }
317}
Note: See TracBrowser for help on using the repository browser.