1 | package de.ugoe.cs.eventbench.models;
|
---|
2 |
|
---|
3 | import java.util.ArrayList;
|
---|
4 | import java.util.List;
|
---|
5 | import java.util.Random;
|
---|
6 |
|
---|
7 | import de.ugoe.cs.eventbench.data.Event;
|
---|
8 | import de.ugoe.cs.util.StringTools;
|
---|
9 | import de.ugoe.cs.util.console.Console;
|
---|
10 | import edu.uci.ics.jung.graph.Graph;
|
---|
11 | import edu.uci.ics.jung.graph.SparseMultigraph;
|
---|
12 | import edu.uci.ics.jung.graph.util.EdgeType;
|
---|
13 |
|
---|
14 | import Jama.Matrix;
|
---|
15 |
|
---|
16 | public class FirstOrderMarkovModel extends HighOrderMarkovModel implements IDotCompatible {
|
---|
17 |
|
---|
18 | final static int MAX_STATDIST_ITERATIONS = 1000;
|
---|
19 |
|
---|
20 | public FirstOrderMarkovModel(Random r) {
|
---|
21 | super(1, r);
|
---|
22 | }
|
---|
23 |
|
---|
24 | private Matrix getTransmissionMatrix() {
|
---|
25 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
26 | int numStates = knownSymbols.size();
|
---|
27 | Matrix transmissionMatrix = new Matrix(numStates, numStates);
|
---|
28 |
|
---|
29 | for( int i=0 ; i<numStates ; i++ ) {
|
---|
30 | Event<?> currentSymbol = knownSymbols.get(i);
|
---|
31 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
32 | context.add(currentSymbol);
|
---|
33 | for( int j=0 ; j<numStates ; j++ ) {
|
---|
34 | Event<?> follower = knownSymbols.get(j);
|
---|
35 | double prob = getProbability(context, follower);
|
---|
36 | transmissionMatrix.set(i, j, prob);
|
---|
37 | }
|
---|
38 | }
|
---|
39 | return transmissionMatrix;
|
---|
40 | }
|
---|
41 |
|
---|
42 | public String getDotRepresentation() {
|
---|
43 | StringBuilder stringBuilder = new StringBuilder();
|
---|
44 | stringBuilder.append("digraph model {" + StringTools.ENDLINE);
|
---|
45 |
|
---|
46 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
47 |
|
---|
48 | for( Event<?> symbol : knownSymbols) {
|
---|
49 | final String thisSaneId = symbol.getShortId().replace("\"", "\\\"").replaceAll("[\r\n]","");
|
---|
50 | stringBuilder.append(" " + symbol.hashCode() + " [label=\""+thisSaneId+"\"];" + StringTools.ENDLINE);
|
---|
51 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
52 | context.add(symbol);
|
---|
53 | List<Event<?>> followers = trie.getFollowingSymbols(context);
|
---|
54 | for( Event<?> follower : followers ) {
|
---|
55 | stringBuilder.append(" "+symbol.hashCode()+" -> " + follower.hashCode() + " ");
|
---|
56 | stringBuilder.append("[label=\"" + getProbability(context, follower) + "\"];" + StringTools.ENDLINE);
|
---|
57 | }
|
---|
58 | }
|
---|
59 | stringBuilder.append('}' + StringTools.ENDLINE);
|
---|
60 | return stringBuilder.toString();
|
---|
61 | }
|
---|
62 |
|
---|
63 | public Graph<String, MarkovEdge> getGraph() {
|
---|
64 | Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
|
---|
65 |
|
---|
66 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
67 |
|
---|
68 | for( Event<?> symbol : knownSymbols) {
|
---|
69 | String from = symbol.getShortId();
|
---|
70 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
71 | context.add(symbol);
|
---|
72 |
|
---|
73 | List<Event<?>> followers = trie.getFollowingSymbols(context);
|
---|
74 |
|
---|
75 | for( Event<?> follower : followers ) {
|
---|
76 | String to = follower.getShortId();
|
---|
77 | MarkovEdge prob = new MarkovEdge(getProbability(context, follower));
|
---|
78 | graph.addEdge(prob, from, to, EdgeType.DIRECTED);
|
---|
79 | }
|
---|
80 | }
|
---|
81 | return graph;
|
---|
82 | }
|
---|
83 |
|
---|
84 | static public class MarkovEdge {
|
---|
85 | double weight;
|
---|
86 | MarkovEdge(double weight) { this.weight = weight; }
|
---|
87 | public String toString() { return ""+weight; }
|
---|
88 | }
|
---|
89 |
|
---|
90 | public double calcEntropy() {
|
---|
91 | Matrix transmissionMatrix = getTransmissionMatrix();
|
---|
92 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
93 | int numStates = knownSymbols.size();
|
---|
94 |
|
---|
95 | int startStateIndex = knownSymbols.indexOf(Event.STARTEVENT);
|
---|
96 | int endStateIndex = knownSymbols.indexOf(Event.ENDEVENT);
|
---|
97 | if( startStateIndex==-1 ) {
|
---|
98 | Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
|
---|
99 | return Double.NaN;
|
---|
100 | }
|
---|
101 | if( endStateIndex==-1 ) {
|
---|
102 | Console.printerrln("Error calculating entropy. End state of markov chain not found.");
|
---|
103 | return Double.NaN;
|
---|
104 | }
|
---|
105 | transmissionMatrix.set(endStateIndex, startStateIndex, 1);
|
---|
106 |
|
---|
107 | // Calculate stationary distribution by raising the power of the transmission matrix.
|
---|
108 | // The rank of the matrix should fall to 1 and each two should be the vector of the
|
---|
109 | // stationory distribution.
|
---|
110 | int iter = 0;
|
---|
111 | int rank = transmissionMatrix.rank();
|
---|
112 | Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
|
---|
113 | while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
|
---|
114 | stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
|
---|
115 | rank = stationaryMatrix.rank();
|
---|
116 | iter++;
|
---|
117 | }
|
---|
118 |
|
---|
119 | if( rank!=1 ) {
|
---|
120 | Console.traceln("rank: " + rank);
|
---|
121 | Console.printerrln("Unable to calculate stationary distribution.");
|
---|
122 | return Double.NaN;
|
---|
123 | }
|
---|
124 |
|
---|
125 | double entropy = 0.0;
|
---|
126 | for( int i=0 ; i<numStates ; i++ ) {
|
---|
127 | for( int j=0 ; j<numStates ; j++ ) {
|
---|
128 | if( transmissionMatrix.get(i,j)!=0 ) {
|
---|
129 | double tmp = stationaryMatrix.get(i, 0);
|
---|
130 | tmp *= transmissionMatrix.get(i, j);
|
---|
131 | tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
|
---|
132 | entropy -= tmp;
|
---|
133 | }
|
---|
134 | }
|
---|
135 | }
|
---|
136 | return entropy;
|
---|
137 | }
|
---|
138 |
|
---|
139 | }
|
---|