source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/models/Trie.java @ 251

Last change on this file since 251 was 251, checked in by sherbold, 13 years ago
  • the method train() of de.ugoe.cs.eventbench.models.Trie now aborts immediately if the order is less than 1.
  • added method isRoot() to de.ugoe.cs.eventbench.models.TrieNode?
File size: 9.4 KB
Line 
1package de.ugoe.cs.eventbench.models;
2
3import java.io.Serializable;
4import java.util.Collection;
5import java.util.HashSet;
6import java.util.LinkedHashSet;
7import java.util.LinkedList;
8import java.util.List;
9import java.util.Set;
10
11import de.ugoe.cs.util.StringTools;
12
13import edu.uci.ics.jung.graph.DelegateTree;
14import edu.uci.ics.jung.graph.Graph;
15import edu.uci.ics.jung.graph.Tree;
16
17/**
18 * <p>
19 * This class implements a <it>trie</it>, i.e., a tree of sequences that the
20 * occurence of subsequences up to a predefined length. This length is the trie
21 * order.
22 * </p>
23 *
24 * @author Steffen Herbold
25 *
26 * @param <T>
27 *            Type of the symbols that are stored in the trie.
28 *
29 * @see TrieNode
30 */
31public class Trie<T> implements IDotCompatible, Serializable {
32
33        /**
34         * <p>
35         * Id for object serialization.
36         * </p>
37         */
38        private static final long serialVersionUID = 1L;
39
40        /**
41         * <p>
42         * Collection of all symbols occuring in the trie.
43         * </p>
44         */
45        private Collection<T> knownSymbols;
46
47        /**
48         * <p>
49         * Reference to the root of the trie.
50         * </p>
51         */
52        private final TrieNode<T> rootNode;
53
54        /**
55         * <p>
56         * Contructor. Creates a new Trie.
57         * </p>
58         */
59        public Trie() {
60                rootNode = new TrieNode<T>();
61                knownSymbols = new LinkedHashSet<T>();
62        }
63
64        /**
65         * <p>
66         * Returns a collection of all symbols occuring in the trie.
67         * </p>
68         *
69         * @return symbols occuring in the trie
70         */
71        public Collection<T> getKnownSymbols() {
72                return new LinkedHashSet<T>(knownSymbols);
73        }
74
75        /**
76         * <p>
77         * Trains the current trie using the given sequence and adds all subsequence
78         * of length {@code maxOrder}.
79         * </p>
80         *
81         * @param sequence
82         *            sequence whose subsequences are added to the trie
83         * @param maxOrder
84         *            maximum length of the subsequences added to the trie
85         */
86        public void train(List<T> sequence, int maxOrder) {
87                if( maxOrder<1 ) {
88                        return;
89                }
90                IncompleteMemory<T> latestActions = new IncompleteMemory<T>(maxOrder);
91                int i = 0;
92                for (T currentEvent : sequence) {
93                        latestActions.add(currentEvent);
94                        knownSymbols.add(currentEvent);
95                        i++;
96                        if (i >= maxOrder) {
97                                add(latestActions.getLast(maxOrder));
98                        }
99                }
100                int sequenceLength = sequence.size();
101                for (int j = maxOrder - 1; j > 0; j--) {
102                        add(sequence.subList(sequenceLength - j, sequenceLength));
103                }
104        }
105
106        /**
107         * <p>
108         * Adds a given subsequence to the trie and increases the counters
109         * accordingly.
110         * </p>
111         *
112         * @param subsequence
113         *            subsequence whose counters are increased
114         * @see TrieNode#add(List)
115         */
116        protected void add(List<T> subsequence) {
117                if (subsequence != null && !subsequence.isEmpty()) {
118                        knownSymbols.addAll(subsequence);
119                        subsequence = new LinkedList<T>(subsequence); // defensive copy!
120                        T firstSymbol = subsequence.get(0);
121                        TrieNode<T> node = getChildCreate(firstSymbol);
122                        node.add(subsequence);
123                }
124        }
125
126        /**
127         * <p>
128         * Returns the child of the root node associated with the given symbol or
129         * creates it if it does not exist yet.
130         * </p>
131         *
132         * @param symbol
133         *            symbol whose node is required
134         * @return node associated with the symbol
135         * @see TrieNode#getChildCreate(Object)
136         */
137        protected TrieNode<T> getChildCreate(T symbol) {
138                return rootNode.getChildCreate(symbol);
139        }
140
141        /**
142         * <p>
143         * Returns the child of the root node associated with the given symbol or
144         * null if it does not exist.
145         * </p>
146         *
147         * @param symbol
148         *            symbol whose node is required
149         * @return node associated with the symbol; null if no such node exists
150         * @see TrieNode#getChild(Object)
151         */
152        protected TrieNode<T> getChild(T symbol) {
153                return rootNode.getChild(symbol);
154        }
155
156        /**
157         * <p>
158         * Returns the number of occurences of the given sequence.
159         * </p>
160         *
161         * @param sequence
162         *            sequence whose number of occurences is required
163         * @return number of occurences of the sequence
164         */
165        public int getCount(List<T> sequence) {
166                int count = 0;
167                TrieNode<T> node = find(sequence);
168                if (node != null) {
169                        count = node.getCount();
170                }
171                return count;
172        }
173
174        /**
175         * <p>
176         * Returns the number of occurences of the given prefix and a symbol that
177         * follows it.<br>
178         * Convenience function to simplify usage of {@link #getCount(List)}.
179         * </p>
180         *
181         * @param sequence
182         *            prefix of the sequence
183         * @param follower
184         *            suffix of the sequence
185         * @return number of occurences of the sequence
186         * @see #getCount(List)
187         */
188        public int getCount(List<T> sequence, T follower) {
189                List<T> tmpSequence = new LinkedList<T>(sequence);
190                tmpSequence.add(follower);
191                return getCount(tmpSequence);
192
193        }
194
195        /**
196         * <p>
197         * Searches the trie for a given sequence and returns the node associated
198         * with the sequence or null if no such node is found.
199         * </p>
200         *
201         * @param sequence
202         *            sequence that is searched for
203         * @return node associated with the sequence
204         * @see TrieNode#find(List)
205         */
206        public TrieNode<T> find(List<T> sequence) {
207                if (sequence == null || sequence.isEmpty()) {
208                        return rootNode;
209                }
210                List<T> sequenceCopy = new LinkedList<T>(sequence);
211                TrieNode<T> result = null;
212                TrieNode<T> node = getChild(sequenceCopy.get(0));
213                if (node != null) {
214                        sequenceCopy.remove(0);
215                        result = node.find(sequenceCopy);
216                }
217                return result;
218        }
219
220        /**
221         * <p>
222         * Returns a collection of all symbols that follow a given sequence in the
223         * trie. In case the sequence is not found or no symbols follow the sequence
224         * the result will be empty.
225         * </p>
226         *
227         * @param sequence
228         *            sequence whose followers are returned
229         * @return symbols following the given sequence
230         * @see TrieNode#getFollowingSymbols()
231         */
232        public Collection<T> getFollowingSymbols(List<T> sequence) {
233                Collection<T> result = new LinkedList<T>();
234                TrieNode<T> node = find(sequence);
235                if (node != null) {
236                        result = node.getFollowingSymbols();
237                }
238                return result;
239        }
240
241        /**
242         * <p>
243         * Returns the longest suffix of the given context that is contained in the
244         * tree and whose children are leaves.
245         * </p>
246         *
247         * @param context
248         *            context whose suffix is searched for
249         * @return longest suffix of the context
250         */
251        public List<T> getContextSuffix(List<T> context) {
252                List<T> contextSuffix = new LinkedList<T>(context); // defensive copy
253                boolean suffixFound = false;
254
255                while (!suffixFound) {
256                        if (contextSuffix.isEmpty()) {
257                                suffixFound = true; // suffix is the empty word
258                        } else {
259                                TrieNode<T> node = find(contextSuffix);
260                                if (node != null) {
261                                        if (!node.getFollowingSymbols().isEmpty()) {
262                                                suffixFound = true;
263                                        }
264                                }
265                                if (!suffixFound) {
266                                        contextSuffix.remove(0);
267                                }
268                        }
269                }
270
271                return contextSuffix;
272        }
273
274        /**
275         * <p>
276         * Helper class for graph visualization of a trie.
277         * </p>
278         *
279         * @author Steffen Herbold
280         * @version 1.0
281         */
282        static public class Edge {
283        }
284
285        /**
286         * <p>
287         * Helper class for graph visualization of a trie.
288         * </p>
289         *
290         * @author Steffen Herbold
291         * @version 1.0
292         */
293        static public class TrieVertex {
294
295                /**
296                 * <p>
297                 * Id of the vertex.
298                 * </p>
299                 */
300                private String id;
301
302                /**
303                 * <p>
304                 * Contructor. Creates a new TrieVertex.
305                 * </p>
306                 *
307                 * @param id
308                 *            id of the vertex
309                 */
310                protected TrieVertex(String id) {
311                        this.id = id;
312                }
313
314                /**
315                 * <p>
316                 * Returns the id of the vertex.
317                 * </p>
318                 *
319                 * @see java.lang.Object#toString()
320                 */
321                @Override
322                public String toString() {
323                        return id;
324                }
325        }
326
327        /**
328         * <p>
329         * Returns a {@link Graph} representation of the trie.
330         * </p>
331         *
332         * @return {@link Graph} representation of the trie
333         */
334        protected Tree<TrieVertex, Edge> getGraph() {
335                DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>();
336                rootNode.getGraph(null, graph);
337                return graph;
338        }
339
340        /*
341         * (non-Javadoc)
342         *
343         * @see de.ugoe.cs.eventbench.models.IDotCompatible#getDotRepresentation()
344         */
345        public String getDotRepresentation() {
346                StringBuilder stringBuilder = new StringBuilder();
347                stringBuilder.append("digraph model {" + StringTools.ENDLINE);
348                rootNode.appendDotRepresentation(stringBuilder);
349                stringBuilder.append('}' + StringTools.ENDLINE);
350                return stringBuilder.toString();
351        }
352
353        /**
354         * <p>
355         * Returns the string representation of the root node.
356         * </p>
357         *
358         * @see TrieNode#toString()
359         * @see java.lang.Object#toString()
360         */
361        @Override
362        public String toString() {
363                return rootNode.toString();
364        }
365
366        /**
367         * <p>
368         * Returns the number of symbols contained in the trie.
369         * </p>
370         *
371         * @return number of symbols contained in the trie
372         */
373        public int getNumSymbols() {
374                return knownSymbols.size();
375        }
376
377        /**
378         * <p>
379         * Returns the number of trie nodes that are ancestors of a leaf. This is
380         * the equivalent to the number of states a first-order markov model would
381         * have.
382         * <p>
383         *
384         * @return number of trie nodes that are ancestors of leafs.
385         */
386        public int getNumLeafAncestors() {
387                Set<TrieNode<T>> ancestors = new HashSet<TrieNode<T>>();
388                rootNode.getLeafAncestors(ancestors);
389                return ancestors.size();
390        }
391
392        /**
393         * <p>
394         * Returns the number of trie nodes that are leafs.
395         * </p>
396         *
397         * @return number of leafs in the trie
398         */
399        public int getNumLeafs() {
400                return rootNode.getNumLeafs();
401        }
402}
Note: See TracBrowser for help on using the repository browser.