source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/TrieBasedModel.java @ 1762

Last change on this file since 1762 was 1762, checked in by sherbold, 10 years ago
  • adapted getLogSum method to use log(x+1) instead of -log for the summation of the probabilities
  • Property svn:mime-type set to text/plain
File size: 13.5 KB
Line 
1//   Copyright 2012 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.autoquest.usageprofiles;
16
17import java.util.ArrayList;
18import java.util.Collection;
19import java.util.HashSet;
20import java.util.LinkedHashSet;
21import java.util.LinkedList;
22import java.util.List;
23import java.util.Random;
24import java.util.Set;
25
26import de.ugoe.cs.autoquest.eventcore.Event;
27import de.ugoe.cs.autoquest.usageprofiles.Trie.Edge;
28import de.ugoe.cs.autoquest.usageprofiles.Trie.TrieVertex;
29import edu.uci.ics.jung.graph.Tree;
30
31/**
32 * <p>
33 * Implements a skeleton for stochastic processes that can calculate probabilities based on a trie.
34 * The skeleton provides all functionalities of {@link IStochasticProcess} except
35 * {@link IStochasticProcess#getProbability(List, Event)}.
36 * </p>
37 *
38 * @author Steffen Herbold
39 * @version 1.0
40 */
41public abstract class TrieBasedModel implements IStochasticProcess {
42
43    /**
44     * <p>
45     * Id for object serialization.
46     * </p>
47     */
48    private static final long serialVersionUID = 1L;
49
50    /**
51     * <p>
52     * The order of the trie, i.e., the maximum length of subsequences stored in the trie.
53     * </p>
54     */
55    protected int trieOrder;
56
57    /**
58     * <p>
59     * Trie on which the probability calculations are based.
60     * </p>
61     */
62    protected Trie<Event> trie = null;
63
64    /**
65     * <p>
66     * Random number generator used by probabilistic sequence generation methods.
67     * </p>
68     */
69    protected final Random r;
70
71    /**
72     * <p>
73     * Constructor. Creates a new TrieBasedModel that can be used for stochastic processes with a
74     * Markov order less than or equal to {@code markovOrder}.
75     * </p>
76     *
77     * @param markovOrder
78     *            Markov order of the model
79     * @param r
80     *            random number generator used by probabilistic methods of the class
81     * @throws IllegalArgumentException
82     *             thrown if markovOrder is less than 0 or the random number generator r is null
83     */
84    public TrieBasedModel(int markovOrder, Random r) {
85        super();
86        if (markovOrder < 0) {
87            throw new IllegalArgumentException("markov order must not be less than 0");
88        }
89        if (r == null) {
90            throw new IllegalArgumentException("random number generator r must not be null");
91        }
92        this.trieOrder = markovOrder + 1;
93        this.r = r;
94    }
95
96    /**
97     * <p>
98     * Trains the model by generating a trie from which probabilities are calculated. The trie is
99     * newly generated based solely on the passed sequences. If an existing model should only be
100     * updated, use {@link #update(Collection)} instead.
101     * </p>
102     *
103     * @param sequences
104     *            training data
105     * @throws IllegalArgumentException
106     *             thrown is sequences is null
107     */
108    public void train(Collection<List<Event>> sequences) {
109        trie = null;
110        update(sequences);
111    }
112
113    /**
114     * <p>
115     * Trains the model by updating the trie from which the probabilities are calculated. This
116     * function updates an existing trie. In case no trie exists yet, a new trie is generated and
117     * the function behaves like {@link #train(Collection)}.
118     * </p>
119     *
120     * @param sequences
121     *            training data
122     * @throws IllegalArgumentException
123     *             thrown is sequences is null
124     */
125    public void update(Collection<List<Event>> sequences) {
126        if (sequences == null) {
127            throw new IllegalArgumentException("sequences must not be null");
128        }
129        if (trie == null) {
130            trie = new Trie<Event>();
131        }
132        for (List<Event> sequence : sequences) {
133            List<Event> currentSequence = new LinkedList<Event>(sequence); // defensive
134                                                                           // copy
135            currentSequence.add(0, Event.STARTEVENT);
136            currentSequence.add(Event.ENDEVENT);
137
138            trie.train(currentSequence, trieOrder);
139        }
140    }
141
142    /*
143     * (non-Javadoc)
144     *
145     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
146     */
147    @Override
148    public List<Event> randomSequence() {
149        return randomSequence(Integer.MAX_VALUE, true);
150    }
151
152    /*
153     * (non-Javadoc)
154     *
155     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
156     */
157    @Override
158    public List<Event> randomSequence(int maxLength, boolean validEnd) {
159        List<Event> sequence = new LinkedList<Event>();
160        if (trie != null) {
161            boolean endFound = false;
162            while (!endFound) { // outer loop for length checking
163                sequence = new LinkedList<Event>();
164                IncompleteMemory<Event> context = new IncompleteMemory<Event>(trieOrder - 1);
165                context.add(Event.STARTEVENT);
166
167                while (!endFound && sequence.size() <= maxLength) {
168                    double randVal = r.nextDouble();
169                    double probSum = 0.0;
170                    List<Event> currentContext = context.getLast(trieOrder);
171                    for (Event symbol : trie.getKnownSymbols()) {
172                        probSum += getProbability(currentContext, symbol);
173                        if (probSum >= randVal) {
174                            if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT.equals(symbol)))
175                            {
176                                // only add the symbol the sequence if it is not
177                                // START or END
178                                context.add(symbol);
179                                sequence.add(symbol);
180                            }
181                            endFound =
182                                (Event.ENDEVENT.equals(symbol)) ||
183                                    (!validEnd && sequence.size() == maxLength);
184                            break;
185                        }
186                    }
187                }
188            }
189        }
190        return sequence;
191    }
192
193    /**
194     * <p>
195     * Returns a Dot representation of the internal trie.
196     * </p>
197     *
198     * @return dot representation of the internal trie
199     */
200    public String getTrieDotRepresentation() {
201        if (trie == null) {
202            return "";
203        }
204        else {
205            return trie.getDotRepresentation();
206        }
207    }
208
209    /**
210     * <p>
211     * Returns a {@link Tree} of the internal trie that can be used for visualization.
212     * </p>
213     *
214     * @return {@link Tree} depicting the internal trie
215     */
216    public Tree<TrieVertex, Edge> getTrieGraph() {
217        if (trie == null) {
218            return null;
219        }
220        else {
221            return trie.getGraph();
222        }
223    }
224
225    /**
226     * <p>
227     * The string representation of the model is {@link Trie#toString()} of {@link #trie}.
228     * </p>
229     *
230     * @see java.lang.Object#toString()
231     */
232    @Override
233    public String toString() {
234        if (trie == null) {
235            return "";
236        }
237        else {
238            return trie.toString();
239        }
240    }
241
242    /*
243     * (non-Javadoc)
244     *
245     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumStates()
246     */
247    @Override
248    public int getNumSymbols() {
249        if (trie == null) {
250            return 0;
251        }
252        else {
253            return trie.getNumSymbols();
254        }
255    }
256
257    /*
258     * (non-Javadoc)
259     *
260     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getStateStrings()
261     */
262    @Override
263    public String[] getSymbolStrings() {
264        if (trie == null) {
265            return new String[0];
266        }
267        String[] stateStrings = new String[getNumSymbols()];
268        int i = 0;
269        for (Event symbol : trie.getKnownSymbols()) {
270            if (symbol.toString() == null) {
271                stateStrings[i] = "null";
272            }
273            else {
274                stateStrings[i] = symbol.toString();
275            }
276            i++;
277        }
278        return stateStrings;
279    }
280
281    /*
282     * (non-Javadoc)
283     *
284     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getEvents()
285     */
286    @Override
287    public Collection<Event> getEvents() {
288        if (trie == null) {
289            return new HashSet<Event>();
290        }
291        else {
292            return trie.getKnownSymbols();
293        }
294    }
295
296    /*
297     * (non-Javadoc)
298     *
299     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int)
300     */
301    @Override
302    public Collection<List<Event>> generateSequences(int length) {
303        return generateSequences(length, false);
304    }
305
306    /*
307     * (non-Javadoc)
308     *
309     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int, boolean)
310     */
311    @Override
312    public Set<List<Event>> generateSequences(int length, boolean fromStart) {
313        Set<List<Event>> sequenceSet = new LinkedHashSet<List<Event>>();
314        if (length < 1) {
315            throw new IllegalArgumentException(
316                                               "Length of generated subsequences must be at least 1.");
317        }
318        if (length == 1) {
319            if (fromStart) {
320                List<Event> subSeq = new LinkedList<Event>();
321                subSeq.add(Event.STARTEVENT);
322                sequenceSet.add(subSeq);
323            }
324            else {
325                for (Event event : getEvents()) {
326                    List<Event> subSeq = new LinkedList<Event>();
327                    subSeq.add(event);
328                    sequenceSet.add(subSeq);
329                }
330            }
331            return sequenceSet;
332        }
333        Collection<Event> events = getEvents();
334        Collection<List<Event>> seqsShorter = generateSequences(length - 1, fromStart);
335        for (Event event : events) {
336            for (List<Event> seqShorter : seqsShorter) {
337                Event lastEvent = event;
338                if (getProbability(seqShorter, lastEvent) > 0.0) {
339                    List<Event> subSeq = new ArrayList<Event>(seqShorter);
340                    subSeq.add(lastEvent);
341                    sequenceSet.add(subSeq);
342                }
343            }
344        }
345        return sequenceSet;
346    }
347
348    /*
349     * (non-Javadoc)
350     *
351     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateValidSequences (int)
352     */
353    @Override
354    public Collection<List<Event>> generateValidSequences(int length) {
355        // check for min-length implicitly done by generateSequences
356        Collection<List<Event>> allSequences = generateSequences(length, true);
357        Collection<List<Event>> validSequences = new LinkedHashSet<List<Event>>();
358        for (List<Event> sequence : allSequences) {
359            if (sequence.size() == length &&
360                Event.ENDEVENT.equals(sequence.get(sequence.size() - 1)))
361            {
362                validSequences.add(sequence);
363            }
364        }
365        return validSequences;
366    }
367
368    /*
369     * (non-Javadoc)
370     *
371     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
372     */
373    @Override
374    public double getProbability(List<Event> sequence) {
375        if (sequence == null) {
376            throw new IllegalArgumentException("sequence must not be null");
377        }
378        double prob = 1.0;
379        List<Event> context = new LinkedList<Event>();
380        for (Event event : sequence) {
381            prob *= getProbability(context, event);
382            context.add(event);
383        }
384        return prob;
385    }
386
387    /*
388     * (non-Javadoc)
389     *
390     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
391     */
392    @Override
393    public double getLogSum(List<Event> sequence) {
394        if (sequence == null) {
395            throw new IllegalArgumentException("sequence must not be null");
396        }
397        double odds = 0.0;
398        List<Event> context = new LinkedList<Event>();
399        for (Event event : sequence) {
400            odds = Math.log(getProbability(context, event)+1);
401            context.add(event);
402        }
403        return odds;
404    }
405
406    /*
407     * (non-Javadoc)
408     *
409     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumFOMStates()
410     */
411    @Override
412    public int getNumFOMStates() {
413        if (trie == null) {
414            return 0;
415        }
416        else {
417            return trie.getNumLeafAncestors();
418        }
419    }
420
421    /*
422     * (non-Javadoc)
423     *
424     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumTransitions()
425     */
426    @Override
427    public int getNumTransitions() {
428        if (trie == null) {
429            return 0;
430        }
431        else {
432            return trie.getNumLeafs();
433        }
434    }
435}
Note: See TracBrowser for help on using the repository browser.