source: branches/ralph/src/main/java/de/ugoe/cs/autoquest/tasktrees/alignment/pal/tree/UPGMAAligningTree.java @ 1630

Last change on this file since 1630 was 1630, checked in by rkrimmel, 10 years ago

Generated Model of matches.

File size: 9.1 KB
Line 
1// UPGMATree.java
2//
3// (c) 1999-2001 PAL Development Core Team
4//
5// This package may be distributed under the
6// terms of the Lesser GNU General Public License (LGPL)
7
8// Known bugs and limitations:
9// - computational complexity O(numSeqs^3)
10//   (this could be brought down to O(numSeqs^2)
11//   but this needs more clever programming ...)
12
13
14package de.ugoe.cs.autoquest.tasktrees.alignment.pal.tree;
15
16import java.util.ArrayList;
17import java.util.Iterator;
18import java.util.logging.Level;
19
20import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithm;
21import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithmFactory;
22import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.NumberSequence;
23import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.PairwiseAlignmentStorage;
24import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.ObjectDistanceSubstitionMatrix;
25import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.UPGMAMatrix;
26import de.ugoe.cs.autoquest.tasktrees.alignment.pal.misc.Identifier;
27import de.ugoe.cs.util.console.Console;
28
29
30/**
31 * constructs a UPGMA tree from pairwise distances
32 *
33 * @version $Id: UPGMATree.java,v 1.9 2001/07/13 14:39:13 korbinian Exp $
34 *
35 * @author Korbinian Strimmer
36 * @author Alexei Drummond
37 */
38public class UPGMAAligningTree extends SimpleTree
39{
40        //
41        // Public stuff
42        //     
43
44        /**
45         * constructor UPGMA tree
46         * @param numberseqs
47         *
48         * @param m distance matrix
49         */
50        public UPGMAAligningTree(ArrayList<NumberSequence> numberseqs, PairwiseAlignmentStorage alignments, ObjectDistanceSubstitionMatrix submat)
51        {
52                if (alignments.getDistanceMatrix().size() < 2)
53                {
54                        new IllegalArgumentException("LESS THAN 2 TAXA IN DISTANCE MATRIX");
55                }
56       
57                this.numberseqs = numberseqs;
58                this.submat = submat;
59                init(alignments.getDistanceMatrix());
60
61                while (true)
62                {
63                        findNextPair();
64                        newBranchLengths();
65                       
66                        if (numClusters == 2)
67                        {
68                                break;
69                        }
70                       
71                        newCluster();
72                }
73               
74                finish();
75                createNodeList();
76        }
77
78
79        //
80        // Private stuff
81        //
82        private ArrayList<NumberSequence> numberseqs;
83        private ObjectDistanceSubstitionMatrix submat;
84        private int numClusters;
85        private int besti, abi;
86        private int bestj, abj;
87        private int[] alias;
88        private double[][] distance;
89
90        private double[] height;
91        private int[] oc;
92
93        private double getDist(int a, int b)
94        {
95                return distance[alias[a]][alias[b]];
96        }
97       
98        private void init(UPGMAMatrix m)
99        {
100                numClusters = m.size();
101
102                distance = new double[numClusters][numClusters];
103                for (int i = 0; i < numClusters; i++)
104                {
105                        for (int j = 0; j < numClusters; j++)
106                        {
107                                distance[i][j] = m.get(i,j);
108                        }
109                }
110
111                for (int i = 0; i < numClusters; i++)
112                {
113                        Node tmp = NodeFactory.createNode();
114                        tmp.setIdentifier(new Identifier(Integer.toString(i)));
115                        tmp.setNumber(i);
116                        tmp.addSequence(numberseqs.get(i));
117                        getRoot().addChild(tmp);
118                }
119               
120                alias = new int[numClusters];
121                for (int i = 0; i < numClusters; i++)
122                {
123                        alias[i] = i;
124                }
125                               
126                height = new double[numClusters];
127                oc = new int[numClusters];
128                for (int i = 0; i < numClusters; i++)
129                {
130                        height[i] = 0.0;
131                        oc[i] = 1;
132                }
133        }
134
135        private void finish()
136        {
137                this.getRoot().setSequences(alignSequences(this.getRoot()));
138                distance = null;               
139        }
140
141        private void findNextPair()
142        {
143                besti = 0;
144                bestj = 1;
145                double dmin = getDist(0, 1);
146                for (int i = 0; i < numClusters-1; i++)
147                {
148                        for (int j = i+1; j < numClusters; j++)
149                        {
150                                if (getDist(i, j) < dmin)
151                                {
152                                        dmin = getDist(i, j);
153                                        besti = i;
154                                        bestj = j;
155                                }
156                        }
157                }
158                abi = alias[besti];
159                abj = alias[bestj];
160                //System.out.println("Found best pair: " + abi + "/" +abj + " - "+ besti+ "/"+bestj +" with distance " + dmin);
161               
162        }
163
164        private void newBranchLengths()
165        {
166                double dij = getDist(besti, bestj);
167               
168                getRoot().getChild(besti).setBranchLength(dij/2.0-height[abi]);
169                getRoot().getChild(bestj).setBranchLength(dij/2.0-height[abj]);
170        }
171
172        private void newCluster()
173        {
174                // Update distances
175                for (int k = 0; k < numClusters; k++)
176                {
177                        if (k != besti && k != bestj)
178                        {
179                                int ak = alias[k];
180                                double updated = updatedDistance(besti,bestj,k);
181                                distance[ak][abi] = distance[abi][ak] = updated;
182                        }
183                }
184                distance[abi][abi] = 0.0;
185
186                // Update UPGMA variables
187                height[abi] = getDist(besti, bestj)/2.0;
188                oc[abi] += oc[abj];
189               
190                // Index besti now represent the new cluster
191                Node newNode = getRoot().joinChildren(besti, bestj);
192               
193                if(newNode instanceof FengDoolittleNode) {
194                        newNode.setSequences(alignSequences(newNode));
195                }
196               
197                // Update alias
198                for (int i = bestj; i < numClusters-1; i++)
199                {
200                        alias[i] = alias[i+1];
201                }
202               
203                numClusters--;
204        }
205       
206       
207        public ArrayList<NumberSequence> alignSequences(Node parent) {
208                ArrayList<NumberSequence> alignment = new ArrayList<NumberSequence>();
209                if(parent.getChildCount()<3) {
210                       
211                        Node node1 = parent.getChild(0);
212                        Node node2 = parent.getChild(1);
213                       
214                        int seqCount1 = node1.getSequences().size();
215                        int seqCount2 = node2.getSequences().size();
216
217                       
218                        Console.traceln(Level.INFO,"Merging node " + node1.getIdentifier() + " with " + node2.getIdentifier());
219                        //Console.println("SeqCount1: " + seqCount1 + " seqCount2 " + seqCount2);
220                        //Align 2 sequences
221                        if(seqCount1 == 1 && seqCount2 == 1) {
222                                AlignmentAlgorithm aa = AlignmentAlgorithmFactory.create();
223                                aa.align(node1.getSequence(0), node2.getSequence(0), submat, 5);
224                                alignment = aa.getAlignment();
225                               
226                        }
227                        //Align a sequence to a group
228                        else if( seqCount1 > 1 && seqCount2 == 1) {
229                                alignment.addAll(node1.getSequences());
230                               
231                                PairwiseAlignmentStorage tempStorage = new PairwiseAlignmentStorage(seqCount1,seqCount2);
232                                double maxScore = 0.0;
233                                int maxIndex = 0;
234                                for(int i=0;i<seqCount1;i++){
235                                        AlignmentAlgorithm aa = AlignmentAlgorithmFactory.create();
236                                        aa.align(node1.getSequence(i), node2.getSequence(0) , submat, 5);
237                                        tempStorage.set(i, 1, aa);
238                                       
239                                        if(maxScore < tempStorage.get(i, 1).getAlignmentScore()) {
240                                                maxScore = tempStorage.get(i, 1).getAlignmentScore();
241                                                maxIndex = i;
242                                        }
243                                }
244                                //if(maxScore > 0)
245                               
246                                alignment.add(tempStorage.get(maxIndex, 1).getAlignment().get(1));
247                        }
248                        //Align a sequence to a group
249                        else if(seqCount1 == 1 && seqCount2 > 1) {
250                                alignment.addAll(node2.getSequences());
251                                PairwiseAlignmentStorage tempStorage = new PairwiseAlignmentStorage(seqCount1,seqCount2);
252                                double maxScore = 0.0;
253                                int maxIndex = 0;
254                                for(int i=0;i<seqCount2;i++) {
255                                        AlignmentAlgorithm aa = AlignmentAlgorithmFactory.create();
256                                        aa.align(node2.getSequence(i), node1.getSequence(0) , submat, 5);
257                                        tempStorage.set(1, i, aa);
258                                        if(maxScore < tempStorage.get(1, i).getAlignmentScore()) {
259                                                maxScore = tempStorage.get(1, i).getAlignmentScore();
260                                                maxIndex = i;
261                                        }
262                                }
263                                //if(maxScore > 0)
264                               
265                                alignment.add(tempStorage.get(1,maxIndex).getAlignment().get(1));
266                        }
267                        //Align 2 groups
268                        else if((seqCount1 > 1) && (seqCount2 > 1)){
269                                        PairwiseAlignmentStorage tempStorage1 = new PairwiseAlignmentStorage(seqCount2,1);
270                                        PairwiseAlignmentStorage tempStorage2 = new PairwiseAlignmentStorage(seqCount1,1);
271                                        double maxScore1 = 0.0;
272                                        double maxScore2 = 0.0;
273                                        int maxIndex1 = 0;
274                                        int maxIndex2 = 0;
275                                        for(int i=0;i<seqCount1;i++) {
276                                                for(int j=0;j<seqCount2;j++) {
277                                                        AlignmentAlgorithm aa =AlignmentAlgorithmFactory.create();
278                                                        aa.align(node1.getSequence(i), node2.getSequence(j), submat, 5);
279                                                        tempStorage1.set(j, 0, aa);
280                                                        if(maxScore1 < tempStorage1.get(j, 0).getAlignmentScore()) {
281                                                                maxScore1 = tempStorage1.get(j, 0).getAlignmentScore();
282                                                                maxIndex1 = j;
283                                                        }
284                                                }
285                                                //if(maxScore1 > 0)
286                                                alignment.add(tempStorage1.get(maxIndex1,0).getAlignment().get(0));
287                                        }
288                                        for(int i=0; i<seqCount2;i++) {
289                                                for (int j=0;j<seqCount1;j++) {
290                                                        AlignmentAlgorithm aa =AlignmentAlgorithmFactory.create();
291                                                        aa.align(node2.getSequence(i),node1.getSequence(j),submat,5);
292                                                        tempStorage2.set(j, 0, aa);
293                                                        if(maxScore2 < tempStorage2.get(j, 0).getAlignmentScore()) {
294                                                                maxScore2 = tempStorage2.get(j, 0).getAlignmentScore();
295                                                                maxIndex2 = j;
296                                                        }
297                                                }
298                                                //if(maxScore2 > 0)
299                                                alignment.add(tempStorage2.get(maxIndex2,0).getAlignment().get(0));
300                                        }
301                                       
302                        }
303                        else {
304                                Console.traceln(Level.WARNING,"No sequences to align while merging " + node1.getIdentifier() + " with " + node2.getIdentifier());
305                        }
306                }
307                else {
308                        Console.traceln(Level.WARNING,"More than 2 children! This should never happen, it's a binary tree.");
309                }
310                return alignment;
311        }
312
313       
314        /**
315         * compute updated distance between the new cluster (i,j)
316         * to any other cluster k
317         */
318        private double updatedDistance(int i, int j, int k)
319        {
320                int ai = alias[i];
321                int aj = alias[j];
322               
323                double ocsum = (double) (oc[ai]+oc[aj]);
324                double idist = getDist(k,i);
325                double jdist = getDist(k,j);
326                //TODO: Dirty hack to deal with infinity, insert proper solution here
327                if(Double.isInfinite(idist)) {
328                        idist = 100;
329                }
330                if(Double.isInfinite(jdist)) {
331                        jdist = 100;
332                }
333               
334                return  (oc[ai]/ocsum)*idist+
335                        (oc[aj]/ocsum)*jdist;
336        }
337
338
339        public void printMultipleAlignment() {
340                for (Iterator<NumberSequence> it =  getRoot().getSequences().iterator(); it.hasNext();) {
341                        NumberSequence tmp  = it.next();
342                        tmp.printSequence();
343                }
344        }
345}
346
347
Note: See TracBrowser for help on using the repository browser.