1 #include "ClientConnection.hpp"
2 #include "CryptoHandler.hpp"
3 #include "Headers.hpp"
4 #include "LogHandler.hpp"
5 #include "ParseConfigFile.hpp"
6 #include "PortForwardHandler.hpp"
7 #include "PortForwardSourceHandler.hpp"
8 #include "RawSocketUtils.hpp"
9 #include "ServerConnection.hpp"
10 #include "SshSetupHandler.hpp"
11 #include "TcpSocketHandler.hpp"
12 
13 #include <errno.h>
14 #include <pwd.h>
15 #include <sys/ioctl.h>
16 #include <sys/types.h>
17 #include <termios.h>
18 
19 #include "ETerminal.pb.h"
20 
21 using namespace et;
22 namespace google {}
23 namespace gflags {}
24 using namespace google;
25 using namespace gflags;
26 
27 const string SYSTEM_SSH_CONFIG_PATH = "/etc/ssh/ssh_config";
28 const string USER_SSH_CONFIG_PATH = "/.ssh/config";
29 const int KEEP_ALIVE_DURATION = 5;
30 
31 shared_ptr<ClientConnection> globalClient;
32 
33 termios terminal_backup;
34 
35 DEFINE_string(u, "", "username to login");
36 DEFINE_string(host, "localhost", "host to join");
37 DEFINE_int32(port, 2022, "port to connect on");
38 DEFINE_string(c, "", "Command to run immediately after connecting");
39 DEFINE_string(
40     prefix, "",
41     "Command prefix to launch etserver/etterminal on the server side");
42 DEFINE_string(t, "",
43               "Array of source:destination ports or "
44               "srcStart-srcEnd:dstStart-dstEnd (inclusive) port ranges (e.g. "
45               "10080:80,10443:443, 10090-10092:8000-8002)");
46 DEFINE_string(rt, "",
47               "Array of source:destination ports or "
48               "srcStart-srcEnd:dstStart-dstEnd (inclusive) port ranges (e.g. "
49               "10080:80,10443:443, 10090-10092:8000-8002)");
50 DEFINE_string(jumphost, "", "jumphost between localhost and destination");
51 DEFINE_int32(jport, 2022, "port to connect on jumphost");
52 DEFINE_bool(x, false, "flag to kill all old sessions belonging to the user");
53 DEFINE_int32(v, 0, "verbose level");
54 DEFINE_bool(logtostdout, false, "log to stdout");
55 DEFINE_bool(silent, false, "If enabled, disable logging");
56 DEFINE_bool(noratelimit, false,
57             "There's 1024 lines/second limit, which can be "
58             "disabled based on different use case.");
59 
createClient(string idpasskeypair)60 shared_ptr<ClientConnection> createClient(string idpasskeypair) {
61   string id = "", passkey = "";
62   // Trim whitespace
63   idpasskeypair.erase(idpasskeypair.find_last_not_of(" \n\r\t") + 1);
64   size_t slashIndex = idpasskeypair.find("/");
65   if (slashIndex == string::npos) {
66     LOG(FATAL) << "Invalid idPasskey id/key pair: " << idpasskeypair;
67   } else {
68     id = idpasskeypair.substr(0, slashIndex);
69     passkey = idpasskeypair.substr(slashIndex + 1);
70     LOG(INFO) << "ID PASSKEY: " << id << " " << passkey;
71   }
72   if (passkey.length() != 32) {
73     LOG(FATAL) << "Invalid/missing passkey: " << passkey << " "
74                << passkey.length();
75   }
76 
77   InitialPayload payload;
78   if (FLAGS_jumphost.length()) {
79     payload.set_jumphost(true);
80   }
81 
82   shared_ptr<SocketHandler> clientSocket(new TcpSocketHandler());
83   shared_ptr<ClientConnection> client =
84       shared_ptr<ClientConnection>(new ClientConnection(
85           clientSocket, SocketEndpoint(FLAGS_host, FLAGS_port), id, passkey));
86 
87   int connectFailCount = 0;
88   while (true) {
89     try {
90       if (client->connect()) {
91         client->writeProto(payload);
92         break;
93       } else {
94         LOG(ERROR) << "Connecting to server failed: Connect timeout";
95         connectFailCount++;
96         if (connectFailCount == 3) {
97           throw std::runtime_error("Connect Timeout");
98         }
99       }
100     } catch (const runtime_error& err) {
101       LOG(INFO) << "Could not make initial connection to server";
102       cout << "Could not make initial connection to " << FLAGS_host << ": "
103            << err.what() << endl;
104       exit(1);
105     }
106     break;
107   }
108   VLOG(1) << "Client created with id: " << client->getId();
109 
110   return client;
111 };
112 
113 int firstWindowChangedCall = 1;
handleWindowChanged(winsize * win)114 void handleWindowChanged(winsize* win) {
115   winsize tmpwin;
116   ioctl(1, TIOCGWINSZ, &tmpwin);
117   if (firstWindowChangedCall || win->ws_row != tmpwin.ws_row ||
118       win->ws_col != tmpwin.ws_col || win->ws_xpixel != tmpwin.ws_xpixel ||
119       win->ws_ypixel != tmpwin.ws_ypixel) {
120     firstWindowChangedCall = 0;
121     *win = tmpwin;
122     LOG(INFO) << "Window size changed: " << win->ws_row << " " << win->ws_col
123               << " " << win->ws_xpixel << " " << win->ws_ypixel;
124     TerminalInfo ti;
125     ti.set_row(win->ws_row);
126     ti.set_column(win->ws_col);
127     ti.set_width(win->ws_xpixel);
128     ti.set_height(win->ws_ypixel);
129     string s(1, (char)et::PacketType::TERMINAL_INFO);
130     globalClient->writeMessage(s);
131     globalClient->writeProto(ti);
132   }
133 }
134 
parseRangesToPairs(const string & input)135 vector<pair<int, int>> parseRangesToPairs(const string& input) {
136   vector<pair<int, int>> pairs;
137   auto j = split(input, ',');
138   for (auto& pair : j) {
139     vector<string> sourceDestination = split(pair, ':');
140     try {
141       if (sourceDestination[0].find('-') != string::npos &&
142           sourceDestination[1].find('-') != string::npos) {
143         vector<string> sourcePortRange = split(sourceDestination[0], '-');
144         int sourcePortStart = stoi(sourcePortRange[0]);
145         int sourcePortEnd = stoi(sourcePortRange[1]);
146 
147         vector<string> destinationPortRange = split(sourceDestination[1], '-');
148         int destinationPortStart = stoi(destinationPortRange[0]);
149         int destinationPortEnd = stoi(destinationPortRange[1]);
150 
151         if (sourcePortEnd - sourcePortStart !=
152             destinationPortEnd - destinationPortStart) {
153           LOG(FATAL) << "source/destination port range mismatch";
154           exit(1);
155         } else {
156           int portRangeLength = sourcePortEnd - sourcePortStart + 1;
157           for (int i = 0; i < portRangeLength; ++i) {
158             pairs.push_back(
159                 make_pair(sourcePortStart + i, destinationPortStart + i));
160           }
161         }
162       } else if (sourceDestination[0].find('-') != string::npos ||
163                  sourceDestination[1].find('-') != string::npos) {
164         LOG(FATAL) << "Invalid port range syntax: if source is range, "
165                       "destination must be range";
166       } else {
167         int sourcePort = stoi(sourceDestination[0]);
168         int destinationPort = stoi(sourceDestination[1]);
169         pairs.push_back(make_pair(sourcePort, destinationPort));
170       }
171     } catch (const std::logic_error& lr) {
172       LOG(FATAL) << "Logic error: " << lr.what();
173       exit(1);
174     }
175   }
176   return pairs;
177 }
178 
main(int argc,char ** argv)179 int main(int argc, char** argv) {
180   // Version string need to be set before GFLAGS parse arguments
181   SetVersionString(string(ET_VERSION));
182 
183   // Setup easylogging configurations
184   el::Configurations defaultConf = LogHandler::setupLogHandler(&argc, &argv);
185 
186   if (FLAGS_logtostdout) {
187     defaultConf.setGlobally(el::ConfigurationType::ToStandardOutput, "true");
188   } else {
189     defaultConf.setGlobally(el::ConfigurationType::ToStandardOutput, "false");
190     // Redirect std streams to a file
191     LogHandler::stderrToFile("/tmp/etclient");
192   }
193 
194   // silent Flag, since etclient doesn't read /etc/et.cfg file
195   if (FLAGS_silent) {
196     defaultConf.setGlobally(el::ConfigurationType::Enabled, "false");
197   }
198 
199   LogHandler::setupLogFile(&defaultConf,
200                            "/tmp/etclient-%datetime{%Y-%M-%d_%H_%m_%s}.log");
201 
202   el::Loggers::reconfigureLogger("default", defaultConf);
203   // set thread name
204   el::Helpers::setThreadName("client-main");
205 
206   // Install log rotation callback
207   el::Helpers::installPreRollOutCallback(LogHandler::rolloutHandler);
208 
209   // Override -h & --help
210   for (int i = 1; i < argc; i++) {
211     string s(argv[i]);
212     if (s == "-h" || s == "--help") {
213       cout << "et (options) [user@]hostname[:port]\n"
214               "Options:\n"
215               "-h Basic usage\n"
216               "-p Port for etserver to run on.  Default: 2022\n"
217               "-u Username to connect to ssh & ET\n"
218               "-v=9 verbose log files\n"
219               "-c Initial command to execute upon connecting\n"
220               "-prefix Command prefix to launch etserver/etterminal on the "
221               "server side\n"
222               "-t Map local to remote TCP port (TCP Tunneling)\n"
223               "   example: et -t=\"18000:8000\" hostname maps localhost:18000\n"
224               "-rt Map remote to local TCP port (TCP Reverse Tunneling)\n"
225               "   example: et -rt=\"18000:8000\" hostname maps hostname:18000\n"
226               "to localhost:8000\n"
227               "-jumphost Jumphost between localhost and destination\n"
228               "-jport Port to connect on jumphost\n"
229               "-x Flag to kill all sessions belongs to the user\n"
230               "-logtostdout Sent log message to stdout\n"
231               "-silent Disable all logs\n"
232               "-noratelimit Disable rate limit"
233            << endl;
234       exit(1);
235     }
236   }
237 
238   GOOGLE_PROTOBUF_VERIFY_VERSION;
239   srand(1);
240 
241   // Parse command-line argument
242   if (argc > 1) {
243     string arg = string(argv[1]);
244     if (arg.find('@') != string::npos) {
245       int i = arg.find('@');
246       FLAGS_u = arg.substr(0, i);
247       arg = arg.substr(i + 1);
248     }
249     if (arg.find(':') != string::npos) {
250       int i = arg.find(':');
251       FLAGS_port = stoi(arg.substr(i + 1));
252       arg = arg.substr(0, i);
253     }
254     FLAGS_host = arg;
255   }
256 
257   Options options = {
258       NULL,  // username
259       NULL,  // host
260       NULL,  // sshdir
261       NULL,  // knownhosts
262       NULL,  // ProxyCommand
263       NULL,  // ProxyJump
264       0,     // timeout
265       0,     // port
266       0,     // StrictHostKeyChecking
267       0,     // ssh2
268       0,     // ssh1
269       NULL,  // gss_server_identity
270       NULL,  // gss_client_identity
271       0      // gss_delegate_creds
272   };
273 
274   char* home_dir = ssh_get_user_home_dir();
275   string host_alias = FLAGS_host;
276   ssh_options_set(&options, SSH_OPTIONS_HOST, FLAGS_host.c_str());
277   // First parse user-specific ssh config, then system-wide config.
278   parse_ssh_config_file(&options, string(home_dir) + USER_SSH_CONFIG_PATH);
279   parse_ssh_config_file(&options, SYSTEM_SSH_CONFIG_PATH);
280   LOG(INFO) << "Parsed ssh config file, connecting to " << options.host;
281   FLAGS_host = string(options.host);
282 
283   // Parse username: cmdline > sshconfig > localuser
284   if (FLAGS_u.empty()) {
285     if (options.username) {
286       FLAGS_u = string(options.username);
287     } else {
288       FLAGS_u = string(ssh_get_local_username());
289     }
290   }
291 
292   // Parse jumphost: cmd > sshconfig
293   if (options.ProxyJump && FLAGS_jumphost.length() == 0) {
294     string proxyjump = string(options.ProxyJump);
295     size_t colonIndex = proxyjump.find(":");
296     if (colonIndex != string::npos) {
297       string userhostpair = proxyjump.substr(0, colonIndex);
298       size_t atIndex = userhostpair.find("@");
299       if (atIndex != string::npos) {
300         FLAGS_jumphost = userhostpair.substr(atIndex + 1);
301       }
302     } else {
303       FLAGS_jumphost = proxyjump;
304     }
305     LOG(INFO) << "ProxyJump found for dst in ssh config" << proxyjump;
306   }
307 
308   string idpasskeypair = SshSetupHandler::SetupSsh(
309       FLAGS_u, FLAGS_host, host_alias, FLAGS_port, FLAGS_jumphost, FLAGS_jport,
310       FLAGS_x, FLAGS_v, FLAGS_prefix, FLAGS_noratelimit);
311 
312   if (!FLAGS_jumphost.empty()) {
313     FLAGS_host = FLAGS_jumphost;
314     FLAGS_port = FLAGS_jport;
315   }
316   globalClient = createClient(idpasskeypair);
317   shared_ptr<TcpSocketHandler> socketHandler =
318       static_pointer_cast<TcpSocketHandler>(globalClient->getSocketHandler());
319 
320   PortForwardHandler portForwardHandler(socketHandler);
321 
322   // Whether the TE should keep running.
323   bool run = true;
324 
325 // TE sends/receives data to/from the shell one char at a time.
326 #define BUF_SIZE (16 * 1024)
327   char b[BUF_SIZE];
328 
329   time_t keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
330   bool waitingOnKeepalive = false;
331 
332   if (FLAGS_c.length()) {
333     LOG(INFO) << "Got command: " << FLAGS_c;
334     et::TerminalBuffer tb;
335     tb.set_buffer(FLAGS_c + "; exit\n");
336 
337     char c = et::PacketType::TERMINAL_BUFFER;
338     string headerString(1, c);
339     globalClient->writeMessage(headerString);
340     globalClient->writeProto(tb);
341   }
342 
343   try {
344     if (FLAGS_t.length()) {
345       auto pairs = parseRangesToPairs(FLAGS_t);
346       for (auto& pair : pairs) {
347         PortForwardSourceRequest pfsr;
348         pfsr.set_sourceport(pair.first);
349         pfsr.set_destinationport(pair.second);
350         auto pfsresponse = portForwardHandler.createSource(pfsr);
351         if (pfsresponse.has_error()) {
352           throw std::runtime_error(pfsresponse.error());
353         }
354       }
355     }
356     if (FLAGS_rt.length()) {
357       auto pairs = parseRangesToPairs(FLAGS_rt);
358       for (auto& pair : pairs) {
359         char c = et::PacketType::PORT_FORWARD_SOURCE_REQUEST;
360         string headerString(1, c);
361         PortForwardSourceRequest pfsr;
362         pfsr.set_sourceport(pair.first);
363         pfsr.set_destinationport(pair.second);
364 
365         globalClient->writeMessage(headerString);
366         globalClient->writeProto(pfsr);
367       }
368     }
369   } catch (const std::runtime_error& ex) {
370     cerr << "Error establishing port forward: " << ex.what() << endl;
371     LOG(FATAL) << "Error establishing port forward: " << ex.what();
372   }
373 
374   winsize win;
375   ioctl(1, TIOCGWINSZ, &win);
376 
377   termios terminal_local;
378   tcgetattr(0, &terminal_local);
379   memcpy(&terminal_backup, &terminal_local, sizeof(struct termios));
380   cfmakeraw(&terminal_local);
381   tcsetattr(0, TCSANOW, &terminal_local);
382 
383   while (run && !globalClient->isShuttingDown()) {
384     // Data structures needed for select() and
385     // non-blocking I/O.
386     fd_set rfd;
387     timeval tv;
388 
389     FD_ZERO(&rfd);
390     int maxfd = STDIN_FILENO;
391     FD_SET(STDIN_FILENO, &rfd);
392     int clientFd = globalClient->getSocketFd();
393     if (clientFd > 0) {
394       FD_SET(clientFd, &rfd);
395       maxfd = max(maxfd, clientFd);
396     }
397     // TODO: set port forward sockets as well for performance reasons.
398     tv.tv_sec = 0;
399     tv.tv_usec = 10000;
400     select(maxfd + 1, &rfd, NULL, NULL, &tv);
401 
402     try {
403       // Check for data to send.
404       if (FD_ISSET(STDIN_FILENO, &rfd)) {
405         // Read from stdin and write to our client that will then send it to the
406         // server.
407         VLOG(4) << "Got data from stdin";
408         int rc = read(STDIN_FILENO, b, BUF_SIZE);
409         FATAL_FAIL(rc);
410         if (rc > 0) {
411           // VLOG(1) << "Sending byte: " << int(b) << " " << char(b) << " " <<
412           // globalClient->getWriter()->getSequenceNumber();
413           string s(b, rc);
414           et::TerminalBuffer tb;
415           tb.set_buffer(s);
416 
417           char c = et::PacketType::TERMINAL_BUFFER;
418           string headerString(1, c);
419           globalClient->writeMessage(headerString);
420           globalClient->writeProto(tb);
421           keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
422         }
423       }
424 
425       if (clientFd > 0 && FD_ISSET(clientFd, &rfd)) {
426         VLOG(4) << "Cliendfd is selected";
427         while (globalClient->hasData()) {
428           VLOG(4) << "GlobalClient has data";
429           string packetTypeString;
430           if (!globalClient->read(&packetTypeString)) {
431             break;
432           }
433           if (packetTypeString.length() != 1) {
434             LOG(FATAL) << "Invalid packet header size: "
435                        << packetTypeString.length();
436           }
437           char packetType = packetTypeString[0];
438           if (packetType == et::PacketType::PORT_FORWARD_DATA ||
439               packetType == et::PacketType::PORT_FORWARD_SOURCE_REQUEST ||
440               packetType == et::PacketType::PORT_FORWARD_SOURCE_RESPONSE ||
441               packetType == et::PacketType::PORT_FORWARD_DESTINATION_REQUEST ||
442               packetType == et::PacketType::PORT_FORWARD_DESTINATION_RESPONSE) {
443             keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
444             VLOG(4) << "Got PF packet type " << packetType;
445             portForwardHandler.handlePacket(packetType, globalClient);
446             continue;
447           }
448           switch (packetType) {
449             case et::PacketType::TERMINAL_BUFFER: {
450               VLOG(3) << "Got terminal buffer";
451               // Read from the server and write to our fake terminal
452               et::TerminalBuffer tb =
453                   globalClient->readProto<et::TerminalBuffer>();
454               const string& s = tb.buffer();
455               // VLOG(5) << "Got message: " << s;
456               // VLOG(1) << "Got byte: " << int(b) << " " << char(b) << " " <<
457               // globalClient->getReader()->getSequenceNumber();
458               keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
459               RawSocketUtils::writeAll(STDOUT_FILENO, &s[0], s.length());
460               break;
461             }
462             case et::PacketType::KEEP_ALIVE:
463               waitingOnKeepalive = false;
464               // This will fill up log file quickly but is helpful for debugging
465               // latency issues.
466               LOG(INFO) << "Got a keepalive";
467               break;
468             default:
469               LOG(FATAL) << "Unknown packet type: " << int(packetType);
470           }
471         }
472       }
473 
474       if (clientFd > 0 && keepaliveTime < time(NULL)) {
475         keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
476         if (waitingOnKeepalive) {
477           LOG(INFO) << "Missed a keepalive, killing connection.";
478           globalClient->closeSocketAndMaybeReconnect();
479           waitingOnKeepalive = false;
480         } else {
481           LOG(INFO) << "Writing keepalive packet";
482           string s(1, (char)et::PacketType::KEEP_ALIVE);
483           globalClient->writeMessage(s);
484           waitingOnKeepalive = true;
485         }
486       }
487       if (clientFd < 0) {
488         // We are disconnected, so stop waiting for keepalive.
489         waitingOnKeepalive = false;
490       }
491 
492       handleWindowChanged(&win);
493 
494       vector<PortForwardDestinationRequest> requests;
495       vector<PortForwardData> dataToSend;
496       portForwardHandler.update(&requests, &dataToSend);
497       for (auto& pfr : requests) {
498         char c = et::PacketType::PORT_FORWARD_DESTINATION_REQUEST;
499         string headerString(1, c);
500         globalClient->writeMessage(headerString);
501         globalClient->writeProto(pfr);
502         VLOG(4) << "send PF request";
503         keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
504       }
505       for (auto& pwd : dataToSend) {
506         char c = PacketType::PORT_FORWARD_DATA;
507         string headerString(1, c);
508         globalClient->writeMessage(headerString);
509         globalClient->writeProto(pwd);
510         VLOG(4) << "send PF data";
511         keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
512       }
513     } catch (const runtime_error& re) {
514       LOG(ERROR) << "Error: " << re.what();
515       tcsetattr(0, TCSANOW, &terminal_backup);
516       cout << "Connection closing because of error: " << re.what() << endl;
517       run = false;
518     }
519   }
520   globalClient.reset();
521   LOG(INFO) << "Client derefernced";
522   tcsetattr(0, TCSANOW, &terminal_backup);
523   cout << "Session terminated" << endl;
524   // Uninstall log rotation callback
525   el::Helpers::uninstallPreRollOutCallback();
526   return 0;
527 }
528