1%%
2%% %CopyrightBegin%
3%%
4%% Copyright Ericsson AB 2004-2017. All Rights Reserved.
5%%
6%% The contents of this file are subject to the Erlang Public License,
7%% Version 1.1, (the "License"); you may not use this file except in
8%% compliance with the License. You should have received a copy of the
9%% Erlang Public License along with this software. If not, it can be
10%% retrieved online at http://www.erlang.org/.
11%%
12%% Software distributed under the License is distributed on an "AS IS"
13%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
14%% the License for the specific language governing rights and limitations
15%% under the License.
16%%
17%% %CopyrightEnd%
18%%
19%%
20
21-module(ssh_trpt_test_lib).
22
23%%-compile(export_all).
24
25-export([exec/1, exec/2,
26	 instantiate/2,
27	 format_msg/1,
28	 server_host_port/1
29	]
30       ).
31
32-include_lib("common_test/include/ct.hrl").
33-include_lib("ssh/src/ssh.hrl").		% ?UINT32, ?BYTE, #ssh{} ...
34-include_lib("ssh/src/ssh_transport.hrl").
35-include_lib("ssh/src/ssh_auth.hrl").
36
37%%%----------------------------------------------------------------
38-record(s, {
39	  socket,
40	  listen_socket,
41	  opts = [],
42	  timeout = 5000,			% ms
43	  seen_hello = false,
44	  ssh = #ssh{},				% #ssh{}
45	  alg_neg = {undefined,undefined},      % {own_kexinit, peer_kexinit}
46	  alg,                                  % #alg{}
47	  vars = dict:new(),
48	  reply = [],				% Some repy msgs are generated hidden in ssh_transport :[
49	  prints = [],
50	  return_value,
51
52          %% Packet retrival and decryption
53          decrypted_data_buffer     = <<>>,
54          encrypted_data_buffer     = <<>>,
55          aead_data                 = <<>>,
56          undecrypted_packet_length
57         }).
58
59-define(role(S), ((S#s.ssh)#ssh.role) ).
60
61
62server_host_port(S=#s{}) ->
63    {Host,Port} = ok(inet:sockname(S#s.listen_socket)),
64    {host(Host), Port}.
65
66
67%%% Options: {print_messages, false}  true|detail
68%%%          {print_seqnums,false}    true
69%%%          {print_ops,false}        true
70
71exec(L) ->  exec(L, #s{}).
72
73exec(L, S) when is_list(L) -> lists:foldl(fun exec/2, S, L);
74
75exec(Op, S0=#s{}) ->
76    S1 = init_op_traces(Op, S0),
77    try seqnum_trace(
78	  op(Op, S1))
79    of
80	S = #s{} ->
81	    case proplists:get_value(silent,S#s.opts) of
82		true -> ok;
83		_ -> print_traces(S)
84	    end,
85	    {ok,S}
86    catch
87	{fail,Reason,Se} ->
88	    report_trace('', Reason, Se),
89	    {error,{Op,Reason,Se}};
90
91	throw:Term ->
92	    report_trace(throw, Term, S1),
93	    throw({Term,Op});
94
95	error:Error ->
96	    report_trace(error, Error, S1),
97	    error({Error,Op});
98
99	exit:Exit ->
100	    report_trace(exit, Exit, S1),
101	    exit({Exit,Op});
102        Cls:Err ->
103            ct:pal("Class=~p, Error=~p", [Cls,Err]),
104            error({"fooooooO",Op})
105    end;
106exec(Op, {ok,S=#s{}}) -> exec(Op, S);
107exec(_, Error) -> Error.
108
109
110%%%---- Server ops
111op(listen, S) when ?role(S) == undefined -> op({listen,0}, S);
112
113op({listen,Port}, S) when ?role(S) == undefined ->
114    S#s{listen_socket = ok(gen_tcp:listen(Port, mangle_opts([]))),
115	ssh = (S#s.ssh)#ssh{role=server}
116       };
117
118op({accept,Opts}, S) when ?role(S) == server ->
119    {ok,Socket} = gen_tcp:accept(S#s.listen_socket, S#s.timeout),
120    {Host,_Port} = ok(inet:sockname(Socket)),
121    S#s{socket = Socket,
122	ssh = init_ssh(server, Socket, host(Host), Opts),
123	return_value = ok};
124
125%%%---- Client ops
126op({connect,Host,Port,Opts}, S) when ?role(S) == undefined ->
127    Socket = ok(gen_tcp:connect(host(Host), Port, mangle_opts([]))),
128    S#s{socket = Socket,
129	ssh = init_ssh(client, Socket, host(Host), Opts),
130	return_value = ok};
131
132%%%---- ops for both client and server
133op(close_socket, S) ->
134    catch gen_tcp:close(S#s.socket),
135    catch gen_tcp:close(S#s.listen_socket),
136    S#s{socket = undefined,
137	listen_socket = undefined,
138	return_value = ok};
139
140op({set_options,Opts}, S) ->
141    S#s{opts = Opts};
142
143op({send,X}, S) ->
144    send(S, instantiate(X,S));
145
146op(receive_hello, S0) when S0#s.seen_hello =/= true ->
147    case recv(S0) of
148	S1=#s{return_value={hello,_}} -> S1;
149	S1=#s{} -> op(receive_hello, receive_wait(S1))
150    end;
151
152op(receive_msg, S) when S#s.seen_hello == true ->
153    try recv(S)
154    catch
155	{tcp,Exc} ->
156	    S1 = opt(print_messages, S,
157		     fun(X) when X==true;X==detail -> {"Recv~n~p~n",[Exc]} end),
158	    S1#s{return_value=Exc}
159    end;
160
161
162op({expect,timeout,E}, S0) ->
163    try op(E, S0)
164    of
165	S=#s{} -> fail({expected,timeout,S#s.return_value}, S)
166    catch
167	{receive_timeout,_} -> S0#s{return_value=timeout}
168    end;
169
170op({match,M,E}, S0) ->
171    {Val,S2} = op_val(E, S0),
172    case match(M, Val, S2) of
173	{true,S3} ->
174	    opt(print_ops,S3,
175		fun(true) ->
176			case dict:fold(
177			       fun(K,V,Acc) ->
178				       case dict:find(K,S0#s.vars) of
179					   error -> [{K,V}|Acc];
180					   _ -> Acc
181				       end
182			       end, [], S3#s.vars)
183			of
184			    [] -> {"Matches! No new bindings.",[]};
185			    New ->
186				Width = lists:max([length(atom_to_list(K)) || {K,_} <- New]),
187				{lists:flatten(
188				   ["Matches! New bindings:~n" |
189				    [io_lib:format(" ~*s = ~p~n",[Width,K,V]) || {K,V}<-New]]),
190				 []}
191			end
192		end);
193	false ->
194	    fail({expected,M,Val},
195		 opt(print_ops,S2,fun(true) -> {"nomatch!!~n",[]} end)
196		)
197    end;
198
199op({print,E}, S0) ->
200    {Val,S} = op_val(E, S0),
201    io:format("Result of ~p ~p =~n~s~n",[?role(S0),E,format_msg(Val)]),
202    S;
203
204op(print_state, S) ->
205    io:format("State(~p)=~n~s~n",[?role(S), format_msg(S)]),
206    S;
207
208op('$$', S) ->
209    %% For matching etc
210    S.
211
212
213op_val(E, S0) ->
214    case catch op(E, S0) of
215	{'EXIT',{function_clause,[{ssh_trpt_test_lib,op,[E,S0],_}|_]}} ->
216	    {instantiate(E,S0), S0};
217	S=#s{} ->
218	    {S#s.return_value, S};
219	F={fail,receive_timeout,_St} ->
220	    throw(F)
221    end.
222
223
224fail(Reason, {Fmt,Args}, S) when is_list(Fmt), is_list(Args) ->
225    fail(Reason, save_prints({Fmt,Args}, S)).
226
227fail(Reason, S) ->
228    throw({fail, Reason, S}).
229
230%%%----------------------------------------------------------------
231%% No optimizations :)
232
233match('$$', V, S) ->
234    match(S#s.return_value, V, S);
235
236match('_', _, S) ->
237    {true, S};
238
239match({'or',[P]}, V, S) -> match(P,V,S);
240match({'or',[Ph|Pt]}, V, S) ->
241    case match(Ph,V,S) of
242        false -> match({'or',Pt}, V, S);
243	{true,S} -> {true,S}
244    end;
245
246match(P, V, S) when is_atom(P) ->
247    case atom_to_list(P) of
248	"$"++_ ->
249	    %% Variable
250	    case dict:find(P,S#s.vars) of
251		{ok,Val} -> match(Val, V, S);
252		error -> {true,S#s{vars = dict:store(P,V,S#s.vars)}}
253	    end;
254	_ when P==V ->
255	    {true,S};
256	_ ->
257	    false
258    end;
259
260match(P, V, S) when P==V ->
261    {true, S};
262
263match(P, V, S) when is_tuple(P),
264		     is_tuple(V) ->
265    match(tuple_to_list(P), tuple_to_list(V), S);
266
267match([Hp|Tp], [Hv|Tv], S0) ->
268    case match(Hp, Hv, S0) of
269	{true,S} -> match(Tp, Tv, S);
270	false -> false
271    end;
272
273match(_, _, _) ->
274    false.
275
276
277
278instantiate('$$', S) ->
279    S#s.return_value;	  % FIXME: What if $$ or $... in return_value?
280
281instantiate(A, S) when is_atom(A) ->
282    case atom_to_list(A) of
283	"$"++_ ->
284	    %% Variable
285	    case dict:find(A,S#s.vars) of
286		{ok,Val} -> Val;   % FIXME: What if $$ or $... in Val?
287		error -> throw({unbound,A})
288	    end;
289	_ ->
290	    A
291    end;
292
293instantiate(T, S) when is_tuple(T) ->
294    list_to_tuple( instantiate(tuple_to_list(T),S) );
295
296instantiate([H|T], S) ->
297    [instantiate(H,S) | instantiate(T,S)];
298
299instantiate(X, _S) ->
300    X.
301
302%%%================================================================
303%%%
304init_ssh(Role, Socket, Host, UserOptions0) ->
305    UserOptions = [{user_interaction, false},
306                   {vsn, {2,0}},
307                   {id_string, "ErlangTestLib"}
308                   | UserOptions0],
309    Opts = ?PUT_INTERNAL_OPT({host,Host},
310                             ssh_options:handle_options(Role, UserOptions)),
311    ssh_connection_handler:init_ssh_record(Role, Socket, Opts).
312
313mangle_opts(Options) ->
314    SysOpts = [{reuseaddr, true},
315	       {active, false},
316	       {mode, binary}
317	      ],
318    SysOpts ++ lists:foldl(fun({K,_},Opts) ->
319				   lists:keydelete(K,1,Opts)
320			   end, Options, SysOpts).
321
322host(H) -> ssh_test_lib:ntoa(ssh_test_lib:mangle_connect_address(H)).
323
324%%%----------------------------------------------------------------
325send(S=#s{ssh=C}, hello) ->
326    Hello = case ?role(S) of
327		client -> C#ssh.c_version;
328		server -> C#ssh.s_version
329	    end ++ "\r\n",
330    send(S, list_to_binary(Hello));
331
332send(S0, ssh_msg_kexinit) ->
333    {Msg, _Bytes, _C0} = ssh_transport:key_exchange_init_msg(S0#s.ssh),
334    send(S0, Msg);
335
336send(S0=#s{alg_neg={undefined,PeerMsg}}, Msg=#ssh_msg_kexinit{}) ->
337    S1 = opt(print_messages, S0,
338	     fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
339    S2 = case PeerMsg of
340	     #ssh_msg_kexinit{} ->
341		 try ssh_transport:handle_kexinit_msg(PeerMsg, Msg, S1#s.ssh) of
342		     {ok,Cx} when ?role(S1) == server ->
343			 S1#s{alg = Cx#ssh.algorithms};
344		     {ok,_NextKexMsgBin,Cx} when ?role(S1) == client ->
345			 S1#s{alg = Cx#ssh.algorithms}
346		 catch
347		     Class:Exc ->
348			 save_prints({"Algoritm negotiation failed at line ~p:~p~n~p:~s~nPeer: ~s~n Own: ~s~n",
349				      [?MODULE,?LINE,Class,format_msg(Exc),format_msg(PeerMsg),format_msg(Msg)]},
350				     S1)
351		 end;
352	     undefined ->
353		 S1
354	 end,
355    {Bytes, C} = ssh_transport:ssh_packet(Msg, S2#s.ssh),
356    send_bytes(Bytes, S2#s{return_value = Msg,
357			  alg_neg = {Msg,PeerMsg},
358			  ssh = C});
359
360send(S0, ssh_msg_kexdh_init) when ?role(S0) == client ->
361    {OwnMsg, PeerMsg} = S0#s.alg_neg,
362    {ok, NextKexMsgBin, C} =
363	try ssh_transport:handle_kexinit_msg(PeerMsg, OwnMsg, S0#s.ssh)
364	catch
365	    Class:Exc ->
366		fail("Algoritm negotiation failed!",
367		     {"Algoritm negotiation failed at line ~p:~p~n~p:~s~nPeer: ~s~n Own: ~s",
368		      [?MODULE,?LINE,Class,format_msg(Exc),format_msg(PeerMsg),format_msg(OwnMsg)]},
369		     S0)
370	end,
371    S = opt(print_messages, S0,
372	    fun(X) when X==true;X==detail ->
373		    #ssh{keyex_key = {{_Private, Public}, {_G, _P}}} = C,
374		    Msg = #ssh_msg_kexdh_init{e = Public},
375		    {"Send (reconstructed)~n~s~n",[format_msg(Msg)]}
376	    end),
377    send_bytes(NextKexMsgBin, S#s{ssh = C});
378
379send(S0, ssh_msg_kexdh_reply) ->
380    Bytes = proplists:get_value(ssh_msg_kexdh_reply, S0#s.reply),
381    S = opt(print_messages, S0,
382	    fun(X) when X==true;X==detail ->
383		    {{_Private, Public}, _} = (S0#s.ssh)#ssh.keyex_key,
384		    Msg = #ssh_msg_kexdh_reply{public_host_key = 'Key',
385					       f = Public,
386					       h_sig = 'H_SIG'
387					      },
388		    {"Send (reconstructed)~n~s~n",[format_msg(Msg)]}
389	    end),
390    send_bytes(Bytes, S#s{return_value = Bytes});
391
392send(S0, Line) when is_binary(Line) ->
393    S = opt(print_messages, S0,
394	    fun(X) when X==true;X==detail -> {"Send line~n~p~n",[Line]} end),
395    send_bytes(Line, S#s{return_value = Line});
396
397send(S0, {special,Msg,PacketFun}) when is_tuple(Msg),
398				       is_function(PacketFun,2) ->
399    S = opt(print_messages, S0,
400	    fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
401    {Packet, C} = PacketFun(Msg, S#s.ssh),
402    send_bytes(Packet, S#s{ssh = C, %%inc_send_seq_num(C),
403			   return_value = Msg});
404
405send(S0, #ssh_msg_newkeys{} = Msg) ->
406    S = opt(print_messages, S0,
407	    fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
408    {ok, Packet, C} = ssh_transport:new_keys_message(S#s.ssh),
409    send_bytes(Packet, S#s{ssh = C});
410
411send(S0, Msg) when is_tuple(Msg) ->
412    S = opt(print_messages, S0,
413	    fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
414    {Packet, C} = ssh_transport:ssh_packet(Msg, S#s.ssh),
415    send_bytes(Packet, S#s{ssh = C, %%inc_send_seq_num(C),
416			   return_value = Msg}).
417
418send_bytes(B, S0) ->
419    S = opt(print_messages, S0, fun(detail) -> {"Send bytes~n~p~n",[B]} end),
420    ok(gen_tcp:send(S#s.socket, B)),
421    S.
422
423%%%----------------------------------------------------------------
424recv(S0 = #s{}) ->
425    S1 = receive_poll(S0),
426    case S1#s.seen_hello of
427	{more,Seen} ->
428	    %% Has received parts of a line. Has not seen a complete hello.
429	    try_find_crlf(Seen, S1);
430	false ->
431	    %% Must see hello before binary messages
432	    try_find_crlf(<<>>, S1);
433	true ->
434	    %% Has seen hello, therefore no more crlf-messages are alowed.
435	    S = receive_binary_msg(S1),
436	    case PeerMsg = S#s.return_value of
437		#ssh_msg_kexinit{} ->
438		    case S#s.alg_neg of
439			{undefined,undefined} ->
440			    S#s{alg_neg = {undefined,PeerMsg}};
441
442			{undefined,_} ->
443			    fail("2 kexint received!!", S);
444
445			{OwnMsg, _} ->
446			    try ssh_transport:handle_kexinit_msg(PeerMsg, OwnMsg, S#s.ssh) of
447				{ok,C} when ?role(S) == server ->
448				    S#s{alg_neg = {OwnMsg, PeerMsg},
449					alg = C#ssh.algorithms,
450					ssh = C};
451				{ok,_NextKexMsgBin,C} when ?role(S) == client ->
452				    S#s{alg_neg = {OwnMsg, PeerMsg},
453					alg = C#ssh.algorithms}
454			    catch
455				Class:Exc ->
456				    save_prints({"Algoritm negotiation failed at line ~p:~p~n~p:~s~nPeer: ~s~n Own: ~s~n",
457						 [?MODULE,?LINE,Class,format_msg(Exc),format_msg(PeerMsg),format_msg(OwnMsg)]},
458						S#s{alg_neg = {OwnMsg, PeerMsg}})
459			    end
460		    end;
461
462		#ssh_msg_kexdh_init{} -> % Always the server
463		    {ok, Reply, C} = ssh_transport:handle_kexdh_init(PeerMsg, S#s.ssh),
464		    S#s{ssh = C,
465			reply = [{ssh_msg_kexdh_reply,Reply} | S#s.reply]
466		       };
467		#ssh_msg_kexdh_reply{} ->
468		    {ok, _NewKeys, C} = ssh_transport:handle_kexdh_reply(PeerMsg, S#s.ssh),
469                    S#s{ssh = (S#s.ssh)#ssh{shared_secret = C#ssh.shared_secret,
470                                            exchanged_hash = C#ssh.exchanged_hash,
471                                            session_id = C#ssh.session_id}};
472		    %%%S#s{ssh=C#ssh{send_sequence=S#s.ssh#ssh.send_sequence}}; % Back the number
473		#ssh_msg_newkeys{} ->
474		    {ok, C} = ssh_transport:handle_new_keys(PeerMsg, S#s.ssh),
475		    S#s{ssh=C};
476		_ ->
477		    S
478	    end
479    end.
480
481%%%================================================================
482try_find_crlf(Seen, S0) ->
483    case erlang:decode_packet(line,S0#s.encrypted_data_buffer,[]) of
484	{more,_} ->
485	    Line = <<Seen/binary,(S0#s.encrypted_data_buffer)/binary>>,
486	    S0#s{seen_hello = {more,Line},
487		 encrypted_data_buffer = <<>>,	       % didn't find a complete line
488				       % -> no more characters to test
489		 return_value = {more,Line}
490	       };
491	{ok,Used,Rest} ->
492	    Line = <<Seen/binary,Used/binary>>,
493	    case handle_hello(Line, S0) of
494		false ->
495		    S = opt(print_messages, S0,
496			    fun(X) when X==true;X==detail -> {"Recv info~n~p~n",[Line]} end),
497		    S#s{seen_hello = false,
498			encrypted_data_buffer = Rest,
499			return_value = {info,Line}};
500		S1=#s{} ->
501		    S = opt(print_messages, S1,
502			fun(X) when X==true;X==detail -> {"Recv hello~n~p~n",[Line]} end),
503		    S#s{seen_hello = true,
504			encrypted_data_buffer = Rest,
505			return_value = {hello,Line}}
506	    end
507    end.
508
509
510handle_hello(Bin, S=#s{ssh=C}) ->
511    case {ssh_transport:handle_hello_version(binary_to_list(Bin)),
512	  ?role(S)}
513    of
514	{{undefined,_}, _} ->  false;
515	{{Vp,Vs}, client} ->   S#s{ssh = C#ssh{s_vsn=Vp, s_version=Vs}};
516	{{Vp,Vs}, server} ->   S#s{ssh = C#ssh{c_vsn=Vp, c_version=Vs}}
517    end.
518
519receive_binary_msg(S0=#s{}) ->
520     case ssh_transport:handle_packet_part(
521           S0#s.decrypted_data_buffer,
522           S0#s.encrypted_data_buffer,
523           S0#s.aead_data,
524           S0#s.undecrypted_packet_length,
525           S0#s.ssh)
526     of
527         {packet_decrypted, DecryptedBytes, EncryptedDataRest, Ssh1} ->
528             S1 = S0#s{ssh = Ssh1#ssh{recv_sequence = ssh_transport:next_seqnum(Ssh1#ssh.recv_sequence)},
529                       decrypted_data_buffer = <<>>,
530                       undecrypted_packet_length = undefined,
531                       aead_data = <<>>,
532                       encrypted_data_buffer = EncryptedDataRest},
533             case
534                 catch ssh_message:decode(set_prefix_if_trouble(DecryptedBytes,S1))
535             of
536                 {'EXIT',_} -> fail(decode_failed,S1);
537
538                 Msg ->
539                     Ssh2 = case Msg of
540                              #ssh_msg_kexinit{} ->
541                                  ssh_transport:key_init(opposite_role(Ssh1), Ssh1, DecryptedBytes);
542                              _ ->
543                                  Ssh1
544			 end,
545                     S2 = opt(print_messages, S1,
546                              fun(X) when X==true;X==detail -> {"Recv~n~s~n",[format_msg(Msg)]} end),
547                     S3 = opt(print_messages, S2,
548                              fun(detail) -> {"decrypted bytes ~p~n",[DecryptedBytes]} end),
549                     S3#s{ssh = inc_recv_seq_num(Ssh2),
550                          return_value = Msg
551                         }
552             end;
553
554         {get_more, DecryptedBytes, EncryptedDataRest, AeadData, TotalNeeded, Ssh1} ->
555             %% Here we know that there are not enough bytes in
556             %% EncryptedDataRest to use. We must wait for more.
557             Remaining = case TotalNeeded of
558                             undefined -> 8;
559                             _ -> TotalNeeded - size(DecryptedBytes) - size(EncryptedDataRest)
560                         end,
561             receive_binary_msg(
562               receive_wait(Remaining,
563                            S0#s{encrypted_data_buffer = EncryptedDataRest,
564                                 decrypted_data_buffer = DecryptedBytes,
565                                 undecrypted_packet_length = TotalNeeded,
566                                 aead_data = AeadData,
567                                 ssh = Ssh1}
568                           ))
569     end.
570
571
572
573old_receive_binary_msg(S0=#s{ssh=C0=#ssh{decrypt_block_size = BlockSize,
574				     recv_mac_size = MacSize
575				    }
576			}) ->
577    case size(S0#s.encrypted_data_buffer) >= max(8,BlockSize) of
578	false ->
579	    %% Need more bytes to decode the packet_length field
580	    Remaining = max(8,BlockSize) - size(S0#s.encrypted_data_buffer),
581	    receive_binary_msg( receive_wait(Remaining, S0) );
582	true ->
583	    %% Has enough bytes to decode the packet_length field
584	    {_, <<?UINT32(PacketLen), _/binary>>, _} =
585		ssh_transport:decrypt_blocks(S0#s.encrypted_data_buffer, BlockSize, C0), % FIXME: BlockSize should be at least 4
586
587	    %% FIXME: Check that ((4+PacketLen) rem BlockSize) == 0 ?
588
589	    S1 = if
590		     PacketLen > ?SSH_MAX_PACKET_SIZE ->
591			 fail({too_large_message,PacketLen},S0);        % FIXME: disconnect
592
593		     ((4+PacketLen) rem BlockSize) =/= 0 ->
594			 fail(bad_packet_length_modulo, S0); % FIXME: disconnect
595
596		     size(S0#s.encrypted_data_buffer) >= (4 + PacketLen + MacSize) ->
597			 %% has the whole packet
598			 S0;
599
600		     true ->
601			 %% need more bytes to get have the whole packet
602			 Remaining = (4 + PacketLen + MacSize) - size(S0#s.encrypted_data_buffer),
603			 receive_wait(Remaining, S0)
604		 end,
605
606	    %% Decrypt all, including the packet_length part (re-use the initial #ssh{})
607	    {C1, SshPacket = <<?UINT32(_),?BYTE(PadLen),Tail/binary>>, EncRest} =
608		ssh_transport:decrypt_blocks(S1#s.encrypted_data_buffer, PacketLen+4, C0),
609
610	    PayloadLen = PacketLen - 1 - PadLen,
611	    <<CompressedPayload:PayloadLen/binary, _Padding:PadLen/binary>> = Tail,
612
613	    {C2, Payload} = ssh_transport:decompress(C1, CompressedPayload),
614
615	    <<Mac:MacSize/binary, Rest/binary>> = EncRest,
616
617	    case {ssh_transport:is_valid_mac(Mac, SshPacket, C2),
618		  catch ssh_message:decode(set_prefix_if_trouble(Payload,S1))}
619	    of
620		{false,         _} -> fail(bad_mac,S1);
621		{_,    {'EXIT',_}} -> fail(decode_failed,S1);
622
623		{true, Msg} ->
624		    C3 = case Msg of
625			     #ssh_msg_kexinit{} ->
626				 ssh_transport:key_init(opposite_role(C2), C2, Payload);
627			     _ ->
628				 C2
629			 end,
630		    S2 = opt(print_messages, S1,
631			     fun(X) when X==true;X==detail -> {"Recv~n~s~n",[format_msg(Msg)]} end),
632		    S3 = opt(print_messages, S2,
633			     fun(detail) -> {"decrypted bytes ~p~n",[SshPacket]} end),
634		    S3#s{ssh = inc_recv_seq_num(C3),
635			 encrypted_data_buffer = Rest,
636			 return_value = Msg
637			}
638	    end
639    end.
640
641
642set_prefix_if_trouble(Msg = <<?BYTE(Op),_/binary>>, #s{alg=#alg{kex=Kex}})
643  when Op == 30;
644       Op == 31
645       ->
646    case catch atom_to_list(Kex) of
647	"ecdh-sha2-" ++ _ ->
648	    <<"ecdh",Msg/binary>>;
649	"diffie-hellman-group-exchange-" ++ _ ->
650	    <<"dh_gex",Msg/binary>>;
651	"diffie-hellman-group" ++ _ ->
652	    <<"dh",Msg/binary>>;
653	_ ->
654	    Msg
655    end;
656set_prefix_if_trouble(Msg, _) ->
657    Msg.
658
659
660receive_poll(S=#s{socket=Sock}) ->
661    inet:setopts(Sock, [{active,once}]),
662    receive
663	{tcp,Sock,Data} ->
664	    receive_poll( S#s{encrypted_data_buffer = <<(S#s.encrypted_data_buffer)/binary,Data/binary>>} );
665	{tcp_closed,Sock} ->
666	    throw({tcp,tcp_closed});
667	{tcp_error, Sock, Reason} ->
668	    throw({tcp,{tcp_error,Reason}})
669    after 0 ->
670	    S
671    end.
672
673receive_wait(S=#s{socket=Sock,
674		  timeout=Timeout}) ->
675    inet:setopts(Sock, [{active,once}]),
676    receive
677	{tcp,Sock,Data} ->
678	    S#s{encrypted_data_buffer = <<(S#s.encrypted_data_buffer)/binary,Data/binary>>};
679	{tcp_closed,Sock} ->
680	    throw({tcp,tcp_closed});
681	{tcp_error, Sock, Reason} ->
682	    throw({tcp,{tcp_error,Reason}})
683    after Timeout ->
684	    fail(receive_timeout,S)
685    end.
686
687receive_wait(N, S=#s{socket=Sock,
688		     timeout=Timeout,
689		     encrypted_data_buffer=Enc0}) when N>0 ->
690    inet:setopts(Sock, [{active,once}]),
691    receive
692	{tcp,Sock,Data} ->
693	    receive_wait(N-size(Data), S#s{encrypted_data_buffer = <<Enc0/binary,Data/binary>>});
694	{tcp_closed,Sock} ->
695	    throw({tcp,tcp_closed});
696	{tcp_error, Sock, Reason} ->
697	    throw({tcp,{tcp_error,Reason}})
698    after Timeout ->
699	    fail(receive_timeout, S)
700    end;
701receive_wait(_N, S) ->
702    S.
703
704%% random_padding_len(PaddingLen1, ChunkSize) ->
705%%     MaxAdditionalRandomPaddingLen = 		% max 255 bytes padding totaö
706%% 	(255 - PaddingLen1) - ((255 - PaddingLen1) rem ChunkSize),
707%%     AddLen0 = crypto:rand_uniform(0,MaxAdditionalRandomPaddingLen),
708%%     AddLen0 - (AddLen0 rem ChunkSize).		% preserve the blocking
709
710inc_recv_seq_num(C=#ssh{recv_sequence=N}) -> C#ssh{recv_sequence=(N+1) band 16#ffffffff}.
711%%%inc_send_seq_num(C=#ssh{send_sequence=N}) -> C#ssh{send_sequence=(N+1) band 16#ffffffff}.
712
713opposite_role(#ssh{role=R}) -> opposite_role(R);
714opposite_role(client) -> server;
715opposite_role(server) -> client.
716
717ok(ok) -> ok;
718ok({ok,R}) -> R;
719ok({error,E}) -> erlang:error(E).
720
721
722%%%================================================================
723%%%
724%%% Formating of records
725%%%
726
727format_msg(M) -> format_msg(M, 0).
728
729format_msg(M, I0) ->
730    case fields(M) of
731	undefined -> io_lib:format('~p',[M]);
732	Fields ->
733	    [Name|Args] = tuple_to_list(M),
734	    Head = io_lib:format('#~p{',[Name]),
735	    I = lists:flatlength(Head)+I0,
736	    NL = io_lib:format('~n~*c',[I,$ ]),
737	    Sep = io_lib:format(',~n~*c',[I,$ ]),
738	    Tail = [begin
739			S0 = io_lib:format('~p = ',[F]),
740			I1 = I + lists:flatlength(S0),
741			[S0,format_msg(A,I1)]
742		    end
743		    || {F,A} <- lists:zip(Fields,Args)],
744	    [[Head|string:join(Tail,Sep)],NL,"}"]
745    end.
746
747fields(M) ->
748    case M of
749	#ssh_msg_debug{} -> record_info(fields, ssh_msg_debug);
750	#ssh_msg_disconnect{} -> record_info(fields, ssh_msg_disconnect);
751	#ssh_msg_ignore{} -> record_info(fields, ssh_msg_ignore);
752	#ssh_msg_kex_dh_gex_group{} -> record_info(fields, ssh_msg_kex_dh_gex_group);
753	#ssh_msg_kex_dh_gex_init{} -> record_info(fields, ssh_msg_kex_dh_gex_init);
754	#ssh_msg_kex_dh_gex_reply{} -> record_info(fields, ssh_msg_kex_dh_gex_reply);
755	#ssh_msg_kex_dh_gex_request{} -> record_info(fields, ssh_msg_kex_dh_gex_request);
756	#ssh_msg_kex_dh_gex_request_old{} -> record_info(fields, ssh_msg_kex_dh_gex_request_old);
757	#ssh_msg_kexdh_init{} -> record_info(fields, ssh_msg_kexdh_init);
758	#ssh_msg_kexdh_reply{} -> record_info(fields, ssh_msg_kexdh_reply);
759	#ssh_msg_kexinit{} -> record_info(fields, ssh_msg_kexinit);
760	#ssh_msg_newkeys{} -> record_info(fields, ssh_msg_newkeys);
761	#ssh_msg_service_accept{} -> record_info(fields, ssh_msg_service_accept);
762	#ssh_msg_service_request{} -> record_info(fields, ssh_msg_service_request);
763	#ssh_msg_unimplemented{} -> record_info(fields, ssh_msg_unimplemented);
764	#ssh_msg_userauth_request{} -> record_info(fields, ssh_msg_userauth_request);
765	#ssh_msg_userauth_failure{} -> record_info(fields, ssh_msg_userauth_failure);
766	#ssh_msg_userauth_success{} -> record_info(fields, ssh_msg_userauth_success);
767	#ssh_msg_userauth_banner{} -> record_info(fields, ssh_msg_userauth_banner);
768	#ssh_msg_userauth_passwd_changereq{} -> record_info(fields, ssh_msg_userauth_passwd_changereq);
769	#ssh_msg_userauth_pk_ok{} -> record_info(fields, ssh_msg_userauth_pk_ok);
770	#ssh_msg_userauth_info_request{} -> record_info(fields, ssh_msg_userauth_info_request);
771	#ssh_msg_userauth_info_response{} -> record_info(fields, ssh_msg_userauth_info_response);
772	#s{} -> record_info(fields, s);
773	#ssh{} -> record_info(fields, ssh);
774	#alg{} -> record_info(fields, alg);
775	_ -> undefined
776    end.
777
778%%%================================================================
779%%%
780%%% Trace handling
781%%%
782
783init_op_traces(Op, S0) ->
784    opt(print_ops, S0#s{prints=[]},
785	fun(true) ->
786		case ?role(S0) of
787		    undefined -> {"-- ~p~n",[Op]};
788		    Role -> {"-- ~p ~p~n",[Role,Op]}
789		end
790	end
791       ).
792
793report_trace(Class, Term, S) ->
794    print_traces(
795      opt(print_ops, S,
796	  fun(true) -> {"~s ~p",[Class,Term]} end)
797     ).
798
799seqnum_trace(S) ->
800    opt(print_seqnums, S,
801	fun(true) when S#s.ssh#ssh.send_sequence =/= S#s.ssh#ssh.send_sequence,
802		       S#s.ssh#ssh.recv_sequence =/= S#s.ssh#ssh.recv_sequence ->
803		{"~p seq num: send ~p->~p,  recv ~p->~p~n",
804		 [?role(S),
805		  S#s.ssh#ssh.send_sequence, S#s.ssh#ssh.send_sequence,
806		  S#s.ssh#ssh.recv_sequence, S#s.ssh#ssh.recv_sequence
807		 ]};
808	   (true) when S#s.ssh#ssh.send_sequence =/=  S#s.ssh#ssh.send_sequence ->
809		{"~p seq num: send ~p->~p~n",
810		 [?role(S),
811		  S#s.ssh#ssh.send_sequence, S#s.ssh#ssh.send_sequence]};
812	   (true) when S#s.ssh#ssh.recv_sequence =/=  S#s.ssh#ssh.recv_sequence ->
813		{"~p seq num: recv ~p->~p~n",
814		 [?role(S),
815		  S#s.ssh#ssh.recv_sequence, S#s.ssh#ssh.recv_sequence]}
816	end).
817
818print_traces(S) when S#s.prints == [] -> S;
819print_traces(S) ->
820    Len = length(S#s.prints),
821    ct:log("~s",
822	   [lists:foldl(
823	      fun({Fmt,Args}, Acc) ->
824		      [case Len-length(Acc)-1 of
825			   0 ->
826			       io_lib:format(Fmt,Args);
827			   _N ->
828			       io_lib:format(lists:concat(['~p --------~n',Fmt]),
829					     [Len-length(Acc)-1|Args])
830		       end | Acc]
831	      end, "", S#s.prints)]
832	  ).
833
834opt(Flag, S, Fun) when is_function(Fun,1) ->
835    try Fun(proplists:get_value(Flag,S#s.opts))
836    of P={Fmt,Args} when is_list(Fmt), is_list(Args) ->
837	    save_prints(P, S)
838    catch _:_ ->
839	    S
840    end.
841
842save_prints({Fmt,Args}, S) ->
843    S#s{prints = [{Fmt,Args}|S#s.prints]}.
844