source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/Trie.java @ 1189

Last change on this file since 1189 was 1189, checked in by pharms, 11 years ago
  • remove a find bugs warning
  • Property svn:mime-type set to text/plain
File size: 21.9 KB
Line 
1//   Copyright 2012 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.autoquest.usageprofiles;
16
17import java.io.Serializable;
18import java.util.Collection;
19import java.util.HashSet;
20import java.util.LinkedHashSet;
21import java.util.LinkedList;
22import java.util.List;
23
24import de.ugoe.cs.util.StringTools;
25
26import edu.uci.ics.jung.graph.DelegateTree;
27import edu.uci.ics.jung.graph.Graph;
28import edu.uci.ics.jung.graph.Tree;
29
30/**
31 * <p>
32 * This class implements a <it>trie</it>, i.e., a tree of sequences that the occurence of
33 * subsequences up to a predefined length. This length is the trie order.
34 * </p>
35 *
36 * @author Steffen Herbold, Patrick Harms
37 *
38 * @param <T>
39 *            Type of the symbols that are stored in the trie.
40 *
41 * @see TrieNode
42 */
43public class Trie<T> implements IDotCompatible, Serializable {
44
45    /**
46     * <p>
47     * Id for object serialization.
48     * </p>
49     */
50    private static final long serialVersionUID = 1L;
51
52    /**
53     * <p>
54     * Collection of all symbols occuring in the trie.
55     * </p>
56     */
57    private Collection<T> knownSymbols;
58
59    /**
60     * <p>
61     * Reference to the root of the trie.
62     * </p>
63     */
64    private final TrieNode<T> rootNode;
65
66    /**
67     * <p>
68     * Comparator to be used for comparing the symbols with each other
69     * </p>
70     */
71    private SymbolComparator<T> comparator;
72
73    /**
74     * <p>
75     * Contructor. Creates a new Trie with a {@link DefaultSymbolComparator}.
76     * </p>
77     */
78    public Trie() {
79        this(new DefaultSymbolComparator<T>());
80    }
81
82    /**
83     * <p>
84     * Contructor. Creates a new Trie with that uses a specific {@link SymbolComparator}.
85     * </p>
86     */
87    public Trie(SymbolComparator<T> comparator) {
88        this.comparator = comparator;
89        rootNode = new TrieNode<T>(comparator);
90        knownSymbols = new LinkedHashSet<T>();
91    }
92
93    /**
94     * <p>
95     * Copy-Constructor. Creates a new Trie as the copy of other. The other trie must not be null.
96     * </p>
97     *
98     * @param other
99     *            Trie that is copied
100     */
101    public Trie(Trie<T> other) {
102        if (other == null) {
103            throw new IllegalArgumentException("other trie must not be null");
104        }
105        rootNode = new TrieNode<T>(other.rootNode);
106        knownSymbols = new LinkedHashSet<T>(other.knownSymbols);
107        comparator = other.comparator;
108    }
109
110    /**
111     * <p>
112     * Returns a collection of all symbols occuring in the trie.
113     * </p>
114     *
115     * @return symbols occuring in the trie
116     */
117    public Collection<T> getKnownSymbols() {
118        return new LinkedHashSet<T>(knownSymbols);
119    }
120
121    /**
122     * <p>
123     * Trains the current trie using the given sequence and adds all subsequences of length
124     * {@code maxOrder}.
125     * </p>
126     *
127     * @param sequence
128     *            sequence whose subsequences are added to the trie
129     * @param maxOrder
130     *            maximum length of the subsequences added to the trie
131     */
132    public void train(List<T> sequence, int maxOrder) {
133        if (maxOrder < 1) {
134            return;
135        }
136        IncompleteMemory<T> latestActions = new IncompleteMemory<T>(maxOrder);
137        int i = 0;
138        for (T currentEvent : sequence) {
139            latestActions.add(currentEvent);
140            addToKnownSymbols(currentEvent);
141            i++;
142            if (i >= maxOrder) {
143                add(latestActions.getLast(maxOrder));
144            }
145        }
146        int sequenceLength = sequence.size();
147        int startIndex = Math.max(0, sequenceLength - maxOrder + 1);
148        for (int j = startIndex; j < sequenceLength; j++) {
149            add(sequence.subList(j, sequenceLength));
150        }
151    }
152
153    /**
154     * <p>
155     * Adds a given subsequence to the trie and increases the counters accordingly.
156     * </p>
157     *
158     * @param subsequence
159     *            subsequence whose counters are increased
160     * @see TrieNode#add(List)
161     */
162    protected void add(List<T> subsequence) {
163        if (subsequence != null && !subsequence.isEmpty()) {
164            addToKnownSymbols(subsequence);
165            subsequence = new LinkedList<T>(subsequence); // defensive copy!
166            T firstSymbol = subsequence.get(0);
167            TrieNode<T> node = getChildCreate(firstSymbol);
168            node.add(subsequence);
169        }
170    }
171
172    /**
173     * <p>
174     * Returns the child of the root node associated with the given symbol or creates it if it does
175     * not exist yet.
176     * </p>
177     *
178     * @param symbol
179     *            symbol whose node is required
180     * @return node associated with the symbol
181     * @see TrieNode#getChildCreate(Object)
182     */
183    protected TrieNode<T> getChildCreate(T symbol) {
184        return rootNode.getChildCreate(symbol);
185    }
186
187    /**
188     * <p>
189     * Returns the child of the root node associated with the given symbol or null if it does not
190     * exist.
191     * </p>
192     *
193     * @param symbol
194     *            symbol whose node is required
195     * @return node associated with the symbol; null if no such node exists
196     * @see TrieNode#getChild(Object)
197     */
198    protected TrieNode<T> getChild(T symbol) {
199        return rootNode.getChild(symbol);
200    }
201
202    /**
203     * <p>
204     * Returns the number of occurences of the given sequence.
205     * </p>
206     *
207     * @param sequence
208     *            sequence whose number of occurences is required
209     * @return number of occurences of the sequence
210     */
211    public int getCount(List<T> sequence) {
212        int count = 0;
213        TrieNode<T> node = find(sequence);
214        if (node != null) {
215            count = node.getCount();
216        }
217        return count;
218    }
219
220    /**
221     * <p>
222     * Returns the number of occurences of the given prefix and a symbol that follows it.<br>
223     * Convenience function to simplify usage of {@link #getCount(List)}.
224     * </p>
225     *
226     * @param sequence
227     *            prefix of the sequence
228     * @param follower
229     *            suffix of the sequence
230     * @return number of occurences of the sequence
231     * @see #getCount(List)
232     */
233    public int getCount(List<T> sequence, T follower) {
234        List<T> tmpSequence = new LinkedList<T>(sequence);
235        tmpSequence.add(follower);
236        return getCount(tmpSequence);
237
238    }
239
240    /**
241     * <p>
242     * Searches the trie for a given sequence and returns the node associated with the sequence or
243     * null if no such node is found.
244     * </p>
245     *
246     * @param sequence
247     *            sequence that is searched for
248     * @return node associated with the sequence
249     * @see TrieNode#find(List)
250     */
251    public TrieNode<T> find(List<T> sequence) {
252        if (sequence == null || sequence.isEmpty()) {
253            return rootNode;
254        }
255        List<T> sequenceCopy = new LinkedList<T>(sequence);
256        TrieNode<T> result = null;
257        TrieNode<T> node = getChild(sequenceCopy.get(0));
258        if (node != null) {
259            sequenceCopy.remove(0);
260            result = node.find(sequenceCopy);
261        }
262        return result;
263    }
264
265    /**
266     * <p>
267     * Returns a collection of all symbols that follow a given sequence in the trie. In case the
268     * sequence is not found or no symbols follow the sequence the result will be empty.
269     * </p>
270     *
271     * @param sequence
272     *            sequence whose followers are returned
273     * @return symbols following the given sequence
274     * @see TrieNode#getFollowingSymbols()
275     */
276    public Collection<T> getFollowingSymbols(List<T> sequence) {
277        Collection<T> result = new LinkedList<T>();
278        TrieNode<T> node = find(sequence);
279        if (node != null) {
280            result = node.getFollowingSymbols();
281        }
282        return result;
283    }
284
285    /**
286     * <p>
287     * Returns the longest suffix of the given context that is contained in the tree and whose
288     * children are leaves.
289     * </p>
290     *
291     * @param context
292     *            context whose suffix is searched for
293     * @return longest suffix of the context
294     */
295    public List<T> getContextSuffix(List<T> context) {
296        List<T> contextSuffix;
297        if (context != null) {
298            contextSuffix = new LinkedList<T>(context); // defensive copy
299        }
300        else {
301            contextSuffix = new LinkedList<T>();
302        }
303        boolean suffixFound = false;
304
305        while (!suffixFound) {
306            if (contextSuffix.isEmpty()) {
307                suffixFound = true; // suffix is the empty word
308            }
309            else {
310                TrieNode<T> node = find(contextSuffix);
311                if (node != null) {
312                    if (!node.getFollowingSymbols().isEmpty()) {
313                        suffixFound = true;
314                    }
315                }
316                if (!suffixFound) {
317                    contextSuffix.remove(0);
318                }
319            }
320        }
321
322        return contextSuffix;
323    }
324   
325    /**
326     * <p>
327     * used to recursively process the trie. The provided processor will be called for any path
328     * through the tree. The processor may abort the processing through returns values of its
329     * {@link TrieProcessor#process(List, int)} method.
330     * </p>
331     *
332     * @param processor the processor to process the tree
333     */
334    public void process(TrieProcessor<T> processor) {
335        LinkedList<T> context = new LinkedList<T>();
336       
337        for (TrieNode<T> child : rootNode.getChildren()) {
338            if (!process(context, child, processor)) {
339                break;
340            }
341        }
342    }
343
344    /**
345     * <p>
346     * processes a specific path by calling the provided processor. Furthermore, the method
347     * calls itself recursively for further subpaths.
348     * </p>
349     *
350     * @param context   the context of the currently processed trie node, i.e. the preceeding
351     *                  symbols
352     * @param child     the processed trie node
353     * @param processor the processor used for processing the trie
354     *
355     * @return true, if processing shall continue, false else
356     */
357    private boolean process(LinkedList<T>    context,
358                            TrieNode<T>      node,
359                            TrieProcessor<T> processor)
360    {
361        context.add(node.getSymbol());
362       
363        TrieProcessor.Result result = processor.process(context, node.getCount());
364       
365        if (result == TrieProcessor.Result.CONTINUE) {
366            for (TrieNode<T> child : node.getChildren()) {
367                if (!process(context, child, processor)) {
368                    break;
369                }
370            }
371        }
372       
373        context.removeLast();
374       
375        return result != TrieProcessor.Result.BREAK;
376    }
377
378    /**
379     * <p>
380     * returns a list of symbol sequences which have a minimal length and that occurred as often
381     * as defined by the given occurrence count. If the given occurrence count is smaller 1 then
382     * those sequences are returned, that occur most often. The resulting list is empty, if there
383     * is no symbol sequence with the minimal length or the provided number of occurrences.
384     * </p>
385     *
386     * @param minimalLength   the minimal length of the returned sequences
387     * @param occurrenceCount the number of occurrences of the returned sequences
388     *
389     * @return as described
390     */
391    public Collection<List<T>> getSequencesWithOccurrenceCount(int minimalLength,
392                                                               int occurrenceCount)
393    {
394        LinkedList<TrieNode<T>> context = new LinkedList<TrieNode<T>>();
395        Collection<List<TrieNode<T>>> paths = new LinkedList<List<TrieNode<T>>>();
396       
397        context.push(rootNode);
398       
399        // traverse the trie and determine all sequences, which have the provided number of
400        // occurrences and a minimal length.
401       
402        // minimalLength + 1 because we denote the depth including the root node
403        determineLongPathsWithMostOccurrences(minimalLength + 1, occurrenceCount, paths, context);
404       
405        Collection<List<T>> resultingPaths = new LinkedList<List<T>>();
406        List<T> resultingPath;
407       
408        if (paths.size() > 0) {
409           
410            for (List<TrieNode<T>> path : paths) {
411                resultingPath = new LinkedList<T>();
412               
413                for (TrieNode<T> node : path) {
414                    if (node.getSymbol() != null) {
415                        resultingPath.add(node.getSymbol());
416                    }
417                }
418               
419                resultingPaths.add(resultingPath);
420            }
421        }
422       
423        return resultingPaths;
424    }
425
426    /**
427     * <p>
428     * Traverses the trie to collect all sequences with a defined number of occurrences and with
429     * a minimal length. If the given occurrence count is smaller 1 then those sequences are
430     * searched that occur most often. The length of the sequences is encoded in the provided
431     * recursion depth.
432     * </p>
433     *
434     * @param minimalDepth    the minimal recursion depth to be done
435     * @param occurrenceCount the number of occurrences of the returned sequences
436     * @param paths           the paths through the trie that all occurred with the same amount
437     *                        (if occurrence count is smaller 1, the paths which occurred most
438     *                        often) and that have the so far found matching number of occurrences
439     *                        (is updated each time a further path with the same number of
440     *                        occurrences is found; if occurrence count is smaller 1
441     *                        it is replaced if a path with more occurrences is found)
442     * @param context         the path through the trie, that is analyzed by the recursive call
443     */
444    private void determineLongPathsWithMostOccurrences(int                           minimalDepth,
445                                                       int                           occurrenceCount,
446                                                       Collection<List<TrieNode<T>>> paths,
447                                                       LinkedList<TrieNode<T>>       context)
448    {
449        int envisagedCount = occurrenceCount;
450
451        // only if we already reached the depth to be achieved, we check if the paths have the
452        // required number of occurrences
453        if (context.size() >= minimalDepth) {
454           
455            if (envisagedCount < 1) {
456                // try to determine the maximum number of occurrences so far, if any
457                if (paths.size() > 0) {
458                    List<TrieNode<T>> path = paths.iterator().next();
459                    envisagedCount = path.get(path.size() - 1).getCount();
460                }
461
462                // if the current path has a higher number of occurrences than all so far, clear
463                // the paths collected so far and set the new number of occurrences as new maximum
464                if (context.getLast().getCount() > envisagedCount) {
465                    paths.clear();
466                    envisagedCount = context.getLast().getCount();
467                }
468            }
469           
470            // if the path matches the current maximal number of occurrences, add it to the list
471            // of collected paths with these number of occurrences
472            if (context.getLast().getCount() == envisagedCount) {
473                paths.add(new LinkedList<TrieNode<T>>(context));
474            }
475        }
476       
477        // perform the trie traversal
478        for (TrieNode<T> child : context.getLast().getChildren()) {
479            if (child.getCount() >= envisagedCount) {
480                context.add(child);
481                determineLongPathsWithMostOccurrences
482                    (minimalDepth, occurrenceCount, paths, context);
483                context.removeLast();
484            }
485        }
486    }
487   
488    /**
489     * <p>
490     * adds a new symbol to the collection of known symbols if this symbol is not already
491     * contained. The symbols are compared using the comparator.
492     * </p>
493     *
494     * @param symbol the symbol to be added to the known symbols
495     */
496    private void addToKnownSymbols(T symbol) {
497        for (T knownSymbol : knownSymbols) {
498            if (comparator.equals(knownSymbol, symbol)) {
499                return;
500            }
501        }
502       
503        knownSymbols.add(symbol);
504    }
505
506    /**
507     * <p>
508     * adds a list of new symbols to the collection of known symbols. Uses the
509     * {@link #addToKnownSymbols(Object)} method for each element of the provided list.
510     * </p>
511     *
512     * @param symbols the list of symbols to be added to the known symbols
513     */
514    private void addToKnownSymbols(List<T> symbols) {
515        for (T symbol : symbols) {
516            addToKnownSymbols(symbol);
517        }
518    }
519
520    /**
521     * <p>
522     * Helper class for graph visualization of a trie.
523     * </p>
524     *
525     * @author Steffen Herbold
526     * @version 1.0
527     */
528    static public class Edge {}
529
530    /**
531     * <p>
532     * Helper class for graph visualization of a trie.
533     * </p>
534     *
535     * @author Steffen Herbold
536     * @version 1.0
537     */
538    static public class TrieVertex {
539
540        /**
541         * <p>
542         * Id of the vertex.
543         * </p>
544         */
545        private String id;
546
547        /**
548         * <p>
549         * Contructor. Creates a new TrieVertex.
550         * </p>
551         *
552         * @param id
553         *            id of the vertex
554         */
555        protected TrieVertex(String id) {
556            this.id = id;
557        }
558
559        /**
560         * <p>
561         * Returns the id of the vertex.
562         * </p>
563         *
564         * @see java.lang.Object#toString()
565         */
566        @Override
567        public String toString() {
568            return id;
569        }
570    }
571
572    /**
573     * <p>
574     * Returns a {@link Graph} representation of the trie.
575     * </p>
576     *
577     * @return {@link Graph} representation of the trie
578     */
579    protected Tree<TrieVertex, Edge> getGraph() {
580        DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>();
581        rootNode.getGraph(null, graph);
582        return graph;
583    }
584
585    /*
586     * (non-Javadoc)
587     *
588     * @see de.ugoe.cs.autoquest.usageprofiles.IDotCompatible#getDotRepresentation()
589     */
590    public String getDotRepresentation() {
591        StringBuilder stringBuilder = new StringBuilder();
592        stringBuilder.append("digraph model {" + StringTools.ENDLINE);
593        rootNode.appendDotRepresentation(stringBuilder);
594        stringBuilder.append('}' + StringTools.ENDLINE);
595        return stringBuilder.toString();
596    }
597
598    /**
599     * <p>
600     * Returns the string representation of the root node.
601     * </p>
602     *
603     * @see TrieNode#toString()
604     * @see java.lang.Object#toString()
605     */
606    @Override
607    public String toString() {
608        return rootNode.toString();
609    }
610
611    /**
612     * <p>
613     * Returns the number of symbols contained in the trie.
614     * </p>
615     *
616     * @return number of symbols contained in the trie
617     */
618    public int getNumSymbols() {
619        return knownSymbols.size();
620    }
621
622    /**
623     * <p>
624     * Returns the number of trie nodes that are ancestors of a leaf. This is the equivalent to the
625     * number of states a first-order markov model would have.
626     * <p>
627     *
628     * @return number of trie nodes that are ancestors of leafs.
629     */
630    public int getNumLeafAncestors() {
631        List<TrieNode<T>> ancestors = new LinkedList<TrieNode<T>>();
632        rootNode.getLeafAncestors(ancestors);
633        return ancestors.size();
634    }
635
636    /**
637     * <p>
638     * Returns the number of trie nodes that are leafs.
639     * </p>
640     *
641     * @return number of leafs in the trie
642     */
643    public int getNumLeafs() {
644        return rootNode.getNumLeafs();
645    }
646
647    /**
648     * <p>
649     * Updates the list of known symbols by replacing it with all symbols that are found in the
650     * child nodes of the root node. This should be the same as all symbols that are contained in
651     * the trie.
652     * </p>
653     */
654    public void updateKnownSymbols() {
655        knownSymbols = new HashSet<T>();
656        for (TrieNode<T> node : rootNode.getChildren()) {
657            addToKnownSymbols(node.getSymbol());
658        }
659    }
660
661    /**
662     * <p>
663     * Two Tries are defined as equal, if their {@link #rootNode} are equal.
664     * </p>
665     *
666     * @see java.lang.Object#equals(java.lang.Object)
667     */
668    @SuppressWarnings("rawtypes")
669    @Override
670    public boolean equals(Object other) {
671        if (other == this) {
672            return true;
673        }
674        if (other instanceof Trie) {
675            return rootNode.equals(((Trie) other).rootNode);
676        }
677        return false;
678    }
679
680    /*
681     * (non-Javadoc)
682     *
683     * @see java.lang.Object#hashCode()
684     */
685    @Override
686    public int hashCode() {
687        int multiplier = 17;
688        int hash = 42;
689        if (rootNode != null) {
690            hash = multiplier * hash + rootNode.hashCode();
691        }
692        return hash;
693    }
694
695}
Note: See TracBrowser for help on using the repository browser.