1 // Copyright (c) 1999-2018 David Muse
2 // See the file COPYING for more information
3 
4 #include <rudiments/tls.h>
5 #include <rudiments/commandline.h>
6 #include <rudiments/permissions.h>
7 #include <rudiments/inetsocketserver.h>
8 #include <rudiments/charstring.h>
9 #include <rudiments/error.h>
10 #include <rudiments/file.h>
11 #include <rudiments/process.h>
12 #include <rudiments/stdio.h>
13 
main(int argc,const char ** argv)14 int main(int argc, const char **argv) {
15 
16 	// process the command line
17 	commandline	cmdl(argc,argv);
18 
19 	if (cmdl.found("help") || !cmdl.found("cert")) {
20 		stdoutput.printf("tlsserver [-port port] [-version version] -cert cert [-ciphers ciphers] [-validate (yes|no)] [-depth depth] [-ca ca] [-commonname name]\n");
21 		process::exit(0);
22 	}
23 
24 	uint16_t	port=9000;
25 	if (cmdl.found("port")) {
26 		port=charstring::toUnsignedInteger(cmdl.getValue("port"));
27 	}
28 	const char	*version=NULL;
29 	if (cmdl.found("version")) {
30 		version=cmdl.getValue("version");
31 	}
32 	const char	*cert=cmdl.getValue("cert");
33 	const char	*ciphers=NULL;
34 	if (cmdl.found("ciphers")) {
35 		ciphers=cmdl.getValue("ciphers");
36 	}
37 	bool	validate=false;
38 	if (cmdl.found("validate")) {
39 		validate=!charstring::compare(cmdl.getValue("validate"),"yes");
40 	}
41 	uint16_t	depth=9;
42 	if (cmdl.found("depth")) {
43 		depth=charstring::toUnsignedInteger(cmdl.getValue("depth"));
44 	}
45 	const char	*ca=NULL;
46 	if (cmdl.found("ca")) {
47 		ca=cmdl.getValue("ca");
48 	}
49 	const char	*commonname="client.localdomain";
50 	if (cmdl.found("commonname")) {
51 		commonname=cmdl.getValue("commonname");
52 	}
53 
54 	// configure the security context
55 	tlscontext	ctx;
56 	ctx.setProtocolVersion(version);
57 	ctx.setCertificateChainFile(cert);
58 	ctx.setPrivateKeyPassword("password");
59 	ctx.setCiphers(ciphers);
60 	ctx.setValidatePeer(validate);
61 	ctx.setValidationDepth(depth);
62 	ctx.setCertificateAuthority(ca);
63 
64 	// create an inet socket server
65 	inetsocketserver	iss;
66 
67 	// attach the security context
68 	iss.setSecurityContext(&ctx);
69 
70 	// listen
71 	if (iss.listen(NULL,port,0)) {
72 
73 		for (;;) {
74 
75 			// accept (will also accept context)
76 			filedescriptor	*fd=iss.accept();
77 			if (!fd) {
78 				if (error::getErrorNumber()) {
79 					stdoutput.printf(
80 						"accept failed (1): %s\n",
81 						error::getErrorString());
82 				} else {
83 					stdoutput.printf(
84 						"accept failed (2): %s\n",
85 						ctx.getErrorString());
86 				}
87 				continue;
88 			}
89 			fd->setWriteBufferSize(65536);
90 			fd->setReadBufferSize(65536);
91 
92 			if (validate) {
93 
94 				// make sure the client sent a certificate
95 				tlscertificate	*pcert=ctx.getPeerCertificate();
96 				if (!pcert) {
97 					stdoutput.printf("peer sent no "
98 							"certificate\n%s\n",
99 							ctx.getErrorString());
100 					fd->close();
101 					delete fd;
102 					delete pcert;
103 					continue;
104 				}
105 
106 				// Make sure the commonname in the certificate
107 				// is the one we expect it to be
108 				const char	*cn=pcert->getCommonName();
109 				if (charstring::compareIgnoringCase(
110 							cn,commonname)) {
111 					stdoutput.printf("%s!=%s\n",
112 							cn,commonname);
113 					fd->close();
114 					delete fd;
115 					delete pcert;
116 					continue;
117 				}
118 
119 				stdoutput.printf("client certificate {\n");
120 				stdoutput.printf("  version: %d\n",
121 						pcert->getVersion());
122 				stdoutput.printf("  serial number: %lld\n",
123 						pcert->getSerialNumber());
124 				stdoutput.printf("  signature algorithm: %s\n",
125 						pcert->getSignatureAlgorithm());
126 				stdoutput.printf("  issuer: %s\n",
127 						pcert->getIssuer());
128 				stdoutput.printf("  valid-from: %s\n",
129 						pcert->getValidFrom()->
130 							getString());
131 				stdoutput.printf("  valid-to: %s\n",
132 						pcert->getValidTo()->
133 							getString());
134 				stdoutput.printf("  subject: %s\n",
135 						pcert->getSubject());
136 				stdoutput.printf(
137 					"  public key algorithm: %s\n",
138 					pcert->getPublicKeyAlgorithm());
139 				stdoutput.printf("  public key: ");
140 				stdoutput.safePrint(
141 					pcert->getPublicKey(),
142 					(pcert->getPublicKeyByteLength()<5)?
143 					pcert->getPublicKeyByteLength():5);
144 				stdoutput.printf("...\n");
145 				stdoutput.printf(
146 					"  public key length: %lld\n",
147 					pcert->getPublicKeyByteLength());
148 				stdoutput.printf("  public key bits: %lld\n",
149 					pcert->getPublicKeyBitLength());
150 				stdoutput.printf("  common name: %s\n",
151 					pcert->getCommonName());
152 				for (linkedlistnode< char * > *node=
153 					pcert->getSubjectAlternateNames()->
154 					getFirst();
155 					node; node=node->getNext()) {
156 					stdoutput.printf("    %s\n",
157 							node->getValue());
158 				}
159 				stdoutput.printf("}\n");
160 
161 				delete pcert;
162 			}
163 
164 			stdoutput.printf("clientSession {\n");
165 
166 			// read messages from the client...
167 			for (;;) {
168 
169 				uint64_t	msgsize;
170 				ssize_t	sizeread=fd->read(&msgsize);
171 				if (sizeread<=0) {
172 					if (sizeread==0) {
173 						stdoutput.printf(
174 						"  read() size failed (0): "
175 						"eof\n");
176 					} else if (error::getErrorNumber()) {
177 						stdoutput.printf(
178 						"  read() size failed (1): "
179 						"%s\n",error::getErrorString());
180 					} else {
181 						stdoutput.printf(
182 						"  read() size failed (2): "
183 						"%s\n",ctx.getErrorString());
184 					}
185 					break;
186 				} else if (sizeread!=sizeof(uint64_t)) {
187 					stdoutput.printf(
188 						"  read() size failed (3): "
189 						"%s\n",ctx.getErrorString());
190 					break;
191 				}
192 
193 				unsigned char	*msg=new unsigned char[msgsize];
194 				sizeread=fd->read(msg,msgsize);
195 				if (sizeread<=0) {
196 					if (sizeread==0) {
197 						stdoutput.printf(
198 						"  read() size failed (0): "
199 						"eof\n");
200 					} else if (error::getErrorNumber()) {
201 						stdoutput.printf(
202 						"  read() size failed (1): "
203 						"%s\n",error::getErrorString());
204 					} else {
205 						stdoutput.printf(
206 						"  read() size failed (2): "
207 						"%s\n",ctx.getErrorString());
208 					}
209 					delete[] msg;
210 					break;
211 				} else if (sizeread!=(ssize_t)msgsize) {
212 					stdoutput.printf(
213 						"  read() size failed (3): "
214 						"%s\n",ctx.getErrorString());
215 					delete[] msg;
216 					break;
217 				}
218 
219 				stdoutput.printf("\nReceived message... "
220 						"(size=%d):\n",msgsize);
221 				stdoutput.safePrint(msg,
222 					(msgsize<=70)?msgsize:70);
223 				if (msgsize>70) {
224 					stdoutput.write("...");
225 				}
226 				stdoutput.write('\n');
227 				stdoutput.printf("\n  Sending response...");
228 
229 				ssize_t	sizewritten=fd->write(msgsize);
230 				if (sizewritten<=0) {
231 					if (sizewritten==0) {
232 						stdoutput.printf(
233 						"  write() size failed (0): "
234 						"eof\n");
235 					} else if (error::getErrorNumber()) {
236 						stdoutput.printf(
237 						"  write() size failed (1): "
238 						"%s\n",error::getErrorString());
239 					} else {
240 						stdoutput.printf(
241 						"  write() size failed (2): "
242 						"%s\n",ctx.getErrorString());
243 					}
244 					delete[] msg;
245 					break;
246 				} else if (sizewritten!=sizeof(uint64_t)) {
247 					stdoutput.printf(
248 						"  write() size failed (3): "
249 						"%s\n",ctx.getErrorString());
250 					delete[] msg;
251 					break;
252 				}
253 
254 				sizewritten=fd->write(msg,msgsize);
255 				if (sizewritten<=0) {
256 					if (sizewritten==0) {
257 						stdoutput.printf(
258 						"  write() size failed (0): "
259 						"eof\n");
260 					} else if (error::getErrorNumber()) {
261 						stdoutput.printf(
262 						"  write() size failed (1): "
263 						"%s\n",error::getErrorString());
264 					} else {
265 						stdoutput.printf(
266 						"  write() size failed (2): "
267 						"%s\n",ctx.getErrorString());
268 					}
269 					delete[] msg;
270 					break;
271 				} else if (sizewritten!=(ssize_t)msgsize) {
272 					stdoutput.printf(
273 						"  write() size failed (3): "
274 						"%s\n",ctx.getErrorString());
275 					delete[] msg;
276 					break;
277 				}
278 
279 				delete[] msg;
280 
281 				if (!fd->flushWriteBuffer(-1,-1)) {
282 					stdoutput.printf(
283 						"\n  flushWriteBuffer() msg "
284 						"failed\n");
285 					break;
286 				}
287 
288 				stdoutput.printf("  success\n");
289 			}
290 
291 			stdoutput.printf("}\n");
292 
293 			// close and delete the client socket
294 			delete fd;
295 
296 			// FIXME: necessary?
297 			ctx.close();
298 		}
299 	}
300 
301 	stdoutput.printf("error listening on port %d\n%s\n",
302 				port,error::getErrorString());
303 	process::exit(1);
304 }
305