1import os
2import sys
3
4# set theano flags
5user_theano_flags = os.environ.get("THEANO_FLAGS")
6default_theano_flags = "device=cpu,floatX=float64,optimizer=fast_run,compute_test_value=ignore," + \
7                       "openmp=true,blas.ldflags=-lmkl_rt,openmp_elemwise_minsize=10"
8theano_flags = default_theano_flags + ("" if user_theano_flags is None else "," + user_theano_flags)
9os.environ["THEANO_FLAGS"] = theano_flags
10
11import logging
12import argparse
13import gcnvkernel
14import shutil
15import json
16from typing import Dict, Any
17
18logger = logging.getLogger("case_denoising_calling")
19
20parser = argparse.ArgumentParser(description="gCNV case calling tool based on a previously trained model",
21                                 formatter_class=gcnvkernel.cli_commons.GCNVHelpFormatter)
22
23# logging args
24gcnvkernel.cli_commons.add_logging_args_to_argparse(parser)
25
26# add tool-specific args
27group = parser.add_argument_group(title="Required arguments")
28
29group.add_argument("--input_model_path",
30                   type=str,
31                   required=True,
32                   default=argparse.SUPPRESS,
33                   help="Path to denoising model parameters")
34
35group.add_argument("--read_count_tsv_files",
36                   type=str,
37                   required=True,
38                   nargs='+',  # one or more
39                   default=argparse.SUPPRESS,
40                   help="List of read count files in the cohort (in .tsv format; must include sample name header)")
41
42group.add_argument("--ploidy_calls_path",
43                   type=str,
44                   required=True,
45                   default=argparse.SUPPRESS,
46                   help="The path to the results of ploidy determination tool")
47
48group.add_argument("--output_calls_path",
49                   type=str,
50                   required=True,
51                   default=argparse.SUPPRESS,
52                   help="Output path to write CNV calls")
53
54group.add_argument("--output_opt_path",
55                   type=str,
56                   required=False,
57                   default=argparse.SUPPRESS,
58                   help="(advanced) Output path to write the latest optimizer state")
59
60group.add_argument("--output_tracking_path",
61                   type=str,
62                   required=True,
63                   default=argparse.SUPPRESS,
64                   help="Output path to write tracked parameters, ELBO, etc.")
65
66group.add_argument("--input_calls_path",
67                   type=str,
68                   required=False,
69                   default=argparse.SUPPRESS,
70                   help="Path to previously obtained calls to take as starting point")
71
72group.add_argument("--input_opt_path",
73                   type=str,
74                   required=False,
75                   default=argparse.SUPPRESS,
76                   help="(advanced) Path to saved optimizer state to take as the starting point")
77
78# add denoising config args
79# Note: we are hiding parameters that are either set by the model or are irrelevant to the case calling task
80gcnvkernel.DenoisingModelConfig.expose_args(
81    parser,
82    hide={
83        "--max_bias_factors",
84        "--psi_t_scale",
85        "--log_mean_bias_std",
86        "--init_ard_rel_unexplained_variance",
87        "--enable_bias_factors",
88        "--enable_explicit_gc_bias_modeling",
89        "--disable_bias_factors_in_active_class",
90        "--num_gc_bins",
91        "--gc_curve_sd",
92    })
93
94# add calling config args
95# Note: we are hiding parameters that are either set by the model or are irrelevant to the case calling task
96gcnvkernel.CopyNumberCallingConfig.expose_args(
97    parser,
98    hide={
99        '--p_active',
100        '--class_coherence_length'
101    })
102
103# override some inference parameters
104gcnvkernel.HybridInferenceParameters.expose_args(parser)
105
106
107def update_args_dict_from_saved_model(input_model_path: str,
108                                      _args_dict: Dict[str, Any]):
109    logging.info("Loading denoising model configuration from the provided model...")
110    with open(os.path.join(input_model_path, "denoising_config.json"), 'r') as fp:
111        loaded_denoising_config_dict = json.load(fp)
112
113    # boolean flags
114    _args_dict['enable_bias_factors'] = \
115        loaded_denoising_config_dict['enable_bias_factors']
116    _args_dict['enable_explicit_gc_bias_modeling'] = \
117        loaded_denoising_config_dict['enable_explicit_gc_bias_modeling']
118    _args_dict['disable_bias_factors_in_active_class'] = \
119        loaded_denoising_config_dict['disable_bias_factors_in_active_class']
120
121    # bias factor related
122    _args_dict['max_bias_factors'] = \
123        loaded_denoising_config_dict['max_bias_factors']
124
125    # gc-related
126    _args_dict['num_gc_bins'] = \
127        loaded_denoising_config_dict['num_gc_bins']
128    _args_dict['gc_curve_sd'] = \
129        loaded_denoising_config_dict['gc_curve_sd']
130
131    logging.info("- bias factors enabled: "
132                 + repr(_args_dict['enable_bias_factors']))
133    logging.info("- explicit GC bias modeling enabled: "
134                 + repr(_args_dict['enable_explicit_gc_bias_modeling']))
135    logging.info("- bias factors in active classes disabled: "
136                 + repr(_args_dict['disable_bias_factors_in_active_class']))
137
138    if _args_dict['enable_bias_factors']:
139        logging.info("- maximum number of bias factors: "
140                     + repr(_args_dict['max_bias_factors']))
141
142    if _args_dict['enable_explicit_gc_bias_modeling']:
143        logging.info("- number of GC curve knobs: "
144                     + repr(_args_dict['num_gc_bins']))
145        logging.info("- GC curve prior standard deviation: "
146                     + repr(_args_dict['gc_curve_sd']))
147
148
149if __name__ == "__main__":
150
151    # parse arguments
152    args = parser.parse_args()
153    gcnvkernel.cli_commons.set_logging_config_from_args(args)
154
155    logger.info("THEANO_FLAGS environment variable has been set to: {theano_flags}".format(theano_flags=theano_flags))
156
157    # check gcnvkernel version in the input model path
158    gcnvkernel.io_commons.check_gcnvkernel_version_from_path(args.input_model_path)
159
160    # copy the intervals to the calls path
161    # (we do this early to avoid inadvertent cleanup of temporary files)
162    gcnvkernel.io_commons.assert_output_path_writable(args.output_calls_path)
163    shutil.copy(os.path.join(args.input_model_path, gcnvkernel.io_consts.default_interval_list_filename),
164                os.path.join(args.output_calls_path, gcnvkernel.io_consts.default_interval_list_filename))
165
166    # load modeling interval list from the model
167    logging.info("Loading modeling interval list from the provided model...")
168    modeling_interval_list = gcnvkernel.io_intervals_and_counts.load_interval_list_tsv_file(
169        os.path.join(args.input_model_path, gcnvkernel.io_consts.default_interval_list_filename))
170    contigs_set = {target.contig for target in modeling_interval_list}
171    logging.info("The model contains {0} intervals and {1} contig(s)".format(
172        len(modeling_interval_list), len(contigs_set)))
173
174    # load sample names, truncated counts, and interval list from the sample read counts table
175    logging.info("Loading {0} read counts file(s)...".format(len(args.read_count_tsv_files)))
176    sample_names, n_st = gcnvkernel.io_intervals_and_counts.load_counts_in_the_modeling_zone(
177        args.read_count_tsv_files, modeling_interval_list)
178
179    # load read depth and ploidy metadata
180    sample_metadata_collection: gcnvkernel.SampleMetadataCollection = gcnvkernel.SampleMetadataCollection()
181    gcnvkernel.io_metadata.update_sample_metadata_collection_from_ploidy_determination_calls(
182        sample_metadata_collection, args.ploidy_calls_path)
183
184    # setup the inference task
185    args_dict = args.__dict__
186
187    # read model configuration and update args dict
188    update_args_dict_from_saved_model(args.input_model_path, args_dict)
189
190    # instantiate config classes
191    denoising_config = gcnvkernel.DenoisingModelConfig.from_args_dict(args_dict)
192    calling_config = gcnvkernel.CopyNumberCallingConfig.from_args_dict(args_dict)
193    inference_params = gcnvkernel.HybridInferenceParameters.from_args_dict(args_dict)
194
195    # instantiate and initialize the workspace
196    shared_workspace = gcnvkernel.DenoisingCallingWorkspace(
197        denoising_config, calling_config, modeling_interval_list,
198        n_st, sample_names, sample_metadata_collection)
199
200    initial_params_supplier = gcnvkernel.DefaultDenoisingModelInitializer(
201        denoising_config, calling_config, shared_workspace)
202
203    task = gcnvkernel.CaseDenoisingCallingTask(
204        denoising_config, calling_config, inference_params,
205        shared_workspace, initial_params_supplier, args.input_model_path)
206
207    if hasattr(args, 'input_calls_path'):
208        logger.info("A call path was provided to use as starting point...")
209        gcnvkernel.io_denoising_calling.SampleDenoisingAndCallingPosteriorsReader(
210            shared_workspace, task.continuous_model, task.continuous_model_approx,
211            args.input_calls_path)()
212
213    if hasattr(args, 'input_opt_path'):
214        logger.info("A saved optimizer state was provided to use as starting point...")
215        task.fancy_opt.load(args.input_opt_path)
216
217    try:
218        # go!
219        task.engage()
220        task.disengage()
221    except gcnvkernel.ConvergenceError as err:
222        logger.info(err.message)
223        # if inference diverged, pass an exit code to the Java side indicating that restart is needed
224        sys.exit(gcnvkernel.io_consts.diverged_inference_exit_code)
225
226
227    # save calls
228    gcnvkernel.io_denoising_calling.SampleDenoisingAndCallingPosteriorsWriter(
229        denoising_config, calling_config, shared_workspace, task.continuous_model, task.continuous_model_approx,
230        args.output_calls_path)()
231
232    # save optimizer state
233    if hasattr(args, 'output_opt_path'):
234        task.fancy_opt.save(args.output_opt_path)
235
236    # save ELBO history
237    if hasattr(args, 'output_tracking_path'):
238        gcnvkernel.io_commons.assert_output_path_writable(args.output_tracking_path)
239
240        elbo_hist_file = os.path.join(args.output_tracking_path, "elbo_history.tsv")
241        task.save_elbo_history(elbo_hist_file)
242