1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/optimization_guide/prediction/remote_decision_tree_predictor.h"
6 
7 #include <string>
8 
9 #include "base/containers/flat_set.h"
10 
11 namespace optimization_guide {
12 
RemoteDecisionTreePredictor(const proto::PredictionModel & model)13 RemoteDecisionTreePredictor::RemoteDecisionTreePredictor(
14     const proto::PredictionModel& model) {
15   // The Decision Tree model type is currently the only supported model type.
16   DCHECK(model.model_info().supported_model_types(0) ==
17          optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE);
18 
19   version_ = model.model_info().version();
20   model_features_.reserve(
21       model.model_info().supported_model_features_size() +
22       model.model_info().supported_host_model_features_size());
23   // Insert all the client model features for the owned |model_|.
24   for (const auto& client_model_feature :
25        model.model_info().supported_model_features()) {
26     model_features_.emplace(
27         proto::ClientModelFeature_Name(client_model_feature));
28   }
29   // Insert all the host model features for the owned |model_|.
30   for (const auto& host_model_feature :
31        model.model_info().supported_host_model_features()) {
32     model_features_.emplace(host_model_feature);
33   }
34 }
35 
36 RemoteDecisionTreePredictor::~RemoteDecisionTreePredictor() = default;
37 
38 machine_learning::mojom::DecisionTreePredictorProxy*
Get() const39 RemoteDecisionTreePredictor::Get() const {
40   if (!remote_)
41     return nullptr;
42 
43   return remote_.get();
44 }
45 
IsConnected() const46 bool RemoteDecisionTreePredictor::IsConnected() const {
47   return remote_.is_connected();
48 }
49 
FlushForTesting()50 void RemoteDecisionTreePredictor::FlushForTesting() {
51   remote_.FlushForTesting();
52 }
53 
54 mojo::PendingReceiver<machine_learning::mojom::DecisionTreePredictor>
BindNewPipeAndPassReceiver()55 RemoteDecisionTreePredictor::BindNewPipeAndPassReceiver() {
56   return remote_.BindNewPipeAndPassReceiver();
57 }
58 
model_features() const59 const base::flat_set<std::string>& RemoteDecisionTreePredictor::model_features()
60     const {
61   return model_features_;
62 }
63 
version() const64 int64_t RemoteDecisionTreePredictor::version() const {
65   return version_;
66 }
67 
68 }  // namespace optimization_guide
69