source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/models/TrieBasedModel.java @ 325

Last change on this file since 325 was 325, checked in by sherbold, 13 years ago
  • changed signature of de.ugoe.cs.util.eventbench.models.TrieBasedModel?.train() and update() from using Collection<List<Event<?>>> to Collection<List<? extends Event<?>>>
  • Property svn:mime-type set to text/plain
File size: 9.3 KB
RevLine 
[13]1package de.ugoe.cs.eventbench.models;
[12]2
[93]3import java.security.InvalidParameterException;
4import java.util.ArrayList;
[102]5import java.util.Collection;
[252]6import java.util.HashSet;
[93]7import java.util.LinkedHashSet;
[12]8import java.util.LinkedList;
9import java.util.List;
10import java.util.Random;
[80]11import java.util.Set;
[12]12
13import de.ugoe.cs.eventbench.data.Event;
[23]14import de.ugoe.cs.eventbench.models.Trie.Edge;
15import de.ugoe.cs.eventbench.models.Trie.TrieVertex;
16import edu.uci.ics.jung.graph.Tree;
[12]17
[100]18/**
19 * <p>
20 * Implements a skeleton for stochastic processes that can calculate
21 * probabilities based on a trie. The skeleton provides all functionalities of
22 * {@link IStochasticProcess} except
23 * {@link IStochasticProcess#getProbability(List, Event)}.
24 * </p>
25 *
26 * @author Steffen Herbold
27 * @version 1.0
28 */
[17]29public abstract class TrieBasedModel implements IStochasticProcess {
[12]30
[86]31        /**
[100]32         * <p>
[86]33         * Id for object serialization.
[100]34         * </p>
[86]35         */
36        private static final long serialVersionUID = 1L;
37
[100]38        /**
39         * <p>
40         * The order of the trie, i.e., the maximum length of subsequences stored in
41         * the trie.
42         * </p>
43         */
[16]44        protected int trieOrder;
[12]45
[100]46        /**
47         * <p>
48         * Trie on which the probability calculations are based.
49         * </p>
50         */
[182]51        protected Trie<Event<?>> trie = null;
[100]52
53        /**
54         * <p>
55         * Random number generator used by probabilistic sequence generation
56         * methods.
57         * </p>
58         */
[12]59        protected final Random r;
60
[100]61        /**
62         * <p>
63         * Constructor. Creates a new TrieBasedModel that can be used for stochastic
64         * processes with a Markov order less than or equal to {@code markovOrder}.
65         * </p>
66         *
67         * @param markovOrder
68         *            Markov order of the model
69         * @param r
70         *            random number generator used by probabilistic methods of the
71         *            class
72         */
[16]73        public TrieBasedModel(int markovOrder, Random r) {
[12]74                super();
[100]75                this.trieOrder = markovOrder + 1;
[12]76                this.r = r;
77        }
78
[100]79        /**
80         * <p>
81         * Trains the model by generating a trie from which probabilities are
[182]82         * calculated. The trie is newly generated based solely on the passed
83         * sequences. If an existing model should only be updated, use
84         * {@link #update(Collection)} instead.
[100]85         * </p>
86         *
87         * @param sequences
88         *            training data
89         */
[325]90        public void train(Collection<List<? extends Event<?>>> sequences) {
[182]91                trie = null;
92                update(sequences);
93        }
[100]94
[182]95        /**
96         * <p>
97         * Trains the model by updating the trie from which the probabilities are
98         * calculated. This function updates an existing trie. In case no trie
99         * exists yet, a new trie is generated and the function behaves like
100         * {@link #train(Collection)}.
101         * </p>
102         *
103         * @param sequences
104         *            training data
105         */
[325]106        public void update(Collection<List<? extends Event<?>>> sequences) {
[252]107                if (sequences == null) {
108                        return;
109                }
[182]110                if (trie == null) {
111                        trie = new Trie<Event<?>>();
112                }
[325]113                for (List<? extends Event<?>> sequence : sequences) {
[100]114                        List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive
115                                                                                                                                                                        // copy
[12]116                        currentSequence.add(0, Event.STARTEVENT);
117                        currentSequence.add(Event.ENDEVENT);
[100]118
[16]119                        trie.train(currentSequence, trieOrder);
[12]120                }
121        }
122
[100]123        /*
124         * (non-Javadoc)
125         *
[17]126         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#randomSequence()
127         */
128        @Override
[12]129        public List<? extends Event<?>> randomSequence() {
130                List<Event<?>> sequence = new LinkedList<Event<?>>();
[252]131                if( trie!=null ) {
132                        IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(
133                                        trieOrder - 1);
134                        context.add(Event.STARTEVENT);
135       
136                        Event<?> currentState = Event.STARTEVENT;
137       
138                        boolean endFound = false;
139       
140                        while (!endFound) {
141                                double randVal = r.nextDouble();
142                                double probSum = 0.0;
143                                List<Event<?>> currentContext = context.getLast(trieOrder);
144                                for (Event<?> symbol : trie.getKnownSymbols()) {
145                                        probSum += getProbability(currentContext, symbol);
146                                        if (probSum >= randVal) {
147                                                endFound = (symbol == Event.ENDEVENT);
148                                                if (!(symbol == Event.STARTEVENT || symbol == Event.ENDEVENT)) {
149                                                        // only add the symbol the sequence if it is not START
150                                                        // or END
151                                                        context.add(symbol);
152                                                        currentState = symbol;
153                                                        sequence.add(currentState);
154                                                }
155                                                break;
[12]156                                        }
157                                }
158                        }
159                }
160                return sequence;
161        }
[100]162
163        /**
164         * <p>
165         * Returns a Dot representation of the internal trie.
166         * </p>
167         *
168         * @return dot representation of the internal trie
169         */
[30]170        public String getTrieDotRepresentation() {
[252]171                if (trie == null) {
172                        return "";
173                } else {
174                        return trie.getDotRepresentation();
175                }
[30]176        }
[100]177
178        /**
179         * <p>
180         * Returns a {@link Tree} of the internal trie that can be used for
181         * visualization.
182         * </p>
183         *
184         * @return {@link Tree} depicting the internal trie
185         */
[23]186        public Tree<TrieVertex, Edge> getTrieGraph() {
[252]187                if (trie == null) {
188                        return null;
189                } else {
190                        return trie.getGraph();
191                }
[23]192        }
[12]193
[100]194        /**
195         * <p>
196         * The string representation of the model is {@link Trie#toString()} of
197         * {@link #trie}.
198         * </p>
199         *
200         * @see java.lang.Object#toString()
201         */
[12]202        @Override
203        public String toString() {
[252]204                if (trie == null) {
205                        return "";
206                } else {
207                        return trie.toString();
208                }
[12]209        }
[100]210
211        /*
212         * (non-Javadoc)
213         *
214         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getNumStates()
215         */
216        @Override
[129]217        public int getNumSymbols() {
[252]218                if (trie == null) {
219                        return 0;
220                } else {
221                        return trie.getNumSymbols();
222                }
[66]223        }
[100]224
225        /*
226         * (non-Javadoc)
227         *
228         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getStateStrings()
229         */
230        @Override
[129]231        public String[] getSymbolStrings() {
[252]232                if (trie == null) {
233                        return new String[0];
234                }
[129]235                String[] stateStrings = new String[getNumSymbols()];
[100]236                int i = 0;
237                for (Event<?> symbol : trie.getKnownSymbols()) {
[70]238                        stateStrings[i] = symbol.toString();
239                        i++;
240                }
241                return stateStrings;
242        }
[100]243
244        /*
245         * (non-Javadoc)
246         *
247         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getEvents()
248         */
249        @Override
[102]250        public Collection<? extends Event<?>> getEvents() {
[252]251                if (trie == null) {
252                        return new HashSet<Event<?>>();
253                } else {
254                        return trie.getKnownSymbols();
255                }
[80]256        }
[100]257
258        /*
259         * (non-Javadoc)
260         *
261         * @see
262         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateSequences(int)
263         */
264        @Override
[102]265        public Collection<List<? extends Event<?>>> generateSequences(int length) {
[94]266                return generateSequences(length, false);
[93]267        }
[100]268
269        /*
270         * (non-Javadoc)
271         *
272         * @see
273         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateSequences(int,
274         * boolean)
275         */
276        @Override
277        public Set<List<? extends Event<?>>> generateSequences(int length,
278                        boolean fromStart) {
279                Set<List<? extends Event<?>>> sequenceSet = new LinkedHashSet<List<? extends Event<?>>>();
280                if (length < 1) {
281                        throw new InvalidParameterException(
282                                        "Length of generated subsequences must be at least 1.");
[94]283                }
[100]284                if (length == 1) {
285                        if (fromStart) {
[94]286                                List<Event<?>> subSeq = new LinkedList<Event<?>>();
287                                subSeq.add(Event.STARTEVENT);
[95]288                                sequenceSet.add(subSeq);
[94]289                        } else {
[100]290                                for (Event<?> event : getEvents()) {
[94]291                                        List<Event<?>> subSeq = new LinkedList<Event<?>>();
292                                        subSeq.add(event);
293                                        sequenceSet.add(subSeq);
294                                }
295                        }
296                        return sequenceSet;
297                }
[102]298                Collection<? extends Event<?>> events = getEvents();
299                Collection<List<? extends Event<?>>> seqsShorter = generateSequences(
[100]300                                length - 1, fromStart);
301                for (Event<?> event : events) {
302                        for (List<? extends Event<?>> seqShorter : seqsShorter) {
[94]303                                Event<?> lastEvent = event;
[100]304                                if (getProbability(seqShorter, lastEvent) > 0.0) {
[94]305                                        List<Event<?>> subSeq = new ArrayList<Event<?>>(seqShorter);
306                                        subSeq.add(lastEvent);
307                                        sequenceSet.add(subSeq);
308                                }
309                        }
310                }
311                return sequenceSet;
312        }
[100]313
314        /*
315         * (non-Javadoc)
316         *
317         * @see
318         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateValidSequences
319         * (int)
320         */
321        @Override
[118]322        public Collection<List<? extends Event<?>>> generateValidSequences(
323                        int length) {
[94]324                // check for min-length implicitly done by generateSequences
[118]325                Collection<List<? extends Event<?>>> allSequences = generateSequences(
326                                length, true);
[102]327                Collection<List<? extends Event<?>>> validSequences = new LinkedHashSet<List<? extends Event<?>>>();
[100]328                for (List<? extends Event<?>> sequence : allSequences) {
329                        if (sequence.size() == length
330                                        && Event.ENDEVENT.equals(sequence.get(sequence.size() - 1))) {
[95]331                                validSequences.add(sequence);
[94]332                        }
333                }
334                return validSequences;
335        }
[12]336
[118]337        /*
338         * (non-Javadoc)
339         *
340         * @see
341         * de.ugoe.cs.eventbench.models.IStochasticProcess#getProbability(java.util
342         * .List)
343         */
344        @Override
345        public double getProbability(List<? extends Event<?>> sequence) {
346                double prob = 1.0;
347                if (sequence != null) {
348                        List<Event<?>> context = new LinkedList<Event<?>>();
349                        for (Event<?> event : sequence) {
350                                prob *= getProbability(context, event);
351                                context.add(event);
352                        }
353                }
354                return prob;
355        }
356
[182]357        /*
358         * (non-Javadoc)
359         *
[129]360         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getNumFOMStates()
361         */
362        @Override
363        public int getNumFOMStates() {
[252]364                if (trie == null) {
365                        return 0;
366                } else {
367                        return trie.getNumLeafAncestors();
368                }
[129]369        }
[248]370
371        /*
372         * (non-Javadoc)
373         *
374         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getNumTransitions()
375         */
376        @Override
377        public int getNumTransitions() {
[252]378                if (trie == null) {
379                        return 0;
380                } else {
381                        return trie.getNumLeafs();
382                }
[248]383        }
[12]384}
Note: See TracBrowser for help on using the repository browser.