[16] | 1 | package de.ugoe.cs.eventbench.models;
|
---|
| 2 |
|
---|
| 3 | import java.util.ArrayList;
|
---|
| 4 | import java.util.List;
|
---|
| 5 | import java.util.Random;
|
---|
| 6 |
|
---|
| 7 | import de.ugoe.cs.eventbench.data.Event;
|
---|
[29] | 8 | import de.ugoe.cs.util.StringTools;
|
---|
[16] | 9 | import de.ugoe.cs.util.console.Console;
|
---|
| 10 | import edu.uci.ics.jung.graph.Graph;
|
---|
| 11 | import edu.uci.ics.jung.graph.SparseMultigraph;
|
---|
| 12 | import edu.uci.ics.jung.graph.util.EdgeType;
|
---|
| 13 |
|
---|
| 14 | import Jama.Matrix;
|
---|
| 15 |
|
---|
[25] | 16 | public class FirstOrderMarkovModel extends HighOrderMarkovModel implements IDotCompatible {
|
---|
[16] | 17 |
|
---|
| 18 | final static int MAX_STATDIST_ITERATIONS = 1000;
|
---|
| 19 |
|
---|
| 20 | public FirstOrderMarkovModel(Random r) {
|
---|
| 21 | super(1, r);
|
---|
| 22 | }
|
---|
| 23 |
|
---|
| 24 | private Matrix getTransmissionMatrix() {
|
---|
| 25 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
| 26 | int numStates = knownSymbols.size();
|
---|
| 27 | Matrix transmissionMatrix = new Matrix(numStates, numStates);
|
---|
| 28 |
|
---|
| 29 | for( int i=0 ; i<numStates ; i++ ) {
|
---|
| 30 | Event<?> currentSymbol = knownSymbols.get(i);
|
---|
| 31 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
| 32 | context.add(currentSymbol);
|
---|
| 33 | for( int j=0 ; j<numStates ; j++ ) {
|
---|
| 34 | Event<?> follower = knownSymbols.get(j);
|
---|
| 35 | double prob = getProbability(context, follower);
|
---|
| 36 | transmissionMatrix.set(i, j, prob);
|
---|
| 37 | }
|
---|
| 38 | }
|
---|
| 39 | return transmissionMatrix;
|
---|
| 40 | }
|
---|
| 41 |
|
---|
[25] | 42 | public String getDotRepresentation() {
|
---|
| 43 | StringBuilder stringBuilder = new StringBuilder();
|
---|
[29] | 44 | stringBuilder.append("digraph model {" + StringTools.ENDLINE);
|
---|
[16] | 45 |
|
---|
| 46 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
| 47 |
|
---|
| 48 | for( Event<?> symbol : knownSymbols) {
|
---|
| 49 | final String thisSaneId = symbol.getShortId().replace("\"", "\\\"").replaceAll("[\r\n]","");
|
---|
[29] | 50 | stringBuilder.append(" " + symbol.hashCode() + " [label=\""+thisSaneId+"\"];" + StringTools.ENDLINE);
|
---|
[16] | 51 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
| 52 | context.add(symbol);
|
---|
| 53 | List<Event<?>> followers = trie.getFollowingSymbols(context);
|
---|
| 54 | for( Event<?> follower : followers ) {
|
---|
[29] | 55 | stringBuilder.append(" "+symbol.hashCode()+" -> " + follower.hashCode() + " ");
|
---|
| 56 | stringBuilder.append("[label=\"" + getProbability(context, follower) + "\"];" + StringTools.ENDLINE);
|
---|
[16] | 57 | }
|
---|
| 58 | }
|
---|
[29] | 59 | stringBuilder.append('}' + StringTools.ENDLINE);
|
---|
[25] | 60 | return stringBuilder.toString();
|
---|
[16] | 61 | }
|
---|
| 62 |
|
---|
| 63 | public Graph<String, MarkovEdge> getGraph() {
|
---|
| 64 | Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
|
---|
| 65 |
|
---|
| 66 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
| 67 |
|
---|
| 68 | for( Event<?> symbol : knownSymbols) {
|
---|
| 69 | String from = symbol.getShortId();
|
---|
| 70 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
| 71 | context.add(symbol);
|
---|
| 72 |
|
---|
| 73 | List<Event<?>> followers = trie.getFollowingSymbols(context);
|
---|
| 74 |
|
---|
| 75 | for( Event<?> follower : followers ) {
|
---|
| 76 | String to = follower.getShortId();
|
---|
| 77 | MarkovEdge prob = new MarkovEdge(getProbability(context, follower));
|
---|
| 78 | graph.addEdge(prob, from, to, EdgeType.DIRECTED);
|
---|
| 79 | }
|
---|
| 80 | }
|
---|
| 81 | return graph;
|
---|
| 82 | }
|
---|
| 83 |
|
---|
| 84 | static public class MarkovEdge {
|
---|
| 85 | double weight;
|
---|
| 86 | MarkovEdge(double weight) { this.weight = weight; }
|
---|
| 87 | public String toString() { return ""+weight; }
|
---|
| 88 | }
|
---|
| 89 |
|
---|
| 90 | public double calcEntropy() {
|
---|
| 91 | Matrix transmissionMatrix = getTransmissionMatrix();
|
---|
| 92 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
| 93 | int numStates = knownSymbols.size();
|
---|
| 94 |
|
---|
| 95 | int startStateIndex = knownSymbols.indexOf(Event.STARTEVENT);
|
---|
| 96 | int endStateIndex = knownSymbols.indexOf(Event.ENDEVENT);
|
---|
| 97 | if( startStateIndex==-1 ) {
|
---|
| 98 | Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
|
---|
| 99 | return Double.NaN;
|
---|
| 100 | }
|
---|
| 101 | if( endStateIndex==-1 ) {
|
---|
| 102 | Console.printerrln("Error calculating entropy. End state of markov chain not found.");
|
---|
| 103 | return Double.NaN;
|
---|
| 104 | }
|
---|
| 105 | transmissionMatrix.set(endStateIndex, startStateIndex, 1);
|
---|
| 106 |
|
---|
| 107 | // Calculate stationary distribution by raising the power of the transmission matrix.
|
---|
| 108 | // The rank of the matrix should fall to 1 and each two should be the vector of the
|
---|
| 109 | // stationory distribution.
|
---|
| 110 | int iter = 0;
|
---|
| 111 | int rank = transmissionMatrix.rank();
|
---|
| 112 | Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
|
---|
| 113 | while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
|
---|
| 114 | stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
|
---|
| 115 | rank = stationaryMatrix.rank();
|
---|
| 116 | iter++;
|
---|
| 117 | }
|
---|
| 118 |
|
---|
| 119 | if( rank!=1 ) {
|
---|
| 120 | Console.traceln("rank: " + rank);
|
---|
| 121 | Console.printerrln("Unable to calculate stationary distribution.");
|
---|
| 122 | return Double.NaN;
|
---|
| 123 | }
|
---|
| 124 |
|
---|
| 125 | double entropy = 0.0;
|
---|
| 126 | for( int i=0 ; i<numStates ; i++ ) {
|
---|
| 127 | for( int j=0 ; j<numStates ; j++ ) {
|
---|
| 128 | if( transmissionMatrix.get(i,j)!=0 ) {
|
---|
| 129 | double tmp = stationaryMatrix.get(i, 0);
|
---|
| 130 | tmp *= transmissionMatrix.get(i, j);
|
---|
| 131 | tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
|
---|
| 132 | entropy -= tmp;
|
---|
| 133 | }
|
---|
| 134 | }
|
---|
| 135 | }
|
---|
| 136 | return entropy;
|
---|
| 137 | }
|
---|
| 138 |
|
---|
| 139 | }
|
---|