1'''This is a simple script that converts a pickled XGBoost 2Scikit-Learn interface object from 0.90 to a native model. Pickle 3format is not stable as it's a direct serialization of Python object. 4We advice not to use it when stability is needed. 5 6''' 7import pickle 8import json 9import os 10import argparse 11import numpy as np 12import xgboost 13import warnings 14 15 16def save_label_encoder(le): 17 '''Save the label encoder in XGBClassifier''' 18 meta = dict() 19 for k, v in le.__dict__.items(): 20 if isinstance(v, np.ndarray): 21 meta[k] = v.tolist() 22 else: 23 meta[k] = v 24 return meta 25 26 27def xgboost_skl_90to100(skl_model): 28 '''Extract the model and related metadata in SKL model.''' 29 model = {} 30 with open(skl_model, 'rb') as fd: 31 old = pickle.load(fd) 32 if not isinstance(old, xgboost.XGBModel): 33 raise TypeError( 34 'The script only handes Scikit-Learn interface object') 35 36 # Save Scikit-Learn specific Python attributes into a JSON document. 37 for k, v in old.__dict__.items(): 38 if k == '_le': 39 model[k] = save_label_encoder(v) 40 elif k == 'classes_': 41 model[k] = v.tolist() 42 elif k == '_Booster': 43 continue 44 else: 45 try: 46 json.dumps({k: v}) 47 model[k] = v 48 except TypeError: 49 warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.') 50 booster = old.get_booster() 51 # Store the JSON serialization as an attribute 52 booster.set_attr(scikit_learn=json.dumps(model)) 53 54 # Save it into a native model. 55 i = 0 56 while True: 57 path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin' 58 if os.path.exists(path): 59 i += 1 60 continue 61 booster.save_model(path) 62 break 63 64 65if __name__ == '__main__': 66 assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version' 67 ' that generates this pickle.') 68 parser = argparse.ArgumentParser( 69 description=('A simple script to convert pickle generated by' 70 ' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).')) 71 parser.add_argument( 72 '--old-pickle', 73 type=str, 74 help='Path to old pickle file of Scikit-Learn interface object. ' 75 'Will output a native model converted from this pickle file', 76 required=True) 77 args = parser.parse_args() 78 79 xgboost_skl_90to100(args.old_pickle) 80