source: trunk/quest-core-usageprofiles/src/main/java/de/ugoe/cs/quest/usageprofiles/TrieBasedModel.java @ 655

Last change on this file since 655 was 655, checked in by pharms, 12 years ago
  • removed old copyright file header
  • Property svn:mime-type set to text/plain
File size: 12.3 KB
Line 
1package de.ugoe.cs.quest.usageprofiles;
2
3import java.security.InvalidParameterException;
4import java.util.ArrayList;
5import java.util.Collection;
6import java.util.HashSet;
7import java.util.LinkedHashSet;
8import java.util.LinkedList;
9import java.util.List;
10import java.util.Random;
11import java.util.Set;
12
13import de.ugoe.cs.quest.eventcore.Event;
14import de.ugoe.cs.quest.usageprofiles.Trie.Edge;
15import de.ugoe.cs.quest.usageprofiles.Trie.TrieVertex;
16import edu.uci.ics.jung.graph.Tree;
17
18/**
19 * <p>
20 * Implements a skeleton for stochastic processes that can calculate probabilities based on a trie.
21 * The skeleton provides all functionalities of {@link IStochasticProcess} except
22 * {@link IStochasticProcess#getProbability(List, Event)}.
23 * </p>
24 *
25 * @author Steffen Herbold
26 * @version 1.0
27 */
28public abstract class TrieBasedModel implements IStochasticProcess {
29
30    /**
31     * <p>
32     * Id for object serialization.
33     * </p>
34     */
35    private static final long serialVersionUID = 1L;
36
37    /**
38     * <p>
39     * The order of the trie, i.e., the maximum length of subsequences stored in the trie.
40     * </p>
41     */
42    protected int trieOrder;
43
44    /**
45     * <p>
46     * Trie on which the probability calculations are based.
47     * </p>
48     */
49    protected Trie<Event> trie = null;
50
51    /**
52     * <p>
53     * Random number generator used by probabilistic sequence generation methods.
54     * </p>
55     */
56    protected final Random r;
57
58    /**
59     * <p>
60     * Constructor. Creates a new TrieBasedModel that can be used for stochastic processes with a
61     * Markov order less than or equal to {@code markovOrder}.
62     * </p>
63     *
64     * @param markovOrder
65     *            Markov order of the model
66     * @param r
67     *            random number generator used by probabilistic methods of the class
68     * @throws InvalidParameterException
69     *             thrown if markovOrder is less than 0 or the random number generator r is null
70     */
71    public TrieBasedModel(int markovOrder, Random r) {
72        super();
73        if (markovOrder < 0) {
74            throw new InvalidParameterException("markov order must not be less than 0");
75        }
76        if (r == null) {
77            throw new InvalidParameterException("random number generator r must not be null");
78        }
79        this.trieOrder = markovOrder + 1;
80        this.r = r;
81    }
82
83    /**
84     * <p>
85     * Trains the model by generating a trie from which probabilities are calculated. The trie is
86     * newly generated based solely on the passed sequences. If an existing model should only be
87     * updated, use {@link #update(Collection)} instead.
88     * </p>
89     *
90     * @param sequences
91     *            training data
92     * @throws InvalidParameterException
93     *             thrown is sequences is null
94     */
95    public void train(Collection<List<Event>> sequences) {
96        trie = null;
97        update(sequences);
98    }
99
100    /**
101     * <p>
102     * Trains the model by updating the trie from which the probabilities are calculated. This
103     * function updates an existing trie. In case no trie exists yet, a new trie is generated and
104     * the function behaves like {@link #train(Collection)}.
105     * </p>
106     *
107     * @param sequences
108     *            training data
109     * @throws InvalidParameterException
110     *             thrown is sequences is null
111     */
112    public void update(Collection<List<Event>> sequences) {
113        if (sequences == null) {
114            throw new InvalidParameterException("sequences must not be null");
115        }
116        if (trie == null) {
117            trie = new Trie<Event>();
118        }
119        for (List<Event> sequence : sequences) {
120            List<Event> currentSequence = new LinkedList<Event>(sequence); // defensive
121                                                                           // copy
122            currentSequence.add(0, Event.STARTEVENT);
123            currentSequence.add(Event.ENDEVENT);
124
125            trie.train(currentSequence, trieOrder);
126        }
127    }
128
129    /*
130     * (non-Javadoc)
131     *
132     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#randomSequence()
133     */
134    @Override
135    public List<Event> randomSequence() {
136        return randomSequence(Integer.MAX_VALUE, true);
137    }
138
139    /*
140     * (non-Javadoc)
141     *
142     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#randomSequence()
143     */
144    @Override
145    public List<Event> randomSequence(int maxLength, boolean validEnd) {
146        List<Event> sequence = new LinkedList<Event>();
147        if (trie != null) {
148            boolean endFound = false;
149            while (!endFound) { // outer loop for length checking
150                sequence = new LinkedList<Event>();
151                IncompleteMemory<Event> context = new IncompleteMemory<Event>(trieOrder - 1);
152                context.add(Event.STARTEVENT);
153
154                while (!endFound && sequence.size() <= maxLength) {
155                    double randVal = r.nextDouble();
156                    double probSum = 0.0;
157                    List<Event> currentContext = context.getLast(trieOrder);
158                    for (Event symbol : trie.getKnownSymbols()) {
159                        probSum += getProbability(currentContext, symbol);
160                        if (probSum >= randVal) {
161                            if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT.equals(symbol)))
162                            {
163                                // only add the symbol the sequence if it is not
164                                // START or END
165                                context.add(symbol);
166                                sequence.add(symbol);
167                            }
168                            endFound =
169                                (Event.ENDEVENT.equals(symbol)) ||
170                                    (!validEnd && sequence.size() == maxLength);
171                            break;
172                        }
173                    }
174                }
175            }
176        }
177        return sequence;
178    }
179
180    /**
181     * <p>
182     * Returns a Dot representation of the internal trie.
183     * </p>
184     *
185     * @return dot representation of the internal trie
186     */
187    public String getTrieDotRepresentation() {
188        if (trie == null) {
189            return "";
190        }
191        else {
192            return trie.getDotRepresentation();
193        }
194    }
195
196    /**
197     * <p>
198     * Returns a {@link Tree} of the internal trie that can be used for visualization.
199     * </p>
200     *
201     * @return {@link Tree} depicting the internal trie
202     */
203    public Tree<TrieVertex, Edge> getTrieGraph() {
204        if (trie == null) {
205            return null;
206        }
207        else {
208            return trie.getGraph();
209        }
210    }
211
212    /**
213     * <p>
214     * The string representation of the model is {@link Trie#toString()} of {@link #trie}.
215     * </p>
216     *
217     * @see java.lang.Object#toString()
218     */
219    @Override
220    public String toString() {
221        if (trie == null) {
222            return "";
223        }
224        else {
225            return trie.toString();
226        }
227    }
228
229    /*
230     * (non-Javadoc)
231     *
232     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumStates()
233     */
234    @Override
235    public int getNumSymbols() {
236        if (trie == null) {
237            return 0;
238        }
239        else {
240            return trie.getNumSymbols();
241        }
242    }
243
244    /*
245     * (non-Javadoc)
246     *
247     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getStateStrings()
248     */
249    @Override
250    public String[] getSymbolStrings() {
251        if (trie == null) {
252            return new String[0];
253        }
254        String[] stateStrings = new String[getNumSymbols()];
255        int i = 0;
256        for (Event symbol : trie.getKnownSymbols()) {
257            if (symbol.toString() == null) {
258                stateStrings[i] = "null";
259            }
260            else {
261                stateStrings[i] = symbol.toString();
262            }
263            i++;
264        }
265        return stateStrings;
266    }
267
268    /*
269     * (non-Javadoc)
270     *
271     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getEvents()
272     */
273    @Override
274    public Collection<Event> getEvents() {
275        if (trie == null) {
276            return new HashSet<Event>();
277        }
278        else {
279            return trie.getKnownSymbols();
280        }
281    }
282
283    /*
284     * (non-Javadoc)
285     *
286     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateSequences(int)
287     */
288    @Override
289    public Collection<List<Event>> generateSequences(int length) {
290        return generateSequences(length, false);
291    }
292
293    /*
294     * (non-Javadoc)
295     *
296     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateSequences(int, boolean)
297     */
298    @Override
299    public Set<List<Event>> generateSequences(int length, boolean fromStart) {
300        Set<List<Event>> sequenceSet = new LinkedHashSet<List<Event>>();
301        if (length < 1) {
302            throw new InvalidParameterException(
303                                                "Length of generated subsequences must be at least 1.");
304        }
305        if (length == 1) {
306            if (fromStart) {
307                List<Event> subSeq = new LinkedList<Event>();
308                subSeq.add(Event.STARTEVENT);
309                sequenceSet.add(subSeq);
310            }
311            else {
312                for (Event event : getEvents()) {
313                    List<Event> subSeq = new LinkedList<Event>();
314                    subSeq.add(event);
315                    sequenceSet.add(subSeq);
316                }
317            }
318            return sequenceSet;
319        }
320        Collection<Event> events = getEvents();
321        Collection<List<Event>> seqsShorter = generateSequences(length - 1, fromStart);
322        for (Event event : events) {
323            for (List<Event> seqShorter : seqsShorter) {
324                Event lastEvent = event;
325                if (getProbability(seqShorter, lastEvent) > 0.0) {
326                    List<Event> subSeq = new ArrayList<Event>(seqShorter);
327                    subSeq.add(lastEvent);
328                    sequenceSet.add(subSeq);
329                }
330            }
331        }
332        return sequenceSet;
333    }
334
335    /*
336     * (non-Javadoc)
337     *
338     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateValidSequences (int)
339     */
340    @Override
341    public Collection<List<Event>> generateValidSequences(int length) {
342        // check for min-length implicitly done by generateSequences
343        Collection<List<Event>> allSequences = generateSequences(length, true);
344        Collection<List<Event>> validSequences = new LinkedHashSet<List<Event>>();
345        for (List<Event> sequence : allSequences) {
346            if (sequence.size() == length &&
347                Event.ENDEVENT.equals(sequence.get(sequence.size() - 1)))
348            {
349                validSequences.add(sequence);
350            }
351        }
352        return validSequences;
353    }
354
355    /*
356     * (non-Javadoc)
357     *
358     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
359     */
360    @Override
361    public double getProbability(List<Event> sequence) {
362        if (sequence == null) {
363            throw new InvalidParameterException("sequence must not be null");
364        }
365        double prob = 1.0;
366        List<Event> context = new LinkedList<Event>();
367        for (Event event : sequence) {
368            prob *= getProbability(context, event);
369            context.add(event);
370        }
371        return prob;
372    }
373
374    /*
375     * (non-Javadoc)
376     *
377     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumFOMStates()
378     */
379    @Override
380    public int getNumFOMStates() {
381        if (trie == null) {
382            return 0;
383        }
384        else {
385            return trie.getNumLeafAncestors();
386        }
387    }
388
389    /*
390     * (non-Javadoc)
391     *
392     * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumTransitions()
393     */
394    @Override
395    public int getNumTransitions() {
396        if (trie == null) {
397            return 0;
398        }
399        else {
400            return trie.getNumLeafs();
401        }
402    }
403}
Note: See TracBrowser for help on using the repository browser.