1import joblib 2from logging import getLogger 3import os 4import shutil 5import tarfile 6 7import numpy 8import pandas 9 10from chainer.dataset import download 11 12from chainer_chemistry.dataset.parsers.csv_file_parser import CSVFileParser 13from chainer_chemistry.dataset.preprocessors.atomic_number_preprocessor import AtomicNumberPreprocessor # NOQA 14from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter 15from chainer_chemistry.dataset.splitters.scaffold_splitter import ScaffoldSplitter # NOQA 16from chainer_chemistry.dataset.splitters.deepchem_scaffold_splitter import DeepChemScaffoldSplitter # NOQA 17from chainer_chemistry.dataset.splitters import split_method_dict 18from chainer_chemistry.datasets.molnet.molnet_config import molnet_default_config # NOQA 19from chainer_chemistry.datasets.molnet.pdbbind_time import get_pdbbind_time 20from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset 21 22_root = 'pfnet/chainer/molnet' 23 24 25def get_molnet_dataset(dataset_name, preprocessor=None, labels=None, 26 split=None, frac_train=.8, frac_valid=.1, 27 frac_test=.1, seed=777, return_smiles=False, 28 return_pdb_id=False, target_index=None, task_index=0, 29 **kwargs): 30 """Downloads, caches and preprocess MoleculeNet dataset. 31 32 Args: 33 dataset_name (str): MoleculeNet dataset name. If you want to know the 34 detail of MoleculeNet, please refer to 35 `official site <http://moleculenet.ai/datasets-1>`_ 36 If you would like to know what dataset_name is available for 37 chainer_chemistry, please refer to `molnet_config.py`. 38 preprocessor (BasePreprocessor): Preprocessor. 39 It should be chosen based on the network to be trained. 40 If it is None, default `AtomicNumberPreprocessor` is used. 41 labels (str or list): List of target labels. 42 split (str or BaseSplitter or None): How to split dataset into train, 43 validation and test. If `None`, this functions use the splitter 44 that is recommended by MoleculeNet. Additionally You can use an 45 instance of BaseSplitter or choose it from 'random', 'stratified' 46 and 'scaffold'. 47 return_smiles (bool): If set to ``True``, 48 smiles array is also returned. 49 return_pdb_id (bool): If set to ``True``, 50 PDB ID array is also returned. 51 This argument is only used when you select 'pdbbind_smiles'. 52 target_index (list or None): target index list to partially extract 53 dataset. If `None` (default), all examples are parsed. 54 task_index (int): Target task index in dataset for stratification. 55 (Stratified Splitter only) 56 Returns (dict): 57 Dictionary that contains dataset that is already split into train, 58 valid and test dataset and 1-d numpy array with dtype=object(string) 59 which is a vector of smiles for each example or `None`. 60 61 """ 62 if dataset_name not in molnet_default_config: 63 raise ValueError("We don't support {} dataset. Please choose from {}". 64 format(dataset_name, 65 list(molnet_default_config.keys()))) 66 67 if dataset_name == 'pdbbind_grid': 68 pdbbind_subset = kwargs.get('pdbbind_subset') 69 return get_pdbbind_grid(pdbbind_subset, split=split, 70 frac_train=frac_train, frac_valid=frac_valid, 71 frac_test=frac_test, task_index=task_index) 72 if dataset_name == 'pdbbind_smiles': 73 pdbbind_subset = kwargs.get('pdbbind_subset') 74 time_list = kwargs.get('time_list') 75 return get_pdbbind_smiles(pdbbind_subset, preprocessor=preprocessor, 76 labels=labels, split=split, 77 frac_train=frac_train, frac_valid=frac_valid, 78 frac_test=frac_test, 79 return_smiles=return_smiles, 80 return_pdb_id=return_pdb_id, 81 target_index=target_index, 82 task_index=task_index, 83 time_list=time_list) 84 85 dataset_config = molnet_default_config[dataset_name] 86 labels = labels or dataset_config['tasks'] 87 if isinstance(labels, str): 88 labels = [labels, ] 89 90 if preprocessor is None: 91 preprocessor = AtomicNumberPreprocessor() 92 93 if dataset_config['task_type'] == 'regression': 94 def postprocess_label(label_list): 95 return numpy.asarray(label_list, dtype=numpy.float32) 96 elif dataset_config['task_type'] == 'classification': 97 def postprocess_label(label_list): 98 label_list = numpy.asarray(label_list) 99 label_list[numpy.isnan(label_list)] = -1 100 return label_list.astype(numpy.int32) 101 102 parser = CSVFileParser(preprocessor, labels=labels, 103 smiles_col=dataset_config['smiles_columns'], 104 postprocess_label=postprocess_label) 105 if dataset_config['dataset_type'] == 'one_file_csv': 106 split = dataset_config['split'] if split is None else split 107 108 if isinstance(split, str): 109 splitter = split_method_dict[split]() 110 elif isinstance(split, BaseSplitter): 111 splitter = split 112 else: 113 raise TypeError("split must be None, str or instance of" 114 " BaseSplitter, but got {}".format(type(split))) 115 116 if isinstance(splitter, (ScaffoldSplitter, DeepChemScaffoldSplitter)): 117 get_smiles = True 118 else: 119 get_smiles = return_smiles 120 121 result = parser.parse(get_molnet_filepath(dataset_name), 122 return_smiles=get_smiles, 123 target_index=target_index, **kwargs) 124 dataset = result['dataset'] 125 smiles = result['smiles'] 126 train_ind, valid_ind, test_ind = \ 127 splitter.train_valid_test_split(dataset, smiles_list=smiles, 128 task_index=task_index, 129 frac_train=frac_train, 130 frac_valid=frac_valid, 131 frac_test=frac_test, **kwargs) 132 train = NumpyTupleDataset(*dataset.features[train_ind]) 133 valid = NumpyTupleDataset(*dataset.features[valid_ind]) 134 test = NumpyTupleDataset(*dataset.features[test_ind]) 135 136 result['dataset'] = (train, valid, test) 137 if return_smiles: 138 train_smiles = smiles[train_ind] 139 valid_smiles = smiles[valid_ind] 140 test_smiles = smiles[test_ind] 141 result['smiles'] = (train_smiles, valid_smiles, test_smiles) 142 else: 143 result['smiles'] = None 144 elif dataset_config['dataset_type'] == 'separate_csv': 145 result = {} 146 train_result = parser.parse(get_molnet_filepath(dataset_name, 'train'), 147 return_smiles=return_smiles, 148 target_index=target_index) 149 valid_result = parser.parse(get_molnet_filepath(dataset_name, 'valid'), 150 return_smiles=return_smiles, 151 target_index=target_index) 152 test_result = parser.parse(get_molnet_filepath(dataset_name, 'test'), 153 return_smiles=return_smiles, 154 target_index=target_index) 155 result['dataset'] = (train_result['dataset'], valid_result['dataset'], 156 test_result['dataset']) 157 result['smiles'] = (train_result['smiles'], valid_result['smiles'], 158 test_result['smiles']) 159 else: 160 raise ValueError('dataset_type={} is not supported' 161 .format(dataset_config['dataset_type'])) 162 163 return result 164 165 166def get_molnet_dataframe(dataset_name, pdbbind_subset=None): 167 """Downloads, caches and get the dataframe of MoleculeNet dataset. 168 169 Args: 170 dataset_name (str): MoleculeNet dataset name. If you want to know the 171 detail of MoleculeNet, please refer to 172 `official site <http://moleculenet.ai/datasets-1>`_ 173 If you would like to know what dataset_name is available for 174 chainer_chemistry, please refer to `molnet_config.py`. 175 pdbbind_subset (str): PDBbind dataset subset name. If you want to know 176 the detail of subset, please refer to `official site 177 <http://www.pdbbind.org.cn/download/pdbbind_2017_intro.pdf>` 178 Returns (pandas.DataFrame or tuple): 179 DataFrame of dataset without any preprocessing. When the files of 180 dataset are seprated, this function returns multiple DataFrame. 181 182 """ 183 if dataset_name not in molnet_default_config: 184 raise ValueError("We don't support {} dataset. Please choose from {}". 185 format(dataset_name, 186 list(molnet_default_config.keys()))) 187 if dataset_name == 'pdbbind_grid': 188 raise ValueError('pdbbind_grid dataset is not supported. Please ', 189 'choose pdbbind_smiles dataset.') 190 dataset_config = molnet_default_config[dataset_name] 191 if dataset_config['dataset_type'] == 'one_file_csv': 192 df = pandas.read_csv(get_molnet_filepath( 193 dataset_name, pdbbind_subset=pdbbind_subset)) 194 return df 195 elif dataset_config['dataset_type'] == 'separate_csv': 196 train_df = pandas.read_csv(get_molnet_filepath(dataset_name, 'train')) 197 valid_df = pandas.read_csv(get_molnet_filepath(dataset_name, 'valid')) 198 test_df = pandas.read_csv(get_molnet_filepath(dataset_name, 'test')) 199 return train_df, valid_df, test_df 200 else: 201 raise ValueError('dataset_type={} is not supported' 202 .format(dataset_config['dataset_type'])) 203 204 205def get_molnet_filepath(dataset_name, filetype='onefile', 206 download_if_not_exist=True, pdbbind_subset=None): 207 """Construct a file path which stores MoleculeNet dataset. 208 209 This method check whether the file exist or not, and downloaded it if 210 necessary. 211 212 Args: 213 dataset_name (str): MoleculeNet dataset name. 214 file_type (str): either 'onefile', 'train', 'valid', 'test' 215 download_if_not_exist (bool): Download a file if it does not exist. 216 217 Returns (str): filepath for specific MoleculeNet dataset 218 219 """ 220 filetype_supported = ['onefile', 'train', 'valid', 'test'] 221 if filetype not in filetype_supported: 222 raise ValueError("filetype {} not supported, please choose filetype " 223 "from {}".format(filetype, filetype_supported)) 224 if filetype == 'onefile': 225 url_key = 'url' 226 else: 227 url_key = filetype + '_url' 228 if dataset_name == 'pdbbind_smiles': 229 file_url = molnet_default_config[dataset_name][url_key][pdbbind_subset] 230 else: 231 file_url = molnet_default_config[dataset_name][url_key] 232 file_name = file_url.split('/')[-1] 233 cache_path = _get_molnet_filepath(file_name) 234 if not os.path.exists(cache_path): 235 if download_if_not_exist: 236 is_successful = download_dataset(file_url, 237 save_filepath=cache_path) 238 if not is_successful: 239 logger = getLogger(__name__) 240 logger.warning('Download failed.') 241 return cache_path 242 243 244def _get_molnet_filepath(file_name): 245 """Construct a filepath which stores MoleculeNet dataset in csv 246 247 This method does not check if the file is already downloaded or not. 248 249 Args: 250 file_name (str): file name of MoleculeNet dataset 251 252 Returns (str): filepath for one of MoleculeNet dataset 253 254 """ 255 cache_root = download.get_dataset_directory(_root) 256 cache_path = os.path.join(cache_root, file_name) 257 return cache_path 258 259 260def download_dataset(dataset_url, save_filepath): 261 """Download and caches MoleculeNet Dataset 262 263 Args: 264 dataset_url (str): URL of dataset 265 save_filepath (str): filepath for dataset 266 267 Returns (bool): If success downloading, returning `True`. 268 269 """ 270 logger = getLogger(__name__) 271 logger.warning('Downloading {} dataset, it takes time...' 272 .format(dataset_url.split('/')[-1])) 273 download_file_path = download.cached_download(dataset_url) 274 shutil.move(download_file_path, save_filepath) 275 # pandas can load gzipped or tarball csv file 276 return True 277 278 279def get_pdbbind_smiles(pdbbind_subset, preprocessor=None, labels=None, 280 split=None, frac_train=.8, frac_valid=.1, 281 frac_test=.1, return_smiles=False, return_pdb_id=True, 282 target_index=None, task_index=0, time_list=None, 283 **kwargs): 284 """Downloads, caches and preprocess PDBbind dataset. 285 286 Args: 287 pdbbind_subset (str): PDBbind dataset subset name. If you want to know 288 the detail of subset, please refer to `official site 289 <http://www.pdbbind.org.cn/download/pdbbind_2017_intro.pdf>` 290 preprocessor (BasePreprocessor): Preprocessor. 291 It should be chosen based on the network to be trained. 292 If it is None, default `AtomicNumberPreprocessor` is used. 293 labels (str or list): List of target labels. 294 split (str or BaseSplitter or None): How to split dataset into train, 295 validation and test. If `None`, this functions use the splitter 296 that is recommended by MoleculeNet. Additionally You can use an 297 instance of BaseSplitter or choose it from 'random', 'stratified' 298 and 'scaffold'. 299 return_smiles (bool): If set to ``True``, 300 smiles array is also returned. 301 return_pdb_id (bool): If set to ``True``, 302 PDB ID array is also returned. 303 This argument is only used when you select 'pdbbind_smiles'. 304 target_index (list or None): target index list to partially extract 305 dataset. If `None` (default), all examples are parsed. 306 task_index (int): Target task index in dataset for stratification. 307 (Stratified Splitter only) 308 Returns (dict): 309 Dictionary that contains dataset that is already split into train, 310 valid and test dataset and 1-d numpy arrays with dtype=object(string) 311 which are vectors of smiles and pdb_id for each example or `None`. 312 313 """ 314 config = molnet_default_config['pdbbind_smiles'] 315 labels = labels or config['tasks'] 316 if isinstance(labels, str): 317 labels = [labels, ] 318 319 if preprocessor is None: 320 preprocessor = AtomicNumberPreprocessor() 321 322 def postprocess_label(label_list): 323 return numpy.asarray(label_list, dtype=numpy.float32) 324 325 parser = CSVFileParser(preprocessor, labels=labels, 326 smiles_col=config['smiles_columns'], 327 postprocess_label=postprocess_label) 328 split = config['split'] if split is None else split 329 if isinstance(split, str): 330 splitter = split_method_dict[split]() 331 elif isinstance(split, BaseSplitter): 332 splitter = split 333 else: 334 raise TypeError("split must be None, str or instance of" 335 " BaseSplitter, but got {}".format(type(split))) 336 337 result = parser.parse(get_molnet_filepath('pdbbind_smiles', 338 pdbbind_subset=pdbbind_subset), 339 return_smiles=return_smiles, 340 return_is_successful=True, 341 target_index=target_index) 342 dataset = result['dataset'] 343 smiles = result['smiles'] 344 is_successful = result['is_successful'] 345 346 if return_pdb_id: 347 df = pandas.read_csv( 348 get_molnet_filepath('pdbbind_smiles', 349 pdbbind_subset=pdbbind_subset)) 350 pdb_id = df['id'][is_successful] 351 else: 352 pdb_id = None 353 354 train_ind, valid_ind, test_ind = \ 355 splitter.train_valid_test_split(dataset, time_list=time_list, 356 smiles_list=smiles, 357 task_index=task_index, 358 frac_train=frac_train, 359 frac_valid=frac_valid, 360 frac_test=frac_test, **kwargs) 361 train = NumpyTupleDataset(*dataset.features[train_ind]) 362 valid = NumpyTupleDataset(*dataset.features[valid_ind]) 363 test = NumpyTupleDataset(*dataset.features[test_ind]) 364 365 result['dataset'] = (train, valid, test) 366 367 if return_smiles: 368 train_smiles = smiles[train_ind] 369 valid_smiles = smiles[valid_ind] 370 test_smiles = smiles[test_ind] 371 result['smiles'] = (train_smiles, valid_smiles, test_smiles) 372 else: 373 result['smiles'] = None 374 375 if return_pdb_id: 376 train_pdb_id = pdb_id[train_ind] 377 valid_pdb_id = pdb_id[valid_ind] 378 test_pdb_id = pdb_id[test_ind] 379 result['pdb_id'] = (train_pdb_id, valid_pdb_id, test_pdb_id) 380 else: 381 result['pdb_id'] = None 382 return result 383 384 385def get_pdbbind_grid(pdbbind_subset, split=None, frac_train=.8, frac_valid=.1, 386 frac_test=.1, task_index=0, **kwargs): 387 """Downloads, caches and grid-featurize PDBbind dataset. 388 389 Args: 390 pdbbind_subset (str): PDBbind dataset subset name. If you want to know 391 the detail of subset, please refer to `official site 392 <http://www.pdbbind.org.cn/download/pdbbind_2017_intro.pdf>` 393 split (str or BaseSplitter or None): How to split dataset into train, 394 validation and test. If `None`, this functions use the splitter 395 that is recommended by MoleculeNet. Additionally You can use an 396 instance of BaseSplitter or choose it from 'random', 'stratified' 397 and 'scaffold'. 398 task_index (int): Target task index in dataset for stratification. 399 (Stratified Splitter only) 400 Returns (dict): 401 Dictionary that contains dataset that is already split into train, 402 valid and test dataset and 1-d numpy arrays with dtype=object(string) 403 which are vectors of smiles and pdb_id for each example or `None`. 404 405 """ 406 result = {} 407 dataset = get_grid_featurized_pdbbind_dataset(pdbbind_subset) 408 if split is None: 409 split = molnet_default_config['pdbbind_grid']['split'] 410 if isinstance(split, str): 411 splitter = split_method_dict[split]() 412 elif isinstance(split, BaseSplitter): 413 splitter = split 414 else: 415 raise TypeError("split must be None, str, or instance of" 416 " BaseSplitter, but got {}".format(type(split))) 417 time_list = get_pdbbind_time() 418 train_ind, valid_ind, test_ind = \ 419 splitter.train_valid_test_split(dataset, time_list=time_list, 420 smiles_list=None, 421 task_index=task_index, 422 frac_train=frac_train, 423 frac_valid=frac_valid, 424 frac_test=frac_test, **kwargs) 425 train = NumpyTupleDataset(*dataset.features[train_ind]) 426 valid = NumpyTupleDataset(*dataset.features[valid_ind]) 427 test = NumpyTupleDataset(*dataset.features[test_ind]) 428 429 result['dataset'] = (train, valid, test) 430 result['smiles'] = None 431 return result 432 433 434def get_grid_featurized_pdbbind_dataset(subset): 435 """Downloads and caches grid featurized PDBBind dataset. 436 437 Args: 438 subset (str): subset name of PDBBind dataset. 439 440 Returns (NumpyTupleDataset): 441 grid featurized PDBBind dataset. 442 443 """ 444 x_path, y_path = get_grid_featurized_pdbbind_filepath(subset) 445 x = joblib.load(x_path).astype('i') 446 y = joblib.load(y_path).astype('f') 447 dataset = NumpyTupleDataset(x, y) 448 return dataset 449 450 451def get_grid_featurized_pdbbind_dirpath(subset, download_if_not_exist=True): 452 """Construct a directory path which stores grid featurized PDBBind dataset. 453 454 This method check whether the file exist or not, and downloaded it if 455 necessary. 456 457 Args: 458 subset (str): subset name of PDBBind dataset. 459 download_if_not_exist (bool): Download a file if it does not exist. 460 461 Returns (str): directory path for specific subset of PDBBind dataset. 462 463 """ 464 subset_supported = ['core', 'full', 'refined'] 465 if subset not in subset_supported: 466 raise ValueError("subset {} not supported, please choose filetype " 467 "from {}".format(subset, subset_supported)) 468 file_url = \ 469 molnet_default_config['pdbbind_grid']['url'][subset] 470 file_name = file_url.split('/')[-1] 471 cache_path = _get_molnet_filepath(file_name) 472 if not os.path.exists(cache_path): 473 if download_if_not_exist: 474 is_successful = download_dataset(file_url, 475 save_filepath=cache_path) 476 if not is_successful: 477 logger = getLogger(__name__) 478 logger.warning('Download failed.') 479 return cache_path 480 481 482def get_grid_featurized_pdbbind_filepath(subset): 483 """Construct a filepath which stores featurized PDBBind dataset in joblib 484 485 This method does not check if the file is already downloaded or not. 486 487 Args: 488 subset (str): subset name of PDBBind dataset 489 490 Returns: 491 x_path (str): filepath for feature vectors 492 y_path (str): filepath for -logKd/Ki 493 494 """ 495 dirpath = get_grid_featurized_pdbbind_dirpath(subset=subset) 496 savedir = '/'.join(dirpath.split('/')[:-1]) + '/' 497 with tarfile.open(dirpath, 'r:gz') as tar: 498 tar.extractall(savedir) 499 x_path = savedir + subset + '_grid/shard-0-X.joblib' 500 y_path = savedir + subset + '_grid/shard-0-y.joblib' 501 return x_path, y_path 502