1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless enforced by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module async_test;
20 
21 import core.atomic;
22 import core.sync.condition : Condition;
23 import core.sync.mutex : Mutex;
24 import core.thread : dur, Thread, ThreadGroup;
25 import std.conv : text;
26 import std.datetime;
27 import std.getopt;
28 import std.exception : collectException, enforce;
29 import std.parallelism : TaskPool;
30 import std.stdio;
31 import std.string;
32 import std.variant : Variant;
33 import thrift.base;
34 import thrift.async.base;
35 import thrift.async.libevent;
36 import thrift.async.socket;
37 import thrift.async.ssl;
38 import thrift.codegen.async_client;
39 import thrift.codegen.async_client_pool;
40 import thrift.codegen.base;
41 import thrift.codegen.processor;
42 import thrift.protocol.base;
43 import thrift.protocol.binary;
44 import thrift.server.base;
45 import thrift.server.simple;
46 import thrift.server.transport.socket;
47 import thrift.server.transport.ssl;
48 import thrift.transport.base;
49 import thrift.transport.buffered;
50 import thrift.transport.ssl;
51 import thrift.util.cancellation;
52 
version(Posix)53 version (Posix) {
54   import core.stdc.signal;
55   import core.sys.posix.signal;
56 
57   // Disable SIGPIPE because SSL server will write to broken socket after
58   // client disconnected (see TSSLSocket docs).
59   shared static this() {
60     signal(SIGPIPE, SIG_IGN);
61   }
62 }
63 
64 interface AsyncTest {
65   string echo(string value);
66   string delayedEcho(string value, long milliseconds);
67 
68   void fail(string reason);
69   void delayedFail(string reason, long milliseconds);
70 
71   enum methodMeta = [
72     TMethodMeta("fail", [], [TExceptionMeta("ate", 1, "AsyncTestException")]),
73     TMethodMeta("delayedFail", [], [TExceptionMeta("ate", 1, "AsyncTestException")])
74   ];
75   alias .AsyncTestException AsyncTestException;
76 }
77 
78 class AsyncTestException : TException {
79   string reason;
80   mixin TStructHelpers!();
81 }
82 
main(string[]args)83 void main(string[] args) {
84   ushort port = 9090;
85   ushort managerCount = 2;
86   ushort serversPerManager = 5;
87   ushort threadsPerServer = 10;
88   uint iterations = 10;
89   bool ssl;
90   bool trace;
91 
92   getopt(args,
93     "iterations", &iterations,
94     "managers", &managerCount,
95     "port", &port,
96     "servers-per-manager", &serversPerManager,
97     "ssl", &ssl,
98     "threads-per-server", &threadsPerServer,
99     "trace", &trace,
100   );
101 
102   TTransportFactory clientTransportFactory;
103   TSSLContext serverSSLContext;
104   if (ssl) {
105     auto clientSSLContext = new TSSLContext();
106     with (clientSSLContext) {
107       authenticate = true;
108       ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
109       loadTrustedCertificates("../../../test/keys/CA.pem");
110     }
111     clientTransportFactory = new TAsyncSSLSocketFactory(clientSSLContext);
112 
113     serverSSLContext = new TSSLContext();
114     with (serverSSLContext) {
115       serverSide = true;
116       ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
117       loadCertificate("../../../test/keys/server.crt");
118       loadPrivateKey("../../../test/keys/server.key");
119     }
120   } else {
121     clientTransportFactory = new TBufferedTransportFactory;
122   }
123 
124 
125   auto serverCancel = new TCancellationOrigin;
126   scope(exit) {
127     writeln("Triggering server shutdown...");
128     serverCancel.trigger();
129     writeln("done.");
130   }
131 
132   auto managers = new TLibeventAsyncManager[managerCount];
133   scope (exit) foreach (ref m; managers) destroy(m);
134 
135   auto clientsThreads = new ThreadGroup;
136   foreach (managerIndex, ref manager; managers) {
137     manager = new TLibeventAsyncManager;
138     foreach (serverIndex; 0 .. serversPerManager) {
139       auto currentPort = cast(ushort)
140         (port + managerIndex * serversPerManager + serverIndex);
141 
142       // Start the server and wait until it is up and running.
143       auto servingMutex = new Mutex;
144       auto servingCondition = new Condition(servingMutex);
145       auto handler = new PreServeNotifyHandler(servingMutex, servingCondition);
146       synchronized (servingMutex) {
147         (new ServerThread!TSimpleServer(currentPort, serverSSLContext, trace,
148           serverCancel, handler)).start();
149         servingCondition.wait();
150       }
151 
152       // We only run the timing tests for the first server on each async
153       // manager, so that we don't get spurious timing errors becaue of
154       // ordering issues.
155       auto runTimingTests = (serverIndex == 0);
156 
157       auto c = new ClientsThread(manager, currentPort, clientTransportFactory,
158         threadsPerServer, iterations, runTimingTests, trace);
159       clientsThreads.add(c);
160       c.start();
161     }
162   }
163   clientsThreads.joinAll();
164 }
165 
166 class AsyncTestHandler : AsyncTest {
this(bool trace)167   this(bool trace) {
168     trace_ = trace;
169   }
170 
echo(string value)171   override string echo(string value) {
172     if (trace_) writefln(`echo("%s")`, value);
173     return value;
174   }
175 
delayedEcho(string value,long milliseconds)176   override string delayedEcho(string value, long milliseconds) {
177     if (trace_) writef(`delayedEcho("%s", %s ms)... `, value, milliseconds);
178     Thread.sleep(dur!"msecs"(milliseconds));
179     if (trace_) writeln("returning.");
180 
181     return value;
182   }
183 
fail(string reason)184   override void fail(string reason) {
185     if (trace_) writefln(`fail("%s")`, reason);
186     auto ate = new AsyncTestException;
187     ate.reason = reason;
188     throw ate;
189   }
190 
delayedFail(string reason,long milliseconds)191   override void delayedFail(string reason, long milliseconds) {
192     if (trace_) writef(`delayedFail("%s", %s ms)... `, reason, milliseconds);
193     Thread.sleep(dur!"msecs"(milliseconds));
194     if (trace_) writeln("returning.");
195 
196     auto ate = new AsyncTestException;
197     ate.reason = reason;
198     throw ate;
199   }
200 
201 private:
202   bool trace_;
203   AsyncTestException ate_;
204 }
205 
206 class PreServeNotifyHandler : TServerEventHandler {
this(Mutex servingMutex,Condition servingCondition)207   this(Mutex servingMutex, Condition servingCondition) {
208     servingMutex_ = servingMutex;
209     servingCondition_ = servingCondition;
210   }
211 
preServe()212   void preServe() {
213     synchronized (servingMutex_) {
214       servingCondition_.notifyAll();
215     }
216   }
createContext(TProtocol input,TProtocol output)217   Variant createContext(TProtocol input, TProtocol output) { return Variant.init; }
deleteContext(Variant serverContext,TProtocol input,TProtocol output)218   void deleteContext(Variant serverContext, TProtocol input, TProtocol output) {}
preProcess(Variant serverContext,TTransport transport)219   void preProcess(Variant serverContext, TTransport transport) {}
220 
221 private:
222   Mutex servingMutex_;
223   Condition servingCondition_;
224 }
225 
ServerThread(ServerType)226 class ServerThread(ServerType) : Thread {
227   this(ushort port, TSSLContext sslContext, bool trace,
228     TCancellation cancellation, TServerEventHandler eventHandler
229   ) {
230     port_ = port;
231     sslContext_ = sslContext;
232     trace_ = trace;
233     cancellation_ = cancellation;
234     eventHandler_ = eventHandler;
235 
236     super(&run);
237   }
238 
239   void run() {
240     TServerSocket serverSocket;
241     if (sslContext_) {
242       serverSocket = new TSSLServerSocket(port_, sslContext_);
243     } else {
244       serverSocket = new TServerSocket(port_);
245     }
246     auto transportFactory = new TBufferedTransportFactory;
247     auto protocolFactory = new TBinaryProtocolFactory!();
248     auto processor = new TServiceProcessor!AsyncTest(new AsyncTestHandler(trace_));
249 
250     auto server = new ServerType(processor, serverSocket, transportFactory,
251       protocolFactory);
252     server.eventHandler = eventHandler_;
253 
254     writefln("Starting server on port %s...", port_);
255     server.serve(cancellation_);
256     writefln("Server thread on port %s done.", port_);
257   }
258 
259 private:
260   ushort port_;
261   bool trace_;
262   TCancellation cancellation_;
263   TSSLContext sslContext_;
264   TServerEventHandler eventHandler_;
265 }
266 
267 class ClientsThread : Thread {
this(TAsyncSocketManager manager,ushort port,TTransportFactory tf,ushort threads,uint iterations,bool runTimingTests,bool trace)268   this(TAsyncSocketManager manager, ushort port, TTransportFactory tf,
269     ushort threads, uint iterations, bool runTimingTests, bool trace
270   ) {
271     manager_ = manager;
272     port_ = port;
273     transportFactory_ = tf;
274     threads_ = threads;
275     iterations_ = iterations;
276     runTimingTests_ = runTimingTests;
277     trace_ = trace;
278     super(&run);
279   }
280 
run()281   void run() {
282     auto transport = new TAsyncSocket(manager_, "localhost", port_);
283 
284     {
285       auto client = new TAsyncClient!AsyncTest(
286         transport,
287         transportFactory_,
288         new TBinaryProtocolFactory!()
289       );
290       transport.open();
291       auto clientThreads = new ThreadGroup;
292       foreach (clientId; 0 .. threads_) {
293         clientThreads.create({
294           auto c = clientId;
295           return {
296             foreach (i; 0 .. iterations_) {
297               immutable id = text(port_, ":", c, ":", i);
298 
299               {
300                 if (trace_) writefln(`Calling echo("%s")... `, id);
301                 auto a = client.echo(id);
302                 enforce(a == id);
303                 if (trace_) writefln(`echo("%s") done.`, id);
304               }
305 
306               {
307                 if (trace_) writefln(`Calling fail("%s")... `, id);
308                 auto a = cast(AsyncTestException)collectException(client.fail(id).waitGet());
309                 enforce(a && a.reason == id);
310                 if (trace_) writefln(`fail("%s") done.`, id);
311               }
312             }
313           };
314         }());
315       }
316       clientThreads.joinAll();
317       transport.close();
318     }
319 
320     if (runTimingTests_) {
321       auto client = new TAsyncClient!AsyncTest(
322         transport,
323         transportFactory_,
324         new TBinaryProtocolFactory!TBufferedTransport
325       );
326 
327       // Temporarily redirect error logs to stdout, as SSL errors on the server
328       // side are expected when the client terminates aburptly (as is the case
329       // in the timeout test).
330       auto oldErrorLogSink = g_errorLogSink;
331       g_errorLogSink = g_infoLogSink;
332       scope (exit) g_errorLogSink = oldErrorLogSink;
333 
334       foreach (i; 0 .. iterations_) {
335         transport.open();
336 
337         immutable id = text(port_, ":", i);
338 
339         {
340           if (trace_) writefln(`Calling delayedEcho("%s", 100 ms)...`, id);
341           auto a = client.delayedEcho(id, 100);
342           enforce(!a.completion.wait(dur!"usecs"(1)),
343             text("wait() succeeded early (", a.get(), ", ", id, ")."));
344           enforce(!a.completion.wait(dur!"usecs"(1)),
345             text("wait() succeeded early (", a.get(), ", ", id, ")."));
346           enforce(a.completion.wait(dur!"msecs"(200)),
347             text("wait() didn't succeed as expected (", id, ")."));
348           enforce(a.get() == id);
349           if (trace_) writefln(`... delayedEcho("%s") done.`, id);
350         }
351 
352         {
353           if (trace_) writefln(`Calling delayedFail("%s", 100 ms)... `, id);
354           auto a = client.delayedFail(id, 100);
355           enforce(!a.completion.wait(dur!"usecs"(1)),
356             text("wait() succeeded early (", id, ", ", collectException(a.get()), ")."));
357           enforce(!a.completion.wait(dur!"usecs"(1)),
358             text("wait() succeeded early (", id, ", ", collectException(a.get()), ")."));
359           enforce(a.completion.wait(dur!"msecs"(200)),
360             text("wait() didn't succeed as expected (", id, ")."));
361           auto e = cast(AsyncTestException)collectException(a.get());
362           enforce(e && e.reason == id);
363           if (trace_) writefln(`... delayedFail("%s") done.`, id);
364         }
365 
366         {
367           transport.recvTimeout = dur!"msecs"(50);
368 
369           if (trace_) write(`Calling delayedEcho("socketTimeout", 100 ms)... `);
370           auto a = client.delayedEcho("socketTimeout", 100);
371           auto e = cast(TTransportException)collectException(a.waitGet());
372           enforce(e, text("Operation didn't fail as expected (", id, ")."));
373           enforce(e.type == TTransportException.Type.TIMED_OUT,
374             text("Wrong timeout exception type (", id, "): ", e));
375           if (trace_) writeln(`timed out as expected.`);
376 
377           // Wait until the server thread reset before the next iteration.
378           Thread.sleep(dur!"msecs"(50));
379           transport.recvTimeout = dur!"hnsecs"(0);
380         }
381 
382         transport.close();
383       }
384     }
385 
386     writefln("Clients thread for port %s done.", port_);
387   }
388 
389   TAsyncSocketManager manager_;
390   ushort port_;
391   TTransportFactory transportFactory_;
392   ushort threads_;
393   uint iterations_;
394   bool runTimingTests_;
395   bool trace_;
396 }
397