1import os 2import sys 3 4# set theano flags 5user_theano_flags = os.environ.get("THEANO_FLAGS") 6default_theano_flags = "device=cpu,floatX=float64,optimizer=fast_run,compute_test_value=ignore," + \ 7 "openmp=true,blas.ldflags=-lmkl_rt,openmp_elemwise_minsize=10" 8theano_flags = default_theano_flags + ("" if user_theano_flags is None else "," + user_theano_flags) 9os.environ["THEANO_FLAGS"] = theano_flags 10 11import logging 12import argparse 13import gcnvkernel 14import shutil 15import json 16from typing import Dict, Any 17 18logger = logging.getLogger("case_denoising_calling") 19 20parser = argparse.ArgumentParser(description="gCNV case calling tool based on a previously trained model", 21 formatter_class=gcnvkernel.cli_commons.GCNVHelpFormatter) 22 23# logging args 24gcnvkernel.cli_commons.add_logging_args_to_argparse(parser) 25 26# add tool-specific args 27group = parser.add_argument_group(title="Required arguments") 28 29group.add_argument("--input_model_path", 30 type=str, 31 required=True, 32 default=argparse.SUPPRESS, 33 help="Path to denoising model parameters") 34 35group.add_argument("--read_count_tsv_files", 36 type=str, 37 required=True, 38 nargs='+', # one or more 39 default=argparse.SUPPRESS, 40 help="List of read count files in the cohort (in .tsv format; must include sample name header)") 41 42group.add_argument("--ploidy_calls_path", 43 type=str, 44 required=True, 45 default=argparse.SUPPRESS, 46 help="The path to the results of ploidy determination tool") 47 48group.add_argument("--output_calls_path", 49 type=str, 50 required=True, 51 default=argparse.SUPPRESS, 52 help="Output path to write CNV calls") 53 54group.add_argument("--output_opt_path", 55 type=str, 56 required=False, 57 default=argparse.SUPPRESS, 58 help="(advanced) Output path to write the latest optimizer state") 59 60group.add_argument("--output_tracking_path", 61 type=str, 62 required=True, 63 default=argparse.SUPPRESS, 64 help="Output path to write tracked parameters, ELBO, etc.") 65 66group.add_argument("--input_calls_path", 67 type=str, 68 required=False, 69 default=argparse.SUPPRESS, 70 help="Path to previously obtained calls to take as starting point") 71 72group.add_argument("--input_opt_path", 73 type=str, 74 required=False, 75 default=argparse.SUPPRESS, 76 help="(advanced) Path to saved optimizer state to take as the starting point") 77 78# add denoising config args 79# Note: we are hiding parameters that are either set by the model or are irrelevant to the case calling task 80gcnvkernel.DenoisingModelConfig.expose_args( 81 parser, 82 hide={ 83 "--max_bias_factors", 84 "--psi_t_scale", 85 "--log_mean_bias_std", 86 "--init_ard_rel_unexplained_variance", 87 "--enable_bias_factors", 88 "--enable_explicit_gc_bias_modeling", 89 "--disable_bias_factors_in_active_class", 90 "--num_gc_bins", 91 "--gc_curve_sd", 92 }) 93 94# add calling config args 95# Note: we are hiding parameters that are either set by the model or are irrelevant to the case calling task 96gcnvkernel.CopyNumberCallingConfig.expose_args( 97 parser, 98 hide={ 99 '--p_active', 100 '--class_coherence_length' 101 }) 102 103# override some inference parameters 104gcnvkernel.HybridInferenceParameters.expose_args(parser) 105 106 107def update_args_dict_from_saved_model(input_model_path: str, 108 _args_dict: Dict[str, Any]): 109 logging.info("Loading denoising model configuration from the provided model...") 110 with open(os.path.join(input_model_path, "denoising_config.json"), 'r') as fp: 111 loaded_denoising_config_dict = json.load(fp) 112 113 # boolean flags 114 _args_dict['enable_bias_factors'] = \ 115 loaded_denoising_config_dict['enable_bias_factors'] 116 _args_dict['enable_explicit_gc_bias_modeling'] = \ 117 loaded_denoising_config_dict['enable_explicit_gc_bias_modeling'] 118 _args_dict['disable_bias_factors_in_active_class'] = \ 119 loaded_denoising_config_dict['disable_bias_factors_in_active_class'] 120 121 # bias factor related 122 _args_dict['max_bias_factors'] = \ 123 loaded_denoising_config_dict['max_bias_factors'] 124 125 # gc-related 126 _args_dict['num_gc_bins'] = \ 127 loaded_denoising_config_dict['num_gc_bins'] 128 _args_dict['gc_curve_sd'] = \ 129 loaded_denoising_config_dict['gc_curve_sd'] 130 131 logging.info("- bias factors enabled: " 132 + repr(_args_dict['enable_bias_factors'])) 133 logging.info("- explicit GC bias modeling enabled: " 134 + repr(_args_dict['enable_explicit_gc_bias_modeling'])) 135 logging.info("- bias factors in active classes disabled: " 136 + repr(_args_dict['disable_bias_factors_in_active_class'])) 137 138 if _args_dict['enable_bias_factors']: 139 logging.info("- maximum number of bias factors: " 140 + repr(_args_dict['max_bias_factors'])) 141 142 if _args_dict['enable_explicit_gc_bias_modeling']: 143 logging.info("- number of GC curve knobs: " 144 + repr(_args_dict['num_gc_bins'])) 145 logging.info("- GC curve prior standard deviation: " 146 + repr(_args_dict['gc_curve_sd'])) 147 148 149if __name__ == "__main__": 150 151 # parse arguments 152 args = parser.parse_args() 153 gcnvkernel.cli_commons.set_logging_config_from_args(args) 154 155 logger.info("THEANO_FLAGS environment variable has been set to: {theano_flags}".format(theano_flags=theano_flags)) 156 157 # check gcnvkernel version in the input model path 158 gcnvkernel.io_commons.check_gcnvkernel_version_from_path(args.input_model_path) 159 160 # copy the intervals to the calls path 161 # (we do this early to avoid inadvertent cleanup of temporary files) 162 gcnvkernel.io_commons.assert_output_path_writable(args.output_calls_path) 163 shutil.copy(os.path.join(args.input_model_path, gcnvkernel.io_consts.default_interval_list_filename), 164 os.path.join(args.output_calls_path, gcnvkernel.io_consts.default_interval_list_filename)) 165 166 # load modeling interval list from the model 167 logging.info("Loading modeling interval list from the provided model...") 168 modeling_interval_list = gcnvkernel.io_intervals_and_counts.load_interval_list_tsv_file( 169 os.path.join(args.input_model_path, gcnvkernel.io_consts.default_interval_list_filename)) 170 contigs_set = {target.contig for target in modeling_interval_list} 171 logging.info("The model contains {0} intervals and {1} contig(s)".format( 172 len(modeling_interval_list), len(contigs_set))) 173 174 # load sample names, truncated counts, and interval list from the sample read counts table 175 logging.info("Loading {0} read counts file(s)...".format(len(args.read_count_tsv_files))) 176 sample_names, n_st = gcnvkernel.io_intervals_and_counts.load_counts_in_the_modeling_zone( 177 args.read_count_tsv_files, modeling_interval_list) 178 179 # load read depth and ploidy metadata 180 sample_metadata_collection: gcnvkernel.SampleMetadataCollection = gcnvkernel.SampleMetadataCollection() 181 gcnvkernel.io_metadata.update_sample_metadata_collection_from_ploidy_determination_calls( 182 sample_metadata_collection, args.ploidy_calls_path) 183 184 # setup the inference task 185 args_dict = args.__dict__ 186 187 # read model configuration and update args dict 188 update_args_dict_from_saved_model(args.input_model_path, args_dict) 189 190 # instantiate config classes 191 denoising_config = gcnvkernel.DenoisingModelConfig.from_args_dict(args_dict) 192 calling_config = gcnvkernel.CopyNumberCallingConfig.from_args_dict(args_dict) 193 inference_params = gcnvkernel.HybridInferenceParameters.from_args_dict(args_dict) 194 195 # instantiate and initialize the workspace 196 shared_workspace = gcnvkernel.DenoisingCallingWorkspace( 197 denoising_config, calling_config, modeling_interval_list, 198 n_st, sample_names, sample_metadata_collection) 199 200 initial_params_supplier = gcnvkernel.DefaultDenoisingModelInitializer( 201 denoising_config, calling_config, shared_workspace) 202 203 task = gcnvkernel.CaseDenoisingCallingTask( 204 denoising_config, calling_config, inference_params, 205 shared_workspace, initial_params_supplier, args.input_model_path) 206 207 if hasattr(args, 'input_calls_path'): 208 logger.info("A call path was provided to use as starting point...") 209 gcnvkernel.io_denoising_calling.SampleDenoisingAndCallingPosteriorsReader( 210 shared_workspace, task.continuous_model, task.continuous_model_approx, 211 args.input_calls_path)() 212 213 if hasattr(args, 'input_opt_path'): 214 logger.info("A saved optimizer state was provided to use as starting point...") 215 task.fancy_opt.load(args.input_opt_path) 216 217 try: 218 # go! 219 task.engage() 220 task.disengage() 221 except gcnvkernel.ConvergenceError as err: 222 logger.info(err.message) 223 # if inference diverged, pass an exit code to the Java side indicating that restart is needed 224 sys.exit(gcnvkernel.io_consts.diverged_inference_exit_code) 225 226 227 # save calls 228 gcnvkernel.io_denoising_calling.SampleDenoisingAndCallingPosteriorsWriter( 229 denoising_config, calling_config, shared_workspace, task.continuous_model, task.continuous_model_approx, 230 args.output_calls_path)() 231 232 # save optimizer state 233 if hasattr(args, 'output_opt_path'): 234 task.fancy_opt.save(args.output_opt_path) 235 236 # save ELBO history 237 if hasattr(args, 'output_tracking_path'): 238 gcnvkernel.io_commons.assert_output_path_writable(args.output_tracking_path) 239 240 elbo_hist_file = os.path.join(args.output_tracking_path, "elbo_history.tsv") 241 task.save_elbo_history(elbo_hist_file) 242