source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/Trie.java @ 1060

Last change on this file since 1060 was 1060, checked in by pharms, 11 years ago
  • extended Trie implementation to be able to use different compare strategies
  • implemented a method to determine subsequences of a minimal length that occur most often
  • Property svn:mime-type set to text/plain
File size: 19.0 KB
Line 
1//   Copyright 2012 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.autoquest.usageprofiles;
16
17import java.io.Serializable;
18import java.util.Collection;
19import java.util.HashSet;
20import java.util.LinkedHashSet;
21import java.util.LinkedList;
22import java.util.List;
23
24import de.ugoe.cs.util.StringTools;
25
26import edu.uci.ics.jung.graph.DelegateTree;
27import edu.uci.ics.jung.graph.Graph;
28import edu.uci.ics.jung.graph.Tree;
29
30/**
31 * <p>
32 * This class implements a <it>trie</it>, i.e., a tree of sequences that the occurence of
33 * subsequences up to a predefined length. This length is the trie order.
34 * </p>
35 *
36 * @author Steffen Herbold, Patrick Harms
37 *
38 * @param <T>
39 *            Type of the symbols that are stored in the trie.
40 *
41 * @see TrieNode
42 */
43public class Trie<T> implements IDotCompatible, Serializable {
44
45    /**
46     * <p>
47     * Id for object serialization.
48     * </p>
49     */
50    private static final long serialVersionUID = 1L;
51
52    /**
53     * <p>
54     * Collection of all symbols occuring in the trie.
55     * </p>
56     */
57    private Collection<T> knownSymbols;
58
59    /**
60     * <p>
61     * Reference to the root of the trie.
62     * </p>
63     */
64    private final TrieNode<T> rootNode;
65
66    /**
67     * <p>
68     * Comparator to be used for comparing the symbols with each other
69     * </p>
70     */
71    private SymbolComparator<T> comparator;
72
73    /**
74     * <p>
75     * Contructor. Creates a new Trie with a {@link DefaultSymbolComparator}.
76     * </p>
77     */
78    public Trie() {
79        this(new DefaultSymbolComparator<T>());
80    }
81
82    /**
83     * <p>
84     * Contructor. Creates a new Trie with that uses a specific {@link SymbolComparator}.
85     * </p>
86     */
87    public Trie(SymbolComparator<T> comparator) {
88        this.comparator = comparator;
89        rootNode = new TrieNode<T>(comparator);
90        knownSymbols = new LinkedHashSet<T>();
91    }
92
93    /**
94     * <p>
95     * Copy-Constructor. Creates a new Trie as the copy of other. The other trie must not be null.
96     * </p>
97     *
98     * @param other
99     *            Trie that is copied
100     */
101    public Trie(Trie<T> other) {
102        if (other == null) {
103            throw new IllegalArgumentException("other trie must not be null");
104        }
105        rootNode = new TrieNode<T>(other.rootNode);
106        knownSymbols = new LinkedHashSet<T>(other.knownSymbols);
107        comparator = other.comparator;
108    }
109
110    /**
111     * <p>
112     * Returns a collection of all symbols occuring in the trie.
113     * </p>
114     *
115     * @return symbols occuring in the trie
116     */
117    public Collection<T> getKnownSymbols() {
118        return new LinkedHashSet<T>(knownSymbols);
119    }
120
121    /**
122     * <p>
123     * Trains the current trie using the given sequence and adds all subsequence of length
124     * {@code maxOrder}.
125     * </p>
126     *
127     * @param sequence
128     *            sequence whose subsequences are added to the trie
129     * @param maxOrder
130     *            maximum length of the subsequences added to the trie
131     */
132    public void train(List<T> sequence, int maxOrder) {
133        if (maxOrder < 1) {
134            return;
135        }
136        IncompleteMemory<T> latestActions = new IncompleteMemory<T>(maxOrder);
137        int i = 0;
138        for (T currentEvent : sequence) {
139            latestActions.add(currentEvent);
140            addToKnownSymbols(currentEvent);
141            i++;
142            if (i >= maxOrder) {
143                add(latestActions.getLast(maxOrder));
144            }
145        }
146        int sequenceLength = sequence.size();
147        for (int j = maxOrder - 1; j > 0; j--) {
148            add(sequence.subList(sequenceLength - j, sequenceLength));
149        }
150    }
151
152    /**
153     * <p>
154     * Adds a given subsequence to the trie and increases the counters accordingly.
155     * </p>
156     *
157     * @param subsequence
158     *            subsequence whose counters are increased
159     * @see TrieNode#add(List)
160     */
161    protected void add(List<T> subsequence) {
162        if (subsequence != null && !subsequence.isEmpty()) {
163            addToKnownSymbols(subsequence);
164            subsequence = new LinkedList<T>(subsequence); // defensive copy!
165            T firstSymbol = subsequence.get(0);
166            TrieNode<T> node = getChildCreate(firstSymbol);
167            node.add(subsequence);
168        }
169    }
170
171    /**
172     * <p>
173     * Returns the child of the root node associated with the given symbol or creates it if it does
174     * not exist yet.
175     * </p>
176     *
177     * @param symbol
178     *            symbol whose node is required
179     * @return node associated with the symbol
180     * @see TrieNode#getChildCreate(Object)
181     */
182    protected TrieNode<T> getChildCreate(T symbol) {
183        return rootNode.getChildCreate(symbol);
184    }
185
186    /**
187     * <p>
188     * Returns the child of the root node associated with the given symbol or null if it does not
189     * exist.
190     * </p>
191     *
192     * @param symbol
193     *            symbol whose node is required
194     * @return node associated with the symbol; null if no such node exists
195     * @see TrieNode#getChild(Object)
196     */
197    protected TrieNode<T> getChild(T symbol) {
198        return rootNode.getChild(symbol);
199    }
200
201    /**
202     * <p>
203     * Returns the number of occurences of the given sequence.
204     * </p>
205     *
206     * @param sequence
207     *            sequence whose number of occurences is required
208     * @return number of occurences of the sequence
209     */
210    public int getCount(List<T> sequence) {
211        int count = 0;
212        TrieNode<T> node = find(sequence);
213        if (node != null) {
214            count = node.getCount();
215        }
216        return count;
217    }
218
219    /**
220     * <p>
221     * Returns the number of occurences of the given prefix and a symbol that follows it.<br>
222     * Convenience function to simplify usage of {@link #getCount(List)}.
223     * </p>
224     *
225     * @param sequence
226     *            prefix of the sequence
227     * @param follower
228     *            suffix of the sequence
229     * @return number of occurences of the sequence
230     * @see #getCount(List)
231     */
232    public int getCount(List<T> sequence, T follower) {
233        List<T> tmpSequence = new LinkedList<T>(sequence);
234        tmpSequence.add(follower);
235        return getCount(tmpSequence);
236
237    }
238
239    /**
240     * <p>
241     * Searches the trie for a given sequence and returns the node associated with the sequence or
242     * null if no such node is found.
243     * </p>
244     *
245     * @param sequence
246     *            sequence that is searched for
247     * @return node associated with the sequence
248     * @see TrieNode#find(List)
249     */
250    public TrieNode<T> find(List<T> sequence) {
251        if (sequence == null || sequence.isEmpty()) {
252            return rootNode;
253        }
254        List<T> sequenceCopy = new LinkedList<T>(sequence);
255        TrieNode<T> result = null;
256        TrieNode<T> node = getChild(sequenceCopy.get(0));
257        if (node != null) {
258            sequenceCopy.remove(0);
259            result = node.find(sequenceCopy);
260        }
261        return result;
262    }
263
264    /**
265     * <p>
266     * Returns a collection of all symbols that follow a given sequence in the trie. In case the
267     * sequence is not found or no symbols follow the sequence the result will be empty.
268     * </p>
269     *
270     * @param sequence
271     *            sequence whose followers are returned
272     * @return symbols following the given sequence
273     * @see TrieNode#getFollowingSymbols()
274     */
275    public Collection<T> getFollowingSymbols(List<T> sequence) {
276        Collection<T> result = new LinkedList<T>();
277        TrieNode<T> node = find(sequence);
278        if (node != null) {
279            result = node.getFollowingSymbols();
280        }
281        return result;
282    }
283
284    /**
285     * <p>
286     * Returns the longest suffix of the given context that is contained in the tree and whose
287     * children are leaves.
288     * </p>
289     *
290     * @param context
291     *            context whose suffix is searched for
292     * @return longest suffix of the context
293     */
294    public List<T> getContextSuffix(List<T> context) {
295        List<T> contextSuffix;
296        if (context != null) {
297            contextSuffix = new LinkedList<T>(context); // defensive copy
298        }
299        else {
300            contextSuffix = new LinkedList<T>();
301        }
302        boolean suffixFound = false;
303
304        while (!suffixFound) {
305            if (contextSuffix.isEmpty()) {
306                suffixFound = true; // suffix is the empty word
307            }
308            else {
309                TrieNode<T> node = find(contextSuffix);
310                if (node != null) {
311                    if (!node.getFollowingSymbols().isEmpty()) {
312                        suffixFound = true;
313                    }
314                }
315                if (!suffixFound) {
316                    contextSuffix.remove(0);
317                }
318            }
319        }
320
321        return contextSuffix;
322    }
323
324    /**
325     * <p>
326     * returns a list of symbol sequences which have a minimal length and that occurred most often
327     * with the same number of occurrences. The resulting list is empty, if there is no symbol
328     * sequence with the minimal length.
329     * </p>
330     *
331     * @param minimalLength the minimal length of the returned sequences
332     *
333     * @return as described
334     */
335    public Collection<List<T>> getSequencesWithMostOccurrences(int minimalLength) {
336        LinkedList<TrieNode<T>> context = new LinkedList<TrieNode<T>>();
337        Collection<List<TrieNode<T>>> paths = new LinkedList<List<TrieNode<T>>>();
338       
339        context.push(rootNode);
340       
341        // traverse the trie and determine all sequences, which have the maximum number of
342        // occurrences and a minimal length.
343       
344        // minimalLength + 1 because we denote the depth including the root node
345        determineLongPathsWithMostOccurrences(minimalLength + 1, paths, context);
346       
347        Collection<List<T>> resultingPaths = new LinkedList<List<T>>();
348        List<T> resultingPath;
349       
350        if (paths.size() > 0) {
351           
352            for (List<TrieNode<T>> path : paths) {
353                resultingPath = new LinkedList<T>();
354               
355                for (TrieNode<T> node : path) {
356                    if (node.getSymbol() != null) {
357                        resultingPath.add(node.getSymbol());
358                    }
359                }
360               
361                resultingPaths.add(resultingPath);
362            }
363        }
364       
365        return resultingPaths;
366    }
367
368    /**
369     * <p>
370     * Traverses the trie to collect all sequences with a maximum number of occurrences and with
371     * a minimal length. The length is encoded in the provided recursion depth.
372     * </p>
373     *
374     * @param minimalDepth the minimal recursion depth to be done
375     * @param paths        the paths through the trie that all occurred with the same amount and
376     *                     that have the so far found maximum of occurrences (is updated each
377     *                     time a further path with the same number of occurrences is found; is
378     *                     replaced if a path with more occurrences is found)
379     * @param context      the path through the trie, that is analyzed by the recursive call
380     */
381    private void determineLongPathsWithMostOccurrences(int                           minimalDepth,
382                                                       Collection<List<TrieNode<T>>> paths,
383                                                       LinkedList<TrieNode<T>>       context)
384    {
385        int maxCount = 0;
386
387        // only if we already reached the depth to be achieved, we check if the paths have the
388        // maximum number of occurrences
389        if (context.size() >= minimalDepth) {
390           
391            // try to determine the maximum number of occurrences so far, if any
392            if (paths.size() > 0) {
393                List<TrieNode<T>> path = paths.iterator().next();
394                maxCount = path.get(path.size() - 1).getCount();
395            }
396
397            // if the current path has a higher number of occurrences than all so far, clear
398            // the paths collected so far and set the new number of occurrences as new maximum
399            if (context.getLast().getCount() > maxCount) {
400                paths.clear();
401                maxCount = context.getLast().getCount();
402            }
403           
404            // if the path matches the current maximal number of occurrences, add it to the list
405            // of collected paths with these number of occurrences
406            if (context.getLast().getCount() == maxCount) {
407                paths.add(new LinkedList<TrieNode<T>>(context));
408            }
409        }
410       
411        // perform the trie traversal
412        for (TrieNode<T> child : context.getLast().getChildren()) {
413            if (child.getCount() >= maxCount) {
414                context.add(child);
415                determineLongPathsWithMostOccurrences(minimalDepth, paths, context);
416                context.removeLast();
417            }
418        }
419    }
420   
421    /**
422     * <p>
423     * adds a new symbol to the collection of known symbols if this symbol is not already
424     * contained. The symbols are compared using the comparator.
425     * </p>
426     *
427     * @param symbol the symbol to be added to the known symbols
428     */
429    private void addToKnownSymbols(T symbol) {
430        for (T knownSymbol : knownSymbols) {
431            if (comparator.equals(knownSymbol, symbol)) {
432                return;
433            }
434        }
435       
436        knownSymbols.add(symbol);
437    }
438
439    /**
440     * <p>
441     * adds a list of new symbols to the collection of known symbols. Uses the
442     * {@link #addToKnownSymbols(Object)} method for each element of the provided list.
443     * </p>
444     *
445     * @param symbols the list of symbols to be added to the known symbols
446     */
447    private void addToKnownSymbols(List<T> symbols) {
448        for (T symbol : symbols) {
449            addToKnownSymbols(symbol);
450        }
451    }
452
453    /**
454     * <p>
455     * Helper class for graph visualization of a trie.
456     * </p>
457     *
458     * @author Steffen Herbold
459     * @version 1.0
460     */
461    static public class Edge {}
462
463    /**
464     * <p>
465     * Helper class for graph visualization of a trie.
466     * </p>
467     *
468     * @author Steffen Herbold
469     * @version 1.0
470     */
471    static public class TrieVertex {
472
473        /**
474         * <p>
475         * Id of the vertex.
476         * </p>
477         */
478        private String id;
479
480        /**
481         * <p>
482         * Contructor. Creates a new TrieVertex.
483         * </p>
484         *
485         * @param id
486         *            id of the vertex
487         */
488        protected TrieVertex(String id) {
489            this.id = id;
490        }
491
492        /**
493         * <p>
494         * Returns the id of the vertex.
495         * </p>
496         *
497         * @see java.lang.Object#toString()
498         */
499        @Override
500        public String toString() {
501            return id;
502        }
503    }
504
505    /**
506     * <p>
507     * Returns a {@link Graph} representation of the trie.
508     * </p>
509     *
510     * @return {@link Graph} representation of the trie
511     */
512    protected Tree<TrieVertex, Edge> getGraph() {
513        DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>();
514        rootNode.getGraph(null, graph);
515        return graph;
516    }
517
518    /*
519     * (non-Javadoc)
520     *
521     * @see de.ugoe.cs.autoquest.usageprofiles.IDotCompatible#getDotRepresentation()
522     */
523    public String getDotRepresentation() {
524        StringBuilder stringBuilder = new StringBuilder();
525        stringBuilder.append("digraph model {" + StringTools.ENDLINE);
526        rootNode.appendDotRepresentation(stringBuilder);
527        stringBuilder.append('}' + StringTools.ENDLINE);
528        return stringBuilder.toString();
529    }
530
531    /**
532     * <p>
533     * Returns the string representation of the root node.
534     * </p>
535     *
536     * @see TrieNode#toString()
537     * @see java.lang.Object#toString()
538     */
539    @Override
540    public String toString() {
541        return rootNode.toString();
542    }
543
544    /**
545     * <p>
546     * Returns the number of symbols contained in the trie.
547     * </p>
548     *
549     * @return number of symbols contained in the trie
550     */
551    public int getNumSymbols() {
552        return knownSymbols.size();
553    }
554
555    /**
556     * <p>
557     * Returns the number of trie nodes that are ancestors of a leaf. This is the equivalent to the
558     * number of states a first-order markov model would have.
559     * <p>
560     *
561     * @return number of trie nodes that are ancestors of leafs.
562     */
563    public int getNumLeafAncestors() {
564        List<TrieNode<T>> ancestors = new LinkedList<TrieNode<T>>();
565        rootNode.getLeafAncestors(ancestors);
566        return ancestors.size();
567    }
568
569    /**
570     * <p>
571     * Returns the number of trie nodes that are leafs.
572     * </p>
573     *
574     * @return number of leafs in the trie
575     */
576    public int getNumLeafs() {
577        return rootNode.getNumLeafs();
578    }
579
580    /**
581     * <p>
582     * Updates the list of known symbols by replacing it with all symbols that are found in the
583     * child nodes of the root node. This should be the same as all symbols that are contained in
584     * the trie.
585     * </p>
586     */
587    public void updateKnownSymbols() {
588        knownSymbols = new HashSet<T>();
589        for (TrieNode<T> node : rootNode.getChildren()) {
590            addToKnownSymbols(node.getSymbol());
591        }
592    }
593
594    /**
595     * <p>
596     * Two Tries are defined as equal, if their {@link #rootNode} are equal.
597     * </p>
598     *
599     * @see java.lang.Object#equals(java.lang.Object)
600     */
601    @SuppressWarnings("rawtypes")
602    @Override
603    public boolean equals(Object other) {
604        if (other == this) {
605            return true;
606        }
607        if (other instanceof Trie) {
608            return rootNode.equals(((Trie) other).rootNode);
609        }
610        return false;
611    }
612
613    /*
614     * (non-Javadoc)
615     *
616     * @see java.lang.Object#hashCode()
617     */
618    @Override
619    public int hashCode() {
620        int multiplier = 17;
621        int hash = 42;
622        if (rootNode != null) {
623            hash = multiplier * hash + rootNode.hashCode();
624        }
625        return hash;
626    }
627
628}
Note: See TracBrowser for help on using the repository browser.