1%%
2%% %CopyrightBegin%
3%%
4%% Copyright Ericsson AB 2004-2020. All Rights Reserved.
5%%
6%% The contents of this file are subject to the Erlang Public License,
7%% Version 1.1, (the "License"); you may not use this file except in
8%% compliance with the License. You should have received a copy of the
9%% Erlang Public License along with this software. If not, it can be
10%% retrieved online at http://www.erlang.org/.
11%%
12%% Software distributed under the License is distributed on an "AS IS"
13%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
14%% the License for the specific language governing rights and limitations
15%% under the License.
16%%
17%% %CopyrightEnd%
18%%
19%%
20
21-module(ssh_trpt_test_lib).
22
23-export([exec/1, exec/2,
24	 instantiate/2,
25	 format_msg/1,
26	 server_host_port/1
27	]
28       ).
29
30-include_lib("common_test/include/ct.hrl").
31-include("ssh.hrl").		% ?UINT32, ?BYTE, #ssh{} ...
32-include("ssh_transport.hrl").
33-include("ssh_auth.hrl").
34
35%%%----------------------------------------------------------------
36-record(s, {
37	  socket,
38	  listen_socket,
39	  opts = [],
40	  timeout = 5000,			% ms
41	  seen_hello = false,
42	  ssh = #ssh{},				% #ssh{}
43	  alg_neg = {undefined,undefined},      % {own_kexinit, peer_kexinit}
44	  alg,                                  % #alg{}
45	  vars = dict:new(),
46	  reply = [],				% Some repy msgs are generated hidden in ssh_transport :[
47	  prints = [],
48	  return_value,
49
50          %% Packet retrival and decryption
51          decrypted_data_buffer     = <<>>,
52          encrypted_data_buffer     = <<>>,
53          aead_data                 = <<>>,
54          undecrypted_packet_length
55         }).
56
57-define(role(S), ((S#s.ssh)#ssh.role) ).
58
59
60server_host_port(S=#s{}) ->
61    {Host,Port} = ok(inet:sockname(S#s.listen_socket)),
62    {host(Host), Port}.
63
64
65%%% Options: {print_messages, false}  true|detail
66%%%          {print_seqnums,false}    true
67%%%          {print_ops,false}        true
68
69exec(L) ->  exec(L, #s{}).
70
71exec(L, S) when is_list(L) -> lists:foldl(fun exec/2, S, L);
72
73exec(Op, S0=#s{}) ->
74    S1 = init_op_traces(Op, S0),
75    try seqnum_trace(
76	  op(Op, S1))
77    of
78	S = #s{} ->
79	    case proplists:get_value(silent,S#s.opts) of
80		true -> ok;
81		_ -> print_traces(S)
82	    end,
83	    {ok,S}
84    catch
85	{fail,Reason,Se} ->
86	    report_trace('', Reason, Se),
87	    {error,{Op,Reason,Se}};
88
89	throw:Term ->
90	    report_trace(throw, Term, S1),
91	    throw({Term,Op});
92
93	error:Error ->
94	    report_trace(error, Error, S1),
95	    error({Error,Op});
96
97	exit:Exit ->
98	    report_trace(exit, Exit, S1),
99	    exit({Exit,Op});
100        Cls:Err ->
101            ct:log("Class=~p, Error=~p", [Cls,Err]),
102            error({"fooooooO",Op})
103    end;
104exec(Op, {ok,S=#s{}}) -> exec(Op, S);
105exec(_, Error) -> Error.
106
107
108%%%---- Server ops
109op(listen, S) when ?role(S) == undefined -> op({listen,0}, S);
110
111op({listen,Port}, S) when ?role(S) == undefined ->
112    S#s{listen_socket = ok(gen_tcp:listen(Port, mangle_opts([]))),
113	ssh = (S#s.ssh)#ssh{role=server}
114       };
115
116op({accept,Opts}, S) when ?role(S) == server ->
117    {ok,Socket} = gen_tcp:accept(S#s.listen_socket, S#s.timeout),
118    {Host,_Port} = ok(inet:sockname(Socket)),
119    S#s{socket = Socket,
120	ssh = init_ssh(server, Socket, host(Host), Opts),
121	return_value = ok};
122
123%%%---- Client ops
124op({connect,Host,Port,Opts}, S) when ?role(S) == undefined ->
125    Socket = ok(gen_tcp:connect(host(Host), Port, mangle_opts([]))),
126    S#s{socket = Socket,
127	ssh = init_ssh(client, Socket, host(Host), Opts),
128	return_value = ok};
129
130%%%---- ops for both client and server
131op(close_socket, S) ->
132    catch gen_tcp:close(S#s.socket),
133    catch gen_tcp:close(S#s.listen_socket),
134    S#s{socket = undefined,
135	listen_socket = undefined,
136	return_value = ok};
137
138op({set_options,Opts}, S) ->
139    S#s{opts = Opts};
140
141op({send,X}, S) ->
142    send(S, instantiate(X,S));
143
144op(receive_hello, S0) when S0#s.seen_hello =/= true ->
145    case recv(S0) of
146	S1=#s{return_value={hello,_}} -> S1;
147	S1=#s{} -> op(receive_hello, receive_wait(S1))
148    end;
149
150op(receive_msg, S) when S#s.seen_hello == true ->
151    try recv(S)
152    catch
153	{tcp,Exc} ->
154	    S1 = opt(print_messages, S,
155		     fun(X) when X==true;X==detail -> {"Recv~n~p~n",[Exc]} end),
156	    S1#s{return_value=Exc}
157    end;
158
159
160op({expect,timeout,E}, S0) ->
161    try op(E, S0)
162    of
163	S=#s{} -> fail({expected,timeout,S#s.return_value}, S)
164    catch
165	{receive_timeout,_} -> S0#s{return_value=timeout}
166    end;
167
168op({match,M,E}, S0) ->
169    {Val,S2} = op_val(E, S0),
170    case match(M, Val, S2) of
171	{true,S3} ->
172	    opt(print_ops,S3,
173		fun(true) ->
174			case dict:fold(
175			       fun(K,V,Acc) ->
176				       case dict:find(K,S0#s.vars) of
177					   error -> [{K,V}|Acc];
178					   _ -> Acc
179				       end
180			       end, [], S3#s.vars)
181			of
182			    [] -> {"Matches! No new bindings.",[]};
183			    New ->
184				Width = lists:max([length(atom_to_list(K)) || {K,_} <- New]),
185				{lists:flatten(
186				   ["Matches! New bindings:~n" |
187				    [io_lib:format(" ~*s = ~p~n",[Width,K,V]) || {K,V}<-New]]),
188				 []}
189			end
190		end);
191	false ->
192	    fail({expected,M,Val},
193		 opt(print_ops,S2,fun(true) -> {"nomatch!!~n",[]} end)
194		)
195    end;
196
197op({print,E}, S0) ->
198    {Val,S} = op_val(E, S0),
199    io:format("Result of ~p ~p =~n~s~n",[?role(S0),E,format_msg(Val)]),
200    S;
201
202op(print_state, S) ->
203    io:format("State(~p)=~n~s~n",[?role(S), format_msg(S)]),
204    S;
205
206op('$$', S) ->
207    %% For matching etc
208    S.
209
210
211op_val(E, S0) ->
212    case catch op(E, S0) of
213	{'EXIT',{function_clause,[{ssh_trpt_test_lib,op,[E,S0],_}|_]}} ->
214	    {instantiate(E,S0), S0};
215	S=#s{} ->
216	    {S#s.return_value, S};
217	F={fail,receive_timeout,_St} ->
218	    throw(F)
219    end.
220
221
222fail(Reason, {Fmt,Args}, S) when is_list(Fmt), is_list(Args) ->
223    fail(Reason, save_prints({Fmt,Args}, S)).
224
225fail(Reason, S) ->
226    throw({fail, Reason, S}).
227
228%%%----------------------------------------------------------------
229%% No optimizations :)
230
231match('$$', V, S) ->
232    match(S#s.return_value, V, S);
233
234match('_', _, S) ->
235    {true, S};
236
237match({'or',[P]}, V, S) -> match(P,V,S);
238match({'or',[Ph|Pt]}, V, S) ->
239    case match(Ph,V,S) of
240        false -> match({'or',Pt}, V, S);
241	{true,S} -> {true,S}
242    end;
243
244match(P, V, S) when is_atom(P) ->
245    case atom_to_list(P) of
246	"$"++_ ->
247	    %% Variable
248	    case dict:find(P,S#s.vars) of
249		{ok,Val} -> match(Val, V, S);
250		error -> {true,S#s{vars = dict:store(P,V,S#s.vars)}}
251	    end;
252	_ when P==V ->
253	    {true,S};
254	_ ->
255	    false
256    end;
257
258match(P, V, S) when P==V ->
259    {true, S};
260
261match(P, V, S) when is_tuple(P),
262		     is_tuple(V) ->
263    match(tuple_to_list(P), tuple_to_list(V), S);
264
265match([Hp|Tp], [Hv|Tv], S0) ->
266    case match(Hp, Hv, S0) of
267	{true,S} -> match(Tp, Tv, S);
268	false -> false
269    end;
270
271match(_, _, _) ->
272    false.
273
274
275
276instantiate('$$', S) ->
277    S#s.return_value;	  % FIXME: What if $$ or $... in return_value?
278
279instantiate(A, S) when is_atom(A) ->
280    case atom_to_list(A) of
281	"$"++_ ->
282	    %% Variable
283	    case dict:find(A,S#s.vars) of
284		{ok,Val} -> Val;   % FIXME: What if $$ or $... in Val?
285		error -> throw({unbound,A})
286	    end;
287	_ ->
288	    A
289    end;
290
291instantiate(T, S) when is_tuple(T) ->
292    list_to_tuple( instantiate(tuple_to_list(T),S) );
293
294instantiate([H|T], S) ->
295    [instantiate(H,S) | instantiate(T,S)];
296
297instantiate(X, _S) ->
298    X.
299
300%%%================================================================
301%%%
302init_ssh(Role, Socket, Host, UserOptions0) ->
303    UserOptions = [{user_interaction, false},
304                   {vsn, {2,0}},
305                   {id_string, "ErlangTestLib"}
306                   | UserOptions0],
307    Opts = ?PUT_INTERNAL_OPT({host,Host},
308                             ssh_options:handle_options(Role, UserOptions)),
309    ssh_connection_handler:init_ssh_record(Role, Socket, Opts).
310
311mangle_opts(Options) ->
312    SysOpts = [{reuseaddr, true},
313	       {active, false},
314	       {mode, binary}
315	      ],
316    SysOpts ++ lists:foldl(fun({K,_},Opts) ->
317				   lists:keydelete(K,1,Opts)
318			   end, Options, SysOpts).
319
320host(H) -> ssh_test_lib:ntoa(ssh_test_lib:mangle_connect_address(H)).
321
322%%%----------------------------------------------------------------
323send(S=#s{ssh=C}, hello) ->
324    Hello = case ?role(S) of
325		client -> C#ssh.c_version;
326		server -> C#ssh.s_version
327	    end ++ "\r\n",
328    send(S, list_to_binary(Hello));
329
330send(S0, ssh_msg_kexinit) ->
331    {Msg, _Bytes, _C0} = ssh_transport:key_exchange_init_msg(S0#s.ssh),
332    send(S0, Msg);
333
334send(S0=#s{alg_neg={undefined,PeerMsg}}, Msg=#ssh_msg_kexinit{}) ->
335    S1 = opt(print_messages, S0,
336	     fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
337    S2 = case PeerMsg of
338	     #ssh_msg_kexinit{} ->
339		 try ssh_transport:handle_kexinit_msg(PeerMsg, Msg, S1#s.ssh) of
340		     {ok,Cx} when ?role(S1) == server ->
341			 S1#s{alg = Cx#ssh.algorithms};
342		     {ok,_NextKexMsgBin,Cx} when ?role(S1) == client ->
343			 S1#s{alg = Cx#ssh.algorithms}
344		 catch
345		     Class:Exc ->
346			 save_prints({"Algoritm negotiation failed at line ~p:~p~n~p:~s~nPeer: ~s~n Own: ~s~n",
347				      [?MODULE,?LINE,Class,format_msg(Exc),format_msg(PeerMsg),format_msg(Msg)]},
348				     S1)
349		 end;
350	     undefined ->
351		 S1
352	 end,
353    {Bytes, C} = ssh_transport:ssh_packet(Msg, S2#s.ssh),
354    send_bytes(Bytes, S2#s{return_value = Msg,
355			  alg_neg = {Msg,PeerMsg},
356			  ssh = C});
357
358send(S0, ssh_msg_kexdh_init) when ?role(S0) == client ->
359    {OwnMsg, PeerMsg} = S0#s.alg_neg,
360    {ok, NextKexMsgBin, C} =
361	try ssh_transport:handle_kexinit_msg(PeerMsg, OwnMsg, S0#s.ssh)
362	catch
363	    Class:Exc ->
364		fail("Algoritm negotiation failed!",
365		     {"Algoritm negotiation failed at line ~p:~p~n~p:~s~nPeer: ~s~n Own: ~s",
366		      [?MODULE,?LINE,Class,format_msg(Exc),format_msg(PeerMsg),format_msg(OwnMsg)]},
367		     S0)
368	end,
369    S = opt(print_messages, S0,
370	    fun(X) when X==true;X==detail ->
371		    #ssh{keyex_key = {{_Private, Public}, {_G, _P}}} = C,
372		    Msg = #ssh_msg_kexdh_init{e = Public},
373		    {"Send (reconstructed)~n~s~n",[format_msg(Msg)]}
374	    end),
375    send_bytes(NextKexMsgBin, S#s{ssh = C});
376
377send(S0, ssh_msg_kexdh_reply) ->
378    Bytes = proplists:get_value(ssh_msg_kexdh_reply, S0#s.reply),
379    S = opt(print_messages, S0,
380	    fun(X) when X==true;X==detail ->
381		    {{_Private, Public}, _} = (S0#s.ssh)#ssh.keyex_key,
382		    Msg = #ssh_msg_kexdh_reply{public_host_key = 'Key',
383					       f = Public,
384					       h_sig = 'H_SIG'
385					      },
386		    {"Send (reconstructed)~n~s~n",[format_msg(Msg)]}
387	    end),
388    send_bytes(Bytes, S#s{return_value = Bytes});
389
390send(S0, Line) when is_binary(Line) ->
391    S = opt(print_messages, S0,
392	    fun(X) when X==true;X==detail -> {"Send line~n~p~n",[Line]} end),
393    send_bytes(Line, S#s{return_value = Line});
394
395send(S0, {special,Msg,PacketFun}) when is_tuple(Msg),
396				       is_function(PacketFun,2) ->
397    S = opt(print_messages, S0,
398	    fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
399    {Packet, C} = PacketFun(Msg, S#s.ssh),
400    send_bytes(Packet, S#s{ssh = C, %%inc_send_seq_num(C),
401			   return_value = Msg});
402
403send(S0, #ssh_msg_newkeys{} = Msg) ->
404    S = opt(print_messages, S0,
405	    fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
406    {ok, Packet, C} = ssh_transport:new_keys_message(S#s.ssh),
407    send_bytes(Packet, S#s{ssh = C});
408
409send(S0, Msg) when is_tuple(Msg) ->
410    S = opt(print_messages, S0,
411	    fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
412    {Packet, C} = ssh_transport:ssh_packet(Msg, S#s.ssh),
413    send_bytes(Packet, S#s{ssh = C, %%inc_send_seq_num(C),
414			   return_value = Msg}).
415
416send_bytes(B, S0) ->
417    S = opt(print_messages, S0, fun(detail) -> {"Send bytes~n~p~n",[B]} end),
418    ok(gen_tcp:send(S#s.socket, B)),
419    S.
420
421%%%----------------------------------------------------------------
422recv(S0 = #s{}) ->
423    S1 = receive_poll(S0),
424    case S1#s.seen_hello of
425	{more,Seen} ->
426	    %% Has received parts of a line. Has not seen a complete hello.
427	    try_find_crlf(Seen, S1);
428	false ->
429	    %% Must see hello before binary messages
430	    try_find_crlf(<<>>, S1);
431	true ->
432	    %% Has seen hello, therefore no more crlf-messages are alowed.
433	    S = receive_binary_msg(S1),
434	    case PeerMsg = S#s.return_value of
435		#ssh_msg_kexinit{} ->
436		    case S#s.alg_neg of
437			{undefined,undefined} ->
438			    S#s{alg_neg = {undefined,PeerMsg}};
439
440			{undefined,_} ->
441			    fail("2 kexint received!!", S);
442
443			{OwnMsg, _} ->
444			    try ssh_transport:handle_kexinit_msg(PeerMsg, OwnMsg, S#s.ssh) of
445				{ok,C} when ?role(S) == server ->
446				    S#s{alg_neg = {OwnMsg, PeerMsg},
447					alg = C#ssh.algorithms,
448					ssh = C};
449				{ok,_NextKexMsgBin,C} when ?role(S) == client ->
450				    S#s{alg_neg = {OwnMsg, PeerMsg},
451					alg = C#ssh.algorithms}
452			    catch
453				Class:Exc ->
454				    save_prints({"Algoritm negotiation failed at line ~p:~p~n~p:~s~nPeer: ~s~n Own: ~s~n",
455						 [?MODULE,?LINE,Class,format_msg(Exc),format_msg(PeerMsg),format_msg(OwnMsg)]},
456						S#s{alg_neg = {OwnMsg, PeerMsg}})
457			    end
458		    end;
459
460		#ssh_msg_kexdh_init{} -> % Always the server
461		    {ok, Reply, C} = ssh_transport:handle_kexdh_init(PeerMsg, S#s.ssh),
462		    S#s{ssh = C,
463			reply = [{ssh_msg_kexdh_reply,Reply} | S#s.reply]
464		       };
465		#ssh_msg_kexdh_reply{} ->
466		    {ok, _NewKeys, C} = ssh_transport:handle_kexdh_reply(PeerMsg, S#s.ssh),
467                    S#s{ssh = (S#s.ssh)#ssh{shared_secret = C#ssh.shared_secret,
468                                            exchanged_hash = C#ssh.exchanged_hash,
469                                            session_id = C#ssh.session_id}};
470		    %%%S#s{ssh=C#ssh{send_sequence=S#s.ssh#ssh.send_sequence}}; % Back the number
471		#ssh_msg_newkeys{} ->
472		    {ok, C} = ssh_transport:handle_new_keys(PeerMsg, S#s.ssh),
473		    S#s{ssh=C};
474		_ ->
475		    S
476	    end
477    end.
478
479%%%================================================================
480try_find_crlf(Seen, S0) ->
481    case erlang:decode_packet(line,S0#s.encrypted_data_buffer,[]) of
482	{more,_} ->
483	    Line = <<Seen/binary,(S0#s.encrypted_data_buffer)/binary>>,
484	    S0#s{seen_hello = {more,Line},
485		 encrypted_data_buffer = <<>>,	       % didn't find a complete line
486				       % -> no more characters to test
487		 return_value = {more,Line}
488	       };
489	{ok,Used,Rest} ->
490	    Line = <<Seen/binary,Used/binary>>,
491	    case handle_hello(Line, S0) of
492		false ->
493		    S = opt(print_messages, S0,
494			    fun(X) when X==true;X==detail -> {"Recv info~n~p~n",[Line]} end),
495		    S#s{seen_hello = false,
496			encrypted_data_buffer = Rest,
497			return_value = {info,Line}};
498		S1=#s{} ->
499		    S = opt(print_messages, S1,
500			fun(X) when X==true;X==detail -> {"Recv hello~n~p~n",[Line]} end),
501		    S#s{seen_hello = true,
502			encrypted_data_buffer = Rest,
503			return_value = {hello,Line}}
504	    end
505    end.
506
507
508handle_hello(Bin, S=#s{ssh=C}) ->
509    case {ssh_transport:handle_hello_version(binary_to_list(Bin)),
510	  ?role(S)}
511    of
512	{{undefined,_}, _} ->  false;
513	{{Vp,Vs}, client} ->   S#s{ssh = C#ssh{s_vsn=Vp, s_version=Vs}};
514	{{Vp,Vs}, server} ->   S#s{ssh = C#ssh{c_vsn=Vp, c_version=Vs}}
515    end.
516
517receive_binary_msg(S0=#s{}) ->
518     case ssh_transport:handle_packet_part(
519           S0#s.decrypted_data_buffer,
520           S0#s.encrypted_data_buffer,
521           S0#s.aead_data,
522           S0#s.undecrypted_packet_length,
523           S0#s.ssh)
524     of
525         {packet_decrypted, DecryptedBytes, EncryptedDataRest, Ssh1} ->
526             S1 = S0#s{ssh = Ssh1#ssh{recv_sequence = ssh_transport:next_seqnum(Ssh1#ssh.recv_sequence)},
527                       decrypted_data_buffer = <<>>,
528                       undecrypted_packet_length = undefined,
529                       aead_data = <<>>,
530                       encrypted_data_buffer = EncryptedDataRest},
531             case
532                 catch ssh_message:decode(set_prefix_if_trouble(DecryptedBytes,S1))
533             of
534                 {'EXIT',_} -> fail(decode_failed,S1);
535
536                 Msg ->
537                     Ssh2 = case Msg of
538                              #ssh_msg_kexinit{} ->
539                                  ssh_transport:key_init(opposite_role(Ssh1), Ssh1, DecryptedBytes);
540                              _ ->
541                                  Ssh1
542			 end,
543                     S2 = opt(print_messages, S1,
544                              fun(X) when X==true;X==detail -> {"Recv~n~s~n",[format_msg(Msg)]} end),
545                     S3 = opt(print_messages, S2,
546                              fun(detail) -> {"decrypted bytes ~p~n",[DecryptedBytes]} end),
547                     S3#s{ssh = inc_recv_seq_num(Ssh2),
548                          return_value = Msg
549                         }
550             end;
551
552         {get_more, DecryptedBytes, EncryptedDataRest, AeadData, TotalNeeded, Ssh1} ->
553             %% Here we know that there are not enough bytes in
554             %% EncryptedDataRest to use. We must wait for more.
555             Remaining = case TotalNeeded of
556                             undefined -> 8;
557                             _ -> TotalNeeded - size(DecryptedBytes) - size(EncryptedDataRest)
558                         end,
559             receive_binary_msg(
560               receive_wait(Remaining,
561                            S0#s{encrypted_data_buffer = EncryptedDataRest,
562                                 decrypted_data_buffer = DecryptedBytes,
563                                 undecrypted_packet_length = TotalNeeded,
564                                 aead_data = AeadData,
565                                 ssh = Ssh1}
566                           ))
567     end.
568
569
570
571set_prefix_if_trouble(Msg = <<?BYTE(Op),_/binary>>, #s{alg=#alg{kex=Kex}})
572  when Op == 30;
573       Op == 31
574       ->
575    case catch atom_to_list(Kex) of
576	"ecdh-sha2-" ++ _ ->
577	    <<"ecdh",Msg/binary>>;
578	"diffie-hellman-group-exchange-" ++ _ ->
579	    <<"dh_gex",Msg/binary>>;
580	"diffie-hellman-group" ++ _ ->
581	    <<"dh",Msg/binary>>;
582	_ ->
583	    Msg
584    end;
585set_prefix_if_trouble(Msg, _) ->
586    Msg.
587
588
589receive_poll(S=#s{socket=Sock}) ->
590    inet:setopts(Sock, [{active,once}]),
591    receive
592	{tcp,Sock,Data} ->
593	    receive_poll( S#s{encrypted_data_buffer = <<(S#s.encrypted_data_buffer)/binary,Data/binary>>} );
594	{tcp_closed,Sock} ->
595	    throw({tcp,tcp_closed});
596	{tcp_error, Sock, Reason} ->
597	    throw({tcp,{tcp_error,Reason}})
598    after 0 ->
599	    S
600    end.
601
602receive_wait(S=#s{socket=Sock,
603		  timeout=Timeout}) ->
604    inet:setopts(Sock, [{active,once}]),
605    receive
606	{tcp,Sock,Data} ->
607	    S#s{encrypted_data_buffer = <<(S#s.encrypted_data_buffer)/binary,Data/binary>>};
608	{tcp_closed,Sock} ->
609	    throw({tcp,tcp_closed});
610	{tcp_error, Sock, Reason} ->
611	    throw({tcp,{tcp_error,Reason}})
612    after Timeout ->
613	    fail(receive_timeout,S)
614    end.
615
616receive_wait(N, S=#s{socket=Sock,
617		     timeout=Timeout,
618		     encrypted_data_buffer=Enc0}) when N>0 ->
619    inet:setopts(Sock, [{active,once}]),
620    receive
621	{tcp,Sock,Data} ->
622	    receive_wait(N-size(Data), S#s{encrypted_data_buffer = <<Enc0/binary,Data/binary>>});
623	{tcp_closed,Sock} ->
624	    throw({tcp,tcp_closed});
625	{tcp_error, Sock, Reason} ->
626	    throw({tcp,{tcp_error,Reason}})
627    after Timeout ->
628	    fail(receive_timeout, S)
629    end;
630receive_wait(_N, S) ->
631    S.
632
633%% random_padding_len(PaddingLen1, ChunkSize) ->
634%%     MaxAdditionalRandomPaddingLen = 		% max 255 bytes padding totaö
635%% 	(255 - PaddingLen1) - ((255 - PaddingLen1) rem ChunkSize),
636%%     AddLen0 = crypto:rand_uniform(0,MaxAdditionalRandomPaddingLen),
637%%     AddLen0 - (AddLen0 rem ChunkSize).		% preserve the blocking
638
639inc_recv_seq_num(C=#ssh{recv_sequence=N}) -> C#ssh{recv_sequence=(N+1) band 16#ffffffff}.
640%%%inc_send_seq_num(C=#ssh{send_sequence=N}) -> C#ssh{send_sequence=(N+1) band 16#ffffffff}.
641
642opposite_role(#ssh{role=R}) -> opposite_role(R);
643opposite_role(client) -> server;
644opposite_role(server) -> client.
645
646ok(ok) -> ok;
647ok({ok,R}) -> R;
648ok({error,E}) -> erlang:error(E).
649
650
651%%%================================================================
652%%%
653%%% Formating of records
654%%%
655
656format_msg(M) -> format_msg(M, 0).
657
658format_msg(M, I0) ->
659    case fields(M) of
660	undefined -> io_lib:format('~p',[M]);
661	Fields ->
662	    [Name|Args] = tuple_to_list(M),
663	    Head = io_lib:format('#~p{',[Name]),
664	    I = lists:flatlength(Head)+I0,
665	    NL = io_lib:format('~n~*c',[I,$ ]),
666	    Sep = io_lib:format(',~n~*c',[I,$ ]),
667	    Tail = [begin
668			S0 = io_lib:format('~p = ',[F]),
669			I1 = I + lists:flatlength(S0),
670			[S0,format_msg(A,I1)]
671		    end
672		    || {F,A} <- lists:zip(Fields,Args)],
673	    [[Head|string:join(Tail,Sep)],NL,"}"]
674    end.
675
676fields(M) ->
677    case M of
678	#ssh_msg_debug{} -> record_info(fields, ssh_msg_debug);
679	#ssh_msg_disconnect{} -> record_info(fields, ssh_msg_disconnect);
680	#ssh_msg_ignore{} -> record_info(fields, ssh_msg_ignore);
681	#ssh_msg_kex_dh_gex_group{} -> record_info(fields, ssh_msg_kex_dh_gex_group);
682	#ssh_msg_kex_dh_gex_init{} -> record_info(fields, ssh_msg_kex_dh_gex_init);
683	#ssh_msg_kex_dh_gex_reply{} -> record_info(fields, ssh_msg_kex_dh_gex_reply);
684	#ssh_msg_kex_dh_gex_request{} -> record_info(fields, ssh_msg_kex_dh_gex_request);
685	#ssh_msg_kex_dh_gex_request_old{} -> record_info(fields, ssh_msg_kex_dh_gex_request_old);
686	#ssh_msg_kexdh_init{} -> record_info(fields, ssh_msg_kexdh_init);
687	#ssh_msg_kexdh_reply{} -> record_info(fields, ssh_msg_kexdh_reply);
688	#ssh_msg_kexinit{} -> record_info(fields, ssh_msg_kexinit);
689	#ssh_msg_newkeys{} -> record_info(fields, ssh_msg_newkeys);
690	#ssh_msg_service_accept{} -> record_info(fields, ssh_msg_service_accept);
691	#ssh_msg_service_request{} -> record_info(fields, ssh_msg_service_request);
692	#ssh_msg_unimplemented{} -> record_info(fields, ssh_msg_unimplemented);
693	#ssh_msg_userauth_request{} -> record_info(fields, ssh_msg_userauth_request);
694	#ssh_msg_userauth_failure{} -> record_info(fields, ssh_msg_userauth_failure);
695	#ssh_msg_userauth_success{} -> record_info(fields, ssh_msg_userauth_success);
696	#ssh_msg_userauth_banner{} -> record_info(fields, ssh_msg_userauth_banner);
697	#ssh_msg_userauth_passwd_changereq{} -> record_info(fields, ssh_msg_userauth_passwd_changereq);
698	#ssh_msg_userauth_pk_ok{} -> record_info(fields, ssh_msg_userauth_pk_ok);
699	#ssh_msg_userauth_info_request{} -> record_info(fields, ssh_msg_userauth_info_request);
700	#ssh_msg_userauth_info_response{} -> record_info(fields, ssh_msg_userauth_info_response);
701	#s{} -> record_info(fields, s);
702	#ssh{} -> record_info(fields, ssh);
703	#alg{} -> record_info(fields, alg);
704	_ -> undefined
705    end.
706
707%%%================================================================
708%%%
709%%% Trace handling
710%%%
711
712init_op_traces(Op, S0) ->
713    opt(print_ops, S0#s{prints=[]},
714	fun(true) ->
715		case ?role(S0) of
716		    undefined -> {"-- ~p~n",[Op]};
717		    Role -> {"-- ~p ~p~n",[Role,Op]}
718		end
719	end
720       ).
721
722report_trace(Class, Term, S) ->
723    print_traces(
724      opt(print_ops, S,
725	  fun(true) -> {"~s ~p",[Class,Term]} end)
726     ).
727
728seqnum_trace(S) ->
729    opt(print_seqnums, S,
730	fun(true) when S#s.ssh#ssh.send_sequence =/= S#s.ssh#ssh.send_sequence,
731		       S#s.ssh#ssh.recv_sequence =/= S#s.ssh#ssh.recv_sequence ->
732		{"~p seq num: send ~p->~p,  recv ~p->~p~n",
733		 [?role(S),
734		  S#s.ssh#ssh.send_sequence, S#s.ssh#ssh.send_sequence,
735		  S#s.ssh#ssh.recv_sequence, S#s.ssh#ssh.recv_sequence
736		 ]};
737	   (true) when S#s.ssh#ssh.send_sequence =/=  S#s.ssh#ssh.send_sequence ->
738		{"~p seq num: send ~p->~p~n",
739		 [?role(S),
740		  S#s.ssh#ssh.send_sequence, S#s.ssh#ssh.send_sequence]};
741	   (true) when S#s.ssh#ssh.recv_sequence =/=  S#s.ssh#ssh.recv_sequence ->
742		{"~p seq num: recv ~p->~p~n",
743		 [?role(S),
744		  S#s.ssh#ssh.recv_sequence, S#s.ssh#ssh.recv_sequence]}
745	end).
746
747print_traces(S) when S#s.prints == [] -> S;
748print_traces(S) ->
749    Len = length(S#s.prints),
750    ct:log("~s",
751	   [lists:foldl(
752	      fun({Fmt,Args}, Acc) ->
753		      [case Len-length(Acc)-1 of
754			   0 ->
755			       io_lib:format(Fmt,Args);
756			   _N ->
757			       io_lib:format(lists:concat(['~p --------~n',Fmt]),
758					     [Len-length(Acc)-1|Args])
759		       end | Acc]
760	      end, "", S#s.prints)]
761	  ).
762
763opt(Flag, S, Fun) when is_function(Fun,1) ->
764    try Fun(proplists:get_value(Flag,S#s.opts))
765    of P={Fmt,Args} when is_list(Fmt), is_list(Args) ->
766	    save_prints(P, S)
767    catch _:_ ->
768	    S
769    end.
770
771save_prints({Fmt,Args}, S) ->
772    S#s{prints = [{Fmt,Args}|S#s.prints]}.
773