source: branches/ralph/src/main/java/de/ugoe/cs/autoquest/tasktrees/alignment/pal/tree/UPGMAAligningTree.java @ 1612

Last change on this file since 1612 was 1612, checked in by rkrimmel, 10 years ago

Removed parameters from alignmentalgorihm factory constructor and changed interface by adding a new align() method, which now gets all the data via parameter

File size: 9.3 KB
Line 
1// UPGMATree.java
2//
3// (c) 1999-2001 PAL Development Core Team
4//
5// This package may be distributed under the
6// terms of the Lesser GNU General Public License (LGPL)
7
8// Known bugs and limitations:
9// - computational complexity O(numSeqs^3)
10//   (this could be brought down to O(numSeqs^2)
11//   but this needs more clever programming ...)
12
13
14package de.ugoe.cs.autoquest.tasktrees.alignment.pal.tree;
15
16import java.util.ArrayList;
17import java.util.Iterator;
18import java.util.logging.Level;
19
20import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithm;
21import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.AlignmentAlgorithmFactory;
22import de.ugoe.cs.autoquest.tasktrees.alignment.algorithms.NumberSequence;
23import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.PairwiseAlignmentStorage;
24import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.ObjectDistanceSubstitionMatrix;
25import de.ugoe.cs.autoquest.tasktrees.alignment.matrix.UPGMAMatrix;
26import de.ugoe.cs.autoquest.tasktrees.alignment.pal.misc.Identifier;
27import de.ugoe.cs.util.console.Console;
28
29
30/**
31 * constructs a UPGMA tree from pairwise distances
32 *
33 * @version $Id: UPGMATree.java,v 1.9 2001/07/13 14:39:13 korbinian Exp $
34 *
35 * @author Korbinian Strimmer
36 * @author Alexei Drummond
37 */
38public class UPGMAAligningTree extends SimpleTree
39{
40        //
41        // Public stuff
42        //     
43
44        /**
45         * constructor UPGMA tree
46         * @param numberseqs
47         *
48         * @param m distance matrix
49         */
50        public UPGMAAligningTree(ArrayList<NumberSequence> numberseqs, PairwiseAlignmentStorage alignments, ObjectDistanceSubstitionMatrix submat)
51        {
52                if (alignments.getDistanceMatrix().size() < 2)
53                {
54                        new IllegalArgumentException("LESS THAN 2 TAXA IN DISTANCE MATRIX");
55                }
56       
57                this.numberseqs = numberseqs;
58                this.alignments = alignments;
59                this.submat = submat;
60                init(alignments.getDistanceMatrix());
61
62                while (true)
63                {
64                        findNextPair();
65                        newBranchLengths();
66                       
67                        if (numClusters == 2)
68                        {
69                                break;
70                        }
71                       
72                        newCluster();
73                }
74               
75                finish();
76                createNodeList();
77        }
78
79
80        //
81        // Private stuff
82        //
83        private ArrayList<NumberSequence> numberseqs;
84        private PairwiseAlignmentStorage alignments;
85        private ObjectDistanceSubstitionMatrix submat;
86        private int numClusters;
87        private int besti, abi;
88        private int bestj, abj;
89        private int[] alias;
90        private double[][] distance;
91
92        private double[] height;
93        private int[] oc;
94
95        private double getDist(int a, int b)
96        {
97                return distance[alias[a]][alias[b]];
98        }
99       
100        private void init(UPGMAMatrix m)
101        {
102                numClusters = m.size();
103
104                distance = new double[numClusters][numClusters];
105                for (int i = 0; i < numClusters; i++)
106                {
107                        for (int j = 0; j < numClusters; j++)
108                        {
109                                distance[i][j] = m.get(i,j);
110                        }
111                }
112
113                for (int i = 0; i < numClusters; i++)
114                {
115                        Node tmp = NodeFactory.createNode();
116                        tmp.setIdentifier(new Identifier(Integer.toString(i)));
117                        tmp.setNumber(i);
118                        tmp.addSequence(numberseqs.get(i));
119                        getRoot().addChild(tmp);
120                }
121               
122                alias = new int[numClusters];
123                for (int i = 0; i < numClusters; i++)
124                {
125                        alias[i] = i;
126                }
127                               
128                height = new double[numClusters];
129                oc = new int[numClusters];
130                for (int i = 0; i < numClusters; i++)
131                {
132                        height[i] = 0.0;
133                        oc[i] = 1;
134                }
135        }
136
137        private void finish()
138        {
139                this.getRoot().setSequences(alignSequences(this.getRoot()));
140                distance = null;               
141        }
142
143        private void findNextPair()
144        {
145                besti = 0;
146                bestj = 1;
147                double dmin = getDist(0, 1);
148                for (int i = 0; i < numClusters-1; i++)
149                {
150                        for (int j = i+1; j < numClusters; j++)
151                        {
152                                if (getDist(i, j) < dmin)
153                                {
154                                        dmin = getDist(i, j);
155                                        besti = i;
156                                        bestj = j;
157                                }
158                        }
159                }
160                abi = alias[besti];
161                abj = alias[bestj];
162                //System.out.println("Found best pair: " + abi + "/" +abj + " - "+ besti+ "/"+bestj +" with distance " + dmin);
163               
164        }
165
166        private void newBranchLengths()
167        {
168                double dij = getDist(besti, bestj);
169               
170                getRoot().getChild(besti).setBranchLength(dij/2.0-height[abi]);
171                getRoot().getChild(bestj).setBranchLength(dij/2.0-height[abj]);
172        }
173
174        private void newCluster()
175        {
176                // Update distances
177                for (int k = 0; k < numClusters; k++)
178                {
179                        if (k != besti && k != bestj)
180                        {
181                                int ak = alias[k];
182                                double updated = updatedDistance(besti,bestj,k);
183                                distance[ak][abi] = distance[abi][ak] = updated;
184                        }
185                }
186                distance[abi][abi] = 0.0;
187
188                // Update UPGMA variables
189                height[abi] = getDist(besti, bestj)/2.0;
190                oc[abi] += oc[abj];
191               
192                // Index besti now represent the new cluster
193                Node newNode = getRoot().joinChildren(besti, bestj);
194               
195                if(newNode instanceof FengDoolittleNode) {
196                        newNode.setSequences(alignSequences(newNode));
197                }
198               
199                // Update alias
200                for (int i = bestj; i < numClusters-1; i++)
201                {
202                        alias[i] = alias[i+1];
203                }
204               
205                numClusters--;
206        }
207       
208       
209        public ArrayList<NumberSequence> alignSequences(Node parent) {
210                ArrayList<NumberSequence> alignment = new ArrayList<NumberSequence>();
211                if(parent.getChildCount()<3) {
212                       
213                        Node node1 = parent.getChild(0);
214                        Node node2 = parent.getChild(1);
215                       
216                        int seqCount1 = node1.getSequences().size();
217                        int seqCount2 = node2.getSequences().size();
218
219                       
220                        Console.traceln(Level.INFO,"Merging node " + node1.getIdentifier() + " with " + node2.getIdentifier());
221                        //Console.println("SeqCount1: " + seqCount1 + " seqCount2 " + seqCount2);
222                        //Align 2 sequences
223                        if(seqCount1 == 1 && seqCount2 == 1) {
224                                AlignmentAlgorithm aa = AlignmentAlgorithmFactory.create();
225                                aa.align(node1.getSequence(0).getSequence(), node2.getSequence(0).getSequence(), submat, 5);
226                                alignment = aa.getAlignment();
227                               
228                        }
229                        //Align a sequence to a group
230                        else if( seqCount1 > 1 && seqCount2 == 1) {
231                                alignment.addAll(node1.getSequences());
232                               
233                                PairwiseAlignmentStorage tempStorage = new PairwiseAlignmentStorage(seqCount1,seqCount2);
234                                double maxScore = 0.0;
235                                int maxIndex = 0;
236                                for(int i=0;i<seqCount1;i++){
237                                        AlignmentAlgorithm aa = AlignmentAlgorithmFactory.create();
238                                        aa.align(node1.getSequence(i).getSequence(), node2.getSequence(0).getSequence() , submat, 5);
239                                        tempStorage.set(i, 1, aa);
240                                       
241                                        if(maxScore < tempStorage.get(i, 1).getAlignmentScore()) {
242                                                maxScore = tempStorage.get(i, 1).getAlignmentScore();
243                                                maxIndex = i;
244                                        }
245                                }
246                                //if(maxScore > 0)
247                               
248                                alignment.add(tempStorage.get(maxIndex, 1).getAlignment().get(1));
249                        }
250                        //Align a sequence to a group
251                        else if(seqCount1 == 1 && seqCount2 > 1) {
252                                alignment.addAll(node2.getSequences());
253                                PairwiseAlignmentStorage tempStorage = new PairwiseAlignmentStorage(seqCount1,seqCount2);
254                                double maxScore = 0.0;
255                                int maxIndex = 0;
256                                for(int i=0;i<seqCount2;i++) {
257                                        AlignmentAlgorithm aa = AlignmentAlgorithmFactory.create();
258                                        aa.align(node2.getSequence(i).getSequence(), node1.getSequence(0).getSequence() , submat, 5);
259                                        tempStorage.set(1, i, aa);
260                                        if(maxScore < tempStorage.get(1, i).getAlignmentScore()) {
261                                                maxScore = tempStorage.get(1, i).getAlignmentScore();
262                                                maxIndex = i;
263                                        }
264                                }
265                                //if(maxScore > 0)
266                               
267                                alignment.add(tempStorage.get(1,maxIndex).getAlignment().get(1));
268                        }
269                        //Align 2 groups
270                        else if((seqCount1 > 1) && (seqCount2 > 1)){
271                                        PairwiseAlignmentStorage tempStorage1 = new PairwiseAlignmentStorage(seqCount2,1);
272                                        PairwiseAlignmentStorage tempStorage2 = new PairwiseAlignmentStorage(seqCount1,1);
273                                        double maxScore1 = 0.0;
274                                        double maxScore2 = 0.0;
275                                        int maxIndex1 = 0;
276                                        int maxIndex2 = 0;
277                                        for(int i=0;i<seqCount1;i++) {
278                                                for(int j=0;j<seqCount2;j++) {
279                                                        AlignmentAlgorithm aa =AlignmentAlgorithmFactory.create();
280                                                        aa.align(node1.getSequence(i).getSequence(), node2.getSequence(j).getSequence() , submat, 5);
281                                                        tempStorage1.set(j, 0, aa);
282                                                        if(maxScore1 < tempStorage1.get(j, 0).getAlignmentScore()) {
283                                                                maxScore1 = tempStorage1.get(j, 0).getAlignmentScore();
284                                                                maxIndex1 = j;
285                                                        }
286                                                }
287                                                //if(maxScore1 > 0)
288                                                alignment.add(tempStorage1.get(maxIndex1,0).getAlignment().get(0));
289                                        }
290                                        for(int i=0; i<seqCount2;i++) {
291                                                for (int j=0;j<seqCount1;j++) {
292                                                        AlignmentAlgorithm aa =AlignmentAlgorithmFactory.create();
293                                                        aa.align(node2.getSequence(i).getSequence(),node1.getSequence(j).getSequence(),submat,5);
294                                                        tempStorage2.set(j, 0, aa);
295                                                        if(maxScore2 < tempStorage2.get(j, 0).getAlignmentScore()) {
296                                                                maxScore2 = tempStorage2.get(j, 0).getAlignmentScore();
297                                                                maxIndex2 = j;
298                                                        }
299                                                }
300                                                //if(maxScore2 > 0)
301                                                alignment.add(tempStorage2.get(maxIndex2,0).getAlignment().get(0));
302                                        }
303                                       
304                        }
305                        else {
306                                Console.traceln(Level.WARNING,"No sequences to align while merging " + node1.getIdentifier() + " with " + node2.getIdentifier());
307                        }
308                }
309                else {
310                        Console.traceln(Level.WARNING,"More than 2 children! This should never happen, it's a binary tree.");
311                }
312                return alignment;
313        }
314
315       
316        /**
317         * compute updated distance between the new cluster (i,j)
318         * to any other cluster k
319         */
320        private double updatedDistance(int i, int j, int k)
321        {
322                int ai = alias[i];
323                int aj = alias[j];
324               
325                double ocsum = (double) (oc[ai]+oc[aj]);
326                double idist = getDist(k,i);
327                double jdist = getDist(k,j);
328                //TODO: Dirty hack to deal with infinity, insert proper solution here
329                if(Double.isInfinite(idist)) {
330                        idist = 100;
331                }
332                if(Double.isInfinite(jdist)) {
333                        jdist = 100;
334                }
335               
336                return  (oc[ai]/ocsum)*idist+
337                        (oc[aj]/ocsum)*jdist;
338        }
339
340
341        public void printMultipleAlignment() {
342                for (Iterator<NumberSequence> it =  getRoot().getSequences().iterator(); it.hasNext();) {
343                        NumberSequence tmp  = it.next();
344                        tmp.printSequence();
345                }
346        }
347}
348
349
Note: See TracBrowser for help on using the repository browser.