source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/TrieBasedModel.java @ 2026

Last change on this file since 2026 was 2026, checked in by sherbold, 9 years ago
  • made generation of random sequences with an expected valid end and a predefined maximum length more robust. the generation now aborts after a user-defined number of attempts to create a valid sequence and returns an empty sequence instead.
  • Property svn:mime-type set to text/plain
File size: 13.6 KB
RevLine 
[927]1//   Copyright 2012 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
[922]15package de.ugoe.cs.autoquest.usageprofiles;
[518]16
17import java.util.ArrayList;
18import java.util.Collection;
19import java.util.HashSet;
20import java.util.LinkedHashSet;
21import java.util.LinkedList;
22import java.util.List;
23import java.util.Random;
24import java.util.Set;
25
[922]26import de.ugoe.cs.autoquest.eventcore.Event;
27import de.ugoe.cs.autoquest.usageprofiles.Trie.Edge;
28import de.ugoe.cs.autoquest.usageprofiles.Trie.TrieVertex;
[518]29import edu.uci.ics.jung.graph.Tree;
30
31/**
32 * <p>
[559]33 * Implements a skeleton for stochastic processes that can calculate probabilities based on a trie.
34 * The skeleton provides all functionalities of {@link IStochasticProcess} except
[518]35 * {@link IStochasticProcess#getProbability(List, Event)}.
36 * </p>
37 *
38 * @author Steffen Herbold
39 * @version 1.0
40 */
41public abstract class TrieBasedModel implements IStochasticProcess {
42
[559]43    /**
44     * <p>
45     * Id for object serialization.
46     * </p>
47     */
48    private static final long serialVersionUID = 1L;
[518]49
[559]50    /**
51     * <p>
52     * The order of the trie, i.e., the maximum length of subsequences stored in the trie.
53     * </p>
54     */
55    protected int trieOrder;
[518]56
[559]57    /**
58     * <p>
59     * Trie on which the probability calculations are based.
60     * </p>
61     */
62    protected Trie<Event> trie = null;
[518]63
[559]64    /**
65     * <p>
66     * Random number generator used by probabilistic sequence generation methods.
67     * </p>
68     */
69    protected final Random r;
[518]70
[559]71    /**
72     * <p>
73     * Constructor. Creates a new TrieBasedModel that can be used for stochastic processes with a
74     * Markov order less than or equal to {@code markovOrder}.
75     * </p>
76     *
77     * @param markovOrder
78     *            Markov order of the model
79     * @param r
80     *            random number generator used by probabilistic methods of the class
[766]81     * @throws IllegalArgumentException
[559]82     *             thrown if markovOrder is less than 0 or the random number generator r is null
83     */
84    public TrieBasedModel(int markovOrder, Random r) {
85        super();
86        if (markovOrder < 0) {
[766]87            throw new IllegalArgumentException("markov order must not be less than 0");
[559]88        }
89        if (r == null) {
[766]90            throw new IllegalArgumentException("random number generator r must not be null");
[559]91        }
92        this.trieOrder = markovOrder + 1;
93        this.r = r;
94    }
[518]95
[559]96    /**
97     * <p>
98     * Trains the model by generating a trie from which probabilities are calculated. The trie is
99     * newly generated based solely on the passed sequences. If an existing model should only be
100     * updated, use {@link #update(Collection)} instead.
101     * </p>
102     *
103     * @param sequences
104     *            training data
[766]105     * @throws IllegalArgumentException
[559]106     *             thrown is sequences is null
107     */
108    public void train(Collection<List<Event>> sequences) {
109        trie = null;
110        update(sequences);
111    }
[518]112
[559]113    /**
114     * <p>
115     * Trains the model by updating the trie from which the probabilities are calculated. This
116     * function updates an existing trie. In case no trie exists yet, a new trie is generated and
117     * the function behaves like {@link #train(Collection)}.
118     * </p>
119     *
120     * @param sequences
121     *            training data
[766]122     * @throws IllegalArgumentException
[559]123     *             thrown is sequences is null
124     */
125    public void update(Collection<List<Event>> sequences) {
126        if (sequences == null) {
[766]127            throw new IllegalArgumentException("sequences must not be null");
[559]128        }
129        if (trie == null) {
130            trie = new Trie<Event>();
131        }
132        for (List<Event> sequence : sequences) {
133            List<Event> currentSequence = new LinkedList<Event>(sequence); // defensive
134                                                                           // copy
135            currentSequence.add(0, Event.STARTEVENT);
136            currentSequence.add(Event.ENDEVENT);
[518]137
[559]138            trie.train(currentSequence, trieOrder);
139        }
140    }
[518]141
[559]142    /*
143     * (non-Javadoc)
144     *
[922]145     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
[559]146     */
147    @Override
148    public List<Event> randomSequence() {
[2026]149        return randomSequence(Integer.MAX_VALUE, true, 100);
[559]150    }
[518]151
[559]152    /*
153     * (non-Javadoc)
154     *
[922]155     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
[559]156     */
157    @Override
[2026]158    public List<Event> randomSequence(int maxLength, boolean validEnd, long maxIter) {
[559]159        List<Event> sequence = new LinkedList<Event>();
[2026]160        int attempts = 0;
[559]161        if (trie != null) {
162            boolean endFound = false;
[2026]163            while (!endFound && attempts <= maxIter) { // outer loop for length checking
[559]164                IncompleteMemory<Event> context = new IncompleteMemory<Event>(trieOrder - 1);
165                context.add(Event.STARTEVENT);
[518]166
[559]167                while (!endFound && sequence.size() <= maxLength) {
168                    double randVal = r.nextDouble();
169                    double probSum = 0.0;
170                    List<Event> currentContext = context.getLast(trieOrder);
171                    for (Event symbol : trie.getKnownSymbols()) {
172                        probSum += getProbability(currentContext, symbol);
173                        if (probSum >= randVal) {
174                            if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT.equals(symbol)))
175                            {
176                                // only add the symbol the sequence if it is not
177                                // START or END
178                                context.add(symbol);
179                                sequence.add(symbol);
180                            }
181                            endFound =
182                                (Event.ENDEVENT.equals(symbol)) ||
183                                    (!validEnd && sequence.size() == maxLength);
184                            break;
185                        }
186                    }
187                }
[2026]188                if (!endFound) {
189                    sequence = new LinkedList<Event>();
190                }
191                attempts++;
[559]192            }
193        }
194        return sequence;
195    }
[518]196
[559]197    /**
198     * <p>
199     * Returns a Dot representation of the internal trie.
200     * </p>
201     *
202     * @return dot representation of the internal trie
203     */
204    public String getTrieDotRepresentation() {
205        if (trie == null) {
206            return "";
207        }
208        else {
209            return trie.getDotRepresentation();
210        }
211    }
[518]212
[559]213    /**
214     * <p>
215     * Returns a {@link Tree} of the internal trie that can be used for visualization.
216     * </p>
217     *
218     * @return {@link Tree} depicting the internal trie
219     */
220    public Tree<TrieVertex, Edge> getTrieGraph() {
221        if (trie == null) {
222            return null;
223        }
224        else {
225            return trie.getGraph();
226        }
227    }
[518]228
[559]229    /**
230     * <p>
231     * The string representation of the model is {@link Trie#toString()} of {@link #trie}.
232     * </p>
233     *
234     * @see java.lang.Object#toString()
235     */
236    @Override
237    public String toString() {
238        if (trie == null) {
239            return "";
240        }
241        else {
242            return trie.toString();
243        }
244    }
[518]245
[559]246    /*
247     * (non-Javadoc)
248     *
[922]249     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumStates()
[559]250     */
251    @Override
252    public int getNumSymbols() {
253        if (trie == null) {
254            return 0;
255        }
256        else {
257            return trie.getNumSymbols();
258        }
259    }
[518]260
[559]261    /*
262     * (non-Javadoc)
263     *
[922]264     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getStateStrings()
[559]265     */
266    @Override
267    public String[] getSymbolStrings() {
268        if (trie == null) {
269            return new String[0];
270        }
271        String[] stateStrings = new String[getNumSymbols()];
272        int i = 0;
273        for (Event symbol : trie.getKnownSymbols()) {
274            if (symbol.toString() == null) {
275                stateStrings[i] = "null";
276            }
277            else {
278                stateStrings[i] = symbol.toString();
279            }
280            i++;
281        }
282        return stateStrings;
283    }
[518]284
[559]285    /*
286     * (non-Javadoc)
287     *
[922]288     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getEvents()
[559]289     */
290    @Override
291    public Collection<Event> getEvents() {
292        if (trie == null) {
293            return new HashSet<Event>();
294        }
295        else {
296            return trie.getKnownSymbols();
297        }
298    }
[518]299
[559]300    /*
301     * (non-Javadoc)
302     *
[922]303     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int)
[559]304     */
305    @Override
306    public Collection<List<Event>> generateSequences(int length) {
307        return generateSequences(length, false);
308    }
[518]309
[559]310    /*
311     * (non-Javadoc)
312     *
[922]313     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int, boolean)
[559]314     */
315    @Override
316    public Set<List<Event>> generateSequences(int length, boolean fromStart) {
317        Set<List<Event>> sequenceSet = new LinkedHashSet<List<Event>>();
318        if (length < 1) {
[766]319            throw new IllegalArgumentException(
[1641]320                                               "Length of generated subsequences must be at least 1.");
[559]321        }
322        if (length == 1) {
323            if (fromStart) {
324                List<Event> subSeq = new LinkedList<Event>();
325                subSeq.add(Event.STARTEVENT);
326                sequenceSet.add(subSeq);
327            }
328            else {
329                for (Event event : getEvents()) {
330                    List<Event> subSeq = new LinkedList<Event>();
331                    subSeq.add(event);
332                    sequenceSet.add(subSeq);
333                }
334            }
335            return sequenceSet;
336        }
337        Collection<Event> events = getEvents();
338        Collection<List<Event>> seqsShorter = generateSequences(length - 1, fromStart);
339        for (Event event : events) {
340            for (List<Event> seqShorter : seqsShorter) {
341                Event lastEvent = event;
342                if (getProbability(seqShorter, lastEvent) > 0.0) {
343                    List<Event> subSeq = new ArrayList<Event>(seqShorter);
344                    subSeq.add(lastEvent);
345                    sequenceSet.add(subSeq);
346                }
347            }
348        }
349        return sequenceSet;
350    }
[518]351
[559]352    /*
353     * (non-Javadoc)
354     *
[922]355     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateValidSequences (int)
[559]356     */
357    @Override
358    public Collection<List<Event>> generateValidSequences(int length) {
359        // check for min-length implicitly done by generateSequences
360        Collection<List<Event>> allSequences = generateSequences(length, true);
361        Collection<List<Event>> validSequences = new LinkedHashSet<List<Event>>();
362        for (List<Event> sequence : allSequences) {
363            if (sequence.size() == length &&
364                Event.ENDEVENT.equals(sequence.get(sequence.size() - 1)))
365            {
366                validSequences.add(sequence);
367            }
368        }
369        return validSequences;
370    }
[518]371
[559]372    /*
373     * (non-Javadoc)
374     *
[922]375     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
[559]376     */
377    @Override
378    public double getProbability(List<Event> sequence) {
379        if (sequence == null) {
[766]380            throw new IllegalArgumentException("sequence must not be null");
[559]381        }
382        double prob = 1.0;
383        List<Event> context = new LinkedList<Event>();
384        for (Event event : sequence) {
385            prob *= getProbability(context, event);
386            context.add(event);
387        }
388        return prob;
389    }
[518]390
[559]391    /*
392     * (non-Javadoc)
393     *
[1641]394     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
395     */
396    @Override
397    public double getLogSum(List<Event> sequence) {
398        if (sequence == null) {
399            throw new IllegalArgumentException("sequence must not be null");
400        }
401        double odds = 0.0;
402        List<Event> context = new LinkedList<Event>();
403        for (Event event : sequence) {
[2026]404            odds += Math.log(getProbability(context, event) + 1);
[1641]405            context.add(event);
406        }
407        return odds;
408    }
409
410    /*
411     * (non-Javadoc)
412     *
[922]413     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumFOMStates()
[559]414     */
415    @Override
416    public int getNumFOMStates() {
417        if (trie == null) {
418            return 0;
419        }
420        else {
421            return trie.getNumLeafAncestors();
422        }
423    }
[518]424
[559]425    /*
426     * (non-Javadoc)
427     *
[922]428     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumTransitions()
[559]429     */
430    @Override
431    public int getNumTransitions() {
432        if (trie == null) {
433            return 0;
434        }
435        else {
436            return trie.getNumLeafs();
437        }
438    }
439}
Note: See TracBrowser for help on using the repository browser.