1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module thrift_test_client;
20 
21 import std.conv;
22 import std.datetime;
23 import std.exception : enforce;
24 import std.getopt;
25 import std.stdio;
26 import std.string;
27 import std.traits;
28 import thrift.base;
29 import thrift.codegen.client;
30 import thrift.protocol.base;
31 import thrift.protocol.binary;
32 import thrift.protocol.compact;
33 import thrift.protocol.json;
34 import thrift.transport.base;
35 import thrift.transport.buffered;
36 import thrift.transport.framed;
37 import thrift.transport.http;
38 import thrift.transport.socket;
39 import thrift.transport.ssl;
40 import thrift.util.hashset;
41 
42 import thrift_test_common;
43 import thrift.test.ThriftTest;
44 import thrift.test.ThriftTest_types;
45 
46 enum TransportType {
47   buffered,
48   framed,
49   http,
50   raw
51 }
52 
createProtocol(T)53 TProtocol createProtocol(T)(T trans, ProtocolType type) {
54   final switch (type) {
55     case ProtocolType.binary:
56       return tBinaryProtocol(trans);
57     case ProtocolType.compact:
58       return tCompactProtocol(trans);
59     case ProtocolType.json:
60       return tJsonProtocol(trans);
61   }
62 }
63 
main(string[]args)64 void main(string[] args) {
65   string host = "localhost";
66   ushort port = 9090;
67   uint numTests = 1;
68   bool ssl;
69   ProtocolType protocolType;
70   TransportType transportType;
71   bool trace;
72 
73   getopt(args,
74     "numTests|n", &numTests,
75     "protocol", &protocolType,
76     "ssl", &ssl,
77     "transport", &transportType,
78     "trace", &trace,
79     "port", &port,
80     "host", (string _, string value) {
81       auto parts = split(value, ":");
82       if (parts.length > 1) {
83         // IPv6 addresses can contain colons, so take the last part for the
84         // port.
85         host = join(parts[0 .. $ - 1], ":");
86         port = to!ushort(parts[$ - 1]);
87       } else {
88         host = value;
89       }
90     }
91   );
92   port = to!ushort(port);
93 
94   TSocket socket;
95   if (ssl) {
96     auto sslContext = new TSSLContext();
97     sslContext.ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
98     sslContext.authenticate = true;
99     sslContext.loadTrustedCertificates("../../../test/keys/CA.pem");
100     socket = new TSSLSocket(sslContext, host, port);
101   } else {
102     socket = new TSocket(host, port);
103   }
104 
105   TProtocol protocol;
106   final switch (transportType) {
107     case TransportType.buffered:
108       protocol = createProtocol(new TBufferedTransport(socket), protocolType);
109       break;
110     case TransportType.framed:
111       protocol = createProtocol(new TFramedTransport(socket), protocolType);
112       break;
113     case TransportType.http:
114       protocol = createProtocol(
115         new TClientHttpTransport(socket, host, "/service"), protocolType);
116       break;
117     case TransportType.raw:
118       protocol = createProtocol(socket, protocolType);
119       break;
120   }
121 
122   auto client = tClient!ThriftTest(protocol);
123 
124   ulong time_min;
125   ulong time_max;
126   ulong time_tot;
127 
128   StopWatch sw;
129   foreach(test; 0 .. numTests) {
130     sw.start();
131 
132     protocol.transport.open();
133 
134     if (trace) writefln("Test #%s, connect %s:%s", test + 1, host, port);
135 
136     if (trace) write("testVoid()");
137     client.testVoid();
138     if (trace) writeln(" = void");
139 
140     if (trace) write("testString(\"Test\")");
141     string s = client.testString("Test");
142     if (trace) writefln(" = \"%s\"", s);
143     enforce(s == "Test");
144 
145     if (trace) write("testByte(1)");
146     byte u8 = client.testByte(1);
147     if (trace) writefln(" = %s", u8);
148     enforce(u8 == 1);
149 
150     if (trace) write("testI32(-1)");
151     int i32 = client.testI32(-1);
152     if (trace) writefln(" = %s", i32);
153     enforce(i32 == -1);
154 
155     if (trace) write("testI64(-34359738368)");
156     long i64 = client.testI64(-34359738368L);
157     if (trace) writefln(" = %s", i64);
158     enforce(i64 == -34359738368L);
159 
160     if (trace) write("testDouble(-5.2098523)");
161     double dub = client.testDouble(-5.2098523);
162     if (trace) writefln(" = %s", dub);
163     enforce(dub == -5.2098523);
164 
165 	// TODO: add testBinary() call
166 
167     Xtruct out1;
168     out1.string_thing = "Zero";
169     out1.byte_thing = 1;
170     out1.i32_thing = -3;
171     out1.i64_thing = -5;
172     if (trace) writef("testStruct(%s)", out1);
173     auto in1 = client.testStruct(out1);
174     if (trace) writefln(" = %s", in1);
175     enforce(in1 == out1);
176 
177     if (trace) write("testNest({1, {\"Zero\", 1, -3, -5}), 5}");
178     Xtruct2 out2;
179     out2.byte_thing = 1;
180     out2.struct_thing = out1;
181     out2.i32_thing = 5;
182     auto in2 = client.testNest(out2);
183     in1 = in2.struct_thing;
184     if (trace) writefln(" = {%s, {\"%s\", %s, %s, %s}, %s}", in2.byte_thing,
185       in1.string_thing, in1.byte_thing, in1.i32_thing, in1.i64_thing,
186       in2.i32_thing);
187     enforce(in2 == out2);
188 
189     int[int] mapout;
190     for (int i = 0; i < 5; ++i) {
191       mapout[i] = i - 10;
192     }
193     if (trace) writef("testMap({%s})", mapout);
194     auto mapin = client.testMap(mapout);
195     if (trace) writefln(" = {%s}", mapin);
196     enforce(mapin == mapout);
197 
198     auto setout = new HashSet!int;
199     for (int i = -2; i < 3; ++i) {
200       setout ~= i;
201     }
202     if (trace) writef("testSet(%s)", setout);
203     auto setin = client.testSet(setout);
204     if (trace) writefln(" = %s", setin);
205     enforce(setin == setout);
206 
207     int[] listout;
208     for (int i = -2; i < 3; ++i) {
209       listout ~= i;
210     }
211     if (trace) writef("testList(%s)", listout);
212     auto listin = client.testList(listout);
213     if (trace) writefln(" = %s", listin);
214     enforce(listin == listout);
215 
216     {
217       if (trace) write("testEnum(ONE)");
218       auto ret = client.testEnum(Numberz.ONE);
219       if (trace) writefln(" = %s", ret);
220       enforce(ret == Numberz.ONE);
221 
222       if (trace) write("testEnum(TWO)");
223       ret = client.testEnum(Numberz.TWO);
224       if (trace) writefln(" = %s", ret);
225       enforce(ret == Numberz.TWO);
226 
227       if (trace) write("testEnum(THREE)");
228       ret = client.testEnum(Numberz.THREE);
229       if (trace) writefln(" = %s", ret);
230       enforce(ret == Numberz.THREE);
231 
232       if (trace) write("testEnum(FIVE)");
233       ret = client.testEnum(Numberz.FIVE);
234       if (trace) writefln(" = %s", ret);
235       enforce(ret == Numberz.FIVE);
236 
237       if (trace) write("testEnum(EIGHT)");
238       ret = client.testEnum(Numberz.EIGHT);
239       if (trace) writefln(" = %s", ret);
240       enforce(ret == Numberz.EIGHT);
241     }
242 
243     if (trace) write("testTypedef(309858235082523)");
244     UserId uid = client.testTypedef(309858235082523L);
245     if (trace) writefln(" = %s", uid);
246     enforce(uid == 309858235082523L);
247 
248     if (trace) write("testMapMap(1)");
249     auto mm = client.testMapMap(1);
250     if (trace) writefln(" = {%s}", mm);
251     // Simply doing == doesn't seem to work for nested AAs.
252     foreach (key, value; mm) {
253       enforce(testMapMapReturn[key] == value);
254     }
255     foreach (key, value; testMapMapReturn) {
256       enforce(mm[key] == value);
257     }
258 
259     Insanity insane;
260     insane.userMap[Numberz.FIVE] = 5000;
261     Xtruct truck;
262     truck.string_thing = "Truck";
263     truck.byte_thing = 8;
264     truck.i32_thing = 8;
265     truck.i64_thing = 8;
266     insane.xtructs ~= truck;
267     if (trace) write("testInsanity()");
268     auto whoa = client.testInsanity(insane);
269     if (trace) writefln(" = %s", whoa);
270 
271     // Commented for now, this is cumbersome to write without opEqual getting
272     // called on AA comparison.
273     // enforce(whoa == testInsanityReturn);
274 
275     {
276       try {
277         if (trace) write("client.testException(\"Xception\") =>");
278         client.testException("Xception");
279         if (trace) writeln("  void\nFAILURE");
280         throw new Exception("testException failed.");
281       } catch (Xception e) {
282         if (trace) writefln("  {%s, \"%s\"}", e.errorCode, e.message);
283       }
284 
285       try {
286         if (trace) write("client.testException(\"TException\") =>");
287         client.testException("Xception");
288         if (trace) writeln("  void\nFAILURE");
289         throw new Exception("testException failed.");
290       } catch (TException e) {
291         if (trace) writefln("  {%s}", e.msg);
292       }
293 
294       try {
295         if (trace) write("client.testException(\"success\") =>");
296         client.testException("success");
297         if (trace) writeln("  void");
298       } catch (Exception e) {
299         if (trace) writeln("  exception\nFAILURE");
300         throw new Exception("testException failed.");
301       }
302     }
303 
304     {
305       try {
306         if (trace) write("client.testMultiException(\"Xception\", \"test 1\") =>");
307         auto result = client.testMultiException("Xception", "test 1");
308         if (trace) writeln("  result\nFAILURE");
309         throw new Exception("testMultiException failed.");
310       } catch (Xception e) {
311         if (trace) writefln("  {%s, \"%s\"}", e.errorCode, e.message);
312       }
313 
314       try {
315         if (trace) write("client.testMultiException(\"Xception2\", \"test 2\") =>");
316         auto result = client.testMultiException("Xception2", "test 2");
317         if (trace) writeln("  result\nFAILURE");
318         throw new Exception("testMultiException failed.");
319       } catch (Xception2 e) {
320         if (trace) writefln("  {%s, {\"%s\"}}",
321           e.errorCode, e.struct_thing.string_thing);
322       }
323 
324       try {
325         if (trace) writef("client.testMultiException(\"success\", \"test 3\") =>");
326         auto result = client.testMultiException("success", "test 3");
327         if (trace) writefln("  {{\"%s\"}}", result.string_thing);
328       } catch (Exception e) {
329         if (trace) writeln("  exception\nFAILURE");
330         throw new Exception("testMultiException failed.");
331       }
332     }
333 
334     // Do not run oneway test when doing multiple iterations, as it blocks the
335     // server for three seconds.
336     if (numTests == 1) {
337       if (trace) writef("client.testOneway(3) =>");
338       auto onewayWatch = StopWatch(AutoStart.yes);
339       client.testOneway(3);
340       onewayWatch.stop();
341       if (onewayWatch.peek().msecs > 200) {
342         if (trace) {
343           writefln("  FAILURE - took %s ms", onewayWatch.peek().usecs / 1000.0);
344         }
345         throw new Exception("testOneway failed.");
346       } else {
347         if (trace) {
348           writefln("  success - took %s ms", onewayWatch.peek().usecs / 1000.0);
349         }
350       }
351 
352       // Redo a simple test after the oneway to make sure we aren't "off by
353       // one", which would be the case if the server treated oneway methods
354       // like normal ones.
355       if (trace) write("re-test testI32(-1)");
356       i32 = client.testI32(-1);
357       if (trace) writefln(" = %s", i32);
358     }
359 
360     // Time metering.
361     sw.stop();
362 
363     immutable tot = sw.peek().usecs;
364     if (trace) writefln("Total time: %s us\n", tot);
365 
366     time_tot += tot;
367     if (time_min == 0 || tot < time_min) {
368       time_min = tot;
369     }
370     if (tot > time_max) {
371       time_max = tot;
372     }
373     protocol.transport.close();
374 
375     sw.reset();
376   }
377 
378   writeln("All tests done.");
379 
380   if (numTests > 1) {
381     auto time_avg = time_tot / numTests;
382     writefln("Min time: %s us", time_min);
383     writefln("Max time: %s us", time_max);
384     writefln("Avg time: %s us", time_avg);
385   }
386 }
387