1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <condition_variable>
20 
21 #include <grpcpp/channel.h>
22 
23 #include "src/proto/grpc/testing/echo.grpc.pb.h"
24 #include "test/cpp/util/string_ref_helper.h"
25 
26 #include <gtest/gtest.h>
27 
28 namespace grpc {
29 namespace testing {
30 /* This interceptor does nothing. Just keeps a global count on the number of
31  * times it was invoked. */
32 class DummyInterceptor : public experimental::Interceptor {
33  public:
DummyInterceptor()34   DummyInterceptor() {}
35 
Intercept(experimental::InterceptorBatchMethods * methods)36   void Intercept(experimental::InterceptorBatchMethods* methods) override {
37     if (methods->QueryInterceptionHookPoint(
38             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
39       num_times_run_++;
40     } else if (methods->QueryInterceptionHookPoint(
41                    experimental::InterceptionHookPoints::
42                        POST_RECV_INITIAL_METADATA)) {
43       num_times_run_reverse_++;
44     } else if (methods->QueryInterceptionHookPoint(
45                    experimental::InterceptionHookPoints::PRE_SEND_CANCEL)) {
46       num_times_cancel_++;
47     }
48     methods->Proceed();
49   }
50 
Reset()51   static void Reset() {
52     num_times_run_.store(0);
53     num_times_run_reverse_.store(0);
54     num_times_cancel_.store(0);
55   }
56 
GetNumTimesRun()57   static int GetNumTimesRun() {
58     EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
59     return num_times_run_.load();
60   }
61 
GetNumTimesCancel()62   static int GetNumTimesCancel() { return num_times_cancel_.load(); }
63 
64  private:
65   static std::atomic<int> num_times_run_;
66   static std::atomic<int> num_times_run_reverse_;
67   static std::atomic<int> num_times_cancel_;
68 };
69 
70 class DummyInterceptorFactory
71     : public experimental::ClientInterceptorFactoryInterface,
72       public experimental::ServerInterceptorFactoryInterface {
73  public:
CreateClientInterceptor(experimental::ClientRpcInfo *)74   experimental::Interceptor* CreateClientInterceptor(
75       experimental::ClientRpcInfo* /*info*/) override {
76     return new DummyInterceptor();
77   }
78 
CreateServerInterceptor(experimental::ServerRpcInfo *)79   experimental::Interceptor* CreateServerInterceptor(
80       experimental::ServerRpcInfo* /*info*/) override {
81     return new DummyInterceptor();
82   }
83 };
84 
85 /* This interceptor factory returns nullptr on interceptor creation */
86 class NullInterceptorFactory
87     : public experimental::ClientInterceptorFactoryInterface,
88       public experimental::ServerInterceptorFactoryInterface {
89  public:
CreateClientInterceptor(experimental::ClientRpcInfo *)90   experimental::Interceptor* CreateClientInterceptor(
91       experimental::ClientRpcInfo* /*info*/) override {
92     return nullptr;
93   }
94 
CreateServerInterceptor(experimental::ServerRpcInfo *)95   experimental::Interceptor* CreateServerInterceptor(
96       experimental::ServerRpcInfo* /*info*/) override {
97     return nullptr;
98   }
99 };
100 
101 class EchoTestServiceStreamingImpl : public EchoTestService::Service {
102  public:
~EchoTestServiceStreamingImpl()103   ~EchoTestServiceStreamingImpl() override {}
104 
Echo(ServerContext * context,const EchoRequest * request,EchoResponse * response)105   Status Echo(ServerContext* context, const EchoRequest* request,
106               EchoResponse* response) override {
107     auto client_metadata = context->client_metadata();
108     for (const auto& pair : client_metadata) {
109       context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
110     }
111     response->set_message(request->message());
112     return Status::OK;
113   }
114 
BidiStream(ServerContext * context,grpc::ServerReaderWriter<EchoResponse,EchoRequest> * stream)115   Status BidiStream(
116       ServerContext* context,
117       grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
118     EchoRequest req;
119     EchoResponse resp;
120     auto client_metadata = context->client_metadata();
121     for (const auto& pair : client_metadata) {
122       context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
123     }
124 
125     while (stream->Read(&req)) {
126       resp.set_message(req.message());
127       EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions()));
128     }
129     return Status::OK;
130   }
131 
RequestStream(ServerContext * context,ServerReader<EchoRequest> * reader,EchoResponse * resp)132   Status RequestStream(ServerContext* context,
133                        ServerReader<EchoRequest>* reader,
134                        EchoResponse* resp) override {
135     auto client_metadata = context->client_metadata();
136     for (const auto& pair : client_metadata) {
137       context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
138     }
139 
140     EchoRequest req;
141     string response_str = "";
142     while (reader->Read(&req)) {
143       response_str += req.message();
144     }
145     resp->set_message(response_str);
146     return Status::OK;
147   }
148 
ResponseStream(ServerContext * context,const EchoRequest * req,ServerWriter<EchoResponse> * writer)149   Status ResponseStream(ServerContext* context, const EchoRequest* req,
150                         ServerWriter<EchoResponse>* writer) override {
151     auto client_metadata = context->client_metadata();
152     for (const auto& pair : client_metadata) {
153       context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
154     }
155 
156     EchoResponse resp;
157     resp.set_message(req->message());
158     for (int i = 0; i < 10; i++) {
159       EXPECT_TRUE(writer->Write(resp));
160     }
161     return Status::OK;
162   }
163 };
164 
165 constexpr int kNumStreamingMessages = 10;
166 
167 void MakeCall(const std::shared_ptr<Channel>& channel);
168 
169 void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel);
170 
171 void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel);
172 
173 void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel);
174 
175 void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel);
176 
177 void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel);
178 
179 void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel);
180 
181 void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel);
182 
183 void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
184 
185 bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
186                    const string& key, const string& value);
187 
188 bool CheckMetadata(const std::multimap<std::string, std::string>& map,
189                    const string& key, const string& value);
190 
191 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
192 CreateDummyClientInterceptors();
193 
tag(int i)194 inline void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
detag(void * p)195 inline int detag(void* p) {
196   return static_cast<int>(reinterpret_cast<intptr_t>(p));
197 }
198 
199 class Verifier {
200  public:
Verifier()201   Verifier() : lambda_run_(false) {}
202   // Expect sets the expected ok value for a specific tag
Expect(int i,bool expect_ok)203   Verifier& Expect(int i, bool expect_ok) {
204     return ExpectUnless(i, expect_ok, false);
205   }
206   // ExpectUnless sets the expected ok value for a specific tag
207   // unless the tag was already marked seen (as a result of ExpectMaybe)
ExpectUnless(int i,bool expect_ok,bool seen)208   Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
209     if (!seen) {
210       expectations_[tag(i)] = expect_ok;
211     }
212     return *this;
213   }
214   // ExpectMaybe sets the expected ok value for a specific tag, but does not
215   // require it to appear
216   // If it does, sets *seen to true
ExpectMaybe(int i,bool expect_ok,bool * seen)217   Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
218     if (!*seen) {
219       maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
220     }
221     return *this;
222   }
223 
224   // Next waits for 1 async tag to complete, checks its
225   // expectations, and returns the tag
Next(CompletionQueue * cq,bool ignore_ok)226   int Next(CompletionQueue* cq, bool ignore_ok) {
227     bool ok;
228     void* got_tag;
229     EXPECT_TRUE(cq->Next(&got_tag, &ok));
230     GotTag(got_tag, ok, ignore_ok);
231     return detag(got_tag);
232   }
233 
234   template <typename T>
DoOnceThenAsyncNext(CompletionQueue * cq,void ** got_tag,bool * ok,T deadline,std::function<void (void)> lambda)235   CompletionQueue::NextStatus DoOnceThenAsyncNext(
236       CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
237       std::function<void(void)> lambda) {
238     if (lambda_run_) {
239       return cq->AsyncNext(got_tag, ok, deadline);
240     } else {
241       lambda_run_ = true;
242       return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
243     }
244   }
245 
246   // Verify keeps calling Next until all currently set
247   // expected tags are complete
Verify(CompletionQueue * cq)248   void Verify(CompletionQueue* cq) { Verify(cq, false); }
249 
250   // This version of Verify allows optionally ignoring the
251   // outcome of the expectation
Verify(CompletionQueue * cq,bool ignore_ok)252   void Verify(CompletionQueue* cq, bool ignore_ok) {
253     GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
254     while (!expectations_.empty()) {
255       Next(cq, ignore_ok);
256     }
257   }
258 
259   // This version of Verify stops after a certain deadline, and uses the
260   // DoThenAsyncNext API
261   // to call the lambda
Verify(CompletionQueue * cq,std::chrono::system_clock::time_point deadline,const std::function<void (void)> & lambda)262   void Verify(CompletionQueue* cq,
263               std::chrono::system_clock::time_point deadline,
264               const std::function<void(void)>& lambda) {
265     if (expectations_.empty()) {
266       bool ok;
267       void* got_tag;
268       EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
269                 CompletionQueue::TIMEOUT);
270     } else {
271       while (!expectations_.empty()) {
272         bool ok;
273         void* got_tag;
274         EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
275                   CompletionQueue::GOT_EVENT);
276         GotTag(got_tag, ok, false);
277       }
278     }
279   }
280 
281  private:
GotTag(void * got_tag,bool ok,bool ignore_ok)282   void GotTag(void* got_tag, bool ok, bool ignore_ok) {
283     auto it = expectations_.find(got_tag);
284     if (it != expectations_.end()) {
285       if (!ignore_ok) {
286         EXPECT_EQ(it->second, ok);
287       }
288       expectations_.erase(it);
289     } else {
290       auto it2 = maybe_expectations_.find(got_tag);
291       if (it2 != maybe_expectations_.end()) {
292         if (it2->second.seen != nullptr) {
293           EXPECT_FALSE(*it2->second.seen);
294           *it2->second.seen = true;
295         }
296         if (!ignore_ok) {
297           EXPECT_EQ(it2->second.ok, ok);
298         }
299       } else {
300         gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag);
301         abort();
302       }
303     }
304   }
305 
306   struct MaybeExpect {
307     bool ok;
308     bool* seen;
309   };
310 
311   std::map<void*, bool> expectations_;
312   std::map<void*, MaybeExpect> maybe_expectations_;
313   bool lambda_run_;
314 };
315 
316 }  // namespace testing
317 }  // namespace grpc
318