1 // Copyright 2020 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/browser/optimization_guide/prediction/remote_decision_tree_predictor.h" 6 7 #include <string> 8 9 #include "base/containers/flat_set.h" 10 11 namespace optimization_guide { 12 RemoteDecisionTreePredictor(const proto::PredictionModel & model)13RemoteDecisionTreePredictor::RemoteDecisionTreePredictor( 14 const proto::PredictionModel& model) { 15 // The Decision Tree model type is currently the only supported model type. 16 DCHECK(model.model_info().supported_model_types(0) == 17 optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE); 18 19 version_ = model.model_info().version(); 20 model_features_.reserve( 21 model.model_info().supported_model_features_size() + 22 model.model_info().supported_host_model_features_size()); 23 // Insert all the client model features for the owned |model_|. 24 for (const auto& client_model_feature : 25 model.model_info().supported_model_features()) { 26 model_features_.emplace( 27 proto::ClientModelFeature_Name(client_model_feature)); 28 } 29 // Insert all the host model features for the owned |model_|. 30 for (const auto& host_model_feature : 31 model.model_info().supported_host_model_features()) { 32 model_features_.emplace(host_model_feature); 33 } 34 } 35 36 RemoteDecisionTreePredictor::~RemoteDecisionTreePredictor() = default; 37 38 machine_learning::mojom::DecisionTreePredictorProxy* Get() const39RemoteDecisionTreePredictor::Get() const { 40 if (!remote_) 41 return nullptr; 42 43 return remote_.get(); 44 } 45 IsConnected() const46bool RemoteDecisionTreePredictor::IsConnected() const { 47 return remote_.is_connected(); 48 } 49 FlushForTesting()50void RemoteDecisionTreePredictor::FlushForTesting() { 51 remote_.FlushForTesting(); 52 } 53 54 mojo::PendingReceiver<machine_learning::mojom::DecisionTreePredictor> BindNewPipeAndPassReceiver()55RemoteDecisionTreePredictor::BindNewPipeAndPassReceiver() { 56 return remote_.BindNewPipeAndPassReceiver(); 57 } 58 model_features() const59const base::flat_set<std::string>& RemoteDecisionTreePredictor::model_features() 60 const { 61 return model_features_; 62 } 63 version() const64int64_t RemoteDecisionTreePredictor::version() const { 65 return version_; 66 } 67 68 } // namespace optimization_guide 69