1 // Copyright 2020 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "components/optimization_guide/prediction_model.h" 6 7 #include <utility> 8 9 #include "components/optimization_guide/decision_tree_prediction_model.h" 10 11 namespace optimization_guide { 12 13 // static Create(std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model)14std::unique_ptr<PredictionModel> PredictionModel::Create( 15 std::unique_ptr<optimization_guide::proto::PredictionModel> 16 prediction_model) { 17 // TODO(crbug/1009123): Add a histogram to record if the provided model is 18 // constructed successfully or not. 19 // TODO(crbug/1009123): Adding timing metrics around initialization due to 20 // potential validation overhead. 21 if (!prediction_model->has_model()) 22 return nullptr; 23 24 if (!prediction_model->has_model_info()) 25 return nullptr; 26 27 if (!prediction_model->model_info().has_version()) 28 return nullptr; 29 30 // Enforce that only one ModelType is specified for the PredictionModel. 31 if (prediction_model->model_info().supported_model_types_size() != 1) { 32 return nullptr; 33 } 34 35 // Check that the client supports this type of model and is not an unknown 36 // type. 37 if (!optimization_guide::proto::ModelType_IsValid( 38 prediction_model->model_info().supported_model_types(0)) || 39 prediction_model->model_info().supported_model_types(0) == 40 optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN) { 41 return nullptr; 42 } 43 44 // Check that the client supports the model features for |prediction model|. 45 for (const auto& model_feature : 46 prediction_model->model_info().supported_model_features()) { 47 if (!optimization_guide::proto::ClientModelFeature_IsValid(model_feature) || 48 model_feature == optimization_guide::proto::ClientModelFeature:: 49 CLIENT_MODEL_FEATURE_UNKNOWN) 50 return nullptr; 51 } 52 53 std::unique_ptr<PredictionModel> model; 54 // The Decision Tree model type is currently the only supported model type. 55 if (prediction_model->model_info().supported_model_types(0) != 56 optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE) { 57 return nullptr; 58 } 59 model = std::make_unique<DecisionTreePredictionModel>( 60 std::move(prediction_model)); 61 62 // Any constructed model must be validated for correctness according to its 63 // model type before being returned. 64 if (!model->ValidatePredictionModel()) 65 return nullptr; 66 67 return model; 68 } 69 PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model)70PredictionModel::PredictionModel( 71 std::unique_ptr<optimization_guide::proto::PredictionModel> 72 prediction_model) { 73 version_ = prediction_model->model_info().version(); 74 model_features_.reserve( 75 prediction_model->model_info().supported_model_features_size() + 76 prediction_model->model_info().supported_host_model_features_size()); 77 // Insert all the client model features for the owned |model_|. 78 for (const auto& client_model_feature : 79 prediction_model->model_info().supported_model_features()) { 80 model_features_.emplace(optimization_guide::proto::ClientModelFeature_Name( 81 client_model_feature)); 82 } 83 // Insert all the host model features for the owned |model_|. 84 for (const auto& host_model_feature : 85 prediction_model->model_info().supported_host_model_features()) { 86 model_features_.emplace(host_model_feature); 87 } 88 model_ = std::make_unique<optimization_guide::proto::Model>( 89 prediction_model->model()); 90 } 91 GetVersion() const92int64_t PredictionModel::GetVersion() const { 93 SEQUENCE_CHECKER(sequence_checker_); 94 return version_; 95 } 96 GetModelFeatures() const97base::flat_set<std::string> PredictionModel::GetModelFeatures() const { 98 SEQUENCE_CHECKER(sequence_checker_); 99 return model_features_; 100 } 101 102 PredictionModel::~PredictionModel() = default; 103 104 } // namespace optimization_guide 105