1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/chromeos/policy/auto_enrollment_client_impl.h"
6 
7 #include <stdint.h>
8 
9 #include "base/bind.h"
10 #include "base/guid.h"
11 #include "base/location.h"
12 #include "base/logging.h"
13 #include "base/memory/ptr_util.h"
14 #include "base/metrics/histogram_functions.h"
15 #include "base/metrics/histogram_macros.h"
16 #include "base/optional.h"
17 #include "base/strings/string_number_conversions.h"
18 #include "base/threading/thread_task_runner_handle.h"
19 #include "chrome/browser/chromeos/login/enrollment/auto_enrollment_controller.h"
20 #include "chrome/browser/chromeos/policy/server_backed_device_state.h"
21 #include "chrome/common/chrome_content_client.h"
22 #include "chrome/common/pref_names.h"
23 #include "components/policy/core/common/cloud/device_management_service.h"
24 #include "components/policy/core/common/cloud/dm_auth.h"
25 #include "components/policy/core/common/cloud/dmserver_job_configurations.h"
26 #include "components/policy/core/common/cloud/enterprise_metrics.h"
27 #include "components/policy/proto/device_management_backend.pb.h"
28 #include "components/prefs/pref_registry_simple.h"
29 #include "components/prefs/pref_service.h"
30 #include "components/prefs/scoped_user_pref_update.h"
31 #include "content/public/browser/browser_thread.h"
32 #include "content/public/browser/network_service_instance.h"
33 #include "crypto/sha2.h"
34 #include "services/network/public/cpp/shared_url_loader_factory.h"
35 #include "third_party/private_membership/src/private_membership_rlwe_client.h"
36 #include "url/gurl.h"
37 
38 using content::BrowserThread;
39 
40 namespace psm_rlwe = private_membership::rlwe;
41 namespace em = enterprise_management;
42 
43 namespace policy {
44 
45 namespace {
46 
47 using EnrollmentCheckType =
48     em::DeviceAutoEnrollmentRequest::EnrollmentCheckType;
49 
50 // Timeout for running private set membership protocol.
51 constexpr base::TimeDelta kPrivateSetMembershipTimeout =
52     base::TimeDelta::FromSeconds(15);
53 
54 // Returns the power of the next power-of-2 starting at |value|.
NextPowerOf2(int64_t value)55 int NextPowerOf2(int64_t value) {
56   for (int i = 0; i <= AutoEnrollmentClient::kMaximumPower; ++i) {
57     if ((INT64_C(1) << i) >= value)
58       return i;
59   }
60   // No other value can be represented in an int64_t.
61   return AutoEnrollmentClient::kMaximumPower + 1;
62 }
63 
64 // Sets or clears a value in a dictionary.
UpdateDict(base::DictionaryValue * dict,const char * pref_path,bool set_or_clear,std::unique_ptr<base::Value> value)65 void UpdateDict(base::DictionaryValue* dict,
66                 const char* pref_path,
67                 bool set_or_clear,
68                 std::unique_ptr<base::Value> value) {
69   if (set_or_clear)
70     dict->Set(pref_path, std::move(value));
71   else
72     dict->Remove(pref_path, NULL);
73 }
74 
75 // Converts a restore mode enum value from the DM protocol for FRE into the
76 // corresponding prefs string constant.
ConvertRestoreMode(em::DeviceStateRetrievalResponse::RestoreMode restore_mode)77 std::string ConvertRestoreMode(
78     em::DeviceStateRetrievalResponse::RestoreMode restore_mode) {
79   switch (restore_mode) {
80     case em::DeviceStateRetrievalResponse::RESTORE_MODE_NONE:
81       return std::string();
82     case em::DeviceStateRetrievalResponse::RESTORE_MODE_REENROLLMENT_REQUESTED:
83       return kDeviceStateRestoreModeReEnrollmentRequested;
84     case em::DeviceStateRetrievalResponse::RESTORE_MODE_REENROLLMENT_ENFORCED:
85       return kDeviceStateRestoreModeReEnrollmentEnforced;
86     case em::DeviceStateRetrievalResponse::RESTORE_MODE_DISABLED:
87       return kDeviceStateModeDisabled;
88     case em::DeviceStateRetrievalResponse::RESTORE_MODE_REENROLLMENT_ZERO_TOUCH:
89       return kDeviceStateRestoreModeReEnrollmentZeroTouch;
90   }
91 
92   // Return is required to avoid compiler warning.
93   NOTREACHED() << "Bad restore_mode=" << restore_mode << ".";
94   return std::string();
95 }
96 
97 // Converts an initial enrollment mode enum value from the DM protocol for
98 // initial enrollment into the corresponding prefs string constant.
ConvertInitialEnrollmentMode(em::DeviceInitialEnrollmentStateResponse::InitialEnrollmentMode initial_enrollment_mode)99 std::string ConvertInitialEnrollmentMode(
100     em::DeviceInitialEnrollmentStateResponse::InitialEnrollmentMode
101         initial_enrollment_mode) {
102   switch (initial_enrollment_mode) {
103     case em::DeviceInitialEnrollmentStateResponse::INITIAL_ENROLLMENT_MODE_NONE:
104       return std::string();
105     case em::DeviceInitialEnrollmentStateResponse::
106         INITIAL_ENROLLMENT_MODE_ENROLLMENT_ENFORCED:
107       return kDeviceStateInitialModeEnrollmentEnforced;
108     case em::DeviceInitialEnrollmentStateResponse::
109         INITIAL_ENROLLMENT_MODE_ZERO_TOUCH_ENFORCED:
110       return kDeviceStateInitialModeEnrollmentZeroTouch;
111     case em::DeviceInitialEnrollmentStateResponse::
112         INITIAL_ENROLLMENT_MODE_DISABLED:
113       return kDeviceStateModeDisabled;
114   }
115 }
116 
117 }  // namespace
118 
ConstructDeviceRlweId(const std::string & device_serial_number,const std::string & device_rlz_brand_code)119 psm_rlwe::RlwePlaintextId ConstructDeviceRlweId(
120     const std::string& device_serial_number,
121     const std::string& device_rlz_brand_code) {
122   psm_rlwe::RlwePlaintextId rlwe_id;
123 
124   std::string rlz_brand_code_hex = base::HexEncode(
125       device_rlz_brand_code.data(), device_rlz_brand_code.size());
126 
127   rlwe_id.set_sensitive_id(rlz_brand_code_hex + "/" + device_serial_number);
128   return rlwe_id;
129 }
130 
131 // Subclasses of this class provide an identifier and specify the identifier
132 // set for the DeviceAutoEnrollmentRequest,
133 class AutoEnrollmentClientImpl::DeviceIdentifierProvider {
134  public:
~DeviceIdentifierProvider()135   virtual ~DeviceIdentifierProvider() {}
136 
137   // Should return the EnrollmentCheckType to be used in the
138   // DeviceAutoEnrollmentRequest. This specifies the identifier set used on
139   // the server.
140   virtual enterprise_management::DeviceAutoEnrollmentRequest::
141       EnrollmentCheckType
142       GetEnrollmentCheckType() const = 0;
143 
144   // Should return the hash of this device's identifier. The
145   // DeviceAutoEnrollmentRequest exchange will check if this hash is in the
146   // server-side identifier set specified by |GetEnrollmentCheckType()|
147   virtual const std::string& GetIdHash() const = 0;
148 };
149 
150 // Subclasses of this class generate the request to download the device state
151 // (after determining that there is server-side device state) and parse the
152 // response.
153 class AutoEnrollmentClientImpl::StateDownloadMessageProcessor {
154  public:
~StateDownloadMessageProcessor()155   virtual ~StateDownloadMessageProcessor() {}
156 
157   // Parsed fields of DeviceManagementResponse.
158   struct ParsedResponse {
159     std::string restore_mode;
160     base::Optional<std::string> management_domain;
161     base::Optional<std::string> disabled_message;
162     base::Optional<bool> is_license_packaged_with_device;
163   };
164 
165   // Returns the request job type. This must match the request filled in
166   // |FillRequest|.
167   virtual DeviceManagementService::JobConfiguration::JobType GetJobType()
168       const = 0;
169 
170   // Fills the specific request type in |request|.
171   virtual void FillRequest(
172       enterprise_management::DeviceManagementRequest* request) = 0;
173 
174   // Parses the |response|. If it is valid, returns a ParsedResponse struct
175   // instance. If it is invalid, returns nullopt.
176   virtual base::Optional<ParsedResponse> ParseResponse(
177       const enterprise_management::DeviceManagementResponse& response) = 0;
178 };
179 
180 class PrivateSetMembershipHelper {
181  public:
182   // Callback will be triggered after completing the protocol, in case of a
183   // successful determination or stopping due to an error. Also, the bool result
184   // is ignored.
185   using CompletionCallback = base::OnceCallback<bool()>;
186 
187   // The PrivateSetMembershipHelper doesn't take ownership of
188   // |device_management_service| and |local_state|. Also, both must not be
189   // nullptr. The |device_management_service| and |local_state| must outlive
190   // PrivateSetMembershipHelper.
PrivateSetMembershipHelper(DeviceManagementService * device_management_service,scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,PrefService * local_state,psm_rlwe::RlwePlaintextId psm_rlwe_id)191   PrivateSetMembershipHelper(
192       DeviceManagementService* device_management_service,
193       scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
194       PrefService* local_state,
195       psm_rlwe::RlwePlaintextId psm_rlwe_id)
196       : random_device_id_(base::GenerateGUID()),
197         url_loader_factory_(url_loader_factory),
198         device_management_service_(device_management_service),
199         local_state_(local_state),
200         psm_rlwe_id_(std::move(psm_rlwe_id)) {
201     CHECK(device_management_service);
202     DCHECK(local_state_);
203 
204     // Create PSM client for |psm_rlwe_id_| with use case as CROS_DEVICE_STATE.
205     std::vector<psm_rlwe::RlwePlaintextId> psm_ids = {psm_rlwe_id_};
206     auto status_or_client = psm_rlwe::PrivateMembershipRlweClient::Create(
207         psm_rlwe::RlweUseCase::CROS_DEVICE_STATE, psm_ids);
208     if (!status_or_client.ok()) {
209       // If the private set membership RLWE client hasn't been created
210       // successfully, then report the error and don't run the protocol.
211       LOG(ERROR)
212           << "PSM error: unexpected internal logic error during creating "
213              "PSM RLWE client";
214       has_private_set_membership_error_ = true;
215       return;
216     }
217 
218     private_set_membership_rlwe_client_ = std::move(status_or_client).value();
219   }
220 
221   // Disallow copy constructor and assignment operator.
222   PrivateSetMembershipHelper(const PrivateSetMembershipHelper&) = delete;
223   PrivateSetMembershipHelper& operator=(const PrivateSetMembershipHelper&) =
224       delete;
225 
226   // Cancels the ongoing private set membership operation, if any (without
227   // calling the operation's callbacks).
~PrivateSetMembershipHelper()228   ~PrivateSetMembershipHelper() {
229     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
230   }
231 
232   // Determines the private set membership for the |psm_rlwe_id_|. Then, will
233   // call |callback| upon completing the protocol, whether it finished with a
234   // successful determination or stopped in case of errors. Also, the |callback|
235   // has to be non-null. In case a request is already in progress, the callback
236   // is called immediately.
CheckMembership(CompletionCallback callback)237   void CheckMembership(CompletionCallback callback) {
238     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
239     DCHECK(callback);
240 
241     // Ignore new calls and execute their completion |callback|, if any error
242     // occurred while running private set membership previously, or in case the
243     // requests from previous call didn't finish yet.
244     if (has_private_set_membership_error_ || psm_request_job_) {
245       std::move(callback).Run();
246       return;
247     }
248 
249     // Report the psm attempt and start the timer to measure successful private
250     // set membership requests.
251     base::UmaHistogramEnumeration(kUMAPrivateSetMembershipRequestStatus,
252                                   PrivateSetMembershipStatus::kAttempt);
253     time_start_ = base::TimeTicks::Now();
254 
255     on_completion_callback_ = std::move(callback);
256 
257     // Start the protocol and its timeout timer.
258     private_set_membership_timeout_.Start(
259         FROM_HERE, kPrivateSetMembershipTimeout,
260         base::BindOnce(&PrivateSetMembershipHelper::OnTimeout,
261                        base::Unretained(this)));
262     SendPrivateSetMembershipRlweOprfRequest();
263   }
264 
265   // Sets the |private_set_membership_rlwe_client_| and |psm_rlwe_id_| for
266   // testing.
SetRlweClientAndIdForTesting(std::unique_ptr<psm_rlwe::PrivateMembershipRlweClient> private_set_membership_rlwe_client,psm_rlwe::RlwePlaintextId psm_rlwe_id)267   void SetRlweClientAndIdForTesting(
268       std::unique_ptr<psm_rlwe::PrivateMembershipRlweClient>
269           private_set_membership_rlwe_client,
270       psm_rlwe::RlwePlaintextId psm_rlwe_id) {
271     private_set_membership_rlwe_client_ =
272         std::move(private_set_membership_rlwe_client);
273     psm_rlwe_id_ = std::move(psm_rlwe_id);
274   }
275 
276   // Tries to load the result of a previous execution of the private set
277   // memberhsip protocol from local state. Returns decision value if it has been
278   // made and is valid, otherwise nullopt.
GetPrivateSetMembershipCachedDecision() const279   base::Optional<bool> GetPrivateSetMembershipCachedDecision() const {
280     const PrefService::Preference* has_psm_server_state_pref =
281         local_state_->FindPreference(prefs::kShouldRetrieveDeviceState);
282 
283     if (!has_psm_server_state_pref ||
284         has_psm_server_state_pref->IsDefaultValue() ||
285         !has_psm_server_state_pref->GetValue()->is_bool()) {
286       return base::nullopt;
287     }
288 
289     return has_psm_server_state_pref->GetValue()->GetBool();
290   }
291 
292   // Indicate whether an error occurred while executing the private set
293   // membership protocol.
HasPrivateSetMembershipError() const294   bool HasPrivateSetMembershipError() const {
295     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
296     return has_private_set_membership_error_;
297   }
298 
299   // Returns true if the private set membership protocol is still running,
300   // otherwise false.
IsCheckMembershipInProgress() const301   bool IsCheckMembershipInProgress() const {
302     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
303     return psm_request_job_ != nullptr;
304   }
305 
306  private:
OnTimeout()307   void OnTimeout() {
308     base::UmaHistogramEnumeration(kUMAPrivateSetMembershipRequestStatus,
309                                   PrivateSetMembershipStatus::kTimeout);
310     StoreErrorAndStop();
311   }
312 
StoreErrorAndStop()313   void StoreErrorAndStop() {
314     // Record the error. Note that a timeout is also recorded as error.
315     base::UmaHistogramEnumeration(kUMAPrivateSetMembershipRequestStatus,
316                                   PrivateSetMembershipStatus::kError);
317 
318     // Stop the private set membership timer.
319     private_set_membership_timeout_.Stop();
320 
321     // Stop the current |psm_request_job_|.
322     psm_request_job_.reset();
323 
324     has_private_set_membership_error_ = true;
325     std::move(on_completion_callback_).Run();
326   }
327 
328   // Constructs and sends the private set membership RLWE OPRF request.
SendPrivateSetMembershipRlweOprfRequest()329   void SendPrivateSetMembershipRlweOprfRequest() {
330     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
331 
332     // Create RLWE OPRF request.
333     const auto status_or_oprf_request =
334         private_set_membership_rlwe_client_->CreateOprfRequest();
335     if (!status_or_oprf_request.ok()) {
336       // If the RLWE OPRF request hasn't been created successfully, then report
337       // the error and stop the protocol.
338       LOG(ERROR)
339           << "PSM error: unexpected internal logic error during creating "
340              "RLWE OPRF request";
341       StoreErrorAndStop();
342       return;
343     }
344 
345     LOG(WARNING) << "PSM: prepare and send out the RLWE OPRF request";
346 
347     // Prepare the RLWE OPRF request job.
348     // The passed callback will not be called if |psm_request_job_| is
349     // destroyed, so it's safe to use base::Unretained.
350     std::unique_ptr<DMServerJobConfiguration> config =
351         CreatePsmRequestJobConfiguration(base::BindOnce(
352             &PrivateSetMembershipHelper::OnRlweOprfRequestCompletion,
353             base::Unretained(this)));
354 
355     em::DeviceManagementRequest* request = config->request();
356     em::PrivateSetMembershipRlweRequest* psm_rlwe_request =
357         request->mutable_private_set_membership_request()
358             ->mutable_rlwe_request();
359 
360     *psm_rlwe_request->mutable_oprf_request() = status_or_oprf_request.value();
361     psm_request_job_ = device_management_service_->CreateJob(std::move(config));
362   }
363 
364   // If the completion was successful, then it makes another request to
365   // DMServer for performing phase two.
OnRlweOprfRequestCompletion(DeviceManagementService::Job * job,DeviceManagementStatus status,int net_error,const em::DeviceManagementResponse & response)366   void OnRlweOprfRequestCompletion(
367       DeviceManagementService::Job* job,
368       DeviceManagementStatus status,
369       int net_error,
370       const em::DeviceManagementResponse& response) {
371     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
372 
373     switch (status) {
374       case DM_STATUS_SUCCESS: {
375         // Check if the RLWE OPRF response is empty.
376         if (!response.private_set_membership_response().has_rlwe_response() ||
377             !response.private_set_membership_response()
378                  .rlwe_response()
379                  .has_oprf_response()) {
380           LOG(ERROR) << "PSM error: empty OPRF RLWE response";
381           StoreErrorAndStop();
382           return;
383         }
384 
385         LOG(WARNING) << "PSM RLWE OPRF request completed successfully";
386         SendPrivateSetMembershipRlweQueryRequest(
387             response.private_set_membership_response());
388         return;
389       }
390       case DM_STATUS_REQUEST_FAILED: {
391         LOG(ERROR)
392             << "PSM error: RLWE OPRF request failed due to connection error";
393         StoreErrorAndStop();
394         return;
395       }
396       default: {
397         LOG(ERROR) << "PSM error: RLWE OPRF request failed due to server error";
398         StoreErrorAndStop();
399         return;
400       }
401     }
402   }
403 
404   // Constructs and sends the private set membership RLWE Query request.
SendPrivateSetMembershipRlweQueryRequest(const em::PrivateSetMembershipResponse & private_set_membership_response)405   void SendPrivateSetMembershipRlweQueryRequest(
406       const em::PrivateSetMembershipResponse& private_set_membership_response) {
407     // Extract the oprf_response from |private_set_membership_response|.
408     const psm_rlwe::PrivateMembershipRlweOprfResponse oprf_response =
409         private_set_membership_response.rlwe_response().oprf_response();
410 
411     const auto status_or_query_request =
412         private_set_membership_rlwe_client_->CreateQueryRequest(oprf_response);
413 
414     // Create RLWE query request.
415     if (!status_or_query_request.ok()) {
416       // If the RLWE query request hasn't been created successfully, then report
417       // the error and stop the protocol.
418       LOG(ERROR)
419           << "PSM error: unexpected internal logic error during creating "
420              "RLWE query request";
421       StoreErrorAndStop();
422       return;
423     }
424 
425     LOG(WARNING) << "PSM: prepare and send out the RLWE query request";
426 
427     // Prepare the RLWE query request job.
428     std::unique_ptr<DMServerJobConfiguration> config =
429         CreatePsmRequestJobConfiguration(base::BindOnce(
430             &PrivateSetMembershipHelper::OnRlweQueryRequestCompletion,
431             base::Unretained(this), oprf_response));
432 
433     em::DeviceManagementRequest* request = config->request();
434     em::PrivateSetMembershipRlweRequest* psm_rlwe_request =
435         request->mutable_private_set_membership_request()
436             ->mutable_rlwe_request();
437 
438     *psm_rlwe_request->mutable_query_request() =
439         status_or_query_request.value();
440     psm_request_job_ = device_management_service_->CreateJob(std::move(config));
441   }
442 
443   // If the completion was successful, then it will parse the result and call
444   // the |on_completion_callback_| for |psm_id_|.
OnRlweQueryRequestCompletion(const psm_rlwe::PrivateMembershipRlweOprfResponse & oprf_response,DeviceManagementService::Job * job,DeviceManagementStatus status,int net_error,const em::DeviceManagementResponse & response)445   void OnRlweQueryRequestCompletion(
446       const psm_rlwe::PrivateMembershipRlweOprfResponse& oprf_response,
447       DeviceManagementService::Job* job,
448       DeviceManagementStatus status,
449       int net_error,
450       const em::DeviceManagementResponse& response) {
451     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
452 
453     switch (status) {
454       case DM_STATUS_SUCCESS: {
455         // Check if the RLWE query response is empty.
456         if (!response.private_set_membership_response().has_rlwe_response() ||
457             !response.private_set_membership_response()
458                  .rlwe_response()
459                  .has_query_response()) {
460           LOG(ERROR) << "PSM error: empty query RLWE response";
461           StoreErrorAndStop();
462           return;
463         }
464 
465         const psm_rlwe::PrivateMembershipRlweQueryResponse query_response =
466             response.private_set_membership_response()
467                 .rlwe_response()
468                 .query_response();
469 
470         auto status_or_responses =
471             private_set_membership_rlwe_client_->ProcessResponse(
472                 query_response);
473 
474         if (!status_or_responses.ok()) {
475           // If the RLWE query response hasn't processed successfully, then
476           // report the error and stop the protocol.
477           LOG(ERROR) << "PSM error: unexpected internal logic error during "
478                         "processing the "
479                         "RLWE query response";
480           StoreErrorAndStop();
481           return;
482         }
483 
484         LOG(WARNING) << "PSM query request completed successfully";
485 
486         base::UmaHistogramEnumeration(
487             kUMAPrivateSetMembershipRequestStatus,
488             PrivateSetMembershipStatus::kSuccessfulDetermination);
489         RecordPrivateSetMembershipSuccessTimeHistogram();
490 
491         // The RLWE query response has been processed successfully. Extract
492         // the membership response, and report the result.
493         psm_rlwe::MembershipResponseMap membership_responses_map =
494             std::move(status_or_responses).value();
495         private_membership::MembershipResponse membership_response =
496             membership_responses_map.Get(psm_rlwe_id_);
497 
498         LOG(WARNING) << "PSM determination successful. Identifier "
499                      << (membership_response.is_member() ? "" : "not ")
500                      << "present on the server";
501 
502         // Reset the |psm_request_job_| to allow another call to
503         // CheckMembership.
504         psm_request_job_.reset();
505 
506         // Stop the private set membership timer.
507         private_set_membership_timeout_.Stop();
508 
509         // Cache the decision in local_state, so that it is reused in case
510         // the device reboots before completing OOBE.
511         local_state_->SetBoolean(prefs::kShouldRetrieveDeviceState,
512                                  membership_response.is_member());
513         local_state_->CommitPendingWrite();
514 
515         std::move(on_completion_callback_).Run();
516         return;
517       }
518       case DM_STATUS_REQUEST_FAILED: {
519         LOG(ERROR)
520             << "PSM error: RLWE query request failed due to connection error";
521         StoreErrorAndStop();
522         return;
523       }
524       default: {
525         LOG(ERROR)
526             << "PSM error: RLWE query request failed due to server error";
527         StoreErrorAndStop();
528         return;
529       }
530     }
531   }
532 
533   // Returns a job config that has TYPE_PSM_REQUEST as job type and |callback|
534   // will be executed on completion.
CreatePsmRequestJobConfiguration(DMServerJobConfiguration::Callback callback)535   std::unique_ptr<DMServerJobConfiguration> CreatePsmRequestJobConfiguration(
536       DMServerJobConfiguration::Callback callback) {
537     return std::make_unique<DMServerJobConfiguration>(
538         device_management_service_,
539         DeviceManagementService::JobConfiguration::
540             TYPE_PSM_HAS_DEVICE_STATE_REQUEST,
541         random_device_id_,
542         /*critical=*/true, DMAuth::NoAuth(),
543         /*oauth_token=*/base::nullopt, url_loader_factory_,
544         std::move(callback));
545   }
546 
547   // Record UMA histogram for timing of successful private set membership
548   // request.
RecordPrivateSetMembershipSuccessTimeHistogram()549   void RecordPrivateSetMembershipSuccessTimeHistogram() {
550     // These values determine bucketing of the histogram, they should not be
551     // changed.
552     static const base::TimeDelta kMin = base::TimeDelta::FromMilliseconds(1);
553     static const base::TimeDelta kMax = base::TimeDelta::FromSeconds(25);
554     static const int kBuckets = 50;
555 
556     base::TimeTicks now = base::TimeTicks::Now();
557     if (!time_start_.is_null()) {
558       base::TimeDelta delta = now - time_start_;
559       base::UmaHistogramCustomTimes(kUMAPrivateSetMembershipSuccessTime, delta,
560                                     kMin, kMax, kBuckets);
561     }
562   }
563 
564   // Private Set Membership RLWE client, used for preparing PSM requests and
565   // parsing PSM responses.
566   std::unique_ptr<psm_rlwe::PrivateMembershipRlweClient>
567       private_set_membership_rlwe_client_;
568 
569   // Randomly generated device id for the private set membership requests.
570   std::string random_device_id_;
571 
572   // The loader factory to use to perform private set membership requests.
573   scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
574 
575   // Unowned by PrivateSetMembershipHelper. Its used to communicate with the
576   // device management service.
577   DeviceManagementService* device_management_service_;
578 
579   // Its being used for both private set membership requests e.g. RLWE OPRF
580   // request and RLWE query request.
581   std::unique_ptr<DeviceManagementService::Job> psm_request_job_;
582 
583   // Callback will be triggered upon completing of the protocol.
584   CompletionCallback on_completion_callback_;
585 
586   // PrefService where the private set membership protocol result is cached.
587   PrefService* const local_state_;
588 
589   // Private Set Membership identifier, which is going to be used while
590   // preparing the private set membership requests.
591   psm_rlwe::RlwePlaintextId psm_rlwe_id_;
592 
593   // Indicates whether there was previously any error occurred while running
594   // private set membership protocol.
595   bool has_private_set_membership_error_ = false;
596 
597   // A timer that puts a hard limit on the maximum time to wait for private set
598   // membership protocol.
599   base::OneShotTimer private_set_membership_timeout_;
600 
601   // The time when the private set membership request started.
602   base::TimeTicks time_start_;
603 
604   // A sequence checker to prevent the race condition of having the possibility
605   // of the destructor being called and any of the callbacks.
606   SEQUENCE_CHECKER(sequence_checker_);
607 };
608 
609 namespace {
610 
611 // Provides device identifier for Forced Re-Enrollment (FRE), where the
612 // server-backed state key is used.
613 class DeviceIdentifierProviderFRE
614     : public AutoEnrollmentClientImpl::DeviceIdentifierProvider {
615  public:
DeviceIdentifierProviderFRE(const std::string & server_backed_state_key)616   explicit DeviceIdentifierProviderFRE(
617       const std::string& server_backed_state_key) {
618     CHECK(!server_backed_state_key.empty());
619     server_backed_state_key_hash_ =
620         crypto::SHA256HashString(server_backed_state_key);
621   }
622 
GetEnrollmentCheckType() const623   EnrollmentCheckType GetEnrollmentCheckType() const override {
624     return em::DeviceAutoEnrollmentRequest::ENROLLMENT_CHECK_TYPE_FRE;
625   }
626 
GetIdHash() const627   const std::string& GetIdHash() const override {
628     return server_backed_state_key_hash_;
629   }
630 
631  private:
632   // SHA-256 digest of the stable identifier.
633   std::string server_backed_state_key_hash_;
634 };
635 
636 // Provides device identifier for Forced Initial Enrollment, where the brand
637 // code and serial number is used.
638 class DeviceIdentifierProviderInitialEnrollment
639     : public AutoEnrollmentClientImpl::DeviceIdentifierProvider {
640  public:
DeviceIdentifierProviderInitialEnrollment(const std::string & device_serial_number,const std::string & device_brand_code)641   DeviceIdentifierProviderInitialEnrollment(
642       const std::string& device_serial_number,
643       const std::string& device_brand_code) {
644     CHECK(!device_serial_number.empty());
645     CHECK(!device_brand_code.empty());
646     // The hash for initial enrollment is the first 8 bytes of
647     // SHA256(<brnad_code>_<serial_number>).
648     id_hash_ =
649         crypto::SHA256HashString(device_brand_code + "_" + device_serial_number)
650             .substr(0, 8);
651   }
652 
GetEnrollmentCheckType() const653   EnrollmentCheckType GetEnrollmentCheckType() const override {
654     return em::DeviceAutoEnrollmentRequest::
655         ENROLLMENT_CHECK_TYPE_FORCED_ENROLLMENT;
656   }
657 
GetIdHash() const658   const std::string& GetIdHash() const override { return id_hash_; }
659 
660  private:
661   // 8-byte Hash built from serial number and brand code passed to the
662   // constructor.
663   std::string id_hash_;
664 };
665 
666 // Handles DeviceInitialEnrollmentStateRequest /
667 // DeviceInitialEnrollmentStateResponse for Forced Initial Enrollment.
668 class StateDownloadMessageProcessorInitialEnrollment
669     : public AutoEnrollmentClientImpl::StateDownloadMessageProcessor {
670  public:
StateDownloadMessageProcessorInitialEnrollment(const std::string & device_serial_number,const std::string & device_brand_code)671   StateDownloadMessageProcessorInitialEnrollment(
672       const std::string& device_serial_number,
673       const std::string& device_brand_code)
674       : device_serial_number_(device_serial_number),
675         device_brand_code_(device_brand_code) {}
676 
GetJobType() const677   DeviceManagementService::JobConfiguration::JobType GetJobType()
678       const override {
679     return DeviceManagementService::JobConfiguration::
680         TYPE_INITIAL_ENROLLMENT_STATE_RETRIEVAL;
681   }
682 
FillRequest(em::DeviceManagementRequest * request)683   void FillRequest(em::DeviceManagementRequest* request) override {
684     auto* inner_request =
685         request->mutable_device_initial_enrollment_state_request();
686     inner_request->set_brand_code(device_brand_code_);
687     inner_request->set_serial_number(device_serial_number_);
688   }
689 
ParseResponse(const em::DeviceManagementResponse & response)690   base::Optional<ParsedResponse> ParseResponse(
691       const em::DeviceManagementResponse& response) override {
692     if (!response.has_device_initial_enrollment_state_response()) {
693       LOG(ERROR) << "Server failed to provide initial enrollment response.";
694       return base::nullopt;
695     }
696 
697     return ParseInitialEnrollmentStateResponse(
698         response.device_initial_enrollment_state_response());
699   }
700 
ParseInitialEnrollmentStateResponse(const em::DeviceInitialEnrollmentStateResponse & state_response)701   static base::Optional<ParsedResponse> ParseInitialEnrollmentStateResponse(
702       const em::DeviceInitialEnrollmentStateResponse& state_response) {
703     StateDownloadMessageProcessor::ParsedResponse parsed_response;
704 
705     if (state_response.has_initial_enrollment_mode()) {
706       parsed_response.restore_mode = ConvertInitialEnrollmentMode(
707           state_response.initial_enrollment_mode());
708     } else {
709       // Unknown initial enrollment mode - treat as no enrollment.
710       parsed_response.restore_mode.clear();
711     }
712 
713     if (state_response.has_management_domain())
714       parsed_response.management_domain = state_response.management_domain();
715 
716     if (state_response.has_is_license_packaged_with_device()) {
717       parsed_response.is_license_packaged_with_device =
718           state_response.is_license_packaged_with_device();
719     }
720 
721     if (state_response.has_disabled_state()) {
722       parsed_response.disabled_message =
723           state_response.disabled_state().message();
724     }
725 
726     // Logging as "WARNING" to make sure it's preserved in the logs.
727     LOG(WARNING) << "Received initial_enrollment_mode="
728                  << state_response.initial_enrollment_mode() << " ("
729                  << parsed_response.restore_mode << "). "
730                  << (state_response.is_license_packaged_with_device()
731                          ? "Device has a packaged license for management."
732                          : "No packaged license.");
733 
734     return parsed_response;
735   }
736 
737  private:
738   // Serial number of the device.
739   std::string device_serial_number_;
740   // 4-character brand code of the device.
741   std::string device_brand_code_;
742 };
743 
744 // Handles DeviceStateRetrievalRequest / DeviceStateRetrievalResponse for
745 // Forced Re-Enrollment (FRE).
746 class StateDownloadMessageProcessorFRE
747     : public AutoEnrollmentClientImpl::StateDownloadMessageProcessor {
748  public:
StateDownloadMessageProcessorFRE(const std::string & server_backed_state_key)749   explicit StateDownloadMessageProcessorFRE(
750       const std::string& server_backed_state_key)
751       : server_backed_state_key_(server_backed_state_key) {}
752 
GetJobType() const753   DeviceManagementService::JobConfiguration::JobType GetJobType()
754       const override {
755     return DeviceManagementService::JobConfiguration::
756         TYPE_DEVICE_STATE_RETRIEVAL;
757   }
758 
FillRequest(em::DeviceManagementRequest * request)759   void FillRequest(em::DeviceManagementRequest* request) override {
760     request->mutable_device_state_retrieval_request()
761         ->set_server_backed_state_key(server_backed_state_key_);
762   }
763 
ParseResponse(const em::DeviceManagementResponse & response)764   base::Optional<ParsedResponse> ParseResponse(
765       const em::DeviceManagementResponse& response) override {
766     if (!response.has_device_state_retrieval_response()) {
767       LOG(ERROR) << "Server failed to provide auto-enrollment response.";
768       return base::nullopt;
769     }
770 
771     const em::DeviceStateRetrievalResponse& state_response =
772         response.device_state_retrieval_response();
773     const auto restore_mode = state_response.restore_mode();
774 
775     if (restore_mode == em::DeviceStateRetrievalResponse::RESTORE_MODE_NONE &&
776         state_response.has_initial_state_response()) {
777       // Logging as "WARNING" to make sure it's preserved in the logs.
778       LOG(WARNING) << "Received restore_mode=" << restore_mode << " ("
779                    << ConvertRestoreMode(restore_mode) << ")"
780                    << " . Parsing included initial state response.";
781 
782       return StateDownloadMessageProcessorInitialEnrollment::
783           ParseInitialEnrollmentStateResponse(
784               state_response.initial_state_response());
785     } else {
786       StateDownloadMessageProcessor::ParsedResponse parsed_response;
787 
788       parsed_response.restore_mode = ConvertRestoreMode(restore_mode);
789 
790       if (state_response.has_management_domain())
791         parsed_response.management_domain = state_response.management_domain();
792 
793       if (state_response.has_disabled_state()) {
794         parsed_response.disabled_message =
795             state_response.disabled_state().message();
796       }
797 
798       // Package license is not available during the re-enrollment
799       parsed_response.is_license_packaged_with_device.reset();
800 
801       // Logging as "WARNING" to make sure it's preserved in the logs.
802       LOG(WARNING) << "Received restore_mode=" << restore_mode << " ("
803                    << parsed_response.restore_mode << ").";
804 
805       return parsed_response;
806     }
807   }
808 
809  private:
810   // Stable state key.
811   std::string server_backed_state_key_;
812 };
813 
814 }  // namespace
815 
FactoryImpl()816 AutoEnrollmentClientImpl::FactoryImpl::FactoryImpl() {}
~FactoryImpl()817 AutoEnrollmentClientImpl::FactoryImpl::~FactoryImpl() {}
818 
819 std::unique_ptr<AutoEnrollmentClient>
CreateForFRE(const ProgressCallback & progress_callback,DeviceManagementService * device_management_service,PrefService * local_state,scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,const std::string & server_backed_state_key,int power_initial,int power_limit)820 AutoEnrollmentClientImpl::FactoryImpl::CreateForFRE(
821     const ProgressCallback& progress_callback,
822     DeviceManagementService* device_management_service,
823     PrefService* local_state,
824     scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
825     const std::string& server_backed_state_key,
826     int power_initial,
827     int power_limit) {
828   return base::WrapUnique(new AutoEnrollmentClientImpl(
829       progress_callback, device_management_service, local_state,
830       url_loader_factory,
831       std::make_unique<DeviceIdentifierProviderFRE>(server_backed_state_key),
832       std::make_unique<StateDownloadMessageProcessorFRE>(
833           server_backed_state_key),
834       power_initial, power_limit,
835       /*power_outdated_server_detect=*/base::nullopt, kUMAHashDanceSuffixFRE,
836       /*private_set_membership_helper=*/nullptr));
837 }
838 
839 std::unique_ptr<AutoEnrollmentClient>
CreateForInitialEnrollment(const ProgressCallback & progress_callback,DeviceManagementService * device_management_service,PrefService * local_state,scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,const std::string & device_serial_number,const std::string & device_brand_code,int power_initial,int power_limit,int power_outdated_server_detect)840 AutoEnrollmentClientImpl::FactoryImpl::CreateForInitialEnrollment(
841     const ProgressCallback& progress_callback,
842     DeviceManagementService* device_management_service,
843     PrefService* local_state,
844     scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
845     const std::string& device_serial_number,
846     const std::string& device_brand_code,
847     int power_initial,
848     int power_limit,
849     int power_outdated_server_detect) {
850   return base::WrapUnique(new AutoEnrollmentClientImpl(
851       progress_callback, device_management_service, local_state,
852       url_loader_factory,
853       std::make_unique<DeviceIdentifierProviderInitialEnrollment>(
854           device_serial_number, device_brand_code),
855       std::make_unique<StateDownloadMessageProcessorInitialEnrollment>(
856           device_serial_number, device_brand_code),
857       power_initial, power_limit,
858       base::make_optional(power_outdated_server_detect),
859       kUMAHashDanceSuffixInitialEnrollment,
860       chromeos::AutoEnrollmentController::IsPrivateSetMembershipEnabled()
861           ? std::make_unique<PrivateSetMembershipHelper>(
862                 device_management_service, url_loader_factory, local_state,
863                 ConstructDeviceRlweId(device_serial_number, device_brand_code))
864           : nullptr));
865 }
866 
~AutoEnrollmentClientImpl()867 AutoEnrollmentClientImpl::~AutoEnrollmentClientImpl() {
868   content::GetNetworkConnectionTracker()->RemoveNetworkConnectionObserver(this);
869 }
870 
871 // static
RegisterPrefs(PrefRegistrySimple * registry)872 void AutoEnrollmentClientImpl::RegisterPrefs(PrefRegistrySimple* registry) {
873   registry->RegisterBooleanPref(prefs::kShouldAutoEnroll, false);
874   registry->RegisterIntegerPref(prefs::kAutoEnrollmentPowerLimit, -1);
875   registry->RegisterBooleanPref(prefs::kShouldRetrieveDeviceState, false);
876 }
877 
Start()878 void AutoEnrollmentClientImpl::Start() {
879   // (Re-)register the network change observer.
880   content::GetNetworkConnectionTracker()->RemoveNetworkConnectionObserver(this);
881   content::GetNetworkConnectionTracker()->AddNetworkConnectionObserver(this);
882 
883   // Drop the previous job and reset state.
884   request_job_.reset();
885   hash_dance_time_start_ = base::TimeTicks();
886   state_ = AUTO_ENROLLMENT_STATE_PENDING;
887   modulus_updates_received_ = 0;
888   has_server_state_ = false;
889   device_state_available_ = false;
890 
891   NextStep();
892 }
893 
Retry()894 void AutoEnrollmentClientImpl::Retry() {
895   RetryStep();
896 }
897 
CancelAndDeleteSoon()898 void AutoEnrollmentClientImpl::CancelAndDeleteSoon() {
899   // Check if neither Hash dance request i.e. DeviceAutoEnrollmentRequest nor
900   // DeviceStateRetrievalRequest is in progress.
901   if (!request_job_) {
902     // The client isn't running, just delete it.
903     delete this;
904   } else {
905     // Client still running, but our owner isn't interested in the result
906     // anymore. Wait until the protocol completes to measure the extra time
907     // needed.
908     time_extra_start_ = base::TimeTicks::Now();
909     progress_callback_.Reset();
910   }
911 }
912 
device_id() const913 std::string AutoEnrollmentClientImpl::device_id() const {
914   return device_id_;
915 }
916 
state() const917 AutoEnrollmentState AutoEnrollmentClientImpl::state() const {
918   return state_;
919 }
920 
OnConnectionChanged(network::mojom::ConnectionType type)921 void AutoEnrollmentClientImpl::OnConnectionChanged(
922     network::mojom::ConnectionType type) {
923   if (type != network::mojom::ConnectionType::CONNECTION_NONE &&
924       !progress_callback_.is_null()) {
925     RetryStep();
926   }
927 }
928 
AutoEnrollmentClientImpl(const ProgressCallback & callback,DeviceManagementService * service,PrefService * local_state,scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,std::unique_ptr<DeviceIdentifierProvider> device_identifier_provider,std::unique_ptr<StateDownloadMessageProcessor> state_download_message_processor,int power_initial,int power_limit,base::Optional<int> power_outdated_server_detect,std::string uma_suffix,std::unique_ptr<PrivateSetMembershipHelper> private_set_membership_helper)929 AutoEnrollmentClientImpl::AutoEnrollmentClientImpl(
930     const ProgressCallback& callback,
931     DeviceManagementService* service,
932     PrefService* local_state,
933     scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
934     std::unique_ptr<DeviceIdentifierProvider> device_identifier_provider,
935     std::unique_ptr<StateDownloadMessageProcessor>
936         state_download_message_processor,
937     int power_initial,
938     int power_limit,
939     base::Optional<int> power_outdated_server_detect,
940     std::string uma_suffix,
941     std::unique_ptr<PrivateSetMembershipHelper> private_set_membership_helper)
942     : progress_callback_(callback),
943       state_(AUTO_ENROLLMENT_STATE_IDLE),
944       has_server_state_(false),
945       device_state_available_(false),
946       device_id_(base::GenerateGUID()),
947       current_power_(power_initial),
948       power_limit_(power_limit),
949       power_outdated_server_detect_(power_outdated_server_detect),
950       modulus_updates_received_(0),
951       device_management_service_(service),
952       local_state_(local_state),
953       url_loader_factory_(url_loader_factory),
954       device_identifier_provider_(std::move(device_identifier_provider)),
955       state_download_message_processor_(
956           std::move(state_download_message_processor)),
957       private_set_membership_helper_(std::move(private_set_membership_helper)),
958       uma_suffix_(uma_suffix),
959       recorded_psm_hash_dance_comparison_(false) {
960   DCHECK_LE(current_power_, power_limit_);
961   DCHECK(!progress_callback_.is_null());
962 }
963 
GetCachedDecision()964 bool AutoEnrollmentClientImpl::GetCachedDecision() {
965   const PrefService::Preference* has_server_state_pref =
966       local_state_->FindPreference(prefs::kShouldAutoEnroll);
967   const PrefService::Preference* previous_limit_pref =
968       local_state_->FindPreference(prefs::kAutoEnrollmentPowerLimit);
969   bool has_server_state = false;
970   int previous_limit = -1;
971 
972   if (!has_server_state_pref || has_server_state_pref->IsDefaultValue() ||
973       !has_server_state_pref->GetValue()->GetAsBoolean(&has_server_state) ||
974       !previous_limit_pref || previous_limit_pref->IsDefaultValue() ||
975       !previous_limit_pref->GetValue()->GetAsInteger(&previous_limit) ||
976       power_limit_ > previous_limit) {
977     return false;
978   }
979 
980   has_server_state_ = has_server_state;
981   return true;
982 }
983 
RetryStep()984 bool AutoEnrollmentClientImpl::RetryStep() {
985   if (PrivateSetMembershipRetryStep())
986     return true;
987 
988   // If there is a pending request job, let it finish.
989   if (request_job_)
990     return true;
991 
992   if (GetCachedDecision()) {
993     VLOG(1) << "Cached: has_state=" << has_server_state_;
994     // The bucket download check has completed already. If it came back
995     // positive, then device state should be (re-)downloaded.
996     if (has_server_state_) {
997       if (!device_state_available_) {
998         SendDeviceStateRequest();
999         return true;
1000       }
1001     }
1002   } else {
1003     // Start bucket download.
1004     SendBucketDownloadRequest();
1005     return true;
1006   }
1007 
1008   return false;
1009 }
1010 
PrivateSetMembershipRetryStep()1011 bool AutoEnrollmentClientImpl::PrivateSetMembershipRetryStep() {
1012   // Don't retry if the protocol is disabled, or an error occurred while
1013   // executing the protocol.
1014   if (!private_set_membership_helper_ ||
1015       private_set_membership_helper_->HasPrivateSetMembershipError()) {
1016     return false;
1017   }
1018 
1019   // If the private set membership protocol is in progress, signal to the caller
1020   // that nothing else needs to be done.
1021   if (private_set_membership_helper_->IsCheckMembershipInProgress())
1022     return true;
1023 
1024   const base::Optional<bool> private_set_membership_server_state =
1025       private_set_membership_helper_->GetPrivateSetMembershipCachedDecision();
1026 
1027   if (private_set_membership_server_state.has_value()) {
1028     LOG(WARNING) << "PSM Cached: psm_server_state="
1029                  << private_set_membership_server_state.value();
1030     return false;
1031   } else {
1032     private_set_membership_helper_->CheckMembership(base::BindOnce(
1033         &AutoEnrollmentClientImpl::RetryStep, base::Unretained(this)));
1034     return true;
1035   }
1036 }
1037 
SetPrivateSetMembershipRlweClientForTesting(std::unique_ptr<psm_rlwe::PrivateMembershipRlweClient> private_set_membership_rlwe_client,const psm_rlwe::RlwePlaintextId & psm_rlwe_id)1038 void AutoEnrollmentClientImpl::SetPrivateSetMembershipRlweClientForTesting(
1039     std::unique_ptr<psm_rlwe::PrivateMembershipRlweClient>
1040         private_set_membership_rlwe_client,
1041     const psm_rlwe::RlwePlaintextId& psm_rlwe_id) {
1042   if (!private_set_membership_helper_)
1043     return;
1044 
1045   DCHECK(private_set_membership_rlwe_client);
1046   private_set_membership_helper_->SetRlweClientAndIdForTesting(
1047       std::move(private_set_membership_rlwe_client), std::move(psm_rlwe_id));
1048 }
1049 
ReportProgress(AutoEnrollmentState state)1050 void AutoEnrollmentClientImpl::ReportProgress(AutoEnrollmentState state) {
1051   state_ = state;
1052   // If hash dance finished with an error or result, record comparison with
1053   // private set membership. Note that hash dance might be retried but for
1054   // recording we only care about the first attempt.
1055   // If |private_set_membership_helper_| is non-null, a private set membership
1056   // request has been made at this point because it is executed before hash
1057   // dance.
1058   const bool has_hash_dance_result = (state != AUTO_ENROLLMENT_STATE_IDLE &&
1059                                       state != AUTO_ENROLLMENT_STATE_PENDING);
1060   if (private_set_membership_helper_ && !recorded_psm_hash_dance_comparison_ &&
1061       has_hash_dance_result) {
1062     RecordPrivateSetMembershipHashDanceComparison();
1063   }
1064   if (progress_callback_.is_null()) {
1065     base::ThreadTaskRunnerHandle::Get()->DeleteSoon(FROM_HERE, this);
1066   } else {
1067     progress_callback_.Run(state_);
1068   }
1069 }
1070 
NextStep()1071 void AutoEnrollmentClientImpl::NextStep() {
1072   if (RetryStep())
1073     return;
1074 
1075   // Protocol finished successfully, report result.
1076   const DeviceStateMode device_state_mode = GetDeviceStateMode();
1077   switch (device_state_mode) {
1078     case RESTORE_MODE_NONE:
1079       ReportProgress(AUTO_ENROLLMENT_STATE_NO_ENROLLMENT);
1080       break;
1081     case RESTORE_MODE_DISABLED:
1082       ReportProgress(AUTO_ENROLLMENT_STATE_DISABLED);
1083       break;
1084     case RESTORE_MODE_REENROLLMENT_REQUESTED:
1085     case RESTORE_MODE_REENROLLMENT_ENFORCED:
1086     case INITIAL_MODE_ENROLLMENT_ENFORCED:
1087       ReportProgress(AUTO_ENROLLMENT_STATE_TRIGGER_ENROLLMENT);
1088       break;
1089     case RESTORE_MODE_REENROLLMENT_ZERO_TOUCH:
1090     case INITIAL_MODE_ENROLLMENT_ZERO_TOUCH:
1091       ReportProgress(AUTO_ENROLLMENT_STATE_TRIGGER_ZERO_TOUCH);
1092       break;
1093   }
1094 }
1095 
SendBucketDownloadRequest()1096 void AutoEnrollmentClientImpl::SendBucketDownloadRequest() {
1097   // Start the Hash dance timer during the first attempt.
1098   if (hash_dance_time_start_.is_null())
1099     hash_dance_time_start_ = base::TimeTicks::Now();
1100 
1101   std::string id_hash = device_identifier_provider_->GetIdHash();
1102   // Currently AutoEnrollmentClientImpl supports working with hashes that are at
1103   // least 8 bytes long. If this is reduced, the computation of the remainder
1104   // must also be adapted to handle the case of a shorter hash gracefully.
1105   DCHECK_GE(id_hash.size(), 8u);
1106 
1107   uint64_t remainder = 0;
1108   const size_t last_byte_index = id_hash.size() - 1;
1109   for (int i = 0; 8 * i < current_power_; ++i) {
1110     uint64_t byte = id_hash[last_byte_index - i] & 0xff;
1111     remainder = remainder | (byte << (8 * i));
1112   }
1113   remainder = remainder & ((UINT64_C(1) << current_power_) - 1);
1114 
1115   ReportProgress(AUTO_ENROLLMENT_STATE_PENDING);
1116 
1117   // Record the time when the bucket download request is started. Note that the
1118   // time may be set multiple times. This is fine, only the last request is the
1119   // one where the hash bucket is actually downloaded.
1120   time_start_bucket_download_ = base::TimeTicks::Now();
1121 
1122   VLOG(1) << "Request bucket #" << remainder;
1123   std::unique_ptr<DMServerJobConfiguration> config = std::make_unique<
1124       DMServerJobConfiguration>(
1125       device_management_service_,
1126       policy::DeviceManagementService::JobConfiguration::TYPE_AUTO_ENROLLMENT,
1127       device_id_,
1128       /*critical=*/false, DMAuth::NoAuth(),
1129       /*oauth_token=*/base::nullopt, url_loader_factory_,
1130       base::BindOnce(
1131           &AutoEnrollmentClientImpl::HandleRequestCompletion,
1132           base::Unretained(this),
1133           &AutoEnrollmentClientImpl::OnBucketDownloadRequestCompletion));
1134 
1135   em::DeviceAutoEnrollmentRequest* request =
1136       config->request()->mutable_auto_enrollment_request();
1137   request->set_remainder(remainder);
1138   request->set_modulus(INT64_C(1) << current_power_);
1139   request->set_enrollment_check_type(
1140       device_identifier_provider_->GetEnrollmentCheckType());
1141 
1142   request_job_ = device_management_service_->CreateJob(std::move(config));
1143 }
1144 
SendDeviceStateRequest()1145 void AutoEnrollmentClientImpl::SendDeviceStateRequest() {
1146   ReportProgress(AUTO_ENROLLMENT_STATE_PENDING);
1147 
1148   std::unique_ptr<DMServerJobConfiguration> config =
1149       std::make_unique<DMServerJobConfiguration>(
1150           device_management_service_,
1151           state_download_message_processor_->GetJobType(), device_id_,
1152           /*critical=*/false, DMAuth::NoAuth(),
1153           /*oauth_token=*/base::nullopt, url_loader_factory_,
1154           base::BindRepeating(
1155               &AutoEnrollmentClientImpl::HandleRequestCompletion,
1156               base::Unretained(this),
1157               &AutoEnrollmentClientImpl::OnDeviceStateRequestCompletion));
1158 
1159   state_download_message_processor_->FillRequest(config->request());
1160   request_job_ = device_management_service_->CreateJob(std::move(config));
1161 }
1162 
HandleRequestCompletion(RequestCompletionHandler handler,policy::DeviceManagementService::Job * job,DeviceManagementStatus status,int net_error,const em::DeviceManagementResponse & response)1163 void AutoEnrollmentClientImpl::HandleRequestCompletion(
1164     RequestCompletionHandler handler,
1165     policy::DeviceManagementService::Job* job,
1166     DeviceManagementStatus status,
1167     int net_error,
1168     const em::DeviceManagementResponse& response) {
1169   base::UmaHistogramSparse(kUMAHashDanceRequestStatus + uma_suffix_, status);
1170   if (status != DM_STATUS_SUCCESS) {
1171     LOG(ERROR) << "Auto enrollment error: " << status;
1172     if (status == DM_STATUS_REQUEST_FAILED)
1173       base::UmaHistogramSparse(kUMAHashDanceNetworkErrorCode + uma_suffix_,
1174                                -net_error);
1175     request_job_.reset();
1176 
1177     // Abort if CancelAndDeleteSoon has been called meanwhile.
1178     if (progress_callback_.is_null()) {
1179       base::ThreadTaskRunnerHandle::Get()->DeleteSoon(FROM_HERE, this);
1180     } else {
1181       ReportProgress(status == DM_STATUS_REQUEST_FAILED
1182                          ? AUTO_ENROLLMENT_STATE_CONNECTION_ERROR
1183                          : AUTO_ENROLLMENT_STATE_SERVER_ERROR);
1184     }
1185     return;
1186   }
1187 
1188   bool progress =
1189       (this->*handler)(request_job_.get(), status, net_error, response);
1190   request_job_.reset();
1191   if (progress)
1192     NextStep();
1193   else
1194     ReportProgress(AUTO_ENROLLMENT_STATE_SERVER_ERROR);
1195 }
1196 
OnBucketDownloadRequestCompletion(policy::DeviceManagementService::Job * job,DeviceManagementStatus status,int net_error,const em::DeviceManagementResponse & response)1197 bool AutoEnrollmentClientImpl::OnBucketDownloadRequestCompletion(
1198     policy::DeviceManagementService::Job* job,
1199     DeviceManagementStatus status,
1200     int net_error,
1201     const em::DeviceManagementResponse& response) {
1202   bool progress = false;
1203   const em::DeviceAutoEnrollmentResponse& enrollment_response =
1204       response.auto_enrollment_response();
1205   if (!response.has_auto_enrollment_response()) {
1206     LOG(ERROR) << "Server failed to provide auto-enrollment response.";
1207   } else if (enrollment_response.has_expected_modulus()) {
1208     // Server is asking us to retry with a different modulus.
1209     modulus_updates_received_++;
1210 
1211     int64_t modulus = enrollment_response.expected_modulus();
1212     int power = NextPowerOf2(modulus);
1213     if ((INT64_C(1) << power) != modulus) {
1214       LOG(ERROR) << "Auto enrollment: the server didn't ask for a power-of-2 "
1215                  << "modulus. Using the closest power-of-2 instead "
1216                  << "(" << modulus << " vs 2^" << power << ")";
1217     }
1218     if (modulus_updates_received_ >= 2) {
1219       LOG(ERROR) << "Auto enrollment error: already retried with an updated "
1220                  << "modulus but the server asked for a new one again: "
1221                  << power;
1222     } else if (power_outdated_server_detect_.has_value() &&
1223                power >= power_outdated_server_detect_.value()) {
1224       LOG(ERROR) << "Skipping auto enrollment: The server was detected as "
1225                  << "outdated (power=" << power
1226                  << ", power_outdated_server_detect="
1227                  << power_outdated_server_detect_.value() << ").";
1228       has_server_state_ = false;
1229       // Cache the decision in local_state, so that it is reused in case
1230       // the device reboots before completing OOBE. Note that this does not
1231       // disable Forced Re-Enrollment for this device, because local state will
1232       // be empty after the device is wiped.
1233       local_state_->SetBoolean(prefs::kShouldAutoEnroll, false);
1234       local_state_->SetInteger(prefs::kAutoEnrollmentPowerLimit, power_limit_);
1235       local_state_->CommitPendingWrite();
1236       return true;
1237     } else if (power > power_limit_) {
1238       LOG(ERROR) << "Auto enrollment error: the server asked for a larger "
1239                  << "modulus than the client accepts (" << power << " vs "
1240                  << power_limit_ << ").";
1241     } else {
1242       // Retry at most once with the modulus that the server requested.
1243       if (power <= current_power_) {
1244         LOG(WARNING) << "Auto enrollment: the server asked to use a modulus ("
1245                      << power << ") that isn't larger than the first used ("
1246                      << current_power_ << "). Retrying anyway.";
1247       }
1248       // Remember this value, so that eventual retries start with the correct
1249       // modulus.
1250       current_power_ = power;
1251       return true;
1252     }
1253   } else {
1254     // Server should have sent down a list of hashes to try.
1255     has_server_state_ = IsIdHashInProtobuf(enrollment_response.hashes());
1256     // Cache the current decision in local_state, so that it is reused in case
1257     // the device reboots before enrolling.
1258     local_state_->SetBoolean(prefs::kShouldAutoEnroll, has_server_state_);
1259     local_state_->SetInteger(prefs::kAutoEnrollmentPowerLimit, power_limit_);
1260     local_state_->CommitPendingWrite();
1261     VLOG(1) << "Received has_state=" << has_server_state_;
1262     progress = true;
1263     // Report timing if hash dance finished successfully and if the caller is
1264     // still interested in the result.
1265     if (!progress_callback_.is_null())
1266       RecordHashDanceSuccessTimeHistogram();
1267   }
1268 
1269   // Bucket download done, update UMA.
1270   UpdateBucketDownloadTimingHistograms();
1271   return progress;
1272 }
1273 
OnDeviceStateRequestCompletion(policy::DeviceManagementService::Job * job,DeviceManagementStatus status,int net_error,const em::DeviceManagementResponse & response)1274 bool AutoEnrollmentClientImpl::OnDeviceStateRequestCompletion(
1275     policy::DeviceManagementService::Job* job,
1276     DeviceManagementStatus status,
1277     int net_error,
1278     const em::DeviceManagementResponse& response) {
1279   base::Optional<StateDownloadMessageProcessor::ParsedResponse>
1280       parsed_response_opt;
1281 
1282   parsed_response_opt =
1283       state_download_message_processor_->ParseResponse(response);
1284   if (!parsed_response_opt)
1285     return false;
1286 
1287   StateDownloadMessageProcessor::ParsedResponse parsed_response =
1288       std::move(parsed_response_opt.value());
1289   {
1290     DictionaryPrefUpdate dict(local_state_, prefs::kServerBackedDeviceState);
1291     UpdateDict(dict.Get(), kDeviceStateManagementDomain,
1292                parsed_response.management_domain.has_value(),
1293                std::make_unique<base::Value>(
1294                    parsed_response.management_domain.value_or(std::string())));
1295 
1296     UpdateDict(dict.Get(), kDeviceStateMode,
1297                !parsed_response.restore_mode.empty(),
1298                std::make_unique<base::Value>(parsed_response.restore_mode));
1299 
1300     UpdateDict(dict.Get(), kDeviceStateDisabledMessage,
1301                parsed_response.disabled_message.has_value(),
1302                std::make_unique<base::Value>(
1303                    parsed_response.disabled_message.value_or(std::string())));
1304 
1305     UpdateDict(
1306         dict.Get(), kDeviceStatePackagedLicense,
1307         parsed_response.is_license_packaged_with_device.has_value(),
1308         std::make_unique<base::Value>(
1309             parsed_response.is_license_packaged_with_device.value_or(false)));
1310   }
1311   local_state_->CommitPendingWrite();
1312   device_state_available_ = true;
1313   return true;
1314 }
1315 
IsIdHashInProtobuf(const google::protobuf::RepeatedPtrField<std::string> & hashes)1316 bool AutoEnrollmentClientImpl::IsIdHashInProtobuf(
1317     const google::protobuf::RepeatedPtrField<std::string>& hashes) {
1318   std::string id_hash = device_identifier_provider_->GetIdHash();
1319   for (int i = 0; i < hashes.size(); ++i) {
1320     if (hashes.Get(i) == id_hash)
1321       return true;
1322   }
1323   return false;
1324 }
1325 
UpdateBucketDownloadTimingHistograms()1326 void AutoEnrollmentClientImpl::UpdateBucketDownloadTimingHistograms() {
1327   // These values determine bucketing of the histogram, they should not be
1328   // changed.
1329   // The minimum time can't be 0, must be at least 1.
1330   static const base::TimeDelta kMin = base::TimeDelta::FromMilliseconds(1);
1331   static const base::TimeDelta kMax = base::TimeDelta::FromMinutes(5);
1332   // However, 0 can still be sampled.
1333   static const base::TimeDelta kZero = base::TimeDelta::FromMilliseconds(0);
1334   static const int kBuckets = 50;
1335 
1336   base::TimeTicks now = base::TimeTicks::Now();
1337   if (!hash_dance_time_start_.is_null()) {
1338     base::TimeDelta delta = now - hash_dance_time_start_;
1339     base::UmaHistogramCustomTimes(kUMAHashDanceProtocolTime + uma_suffix_,
1340                                   delta, kMin, kMax, kBuckets);
1341   }
1342   if (!time_start_bucket_download_.is_null()) {
1343     base::TimeDelta delta = now - time_start_bucket_download_;
1344     base::UmaHistogramCustomTimes(kUMAHashDanceBucketDownloadTime + uma_suffix_,
1345                                   delta, kMin, kMax, kBuckets);
1346   }
1347   base::TimeDelta delta = kZero;
1348   if (!time_extra_start_.is_null())
1349     delta = now - time_extra_start_;
1350   // This samples |kZero| when there was no need for extra time, so that we can
1351   // measure the ratio of users that succeeded without needing a delay to the
1352   // total users going through OOBE.
1353   base::UmaHistogramCustomTimes(kUMAHashDanceExtraTime + uma_suffix_, delta,
1354                                 kMin, kMax, kBuckets);
1355 }
1356 
RecordHashDanceSuccessTimeHistogram()1357 void AutoEnrollmentClientImpl::RecordHashDanceSuccessTimeHistogram() {
1358   // These values determine bucketing of the histogram, they should not be
1359   // changed.
1360   static const base::TimeDelta kMin = base::TimeDelta::FromMilliseconds(1);
1361   static const base::TimeDelta kMax = base::TimeDelta::FromSeconds(25);
1362   static const int kBuckets = 50;
1363 
1364   base::TimeTicks now = base::TimeTicks::Now();
1365   if (!hash_dance_time_start_.is_null()) {
1366     base::TimeDelta delta = now - hash_dance_time_start_;
1367     base::UmaHistogramCustomTimes(kUMAHashDanceSuccessTime + uma_suffix_, delta,
1368                                   kMin, kMax, kBuckets);
1369   }
1370 }
1371 
RecordPrivateSetMembershipHashDanceComparison()1372 void AutoEnrollmentClientImpl::RecordPrivateSetMembershipHashDanceComparison() {
1373   // Private set membership timeout is enforced in the helper class. This method
1374   // should only be called after private set membership request finished or ran
1375   // into timeout.
1376   DCHECK(private_set_membership_helper_);
1377   DCHECK(!private_set_membership_helper_->IsCheckMembershipInProgress());
1378 
1379   // Make sure to only record once per instance.
1380   recorded_psm_hash_dance_comparison_ = true;
1381 
1382   bool private_set_membership_error =
1383       private_set_membership_helper_->HasPrivateSetMembershipError();
1384 
1385   bool hash_dance_decision = has_server_state_;
1386   bool hash_dance_error = false;
1387   switch (state_) {
1388     case AUTO_ENROLLMENT_STATE_TRIGGER_ENROLLMENT:
1389     case AUTO_ENROLLMENT_STATE_NO_ENROLLMENT:
1390     case AUTO_ENROLLMENT_STATE_TRIGGER_ZERO_TOUCH:
1391     case AUTO_ENROLLMENT_STATE_DISABLED:
1392       hash_dance_error = false;
1393       break;
1394     case AUTO_ENROLLMENT_STATE_CONNECTION_ERROR:
1395     case AUTO_ENROLLMENT_STATE_SERVER_ERROR:
1396       hash_dance_error = true;
1397       break;
1398     // This method should only be called if hash dance finished.
1399     case AUTO_ENROLLMENT_STATE_IDLE:
1400     case AUTO_ENROLLMENT_STATE_PENDING:
1401     default:
1402       NOTREACHED();
1403   }
1404 
1405   auto comparison = PrivateSetMembershipHashDanceComparison::kEqualResults;
1406   if (!hash_dance_error && !private_set_membership_error) {
1407     base::Optional<bool> private_set_membership_decision =
1408         private_set_membership_helper_->GetPrivateSetMembershipCachedDecision();
1409 
1410     // There was no error and this function is only invoked after PSM has been
1411     // performed, so there must be a decision.
1412     DCHECK(private_set_membership_decision.has_value());
1413 
1414     comparison =
1415         (hash_dance_decision == private_set_membership_decision.value())
1416             ? PrivateSetMembershipHashDanceComparison::kEqualResults
1417             : PrivateSetMembershipHashDanceComparison::kDifferentResults;
1418   } else if (hash_dance_error && !private_set_membership_error) {
1419     comparison =
1420         PrivateSetMembershipHashDanceComparison::kPSMSuccessHashDanceError;
1421   } else if (!hash_dance_error && private_set_membership_error) {
1422     comparison =
1423         PrivateSetMembershipHashDanceComparison::kPSMErrorHashDanceSuccess;
1424   } else {
1425     comparison = PrivateSetMembershipHashDanceComparison::kBothError;
1426   }
1427 
1428   base::UmaHistogramEnumeration(kUMAPrivateSetMembershipHashDanceComparison,
1429                                 comparison);
1430 }
1431 
1432 }  // namespace policy
1433