1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include "test/cpp/end2end/test_service_impl.h"
20 
21 #include <string>
22 #include <thread>
23 
24 #include <grpc/support/log.h>
25 #include <grpcpp/security/credentials.h>
26 #include <grpcpp/server_context.h>
27 
28 #include "src/proto/grpc/testing/echo.grpc.pb.h"
29 #include "test/cpp/util/string_ref_helper.h"
30 
31 #include <gtest/gtest.h>
32 
33 using std::chrono::system_clock;
34 
35 namespace grpc {
36 namespace testing {
37 namespace {
38 
39 // When echo_deadline is requested, deadline seen in the ServerContext is set in
40 // the response in seconds.
MaybeEchoDeadline(ServerContext * context,const EchoRequest * request,EchoResponse * response)41 void MaybeEchoDeadline(ServerContext* context, const EchoRequest* request,
42                        EchoResponse* response) {
43   if (request->has_param() && request->param().echo_deadline()) {
44     gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_REALTIME);
45     if (context->deadline() != system_clock::time_point::max()) {
46       Timepoint2Timespec(context->deadline(), &deadline);
47     }
48     response->mutable_param()->set_request_deadline(deadline.tv_sec);
49   }
50 }
51 
CheckServerAuthContext(const ServerContext * context,const grpc::string & expected_transport_security_type,const grpc::string & expected_client_identity)52 void CheckServerAuthContext(
53     const ServerContext* context,
54     const grpc::string& expected_transport_security_type,
55     const grpc::string& expected_client_identity) {
56   std::shared_ptr<const AuthContext> auth_ctx = context->auth_context();
57   std::vector<grpc::string_ref> tst =
58       auth_ctx->FindPropertyValues("transport_security_type");
59   EXPECT_EQ(1u, tst.size());
60   EXPECT_EQ(expected_transport_security_type, ToString(tst[0]));
61   if (expected_client_identity.empty()) {
62     EXPECT_TRUE(auth_ctx->GetPeerIdentityPropertyName().empty());
63     EXPECT_TRUE(auth_ctx->GetPeerIdentity().empty());
64     EXPECT_FALSE(auth_ctx->IsPeerAuthenticated());
65   } else {
66     auto identity = auth_ctx->GetPeerIdentity();
67     EXPECT_TRUE(auth_ctx->IsPeerAuthenticated());
68     EXPECT_EQ(1u, identity.size());
69     EXPECT_EQ(expected_client_identity, identity[0]);
70   }
71 }
72 
73 // Returns the number of pairs in metadata that exactly match the given
74 // key-value pair. Returns -1 if the pair wasn't found.
MetadataMatchCount(const std::multimap<grpc::string_ref,grpc::string_ref> & metadata,const grpc::string & key,const grpc::string & value)75 int MetadataMatchCount(
76     const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
77     const grpc::string& key, const grpc::string& value) {
78   int count = 0;
79   for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator iter =
80            metadata.begin();
81        iter != metadata.end(); ++iter) {
82     if (ToString(iter->first) == key && ToString(iter->second) == value) {
83       count++;
84     }
85   }
86   return count;
87 }
88 }  // namespace
89 
90 namespace {
GetIntValueFromMetadataHelper(const char * key,const std::multimap<grpc::string_ref,grpc::string_ref> & metadata,int default_value)91 int GetIntValueFromMetadataHelper(
92     const char* key,
93     const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
94     int default_value) {
95   if (metadata.find(key) != metadata.end()) {
96     std::istringstream iss(ToString(metadata.find(key)->second));
97     iss >> default_value;
98     gpr_log(GPR_INFO, "%s : %d", key, default_value);
99   }
100 
101   return default_value;
102 }
103 
GetIntValueFromMetadata(const char * key,const std::multimap<grpc::string_ref,grpc::string_ref> & metadata,int default_value)104 int GetIntValueFromMetadata(
105     const char* key,
106     const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
107     int default_value) {
108   return GetIntValueFromMetadataHelper(key, metadata, default_value);
109 }
110 
ServerTryCancel(ServerContext * context)111 void ServerTryCancel(ServerContext* context) {
112   EXPECT_FALSE(context->IsCancelled());
113   context->TryCancel();
114   gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
115   // Now wait until it's really canceled
116   while (!context->IsCancelled()) {
117     gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
118                                  gpr_time_from_micros(1000, GPR_TIMESPAN)));
119   }
120 }
121 
ServerTryCancelNonblocking(ServerContext * context)122 void ServerTryCancelNonblocking(ServerContext* context) {
123   EXPECT_FALSE(context->IsCancelled());
124   context->TryCancel();
125   gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
126 }
127 
LoopUntilCancelled(Alarm * alarm,ServerContext * context,experimental::ServerCallbackRpcController * controller,int loop_delay_us)128 void LoopUntilCancelled(Alarm* alarm, ServerContext* context,
129                         experimental::ServerCallbackRpcController* controller,
130                         int loop_delay_us) {
131   if (!context->IsCancelled()) {
132     alarm->experimental().Set(
133         gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
134                      gpr_time_from_micros(loop_delay_us, GPR_TIMESPAN)),
135         [alarm, context, controller, loop_delay_us](bool) {
136           LoopUntilCancelled(alarm, context, controller, loop_delay_us);
137         });
138   } else {
139     controller->Finish(Status::CANCELLED);
140   }
141 }
142 }  // namespace
143 
Echo(ServerContext * context,const EchoRequest * request,EchoResponse * response)144 Status TestServiceImpl::Echo(ServerContext* context, const EchoRequest* request,
145                              EchoResponse* response) {
146   // A bit of sleep to make sure that short deadline tests fail
147   if (request->has_param() && request->param().server_sleep_us() > 0) {
148     gpr_sleep_until(
149         gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
150                      gpr_time_from_micros(request->param().server_sleep_us(),
151                                           GPR_TIMESPAN)));
152   }
153 
154   if (request->has_param() && request->param().server_die()) {
155     gpr_log(GPR_ERROR, "The request should not reach application handler.");
156     GPR_ASSERT(0);
157   }
158   if (request->has_param() && request->param().has_expected_error()) {
159     const auto& error = request->param().expected_error();
160     return Status(static_cast<StatusCode>(error.code()), error.error_message(),
161                   error.binary_error_details());
162   }
163   int server_try_cancel = GetIntValueFromMetadata(
164       kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
165   if (server_try_cancel > DO_NOT_CANCEL) {
166     // Since this is a unary RPC, by the time this server handler is called,
167     // the 'request' message is already read from the client. So the scenarios
168     // in server_try_cancel don't make much sense. Just cancel the RPC as long
169     // as server_try_cancel is not DO_NOT_CANCEL
170     ServerTryCancel(context);
171     return Status::CANCELLED;
172   }
173 
174   response->set_message(request->message());
175   MaybeEchoDeadline(context, request, response);
176   if (host_) {
177     response->mutable_param()->set_host(*host_);
178   }
179   if (request->has_param() && request->param().client_cancel_after_us()) {
180     {
181       std::unique_lock<std::mutex> lock(mu_);
182       signal_client_ = true;
183     }
184     while (!context->IsCancelled()) {
185       gpr_sleep_until(gpr_time_add(
186           gpr_now(GPR_CLOCK_REALTIME),
187           gpr_time_from_micros(request->param().client_cancel_after_us(),
188                                GPR_TIMESPAN)));
189     }
190     return Status::CANCELLED;
191   } else if (request->has_param() &&
192              request->param().server_cancel_after_us()) {
193     gpr_sleep_until(gpr_time_add(
194         gpr_now(GPR_CLOCK_REALTIME),
195         gpr_time_from_micros(request->param().server_cancel_after_us(),
196                              GPR_TIMESPAN)));
197     return Status::CANCELLED;
198   } else if (!request->has_param() ||
199              !request->param().skip_cancelled_check()) {
200     EXPECT_FALSE(context->IsCancelled());
201   }
202 
203   if (request->has_param() && request->param().echo_metadata_initially()) {
204     const std::multimap<grpc::string_ref, grpc::string_ref>& client_metadata =
205         context->client_metadata();
206     for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
207              iter = client_metadata.begin();
208          iter != client_metadata.end(); ++iter) {
209       context->AddInitialMetadata(ToString(iter->first),
210                                   ToString(iter->second));
211     }
212   }
213 
214   if (request->has_param() && request->param().echo_metadata()) {
215     const std::multimap<grpc::string_ref, grpc::string_ref>& client_metadata =
216         context->client_metadata();
217     for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
218              iter = client_metadata.begin();
219          iter != client_metadata.end(); ++iter) {
220       context->AddTrailingMetadata(ToString(iter->first),
221                                    ToString(iter->second));
222     }
223     // Terminate rpc with error and debug info in trailer.
224     if (request->param().debug_info().stack_entries_size() ||
225         !request->param().debug_info().detail().empty()) {
226       grpc::string serialized_debug_info =
227           request->param().debug_info().SerializeAsString();
228       context->AddTrailingMetadata(kDebugInfoTrailerKey, serialized_debug_info);
229       return Status::CANCELLED;
230     }
231   }
232   if (request->has_param() &&
233       (request->param().expected_client_identity().length() > 0 ||
234        request->param().check_auth_context())) {
235     CheckServerAuthContext(context,
236                            request->param().expected_transport_security_type(),
237                            request->param().expected_client_identity());
238   }
239   if (request->has_param() && request->param().response_message_length() > 0) {
240     response->set_message(
241         grpc::string(request->param().response_message_length(), '\0'));
242   }
243   if (request->has_param() && request->param().echo_peer()) {
244     response->mutable_param()->set_peer(context->peer());
245   }
246   return Status::OK;
247 }
248 
CheckClientInitialMetadata(ServerContext * context,const SimpleRequest * request,SimpleResponse * response)249 Status TestServiceImpl::CheckClientInitialMetadata(ServerContext* context,
250                                                    const SimpleRequest* request,
251                                                    SimpleResponse* response) {
252   EXPECT_EQ(MetadataMatchCount(context->client_metadata(),
253                                kCheckClientInitialMetadataKey,
254                                kCheckClientInitialMetadataVal),
255             1);
256   EXPECT_EQ(1u,
257             context->client_metadata().count(kCheckClientInitialMetadataKey));
258   return Status::OK;
259 }
260 
Echo(ServerContext * context,const EchoRequest * request,EchoResponse * response,experimental::ServerCallbackRpcController * controller)261 void CallbackTestServiceImpl::Echo(
262     ServerContext* context, const EchoRequest* request, EchoResponse* response,
263     experimental::ServerCallbackRpcController* controller) {
264   CancelState* cancel_state = new CancelState;
265   int server_use_cancel_callback =
266       GetIntValueFromMetadata(kServerUseCancelCallback,
267                               context->client_metadata(), DO_NOT_USE_CALLBACK);
268   if (server_use_cancel_callback != DO_NOT_USE_CALLBACK) {
269     controller->SetCancelCallback([cancel_state] {
270       EXPECT_FALSE(cancel_state->callback_invoked.exchange(
271           true, std::memory_order_relaxed));
272     });
273     if (server_use_cancel_callback == MAYBE_USE_CALLBACK_EARLY_CANCEL) {
274       EXPECT_TRUE(context->IsCancelled());
275       EXPECT_TRUE(
276           cancel_state->callback_invoked.load(std::memory_order_relaxed));
277     } else {
278       EXPECT_FALSE(context->IsCancelled());
279       EXPECT_FALSE(
280           cancel_state->callback_invoked.load(std::memory_order_relaxed));
281     }
282   }
283   // A bit of sleep to make sure that short deadline tests fail
284   if (request->has_param() && request->param().server_sleep_us() > 0) {
285     // Set an alarm for that much time
286     alarm_.experimental().Set(
287         gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
288                      gpr_time_from_micros(request->param().server_sleep_us(),
289                                           GPR_TIMESPAN)),
290         [this, context, request, response, controller, cancel_state](bool) {
291           EchoNonDelayed(context, request, response, controller, cancel_state);
292         });
293   } else {
294     EchoNonDelayed(context, request, response, controller, cancel_state);
295   }
296 }
297 
CheckClientInitialMetadata(ServerContext * context,const SimpleRequest * request,SimpleResponse * response,experimental::ServerCallbackRpcController * controller)298 void CallbackTestServiceImpl::CheckClientInitialMetadata(
299     ServerContext* context, const SimpleRequest* request,
300     SimpleResponse* response,
301     experimental::ServerCallbackRpcController* controller) {
302   EXPECT_EQ(MetadataMatchCount(context->client_metadata(),
303                                kCheckClientInitialMetadataKey,
304                                kCheckClientInitialMetadataVal),
305             1);
306   EXPECT_EQ(1u,
307             context->client_metadata().count(kCheckClientInitialMetadataKey));
308   controller->Finish(Status::OK);
309 }
310 
EchoNonDelayed(ServerContext * context,const EchoRequest * request,EchoResponse * response,experimental::ServerCallbackRpcController * controller,CancelState * cancel_state)311 void CallbackTestServiceImpl::EchoNonDelayed(
312     ServerContext* context, const EchoRequest* request, EchoResponse* response,
313     experimental::ServerCallbackRpcController* controller,
314     CancelState* cancel_state) {
315   int server_use_cancel_callback =
316       GetIntValueFromMetadata(kServerUseCancelCallback,
317                               context->client_metadata(), DO_NOT_USE_CALLBACK);
318 
319   // Safe to clear cancel callback even if it wasn't set
320   controller->ClearCancelCallback();
321   if (server_use_cancel_callback == MAYBE_USE_CALLBACK_EARLY_CANCEL ||
322       server_use_cancel_callback == MAYBE_USE_CALLBACK_LATE_CANCEL) {
323     EXPECT_TRUE(context->IsCancelled());
324     EXPECT_TRUE(cancel_state->callback_invoked.load(std::memory_order_relaxed));
325     delete cancel_state;
326     controller->Finish(Status::CANCELLED);
327     return;
328   }
329 
330   EXPECT_FALSE(cancel_state->callback_invoked.load(std::memory_order_relaxed));
331   delete cancel_state;
332 
333   if (request->has_param() && request->param().server_die()) {
334     gpr_log(GPR_ERROR, "The request should not reach application handler.");
335     GPR_ASSERT(0);
336   }
337   if (request->has_param() && request->param().has_expected_error()) {
338     const auto& error = request->param().expected_error();
339     controller->Finish(Status(static_cast<StatusCode>(error.code()),
340                               error.error_message(),
341                               error.binary_error_details()));
342     return;
343   }
344   int server_try_cancel = GetIntValueFromMetadata(
345       kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
346   if (server_try_cancel > DO_NOT_CANCEL) {
347     // Since this is a unary RPC, by the time this server handler is called,
348     // the 'request' message is already read from the client. So the scenarios
349     // in server_try_cancel don't make much sense. Just cancel the RPC as long
350     // as server_try_cancel is not DO_NOT_CANCEL
351     EXPECT_FALSE(context->IsCancelled());
352     context->TryCancel();
353     gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
354 
355     if (server_use_cancel_callback == DO_NOT_USE_CALLBACK) {
356       // Now wait until it's really canceled
357       LoopUntilCancelled(&alarm_, context, controller, 1000);
358     }
359     return;
360   }
361 
362   gpr_log(GPR_DEBUG, "Request message was %s", request->message().c_str());
363   response->set_message(request->message());
364   MaybeEchoDeadline(context, request, response);
365   if (host_) {
366     response->mutable_param()->set_host(*host_);
367   }
368   if (request->has_param() && request->param().client_cancel_after_us()) {
369     {
370       std::unique_lock<std::mutex> lock(mu_);
371       signal_client_ = true;
372     }
373     if (server_use_cancel_callback == DO_NOT_USE_CALLBACK) {
374       // Now wait until it's really canceled
375       LoopUntilCancelled(&alarm_, context, controller,
376                          request->param().client_cancel_after_us());
377     }
378     return;
379   } else if (request->has_param() &&
380              request->param().server_cancel_after_us()) {
381     alarm_.experimental().Set(
382         gpr_time_add(
383             gpr_now(GPR_CLOCK_REALTIME),
384             gpr_time_from_micros(request->param().server_cancel_after_us(),
385                                  GPR_TIMESPAN)),
386         [controller](bool) { controller->Finish(Status::CANCELLED); });
387     return;
388   } else if (!request->has_param() ||
389              !request->param().skip_cancelled_check()) {
390     EXPECT_FALSE(context->IsCancelled());
391   }
392 
393   if (request->has_param() && request->param().echo_metadata_initially()) {
394     const std::multimap<grpc::string_ref, grpc::string_ref>& client_metadata =
395         context->client_metadata();
396     for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
397              iter = client_metadata.begin();
398          iter != client_metadata.end(); ++iter) {
399       context->AddInitialMetadata(ToString(iter->first),
400                                   ToString(iter->second));
401     }
402     controller->SendInitialMetadata([](bool ok) { EXPECT_TRUE(ok); });
403   }
404 
405   if (request->has_param() && request->param().echo_metadata()) {
406     const std::multimap<grpc::string_ref, grpc::string_ref>& client_metadata =
407         context->client_metadata();
408     for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
409              iter = client_metadata.begin();
410          iter != client_metadata.end(); ++iter) {
411       context->AddTrailingMetadata(ToString(iter->first),
412                                    ToString(iter->second));
413     }
414     // Terminate rpc with error and debug info in trailer.
415     if (request->param().debug_info().stack_entries_size() ||
416         !request->param().debug_info().detail().empty()) {
417       grpc::string serialized_debug_info =
418           request->param().debug_info().SerializeAsString();
419       context->AddTrailingMetadata(kDebugInfoTrailerKey, serialized_debug_info);
420       controller->Finish(Status::CANCELLED);
421       return;
422     }
423   }
424   if (request->has_param() &&
425       (request->param().expected_client_identity().length() > 0 ||
426        request->param().check_auth_context())) {
427     CheckServerAuthContext(context,
428                            request->param().expected_transport_security_type(),
429                            request->param().expected_client_identity());
430   }
431   if (request->has_param() && request->param().response_message_length() > 0) {
432     response->set_message(
433         grpc::string(request->param().response_message_length(), '\0'));
434   }
435   if (request->has_param() && request->param().echo_peer()) {
436     response->mutable_param()->set_peer(context->peer());
437   }
438   controller->Finish(Status::OK);
439 }
440 
441 // Unimplemented is left unimplemented to test the returned error.
442 
RequestStream(ServerContext * context,ServerReader<EchoRequest> * reader,EchoResponse * response)443 Status TestServiceImpl::RequestStream(ServerContext* context,
444                                       ServerReader<EchoRequest>* reader,
445                                       EchoResponse* response) {
446   // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
447   // the server by calling ServerContext::TryCancel() depending on the value:
448   //   CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server reads
449   //   any message from the client
450   //   CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is
451   //   reading messages from the client
452   //   CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
453   //   all the messages from the client
454   int server_try_cancel = GetIntValueFromMetadata(
455       kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
456 
457   EchoRequest request;
458   response->set_message("");
459 
460   if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
461     ServerTryCancel(context);
462     return Status::CANCELLED;
463   }
464 
465   std::thread* server_try_cancel_thd = nullptr;
466   if (server_try_cancel == CANCEL_DURING_PROCESSING) {
467     server_try_cancel_thd =
468         new std::thread([context] { ServerTryCancel(context); });
469   }
470 
471   int num_msgs_read = 0;
472   while (reader->Read(&request)) {
473     response->mutable_message()->append(request.message());
474   }
475   gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read);
476 
477   if (server_try_cancel_thd != nullptr) {
478     server_try_cancel_thd->join();
479     delete server_try_cancel_thd;
480     return Status::CANCELLED;
481   }
482 
483   if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
484     ServerTryCancel(context);
485     return Status::CANCELLED;
486   }
487 
488   return Status::OK;
489 }
490 
491 // Return 'kNumResponseStreamMsgs' messages.
492 // TODO(yangg) make it generic by adding a parameter into EchoRequest
ResponseStream(ServerContext * context,const EchoRequest * request,ServerWriter<EchoResponse> * writer)493 Status TestServiceImpl::ResponseStream(ServerContext* context,
494                                        const EchoRequest* request,
495                                        ServerWriter<EchoResponse>* writer) {
496   // If server_try_cancel is set in the metadata, the RPC is cancelled by the
497   // server by calling ServerContext::TryCancel() depending on the value:
498   //   CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server writes
499   //   any messages to the client
500   //   CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is
501   //   writing messages to the client
502   //   CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server writes
503   //   all the messages to the client
504   int server_try_cancel = GetIntValueFromMetadata(
505       kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
506 
507   int server_coalescing_api = GetIntValueFromMetadata(
508       kServerUseCoalescingApi, context->client_metadata(), 0);
509 
510   int server_responses_to_send = GetIntValueFromMetadata(
511       kServerResponseStreamsToSend, context->client_metadata(),
512       kServerDefaultResponseStreamsToSend);
513 
514   if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
515     ServerTryCancel(context);
516     return Status::CANCELLED;
517   }
518 
519   EchoResponse response;
520   std::thread* server_try_cancel_thd = nullptr;
521   if (server_try_cancel == CANCEL_DURING_PROCESSING) {
522     server_try_cancel_thd =
523         new std::thread([context] { ServerTryCancel(context); });
524   }
525 
526   for (int i = 0; i < server_responses_to_send; i++) {
527     response.set_message(request->message() + grpc::to_string(i));
528     if (i == server_responses_to_send - 1 && server_coalescing_api != 0) {
529       writer->WriteLast(response, WriteOptions());
530     } else {
531       writer->Write(response);
532     }
533   }
534 
535   if (server_try_cancel_thd != nullptr) {
536     server_try_cancel_thd->join();
537     delete server_try_cancel_thd;
538     return Status::CANCELLED;
539   }
540 
541   if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
542     ServerTryCancel(context);
543     return Status::CANCELLED;
544   }
545 
546   return Status::OK;
547 }
548 
BidiStream(ServerContext * context,ServerReaderWriter<EchoResponse,EchoRequest> * stream)549 Status TestServiceImpl::BidiStream(
550     ServerContext* context,
551     ServerReaderWriter<EchoResponse, EchoRequest>* stream) {
552   // If server_try_cancel is set in the metadata, the RPC is cancelled by the
553   // server by calling ServerContext::TryCancel() depending on the value:
554   //   CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server reads/
555   //   writes any messages from/to the client
556   //   CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is
557   //   reading/writing messages from/to the client
558   //   CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server
559   //   reads/writes all messages from/to the client
560   int server_try_cancel = GetIntValueFromMetadata(
561       kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
562 
563   EchoRequest request;
564   EchoResponse response;
565 
566   if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
567     ServerTryCancel(context);
568     return Status::CANCELLED;
569   }
570 
571   std::thread* server_try_cancel_thd = nullptr;
572   if (server_try_cancel == CANCEL_DURING_PROCESSING) {
573     server_try_cancel_thd =
574         new std::thread([context] { ServerTryCancel(context); });
575   }
576 
577   // kServerFinishAfterNReads suggests after how many reads, the server should
578   // write the last message and send status (coalesced using WriteLast)
579   int server_write_last = GetIntValueFromMetadata(
580       kServerFinishAfterNReads, context->client_metadata(), 0);
581 
582   int read_counts = 0;
583   while (stream->Read(&request)) {
584     read_counts++;
585     gpr_log(GPR_INFO, "recv msg %s", request.message().c_str());
586     response.set_message(request.message());
587     if (read_counts == server_write_last) {
588       stream->WriteLast(response, WriteOptions());
589     } else {
590       stream->Write(response);
591     }
592   }
593 
594   if (server_try_cancel_thd != nullptr) {
595     server_try_cancel_thd->join();
596     delete server_try_cancel_thd;
597     return Status::CANCELLED;
598   }
599 
600   if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
601     ServerTryCancel(context);
602     return Status::CANCELLED;
603   }
604 
605   return Status::OK;
606 }
607 
608 experimental::ServerReadReactor<EchoRequest, EchoResponse>*
RequestStream()609 CallbackTestServiceImpl::RequestStream() {
610   class Reactor : public ::grpc::experimental::ServerReadReactor<EchoRequest,
611                                                                  EchoResponse> {
612    public:
613     Reactor() {}
614     void OnStarted(ServerContext* context, EchoResponse* response) override {
615       // Assign ctx_ and response_ as late as possible to increase likelihood of
616       // catching any races
617 
618       // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
619       // the server by calling ServerContext::TryCancel() depending on the
620       // value:
621       //   CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
622       //   reads any message from the client CANCEL_DURING_PROCESSING: The RPC
623       //   is cancelled while the server is reading messages from the client
624       //   CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
625       //   all the messages from the client
626       server_try_cancel_ = GetIntValueFromMetadata(
627           kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
628 
629       response->set_message("");
630 
631       if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
632         ServerTryCancelNonblocking(context);
633         ctx_ = context;
634       } else {
635         if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
636           context->TryCancel();
637           // Don't wait for it here
638         }
639         ctx_ = context;
640         response_ = response;
641         StartRead(&request_);
642       }
643 
644       on_started_done_ = true;
645     }
646     void OnDone() override { delete this; }
647     void OnCancel() override {
648       EXPECT_TRUE(on_started_done_);
649       EXPECT_TRUE(ctx_->IsCancelled());
650       FinishOnce(Status::CANCELLED);
651     }
652     void OnReadDone(bool ok) override {
653       if (ok) {
654         response_->mutable_message()->append(request_.message());
655         num_msgs_read_++;
656         StartRead(&request_);
657       } else {
658         gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read_);
659 
660         if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
661           // Let OnCancel recover this
662           return;
663         }
664         if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
665           ServerTryCancelNonblocking(ctx_);
666           return;
667         }
668         FinishOnce(Status::OK);
669       }
670     }
671 
672    private:
673     void FinishOnce(const Status& s) {
674       std::lock_guard<std::mutex> l(finish_mu_);
675       if (!finished_) {
676         Finish(s);
677         finished_ = true;
678       }
679     }
680 
681     ServerContext* ctx_;
682     EchoResponse* response_;
683     EchoRequest request_;
684     int num_msgs_read_{0};
685     int server_try_cancel_;
686     std::mutex finish_mu_;
687     bool finished_{false};
688     bool on_started_done_{false};
689   };
690 
691   return new Reactor;
692 }
693 
694 // Return 'kNumResponseStreamMsgs' messages.
695 // TODO(yangg) make it generic by adding a parameter into EchoRequest
696 experimental::ServerWriteReactor<EchoRequest, EchoResponse>*
ResponseStream()697 CallbackTestServiceImpl::ResponseStream() {
698   class Reactor
699       : public ::grpc::experimental::ServerWriteReactor<EchoRequest,
700                                                         EchoResponse> {
701    public:
702     Reactor() {}
703     void OnStarted(ServerContext* context,
704                    const EchoRequest* request) override {
705       // Assign ctx_ and request_ as late as possible to increase likelihood of
706       // catching any races
707 
708       // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
709       // the server by calling ServerContext::TryCancel() depending on the
710       // value:
711       //   CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
712       //   reads any message from the client CANCEL_DURING_PROCESSING: The RPC
713       //   is cancelled while the server is reading messages from the client
714       //   CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
715       //   all the messages from the client
716       server_try_cancel_ = GetIntValueFromMetadata(
717           kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
718       server_coalescing_api_ = GetIntValueFromMetadata(
719           kServerUseCoalescingApi, context->client_metadata(), 0);
720       server_responses_to_send_ = GetIntValueFromMetadata(
721           kServerResponseStreamsToSend, context->client_metadata(),
722           kServerDefaultResponseStreamsToSend);
723       if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
724         ServerTryCancelNonblocking(context);
725         ctx_ = context;
726       } else {
727         if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
728           context->TryCancel();
729         }
730         ctx_ = context;
731         request_ = request;
732         if (num_msgs_sent_ < server_responses_to_send_) {
733           NextWrite();
734         }
735       }
736       on_started_done_ = true;
737     }
738     void OnDone() override { delete this; }
739     void OnCancel() override {
740       EXPECT_TRUE(on_started_done_);
741       EXPECT_TRUE(ctx_->IsCancelled());
742       FinishOnce(Status::CANCELLED);
743     }
744     void OnWriteDone(bool ok) override {
745       if (num_msgs_sent_ < server_responses_to_send_) {
746         NextWrite();
747       } else if (server_coalescing_api_ != 0) {
748         // We would have already done Finish just after the WriteLast
749       } else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
750         // Let OnCancel recover this
751       } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
752         ServerTryCancelNonblocking(ctx_);
753       } else {
754         FinishOnce(Status::OK);
755       }
756     }
757 
758    private:
759     void FinishOnce(const Status& s) {
760       std::lock_guard<std::mutex> l(finish_mu_);
761       if (!finished_) {
762         Finish(s);
763         finished_ = true;
764       }
765     }
766 
767     void NextWrite() {
768       response_.set_message(request_->message() +
769                             grpc::to_string(num_msgs_sent_));
770       if (num_msgs_sent_ == server_responses_to_send_ - 1 &&
771           server_coalescing_api_ != 0) {
772         num_msgs_sent_++;
773         StartWriteLast(&response_, WriteOptions());
774         // If we use WriteLast, we shouldn't wait before attempting Finish
775         FinishOnce(Status::OK);
776       } else {
777         num_msgs_sent_++;
778         StartWrite(&response_);
779       }
780     }
781     ServerContext* ctx_;
782     const EchoRequest* request_;
783     EchoResponse response_;
784     int num_msgs_sent_{0};
785     int server_try_cancel_;
786     int server_coalescing_api_;
787     int server_responses_to_send_;
788     std::mutex finish_mu_;
789     bool finished_{false};
790     bool on_started_done_{false};
791   };
792   return new Reactor;
793 }
794 
795 experimental::ServerBidiReactor<EchoRequest, EchoResponse>*
BidiStream()796 CallbackTestServiceImpl::BidiStream() {
797   class Reactor : public ::grpc::experimental::ServerBidiReactor<EchoRequest,
798                                                                  EchoResponse> {
799    public:
800     Reactor() {}
801     void OnStarted(ServerContext* context) override {
802       // Assign ctx_ as late as possible to increase likelihood of catching any
803       // races
804 
805       // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
806       // the server by calling ServerContext::TryCancel() depending on the
807       // value:
808       //   CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
809       //   reads any message from the client CANCEL_DURING_PROCESSING: The RPC
810       //   is cancelled while the server is reading messages from the client
811       //   CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
812       //   all the messages from the client
813       server_try_cancel_ = GetIntValueFromMetadata(
814           kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
815       server_write_last_ = GetIntValueFromMetadata(
816           kServerFinishAfterNReads, context->client_metadata(), 0);
817       if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
818         ServerTryCancelNonblocking(context);
819         ctx_ = context;
820       } else {
821         if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
822           context->TryCancel();
823         }
824         ctx_ = context;
825         StartRead(&request_);
826       }
827       on_started_done_ = true;
828     }
829     void OnDone() override { delete this; }
830     void OnCancel() override {
831       EXPECT_TRUE(on_started_done_);
832       EXPECT_TRUE(ctx_->IsCancelled());
833       FinishOnce(Status::CANCELLED);
834     }
835     void OnReadDone(bool ok) override {
836       if (ok) {
837         num_msgs_read_++;
838         gpr_log(GPR_INFO, "recv msg %s", request_.message().c_str());
839         response_.set_message(request_.message());
840         if (num_msgs_read_ == server_write_last_) {
841           StartWriteLast(&response_, WriteOptions());
842           // If we use WriteLast, we shouldn't wait before attempting Finish
843         } else {
844           StartWrite(&response_);
845           return;
846         }
847       }
848 
849       if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
850         // Let OnCancel handle this
851       } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
852         ServerTryCancelNonblocking(ctx_);
853       } else {
854         FinishOnce(Status::OK);
855       }
856     }
857     void OnWriteDone(bool ok) override {
858       std::lock_guard<std::mutex> l(finish_mu_);
859       if (!finished_) {
860         StartRead(&request_);
861       }
862     }
863 
864    private:
865     void FinishOnce(const Status& s) {
866       std::lock_guard<std::mutex> l(finish_mu_);
867       if (!finished_) {
868         Finish(s);
869         finished_ = true;
870       }
871     }
872 
873     ServerContext* ctx_;
874     EchoRequest request_;
875     EchoResponse response_;
876     int num_msgs_read_{0};
877     int server_try_cancel_;
878     int server_write_last_;
879     std::mutex finish_mu_;
880     bool finished_{false};
881     bool on_started_done_{false};
882   };
883 
884   return new Reactor;
885 }
886 
887 }  // namespace testing
888 }  // namespace grpc
889