1 package org.broadinstitute.hellbender.tools.sv; 2 3 import htsjdk.samtools.SAMSequenceDictionary; 4 import htsjdk.samtools.util.Locatable; 5 import org.broadinstitute.hellbender.utils.*; 6 import scala.Tuple2; 7 8 import java.util.*; 9 import java.util.stream.Collectors; 10 import java.util.stream.IntStream; 11 12 public abstract class LocatableClusterEngine<T extends Locatable> { 13 14 protected final TreeMap<GenomeLoc, Integer> genomicToBinMap; 15 protected final List<GenomeLoc> coverageIntervals; 16 final GenomeLocParser parser; 17 18 public enum CLUSTERING_TYPE { 19 SINGLE_LINKAGE, 20 MAX_CLIQUE 21 } 22 23 protected final SAMSequenceDictionary dictionary; 24 private final List<Tuple2<SimpleInterval, List<Long>>> currentClusters; // Pairs of cluster start interval with item IDs 25 private final Map<Long,T> idToItemMap; 26 private final List<T> outputBuffer; 27 private final CLUSTERING_TYPE clusteringType; 28 private long currentItemId; 29 private String currentContig; 30 31 LocatableClusterEngine(final SAMSequenceDictionary dictionary, final CLUSTERING_TYPE clusteringType, final List<GenomeLoc> coverageIntervals)32 public LocatableClusterEngine(final SAMSequenceDictionary dictionary, final CLUSTERING_TYPE clusteringType, final List<GenomeLoc> coverageIntervals) { 33 this.dictionary = dictionary; 34 this.clusteringType = clusteringType; 35 this.currentClusters = new LinkedList<>(); 36 this.idToItemMap = new HashMap<>(); 37 this.outputBuffer = new ArrayList<>(); 38 currentItemId = 0; 39 currentContig = null; 40 41 parser = new GenomeLocParser(this.dictionary); 42 if (coverageIntervals != null) { 43 this.coverageIntervals = coverageIntervals; 44 genomicToBinMap = new TreeMap<>(); 45 for (int i = 0; i < coverageIntervals.size(); i++) { 46 genomicToBinMap.put(coverageIntervals.get(i),i); 47 } 48 } else { 49 genomicToBinMap = null; 50 this.coverageIntervals = null; 51 } 52 } 53 clusterTogether(final T a, final T b)54 abstract protected boolean clusterTogether(final T a, final T b); getClusteringInterval(final T item, final SimpleInterval currentClusterInterval)55 abstract protected SimpleInterval getClusteringInterval(final T item, final SimpleInterval currentClusterInterval); deduplicateIdenticalItems(final Collection<T> items)56 abstract protected T deduplicateIdenticalItems(final Collection<T> items); itemsAreIdentical(final T a, final T b)57 abstract protected boolean itemsAreIdentical(final T a, final T b); flattenCluster(final Collection<T> cluster)58 abstract protected T flattenCluster(final Collection<T> cluster); 59 getOutput()60 public List<T> getOutput() { 61 flushClusters(); 62 final List<T> output; 63 if (clusteringType == CLUSTERING_TYPE.MAX_CLIQUE) { 64 output = deduplicateItems(outputBuffer); 65 } else { 66 output = new ArrayList<>(outputBuffer); 67 } 68 outputBuffer.clear(); 69 return output; 70 } 71 resetItemIds()72 private void resetItemIds() { 73 Utils.validate(currentClusters.isEmpty(), "Current cluster collection not empty"); 74 currentItemId = 0; 75 idToItemMap.clear(); 76 } 77 isEmpty()78 public boolean isEmpty() { 79 return currentContig == null; 80 } 81 add(final T item)82 public void add(final T item) { 83 84 // Start a new cluster if on a new contig 85 if (!item.getContig().equals(currentContig)) { 86 flushClusters(); 87 currentContig = item.getContig(); 88 idToItemMap.put(currentItemId, item); 89 seedCluster(currentItemId); 90 currentItemId++; 91 return; 92 } 93 94 // Keep track of a unique id for each item 95 idToItemMap.put(currentItemId, item); 96 final List<Integer> clusterIdsToProcess = cluster(item); 97 processFinalizedClusters(clusterIdsToProcess); 98 deleteRedundantClusters(); 99 currentItemId++; 100 } 101 getCurrentContig()102 public String getCurrentContig() { 103 return currentContig; 104 } 105 deduplicateItems(final List<T> items)106 public List<T> deduplicateItems(final List<T> items) { 107 final List<T> sortedItems = IntervalUtils.sortLocatablesBySequenceDictionary(items, dictionary); 108 final List<T> deduplicatedList = new ArrayList<>(); 109 int i = 0; 110 while (i < sortedItems.size()) { 111 final T record = sortedItems.get(i); 112 int j = i + 1; 113 final Collection<Integer> identicalItemIndexes = new ArrayList<>(); 114 while (j < sortedItems.size() && record.getStart() == sortedItems.get(j).getStart()) { 115 final T other = sortedItems.get(j); 116 if (itemsAreIdentical(record, other)) { 117 identicalItemIndexes.add(j); 118 } 119 j++; 120 } 121 if (identicalItemIndexes.isEmpty()) { 122 deduplicatedList.add(record); 123 i++; 124 } else { 125 identicalItemIndexes.add(i); 126 final List<T> identicalItems = identicalItemIndexes.stream().map(sortedItems::get).collect(Collectors.toList()); 127 deduplicatedList.add(deduplicateIdenticalItems(identicalItems)); 128 i = j; 129 } 130 } 131 return deduplicatedList; 132 } 133 134 /** 135 * Add a new {@param <T>} to the current clusters and determine which are complete 136 * @param item to be added 137 * @return the IDs for clusters that are complete and ready for processing 138 */ cluster(final T item)139 private List<Integer> cluster(final T item) { 140 // Get list of item IDs from active clusters that cluster with this item 141 final Set<Long> linkedItemIds = idToItemMap.entrySet().stream() 142 .filter(other -> other.getKey().intValue() != currentItemId && clusterTogether(item, other.getValue())) 143 .map(Map.Entry::getKey) 144 .collect(Collectors.toCollection(LinkedHashSet::new)); 145 146 // Find clusters to which this item belongs, and which active clusters we're definitely done with 147 int clusterIndex = 0; 148 final List<Integer> clusterIdsToProcess = new ArrayList<>(); 149 final List<Integer> clustersToAdd = new ArrayList<>(); 150 final List<Integer> clustersToSeedWith = new ArrayList<>(); 151 for (final Tuple2<SimpleInterval, List<Long>> cluster : currentClusters) { 152 final SimpleInterval clusterInterval = cluster._1; 153 final List<Long> clusterItemIds = cluster._2; 154 if (getClusteringInterval(item, null).getStart() > clusterInterval.getEnd()) { 155 clusterIdsToProcess.add(clusterIndex); //this cluster is complete -- process it when we're done 156 } else { 157 if (clusteringType.equals(CLUSTERING_TYPE.MAX_CLIQUE)) { 158 final int n = (int) clusterItemIds.stream().filter(linkedItemIds::contains).count(); 159 if (n == clusterItemIds.size()) { 160 clustersToAdd.add(clusterIndex); 161 } else if (n > 0) { 162 clustersToSeedWith.add(clusterIndex); 163 } 164 } else if (clusteringType.equals(CLUSTERING_TYPE.SINGLE_LINKAGE)) { 165 final boolean matchesCluster = clusterItemIds.stream().anyMatch(linkedItemIds::contains); 166 if (matchesCluster) { 167 clustersToAdd.add(clusterIndex); 168 } 169 } else { 170 throw new IllegalArgumentException("Clustering algorithm for type " + clusteringType.name() + " not implemented"); 171 } 172 } 173 clusterIndex++; 174 } 175 176 // Add to item clusters 177 for (final int index : clustersToAdd) { 178 addToCluster(index, currentItemId); 179 } 180 // Create new clusters/cliques 181 for (final int index : clustersToSeedWith) { 182 seedWithExistingCluster(currentItemId, index, linkedItemIds); 183 } 184 // If there weren't any matches, create a new singleton cluster 185 if (clustersToAdd.isEmpty() && clustersToSeedWith.isEmpty()) { 186 seedCluster(currentItemId); 187 } 188 return clusterIdsToProcess; 189 } 190 processCluster(final int clusterIndex)191 private void processCluster(final int clusterIndex) { 192 final Tuple2<SimpleInterval, List<Long>> cluster = validateClusterIndex(clusterIndex); 193 final List<Long> clusterItemIds = cluster._2; 194 currentClusters.remove(clusterIndex); 195 final List<T> clusterItems = clusterItemIds.stream().map(idToItemMap::get).collect(Collectors.toList()); 196 outputBuffer.add(flattenCluster(clusterItems)); 197 } 198 processFinalizedClusters(final List<Integer> clusterIdsToProcess)199 private void processFinalizedClusters(final List<Integer> clusterIdsToProcess) { 200 final Set<Integer> activeClusterIds = IntStream.range(0, currentClusters.size()).boxed().collect(Collectors.toSet()); 201 activeClusterIds.removeAll(clusterIdsToProcess); 202 final Set<Long> activeClusterItemIds = activeClusterIds.stream().flatMap(i -> currentClusters.get(i)._2.stream()).collect(Collectors.toSet()); 203 final Set<Long> finalizedItemIds = clusterIdsToProcess.stream() 204 .flatMap(i -> currentClusters.get(i)._2.stream()) 205 .filter(i -> !activeClusterItemIds.contains(i)) 206 .collect(Collectors.toSet()); 207 for (int i = clusterIdsToProcess.size() - 1; i >= 0; i--) { 208 processCluster(clusterIdsToProcess.get(i)); 209 } 210 finalizedItemIds.stream().forEach(idToItemMap::remove); 211 } 212 deleteRedundantClusters()213 private void deleteRedundantClusters() { 214 final Set<Integer> redundantClusterSet = new HashSet<>(); 215 for (int i = 0; i < currentClusters.size(); i++) { 216 final Set<Long> clusterSetA = new HashSet<>(currentClusters.get(i)._2); 217 for (int j = 0; j < i; j++) { 218 final Set<Long> clusterSetB = new HashSet<>(currentClusters.get(j)._2); 219 if (clusterSetA.containsAll(clusterSetB)) { 220 redundantClusterSet.add(j); 221 } else if (clusterSetA.size() != clusterSetB.size() && clusterSetB.containsAll(clusterSetA)) { 222 redundantClusterSet.add(i); 223 } 224 } 225 } 226 final List<Integer> redundantClustersList = new ArrayList<>(redundantClusterSet); 227 redundantClustersList.sort(Comparator.naturalOrder()); 228 for (int i = redundantClustersList.size() - 1; i >= 0; i--) { 229 currentClusters.remove((int)redundantClustersList.get(i)); 230 } 231 } 232 flushClusters()233 private void flushClusters() { 234 while (!currentClusters.isEmpty()) { 235 processCluster(0); 236 } 237 resetItemIds(); 238 } 239 seedCluster(final long seedId)240 private void seedCluster(final long seedId) { 241 final T seed = validateItemIndex(seedId); 242 final List<Long> newCluster = new ArrayList<>(1); 243 newCluster.add(seedId); 244 currentClusters.add(new Tuple2<>(getClusteringInterval(seed, null), newCluster)); 245 } 246 247 /** 248 * Create a new cluster 249 * @param seedId itemId 250 * @param existingClusterIndex 251 * @param clusteringIds 252 */ seedWithExistingCluster(final Long seedId, final int existingClusterIndex, final Set<Long> clusteringIds)253 private void seedWithExistingCluster(final Long seedId, final int existingClusterIndex, final Set<Long> clusteringIds) { 254 final T seed = validateItemIndex(seedId); 255 final List<Long> existingCluster = currentClusters.get(existingClusterIndex)._2; 256 final List<Long> validClusterIds = existingCluster.stream().filter(clusteringIds::contains).collect(Collectors.toList()); 257 final List<Long> newCluster = new ArrayList<>(1 + validClusterIds.size()); 258 newCluster.addAll(validClusterIds); 259 newCluster.add(seedId); 260 currentClusters.add(new Tuple2<>(getClusteringInterval(seed, currentClusters.get(existingClusterIndex)._1), newCluster)); 261 } 262 validateItemIndex(final long index)263 private T validateItemIndex(final long index) { 264 final T item = idToItemMap.get(index); 265 if (item == null) { 266 throw new IllegalArgumentException("Item id " + index + " not found in table"); 267 } 268 if (!currentContig.equals(item.getContig())) { 269 throw new IllegalArgumentException("Attempted to seed new cluster with item on contig " + item.getContig() + " but the current contig is " + currentContig); 270 } 271 return item; 272 } 273 validateClusterIndex(final int index)274 private Tuple2<SimpleInterval, List<Long>> validateClusterIndex(final int index) { 275 if (index < 0 || index >= currentClusters.size()) { 276 throw new IllegalArgumentException("Specified cluster index " + index + " is out of range."); 277 } 278 final Tuple2<SimpleInterval, List<Long>> cluster = currentClusters.get(index); 279 final List<Long> clusterItemIds = cluster._2; 280 if (clusterItemIds.isEmpty()) { 281 throw new IllegalArgumentException("Encountered empty cluster"); 282 } 283 return cluster; 284 } 285 286 /** 287 * Add the item specified by {@param itemId} to the cluster specified by {@param clusterIndex} 288 * and expand the clustering interval 289 * @param clusterIndex 290 * @param itemId 291 */ addToCluster(final int clusterIndex, final long itemId)292 private void addToCluster(final int clusterIndex, final long itemId) { 293 final T item = idToItemMap.get(itemId); 294 if (item == null) { 295 throw new IllegalArgumentException("Item id " + item + " not found in table"); 296 } 297 if (!currentContig.equals(item.getContig())) { 298 throw new IllegalArgumentException("Attempted to add new item on contig " + item.getContig() + " but the current contig is " + currentContig); 299 } 300 if (clusterIndex >= currentClusters.size()) { 301 throw new IllegalArgumentException("Specified cluster index " + clusterIndex + " is greater than the largest index."); 302 } 303 final Tuple2<SimpleInterval, List<Long>> cluster = currentClusters.get(clusterIndex); 304 final SimpleInterval clusterInterval = cluster._1; 305 final List<Long> clusterItems = cluster._2; 306 clusterItems.add(itemId); 307 final SimpleInterval clusteringStartInterval = getClusteringInterval(item, clusterInterval); 308 if (clusteringStartInterval.getStart() != clusterInterval.getStart() || clusteringStartInterval.getEnd() != clusterInterval.getEnd()) { 309 currentClusters.remove(clusterIndex); 310 currentClusters.add(clusterIndex, new Tuple2<>(clusteringStartInterval, clusterItems)); 311 } 312 } 313 } 314