Changeset 559 for trunk/quest-core-usageprofiles/src/main/java/de/ugoe/cs/quest/usageprofiles/TrieBasedModel.java
- Timestamp:
- 08/17/12 09:05:19 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/quest-core-usageprofiles/src/main/java/de/ugoe/cs/quest/usageprofiles/TrieBasedModel.java
r547 r559 1 1 2 package de.ugoe.cs.quest.usageprofiles; 2 3 … … 18 19 /** 19 20 * <p> 20 * Implements a skeleton for stochastic processes that can calculate 21 * probabilities based on a trie. The skeleton provides all functionalities of 22 * {@link IStochasticProcess} except 21 * Implements a skeleton for stochastic processes that can calculate probabilities based on a trie. 22 * The skeleton provides all functionalities of {@link IStochasticProcess} except 23 23 * {@link IStochasticProcess#getProbability(List, Event)}. 24 24 * </p> … … 29 29 public abstract class TrieBasedModel implements IStochasticProcess { 30 30 31 /** 32 * <p> 33 * Id for object serialization. 34 * </p> 35 */ 36 private static final long serialVersionUID = 1L; 37 38 /** 39 * <p> 40 * The order of the trie, i.e., the maximum length of subsequences stored in 41 * the trie. 42 * </p> 43 */ 44 protected int trieOrder; 45 46 /** 47 * <p> 48 * Trie on which the probability calculations are based. 49 * </p> 50 */ 51 protected Trie<Event> trie = null; 52 53 /** 54 * <p> 55 * Random number generator used by probabilistic sequence generation 56 * methods. 57 * </p> 58 */ 59 protected final Random r; 60 61 /** 62 * <p> 63 * Constructor. Creates a new TrieBasedModel that can be used for stochastic 64 * processes with a Markov order less than or equal to {@code markovOrder}. 65 * </p> 66 * 67 * @param markovOrder 68 * Markov order of the model 69 * @param r 70 * random number generator used by probabilistic methods of the 71 * class 72 * @throws InvalidParameterException 73 * thrown if markovOrder is less than 0 or the random number 74 * generator r is null 75 */ 76 public TrieBasedModel(int markovOrder, Random r) { 77 super(); 78 if (markovOrder < 0) { 79 throw new InvalidParameterException( 80 "markov order must not be less than 0"); 81 } 82 if (r == null) { 83 throw new InvalidParameterException( 84 "random number generator r must not be null"); 85 } 86 this.trieOrder = markovOrder + 1; 87 this.r = r; 88 } 89 90 /** 91 * <p> 92 * Trains the model by generating a trie from which probabilities are 93 * calculated. The trie is newly generated based solely on the passed 94 * sequences. If an existing model should only be updated, use 95 * {@link #update(Collection)} instead. 96 * </p> 97 * 98 * @param sequences 99 * training data 100 * @throws InvalidParameterException 101 * thrown is sequences is null 102 */ 103 public void train(Collection<List<Event>> sequences) { 104 trie = null; 105 update(sequences); 106 } 107 108 /** 109 * <p> 110 * Trains the model by updating the trie from which the probabilities are 111 * calculated. This function updates an existing trie. In case no trie 112 * exists yet, a new trie is generated and the function behaves like 113 * {@link #train(Collection)}. 114 * </p> 115 * 116 * @param sequences 117 * training data 118 * @throws InvalidParameterException 119 * thrown is sequences is null 120 */ 121 public void update(Collection<List<Event>> sequences) { 122 if (sequences == null) { 123 throw new InvalidParameterException("sequences must not be null"); 124 } 125 if (trie == null) { 126 trie = new Trie<Event>(); 127 } 128 for (List<Event> sequence : sequences) { 129 List<Event> currentSequence = new LinkedList<Event>(sequence); // defensive 130 // copy 131 currentSequence.add(0, Event.STARTEVENT); 132 currentSequence.add(Event.ENDEVENT); 133 134 trie.train(currentSequence, trieOrder); 135 } 136 } 137 138 /* 139 * (non-Javadoc) 140 * 141 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#randomSequence() 142 */ 143 @Override 144 public List<Event> randomSequence() { 145 return randomSequence(Integer.MAX_VALUE, true); 146 } 147 148 /* 149 * (non-Javadoc) 150 * 151 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#randomSequence() 152 */ 153 @Override 154 public List<Event> randomSequence(int maxLength, 155 boolean validEnd) { 156 List<Event> sequence = new LinkedList<Event>(); 157 if (trie != null) { 158 boolean endFound = false; 159 while (!endFound) { // outer loop for length checking 160 sequence = new LinkedList<Event>(); 161 IncompleteMemory<Event> context = new IncompleteMemory<Event>( 162 trieOrder - 1); 163 context.add(Event.STARTEVENT); 164 165 while (!endFound && sequence.size() <= maxLength) { 166 double randVal = r.nextDouble(); 167 double probSum = 0.0; 168 List<Event> currentContext = context.getLast(trieOrder); 169 for (Event symbol : trie.getKnownSymbols()) { 170 probSum += getProbability(currentContext, symbol); 171 if (probSum >= randVal) { 172 if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT 173 .equals(symbol))) { 174 // only add the symbol the sequence if it is not 175 // START or END 176 context.add(symbol); 177 sequence.add(symbol); 178 } 179 endFound = (Event.ENDEVENT.equals(symbol)) 180 || (!validEnd && sequence.size() == maxLength); 181 break; 182 } 183 } 184 } 185 } 186 } 187 return sequence; 188 } 189 190 /** 191 * <p> 192 * Returns a Dot representation of the internal trie. 193 * </p> 194 * 195 * @return dot representation of the internal trie 196 */ 197 public String getTrieDotRepresentation() { 198 if (trie == null) { 199 return ""; 200 } else { 201 return trie.getDotRepresentation(); 202 } 203 } 204 205 /** 206 * <p> 207 * Returns a {@link Tree} of the internal trie that can be used for 208 * visualization. 209 * </p> 210 * 211 * @return {@link Tree} depicting the internal trie 212 */ 213 public Tree<TrieVertex, Edge> getTrieGraph() { 214 if (trie == null) { 215 return null; 216 } else { 217 return trie.getGraph(); 218 } 219 } 220 221 /** 222 * <p> 223 * The string representation of the model is {@link Trie#toString()} of 224 * {@link #trie}. 225 * </p> 226 * 227 * @see java.lang.Object#toString() 228 */ 229 @Override 230 public String toString() { 231 if (trie == null) { 232 return ""; 233 } else { 234 return trie.toString(); 235 } 236 } 237 238 /* 239 * (non-Javadoc) 240 * 241 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumStates() 242 */ 243 @Override 244 public int getNumSymbols() { 245 if (trie == null) { 246 return 0; 247 } else { 248 return trie.getNumSymbols(); 249 } 250 } 251 252 /* 253 * (non-Javadoc) 254 * 255 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getStateStrings() 256 */ 257 @Override 258 public String[] getSymbolStrings() { 259 if (trie == null) { 260 return new String[0]; 261 } 262 String[] stateStrings = new String[getNumSymbols()]; 263 int i = 0; 264 for (Event symbol : trie.getKnownSymbols()) { 265 if (symbol.toString() == null) { 266 stateStrings[i] = "null"; 267 } else { 268 stateStrings[i] = symbol.toString(); 269 } 270 i++; 271 } 272 return stateStrings; 273 } 274 275 /* 276 * (non-Javadoc) 277 * 278 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getEvents() 279 */ 280 @Override 281 public Collection<Event> getEvents() { 282 if (trie == null) { 283 return new HashSet<Event>(); 284 } else { 285 return trie.getKnownSymbols(); 286 } 287 } 288 289 /* 290 * (non-Javadoc) 291 * 292 * @see 293 * de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateSequences(int) 294 */ 295 @Override 296 public Collection<List<Event>> generateSequences(int length) { 297 return generateSequences(length, false); 298 } 299 300 /* 301 * (non-Javadoc) 302 * 303 * @see 304 * de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateSequences(int, 305 * boolean) 306 */ 307 @Override 308 public Set<List<Event>> generateSequences(int length, 309 boolean fromStart) { 310 Set<List<Event>> sequenceSet = new LinkedHashSet<List<Event>>(); 311 if (length < 1) { 312 throw new InvalidParameterException( 313 "Length of generated subsequences must be at least 1."); 314 } 315 if (length == 1) { 316 if (fromStart) { 317 List<Event> subSeq = new LinkedList<Event>(); 318 subSeq.add(Event.STARTEVENT); 319 sequenceSet.add(subSeq); 320 } else { 321 for (Event event : getEvents()) { 322 List<Event> subSeq = new LinkedList<Event>(); 323 subSeq.add(event); 324 sequenceSet.add(subSeq); 325 } 326 } 327 return sequenceSet; 328 } 329 Collection<Event> events = getEvents(); 330 Collection<List<Event>> seqsShorter = generateSequences( 331 length - 1, fromStart); 332 for (Event event : events) { 333 for (List<Event> seqShorter : seqsShorter) { 334 Event lastEvent = event; 335 if (getProbability(seqShorter, lastEvent) > 0.0) { 336 List<Event> subSeq = new ArrayList<Event>(seqShorter); 337 subSeq.add(lastEvent); 338 sequenceSet.add(subSeq); 339 } 340 } 341 } 342 return sequenceSet; 343 } 344 345 /* 346 * (non-Javadoc) 347 * 348 * @see 349 * de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateValidSequences 350 * (int) 351 */ 352 @Override 353 public Collection<List<Event>> generateValidSequences( 354 int length) { 355 // check for min-length implicitly done by generateSequences 356 Collection<List<Event>> allSequences = generateSequences( 357 length, true); 358 Collection<List<Event>> validSequences = new LinkedHashSet<List<Event>>(); 359 for (List<Event> sequence : allSequences) { 360 if (sequence.size() == length 361 && Event.ENDEVENT.equals(sequence.get(sequence.size() - 1))) { 362 validSequences.add(sequence); 363 } 364 } 365 return validSequences; 366 } 367 368 /* 369 * (non-Javadoc) 370 * 371 * @see 372 * de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getProbability(java.util 373 * .List) 374 */ 375 @Override 376 public double getProbability(List<Event> sequence) { 377 if (sequence == null) { 378 throw new InvalidParameterException("sequence must not be null"); 379 } 380 double prob = 1.0; 381 List<Event> context = new LinkedList<Event>(); 382 for (Event event : sequence) { 383 prob *= getProbability(context, event); 384 context.add(event); 385 } 386 return prob; 387 } 388 389 /* 390 * (non-Javadoc) 391 * 392 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumFOMStates() 393 */ 394 @Override 395 public int getNumFOMStates() { 396 if (trie == null) { 397 return 0; 398 } else { 399 return trie.getNumLeafAncestors(); 400 } 401 } 402 403 /* 404 * (non-Javadoc) 405 * 406 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumTransitions() 407 */ 408 @Override 409 public int getNumTransitions() { 410 if (trie == null) { 411 return 0; 412 } else { 413 return trie.getNumLeafs(); 414 } 415 } 31 /** 32 * <p> 33 * Id for object serialization. 34 * </p> 35 */ 36 private static final long serialVersionUID = 1L; 37 38 /** 39 * <p> 40 * The order of the trie, i.e., the maximum length of subsequences stored in the trie. 41 * </p> 42 */ 43 protected int trieOrder; 44 45 /** 46 * <p> 47 * Trie on which the probability calculations are based. 48 * </p> 49 */ 50 protected Trie<Event> trie = null; 51 52 /** 53 * <p> 54 * Random number generator used by probabilistic sequence generation methods. 55 * </p> 56 */ 57 protected final Random r; 58 59 /** 60 * <p> 61 * Constructor. Creates a new TrieBasedModel that can be used for stochastic processes with a 62 * Markov order less than or equal to {@code markovOrder}. 63 * </p> 64 * 65 * @param markovOrder 66 * Markov order of the model 67 * @param r 68 * random number generator used by probabilistic methods of the class 69 * @throws InvalidParameterException 70 * thrown if markovOrder is less than 0 or the random number generator r is null 71 */ 72 public TrieBasedModel(int markovOrder, Random r) { 73 super(); 74 if (markovOrder < 0) { 75 throw new InvalidParameterException("markov order must not be less than 0"); 76 } 77 if (r == null) { 78 throw new InvalidParameterException("random number generator r must not be null"); 79 } 80 this.trieOrder = markovOrder + 1; 81 this.r = r; 82 } 83 84 /** 85 * <p> 86 * Trains the model by generating a trie from which probabilities are calculated. The trie is 87 * newly generated based solely on the passed sequences. If an existing model should only be 88 * updated, use {@link #update(Collection)} instead. 89 * </p> 90 * 91 * @param sequences 92 * training data 93 * @throws InvalidParameterException 94 * thrown is sequences is null 95 */ 96 public void train(Collection<List<Event>> sequences) { 97 trie = null; 98 update(sequences); 99 } 100 101 /** 102 * <p> 103 * Trains the model by updating the trie from which the probabilities are calculated. This 104 * function updates an existing trie. In case no trie exists yet, a new trie is generated and 105 * the function behaves like {@link #train(Collection)}. 106 * </p> 107 * 108 * @param sequences 109 * training data 110 * @throws InvalidParameterException 111 * thrown is sequences is null 112 */ 113 public void update(Collection<List<Event>> sequences) { 114 if (sequences == null) { 115 throw new InvalidParameterException("sequences must not be null"); 116 } 117 if (trie == null) { 118 trie = new Trie<Event>(); 119 } 120 for (List<Event> sequence : sequences) { 121 List<Event> currentSequence = new LinkedList<Event>(sequence); // defensive 122 // copy 123 currentSequence.add(0, Event.STARTEVENT); 124 currentSequence.add(Event.ENDEVENT); 125 126 trie.train(currentSequence, trieOrder); 127 } 128 } 129 130 /* 131 * (non-Javadoc) 132 * 133 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#randomSequence() 134 */ 135 @Override 136 public List<Event> randomSequence() { 137 return randomSequence(Integer.MAX_VALUE, true); 138 } 139 140 /* 141 * (non-Javadoc) 142 * 143 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#randomSequence() 144 */ 145 @Override 146 public List<Event> randomSequence(int maxLength, boolean validEnd) { 147 List<Event> sequence = new LinkedList<Event>(); 148 if (trie != null) { 149 boolean endFound = false; 150 while (!endFound) { // outer loop for length checking 151 sequence = new LinkedList<Event>(); 152 IncompleteMemory<Event> context = new IncompleteMemory<Event>(trieOrder - 1); 153 context.add(Event.STARTEVENT); 154 155 while (!endFound && sequence.size() <= maxLength) { 156 double randVal = r.nextDouble(); 157 double probSum = 0.0; 158 List<Event> currentContext = context.getLast(trieOrder); 159 for (Event symbol : trie.getKnownSymbols()) { 160 probSum += getProbability(currentContext, symbol); 161 if (probSum >= randVal) { 162 if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT.equals(symbol))) 163 { 164 // only add the symbol the sequence if it is not 165 // START or END 166 context.add(symbol); 167 sequence.add(symbol); 168 } 169 endFound = 170 (Event.ENDEVENT.equals(symbol)) || 171 (!validEnd && sequence.size() == maxLength); 172 break; 173 } 174 } 175 } 176 } 177 } 178 return sequence; 179 } 180 181 /** 182 * <p> 183 * Returns a Dot representation of the internal trie. 184 * </p> 185 * 186 * @return dot representation of the internal trie 187 */ 188 public String getTrieDotRepresentation() { 189 if (trie == null) { 190 return ""; 191 } 192 else { 193 return trie.getDotRepresentation(); 194 } 195 } 196 197 /** 198 * <p> 199 * Returns a {@link Tree} of the internal trie that can be used for visualization. 200 * </p> 201 * 202 * @return {@link Tree} depicting the internal trie 203 */ 204 public Tree<TrieVertex, Edge> getTrieGraph() { 205 if (trie == null) { 206 return null; 207 } 208 else { 209 return trie.getGraph(); 210 } 211 } 212 213 /** 214 * <p> 215 * The string representation of the model is {@link Trie#toString()} of {@link #trie}. 216 * </p> 217 * 218 * @see java.lang.Object#toString() 219 */ 220 @Override 221 public String toString() { 222 if (trie == null) { 223 return ""; 224 } 225 else { 226 return trie.toString(); 227 } 228 } 229 230 /* 231 * (non-Javadoc) 232 * 233 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumStates() 234 */ 235 @Override 236 public int getNumSymbols() { 237 if (trie == null) { 238 return 0; 239 } 240 else { 241 return trie.getNumSymbols(); 242 } 243 } 244 245 /* 246 * (non-Javadoc) 247 * 248 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getStateStrings() 249 */ 250 @Override 251 public String[] getSymbolStrings() { 252 if (trie == null) { 253 return new String[0]; 254 } 255 String[] stateStrings = new String[getNumSymbols()]; 256 int i = 0; 257 for (Event symbol : trie.getKnownSymbols()) { 258 if (symbol.toString() == null) { 259 stateStrings[i] = "null"; 260 } 261 else { 262 stateStrings[i] = symbol.toString(); 263 } 264 i++; 265 } 266 return stateStrings; 267 } 268 269 /* 270 * (non-Javadoc) 271 * 272 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getEvents() 273 */ 274 @Override 275 public Collection<Event> getEvents() { 276 if (trie == null) { 277 return new HashSet<Event>(); 278 } 279 else { 280 return trie.getKnownSymbols(); 281 } 282 } 283 284 /* 285 * (non-Javadoc) 286 * 287 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateSequences(int) 288 */ 289 @Override 290 public Collection<List<Event>> generateSequences(int length) { 291 return generateSequences(length, false); 292 } 293 294 /* 295 * (non-Javadoc) 296 * 297 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateSequences(int, boolean) 298 */ 299 @Override 300 public Set<List<Event>> generateSequences(int length, boolean fromStart) { 301 Set<List<Event>> sequenceSet = new LinkedHashSet<List<Event>>(); 302 if (length < 1) { 303 throw new InvalidParameterException( 304 "Length of generated subsequences must be at least 1."); 305 } 306 if (length == 1) { 307 if (fromStart) { 308 List<Event> subSeq = new LinkedList<Event>(); 309 subSeq.add(Event.STARTEVENT); 310 sequenceSet.add(subSeq); 311 } 312 else { 313 for (Event event : getEvents()) { 314 List<Event> subSeq = new LinkedList<Event>(); 315 subSeq.add(event); 316 sequenceSet.add(subSeq); 317 } 318 } 319 return sequenceSet; 320 } 321 Collection<Event> events = getEvents(); 322 Collection<List<Event>> seqsShorter = generateSequences(length - 1, fromStart); 323 for (Event event : events) { 324 for (List<Event> seqShorter : seqsShorter) { 325 Event lastEvent = event; 326 if (getProbability(seqShorter, lastEvent) > 0.0) { 327 List<Event> subSeq = new ArrayList<Event>(seqShorter); 328 subSeq.add(lastEvent); 329 sequenceSet.add(subSeq); 330 } 331 } 332 } 333 return sequenceSet; 334 } 335 336 /* 337 * (non-Javadoc) 338 * 339 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#generateValidSequences (int) 340 */ 341 @Override 342 public Collection<List<Event>> generateValidSequences(int length) { 343 // check for min-length implicitly done by generateSequences 344 Collection<List<Event>> allSequences = generateSequences(length, true); 345 Collection<List<Event>> validSequences = new LinkedHashSet<List<Event>>(); 346 for (List<Event> sequence : allSequences) { 347 if (sequence.size() == length && 348 Event.ENDEVENT.equals(sequence.get(sequence.size() - 1))) 349 { 350 validSequences.add(sequence); 351 } 352 } 353 return validSequences; 354 } 355 356 /* 357 * (non-Javadoc) 358 * 359 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getProbability(java.util .List) 360 */ 361 @Override 362 public double getProbability(List<Event> sequence) { 363 if (sequence == null) { 364 throw new InvalidParameterException("sequence must not be null"); 365 } 366 double prob = 1.0; 367 List<Event> context = new LinkedList<Event>(); 368 for (Event event : sequence) { 369 prob *= getProbability(context, event); 370 context.add(event); 371 } 372 return prob; 373 } 374 375 /* 376 * (non-Javadoc) 377 * 378 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumFOMStates() 379 */ 380 @Override 381 public int getNumFOMStates() { 382 if (trie == null) { 383 return 0; 384 } 385 else { 386 return trie.getNumLeafAncestors(); 387 } 388 } 389 390 /* 391 * (non-Javadoc) 392 * 393 * @see de.ugoe.cs.quest.usageprofiles.IStochasticProcess#getNumTransitions() 394 */ 395 @Override 396 public int getNumTransitions() { 397 if (trie == null) { 398 return 0; 399 } 400 else { 401 return trie.getNumLeafs(); 402 } 403 } 416 404 }
Note: See TracChangeset
for help on using the changeset viewer.