1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <fuzzer/FuzzedDataProvider.h>
6 
7 #include "base/test/task_environment.h"
8 #include "media/learning/impl/learning_task_controller_impl.h"
9 
10 using media::learning::FeatureValue;
11 using media::learning::FeatureVector;
12 using media::learning::LearningTask;
13 using ValueDescription = media::learning::LearningTask::ValueDescription;
14 using media::learning::LearningTaskControllerImpl;
15 using media::learning::ObservationCompletion;
16 using media::learning::TargetValue;
17 
ConsumeValueDescription(FuzzedDataProvider * provider)18 ValueDescription ConsumeValueDescription(FuzzedDataProvider* provider) {
19   ValueDescription desc;
20   desc.name = provider->ConsumeRandomLengthString(100);
21   desc.ordering = provider->ConsumeEnum<LearningTask::Ordering>();
22   desc.privacy_mode = provider->ConsumeEnum<LearningTask::PrivacyMode>();
23   return desc;
24 }
25 
ConsumeDouble(FuzzedDataProvider * provider)26 double ConsumeDouble(FuzzedDataProvider* provider) {
27   std::vector<uint8_t> v = provider->ConsumeBytes<uint8_t>(sizeof(double));
28   if (v.size() == sizeof(double))
29     return reinterpret_cast<double*>(v.data())[0];
30 
31   return 0;
32 }
33 
ConsumeFeatureVector(FuzzedDataProvider * provider)34 FeatureVector ConsumeFeatureVector(FuzzedDataProvider* provider) {
35   FeatureVector features;
36   int n = provider->ConsumeIntegralInRange(0, 100);
37   while (n-- > 0)
38     features.push_back(FeatureValue(ConsumeDouble(provider)));
39 
40   return features;
41 }
42 
LLVMFuzzerTestOneInput(const uint8_t * data,size_t size)43 extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
44   base::test::TaskEnvironment task_environment;
45   FuzzedDataProvider provider(data, size);
46 
47   LearningTask task;
48   task.name = provider.ConsumeRandomLengthString(100);
49   task.model = provider.ConsumeEnum<LearningTask::Model>();
50   task.use_one_hot_conversion = provider.ConsumeBool();
51   task.uma_hacky_aggregate_confusion_matrix = provider.ConsumeBool();
52   task.uma_hacky_by_training_weight_confusion_matrix = provider.ConsumeBool();
53   task.uma_hacky_by_feature_subset_confusion_matrix = provider.ConsumeBool();
54   int n_features = provider.ConsumeIntegralInRange(0, 100);
55   int subset_size = provider.ConsumeIntegralInRange<uint8_t>(0, n_features);
56   if (subset_size)
57     task.feature_subset_size = subset_size;
58   for (int i = 0; i < n_features; i++)
59     task.feature_descriptions.push_back(ConsumeValueDescription(&provider));
60   task.target_description = ConsumeValueDescription(&provider);
61 
62   LearningTaskControllerImpl controller(task);
63 
64   // Build random examples.
65   while (provider.remaining_bytes() > 0) {
66     base::UnguessableToken id = base::UnguessableToken::Create();
67     base::Optional<TargetValue> default_target;
68     if (provider.ConsumeBool())
69       default_target = TargetValue(ConsumeDouble(&provider));
70     controller.BeginObservation(id, ConsumeFeatureVector(&provider),
71                                 default_target, base::nullopt);
72     controller.CompleteObservation(
73         id, ObservationCompletion(TargetValue(ConsumeDouble(&provider)),
74                                   ConsumeDouble(&provider)));
75     task_environment.RunUntilIdle();
76   }
77 
78   return 0;
79 }
80