1 #include <cstdio>
2 #include <cstring>
3 #include <stdlib.h>
4 #include "GetTime.h"
5 #include "Rand.h"
6 #include "RSACrypt.h"
7 #include "DataBlockEncryptor.h"
8 #include "Rand.h"
9 #include "RakPeerInterface.h"
10 #include "MessageIdentifiers.h"
11 #include "RakNetworkFactory.h"
12 #include "RakNetTypes.h"
13 #include <assert.h>
14 #include "RakSleep.h"
15 #include "BitStream.h"
16 
PrintOptions(void)17 void PrintOptions(void)
18 {
19 	printf("1. Generate RSA keys and save to disk.\n");
20 	printf("2. Load RSA keys from disk.\n");
21 	printf("3. Test peers with key.\n");
22 	printf("(H)elp.\n");
23 	printf("(Q)uit.\n");
24 }
25 
26 RakPeerInterface *rakPeer1, *rakPeer2;
27 
PrintPacketHeader(Packet * packet)28 void PrintPacketHeader(Packet *packet)
29 {
30 	switch (packet->data[0])
31 	{
32 		case ID_RSA_PUBLIC_KEY_MISMATCH:
33 			printf("Public key mismatch.\nThe connecting system's public key does not\nmatch what the sender sent.\n");
34 			break;
35 		case ID_CONNECTION_REQUEST_ACCEPTED:
36 			printf("Connection request accepted.\n");
37 			break;
38 		case ID_NEW_INCOMING_CONNECTION:
39 			{
40 				printf("New incoming connection.\n");
41 				RakNet::BitStream testBlockLargerThanMTU;
42 				testBlockLargerThanMTU.Write((MessageID) ID_USER_PACKET_ENUM);
43 				testBlockLargerThanMTU.PadWithZeroToByteLength(10000);
44 				rakPeer2->Send(&testBlockLargerThanMTU, HIGH_PRIORITY, RELIABLE_ORDERED, 0, packet->systemAddress, false);
45 			}
46 			break;
47 		case ID_MODIFIED_PACKET:
48 			printf("Packet checksum invalid.  Either RSA decrypt function gave the wrong value\nor the packet was tampered with.\n");
49 			break;
50 		case ID_USER_PACKET_ENUM:
51 			printf("Got test message\n");
52 			break;
53 		default:
54 			printf("%s\n", packet->data);
55 			break;
56 	}
57 }
58 
main(void)59 int main(void)
60 {
61 	char str[256];
62 	bool keyLoaded; // Does D,E,N have values?
63 
64 	// RSACrypt is a class that handles RSA encryption/decryption internally
65 	RSACrypt rsacrypt;
66 
67 	uint32_t e;
68 	uint32_t modulus[RAKNET_RSA_FACTOR_LIMBS];
69 	// e and modulus form the public key
70 
71 	// p,q is the private key
72 	uint32_t p[RAKNET_RSA_FACTOR_LIMBS/2],q[RAKNET_RSA_FACTOR_LIMBS/2];
73 
74 	/*
75 	// RSACrypt is a class that handles RSA encryption/decryption internally
76 	big::RSACrypt<RSA_BIT_SIZE> rsacrypt;
77 
78 	// These are the sizes necessary for e,n,p,q
79 	// e,n is the public key
80 	// p,q is the private key
81 	big::u32 e;
82 	RSA_BIT_SIZE n;
83 	BIGHALFSIZE(RSA_BIT_SIZE, p);
84 	BIGHALFSIZE(RSA_BIT_SIZE, q);
85 	*/
86 
87 	FILE *fp;
88 	RakNetTime time;
89 	rakPeer1=RakNetworkFactory::GetRakPeerInterface();
90 	rakPeer2=RakNetworkFactory::GetRakPeerInterface();
91 	Packet *packet;
92 	bool peer1GotMessage, peer2GotMessage;
93 
94 	seedMT((unsigned int) RakNet::GetTimeMS());
95 
96 	keyLoaded=false;
97 
98 	printf("Demonstrates how to setup RakNet to use secure connections\n");
99 	printf("Also shows how to read and write RSA keys to and from disk\n");
100 	printf("Difficulty: Intermediate\n\n");
101 
102 	printf("Select option:\n");
103 	PrintOptions();
104 
105 	while (1)
106 	{
107 		gets(str);
108 
109 		if (str[0]=='1')
110 		{
111 			printf("Generating %i bit key. This will take a while...\n", RAKNET_RSA_FACTOR_LIMBS*32);
112 			rsacrypt.generatePrivateKey(RAKNET_RSA_FACTOR_LIMBS);
113 			e=rsacrypt.getPublicExponent();
114 			rsacrypt.getPublicModulus(modulus);
115 			rsacrypt.getPrivateP(p);
116 			rsacrypt.getPrivateQ(q);
117 
118 
119 			/*
120             printf("Generating %i byte key.  This will take a while...\n", sizeof(RSA_BIT_SIZE));
121 			rsacrypt.generateKeys();
122 			rsacrypt.getPublicKey(e,n);
123 			rsacrypt.getPrivateKey(p,q);
124 			*/
125 			keyLoaded=true;
126 			printf("Key generated.  Save to disk? (y/n)\n");
127 			gets(str);
128 			if (str[0]=='y' || str[0]=='Y')
129 			{
130 				printf("Enter filename to save public keys to: ");
131 				gets(str);
132                 if (str[0])
133 				{
134 					printf("Writing public key... ");
135 					fp=fopen(str, "wb");
136 					fwrite((char*)&e, sizeof(e), 1, fp);
137 					fwrite((char*)modulus, sizeof(modulus), 1, fp);
138 					//fwrite((char*)n, sizeof(n), 1, fp);
139 					fclose(fp);
140 					printf("Done.\n");
141 				}
142 				else
143 					printf("\nKey not written.\n");
144 
145 				printf("Enter filename to save private key to: ");
146 				gets(str);
147 				if (str[0])
148 				{
149 					printf("Writing private key... ");
150 					fp=fopen(str, "wb");
151 					fwrite(p, sizeof(p),1,fp);
152 					fwrite(q, sizeof(q), 1, fp);
153 					//fwrite(p, sizeof(RSA_BIT_SIZE)/2,1,fp);
154 					//fwrite(q, sizeof(RSA_BIT_SIZE)/2, 1, fp);
155 					fclose(fp);
156 					printf("Done.\n");
157 				}
158 				else
159 					printf("\nKey not written.\n");
160 			}
161 			PrintOptions();
162 		}
163 		else if (str[0]=='2')
164 		{
165 			printf("Enter filename to load public keys from: ");
166 			gets(str);
167 			if (str[0])
168 			{
169 				fp=fopen(str, "rb");
170 				if (fp)
171 				{
172 					printf("Loading public keys... ");
173 					fread((char*)(&e), sizeof(e), 1, fp);
174 					fread((char*)(modulus), sizeof(modulus), 1, fp);
175 					fclose(fp);
176 					printf("Done.\n");
177 
178 					printf("Enter filename to load private key from: ");
179 					gets(str);
180 					if (str[0])
181 					{
182 						fp=fopen(str, "rb");
183 						if (fp)
184 						{
185 							printf("Loading private key... ");
186 							fread(p, sizeof(p), 1, fp);
187 							fread(q, sizeof(q), 1, fp);
188 							//fread(p, sizeof(RSA_BIT_SIZE)/2, 1, fp);
189 							//fread(q, sizeof(RSA_BIT_SIZE)/2, 1, fp);
190 							fclose(fp);
191 							printf("Done.\n");
192 							keyLoaded=true;
193 						}
194 						else
195 						{
196 							printf("Failed to open %s.\n", str);
197 						}
198 					}
199 					else
200 						printf("Not loading private key.\n");
201 				}
202 				else
203 				{
204 					printf("Failed to open %s.\n", str);
205 				}
206 			}
207 			else
208 				printf("Not loading public keys.\n");
209 
210 			PrintOptions();
211 		}
212 		else if (str[0]=='3')
213 		{
214 			if (keyLoaded)
215 			{
216 				printf("(G)enerate new keys automatically or use (e)xisting?\n");
217 				gets(str);
218 				if (str[0]=='g' || str[0]=='G')
219 				{
220 					printf("Generating 32 byte keys.  Please wait.\n");
221 					rakPeer1->InitializeSecurity(0,0,0,0);
222 					printf("Keys generated.\n");
223 				}
224 				else
225 				{
226 					rakPeer1->InitializeSecurity(0,0,(char*)p, (char*)q);
227 					printf("Tell the connecting system the public keys in advance?\n(Y)es, better security.\n(N)o, worse security but everything works automatically.\n");
228 					gets(str);
229 					if (str[0]=='y' || str[0]=='Y')
230 					{
231 						printf("Using preloaded keys for the connecting system.\n");
232 						//rakPeer2->InitializeSecurity((char*)&e, (char*)n, 0, 0);
233 						rakPeer2->InitializeSecurity((char*)&e, (char*)modulus, 0, 0);
234 					}
235 					else
236 					{
237 						printf("Relying on server to transmit public keys to the connecting system.\n");
238 
239 						// Clear out any old saved public keys
240 						rakPeer2->DisableSecurity();
241 					}
242 				}
243 			}
244 			else
245 			{
246 				printf("Generating key automatically on host.  Please wait.\n");
247 				rakPeer1->InitializeSecurity(0, 0, 0, 0);
248 
249 				// Clear out any old saved public keys
250 				rakPeer2->DisableSecurity();
251 				printf("Key generation complete.\n");
252 			}
253 
254 			printf("Initializing peers.\n");
255 			SocketDescriptor socketDescriptor(1234,0);
256 			rakPeer1->Startup(8,0,&socketDescriptor, 1);
257 			rakPeer1->SetMaximumIncomingConnections(8);
258 			socketDescriptor.port=0;
259 			rakPeer2->Startup(1,0,&socketDescriptor, 1);
260 			rakPeer2->Connect("127.0.0.1", 1234, 0, 0);
261 			printf("Running connection for 5 seconds.\n");
262 
263 			peer1GotMessage=false;
264 			peer2GotMessage=false;
265 			time = RakNet::GetTime() + 5000;
266 			while (RakNet::GetTime() < time)
267 			{
268 				packet=rakPeer1->Receive();
269 				if (packet)
270 				{
271 					peer1GotMessage=true;
272 					printf("Host got: ");
273 					PrintPacketHeader(packet);
274 					rakPeer1->DeallocatePacket(packet);
275 				}
276 				packet=rakPeer2->Receive();
277 				if (packet)
278 				{
279 					peer2GotMessage=true;
280 					printf("Connecting system got: ");
281 					PrintPacketHeader(packet);
282 					rakPeer2->DeallocatePacket(packet);
283 				}
284 
285 				RakSleep(30);
286 			}
287 
288 			if (peer1GotMessage==false)
289 				printf("Error, host got no packets.\n");
290 			if (peer2GotMessage==false)
291 				printf("Error, connecting system got no packets.\n");
292 
293 			if (peer1GotMessage && peer2GotMessage)
294 				printf("Test successful as long as you got no error messages.\n");
295 			rakPeer2->Shutdown(0);
296 			rakPeer1->Shutdown(0);
297 			PrintOptions();
298 		}
299 		else if (str[0]=='h' || str[0]=='H')
300 		{
301 			PrintOptions();
302 		}
303 		else if (str[0]=='q' || str[0]=='Q')
304 			break;
305 
306 		str[0]=0;
307 	}
308 
309 	RakNetworkFactory::DestroyRakPeerInterface(rakPeer1);
310 	RakNetworkFactory::DestroyRakPeerInterface(rakPeer2);
311 }
312