1 /**
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  * SPDX-License-Identifier: Apache-2.0.
4  */
5 #include <aws/crt/Api.h>
6 #include <aws/crt/crypto/Hash.h>
7 #include <aws/crt/http/HttpConnection.h>
8 #include <aws/crt/http/HttpRequestResponse.h>
9 #include <aws/crt/io/Uri.h>
10 
11 #include <aws/common/command_line_parser.h>
12 
13 #include <condition_variable>
14 #include <fstream>
15 #include <future>
16 #include <iostream>
17 
18 using namespace Aws::Crt;
19 
20 #define ELASTICURL_VERSION "0.0.1"
21 
22 struct ElasticurlCtx
23 {
24     Allocator *allocator = nullptr;
25     const char *verb = "GET";
26     Io::Uri uri;
27     bool ResponseCodeWritten = false;
28     const char *CaCert = nullptr;
29     const char *CaPath = nullptr;
30     const char *Cert = nullptr;
31     const char *Key = nullptr;
32     int ConnectTimeout = 3000;
33     Vector<const char *> HeaderLines;
34     const char *Alpn = "h2;http/1.1";
35     bool IncludeHeaders = false;
36     bool Insecure = false;
37 
38     const char *TraceFile = nullptr;
39     Aws::Crt::LogLevel LogLevel = Aws::Crt::LogLevel::None;
40     Http::HttpVersion RequiredHttpVersion = Http::HttpVersion::Unknown;
41 
42     std::shared_ptr<Io::IStream> InputBody = nullptr;
43     std::ofstream Output;
44 };
45 
s_Usage(int exit_code)46 static void s_Usage(int exit_code)
47 {
48 
49     std::cerr << "usage: elasticurl [options] url\n";
50     std::cerr << " url: url to make a request to. The default is a GET request.\n";
51     std::cerr << "\n Options:\n\n";
52     std::cerr << "      --cacert FILE: path to a CA certficate file.\n";
53     std::cerr << "      --capath PATH: path to a directory containing CA files.\n";
54     std::cerr << "      --cert FILE: path to a PEM encoded certificate to use with mTLS\n";
55     std::cerr << "      --key FILE: Path to a PEM encoded private key that matches cert.\n";
56     std::cerr << "      --connect-timeout INT: time in milliseconds to wait for a connection.\n";
57     std::cerr << "  -H, --header LINE: line to send as a header in format [header-key]: [header-value]\n";
58     std::cerr << "  -d, --data STRING: Data to POST or PUT\n";
59     std::cerr << "      --data-file FILE: File to read from file and POST or PUT\n";
60     std::cerr << "  -M, --method STRING: Http Method verb to use for the request\n";
61     std::cerr << "  -G, --get: uses GET for the verb.\n";
62     std::cerr << "  -P, --post: uses POST for the verb.\n";
63     std::cerr << "  -I, --head: uses HEAD for the verb.\n";
64     std::cerr << "  -i, --include: includes headers in output.\n";
65     std::cerr << "  -k, --insecure: turns off SSL/TLS validation.\n";
66     std::cerr << "  -o, --output FILE: dumps content-body to FILE instead of stdout.\n";
67     std::cerr << "  -t, --trace FILE: dumps logs to FILE instead of stderr.\n";
68     std::cerr << "  -v, --verbose: ERROR|INFO|DEBUG|TRACE: log level to configure. Default is none.\n";
69     std::cerr << "      --version: print the version of elasticurl.\n";
70     std::cerr << "      --http2: HTTP/2 connection required\n";
71     std::cerr << "      --http1_1: HTTP/1.1 connection required\n";
72     std::cerr << "  -h, --help\n";
73     std::cerr << "            Display this message and quit.\n";
74     exit(exit_code);
75 }
76 
77 static struct aws_cli_option s_LongOptions[] = {
78     {"cacert", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'a'},
79     {"capath", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'b'},
80     {"cert", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'c'},
81     {"key", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'e'},
82     {"connect-timeout", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'f'},
83     {"header", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'H'},
84     {"data", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'd'},
85     {"data-file", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'g'},
86     {"method", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'M'},
87     {"get", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'G'},
88     {"post", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'P'},
89     {"head", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'I'},
90     {"include", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'i'},
91     {"insecure", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'k'},
92     {"output", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'o'},
93     {"trace", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 't'},
94     {"verbose", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, nullptr, 'v'},
95     {"version", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'V'},
96     {"http2", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'w'},
97     {"http1_1", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'W'},
98     {"help", AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 'h'},
99     /* Per getopt(3) the last element of the array has to be filled with all zeros */
100     {nullptr, AWS_CLI_OPTIONS_NO_ARGUMENT, nullptr, 0},
101 };
102 
s_ParseOptions(int argc,char ** argv,ElasticurlCtx & ctx)103 static void s_ParseOptions(int argc, char **argv, ElasticurlCtx &ctx)
104 {
105     while (true)
106     {
107         int option_index = 0;
108         int c = aws_cli_getopt_long(argc, argv, "a:b:c:e:f:H:d:g:M:GPHiko:t:v:VwWh", s_LongOptions, &option_index);
109         if (c == -1)
110         {
111             /* finished parsing */
112             break;
113         }
114 
115         switch (c)
116         {
117             case 0:
118                 /* getopt_long() returns 0 if an option.flag is non-null */
119                 break;
120             case 'a':
121                 ctx.CaCert = aws_cli_optarg;
122                 break;
123             case 'b':
124                 ctx.CaPath = aws_cli_optarg;
125                 break;
126             case 'c':
127                 ctx.Cert = aws_cli_optarg;
128                 break;
129             case 'e':
130                 ctx.Key = aws_cli_optarg;
131                 break;
132             case 'f':
133                 ctx.ConnectTimeout = atoi(aws_cli_optarg);
134                 break;
135             case 'H':
136                 ctx.HeaderLines.push_back(aws_cli_optarg);
137                 break;
138             case 'd':
139             {
140                 ctx.InputBody = std::make_shared<std::stringstream>(aws_cli_optarg);
141                 break;
142             }
143             case 'g':
144             {
145                 ctx.InputBody = std::make_shared<std::ifstream>(aws_cli_optarg, std::ios::in);
146                 if (!ctx.InputBody->good())
147                 {
148                     std::cerr << "unable to open file " << aws_cli_optarg << std::endl;
149                     s_Usage(1);
150                 }
151                 break;
152             }
153             case 'M':
154                 ctx.verb = aws_cli_optarg;
155                 break;
156             case 'G':
157                 ctx.verb = "GET";
158                 break;
159             case 'P':
160                 ctx.verb = "POST";
161                 break;
162             case 'I':
163                 ctx.verb = "HEAD";
164                 break;
165             case 'i':
166                 ctx.IncludeHeaders = true;
167                 break;
168             case 'k':
169                 ctx.Insecure = true;
170                 break;
171             case 'o':
172                 ctx.Output.open(aws_cli_optarg, std::ios::out | std::ios::binary);
173                 break;
174             case 't':
175                 ctx.TraceFile = aws_cli_optarg;
176                 break;
177             case 'v':
178                 if (!strcmp(aws_cli_optarg, "TRACE"))
179                 {
180                     ctx.LogLevel = Aws::Crt::LogLevel::Trace;
181                 }
182                 else if (!strcmp(aws_cli_optarg, "INFO"))
183                 {
184                     ctx.LogLevel = Aws::Crt::LogLevel::Info;
185                 }
186                 else if (!strcmp(aws_cli_optarg, "DEBUG"))
187                 {
188                     ctx.LogLevel = Aws::Crt::LogLevel::Debug;
189                 }
190                 else if (!strcmp(aws_cli_optarg, "ERROR"))
191                 {
192                     ctx.LogLevel = Aws::Crt::LogLevel::Error;
193                 }
194                 else
195                 {
196                     std::cerr << "unsupported log level " << aws_cli_optarg << std::endl;
197                     s_Usage(1);
198                 }
199                 break;
200             case 'V':
201                 std::cerr << "elasticurl " << ELASTICURL_VERSION << std::endl;
202                 exit(0);
203             case 'w':
204                 ctx.Alpn = "h2";
205                 ctx.RequiredHttpVersion = Http::HttpVersion::Http2;
206                 break;
207             case 'W':
208                 ctx.Alpn = "http/1.1";
209                 ctx.RequiredHttpVersion = Http::HttpVersion::Http1_1;
210                 break;
211             case 'h':
212                 s_Usage(0);
213                 break;
214             default:
215                 std::cerr << "Unknown option\n";
216                 s_Usage(1);
217         }
218     }
219 
220     if (ctx.InputBody == nullptr)
221     {
222         ctx.InputBody = std::make_shared<std::stringstream>("");
223     }
224 
225     if (aws_cli_optind < argc)
226     {
227         struct aws_byte_cursor uri_cursor = aws_byte_cursor_from_c_str(argv[aws_cli_optind++]);
228 
229         ctx.uri = Io::Uri(uri_cursor, ctx.allocator);
230         if (!ctx.uri)
231         {
232             std::cerr << "Failed to parse uri" << (char *)uri_cursor.ptr << "with error "
233                       << aws_error_debug_str(ctx.uri.LastError()) << std::endl;
234             s_Usage(1);
235         };
236     }
237     else
238     {
239         std::cerr << "A URI for the request must be supplied.\n";
240         s_Usage(1);
241     }
242 }
243 
main(int argc,char ** argv)244 int main(int argc, char **argv)
245 {
246     struct aws_allocator *allocator = aws_default_allocator();
247 
248     struct ElasticurlCtx appCtx;
249     appCtx.allocator = allocator;
250 
251     s_ParseOptions(argc, argv, appCtx);
252     ApiHandle apiHandle(allocator);
253     if (appCtx.TraceFile)
254     {
255         apiHandle.InitializeLogging(appCtx.LogLevel, appCtx.TraceFile);
256     }
257     else
258     {
259         apiHandle.InitializeLogging(appCtx.LogLevel, stderr);
260     }
261     bool useTls = true;
262     uint16_t port = 443;
263     if (!appCtx.uri.GetScheme().len && (appCtx.uri.GetPort() == 80 || appCtx.uri.GetPort() == 8080))
264     {
265         useTls = false;
266     }
267     else
268     {
269         ByteCursor scheme = appCtx.uri.GetScheme();
270         if (aws_byte_cursor_eq_c_str_ignore_case(&scheme, "http"))
271         {
272             useTls = false;
273         }
274     }
275 
276     auto hostName = appCtx.uri.GetHostName();
277 
278     Io::TlsContextOptions tlsCtxOptions;
279     Io::TlsContext tlsContext;
280     Io::TlsConnectionOptions tlsConnectionOptions;
281     if (useTls)
282     {
283         if (appCtx.Cert && appCtx.Key)
284         {
285             tlsCtxOptions = Io::TlsContextOptions::InitClientWithMtls(appCtx.Cert, appCtx.Key);
286             if (!tlsCtxOptions)
287             {
288                 std::cerr << "Failed to load " << appCtx.Cert << " and " << appCtx.Key << " with error "
289                           << aws_error_debug_str(tlsCtxOptions.LastError()) << std::endl;
290                 exit(1);
291             }
292         }
293         else
294         {
295             tlsCtxOptions = Io::TlsContextOptions::InitDefaultClient();
296             if (!tlsCtxOptions)
297             {
298                 std::cerr << "Failed to create a default tlsCtxOptions with error "
299                           << aws_error_debug_str(tlsCtxOptions.LastError()) << std::endl;
300                 exit(1);
301             }
302         }
303 
304         if (appCtx.CaPath || appCtx.CaCert)
305         {
306             if (!tlsCtxOptions.OverrideDefaultTrustStore(appCtx.CaPath, appCtx.CaCert))
307             {
308                 std::cerr << "Failed to load " << appCtx.CaPath << " and " << appCtx.CaCert << " with error "
309                           << aws_error_debug_str(tlsCtxOptions.LastError()) << std::endl;
310                 exit(1);
311             }
312         }
313         if (appCtx.Insecure)
314         {
315             tlsCtxOptions.SetVerifyPeer(false);
316         }
317 
318         tlsContext = Io::TlsContext(tlsCtxOptions, Io::TlsMode::CLIENT, allocator);
319 
320         tlsConnectionOptions = tlsContext.NewConnectionOptions();
321 
322         if (!tlsConnectionOptions.SetServerName(hostName))
323         {
324             std::cerr << "Failed to set servername with error " << aws_error_debug_str(tlsConnectionOptions.LastError())
325                       << std::endl;
326             exit(1);
327         }
328         if (!tlsConnectionOptions.SetAlpnList(appCtx.Alpn))
329         {
330             std::cerr << "Failed to load alpn list with error " << aws_error_debug_str(tlsConnectionOptions.LastError())
331                       << std::endl;
332             exit(1);
333         }
334     }
335     else
336     {
337         if (appCtx.RequiredHttpVersion == Http::HttpVersion::Http2)
338         {
339             std::cerr << "Error, we don't support h2c, please use TLS for HTTP/2 connection" << std::endl;
340             exit(1);
341         }
342         port = 80;
343         if (appCtx.uri.GetPort())
344         {
345             port = appCtx.uri.GetPort();
346         }
347     }
348 
349     Io::SocketOptions socketOptions;
350     socketOptions.SetConnectTimeoutMs(appCtx.ConnectTimeout);
351 
352     Io::EventLoopGroup eventLoopGroup(0, allocator);
353     if (!eventLoopGroup)
354     {
355         std::cerr << "Failed to create evenloop group with error " << aws_error_debug_str(eventLoopGroup.LastError())
356                   << std::endl;
357         exit(1);
358     }
359 
360     Io::DefaultHostResolver defaultHostResolver(eventLoopGroup, 8, 30, allocator);
361     if (!defaultHostResolver)
362     {
363         std::cerr << "Failed to create host resolver with error "
364                   << aws_error_debug_str(defaultHostResolver.LastError()) << std::endl;
365         exit(1);
366     }
367 
368     Io::ClientBootstrap clientBootstrap(eventLoopGroup, defaultHostResolver, allocator);
369     if (!clientBootstrap)
370     {
371         std::cerr << "Failed to create client bootstrap with error " << aws_error_debug_str(clientBootstrap.LastError())
372                   << std::endl;
373         exit(1);
374     }
375     clientBootstrap.EnableBlockingShutdown();
376 
377     std::promise<std::shared_ptr<Http::HttpClientConnection>> connectionPromise;
378     std::promise<void> shutdownPromise;
379 
380     auto onConnectionSetup = [&appCtx, &connectionPromise](const std::shared_ptr<Http::HttpClientConnection> &newConnection, int errorCode) {
381         if (!errorCode)
382         {
383             if (appCtx.RequiredHttpVersion != Http::HttpVersion::Unknown)
384             {
385                 if (newConnection->GetVersion() != appCtx.RequiredHttpVersion)
386                 {
387                     std::cerr << "Error. The requested HTTP version, " << appCtx.Alpn
388                               << ", is not supported by the peer." << std::endl;
389                     exit(1);
390                 }
391             }
392         }
393         else
394         {
395             std::cerr << "Connection failed with error " << aws_error_debug_str(errorCode) << std::endl;
396             exit(1);
397         }
398         connectionPromise.set_value(newConnection);
399     };
400 
401     auto onConnectionShutdown = [&shutdownPromise](Http::HttpClientConnection &newConnection, int errorCode) {
402         (void)newConnection;
403         if (errorCode)
404         {
405             std::cerr << "Connection shutdown with error " << aws_error_debug_str(errorCode) << std::endl;
406             exit(1);
407         }
408 
409         shutdownPromise.set_value();
410     };
411 
412     Http::HttpClientConnectionOptions httpClientConnectionOptions;
413     httpClientConnectionOptions.Bootstrap = &clientBootstrap;
414     httpClientConnectionOptions.OnConnectionSetupCallback = onConnectionSetup;
415     httpClientConnectionOptions.OnConnectionShutdownCallback = onConnectionShutdown;
416     httpClientConnectionOptions.SocketOptions = socketOptions;
417     if (useTls)
418     {
419         httpClientConnectionOptions.TlsOptions = tlsConnectionOptions;
420     }
421     httpClientConnectionOptions.HostName = String((const char *)hostName.ptr, hostName.len);
422     httpClientConnectionOptions.Port = port;
423 
424     Http::HttpClientConnection::CreateConnection(httpClientConnectionOptions, allocator);
425 
426     std::shared_ptr<Http::HttpClientConnection> connection = connectionPromise.get_future().get();
427     /* Send request */
428     int responseCode = 0;
429 
430     Http::HttpRequest request;
431     Http::HttpRequestOptions requestOptions;
432     requestOptions.request = &request;
433     std::promise<void> streamCompletePromise;
434 
435     requestOptions.onStreamComplete = [&streamCompletePromise](Http::HttpStream &stream, int errorCode) {
436         (void)stream;
437         if (errorCode)
438         {
439             std::cerr << "Stream completed with error " << aws_error_debug_str(errorCode) << std::endl;
440             exit(1);
441         }
442         streamCompletePromise.set_value();
443     };
444     requestOptions.onIncomingHeadersBlockDone = nullptr;
445     requestOptions.onIncomingHeaders = [&](Http::HttpStream &stream,
446                                            enum aws_http_header_block header_block,
447                                            const Http::HttpHeader *header,
448                                            std::size_t len) {
449         /* Ignore informational headers */
450         if (header_block == AWS_HTTP_HEADER_BLOCK_INFORMATIONAL)
451         {
452             return;
453         }
454 
455         if (appCtx.IncludeHeaders)
456         {
457             if (!appCtx.ResponseCodeWritten)
458             {
459                 responseCode = stream.GetResponseStatusCode();
460                 std::cout << "Response Status: " << responseCode << std::endl;
461                 appCtx.ResponseCodeWritten = true;
462             }
463 
464             for (size_t i = 0; i < len; ++i)
465             {
466                 std::cout.write((char *)header[i].name.ptr, header[i].name.len);
467                 std::cout << ": ";
468                 std::cout.write((char *)header[i].value.ptr, header[i].value.len);
469                 std::cout << std::endl;
470             }
471         }
472     };
473     requestOptions.onIncomingBody = [&appCtx](Http::HttpStream &, const ByteCursor &data) {
474         if (appCtx.Output.is_open())
475         {
476             appCtx.Output.write((char *)data.ptr, data.len);
477         }
478         else
479         {
480             std::cout.write((char *)data.ptr, data.len);
481         }
482     };
483 
484     request.SetMethod(ByteCursorFromCString(appCtx.verb));
485     request.SetPath(appCtx.uri.GetPathAndQuery());
486 
487     Http::HttpHeader hostHeader;
488     hostHeader.name = ByteCursorFromCString("host");
489     hostHeader.value = appCtx.uri.GetHostName();
490     request.AddHeader(hostHeader);
491 
492     Http::HttpHeader userAgentHeader;
493     userAgentHeader.name = ByteCursorFromCString("user-agent");
494     userAgentHeader.value = ByteCursorFromCString("elasticurl_cpp 1.0, Powered by the AWS Common Runtime.");
495     request.AddHeader(userAgentHeader);
496 
497     std::shared_ptr<Io::StdIOStreamInputStream> bodyStream =
498         MakeShared<Io::StdIOStreamInputStream>(allocator, appCtx.InputBody, allocator);
499     int64_t dataLen;
500     if (!bodyStream->GetLength(dataLen))
501     {
502         std::cerr << "failed to get length of input stream.\n";
503         exit(1);
504     }
505     if (dataLen > 0)
506     {
507         std::string contentLength = std::to_string(dataLen);
508         Http::HttpHeader contentLengthHeader;
509         contentLengthHeader.name = ByteCursorFromCString("content-length");
510         contentLengthHeader.value = ByteCursorFromCString(contentLength.c_str());
511         request.AddHeader(contentLengthHeader);
512         request.SetBody(bodyStream);
513     }
514 
515     for (auto headerLine : appCtx.HeaderLines)
516     {
517         char *delimiter = (char *)memchr(headerLine, ':', strlen(headerLine));
518 
519         if (!delimiter)
520         {
521             std::cerr << "invalid header line " << headerLine << " configured." << std::endl;
522             exit(1);
523         }
524 
525         Http::HttpHeader userHeader;
526         userHeader.name = ByteCursorFromArray((uint8_t *)headerLine, delimiter - headerLine);
527         userHeader.value = ByteCursorFromCString(delimiter + 1);
528         request.AddHeader(userHeader);
529     }
530 
531     auto stream = connection->NewClientStream(requestOptions);
532     stream->Activate();
533 
534     streamCompletePromise.get_future().wait(); // wait for connection shutdown to complete
535 
536     connection->Close();
537     shutdownPromise.get_future().wait(); // wait for connection shutdown to complete
538 
539     return 0;
540 }
541