source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/TrieBasedModel.java @ 2026

Last change on this file since 2026 was 2026, checked in by sherbold, 9 years ago
  • made generation of random sequences with an expected valid end and a predefined maximum length more robust. the generation now aborts after a user-defined number of attempts to create a valid sequence and returns an empty sequence instead.
  • Property svn:mime-type set to text/plain
File size: 13.6 KB
Line 
1//   Copyright 2012 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.autoquest.usageprofiles;
16
17import java.util.ArrayList;
18import java.util.Collection;
19import java.util.HashSet;
20import java.util.LinkedHashSet;
21import java.util.LinkedList;
22import java.util.List;
23import java.util.Random;
24import java.util.Set;
25
26import de.ugoe.cs.autoquest.eventcore.Event;
27import de.ugoe.cs.autoquest.usageprofiles.Trie.Edge;
28import de.ugoe.cs.autoquest.usageprofiles.Trie.TrieVertex;
29import edu.uci.ics.jung.graph.Tree;
30
31/**
32 * <p>
33 * Implements a skeleton for stochastic processes that can calculate probabilities based on a trie.
34 * The skeleton provides all functionalities of {@link IStochasticProcess} except
35 * {@link IStochasticProcess#getProbability(List, Event)}.
36 * </p>
37 *
38 * @author Steffen Herbold
39 * @version 1.0
40 */
41public abstract class TrieBasedModel implements IStochasticProcess {
42
43    /**
44     * <p>
45     * Id for object serialization.
46     * </p>
47     */
48    private static final long serialVersionUID = 1L;
49
50    /**
51     * <p>
52     * The order of the trie, i.e., the maximum length of subsequences stored in the trie.
53     * </p>
54     */
55    protected int trieOrder;
56
57    /**
58     * <p>
59     * Trie on which the probability calculations are based.
60     * </p>
61     */
62    protected Trie<Event> trie = null;
63
64    /**
65     * <p>
66     * Random number generator used by probabilistic sequence generation methods.
67     * </p>
68     */
69    protected final Random r;
70
71    /**
72     * <p>
73     * Constructor. Creates a new TrieBasedModel that can be used for stochastic processes with a
74     * Markov order less than or equal to {@code markovOrder}.
75     * </p>
76     *
77     * @param markovOrder
78     *            Markov order of the model
79     * @param r
80     *            random number generator used by probabilistic methods of the class
81     * @throws IllegalArgumentException
82     *             thrown if markovOrder is less than 0 or the random number generator r is null
83     */
84    public TrieBasedModel(int markovOrder, Random r) {
85        super();
86        if (markovOrder < 0) {
87            throw new IllegalArgumentException("markov order must not be less than 0");
88        }
89        if (r == null) {
90            throw new IllegalArgumentException("random number generator r must not be null");
91        }
92        this.trieOrder = markovOrder + 1;
93        this.r = r;
94    }
95
96    /**
97     * <p>
98     * Trains the model by generating a trie from which probabilities are calculated. The trie is
99     * newly generated based solely on the passed sequences. If an existing model should only be
100     * updated, use {@link #update(Collection)} instead.
101     * </p>
102     *
103     * @param sequences
104     *            training data
105     * @throws IllegalArgumentException
106     *             thrown is sequences is null
107     */
108    public void train(Collection<List<Event>> sequences) {
109        trie = null;
110        update(sequences);
111    }
112
113    /**
114     * <p>
115     * Trains the model by updating the trie from which the probabilities are calculated. This
116     * function updates an existing trie. In case no trie exists yet, a new trie is generated and
117     * the function behaves like {@link #train(Collection)}.
118     * </p>
119     *
120     * @param sequences
121     *            training data
122     * @throws IllegalArgumentException
123     *             thrown is sequences is null
124     */
125    public void update(Collection<List<Event>> sequences) {
126        if (sequences == null) {
127            throw new IllegalArgumentException("sequences must not be null");
128        }
129        if (trie == null) {
130            trie = new Trie<Event>();
131        }
132        for (List<Event> sequence : sequences) {
133            List<Event> currentSequence = new LinkedList<Event>(sequence); // defensive
134                                                                           // copy
135            currentSequence.add(0, Event.STARTEVENT);
136            currentSequence.add(Event.ENDEVENT);
137
138            trie.train(currentSequence, trieOrder);
139        }
140    }
141
142    /*
143     * (non-Javadoc)
144     *
145     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
146     */
147    @Override
148    public List<Event> randomSequence() {
149        return randomSequence(Integer.MAX_VALUE, true, 100);
150    }
151
152    /*
153     * (non-Javadoc)
154     *
155     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
156     */
157    @Override
158    public List<Event> randomSequence(int maxLength, boolean validEnd, long maxIter) {
159        List<Event> sequence = new LinkedList<Event>();
160        int attempts = 0;
161        if (trie != null) {
162            boolean endFound = false;
163            while (!endFound && attempts <= maxIter) { // outer loop for length checking
164                IncompleteMemory<Event> context = new IncompleteMemory<Event>(trieOrder - 1);
165                context.add(Event.STARTEVENT);
166
167                while (!endFound && sequence.size() <= maxLength) {
168                    double randVal = r.nextDouble();
169                    double probSum = 0.0;
170                    List<Event> currentContext = context.getLast(trieOrder);
171                    for (Event symbol : trie.getKnownSymbols()) {
172                        probSum += getProbability(currentContext, symbol);
173                        if (probSum >= randVal) {
174                            if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT.equals(symbol)))
175                            {
176                                // only add the symbol the sequence if it is not
177                                // START or END
178                                context.add(symbol);
179                                sequence.add(symbol);
180                            }
181                            endFound =
182                                (Event.ENDEVENT.equals(symbol)) ||
183                                    (!validEnd && sequence.size() == maxLength);
184                            break;
185                        }
186                    }
187                }
188                if (!endFound) {
189                    sequence = new LinkedList<Event>();
190                }
191                attempts++;
192            }
193        }
194        return sequence;
195    }
196
197    /**
198     * <p>
199     * Returns a Dot representation of the internal trie.
200     * </p>
201     *
202     * @return dot representation of the internal trie
203     */
204    public String getTrieDotRepresentation() {
205        if (trie == null) {
206            return "";
207        }
208        else {
209            return trie.getDotRepresentation();
210        }
211    }
212
213    /**
214     * <p>
215     * Returns a {@link Tree} of the internal trie that can be used for visualization.
216     * </p>
217     *
218     * @return {@link Tree} depicting the internal trie
219     */
220    public Tree<TrieVertex, Edge> getTrieGraph() {
221        if (trie == null) {
222            return null;
223        }
224        else {
225            return trie.getGraph();
226        }
227    }
228
229    /**
230     * <p>
231     * The string representation of the model is {@link Trie#toString()} of {@link #trie}.
232     * </p>
233     *
234     * @see java.lang.Object#toString()
235     */
236    @Override
237    public String toString() {
238        if (trie == null) {
239            return "";
240        }
241        else {
242            return trie.toString();
243        }
244    }
245
246    /*
247     * (non-Javadoc)
248     *
249     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumStates()
250     */
251    @Override
252    public int getNumSymbols() {
253        if (trie == null) {
254            return 0;
255        }
256        else {
257            return trie.getNumSymbols();
258        }
259    }
260
261    /*
262     * (non-Javadoc)
263     *
264     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getStateStrings()
265     */
266    @Override
267    public String[] getSymbolStrings() {
268        if (trie == null) {
269            return new String[0];
270        }
271        String[] stateStrings = new String[getNumSymbols()];
272        int i = 0;
273        for (Event symbol : trie.getKnownSymbols()) {
274            if (symbol.toString() == null) {
275                stateStrings[i] = "null";
276            }
277            else {
278                stateStrings[i] = symbol.toString();
279            }
280            i++;
281        }
282        return stateStrings;
283    }
284
285    /*
286     * (non-Javadoc)
287     *
288     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getEvents()
289     */
290    @Override
291    public Collection<Event> getEvents() {
292        if (trie == null) {
293            return new HashSet<Event>();
294        }
295        else {
296            return trie.getKnownSymbols();
297        }
298    }
299
300    /*
301     * (non-Javadoc)
302     *
303     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int)
304     */
305    @Override
306    public Collection<List<Event>> generateSequences(int length) {
307        return generateSequences(length, false);
308    }
309
310    /*
311     * (non-Javadoc)
312     *
313     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int, boolean)
314     */
315    @Override
316    public Set<List<Event>> generateSequences(int length, boolean fromStart) {
317        Set<List<Event>> sequenceSet = new LinkedHashSet<List<Event>>();
318        if (length < 1) {
319            throw new IllegalArgumentException(
320                                               "Length of generated subsequences must be at least 1.");
321        }
322        if (length == 1) {
323            if (fromStart) {
324                List<Event> subSeq = new LinkedList<Event>();
325                subSeq.add(Event.STARTEVENT);
326                sequenceSet.add(subSeq);
327            }
328            else {
329                for (Event event : getEvents()) {
330                    List<Event> subSeq = new LinkedList<Event>();
331                    subSeq.add(event);
332                    sequenceSet.add(subSeq);
333                }
334            }
335            return sequenceSet;
336        }
337        Collection<Event> events = getEvents();
338        Collection<List<Event>> seqsShorter = generateSequences(length - 1, fromStart);
339        for (Event event : events) {
340            for (List<Event> seqShorter : seqsShorter) {
341                Event lastEvent = event;
342                if (getProbability(seqShorter, lastEvent) > 0.0) {
343                    List<Event> subSeq = new ArrayList<Event>(seqShorter);
344                    subSeq.add(lastEvent);
345                    sequenceSet.add(subSeq);
346                }
347            }
348        }
349        return sequenceSet;
350    }
351
352    /*
353     * (non-Javadoc)
354     *
355     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateValidSequences (int)
356     */
357    @Override
358    public Collection<List<Event>> generateValidSequences(int length) {
359        // check for min-length implicitly done by generateSequences
360        Collection<List<Event>> allSequences = generateSequences(length, true);
361        Collection<List<Event>> validSequences = new LinkedHashSet<List<Event>>();
362        for (List<Event> sequence : allSequences) {
363            if (sequence.size() == length &&
364                Event.ENDEVENT.equals(sequence.get(sequence.size() - 1)))
365            {
366                validSequences.add(sequence);
367            }
368        }
369        return validSequences;
370    }
371
372    /*
373     * (non-Javadoc)
374     *
375     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
376     */
377    @Override
378    public double getProbability(List<Event> sequence) {
379        if (sequence == null) {
380            throw new IllegalArgumentException("sequence must not be null");
381        }
382        double prob = 1.0;
383        List<Event> context = new LinkedList<Event>();
384        for (Event event : sequence) {
385            prob *= getProbability(context, event);
386            context.add(event);
387        }
388        return prob;
389    }
390
391    /*
392     * (non-Javadoc)
393     *
394     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
395     */
396    @Override
397    public double getLogSum(List<Event> sequence) {
398        if (sequence == null) {
399            throw new IllegalArgumentException("sequence must not be null");
400        }
401        double odds = 0.0;
402        List<Event> context = new LinkedList<Event>();
403        for (Event event : sequence) {
404            odds += Math.log(getProbability(context, event) + 1);
405            context.add(event);
406        }
407        return odds;
408    }
409
410    /*
411     * (non-Javadoc)
412     *
413     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumFOMStates()
414     */
415    @Override
416    public int getNumFOMStates() {
417        if (trie == null) {
418            return 0;
419        }
420        else {
421            return trie.getNumLeafAncestors();
422        }
423    }
424
425    /*
426     * (non-Javadoc)
427     *
428     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumTransitions()
429     */
430    @Override
431    public int getNumTransitions() {
432        if (trie == null) {
433            return 0;
434        }
435        else {
436            return trie.getNumLeafs();
437        }
438    }
439}
Note: See TracBrowser for help on using the repository browser.