1 #include "SSH.h"
2 #include "Malloc.cpp"
3 
4 namespace Upp {
5 
6 #define LLOG(x)       do { if(SSH::sTrace) RLOG(SSH::GetName(ssh->otype, ssh->oid) << x); } while(false)
7 #define LDUMPHEX(x)   do { if(SSH::sTraceVerbose) RDUMPHEX(x); } while(false)
8 
9 // ssh_keyboard_callback: Authenticates a session, using keyboard-interactive authentication.
10 
ssh_keyboard_callback(const char * name,int name_len,const char * instruction,int instruction_len,int num_prompts,const LIBSSH2_USERAUTH_KBDINT_PROMPT * prompts,LIBSSH2_USERAUTH_KBDINT_RESPONSE * responses,void ** abstract)11 static void ssh_keyboard_callback(const char *name, int name_len, const char *instruction,
12 	int instruction_len, int num_prompts, const LIBSSH2_USERAUTH_KBDINT_PROMPT *prompts,
13 	LIBSSH2_USERAUTH_KBDINT_RESPONSE *responses, void **abstract)
14 {
15 	SshSession *session = static_cast<SshSession*>(*abstract);
16 	for(auto i = 0; i < num_prompts; i++) {
17 		auto response = session->WhenKeyboard(
18 			String(name, name_len),
19 			String(instruction, instruction_len),
20 			String(prompts[i].text, prompts[i].length)
21 		);
22 #ifdef UPP_HEAP
23 		auto *r = (char*) ssh_malloc(response.GetLength(), abstract);
24 		memcpy(r, response.Begin(), response.GetLength());
25 #else
26 		auto *r = strdup(~response);
27 #endif
28 		if(r) {
29 			responses[i].text   = r;
30 			responses[i].length = response.GetLength();
31 		}
32 	}
33 }
34 
35 // ssh_password_change: Requests that the client's password be changed.
36 
ssh_password_change(LIBSSH2_SESSION * session,char ** pwd,int * len,void ** abstract)37 static void ssh_password_change(LIBSSH2_SESSION *session, char **pwd, int *len, void **abstract)
38 {
39 	String newpwd = static_cast<SshSession*>(*abstract)->WhenPasswordChange();
40 #ifdef UPP_HEAP
41 		*pwd = (char*) ssh_malloc(newpwd.GetLength(), abstract);
42 		memcpy(*pwd, ~newpwd, newpwd.GetLength());
43 #else
44 		*pwd = strdup(~newpwd);
45 #endif
46 }
47 
48 // ssh_x11_request: Dispatches incoming X11 requests.
49 
ssh_x11_request(LIBSSH2_SESSION * session,LIBSSH2_CHANNEL * channel,char * shost,int sport,void ** abstract)50 static void ssh_x11_request(LIBSSH2_SESSION *session, LIBSSH2_CHANNEL *channel, char *shost, int sport, void **abstract)
51 {
52 	static_cast<SshSession*>(*abstract)->WhenX11((SshX11Handle) channel);
53 }
54 
55 // ssh_session_libtrace: Allows full-level logging (redirection) of libsssh2 diagnostic messages.
56 
57 #ifdef flagLIBSSH2TRACE
ssh_session_libtrace(LIBSSH2_SESSION * session,void * context,const char * data,size_t length)58 static void ssh_session_libtrace(LIBSSH2_SESSION *session, void* context, const char*data, size_t length)
59 {
60 	if(!session  || !SSH::sTraceVerbose)
61 		return;
62 	auto* ssh_obj = static_cast<SshSession*>(context);
63 	RLOG(SSH::GetName(ssh_obj->GetType(), ssh_obj->GetId()) << String(data, int64(length)));
64 }
65 #endif
66 
Exit()67 void SshSession::Exit()
68 {
69 	if(!session)
70 		return;
71 
72 	Run([=]() mutable {
73 		if(!ssh->session)
74 			return true;
75 		int rc = libssh2_session_disconnect(ssh->session, "Disconnecting...");
76 		if(WouldBlock(rc))
77 			return false;
78 		LLOG("Successfully disconnected from the server.");
79 		return true;
80 	});
81 
82 	Run([=]() mutable {
83 		if(!ssh->session)
84 			return true;
85 		int rc = libssh2_session_free(ssh->session);
86 		if(WouldBlock(rc))
87 			return false;
88 		ssh->init    = false;
89 		ssh->socket  = nullptr;
90 		ssh->session = nullptr;
91 		session->connected = false;
92 		LLOG("Session handles freed.");
93 		return true;
94 	});
95 }
96 
Connect(const String & url)97 bool SshSession::Connect(const String& url)
98 {
99 	UrlInfo u(url);
100 
101 	auto b = u.scheme == "ssh"   ||
102              u.scheme == "scp"   ||
103              u.scheme == "sftp"  ||
104              u.scheme == "exec"  ||
105              (u.scheme.IsEmpty()  && !u.host.IsEmpty());
106 	int port = (u.port.IsEmpty() || !b) ? 22 : StrInt(u.port);
107 
108 	return b ? Connect(u.host, port, u.username, u.password)
109 	         : Run([=]{ SetError(-1, "Malformed secure shell URL."); return false; });
110 }
111 
Connect(const String & host,int port,const String & user,const String & password)112 bool SshSession::Connect(const String& host, int port, const String& user, const String& password)
113 {
114 	IpAddrInfo ipinfo;
115 
116 	if(!Run([=, &ipinfo] () mutable {
117 		if(host.IsEmpty())
118 			SetError(-1, "Host is not specified.");
119 		ssh->session = nullptr;
120 		session->socket.Timeout(0);
121 		if(!WhenProxy) {
122 			ipinfo.Start(host, port);
123 			LLOG(Format("Starting DNS sequence locally for %s:%d", host, port));
124 		}
125 		else
126 			LLOG("Proxy plugin found. Attempting to connect via proxy...");
127 		WhenPhase(WhenProxy ? PHASE_CONNECTION : PHASE_DNS);
128 		return true;
129 	})) goto Bailout;
130 
131 	if(!WhenProxy) {
132 		if(!Run([=, &ipinfo] () mutable {
133 			if(ipinfo.InProgress())
134 				return false;
135 			if(!ipinfo.GetResult())
136 				SetError(-1, "DNS lookup failed.");
137 			WhenPhase(PHASE_CONNECTION);
138 			return true;
139 		})) goto Bailout;
140 
141 		if(!Run([=, &ipinfo] () mutable {
142 			if(!session->socket.Connect(ipinfo))
143 				return false;
144 			ipinfo.Clear();
145 			return true;
146 		})) goto Bailout;
147 
148 		if(!Run([=, &ipinfo] () mutable {
149 			if(!session->socket.WaitConnect())
150 				return false;
151 			LLOG("Successfully connected to " << host <<":" << port);
152 			return true;
153 		})) goto Bailout;
154 	}
155 	else {
156 		if(!Run([=] () mutable {
157 			if(!WhenProxy())
158 				SetError(-1, "Proxy connection attempt failed.");
159 			LLOG("Proxy connection to " << host << ":" << port << " is successful.");
160 			return true;
161 		})) goto Bailout;
162 	}
163 
164 	if(!Run([=] () mutable {
165 #ifdef UPP_HEAP
166 			LLOG("Using Upp's memory managers.");
167 			ssh->session = libssh2_session_init_ex((*ssh_malloc), (*ssh_free), (*ssh_realloc), this);
168 #else
169 			LLOG("Using libssh2's memory managers.");
170 			ssh->session = libssh2_session_init_ex(nullptr, nullptr, nullptr, this);
171 #endif
172 			if(!ssh->session)
173 				SetError(-1, "Failed to initalize libssh2 session.");
174 #ifdef flagLIBSSH2TRACE
175 			if(libssh2_trace_sethandler(ssh->session, this, &ssh_session_libtrace))
176 				LLOG("Warning: Unable to set trace (debug) handler for libssh2.");
177 			else {
178 				libssh2_trace(ssh->session, SSH::sTraceVerbose);
179 				LLOG("Verbose debugging mode enabled.");
180 			}
181 #endif
182 			libssh2_session_set_blocking(ssh->session, 0);
183 			ssh->socket = &session->socket;
184 			LLOG("Session successfully initialized.");
185 			WhenConfig();
186 			libssh2_session_flag(ssh->session, LIBSSH2_FLAG_COMPRESS, (int) session->compression);
187 			LLOG("Compression is " << (session->compression ? "enabled." : "disabled."));
188 			WhenPhase(PHASE_HANDSHAKE);
189 			return true;
190 	})) goto Bailout;
191 
192 	while(!session->iomethods.IsEmpty()) {
193 		if(!Run([=] () mutable {
194 			int    method = session->iomethods.GetKey(0);
195 			String mnames = GetMethodNames(method);
196 			int rc = libssh2_session_method_pref(ssh->session, method, ~mnames);
197 			if(!WouldBlock(rc) && rc < 0) SetError(rc);
198 			if(!rc && !session->iomethods.IsEmpty()) {
199 				LLOG("Transport method: #" << method << " is set to [" << mnames << "]");
200 				session->iomethods.Remove(0);
201 			}
202 			return !rc;
203 		})) goto Bailout;
204 	}
205 
206 	if(!Run([=] () mutable {
207 			int rc = libssh2_session_handshake(ssh->session, session->socket.GetSOCKET());
208 			if(!WouldBlock(rc) && rc < 0) SetError(rc);
209 			if(!rc) {
210 				LLOG("Handshake successful.");
211 				WhenPhase(PHASE_AUTHORIZATION);
212 			}
213 			return !rc;
214 	})) goto Bailout;
215 
216 	if(!Run([=] () mutable {
217 			switch(session->hashtype) {  // TODO: Remove this block along with the deprecated Hashtype()
218 			case HASH_MD5:               //       and  GetFingerprint() methods, in the future versions.
219 				session->fingerprint = GetMD5Fingerprint();
220 				LLOG("MD5 fingerprint of " << host << ": " << HexString(session->fingerprint, 1, ':'));
221 				break;
222 			case HASH_SHA1:
223 				session->fingerprint = GetSHA1Fingerprint();
224 				LLOG("SHA1 fingerprint of " << host << ": " << HexString(session->fingerprint, 1, ':'));
225 				break;
226 			case HASH_SHA256:
227 				session->fingerprint = GetSHA256Fingerprint();
228 				LLOG("SHA256 fingerprint of " << host << ": " << Base64Encode(session->fingerprint));
229 				break;
230 			default:
231 				break;
232 			}
233 			if(WhenVerify && !WhenVerify(host, port))
234 				SetError(-1);
235 			return true;
236 	})) goto Bailout;
237 
238 	if(!Run([=] () mutable {
239 			session->authmethods = libssh2_userauth_list(ssh->session, user, user.GetLength());
240 			if(IsNull(session->authmethods)) {
241 				if(libssh2_userauth_authenticated(ssh->session)) {
242 					LLOG("Server @" << host << " does not require authentication!");
243 					WhenPhase(PHASE_SUCCESS);
244 					session->connected = true;
245 					return session->connected;
246 				}
247 				else
248 				if(!WouldBlock())
249 					SetError(-1);
250 				return false;
251 			}
252 			LLOG("Authentication methods list successfully retrieved: [" << session->authmethods << "]");
253 			WhenAuth();
254 			return true;
255 	})) goto Bailout;
256 
257 	if(session->connected)
258 		goto Finalize;
259 
260 	if(!Run([=] () mutable {
261 			int rc = -1;
262 			switch(session->authmethod) {
263 				case PASSWORD:
264 					rc = libssh2_userauth_password_ex(
265 							ssh->session,
266 							~user,
267 							 user.GetLength(),
268 							~password,
269 							 password.GetLength(),
270 							 WhenPasswordChange
271 								? &ssh_password_change
272 									: nullptr);
273 					break;
274 				case PUBLICKEY:
275 					rc = session->keyfile
276 					?	libssh2_userauth_publickey_fromfile(
277 							ssh->session,
278 							~user,
279 							~session->pubkey,
280 							~session->prikey,
281 							~session->phrase)
282 					:	libssh2_userauth_publickey_frommemory(
283 							ssh->session,
284 							~user,
285 							 user.GetLength(),
286 							~session->pubkey,
287 							 session->pubkey.GetLength(),
288 							~session->prikey,
289 							 session->prikey.GetLength(),
290 							~session->phrase);
291 					break;
292 				case HOSTBASED:
293 					if(!session->keyfile)
294 						SetError(-1, "Keys cannot be loaded from memory.");
295 					else
296 					rc = libssh2_userauth_hostbased_fromfile(
297 							ssh->session,
298 							~user,
299 							~session->pubkey,
300 							~session->prikey,
301 							~session->phrase,
302 							~host);
303 					break;
304 				case KEYBOARD:
305 					rc = libssh2_userauth_keyboard_interactive(
306 						ssh->session,
307 						~user,
308 						&ssh_keyboard_callback);
309 					break;
310 				case SSHAGENT:
311 					rc = TryAgent(user);
312 					break;
313 				default:
314 					NEVER();
315 
316 			}
317 			if(rc != 0 && !WouldBlock(rc))
318 				SetError(rc);
319 			if(rc == 0 && libssh2_userauth_authenticated(ssh->session)) {
320 				LLOG("Client succesfully authenticated.");
321 				WhenPhase(PHASE_SUCCESS);
322 				session->connected = true;
323 			}
324 			return	session->connected;
325 	})) goto Bailout;
326 
327 Finalize:
328 #ifdef PLATFORM_POSIX
329 	libssh2_session_callback_set(ssh->session, LIBSSH2_CALLBACK_X11, (void*) ssh_x11_request);
330 	LLOG("X11 dispatcher is set.");
331 #endif
332 	return true;
333 
334 Bailout:
335 	LLOG("Connection attempt failed. Bailing out...");
336 	Exit();
337 	return false;
338 }
339 
Disconnect()340 void SshSession::Disconnect()
341 {
342 	Exit();
343 }
344 
CreateSFtp()345 SFtp SshSession::CreateSFtp()
346 {
347 	ASSERT(ssh && ssh->session);
348 	return pick(SFtp(*this));
349 }
350 
CreateChannel()351 SshChannel SshSession::CreateChannel()
352 {
353 	ASSERT(ssh && ssh->session);
354 	return pick(SshChannel(*this));
355 }
356 
CreateExec()357 SshExec SshSession::CreateExec()
358 {
359 	ASSERT(ssh && ssh->session);
360 	return pick(SshExec(*this));
361 }
362 
CreateScp()363 Scp SshSession::CreateScp()
364 {
365 	ASSERT(ssh && ssh->session);
366 	return pick(Scp(*this));
367 }
368 
CreateTunnel()369 SshTunnel SshSession::CreateTunnel()
370 {
371 	ASSERT(ssh && ssh->session);
372 	return pick(SshTunnel(*this));
373 }
374 
CreateShell()375 SshShell SshSession::CreateShell()
376 {
377 	ASSERT(ssh && ssh->session);
378 	return pick(SshShell(*this));
379 }
380 
GetHostKeyHash(int type,int length) const381 String SshSession::GetHostKeyHash(int type, int length) const
382 {
383 	String hash;
384 	if(ssh->session) {
385 		hash = libssh2_hostkey_hash(ssh->session, type);
386 		if(hash.GetLength() > length)
387 			hash.TrimLast(hash.GetLength() - length);
388 	}
389 	return hash;
390 }
391 
GetMethods() const392 ValueMap SshSession::GetMethods() const
393 {
394 	ValueMap methods;
395 	if(ssh->session) {
396 		for(int i = METHOD_EXCHANGE; i <= METHOD_SLANGUAGE; i++) {
397 			const char **p = nullptr;
398 			auto rc = libssh2_session_supported_algs(ssh->session, i, &p);
399 			if(rc > 0) {
400 				auto& v = methods(i);
401 				for(int j = 0; j < rc; j++) {
402 					v << p[j];
403 				}
404 				libssh2_free(ssh->session, p);
405 			}
406 		}
407 	}
408 	return pick(methods);
409 }
410 
GetMethodNames(int type) const411 String SshSession::GetMethodNames(int type) const
412 {
413 	String names;
414 	const Value& v = session->iomethods[type];
415 	if(IsValueArray(v)) {
416 		for(int i = 0; i < v.GetCount(); i++)
417 			names << v[i].To<String>() << (i < v.GetCount() - 1 ? "," : "");
418 	}
419 	else names << v;
420 	return pick(names);
421 }
422 
TryAgent(const String & username)423 int SshSession::TryAgent(const String& username)
424 {
425 	LLOG("Attempting to authenticate via ssh-agent...");
426 	auto agent = libssh2_agent_init(ssh->session);
427 	if(!agent)
428 		SetError(-1, "Couldn't initialize ssh-agent support.");
429 	if(libssh2_agent_connect(agent)) {
430 		libssh2_agent_free(agent);
431 		SetError(-1, "Couldn't connect to ssh-agent.");
432 	}
433 	if(libssh2_agent_list_identities(agent)) {
434 		FreeAgent(agent);
435 		SetError(-1, "Couldn't request identities to ssh-agent.");
436 	}
437 	libssh2_agent_publickey *id = nullptr, *previd = nullptr;
438 
439 	for(;;) {
440 		auto rc = libssh2_agent_get_identity(agent, &id, previd);
441 		if(rc < 0) {
442 			FreeAgent(agent);
443 			SetError(-1, "Unable to obtain identity from ssh-agent.");
444 		}
445 		if(rc != 1) {
446 			if(libssh2_agent_userauth(agent, ~username, id)) {
447 				LLOG(Format("Authentication with username %s and public key %s failed.",
448 							username, id->comment));
449 			}
450 			else {
451 				LLOG(Format("Authentication with username %s and public key %s succeesful.",
452 							username, id->comment));
453 				break;
454 			}
455 		}
456 		else {
457 			FreeAgent(agent);
458 			SetError(-1, "Couldn't authenticate via ssh-agent");
459 		}
460 		previd = id;
461 	}
462 	FreeAgent(agent);
463 	return 0;
464 }
465 
FreeAgent(SshAgent * agent)466 void SshSession::FreeAgent(SshAgent* agent)
467 {
468 	libssh2_agent_disconnect(agent);
469 	libssh2_agent_free(agent);
470 }
471 
472 
Keys(const String & prikey,const String & pubkey,const String & phrase,bool fromfile)473 SshSession& SshSession::Keys(const String& prikey, const String& pubkey, const String& phrase, bool fromfile)
474 {
475     session->prikey  = prikey;
476     session->pubkey  = pubkey;
477     session->phrase  = phrase;
478     session->keyfile = fromfile;
479     return *this;
480 }
481 
SshSession()482 SshSession::SshSession()
483 : Ssh()
484 {
485     session.Create();
486     ssh->otype           = SESSION;
487     ssh->whenwait        = Proxy(WhenWait);
488     session->authmethod  = PASSWORD;
489     session->connected   = false;
490     session->keyfile     = true;
491     session->compression = false;
492     session->hashtype    = HASH_SHA256;
493  }
494 
~SshSession()495 SshSession::~SshSession()
496 {
497 	Exit();
498 }
499 }