source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/Trie.java @ 1110

Last change on this file since 1110 was 1110, checked in by pharms, 11 years ago
  • allowed to search for all sub sequences with a dedicated number of occurrence
  • Property svn:mime-type set to text/plain
File size: 20.0 KB
Line 
1//   Copyright 2012 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.autoquest.usageprofiles;
16
17import java.io.Serializable;
18import java.util.Collection;
19import java.util.HashSet;
20import java.util.LinkedHashSet;
21import java.util.LinkedList;
22import java.util.List;
23
24import de.ugoe.cs.util.StringTools;
25
26import edu.uci.ics.jung.graph.DelegateTree;
27import edu.uci.ics.jung.graph.Graph;
28import edu.uci.ics.jung.graph.Tree;
29
30/**
31 * <p>
32 * This class implements a <it>trie</it>, i.e., a tree of sequences that the occurence of
33 * subsequences up to a predefined length. This length is the trie order.
34 * </p>
35 *
36 * @author Steffen Herbold, Patrick Harms
37 *
38 * @param <T>
39 *            Type of the symbols that are stored in the trie.
40 *
41 * @see TrieNode
42 */
43public class Trie<T> implements IDotCompatible, Serializable {
44
45    /**
46     * <p>
47     * Id for object serialization.
48     * </p>
49     */
50    private static final long serialVersionUID = 1L;
51
52    /**
53     * <p>
54     * Collection of all symbols occuring in the trie.
55     * </p>
56     */
57    private Collection<T> knownSymbols;
58
59    /**
60     * <p>
61     * Reference to the root of the trie.
62     * </p>
63     */
64    private final TrieNode<T> rootNode;
65
66    /**
67     * <p>
68     * Comparator to be used for comparing the symbols with each other
69     * </p>
70     */
71    private SymbolComparator<T> comparator;
72
73    /**
74     * <p>
75     * Contructor. Creates a new Trie with a {@link DefaultSymbolComparator}.
76     * </p>
77     */
78    public Trie() {
79        this(new DefaultSymbolComparator<T>());
80    }
81
82    /**
83     * <p>
84     * Contructor. Creates a new Trie with that uses a specific {@link SymbolComparator}.
85     * </p>
86     */
87    public Trie(SymbolComparator<T> comparator) {
88        this.comparator = comparator;
89        rootNode = new TrieNode<T>(comparator);
90        knownSymbols = new LinkedHashSet<T>();
91    }
92
93    /**
94     * <p>
95     * Copy-Constructor. Creates a new Trie as the copy of other. The other trie must not be null.
96     * </p>
97     *
98     * @param other
99     *            Trie that is copied
100     */
101    public Trie(Trie<T> other) {
102        if (other == null) {
103            throw new IllegalArgumentException("other trie must not be null");
104        }
105        rootNode = new TrieNode<T>(other.rootNode);
106        knownSymbols = new LinkedHashSet<T>(other.knownSymbols);
107        comparator = other.comparator;
108    }
109
110    /**
111     * <p>
112     * Returns a collection of all symbols occuring in the trie.
113     * </p>
114     *
115     * @return symbols occuring in the trie
116     */
117    public Collection<T> getKnownSymbols() {
118        return new LinkedHashSet<T>(knownSymbols);
119    }
120
121    /**
122     * <p>
123     * Trains the current trie using the given sequence and adds all subsequence of length
124     * {@code maxOrder}.
125     * </p>
126     *
127     * @param sequence
128     *            sequence whose subsequences are added to the trie
129     * @param maxOrder
130     *            maximum length of the subsequences added to the trie
131     */
132    public void train(List<T> sequence, int maxOrder) {
133        if (maxOrder < 1) {
134            return;
135        }
136        IncompleteMemory<T> latestActions = new IncompleteMemory<T>(maxOrder);
137        int i = 0;
138        for (T currentEvent : sequence) {
139            latestActions.add(currentEvent);
140            addToKnownSymbols(currentEvent);
141            i++;
142            if (i >= maxOrder) {
143                add(latestActions.getLast(maxOrder));
144            }
145        }
146        int sequenceLength = sequence.size();
147        for (int j = maxOrder - 1; j > 0; j--) {
148            add(sequence.subList(sequenceLength - j, sequenceLength));
149        }
150    }
151
152    /**
153     * <p>
154     * Adds a given subsequence to the trie and increases the counters accordingly.
155     * </p>
156     *
157     * @param subsequence
158     *            subsequence whose counters are increased
159     * @see TrieNode#add(List)
160     */
161    protected void add(List<T> subsequence) {
162        if (subsequence != null && !subsequence.isEmpty()) {
163            addToKnownSymbols(subsequence);
164            subsequence = new LinkedList<T>(subsequence); // defensive copy!
165            T firstSymbol = subsequence.get(0);
166            TrieNode<T> node = getChildCreate(firstSymbol);
167            node.add(subsequence);
168        }
169    }
170
171    /**
172     * <p>
173     * Returns the child of the root node associated with the given symbol or creates it if it does
174     * not exist yet.
175     * </p>
176     *
177     * @param symbol
178     *            symbol whose node is required
179     * @return node associated with the symbol
180     * @see TrieNode#getChildCreate(Object)
181     */
182    protected TrieNode<T> getChildCreate(T symbol) {
183        return rootNode.getChildCreate(symbol);
184    }
185
186    /**
187     * <p>
188     * Returns the child of the root node associated with the given symbol or null if it does not
189     * exist.
190     * </p>
191     *
192     * @param symbol
193     *            symbol whose node is required
194     * @return node associated with the symbol; null if no such node exists
195     * @see TrieNode#getChild(Object)
196     */
197    protected TrieNode<T> getChild(T symbol) {
198        return rootNode.getChild(symbol);
199    }
200
201    /**
202     * <p>
203     * Returns the number of occurences of the given sequence.
204     * </p>
205     *
206     * @param sequence
207     *            sequence whose number of occurences is required
208     * @return number of occurences of the sequence
209     */
210    public int getCount(List<T> sequence) {
211        int count = 0;
212        TrieNode<T> node = find(sequence);
213        if (node != null) {
214            count = node.getCount();
215        }
216        return count;
217    }
218
219    /**
220     * <p>
221     * Returns the number of occurences of the given prefix and a symbol that follows it.<br>
222     * Convenience function to simplify usage of {@link #getCount(List)}.
223     * </p>
224     *
225     * @param sequence
226     *            prefix of the sequence
227     * @param follower
228     *            suffix of the sequence
229     * @return number of occurences of the sequence
230     * @see #getCount(List)
231     */
232    public int getCount(List<T> sequence, T follower) {
233        List<T> tmpSequence = new LinkedList<T>(sequence);
234        tmpSequence.add(follower);
235        return getCount(tmpSequence);
236
237    }
238
239    /**
240     * <p>
241     * Searches the trie for a given sequence and returns the node associated with the sequence or
242     * null if no such node is found.
243     * </p>
244     *
245     * @param sequence
246     *            sequence that is searched for
247     * @return node associated with the sequence
248     * @see TrieNode#find(List)
249     */
250    public TrieNode<T> find(List<T> sequence) {
251        if (sequence == null || sequence.isEmpty()) {
252            return rootNode;
253        }
254        List<T> sequenceCopy = new LinkedList<T>(sequence);
255        TrieNode<T> result = null;
256        TrieNode<T> node = getChild(sequenceCopy.get(0));
257        if (node != null) {
258            sequenceCopy.remove(0);
259            result = node.find(sequenceCopy);
260        }
261        return result;
262    }
263
264    /**
265     * <p>
266     * Returns a collection of all symbols that follow a given sequence in the trie. In case the
267     * sequence is not found or no symbols follow the sequence the result will be empty.
268     * </p>
269     *
270     * @param sequence
271     *            sequence whose followers are returned
272     * @return symbols following the given sequence
273     * @see TrieNode#getFollowingSymbols()
274     */
275    public Collection<T> getFollowingSymbols(List<T> sequence) {
276        Collection<T> result = new LinkedList<T>();
277        TrieNode<T> node = find(sequence);
278        if (node != null) {
279            result = node.getFollowingSymbols();
280        }
281        return result;
282    }
283
284    /**
285     * <p>
286     * Returns the longest suffix of the given context that is contained in the tree and whose
287     * children are leaves.
288     * </p>
289     *
290     * @param context
291     *            context whose suffix is searched for
292     * @return longest suffix of the context
293     */
294    public List<T> getContextSuffix(List<T> context) {
295        List<T> contextSuffix;
296        if (context != null) {
297            contextSuffix = new LinkedList<T>(context); // defensive copy
298        }
299        else {
300            contextSuffix = new LinkedList<T>();
301        }
302        boolean suffixFound = false;
303
304        while (!suffixFound) {
305            if (contextSuffix.isEmpty()) {
306                suffixFound = true; // suffix is the empty word
307            }
308            else {
309                TrieNode<T> node = find(contextSuffix);
310                if (node != null) {
311                    if (!node.getFollowingSymbols().isEmpty()) {
312                        suffixFound = true;
313                    }
314                }
315                if (!suffixFound) {
316                    contextSuffix.remove(0);
317                }
318            }
319        }
320
321        return contextSuffix;
322    }
323
324    /**
325     * <p>
326     * returns a list of symbol sequences which have a minimal length and that occurred as often
327     * as defined by the given occurrence count. If the given occurrence count is smaller 1 then
328     * those sequences are returned, that occur most often. The resulting list is empty, if there
329     * is no symbol sequence with the minimal length or the provided number of occurrences.
330     * </p>
331     *
332     * @param minimalLength   the minimal length of the returned sequences
333     * @param occurrenceCount the number of occurrences of the returned sequences
334     *
335     * @return as described
336     */
337    public Collection<List<T>> getSequencesWithOccurrenceCount(int minimalLength,
338                                                               int occurrenceCount)
339    {
340        LinkedList<TrieNode<T>> context = new LinkedList<TrieNode<T>>();
341        Collection<List<TrieNode<T>>> paths = new LinkedList<List<TrieNode<T>>>();
342       
343        context.push(rootNode);
344       
345        // traverse the trie and determine all sequences, which have the provided number of
346        // occurrences and a minimal length.
347       
348        // minimalLength + 1 because we denote the depth including the root node
349        determineLongPathsWithMostOccurrences(minimalLength + 1, occurrenceCount, paths, context);
350       
351        Collection<List<T>> resultingPaths = new LinkedList<List<T>>();
352        List<T> resultingPath;
353       
354        if (paths.size() > 0) {
355           
356            for (List<TrieNode<T>> path : paths) {
357                resultingPath = new LinkedList<T>();
358               
359                for (TrieNode<T> node : path) {
360                    if (node.getSymbol() != null) {
361                        resultingPath.add(node.getSymbol());
362                    }
363                }
364               
365                resultingPaths.add(resultingPath);
366            }
367        }
368       
369        return resultingPaths;
370    }
371
372    /**
373     * <p>
374     * Traverses the trie to collect all sequences with a defined number of occurrences and with
375     * a minimal length. If the given occurrence count is smaller 1 then those sequences are
376     * searched that occur most often. The length of the sequences is encoded in the provided
377     * recursion depth.
378     * </p>
379     *
380     * @param minimalDepth    the minimal recursion depth to be done
381     * @param occurrenceCount the number of occurrences of the returned sequences
382     * @param paths           the paths through the trie that all occurred with the same amount
383     *                        (if occurrence count is smaller 1, the paths which occurred most
384     *                        often) and that have the so far found matching number of occurrences
385     *                        (is updated each time a further path with the same number of
386     *                        occurrences is found; if occurrence count is smaller 1
387     *                        it is replaced if a path with more occurrences is found)
388     * @param context         the path through the trie, that is analyzed by the recursive call
389     */
390    private void determineLongPathsWithMostOccurrences(int                           minimalDepth,
391                                                       int                           occurrenceCount,
392                                                       Collection<List<TrieNode<T>>> paths,
393                                                       LinkedList<TrieNode<T>>       context)
394    {
395        int envisagedCount = occurrenceCount;
396
397        // only if we already reached the depth to be achieved, we check if the paths have the
398        // required number of occurrences
399        if (context.size() >= minimalDepth) {
400           
401            if (envisagedCount < 1) {
402                // try to determine the maximum number of occurrences so far, if any
403                if (paths.size() > 0) {
404                    List<TrieNode<T>> path = paths.iterator().next();
405                    envisagedCount = path.get(path.size() - 1).getCount();
406                }
407
408                // if the current path has a higher number of occurrences than all so far, clear
409                // the paths collected so far and set the new number of occurrences as new maximum
410                if (context.getLast().getCount() > envisagedCount) {
411                    paths.clear();
412                    envisagedCount = context.getLast().getCount();
413                }
414            }
415           
416            // if the path matches the current maximal number of occurrences, add it to the list
417            // of collected paths with these number of occurrences
418            if (context.getLast().getCount() == envisagedCount) {
419                paths.add(new LinkedList<TrieNode<T>>(context));
420            }
421        }
422       
423        // perform the trie traversal
424        for (TrieNode<T> child : context.getLast().getChildren()) {
425            if (child.getCount() >= envisagedCount) {
426                context.add(child);
427                determineLongPathsWithMostOccurrences
428                    (minimalDepth, occurrenceCount, paths, context);
429                context.removeLast();
430            }
431        }
432    }
433   
434    /**
435     * <p>
436     * adds a new symbol to the collection of known symbols if this symbol is not already
437     * contained. The symbols are compared using the comparator.
438     * </p>
439     *
440     * @param symbol the symbol to be added to the known symbols
441     */
442    private void addToKnownSymbols(T symbol) {
443        for (T knownSymbol : knownSymbols) {
444            if (comparator.equals(knownSymbol, symbol)) {
445                return;
446            }
447        }
448       
449        knownSymbols.add(symbol);
450    }
451
452    /**
453     * <p>
454     * adds a list of new symbols to the collection of known symbols. Uses the
455     * {@link #addToKnownSymbols(Object)} method for each element of the provided list.
456     * </p>
457     *
458     * @param symbols the list of symbols to be added to the known symbols
459     */
460    private void addToKnownSymbols(List<T> symbols) {
461        for (T symbol : symbols) {
462            addToKnownSymbols(symbol);
463        }
464    }
465
466    /**
467     * <p>
468     * Helper class for graph visualization of a trie.
469     * </p>
470     *
471     * @author Steffen Herbold
472     * @version 1.0
473     */
474    static public class Edge {}
475
476    /**
477     * <p>
478     * Helper class for graph visualization of a trie.
479     * </p>
480     *
481     * @author Steffen Herbold
482     * @version 1.0
483     */
484    static public class TrieVertex {
485
486        /**
487         * <p>
488         * Id of the vertex.
489         * </p>
490         */
491        private String id;
492
493        /**
494         * <p>
495         * Contructor. Creates a new TrieVertex.
496         * </p>
497         *
498         * @param id
499         *            id of the vertex
500         */
501        protected TrieVertex(String id) {
502            this.id = id;
503        }
504
505        /**
506         * <p>
507         * Returns the id of the vertex.
508         * </p>
509         *
510         * @see java.lang.Object#toString()
511         */
512        @Override
513        public String toString() {
514            return id;
515        }
516    }
517
518    /**
519     * <p>
520     * Returns a {@link Graph} representation of the trie.
521     * </p>
522     *
523     * @return {@link Graph} representation of the trie
524     */
525    protected Tree<TrieVertex, Edge> getGraph() {
526        DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>();
527        rootNode.getGraph(null, graph);
528        return graph;
529    }
530
531    /*
532     * (non-Javadoc)
533     *
534     * @see de.ugoe.cs.autoquest.usageprofiles.IDotCompatible#getDotRepresentation()
535     */
536    public String getDotRepresentation() {
537        StringBuilder stringBuilder = new StringBuilder();
538        stringBuilder.append("digraph model {" + StringTools.ENDLINE);
539        rootNode.appendDotRepresentation(stringBuilder);
540        stringBuilder.append('}' + StringTools.ENDLINE);
541        return stringBuilder.toString();
542    }
543
544    /**
545     * <p>
546     * Returns the string representation of the root node.
547     * </p>
548     *
549     * @see TrieNode#toString()
550     * @see java.lang.Object#toString()
551     */
552    @Override
553    public String toString() {
554        return rootNode.toString();
555    }
556
557    /**
558     * <p>
559     * Returns the number of symbols contained in the trie.
560     * </p>
561     *
562     * @return number of symbols contained in the trie
563     */
564    public int getNumSymbols() {
565        return knownSymbols.size();
566    }
567
568    /**
569     * <p>
570     * Returns the number of trie nodes that are ancestors of a leaf. This is the equivalent to the
571     * number of states a first-order markov model would have.
572     * <p>
573     *
574     * @return number of trie nodes that are ancestors of leafs.
575     */
576    public int getNumLeafAncestors() {
577        List<TrieNode<T>> ancestors = new LinkedList<TrieNode<T>>();
578        rootNode.getLeafAncestors(ancestors);
579        return ancestors.size();
580    }
581
582    /**
583     * <p>
584     * Returns the number of trie nodes that are leafs.
585     * </p>
586     *
587     * @return number of leafs in the trie
588     */
589    public int getNumLeafs() {
590        return rootNode.getNumLeafs();
591    }
592
593    /**
594     * <p>
595     * Updates the list of known symbols by replacing it with all symbols that are found in the
596     * child nodes of the root node. This should be the same as all symbols that are contained in
597     * the trie.
598     * </p>
599     */
600    public void updateKnownSymbols() {
601        knownSymbols = new HashSet<T>();
602        for (TrieNode<T> node : rootNode.getChildren()) {
603            addToKnownSymbols(node.getSymbol());
604        }
605    }
606
607    /**
608     * <p>
609     * Two Tries are defined as equal, if their {@link #rootNode} are equal.
610     * </p>
611     *
612     * @see java.lang.Object#equals(java.lang.Object)
613     */
614    @SuppressWarnings("rawtypes")
615    @Override
616    public boolean equals(Object other) {
617        if (other == this) {
618            return true;
619        }
620        if (other instanceof Trie) {
621            return rootNode.equals(((Trie) other).rootNode);
622        }
623        return false;
624    }
625
626    /*
627     * (non-Javadoc)
628     *
629     * @see java.lang.Object#hashCode()
630     */
631    @Override
632    public int hashCode() {
633        int multiplier = 17;
634        int hash = 42;
635        if (rootNode != null) {
636            hash = multiplier * hash + rootNode.hashCode();
637        }
638        return hash;
639    }
640
641}
Note: See TracBrowser for help on using the repository browser.