source: trunk/autoquest-core-usageprofiles/src/main/java/de/ugoe/cs/autoquest/usageprofiles/Trie.java @ 1118

Last change on this file since 1118 was 1118, checked in by pharms, 11 years ago
  • added possibility to perform effective trie processing
  • removed bug that happens, if a trained sequences was shorter than the provided max order
  • Property svn:mime-type set to text/plain
File size: 21.2 KB
Line 
1//   Copyright 2012 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.autoquest.usageprofiles;
16
17import java.io.Serializable;
18import java.util.Collection;
19import java.util.HashSet;
20import java.util.LinkedHashSet;
21import java.util.LinkedList;
22import java.util.List;
23
24import de.ugoe.cs.util.StringTools;
25
26import edu.uci.ics.jung.graph.DelegateTree;
27import edu.uci.ics.jung.graph.Graph;
28import edu.uci.ics.jung.graph.Tree;
29
30/**
31 * <p>
32 * This class implements a <it>trie</it>, i.e., a tree of sequences that the occurence of
33 * subsequences up to a predefined length. This length is the trie order.
34 * </p>
35 *
36 * @author Steffen Herbold, Patrick Harms
37 *
38 * @param <T>
39 *            Type of the symbols that are stored in the trie.
40 *
41 * @see TrieNode
42 */
43public class Trie<T> implements IDotCompatible, Serializable {
44
45    /**
46     * <p>
47     * Id for object serialization.
48     * </p>
49     */
50    private static final long serialVersionUID = 1L;
51
52    /**
53     * <p>
54     * Collection of all symbols occuring in the trie.
55     * </p>
56     */
57    private Collection<T> knownSymbols;
58
59    /**
60     * <p>
61     * Reference to the root of the trie.
62     * </p>
63     */
64    private final TrieNode<T> rootNode;
65
66    /**
67     * <p>
68     * Comparator to be used for comparing the symbols with each other
69     * </p>
70     */
71    private SymbolComparator<T> comparator;
72
73    /**
74     * <p>
75     * Contructor. Creates a new Trie with a {@link DefaultSymbolComparator}.
76     * </p>
77     */
78    public Trie() {
79        this(new DefaultSymbolComparator<T>());
80    }
81
82    /**
83     * <p>
84     * Contructor. Creates a new Trie with that uses a specific {@link SymbolComparator}.
85     * </p>
86     */
87    public Trie(SymbolComparator<T> comparator) {
88        this.comparator = comparator;
89        rootNode = new TrieNode<T>(comparator);
90        knownSymbols = new LinkedHashSet<T>();
91    }
92
93    /**
94     * <p>
95     * Copy-Constructor. Creates a new Trie as the copy of other. The other trie must not be null.
96     * </p>
97     *
98     * @param other
99     *            Trie that is copied
100     */
101    public Trie(Trie<T> other) {
102        if (other == null) {
103            throw new IllegalArgumentException("other trie must not be null");
104        }
105        rootNode = new TrieNode<T>(other.rootNode);
106        knownSymbols = new LinkedHashSet<T>(other.knownSymbols);
107        comparator = other.comparator;
108    }
109
110    /**
111     * <p>
112     * Returns a collection of all symbols occuring in the trie.
113     * </p>
114     *
115     * @return symbols occuring in the trie
116     */
117    public Collection<T> getKnownSymbols() {
118        return new LinkedHashSet<T>(knownSymbols);
119    }
120
121    /**
122     * <p>
123     * Trains the current trie using the given sequence and adds all subsequences of length
124     * {@code maxOrder}.
125     * </p>
126     *
127     * @param sequence
128     *            sequence whose subsequences are added to the trie
129     * @param maxOrder
130     *            maximum length of the subsequences added to the trie
131     */
132    public void train(List<T> sequence, int maxOrder) {
133        if (maxOrder < 1) {
134            return;
135        }
136        IncompleteMemory<T> latestActions = new IncompleteMemory<T>(maxOrder);
137        int i = 0;
138        for (T currentEvent : sequence) {
139            latestActions.add(currentEvent);
140            addToKnownSymbols(currentEvent);
141            i++;
142            if (i >= maxOrder) {
143                add(latestActions.getLast(maxOrder));
144            }
145        }
146        int sequenceLength = sequence.size();
147        int startIndex = Math.max(0, sequenceLength - maxOrder + 1);
148        for (int j = startIndex; j < sequenceLength; j++) {
149            add(sequence.subList(j, sequenceLength));
150        }
151    }
152
153    /**
154     * <p>
155     * Adds a given subsequence to the trie and increases the counters accordingly.
156     * </p>
157     *
158     * @param subsequence
159     *            subsequence whose counters are increased
160     * @see TrieNode#add(List)
161     */
162    protected void add(List<T> subsequence) {
163        if (subsequence != null && !subsequence.isEmpty()) {
164            addToKnownSymbols(subsequence);
165            subsequence = new LinkedList<T>(subsequence); // defensive copy!
166            T firstSymbol = subsequence.get(0);
167            TrieNode<T> node = getChildCreate(firstSymbol);
168            node.add(subsequence);
169        }
170    }
171
172    /**
173     * <p>
174     * Returns the child of the root node associated with the given symbol or creates it if it does
175     * not exist yet.
176     * </p>
177     *
178     * @param symbol
179     *            symbol whose node is required
180     * @return node associated with the symbol
181     * @see TrieNode#getChildCreate(Object)
182     */
183    protected TrieNode<T> getChildCreate(T symbol) {
184        return rootNode.getChildCreate(symbol);
185    }
186
187    /**
188     * <p>
189     * Returns the child of the root node associated with the given symbol or null if it does not
190     * exist.
191     * </p>
192     *
193     * @param symbol
194     *            symbol whose node is required
195     * @return node associated with the symbol; null if no such node exists
196     * @see TrieNode#getChild(Object)
197     */
198    protected TrieNode<T> getChild(T symbol) {
199        return rootNode.getChild(symbol);
200    }
201
202    /**
203     * <p>
204     * Returns the number of occurences of the given sequence.
205     * </p>
206     *
207     * @param sequence
208     *            sequence whose number of occurences is required
209     * @return number of occurences of the sequence
210     */
211    public int getCount(List<T> sequence) {
212        int count = 0;
213        TrieNode<T> node = find(sequence);
214        if (node != null) {
215            count = node.getCount();
216        }
217        return count;
218    }
219
220    /**
221     * <p>
222     * Returns the number of occurences of the given prefix and a symbol that follows it.<br>
223     * Convenience function to simplify usage of {@link #getCount(List)}.
224     * </p>
225     *
226     * @param sequence
227     *            prefix of the sequence
228     * @param follower
229     *            suffix of the sequence
230     * @return number of occurences of the sequence
231     * @see #getCount(List)
232     */
233    public int getCount(List<T> sequence, T follower) {
234        List<T> tmpSequence = new LinkedList<T>(sequence);
235        tmpSequence.add(follower);
236        return getCount(tmpSequence);
237
238    }
239
240    /**
241     * <p>
242     * Searches the trie for a given sequence and returns the node associated with the sequence or
243     * null if no such node is found.
244     * </p>
245     *
246     * @param sequence
247     *            sequence that is searched for
248     * @return node associated with the sequence
249     * @see TrieNode#find(List)
250     */
251    public TrieNode<T> find(List<T> sequence) {
252        if (sequence == null || sequence.isEmpty()) {
253            return rootNode;
254        }
255        List<T> sequenceCopy = new LinkedList<T>(sequence);
256        TrieNode<T> result = null;
257        TrieNode<T> node = getChild(sequenceCopy.get(0));
258        if (node != null) {
259            sequenceCopy.remove(0);
260            result = node.find(sequenceCopy);
261        }
262        return result;
263    }
264
265    /**
266     * <p>
267     * Returns a collection of all symbols that follow a given sequence in the trie. In case the
268     * sequence is not found or no symbols follow the sequence the result will be empty.
269     * </p>
270     *
271     * @param sequence
272     *            sequence whose followers are returned
273     * @return symbols following the given sequence
274     * @see TrieNode#getFollowingSymbols()
275     */
276    public Collection<T> getFollowingSymbols(List<T> sequence) {
277        Collection<T> result = new LinkedList<T>();
278        TrieNode<T> node = find(sequence);
279        if (node != null) {
280            result = node.getFollowingSymbols();
281        }
282        return result;
283    }
284
285    /**
286     * <p>
287     * Returns the longest suffix of the given context that is contained in the tree and whose
288     * children are leaves.
289     * </p>
290     *
291     * @param context
292     *            context whose suffix is searched for
293     * @return longest suffix of the context
294     */
295    public List<T> getContextSuffix(List<T> context) {
296        List<T> contextSuffix;
297        if (context != null) {
298            contextSuffix = new LinkedList<T>(context); // defensive copy
299        }
300        else {
301            contextSuffix = new LinkedList<T>();
302        }
303        boolean suffixFound = false;
304
305        while (!suffixFound) {
306            if (contextSuffix.isEmpty()) {
307                suffixFound = true; // suffix is the empty word
308            }
309            else {
310                TrieNode<T> node = find(contextSuffix);
311                if (node != null) {
312                    if (!node.getFollowingSymbols().isEmpty()) {
313                        suffixFound = true;
314                    }
315                }
316                if (!suffixFound) {
317                    contextSuffix.remove(0);
318                }
319            }
320        }
321
322        return contextSuffix;
323    }
324   
325    /**
326     *
327     */
328    public void process(TrieProcessor<T> processor) {
329        LinkedList<T> context = new LinkedList<T>();
330       
331        for (TrieNode<T> child : rootNode.getChildren()) {
332            if (!process(context, child, processor)) {
333                break;
334            }
335        }
336    }
337
338    /**
339     * <p>
340     * TODO: comment
341     * </p>
342     * @param context
343     *
344     * @param child
345     * @param processor
346     */
347    private boolean process(LinkedList<T>    context,
348                            TrieNode<T>      node,
349                            TrieProcessor<T> processor)
350    {
351        context.add(node.getSymbol());
352       
353        TrieProcessor.Result result = processor.process(context, node.getCount());
354       
355        if (result == TrieProcessor.Result.CONTINUE) {
356            for (TrieNode<T> child : node.getChildren()) {
357                if (!process(context, child, processor)) {
358                    break;
359                }
360            }
361        }
362       
363        context.removeLast();
364       
365        return result != TrieProcessor.Result.BREAK;
366    }
367
368    /**
369     * <p>
370     * returns a list of symbol sequences which have a minimal length and that occurred as often
371     * as defined by the given occurrence count. If the given occurrence count is smaller 1 then
372     * those sequences are returned, that occur most often. The resulting list is empty, if there
373     * is no symbol sequence with the minimal length or the provided number of occurrences.
374     * </p>
375     *
376     * @param minimalLength   the minimal length of the returned sequences
377     * @param occurrenceCount the number of occurrences of the returned sequences
378     *
379     * @return as described
380     */
381    public Collection<List<T>> getSequencesWithOccurrenceCount(int minimalLength,
382                                                               int occurrenceCount)
383    {
384        LinkedList<TrieNode<T>> context = new LinkedList<TrieNode<T>>();
385        Collection<List<TrieNode<T>>> paths = new LinkedList<List<TrieNode<T>>>();
386       
387        context.push(rootNode);
388       
389        // traverse the trie and determine all sequences, which have the provided number of
390        // occurrences and a minimal length.
391       
392        // minimalLength + 1 because we denote the depth including the root node
393        determineLongPathsWithMostOccurrences(minimalLength + 1, occurrenceCount, paths, context);
394       
395        Collection<List<T>> resultingPaths = new LinkedList<List<T>>();
396        List<T> resultingPath;
397       
398        if (paths.size() > 0) {
399           
400            for (List<TrieNode<T>> path : paths) {
401                resultingPath = new LinkedList<T>();
402               
403                for (TrieNode<T> node : path) {
404                    if (node.getSymbol() != null) {
405                        resultingPath.add(node.getSymbol());
406                    }
407                }
408               
409                resultingPaths.add(resultingPath);
410            }
411        }
412       
413        return resultingPaths;
414    }
415
416    /**
417     * <p>
418     * Traverses the trie to collect all sequences with a defined number of occurrences and with
419     * a minimal length. If the given occurrence count is smaller 1 then those sequences are
420     * searched that occur most often. The length of the sequences is encoded in the provided
421     * recursion depth.
422     * </p>
423     *
424     * @param minimalDepth    the minimal recursion depth to be done
425     * @param occurrenceCount the number of occurrences of the returned sequences
426     * @param paths           the paths through the trie that all occurred with the same amount
427     *                        (if occurrence count is smaller 1, the paths which occurred most
428     *                        often) and that have the so far found matching number of occurrences
429     *                        (is updated each time a further path with the same number of
430     *                        occurrences is found; if occurrence count is smaller 1
431     *                        it is replaced if a path with more occurrences is found)
432     * @param context         the path through the trie, that is analyzed by the recursive call
433     */
434    private void determineLongPathsWithMostOccurrences(int                           minimalDepth,
435                                                       int                           occurrenceCount,
436                                                       Collection<List<TrieNode<T>>> paths,
437                                                       LinkedList<TrieNode<T>>       context)
438    {
439        int envisagedCount = occurrenceCount;
440
441        // only if we already reached the depth to be achieved, we check if the paths have the
442        // required number of occurrences
443        if (context.size() >= minimalDepth) {
444           
445            if (envisagedCount < 1) {
446                // try to determine the maximum number of occurrences so far, if any
447                if (paths.size() > 0) {
448                    List<TrieNode<T>> path = paths.iterator().next();
449                    envisagedCount = path.get(path.size() - 1).getCount();
450                }
451
452                // if the current path has a higher number of occurrences than all so far, clear
453                // the paths collected so far and set the new number of occurrences as new maximum
454                if (context.getLast().getCount() > envisagedCount) {
455                    paths.clear();
456                    envisagedCount = context.getLast().getCount();
457                }
458            }
459           
460            // if the path matches the current maximal number of occurrences, add it to the list
461            // of collected paths with these number of occurrences
462            if (context.getLast().getCount() == envisagedCount) {
463                paths.add(new LinkedList<TrieNode<T>>(context));
464            }
465        }
466       
467        // perform the trie traversal
468        for (TrieNode<T> child : context.getLast().getChildren()) {
469            if (child.getCount() >= envisagedCount) {
470                context.add(child);
471                determineLongPathsWithMostOccurrences
472                    (minimalDepth, occurrenceCount, paths, context);
473                context.removeLast();
474            }
475        }
476    }
477   
478    /**
479     * <p>
480     * adds a new symbol to the collection of known symbols if this symbol is not already
481     * contained. The symbols are compared using the comparator.
482     * </p>
483     *
484     * @param symbol the symbol to be added to the known symbols
485     */
486    private void addToKnownSymbols(T symbol) {
487        for (T knownSymbol : knownSymbols) {
488            if (comparator.equals(knownSymbol, symbol)) {
489                return;
490            }
491        }
492       
493        knownSymbols.add(symbol);
494    }
495
496    /**
497     * <p>
498     * adds a list of new symbols to the collection of known symbols. Uses the
499     * {@link #addToKnownSymbols(Object)} method for each element of the provided list.
500     * </p>
501     *
502     * @param symbols the list of symbols to be added to the known symbols
503     */
504    private void addToKnownSymbols(List<T> symbols) {
505        for (T symbol : symbols) {
506            addToKnownSymbols(symbol);
507        }
508    }
509
510    /**
511     * <p>
512     * Helper class for graph visualization of a trie.
513     * </p>
514     *
515     * @author Steffen Herbold
516     * @version 1.0
517     */
518    static public class Edge {}
519
520    /**
521     * <p>
522     * Helper class for graph visualization of a trie.
523     * </p>
524     *
525     * @author Steffen Herbold
526     * @version 1.0
527     */
528    static public class TrieVertex {
529
530        /**
531         * <p>
532         * Id of the vertex.
533         * </p>
534         */
535        private String id;
536
537        /**
538         * <p>
539         * Contructor. Creates a new TrieVertex.
540         * </p>
541         *
542         * @param id
543         *            id of the vertex
544         */
545        protected TrieVertex(String id) {
546            this.id = id;
547        }
548
549        /**
550         * <p>
551         * Returns the id of the vertex.
552         * </p>
553         *
554         * @see java.lang.Object#toString()
555         */
556        @Override
557        public String toString() {
558            return id;
559        }
560    }
561
562    /**
563     * <p>
564     * Returns a {@link Graph} representation of the trie.
565     * </p>
566     *
567     * @return {@link Graph} representation of the trie
568     */
569    protected Tree<TrieVertex, Edge> getGraph() {
570        DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>();
571        rootNode.getGraph(null, graph);
572        return graph;
573    }
574
575    /*
576     * (non-Javadoc)
577     *
578     * @see de.ugoe.cs.autoquest.usageprofiles.IDotCompatible#getDotRepresentation()
579     */
580    public String getDotRepresentation() {
581        StringBuilder stringBuilder = new StringBuilder();
582        stringBuilder.append("digraph model {" + StringTools.ENDLINE);
583        rootNode.appendDotRepresentation(stringBuilder);
584        stringBuilder.append('}' + StringTools.ENDLINE);
585        return stringBuilder.toString();
586    }
587
588    /**
589     * <p>
590     * Returns the string representation of the root node.
591     * </p>
592     *
593     * @see TrieNode#toString()
594     * @see java.lang.Object#toString()
595     */
596    @Override
597    public String toString() {
598        return rootNode.toString();
599    }
600
601    /**
602     * <p>
603     * Returns the number of symbols contained in the trie.
604     * </p>
605     *
606     * @return number of symbols contained in the trie
607     */
608    public int getNumSymbols() {
609        return knownSymbols.size();
610    }
611
612    /**
613     * <p>
614     * Returns the number of trie nodes that are ancestors of a leaf. This is the equivalent to the
615     * number of states a first-order markov model would have.
616     * <p>
617     *
618     * @return number of trie nodes that are ancestors of leafs.
619     */
620    public int getNumLeafAncestors() {
621        List<TrieNode<T>> ancestors = new LinkedList<TrieNode<T>>();
622        rootNode.getLeafAncestors(ancestors);
623        return ancestors.size();
624    }
625
626    /**
627     * <p>
628     * Returns the number of trie nodes that are leafs.
629     * </p>
630     *
631     * @return number of leafs in the trie
632     */
633    public int getNumLeafs() {
634        return rootNode.getNumLeafs();
635    }
636
637    /**
638     * <p>
639     * Updates the list of known symbols by replacing it with all symbols that are found in the
640     * child nodes of the root node. This should be the same as all symbols that are contained in
641     * the trie.
642     * </p>
643     */
644    public void updateKnownSymbols() {
645        knownSymbols = new HashSet<T>();
646        for (TrieNode<T> node : rootNode.getChildren()) {
647            addToKnownSymbols(node.getSymbol());
648        }
649    }
650
651    /**
652     * <p>
653     * Two Tries are defined as equal, if their {@link #rootNode} are equal.
654     * </p>
655     *
656     * @see java.lang.Object#equals(java.lang.Object)
657     */
658    @SuppressWarnings("rawtypes")
659    @Override
660    public boolean equals(Object other) {
661        if (other == this) {
662            return true;
663        }
664        if (other instanceof Trie) {
665            return rootNode.equals(((Trie) other).rootNode);
666        }
667        return false;
668    }
669
670    /*
671     * (non-Javadoc)
672     *
673     * @see java.lang.Object#hashCode()
674     */
675    @Override
676    public int hashCode() {
677        int multiplier = 17;
678        int hash = 42;
679        if (rootNode != null) {
680            hash = multiplier * hash + rootNode.hashCode();
681        }
682        return hash;
683    }
684
685}
Note: See TracBrowser for help on using the repository browser.