1# Copyright 2018 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utilities for running predictions.
15
16Includes (from the Cloud ML SDK):
17- _predict_lib
18
19Important changes:
20- Remove interfaces for TensorFlowModel (they don't change behavior).
21- Set from_client(skip_preprocessing=True) and remove the pre-processing code.
22"""
23from . import custom_code_utils
24from . import prediction_utils
25
26
27# --------------------------
28# prediction.prediction_lib
29# --------------------------
30def create_model(client, model_path, framework=None, **unused_kwargs):
31  """Creates and returns the appropriate model.
32
33  Creates and returns a Model if no user specified model is
34  provided. Otherwise, the user specified model is imported, created, and
35  returned.
36
37  Args:
38    client: An instance of PredictionClient for performing prediction.
39    model_path: The path to the exported model (e.g. session_bundle or
40      SavedModel)
41    framework: The framework used to train the model.
42
43  Returns:
44    An instance of the appropriate model class.
45  """
46  custom_model = custom_code_utils.create_user_model(model_path, None)
47  if custom_model:
48    return custom_model
49
50  framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME
51
52  if framework == prediction_utils.TENSORFLOW_FRAMEWORK_NAME:
53    from .frameworks import tf_prediction_lib  # pylint: disable=g-import-not-at-top
54    model_cls = tf_prediction_lib.TensorFlowModel
55  elif framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME:
56    from .frameworks import sk_xg_prediction_lib  # pylint: disable=g-import-not-at-top
57    model_cls = sk_xg_prediction_lib.SklearnModel
58  elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME:
59    from .frameworks import sk_xg_prediction_lib  # pylint: disable=g-import-not-at-top
60    model_cls = sk_xg_prediction_lib.XGBoostModel
61
62  return model_cls(client)
63
64
65def create_client(framework, model_path, **kwargs):
66  """Creates and returns the appropriate prediction client.
67
68  Creates and returns a PredictionClient based on the provided framework.
69
70  Args:
71    framework: The framework used to train the model.
72    model_path: The path to the exported model (e.g. session_bundle or
73      SavedModel)
74    **kwargs: Optional additional params to pass to the client constructor (such
75      as TF tags).
76
77  Returns:
78    An instance of the appropriate PredictionClient.
79  """
80  framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME
81  if framework == prediction_utils.TENSORFLOW_FRAMEWORK_NAME:
82    from .frameworks import tf_prediction_lib  # pylint: disable=g-import-not-at-top
83    create_client_fn = tf_prediction_lib.create_tf_session_client
84  elif framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME:
85    from .frameworks import sk_xg_prediction_lib  # pylint: disable=g-import-not-at-top
86    create_client_fn = sk_xg_prediction_lib.create_sklearn_client
87  elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME:
88    from .frameworks import sk_xg_prediction_lib  # pylint: disable=g-import-not-at-top
89    create_client_fn = sk_xg_prediction_lib.create_xgboost_client
90
91  return create_client_fn(model_path, **kwargs)
92
93
94def local_predict(model_dir=None, signature_name=None, instances=None,
95                  framework=None, **kwargs):
96  """Run a prediction locally."""
97  framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME
98  client = create_client(framework, model_dir, **kwargs)
99  model = create_model(client, model_dir, framework)
100  if prediction_utils.should_base64_decode(framework, model, signature_name):
101    instances = prediction_utils.decode_base64(instances)
102  predictions = model.predict(instances, signature_name=signature_name)
103  return {"predictions": list(predictions)}
104