1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
6
7 #include <utility>
8
9 namespace media {
10 namespace learning {
11
MojoLearningTaskController(const LearningTask & task,mojo::Remote<mojom::LearningTaskController> controller)12 MojoLearningTaskController::MojoLearningTaskController(
13 const LearningTask& task,
14 mojo::Remote<mojom::LearningTaskController> controller)
15 : task_(task), controller_(std::move(controller)) {}
16
17 MojoLearningTaskController::~MojoLearningTaskController() = default;
18
BeginObservation(base::UnguessableToken id,const FeatureVector & features,const base::Optional<TargetValue> & default_target,const base::Optional<ukm::SourceId> & source_id)19 void MojoLearningTaskController::BeginObservation(
20 base::UnguessableToken id,
21 const FeatureVector& features,
22 const base::Optional<TargetValue>& default_target,
23 const base::Optional<ukm::SourceId>& source_id) {
24 // We don't need to keep track of in-flight observations, since the service
25 // side handles it for us. Also note that |source_id| is ignored; the service
26 // has no reason to trust it. It will fill it in for us. DCHECK in case
27 // somebody actually tries to send us one, expecting it to be used.
28 DCHECK(!source_id);
29 controller_->BeginObservation(id, features, default_target);
30 }
31
CompleteObservation(base::UnguessableToken id,const ObservationCompletion & completion)32 void MojoLearningTaskController::CompleteObservation(
33 base::UnguessableToken id,
34 const ObservationCompletion& completion) {
35 controller_->CompleteObservation(id, completion);
36 }
37
CancelObservation(base::UnguessableToken id)38 void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) {
39 controller_->CancelObservation(id);
40 }
41
UpdateDefaultTarget(base::UnguessableToken id,const base::Optional<TargetValue> & default_target)42 void MojoLearningTaskController::UpdateDefaultTarget(
43 base::UnguessableToken id,
44 const base::Optional<TargetValue>& default_target) {
45 controller_->UpdateDefaultTarget(id, default_target);
46 }
47
GetLearningTask()48 const LearningTask& MojoLearningTaskController::GetLearningTask() {
49 return task_;
50 }
51
PredictDistribution(const FeatureVector & features,PredictionCB callback)52 void MojoLearningTaskController::PredictDistribution(
53 const FeatureVector& features,
54 PredictionCB callback) {
55 controller_->PredictDistribution(features, std::move(callback));
56 }
57
58 } // namespace learning
59 } // namespace media
60