source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/models/Trie.java @ 351

Last change on this file since 351 was 316, checked in by sherbold, 13 years ago
  • fixed bug in de.ugoe.cs.eventbench.models.Trie.getContextSuffix() which now handles null values correctly. Test case de.ugoe.cs.eventbench.models.TrieTest?.testGetContextSuffix_5() should now run without failures
File size: 9.9 KB
Line 
1package de.ugoe.cs.eventbench.models;
2
3import java.io.Serializable;
4import java.util.Collection;
5import java.util.HashSet;
6import java.util.LinkedHashSet;
7import java.util.LinkedList;
8import java.util.List;
9import java.util.Set;
10
11import de.ugoe.cs.util.StringTools;
12
13import edu.uci.ics.jung.graph.DelegateTree;
14import edu.uci.ics.jung.graph.Graph;
15import edu.uci.ics.jung.graph.Tree;
16
17/**
18 * <p>
19 * This class implements a <it>trie</it>, i.e., a tree of sequences that the
20 * occurence of subsequences up to a predefined length. This length is the trie
21 * order.
22 * </p>
23 *
24 * @author Steffen Herbold
25 *
26 * @param <T>
27 *            Type of the symbols that are stored in the trie.
28 *
29 * @see TrieNode
30 */
31public class Trie<T> implements IDotCompatible, Serializable {
32
33        /**
34         * <p>
35         * Id for object serialization.
36         * </p>
37         */
38        private static final long serialVersionUID = 1L;
39
40        /**
41         * <p>
42         * Collection of all symbols occuring in the trie.
43         * </p>
44         */
45        private Collection<T> knownSymbols;
46
47        /**
48         * <p>
49         * Reference to the root of the trie.
50         * </p>
51         */
52        private final TrieNode<T> rootNode;
53
54        /**
55         * <p>
56         * Contructor. Creates a new Trie.
57         * </p>
58         */
59        public Trie() {
60                rootNode = new TrieNode<T>();
61                knownSymbols = new LinkedHashSet<T>();
62        }
63
64        /**
65         * <p>
66         * Returns a collection of all symbols occuring in the trie.
67         * </p>
68         *
69         * @return symbols occuring in the trie
70         */
71        public Collection<T> getKnownSymbols() {
72                return new LinkedHashSet<T>(knownSymbols);
73        }
74
75        /**
76         * <p>
77         * Trains the current trie using the given sequence and adds all subsequence
78         * of length {@code maxOrder}.
79         * </p>
80         *
81         * @param sequence
82         *            sequence whose subsequences are added to the trie
83         * @param maxOrder
84         *            maximum length of the subsequences added to the trie
85         */
86        public void train(List<T> sequence, int maxOrder) {
87                if (maxOrder < 1) {
88                        return;
89                }
90                IncompleteMemory<T> latestActions = new IncompleteMemory<T>(maxOrder);
91                int i = 0;
92                for (T currentEvent : sequence) {
93                        latestActions.add(currentEvent);
94                        knownSymbols.add(currentEvent);
95                        i++;
96                        if (i >= maxOrder) {
97                                add(latestActions.getLast(maxOrder));
98                        }
99                }
100                int sequenceLength = sequence.size();
101                for (int j = maxOrder - 1; j > 0; j--) {
102                        add(sequence.subList(sequenceLength - j, sequenceLength));
103                }
104        }
105
106        /**
107         * <p>
108         * Adds a given subsequence to the trie and increases the counters
109         * accordingly.
110         * </p>
111         *
112         * @param subsequence
113         *            subsequence whose counters are increased
114         * @see TrieNode#add(List)
115         */
116        protected void add(List<T> subsequence) {
117                if (subsequence != null && !subsequence.isEmpty()) {
118                        knownSymbols.addAll(subsequence);
119                        subsequence = new LinkedList<T>(subsequence); // defensive copy!
120                        T firstSymbol = subsequence.get(0);
121                        TrieNode<T> node = getChildCreate(firstSymbol);
122                        node.add(subsequence);
123                }
124        }
125
126        /**
127         * <p>
128         * Returns the child of the root node associated with the given symbol or
129         * creates it if it does not exist yet.
130         * </p>
131         *
132         * @param symbol
133         *            symbol whose node is required
134         * @return node associated with the symbol
135         * @see TrieNode#getChildCreate(Object)
136         */
137        protected TrieNode<T> getChildCreate(T symbol) {
138                return rootNode.getChildCreate(symbol);
139        }
140
141        /**
142         * <p>
143         * Returns the child of the root node associated with the given symbol or
144         * null if it does not exist.
145         * </p>
146         *
147         * @param symbol
148         *            symbol whose node is required
149         * @return node associated with the symbol; null if no such node exists
150         * @see TrieNode#getChild(Object)
151         */
152        protected TrieNode<T> getChild(T symbol) {
153                return rootNode.getChild(symbol);
154        }
155
156        /**
157         * <p>
158         * Returns the number of occurences of the given sequence.
159         * </p>
160         *
161         * @param sequence
162         *            sequence whose number of occurences is required
163         * @return number of occurences of the sequence
164         */
165        public int getCount(List<T> sequence) {
166                int count = 0;
167                TrieNode<T> node = find(sequence);
168                if (node != null) {
169                        count = node.getCount();
170                }
171                return count;
172        }
173
174        /**
175         * <p>
176         * Returns the number of occurences of the given prefix and a symbol that
177         * follows it.<br>
178         * Convenience function to simplify usage of {@link #getCount(List)}.
179         * </p>
180         *
181         * @param sequence
182         *            prefix of the sequence
183         * @param follower
184         *            suffix of the sequence
185         * @return number of occurences of the sequence
186         * @see #getCount(List)
187         */
188        public int getCount(List<T> sequence, T follower) {
189                List<T> tmpSequence = new LinkedList<T>(sequence);
190                tmpSequence.add(follower);
191                return getCount(tmpSequence);
192
193        }
194
195        /**
196         * <p>
197         * Searches the trie for a given sequence and returns the node associated
198         * with the sequence or null if no such node is found.
199         * </p>
200         *
201         * @param sequence
202         *            sequence that is searched for
203         * @return node associated with the sequence
204         * @see TrieNode#find(List)
205         */
206        public TrieNode<T> find(List<T> sequence) {
207                if (sequence == null || sequence.isEmpty()) {
208                        return rootNode;
209                }
210                List<T> sequenceCopy = new LinkedList<T>(sequence);
211                TrieNode<T> result = null;
212                TrieNode<T> node = getChild(sequenceCopy.get(0));
213                if (node != null) {
214                        sequenceCopy.remove(0);
215                        result = node.find(sequenceCopy);
216                }
217                return result;
218        }
219
220        /**
221         * <p>
222         * Returns a collection of all symbols that follow a given sequence in the
223         * trie. In case the sequence is not found or no symbols follow the sequence
224         * the result will be empty.
225         * </p>
226         *
227         * @param sequence
228         *            sequence whose followers are returned
229         * @return symbols following the given sequence
230         * @see TrieNode#getFollowingSymbols()
231         */
232        public Collection<T> getFollowingSymbols(List<T> sequence) {
233                Collection<T> result = new LinkedList<T>();
234                TrieNode<T> node = find(sequence);
235                if (node != null) {
236                        result = node.getFollowingSymbols();
237                }
238                return result;
239        }
240
241        /**
242         * <p>
243         * Returns the longest suffix of the given context that is contained in the
244         * tree and whose children are leaves.
245         * </p>
246         *
247         * @param context
248         *            context whose suffix is searched for
249         * @return longest suffix of the context
250         */
251        public List<T> getContextSuffix(List<T> context) {
252                List<T> contextSuffix;
253                if( context!=null ) {
254                        contextSuffix = new LinkedList<T>(context); // defensive copy
255                } else {
256                        contextSuffix = new LinkedList<T>();
257                }
258                boolean suffixFound = false;
259
260                while (!suffixFound) {
261                        if (contextSuffix.isEmpty()) {
262                                suffixFound = true; // suffix is the empty word
263                        } else {
264                                TrieNode<T> node = find(contextSuffix);
265                                if (node != null) {
266                                        if (!node.getFollowingSymbols().isEmpty()) {
267                                                suffixFound = true;
268                                        }
269                                }
270                                if (!suffixFound) {
271                                        contextSuffix.remove(0);
272                                }
273                        }
274                }
275
276                return contextSuffix;
277        }
278
279        /**
280         * <p>
281         * Helper class for graph visualization of a trie.
282         * </p>
283         *
284         * @author Steffen Herbold
285         * @version 1.0
286         */
287        static public class Edge {
288        }
289
290        /**
291         * <p>
292         * Helper class for graph visualization of a trie.
293         * </p>
294         *
295         * @author Steffen Herbold
296         * @version 1.0
297         */
298        static public class TrieVertex {
299
300                /**
301                 * <p>
302                 * Id of the vertex.
303                 * </p>
304                 */
305                private String id;
306
307                /**
308                 * <p>
309                 * Contructor. Creates a new TrieVertex.
310                 * </p>
311                 *
312                 * @param id
313                 *            id of the vertex
314                 */
315                protected TrieVertex(String id) {
316                        this.id = id;
317                }
318
319                /**
320                 * <p>
321                 * Returns the id of the vertex.
322                 * </p>
323                 *
324                 * @see java.lang.Object#toString()
325                 */
326                @Override
327                public String toString() {
328                        return id;
329                }
330        }
331
332        /**
333         * <p>
334         * Returns a {@link Graph} representation of the trie.
335         * </p>
336         *
337         * @return {@link Graph} representation of the trie
338         */
339        protected Tree<TrieVertex, Edge> getGraph() {
340                DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>();
341                rootNode.getGraph(null, graph);
342                return graph;
343        }
344
345        /*
346         * (non-Javadoc)
347         *
348         * @see de.ugoe.cs.eventbench.models.IDotCompatible#getDotRepresentation()
349         */
350        public String getDotRepresentation() {
351                StringBuilder stringBuilder = new StringBuilder();
352                stringBuilder.append("digraph model {" + StringTools.ENDLINE);
353                rootNode.appendDotRepresentation(stringBuilder);
354                stringBuilder.append('}' + StringTools.ENDLINE);
355                return stringBuilder.toString();
356        }
357
358        /**
359         * <p>
360         * Returns the string representation of the root node.
361         * </p>
362         *
363         * @see TrieNode#toString()
364         * @see java.lang.Object#toString()
365         */
366        @Override
367        public String toString() {
368                return rootNode.toString();
369        }
370
371        /**
372         * <p>
373         * Returns the number of symbols contained in the trie.
374         * </p>
375         *
376         * @return number of symbols contained in the trie
377         */
378        public int getNumSymbols() {
379                return knownSymbols.size();
380        }
381
382        /**
383         * <p>
384         * Returns the number of trie nodes that are ancestors of a leaf. This is
385         * the equivalent to the number of states a first-order markov model would
386         * have.
387         * <p>
388         *
389         * @return number of trie nodes that are ancestors of leafs.
390         */
391        public int getNumLeafAncestors() {
392                Set<TrieNode<T>> ancestors = new HashSet<TrieNode<T>>();
393                rootNode.getLeafAncestors(ancestors);
394                return ancestors.size();
395        }
396
397        /**
398         * <p>
399         * Returns the number of trie nodes that are leafs.
400         * </p>
401         *
402         * @return number of leafs in the trie
403         */
404        public int getNumLeafs() {
405                return rootNode.getNumLeafs();
406        }
407
408        /**
409         * <p>
410         * Updates the list of known symbols by replacing it with all symbols that
411         * are found in the child nodes of the root node. This should be the same as
412         * all symbols that are contained in the trie.
413         * </p>
414         */
415        public void updateKnownSymbols() {
416                knownSymbols = new HashSet<T>();
417                for (TrieNode<T> node : rootNode.getChildren()) {
418                        knownSymbols.add(node.getSymbol());
419                }
420        }
421}
Note: See TracBrowser for help on using the repository browser.