1 package org.broadinstitute.hellbender.tools.sv;
2 
3 import htsjdk.samtools.SAMSequenceDictionary;
4 import htsjdk.samtools.util.Locatable;
5 import org.broadinstitute.hellbender.utils.*;
6 import scala.Tuple2;
7 
8 import java.util.*;
9 import java.util.stream.Collectors;
10 import java.util.stream.IntStream;
11 
12 public abstract class LocatableClusterEngine<T extends Locatable> {
13 
14     protected final TreeMap<GenomeLoc, Integer> genomicToBinMap;
15     protected final List<GenomeLoc> coverageIntervals;
16     final GenomeLocParser parser;
17 
18     public enum CLUSTERING_TYPE {
19         SINGLE_LINKAGE,
20         MAX_CLIQUE
21     }
22 
23     protected final SAMSequenceDictionary dictionary;
24     private final List<Tuple2<SimpleInterval, List<Long>>> currentClusters; // Pairs of cluster start interval with item IDs
25     private final Map<Long,T> idToItemMap;
26     private final List<T> outputBuffer;
27     private final CLUSTERING_TYPE clusteringType;
28     private long currentItemId;
29     private String currentContig;
30 
31 
LocatableClusterEngine(final SAMSequenceDictionary dictionary, final CLUSTERING_TYPE clusteringType, final List<GenomeLoc> coverageIntervals)32     public LocatableClusterEngine(final SAMSequenceDictionary dictionary, final CLUSTERING_TYPE clusteringType, final List<GenomeLoc> coverageIntervals) {
33         this.dictionary = dictionary;
34         this.clusteringType = clusteringType;
35         this.currentClusters = new LinkedList<>();
36         this.idToItemMap = new HashMap<>();
37         this.outputBuffer = new ArrayList<>();
38         currentItemId = 0;
39         currentContig = null;
40 
41         parser = new GenomeLocParser(this.dictionary);
42         if (coverageIntervals != null) {
43             this.coverageIntervals = coverageIntervals;
44             genomicToBinMap = new TreeMap<>();
45             for (int i = 0; i < coverageIntervals.size(); i++) {
46                 genomicToBinMap.put(coverageIntervals.get(i),i);
47             }
48         } else {
49             genomicToBinMap = null;
50             this.coverageIntervals = null;
51         }
52     }
53 
clusterTogether(final T a, final T b)54     abstract protected boolean clusterTogether(final T a, final T b);
getClusteringInterval(final T item, final SimpleInterval currentClusterInterval)55     abstract protected SimpleInterval getClusteringInterval(final T item, final SimpleInterval currentClusterInterval);
deduplicateIdenticalItems(final Collection<T> items)56     abstract protected T deduplicateIdenticalItems(final Collection<T> items);
itemsAreIdentical(final T a, final T b)57     abstract protected boolean itemsAreIdentical(final T a, final T b);
flattenCluster(final Collection<T> cluster)58     abstract protected T flattenCluster(final Collection<T> cluster);
59 
getOutput()60     public List<T> getOutput() {
61         flushClusters();
62         final List<T> output;
63         if (clusteringType == CLUSTERING_TYPE.MAX_CLIQUE) {
64             output = deduplicateItems(outputBuffer);
65         } else {
66             output = new ArrayList<>(outputBuffer);
67         }
68         outputBuffer.clear();
69         return output;
70     }
71 
resetItemIds()72     private void resetItemIds() {
73         Utils.validate(currentClusters.isEmpty(), "Current cluster collection not empty");
74         currentItemId = 0;
75         idToItemMap.clear();
76     }
77 
isEmpty()78     public boolean isEmpty() {
79         return currentContig == null;
80     }
81 
add(final T item)82     public void add(final T item) {
83 
84         // Start a new cluster if on a new contig
85         if (!item.getContig().equals(currentContig)) {
86             flushClusters();
87             currentContig = item.getContig();
88             idToItemMap.put(currentItemId, item);
89             seedCluster(currentItemId);
90             currentItemId++;
91             return;
92         }
93 
94         // Keep track of a unique id for each item
95         idToItemMap.put(currentItemId, item);
96         final List<Integer> clusterIdsToProcess = cluster(item);
97         processFinalizedClusters(clusterIdsToProcess);
98         deleteRedundantClusters();
99         currentItemId++;
100     }
101 
getCurrentContig()102     public String getCurrentContig() {
103         return currentContig;
104     }
105 
deduplicateItems(final List<T> items)106     public List<T> deduplicateItems(final List<T> items) {
107         final List<T> sortedItems = IntervalUtils.sortLocatablesBySequenceDictionary(items, dictionary);
108         final List<T> deduplicatedList = new ArrayList<>();
109         int i = 0;
110         while (i < sortedItems.size()) {
111             final T record = sortedItems.get(i);
112             int j = i + 1;
113             final Collection<Integer> identicalItemIndexes = new ArrayList<>();
114             while (j < sortedItems.size() && record.getStart() == sortedItems.get(j).getStart()) {
115                 final T other = sortedItems.get(j);
116                 if (itemsAreIdentical(record, other)) {
117                     identicalItemIndexes.add(j);
118                 }
119                 j++;
120             }
121             if (identicalItemIndexes.isEmpty()) {
122                 deduplicatedList.add(record);
123                 i++;
124             } else {
125                 identicalItemIndexes.add(i);
126                 final List<T> identicalItems = identicalItemIndexes.stream().map(sortedItems::get).collect(Collectors.toList());
127                 deduplicatedList.add(deduplicateIdenticalItems(identicalItems));
128                 i = j;
129             }
130         }
131         return deduplicatedList;
132     }
133 
134     /**
135      * Add a new {@param <T>} to the current clusters and determine which are complete
136      * @param item to be added
137      * @return the IDs for clusters that are complete and ready for processing
138      */
cluster(final T item)139     private List<Integer> cluster(final T item) {
140         // Get list of item IDs from active clusters that cluster with this item
141         final Set<Long> linkedItemIds = idToItemMap.entrySet().stream()
142                 .filter(other -> other.getKey().intValue() != currentItemId && clusterTogether(item, other.getValue()))
143                 .map(Map.Entry::getKey)
144                 .collect(Collectors.toCollection(LinkedHashSet::new));
145 
146         // Find clusters to which this item belongs, and which active clusters we're definitely done with
147         int clusterIndex = 0;
148         final List<Integer> clusterIdsToProcess = new ArrayList<>();
149         final List<Integer> clustersToAdd = new ArrayList<>();
150         final List<Integer> clustersToSeedWith = new ArrayList<>();
151         for (final Tuple2<SimpleInterval, List<Long>> cluster : currentClusters) {
152             final SimpleInterval clusterInterval = cluster._1;
153             final List<Long> clusterItemIds = cluster._2;
154             if (getClusteringInterval(item, null).getStart() > clusterInterval.getEnd()) {
155                 clusterIdsToProcess.add(clusterIndex);  //this cluster is complete -- process it when we're done
156             } else {
157                 if (clusteringType.equals(CLUSTERING_TYPE.MAX_CLIQUE)) {
158                     final int n = (int) clusterItemIds.stream().filter(linkedItemIds::contains).count();
159                     if (n == clusterItemIds.size()) {
160                         clustersToAdd.add(clusterIndex);
161                     } else if (n > 0) {
162                         clustersToSeedWith.add(clusterIndex);
163                     }
164                 } else if (clusteringType.equals(CLUSTERING_TYPE.SINGLE_LINKAGE)) {
165                     final boolean matchesCluster = clusterItemIds.stream().anyMatch(linkedItemIds::contains);
166                     if (matchesCluster) {
167                         clustersToAdd.add(clusterIndex);
168                     }
169                 } else {
170                     throw new IllegalArgumentException("Clustering algorithm for type " + clusteringType.name() + " not implemented");
171                 }
172             }
173             clusterIndex++;
174         }
175 
176         // Add to item clusters
177         for (final int index : clustersToAdd) {
178             addToCluster(index, currentItemId);
179         }
180         // Create new clusters/cliques
181         for (final int index : clustersToSeedWith) {
182             seedWithExistingCluster(currentItemId, index, linkedItemIds);
183         }
184         // If there weren't any matches, create a new singleton cluster
185         if (clustersToAdd.isEmpty() && clustersToSeedWith.isEmpty()) {
186             seedCluster(currentItemId);
187         }
188         return clusterIdsToProcess;
189     }
190 
processCluster(final int clusterIndex)191     private void processCluster(final int clusterIndex) {
192         final Tuple2<SimpleInterval, List<Long>> cluster = validateClusterIndex(clusterIndex);
193         final List<Long> clusterItemIds = cluster._2;
194         currentClusters.remove(clusterIndex);
195         final List<T> clusterItems = clusterItemIds.stream().map(idToItemMap::get).collect(Collectors.toList());
196         outputBuffer.add(flattenCluster(clusterItems));
197     }
198 
processFinalizedClusters(final List<Integer> clusterIdsToProcess)199     private void processFinalizedClusters(final List<Integer> clusterIdsToProcess) {
200         final Set<Integer> activeClusterIds = IntStream.range(0, currentClusters.size()).boxed().collect(Collectors.toSet());
201         activeClusterIds.removeAll(clusterIdsToProcess);
202         final Set<Long> activeClusterItemIds = activeClusterIds.stream().flatMap(i -> currentClusters.get(i)._2.stream()).collect(Collectors.toSet());
203         final Set<Long> finalizedItemIds = clusterIdsToProcess.stream()
204                 .flatMap(i -> currentClusters.get(i)._2.stream())
205                 .filter(i -> !activeClusterItemIds.contains(i))
206                 .collect(Collectors.toSet());
207         for (int i = clusterIdsToProcess.size() - 1; i >= 0; i--) {
208             processCluster(clusterIdsToProcess.get(i));
209         }
210         finalizedItemIds.stream().forEach(idToItemMap::remove);
211     }
212 
deleteRedundantClusters()213     private void deleteRedundantClusters() {
214         final Set<Integer> redundantClusterSet = new HashSet<>();
215         for (int i = 0; i < currentClusters.size(); i++) {
216             final Set<Long> clusterSetA = new HashSet<>(currentClusters.get(i)._2);
217             for (int j = 0; j < i; j++) {
218                 final Set<Long> clusterSetB = new HashSet<>(currentClusters.get(j)._2);
219                 if (clusterSetA.containsAll(clusterSetB)) {
220                     redundantClusterSet.add(j);
221                 } else if (clusterSetA.size() != clusterSetB.size() && clusterSetB.containsAll(clusterSetA)) {
222                     redundantClusterSet.add(i);
223                 }
224             }
225         }
226         final List<Integer> redundantClustersList = new ArrayList<>(redundantClusterSet);
227         redundantClustersList.sort(Comparator.naturalOrder());
228         for (int i = redundantClustersList.size() - 1; i >= 0; i--) {
229             currentClusters.remove((int)redundantClustersList.get(i));
230         }
231     }
232 
flushClusters()233     private void flushClusters() {
234         while (!currentClusters.isEmpty()) {
235             processCluster(0);
236         }
237         resetItemIds();
238     }
239 
seedCluster(final long seedId)240     private void seedCluster(final long seedId) {
241         final T seed = validateItemIndex(seedId);
242         final List<Long> newCluster = new ArrayList<>(1);
243         newCluster.add(seedId);
244         currentClusters.add(new Tuple2<>(getClusteringInterval(seed, null), newCluster));
245     }
246 
247     /**
248      * Create a new cluster
249      * @param seedId    itemId
250      * @param existingClusterIndex
251      * @param clusteringIds
252      */
seedWithExistingCluster(final Long seedId, final int existingClusterIndex, final Set<Long> clusteringIds)253     private void seedWithExistingCluster(final Long seedId, final int existingClusterIndex, final Set<Long> clusteringIds) {
254         final T seed = validateItemIndex(seedId);
255         final List<Long> existingCluster = currentClusters.get(existingClusterIndex)._2;
256         final List<Long> validClusterIds = existingCluster.stream().filter(clusteringIds::contains).collect(Collectors.toList());
257         final List<Long> newCluster = new ArrayList<>(1 + validClusterIds.size());
258         newCluster.addAll(validClusterIds);
259         newCluster.add(seedId);
260         currentClusters.add(new Tuple2<>(getClusteringInterval(seed, currentClusters.get(existingClusterIndex)._1), newCluster));
261     }
262 
validateItemIndex(final long index)263     private T validateItemIndex(final long index) {
264         final T item = idToItemMap.get(index);
265         if (item == null) {
266             throw new IllegalArgumentException("Item id " + index + " not found in table");
267         }
268         if (!currentContig.equals(item.getContig())) {
269             throw new IllegalArgumentException("Attempted to seed new cluster with item on contig " + item.getContig() + " but the current contig is " + currentContig);
270         }
271         return item;
272     }
273 
validateClusterIndex(final int index)274     private Tuple2<SimpleInterval, List<Long>> validateClusterIndex(final int index) {
275         if (index < 0 || index >= currentClusters.size()) {
276             throw new IllegalArgumentException("Specified cluster index " + index + " is out of range.");
277         }
278         final Tuple2<SimpleInterval, List<Long>> cluster = currentClusters.get(index);
279         final List<Long> clusterItemIds = cluster._2;
280         if (clusterItemIds.isEmpty()) {
281             throw new IllegalArgumentException("Encountered empty cluster");
282         }
283         return cluster;
284     }
285 
286     /**
287      * Add the item specified by {@param itemId} to the cluster specified by {@param clusterIndex}
288      * and expand the clustering interval
289      * @param clusterIndex
290      * @param itemId
291      */
addToCluster(final int clusterIndex, final long itemId)292     private void addToCluster(final int clusterIndex, final long itemId) {
293         final T item = idToItemMap.get(itemId);
294         if (item == null) {
295             throw new IllegalArgumentException("Item id " + item + " not found in table");
296         }
297         if (!currentContig.equals(item.getContig())) {
298             throw new IllegalArgumentException("Attempted to add new item on contig " + item.getContig() + " but the current contig is " + currentContig);
299         }
300         if (clusterIndex >= currentClusters.size()) {
301             throw new IllegalArgumentException("Specified cluster index " + clusterIndex + " is greater than the largest index.");
302         }
303         final Tuple2<SimpleInterval, List<Long>> cluster = currentClusters.get(clusterIndex);
304         final SimpleInterval clusterInterval = cluster._1;
305         final List<Long> clusterItems = cluster._2;
306         clusterItems.add(itemId);
307         final SimpleInterval clusteringStartInterval = getClusteringInterval(item, clusterInterval);
308         if (clusteringStartInterval.getStart() != clusterInterval.getStart() || clusteringStartInterval.getEnd() != clusterInterval.getEnd()) {
309             currentClusters.remove(clusterIndex);
310             currentClusters.add(clusterIndex, new Tuple2<>(clusteringStartInterval, clusterItems));
311         }
312     }
313 }
314