source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/TrieBasedModel.java @ 922

Last change on this file since 922 was 922, checked in by sherbold, 12 years ago
  • renaming of packages from de.ugoe.cs.quest to de.ugoe.cs.autoquest
  • Property svn:mime-type set to text/plain
File size: 12.3 KB
Line 
1package de.ugoe.cs.autoquest.usageprofiles;
2
3import java.util.ArrayList;
4import java.util.Collection;
5import java.util.HashSet;
6import java.util.LinkedHashSet;
7import java.util.LinkedList;
8import java.util.List;
9import java.util.Random;
10import java.util.Set;
11
12import de.ugoe.cs.autoquest.eventcore.Event;
13import de.ugoe.cs.autoquest.usageprofiles.Trie.Edge;
14import de.ugoe.cs.autoquest.usageprofiles.Trie.TrieVertex;
15import edu.uci.ics.jung.graph.Tree;
16
17/**
18 * <p>
19 * Implements a skeleton for stochastic processes that can calculate probabilities based on a trie.
20 * The skeleton provides all functionalities of {@link IStochasticProcess} except
21 * {@link IStochasticProcess#getProbability(List, Event)}.
22 * </p>
23 *
24 * @author Steffen Herbold
25 * @version 1.0
26 */
27public abstract class TrieBasedModel implements IStochasticProcess {
28
29    /**
30     * <p>
31     * Id for object serialization.
32     * </p>
33     */
34    private static final long serialVersionUID = 1L;
35
36    /**
37     * <p>
38     * The order of the trie, i.e., the maximum length of subsequences stored in the trie.
39     * </p>
40     */
41    protected int trieOrder;
42
43    /**
44     * <p>
45     * Trie on which the probability calculations are based.
46     * </p>
47     */
48    protected Trie<Event> trie = null;
49
50    /**
51     * <p>
52     * Random number generator used by probabilistic sequence generation methods.
53     * </p>
54     */
55    protected final Random r;
56
57    /**
58     * <p>
59     * Constructor. Creates a new TrieBasedModel that can be used for stochastic processes with a
60     * Markov order less than or equal to {@code markovOrder}.
61     * </p>
62     *
63     * @param markovOrder
64     *            Markov order of the model
65     * @param r
66     *            random number generator used by probabilistic methods of the class
67     * @throws IllegalArgumentException
68     *             thrown if markovOrder is less than 0 or the random number generator r is null
69     */
70    public TrieBasedModel(int markovOrder, Random r) {
71        super();
72        if (markovOrder < 0) {
73            throw new IllegalArgumentException("markov order must not be less than 0");
74        }
75        if (r == null) {
76            throw new IllegalArgumentException("random number generator r must not be null");
77        }
78        this.trieOrder = markovOrder + 1;
79        this.r = r;
80    }
81
82    /**
83     * <p>
84     * Trains the model by generating a trie from which probabilities are calculated. The trie is
85     * newly generated based solely on the passed sequences. If an existing model should only be
86     * updated, use {@link #update(Collection)} instead.
87     * </p>
88     *
89     * @param sequences
90     *            training data
91     * @throws IllegalArgumentException
92     *             thrown is sequences is null
93     */
94    public void train(Collection<List<Event>> sequences) {
95        trie = null;
96        update(sequences);
97    }
98
99    /**
100     * <p>
101     * Trains the model by updating the trie from which the probabilities are calculated. This
102     * function updates an existing trie. In case no trie exists yet, a new trie is generated and
103     * the function behaves like {@link #train(Collection)}.
104     * </p>
105     *
106     * @param sequences
107     *            training data
108     * @throws IllegalArgumentException
109     *             thrown is sequences is null
110     */
111    public void update(Collection<List<Event>> sequences) {
112        if (sequences == null) {
113            throw new IllegalArgumentException("sequences must not be null");
114        }
115        if (trie == null) {
116            trie = new Trie<Event>();
117        }
118        for (List<Event> sequence : sequences) {
119            List<Event> currentSequence = new LinkedList<Event>(sequence); // defensive
120                                                                           // copy
121            currentSequence.add(0, Event.STARTEVENT);
122            currentSequence.add(Event.ENDEVENT);
123
124            trie.train(currentSequence, trieOrder);
125        }
126    }
127
128    /*
129     * (non-Javadoc)
130     *
131     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
132     */
133    @Override
134    public List<Event> randomSequence() {
135        return randomSequence(Integer.MAX_VALUE, true);
136    }
137
138    /*
139     * (non-Javadoc)
140     *
141     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#randomSequence()
142     */
143    @Override
144    public List<Event> randomSequence(int maxLength, boolean validEnd) {
145        List<Event> sequence = new LinkedList<Event>();
146        if (trie != null) {
147            boolean endFound = false;
148            while (!endFound) { // outer loop for length checking
149                sequence = new LinkedList<Event>();
150                IncompleteMemory<Event> context = new IncompleteMemory<Event>(trieOrder - 1);
151                context.add(Event.STARTEVENT);
152
153                while (!endFound && sequence.size() <= maxLength) {
154                    double randVal = r.nextDouble();
155                    double probSum = 0.0;
156                    List<Event> currentContext = context.getLast(trieOrder);
157                    for (Event symbol : trie.getKnownSymbols()) {
158                        probSum += getProbability(currentContext, symbol);
159                        if (probSum >= randVal) {
160                            if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT.equals(symbol)))
161                            {
162                                // only add the symbol the sequence if it is not
163                                // START or END
164                                context.add(symbol);
165                                sequence.add(symbol);
166                            }
167                            endFound =
168                                (Event.ENDEVENT.equals(symbol)) ||
169                                    (!validEnd && sequence.size() == maxLength);
170                            break;
171                        }
172                    }
173                }
174            }
175        }
176        return sequence;
177    }
178
179    /**
180     * <p>
181     * Returns a Dot representation of the internal trie.
182     * </p>
183     *
184     * @return dot representation of the internal trie
185     */
186    public String getTrieDotRepresentation() {
187        if (trie == null) {
188            return "";
189        }
190        else {
191            return trie.getDotRepresentation();
192        }
193    }
194
195    /**
196     * <p>
197     * Returns a {@link Tree} of the internal trie that can be used for visualization.
198     * </p>
199     *
200     * @return {@link Tree} depicting the internal trie
201     */
202    public Tree<TrieVertex, Edge> getTrieGraph() {
203        if (trie == null) {
204            return null;
205        }
206        else {
207            return trie.getGraph();
208        }
209    }
210
211    /**
212     * <p>
213     * The string representation of the model is {@link Trie#toString()} of {@link #trie}.
214     * </p>
215     *
216     * @see java.lang.Object#toString()
217     */
218    @Override
219    public String toString() {
220        if (trie == null) {
221            return "";
222        }
223        else {
224            return trie.toString();
225        }
226    }
227
228    /*
229     * (non-Javadoc)
230     *
231     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumStates()
232     */
233    @Override
234    public int getNumSymbols() {
235        if (trie == null) {
236            return 0;
237        }
238        else {
239            return trie.getNumSymbols();
240        }
241    }
242
243    /*
244     * (non-Javadoc)
245     *
246     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getStateStrings()
247     */
248    @Override
249    public String[] getSymbolStrings() {
250        if (trie == null) {
251            return new String[0];
252        }
253        String[] stateStrings = new String[getNumSymbols()];
254        int i = 0;
255        for (Event symbol : trie.getKnownSymbols()) {
256            if (symbol.toString() == null) {
257                stateStrings[i] = "null";
258            }
259            else {
260                stateStrings[i] = symbol.toString();
261            }
262            i++;
263        }
264        return stateStrings;
265    }
266
267    /*
268     * (non-Javadoc)
269     *
270     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getEvents()
271     */
272    @Override
273    public Collection<Event> getEvents() {
274        if (trie == null) {
275            return new HashSet<Event>();
276        }
277        else {
278            return trie.getKnownSymbols();
279        }
280    }
281
282    /*
283     * (non-Javadoc)
284     *
285     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int)
286     */
287    @Override
288    public Collection<List<Event>> generateSequences(int length) {
289        return generateSequences(length, false);
290    }
291
292    /*
293     * (non-Javadoc)
294     *
295     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateSequences(int, boolean)
296     */
297    @Override
298    public Set<List<Event>> generateSequences(int length, boolean fromStart) {
299        Set<List<Event>> sequenceSet = new LinkedHashSet<List<Event>>();
300        if (length < 1) {
301            throw new IllegalArgumentException(
302                                                "Length of generated subsequences must be at least 1.");
303        }
304        if (length == 1) {
305            if (fromStart) {
306                List<Event> subSeq = new LinkedList<Event>();
307                subSeq.add(Event.STARTEVENT);
308                sequenceSet.add(subSeq);
309            }
310            else {
311                for (Event event : getEvents()) {
312                    List<Event> subSeq = new LinkedList<Event>();
313                    subSeq.add(event);
314                    sequenceSet.add(subSeq);
315                }
316            }
317            return sequenceSet;
318        }
319        Collection<Event> events = getEvents();
320        Collection<List<Event>> seqsShorter = generateSequences(length - 1, fromStart);
321        for (Event event : events) {
322            for (List<Event> seqShorter : seqsShorter) {
323                Event lastEvent = event;
324                if (getProbability(seqShorter, lastEvent) > 0.0) {
325                    List<Event> subSeq = new ArrayList<Event>(seqShorter);
326                    subSeq.add(lastEvent);
327                    sequenceSet.add(subSeq);
328                }
329            }
330        }
331        return sequenceSet;
332    }
333
334    /*
335     * (non-Javadoc)
336     *
337     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#generateValidSequences (int)
338     */
339    @Override
340    public Collection<List<Event>> generateValidSequences(int length) {
341        // check for min-length implicitly done by generateSequences
342        Collection<List<Event>> allSequences = generateSequences(length, true);
343        Collection<List<Event>> validSequences = new LinkedHashSet<List<Event>>();
344        for (List<Event> sequence : allSequences) {
345            if (sequence.size() == length &&
346                Event.ENDEVENT.equals(sequence.get(sequence.size() - 1)))
347            {
348                validSequences.add(sequence);
349            }
350        }
351        return validSequences;
352    }
353
354    /*
355     * (non-Javadoc)
356     *
357     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getProbability(java.util .List)
358     */
359    @Override
360    public double getProbability(List<Event> sequence) {
361        if (sequence == null) {
362            throw new IllegalArgumentException("sequence must not be null");
363        }
364        double prob = 1.0;
365        List<Event> context = new LinkedList<Event>();
366        for (Event event : sequence) {
367            prob *= getProbability(context, event);
368            context.add(event);
369        }
370        return prob;
371    }
372
373    /*
374     * (non-Javadoc)
375     *
376     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumFOMStates()
377     */
378    @Override
379    public int getNumFOMStates() {
380        if (trie == null) {
381            return 0;
382        }
383        else {
384            return trie.getNumLeafAncestors();
385        }
386    }
387
388    /*
389     * (non-Javadoc)
390     *
391     * @see de.ugoe.cs.autoquest.usageprofiles.IStochasticProcess#getNumTransitions()
392     */
393    @Override
394    public int getNumTransitions() {
395        if (trie == null) {
396            return 0;
397        }
398        else {
399            return trie.getNumLeafs();
400        }
401    }
402}
Note: See TracBrowser for help on using the repository browser.