1# Copyright 2018 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Utilities for running predictions. 15 16Includes (from the Cloud ML SDK): 17- _predict_lib 18 19Important changes: 20- Remove interfaces for TensorFlowModel (they don't change behavior). 21- Set from_client(skip_preprocessing=True) and remove the pre-processing code. 22""" 23from . import custom_code_utils 24from . import prediction_utils 25 26 27# -------------------------- 28# prediction.prediction_lib 29# -------------------------- 30def create_model(client, model_path, framework=None, **unused_kwargs): 31 """Creates and returns the appropriate model. 32 33 Creates and returns a Model if no user specified model is 34 provided. Otherwise, the user specified model is imported, created, and 35 returned. 36 37 Args: 38 client: An instance of PredictionClient for performing prediction. 39 model_path: The path to the exported model (e.g. session_bundle or 40 SavedModel) 41 framework: The framework used to train the model. 42 43 Returns: 44 An instance of the appropriate model class. 45 """ 46 custom_model = custom_code_utils.create_user_model(model_path, None) 47 if custom_model: 48 return custom_model 49 50 framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME 51 52 if framework == prediction_utils.TENSORFLOW_FRAMEWORK_NAME: 53 from .frameworks import tf_prediction_lib # pylint: disable=g-import-not-at-top 54 model_cls = tf_prediction_lib.TensorFlowModel 55 elif framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME: 56 from .frameworks import sk_xg_prediction_lib # pylint: disable=g-import-not-at-top 57 model_cls = sk_xg_prediction_lib.SklearnModel 58 elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME: 59 from .frameworks import sk_xg_prediction_lib # pylint: disable=g-import-not-at-top 60 model_cls = sk_xg_prediction_lib.XGBoostModel 61 62 return model_cls(client) 63 64 65def create_client(framework, model_path, **kwargs): 66 """Creates and returns the appropriate prediction client. 67 68 Creates and returns a PredictionClient based on the provided framework. 69 70 Args: 71 framework: The framework used to train the model. 72 model_path: The path to the exported model (e.g. session_bundle or 73 SavedModel) 74 **kwargs: Optional additional params to pass to the client constructor (such 75 as TF tags). 76 77 Returns: 78 An instance of the appropriate PredictionClient. 79 """ 80 framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME 81 if framework == prediction_utils.TENSORFLOW_FRAMEWORK_NAME: 82 from .frameworks import tf_prediction_lib # pylint: disable=g-import-not-at-top 83 create_client_fn = tf_prediction_lib.create_tf_session_client 84 elif framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME: 85 from .frameworks import sk_xg_prediction_lib # pylint: disable=g-import-not-at-top 86 create_client_fn = sk_xg_prediction_lib.create_sklearn_client 87 elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME: 88 from .frameworks import sk_xg_prediction_lib # pylint: disable=g-import-not-at-top 89 create_client_fn = sk_xg_prediction_lib.create_xgboost_client 90 91 return create_client_fn(model_path, **kwargs) 92 93 94def local_predict(model_dir=None, signature_name=None, instances=None, 95 framework=None, **kwargs): 96 """Run a prediction locally.""" 97 framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME 98 client = create_client(framework, model_dir, **kwargs) 99 model = create_model(client, model_dir, framework) 100 if prediction_utils.should_base64_decode(framework, model, signature_name): 101 instances = prediction_utils.decode_base64(instances) 102 predictions = model.predict(instances, signature_name=signature_name) 103 return {"predictions": list(predictions)} 104