1 #include "ClientConnection.hpp"
2 #include "CryptoHandler.hpp"
3 #include "Headers.hpp"
4 #include "LogHandler.hpp"
5 #include "ParseConfigFile.hpp"
6 #include "PortForwardHandler.hpp"
7 #include "PortForwardSourceHandler.hpp"
8 #include "RawSocketUtils.hpp"
9 #include "ServerConnection.hpp"
10 #include "SshSetupHandler.hpp"
11 #include "TcpSocketHandler.hpp"
12
13 #include <errno.h>
14 #include <pwd.h>
15 #include <sys/ioctl.h>
16 #include <sys/types.h>
17 #include <termios.h>
18
19 #include "ETerminal.pb.h"
20
21 using namespace et;
22 namespace google {}
23 namespace gflags {}
24 using namespace google;
25 using namespace gflags;
26
27 const string SYSTEM_SSH_CONFIG_PATH = "/etc/ssh/ssh_config";
28 const string USER_SSH_CONFIG_PATH = "/.ssh/config";
29 const int KEEP_ALIVE_DURATION = 5;
30
31 shared_ptr<ClientConnection> globalClient;
32
33 termios terminal_backup;
34
35 DEFINE_string(u, "", "username to login");
36 DEFINE_string(host, "localhost", "host to join");
37 DEFINE_int32(port, 2022, "port to connect on");
38 DEFINE_string(c, "", "Command to run immediately after connecting");
39 DEFINE_string(
40 prefix, "",
41 "Command prefix to launch etserver/etterminal on the server side");
42 DEFINE_string(t, "",
43 "Array of source:destination ports or "
44 "srcStart-srcEnd:dstStart-dstEnd (inclusive) port ranges (e.g. "
45 "10080:80,10443:443, 10090-10092:8000-8002)");
46 DEFINE_string(rt, "",
47 "Array of source:destination ports or "
48 "srcStart-srcEnd:dstStart-dstEnd (inclusive) port ranges (e.g. "
49 "10080:80,10443:443, 10090-10092:8000-8002)");
50 DEFINE_string(jumphost, "", "jumphost between localhost and destination");
51 DEFINE_int32(jport, 2022, "port to connect on jumphost");
52 DEFINE_bool(x, false, "flag to kill all old sessions belonging to the user");
53 DEFINE_int32(v, 0, "verbose level");
54 DEFINE_bool(logtostdout, false, "log to stdout");
55 DEFINE_bool(silent, false, "If enabled, disable logging");
56 DEFINE_bool(noratelimit, false,
57 "There's 1024 lines/second limit, which can be "
58 "disabled based on different use case.");
59
createClient(string idpasskeypair)60 shared_ptr<ClientConnection> createClient(string idpasskeypair) {
61 string id = "", passkey = "";
62 // Trim whitespace
63 idpasskeypair.erase(idpasskeypair.find_last_not_of(" \n\r\t") + 1);
64 size_t slashIndex = idpasskeypair.find("/");
65 if (slashIndex == string::npos) {
66 LOG(FATAL) << "Invalid idPasskey id/key pair: " << idpasskeypair;
67 } else {
68 id = idpasskeypair.substr(0, slashIndex);
69 passkey = idpasskeypair.substr(slashIndex + 1);
70 LOG(INFO) << "ID PASSKEY: " << id << " " << passkey;
71 }
72 if (passkey.length() != 32) {
73 LOG(FATAL) << "Invalid/missing passkey: " << passkey << " "
74 << passkey.length();
75 }
76
77 InitialPayload payload;
78 if (FLAGS_jumphost.length()) {
79 payload.set_jumphost(true);
80 }
81
82 shared_ptr<SocketHandler> clientSocket(new TcpSocketHandler());
83 shared_ptr<ClientConnection> client =
84 shared_ptr<ClientConnection>(new ClientConnection(
85 clientSocket, SocketEndpoint(FLAGS_host, FLAGS_port), id, passkey));
86
87 int connectFailCount = 0;
88 while (true) {
89 try {
90 if (client->connect()) {
91 client->writeProto(payload);
92 break;
93 } else {
94 LOG(ERROR) << "Connecting to server failed: Connect timeout";
95 connectFailCount++;
96 if (connectFailCount == 3) {
97 throw std::runtime_error("Connect Timeout");
98 }
99 }
100 } catch (const runtime_error& err) {
101 LOG(INFO) << "Could not make initial connection to server";
102 cout << "Could not make initial connection to " << FLAGS_host << ": "
103 << err.what() << endl;
104 exit(1);
105 }
106 break;
107 }
108 VLOG(1) << "Client created with id: " << client->getId();
109
110 return client;
111 };
112
113 int firstWindowChangedCall = 1;
handleWindowChanged(winsize * win)114 void handleWindowChanged(winsize* win) {
115 winsize tmpwin;
116 ioctl(1, TIOCGWINSZ, &tmpwin);
117 if (firstWindowChangedCall || win->ws_row != tmpwin.ws_row ||
118 win->ws_col != tmpwin.ws_col || win->ws_xpixel != tmpwin.ws_xpixel ||
119 win->ws_ypixel != tmpwin.ws_ypixel) {
120 firstWindowChangedCall = 0;
121 *win = tmpwin;
122 LOG(INFO) << "Window size changed: " << win->ws_row << " " << win->ws_col
123 << " " << win->ws_xpixel << " " << win->ws_ypixel;
124 TerminalInfo ti;
125 ti.set_row(win->ws_row);
126 ti.set_column(win->ws_col);
127 ti.set_width(win->ws_xpixel);
128 ti.set_height(win->ws_ypixel);
129 string s(1, (char)et::PacketType::TERMINAL_INFO);
130 globalClient->writeMessage(s);
131 globalClient->writeProto(ti);
132 }
133 }
134
parseRangesToPairs(const string & input)135 vector<pair<int, int>> parseRangesToPairs(const string& input) {
136 vector<pair<int, int>> pairs;
137 auto j = split(input, ',');
138 for (auto& pair : j) {
139 vector<string> sourceDestination = split(pair, ':');
140 try {
141 if (sourceDestination[0].find('-') != string::npos &&
142 sourceDestination[1].find('-') != string::npos) {
143 vector<string> sourcePortRange = split(sourceDestination[0], '-');
144 int sourcePortStart = stoi(sourcePortRange[0]);
145 int sourcePortEnd = stoi(sourcePortRange[1]);
146
147 vector<string> destinationPortRange = split(sourceDestination[1], '-');
148 int destinationPortStart = stoi(destinationPortRange[0]);
149 int destinationPortEnd = stoi(destinationPortRange[1]);
150
151 if (sourcePortEnd - sourcePortStart !=
152 destinationPortEnd - destinationPortStart) {
153 LOG(FATAL) << "source/destination port range mismatch";
154 exit(1);
155 } else {
156 int portRangeLength = sourcePortEnd - sourcePortStart + 1;
157 for (int i = 0; i < portRangeLength; ++i) {
158 pairs.push_back(
159 make_pair(sourcePortStart + i, destinationPortStart + i));
160 }
161 }
162 } else if (sourceDestination[0].find('-') != string::npos ||
163 sourceDestination[1].find('-') != string::npos) {
164 LOG(FATAL) << "Invalid port range syntax: if source is range, "
165 "destination must be range";
166 } else {
167 int sourcePort = stoi(sourceDestination[0]);
168 int destinationPort = stoi(sourceDestination[1]);
169 pairs.push_back(make_pair(sourcePort, destinationPort));
170 }
171 } catch (const std::logic_error& lr) {
172 LOG(FATAL) << "Logic error: " << lr.what();
173 exit(1);
174 }
175 }
176 return pairs;
177 }
178
main(int argc,char ** argv)179 int main(int argc, char** argv) {
180 // Version string need to be set before GFLAGS parse arguments
181 SetVersionString(string(ET_VERSION));
182
183 // Setup easylogging configurations
184 el::Configurations defaultConf = LogHandler::setupLogHandler(&argc, &argv);
185
186 if (FLAGS_logtostdout) {
187 defaultConf.setGlobally(el::ConfigurationType::ToStandardOutput, "true");
188 } else {
189 defaultConf.setGlobally(el::ConfigurationType::ToStandardOutput, "false");
190 // Redirect std streams to a file
191 LogHandler::stderrToFile("/tmp/etclient");
192 }
193
194 // silent Flag, since etclient doesn't read /etc/et.cfg file
195 if (FLAGS_silent) {
196 defaultConf.setGlobally(el::ConfigurationType::Enabled, "false");
197 }
198
199 LogHandler::setupLogFile(&defaultConf,
200 "/tmp/etclient-%datetime{%Y-%M-%d_%H_%m_%s}.log");
201
202 el::Loggers::reconfigureLogger("default", defaultConf);
203 // set thread name
204 el::Helpers::setThreadName("client-main");
205
206 // Install log rotation callback
207 el::Helpers::installPreRollOutCallback(LogHandler::rolloutHandler);
208
209 // Override -h & --help
210 for (int i = 1; i < argc; i++) {
211 string s(argv[i]);
212 if (s == "-h" || s == "--help") {
213 cout << "et (options) [user@]hostname[:port]\n"
214 "Options:\n"
215 "-h Basic usage\n"
216 "-p Port for etserver to run on. Default: 2022\n"
217 "-u Username to connect to ssh & ET\n"
218 "-v=9 verbose log files\n"
219 "-c Initial command to execute upon connecting\n"
220 "-prefix Command prefix to launch etserver/etterminal on the "
221 "server side\n"
222 "-t Map local to remote TCP port (TCP Tunneling)\n"
223 " example: et -t=\"18000:8000\" hostname maps localhost:18000\n"
224 "-rt Map remote to local TCP port (TCP Reverse Tunneling)\n"
225 " example: et -rt=\"18000:8000\" hostname maps hostname:18000\n"
226 "to localhost:8000\n"
227 "-jumphost Jumphost between localhost and destination\n"
228 "-jport Port to connect on jumphost\n"
229 "-x Flag to kill all sessions belongs to the user\n"
230 "-logtostdout Sent log message to stdout\n"
231 "-silent Disable all logs\n"
232 "-noratelimit Disable rate limit"
233 << endl;
234 exit(1);
235 }
236 }
237
238 GOOGLE_PROTOBUF_VERIFY_VERSION;
239 srand(1);
240
241 // Parse command-line argument
242 if (argc > 1) {
243 string arg = string(argv[1]);
244 if (arg.find('@') != string::npos) {
245 int i = arg.find('@');
246 FLAGS_u = arg.substr(0, i);
247 arg = arg.substr(i + 1);
248 }
249 if (arg.find(':') != string::npos) {
250 int i = arg.find(':');
251 FLAGS_port = stoi(arg.substr(i + 1));
252 arg = arg.substr(0, i);
253 }
254 FLAGS_host = arg;
255 }
256
257 Options options = {
258 NULL, // username
259 NULL, // host
260 NULL, // sshdir
261 NULL, // knownhosts
262 NULL, // ProxyCommand
263 NULL, // ProxyJump
264 0, // timeout
265 0, // port
266 0, // StrictHostKeyChecking
267 0, // ssh2
268 0, // ssh1
269 NULL, // gss_server_identity
270 NULL, // gss_client_identity
271 0 // gss_delegate_creds
272 };
273
274 char* home_dir = ssh_get_user_home_dir();
275 string host_alias = FLAGS_host;
276 ssh_options_set(&options, SSH_OPTIONS_HOST, FLAGS_host.c_str());
277 // First parse user-specific ssh config, then system-wide config.
278 parse_ssh_config_file(&options, string(home_dir) + USER_SSH_CONFIG_PATH);
279 parse_ssh_config_file(&options, SYSTEM_SSH_CONFIG_PATH);
280 LOG(INFO) << "Parsed ssh config file, connecting to " << options.host;
281 FLAGS_host = string(options.host);
282
283 // Parse username: cmdline > sshconfig > localuser
284 if (FLAGS_u.empty()) {
285 if (options.username) {
286 FLAGS_u = string(options.username);
287 } else {
288 FLAGS_u = string(ssh_get_local_username());
289 }
290 }
291
292 // Parse jumphost: cmd > sshconfig
293 if (options.ProxyJump && FLAGS_jumphost.length() == 0) {
294 string proxyjump = string(options.ProxyJump);
295 size_t colonIndex = proxyjump.find(":");
296 if (colonIndex != string::npos) {
297 string userhostpair = proxyjump.substr(0, colonIndex);
298 size_t atIndex = userhostpair.find("@");
299 if (atIndex != string::npos) {
300 FLAGS_jumphost = userhostpair.substr(atIndex + 1);
301 }
302 } else {
303 FLAGS_jumphost = proxyjump;
304 }
305 LOG(INFO) << "ProxyJump found for dst in ssh config" << proxyjump;
306 }
307
308 string idpasskeypair = SshSetupHandler::SetupSsh(
309 FLAGS_u, FLAGS_host, host_alias, FLAGS_port, FLAGS_jumphost, FLAGS_jport,
310 FLAGS_x, FLAGS_v, FLAGS_prefix, FLAGS_noratelimit);
311
312 if (!FLAGS_jumphost.empty()) {
313 FLAGS_host = FLAGS_jumphost;
314 FLAGS_port = FLAGS_jport;
315 }
316 globalClient = createClient(idpasskeypair);
317 shared_ptr<TcpSocketHandler> socketHandler =
318 static_pointer_cast<TcpSocketHandler>(globalClient->getSocketHandler());
319
320 PortForwardHandler portForwardHandler(socketHandler);
321
322 // Whether the TE should keep running.
323 bool run = true;
324
325 // TE sends/receives data to/from the shell one char at a time.
326 #define BUF_SIZE (16 * 1024)
327 char b[BUF_SIZE];
328
329 time_t keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
330 bool waitingOnKeepalive = false;
331
332 if (FLAGS_c.length()) {
333 LOG(INFO) << "Got command: " << FLAGS_c;
334 et::TerminalBuffer tb;
335 tb.set_buffer(FLAGS_c + "; exit\n");
336
337 char c = et::PacketType::TERMINAL_BUFFER;
338 string headerString(1, c);
339 globalClient->writeMessage(headerString);
340 globalClient->writeProto(tb);
341 }
342
343 try {
344 if (FLAGS_t.length()) {
345 auto pairs = parseRangesToPairs(FLAGS_t);
346 for (auto& pair : pairs) {
347 PortForwardSourceRequest pfsr;
348 pfsr.set_sourceport(pair.first);
349 pfsr.set_destinationport(pair.second);
350 auto pfsresponse = portForwardHandler.createSource(pfsr);
351 if (pfsresponse.has_error()) {
352 throw std::runtime_error(pfsresponse.error());
353 }
354 }
355 }
356 if (FLAGS_rt.length()) {
357 auto pairs = parseRangesToPairs(FLAGS_rt);
358 for (auto& pair : pairs) {
359 char c = et::PacketType::PORT_FORWARD_SOURCE_REQUEST;
360 string headerString(1, c);
361 PortForwardSourceRequest pfsr;
362 pfsr.set_sourceport(pair.first);
363 pfsr.set_destinationport(pair.second);
364
365 globalClient->writeMessage(headerString);
366 globalClient->writeProto(pfsr);
367 }
368 }
369 } catch (const std::runtime_error& ex) {
370 cerr << "Error establishing port forward: " << ex.what() << endl;
371 LOG(FATAL) << "Error establishing port forward: " << ex.what();
372 }
373
374 winsize win;
375 ioctl(1, TIOCGWINSZ, &win);
376
377 termios terminal_local;
378 tcgetattr(0, &terminal_local);
379 memcpy(&terminal_backup, &terminal_local, sizeof(struct termios));
380 cfmakeraw(&terminal_local);
381 tcsetattr(0, TCSANOW, &terminal_local);
382
383 while (run && !globalClient->isShuttingDown()) {
384 // Data structures needed for select() and
385 // non-blocking I/O.
386 fd_set rfd;
387 timeval tv;
388
389 FD_ZERO(&rfd);
390 int maxfd = STDIN_FILENO;
391 FD_SET(STDIN_FILENO, &rfd);
392 int clientFd = globalClient->getSocketFd();
393 if (clientFd > 0) {
394 FD_SET(clientFd, &rfd);
395 maxfd = max(maxfd, clientFd);
396 }
397 // TODO: set port forward sockets as well for performance reasons.
398 tv.tv_sec = 0;
399 tv.tv_usec = 10000;
400 select(maxfd + 1, &rfd, NULL, NULL, &tv);
401
402 try {
403 // Check for data to send.
404 if (FD_ISSET(STDIN_FILENO, &rfd)) {
405 // Read from stdin and write to our client that will then send it to the
406 // server.
407 VLOG(4) << "Got data from stdin";
408 int rc = read(STDIN_FILENO, b, BUF_SIZE);
409 FATAL_FAIL(rc);
410 if (rc > 0) {
411 // VLOG(1) << "Sending byte: " << int(b) << " " << char(b) << " " <<
412 // globalClient->getWriter()->getSequenceNumber();
413 string s(b, rc);
414 et::TerminalBuffer tb;
415 tb.set_buffer(s);
416
417 char c = et::PacketType::TERMINAL_BUFFER;
418 string headerString(1, c);
419 globalClient->writeMessage(headerString);
420 globalClient->writeProto(tb);
421 keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
422 }
423 }
424
425 if (clientFd > 0 && FD_ISSET(clientFd, &rfd)) {
426 VLOG(4) << "Cliendfd is selected";
427 while (globalClient->hasData()) {
428 VLOG(4) << "GlobalClient has data";
429 string packetTypeString;
430 if (!globalClient->read(&packetTypeString)) {
431 break;
432 }
433 if (packetTypeString.length() != 1) {
434 LOG(FATAL) << "Invalid packet header size: "
435 << packetTypeString.length();
436 }
437 char packetType = packetTypeString[0];
438 if (packetType == et::PacketType::PORT_FORWARD_DATA ||
439 packetType == et::PacketType::PORT_FORWARD_SOURCE_REQUEST ||
440 packetType == et::PacketType::PORT_FORWARD_SOURCE_RESPONSE ||
441 packetType == et::PacketType::PORT_FORWARD_DESTINATION_REQUEST ||
442 packetType == et::PacketType::PORT_FORWARD_DESTINATION_RESPONSE) {
443 keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
444 VLOG(4) << "Got PF packet type " << packetType;
445 portForwardHandler.handlePacket(packetType, globalClient);
446 continue;
447 }
448 switch (packetType) {
449 case et::PacketType::TERMINAL_BUFFER: {
450 VLOG(3) << "Got terminal buffer";
451 // Read from the server and write to our fake terminal
452 et::TerminalBuffer tb =
453 globalClient->readProto<et::TerminalBuffer>();
454 const string& s = tb.buffer();
455 // VLOG(5) << "Got message: " << s;
456 // VLOG(1) << "Got byte: " << int(b) << " " << char(b) << " " <<
457 // globalClient->getReader()->getSequenceNumber();
458 keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
459 RawSocketUtils::writeAll(STDOUT_FILENO, &s[0], s.length());
460 break;
461 }
462 case et::PacketType::KEEP_ALIVE:
463 waitingOnKeepalive = false;
464 // This will fill up log file quickly but is helpful for debugging
465 // latency issues.
466 LOG(INFO) << "Got a keepalive";
467 break;
468 default:
469 LOG(FATAL) << "Unknown packet type: " << int(packetType);
470 }
471 }
472 }
473
474 if (clientFd > 0 && keepaliveTime < time(NULL)) {
475 keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
476 if (waitingOnKeepalive) {
477 LOG(INFO) << "Missed a keepalive, killing connection.";
478 globalClient->closeSocketAndMaybeReconnect();
479 waitingOnKeepalive = false;
480 } else {
481 LOG(INFO) << "Writing keepalive packet";
482 string s(1, (char)et::PacketType::KEEP_ALIVE);
483 globalClient->writeMessage(s);
484 waitingOnKeepalive = true;
485 }
486 }
487 if (clientFd < 0) {
488 // We are disconnected, so stop waiting for keepalive.
489 waitingOnKeepalive = false;
490 }
491
492 handleWindowChanged(&win);
493
494 vector<PortForwardDestinationRequest> requests;
495 vector<PortForwardData> dataToSend;
496 portForwardHandler.update(&requests, &dataToSend);
497 for (auto& pfr : requests) {
498 char c = et::PacketType::PORT_FORWARD_DESTINATION_REQUEST;
499 string headerString(1, c);
500 globalClient->writeMessage(headerString);
501 globalClient->writeProto(pfr);
502 VLOG(4) << "send PF request";
503 keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
504 }
505 for (auto& pwd : dataToSend) {
506 char c = PacketType::PORT_FORWARD_DATA;
507 string headerString(1, c);
508 globalClient->writeMessage(headerString);
509 globalClient->writeProto(pwd);
510 VLOG(4) << "send PF data";
511 keepaliveTime = time(NULL) + KEEP_ALIVE_DURATION;
512 }
513 } catch (const runtime_error& re) {
514 LOG(ERROR) << "Error: " << re.what();
515 tcsetattr(0, TCSANOW, &terminal_backup);
516 cout << "Connection closing because of error: " << re.what() << endl;
517 run = false;
518 }
519 }
520 globalClient.reset();
521 LOG(INFO) << "Client derefernced";
522 tcsetattr(0, TCSANOW, &terminal_backup);
523 cout << "Session terminated" << endl;
524 // Uninstall log rotation callback
525 el::Helpers::uninstallPreRollOutCallback();
526 return 0;
527 }
528