source: branches/ralph/src/main/java/de/ugoe/cs/autoquest/tasktrees/alignment/pal/tree/UPGMAAligningTree.java @ 1589

Last change on this file since 1589 was 1589, checked in by rkrimmel, 10 years ago

Refactoring and code cleanup

File size: 8.9 KB
Line 
1// UPGMATree.java
2//
3// (c) 1999-2001 PAL Development Core Team
4//
5// This package may be distributed under the
6// terms of the Lesser GNU General Public License (LGPL)
7
8// Known bugs and limitations:
9// - computational complexity O(numSeqs^3)
10//   (this could be brought down to O(numSeqs^2)
11//   but this needs more clever programming ...)
12
13
14package de.ugoe.cs.autoquest.tasktrees.alignment.pal.tree;
15
16import java.util.ArrayList;
17import java.util.Iterator;
18import java.util.logging.Level;
19
20import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithmFactory;
21import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.NumberSequence;
22import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.PairwiseAlignmentStorage;
23import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.ObjectDistanceSubstitionMatrix;
24import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.UPGMAMatrix;
25import de.ugoe.cs.autoquest.tasktrees.alignment.pal.misc.Identifier;
26import de.ugoe.cs.util.console.Console;
27
28
29/**
30 * constructs a UPGMA tree from pairwise distances
31 *
32 * @version $Id: UPGMATree.java,v 1.9 2001/07/13 14:39:13 korbinian Exp $
33 *
34 * @author Korbinian Strimmer
35 * @author Alexei Drummond
36 */
37public class UPGMAAligningTree extends SimpleTree
38{
39        //
40        // Public stuff
41        //     
42
43        /**
44         * constructor UPGMA tree
45         * @param numberseqs
46         *
47         * @param m distance matrix
48         */
49        public UPGMAAligningTree(ArrayList<NumberSequence> numberseqs, PairwiseAlignmentStorage alignments, ObjectDistanceSubstitionMatrix submat)
50        {
51                if (alignments.getDistanceMatrix().size() < 2)
52                {
53                        new IllegalArgumentException("LESS THAN 2 TAXA IN DISTANCE MATRIX");
54                }
55       
56                this.numberseqs = numberseqs;
57                this.alignments = alignments;
58                this.submat = submat;
59                init(alignments.getDistanceMatrix());
60
61                while (true)
62                {
63                        findNextPair();
64                        newBranchLengths();
65                       
66                        if (numClusters == 2)
67                        {
68                                break;
69                        }
70                       
71                        newCluster();
72                }
73               
74                finish();
75                createNodeList();
76        }
77
78
79        //
80        // Private stuff
81        //
82        private ArrayList<NumberSequence> numberseqs;
83        private PairwiseAlignmentStorage alignments;
84        private ObjectDistanceSubstitionMatrix submat;
85        private int numClusters;
86        private int besti, abi;
87        private int bestj, abj;
88        private int[] alias;
89        private double[][] distance;
90
91        private double[] height;
92        private int[] oc;
93
94        private double getDist(int a, int b)
95        {
96                return distance[alias[a]][alias[b]];
97        }
98       
99        private void init(UPGMAMatrix m)
100        {
101                numClusters = m.size();
102
103                distance = new double[numClusters][numClusters];
104                for (int i = 0; i < numClusters; i++)
105                {
106                        for (int j = 0; j < numClusters; j++)
107                        {
108                                distance[i][j] = m.get(i,j);
109                        }
110                }
111
112                for (int i = 0; i < numClusters; i++)
113                {
114                        Node tmp = NodeFactory.createNode();
115                        tmp.setIdentifier(new Identifier(Integer.toString(i)));
116                        tmp.setNumber(i);
117                        tmp.addSequence(numberseqs.get(i));
118                        getRoot().addChild(tmp);
119                }
120               
121                alias = new int[numClusters];
122                for (int i = 0; i < numClusters; i++)
123                {
124                        alias[i] = i;
125                }
126                               
127                height = new double[numClusters];
128                oc = new int[numClusters];
129                for (int i = 0; i < numClusters; i++)
130                {
131                        height[i] = 0.0;
132                        oc[i] = 1;
133                }
134        }
135
136        private void finish()
137        {
138                this.getRoot().setSequences(alignSequences(this.getRoot()));
139                distance = null;               
140        }
141
142        private void findNextPair()
143        {
144                besti = 0;
145                bestj = 1;
146                double dmin = getDist(0, 1);
147                for (int i = 0; i < numClusters-1; i++)
148                {
149                        for (int j = i+1; j < numClusters; j++)
150                        {
151                                if (getDist(i, j) < dmin)
152                                {
153                                        dmin = getDist(i, j);
154                                        besti = i;
155                                        bestj = j;
156                                }
157                        }
158                }
159                abi = alias[besti];
160                abj = alias[bestj];
161                //System.out.println("Found best pair: " + abi + "/" +abj + " - "+ besti+ "/"+bestj +" with distance " + dmin);
162               
163        }
164
165        private void newBranchLengths()
166        {
167                double dij = getDist(besti, bestj);
168               
169                getRoot().getChild(besti).setBranchLength(dij/2.0-height[abi]);
170                getRoot().getChild(bestj).setBranchLength(dij/2.0-height[abj]);
171        }
172
173        private void newCluster()
174        {
175                // Update distances
176                for (int k = 0; k < numClusters; k++)
177                {
178                        if (k != besti && k != bestj)
179                        {
180                                int ak = alias[k];
181                                double updated = updatedDistance(besti,bestj,k);
182                                distance[ak][abi] = distance[abi][ak] = updated;
183                        }
184                }
185                distance[abi][abi] = 0.0;
186
187                // Update UPGMA variables
188                height[abi] = getDist(besti, bestj)/2.0;
189                oc[abi] += oc[abj];
190               
191                // Index besti now represent the new cluster
192                Node newNode = getRoot().joinChildren(besti, bestj);
193               
194                if(newNode instanceof FengDoolittleNode) {
195                        newNode.setSequences(alignSequences(newNode));
196                }
197               
198                // Update alias
199                for (int i = bestj; i < numClusters-1; i++)
200                {
201                        alias[i] = alias[i+1];
202                }
203               
204                numClusters--;
205        }
206       
207       
208        public ArrayList<NumberSequence> alignSequences(Node parent) {
209                ArrayList<NumberSequence> alignment = new ArrayList<NumberSequence>();
210                if(parent.getChildCount()<3) {
211                       
212                        Node node1 = parent.getChild(0);
213                        Node node2 = parent.getChild(1);
214                       
215                        int seqCount1 = node1.getSequences().size();
216                        int seqCount2 = node2.getSequences().size();
217
218                       
219                        Console.traceln(Level.INFO,"Merging node " + node1.getIdentifier() + " with " + node2.getIdentifier());
220                        //Console.println("SeqCount1: " + seqCount1 + " seqCount2 " + seqCount2);
221                        //Align 2 sequences
222                        if(seqCount1 == 1 && seqCount2 == 1) {
223                                alignment = (alignments.get(node1.getNumber(), node2.getNumber())).getAlignment();
224                               
225                        }
226                        //Align a sequence to a group
227                        else if( seqCount1 > 1 && seqCount2 == 1) {
228                                alignment.addAll(node1.getSequences());
229                               
230                                PairwiseAlignmentStorage tempStorage = new PairwiseAlignmentStorage(seqCount1,seqCount2);
231                                double maxScore = 0.0;
232                                int maxIndex = 0;
233                                for(int i=0;i<seqCount1;i++) {
234                                        tempStorage.set(i, 1, AlignmentAlgorithmFactory.create(node1.getSequence(i).getSequence(), node2.getSequence(0).getSequence() , submat, 5));
235                                        if(maxScore < tempStorage.get(i, 1).getAlignmentScore()) {
236                                                maxScore = tempStorage.get(i, 1).getAlignmentScore();
237                                                maxIndex = i;
238                                        }
239                                }
240                                //if(maxScore > 0)
241                               
242                                alignment.add(tempStorage.get(maxIndex, 1).getAlignment().get(1));
243                        }
244                        //Align a sequence to a group
245                        else if(seqCount1 == 1 && seqCount2 > 1) {
246                                alignment.addAll(node2.getSequences());
247                                PairwiseAlignmentStorage tempStorage = new PairwiseAlignmentStorage(seqCount1,seqCount2);
248                                double maxScore = 0.0;
249                                int maxIndex = 0;
250                                for(int i=0;i<seqCount2;i++) {
251                                        tempStorage.set(1, i, AlignmentAlgorithmFactory.create(node2.getSequence(i).getSequence(), node1.getSequence(0).getSequence() , submat, 5));
252                                        if(maxScore < tempStorage.get(1, i).getAlignmentScore()) {
253                                                maxScore = tempStorage.get(1, i).getAlignmentScore();
254                                                maxIndex = i;
255                                        }
256                                }
257                                //if(maxScore > 0)
258                               
259                                alignment.add(tempStorage.get(1,maxIndex).getAlignment().get(1));
260                        }
261                        //Align 2 groups
262                        else if((seqCount1 > 1) && (seqCount2 > 1)){
263                                        PairwiseAlignmentStorage tempStorage1 = new PairwiseAlignmentStorage(seqCount2,1);
264                                        PairwiseAlignmentStorage tempStorage2 = new PairwiseAlignmentStorage(seqCount1,1);
265                                        double maxScore1 = 0.0;
266                                        double maxScore2 = 0.0;
267                                        int maxIndex1 = 0;
268                                        int maxIndex2 = 0;
269                                        for(int i=0;i<seqCount1;i++) {
270                                                for(int j=0;j<seqCount2;j++) {
271                                                        tempStorage1.set(j, 0, AlignmentAlgorithmFactory.create(node1.getSequence(i).getSequence(), node2.getSequence(j).getSequence() , submat, 5));
272                                                        if(maxScore1 < tempStorage1.get(j, 0).getAlignmentScore()) {
273                                                                maxScore1 = tempStorage1.get(j, 0).getAlignmentScore();
274                                                                maxIndex1 = j;
275                                                        }
276                                                }
277                                                //if(maxScore1 > 0)
278                                                alignment.add(tempStorage1.get(maxIndex1,0).getAlignment().get(0));
279                                        }
280                                        for(int i=0; i<seqCount2;i++) {
281                                                for (int j=0;j<seqCount1;j++) {
282                                                        tempStorage2.set(j, 0, AlignmentAlgorithmFactory.create(node2.getSequence(i).getSequence(),node1.getSequence(j).getSequence(),submat,5));
283                                                        if(maxScore2 < tempStorage2.get(j, 0).getAlignmentScore()) {
284                                                                maxScore2 = tempStorage2.get(j, 0).getAlignmentScore();
285                                                                maxIndex2 = j;
286                                                        }
287                                                }
288                                                //if(maxScore2 > 0)
289                                                alignment.add(tempStorage2.get(maxIndex2,0).getAlignment().get(0));
290                                        }
291                                       
292                        }
293                        else {
294                                Console.traceln(Level.WARNING,"No sequences to align while merging " + node1.getIdentifier() + " with " + node2.getIdentifier());
295                        }
296                }
297                else {
298                        Console.traceln(Level.WARNING,"More than 2 children! This should never happen, it's a binary tree.");
299                }
300                return alignment;
301        }
302
303       
304        /**
305         * compute updated distance between the new cluster (i,j)
306         * to any other cluster k
307         */
308        private double updatedDistance(int i, int j, int k)
309        {
310                int ai = alias[i];
311                int aj = alias[j];
312               
313                double ocsum = (double) (oc[ai]+oc[aj]);
314                double idist = getDist(k,i);
315                double jdist = getDist(k,j);
316                //TODO: Dirty hack to deal with infinity, insert proper solution here
317                if(Double.isInfinite(idist)) {
318                        idist = 100;
319                }
320                if(Double.isInfinite(jdist)) {
321                        jdist = 100;
322                }
323               
324                return  (oc[ai]/ocsum)*idist+
325                        (oc[aj]/ocsum)*jdist;
326        }
327
328
329        public void printMultipleAlignment() {
330                for (Iterator<NumberSequence> it =  getRoot().getSequences().iterator(); it.hasNext();) {
331                        NumberSequence tmp  = it.next();
332                        tmp.printSequence();
333                }
334        }
335}
336
337
Note: See TracBrowser for help on using the repository browser.