1 // Copyright (c) 1999-2018 David Muse
2 // See the file COPYING for more information
3 
4 #include <rudiments/commandline.h>
5 #include <rudiments/tls.h>
6 #include <rudiments/inetsocketclient.h>
7 #include <rudiments/charstring.h>
8 #include <rudiments/bytebuffer.h>
9 #include <rudiments/error.h>
10 #include <rudiments/process.h>
11 #include <rudiments/stdio.h>
12 
usage()13 static void usage() {
14 	stdoutput.printf("tlsclient "
15 			"[-host host] [-port port] "
16 			"[-version version] [-cert cert] [-ciphers ciphers] "
17 			"[-validate (yes|no)] [-depth depth] [-ca ca] "
18 			"[-commonname name] "
19 			"[-ccount count] [-mcount count] [-dcount count]\n");
20 }
21 
main(int argc,const char ** argv)22 int main(int argc, const char **argv) {
23 
24 	// process the command line
25 	commandline	cmdl(argc,argv);
26 	if (cmdl.found("help")) {
27 		usage();
28 		process::exit(0);
29 	}
30 	const char	*host="127.0.0.1";
31 	if (cmdl.found("host")) {
32 		host=cmdl.getValue("host");
33 	}
34 	uint16_t	port=9000;
35 	if (cmdl.found("port")) {
36 		port=charstring::toUnsignedInteger(cmdl.getValue("port"));
37 	}
38 	const char	*version=NULL;
39 	if (cmdl.found("version")) {
40 		version=cmdl.getValue("version");
41 	}
42 	const char	*cert=NULL;
43 	if (cmdl.found("cert")) {
44 		cert=cmdl.getValue("cert");
45 	}
46 	const char	*ciphers=NULL;
47 	if (cmdl.found("ciphers")) {
48 		ciphers=cmdl.getValue("ciphers");
49 	}
50 	bool	validate=true;
51 	if (cmdl.found("validate")) {
52 		validate=charstring::compare(cmdl.getValue("validate"),"no");
53 	}
54 	uint16_t	depth=9;
55 	if (cmdl.found("depth")) {
56 		depth=charstring::toUnsignedInteger(cmdl.getValue("depth"));
57 	}
58 	const char	*ca=NULL;
59 	if (cmdl.found("ca")) {
60 		ca=cmdl.getValue("ca");
61 	}
62 	const char	*commonname="server.localdomain";
63 	if (cmdl.found("commonname")) {
64 		commonname=cmdl.getValue("commonname");
65 	}
66 	const char	*msg="hello";
67 	if (cmdl.found("message")) {
68 		msg=cmdl.getValue("message");
69 	}
70 	int64_t	ccount=charstring::toInteger(cmdl.getValue("ccount"));
71 	if (ccount<0) {
72 		usage();
73 		process::exit(1);
74 	} else if (!ccount) {
75 		ccount=1;
76 	}
77 	int64_t	mcount=charstring::toInteger(cmdl.getValue("mcount"));
78 	if (mcount<0) {
79 		usage();
80 		process::exit(1);
81 	} else if (!mcount) {
82 		mcount=1;
83 	}
84 	int64_t	dcount=charstring::toInteger(cmdl.getValue("dcount"));
85 	if (dcount<0) {
86 		usage();
87 		process::exit(1);
88 	} else if (!dcount) {
89 		dcount=1;
90 	}
91 
92 	// duplicate the message the
93 	// specified number of times...
94 	bytebuffer	msgbuf;
95 	for (int64_t i=0; i<dcount; i++) {
96 		msgbuf.append(msg)->append(' ');
97 	}
98 
99 {
100 	// configure the security context
101 	tlscontext	ctx;
102 	ctx.setProtocolVersion(version);
103 	ctx.setCertificateChainFile(cert);
104 	ctx.setPrivateKeyPassword("password");
105 	ctx.setCiphers(ciphers);
106 	ctx.setValidatePeer(validate);
107 	ctx.setValidationDepth(depth);
108 	ctx.setCertificateAuthority(ca);
109 
110 	// create an inet socket client
111 	inetsocketclient	fd;
112 	fd.setWriteBufferSize(65536);
113 	fd.setReadBufferSize(65536);
114 
115 	// attach the security context
116 	fd.setSecurityContext(&ctx);
117 
118 	// loop, having sessions with the server
119 	for (int64_t i=0; i<ccount; i++) {
120 
121 		// connect
122 		if (fd.connect(host,port,-1,-1,1,1)!=RESULT_SUCCESS) {
123 			if (error::getErrorNumber()) {
124 				stdoutput.printf("connect failed (1): %s\n",
125 							error::getErrorString());
126 			} else {
127 				stdoutput.printf("connect failed (2): %s\n",
128 							ctx.getErrorString());
129 			}
130 			continue;
131 		}
132 
133 		// make sure the server sent a certificate
134 		if (validate) {
135 			tlscertificate	*pcert=ctx.getPeerCertificate();
136 			if (!pcert) {
137 				stdoutput.printf(
138 					"peer sent no certificate\n%s\n",
139 					ctx.getErrorString());
140 				fd.close();
141 				delete pcert;
142 				continue;
143 			}
144 
145 			// Make sure the commonname in the certificate
146 			// is the one we expect it to be.
147 			const char	*cn=pcert->getCommonName();
148 			if (charstring::compareIgnoringCase(cn,commonname)) {
149 				stdoutput.printf("%s!=%s\n",cn,commonname);
150 				fd.close();
151 				delete pcert;
152 				continue;
153 			}
154 
155 			stdoutput.printf("server certificate {\n");
156 			stdoutput.printf("  version: %d\n",
157 					pcert->getVersion());
158 			stdoutput.printf("  serial number: %lld\n",
159 					pcert->getSerialNumber());
160 			stdoutput.printf("  signature algorithm: %s\n",
161 					pcert->getSignatureAlgorithm());
162 			stdoutput.printf("  issuer: %s\n",
163 					pcert->getIssuer());
164 			stdoutput.printf("  valid-from: %s\n",
165 					pcert->getValidFrom()->getString());
166 			stdoutput.printf("  valid-to: %s\n",
167 					pcert->getValidTo()->getString());
168 			stdoutput.printf("  subject: %s\n",
169 					pcert->getSubject());
170 			stdoutput.printf("  public key algorithm: %s\n",
171 					pcert->getPublicKeyAlgorithm());
172 			stdoutput.printf("  public key: ");
173 			stdoutput.safePrint(pcert->getPublicKey(),
174 					(pcert->getPublicKeyByteLength()<5)?
175 					pcert->getPublicKeyByteLength():5);
176 			stdoutput.printf("...\n");
177 			stdoutput.printf("  public key length: %lld\n",
178 					pcert->getPublicKeyByteLength());
179 			stdoutput.printf("  public key bits: %lld\n",
180 					pcert->getPublicKeyBitLength());
181 			stdoutput.printf("  common name: %s\n",
182 					pcert->getCommonName());
183 			stdoutput.printf("  subject alternate names:\n");
184 			for (linkedlistnode< char * > *node=
185 				pcert->getSubjectAlternateNames()->getFirst();
186 				node; node=node->getNext()) {
187 				stdoutput.printf("    %s\n",node->getValue());
188 			}
189 			stdoutput.printf("}\n");
190 
191 			delete pcert;
192 		}
193 
194 		stdoutput.printf("serverSession {\n");
195 
196 		// write the message to the server,
197 		// the specified number of times
198 		for (int64_t j=0; j<mcount; j++) {
199 
200 			// write size
201 			ssize_t	sizewritten=fd.write((uint64_t)
202 						msgbuf.getSize());
203 			if (sizewritten<=0) {
204 				if (sizewritten==0) {
205 					stdoutput.printf(
206 						"  write() size failed (0): "
207 						"eof\n");
208 					break;
209 				} else if (error::getErrorNumber()) {
210 					stdoutput.printf(
211 						"  write() size failed (1): "
212 						"%s\n",
213 						error::getErrorString());
214 					break;
215 				} else {
216 					stdoutput.printf(
217 						"  write() size failed (2): "
218 						"%s\n",
219 						ctx.getErrorString());
220 					break;
221 				}
222 			} else if (sizewritten!=sizeof(uint64_t)) {
223 				stdoutput.printf(
224 					"  write() size failed (3): %s\n",
225 					ctx.getErrorString());
226 				break;
227 			}
228 
229 			// write message
230 			sizewritten=fd.write(msgbuf.getBuffer(),
231 						msgbuf.getSize());
232 			if (sizewritten<=0) {
233 				if (sizewritten==0) {
234 					stdoutput.printf(
235 						"  write() msg failed (0): "
236 						"eof\n");
237 					break;
238 				} else if (error::getErrorNumber()) {
239 					stdoutput.printf(
240 						"  write() msg failed (1): "
241 						"%s\n",
242 						error::getErrorString());
243 					break;
244 				} else {
245 					stdoutput.printf(
246 						"  write() msg failed (2): "
247 						"%s\n",
248 						ctx.getErrorString());
249 					break;
250 				}
251 			} else if (sizewritten!=(ssize_t)msgbuf.getSize()) {
252 				stdoutput.printf(
253 					"  write() msg failed (3): %s\n",
254 					ctx.getErrorString());
255 				break;
256 			}
257 
258 			// flush write buffer
259 			if (!fd.flushWriteBuffer(-1,-1)) {
260 				stdoutput.printf("flushWriteBuffer() failed\n");
261 				break;
262 			}
263 
264 			stdoutput.printf("\n  Sent message... "
265 					"(size=%d):\n  ",msgbuf.getSize());
266 			stdoutput.safePrint(msgbuf.getBuffer(),
267 				(msgbuf.getSize()<=80)?msgbuf.getSize():80);
268 			if (msgbuf.getSize()>80) {
269 				stdoutput.write("...");
270 			}
271 			stdoutput.write('\n');
272 			stdoutput.printf("\n  Receiving response...");
273 
274 			// read size
275 			uint64_t	msgsize;
276 			ssize_t	sizeread=fd.read(&msgsize);
277 			if (sizeread<=0) {
278 				if (sizeread==0) {
279 					stdoutput.printf(
280 						"  read() size failed (0): "
281 						"eof\n");
282 					break;
283 				} else if (error::getErrorNumber()) {
284 					stdoutput.printf(
285 						"  read() size failed (1): "
286 						"%s\n",
287 						error::getErrorString());
288 					break;
289 				} else {
290 					stdoutput.printf(
291 						"  read() size failed (2): "
292 						"%s\n",
293 						ctx.getErrorString());
294 					break;
295 				}
296 			} else if (sizeread!=sizeof(uint64_t)) {
297 				stdoutput.printf(
298 					"  read() size failed (3): %s\n",
299 					ctx.getErrorString());
300 				break;
301 			}
302 
303 			// read message
304 			unsigned char	*msg=new unsigned char[msgsize];
305 			sizeread=fd.read(msg,msgsize);
306 			if (sizeread<=0) {
307 				if (sizeread==0) {
308 					stdoutput.printf(
309 						"  read() msg failed (0): "
310 						"eof\n");
311 					delete[] msg;
312 					break;
313 				} else if (error::getErrorNumber()) {
314 					stdoutput.printf(
315 						"  read() msg failed (1): "
316 						"%s\n",
317 						error::getErrorString());
318 					delete[] msg;
319 					break;
320 				} else {
321 					stdoutput.printf(
322 						"  read() msg failed (2): "
323 						"%s\n",
324 						ctx.getErrorString());
325 					delete[] msg;
326 					break;
327 				}
328 			} else if (sizeread!=(ssize_t)msgsize) {
329 				stdoutput.printf(
330 					"  read() msg failed (3): %s\n",
331 					ctx.getErrorString());
332 				delete[] msg;
333 				break;
334 			}
335 
336 			stdoutput.printf("  success\n");
337 
338 			delete[] msg;
339 		}
340 
341 		stdoutput.printf("}\n");
342 
343 		// close the connection to the server
344 		fd.close();
345 	}
346 }
347 
348 	process::exit(0);
349 }
350