1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 #include <grpc/support/port_platform.h>
19 
20 #include "src/core/lib/iomgr/port.h"
21 #if GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER)
22 
23 #include "absl/strings/str_format.h"
24 
25 #include <ares.h>
26 
27 #include <grpc/support/alloc.h>
28 #include <grpc/support/log.h>
29 #include <grpc/support/log_windows.h>
30 #include <grpc/support/string_util.h>
31 #include <grpc/support/time.h>
32 #include <string.h>
33 #include "src/core/lib/gpr/string.h"
34 #include "src/core/lib/gprpp/memory.h"
35 #include "src/core/lib/iomgr/iocp_windows.h"
36 #include "src/core/lib/iomgr/sockaddr_utils.h"
37 #include "src/core/lib/iomgr/sockaddr_windows.h"
38 #include "src/core/lib/iomgr/socket_windows.h"
39 #include "src/core/lib/iomgr/tcp_windows.h"
40 #include "src/core/lib/iomgr/work_serializer.h"
41 #include "src/core/lib/slice/slice_internal.h"
42 
43 #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h"
44 #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h"
45 
46 /* TODO(apolcyn): remove this hack after fixing upstream.
47  * Our grpc/c-ares code on Windows uses the ares_set_socket_functions API,
48  * which uses "struct iovec" type, which on Windows is defined inside of
49  * a c-ares header that is not public.
50  * See https://github.com/c-ares/c-ares/issues/206. */
51 struct iovec {
52   void* iov_base;
53   size_t iov_len;
54 };
55 
56 namespace grpc_core {
57 
58 /* c-ares reads and takes action on the error codes of the
59  * "virtual socket operations" in this file, via the WSAGetLastError
60  * APIs. If code in this file wants to set a specific WSA error that
61  * c-ares should read, it must do so by calling SetWSAError() on the
62  * WSAErrorContext instance passed to it. A WSAErrorContext must only be
63  * instantiated at the top of the virtual socket function callstack. */
64 class WSAErrorContext {
65  public:
WSAErrorContext()66   explicit WSAErrorContext(){};
67 
~WSAErrorContext()68   ~WSAErrorContext() {
69     if (error_ != 0) {
70       WSASetLastError(error_);
71     }
72   }
73 
74   /* Disallow copy and assignment operators */
75   WSAErrorContext(const WSAErrorContext&) = delete;
76   WSAErrorContext& operator=(const WSAErrorContext&) = delete;
77 
SetWSAError(int error)78   void SetWSAError(int error) { error_ = error; }
79 
80  private:
81   int error_ = 0;
82 };
83 
84 /* c-ares creates its own sockets and is meant to read them when readable and
85  * write them when writeable. To fit this socket usage model into the grpc
86  * windows poller (which gives notifications when attempted reads and writes are
87  * actually fulfilled rather than possible), this GrpcPolledFdWindows class
88  * takes advantage of the ares_set_socket_functions API and acts as a virtual
89  * socket. It holds its own read and write buffers which are written to and read
90  * from c-ares and are used with the grpc windows poller, and it, e.g.,
91  * manufactures virtual socket error codes when it e.g. needs to tell the c-ares
92  * library to wait for an async read. */
93 class GrpcPolledFdWindows {
94  public:
95   enum WriteState {
96     WRITE_IDLE,
97     WRITE_REQUESTED,
98     WRITE_PENDING,
99     WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY,
100   };
101 
GrpcPolledFdWindows(ares_socket_t as,std::shared_ptr<WorkSerializer> work_serializer,int address_family,int socket_type)102   GrpcPolledFdWindows(ares_socket_t as,
103                       std::shared_ptr<WorkSerializer> work_serializer,
104                       int address_family, int socket_type)
105       : work_serializer_(std::move(work_serializer)),
106         read_buf_(grpc_empty_slice()),
107         write_buf_(grpc_empty_slice()),
108         tcp_write_state_(WRITE_IDLE),
109         name_(absl::StrFormat("c-ares socket: %" PRIdPTR, as)),
110         gotten_into_driver_list_(false),
111         address_family_(address_family),
112         socket_type_(socket_type) {
113     // Closure Initialization
114     GRPC_CLOSURE_INIT(&outer_read_closure_,
115                       &GrpcPolledFdWindows::OnIocpReadable, this,
116                       grpc_schedule_on_exec_ctx);
117     GRPC_CLOSURE_INIT(&outer_write_closure_,
118                       &GrpcPolledFdWindows::OnIocpWriteable, this,
119                       grpc_schedule_on_exec_ctx);
120     GRPC_CLOSURE_INIT(&on_tcp_connect_locked_,
121                       &GrpcPolledFdWindows::OnTcpConnect, this,
122                       grpc_schedule_on_exec_ctx);
123     winsocket_ = grpc_winsocket_create(as, name_.c_str());
124   }
125 
~GrpcPolledFdWindows()126   ~GrpcPolledFdWindows() {
127     grpc_slice_unref_internal(read_buf_);
128     grpc_slice_unref_internal(write_buf_);
129     GPR_ASSERT(read_closure_ == nullptr);
130     GPR_ASSERT(write_closure_ == nullptr);
131     grpc_winsocket_destroy(winsocket_);
132   }
133 
ScheduleAndNullReadClosure(grpc_error * error)134   void ScheduleAndNullReadClosure(grpc_error* error) {
135     grpc_core::ExecCtx::Run(DEBUG_LOCATION, read_closure_, error);
136     read_closure_ = nullptr;
137   }
138 
ScheduleAndNullWriteClosure(grpc_error * error)139   void ScheduleAndNullWriteClosure(grpc_error* error) {
140     grpc_core::ExecCtx::Run(DEBUG_LOCATION, write_closure_, error);
141     write_closure_ = nullptr;
142   }
143 
RegisterForOnReadableLocked(grpc_closure * read_closure)144   void RegisterForOnReadableLocked(grpc_closure* read_closure) {
145     GPR_ASSERT(read_closure_ == nullptr);
146     read_closure_ = read_closure;
147     GPR_ASSERT(GRPC_SLICE_LENGTH(read_buf_) == 0);
148     grpc_slice_unref_internal(read_buf_);
149     GPR_ASSERT(!read_buf_has_data_);
150     read_buf_ = GRPC_SLICE_MALLOC(4192);
151     if (connect_done_) {
152       work_serializer_->Run([this]() { ContinueRegisterForOnReadableLocked(); },
153                             DEBUG_LOCATION);
154     } else {
155       GPR_ASSERT(pending_continue_register_for_on_readable_locked_ == false);
156       pending_continue_register_for_on_readable_locked_ = true;
157     }
158   }
159 
ContinueRegisterForOnReadableLocked()160   void ContinueRegisterForOnReadableLocked() {
161     GRPC_CARES_TRACE_LOG(
162         "fd:|%s| InnerContinueRegisterForOnReadableLocked "
163         "wsa_connect_error_:%d",
164         GetName(), wsa_connect_error_);
165     GPR_ASSERT(connect_done_);
166     if (wsa_connect_error_ != 0) {
167       ScheduleAndNullReadClosure(GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
168       return;
169     }
170     WSABUF buffer;
171     buffer.buf = (char*)GRPC_SLICE_START_PTR(read_buf_);
172     buffer.len = GRPC_SLICE_LENGTH(read_buf_);
173     memset(&winsocket_->read_info.overlapped, 0, sizeof(OVERLAPPED));
174     recv_from_source_addr_len_ = sizeof(recv_from_source_addr_);
175     DWORD flags = 0;
176     if (WSARecvFrom(grpc_winsocket_wrapped_socket(winsocket_), &buffer, 1,
177                     nullptr, &flags, (sockaddr*)recv_from_source_addr_,
178                     &recv_from_source_addr_len_,
179                     &winsocket_->read_info.overlapped, nullptr)) {
180       int wsa_last_error = WSAGetLastError();
181       char* msg = gpr_format_message(wsa_last_error);
182       GRPC_CARES_TRACE_LOG(
183           "fd:|%s| RegisterForOnReadableLocked WSARecvFrom error code:|%d| "
184           "msg:|%s|",
185           GetName(), wsa_last_error, msg);
186       gpr_free(msg);
187       if (wsa_last_error != WSA_IO_PENDING) {
188         ScheduleAndNullReadClosure(
189             GRPC_WSA_ERROR(wsa_last_error, "WSARecvFrom"));
190         return;
191       }
192     }
193     grpc_socket_notify_on_read(winsocket_, &outer_read_closure_);
194   }
195 
RegisterForOnWriteableLocked(grpc_closure * write_closure)196   void RegisterForOnWriteableLocked(grpc_closure* write_closure) {
197     if (socket_type_ == SOCK_DGRAM) {
198       GRPC_CARES_TRACE_LOG("fd:|%s| RegisterForOnWriteableLocked called",
199                            GetName());
200     } else {
201       GPR_ASSERT(socket_type_ == SOCK_STREAM);
202       GRPC_CARES_TRACE_LOG(
203           "fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d",
204           GetName(), tcp_write_state_);
205     }
206     GPR_ASSERT(write_closure_ == nullptr);
207     write_closure_ = write_closure;
208     if (connect_done_) {
209       work_serializer_->Run(
210           [this]() { ContinueRegisterForOnWriteableLocked(); }, DEBUG_LOCATION);
211     } else {
212       GPR_ASSERT(pending_continue_register_for_on_writeable_locked_ == false);
213       pending_continue_register_for_on_writeable_locked_ = true;
214     }
215   }
216 
ContinueRegisterForOnWriteableLocked()217   void ContinueRegisterForOnWriteableLocked() {
218     GRPC_CARES_TRACE_LOG(
219         "fd:|%s| InnerContinueRegisterForOnWriteableLocked "
220         "wsa_connect_error_:%d",
221         GetName(), wsa_connect_error_);
222     GPR_ASSERT(connect_done_);
223     if (wsa_connect_error_ != 0) {
224       ScheduleAndNullWriteClosure(
225           GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
226       return;
227     }
228     if (socket_type_ == SOCK_DGRAM) {
229       ScheduleAndNullWriteClosure(GRPC_ERROR_NONE);
230     } else {
231       GPR_ASSERT(socket_type_ == SOCK_STREAM);
232       int wsa_error_code = 0;
233       switch (tcp_write_state_) {
234         case WRITE_IDLE:
235           ScheduleAndNullWriteClosure(GRPC_ERROR_NONE);
236           break;
237         case WRITE_REQUESTED:
238           tcp_write_state_ = WRITE_PENDING;
239           if (SendWriteBuf(nullptr, &winsocket_->write_info.overlapped,
240                            &wsa_error_code) != 0) {
241             ScheduleAndNullWriteClosure(
242                 GRPC_WSA_ERROR(wsa_error_code, "WSASend (overlapped)"));
243           } else {
244             grpc_socket_notify_on_write(winsocket_, &outer_write_closure_);
245           }
246           break;
247         case WRITE_PENDING:
248         case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
249           abort();
250       }
251     }
252   }
253 
IsFdStillReadableLocked()254   bool IsFdStillReadableLocked() { return GRPC_SLICE_LENGTH(read_buf_) > 0; }
255 
ShutdownLocked(grpc_error * error)256   void ShutdownLocked(grpc_error* error) {
257     grpc_winsocket_shutdown(winsocket_);
258   }
259 
GetWrappedAresSocketLocked()260   ares_socket_t GetWrappedAresSocketLocked() {
261     return grpc_winsocket_wrapped_socket(winsocket_);
262   }
263 
GetName()264   const char* GetName() { return name_.c_str(); }
265 
RecvFrom(WSAErrorContext * wsa_error_ctx,void * data,ares_socket_t data_len,int flags,struct sockaddr * from,ares_socklen_t * from_len)266   ares_ssize_t RecvFrom(WSAErrorContext* wsa_error_ctx, void* data,
267                         ares_socket_t data_len, int flags,
268                         struct sockaddr* from, ares_socklen_t* from_len) {
269     GRPC_CARES_TRACE_LOG(
270         "fd:|%s| RecvFrom called read_buf_has_data:%d Current read buf "
271         "length:|%d|",
272         GetName(), read_buf_has_data_, GRPC_SLICE_LENGTH(read_buf_));
273     if (!read_buf_has_data_) {
274       wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
275       return -1;
276     }
277     ares_ssize_t bytes_read = 0;
278     for (size_t i = 0; i < GRPC_SLICE_LENGTH(read_buf_) && i < data_len; i++) {
279       ((char*)data)[i] = GRPC_SLICE_START_PTR(read_buf_)[i];
280       bytes_read++;
281     }
282     read_buf_ = grpc_slice_sub_no_ref(read_buf_, bytes_read,
283                                       GRPC_SLICE_LENGTH(read_buf_));
284     if (GRPC_SLICE_LENGTH(read_buf_) == 0) {
285       read_buf_has_data_ = false;
286     }
287     /* c-ares overloads this recv_from virtual socket function to receive
288      * data on both UDP and TCP sockets, and from is nullptr for TCP. */
289     if (from != nullptr) {
290       GPR_ASSERT(*from_len <= recv_from_source_addr_len_);
291       memcpy(from, &recv_from_source_addr_, recv_from_source_addr_len_);
292       *from_len = recv_from_source_addr_len_;
293     }
294     return bytes_read;
295   }
296 
FlattenIovec(const struct iovec * iov,int iov_count)297   grpc_slice FlattenIovec(const struct iovec* iov, int iov_count) {
298     int total = 0;
299     for (int i = 0; i < iov_count; i++) {
300       total += iov[i].iov_len;
301     }
302     grpc_slice out = GRPC_SLICE_MALLOC(total);
303     size_t cur = 0;
304     for (int i = 0; i < iov_count; i++) {
305       for (int k = 0; k < iov[i].iov_len; k++) {
306         GRPC_SLICE_START_PTR(out)[cur++] = ((char*)iov[i].iov_base)[k];
307       }
308     }
309     return out;
310   }
311 
SendWriteBuf(LPDWORD bytes_sent_ptr,LPWSAOVERLAPPED overlapped,int * wsa_error_code)312   int SendWriteBuf(LPDWORD bytes_sent_ptr, LPWSAOVERLAPPED overlapped,
313                    int* wsa_error_code) {
314     WSABUF buf;
315     buf.len = GRPC_SLICE_LENGTH(write_buf_);
316     buf.buf = (char*)GRPC_SLICE_START_PTR(write_buf_);
317     DWORD flags = 0;
318     int out = WSASend(grpc_winsocket_wrapped_socket(winsocket_), &buf, 1,
319                       bytes_sent_ptr, flags, overlapped, nullptr);
320     *wsa_error_code = WSAGetLastError();
321     GRPC_CARES_TRACE_LOG(
322         "fd:|%s| SendWriteBuf WSASend buf.len:%d *bytes_sent_ptr:%d "
323         "overlapped:%p "
324         "return:%d *wsa_error_code:%d",
325         GetName(), buf.len, bytes_sent_ptr != nullptr ? *bytes_sent_ptr : 0,
326         overlapped, out, *wsa_error_code);
327     return out;
328   }
329 
SendV(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)330   ares_ssize_t SendV(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
331                      int iov_count) {
332     GRPC_CARES_TRACE_LOG(
333         "fd:|%s| SendV called connect_done_:%d wsa_connect_error_:%d",
334         GetName(), connect_done_, wsa_connect_error_);
335     if (!connect_done_) {
336       wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
337       return -1;
338     }
339     if (wsa_connect_error_ != 0) {
340       wsa_error_ctx->SetWSAError(wsa_connect_error_);
341       return -1;
342     }
343     switch (socket_type_) {
344       case SOCK_DGRAM:
345         return SendVUDP(wsa_error_ctx, iov, iov_count);
346       case SOCK_STREAM:
347         return SendVTCP(wsa_error_ctx, iov, iov_count);
348       default:
349         abort();
350     }
351   }
352 
SendVUDP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)353   ares_ssize_t SendVUDP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
354                         int iov_count) {
355     // c-ares doesn't handle retryable errors on writes of UDP sockets.
356     // Therefore, the sendv handler for UDP sockets must only attempt
357     // to write everything inline.
358     GRPC_CARES_TRACE_LOG("fd:|%s| SendVUDP called", GetName());
359     GPR_ASSERT(GRPC_SLICE_LENGTH(write_buf_) == 0);
360     grpc_slice_unref_internal(write_buf_);
361     write_buf_ = FlattenIovec(iov, iov_count);
362     DWORD bytes_sent = 0;
363     int wsa_error_code = 0;
364     if (SendWriteBuf(&bytes_sent, nullptr, &wsa_error_code) != 0) {
365       wsa_error_ctx->SetWSAError(wsa_error_code);
366       char* msg = gpr_format_message(wsa_error_code);
367       GRPC_CARES_TRACE_LOG(
368           "fd:|%s| SendVUDP SendWriteBuf error code:%d msg:|%s|", GetName(),
369           wsa_error_code, msg);
370       gpr_free(msg);
371       return -1;
372     }
373     write_buf_ = grpc_slice_sub_no_ref(write_buf_, bytes_sent,
374                                        GRPC_SLICE_LENGTH(write_buf_));
375     return bytes_sent;
376   }
377 
SendVTCP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)378   ares_ssize_t SendVTCP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
379                         int iov_count) {
380     // The "sendv" handler on TCP sockets buffers up write
381     // requests and returns an artificial WSAEWOULDBLOCK. Writing that buffer
382     // out in the background, and making further send progress in general, will
383     // happen as long as c-ares continues to show interest in writeability on
384     // this fd.
385     GRPC_CARES_TRACE_LOG("fd:|%s| SendVTCP called tcp_write_state_:%d",
386                          GetName(), tcp_write_state_);
387     switch (tcp_write_state_) {
388       case WRITE_IDLE:
389         tcp_write_state_ = WRITE_REQUESTED;
390         GPR_ASSERT(GRPC_SLICE_LENGTH(write_buf_) == 0);
391         grpc_slice_unref_internal(write_buf_);
392         write_buf_ = FlattenIovec(iov, iov_count);
393         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
394         return -1;
395       case WRITE_REQUESTED:
396       case WRITE_PENDING:
397         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
398         return -1;
399       case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
400         // c-ares is retrying a send on data that we previously returned
401         // WSAEWOULDBLOCK for, but then subsequently wrote out in the
402         // background. Right now, we assume that c-ares is retrying the same
403         // send again. If c-ares still needs to send even more data, we'll get
404         // to it eventually.
405         grpc_slice currently_attempted = FlattenIovec(iov, iov_count);
406         GPR_ASSERT(GRPC_SLICE_LENGTH(currently_attempted) >=
407                    GRPC_SLICE_LENGTH(write_buf_));
408         ares_ssize_t total_sent = 0;
409         for (size_t i = 0; i < GRPC_SLICE_LENGTH(write_buf_); i++) {
410           GPR_ASSERT(GRPC_SLICE_START_PTR(currently_attempted)[i] ==
411                      GRPC_SLICE_START_PTR(write_buf_)[i]);
412           total_sent++;
413         }
414         grpc_slice_unref_internal(currently_attempted);
415         tcp_write_state_ = WRITE_IDLE;
416         return total_sent;
417     }
418     abort();
419   }
420 
OnTcpConnect(void * arg,grpc_error * error)421   static void OnTcpConnect(void* arg, grpc_error* error) {
422     GrpcPolledFdWindows* grpc_polled_fd =
423         static_cast<GrpcPolledFdWindows*>(arg);
424     GRPC_ERROR_REF(error);  // ref owned by lambda
425     grpc_polled_fd->work_serializer_->Run(
426         [grpc_polled_fd, error]() {
427           grpc_polled_fd->OnTcpConnectLocked(error);
428         },
429         DEBUG_LOCATION);
430   }
431 
OnTcpConnectLocked(grpc_error * error)432   void OnTcpConnectLocked(grpc_error* error) {
433     GRPC_CARES_TRACE_LOG(
434         "fd:%s InnerOnTcpConnectLocked error:|%s| "
435         "pending_register_for_readable:%d"
436         " pending_register_for_writeable:%d",
437         GetName(), grpc_error_string(error),
438         pending_continue_register_for_on_readable_locked_,
439         pending_continue_register_for_on_writeable_locked_);
440     GPR_ASSERT(!connect_done_);
441     connect_done_ = true;
442     GPR_ASSERT(wsa_connect_error_ == 0);
443     if (error == GRPC_ERROR_NONE) {
444       DWORD transferred_bytes = 0;
445       DWORD flags;
446       BOOL wsa_success =
447           WSAGetOverlappedResult(grpc_winsocket_wrapped_socket(winsocket_),
448                                  &winsocket_->write_info.overlapped,
449                                  &transferred_bytes, FALSE, &flags);
450       GPR_ASSERT(transferred_bytes == 0);
451       if (!wsa_success) {
452         wsa_connect_error_ = WSAGetLastError();
453         char* msg = gpr_format_message(wsa_connect_error_);
454         GRPC_CARES_TRACE_LOG(
455             "fd:%s InnerOnTcpConnectLocked WSA overlapped result code:%d "
456             "msg:|%s|",
457             GetName(), wsa_connect_error_, msg);
458         gpr_free(msg);
459       }
460     } else {
461       // Spoof up an error code that will cause any future c-ares operations on
462       // this fd to abort.
463       wsa_connect_error_ = WSA_OPERATION_ABORTED;
464     }
465     if (pending_continue_register_for_on_readable_locked_) {
466       work_serializer_->Run([this]() { ContinueRegisterForOnReadableLocked(); },
467                             DEBUG_LOCATION);
468     }
469     if (pending_continue_register_for_on_writeable_locked_) {
470       work_serializer_->Run(
471           [this]() { ContinueRegisterForOnWriteableLocked(); }, DEBUG_LOCATION);
472     }
473     GRPC_ERROR_UNREF(error);
474   }
475 
Connect(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)476   int Connect(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
477               ares_socklen_t target_len) {
478     switch (socket_type_) {
479       case SOCK_DGRAM:
480         return ConnectUDP(wsa_error_ctx, target, target_len);
481       case SOCK_STREAM:
482         return ConnectTCP(wsa_error_ctx, target, target_len);
483       default:
484         abort();
485     }
486   }
487 
ConnectUDP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)488   int ConnectUDP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
489                  ares_socklen_t target_len) {
490     GRPC_CARES_TRACE_LOG("fd:%s ConnectUDP", GetName());
491     GPR_ASSERT(!connect_done_);
492     GPR_ASSERT(wsa_connect_error_ == 0);
493     SOCKET s = grpc_winsocket_wrapped_socket(winsocket_);
494     int out =
495         WSAConnect(s, target, target_len, nullptr, nullptr, nullptr, nullptr);
496     wsa_connect_error_ = WSAGetLastError();
497     wsa_error_ctx->SetWSAError(wsa_connect_error_);
498     connect_done_ = true;
499     char* msg = gpr_format_message(wsa_connect_error_);
500     GRPC_CARES_TRACE_LOG("fd:%s WSAConnect error code:|%d| msg:|%s|", GetName(),
501                          wsa_connect_error_, msg);
502     gpr_free(msg);
503     // c-ares expects a posix-style connect API
504     return out == 0 ? 0 : -1;
505   }
506 
ConnectTCP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)507   int ConnectTCP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
508                  ares_socklen_t target_len) {
509     GRPC_CARES_TRACE_LOG("fd:%s ConnectTCP", GetName());
510     LPFN_CONNECTEX ConnectEx;
511     GUID guid = WSAID_CONNECTEX;
512     DWORD ioctl_num_bytes;
513     SOCKET s = grpc_winsocket_wrapped_socket(winsocket_);
514     if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
515                  &ConnectEx, sizeof(ConnectEx), &ioctl_num_bytes, nullptr,
516                  nullptr) != 0) {
517       int wsa_last_error = WSAGetLastError();
518       wsa_error_ctx->SetWSAError(wsa_last_error);
519       char* msg = gpr_format_message(wsa_last_error);
520       GRPC_CARES_TRACE_LOG(
521           "fd:%s WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER) error code:%d "
522           "msg:|%s|",
523           GetName(), wsa_last_error, msg);
524       gpr_free(msg);
525       connect_done_ = true;
526       wsa_connect_error_ = wsa_last_error;
527       return -1;
528     }
529     grpc_resolved_address wildcard4_addr;
530     grpc_resolved_address wildcard6_addr;
531     grpc_sockaddr_make_wildcards(0, &wildcard4_addr, &wildcard6_addr);
532     grpc_resolved_address* local_address = nullptr;
533     if (address_family_ == AF_INET) {
534       local_address = &wildcard4_addr;
535     } else {
536       local_address = &wildcard6_addr;
537     }
538     if (bind(s, (struct sockaddr*)local_address->addr,
539              (int)local_address->len) != 0) {
540       int wsa_last_error = WSAGetLastError();
541       wsa_error_ctx->SetWSAError(wsa_last_error);
542       char* msg = gpr_format_message(wsa_last_error);
543       GRPC_CARES_TRACE_LOG("fd:%s bind error code:%d msg:|%s|", GetName(),
544                            wsa_last_error, msg);
545       gpr_free(msg);
546       connect_done_ = true;
547       wsa_connect_error_ = wsa_last_error;
548       return -1;
549     }
550     int out = 0;
551     if (ConnectEx(s, target, target_len, nullptr, 0, nullptr,
552                   &winsocket_->write_info.overlapped) == 0) {
553       out = -1;
554       int wsa_last_error = WSAGetLastError();
555       wsa_error_ctx->SetWSAError(wsa_last_error);
556       char* msg = gpr_format_message(wsa_last_error);
557       GRPC_CARES_TRACE_LOG("fd:%s ConnectEx error code:%d msg:|%s|", GetName(),
558                            wsa_last_error, msg);
559       gpr_free(msg);
560       if (wsa_last_error == WSA_IO_PENDING) {
561         // c-ares only understands WSAEINPROGRESS and EWOULDBLOCK error codes on
562         // connect, but an async connect on IOCP socket will give
563         // WSA_IO_PENDING, so we need to convert.
564         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
565       } else {
566         // By returning a non-retryable error to c-ares at this point,
567         // we're aborting the possibility of any future operations on this fd.
568         connect_done_ = true;
569         wsa_connect_error_ = wsa_last_error;
570         return -1;
571       }
572     }
573     grpc_socket_notify_on_write(winsocket_, &on_tcp_connect_locked_);
574     return out;
575   }
576 
OnIocpReadable(void * arg,grpc_error * error)577   static void OnIocpReadable(void* arg, grpc_error* error) {
578     GrpcPolledFdWindows* polled_fd = static_cast<GrpcPolledFdWindows*>(arg);
579     GRPC_ERROR_REF(error);  // ref owned by lambda
580     polled_fd->work_serializer_->Run(
581         [polled_fd, error]() { polled_fd->OnIocpReadableLocked(error); },
582         DEBUG_LOCATION);
583   }
584 
585   // TODO(apolcyn): improve this error handling to be less conversative.
586   // An e.g. ECONNRESET error here should result in errors when
587   // c-ares reads from this socket later, but it shouldn't necessarily cancel
588   // the entire resolution attempt. Doing so will allow the "inject broken
589   // nameserver list" test to pass on Windows.
OnIocpReadableLocked(grpc_error * error)590   void OnIocpReadableLocked(grpc_error* error) {
591     if (error == GRPC_ERROR_NONE) {
592       if (winsocket_->read_info.wsa_error != 0) {
593         /* WSAEMSGSIZE would be due to receiving more data
594          * than our read buffer's fixed capacity. Assume that
595          * the connection is TCP and read the leftovers
596          * in subsequent c-ares reads. */
597         if (winsocket_->read_info.wsa_error != WSAEMSGSIZE) {
598           error = GRPC_WSA_ERROR(winsocket_->read_info.wsa_error,
599                                  "OnIocpReadableInner");
600           GRPC_CARES_TRACE_LOG(
601               "fd:|%s| OnIocpReadableInner winsocket_->read_info.wsa_error "
602               "code:|%d| msg:|%s|",
603               GetName(), winsocket_->read_info.wsa_error,
604               grpc_error_string(error));
605         }
606       }
607     }
608     if (error == GRPC_ERROR_NONE) {
609       read_buf_ = grpc_slice_sub_no_ref(
610           read_buf_, 0, winsocket_->read_info.bytes_transferred);
611       read_buf_has_data_ = true;
612     } else {
613       grpc_slice_unref_internal(read_buf_);
614       read_buf_ = grpc_empty_slice();
615     }
616     GRPC_CARES_TRACE_LOG(
617         "fd:|%s| OnIocpReadable finishing. read buf length now:|%d|", GetName(),
618         GRPC_SLICE_LENGTH(read_buf_));
619     ScheduleAndNullReadClosure(error);
620   }
621 
OnIocpWriteable(void * arg,grpc_error * error)622   static void OnIocpWriteable(void* arg, grpc_error* error) {
623     GrpcPolledFdWindows* polled_fd = static_cast<GrpcPolledFdWindows*>(arg);
624     GRPC_ERROR_REF(error);  // error owned by lambda
625     polled_fd->work_serializer_->Run(
626         [polled_fd, error]() { polled_fd->OnIocpWriteableLocked(error); },
627         DEBUG_LOCATION);
628   }
629 
OnIocpWriteableLocked(grpc_error * error)630   void OnIocpWriteableLocked(grpc_error* error) {
631     GRPC_CARES_TRACE_LOG("OnIocpWriteableInner. fd:|%s|", GetName());
632     GPR_ASSERT(socket_type_ == SOCK_STREAM);
633     if (error == GRPC_ERROR_NONE) {
634       if (winsocket_->write_info.wsa_error != 0) {
635         error = GRPC_WSA_ERROR(winsocket_->write_info.wsa_error,
636                                "OnIocpWriteableInner");
637         GRPC_CARES_TRACE_LOG(
638             "fd:|%s| OnIocpWriteableInner. winsocket_->write_info.wsa_error "
639             "code:|%d| msg:|%s|",
640             GetName(), winsocket_->write_info.wsa_error,
641             grpc_error_string(error));
642       }
643     }
644     GPR_ASSERT(tcp_write_state_ == WRITE_PENDING);
645     if (error == GRPC_ERROR_NONE) {
646       tcp_write_state_ = WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY;
647       write_buf_ = grpc_slice_sub_no_ref(
648           write_buf_, 0, winsocket_->write_info.bytes_transferred);
649       GRPC_CARES_TRACE_LOG("fd:|%s| OnIocpWriteableInner. bytes transferred:%d",
650                            GetName(), winsocket_->write_info.bytes_transferred);
651     } else {
652       grpc_slice_unref_internal(write_buf_);
653       write_buf_ = grpc_empty_slice();
654     }
655     ScheduleAndNullWriteClosure(error);
656   }
657 
gotten_into_driver_list() const658   bool gotten_into_driver_list() const { return gotten_into_driver_list_; }
set_gotten_into_driver_list()659   void set_gotten_into_driver_list() { gotten_into_driver_list_ = true; }
660 
661  private:
662   std::shared_ptr<WorkSerializer> work_serializer_;
663   char recv_from_source_addr_[200];
664   ares_socklen_t recv_from_source_addr_len_;
665   grpc_slice read_buf_;
666   bool read_buf_has_data_ = false;
667   grpc_slice write_buf_;
668   grpc_closure* read_closure_ = nullptr;
669   grpc_closure* write_closure_ = nullptr;
670   grpc_closure outer_read_closure_;
671   grpc_closure outer_write_closure_;
672   grpc_winsocket* winsocket_;
673   // tcp_write_state_ is only used on TCP GrpcPolledFds
674   WriteState tcp_write_state_;
675   std::string name_;
676   bool gotten_into_driver_list_;
677   int address_family_;
678   int socket_type_;
679   grpc_closure on_tcp_connect_locked_;
680   bool connect_done_ = false;
681   int wsa_connect_error_ = 0;
682   // We don't run register_for_{readable,writeable} logic until
683   // a socket is connected. In the interim, we queue readable/writeable
684   // registrations with the following state.
685   bool pending_continue_register_for_on_readable_locked_ = false;
686   bool pending_continue_register_for_on_writeable_locked_ = false;
687 };
688 
689 struct SockToPolledFdEntry {
SockToPolledFdEntrygrpc_core::SockToPolledFdEntry690   SockToPolledFdEntry(SOCKET s, GrpcPolledFdWindows* fd)
691       : socket(s), polled_fd(fd) {}
692   SOCKET socket;
693   GrpcPolledFdWindows* polled_fd;
694   SockToPolledFdEntry* next = nullptr;
695 };
696 
697 /* A SockToPolledFdMap can make ares_socket_t types (SOCKET's on windows)
698  * to GrpcPolledFdWindow's, and is used to find the appropriate
699  * GrpcPolledFdWindows to handle a virtual socket call when c-ares makes that
700  * socket call on the ares_socket_t type. Instances are owned by and one-to-one
701  * with a GrpcPolledFdWindows factory and event driver */
702 class SockToPolledFdMap {
703  public:
SockToPolledFdMap(std::shared_ptr<WorkSerializer> work_serializer)704   explicit SockToPolledFdMap(std::shared_ptr<WorkSerializer> work_serializer)
705       : work_serializer_(std::move(work_serializer)) {}
706 
~SockToPolledFdMap()707   ~SockToPolledFdMap() { GPR_ASSERT(head_ == nullptr); }
708 
AddNewSocket(SOCKET s,GrpcPolledFdWindows * polled_fd)709   void AddNewSocket(SOCKET s, GrpcPolledFdWindows* polled_fd) {
710     SockToPolledFdEntry* new_node = new SockToPolledFdEntry(s, polled_fd);
711     new_node->next = head_;
712     head_ = new_node;
713   }
714 
LookupPolledFd(SOCKET s)715   GrpcPolledFdWindows* LookupPolledFd(SOCKET s) {
716     for (SockToPolledFdEntry* node = head_; node != nullptr;
717          node = node->next) {
718       if (node->socket == s) {
719         GPR_ASSERT(node->polled_fd != nullptr);
720         return node->polled_fd;
721       }
722     }
723     abort();
724   }
725 
RemoveEntry(SOCKET s)726   void RemoveEntry(SOCKET s) {
727     GPR_ASSERT(head_ != nullptr);
728     SockToPolledFdEntry** prev = &head_;
729     for (SockToPolledFdEntry* node = head_; node != nullptr;
730          node = node->next) {
731       if (node->socket == s) {
732         *prev = node->next;
733         delete node;
734         return;
735       }
736       prev = &node->next;
737     }
738     abort();
739   }
740 
741   /* These virtual socket functions are called from within the c-ares
742    * library. These methods generally dispatch those socket calls to the
743    * appropriate methods. The virtual "socket" and "close" methods are
744    * special and instead create/add and remove/destroy GrpcPolledFdWindows
745    * objects.
746    */
Socket(int af,int type,int protocol,void * user_data)747   static ares_socket_t Socket(int af, int type, int protocol, void* user_data) {
748     if (type != SOCK_DGRAM && type != SOCK_STREAM) {
749       GRPC_CARES_TRACE_LOG("Socket called with invalid socket type:%d", type);
750       return INVALID_SOCKET;
751     }
752     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
753     SOCKET s = WSASocket(af, type, protocol, nullptr, 0,
754                          grpc_get_default_wsa_socket_flags());
755     if (s == INVALID_SOCKET) {
756       GRPC_CARES_TRACE_LOG(
757           "WSASocket failed with params af:%d type:%d protocol:%d", af, type,
758           protocol);
759       return s;
760     }
761     grpc_tcp_set_non_block(s);
762     GrpcPolledFdWindows* polled_fd =
763         new GrpcPolledFdWindows(s, map->work_serializer_, af, type);
764     GRPC_CARES_TRACE_LOG(
765         "fd:|%s| created with params af:%d type:%d protocol:%d",
766         polled_fd->GetName(), af, type, protocol);
767     map->AddNewSocket(s, polled_fd);
768     return s;
769   }
770 
Connect(ares_socket_t as,const struct sockaddr * target,ares_socklen_t target_len,void * user_data)771   static int Connect(ares_socket_t as, const struct sockaddr* target,
772                      ares_socklen_t target_len, void* user_data) {
773     WSAErrorContext wsa_error_ctx;
774     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
775     GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
776     return polled_fd->Connect(&wsa_error_ctx, target, target_len);
777   }
778 
SendV(ares_socket_t as,const struct iovec * iov,int iovec_count,void * user_data)779   static ares_ssize_t SendV(ares_socket_t as, const struct iovec* iov,
780                             int iovec_count, void* user_data) {
781     WSAErrorContext wsa_error_ctx;
782     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
783     GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
784     return polled_fd->SendV(&wsa_error_ctx, iov, iovec_count);
785   }
786 
RecvFrom(ares_socket_t as,void * data,size_t data_len,int flags,struct sockaddr * from,ares_socklen_t * from_len,void * user_data)787   static ares_ssize_t RecvFrom(ares_socket_t as, void* data, size_t data_len,
788                                int flags, struct sockaddr* from,
789                                ares_socklen_t* from_len, void* user_data) {
790     WSAErrorContext wsa_error_ctx;
791     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
792     GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
793     return polled_fd->RecvFrom(&wsa_error_ctx, data, data_len, flags, from,
794                                from_len);
795   }
796 
CloseSocket(SOCKET s,void * user_data)797   static int CloseSocket(SOCKET s, void* user_data) {
798     SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
799     GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(s);
800     map->RemoveEntry(s);
801     // See https://github.com/grpc/grpc/pull/20284, this trace log is
802     // intentionally placed to attempt to trigger a crash in case of a
803     // use after free on polled_fd.
804     GRPC_CARES_TRACE_LOG("CloseSocket called for socket: %s",
805                          polled_fd->GetName());
806     // If a gRPC polled fd has not made it in to the driver's list yet, then
807     // the driver has not and will never see this socket.
808     if (!polled_fd->gotten_into_driver_list()) {
809       polled_fd->ShutdownLocked(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
810           "Shut down c-ares fd before without it ever having made it into the "
811           "driver's list"));
812     }
813     delete polled_fd;
814     return 0;
815   }
816 
817  private:
818   SockToPolledFdEntry* head_ = nullptr;
819   std::shared_ptr<WorkSerializer> work_serializer_;
820 };
821 
822 const struct ares_socket_functions custom_ares_sock_funcs = {
823     &SockToPolledFdMap::Socket /* socket */,
824     &SockToPolledFdMap::CloseSocket /* close */,
825     &SockToPolledFdMap::Connect /* connect */,
826     &SockToPolledFdMap::RecvFrom /* recvfrom */,
827     &SockToPolledFdMap::SendV /* sendv */,
828 };
829 
830 /* A thin wrapper over a GrpcPolledFdWindows object but with a shorter
831    lifetime. This object releases it's GrpcPolledFdWindows upon destruction,
832    so that c-ares can close it via usual socket teardown. */
833 class GrpcPolledFdWindowsWrapper : public GrpcPolledFd {
834  public:
GrpcPolledFdWindowsWrapper(GrpcPolledFdWindows * wrapped)835   explicit GrpcPolledFdWindowsWrapper(GrpcPolledFdWindows* wrapped)
836       : wrapped_(wrapped) {}
837 
~GrpcPolledFdWindowsWrapper()838   ~GrpcPolledFdWindowsWrapper() {}
839 
RegisterForOnReadableLocked(grpc_closure * read_closure)840   void RegisterForOnReadableLocked(grpc_closure* read_closure) override {
841     wrapped_->RegisterForOnReadableLocked(read_closure);
842   }
843 
RegisterForOnWriteableLocked(grpc_closure * write_closure)844   void RegisterForOnWriteableLocked(grpc_closure* write_closure) override {
845     wrapped_->RegisterForOnWriteableLocked(write_closure);
846   }
847 
IsFdStillReadableLocked()848   bool IsFdStillReadableLocked() override {
849     return wrapped_->IsFdStillReadableLocked();
850   }
851 
ShutdownLocked(grpc_error * error)852   void ShutdownLocked(grpc_error* error) override {
853     wrapped_->ShutdownLocked(error);
854   }
855 
GetWrappedAresSocketLocked()856   ares_socket_t GetWrappedAresSocketLocked() override {
857     return wrapped_->GetWrappedAresSocketLocked();
858   }
859 
GetName()860   const char* GetName() override { return wrapped_->GetName(); }
861 
862  private:
863   GrpcPolledFdWindows* wrapped_;
864 };
865 
866 class GrpcPolledFdFactoryWindows : public GrpcPolledFdFactory {
867  public:
GrpcPolledFdFactoryWindows(std::shared_ptr<WorkSerializer> work_serializer)868   explicit GrpcPolledFdFactoryWindows(
869       std::shared_ptr<WorkSerializer> work_serializer)
870       : sock_to_polled_fd_map_(std::move(work_serializer)) {}
871 
NewGrpcPolledFdLocked(ares_socket_t as,grpc_pollset_set * driver_pollset_set,std::shared_ptr<WorkSerializer> work_serializer)872   GrpcPolledFd* NewGrpcPolledFdLocked(
873       ares_socket_t as, grpc_pollset_set* driver_pollset_set,
874       std::shared_ptr<WorkSerializer> work_serializer) override {
875     GrpcPolledFdWindows* polled_fd = sock_to_polled_fd_map_.LookupPolledFd(as);
876     // Set a flag so that the virtual socket "close" method knows it
877     // doesn't need to call ShutdownLocked, since now the driver will.
878     polled_fd->set_gotten_into_driver_list();
879     return new GrpcPolledFdWindowsWrapper(polled_fd);
880   }
881 
ConfigureAresChannelLocked(ares_channel channel)882   void ConfigureAresChannelLocked(ares_channel channel) override {
883     ares_set_socket_functions(channel, &custom_ares_sock_funcs,
884                               &sock_to_polled_fd_map_);
885   }
886 
887  private:
888   SockToPolledFdMap sock_to_polled_fd_map_;
889 };
890 
NewGrpcPolledFdFactory(std::shared_ptr<WorkSerializer> work_serializer)891 std::unique_ptr<GrpcPolledFdFactory> NewGrpcPolledFdFactory(
892     std::shared_ptr<WorkSerializer> work_serializer) {
893   return absl::make_unique<GrpcPolledFdFactoryWindows>(
894       std::move(work_serializer));
895 }
896 
897 }  // namespace grpc_core
898 
899 #endif /* GRPC_ARES == 1 && defined(GPR_WINDOWS) */
900