1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "services/network/udp_socket_test_util.h"
6 
7 #include <utility>
8 
9 #include "base/run_loop.h"
10 #include "base/test/bind_test_util.h"
11 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
12 #include "testing/gtest/include/gtest/gtest.h"
13 
14 namespace network {
15 
16 namespace test {
17 
UDPSocketTestHelper(mojo::Remote<mojom::UDPSocket> * socket)18 UDPSocketTestHelper::UDPSocketTestHelper(mojo::Remote<mojom::UDPSocket>* socket)
19     : socket_(socket) {}
20 
~UDPSocketTestHelper()21 UDPSocketTestHelper::~UDPSocketTestHelper() {}
22 
ConnectSync(const net::IPEndPoint & remote_addr,mojom::UDPSocketOptionsPtr options,net::IPEndPoint * local_addr_out)23 int UDPSocketTestHelper::ConnectSync(const net::IPEndPoint& remote_addr,
24                                      mojom::UDPSocketOptionsPtr options,
25                                      net::IPEndPoint* local_addr_out) {
26   base::RunLoop run_loop;
27   int net_error = net::ERR_FAILED;
28   socket_->get()->Connect(
29       remote_addr, std::move(options),
30       base::BindLambdaForTesting(
31           [&](int result, const base::Optional<net::IPEndPoint>& local_addr) {
32             net_error = result;
33             if (local_addr) {
34               *local_addr_out = local_addr.value();
35             }
36             run_loop.Quit();
37           }));
38   run_loop.Run();
39   return net_error;
40 }
41 
BindSync(const net::IPEndPoint & local_addr,mojom::UDPSocketOptionsPtr options,net::IPEndPoint * local_addr_out)42 int UDPSocketTestHelper::BindSync(const net::IPEndPoint& local_addr,
43                                   mojom::UDPSocketOptionsPtr options,
44                                   net::IPEndPoint* local_addr_out) {
45   base::RunLoop run_loop;
46   int net_error = net::ERR_FAILED;
47   socket_->get()->Bind(
48       local_addr, std::move(options),
49       base::BindLambdaForTesting(
50           [&](int result, const base::Optional<net::IPEndPoint>& local_addr) {
51             net_error = result;
52             if (local_addr) {
53               *local_addr_out = local_addr.value();
54             }
55             run_loop.Quit();
56           }));
57   run_loop.Run();
58   return net_error;
59 }
60 
SendToSync(const net::IPEndPoint & remote_addr,const std::vector<uint8_t> & input)61 int UDPSocketTestHelper::SendToSync(const net::IPEndPoint& remote_addr,
62                                     const std::vector<uint8_t>& input) {
63   base::RunLoop run_loop;
64   int net_error = net::ERR_FAILED;
65   socket_->get()->SendTo(
66       remote_addr, input,
67       net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
68       base::BindLambdaForTesting([&](int result) {
69         net_error = result;
70         run_loop.Quit();
71       }));
72   run_loop.Run();
73   return net_error;
74 }
75 
SendSync(const std::vector<uint8_t> & input)76 int UDPSocketTestHelper::SendSync(const std::vector<uint8_t>& input) {
77   base::RunLoop run_loop;
78   int net_error = net::ERR_FAILED;
79   socket_->get()->Send(
80       input,
81       net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
82       base::BindLambdaForTesting([&](int result) {
83         net_error = result;
84         run_loop.Quit();
85       }));
86   run_loop.Run();
87   return net_error;
88 }
89 
SetBroadcastSync(bool broadcast)90 int UDPSocketTestHelper::SetBroadcastSync(bool broadcast) {
91   base::RunLoop run_loop;
92   int net_error = net::ERR_FAILED;
93   socket_->get()->SetBroadcast(broadcast,
94                                base::BindLambdaForTesting([&](int result) {
95                                  net_error = result;
96                                  run_loop.Quit();
97                                }));
98   run_loop.Run();
99   return net_error;
100 }
101 
SetSendBufferSizeSync(int send_buffer_size)102 int UDPSocketTestHelper::SetSendBufferSizeSync(int send_buffer_size) {
103   base::RunLoop run_loop;
104   int net_error = net::ERR_FAILED;
105   socket_->get()->SetSendBufferSize(send_buffer_size,
106                                     base::BindLambdaForTesting([&](int result) {
107                                       net_error = result;
108                                       run_loop.Quit();
109                                     }));
110   run_loop.Run();
111   return net_error;
112 }
113 
SetReceiveBufferSizeSync(int receive_buffer_size)114 int UDPSocketTestHelper::SetReceiveBufferSizeSync(int receive_buffer_size) {
115   base::RunLoop run_loop;
116   int net_error = net::ERR_FAILED;
117   socket_->get()->SetReceiveBufferSize(
118       receive_buffer_size, base::BindLambdaForTesting([&](int result) {
119         net_error = result;
120         run_loop.Quit();
121       }));
122   run_loop.Run();
123   return net_error;
124 }
125 
JoinGroupSync(const net::IPAddress & group_address)126 int UDPSocketTestHelper::JoinGroupSync(const net::IPAddress& group_address) {
127   base::RunLoop run_loop;
128   int net_error = net::ERR_FAILED;
129   socket_->get()->JoinGroup(group_address,
130                             base::BindLambdaForTesting([&](int result) {
131                               net_error = result;
132                               run_loop.Quit();
133                             }));
134   run_loop.Run();
135   return net_error;
136 }
137 
LeaveGroupSync(const net::IPAddress & group_address)138 int UDPSocketTestHelper::LeaveGroupSync(const net::IPAddress& group_address) {
139   base::RunLoop run_loop;
140   int net_error = net::ERR_FAILED;
141   socket_->get()->LeaveGroup(group_address,
142                              base::BindLambdaForTesting([&](int result) {
143                                net_error = result;
144                                run_loop.Quit();
145                              }));
146   run_loop.Run();
147   return net_error;
148 }
149 
ReceivedResult(int net_error_arg,const base::Optional<net::IPEndPoint> & src_addr_arg,base::Optional<std::vector<uint8_t>> data_arg)150 UDPSocketListenerImpl::ReceivedResult::ReceivedResult(
151     int net_error_arg,
152     const base::Optional<net::IPEndPoint>& src_addr_arg,
153     base::Optional<std::vector<uint8_t>> data_arg)
154     : net_error(net_error_arg),
155       src_addr(src_addr_arg),
156       data(std::move(data_arg)) {}
157 
158 UDPSocketListenerImpl::ReceivedResult::ReceivedResult(
159     const ReceivedResult& other) = default;
160 
~ReceivedResult()161 UDPSocketListenerImpl::ReceivedResult::~ReceivedResult() {}
162 
UDPSocketListenerImpl()163 UDPSocketListenerImpl::UDPSocketListenerImpl()
164     : run_loop_(std::make_unique<base::RunLoop>()),
165       expected_receive_count_(0) {}
166 
~UDPSocketListenerImpl()167 UDPSocketListenerImpl::~UDPSocketListenerImpl() {}
168 
WaitForReceivedResults(size_t count)169 void UDPSocketListenerImpl::WaitForReceivedResults(size_t count) {
170   DCHECK_LE(results_.size(), count);
171   DCHECK_EQ(0u, expected_receive_count_);
172 
173   if (results_.size() == count)
174     return;
175 
176   expected_receive_count_ = count;
177   run_loop_->Run();
178   run_loop_ = std::make_unique<base::RunLoop>();
179 }
180 
OnReceived(int32_t result,const base::Optional<net::IPEndPoint> & src_addr,base::Optional<base::span<const uint8_t>> data)181 void UDPSocketListenerImpl::OnReceived(
182     int32_t result,
183     const base::Optional<net::IPEndPoint>& src_addr,
184     base::Optional<base::span<const uint8_t>> data) {
185   // OnReceive() API contracts specifies that this method will not be called
186   // with a |result| that is > 0.
187   DCHECK_GE(0, result);
188   DCHECK(result < 0 || data);
189 
190   results_.emplace_back(result, src_addr,
191                         data ? base::make_optional(std::vector<uint8_t>(
192                                    data.value().begin(), data.value().end()))
193                              : base::nullopt);
194   if (results_.size() == expected_receive_count_) {
195     expected_receive_count_ = 0;
196     run_loop_->Quit();
197   }
198 }
199 
200 }  // namespace test
201 
202 }  // namespace network
203