source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/Trie.java @ 1251

Last change on this file since 1251 was 1251, checked in by pharms, 11 years ago
  • improved performance of the trie for large alphabets by using the symbol map. This is an improved list of symbols that allows a more efficient lookup for symbols using buckets of symbol as an initial search order
  • Property svn:mime-type set to text/plain
File size: 21.5 KB
Line 
1//   Copyright 2012 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.autoquest.usageprofiles;
16
17import java.io.Serializable;
18import java.util.Collection;
19import java.util.LinkedHashSet;
20import java.util.LinkedList;
21import java.util.List;
22
23import de.ugoe.cs.util.StringTools;
24
25import edu.uci.ics.jung.graph.DelegateTree;
26import edu.uci.ics.jung.graph.Graph;
27import edu.uci.ics.jung.graph.Tree;
28
29/**
30 * <p>
31 * This class implements a <it>trie</it>, i.e., a tree of sequences that the occurence of
32 * subsequences up to a predefined length. This length is the trie order.
33 * </p>
34 *
35 * @author Steffen Herbold, Patrick Harms
36 *
37 * @param <T>
38 *            Type of the symbols that are stored in the trie.
39 *
40 * @see TrieNode
41 */
42public class Trie<T> implements IDotCompatible, Serializable {
43
44    /**
45     * <p>
46     * Id for object serialization.
47     * </p>
48     */
49    private static final long serialVersionUID = 1L;
50
51    /**
52     * <p>
53     * Collection of all symbols occuring in the trie.
54     * </p>
55     */
56    private SymbolMap<T, T> knownSymbols;
57
58    /**
59     * <p>
60     * Reference to the root of the trie.
61     * </p>
62     */
63    private final TrieNode<T> rootNode;
64
65    /**
66     * <p>
67     * Comparator to be used for comparing the symbols with each other
68     * </p>
69     */
70    private SymbolComparator<T> comparator;
71
72    /**
73     * <p>
74     * Contructor. Creates a new Trie with a {@link DefaultSymbolComparator}.
75     * </p>
76     */
77    public Trie() {
78        this(new DefaultSymbolComparator<T>());
79    }
80
81    /**
82     * <p>
83     * Contructor. Creates a new Trie with that uses a specific {@link SymbolComparator}.
84     * </p>
85     */
86    public Trie(SymbolComparator<T> comparator) {
87        this.comparator = comparator;
88        rootNode = new TrieNode<T>(comparator);
89        knownSymbols = new SymbolMap<T, T>(this.comparator);
90    }
91
92    /**
93     * <p>
94     * Copy-Constructor. Creates a new Trie as the copy of other. The other trie must not be null.
95     * </p>
96     *
97     * @param other
98     *            Trie that is copied
99     */
100    public Trie(Trie<T> other) {
101        if (other == null) {
102            throw new IllegalArgumentException("other trie must not be null");
103        }
104        rootNode = new TrieNode<T>(other.rootNode);
105        knownSymbols = new SymbolMap<T, T>(other.knownSymbols);
106        comparator = other.comparator;
107    }
108
109    /**
110     * <p>
111     * Returns a collection of all symbols occuring in the trie.
112     * </p>
113     *
114     * @return symbols occuring in the trie
115     */
116    public Collection<T> getKnownSymbols() {
117        return new LinkedHashSet<T>(knownSymbols.getSymbols());
118    }
119
120    /**
121     * <p>
122     * Trains the current trie using the given sequence and adds all subsequences of length
123     * {@code maxOrder}.
124     * </p>
125     *
126     * @param sequence
127     *            sequence whose subsequences are added to the trie
128     * @param maxOrder
129     *            maximum length of the subsequences added to the trie
130     */
131    public void train(List<T> sequence, int maxOrder) {
132        if (maxOrder < 1) {
133            return;
134        }
135        IncompleteMemory<T> latestActions = new IncompleteMemory<T>(maxOrder);
136        int i = 0;
137        for (T currentEvent : sequence) {
138            latestActions.add(currentEvent);
139            addToKnownSymbols(currentEvent);
140            i++;
141            if (i >= maxOrder) {
142                add(latestActions.getLast(maxOrder));
143            }
144        }
145        int sequenceLength = sequence.size();
146        int startIndex = Math.max(0, sequenceLength - maxOrder + 1);
147        for (int j = startIndex; j < sequenceLength; j++) {
148            add(sequence.subList(j, sequenceLength));
149        }
150    }
151
152    /**
153     * <p>
154     * Adds a given subsequence to the trie and increases the counters accordingly. NOTE: This
155     * method does not add the symbols to the list of known symbols. This is only ensured using
156     * the method {@link #train(List, int)}.
157     * </p>
158     *
159     * @param subsequence
160     *            subsequence whose counters are increased
161     * @see TrieNode#add(List)
162     */
163    protected void add(List<T> subsequence) {
164        if (subsequence != null && !subsequence.isEmpty()) {
165            subsequence = new LinkedList<T>(subsequence); // defensive copy!
166            T firstSymbol = subsequence.get(0);
167            TrieNode<T> node = getChildCreate(firstSymbol);
168            node.add(subsequence);
169        }
170    }
171
172    /**
173     * <p>
174     * Returns the child of the root node associated with the given symbol or creates it if it does
175     * not exist yet.
176     * </p>
177     *
178     * @param symbol
179     *            symbol whose node is required
180     * @return node associated with the symbol
181     * @see TrieNode#getChildCreate(Object)
182     */
183    protected TrieNode<T> getChildCreate(T symbol) {
184        return rootNode.getChildCreate(symbol);
185    }
186
187    /**
188     * <p>
189     * Returns the child of the root node associated with the given symbol or null if it does not
190     * exist.
191     * </p>
192     *
193     * @param symbol
194     *            symbol whose node is required
195     * @return node associated with the symbol; null if no such node exists
196     * @see TrieNode#getChild(Object)
197     */
198    protected TrieNode<T> getChild(T symbol) {
199        return rootNode.getChild(symbol);
200    }
201
202    /**
203     * <p>
204     * Returns the number of occurences of the given sequence.
205     * </p>
206     *
207     * @param sequence
208     *            sequence whose number of occurences is required
209     * @return number of occurences of the sequence
210     */
211    public int getCount(List<T> sequence) {
212        int count = 0;
213        TrieNode<T> node = find(sequence);
214        if (node != null) {
215            count = node.getCount();
216        }
217        return count;
218    }
219
220    /**
221     * <p>
222     * Returns the number of occurences of the given prefix and a symbol that follows it.<br>
223     * Convenience function to simplify usage of {@link #getCount(List)}.
224     * </p>
225     *
226     * @param sequence
227     *            prefix of the sequence
228     * @param follower
229     *            suffix of the sequence
230     * @return number of occurences of the sequence
231     * @see #getCount(List)
232     */
233    public int getCount(List<T> sequence, T follower) {
234        List<T> tmpSequence = new LinkedList<T>(sequence);
235        tmpSequence.add(follower);
236        return getCount(tmpSequence);
237
238    }
239
240    /**
241     * <p>
242     * Searches the trie for a given sequence and returns the node associated with the sequence or
243     * null if no such node is found.
244     * </p>
245     *
246     * @param sequence
247     *            sequence that is searched for
248     * @return node associated with the sequence
249     * @see TrieNode#find(List)
250     */
251    public TrieNode<T> find(List<T> sequence) {
252        if (sequence == null || sequence.isEmpty()) {
253            return rootNode;
254        }
255        List<T> sequenceCopy = new LinkedList<T>(sequence);
256        TrieNode<T> result = null;
257        TrieNode<T> node = getChild(sequenceCopy.get(0));
258        if (node != null) {
259            sequenceCopy.remove(0);
260            result = node.find(sequenceCopy);
261        }
262        return result;
263    }
264
265    /**
266     * <p>
267     * Returns a collection of all symbols that follow a given sequence in the trie. In case the
268     * sequence is not found or no symbols follow the sequence the result will be empty.
269     * </p>
270     *
271     * @param sequence
272     *            sequence whose followers are returned
273     * @return symbols following the given sequence
274     * @see TrieNode#getFollowingSymbols()
275     */
276    public Collection<T> getFollowingSymbols(List<T> sequence) {
277        Collection<T> result = new LinkedList<T>();
278        TrieNode<T> node = find(sequence);
279        if (node != null) {
280            result = node.getFollowingSymbols();
281        }
282        return result;
283    }
284
285    /**
286     * <p>
287     * Returns the longest suffix of the given context that is contained in the tree and whose
288     * children are leaves.
289     * </p>
290     *
291     * @param context
292     *            context whose suffix is searched for
293     * @return longest suffix of the context
294     */
295    public List<T> getContextSuffix(List<T> context) {
296        List<T> contextSuffix;
297        if (context != null) {
298            contextSuffix = new LinkedList<T>(context); // defensive copy
299        }
300        else {
301            contextSuffix = new LinkedList<T>();
302        }
303        boolean suffixFound = false;
304
305        while (!suffixFound) {
306            if (contextSuffix.isEmpty()) {
307                suffixFound = true; // suffix is the empty word
308            }
309            else {
310                TrieNode<T> node = find(contextSuffix);
311                if (node != null) {
312                    if (!node.getFollowingSymbols().isEmpty()) {
313                        suffixFound = true;
314                    }
315                }
316                if (!suffixFound) {
317                    contextSuffix.remove(0);
318                }
319            }
320        }
321
322        return contextSuffix;
323    }
324   
325    /**
326     * <p>
327     * used to recursively process the trie. The provided processor will be called for any path
328     * through the tree. The processor may abort the processing through returns values of its
329     * {@link TrieProcessor#process(List, int)} method.
330     * </p>
331     *
332     * @param processor the processor to process the tree
333     */
334    public void process(TrieProcessor<T> processor) {
335        LinkedList<T> context = new LinkedList<T>();
336       
337        for (TrieNode<T> child : rootNode.getChildren()) {
338            if (!process(context, child, processor)) {
339                break;
340            }
341        }
342    }
343
344    /**
345     * <p>
346     * processes a specific path by calling the provided processor. Furthermore, the method
347     * calls itself recursively for further subpaths.
348     * </p>
349     *
350     * @param context   the context of the currently processed trie node, i.e. the preceeding
351     *                  symbols
352     * @param child     the processed trie node
353     * @param processor the processor used for processing the trie
354     *
355     * @return true, if processing shall continue, false else
356     */
357    private boolean process(LinkedList<T>    context,
358                            TrieNode<T>      node,
359                            TrieProcessor<T> processor)
360    {
361        context.add(node.getSymbol());
362       
363        TrieProcessor.Result result = processor.process(context, node.getCount());
364       
365        if (result == TrieProcessor.Result.CONTINUE) {
366            for (TrieNode<T> child : node.getChildren()) {
367                if (!process(context, child, processor)) {
368                    break;
369                }
370            }
371        }
372       
373        context.removeLast();
374       
375        return result != TrieProcessor.Result.BREAK;
376    }
377
378    /**
379     * <p>
380     * returns a list of symbol sequences which have a minimal length and that occurred as often
381     * as defined by the given occurrence count. If the given occurrence count is smaller 1 then
382     * those sequences are returned, that occur most often. The resulting list is empty, if there
383     * is no symbol sequence with the minimal length or the provided number of occurrences.
384     * </p>
385     *
386     * @param minimalLength   the minimal length of the returned sequences
387     * @param occurrenceCount the number of occurrences of the returned sequences
388     *
389     * @return as described
390     */
391    public Collection<List<T>> getSequencesWithOccurrenceCount(int minimalLength,
392                                                               int occurrenceCount)
393    {
394        LinkedList<TrieNode<T>> context = new LinkedList<TrieNode<T>>();
395        Collection<List<TrieNode<T>>> paths = new LinkedList<List<TrieNode<T>>>();
396       
397        context.push(rootNode);
398       
399        // traverse the trie and determine all sequences, which have the provided number of
400        // occurrences and a minimal length.
401       
402        // minimalLength + 1 because we denote the depth including the root node
403        determineLongPathsWithMostOccurrences(minimalLength + 1, occurrenceCount, paths, context);
404       
405        Collection<List<T>> resultingPaths = new LinkedList<List<T>>();
406        List<T> resultingPath;
407       
408        if (paths.size() > 0) {
409           
410            for (List<TrieNode<T>> path : paths) {
411                resultingPath = new LinkedList<T>();
412               
413                for (TrieNode<T> node : path) {
414                    if (node.getSymbol() != null) {
415                        resultingPath.add(node.getSymbol());
416                    }
417                }
418               
419                resultingPaths.add(resultingPath);
420            }
421        }
422       
423        return resultingPaths;
424    }
425
426    /**
427     * <p>
428     * Traverses the trie to collect all sequences with a defined number of occurrences and with
429     * a minimal length. If the given occurrence count is smaller 1 then those sequences are
430     * searched that occur most often. The length of the sequences is encoded in the provided
431     * recursion depth.
432     * </p>
433     *
434     * @param minimalDepth    the minimal recursion depth to be done
435     * @param occurrenceCount the number of occurrences of the returned sequences
436     * @param paths           the paths through the trie that all occurred with the same amount
437     *                        (if occurrence count is smaller 1, the paths which occurred most
438     *                        often) and that have the so far found matching number of occurrences
439     *                        (is updated each time a further path with the same number of
440     *                        occurrences is found; if occurrence count is smaller 1
441     *                        it is replaced if a path with more occurrences is found)
442     * @param context         the path through the trie, that is analyzed by the recursive call
443     */
444    private void determineLongPathsWithMostOccurrences(int                           minimalDepth,
445                                                       int                           occurrenceCount,
446                                                       Collection<List<TrieNode<T>>> paths,
447                                                       LinkedList<TrieNode<T>>       context)
448    {
449        int envisagedCount = occurrenceCount;
450
451        // only if we already reached the depth to be achieved, we check if the paths have the
452        // required number of occurrences
453        if (context.size() >= minimalDepth) {
454           
455            if (envisagedCount < 1) {
456                // try to determine the maximum number of occurrences so far, if any
457                if (paths.size() > 0) {
458                    List<TrieNode<T>> path = paths.iterator().next();
459                    envisagedCount = path.get(path.size() - 1).getCount();
460                }
461
462                // if the current path has a higher number of occurrences than all so far, clear
463                // the paths collected so far and set the new number of occurrences as new maximum
464                if (context.getLast().getCount() > envisagedCount) {
465                    paths.clear();
466                    envisagedCount = context.getLast().getCount();
467                }
468            }
469           
470            // if the path matches the current maximal number of occurrences, add it to the list
471            // of collected paths with these number of occurrences
472            if (context.getLast().getCount() == envisagedCount) {
473                paths.add(new LinkedList<TrieNode<T>>(context));
474            }
475        }
476       
477        // perform the trie traversal
478        for (TrieNode<T> child : context.getLast().getChildren()) {
479            if (child.getCount() >= envisagedCount) {
480                context.add(child);
481                determineLongPathsWithMostOccurrences
482                    (minimalDepth, occurrenceCount, paths, context);
483                context.removeLast();
484            }
485        }
486    }
487   
488    /**
489     * <p>
490     * adds a new symbol to the collection of known symbols if this symbol is not already
491     * contained. The symbols are compared using the comparator.
492     * </p>
493     *
494     * @param symbol the symbol to be added to the known symbols
495     */
496    private void addToKnownSymbols(T symbol) {
497        if (!knownSymbols.containsSymbol(symbol)) {
498            knownSymbols.addSymbol(symbol, symbol);
499        }
500    }
501
502    /**
503     * <p>
504     * Helper class for graph visualization of a trie.
505     * </p>
506     *
507     * @author Steffen Herbold
508     * @version 1.0
509     */
510    static public class Edge {}
511
512    /**
513     * <p>
514     * Helper class for graph visualization of a trie.
515     * </p>
516     *
517     * @author Steffen Herbold
518     * @version 1.0
519     */
520    static public class TrieVertex {
521
522        /**
523         * <p>
524         * Id of the vertex.
525         * </p>
526         */
527        private String id;
528
529        /**
530         * <p>
531         * Contructor. Creates a new TrieVertex.
532         * </p>
533         *
534         * @param id
535         *            id of the vertex
536         */
537        protected TrieVertex(String id) {
538            this.id = id;
539        }
540
541        /**
542         * <p>
543         * Returns the id of the vertex.
544         * </p>
545         *
546         * @see java.lang.Object#toString()
547         */
548        @Override
549        public String toString() {
550            return id;
551        }
552    }
553
554    /**
555     * <p>
556     * Returns a {@link Graph} representation of the trie.
557     * </p>
558     *
559     * @return {@link Graph} representation of the trie
560     */
561    protected Tree<TrieVertex, Edge> getGraph() {
562        DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>();
563        rootNode.getGraph(null, graph);
564        return graph;
565    }
566
567    /*
568     * (non-Javadoc)
569     *
570     * @see de.ugoe.cs.autoquest.usageprofiles.IDotCompatible#getDotRepresentation()
571     */
572    public String getDotRepresentation() {
573        StringBuilder stringBuilder = new StringBuilder();
574        stringBuilder.append("digraph model {" + StringTools.ENDLINE);
575        rootNode.appendDotRepresentation(stringBuilder);
576        stringBuilder.append('}' + StringTools.ENDLINE);
577        return stringBuilder.toString();
578    }
579
580    /**
581     * <p>
582     * Returns the string representation of the root node.
583     * </p>
584     *
585     * @see TrieNode#toString()
586     * @see java.lang.Object#toString()
587     */
588    @Override
589    public String toString() {
590        return rootNode.toString();
591    }
592
593    /**
594     * <p>
595     * Returns the number of symbols contained in the trie.
596     * </p>
597     *
598     * @return number of symbols contained in the trie
599     */
600    public int getNumSymbols() {
601        return knownSymbols.size();
602    }
603
604    /**
605     * <p>
606     * Returns the number of trie nodes that are ancestors of a leaf. This is the equivalent to the
607     * number of states a first-order markov model would have.
608     * <p>
609     *
610     * @return number of trie nodes that are ancestors of leafs.
611     */
612    public int getNumLeafAncestors() {
613        List<TrieNode<T>> ancestors = new LinkedList<TrieNode<T>>();
614        rootNode.getLeafAncestors(ancestors);
615        return ancestors.size();
616    }
617
618    /**
619     * <p>
620     * Returns the number of trie nodes that are leafs.
621     * </p>
622     *
623     * @return number of leafs in the trie
624     */
625    public int getNumLeafs() {
626        return rootNode.getNumLeafs();
627    }
628
629    /**
630     * <p>
631     * Updates the list of known symbols by replacing it with all symbols that are found in the
632     * child nodes of the root node. This should be the same as all symbols that are contained in
633     * the trie.
634     * </p>
635     */
636    public void updateKnownSymbols() {
637        knownSymbols = new SymbolMap<T, T>(this.comparator);
638        for (TrieNode<T> node : rootNode.getChildren()) {
639            addToKnownSymbols(node.getSymbol());
640        }
641    }
642
643    /**
644     * <p>
645     * Two Tries are defined as equal, if their {@link #rootNode} are equal.
646     * </p>
647     *
648     * @see java.lang.Object#equals(java.lang.Object)
649     */
650    @SuppressWarnings("rawtypes")
651    @Override
652    public boolean equals(Object other) {
653        if (other == this) {
654            return true;
655        }
656        if (other instanceof Trie) {
657            return rootNode.equals(((Trie) other).rootNode);
658        }
659        return false;
660    }
661
662    /*
663     * (non-Javadoc)
664     *
665     * @see java.lang.Object#hashCode()
666     */
667    @Override
668    public int hashCode() {
669        int multiplier = 17;
670        int hash = 42;
671        if (rootNode != null) {
672            hash = multiplier * hash + rootNode.hashCode();
673        }
674        return hash;
675    }
676
677}
Note: See TracBrowser for help on using the repository browser.