1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 #include <grpc/support/port_platform.h>
19 
20 #include "src/core/lib/iomgr/port.h"
21 #if GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER)
22 
23 #include <string.h>
24 
25 #include <ares.h>
26 
27 #include "absl/strings/str_format.h"
28 
29 #include <grpc/support/alloc.h>
30 #include <grpc/support/log.h>
31 #include <grpc/support/log_windows.h>
32 #include <grpc/support/string_util.h>
33 #include <grpc/support/time.h>
34 
35 #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h"
36 #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h"
37 #include "src/core/lib/address_utils/sockaddr_utils.h"
38 #include "src/core/lib/gpr/string.h"
39 #include "src/core/lib/gprpp/memory.h"
40 #include "src/core/lib/iomgr/iocp_windows.h"
41 #include "src/core/lib/iomgr/sockaddr_windows.h"
42 #include "src/core/lib/iomgr/socket_windows.h"
43 #include "src/core/lib/iomgr/tcp_windows.h"
44 #include "src/core/lib/iomgr/work_serializer.h"
45 #include "src/core/lib/slice/slice_internal.h"
46 
47 /* TODO(apolcyn): remove this hack after fixing upstream.
48  * Our grpc/c-ares code on Windows uses the ares_set_socket_functions API,
49  * which uses "struct iovec" type, which on Windows is defined inside of
50  * a c-ares header that is not public.
51  * See https://github.com/c-ares/c-ares/issues/206. */
52 struct iovec {
53   void* iov_base;
54   size_t iov_len;
55 };
56 
57 namespace grpc_core {
58 
59 /* c-ares reads and takes action on the error codes of the
60  * "virtual socket operations" in this file, via the WSAGetLastError
61  * APIs. If code in this file wants to set a specific WSA error that
62  * c-ares should read, it must do so by calling SetWSAError() on the
63  * WSAErrorContext instance passed to it. A WSAErrorContext must only be
64  * instantiated at the top of the virtual socket function callstack. */
65 class WSAErrorContext {
66  public:
WSAErrorContext()67   explicit WSAErrorContext(){};
68 
~WSAErrorContext()69   ~WSAErrorContext() {
70     if (error_ != 0) {
71       WSASetLastError(error_);
72     }
73   }
74 
75   /* Disallow copy and assignment operators */
76   WSAErrorContext(const WSAErrorContext&) = delete;
77   WSAErrorContext& operator=(const WSAErrorContext&) = delete;
78 
SetWSAError(int error)79   void SetWSAError(int error) { error_ = error; }
80 
81  private:
82   int error_ = 0;
83 };
84 
85 /* c-ares creates its own sockets and is meant to read them when readable and
86  * write them when writeable. To fit this socket usage model into the grpc
87  * windows poller (which gives notifications when attempted reads and writes are
88  * actually fulfilled rather than possible), this GrpcPolledFdWindows class
89  * takes advantage of the ares_set_socket_functions API and acts as a virtual
90  * socket. It holds its own read and write buffers which are written to and read
91  * from c-ares and are used with the grpc windows poller, and it, e.g.,
92  * manufactures virtual socket error codes when it e.g. needs to tell the c-ares
93  * library to wait for an async read. */
94 class GrpcPolledFdWindows {
95  public:
96   enum WriteState {
97     WRITE_IDLE,
98     WRITE_REQUESTED,
99     WRITE_PENDING,
100     WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY,
101   };
102 
GrpcPolledFdWindows(ares_socket_t as,std::shared_ptr<WorkSerializer> work_serializer,int address_family,int socket_type)103   GrpcPolledFdWindows(ares_socket_t as,
104                       std::shared_ptr<WorkSerializer> work_serializer,
105                       int address_family, int socket_type)
106       : work_serializer_(std::move(work_serializer)),
107         read_buf_(grpc_empty_slice()),
108         write_buf_(grpc_empty_slice()),
109         tcp_write_state_(WRITE_IDLE),
110         name_(absl::StrFormat("c-ares socket: %" PRIdPTR, as)),
111         gotten_into_driver_list_(false),
112         address_family_(address_family),
113         socket_type_(socket_type) {
114     // Closure Initialization
115     GRPC_CLOSURE_INIT(&outer_read_closure_,
116                       &GrpcPolledFdWindows::OnIocpReadable, this,
117                       grpc_schedule_on_exec_ctx);
118     GRPC_CLOSURE_INIT(&outer_write_closure_,
119                       &GrpcPolledFdWindows::OnIocpWriteable, this,
120                       grpc_schedule_on_exec_ctx);
121     GRPC_CLOSURE_INIT(&on_tcp_connect_locked_,
122                       &GrpcPolledFdWindows::OnTcpConnect, this,
123                       grpc_schedule_on_exec_ctx);
124     winsocket_ = grpc_winsocket_create(as, name_.c_str());
125   }
126 
~GrpcPolledFdWindows()127   ~GrpcPolledFdWindows() {
128     grpc_slice_unref_internal(read_buf_);
129     grpc_slice_unref_internal(write_buf_);
130     GPR_ASSERT(read_closure_ == nullptr);
131     GPR_ASSERT(write_closure_ == nullptr);
132     grpc_winsocket_destroy(winsocket_);
133   }
134 
ScheduleAndNullReadClosure(grpc_error_handle error)135   void ScheduleAndNullReadClosure(grpc_error_handle error) {
136     grpc_core::ExecCtx::Run(DEBUG_LOCATION, read_closure_, error);
137     read_closure_ = nullptr;
138   }
139 
ScheduleAndNullWriteClosure(grpc_error_handle error)140   void ScheduleAndNullWriteClosure(grpc_error_handle error) {
141     grpc_core::ExecCtx::Run(DEBUG_LOCATION, write_closure_, error);
142     write_closure_ = nullptr;
143   }
144 
RegisterForOnReadableLocked(grpc_closure * read_closure)145   void RegisterForOnReadableLocked(grpc_closure* read_closure) {
146     GPR_ASSERT(read_closure_ == nullptr);
147     read_closure_ = read_closure;
148     GPR_ASSERT(GRPC_SLICE_LENGTH(read_buf_) == 0);
149     grpc_slice_unref_internal(read_buf_);
150     GPR_ASSERT(!read_buf_has_data_);
151     read_buf_ = GRPC_SLICE_MALLOC(4192);
152     if (connect_done_) {
153       work_serializer_->Run([this]() { ContinueRegisterForOnReadableLocked(); },
154                             DEBUG_LOCATION);
155     } else {
156       GPR_ASSERT(pending_continue_register_for_on_readable_locked_ == false);
157       pending_continue_register_for_on_readable_locked_ = true;
158     }
159   }
160 
ContinueRegisterForOnReadableLocked()161   void ContinueRegisterForOnReadableLocked() {
162     GRPC_CARES_TRACE_LOG(
163         "fd:|%s| InnerContinueRegisterForOnReadableLocked "
164         "wsa_connect_error_:%d",
165         GetName(), wsa_connect_error_);
166     GPR_ASSERT(connect_done_);
167     if (wsa_connect_error_ != 0) {
168       ScheduleAndNullReadClosure(GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
169       return;
170     }
171     WSABUF buffer;
172     buffer.buf = (char*)GRPC_SLICE_START_PTR(read_buf_);
173     buffer.len = GRPC_SLICE_LENGTH(read_buf_);
174     memset(&winsocket_->read_info.overlapped, 0, sizeof(OVERLAPPED));
175     recv_from_source_addr_len_ = sizeof(recv_from_source_addr_);
176     DWORD flags = 0;
177     if (WSARecvFrom(grpc_winsocket_wrapped_socket(winsocket_), &buffer, 1,
178                     nullptr, &flags, (sockaddr*)recv_from_source_addr_,
179                     &recv_from_source_addr_len_,
180                     &winsocket_->read_info.overlapped, nullptr)) {
181       int wsa_last_error = WSAGetLastError();
182       char* msg = gpr_format_message(wsa_last_error);
183       GRPC_CARES_TRACE_LOG(
184           "fd:|%s| RegisterForOnReadableLocked WSARecvFrom error code:|%d| "
185           "msg:|%s|",
186           GetName(), wsa_last_error, msg);
187       gpr_free(msg);
188       if (wsa_last_error != WSA_IO_PENDING) {
189         ScheduleAndNullReadClosure(
190             GRPC_WSA_ERROR(wsa_last_error, "WSARecvFrom"));
191         return;
192       }
193     }
194     grpc_socket_notify_on_read(winsocket_, &outer_read_closure_);
195   }
196 
RegisterForOnWriteableLocked(grpc_closure * write_closure)197   void RegisterForOnWriteableLocked(grpc_closure* write_closure) {
198     if (socket_type_ == SOCK_DGRAM) {
199       GRPC_CARES_TRACE_LOG("fd:|%s| RegisterForOnWriteableLocked called",
200                            GetName());
201     } else {
202       GPR_ASSERT(socket_type_ == SOCK_STREAM);
203       GRPC_CARES_TRACE_LOG(
204           "fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d",
205           GetName(), tcp_write_state_);
206     }
207     GPR_ASSERT(write_closure_ == nullptr);
208     write_closure_ = write_closure;
209     if (connect_done_) {
210       work_serializer_->Run(
211           [this]() { ContinueRegisterForOnWriteableLocked(); }, DEBUG_LOCATION);
212     } else {
213       GPR_ASSERT(pending_continue_register_for_on_writeable_locked_ == false);
214       pending_continue_register_for_on_writeable_locked_ = true;
215     }
216   }
217 
ContinueRegisterForOnWriteableLocked()218   void ContinueRegisterForOnWriteableLocked() {
219     GRPC_CARES_TRACE_LOG(
220         "fd:|%s| InnerContinueRegisterForOnWriteableLocked "
221         "wsa_connect_error_:%d",
222         GetName(), wsa_connect_error_);
223     GPR_ASSERT(connect_done_);
224     if (wsa_connect_error_ != 0) {
225       ScheduleAndNullWriteClosure(
226           GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
227       return;
228     }
229     if (socket_type_ == SOCK_DGRAM) {
230       ScheduleAndNullWriteClosure(GRPC_ERROR_NONE);
231     } else {
232       GPR_ASSERT(socket_type_ == SOCK_STREAM);
233       int wsa_error_code = 0;
234       switch (tcp_write_state_) {
235         case WRITE_IDLE:
236           ScheduleAndNullWriteClosure(GRPC_ERROR_NONE);
237           break;
238         case WRITE_REQUESTED:
239           tcp_write_state_ = WRITE_PENDING;
240           if (SendWriteBuf(nullptr, &winsocket_->write_info.overlapped,
241                            &wsa_error_code) != 0) {
242             ScheduleAndNullWriteClosure(
243                 GRPC_WSA_ERROR(wsa_error_code, "WSASend (overlapped)"));
244           } else {
245             grpc_socket_notify_on_write(winsocket_, &outer_write_closure_);
246           }
247           break;
248         case WRITE_PENDING:
249         case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
250           abort();
251       }
252     }
253   }
254 
IsFdStillReadableLocked()255   bool IsFdStillReadableLocked() { return read_buf_has_data_; }
256 
ShutdownLocked(grpc_error_handle error)257   void ShutdownLocked(grpc_error_handle error) {
258     grpc_winsocket_shutdown(winsocket_);
259   }
260 
GetWrappedAresSocketLocked()261   ares_socket_t GetWrappedAresSocketLocked() {
262     return grpc_winsocket_wrapped_socket(winsocket_);
263   }
264 
GetName()265   const char* GetName() { return name_.c_str(); }
266 
RecvFrom(WSAErrorContext * wsa_error_ctx,void * data,ares_socket_t data_len,int flags,struct sockaddr * from,ares_socklen_t * from_len)267   ares_ssize_t RecvFrom(WSAErrorContext* wsa_error_ctx, void* data,
268                         ares_socket_t data_len, int flags,
269                         struct sockaddr* from, ares_socklen_t* from_len) {
270     GRPC_CARES_TRACE_LOG(
271         "fd:|%s| RecvFrom called read_buf_has_data:%d Current read buf "
272         "length:|%d|",
273         GetName(), read_buf_has_data_, GRPC_SLICE_LENGTH(read_buf_));
274     if (!read_buf_has_data_) {
275       wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
276       return -1;
277     }
278     ares_ssize_t bytes_read = 0;
279     for (size_t i = 0; i < GRPC_SLICE_LENGTH(read_buf_) && i < data_len; i++) {
280       ((char*)data)[i] = GRPC_SLICE_START_PTR(read_buf_)[i];
281       bytes_read++;
282     }
283     read_buf_ = grpc_slice_sub_no_ref(read_buf_, bytes_read,
284                                       GRPC_SLICE_LENGTH(read_buf_));
285     if (GRPC_SLICE_LENGTH(read_buf_) == 0) {
286       read_buf_has_data_ = false;
287     }
288     /* c-ares overloads this recv_from virtual socket function to receive
289      * data on both UDP and TCP sockets, and from is nullptr for TCP. */
290     if (from != nullptr) {
291       GPR_ASSERT(*from_len <= recv_from_source_addr_len_);
292       memcpy(from, &recv_from_source_addr_, recv_from_source_addr_len_);
293       *from_len = recv_from_source_addr_len_;
294     }
295     return bytes_read;
296   }
297 
FlattenIovec(const struct iovec * iov,int iov_count)298   grpc_slice FlattenIovec(const struct iovec* iov, int iov_count) {
299     int total = 0;
300     for (int i = 0; i < iov_count; i++) {
301       total += iov[i].iov_len;
302     }
303     grpc_slice out = GRPC_SLICE_MALLOC(total);
304     size_t cur = 0;
305     for (int i = 0; i < iov_count; i++) {
306       for (int k = 0; k < iov[i].iov_len; k++) {
307         GRPC_SLICE_START_PTR(out)[cur++] = ((char*)iov[i].iov_base)[k];
308       }
309     }
310     return out;
311   }
312 
SendWriteBuf(LPDWORD bytes_sent_ptr,LPWSAOVERLAPPED overlapped,int * wsa_error_code)313   int SendWriteBuf(LPDWORD bytes_sent_ptr, LPWSAOVERLAPPED overlapped,
314                    int* wsa_error_code) {
315     WSABUF buf;
316     buf.len = GRPC_SLICE_LENGTH(write_buf_);
317     buf.buf = (char*)GRPC_SLICE_START_PTR(write_buf_);
318     DWORD flags = 0;
319     int out = WSASend(grpc_winsocket_wrapped_socket(winsocket_), &buf, 1,
320                       bytes_sent_ptr, flags, overlapped, nullptr);
321     *wsa_error_code = WSAGetLastError();
322     GRPC_CARES_TRACE_LOG(
323         "fd:|%s| SendWriteBuf WSASend buf.len:%d *bytes_sent_ptr:%d "
324         "overlapped:%p "
325         "return:%d *wsa_error_code:%d",
326         GetName(), buf.len, bytes_sent_ptr != nullptr ? *bytes_sent_ptr : 0,
327         overlapped, out, *wsa_error_code);
328     return out;
329   }
330 
SendV(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)331   ares_ssize_t SendV(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
332                      int iov_count) {
333     GRPC_CARES_TRACE_LOG(
334         "fd:|%s| SendV called connect_done_:%d wsa_connect_error_:%d",
335         GetName(), connect_done_, wsa_connect_error_);
336     if (!connect_done_) {
337       wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
338       return -1;
339     }
340     if (wsa_connect_error_ != 0) {
341       wsa_error_ctx->SetWSAError(wsa_connect_error_);
342       return -1;
343     }
344     switch (socket_type_) {
345       case SOCK_DGRAM:
346         return SendVUDP(wsa_error_ctx, iov, iov_count);
347       case SOCK_STREAM:
348         return SendVTCP(wsa_error_ctx, iov, iov_count);
349       default:
350         abort();
351     }
352   }
353 
SendVUDP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)354   ares_ssize_t SendVUDP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
355                         int iov_count) {
356     // c-ares doesn't handle retryable errors on writes of UDP sockets.
357     // Therefore, the sendv handler for UDP sockets must only attempt
358     // to write everything inline.
359     GRPC_CARES_TRACE_LOG("fd:|%s| SendVUDP called", GetName());
360     GPR_ASSERT(GRPC_SLICE_LENGTH(write_buf_) == 0);
361     grpc_slice_unref_internal(write_buf_);
362     write_buf_ = FlattenIovec(iov, iov_count);
363     DWORD bytes_sent = 0;
364     int wsa_error_code = 0;
365     if (SendWriteBuf(&bytes_sent, nullptr, &wsa_error_code) != 0) {
366       grpc_slice_unref_internal(write_buf_);
367       write_buf_ = grpc_empty_slice();
368       wsa_error_ctx->SetWSAError(wsa_error_code);
369       char* msg = gpr_format_message(wsa_error_code);
370       GRPC_CARES_TRACE_LOG(
371           "fd:|%s| SendVUDP SendWriteBuf error code:%d msg:|%s|", GetName(),
372           wsa_error_code, msg);
373       gpr_free(msg);
374       return -1;
375     }
376     write_buf_ = grpc_slice_sub_no_ref(write_buf_, bytes_sent,
377                                        GRPC_SLICE_LENGTH(write_buf_));
378     return bytes_sent;
379   }
380 
SendVTCP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)381   ares_ssize_t SendVTCP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
382                         int iov_count) {
383     // The "sendv" handler on TCP sockets buffers up write
384     // requests and returns an artificial WSAEWOULDBLOCK. Writing that buffer
385     // out in the background, and making further send progress in general, will
386     // happen as long as c-ares continues to show interest in writeability on
387     // this fd.
388     GRPC_CARES_TRACE_LOG("fd:|%s| SendVTCP called tcp_write_state_:%d",
389                          GetName(), tcp_write_state_);
390     switch (tcp_write_state_) {
391       case WRITE_IDLE:
392         tcp_write_state_ = WRITE_REQUESTED;
393         GPR_ASSERT(GRPC_SLICE_LENGTH(write_buf_) == 0);
394         grpc_slice_unref_internal(write_buf_);
395         write_buf_ = FlattenIovec(iov, iov_count);
396         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
397         return -1;
398       case WRITE_REQUESTED:
399       case WRITE_PENDING:
400         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
401         return -1;
402       case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
403         // c-ares is retrying a send on data that we previously returned
404         // WSAEWOULDBLOCK for, but then subsequently wrote out in the
405         // background. Right now, we assume that c-ares is retrying the same
406         // send again. If c-ares still needs to send even more data, we'll get
407         // to it eventually.
408         grpc_slice currently_attempted = FlattenIovec(iov, iov_count);
409         GPR_ASSERT(GRPC_SLICE_LENGTH(currently_attempted) >=
410                    GRPC_SLICE_LENGTH(write_buf_));
411         ares_ssize_t total_sent = 0;
412         for (size_t i = 0; i < GRPC_SLICE_LENGTH(write_buf_); i++) {
413           GPR_ASSERT(GRPC_SLICE_START_PTR(currently_attempted)[i] ==
414                      GRPC_SLICE_START_PTR(write_buf_)[i]);
415           total_sent++;
416         }
417         grpc_slice_unref_internal(currently_attempted);
418         tcp_write_state_ = WRITE_IDLE;
419         return total_sent;
420     }
421     abort();
422   }
423 
OnTcpConnect(void * arg,grpc_error_handle error)424   static void OnTcpConnect(void* arg, grpc_error_handle error) {
425     GrpcPolledFdWindows* grpc_polled_fd =
426         static_cast<GrpcPolledFdWindows*>(arg);
427     (void)GRPC_ERROR_REF(error);  // ref owned by lambda
428     grpc_polled_fd->work_serializer_->Run(
429         [grpc_polled_fd, error]() {
430           grpc_polled_fd->OnTcpConnectLocked(error);
431         },
432         DEBUG_LOCATION);
433   }
434 
OnTcpConnectLocked(grpc_error_handle error)435   void OnTcpConnectLocked(grpc_error_handle error) {
436     GRPC_CARES_TRACE_LOG(
437         "fd:%s InnerOnTcpConnectLocked error:|%s| "
438         "pending_register_for_readable:%d"
439         " pending_register_for_writeable:%d",
440         GetName(), grpc_error_std_string(error).c_str(),
441         pending_continue_register_for_on_readable_locked_,
442         pending_continue_register_for_on_writeable_locked_);
443     GPR_ASSERT(!connect_done_);
444     connect_done_ = true;
445     GPR_ASSERT(wsa_connect_error_ == 0);
446     if (error == GRPC_ERROR_NONE) {
447       DWORD transferred_bytes = 0;
448       DWORD flags;
449       BOOL wsa_success =
450           WSAGetOverlappedResult(grpc_winsocket_wrapped_socket(winsocket_),
451                                  &winsocket_->write_info.overlapped,
452                                  &transferred_bytes, FALSE, &flags);
453       GPR_ASSERT(transferred_bytes == 0);
454       if (!wsa_success) {
455         wsa_connect_error_ = WSAGetLastError();
456         char* msg = gpr_format_message(wsa_connect_error_);
457         GRPC_CARES_TRACE_LOG(
458             "fd:%s InnerOnTcpConnectLocked WSA overlapped result code:%d "
459             "msg:|%s|",
460             GetName(), wsa_connect_error_, msg);
461         gpr_free(msg);
462       }
463     } else {
464       // Spoof up an error code that will cause any future c-ares operations on
465       // this fd to abort.
466       wsa_connect_error_ = WSA_OPERATION_ABORTED;
467     }
468     if (pending_continue_register_for_on_readable_locked_) {
469       work_serializer_->Run([this]() { ContinueRegisterForOnReadableLocked(); },
470                             DEBUG_LOCATION);
471     }
472     if (pending_continue_register_for_on_writeable_locked_) {
473       work_serializer_->Run(
474           [this]() { ContinueRegisterForOnWriteableLocked(); }, DEBUG_LOCATION);
475     }
476     GRPC_ERROR_UNREF(error);
477   }
478 
Connect(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)479   int Connect(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
480               ares_socklen_t target_len) {
481     switch (socket_type_) {
482       case SOCK_DGRAM:
483         return ConnectUDP(wsa_error_ctx, target, target_len);
484       case SOCK_STREAM:
485         return ConnectTCP(wsa_error_ctx, target, target_len);
486       default:
487         abort();
488     }
489   }
490 
ConnectUDP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)491   int ConnectUDP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
492                  ares_socklen_t target_len) {
493     GRPC_CARES_TRACE_LOG("fd:%s ConnectUDP", GetName());
494     GPR_ASSERT(!connect_done_);
495     GPR_ASSERT(wsa_connect_error_ == 0);
496     SOCKET s = grpc_winsocket_wrapped_socket(winsocket_);
497     int out =
498         WSAConnect(s, target, target_len, nullptr, nullptr, nullptr, nullptr);
499     wsa_connect_error_ = WSAGetLastError();
500     wsa_error_ctx->SetWSAError(wsa_connect_error_);
501     connect_done_ = true;
502     char* msg = gpr_format_message(wsa_connect_error_);
503     GRPC_CARES_TRACE_LOG("fd:%s WSAConnect error code:|%d| msg:|%s|", GetName(),
504                          wsa_connect_error_, msg);
505     gpr_free(msg);
506     // c-ares expects a posix-style connect API
507     return out == 0 ? 0 : -1;
508   }
509 
ConnectTCP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)510   int ConnectTCP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
511                  ares_socklen_t target_len) {
512     GRPC_CARES_TRACE_LOG("fd:%s ConnectTCP", GetName());
513     LPFN_CONNECTEX ConnectEx;
514     GUID guid = WSAID_CONNECTEX;
515     DWORD ioctl_num_bytes;
516     SOCKET s = grpc_winsocket_wrapped_socket(winsocket_);
517     if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
518                  &ConnectEx, sizeof(ConnectEx), &ioctl_num_bytes, nullptr,
519                  nullptr) != 0) {
520       int wsa_last_error = WSAGetLastError();
521       wsa_error_ctx->SetWSAError(wsa_last_error);
522       char* msg = gpr_format_message(wsa_last_error);
523       GRPC_CARES_TRACE_LOG(
524           "fd:%s WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER) error code:%d "
525           "msg:|%s|",
526           GetName(), wsa_last_error, msg);
527       gpr_free(msg);
528       connect_done_ = true;
529       wsa_connect_error_ = wsa_last_error;
530       return -1;
531     }
532     grpc_resolved_address wildcard4_addr;
533     grpc_resolved_address wildcard6_addr;
534     grpc_sockaddr_make_wildcards(0, &wildcard4_addr, &wildcard6_addr);
535     grpc_resolved_address* local_address = nullptr;
536     if (address_family_ == AF_INET) {
537       local_address = &wildcard4_addr;
538     } else {
539       local_address = &wildcard6_addr;
540     }
541     if (bind(s, (struct sockaddr*)local_address->addr,
542              (int)local_address->len) != 0) {
543       int wsa_last_error = WSAGetLastError();
544       wsa_error_ctx->SetWSAError(wsa_last_error);
545       char* msg = gpr_format_message(wsa_last_error);
546       GRPC_CARES_TRACE_LOG("fd:%s bind error code:%d msg:|%s|", GetName(),
547                            wsa_last_error, msg);
548       gpr_free(msg);
549       connect_done_ = true;
550       wsa_connect_error_ = wsa_last_error;
551       return -1;
552     }
553     int out = 0;
554     if (ConnectEx(s, target, target_len, nullptr, 0, nullptr,
555                   &winsocket_->write_info.overlapped) == 0) {
556       out = -1;
557       int wsa_last_error = WSAGetLastError();
558       wsa_error_ctx->SetWSAError(wsa_last_error);
559       char* msg = gpr_format_message(wsa_last_error);
560       GRPC_CARES_TRACE_LOG("fd:%s ConnectEx error code:%d msg:|%s|", GetName(),
561                            wsa_last_error, msg);
562       gpr_free(msg);
563       if (wsa_last_error == WSA_IO_PENDING) {
564         // c-ares only understands WSAEINPROGRESS and EWOULDBLOCK error codes on
565         // connect, but an async connect on IOCP socket will give
566         // WSA_IO_PENDING, so we need to convert.
567         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
568       } else {
569         // By returning a non-retryable error to c-ares at this point,
570         // we're aborting the possibility of any future operations on this fd.
571         connect_done_ = true;
572         wsa_connect_error_ = wsa_last_error;
573         return -1;
574       }
575     }
576     grpc_socket_notify_on_write(winsocket_, &on_tcp_connect_locked_);
577     return out;
578   }
579 
OnIocpReadable(void * arg,grpc_error_handle error)580   static void OnIocpReadable(void* arg, grpc_error_handle error) {
581     GrpcPolledFdWindows* polled_fd = static_cast<GrpcPolledFdWindows*>(arg);
582     (void)GRPC_ERROR_REF(error);  // ref owned by lambda
583     polled_fd->work_serializer_->Run(
584         [polled_fd, error]() { polled_fd->OnIocpReadableLocked(error); },
585         DEBUG_LOCATION);
586   }
587 
588   // TODO(apolcyn): improve this error handling to be less conversative.
589   // An e.g. ECONNRESET error here should result in errors when
590   // c-ares reads from this socket later, but it shouldn't necessarily cancel
591   // the entire resolution attempt. Doing so will allow the "inject broken
592   // nameserver list" test to pass on Windows.
OnIocpReadableLocked(grpc_error_handle error)593   void OnIocpReadableLocked(grpc_error_handle error) {
594     if (error == GRPC_ERROR_NONE) {
595       if (winsocket_->read_info.wsa_error != 0) {
596         /* WSAEMSGSIZE would be due to receiving more data
597          * than our read buffer's fixed capacity. Assume that
598          * the connection is TCP and read the leftovers
599          * in subsequent c-ares reads. */
600         if (winsocket_->read_info.wsa_error != WSAEMSGSIZE) {
601           error = GRPC_WSA_ERROR(winsocket_->read_info.wsa_error,
602                                  "OnIocpReadableInner");
603           GRPC_CARES_TRACE_LOG(
604               "fd:|%s| OnIocpReadableInner winsocket_->read_info.wsa_error "
605               "code:|%d| msg:|%s|",
606               GetName(), winsocket_->read_info.wsa_error,
607               grpc_error_std_string(error).c_str());
608         }
609       }
610     }
611     if (error == GRPC_ERROR_NONE) {
612       read_buf_ = grpc_slice_sub_no_ref(
613           read_buf_, 0, winsocket_->read_info.bytes_transferred);
614       read_buf_has_data_ = true;
615     } else {
616       grpc_slice_unref_internal(read_buf_);
617       read_buf_ = grpc_empty_slice();
618     }
619     GRPC_CARES_TRACE_LOG(
620         "fd:|%s| OnIocpReadable finishing. read buf length now:|%d|", GetName(),
621         GRPC_SLICE_LENGTH(read_buf_));
622     ScheduleAndNullReadClosure(error);
623   }
624 
OnIocpWriteable(void * arg,grpc_error_handle error)625   static void OnIocpWriteable(void* arg, grpc_error_handle error) {
626     GrpcPolledFdWindows* polled_fd = static_cast<GrpcPolledFdWindows*>(arg);
627     (void)GRPC_ERROR_REF(error);  // error owned by lambda
628     polled_fd->work_serializer_->Run(
629         [polled_fd, error]() { polled_fd->OnIocpWriteableLocked(error); },
630         DEBUG_LOCATION);
631   }
632 
OnIocpWriteableLocked(grpc_error_handle error)633   void OnIocpWriteableLocked(grpc_error_handle error) {
634     GRPC_CARES_TRACE_LOG("OnIocpWriteableInner. fd:|%s|", GetName());
635     GPR_ASSERT(socket_type_ == SOCK_STREAM);
636     if (error == GRPC_ERROR_NONE) {
637       if (winsocket_->write_info.wsa_error != 0) {
638         error = GRPC_WSA_ERROR(winsocket_->write_info.wsa_error,
639                                "OnIocpWriteableInner");
640         GRPC_CARES_TRACE_LOG(
641             "fd:|%s| OnIocpWriteableInner. winsocket_->write_info.wsa_error "
642             "code:|%d| msg:|%s|",
643             GetName(), winsocket_->write_info.wsa_error,
644             grpc_error_std_string(error).c_str());
645       }
646     }
647     GPR_ASSERT(tcp_write_state_ == WRITE_PENDING);
648     if (error == GRPC_ERROR_NONE) {
649       tcp_write_state_ = WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY;
650       write_buf_ = grpc_slice_sub_no_ref(
651           write_buf_, 0, winsocket_->write_info.bytes_transferred);
652       GRPC_CARES_TRACE_LOG("fd:|%s| OnIocpWriteableInner. bytes transferred:%d",
653                            GetName(), winsocket_->write_info.bytes_transferred);
654     } else {
655       grpc_slice_unref_internal(write_buf_);
656       write_buf_ = grpc_empty_slice();
657     }
658     ScheduleAndNullWriteClosure(error);
659   }
660 
gotten_into_driver_list() const661   bool gotten_into_driver_list() const { return gotten_into_driver_list_; }
set_gotten_into_driver_list()662   void set_gotten_into_driver_list() { gotten_into_driver_list_ = true; }
663 
664  private:
665   std::shared_ptr<WorkSerializer> work_serializer_;
666   char recv_from_source_addr_[200];
667   ares_socklen_t recv_from_source_addr_len_;
668   grpc_slice read_buf_;
669   bool read_buf_has_data_ = false;
670   grpc_slice write_buf_;
671   grpc_closure* read_closure_ = nullptr;
672   grpc_closure* write_closure_ = nullptr;
673   grpc_closure outer_read_closure_;
674   grpc_closure outer_write_closure_;
675   grpc_winsocket* winsocket_;
676   // tcp_write_state_ is only used on TCP GrpcPolledFds
677   WriteState tcp_write_state_;
678   std::string name_;
679   bool gotten_into_driver_list_;
680   int address_family_;
681   int socket_type_;
682   grpc_closure on_tcp_connect_locked_;
683   bool connect_done_ = false;
684   int wsa_connect_error_ = 0;
685   // We don't run register_for_{readable,writeable} logic until
686   // a socket is connected. In the interim, we queue readable/writeable
687   // registrations with the following state.
688   bool pending_continue_register_for_on_readable_locked_ = false;
689   bool pending_continue_register_for_on_writeable_locked_ = false;
690 };
691 
692 struct SockToPolledFdEntry {
SockToPolledFdEntrygrpc_core::SockToPolledFdEntry693   SockToPolledFdEntry(SOCKET s, GrpcPolledFdWindows* fd)
694       : socket(s), polled_fd(fd) {}
695   SOCKET socket;
696   GrpcPolledFdWindows* polled_fd;
697   SockToPolledFdEntry* next = nullptr;
698 };
699 
700 /* A SockToPolledFdMap can make ares_socket_t types (SOCKET's on windows)
701  * to GrpcPolledFdWindow's, and is used to find the appropriate
702  * GrpcPolledFdWindows to handle a virtual socket call when c-ares makes that
703  * socket call on the ares_socket_t type. Instances are owned by and one-to-one
704  * with a GrpcPolledFdWindows factory and event driver */
705 class SockToPolledFdMap {
706  public:
SockToPolledFdMap(std::shared_ptr<WorkSerializer> work_serializer)707   explicit SockToPolledFdMap(std::shared_ptr<WorkSerializer> work_serializer)
708       : work_serializer_(std::move(work_serializer)) {}
709 
~SockToPolledFdMap()710   ~SockToPolledFdMap() { GPR_ASSERT(head_ == nullptr); }
711 
AddNewSocket(SOCKET s,GrpcPolledFdWindows * polled_fd)712   void AddNewSocket(SOCKET s, GrpcPolledFdWindows* polled_fd) {
713     SockToPolledFdEntry* new_node = new SockToPolledFdEntry(s, polled_fd);
714     new_node->next = head_;
715     head_ = new_node;
716   }
717 
LookupPolledFd(SOCKET s)718   GrpcPolledFdWindows* LookupPolledFd(SOCKET s) {
719     for (SockToPolledFdEntry* node = head_; node != nullptr;
720          node = node->next) {
721       if (node->socket == s) {
722         GPR_ASSERT(node->polled_fd != nullptr);
723         return node->polled_fd;
724       }
725     }
726     abort();
727   }
728 
RemoveEntry(SOCKET s)729   void RemoveEntry(SOCKET s) {
730     GPR_ASSERT(head_ != nullptr);
731     SockToPolledFdEntry** prev = &head_;
732     for (SockToPolledFdEntry* node = head_; node != nullptr;
733          node = node->next) {
734       if (node->socket == s) {
735         *prev = node->next;
736         delete node;
737         return;
738       }
739       prev = &node->next;
740     }
741     abort();
742   }
743 
744   /* These virtual socket functions are called from within the c-ares
745    * library. These methods generally dispatch those socket calls to the
746    * appropriate methods. The virtual "socket" and "close" methods are
747    * special and instead create/add and remove/destroy GrpcPolledFdWindows
748    * objects.
749    */
Socket(int af,int type,int protocol,void * user_data)750   static ares_socket_t Socket(int af, int type, int protocol, void* user_data) {
751     if (type != SOCK_DGRAM && type != SOCK_STREAM) {
752       GRPC_CARES_TRACE_LOG("Socket called with invalid socket type:%d", type);
753       return INVALID_SOCKET;
754     }
755     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
756     SOCKET s = WSASocket(af, type, protocol, nullptr, 0,
757                          grpc_get_default_wsa_socket_flags());
758     if (s == INVALID_SOCKET) {
759       GRPC_CARES_TRACE_LOG(
760           "WSASocket failed with params af:%d type:%d protocol:%d", af, type,
761           protocol);
762       return s;
763     }
764     grpc_tcp_set_non_block(s);
765     GrpcPolledFdWindows* polled_fd =
766         new GrpcPolledFdWindows(s, map->work_serializer_, af, type);
767     GRPC_CARES_TRACE_LOG(
768         "fd:|%s| created with params af:%d type:%d protocol:%d",
769         polled_fd->GetName(), af, type, protocol);
770     map->AddNewSocket(s, polled_fd);
771     return s;
772   }
773 
Connect(ares_socket_t as,const struct sockaddr * target,ares_socklen_t target_len,void * user_data)774   static int Connect(ares_socket_t as, const struct sockaddr* target,
775                      ares_socklen_t target_len, void* user_data) {
776     WSAErrorContext wsa_error_ctx;
777     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
778     GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
779     return polled_fd->Connect(&wsa_error_ctx, target, target_len);
780   }
781 
SendV(ares_socket_t as,const struct iovec * iov,int iovec_count,void * user_data)782   static ares_ssize_t SendV(ares_socket_t as, const struct iovec* iov,
783                             int iovec_count, void* user_data) {
784     WSAErrorContext wsa_error_ctx;
785     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
786     GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
787     return polled_fd->SendV(&wsa_error_ctx, iov, iovec_count);
788   }
789 
RecvFrom(ares_socket_t as,void * data,size_t data_len,int flags,struct sockaddr * from,ares_socklen_t * from_len,void * user_data)790   static ares_ssize_t RecvFrom(ares_socket_t as, void* data, size_t data_len,
791                                int flags, struct sockaddr* from,
792                                ares_socklen_t* from_len, void* user_data) {
793     WSAErrorContext wsa_error_ctx;
794     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
795     GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
796     return polled_fd->RecvFrom(&wsa_error_ctx, data, data_len, flags, from,
797                                from_len);
798   }
799 
CloseSocket(SOCKET s,void * user_data)800   static int CloseSocket(SOCKET s, void* user_data) {
801     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
802     GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(s);
803     map->RemoveEntry(s);
804     // See https://github.com/grpc/grpc/pull/20284, this trace log is
805     // intentionally placed to attempt to trigger a crash in case of a
806     // use after free on polled_fd.
807     GRPC_CARES_TRACE_LOG("CloseSocket called for socket: %s",
808                          polled_fd->GetName());
809     // If a gRPC polled fd has not made it in to the driver's list yet, then
810     // the driver has not and will never see this socket.
811     if (!polled_fd->gotten_into_driver_list()) {
812       polled_fd->ShutdownLocked(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
813           "Shut down c-ares fd before without it ever having made it into the "
814           "driver's list"));
815     }
816     delete polled_fd;
817     return 0;
818   }
819 
820  private:
821   SockToPolledFdEntry* head_ = nullptr;
822   std::shared_ptr<WorkSerializer> work_serializer_;
823 };
824 
825 const struct ares_socket_functions custom_ares_sock_funcs = {
826     &SockToPolledFdMap::Socket /* socket */,
827     &SockToPolledFdMap::CloseSocket /* close */,
828     &SockToPolledFdMap::Connect /* connect */,
829     &SockToPolledFdMap::RecvFrom /* recvfrom */,
830     &SockToPolledFdMap::SendV /* sendv */,
831 };
832 
833 /* A thin wrapper over a GrpcPolledFdWindows object but with a shorter
834    lifetime. This object releases it's GrpcPolledFdWindows upon destruction,
835    so that c-ares can close it via usual socket teardown. */
836 class GrpcPolledFdWindowsWrapper : public GrpcPolledFd {
837  public:
GrpcPolledFdWindowsWrapper(GrpcPolledFdWindows * wrapped)838   explicit GrpcPolledFdWindowsWrapper(GrpcPolledFdWindows* wrapped)
839       : wrapped_(wrapped) {}
840 
~GrpcPolledFdWindowsWrapper()841   ~GrpcPolledFdWindowsWrapper() {}
842 
RegisterForOnReadableLocked(grpc_closure * read_closure)843   void RegisterForOnReadableLocked(grpc_closure* read_closure) override {
844     wrapped_->RegisterForOnReadableLocked(read_closure);
845   }
846 
RegisterForOnWriteableLocked(grpc_closure * write_closure)847   void RegisterForOnWriteableLocked(grpc_closure* write_closure) override {
848     wrapped_->RegisterForOnWriteableLocked(write_closure);
849   }
850 
IsFdStillReadableLocked()851   bool IsFdStillReadableLocked() override {
852     return wrapped_->IsFdStillReadableLocked();
853   }
854 
ShutdownLocked(grpc_error_handle error)855   void ShutdownLocked(grpc_error_handle error) override {
856     wrapped_->ShutdownLocked(error);
857   }
858 
GetWrappedAresSocketLocked()859   ares_socket_t GetWrappedAresSocketLocked() override {
860     return wrapped_->GetWrappedAresSocketLocked();
861   }
862 
GetName()863   const char* GetName() override { return wrapped_->GetName(); }
864 
865  private:
866   GrpcPolledFdWindows* wrapped_;
867 };
868 
869 class GrpcPolledFdFactoryWindows : public GrpcPolledFdFactory {
870  public:
GrpcPolledFdFactoryWindows(std::shared_ptr<WorkSerializer> work_serializer)871   explicit GrpcPolledFdFactoryWindows(
872       std::shared_ptr<WorkSerializer> work_serializer)
873       : sock_to_polled_fd_map_(std::move(work_serializer)) {}
874 
NewGrpcPolledFdLocked(ares_socket_t as,grpc_pollset_set * driver_pollset_set,std::shared_ptr<WorkSerializer> work_serializer)875   GrpcPolledFd* NewGrpcPolledFdLocked(
876       ares_socket_t as, grpc_pollset_set* driver_pollset_set,
877       std::shared_ptr<WorkSerializer> work_serializer) override {
878     GrpcPolledFdWindows* polled_fd = sock_to_polled_fd_map_.LookupPolledFd(as);
879     // Set a flag so that the virtual socket "close" method knows it
880     // doesn't need to call ShutdownLocked, since now the driver will.
881     polled_fd->set_gotten_into_driver_list();
882     return new GrpcPolledFdWindowsWrapper(polled_fd);
883   }
884 
ConfigureAresChannelLocked(ares_channel channel)885   void ConfigureAresChannelLocked(ares_channel channel) override {
886     ares_set_socket_functions(channel, &custom_ares_sock_funcs,
887                               &sock_to_polled_fd_map_);
888   }
889 
890  private:
891   SockToPolledFdMap sock_to_polled_fd_map_;
892 };
893 
NewGrpcPolledFdFactory(std::shared_ptr<WorkSerializer> work_serializer)894 std::unique_ptr<GrpcPolledFdFactory> NewGrpcPolledFdFactory(
895     std::shared_ptr<WorkSerializer> work_serializer) {
896   return absl::make_unique<GrpcPolledFdFactoryWindows>(
897       std::move(work_serializer));
898 }
899 
900 }  // namespace grpc_core
901 
902 #endif /* GRPC_ARES == 1 && defined(GPR_WINDOWS) */
903