1import argparse 2import collections 3import inspect 4import json 5import logging 6from abc import abstractmethod 7from typing import List, Tuple, Set, Dict, Optional 8 9import numpy as np 10import pymc3 as pm 11import scipy.sparse as sp 12import theano as th 13import theano.sparse as tst 14import theano.tensor as tt 15from pymc3 import Normal, Deterministic, DensityDist, Lognormal, Exponential 16 17from . import commons 18from .dists import HalfFlat 19from .fancy_model import GeneralizedContinuousModel 20from .theano_hmm import TheanoForwardBackward 21from .. import config, types 22from ..structs.interval import Interval, GCContentAnnotation 23from ..structs.metadata import SampleMetadataCollection 24from ..tasks.inference_task_base import HybridInferenceParameters 25 26_logger = logging.getLogger(__name__) 27 28_eps = commons.eps 29 30 31class DenoisingModelConfig: 32 """Configuration for the coverage denoising model, including hyper-parameters, model feature selection, 33 and choice of approximation schemes.""" 34 35 # approximation schemes for calculating expectations with respect to copy number posteriors 36 _q_c_expectation_modes = ['map', 'exact', 'hybrid'] 37 38 def __init__(self, 39 max_bias_factors: int = 5, 40 mapping_error_rate: float = 0.01, 41 psi_t_scale: float = 0.001, 42 psi_s_scale: float = 0.0001, 43 depth_correction_tau: float = 10000.0, 44 log_mean_bias_std: float = 0.1, 45 init_ard_rel_unexplained_variance: float = 0.1, 46 num_gc_bins: int = 20, 47 gc_curve_sd: float = 1.0, 48 q_c_expectation_mode: str = 'hybrid', 49 active_class_padding_hybrid_mode: int = 50000, 50 enable_bias_factors: bool = True, 51 enable_explicit_gc_bias_modeling: bool = False, 52 disable_bias_factors_in_active_class: bool = False): 53 """See `expose_args` for the description of arguments""" 54 self.max_bias_factors = max_bias_factors 55 self.mapping_error_rate = mapping_error_rate 56 self.psi_t_scale = psi_t_scale 57 self.psi_s_scale = psi_s_scale 58 self.depth_correction_tau = depth_correction_tau 59 self.log_mean_bias_std = log_mean_bias_std 60 self.init_ard_rel_unexplained_variance = init_ard_rel_unexplained_variance 61 self.num_gc_bins = num_gc_bins 62 self.gc_curve_sd = gc_curve_sd 63 self.q_c_expectation_mode = q_c_expectation_mode 64 self.active_class_padding_hybrid_mode = active_class_padding_hybrid_mode 65 self.enable_bias_factors = enable_bias_factors 66 self.enable_explicit_gc_bias_modeling = enable_explicit_gc_bias_modeling 67 self.disable_bias_factors_in_active_class = disable_bias_factors_in_active_class 68 69 @staticmethod 70 def expose_args(args: argparse.ArgumentParser, 71 hide: Set[str] = None): 72 """Exposes arguments of `__init__` to a given instance of `ArgumentParser`. 73 74 Args: 75 args: an instance of `ArgumentParser` 76 hide: a set of arguments not to expose 77 78 Returns: 79 None 80 """ 81 group = args.add_argument_group(title="Coverage denoising model parameters") 82 if hide is None: 83 hide = set() 84 85 initializer_params = inspect.signature(DenoisingModelConfig.__init__).parameters 86 valid_args = {"--" + arg for arg in initializer_params.keys()} 87 for hidden_arg in hide: 88 assert hidden_arg in valid_args, \ 89 "Initializer argument to be hidden {0} is not a valid initializer arguments; possible " \ 90 "choices are: {1}".format(hidden_arg, valid_args) 91 92 def process_and_maybe_add(arg, **kwargs): 93 full_arg = "--" + arg 94 if full_arg in hide: 95 return 96 kwargs['default'] = initializer_params[arg].default 97 group.add_argument(full_arg, **kwargs) 98 99 def str_to_bool(value: str): 100 if value.lower() in ('yes', 'true', 't', 'y', '1'): 101 return True 102 elif value.lower() in ('no', 'false', 'f', 'n', '0'): 103 return False 104 else: 105 raise argparse.ArgumentTypeError('Boolean value expected.') 106 107 process_and_maybe_add("max_bias_factors", 108 type=int, 109 help="Maximum number of bias factors") 110 111 process_and_maybe_add("mapping_error_rate", 112 type=float, 113 help="Typical mapping error rate") 114 115 process_and_maybe_add("psi_t_scale", 116 type=float, 117 help="Typical scale of interval-specific unexplained variance") 118 119 process_and_maybe_add("psi_s_scale", 120 type=float, 121 help="Typical scale of sample-specific unexplained variance") 122 123 process_and_maybe_add("depth_correction_tau", 124 type=float, 125 help="Precision of pinning read-depth in the coverage denoising model " 126 "to its globally determined value") 127 128 process_and_maybe_add("log_mean_bias_std", 129 type=float, 130 help="Standard deviation of mean bias in log space") 131 132 process_and_maybe_add("init_ard_rel_unexplained_variance", 133 type=float, 134 help="Initial value of automatic relevance determination (ARD) precisions relative " 135 "to the typical interval-specific unexplained variance scale") 136 137 process_and_maybe_add("num_gc_bins", 138 type=int, 139 help="Number of knobs on the GC curve") 140 141 process_and_maybe_add("gc_curve_sd", 142 type=float, 143 help="Prior standard deviation of the GC curve from a flat curve") 144 145 process_and_maybe_add("q_c_expectation_mode", 146 type=str, 147 choices=DenoisingModelConfig._q_c_expectation_modes, 148 help="The strategy for calculating copy number posterior expectations in the denoising " 149 "model. Choices: \"exact\": summation over all states, \"map\": drop all terms " 150 "except for the maximum a posteriori (MAP) copy number estimate, \"hybrid\": " 151 "use MAP strategy in silent regions and exact strategy in active regions.") 152 153 process_and_maybe_add("active_class_padding_hybrid_mode", 154 type=int, 155 help="If q_c_expectation_mode is set to \"hybrid\", the active intervals " 156 "will be further padded by this value (in the units of bp) in order to achieve " 157 "higher sensitivity in detecting common CNVs and to avoid boundary artifacts") 158 159 process_and_maybe_add("enable_bias_factors", 160 type=str_to_bool, 161 help="Enable discovery of novel bias factors") 162 163 process_and_maybe_add("enable_explicit_gc_bias_modeling", 164 type=str_to_bool, 165 help="Enable explicit modeling of GC bias (if enabled, the provided modeling interval " 166 "list of contain a column for {0} values)".format(GCContentAnnotation.get_key())) 167 168 process_and_maybe_add("disable_bias_factors_in_active_class", 169 type=str_to_bool, 170 help="Disable novel bias factor discovery CNV-active regions") 171 172 @staticmethod 173 def from_args_dict(args_dict: Dict) -> 'DenoisingModelConfig': 174 """Initialize an instance of `DenoisingModelConfig` from a dictionary of arguments. 175 176 Args: 177 args_dict: a dictionary of arguments; the keys must match argument names in 178 `DenoisingModelConfig.__init__` 179 180 Returns: 181 an instance of `DenoisingModelConfig` 182 """ 183 relevant_keys = set(inspect.getfullargspec(DenoisingModelConfig.__init__).args) 184 relevant_kwargs = {k: v for k, v in args_dict.items() if k in relevant_keys} 185 return DenoisingModelConfig(**relevant_kwargs) 186 187 @staticmethod 188 def from_json_file(json_file: str) -> 'DenoisingModelConfig': 189 with open(json_file, 'r') as fp: 190 imported_denoising_config_dict = json.load(fp) 191 return DenoisingModelConfig.from_args_dict(imported_denoising_config_dict) 192 193 194class CopyNumberCallingConfig: 195 """Configuration of the copy number caller.""" 196 def __init__(self, 197 p_alt: float = 1e-6, 198 p_active: float = 1e-3, 199 cnv_coherence_length: float = 10000.0, 200 class_coherence_length: float = 10000.0, 201 max_copy_number: int = 5, 202 num_calling_processes: int = 1): 203 """See `expose_args` for the description of arguments""" 204 assert 0.0 <= p_alt <= 1.0 205 assert 0.0 <= p_active <= 1.0 206 assert cnv_coherence_length > 0.0 207 assert class_coherence_length > 0.0 208 assert max_copy_number > 0 209 assert max_copy_number * p_alt < 1.0 210 assert num_calling_processes > 0 211 212 self.p_alt = p_alt 213 self.p_active = p_active 214 self.cnv_coherence_length = cnv_coherence_length 215 self.class_coherence_length = class_coherence_length 216 self.max_copy_number = max_copy_number 217 self.num_calling_processes = num_calling_processes 218 219 self.num_copy_number_states = max_copy_number + 1 220 self.num_copy_number_classes = 2 221 222 @staticmethod 223 def expose_args(args: argparse.ArgumentParser, hide: Set[str] = None): 224 """Exposes arguments of `__init__` to a given instance of `ArgumentParser`. 225 226 Args: 227 args: an instance of `ArgumentParser` 228 hide: a set of arguments not to expose 229 230 Returns: 231 None 232 """ 233 group = args.add_argument_group(title="Copy number calling parameters") 234 if hide is None: 235 hide = set() 236 237 initializer_params = inspect.signature(CopyNumberCallingConfig.__init__).parameters 238 valid_args = {"--" + arg for arg in initializer_params.keys()} 239 for hidden_arg in hide: 240 assert hidden_arg in valid_args, \ 241 "Initializer argument to be hidden {0} is not a valid initializer arguments; possible " \ 242 "choices are: {1}".format(hidden_arg, valid_args) 243 244 def process_and_maybe_add(arg, **kwargs): 245 full_arg = "--" + arg 246 if full_arg in hide: 247 return 248 kwargs['default'] = initializer_params[arg].default 249 group.add_argument(full_arg, **kwargs) 250 251 def str_to_bool(value: str): 252 if value.lower() in ('yes', 'true', 't', 'y', '1'): 253 return True 254 elif value.lower() in ('no', 'false', 'f', 'n', '0'): 255 return False 256 else: 257 raise argparse.ArgumentTypeError('Boolean value expected.') 258 259 process_and_maybe_add("p_alt", 260 type=float, 261 help="Prior probability of alternate copy number with respect to contig baseline " 262 "state in CNV-silent intervals") 263 264 process_and_maybe_add("p_active", 265 type=float, 266 help="Prior probability of treating an interval as CNV-active") 267 268 process_and_maybe_add("cnv_coherence_length", 269 type=float, 270 help="Coherence length of CNV events (in the units of bp)") 271 272 process_and_maybe_add("class_coherence_length", 273 type=float, 274 help="Coherence length of CNV-silent and CNV-active domains (in the units of bp)") 275 276 process_and_maybe_add("max_copy_number", 277 type=int, 278 help="Highest called copy number state") 279 280 process_and_maybe_add("num_calling_processes", 281 type=int, 282 help="Number of concurrent forward-backward threads (not implemented yet)") 283 284 @staticmethod 285 def from_args_dict(args_dict: Dict): 286 """Initialize an instance of `CopyNumberCallingConfig` from a dictionary of arguments. 287 288 Args: 289 args_dict: a dictionary of arguments; the keys must match argument names in 290 `CopyNumberCallingConfig.__init__` 291 292 Returns: 293 an instance of `CopyNumberCallingConfig` 294 """ 295 relevant_keys = set(inspect.getfullargspec(CopyNumberCallingConfig.__init__).args) 296 relevant_kwargs = {k: v for k, v in args_dict.items() if k in relevant_keys} 297 return CopyNumberCallingConfig(**relevant_kwargs) 298 299 @staticmethod 300 def from_json_file(json_file: str) -> 'CopyNumberCallingConfig': 301 with open(json_file, 'r') as fp: 302 imported_calling_config_dict = json.load(fp) 303 return CopyNumberCallingConfig.from_args_dict(imported_calling_config_dict) 304 305 306class PosteriorInitializer: 307 """Base class for posterior initializers.""" 308 @staticmethod 309 @abstractmethod 310 def initialize_posterior(denoising_config: DenoisingModelConfig, 311 calling_config: CopyNumberCallingConfig, 312 shared_workspace: 'DenoisingCallingWorkspace') -> None: 313 raise NotImplementedError 314 315 316class TrivialPosteriorInitializer(PosteriorInitializer): 317 """Initialize posteriors to reasonable values based on priors.""" 318 @staticmethod 319 def initialize_posterior(denoising_config: DenoisingModelConfig, 320 calling_config: CopyNumberCallingConfig, 321 shared_workspace: 'DenoisingCallingWorkspace'): 322 # interval class log posterior probs 323 class_probs_k = np.asarray([1.0 - calling_config.p_active, calling_config.p_active], dtype=types.floatX) 324 log_q_tau_tk = np.tile(np.log(class_probs_k), (shared_workspace.num_intervals, 1)) 325 shared_workspace.log_q_tau_tk = th.shared(log_q_tau_tk, name="log_q_tau_tk", borrow=config.borrow_numpy) 326 327 # copy number log posterior probs 328 log_q_c_stc = np.zeros((shared_workspace.num_samples, shared_workspace.num_intervals, 329 calling_config.num_copy_number_states), dtype=types.floatX) 330 t_to_j_map = shared_workspace.t_to_j_map.get_value(borrow=True) 331 for si in range(shared_workspace.num_samples): 332 sample_baseline_copy_number_j = shared_workspace.baseline_copy_number_sj[si, :] 333 sample_pi_jkc = HHMMClassAndCopyNumberBasicCaller.get_copy_number_prior_for_sample_jkc( 334 calling_config.num_copy_number_states, 335 calling_config.p_alt, 336 sample_baseline_copy_number_j) 337 sample_log_pi_jc = np.log(np.sum(sample_pi_jkc * class_probs_k[np.newaxis, :, np.newaxis], axis=1)) 338 for ti in range(shared_workspace.num_intervals): 339 log_q_c_stc[si, ti, :] = sample_log_pi_jc[t_to_j_map[ti], :] 340 shared_workspace.log_q_c_stc = th.shared(log_q_c_stc, name="log_q_c_stc", borrow=config.borrow_numpy) 341 342 343class DenoisingCallingWorkspace: 344 """This class contains objects (numpy arrays, theano tensors, etc) shared between the denoising model 345 and the copy number caller.""" 346 def __init__(self, 347 denoising_config: DenoisingModelConfig, 348 calling_config: CopyNumberCallingConfig, 349 interval_list: List[Interval], 350 n_st: np.ndarray, 351 sample_names: List[str], 352 sample_metadata_collection: SampleMetadataCollection, 353 posterior_initializer: Optional[PosteriorInitializer] = TrivialPosteriorInitializer): 354 self.denoising_config = denoising_config 355 self.calling_config = calling_config 356 self.interval_list = interval_list 357 self.sample_names = sample_names 358 359 assert n_st.ndim == 2, "Read counts matrix must be a 2-dim ndarray with shape (num_samples, num_intervals)" 360 361 self.num_samples: int = n_st.shape[0] 362 self.num_intervals: int = n_st.shape[1] 363 364 assert self.num_intervals >= 2, "At least two intervals must be provided" 365 assert len(interval_list) == self.num_intervals,\ 366 "The length of the interval list is incompatible with the shape of the read counts matrix" 367 368 # a list of unique contigs appearing in the interval list; the ordering is arbitrary and 369 # is only used internally 370 # Note: j is the index subscript used for contig index hereafter 371 self.contig_list = list({interval.contig for interval in interval_list}) 372 self.num_contigs = len(self.contig_list) 373 contig_to_j_map = {contig: self.contig_list.index(contig) for contig in self.contig_list} 374 t_to_j_map = np.asarray([contig_to_j_map[interval.contig] for interval in interval_list], 375 dtype=types.small_uint) 376 self.t_to_j_map: types.TensorSharedVariable = th.shared( 377 t_to_j_map, name="t_to_j_map", borrow=config.borrow_numpy) 378 379 self.global_read_depth_s, average_ploidy_s, self.baseline_copy_number_sj = \ 380 DenoisingCallingWorkspace._get_baseline_copy_number_and_read_depth( 381 sample_metadata_collection, sample_names, self.contig_list) 382 383 max_baseline_copy_number = np.max(self.baseline_copy_number_sj) 384 assert max_baseline_copy_number <= calling_config.max_copy_number, \ 385 "The highest contig ploidy ({0}) must be smaller or equal to the highest copy number state ({1})".format( 386 max_baseline_copy_number, calling_config.max_copy_number) 387 388 # shared theano tensors from the input data 389 self.n_st: types.TensorSharedVariable = th.shared( 390 n_st.astype(types.med_uint), name="n_st", borrow=config.borrow_numpy) 391 self.average_ploidy_s: types.TensorSharedVariable = th.shared( 392 average_ploidy_s.astype(types.floatX), name="average_ploidy_s", borrow=config.borrow_numpy) 393 394 # copy-number event stay probability 395 self.dist_t = np.asarray([self.interval_list[ti + 1].distance(self.interval_list[ti]) 396 for ti in range(self.num_intervals - 1)]) 397 cnv_stay_prob_t = np.exp(-self.dist_t / calling_config.cnv_coherence_length) 398 self.cnv_stay_prob_t = th.shared(cnv_stay_prob_t, name='cnv_stay_prob_t', borrow=config.borrow_numpy) 399 400 # copy number values for each copy number state 401 copy_number_values_c = np.arange(0, calling_config.num_copy_number_states, dtype=types.small_uint) 402 self.copy_number_values_c = th.shared(copy_number_values_c, name='copy_number_values_c', 403 borrow=config.borrow_numpy) 404 405 # copy number log posterior and derived quantities (to be initialized by `PosteriorInitializer`) 406 self.log_q_c_stc: Optional[types.TensorSharedVariable] = None 407 408 # latest MAP estimate of integer copy number (to be initialized and periodically updated by 409 # `DenoisingCallingWorkspace.update_auxiliary_vars) 410 self.c_map_st: Optional[types.TensorSharedVariable] = None 411 412 # latest bitmask of CNV-active intervals (to be initialized and periodically updated by 413 # `DenoisingCallingWorkspace.update_auxiliary_vars if q_c_expectation_mode == 'hybrid') 414 self.active_class_bitmask_t: Optional[types.TensorSharedVariable] = None 415 416 # copy number emission log posterior 417 log_copy_number_emission_stc = np.zeros( 418 (self.num_samples, self.num_intervals, calling_config.num_copy_number_states), dtype=types.floatX) 419 self.log_copy_number_emission_stc: types.TensorSharedVariable = th.shared( 420 log_copy_number_emission_stc, name="log_copy_number_emission_stc", borrow=config.borrow_numpy) 421 422 # class log posterior (to be initialized by `PosteriorInitializer`) 423 self.log_q_tau_tk: Optional[types.TensorSharedVariable] = None 424 425 # class emission log posterior 426 # (to be initialized by calling `initialize_copy_number_class_inference_vars`) 427 self.log_class_emission_tk: Optional[types.TensorSharedVariable] = None 428 429 # class assignment prior probabilities 430 # (to be initialized by calling `initialize_copy_number_class_inference_vars`) 431 self.class_probs_k: Optional[types.TensorSharedVariable] = None 432 433 # class Markov chain log prior (initialized here and remains constant throughout) 434 # (to be initialized by calling `initialize_copy_number_class_inference_vars`) 435 self.log_prior_k: Optional[np.ndarray] = None 436 437 # class Markov chain log transition (initialized here and remains constant throughout) 438 # (to be initialized by calling `initialize_copy_number_class_inference_vars`) 439 self.log_trans_tkk: Optional[np.ndarray] = None 440 441 # GC bias factors 442 # (to be initialized by calling `initialize_bias_inference_vars`) 443 self.W_gc_tg: Optional[tst.SparseConstant] = None 444 445 # auxiliary data structures for hybrid q_c_expectation_mode calculation 446 # (to be initialized by calling `initialize_bias_inference_vars`) 447 self.interval_neighbor_index_list: Optional[List[List[int]]] = None 448 449 # denoised copy ratios 450 denoised_copy_ratio_st = np.zeros((self.num_samples, self.num_intervals), dtype=types.floatX) 451 self.denoised_copy_ratio_st: types.TensorSharedVariable = th.shared( 452 denoised_copy_ratio_st, name="denoised_copy_ratio_st", borrow=config.borrow_numpy) 453 454 # initialize posterior 455 posterior_initializer.initialize_posterior(denoising_config, calling_config, self) 456 self.initialize_bias_inference_vars() 457 self.update_auxiliary_vars() 458 459 def initialize_copy_number_class_inference_vars(self): 460 """Initializes members required for copy number class inference (must be called in the cohort mode). 461 The following members are initialized: 462 - `DenoisingCallingWorkspace.log_class_emission_tk` 463 - `DenoisingCallingWorkspace.class_probs_k` 464 - `DenoisingCallingWorkspace.log_prior_k` 465 - `DenoisingCallingWorkspace.log_trans_tkk` 466 """ 467 # class emission log posterior 468 log_class_emission_tk = np.zeros( 469 (self.num_intervals, self.calling_config.num_copy_number_classes), dtype=types.floatX) 470 self.log_class_emission_tk: types.TensorSharedVariable = th.shared( 471 log_class_emission_tk, name="log_class_emission_tk", borrow=True) 472 473 # class assignment prior probabilities 474 # Note: 475 # The first class is the CNV-silent class (highly biased toward the baseline copy number) 476 # The second class is a CNV-active class (all copy number states are equally probable) 477 class_probs_k = np.asarray([1.0 - self.calling_config.p_active, self.calling_config.p_active], 478 dtype=types.floatX) 479 self.class_probs_k: types.TensorSharedVariable = th.shared( 480 class_probs_k, name='class_probs_k', borrow=config.borrow_numpy) 481 482 # class Markov chain log prior (initialized here and remains constant throughout) 483 self.log_prior_k: np.ndarray = np.log(class_probs_k) 484 485 # class Markov chain log transition (initialized here and remains constant throughout) 486 self.log_trans_tkk: np.ndarray = self._get_log_trans_tkk( 487 self.dist_t, 488 self.calling_config.class_coherence_length, 489 self.calling_config.num_copy_number_classes, 490 class_probs_k) 491 492 def initialize_bias_inference_vars(self): 493 """Initializes `DenoisingCallingWorkspace.W_gc_tg` and `DenoisingCallingWorkspace.interval_neighbor_index_list` 494 if required by the model configuration.""" 495 if self.denoising_config.enable_explicit_gc_bias_modeling: 496 self.W_gc_tg = self._create_sparse_gc_bin_tensor_tg( 497 self.interval_list, self.denoising_config.num_gc_bins) 498 499 if self.denoising_config.q_c_expectation_mode == 'hybrid': 500 self.interval_neighbor_index_list = self._get_interval_neighbor_index_list( 501 self.interval_list, self.denoising_config.active_class_padding_hybrid_mode) 502 else: 503 self.interval_neighbor_index_list = None 504 505 def update_auxiliary_vars(self): 506 """Updates `DenoisingCallingWorkspace.c_map_st' and `DenoisingCallingWorkspace.active_class_bitmask_t`.""" 507 # MAP copy number call 508 if self.c_map_st is None: 509 c_map_st = np.zeros((self.num_samples, self.num_intervals), dtype=types.small_uint) 510 self.c_map_st = th.shared(c_map_st, name="c_map_st", borrow=config.borrow_numpy) 511 self.c_map_st.set_value( 512 np.argmax(self.log_q_c_stc.get_value(borrow=True), axis=2).astype(types.small_uint), 513 borrow=config.borrow_numpy) 514 515 if self.denoising_config.q_c_expectation_mode == 'hybrid': 516 _logger.debug("Updating CNV-active class bitmask...") 517 if self.active_class_bitmask_t is None: 518 active_class_bitmask_t = np.zeros((self.num_intervals,), dtype=bool) 519 self.active_class_bitmask_t = th.shared( 520 active_class_bitmask_t, name="active_class_bitmask_t", borrow=config.borrow_numpy) 521 522 # bitmask for intervals of which the probability of being in the silent class is below 0.5 523 active_class_bitmask_t: np.ndarray = \ 524 self.log_q_tau_tk.get_value(borrow=True)[:, 0] < -np.log(2) 525 padded_active_class_bitmask_t = np.zeros_like(active_class_bitmask_t) 526 for ti, neighbor_index_list in enumerate(self.interval_neighbor_index_list): 527 padded_active_class_bitmask_t[ti] = np.any(active_class_bitmask_t[neighbor_index_list]) 528 self.active_class_bitmask_t.set_value( 529 padded_active_class_bitmask_t, borrow=config.borrow_numpy) 530 531 @staticmethod 532 def _get_interval_neighbor_index_list(interval_list: List[Interval], 533 maximum_neighbor_distance: int) -> List[List[int]]: 534 """Pads a given interval list, finds the index of overlapping neighbors, and returns a list of indices of 535 overlapping neighbors. 536 537 Note: 538 It is assumed that the `interval_list` is sorted (this is not asserted). 539 540 Args: 541 interval_list: list of intervals 542 maximum_neighbor_distance: Maximum distance between intervals to be considered neighbors 543 544 Returns: 545 A list of indices of overlapping neighbors with the same length as `interval_list`. Each element 546 in a variable-length list, depending on the number of neighbors. 547 """ 548 assert maximum_neighbor_distance >= 0 549 num_intervals = len(interval_list) 550 padded_interval_list = [interval.get_padded(maximum_neighbor_distance) for interval in interval_list] 551 interval_neighbor_index_list = [] 552 for ti, padded_interval in enumerate(padded_interval_list): 553 overlapping_interval_indices = [ti] 554 right_ti = ti 555 while right_ti < num_intervals - 1: 556 right_ti += 1 557 if interval_list[right_ti].overlaps_with(padded_interval): 558 overlapping_interval_indices.append(right_ti) 559 else: 560 break 561 left_ti = ti 562 while left_ti > 0: 563 left_ti -= 1 564 if interval_list[left_ti].overlaps_with(padded_interval): 565 overlapping_interval_indices.append(left_ti) 566 else: 567 break 568 interval_neighbor_index_list.append(overlapping_interval_indices) 569 return interval_neighbor_index_list 570 571 @staticmethod 572 def _get_log_trans_tkk(dist_t: np.ndarray, 573 class_coherence_length: float, 574 num_copy_number_classes: int, 575 class_probs_k: np.ndarray) -> np.ndarray: 576 """Calculates the log transition probability between copy number classes.""" 577 class_stay_prob_t = np.exp(-dist_t / class_coherence_length) 578 class_not_stay_prob_t = np.ones_like(class_stay_prob_t) - class_stay_prob_t 579 delta_kl = np.eye(num_copy_number_classes, dtype=types.floatX) 580 trans_tkl = (class_not_stay_prob_t[:, None, None] * class_probs_k[None, None, :] 581 + class_stay_prob_t[:, None, None] * delta_kl[None, :, :]) 582 return np.log(trans_tkl) 583 584 @staticmethod 585 def _create_sparse_gc_bin_tensor_tg(interval_list: List[Interval], num_gc_bins: int) -> tst.SparseConstant: 586 """Creates a sparse 2d theano tensor with shape (num_intervals, gc_bin). The sparse 587 tensor represents a 1-hot mapping of each interval to its GC bin index. The range [0, 1] 588 is uniformly divided into num_gc_bins. 589 """ 590 assert all([GCContentAnnotation.get_key() in interval.annotations.keys() for interval in interval_list]), \ 591 "Explicit GC bias modeling is enabled, however, some or all intervals lack \"{0}\" annotation".format( 592 GCContentAnnotation.get_key()) 593 594 def get_gc_bin_idx(gc_content): 595 return min(int(gc_content * num_gc_bins), num_gc_bins - 1) 596 597 num_intervals = len(interval_list) 598 data = np.ones((num_intervals,)) 599 indices = [get_gc_bin_idx(interval.get_annotation(GCContentAnnotation.get_key())) 600 for interval in interval_list] 601 indptr = np.arange(0, num_intervals + 1) 602 scipy_gc_matrix = sp.csr_matrix((data, indices, indptr), shape=(num_intervals, num_gc_bins), 603 dtype=types.small_uint) 604 theano_gc_matrix: tst.SparseConstant = tst.as_sparse(scipy_gc_matrix) 605 return theano_gc_matrix 606 607 @staticmethod 608 def _get_baseline_copy_number_and_read_depth(sample_metadata_collection: SampleMetadataCollection, 609 sample_names: List[str], 610 contig_list: List[str]) \ 611 -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 612 """Generates global read depth array, average ploidy array, and baseline copy numbers for all 613 samples. 614 615 Args: 616 sample_metadata_collection: a instance of `SampleMetadataCollection` containing required metadata 617 for all samples in `sample_names` 618 sample_names: list of sample names 619 contig_list: list of contigs appearing in the modeling interval list 620 621 Returns: 622 global read depth, average ploudy, baseline copy number 623 """ 624 assert sample_metadata_collection.all_samples_have_read_depth_metadata(sample_names), \ 625 "Some samples do not have read depth metadata" 626 assert sample_metadata_collection.all_samples_have_ploidy_metadata(sample_names), \ 627 "Some samples do not have ploidy metadata" 628 num_samples = len(sample_names) 629 num_contigs = len(contig_list) 630 631 global_read_depth_s = np.zeros((num_samples,), dtype=types.floatX) 632 average_ploidy_s = np.zeros((num_samples,), dtype=types.floatX) 633 baseline_copy_number_sj = np.zeros((num_samples, num_contigs), dtype=types.small_uint) 634 635 for si, sample_name in enumerate(sample_names): 636 sample_read_depth_metadata = sample_metadata_collection.get_sample_read_depth_metadata(sample_name) 637 sample_ploidy_metadata = sample_metadata_collection.get_sample_ploidy_metadata(sample_name) 638 639 global_read_depth_s[si] = sample_read_depth_metadata.global_read_depth 640 average_ploidy_s[si] = sample_read_depth_metadata.average_ploidy 641 sample_baseline_copy_number_j = np.asarray([sample_ploidy_metadata.get_contig_ploidy(contig) 642 for contig in contig_list], dtype=types.small_uint) 643 baseline_copy_number_sj[si, :] = sample_baseline_copy_number_j[:] 644 645 return global_read_depth_s, average_ploidy_s, baseline_copy_number_sj 646 647 648class InitialModelParametersSupplier: 649 """Base class for suppliers of initial global model parameters""" 650 def __init__(self, 651 denoising_model_config: DenoisingModelConfig, 652 calling_config: CopyNumberCallingConfig, 653 shared_workspace: DenoisingCallingWorkspace): 654 self.denoising_model_config = denoising_model_config 655 self.calling_config = calling_config 656 self.shared_workspace = shared_workspace 657 658 @abstractmethod 659 def get_init_psi_t(self) -> np.ndarray: 660 """Initial interval-specific unexplained variance.""" 661 raise NotImplementedError 662 663 @abstractmethod 664 def get_init_log_mean_bias_t(self) -> np.ndarray: 665 """Initial mean bias in log space.""" 666 raise NotImplementedError 667 668 @abstractmethod 669 def get_init_ard_u(self) -> np.ndarray: 670 """Initial ARD prior precisions.""" 671 raise NotImplementedError 672 673 674class TrivialInitialModelParametersSupplier(InitialModelParametersSupplier): 675 """Trivial initial model supplier.""" 676 def __init__(self, 677 denoising_model_config: DenoisingModelConfig, 678 calling_config: CopyNumberCallingConfig, 679 shared_workspace: DenoisingCallingWorkspace): 680 super().__init__(denoising_model_config, calling_config, shared_workspace) 681 682 def get_init_psi_t(self) -> np.ndarray: 683 return self.denoising_model_config.psi_t_scale * np.ones( 684 (self.shared_workspace.num_intervals,), dtype=types.floatX) 685 686 def get_init_log_mean_bias_t(self) -> np.ndarray: 687 return np.zeros((self.shared_workspace.num_intervals,), dtype=types.floatX) 688 689 def get_init_ard_u(self) -> np.ndarray: 690 fact = self.denoising_model_config.psi_t_scale * self.denoising_model_config.init_ard_rel_unexplained_variance 691 return fact * np.ones((self.denoising_model_config.max_bias_factors,), dtype=types.floatX) 692 693 694class DenoisingModel(GeneralizedContinuousModel): 695 """The gCNV coverage denoising model declaration (continuous RVs only; discrete posteriors are assumed 696 to be given).""" 697 def __init__(self, 698 denoising_model_config: DenoisingModelConfig, 699 shared_workspace: DenoisingCallingWorkspace, 700 initial_model_parameters_supplier: InitialModelParametersSupplier): 701 super().__init__() 702 self.shared_workspace = shared_workspace 703 register_as_global = self.register_as_global 704 register_as_sample_specific = self.register_as_sample_specific 705 706 eps_mapping = denoising_model_config.mapping_error_rate 707 708 # interval-specific unexplained variance 709 psi_t = Exponential(name='psi_t', lam=1.0 / denoising_model_config.psi_t_scale, 710 shape=(shared_workspace.num_intervals,), 711 broadcastable=(False,)) 712 register_as_global(psi_t) 713 714 # sample-specific unexplained variance 715 psi_s = Exponential(name='psi_s', lam=1.0 / denoising_model_config.psi_s_scale, 716 shape=(shared_workspace.num_samples,), 717 broadcastable=(False,)) 718 register_as_sample_specific(psi_s, sample_axis=0) 719 720 # convert "unexplained variance" to negative binomial over-dispersion 721 alpha_st = tt.maximum(tt.inv(tt.exp(psi_t.dimshuffle('x', 0) + psi_s.dimshuffle(0, 'x')) - 1.0), 722 _eps) 723 724 # interval-specific mean log bias 725 log_mean_bias_t = Normal(name='log_mean_bias_t', mu=0.0, sd=denoising_model_config.log_mean_bias_std, 726 shape=(shared_workspace.num_intervals,), 727 broadcastable=(False,), 728 testval=initial_model_parameters_supplier.get_init_log_mean_bias_t()) 729 register_as_global(log_mean_bias_t) 730 731 # log-normal read depth centered at the global read depth 732 read_depth_mu_s = (np.log(shared_workspace.global_read_depth_s) 733 - 0.5 / denoising_model_config.depth_correction_tau) 734 read_depth_s = Lognormal(name='read_depth_s', 735 mu=read_depth_mu_s, 736 tau=denoising_model_config.depth_correction_tau, 737 shape=(shared_workspace.num_samples,), 738 broadcastable=(False,), 739 testval=shared_workspace.global_read_depth_s) 740 register_as_sample_specific(read_depth_s, sample_axis=0) 741 742 # log bias modelling, starting with the log mean bias 743 log_bias_st = tt.tile(log_mean_bias_t, (shared_workspace.num_samples, 1)) 744 745 if denoising_model_config.enable_bias_factors: 746 # ARD prior precisions 747 ard_u = HalfFlat(name='ard_u', 748 shape=(denoising_model_config.max_bias_factors,), 749 broadcastable=(False,), 750 testval=initial_model_parameters_supplier.get_init_ard_u()) 751 register_as_global(ard_u) 752 753 # bias factors 754 W_tu = Normal(name='W_tu', mu=0.0, tau=ard_u.dimshuffle('x', 0), 755 shape=(shared_workspace.num_intervals, denoising_model_config.max_bias_factors), 756 broadcastable=(False, False)) 757 register_as_global(W_tu) 758 759 # sample-specific bias factor loadings 760 z_su = Normal(name='z_su', mu=0.0, sd=1.0, 761 shape=(shared_workspace.num_samples, denoising_model_config.max_bias_factors), 762 broadcastable=(False, False)) 763 register_as_sample_specific(z_su, sample_axis=0) 764 765 # add contribution to total log bias 766 if denoising_model_config.disable_bias_factors_in_active_class: 767 prob_silent_class_t = tt.exp(shared_workspace.log_q_tau_tk[:, 0]) 768 log_bias_st += (prob_silent_class_t.dimshuffle('x', 0) * tt.dot(W_tu, z_su.T).T) 769 else: 770 log_bias_st += tt.dot(W_tu, z_su.T).T 771 772 # GC bias 773 if denoising_model_config.enable_explicit_gc_bias_modeling: 774 # sample-specific GC bias factor loadings 775 z_sg = Normal(name='z_sg', mu=0.0, sd=denoising_model_config.gc_curve_sd, 776 shape=(shared_workspace.num_samples, denoising_model_config.num_gc_bins), 777 broadcastable=(False, False)) 778 register_as_sample_specific(z_sg, sample_axis=0) 779 780 # add contribution to total log bias 781 log_bias_st += tst.dot(shared_workspace.W_gc_tg, z_sg.T).T 782 783 # useful expressions 784 bias_st = tt.exp(log_bias_st) 785 786 # the expected number of erroneously mapped reads 787 mean_mapping_error_correction_s = eps_mapping * read_depth_s * shared_workspace.average_ploidy_s 788 789 denoised_copy_ratio_st = ((shared_workspace.n_st - mean_mapping_error_correction_s.dimshuffle(0, 'x')) 790 / ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x') * bias_st)) 791 792 Deterministic(name='denoised_copy_ratio_st', var=denoised_copy_ratio_st) 793 794 mu_stc = ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x', 'x') 795 * bias_st.dimshuffle(0, 1, 'x') 796 * shared_workspace.copy_number_values_c.dimshuffle('x', 'x', 0) 797 + mean_mapping_error_correction_s.dimshuffle(0, 'x', 'x')) 798 799 Deterministic(name='log_copy_number_emission_stc', 800 var=commons.negative_binomial_logp( 801 mu_stc, alpha_st.dimshuffle(0, 1, 'x'), shared_workspace.n_st.dimshuffle(0, 1, 'x'))) 802 803 # n_st (observed) 804 if denoising_model_config.q_c_expectation_mode == 'map': 805 def _copy_number_emission_logp(_n_st): 806 mu_st = ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x') * bias_st 807 * shared_workspace.c_map_st + mean_mapping_error_correction_s.dimshuffle(0, 'x')) 808 log_copy_number_emission_st = commons.negative_binomial_logp( 809 mu_st, alpha_st, _n_st) 810 return log_copy_number_emission_st 811 812 elif denoising_model_config.q_c_expectation_mode == 'exact': 813 def _copy_number_emission_logp(_n_st): 814 _log_copy_number_emission_stc = commons.negative_binomial_logp( 815 mu_stc, 816 alpha_st.dimshuffle(0, 1, 'x'), 817 _n_st.dimshuffle(0, 1, 'x')) 818 log_q_c_stc = shared_workspace.log_q_c_stc 819 q_c_stc = tt.exp(log_q_c_stc) 820 return tt.sum(q_c_stc * (_log_copy_number_emission_stc - log_q_c_stc), axis=2) 821 822 elif denoising_model_config.q_c_expectation_mode == 'hybrid': 823 def _copy_number_emission_logp(_n_st): 824 active_class_bitmask_t = self.shared_workspace.active_class_bitmask_t 825 active_class_indices = active_class_bitmask_t.nonzero()[0] 826 silent_class_indices = (1 - active_class_bitmask_t).nonzero()[0] 827 828 # for CNV-active classes, calculate exact posterior expectation 829 mu_active_stc = ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x', 'x') 830 * bias_st.dimshuffle(0, 1, 'x')[:, active_class_indices, :] 831 * shared_workspace.copy_number_values_c.dimshuffle('x', 'x', 0) 832 + mean_mapping_error_correction_s.dimshuffle(0, 'x', 'x')) 833 alpha_active_stc = tt.maximum(tt.inv((tt.exp(psi_t.dimshuffle('x', 0)[:, active_class_indices] 834 + psi_s.dimshuffle(0, 'x')) - 1.0)).dimshuffle(0, 1, 'x'), 835 _eps) 836 n_active_stc = _n_st.dimshuffle(0, 1, 'x')[:, active_class_indices, :] 837 active_class_logp_stc = commons.negative_binomial_logp(mu_active_stc, alpha_active_stc, n_active_stc) 838 log_q_c_active_stc = shared_workspace.log_q_c_stc[:, active_class_indices, :] 839 q_c_active_stc = tt.exp(log_q_c_active_stc) 840 active_class_logp = tt.sum(q_c_active_stc * (active_class_logp_stc - log_q_c_active_stc)) 841 842 # for CNV-silent classes, use MAP copy number state 843 mu_silent_st = ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x') * bias_st[:, silent_class_indices] 844 * shared_workspace.c_map_st[:, silent_class_indices] 845 + mean_mapping_error_correction_s.dimshuffle(0, 'x')) 846 alpha_silent_st = alpha_st[:, silent_class_indices] 847 n_silent_st = _n_st[:, silent_class_indices] 848 silent_class_logp = tt.sum(commons.negative_binomial_logp(mu_silent_st, alpha_silent_st, n_silent_st)) 849 850 return active_class_logp + silent_class_logp 851 852 elif denoising_model_config.q_c_expectation_mode == 'marginalize': 853 def _copy_number_emission_logp(_n_st): 854 _log_copy_number_emission_stc = commons.negative_binomial_logp( 855 mu_stc, 856 alpha_st.dimshuffle(0, 1, 'x'), 857 _n_st.dimshuffle(0, 1, 'x')) 858 return pm.math.logsumexp(shared_workspace.log_q_c_stc + _log_copy_number_emission_stc, axis=2) 859 860 else: 861 raise Exception("Unknown q_c expectation mode; an exception should have been raised earlier") 862 863 DensityDist(name='n_st_obs', 864 logp=_copy_number_emission_logp, 865 observed=shared_workspace.n_st) 866 867 868class CopyNumberEmissionBasicSampler: 869 """Draws posterior samples from log copy number emission probabilities for a given variational 870 approximation to the denoising model continuous RVs.""" 871 def __init__(self, 872 denoising_model_config: DenoisingModelConfig, 873 calling_config: CopyNumberCallingConfig, 874 inference_params: HybridInferenceParameters, 875 shared_workspace: DenoisingCallingWorkspace, 876 denoising_model: DenoisingModel): 877 self.model_config = denoising_model_config 878 self.calling_config = calling_config 879 self.inference_params = inference_params 880 self.shared_workspace = shared_workspace 881 self.denoising_model = denoising_model 882 self._simultaneous_log_copy_number_emission_sampler = None 883 884 def update_approximation(self, approx: pm.approximations.MeanField): 885 """Generates a new compiled sampler based on a given approximation. 886 Args: 887 approx: an instance of PyMC3 mean-field approximation 888 889 Returns: 890 None 891 """ 892 self._simultaneous_log_copy_number_emission_sampler = \ 893 self._get_compiled_simultaneous_log_copy_number_emission_sampler(approx) 894 895 @property 896 def is_sampler_initialized(self): 897 return self._simultaneous_log_copy_number_emission_sampler is not None 898 899 def draw(self) -> np.ndarray: 900 assert self.is_sampler_initialized, "Posterior approximation is not provided yet" 901 return self._simultaneous_log_copy_number_emission_sampler() 902 903 @th.configparser.change_flags(compute_test_value="off") 904 def _get_compiled_simultaneous_log_copy_number_emission_sampler(self, approx: pm.approximations.MeanField): 905 """For a given variational approximation, returns a compiled theano function that draws posterior samples 906 from log copy number emission probabilities.""" 907 log_copy_number_emission_stc = commons.stochastic_node_mean_symbolic( 908 approx, self.denoising_model['log_copy_number_emission_stc'], 909 size=self.inference_params.log_emission_samples_per_round) 910 return th.function(inputs=[], outputs=log_copy_number_emission_stc) 911 912 913class HHMMClassAndCopyNumberBasicCaller: 914 """This class updates copy number and interval class posteriors according to the following hierarchical 915 hidden Markov model: 916 917 class_prior_k --> (tau_1) --> (tau_2) --> (tau_3) --> ... 918 | | | 919 | | | 920 v v v 921 (c_s1) --> (c_s2) --> (c_s3) --> ... 922 | | | 923 | | | 924 v v v 925 n_s1 n_s2 n_s3 926 927 The posterior probability of `tau` and `c_s`, q(tau) and q(c_s) respectively, are obtained via 928 the following variational ansatz: 929 930 \prod_s p(tau, c_s | n) ~ q(tau) \prod_s q(c_s), 931 932 where correlations between intervals are preserved in both chains, however, cross-correlations 933 between `tau` and `c` are neglected, including correlations induced between copy numbers of 934 different samples. As usual, the posteriors are determined by minimizing the KL divergence w.r.t. 935 the true posterior resulting in the following iterative scheme: 936 937 - Given q(tau), the effective copy number prior for the first interval and the effective copy number 938 transition probabilities are determined (see _get_update_copy_number_hmm_specs_compiled_function). 939 Along with the given emission probabilities to sample read counts, q(c_s) is updated using the 940 forward-backward algorithm for each sample (see _update_copy_number_log_posterior) 941 942 - Given q(c_s), the emission probability of each copy number class (tau) is determined 943 (see _get_update_log_class_emission_tk_theano_func). The class prior and transition probabilities 944 are fixed hyperparameters. Therefore, q(tau) can be updated immediately using a single run 945 of forward-backward algorithm (see _update_class_log_posterior). 946 """ 947 CopyNumberForwardBackwardResult = collections.namedtuple( 948 'CopyNumberForwardBackwardResult', 949 'sample_index, new_log_posterior_tc, copy_number_update_size, log_likelihood') 950 951 def __init__(self, 952 calling_config: CopyNumberCallingConfig, 953 inference_params: HybridInferenceParameters, 954 shared_workspace: DenoisingCallingWorkspace, 955 disable_class_update: bool, 956 temperature: types.TensorSharedVariable): 957 self.calling_config = calling_config 958 self.inference_params = inference_params 959 self.shared_workspace = shared_workspace 960 self.disable_class_update = disable_class_update 961 self.temperature = temperature 962 963 # generate the 2-class inventory of copy number priors (CNV-silent, CNV-active) for all samples 964 # according to their respective germline contig ploidies 965 pi_sjkc = np.zeros((shared_workspace.num_samples, 966 shared_workspace.num_contigs, 967 calling_config.num_copy_number_classes, 968 calling_config.num_copy_number_states), dtype=types.floatX) 969 for si in range(shared_workspace.num_samples): 970 pi_sjkc[si, :, :, :] = self.get_copy_number_prior_for_sample_jkc( 971 calling_config.num_copy_number_states, 972 calling_config.p_alt, 973 shared_workspace.baseline_copy_number_sj[si, :])[:, :, :] 974 self.pi_sjkc: types.TensorSharedVariable = th.shared(pi_sjkc, name='pi_sjkc', borrow=config.borrow_numpy) 975 976 # compiled function for forward-backward updates of copy number posterior 977 self._hmm_q_copy_number = TheanoForwardBackward( 978 log_posterior_probs_output_tc=None, 979 resolve_nans=False, 980 do_thermalization=True, 981 do_admixing=True, 982 include_update_size_output=True, 983 include_alpha_beta_output=False) 984 985 if not disable_class_update: 986 # compiled function for forward-backward update of class posterior 987 # Note: 988 # if p_active == 0, we have to deal with inf - inf expressions properly. 989 # setting resolve_nans = True takes care of such ambiguities. 990 self._hmm_q_class = TheanoForwardBackward( 991 log_posterior_probs_output_tc=shared_workspace.log_q_tau_tk, 992 resolve_nans=(calling_config.p_active == 0), 993 do_thermalization=True, 994 do_admixing=True, 995 include_update_size_output=True, 996 include_alpha_beta_output=False) 997 998 # compiled function for update of class log emission 999 self._update_log_class_emission_tk_theano_func = self._get_update_log_class_emission_tk_theano_func() 1000 else: 1001 self._hmm_q_class: Optional[TheanoForwardBackward] = None 1002 self._update_log_class_emission_tk_theano_func = None 1003 1004 # compiled function for variational update of copy number HMM specs 1005 self._get_copy_number_hmm_specs_theano_func = self.get_compiled_copy_number_hmm_specs_theano_func() 1006 1007 @staticmethod 1008 def get_copy_number_prior_for_sample_jkc(num_copy_number_states: int, 1009 p_alt: float, 1010 baseline_copy_number_j: np.ndarray) -> np.ndarray: 1011 """Returns copy-number prior probabilities for each contig (j) and class (k) as a 3d ndarray. 1012 1013 Args: 1014 num_copy_number_states: total number of copy-number states 1015 p_alt: total probability of alt copy-number states 1016 baseline_copy_number_j: baseline copy-number state for each contig 1017 1018 Returns: 1019 a 3d ndarray 1020 """ 1021 p_baseline = 1.0 - (num_copy_number_states - 1) * p_alt 1022 pi_jkc = np.zeros((len(baseline_copy_number_j), 2, num_copy_number_states), dtype=types.floatX) 1023 for j, baseline_state in enumerate(baseline_copy_number_j): 1024 # the silent class 1025 pi_jkc[j, 0, :] = p_alt 1026 pi_jkc[j, 0, baseline_state] = p_baseline 1027 # the active class 1028 pi_jkc[j, 1, :] = 1.0 / num_copy_number_states 1029 1030 return pi_jkc 1031 1032 def call(self, 1033 copy_number_update_summary_statistic_reducer, 1034 class_update_summary_statistic_reducer) -> Tuple[np.ndarray, np.ndarray, float, float]: 1035 """Perform a round of update of q(tau) and q(c) 1036 1037 Note: 1038 This function must be called until q(tau) and q(c) converge to a self-consistent solution. 1039 1040 Args: 1041 copy_number_update_summary_statistic_reducer: a function that reduces vectors to scalars and 1042 is used to compile a summary of copy number posterior updates across intervals for each sample 1043 class_update_summary_statistic_reducer: a function that reduces vectors to scalars and 1044 is used to compile a summary of interval class posterior updates across intervals 1045 1046 Returns: 1047 copy number update summary (ndarray of size `num_samples`), 1048 copy number Markov chain log likelihoods (ndarray of size `num_samples`), 1049 interval class update summary, 1050 interval class Markov chain log likelihood 1051 """ 1052 # copy number posterior update 1053 copy_number_update_s, copy_number_log_likelihoods_s = self._update_copy_number_log_posterior( 1054 copy_number_update_summary_statistic_reducer) 1055 1056 if not self.disable_class_update: 1057 # class posterior update 1058 self._update_log_class_emission_tk() 1059 class_update, class_log_likelihood = self._update_class_log_posterior( 1060 class_update_summary_statistic_reducer) 1061 else: 1062 class_update = None 1063 class_log_likelihood = None 1064 1065 return copy_number_update_s, copy_number_log_likelihoods_s, class_update, class_log_likelihood 1066 1067 def _update_copy_number_log_posterior(self, copy_number_update_summary_statistic_reducer) \ 1068 -> Tuple[np.ndarray, np.ndarray]: 1069 ws = self.shared_workspace 1070 copy_number_update_s = np.zeros((ws.num_samples,), dtype=types.floatX) 1071 copy_number_log_likelihoods_s = np.zeros((ws.num_samples,), dtype=types.floatX) 1072 num_calling_processes = self.calling_config.num_calling_processes 1073 1074 def _run_single_sample_fb(_sample_index: int): 1075 # step 1. calculate copy-number HMM log prior and log transition matrix 1076 pi_jkc = self.pi_sjkc.get_value(borrow=True)[_sample_index, ...] 1077 cnv_stay_prob_t = self.shared_workspace.cnv_stay_prob_t.get_value(borrow=True) 1078 log_q_tau_tk = self.shared_workspace.log_q_tau_tk.get_value(borrow=True) 1079 t_to_j_map = self.shared_workspace.t_to_j_map.get_value(borrow=True) 1080 hmm_spec = self._get_copy_number_hmm_specs_theano_func(pi_jkc, cnv_stay_prob_t, log_q_tau_tk, t_to_j_map) 1081 log_prior_c = hmm_spec[0] 1082 log_trans_tcc = hmm_spec[1] 1083 1084 prev_log_posterior_tc = ws.log_q_c_stc.get_value(borrow=True)[_sample_index, ...] 1085 log_copy_number_emission_tc = ws.log_copy_number_emission_stc.get_value(borrow=True)[_sample_index, ...] 1086 1087 # step 2. run forward-backward and update copy-number posteriors 1088 _fb_result = self._hmm_q_copy_number.perform_forward_backward( 1089 log_prior_c, log_trans_tcc, log_copy_number_emission_tc, 1090 prev_log_posterior_tc=prev_log_posterior_tc, 1091 admixing_rate=self.inference_params.caller_internal_admixing_rate, 1092 temperature=self.temperature.get_value()[0]) 1093 new_log_posterior_tc = _fb_result.log_posterior_probs_tc 1094 copy_number_update_size = copy_number_update_summary_statistic_reducer(_fb_result.update_norm_t) 1095 log_likelihood = float(_fb_result.log_data_likelihood) 1096 1097 return self.CopyNumberForwardBackwardResult( 1098 _sample_index, new_log_posterior_tc, copy_number_update_size, log_likelihood) 1099 1100 def _update_log_q_c_stc_inplace(log_q_c_stc, _sample_index, new_log_posterior_tc): 1101 log_q_c_stc[_sample_index, :, :] = new_log_posterior_tc[:, :] 1102 return log_q_c_stc 1103 1104 max_chunks = ws.num_samples // num_calling_processes + 1 1105 for chunk_index in range(max_chunks): 1106 begin_index = chunk_index * num_calling_processes 1107 end_index = min((chunk_index + 1) * num_calling_processes, ws.num_samples) 1108 if begin_index >= ws.num_samples: 1109 break 1110 # todo multiprocessing 1111 # with mp.Pool(processes=num_calling_processes) as pool: 1112 # for fb_result in pool.map(_run_single_sample_fb, range(begin_index, end_index)): 1113 for fb_result in [_run_single_sample_fb(sample_index) 1114 for sample_index in range(begin_index, end_index)]: 1115 # update log posterior in the workspace 1116 ws.log_q_c_stc.set_value( 1117 _update_log_q_c_stc_inplace( 1118 ws.log_q_c_stc.get_value(borrow=True), 1119 fb_result.sample_index, fb_result.new_log_posterior_tc), 1120 borrow=True) 1121 # update summary stats 1122 copy_number_update_s[fb_result.sample_index] = fb_result.copy_number_update_size 1123 copy_number_log_likelihoods_s[fb_result.sample_index] = fb_result.log_likelihood 1124 1125 return copy_number_update_s, copy_number_log_likelihoods_s 1126 1127 def _update_log_class_emission_tk(self): 1128 self._update_log_class_emission_tk_theano_func() 1129 1130 def _update_class_log_posterior(self, class_update_summary_statistic_reducer) -> Tuple[float, float]: 1131 fb_result = self._hmm_q_class.perform_forward_backward( 1132 self.shared_workspace.log_prior_k, 1133 self.shared_workspace.log_trans_tkk, 1134 self.shared_workspace.log_class_emission_tk.get_value(borrow=True), 1135 prev_log_posterior_tc=self.shared_workspace.log_q_tau_tk.get_value(borrow=True), 1136 admixing_rate=self.inference_params.caller_internal_admixing_rate, 1137 temperature=self.temperature.get_value()[0]) 1138 class_update_size = class_update_summary_statistic_reducer(fb_result.update_norm_t) 1139 log_likelihood = float(fb_result.log_data_likelihood) 1140 return class_update_size, log_likelihood 1141 1142 def update_auxiliary_vars(self): 1143 self.shared_workspace.update_auxiliary_vars() 1144 1145 @staticmethod 1146 @th.configparser.change_flags(compute_test_value="off") 1147 def get_compiled_copy_number_hmm_specs_theano_func() -> th.compile.function_module.Function: 1148 """Returns a compiled function that calculates the interval-class-averaged and probability-sum-normalized 1149 log copy number transition matrix and log copy number prior for the first interval 1150 1151 Returned theano function inputs: 1152 pi_jkc: a 3d tensor containing copy-number priors for each contig (j) and each class (k) 1153 cnv_stay_prob_t: probability of staying on the same copy-number state at interval `t` 1154 log_q_tau_tk: log probability of copy-number classes at interval `t` 1155 t_to_j_map: a mapping from interval indices (t) to contig indices (j); it is used to unpack 1156 `pi_jkc` to `pi_tkc` (see below) 1157 1158 Returned theano function outputs: 1159 log_prior_c_first_interval: log probability of copy-number states for the first interval 1160 log_trans_tab: log transition probability matrix from interval `t` to interval `t+1` 1161 1162 Note: 1163 In the following, we use "a" and "b" subscripts in the variable names to refer to the departure 1164 and destination states, respectively. Like before, "t" and "k" denote interval and class, and "j" 1165 refers to contig index. 1166 """ 1167 # shorthands 1168 pi_jkc = tt.tensor3(name='pi_jkc') 1169 cnv_stay_prob_t = tt.vector(name='cnv_stay_prob_t') 1170 log_q_tau_tk = tt.matrix(name='log_q_tau_tk') 1171 t_to_j_map = tt.vector(name='t_to_j_map', dtype=tt.scal.uint32) 1172 1173 # log prior probability for the first interval 1174 log_prior_c_first_interval = tt.dot(tt.log(pi_jkc[t_to_j_map[0], :, :].T), tt.exp(log_q_tau_tk[0, :])) 1175 log_prior_c_first_interval -= pm.logsumexp(log_prior_c_first_interval) 1176 1177 # log transition matrix 1178 cnv_not_stay_prob_t = tt.ones_like(cnv_stay_prob_t) - cnv_stay_prob_t 1179 num_copy_number_states = pi_jkc.shape[2] 1180 delta_ab = tt.eye(num_copy_number_states) 1181 1182 # map contig to interval and obtain pi_tkc for the rest of the targets 1183 pi_tkc = pi_jkc[t_to_j_map[1:], :, :] 1184 1185 # calculate normalized log transition matrix 1186 # todo use logaddexp 1187 log_trans_tkab = tt.log(cnv_not_stay_prob_t.dimshuffle(0, 'x', 'x', 'x') * pi_tkc.dimshuffle(0, 1, 'x', 2) 1188 + cnv_stay_prob_t.dimshuffle(0, 'x', 'x', 'x') * delta_ab.dimshuffle('x', 'x', 0, 1)) 1189 q_tau_tkab = tt.exp(log_q_tau_tk[1:, :]).dimshuffle(0, 1, 'x', 'x') 1190 log_trans_tab = tt.sum(q_tau_tkab * log_trans_tkab, axis=1) 1191 log_trans_tab -= pm.logsumexp(log_trans_tab, axis=2) 1192 1193 inputs = [pi_jkc, cnv_stay_prob_t, log_q_tau_tk, t_to_j_map] 1194 outputs = [log_prior_c_first_interval, log_trans_tab] 1195 1196 return th.function(inputs=inputs, outputs=outputs) 1197 1198 @th.configparser.change_flags(compute_test_value="off") 1199 def _get_update_log_class_emission_tk_theano_func(self) -> th.compile.function_module.Function: 1200 """Returns a compiled function that calculates the log interval class emission probability and 1201 directly updates `log_class_emission_tk` in the workspace. 1202 1203 Note: 1204 In the following, 1205 1206 xi_tab ~ posterior copy number probability of two subsequent intervals. 1207 1208 We ignore correlations, i.e. we assume: 1209 1210 xi_st(a, b) \equiv q_c(c_{s,t} = a, c_{s,t+1} = b) 1211 \approx q_c(c_{s,t} = a) q_c(c_{s,t+1} = b) 1212 1213 If needed, xi can be calculated exactly from the forward-backward tables. 1214 """ 1215 # shorthands 1216 cnv_stay_prob_t = self.shared_workspace.cnv_stay_prob_t 1217 q_c_stc = tt.exp(self.shared_workspace.log_q_c_stc) 1218 pi_sjkc = self.pi_sjkc 1219 t_to_j_map = self.shared_workspace.t_to_j_map 1220 num_copy_number_states = self.calling_config.num_copy_number_states 1221 1222 # log copy number transition matrix for each class 1223 cnv_not_stay_prob_t = tt.ones_like(cnv_stay_prob_t) - cnv_stay_prob_t 1224 delta_ab = tt.eye(num_copy_number_states) 1225 1226 # calculate log class emission by reducing over samples; see below 1227 log_class_emission_cum_sum_tk = tt.zeros((self.shared_workspace.num_intervals - 1, 1228 self.calling_config.num_copy_number_classes), 1229 dtype=types.floatX) 1230 1231 def inc_log_class_emission_tk_except_for_first_interval(pi_jkc, q_c_tc, cum_sum_tk): 1232 """Adds the contribution of a given sample to the log class emission (symbolically). 1233 1234 Args: 1235 pi_jkc: copy number prior inventory for the sample 1236 q_c_tc: copy number posteriors for the sample 1237 cum_sum_tk: current cumulative sum of log class emission 1238 1239 Returns: 1240 Symbolically updated cumulative sum of log class emission 1241 """ 1242 # map contigs to targets (starting from the second interval) 1243 pi_tkc = pi_jkc[t_to_j_map[1:], :, :] 1244 1245 # todo use logaddexp 1246 log_trans_tkab = tt.log( 1247 cnv_not_stay_prob_t.dimshuffle(0, 'x', 'x', 'x') * pi_tkc.dimshuffle(0, 1, 'x', 2) 1248 + cnv_stay_prob_t.dimshuffle(0, 'x', 'x', 'x') * delta_ab.dimshuffle('x', 'x', 0, 1)) 1249 xi_tab = q_c_tc[:-1, :].dimshuffle(0, 1, 'x') * q_c_tc[1:, :].dimshuffle(0, 'x', 1) 1250 current_log_class_emission_tk = tt.sum(tt.sum( 1251 xi_tab.dimshuffle(0, 'x', 1, 2) * log_trans_tkab, axis=-1), axis=-1) 1252 return cum_sum_tk + current_log_class_emission_tk 1253 1254 reduce_output = th.reduce(inc_log_class_emission_tk_except_for_first_interval, 1255 sequences=[pi_sjkc, q_c_stc], 1256 outputs_info=[log_class_emission_cum_sum_tk]) 1257 log_class_emission_tk_except_for_first_interval = reduce_output[0] 1258 1259 # the first interval 1260 pi_skc_first = pi_sjkc[:, t_to_j_map[0], :, :] 1261 q_skc_first = q_c_stc[:, 0, :].dimshuffle(0, 'x', 1) 1262 log_class_emission_k_first = tt.sum(tt.sum(tt.log(pi_skc_first) * q_skc_first, axis=0), axis=-1) 1263 1264 # concatenate first and rest 1265 log_class_emission_tk = tt.concatenate((log_class_emission_k_first.dimshuffle('x', 0), 1266 log_class_emission_tk_except_for_first_interval)) 1267 1268 return th.function(inputs=[], outputs=[], updates=[ 1269 (self.shared_workspace.log_class_emission_tk, log_class_emission_tk)]) 1270