1import logging
2import os
3from typing import List, Optional
4
5import numpy as np
6import pymc3 as pm
7
8from . import io_commons
9from . import io_consts
10from . import io_intervals_and_counts
11from .. import config
12from ..models import commons
13from ..models.model_denoising_calling import CopyNumberCallingConfig, DenoisingModelConfig
14from ..models.model_denoising_calling import DenoisingCallingWorkspace, DenoisingModel
15from ..utils import math
16
17_logger = logging.getLogger(__name__)
18
19
20class DenoisingModelWriter:
21    """Writes global denoising model parameters to disk."""
22    def __init__(self,
23                 denoising_config: DenoisingModelConfig,
24                 calling_config: CopyNumberCallingConfig,
25                 denoising_calling_workspace: DenoisingCallingWorkspace,
26                 denoising_model: DenoisingModel,
27                 denoising_model_approx: pm.MeanField,
28                 output_path: str):
29        io_commons.assert_output_path_writable(output_path)
30        self.denoising_config = denoising_config
31        self.calling_config = calling_config
32        self.denoising_calling_workspace = denoising_calling_workspace
33        self.denoising_model = denoising_model
34        self.denoising_model_approx = denoising_model_approx
35        self.output_path = output_path
36
37    @staticmethod
38    def _write_class_log_posterior(output_path, log_q_tau_tk):
39        io_commons.write_ndarray_to_tsv(
40            os.path.join(output_path, io_consts.default_class_log_posterior_tsv_filename), log_q_tau_tk)
41
42    def __call__(self):
43        # write gcnvkernel version
44        io_commons.write_gcnvkernel_version(self.output_path)
45
46        # write denoising config
47        io_commons.write_dict_to_json_file(
48            os.path.join(self.output_path, io_consts.default_denoising_config_json_filename),
49            self.denoising_config.__dict__, set())
50
51        # write calling config
52        io_commons.write_dict_to_json_file(
53            os.path.join(self.output_path, io_consts.default_calling_config_json_filename),
54            self.calling_config.__dict__, set())
55
56        # write global variables in the workspace
57        self._write_class_log_posterior(
58            self.output_path, self.denoising_calling_workspace.log_q_tau_tk.get_value(borrow=True))
59
60        # write global variables in the posterior
61        io_commons.write_mean_field_global_params(
62            self.output_path, self.denoising_model_approx, self.denoising_model)
63
64
65class DenoisingModelReader:
66    """Reads global denoising model parameters from disk."""
67    def __init__(self,
68                 denoising_config: DenoisingModelConfig,
69                 calling_config: CopyNumberCallingConfig,
70                 denoising_calling_workspace: DenoisingCallingWorkspace,
71                 denoising_model: DenoisingModel,
72                 denoising_model_approx: pm.MeanField,
73                 input_path: str):
74        self.denoising_config = denoising_config
75        self.calling_config = calling_config
76        self.denoising_calling_workspace = denoising_calling_workspace
77        self.denoising_model = denoising_model
78        self.denoising_model_approx = denoising_model_approx
79        self.input_path = input_path
80
81    def __call__(self):
82        # check if the model is created with the same gcnvkernel version
83        io_commons.check_gcnvkernel_version_from_path(self.input_path)
84
85        # read global workspace variables
86        self.denoising_calling_workspace.log_q_tau_tk.set_value(
87            io_commons.read_ndarray_from_tsv(
88                os.path.join(self.input_path, io_consts.default_class_log_posterior_tsv_filename)),
89            borrow=config.borrow_numpy)
90
91        # read global posterior parameters
92        io_commons.read_mean_field_global_params(
93            self.input_path, self.denoising_model_approx, self.denoising_model)
94
95
96def get_sample_posterior_path(calls_path: str, sample_index: int):
97    return os.path.join(calls_path, io_consts.sample_folder_prefix + repr(sample_index))
98
99
100class SampleDenoisingAndCallingPosteriorsWriter:
101    """Writes sample-specific model parameters and associated workspace variables to disk."""
102    def __init__(self,
103                 denoising_config: DenoisingModelConfig,
104                 calling_config: CopyNumberCallingConfig,
105                 denoising_calling_workspace: DenoisingCallingWorkspace,
106                 denoising_model: DenoisingModel,
107                 denoising_model_approx: pm.MeanField,
108                 output_path: str):
109        io_commons.assert_output_path_writable(output_path)
110        self.denoising_config = denoising_config
111        self.calling_config = calling_config
112        self.denoising_calling_workspace = denoising_calling_workspace
113        self.denoising_model = denoising_model
114        self.denoising_model_approx = denoising_model_approx
115        self.output_path = output_path
116
117    @staticmethod
118    def write_ndarray_tc_with_copy_number_header(sample_posterior_path: str,
119                                                 ndarray_tc: np.ndarray,
120                                                 output_file_name: str,
121                                                 comment=io_consts.default_comment_char,
122                                                 delimiter=io_consts.default_delimiter_char,
123                                                 extra_comment_lines: Optional[List[str]] = None):
124        assert isinstance(ndarray_tc, np.ndarray)
125        assert ndarray_tc.ndim == 2
126        num_copy_number_states = ndarray_tc.shape[1]
127        copy_number_header_columns = [io_consts.copy_number_column_prefix + str(cn)
128                                      for cn in range(num_copy_number_states)]
129        with open(os.path.join(sample_posterior_path, output_file_name), 'w') as f:
130            if extra_comment_lines is not None:
131                for comment_line in extra_comment_lines:
132                    f.write(comment + comment_line + '\n')
133            f.write(delimiter.join(copy_number_header_columns) + '\n')
134            for ti in range(ndarray_tc.shape[0]):
135                f.write(delimiter.join([repr(x) for x in ndarray_tc[ti, :]]) + '\n')
136
137    def __call__(self):
138        # write gcnvkernel version
139        io_commons.write_gcnvkernel_version(self.output_path)
140
141        # write denoising config
142        io_commons.write_dict_to_json_file(
143            os.path.join(self.output_path, io_consts.default_denoising_config_json_filename),
144            self.denoising_config.__dict__, set())
145
146        # write calling config
147        io_commons.write_dict_to_json_file(
148            os.path.join(self.output_path, io_consts.default_calling_config_json_filename),
149            self.calling_config.__dict__, set())
150
151        # extract mean-field parameters
152        approx_var_set, approx_mu_map, approx_std_map = io_commons.extract_mean_field_posterior_parameters(
153            self.denoising_model_approx)
154
155        # compute approximate denoised copy ratios
156        _logger.info("Sampling and approximating posteriors for denoised copy ratios...")
157        denoising_copy_ratios_st_approx_generator = commons.get_sampling_generator_for_model_approximation(
158            model_approx=self.denoising_model_approx, node=self.denoising_model['denoised_copy_ratio_st'])
159        mu_denoised_copy_ratio_st, var_denoised_copy_ratio_st =\
160            math.calculate_mean_and_variance_online(denoising_copy_ratios_st_approx_generator)
161        std_denoised_copy_ratio_st = np.sqrt(var_denoised_copy_ratio_st)
162
163        for si, sample_name in enumerate(self.denoising_calling_workspace.sample_names):
164            sample_name_comment_line = [io_consts.sample_name_sam_header_prefix + sample_name]
165            sample_posterior_path = get_sample_posterior_path(self.output_path, si)
166            _logger.info("Saving posteriors for sample \"{0}\" in \"{1}\"...".format(
167                sample_name, sample_posterior_path))
168            io_commons.assert_output_path_writable(sample_posterior_path, try_creating_output_path=True)
169
170            # write sample-specific posteriors in the approximation
171            io_commons.write_mean_field_sample_specific_params(
172                si, sample_posterior_path, approx_var_set, approx_mu_map, approx_std_map,
173                self.denoising_model, sample_name_comment_line)
174
175            # write sample name
176            io_commons.write_sample_name_to_txt_file(sample_posterior_path, sample_name)
177
178            # write copy number log posterior
179            self.write_ndarray_tc_with_copy_number_header(
180                sample_posterior_path,
181                self.denoising_calling_workspace.log_q_c_stc.get_value(borrow=True)[si, ...],
182                io_consts.default_copy_number_log_posterior_tsv_filename,
183                extra_comment_lines=sample_name_comment_line)
184
185            # write copy number log emission
186            self.write_ndarray_tc_with_copy_number_header(
187                sample_posterior_path,
188                self.denoising_calling_workspace.log_copy_number_emission_stc.get_value(borrow=True)[si, ...],
189                io_consts.default_copy_number_log_emission_tsv_filename,
190                extra_comment_lines=sample_name_comment_line)
191
192            # write baseline copy numbers
193            baseline_copy_number_t = self.denoising_calling_workspace.baseline_copy_number_sj[
194                si, self.denoising_calling_workspace.t_to_j_map.get_value(borrow=True)]
195            io_commons.write_ndarray_to_tsv(
196                os.path.join(sample_posterior_path, io_consts.default_baseline_copy_number_tsv_filename),
197                baseline_copy_number_t,
198                extra_comment_lines=sample_name_comment_line,
199                column_name_str=io_consts.baseline_copy_number_column_name)
200
201            # write denoised copy ratio means
202            mu_denoised_copy_ratio_t = mu_denoised_copy_ratio_st[si, :]
203            io_commons.write_ndarray_to_tsv(
204                os.path.join(sample_posterior_path, io_consts.default_denoised_copy_ratios_mean_tsv_filename),
205                mu_denoised_copy_ratio_t,
206                extra_comment_lines=sample_name_comment_line
207            )
208
209            # write denoised copy ratio standard deviations
210            std_denoised_copy_ratio_t = std_denoised_copy_ratio_st[si, :]
211            io_commons.write_ndarray_to_tsv(
212                os.path.join(sample_posterior_path, io_consts.default_denoised_copy_ratios_std_tsv_filename),
213                std_denoised_copy_ratio_t,
214                extra_comment_lines=sample_name_comment_line
215            )
216
217
218class SampleDenoisingAndCallingPosteriorsReader:
219    """Reads sample-specific model parameters and associated workspace variables from disk."""
220    def __init__(self,
221                 denoising_calling_workspace: DenoisingCallingWorkspace,
222                 denoising_model: DenoisingModel,
223                 denoising_model_approx: pm.MeanField,
224                 input_calls_path: str):
225        self.denoising_calling_workspace = denoising_calling_workspace
226        self.denoising_model = denoising_model
227        self.denoising_model_approx = denoising_model_approx
228        self.input_calls_path = input_calls_path
229
230    @staticmethod
231    def read_ndarray_tc_with_copy_number_header(sample_posterior_path: str,
232                                                input_file_name: str,
233                                                comment=io_consts.default_comment_char,
234                                                delimiter=io_consts.default_delimiter_char) -> np.ndarray:
235        """Reads a TSV-formatted dim-2 (intervals x copy-number) ndarray from a sample posterior path."""
236        ndarray_tc_tsv_file = os.path.join(sample_posterior_path, input_file_name)
237        ndarray_tc_pd = io_commons.read_csv(ndarray_tc_tsv_file, comment=comment, delimiter=delimiter)
238        read_columns = [str(column_name) for column_name in ndarray_tc_pd.columns.values]
239        num_read_columns = len(read_columns)
240        expected_copy_number_header_columns =\
241            [io_consts.copy_number_column_prefix + str(cn) for cn in range(num_read_columns)]
242        assert read_columns == expected_copy_number_header_columns
243        read_ndarray_tc = ndarray_tc_pd.values
244        assert read_ndarray_tc.ndim == 2
245        return read_ndarray_tc
246
247    def _read_sample_copy_number_log_posterior(self,
248                                               sample_posterior_path: str,
249                                               comment=io_consts.default_comment_char,
250                                               delimiter=io_consts.default_delimiter_char) -> np.ndarray:
251        read_log_q_c_tc = self.read_ndarray_tc_with_copy_number_header(
252            sample_posterior_path,
253            io_consts.default_copy_number_log_posterior_tsv_filename,
254            delimiter=delimiter,
255            comment=comment)
256        assert read_log_q_c_tc.shape == (self.denoising_calling_workspace.num_intervals,
257                                         self.denoising_calling_workspace.calling_config.num_copy_number_states)
258        return read_log_q_c_tc
259
260    def _read_sample_copy_number_log_emission(self,
261                                              sample_posterior_path: str,
262                                              comment=io_consts.default_comment_char,
263                                              delimiter=io_consts.default_delimiter_char) -> np.ndarray:
264        read_log_emission_tc = self.read_ndarray_tc_with_copy_number_header(
265            sample_posterior_path,
266            io_consts.default_copy_number_log_emission_tsv_filename,
267            delimiter=delimiter,
268            comment=comment)
269        assert read_log_emission_tc.shape == (self.denoising_calling_workspace.num_intervals,
270                                              self.denoising_calling_workspace.calling_config.num_copy_number_states)
271        return read_log_emission_tc
272
273    def __call__(self):
274        # assert that the interval list is the same
275        interval_list_tsv_file = os.path.join(self.input_calls_path, io_consts.default_interval_list_filename)
276        assert os.path.exists(interval_list_tsv_file)
277        read_interval_list = io_intervals_and_counts.load_interval_list_tsv_file(interval_list_tsv_file)
278        assert read_interval_list == self.denoising_calling_workspace.interval_list
279
280        for si in range(self.denoising_calling_workspace.num_samples):
281            sample_posterior_path = get_sample_posterior_path(self.input_calls_path, si)
282            assert os.path.exists(sample_posterior_path)
283
284            # read sample-specific posteriors and update approximation
285            io_commons.read_mean_field_sample_specific_params(
286                sample_posterior_path, si, self.denoising_calling_workspace.sample_names[si],
287                self.denoising_model_approx, self.denoising_model)
288
289            # read copy number posterior and emission and update workspace
290            log_q_c_tc = self._read_sample_copy_number_log_posterior(sample_posterior_path)
291            log_copy_number_emission_tc = self._read_sample_copy_number_log_emission(sample_posterior_path)
292
293            def update_log_q_c_stc_for_sample(log_q_c_stc):
294                log_q_c_stc[si, ...] = log_q_c_tc[...]
295                return log_q_c_stc
296
297            def update_log_copy_number_emission_stc_for_sample(log_copy_number_emission_stc):
298                log_copy_number_emission_stc[si, ...] = log_copy_number_emission_tc[...]
299                return log_copy_number_emission_stc
300
301            self.denoising_calling_workspace.log_q_c_stc.set_value(
302                update_log_q_c_stc_for_sample(
303                    self.denoising_calling_workspace.log_q_c_stc.get_value(borrow=True)),
304                borrow=True)
305
306            self.denoising_calling_workspace.log_copy_number_emission_stc.set_value(
307                update_log_copy_number_emission_stc_for_sample(
308                    self.denoising_calling_workspace.log_copy_number_emission_stc.get_value(borrow=True)),
309                borrow=True)
310
311        # update auxiliary workspace variables
312        self.denoising_calling_workspace.update_auxiliary_vars()
313