1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chromeos/components/tether/message_transfer_operation.h"
6 
7 #include <memory>
8 
9 #include "base/memory/ptr_util.h"
10 #include "base/timer/mock_timer.h"
11 #include "chromeos/components/multidevice/remote_device_test_util.h"
12 #include "chromeos/components/tether/message_wrapper.h"
13 #include "chromeos/components/tether/proto_test_util.h"
14 #include "chromeos/components/tether/test_timer_factory.h"
15 #include "chromeos/services/device_sync/public/cpp/fake_device_sync_client.h"
16 #include "chromeos/services/secure_channel/public/cpp/client/fake_client_channel.h"
17 #include "chromeos/services/secure_channel/public/cpp/client/fake_connection_attempt.h"
18 #include "chromeos/services/secure_channel/public/cpp/client/fake_secure_channel_client.h"
19 #include "testing/gmock/include/gmock/gmock.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21 
22 namespace chromeos {
23 
24 namespace tether {
25 
26 namespace {
27 
28 // Arbitrarily chosen value. The MessageType used in this test does not matter
29 // except that it must be consistent throughout the test.
30 const MessageType kTestMessageType = MessageType::TETHER_AVAILABILITY_REQUEST;
31 
32 const uint32_t kTestTimeoutSeconds = 5;
33 
34 const char kTetherFeature[] = "magic_tether";
35 
36 // A test double for MessageTransferOperation is needed because
37 // MessageTransferOperation has pure virtual methods which must be overridden in
38 // order to create a concrete instantiation of the class.
39 class TestOperation : public MessageTransferOperation {
40  public:
TestOperation(const multidevice::RemoteDeviceRefList & devices_to_connect,device_sync::DeviceSyncClient * device_sync_client,secure_channel::SecureChannelClient * secure_channel_client)41   TestOperation(const multidevice::RemoteDeviceRefList& devices_to_connect,
42                 device_sync::DeviceSyncClient* device_sync_client,
43                 secure_channel::SecureChannelClient* secure_channel_client)
44       : MessageTransferOperation(devices_to_connect,
45                                  secure_channel::ConnectionPriority::kLow,
46                                  device_sync_client,
47                                  secure_channel_client) {}
48   ~TestOperation() override = default;
49 
HasDeviceAuthenticated(multidevice::RemoteDeviceRef remote_device)50   bool HasDeviceAuthenticated(multidevice::RemoteDeviceRef remote_device) {
51     const auto iter = device_map_.find(remote_device);
52     if (iter == device_map_.end())
53       return false;
54 
55     return iter->second.has_device_authenticated;
56   }
57 
GetReceivedMessages(multidevice::RemoteDeviceRef remote_device)58   std::vector<std::shared_ptr<MessageWrapper>> GetReceivedMessages(
59       multidevice::RemoteDeviceRef remote_device) {
60     const auto iter = device_map_.find(remote_device);
61     if (iter == device_map_.end())
62       return std::vector<std::shared_ptr<MessageWrapper>>();
63 
64     return iter->second.received_messages;
65   }
66 
67   // MessageTransferOperation:
OnDeviceAuthenticated(multidevice::RemoteDeviceRef remote_device)68   void OnDeviceAuthenticated(
69       multidevice::RemoteDeviceRef remote_device) override {
70     device_map_[remote_device].has_device_authenticated = true;
71   }
72 
OnMessageReceived(std::unique_ptr<MessageWrapper> message_wrapper,multidevice::RemoteDeviceRef remote_device)73   void OnMessageReceived(std::unique_ptr<MessageWrapper> message_wrapper,
74                          multidevice::RemoteDeviceRef remote_device) override {
75     device_map_[remote_device].received_messages.push_back(
76         std::move(message_wrapper));
77 
78     if (should_unregister_device_on_message_received_)
79       UnregisterDevice(remote_device);
80   }
81 
OnOperationStarted()82   void OnOperationStarted() override { has_operation_started_ = true; }
83 
OnOperationFinished()84   void OnOperationFinished() override { has_operation_finished_ = true; }
85 
GetMessageTypeForConnection()86   MessageType GetMessageTypeForConnection() override {
87     return kTestMessageType;
88   }
89 
OnMessageSent(int sequence_number)90   void OnMessageSent(int sequence_number) override {
91     last_sequence_number_ = sequence_number;
92   }
93 
GetMessageTimeoutSeconds()94   uint32_t GetMessageTimeoutSeconds() override { return timeout_seconds_; }
95 
set_timeout_seconds(uint32_t timeout_seconds)96   void set_timeout_seconds(uint32_t timeout_seconds) {
97     timeout_seconds_ = timeout_seconds;
98   }
99 
set_should_unregister_device_on_message_received(bool should_unregister_device_on_message_received)100   void set_should_unregister_device_on_message_received(
101       bool should_unregister_device_on_message_received) {
102     should_unregister_device_on_message_received_ =
103         should_unregister_device_on_message_received;
104   }
105 
has_operation_started()106   bool has_operation_started() { return has_operation_started_; }
107 
has_operation_finished()108   bool has_operation_finished() { return has_operation_finished_; }
109 
last_sequence_number()110   base::Optional<int> last_sequence_number() { return last_sequence_number_; }
111 
112  private:
113   struct DeviceMapValue {
114     DeviceMapValue() = default;
115     ~DeviceMapValue() = default;
116 
117     bool has_device_authenticated;
118     std::vector<std::shared_ptr<MessageWrapper>> received_messages;
119   };
120 
121   base::flat_map<multidevice::RemoteDeviceRef, DeviceMapValue> device_map_;
122 
123   uint32_t timeout_seconds_ = kTestTimeoutSeconds;
124   bool should_unregister_device_on_message_received_ = false;
125   bool has_operation_started_ = false;
126   bool has_operation_finished_ = false;
127   base::Optional<int> last_sequence_number_;
128 };
129 
CreateTetherAvailabilityResponse()130 TetherAvailabilityResponse CreateTetherAvailabilityResponse() {
131   TetherAvailabilityResponse response;
132   response.set_response_code(
133       TetherAvailabilityResponse_ResponseCode::
134           TetherAvailabilityResponse_ResponseCode_TETHER_AVAILABLE);
135   response.mutable_device_status()->CopyFrom(
136       CreateDeviceStatusWithFakeFields());
137   return response;
138 }
139 
140 }  // namespace
141 
142 class MessageTransferOperationTest : public testing::Test {
143  protected:
MessageTransferOperationTest()144   MessageTransferOperationTest()
145       : test_local_device_(multidevice::RemoteDeviceRefBuilder()
146                                .SetPublicKey("local device")
147                                .Build()),
148         test_devices_(multidevice::CreateRemoteDeviceRefListForTest(4)) {}
149 
SetUp()150   void SetUp() override {
151     fake_device_sync_client_ =
152         std::make_unique<device_sync::FakeDeviceSyncClient>();
153     fake_device_sync_client_->set_local_device_metadata(test_local_device_);
154     fake_secure_channel_client_ =
155         std::make_unique<secure_channel::FakeSecureChannelClient>();
156   }
157 
ConstructOperation(multidevice::RemoteDeviceRefList remote_devices)158   void ConstructOperation(multidevice::RemoteDeviceRefList remote_devices) {
159     test_timer_factory_ = new TestTimerFactory();
160 
161     for (auto remote_device : remote_devices) {
162       // Prepare for connection timeout timers to be made for each remote
163       // device.
164       test_timer_factory_->set_device_id_for_next_timer(
165           remote_device.GetDeviceId());
166 
167       auto fake_connection_attempt =
168           std::make_unique<secure_channel::FakeConnectionAttempt>();
169       remote_device_to_fake_connection_attempt_map_[remote_device] =
170           fake_connection_attempt.get();
171       fake_secure_channel_client_->set_next_listen_connection_attempt(
172           remote_device, test_local_device_,
173           std::move(fake_connection_attempt));
174     }
175 
176     operation_ = base::WrapUnique(
177         new TestOperation(remote_devices, fake_device_sync_client_.get(),
178                           fake_secure_channel_client_.get()));
179     operation_->SetTimerFactoryForTest(base::WrapUnique(test_timer_factory_));
180     VerifyOperationStartedAndFinished(false /* has_started */,
181                                       false /* has_finished */);
182   }
183 
InitializeOperation()184   void InitializeOperation() {
185     VerifyOperationStartedAndFinished(false /* has_started */,
186                                       false /* has_finished */);
187     operation_->Initialize();
188 
189     for (const auto* arguments :
190          fake_secure_channel_client_
191              ->last_listen_for_connection_request_arguments_list()) {
192       EXPECT_EQ(kTetherFeature, arguments->feature);
193     }
194 
195     VerifyOperationStartedAndFinished(true /* has_started */,
196                                       false /* has_finished */);
197 
198     for (const auto& remote_device : operation_->remote_devices())
199       VerifyConnectionTimerCreatedForDevice(remote_device);
200   }
201 
VerifyOperationStartedAndFinished(bool has_started,bool has_finished)202   void VerifyOperationStartedAndFinished(bool has_started, bool has_finished) {
203     EXPECT_EQ(has_started, operation_->has_operation_started());
204     EXPECT_EQ(has_finished, operation_->has_operation_finished());
205   }
206 
CreateAuthenticatedChannelForDevice(multidevice::RemoteDeviceRef remote_device)207   void CreateAuthenticatedChannelForDevice(
208       multidevice::RemoteDeviceRef remote_device) {
209     test_timer_factory_->set_device_id_for_next_timer(
210         remote_device.GetDeviceId());
211 
212     auto fake_client_channel =
213         std::make_unique<secure_channel::FakeClientChannel>();
214     remote_device_to_fake_client_channel_map_[remote_device] =
215         fake_client_channel.get();
216     remote_device_to_fake_connection_attempt_map_[remote_device]
217         ->NotifyConnection(std::move(fake_client_channel));
218   }
219 
GetTimerForDevice(multidevice::RemoteDeviceRef remote_device)220   base::MockOneShotTimer* GetTimerForDevice(
221       multidevice::RemoteDeviceRef remote_device) {
222     return test_timer_factory_->GetTimerForDeviceId(
223         remote_device.GetDeviceId());
224   }
225 
VerifyDefaultTimerCreatedForDevice(multidevice::RemoteDeviceRef remote_device)226   void VerifyDefaultTimerCreatedForDevice(
227       multidevice::RemoteDeviceRef remote_device) {
228     VerifyTimerCreatedForDevice(remote_device, kTestTimeoutSeconds);
229   }
230 
VerifyConnectionTimerCreatedForDevice(multidevice::RemoteDeviceRef remote_device)231   void VerifyConnectionTimerCreatedForDevice(
232       multidevice::RemoteDeviceRef remote_device) {
233     VerifyTimerCreatedForDevice(
234         remote_device, MessageTransferOperation::kConnectionTimeoutSeconds);
235   }
236 
VerifyTimerCreatedForDevice(multidevice::RemoteDeviceRef remote_device,uint32_t timeout_seconds)237   void VerifyTimerCreatedForDevice(multidevice::RemoteDeviceRef remote_device,
238                                    uint32_t timeout_seconds) {
239     EXPECT_TRUE(GetTimerForDevice(remote_device));
240     EXPECT_EQ(base::TimeDelta::FromSeconds(timeout_seconds),
241               GetTimerForDevice(remote_device)->GetCurrentDelay());
242   }
243 
SendMessageToDevice(multidevice::RemoteDeviceRef remote_device,std::unique_ptr<MessageWrapper> message_wrapper)244   int SendMessageToDevice(multidevice::RemoteDeviceRef remote_device,
245                           std::unique_ptr<MessageWrapper> message_wrapper) {
246     return operation_->SendMessageToDevice(test_devices_[0],
247                                            std::move(message_wrapper));
248   }
249 
250   const multidevice::RemoteDeviceRef test_local_device_;
251   const multidevice::RemoteDeviceRefList test_devices_;
252 
253   base::flat_map<multidevice::RemoteDeviceRef,
254                  secure_channel::FakeConnectionAttempt*>
255       remote_device_to_fake_connection_attempt_map_;
256   base::flat_map<multidevice::RemoteDeviceRef,
257                  secure_channel::FakeClientChannel*>
258       remote_device_to_fake_client_channel_map_;
259 
260   std::unique_ptr<device_sync::FakeDeviceSyncClient> fake_device_sync_client_;
261   std::unique_ptr<secure_channel::FakeSecureChannelClient>
262       fake_secure_channel_client_;
263   TestTimerFactory* test_timer_factory_;
264   std::unique_ptr<TestOperation> operation_;
265 
266  private:
267   DISALLOW_COPY_AND_ASSIGN(MessageTransferOperationTest);
268 };
269 
TEST_F(MessageTransferOperationTest,TestFailedConnection)270 TEST_F(MessageTransferOperationTest, TestFailedConnection) {
271   ConstructOperation(multidevice::RemoteDeviceRefList{test_devices_[0]});
272   InitializeOperation();
273 
274   remote_device_to_fake_connection_attempt_map_[test_devices_[0]]
275       ->NotifyConnectionAttemptFailure(
276           secure_channel::mojom::ConnectionAttemptFailureReason::
277               AUTHENTICATION_ERROR);
278 
279   VerifyOperationStartedAndFinished(true /* has_started */,
280                                     true /* has_finished */);
281   EXPECT_FALSE(operation_->HasDeviceAuthenticated(test_devices_[0]));
282   EXPECT_TRUE(operation_->GetReceivedMessages(test_devices_[0]).empty());
283 }
284 
TEST_F(MessageTransferOperationTest,TestSuccessfulConnectionSendAndReceiveMessage)285 TEST_F(MessageTransferOperationTest,
286        TestSuccessfulConnectionSendAndReceiveMessage) {
287   ConstructOperation(multidevice::RemoteDeviceRefList{test_devices_[0]});
288   InitializeOperation();
289 
290   // Simulate how subclasses behave after a successful response: unregister the
291   // device.
292   operation_->set_should_unregister_device_on_message_received(true);
293 
294   CreateAuthenticatedChannelForDevice(test_devices_[0]);
295   EXPECT_TRUE(operation_->HasDeviceAuthenticated(test_devices_[0]));
296   VerifyDefaultTimerCreatedForDevice(test_devices_[0]);
297 
298   auto message_wrapper =
299       std::make_unique<MessageWrapper>(TetherAvailabilityRequest());
300   std::string expected_payload = message_wrapper->ToRawMessage();
301   int sequence_number =
302       SendMessageToDevice(test_devices_[0], std::move(message_wrapper));
303   std::vector<std::pair<std::string, base::OnceClosure>>& sent_messages =
304       remote_device_to_fake_client_channel_map_[test_devices_[0]]
305           ->sent_messages();
306   EXPECT_EQ(1u, sent_messages.size());
307   EXPECT_EQ(expected_payload, sent_messages[0].first);
308 
309   EXPECT_FALSE(operation_->last_sequence_number());
310   std::move(sent_messages[0].second).Run();
311   EXPECT_EQ(sequence_number, operation_->last_sequence_number());
312 
313   remote_device_to_fake_client_channel_map_[test_devices_[0]]
314       ->NotifyMessageReceived(
315           MessageWrapper(CreateTetherAvailabilityResponse()).ToRawMessage());
316 
317   EXPECT_EQ(1u, operation_->GetReceivedMessages(test_devices_[0]).size());
318   std::shared_ptr<MessageWrapper> message =
319       operation_->GetReceivedMessages(test_devices_[0])[0];
320   EXPECT_EQ(MessageType::TETHER_AVAILABILITY_RESPONSE,
321             message->GetMessageType());
322   EXPECT_EQ(CreateTetherAvailabilityResponse().SerializeAsString(),
323             message->GetProto()->SerializeAsString());
324 }
325 
TEST_F(MessageTransferOperationTest,TestTimesOutBeforeAuthentication)326 TEST_F(MessageTransferOperationTest, TestTimesOutBeforeAuthentication) {
327   ConstructOperation(multidevice::RemoteDeviceRefList{test_devices_[0]});
328   InitializeOperation();
329 
330   GetTimerForDevice(test_devices_[0])->Fire();
331   EXPECT_TRUE(operation_->has_operation_finished());
332 }
333 
TEST_F(MessageTransferOperationTest,TestAuthenticatesButThenTimesOut)334 TEST_F(MessageTransferOperationTest, TestAuthenticatesButThenTimesOut) {
335   ConstructOperation(multidevice::RemoteDeviceRefList{test_devices_[0]});
336   InitializeOperation();
337 
338   CreateAuthenticatedChannelForDevice(test_devices_[0]);
339   EXPECT_TRUE(operation_->HasDeviceAuthenticated(test_devices_[0]));
340   VerifyDefaultTimerCreatedForDevice(test_devices_[0]);
341 
342   GetTimerForDevice(test_devices_[0])->Fire();
343 
344   EXPECT_TRUE(operation_->has_operation_finished());
345 }
346 
TEST_F(MessageTransferOperationTest,TestRepeatedInputDevice)347 TEST_F(MessageTransferOperationTest, TestRepeatedInputDevice) {
348   // Construct with two copies of the same device.
349   ConstructOperation(
350       multidevice::RemoteDeviceRefList{test_devices_[0], test_devices_[0]});
351   InitializeOperation();
352 
353   CreateAuthenticatedChannelForDevice(test_devices_[0]);
354   EXPECT_TRUE(operation_->HasDeviceAuthenticated(test_devices_[0]));
355   VerifyDefaultTimerCreatedForDevice(test_devices_[0]);
356 
357   remote_device_to_fake_client_channel_map_[test_devices_[0]]
358       ->NotifyMessageReceived(
359           MessageWrapper(CreateTetherAvailabilityResponse()).ToRawMessage());
360 
361   // Should still have received only one message even though the device was
362   // repeated twice in the constructor.
363   EXPECT_EQ(1u, operation_->GetReceivedMessages(test_devices_[0]).size());
364   std::shared_ptr<MessageWrapper> message =
365       operation_->GetReceivedMessages(test_devices_[0])[0];
366   EXPECT_EQ(MessageType::TETHER_AVAILABILITY_RESPONSE,
367             message->GetMessageType());
368   EXPECT_EQ(CreateTetherAvailabilityResponse().SerializeAsString(),
369             message->GetProto()->SerializeAsString());
370 }
371 
TEST_F(MessageTransferOperationTest,MultipleDevices)372 TEST_F(MessageTransferOperationTest, MultipleDevices) {
373   ConstructOperation(test_devices_);
374   InitializeOperation();
375 
376   for (const auto& remote_device : test_devices_)
377     test_timer_factory_->ClearTimerForDeviceId(remote_device.GetDeviceId());
378 
379   // Authenticate |test_devices_[0]|'s channel.
380   CreateAuthenticatedChannelForDevice(test_devices_[0]);
381   EXPECT_TRUE(operation_->HasDeviceAuthenticated(test_devices_[0]));
382   VerifyDefaultTimerCreatedForDevice(test_devices_[0]);
383 
384   // Fail to connect to |test_devices_[1]|.
385   test_timer_factory_->set_device_id_for_next_timer(
386       test_devices_[1].GetDeviceId());
387   remote_device_to_fake_connection_attempt_map_[test_devices_[1]]
388       ->NotifyConnectionAttemptFailure(
389           secure_channel::mojom::ConnectionAttemptFailureReason::
390               GATT_CONNECTION_ERROR);
391   EXPECT_FALSE(operation_->HasDeviceAuthenticated(test_devices_[1]));
392   EXPECT_FALSE(GetTimerForDevice(test_devices_[1]));
393 
394   // Authenticate |test_devices_[2]|'s channel.
395   CreateAuthenticatedChannelForDevice(test_devices_[2]);
396   EXPECT_TRUE(operation_->HasDeviceAuthenticated(test_devices_[2]));
397   VerifyDefaultTimerCreatedForDevice(test_devices_[2]);
398 
399   // Fail to connect to |test_devices_[3]|.
400   test_timer_factory_->set_device_id_for_next_timer(
401       test_devices_[3].GetDeviceId());
402   remote_device_to_fake_connection_attempt_map_[test_devices_[3]]
403       ->NotifyConnectionAttemptFailure(
404           secure_channel::mojom::ConnectionAttemptFailureReason::
405               GATT_CONNECTION_ERROR);
406   EXPECT_FALSE(operation_->HasDeviceAuthenticated(test_devices_[3]));
407   EXPECT_FALSE(GetTimerForDevice(test_devices_[3]));
408 }
409 
410 }  // namespace tether
411 
412 }  // namespace chromeos
413