1'''This is a simple script that converts a pickled XGBoost
2Scikit-Learn interface object from 0.90 to a native model.  Pickle
3format is not stable as it's a direct serialization of Python object.
4We advice not to use it when stability is needed.
5
6'''
7import pickle
8import json
9import os
10import argparse
11import numpy as np
12import xgboost
13import warnings
14
15
16def save_label_encoder(le):
17    '''Save the label encoder in XGBClassifier'''
18    meta = dict()
19    for k, v in le.__dict__.items():
20        if isinstance(v, np.ndarray):
21            meta[k] = v.tolist()
22        else:
23            meta[k] = v
24    return meta
25
26
27def xgboost_skl_90to100(skl_model):
28    '''Extract the model and related metadata in SKL model.'''
29    model = {}
30    with open(skl_model, 'rb') as fd:
31        old = pickle.load(fd)
32        if not isinstance(old, xgboost.XGBModel):
33            raise TypeError(
34                'The script only handes Scikit-Learn interface object')
35
36    # Save Scikit-Learn specific Python attributes into a JSON document.
37    for k, v in old.__dict__.items():
38        if k == '_le':
39            model[k] = save_label_encoder(v)
40        elif k == 'classes_':
41            model[k] = v.tolist()
42        elif k == '_Booster':
43            continue
44        else:
45            try:
46                json.dumps({k: v})
47                model[k] = v
48            except TypeError:
49                warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
50    booster = old.get_booster()
51    # Store the JSON serialization as an attribute
52    booster.set_attr(scikit_learn=json.dumps(model))
53
54    # Save it into a native model.
55    i = 0
56    while True:
57        path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin'
58        if os.path.exists(path):
59            i += 1
60            continue
61        booster.save_model(path)
62        break
63
64
65if __name__ == '__main__':
66    assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version'
67                                            ' that generates this pickle.')
68    parser = argparse.ArgumentParser(
69        description=('A simple script to convert pickle generated by'
70                     ' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).'))
71    parser.add_argument(
72        '--old-pickle',
73        type=str,
74        help='Path to old pickle file of Scikit-Learn interface object.  '
75        'Will output a native model converted from this pickle file',
76        required=True)
77    args = parser.parse_args()
78
79    xgboost_skl_90to100(args.old_pickle)
80