1import joblib
2from logging import getLogger
3import os
4import shutil
5import tarfile
6
7import numpy
8import pandas
9
10from chainer.dataset import download
11
12from chainer_chemistry.dataset.parsers.csv_file_parser import CSVFileParser
13from chainer_chemistry.dataset.preprocessors.atomic_number_preprocessor import AtomicNumberPreprocessor  # NOQA
14from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter
15from chainer_chemistry.dataset.splitters.scaffold_splitter import ScaffoldSplitter  # NOQA
16from chainer_chemistry.dataset.splitters.deepchem_scaffold_splitter import DeepChemScaffoldSplitter  # NOQA
17from chainer_chemistry.dataset.splitters import split_method_dict
18from chainer_chemistry.datasets.molnet.molnet_config import molnet_default_config  # NOQA
19from chainer_chemistry.datasets.molnet.pdbbind_time import get_pdbbind_time
20from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset
21
22_root = 'pfnet/chainer/molnet'
23
24
25def get_molnet_dataset(dataset_name, preprocessor=None, labels=None,
26                       split=None, frac_train=.8, frac_valid=.1,
27                       frac_test=.1, seed=777, return_smiles=False,
28                       return_pdb_id=False, target_index=None, task_index=0,
29                       **kwargs):
30    """Downloads, caches and preprocess MoleculeNet dataset.
31
32    Args:
33        dataset_name (str): MoleculeNet dataset name. If you want to know the
34            detail of MoleculeNet, please refer to
35            `official site <http://moleculenet.ai/datasets-1>`_
36            If you would like to know what dataset_name is available for
37            chainer_chemistry, please refer to `molnet_config.py`.
38        preprocessor (BasePreprocessor): Preprocessor.
39            It should be chosen based on the network to be trained.
40            If it is None, default `AtomicNumberPreprocessor` is used.
41        labels (str or list): List of target labels.
42        split (str or BaseSplitter or None): How to split dataset into train,
43            validation and test. If `None`, this functions use the splitter
44            that is recommended by MoleculeNet. Additionally You can use an
45            instance of BaseSplitter or choose it from 'random', 'stratified'
46            and 'scaffold'.
47        return_smiles (bool): If set to ``True``,
48            smiles array is also returned.
49        return_pdb_id (bool): If set to ``True``,
50            PDB ID array is also returned.
51            This argument is only used when you select 'pdbbind_smiles'.
52        target_index (list or None): target index list to partially extract
53            dataset. If `None` (default), all examples are parsed.
54        task_index (int): Target task index in dataset for stratification.
55            (Stratified Splitter only)
56    Returns (dict):
57        Dictionary that contains dataset that is already split into train,
58        valid and test dataset and 1-d numpy array with dtype=object(string)
59        which is a vector of smiles for each example or `None`.
60
61    """
62    if dataset_name not in molnet_default_config:
63        raise ValueError("We don't support {} dataset. Please choose from {}".
64                         format(dataset_name,
65                                list(molnet_default_config.keys())))
66
67    if dataset_name == 'pdbbind_grid':
68        pdbbind_subset = kwargs.get('pdbbind_subset')
69        return get_pdbbind_grid(pdbbind_subset, split=split,
70                                frac_train=frac_train, frac_valid=frac_valid,
71                                frac_test=frac_test, task_index=task_index)
72    if dataset_name == 'pdbbind_smiles':
73        pdbbind_subset = kwargs.get('pdbbind_subset')
74        time_list = kwargs.get('time_list')
75        return get_pdbbind_smiles(pdbbind_subset, preprocessor=preprocessor,
76                                  labels=labels, split=split,
77                                  frac_train=frac_train, frac_valid=frac_valid,
78                                  frac_test=frac_test,
79                                  return_smiles=return_smiles,
80                                  return_pdb_id=return_pdb_id,
81                                  target_index=target_index,
82                                  task_index=task_index,
83                                  time_list=time_list)
84
85    dataset_config = molnet_default_config[dataset_name]
86    labels = labels or dataset_config['tasks']
87    if isinstance(labels, str):
88        labels = [labels, ]
89
90    if preprocessor is None:
91        preprocessor = AtomicNumberPreprocessor()
92
93    if dataset_config['task_type'] == 'regression':
94        def postprocess_label(label_list):
95            return numpy.asarray(label_list, dtype=numpy.float32)
96    elif dataset_config['task_type'] == 'classification':
97        def postprocess_label(label_list):
98            label_list = numpy.asarray(label_list)
99            label_list[numpy.isnan(label_list)] = -1
100            return label_list.astype(numpy.int32)
101
102    parser = CSVFileParser(preprocessor, labels=labels,
103                           smiles_col=dataset_config['smiles_columns'],
104                           postprocess_label=postprocess_label)
105    if dataset_config['dataset_type'] == 'one_file_csv':
106        split = dataset_config['split'] if split is None else split
107
108        if isinstance(split, str):
109            splitter = split_method_dict[split]()
110        elif isinstance(split, BaseSplitter):
111            splitter = split
112        else:
113            raise TypeError("split must be None, str or instance of"
114                            " BaseSplitter, but got {}".format(type(split)))
115
116        if isinstance(splitter, (ScaffoldSplitter, DeepChemScaffoldSplitter)):
117            get_smiles = True
118        else:
119            get_smiles = return_smiles
120
121        result = parser.parse(get_molnet_filepath(dataset_name),
122                              return_smiles=get_smiles,
123                              target_index=target_index, **kwargs)
124        dataset = result['dataset']
125        smiles = result['smiles']
126        train_ind, valid_ind, test_ind = \
127            splitter.train_valid_test_split(dataset, smiles_list=smiles,
128                                            task_index=task_index,
129                                            frac_train=frac_train,
130                                            frac_valid=frac_valid,
131                                            frac_test=frac_test, **kwargs)
132        train = NumpyTupleDataset(*dataset.features[train_ind])
133        valid = NumpyTupleDataset(*dataset.features[valid_ind])
134        test = NumpyTupleDataset(*dataset.features[test_ind])
135
136        result['dataset'] = (train, valid, test)
137        if return_smiles:
138            train_smiles = smiles[train_ind]
139            valid_smiles = smiles[valid_ind]
140            test_smiles = smiles[test_ind]
141            result['smiles'] = (train_smiles, valid_smiles, test_smiles)
142        else:
143            result['smiles'] = None
144    elif dataset_config['dataset_type'] == 'separate_csv':
145        result = {}
146        train_result = parser.parse(get_molnet_filepath(dataset_name, 'train'),
147                                    return_smiles=return_smiles,
148                                    target_index=target_index)
149        valid_result = parser.parse(get_molnet_filepath(dataset_name, 'valid'),
150                                    return_smiles=return_smiles,
151                                    target_index=target_index)
152        test_result = parser.parse(get_molnet_filepath(dataset_name, 'test'),
153                                   return_smiles=return_smiles,
154                                   target_index=target_index)
155        result['dataset'] = (train_result['dataset'], valid_result['dataset'],
156                             test_result['dataset'])
157        result['smiles'] = (train_result['smiles'], valid_result['smiles'],
158                            test_result['smiles'])
159    else:
160        raise ValueError('dataset_type={} is not supported'
161                         .format(dataset_config['dataset_type']))
162
163    return result
164
165
166def get_molnet_dataframe(dataset_name, pdbbind_subset=None):
167    """Downloads, caches and get the dataframe of MoleculeNet dataset.
168
169    Args:
170        dataset_name (str): MoleculeNet dataset name. If you want to know the
171            detail of MoleculeNet, please refer to
172            `official site <http://moleculenet.ai/datasets-1>`_
173            If you would like to know what dataset_name is available for
174            chainer_chemistry, please refer to `molnet_config.py`.
175        pdbbind_subset (str): PDBbind dataset subset name. If you want to know
176            the detail of subset, please refer to `official site
177            <http://www.pdbbind.org.cn/download/pdbbind_2017_intro.pdf>`
178    Returns (pandas.DataFrame or tuple):
179        DataFrame of dataset without any preprocessing. When the files of
180        dataset are seprated, this function returns multiple DataFrame.
181
182    """
183    if dataset_name not in molnet_default_config:
184        raise ValueError("We don't support {} dataset. Please choose from {}".
185                         format(dataset_name,
186                                list(molnet_default_config.keys())))
187    if dataset_name == 'pdbbind_grid':
188        raise ValueError('pdbbind_grid dataset is not supported. Please ',
189                         'choose pdbbind_smiles dataset.')
190    dataset_config = molnet_default_config[dataset_name]
191    if dataset_config['dataset_type'] == 'one_file_csv':
192        df = pandas.read_csv(get_molnet_filepath(
193            dataset_name, pdbbind_subset=pdbbind_subset))
194        return df
195    elif dataset_config['dataset_type'] == 'separate_csv':
196        train_df = pandas.read_csv(get_molnet_filepath(dataset_name, 'train'))
197        valid_df = pandas.read_csv(get_molnet_filepath(dataset_name, 'valid'))
198        test_df = pandas.read_csv(get_molnet_filepath(dataset_name, 'test'))
199        return train_df, valid_df, test_df
200    else:
201        raise ValueError('dataset_type={} is not supported'
202                         .format(dataset_config['dataset_type']))
203
204
205def get_molnet_filepath(dataset_name, filetype='onefile',
206                        download_if_not_exist=True, pdbbind_subset=None):
207    """Construct a file path which stores MoleculeNet dataset.
208
209    This method check whether the file exist or not, and downloaded it if
210    necessary.
211
212    Args:
213        dataset_name (str): MoleculeNet dataset name.
214        file_type (str): either 'onefile', 'train', 'valid', 'test'
215        download_if_not_exist (bool): Download a file if it does not exist.
216
217    Returns (str): filepath for specific MoleculeNet dataset
218
219    """
220    filetype_supported = ['onefile', 'train', 'valid', 'test']
221    if filetype not in filetype_supported:
222        raise ValueError("filetype {} not supported, please choose filetype "
223                         "from {}".format(filetype, filetype_supported))
224    if filetype == 'onefile':
225        url_key = 'url'
226    else:
227        url_key = filetype + '_url'
228    if dataset_name == 'pdbbind_smiles':
229        file_url = molnet_default_config[dataset_name][url_key][pdbbind_subset]
230    else:
231        file_url = molnet_default_config[dataset_name][url_key]
232    file_name = file_url.split('/')[-1]
233    cache_path = _get_molnet_filepath(file_name)
234    if not os.path.exists(cache_path):
235        if download_if_not_exist:
236            is_successful = download_dataset(file_url,
237                                             save_filepath=cache_path)
238            if not is_successful:
239                logger = getLogger(__name__)
240                logger.warning('Download failed.')
241    return cache_path
242
243
244def _get_molnet_filepath(file_name):
245    """Construct a filepath which stores MoleculeNet dataset in csv
246
247    This method does not check if the file is already downloaded or not.
248
249    Args:
250        file_name (str): file name of MoleculeNet dataset
251
252    Returns (str): filepath for one of MoleculeNet dataset
253
254    """
255    cache_root = download.get_dataset_directory(_root)
256    cache_path = os.path.join(cache_root, file_name)
257    return cache_path
258
259
260def download_dataset(dataset_url, save_filepath):
261    """Download and caches MoleculeNet Dataset
262
263    Args:
264        dataset_url (str): URL of dataset
265        save_filepath (str): filepath for dataset
266
267    Returns (bool): If success downloading, returning `True`.
268
269    """
270    logger = getLogger(__name__)
271    logger.warning('Downloading {} dataset, it takes time...'
272                   .format(dataset_url.split('/')[-1]))
273    download_file_path = download.cached_download(dataset_url)
274    shutil.move(download_file_path, save_filepath)
275    # pandas can load gzipped or tarball csv file
276    return True
277
278
279def get_pdbbind_smiles(pdbbind_subset, preprocessor=None, labels=None,
280                       split=None, frac_train=.8, frac_valid=.1,
281                       frac_test=.1, return_smiles=False, return_pdb_id=True,
282                       target_index=None, task_index=0, time_list=None,
283                       **kwargs):
284    """Downloads, caches and preprocess PDBbind dataset.
285
286    Args:
287        pdbbind_subset (str): PDBbind dataset subset name. If you want to know
288            the detail of subset, please refer to `official site
289            <http://www.pdbbind.org.cn/download/pdbbind_2017_intro.pdf>`
290        preprocessor (BasePreprocessor): Preprocessor.
291            It should be chosen based on the network to be trained.
292            If it is None, default `AtomicNumberPreprocessor` is used.
293        labels (str or list): List of target labels.
294        split (str or BaseSplitter or None): How to split dataset into train,
295            validation and test. If `None`, this functions use the splitter
296            that is recommended by MoleculeNet. Additionally You can use an
297            instance of BaseSplitter or choose it from 'random', 'stratified'
298            and 'scaffold'.
299        return_smiles (bool): If set to ``True``,
300            smiles array is also returned.
301        return_pdb_id (bool): If set to ``True``,
302            PDB ID array is also returned.
303            This argument is only used when you select 'pdbbind_smiles'.
304        target_index (list or None): target index list to partially extract
305            dataset. If `None` (default), all examples are parsed.
306        task_index (int): Target task index in dataset for stratification.
307            (Stratified Splitter only)
308    Returns (dict):
309        Dictionary that contains dataset that is already split into train,
310        valid and test dataset and 1-d numpy arrays with dtype=object(string)
311        which are vectors of smiles and pdb_id for each example or `None`.
312
313    """
314    config = molnet_default_config['pdbbind_smiles']
315    labels = labels or config['tasks']
316    if isinstance(labels, str):
317        labels = [labels, ]
318
319    if preprocessor is None:
320        preprocessor = AtomicNumberPreprocessor()
321
322    def postprocess_label(label_list):
323        return numpy.asarray(label_list, dtype=numpy.float32)
324
325    parser = CSVFileParser(preprocessor, labels=labels,
326                           smiles_col=config['smiles_columns'],
327                           postprocess_label=postprocess_label)
328    split = config['split'] if split is None else split
329    if isinstance(split, str):
330        splitter = split_method_dict[split]()
331    elif isinstance(split, BaseSplitter):
332        splitter = split
333    else:
334        raise TypeError("split must be None, str or instance of"
335                        " BaseSplitter, but got {}".format(type(split)))
336
337    result = parser.parse(get_molnet_filepath('pdbbind_smiles',
338                                              pdbbind_subset=pdbbind_subset),
339                          return_smiles=return_smiles,
340                          return_is_successful=True,
341                          target_index=target_index)
342    dataset = result['dataset']
343    smiles = result['smiles']
344    is_successful = result['is_successful']
345
346    if return_pdb_id:
347        df = pandas.read_csv(
348            get_molnet_filepath('pdbbind_smiles',
349                                pdbbind_subset=pdbbind_subset))
350        pdb_id = df['id'][is_successful]
351    else:
352        pdb_id = None
353
354    train_ind, valid_ind, test_ind = \
355        splitter.train_valid_test_split(dataset, time_list=time_list,
356                                        smiles_list=smiles,
357                                        task_index=task_index,
358                                        frac_train=frac_train,
359                                        frac_valid=frac_valid,
360                                        frac_test=frac_test, **kwargs)
361    train = NumpyTupleDataset(*dataset.features[train_ind])
362    valid = NumpyTupleDataset(*dataset.features[valid_ind])
363    test = NumpyTupleDataset(*dataset.features[test_ind])
364
365    result['dataset'] = (train, valid, test)
366
367    if return_smiles:
368        train_smiles = smiles[train_ind]
369        valid_smiles = smiles[valid_ind]
370        test_smiles = smiles[test_ind]
371        result['smiles'] = (train_smiles, valid_smiles, test_smiles)
372    else:
373        result['smiles'] = None
374
375    if return_pdb_id:
376        train_pdb_id = pdb_id[train_ind]
377        valid_pdb_id = pdb_id[valid_ind]
378        test_pdb_id = pdb_id[test_ind]
379        result['pdb_id'] = (train_pdb_id, valid_pdb_id, test_pdb_id)
380    else:
381        result['pdb_id'] = None
382    return result
383
384
385def get_pdbbind_grid(pdbbind_subset, split=None, frac_train=.8, frac_valid=.1,
386                     frac_test=.1, task_index=0, **kwargs):
387    """Downloads, caches and grid-featurize PDBbind dataset.
388
389    Args:
390        pdbbind_subset (str): PDBbind dataset subset name. If you want to know
391            the detail of subset, please refer to `official site
392            <http://www.pdbbind.org.cn/download/pdbbind_2017_intro.pdf>`
393        split (str or BaseSplitter or None): How to split dataset into train,
394            validation and test. If `None`, this functions use the splitter
395            that is recommended by MoleculeNet. Additionally You can use an
396            instance of BaseSplitter or choose it from 'random', 'stratified'
397            and 'scaffold'.
398        task_index (int): Target task index in dataset for stratification.
399            (Stratified Splitter only)
400    Returns (dict):
401        Dictionary that contains dataset that is already split into train,
402        valid and test dataset and 1-d numpy arrays with dtype=object(string)
403        which are vectors of smiles and pdb_id for each example or `None`.
404
405    """
406    result = {}
407    dataset = get_grid_featurized_pdbbind_dataset(pdbbind_subset)
408    if split is None:
409        split = molnet_default_config['pdbbind_grid']['split']
410    if isinstance(split, str):
411        splitter = split_method_dict[split]()
412    elif isinstance(split, BaseSplitter):
413        splitter = split
414    else:
415        raise TypeError("split must be None, str, or instance of"
416                        " BaseSplitter, but got {}".format(type(split)))
417    time_list = get_pdbbind_time()
418    train_ind, valid_ind, test_ind = \
419        splitter.train_valid_test_split(dataset, time_list=time_list,
420                                        smiles_list=None,
421                                        task_index=task_index,
422                                        frac_train=frac_train,
423                                        frac_valid=frac_valid,
424                                        frac_test=frac_test, **kwargs)
425    train = NumpyTupleDataset(*dataset.features[train_ind])
426    valid = NumpyTupleDataset(*dataset.features[valid_ind])
427    test = NumpyTupleDataset(*dataset.features[test_ind])
428
429    result['dataset'] = (train, valid, test)
430    result['smiles'] = None
431    return result
432
433
434def get_grid_featurized_pdbbind_dataset(subset):
435    """Downloads and caches grid featurized PDBBind dataset.
436
437    Args:
438        subset (str): subset name of PDBBind dataset.
439
440    Returns (NumpyTupleDataset):
441        grid featurized PDBBind dataset.
442
443    """
444    x_path, y_path = get_grid_featurized_pdbbind_filepath(subset)
445    x = joblib.load(x_path).astype('i')
446    y = joblib.load(y_path).astype('f')
447    dataset = NumpyTupleDataset(x, y)
448    return dataset
449
450
451def get_grid_featurized_pdbbind_dirpath(subset, download_if_not_exist=True):
452    """Construct a directory path which stores grid featurized PDBBind dataset.
453
454    This method check whether the file exist or not, and downloaded it if
455    necessary.
456
457    Args:
458        subset (str): subset name of PDBBind dataset.
459        download_if_not_exist (bool): Download a file if it does not exist.
460
461    Returns (str): directory path for specific subset of PDBBind dataset.
462
463    """
464    subset_supported = ['core', 'full', 'refined']
465    if subset not in subset_supported:
466        raise ValueError("subset {} not supported, please choose filetype "
467                         "from {}".format(subset, subset_supported))
468    file_url = \
469        molnet_default_config['pdbbind_grid']['url'][subset]
470    file_name = file_url.split('/')[-1]
471    cache_path = _get_molnet_filepath(file_name)
472    if not os.path.exists(cache_path):
473        if download_if_not_exist:
474            is_successful = download_dataset(file_url,
475                                             save_filepath=cache_path)
476            if not is_successful:
477                logger = getLogger(__name__)
478                logger.warning('Download failed.')
479    return cache_path
480
481
482def get_grid_featurized_pdbbind_filepath(subset):
483    """Construct a filepath which stores featurized PDBBind dataset in joblib
484
485    This method does not check if the file is already downloaded or not.
486
487    Args:
488        subset (str): subset name of PDBBind dataset
489
490    Returns:
491        x_path (str): filepath for feature vectors
492        y_path (str): filepath for -logKd/Ki
493
494    """
495    dirpath = get_grid_featurized_pdbbind_dirpath(subset=subset)
496    savedir = '/'.join(dirpath.split('/')[:-1]) + '/'
497    with tarfile.open(dirpath, 'r:gz') as tar:
498        tar.extractall(savedir)
499        x_path = savedir + subset + '_grid/shard-0-X.joblib'
500        y_path = savedir + subset + '_grid/shard-0-y.joblib'
501    return x_path, y_path
502