1import argparse
2import collections
3import inspect
4import json
5import logging
6from abc import abstractmethod
7from typing import List, Tuple, Set, Dict, Optional
8
9import numpy as np
10import pymc3 as pm
11import scipy.sparse as sp
12import theano as th
13import theano.sparse as tst
14import theano.tensor as tt
15from pymc3 import Normal, Deterministic, DensityDist, Lognormal, Exponential
16
17from . import commons
18from .dists import HalfFlat
19from .fancy_model import GeneralizedContinuousModel
20from .theano_hmm import TheanoForwardBackward
21from .. import config, types
22from ..structs.interval import Interval, GCContentAnnotation
23from ..structs.metadata import SampleMetadataCollection
24from ..tasks.inference_task_base import HybridInferenceParameters
25
26_logger = logging.getLogger(__name__)
27
28_eps = commons.eps
29
30
31class DenoisingModelConfig:
32    """Configuration for the coverage denoising model, including hyper-parameters, model feature selection,
33    and choice of approximation schemes."""
34
35    # approximation schemes for calculating expectations with respect to copy number posteriors
36    _q_c_expectation_modes = ['map', 'exact', 'hybrid']
37
38    def __init__(self,
39                 max_bias_factors: int = 5,
40                 mapping_error_rate: float = 0.01,
41                 psi_t_scale: float = 0.001,
42                 psi_s_scale: float = 0.0001,
43                 depth_correction_tau: float = 10000.0,
44                 log_mean_bias_std: float = 0.1,
45                 init_ard_rel_unexplained_variance: float = 0.1,
46                 num_gc_bins: int = 20,
47                 gc_curve_sd: float = 1.0,
48                 q_c_expectation_mode: str = 'hybrid',
49                 active_class_padding_hybrid_mode: int = 50000,
50                 enable_bias_factors: bool = True,
51                 enable_explicit_gc_bias_modeling: bool = False,
52                 disable_bias_factors_in_active_class: bool = False):
53        """See `expose_args` for the description of arguments"""
54        self.max_bias_factors = max_bias_factors
55        self.mapping_error_rate = mapping_error_rate
56        self.psi_t_scale = psi_t_scale
57        self.psi_s_scale = psi_s_scale
58        self.depth_correction_tau = depth_correction_tau
59        self.log_mean_bias_std = log_mean_bias_std
60        self.init_ard_rel_unexplained_variance = init_ard_rel_unexplained_variance
61        self.num_gc_bins = num_gc_bins
62        self.gc_curve_sd = gc_curve_sd
63        self.q_c_expectation_mode = q_c_expectation_mode
64        self.active_class_padding_hybrid_mode = active_class_padding_hybrid_mode
65        self.enable_bias_factors = enable_bias_factors
66        self.enable_explicit_gc_bias_modeling = enable_explicit_gc_bias_modeling
67        self.disable_bias_factors_in_active_class = disable_bias_factors_in_active_class
68
69    @staticmethod
70    def expose_args(args: argparse.ArgumentParser,
71                    hide: Set[str] = None):
72        """Exposes arguments of `__init__` to a given instance of `ArgumentParser`.
73
74        Args:
75            args: an instance of `ArgumentParser`
76            hide: a set of arguments not to expose
77
78        Returns:
79            None
80        """
81        group = args.add_argument_group(title="Coverage denoising model parameters")
82        if hide is None:
83            hide = set()
84
85        initializer_params = inspect.signature(DenoisingModelConfig.__init__).parameters
86        valid_args = {"--" + arg for arg in initializer_params.keys()}
87        for hidden_arg in hide:
88            assert hidden_arg in valid_args, \
89                "Initializer argument to be hidden {0} is not a valid initializer arguments; possible " \
90                "choices are: {1}".format(hidden_arg, valid_args)
91
92        def process_and_maybe_add(arg, **kwargs):
93            full_arg = "--" + arg
94            if full_arg in hide:
95                return
96            kwargs['default'] = initializer_params[arg].default
97            group.add_argument(full_arg, **kwargs)
98
99        def str_to_bool(value: str):
100            if value.lower() in ('yes', 'true', 't', 'y', '1'):
101                return True
102            elif value.lower() in ('no', 'false', 'f', 'n', '0'):
103                return False
104            else:
105                raise argparse.ArgumentTypeError('Boolean value expected.')
106
107        process_and_maybe_add("max_bias_factors",
108                              type=int,
109                              help="Maximum number of bias factors")
110
111        process_and_maybe_add("mapping_error_rate",
112                              type=float,
113                              help="Typical mapping error rate")
114
115        process_and_maybe_add("psi_t_scale",
116                              type=float,
117                              help="Typical scale of interval-specific unexplained variance")
118
119        process_and_maybe_add("psi_s_scale",
120                              type=float,
121                              help="Typical scale of sample-specific unexplained variance")
122
123        process_and_maybe_add("depth_correction_tau",
124                              type=float,
125                              help="Precision of pinning read-depth in the coverage denoising model "
126                                   "to its globally determined value")
127
128        process_and_maybe_add("log_mean_bias_std",
129                              type=float,
130                              help="Standard deviation of mean bias in log space")
131
132        process_and_maybe_add("init_ard_rel_unexplained_variance",
133                              type=float,
134                              help="Initial value of automatic relevance determination (ARD) precisions relative "
135                                   "to the typical interval-specific unexplained variance scale")
136
137        process_and_maybe_add("num_gc_bins",
138                              type=int,
139                              help="Number of knobs on the GC curve")
140
141        process_and_maybe_add("gc_curve_sd",
142                              type=float,
143                              help="Prior standard deviation of the GC curve from a flat curve")
144
145        process_and_maybe_add("q_c_expectation_mode",
146                              type=str,
147                              choices=DenoisingModelConfig._q_c_expectation_modes,
148                              help="The strategy for calculating copy number posterior expectations in the denoising "
149                                   "model. Choices: \"exact\": summation over all states, \"map\": drop all terms "
150                                   "except for the maximum a posteriori (MAP) copy number estimate, \"hybrid\": "
151                                   "use MAP strategy in silent regions and exact strategy in active regions.")
152
153        process_and_maybe_add("active_class_padding_hybrid_mode",
154                              type=int,
155                              help="If q_c_expectation_mode is set to \"hybrid\", the active intervals "
156                                   "will be further padded by this value (in the units of bp) in order to achieve "
157                                   "higher sensitivity in detecting common CNVs and to avoid boundary artifacts")
158
159        process_and_maybe_add("enable_bias_factors",
160                              type=str_to_bool,
161                              help="Enable discovery of novel bias factors")
162
163        process_and_maybe_add("enable_explicit_gc_bias_modeling",
164                              type=str_to_bool,
165                              help="Enable explicit modeling of GC bias (if enabled, the provided modeling interval "
166                                   "list of contain a column for {0} values)".format(GCContentAnnotation.get_key()))
167
168        process_and_maybe_add("disable_bias_factors_in_active_class",
169                              type=str_to_bool,
170                              help="Disable novel bias factor discovery CNV-active regions")
171
172    @staticmethod
173    def from_args_dict(args_dict: Dict) -> 'DenoisingModelConfig':
174        """Initialize an instance of `DenoisingModelConfig` from a dictionary of arguments.
175
176        Args:
177            args_dict: a dictionary of arguments; the keys must match argument names in
178                `DenoisingModelConfig.__init__`
179
180        Returns:
181            an instance of `DenoisingModelConfig`
182        """
183        relevant_keys = set(inspect.getfullargspec(DenoisingModelConfig.__init__).args)
184        relevant_kwargs = {k: v for k, v in args_dict.items() if k in relevant_keys}
185        return DenoisingModelConfig(**relevant_kwargs)
186
187    @staticmethod
188    def from_json_file(json_file: str) -> 'DenoisingModelConfig':
189        with open(json_file, 'r') as fp:
190            imported_denoising_config_dict = json.load(fp)
191        return DenoisingModelConfig.from_args_dict(imported_denoising_config_dict)
192
193
194class CopyNumberCallingConfig:
195    """Configuration of the copy number caller."""
196    def __init__(self,
197                 p_alt: float = 1e-6,
198                 p_active: float = 1e-3,
199                 cnv_coherence_length: float = 10000.0,
200                 class_coherence_length: float = 10000.0,
201                 max_copy_number: int = 5,
202                 num_calling_processes: int = 1):
203        """See `expose_args` for the description of arguments"""
204        assert 0.0 <= p_alt <= 1.0
205        assert 0.0 <= p_active <= 1.0
206        assert cnv_coherence_length > 0.0
207        assert class_coherence_length > 0.0
208        assert max_copy_number > 0
209        assert max_copy_number * p_alt < 1.0
210        assert num_calling_processes > 0
211
212        self.p_alt = p_alt
213        self.p_active = p_active
214        self.cnv_coherence_length = cnv_coherence_length
215        self.class_coherence_length = class_coherence_length
216        self.max_copy_number = max_copy_number
217        self.num_calling_processes = num_calling_processes
218
219        self.num_copy_number_states = max_copy_number + 1
220        self.num_copy_number_classes = 2
221
222    @staticmethod
223    def expose_args(args: argparse.ArgumentParser, hide: Set[str] = None):
224        """Exposes arguments of `__init__` to a given instance of `ArgumentParser`.
225
226        Args:
227            args: an instance of `ArgumentParser`
228            hide: a set of arguments not to expose
229
230        Returns:
231            None
232        """
233        group = args.add_argument_group(title="Copy number calling parameters")
234        if hide is None:
235            hide = set()
236
237        initializer_params = inspect.signature(CopyNumberCallingConfig.__init__).parameters
238        valid_args = {"--" + arg for arg in initializer_params.keys()}
239        for hidden_arg in hide:
240            assert hidden_arg in valid_args, \
241                "Initializer argument to be hidden {0} is not a valid initializer arguments; possible " \
242                "choices are: {1}".format(hidden_arg, valid_args)
243
244        def process_and_maybe_add(arg, **kwargs):
245            full_arg = "--" + arg
246            if full_arg in hide:
247                return
248            kwargs['default'] = initializer_params[arg].default
249            group.add_argument(full_arg, **kwargs)
250
251        def str_to_bool(value: str):
252            if value.lower() in ('yes', 'true', 't', 'y', '1'):
253                return True
254            elif value.lower() in ('no', 'false', 'f', 'n', '0'):
255                return False
256            else:
257                raise argparse.ArgumentTypeError('Boolean value expected.')
258
259        process_and_maybe_add("p_alt",
260                              type=float,
261                              help="Prior probability of alternate copy number with respect to contig baseline "
262                                   "state in CNV-silent intervals")
263
264        process_and_maybe_add("p_active",
265                              type=float,
266                              help="Prior probability of treating an interval as CNV-active")
267
268        process_and_maybe_add("cnv_coherence_length",
269                              type=float,
270                              help="Coherence length of CNV events (in the units of bp)")
271
272        process_and_maybe_add("class_coherence_length",
273                              type=float,
274                              help="Coherence length of CNV-silent and CNV-active domains (in the units of bp)")
275
276        process_and_maybe_add("max_copy_number",
277                              type=int,
278                              help="Highest called copy number state")
279
280        process_and_maybe_add("num_calling_processes",
281                              type=int,
282                              help="Number of concurrent forward-backward threads (not implemented yet)")
283
284    @staticmethod
285    def from_args_dict(args_dict: Dict):
286        """Initialize an instance of `CopyNumberCallingConfig` from a dictionary of arguments.
287
288        Args:
289            args_dict: a dictionary of arguments; the keys must match argument names in
290                `CopyNumberCallingConfig.__init__`
291
292        Returns:
293            an instance of `CopyNumberCallingConfig`
294        """
295        relevant_keys = set(inspect.getfullargspec(CopyNumberCallingConfig.__init__).args)
296        relevant_kwargs = {k: v for k, v in args_dict.items() if k in relevant_keys}
297        return CopyNumberCallingConfig(**relevant_kwargs)
298
299    @staticmethod
300    def from_json_file(json_file: str) -> 'CopyNumberCallingConfig':
301        with open(json_file, 'r') as fp:
302            imported_calling_config_dict = json.load(fp)
303        return CopyNumberCallingConfig.from_args_dict(imported_calling_config_dict)
304
305
306class PosteriorInitializer:
307    """Base class for posterior initializers."""
308    @staticmethod
309    @abstractmethod
310    def initialize_posterior(denoising_config: DenoisingModelConfig,
311                             calling_config: CopyNumberCallingConfig,
312                             shared_workspace: 'DenoisingCallingWorkspace') -> None:
313        raise NotImplementedError
314
315
316class TrivialPosteriorInitializer(PosteriorInitializer):
317    """Initialize posteriors to reasonable values based on priors."""
318    @staticmethod
319    def initialize_posterior(denoising_config: DenoisingModelConfig,
320                             calling_config: CopyNumberCallingConfig,
321                             shared_workspace: 'DenoisingCallingWorkspace'):
322        # interval class log posterior probs
323        class_probs_k = np.asarray([1.0 - calling_config.p_active, calling_config.p_active], dtype=types.floatX)
324        log_q_tau_tk = np.tile(np.log(class_probs_k), (shared_workspace.num_intervals, 1))
325        shared_workspace.log_q_tau_tk = th.shared(log_q_tau_tk, name="log_q_tau_tk", borrow=config.borrow_numpy)
326
327        # copy number log posterior probs
328        log_q_c_stc = np.zeros((shared_workspace.num_samples, shared_workspace.num_intervals,
329                                calling_config.num_copy_number_states), dtype=types.floatX)
330        t_to_j_map = shared_workspace.t_to_j_map.get_value(borrow=True)
331        for si in range(shared_workspace.num_samples):
332            sample_baseline_copy_number_j = shared_workspace.baseline_copy_number_sj[si, :]
333            sample_pi_jkc = HHMMClassAndCopyNumberBasicCaller.get_copy_number_prior_for_sample_jkc(
334                calling_config.num_copy_number_states,
335                calling_config.p_alt,
336                sample_baseline_copy_number_j)
337            sample_log_pi_jc = np.log(np.sum(sample_pi_jkc * class_probs_k[np.newaxis, :, np.newaxis], axis=1))
338            for ti in range(shared_workspace.num_intervals):
339                log_q_c_stc[si, ti, :] = sample_log_pi_jc[t_to_j_map[ti], :]
340        shared_workspace.log_q_c_stc = th.shared(log_q_c_stc, name="log_q_c_stc", borrow=config.borrow_numpy)
341
342
343class DenoisingCallingWorkspace:
344    """This class contains objects (numpy arrays, theano tensors, etc) shared between the denoising model
345    and the copy number caller."""
346    def __init__(self,
347                 denoising_config: DenoisingModelConfig,
348                 calling_config: CopyNumberCallingConfig,
349                 interval_list: List[Interval],
350                 n_st: np.ndarray,
351                 sample_names: List[str],
352                 sample_metadata_collection: SampleMetadataCollection,
353                 posterior_initializer: Optional[PosteriorInitializer] = TrivialPosteriorInitializer):
354        self.denoising_config = denoising_config
355        self.calling_config = calling_config
356        self.interval_list = interval_list
357        self.sample_names = sample_names
358
359        assert n_st.ndim == 2, "Read counts matrix must be a 2-dim ndarray with shape (num_samples, num_intervals)"
360
361        self.num_samples: int = n_st.shape[0]
362        self.num_intervals: int = n_st.shape[1]
363
364        assert self.num_intervals >= 2, "At least two intervals must be provided"
365        assert len(interval_list) == self.num_intervals,\
366            "The length of the interval list is incompatible with the shape of the read counts matrix"
367
368        # a list of unique contigs appearing in the interval list; the ordering is arbitrary and
369        # is only used internally
370        # Note: j is the index subscript used for contig index hereafter
371        self.contig_list = list({interval.contig for interval in interval_list})
372        self.num_contigs = len(self.contig_list)
373        contig_to_j_map = {contig: self.contig_list.index(contig) for contig in self.contig_list}
374        t_to_j_map = np.asarray([contig_to_j_map[interval.contig] for interval in interval_list],
375                                dtype=types.small_uint)
376        self.t_to_j_map: types.TensorSharedVariable = th.shared(
377            t_to_j_map, name="t_to_j_map", borrow=config.borrow_numpy)
378
379        self.global_read_depth_s, average_ploidy_s, self.baseline_copy_number_sj = \
380            DenoisingCallingWorkspace._get_baseline_copy_number_and_read_depth(
381                sample_metadata_collection, sample_names, self.contig_list)
382
383        max_baseline_copy_number = np.max(self.baseline_copy_number_sj)
384        assert max_baseline_copy_number <= calling_config.max_copy_number, \
385            "The highest contig ploidy ({0}) must be smaller or equal to the highest copy number state ({1})".format(
386                max_baseline_copy_number, calling_config.max_copy_number)
387
388        # shared theano tensors from the input data
389        self.n_st: types.TensorSharedVariable = th.shared(
390            n_st.astype(types.med_uint), name="n_st", borrow=config.borrow_numpy)
391        self.average_ploidy_s: types.TensorSharedVariable = th.shared(
392            average_ploidy_s.astype(types.floatX), name="average_ploidy_s", borrow=config.borrow_numpy)
393
394        # copy-number event stay probability
395        self.dist_t = np.asarray([self.interval_list[ti + 1].distance(self.interval_list[ti])
396                                  for ti in range(self.num_intervals - 1)])
397        cnv_stay_prob_t = np.exp(-self.dist_t / calling_config.cnv_coherence_length)
398        self.cnv_stay_prob_t = th.shared(cnv_stay_prob_t, name='cnv_stay_prob_t', borrow=config.borrow_numpy)
399
400        # copy number values for each copy number state
401        copy_number_values_c = np.arange(0, calling_config.num_copy_number_states, dtype=types.small_uint)
402        self.copy_number_values_c = th.shared(copy_number_values_c, name='copy_number_values_c',
403                                              borrow=config.borrow_numpy)
404
405        # copy number log posterior and derived quantities (to be initialized by `PosteriorInitializer`)
406        self.log_q_c_stc: Optional[types.TensorSharedVariable] = None
407
408        # latest MAP estimate of integer copy number (to be initialized and periodically updated by
409        #   `DenoisingCallingWorkspace.update_auxiliary_vars)
410        self.c_map_st: Optional[types.TensorSharedVariable] = None
411
412        # latest bitmask of CNV-active intervals (to be initialized and periodically updated by
413        #   `DenoisingCallingWorkspace.update_auxiliary_vars if q_c_expectation_mode == 'hybrid')
414        self.active_class_bitmask_t: Optional[types.TensorSharedVariable] = None
415
416        # copy number emission log posterior
417        log_copy_number_emission_stc = np.zeros(
418            (self.num_samples, self.num_intervals, calling_config.num_copy_number_states), dtype=types.floatX)
419        self.log_copy_number_emission_stc: types.TensorSharedVariable = th.shared(
420            log_copy_number_emission_stc, name="log_copy_number_emission_stc", borrow=config.borrow_numpy)
421
422        # class log posterior (to be initialized by `PosteriorInitializer`)
423        self.log_q_tau_tk: Optional[types.TensorSharedVariable] = None
424
425        # class emission log posterior
426        # (to be initialized by calling `initialize_copy_number_class_inference_vars`)
427        self.log_class_emission_tk: Optional[types.TensorSharedVariable] = None
428
429        # class assignment prior probabilities
430        # (to be initialized by calling `initialize_copy_number_class_inference_vars`)
431        self.class_probs_k: Optional[types.TensorSharedVariable] = None
432
433        # class Markov chain log prior (initialized here and remains constant throughout)
434        # (to be initialized by calling `initialize_copy_number_class_inference_vars`)
435        self.log_prior_k: Optional[np.ndarray] = None
436
437        # class Markov chain log transition (initialized here and remains constant throughout)
438        # (to be initialized by calling `initialize_copy_number_class_inference_vars`)
439        self.log_trans_tkk: Optional[np.ndarray] = None
440
441        # GC bias factors
442        # (to be initialized by calling `initialize_bias_inference_vars`)
443        self.W_gc_tg: Optional[tst.SparseConstant] = None
444
445        # auxiliary data structures for hybrid q_c_expectation_mode calculation
446        # (to be initialized by calling `initialize_bias_inference_vars`)
447        self.interval_neighbor_index_list: Optional[List[List[int]]] = None
448
449        # denoised copy ratios
450        denoised_copy_ratio_st = np.zeros((self.num_samples, self.num_intervals), dtype=types.floatX)
451        self.denoised_copy_ratio_st: types.TensorSharedVariable = th.shared(
452            denoised_copy_ratio_st, name="denoised_copy_ratio_st", borrow=config.borrow_numpy)
453
454        # initialize posterior
455        posterior_initializer.initialize_posterior(denoising_config, calling_config, self)
456        self.initialize_bias_inference_vars()
457        self.update_auxiliary_vars()
458
459    def initialize_copy_number_class_inference_vars(self):
460        """Initializes members required for copy number class inference (must be called in the cohort mode).
461        The following members are initialized:
462            - `DenoisingCallingWorkspace.log_class_emission_tk`
463            - `DenoisingCallingWorkspace.class_probs_k`
464            - `DenoisingCallingWorkspace.log_prior_k`
465            - `DenoisingCallingWorkspace.log_trans_tkk`
466        """
467        # class emission log posterior
468        log_class_emission_tk = np.zeros(
469            (self.num_intervals, self.calling_config.num_copy_number_classes), dtype=types.floatX)
470        self.log_class_emission_tk: types.TensorSharedVariable = th.shared(
471            log_class_emission_tk, name="log_class_emission_tk", borrow=True)
472
473        # class assignment prior probabilities
474        # Note:
475        #   The first class is the CNV-silent class (highly biased toward the baseline copy number)
476        #   The second class is a CNV-active class (all copy number states are equally probable)
477        class_probs_k = np.asarray([1.0 - self.calling_config.p_active, self.calling_config.p_active],
478                                   dtype=types.floatX)
479        self.class_probs_k: types.TensorSharedVariable = th.shared(
480            class_probs_k, name='class_probs_k', borrow=config.borrow_numpy)
481
482        # class Markov chain log prior (initialized here and remains constant throughout)
483        self.log_prior_k: np.ndarray = np.log(class_probs_k)
484
485        # class Markov chain log transition (initialized here and remains constant throughout)
486        self.log_trans_tkk: np.ndarray = self._get_log_trans_tkk(
487            self.dist_t,
488            self.calling_config.class_coherence_length,
489            self.calling_config.num_copy_number_classes,
490            class_probs_k)
491
492    def initialize_bias_inference_vars(self):
493        """Initializes `DenoisingCallingWorkspace.W_gc_tg` and `DenoisingCallingWorkspace.interval_neighbor_index_list`
494        if required by the model configuration."""
495        if self.denoising_config.enable_explicit_gc_bias_modeling:
496            self.W_gc_tg = self._create_sparse_gc_bin_tensor_tg(
497                self.interval_list, self.denoising_config.num_gc_bins)
498
499        if self.denoising_config.q_c_expectation_mode == 'hybrid':
500            self.interval_neighbor_index_list = self._get_interval_neighbor_index_list(
501                self.interval_list, self.denoising_config.active_class_padding_hybrid_mode)
502        else:
503            self.interval_neighbor_index_list = None
504
505    def update_auxiliary_vars(self):
506        """Updates `DenoisingCallingWorkspace.c_map_st' and `DenoisingCallingWorkspace.active_class_bitmask_t`."""
507        # MAP copy number call
508        if self.c_map_st is None:
509            c_map_st = np.zeros((self.num_samples, self.num_intervals), dtype=types.small_uint)
510            self.c_map_st = th.shared(c_map_st, name="c_map_st", borrow=config.borrow_numpy)
511        self.c_map_st.set_value(
512            np.argmax(self.log_q_c_stc.get_value(borrow=True), axis=2).astype(types.small_uint),
513            borrow=config.borrow_numpy)
514
515        if self.denoising_config.q_c_expectation_mode == 'hybrid':
516            _logger.debug("Updating CNV-active class bitmask...")
517            if self.active_class_bitmask_t is None:
518                active_class_bitmask_t = np.zeros((self.num_intervals,), dtype=bool)
519                self.active_class_bitmask_t = th.shared(
520                    active_class_bitmask_t, name="active_class_bitmask_t", borrow=config.borrow_numpy)
521
522            # bitmask for intervals of which the probability of being in the silent class is below 0.5
523            active_class_bitmask_t: np.ndarray = \
524                self.log_q_tau_tk.get_value(borrow=True)[:, 0] < -np.log(2)
525            padded_active_class_bitmask_t = np.zeros_like(active_class_bitmask_t)
526            for ti, neighbor_index_list in enumerate(self.interval_neighbor_index_list):
527                padded_active_class_bitmask_t[ti] = np.any(active_class_bitmask_t[neighbor_index_list])
528            self.active_class_bitmask_t.set_value(
529                padded_active_class_bitmask_t, borrow=config.borrow_numpy)
530
531    @staticmethod
532    def _get_interval_neighbor_index_list(interval_list: List[Interval],
533                                          maximum_neighbor_distance: int) -> List[List[int]]:
534        """Pads a given interval list, finds the index of overlapping neighbors, and returns a list of indices of
535        overlapping neighbors.
536
537        Note:
538            It is assumed that the `interval_list` is sorted (this is not asserted).
539
540        Args:
541            interval_list: list of intervals
542            maximum_neighbor_distance: Maximum distance between intervals to be considered neighbors
543
544        Returns:
545            A list of indices of overlapping neighbors with the same length as `interval_list`. Each element
546            in a variable-length list, depending on the number of neighbors.
547        """
548        assert maximum_neighbor_distance >= 0
549        num_intervals = len(interval_list)
550        padded_interval_list = [interval.get_padded(maximum_neighbor_distance) for interval in interval_list]
551        interval_neighbor_index_list = []
552        for ti, padded_interval in enumerate(padded_interval_list):
553            overlapping_interval_indices = [ti]
554            right_ti = ti
555            while right_ti < num_intervals - 1:
556                right_ti += 1
557                if interval_list[right_ti].overlaps_with(padded_interval):
558                    overlapping_interval_indices.append(right_ti)
559                else:
560                    break
561            left_ti = ti
562            while left_ti > 0:
563                left_ti -= 1
564                if interval_list[left_ti].overlaps_with(padded_interval):
565                    overlapping_interval_indices.append(left_ti)
566                else:
567                    break
568            interval_neighbor_index_list.append(overlapping_interval_indices)
569        return interval_neighbor_index_list
570
571    @staticmethod
572    def _get_log_trans_tkk(dist_t: np.ndarray,
573                           class_coherence_length: float,
574                           num_copy_number_classes: int,
575                           class_probs_k: np.ndarray) -> np.ndarray:
576        """Calculates the log transition probability between copy number classes."""
577        class_stay_prob_t = np.exp(-dist_t / class_coherence_length)
578        class_not_stay_prob_t = np.ones_like(class_stay_prob_t) - class_stay_prob_t
579        delta_kl = np.eye(num_copy_number_classes, dtype=types.floatX)
580        trans_tkl = (class_not_stay_prob_t[:, None, None] * class_probs_k[None, None, :]
581                     + class_stay_prob_t[:, None, None] * delta_kl[None, :, :])
582        return np.log(trans_tkl)
583
584    @staticmethod
585    def _create_sparse_gc_bin_tensor_tg(interval_list: List[Interval], num_gc_bins: int) -> tst.SparseConstant:
586        """Creates a sparse 2d theano tensor with shape (num_intervals, gc_bin). The sparse
587        tensor represents a 1-hot mapping of each interval to its GC bin index. The range [0, 1]
588        is uniformly divided into num_gc_bins.
589        """
590        assert all([GCContentAnnotation.get_key() in interval.annotations.keys() for interval in interval_list]), \
591            "Explicit GC bias modeling is enabled, however, some or all intervals lack \"{0}\" annotation".format(
592                GCContentAnnotation.get_key())
593
594        def get_gc_bin_idx(gc_content):
595            return min(int(gc_content * num_gc_bins), num_gc_bins - 1)
596
597        num_intervals = len(interval_list)
598        data = np.ones((num_intervals,))
599        indices = [get_gc_bin_idx(interval.get_annotation(GCContentAnnotation.get_key()))
600                   for interval in interval_list]
601        indptr = np.arange(0, num_intervals + 1)
602        scipy_gc_matrix = sp.csr_matrix((data, indices, indptr), shape=(num_intervals, num_gc_bins),
603                                        dtype=types.small_uint)
604        theano_gc_matrix: tst.SparseConstant = tst.as_sparse(scipy_gc_matrix)
605        return theano_gc_matrix
606
607    @staticmethod
608    def _get_baseline_copy_number_and_read_depth(sample_metadata_collection: SampleMetadataCollection,
609                                                 sample_names: List[str],
610                                                 contig_list: List[str]) \
611            -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
612        """Generates global read depth array, average ploidy array, and baseline copy numbers for all
613        samples.
614
615        Args:
616            sample_metadata_collection: a instance of `SampleMetadataCollection` containing required metadata
617                for all samples in `sample_names`
618            sample_names: list of sample names
619            contig_list: list of contigs appearing in the modeling interval list
620
621        Returns:
622            global read depth, average ploudy, baseline copy number
623        """
624        assert sample_metadata_collection.all_samples_have_read_depth_metadata(sample_names), \
625            "Some samples do not have read depth metadata"
626        assert sample_metadata_collection.all_samples_have_ploidy_metadata(sample_names), \
627            "Some samples do not have ploidy metadata"
628        num_samples = len(sample_names)
629        num_contigs = len(contig_list)
630
631        global_read_depth_s = np.zeros((num_samples,), dtype=types.floatX)
632        average_ploidy_s = np.zeros((num_samples,), dtype=types.floatX)
633        baseline_copy_number_sj = np.zeros((num_samples, num_contigs), dtype=types.small_uint)
634
635        for si, sample_name in enumerate(sample_names):
636            sample_read_depth_metadata = sample_metadata_collection.get_sample_read_depth_metadata(sample_name)
637            sample_ploidy_metadata = sample_metadata_collection.get_sample_ploidy_metadata(sample_name)
638
639            global_read_depth_s[si] = sample_read_depth_metadata.global_read_depth
640            average_ploidy_s[si] = sample_read_depth_metadata.average_ploidy
641            sample_baseline_copy_number_j = np.asarray([sample_ploidy_metadata.get_contig_ploidy(contig)
642                                                        for contig in contig_list], dtype=types.small_uint)
643            baseline_copy_number_sj[si, :] = sample_baseline_copy_number_j[:]
644
645        return global_read_depth_s, average_ploidy_s, baseline_copy_number_sj
646
647
648class InitialModelParametersSupplier:
649    """Base class for suppliers of initial global model parameters"""
650    def __init__(self,
651                 denoising_model_config: DenoisingModelConfig,
652                 calling_config: CopyNumberCallingConfig,
653                 shared_workspace: DenoisingCallingWorkspace):
654        self.denoising_model_config = denoising_model_config
655        self.calling_config = calling_config
656        self.shared_workspace = shared_workspace
657
658    @abstractmethod
659    def get_init_psi_t(self) -> np.ndarray:
660        """Initial interval-specific unexplained variance."""
661        raise NotImplementedError
662
663    @abstractmethod
664    def get_init_log_mean_bias_t(self) -> np.ndarray:
665        """Initial mean bias in log space."""
666        raise NotImplementedError
667
668    @abstractmethod
669    def get_init_ard_u(self) -> np.ndarray:
670        """Initial ARD prior precisions."""
671        raise NotImplementedError
672
673
674class TrivialInitialModelParametersSupplier(InitialModelParametersSupplier):
675    """Trivial initial model supplier."""
676    def __init__(self,
677                 denoising_model_config: DenoisingModelConfig,
678                 calling_config: CopyNumberCallingConfig,
679                 shared_workspace: DenoisingCallingWorkspace):
680        super().__init__(denoising_model_config, calling_config, shared_workspace)
681
682    def get_init_psi_t(self) -> np.ndarray:
683        return self.denoising_model_config.psi_t_scale * np.ones(
684            (self.shared_workspace.num_intervals,), dtype=types.floatX)
685
686    def get_init_log_mean_bias_t(self) -> np.ndarray:
687        return np.zeros((self.shared_workspace.num_intervals,), dtype=types.floatX)
688
689    def get_init_ard_u(self) -> np.ndarray:
690        fact = self.denoising_model_config.psi_t_scale * self.denoising_model_config.init_ard_rel_unexplained_variance
691        return fact * np.ones((self.denoising_model_config.max_bias_factors,), dtype=types.floatX)
692
693
694class DenoisingModel(GeneralizedContinuousModel):
695    """The gCNV coverage denoising model declaration (continuous RVs only; discrete posteriors are assumed
696    to be given)."""
697    def __init__(self,
698                 denoising_model_config: DenoisingModelConfig,
699                 shared_workspace: DenoisingCallingWorkspace,
700                 initial_model_parameters_supplier: InitialModelParametersSupplier):
701        super().__init__()
702        self.shared_workspace = shared_workspace
703        register_as_global = self.register_as_global
704        register_as_sample_specific = self.register_as_sample_specific
705
706        eps_mapping = denoising_model_config.mapping_error_rate
707
708        # interval-specific unexplained variance
709        psi_t = Exponential(name='psi_t', lam=1.0 / denoising_model_config.psi_t_scale,
710                            shape=(shared_workspace.num_intervals,),
711                            broadcastable=(False,))
712        register_as_global(psi_t)
713
714        # sample-specific unexplained variance
715        psi_s = Exponential(name='psi_s', lam=1.0 / denoising_model_config.psi_s_scale,
716                            shape=(shared_workspace.num_samples,),
717                            broadcastable=(False,))
718        register_as_sample_specific(psi_s, sample_axis=0)
719
720        # convert "unexplained variance" to negative binomial over-dispersion
721        alpha_st = tt.maximum(tt.inv(tt.exp(psi_t.dimshuffle('x', 0) + psi_s.dimshuffle(0, 'x')) - 1.0),
722                              _eps)
723
724        # interval-specific mean log bias
725        log_mean_bias_t = Normal(name='log_mean_bias_t', mu=0.0, sd=denoising_model_config.log_mean_bias_std,
726                                 shape=(shared_workspace.num_intervals,),
727                                 broadcastable=(False,),
728                                 testval=initial_model_parameters_supplier.get_init_log_mean_bias_t())
729        register_as_global(log_mean_bias_t)
730
731        # log-normal read depth centered at the global read depth
732        read_depth_mu_s = (np.log(shared_workspace.global_read_depth_s)
733                           - 0.5 / denoising_model_config.depth_correction_tau)
734        read_depth_s = Lognormal(name='read_depth_s',
735                                 mu=read_depth_mu_s,
736                                 tau=denoising_model_config.depth_correction_tau,
737                                 shape=(shared_workspace.num_samples,),
738                                 broadcastable=(False,),
739                                 testval=shared_workspace.global_read_depth_s)
740        register_as_sample_specific(read_depth_s, sample_axis=0)
741
742        # log bias modelling, starting with the log mean bias
743        log_bias_st = tt.tile(log_mean_bias_t, (shared_workspace.num_samples, 1))
744
745        if denoising_model_config.enable_bias_factors:
746            # ARD prior precisions
747            ard_u = HalfFlat(name='ard_u',
748                             shape=(denoising_model_config.max_bias_factors,),
749                             broadcastable=(False,),
750                             testval=initial_model_parameters_supplier.get_init_ard_u())
751            register_as_global(ard_u)
752
753            # bias factors
754            W_tu = Normal(name='W_tu', mu=0.0, tau=ard_u.dimshuffle('x', 0),
755                          shape=(shared_workspace.num_intervals, denoising_model_config.max_bias_factors),
756                          broadcastable=(False, False))
757            register_as_global(W_tu)
758
759            # sample-specific bias factor loadings
760            z_su = Normal(name='z_su', mu=0.0, sd=1.0,
761                          shape=(shared_workspace.num_samples, denoising_model_config.max_bias_factors),
762                          broadcastable=(False, False))
763            register_as_sample_specific(z_su, sample_axis=0)
764
765            # add contribution to total log bias
766            if denoising_model_config.disable_bias_factors_in_active_class:
767                prob_silent_class_t = tt.exp(shared_workspace.log_q_tau_tk[:, 0])
768                log_bias_st += (prob_silent_class_t.dimshuffle('x', 0) * tt.dot(W_tu, z_su.T).T)
769            else:
770                log_bias_st += tt.dot(W_tu, z_su.T).T
771
772        # GC bias
773        if denoising_model_config.enable_explicit_gc_bias_modeling:
774            # sample-specific GC bias factor loadings
775            z_sg = Normal(name='z_sg', mu=0.0, sd=denoising_model_config.gc_curve_sd,
776                          shape=(shared_workspace.num_samples, denoising_model_config.num_gc_bins),
777                          broadcastable=(False, False))
778            register_as_sample_specific(z_sg, sample_axis=0)
779
780            # add contribution to total log bias
781            log_bias_st += tst.dot(shared_workspace.W_gc_tg, z_sg.T).T
782
783        # useful expressions
784        bias_st = tt.exp(log_bias_st)
785
786        # the expected number of erroneously mapped reads
787        mean_mapping_error_correction_s = eps_mapping * read_depth_s * shared_workspace.average_ploidy_s
788
789        denoised_copy_ratio_st = ((shared_workspace.n_st - mean_mapping_error_correction_s.dimshuffle(0, 'x'))
790                                  / ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x') * bias_st))
791
792        Deterministic(name='denoised_copy_ratio_st', var=denoised_copy_ratio_st)
793
794        mu_stc = ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x', 'x')
795                  * bias_st.dimshuffle(0, 1, 'x')
796                  * shared_workspace.copy_number_values_c.dimshuffle('x', 'x', 0)
797                  + mean_mapping_error_correction_s.dimshuffle(0, 'x', 'x'))
798
799        Deterministic(name='log_copy_number_emission_stc',
800                      var=commons.negative_binomial_logp(
801                          mu_stc, alpha_st.dimshuffle(0, 1, 'x'), shared_workspace.n_st.dimshuffle(0, 1, 'x')))
802
803        # n_st (observed)
804        if denoising_model_config.q_c_expectation_mode == 'map':
805            def _copy_number_emission_logp(_n_st):
806                mu_st = ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x') * bias_st
807                         * shared_workspace.c_map_st + mean_mapping_error_correction_s.dimshuffle(0, 'x'))
808                log_copy_number_emission_st = commons.negative_binomial_logp(
809                    mu_st, alpha_st, _n_st)
810                return log_copy_number_emission_st
811
812        elif denoising_model_config.q_c_expectation_mode == 'exact':
813            def _copy_number_emission_logp(_n_st):
814                _log_copy_number_emission_stc = commons.negative_binomial_logp(
815                    mu_stc,
816                    alpha_st.dimshuffle(0, 1, 'x'),
817                    _n_st.dimshuffle(0, 1, 'x'))
818                log_q_c_stc = shared_workspace.log_q_c_stc
819                q_c_stc = tt.exp(log_q_c_stc)
820                return tt.sum(q_c_stc * (_log_copy_number_emission_stc - log_q_c_stc), axis=2)
821
822        elif denoising_model_config.q_c_expectation_mode == 'hybrid':
823            def _copy_number_emission_logp(_n_st):
824                active_class_bitmask_t = self.shared_workspace.active_class_bitmask_t
825                active_class_indices = active_class_bitmask_t.nonzero()[0]
826                silent_class_indices = (1 - active_class_bitmask_t).nonzero()[0]
827
828                # for CNV-active classes, calculate exact posterior expectation
829                mu_active_stc = ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x', 'x')
830                                 * bias_st.dimshuffle(0, 1, 'x')[:, active_class_indices, :]
831                                 * shared_workspace.copy_number_values_c.dimshuffle('x', 'x', 0)
832                                 + mean_mapping_error_correction_s.dimshuffle(0, 'x', 'x'))
833                alpha_active_stc = tt.maximum(tt.inv((tt.exp(psi_t.dimshuffle('x', 0)[:, active_class_indices]
834                                                             + psi_s.dimshuffle(0, 'x')) - 1.0)).dimshuffle(0, 1, 'x'),
835                                              _eps)
836                n_active_stc = _n_st.dimshuffle(0, 1, 'x')[:, active_class_indices, :]
837                active_class_logp_stc = commons.negative_binomial_logp(mu_active_stc, alpha_active_stc, n_active_stc)
838                log_q_c_active_stc = shared_workspace.log_q_c_stc[:, active_class_indices, :]
839                q_c_active_stc = tt.exp(log_q_c_active_stc)
840                active_class_logp = tt.sum(q_c_active_stc * (active_class_logp_stc - log_q_c_active_stc))
841
842                # for CNV-silent classes, use MAP copy number state
843                mu_silent_st = ((1.0 - eps_mapping) * read_depth_s.dimshuffle(0, 'x') * bias_st[:, silent_class_indices]
844                                * shared_workspace.c_map_st[:, silent_class_indices]
845                                + mean_mapping_error_correction_s.dimshuffle(0, 'x'))
846                alpha_silent_st = alpha_st[:, silent_class_indices]
847                n_silent_st = _n_st[:, silent_class_indices]
848                silent_class_logp = tt.sum(commons.negative_binomial_logp(mu_silent_st, alpha_silent_st, n_silent_st))
849
850                return active_class_logp + silent_class_logp
851
852        elif denoising_model_config.q_c_expectation_mode == 'marginalize':
853            def _copy_number_emission_logp(_n_st):
854                _log_copy_number_emission_stc = commons.negative_binomial_logp(
855                    mu_stc,
856                    alpha_st.dimshuffle(0, 1, 'x'),
857                    _n_st.dimshuffle(0, 1, 'x'))
858                return pm.math.logsumexp(shared_workspace.log_q_c_stc + _log_copy_number_emission_stc, axis=2)
859
860        else:
861            raise Exception("Unknown q_c expectation mode; an exception should have been raised earlier")
862
863        DensityDist(name='n_st_obs',
864                    logp=_copy_number_emission_logp,
865                    observed=shared_workspace.n_st)
866
867
868class CopyNumberEmissionBasicSampler:
869    """Draws posterior samples from log copy number emission probabilities for a given variational
870    approximation to the denoising model continuous RVs."""
871    def __init__(self,
872                 denoising_model_config: DenoisingModelConfig,
873                 calling_config: CopyNumberCallingConfig,
874                 inference_params: HybridInferenceParameters,
875                 shared_workspace: DenoisingCallingWorkspace,
876                 denoising_model: DenoisingModel):
877        self.model_config = denoising_model_config
878        self.calling_config = calling_config
879        self.inference_params = inference_params
880        self.shared_workspace = shared_workspace
881        self.denoising_model = denoising_model
882        self._simultaneous_log_copy_number_emission_sampler = None
883
884    def update_approximation(self, approx: pm.approximations.MeanField):
885        """Generates a new compiled sampler based on a given approximation.
886        Args:
887            approx: an instance of PyMC3 mean-field approximation
888
889        Returns:
890            None
891        """
892        self._simultaneous_log_copy_number_emission_sampler = \
893            self._get_compiled_simultaneous_log_copy_number_emission_sampler(approx)
894
895    @property
896    def is_sampler_initialized(self):
897        return self._simultaneous_log_copy_number_emission_sampler is not None
898
899    def draw(self) -> np.ndarray:
900        assert self.is_sampler_initialized, "Posterior approximation is not provided yet"
901        return self._simultaneous_log_copy_number_emission_sampler()
902
903    @th.configparser.change_flags(compute_test_value="off")
904    def _get_compiled_simultaneous_log_copy_number_emission_sampler(self, approx: pm.approximations.MeanField):
905        """For a given variational approximation, returns a compiled theano function that draws posterior samples
906        from log copy number emission probabilities."""
907        log_copy_number_emission_stc = commons.stochastic_node_mean_symbolic(
908            approx, self.denoising_model['log_copy_number_emission_stc'],
909            size=self.inference_params.log_emission_samples_per_round)
910        return th.function(inputs=[], outputs=log_copy_number_emission_stc)
911
912
913class HHMMClassAndCopyNumberBasicCaller:
914    """This class updates copy number and interval class posteriors according to the following hierarchical
915    hidden Markov model:
916
917        class_prior_k --> (tau_1) --> (tau_2) --> (tau_3) --> ...
918                             |           |           |
919                             |           |           |
920                             v           v           v
921                           (c_s1) -->  (c_s2) -->  (c_s3) --> ...
922                             |           |           |
923                             |           |           |
924                             v           v           v
925                            n_s1        n_s2        n_s3
926
927        The posterior probability of `tau` and `c_s`, q(tau) and q(c_s) respectively, are obtained via
928        the following variational ansatz:
929
930            \prod_s p(tau, c_s | n) ~ q(tau) \prod_s q(c_s),
931
932        where correlations between intervals are preserved in both chains, however, cross-correlations
933        between `tau` and `c` are neglected, including correlations induced between copy numbers of
934        different samples. As usual, the posteriors are determined by minimizing the KL divergence w.r.t.
935        the true posterior resulting in the following iterative scheme:
936
937        - Given q(tau), the effective copy number prior for the first interval and the effective copy number
938          transition probabilities are determined (see _get_update_copy_number_hmm_specs_compiled_function).
939          Along with the given emission probabilities to sample read counts, q(c_s) is updated using the
940          forward-backward algorithm for each sample (see _update_copy_number_log_posterior)
941
942        - Given q(c_s), the emission probability of each copy number class (tau) is determined
943          (see _get_update_log_class_emission_tk_theano_func). The class prior and transition probabilities
944          are fixed hyperparameters. Therefore, q(tau) can be updated immediately using a single run
945          of forward-backward algorithm (see _update_class_log_posterior).
946    """
947    CopyNumberForwardBackwardResult = collections.namedtuple(
948        'CopyNumberForwardBackwardResult',
949        'sample_index, new_log_posterior_tc, copy_number_update_size, log_likelihood')
950
951    def __init__(self,
952                 calling_config: CopyNumberCallingConfig,
953                 inference_params: HybridInferenceParameters,
954                 shared_workspace: DenoisingCallingWorkspace,
955                 disable_class_update: bool,
956                 temperature: types.TensorSharedVariable):
957        self.calling_config = calling_config
958        self.inference_params = inference_params
959        self.shared_workspace = shared_workspace
960        self.disable_class_update = disable_class_update
961        self.temperature = temperature
962
963        # generate the 2-class inventory of copy number priors (CNV-silent, CNV-active) for all samples
964        # according to their respective germline contig ploidies
965        pi_sjkc = np.zeros((shared_workspace.num_samples,
966                            shared_workspace.num_contigs,
967                            calling_config.num_copy_number_classes,
968                            calling_config.num_copy_number_states), dtype=types.floatX)
969        for si in range(shared_workspace.num_samples):
970            pi_sjkc[si, :, :, :] = self.get_copy_number_prior_for_sample_jkc(
971                calling_config.num_copy_number_states,
972                calling_config.p_alt,
973                shared_workspace.baseline_copy_number_sj[si, :])[:, :, :]
974        self.pi_sjkc: types.TensorSharedVariable = th.shared(pi_sjkc, name='pi_sjkc', borrow=config.borrow_numpy)
975
976        # compiled function for forward-backward updates of copy number posterior
977        self._hmm_q_copy_number = TheanoForwardBackward(
978            log_posterior_probs_output_tc=None,
979            resolve_nans=False,
980            do_thermalization=True,
981            do_admixing=True,
982            include_update_size_output=True,
983            include_alpha_beta_output=False)
984
985        if not disable_class_update:
986            # compiled function for forward-backward update of class posterior
987            # Note:
988            #   if p_active == 0, we have to deal with inf - inf expressions properly.
989            #   setting resolve_nans = True takes care of such ambiguities.
990            self._hmm_q_class = TheanoForwardBackward(
991                log_posterior_probs_output_tc=shared_workspace.log_q_tau_tk,
992                resolve_nans=(calling_config.p_active == 0),
993                do_thermalization=True,
994                do_admixing=True,
995                include_update_size_output=True,
996                include_alpha_beta_output=False)
997
998            # compiled function for update of class log emission
999            self._update_log_class_emission_tk_theano_func = self._get_update_log_class_emission_tk_theano_func()
1000        else:
1001            self._hmm_q_class: Optional[TheanoForwardBackward] = None
1002            self._update_log_class_emission_tk_theano_func = None
1003
1004        # compiled function for variational update of copy number HMM specs
1005        self._get_copy_number_hmm_specs_theano_func = self.get_compiled_copy_number_hmm_specs_theano_func()
1006
1007    @staticmethod
1008    def get_copy_number_prior_for_sample_jkc(num_copy_number_states: int,
1009                                             p_alt: float,
1010                                             baseline_copy_number_j: np.ndarray) -> np.ndarray:
1011        """Returns copy-number prior probabilities for each contig (j) and class (k) as a 3d ndarray.
1012
1013        Args:
1014            num_copy_number_states: total number of copy-number states
1015            p_alt: total probability of alt copy-number states
1016            baseline_copy_number_j: baseline copy-number state for each contig
1017
1018        Returns:
1019            a 3d ndarray
1020        """
1021        p_baseline = 1.0 - (num_copy_number_states - 1) * p_alt
1022        pi_jkc = np.zeros((len(baseline_copy_number_j), 2, num_copy_number_states), dtype=types.floatX)
1023        for j, baseline_state in enumerate(baseline_copy_number_j):
1024            # the silent class
1025            pi_jkc[j, 0, :] = p_alt
1026            pi_jkc[j, 0, baseline_state] = p_baseline
1027            # the active class
1028            pi_jkc[j, 1, :] = 1.0 / num_copy_number_states
1029
1030        return pi_jkc
1031
1032    def call(self,
1033             copy_number_update_summary_statistic_reducer,
1034             class_update_summary_statistic_reducer) -> Tuple[np.ndarray, np.ndarray, float, float]:
1035        """Perform a round of update of q(tau) and q(c)
1036
1037        Note:
1038            This function must be called until q(tau) and q(c) converge to a self-consistent solution.
1039
1040        Args:
1041            copy_number_update_summary_statistic_reducer: a function that reduces vectors to scalars and
1042                is used to compile a summary of copy number posterior updates across intervals for each sample
1043            class_update_summary_statistic_reducer: a function that reduces vectors to scalars and
1044                is used to compile a summary of interval class posterior updates across intervals
1045
1046        Returns:
1047            copy number update summary (ndarray of size `num_samples`),
1048            copy number Markov chain log likelihoods (ndarray of size `num_samples`),
1049            interval class update summary,
1050            interval class Markov chain log likelihood
1051        """
1052        # copy number posterior update
1053        copy_number_update_s, copy_number_log_likelihoods_s = self._update_copy_number_log_posterior(
1054            copy_number_update_summary_statistic_reducer)
1055
1056        if not self.disable_class_update:
1057            # class posterior update
1058            self._update_log_class_emission_tk()
1059            class_update, class_log_likelihood = self._update_class_log_posterior(
1060                class_update_summary_statistic_reducer)
1061        else:
1062            class_update = None
1063            class_log_likelihood = None
1064
1065        return copy_number_update_s, copy_number_log_likelihoods_s, class_update, class_log_likelihood
1066
1067    def _update_copy_number_log_posterior(self, copy_number_update_summary_statistic_reducer) \
1068            -> Tuple[np.ndarray, np.ndarray]:
1069        ws = self.shared_workspace
1070        copy_number_update_s = np.zeros((ws.num_samples,), dtype=types.floatX)
1071        copy_number_log_likelihoods_s = np.zeros((ws.num_samples,), dtype=types.floatX)
1072        num_calling_processes = self.calling_config.num_calling_processes
1073
1074        def _run_single_sample_fb(_sample_index: int):
1075            # step 1. calculate copy-number HMM log prior and log transition matrix
1076            pi_jkc = self.pi_sjkc.get_value(borrow=True)[_sample_index, ...]
1077            cnv_stay_prob_t = self.shared_workspace.cnv_stay_prob_t.get_value(borrow=True)
1078            log_q_tau_tk = self.shared_workspace.log_q_tau_tk.get_value(borrow=True)
1079            t_to_j_map = self.shared_workspace.t_to_j_map.get_value(borrow=True)
1080            hmm_spec = self._get_copy_number_hmm_specs_theano_func(pi_jkc, cnv_stay_prob_t, log_q_tau_tk, t_to_j_map)
1081            log_prior_c = hmm_spec[0]
1082            log_trans_tcc = hmm_spec[1]
1083
1084            prev_log_posterior_tc = ws.log_q_c_stc.get_value(borrow=True)[_sample_index, ...]
1085            log_copy_number_emission_tc = ws.log_copy_number_emission_stc.get_value(borrow=True)[_sample_index, ...]
1086
1087            # step 2. run forward-backward and update copy-number posteriors
1088            _fb_result = self._hmm_q_copy_number.perform_forward_backward(
1089                log_prior_c, log_trans_tcc, log_copy_number_emission_tc,
1090                prev_log_posterior_tc=prev_log_posterior_tc,
1091                admixing_rate=self.inference_params.caller_internal_admixing_rate,
1092                temperature=self.temperature.get_value()[0])
1093            new_log_posterior_tc = _fb_result.log_posterior_probs_tc
1094            copy_number_update_size = copy_number_update_summary_statistic_reducer(_fb_result.update_norm_t)
1095            log_likelihood = float(_fb_result.log_data_likelihood)
1096
1097            return self.CopyNumberForwardBackwardResult(
1098                _sample_index, new_log_posterior_tc, copy_number_update_size, log_likelihood)
1099
1100        def _update_log_q_c_stc_inplace(log_q_c_stc, _sample_index, new_log_posterior_tc):
1101            log_q_c_stc[_sample_index, :, :] = new_log_posterior_tc[:, :]
1102            return log_q_c_stc
1103
1104        max_chunks = ws.num_samples // num_calling_processes + 1
1105        for chunk_index in range(max_chunks):
1106            begin_index = chunk_index * num_calling_processes
1107            end_index = min((chunk_index + 1) * num_calling_processes, ws.num_samples)
1108            if begin_index >= ws.num_samples:
1109                break
1110            # todo multiprocessing
1111            # with mp.Pool(processes=num_calling_processes) as pool:
1112            #     for fb_result in pool.map(_run_single_sample_fb, range(begin_index, end_index)):
1113            for fb_result in [_run_single_sample_fb(sample_index)
1114                              for sample_index in range(begin_index, end_index)]:
1115                # update log posterior in the workspace
1116                ws.log_q_c_stc.set_value(
1117                    _update_log_q_c_stc_inplace(
1118                        ws.log_q_c_stc.get_value(borrow=True),
1119                        fb_result.sample_index, fb_result.new_log_posterior_tc),
1120                    borrow=True)
1121                # update summary stats
1122                copy_number_update_s[fb_result.sample_index] = fb_result.copy_number_update_size
1123                copy_number_log_likelihoods_s[fb_result.sample_index] = fb_result.log_likelihood
1124
1125        return copy_number_update_s, copy_number_log_likelihoods_s
1126
1127    def _update_log_class_emission_tk(self):
1128        self._update_log_class_emission_tk_theano_func()
1129
1130    def _update_class_log_posterior(self, class_update_summary_statistic_reducer) -> Tuple[float, float]:
1131        fb_result = self._hmm_q_class.perform_forward_backward(
1132            self.shared_workspace.log_prior_k,
1133            self.shared_workspace.log_trans_tkk,
1134            self.shared_workspace.log_class_emission_tk.get_value(borrow=True),
1135            prev_log_posterior_tc=self.shared_workspace.log_q_tau_tk.get_value(borrow=True),
1136            admixing_rate=self.inference_params.caller_internal_admixing_rate,
1137            temperature=self.temperature.get_value()[0])
1138        class_update_size = class_update_summary_statistic_reducer(fb_result.update_norm_t)
1139        log_likelihood = float(fb_result.log_data_likelihood)
1140        return class_update_size, log_likelihood
1141
1142    def update_auxiliary_vars(self):
1143        self.shared_workspace.update_auxiliary_vars()
1144
1145    @staticmethod
1146    @th.configparser.change_flags(compute_test_value="off")
1147    def get_compiled_copy_number_hmm_specs_theano_func() -> th.compile.function_module.Function:
1148        """Returns a compiled function that calculates the interval-class-averaged and probability-sum-normalized
1149        log copy number transition matrix and log copy number prior for the first interval
1150
1151        Returned theano function inputs:
1152            pi_jkc: a 3d tensor containing copy-number priors for each contig (j) and each class (k)
1153            cnv_stay_prob_t: probability of staying on the same copy-number state at interval `t`
1154            log_q_tau_tk: log probability of copy-number classes at interval `t`
1155            t_to_j_map: a mapping from interval indices (t) to contig indices (j); it is used to unpack
1156                `pi_jkc` to `pi_tkc` (see below)
1157
1158        Returned theano function outputs:
1159            log_prior_c_first_interval: log probability of copy-number states for the first interval
1160            log_trans_tab: log transition probability matrix from interval `t` to interval `t+1`
1161
1162        Note:
1163            In the following, we use "a" and "b" subscripts in the variable names to refer to the departure
1164            and destination states, respectively. Like before, "t" and "k" denote interval and class, and "j"
1165            refers to contig index.
1166        """
1167        # shorthands
1168        pi_jkc = tt.tensor3(name='pi_jkc')
1169        cnv_stay_prob_t = tt.vector(name='cnv_stay_prob_t')
1170        log_q_tau_tk = tt.matrix(name='log_q_tau_tk')
1171        t_to_j_map = tt.vector(name='t_to_j_map', dtype=tt.scal.uint32)
1172
1173        # log prior probability for the first interval
1174        log_prior_c_first_interval = tt.dot(tt.log(pi_jkc[t_to_j_map[0], :, :].T), tt.exp(log_q_tau_tk[0, :]))
1175        log_prior_c_first_interval -= pm.logsumexp(log_prior_c_first_interval)
1176
1177        # log transition matrix
1178        cnv_not_stay_prob_t = tt.ones_like(cnv_stay_prob_t) - cnv_stay_prob_t
1179        num_copy_number_states = pi_jkc.shape[2]
1180        delta_ab = tt.eye(num_copy_number_states)
1181
1182        # map contig to interval and obtain pi_tkc for the rest of the targets
1183        pi_tkc = pi_jkc[t_to_j_map[1:], :, :]
1184
1185        # calculate normalized log transition matrix
1186        # todo use logaddexp
1187        log_trans_tkab = tt.log(cnv_not_stay_prob_t.dimshuffle(0, 'x', 'x', 'x') * pi_tkc.dimshuffle(0, 1, 'x', 2)
1188                                + cnv_stay_prob_t.dimshuffle(0, 'x', 'x', 'x') * delta_ab.dimshuffle('x', 'x', 0, 1))
1189        q_tau_tkab = tt.exp(log_q_tau_tk[1:, :]).dimshuffle(0, 1, 'x', 'x')
1190        log_trans_tab = tt.sum(q_tau_tkab * log_trans_tkab, axis=1)
1191        log_trans_tab -= pm.logsumexp(log_trans_tab, axis=2)
1192
1193        inputs = [pi_jkc, cnv_stay_prob_t, log_q_tau_tk, t_to_j_map]
1194        outputs = [log_prior_c_first_interval, log_trans_tab]
1195
1196        return th.function(inputs=inputs, outputs=outputs)
1197
1198    @th.configparser.change_flags(compute_test_value="off")
1199    def _get_update_log_class_emission_tk_theano_func(self) -> th.compile.function_module.Function:
1200        """Returns a compiled function that calculates the log interval class emission probability and
1201        directly updates `log_class_emission_tk` in the workspace.
1202
1203        Note:
1204            In the following,
1205
1206                xi_tab ~ posterior copy number probability of two subsequent intervals.
1207
1208            We ignore correlations, i.e. we assume:
1209
1210              xi_st(a, b) \equiv q_c(c_{s,t} = a, c_{s,t+1} = b)
1211                          \approx q_c(c_{s,t} = a) q_c(c_{s,t+1} = b)
1212
1213            If needed, xi can be calculated exactly from the forward-backward tables.
1214        """
1215        # shorthands
1216        cnv_stay_prob_t = self.shared_workspace.cnv_stay_prob_t
1217        q_c_stc = tt.exp(self.shared_workspace.log_q_c_stc)
1218        pi_sjkc = self.pi_sjkc
1219        t_to_j_map = self.shared_workspace.t_to_j_map
1220        num_copy_number_states = self.calling_config.num_copy_number_states
1221
1222        # log copy number transition matrix for each class
1223        cnv_not_stay_prob_t = tt.ones_like(cnv_stay_prob_t) - cnv_stay_prob_t
1224        delta_ab = tt.eye(num_copy_number_states)
1225
1226        # calculate log class emission by reducing over samples; see below
1227        log_class_emission_cum_sum_tk = tt.zeros((self.shared_workspace.num_intervals - 1,
1228                                                  self.calling_config.num_copy_number_classes),
1229                                                 dtype=types.floatX)
1230
1231        def inc_log_class_emission_tk_except_for_first_interval(pi_jkc, q_c_tc, cum_sum_tk):
1232            """Adds the contribution of a given sample to the log class emission (symbolically).
1233
1234            Args:
1235                pi_jkc: copy number prior inventory for the sample
1236                q_c_tc: copy number posteriors for the sample
1237                cum_sum_tk: current cumulative sum of log class emission
1238
1239            Returns:
1240                Symbolically updated cumulative sum of log class emission
1241            """
1242            # map contigs to targets (starting from the second interval)
1243            pi_tkc = pi_jkc[t_to_j_map[1:], :, :]
1244
1245            # todo use logaddexp
1246            log_trans_tkab = tt.log(
1247                cnv_not_stay_prob_t.dimshuffle(0, 'x', 'x', 'x') * pi_tkc.dimshuffle(0, 1, 'x', 2)
1248                + cnv_stay_prob_t.dimshuffle(0, 'x', 'x', 'x') * delta_ab.dimshuffle('x', 'x', 0, 1))
1249            xi_tab = q_c_tc[:-1, :].dimshuffle(0, 1, 'x') * q_c_tc[1:, :].dimshuffle(0, 'x', 1)
1250            current_log_class_emission_tk = tt.sum(tt.sum(
1251                xi_tab.dimshuffle(0, 'x', 1, 2) * log_trans_tkab, axis=-1), axis=-1)
1252            return cum_sum_tk + current_log_class_emission_tk
1253
1254        reduce_output = th.reduce(inc_log_class_emission_tk_except_for_first_interval,
1255                                  sequences=[pi_sjkc, q_c_stc],
1256                                  outputs_info=[log_class_emission_cum_sum_tk])
1257        log_class_emission_tk_except_for_first_interval = reduce_output[0]
1258
1259        # the first interval
1260        pi_skc_first = pi_sjkc[:, t_to_j_map[0], :, :]
1261        q_skc_first = q_c_stc[:, 0, :].dimshuffle(0, 'x', 1)
1262        log_class_emission_k_first = tt.sum(tt.sum(tt.log(pi_skc_first) * q_skc_first, axis=0), axis=-1)
1263
1264        # concatenate first and rest
1265        log_class_emission_tk = tt.concatenate((log_class_emission_k_first.dimshuffle('x', 0),
1266                                                log_class_emission_tk_except_for_first_interval))
1267
1268        return th.function(inputs=[], outputs=[], updates=[
1269            (self.shared_workspace.log_class_emission_tk, log_class_emission_tk)])
1270