1%%
2%% %CopyrightBegin%
3%%
4%% Copyright Ericsson AB 2019. All Rights Reserved.
5%%
6%% Licensed under the Apache License, Version 2.0 (the "License");
7%% you may not use this file except in compliance with the License.
8%% You may obtain a copy of the License at
9%%
10%%     http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing, software
13%% distributed under the License is distributed on an "AS IS" BASIS,
14%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15%% See the License for the specific language governing permissions and
16%% limitations under the License.
17%%
18%% %CopyrightEnd%
19%%
20%% -------------------------------------------------------------------------
21%%
22%% Module for encrypted Erlang protocol - a minimal encrypted
23%% distribution protocol based on only a shared secret
24%% and the crypto application
25%%
26-module(inet_crypto_dist).
27-define(DIST_NAME, inet_crypto).
28-define(DIST_PROTO, crypto).
29-define(DRIVER, inet_tcp).
30-define(FAMILY, inet).
31
32-export([listen/1, accept/1, accept_connection/5,
33	 setup/5, close/1, select/1, is_node_name/1]).
34
35%% Generalized dist API, for sibling IPv6 module inet6_crypto_dist
36-export([gen_listen/2, gen_accept/2, gen_accept_connection/6,
37	 gen_setup/6, gen_close/2, gen_select/2]).
38
39-export([nodelay/0]).
40
41%% Debug
42%%%-compile(export_all).
43-export([dbg/0, test_server/0, test_client/1]).
44
45-include_lib("kernel/include/net_address.hrl").
46-include_lib("kernel/include/dist.hrl").
47-include_lib("kernel/include/dist_util.hrl").
48
49-define(PACKET_SIZE, 65536).
50-define(BUFFER_SIZE, (?PACKET_SIZE bsl 4)).
51
52%% -------------------------------------------------------------------------
53
54-record(params,
55        {socket,
56         dist_handle,
57         hmac_algorithm = sha256,
58         aead_cipher = aes_gcm,
59         rekey_key,
60         iv = 12,
61         key = 16,
62         tag_len = 16,
63         rekey_count = 262144,
64         rekey_time = 7200000, % 2 hours
65         rekey_msg
66        }).
67
68params(Socket) ->
69    #params{socket = Socket}.
70
71
72-record(key_pair,
73        {type = ecdh,
74         %% The curve choice greatly affects setup time,
75         %% we really want an Edwards curve but that would
76         %% require a very new openssl version.
77         %% Twisted brainpool curves (*t1) are faster than
78         %% non-twisted (*r1), 256 is much faster than 384,
79         %% and so on...
80%%%         params = brainpoolP384t1,
81         params = brainpoolP256t1,
82         public,
83         private,
84         life_time = 3600000, % 1 hour
85         life_count = 256 % Number of connection setups
86        }).
87
88
89%% -------------------------------------------------------------------------
90%% Keep the node's public/private key pair in the process state
91%% of a key pair server linked to the acceptor process.
92%% Create the key pair the first time it is needed
93%% so crypto gets time to start first.
94%%
95
96start_key_pair_server() ->
97    monitor_dist_proc(
98      spawn_link(
99        fun () ->
100                register(?MODULE, self()),
101                key_pair_server()
102        end)).
103
104key_pair_server() ->
105    key_pair_server(undefined, undefined, undefined).
106%%
107key_pair_server(
108  #key_pair{life_time = LifeTime, life_count = LifeCount} = KeyPair) ->
109    %% Presuming: 1 < LifeCount
110    Timer =
111        case LifeCount of
112            1 ->
113                undefined;
114            _ ->
115                erlang:start_timer(LifeTime, self(), discard)
116        end,
117    key_pair_server(KeyPair, Timer, LifeCount - 1).
118%%
119key_pair_server(_KeyPair, Timer, 0) ->
120    cancel_timer(Timer),
121    key_pair_server();
122key_pair_server(KeyPair, Timer, Count) ->
123    receive
124        {Pid, Tag, get_key_pair} ->
125            case KeyPair of
126                undefined ->
127                    KeyPair_1 = generate_key_pair(),
128                    Pid ! {Tag, KeyPair_1},
129                    key_pair_server(KeyPair_1);
130                #key_pair{} ->
131                    Pid ! {Tag, KeyPair},
132                    key_pair_server(KeyPair, Timer, Count - 1)
133            end;
134        {Pid, Tag, get_new_key_pair} ->
135            cancel_timer(Timer),
136            KeyPair_1 = generate_key_pair(),
137            Pid ! {Tag, KeyPair_1},
138            key_pair_server(KeyPair_1);
139        {timeout, Timer, discard} when is_reference(Timer) ->
140            key_pair_server()
141    end.
142
143generate_key_pair() ->
144    #key_pair{type = Type, params = Params} = #key_pair{},
145    {Public, Private} =
146        crypto:generate_key(Type, Params),
147    #key_pair{public = Public, private = Private}.
148
149
150cancel_timer(undefined) ->
151    ok;
152cancel_timer(Timer) ->
153    erlang_cancel_timer(Timer).
154
155start_rekey_timer(Time) ->
156    Timer = erlang:start_timer(Time, self(), rekey_time),
157    {timeout, Timer, rekey_time}.
158
159cancel_rekey_timer({timeout, Timer, rekey_time}) ->
160    erlang_cancel_timer(Timer).
161
162erlang_cancel_timer(Timer) ->
163    case erlang:cancel_timer(Timer) of
164        false ->
165            receive
166                {timeout, Timer, _} -> ok
167            end;
168        _RemainingTime ->
169            ok
170    end.
171
172get_key_pair() ->
173    call_key_pair_server(get_key_pair).
174
175get_new_key_pair() ->
176    call_key_pair_server(get_new_key_pair).
177
178call_key_pair_server(Request) ->
179    Pid = whereis(?MODULE),
180    Ref = erlang:monitor(process, Pid),
181    Pid ! {self(), Ref, Request},
182    receive
183        {Ref, Reply} ->
184            erlang:demonitor(Ref, [flush]),
185            Reply;
186        {'DOWN', Ref, process, Pid, Reason} ->
187            error(Reason)
188    end.
189
190compute_shared_secret(
191  #key_pair{
192     type = PublicKeyType,
193     params = PublicKeyParams,
194     private = PrivKey}, PubKey) ->
195    %%
196    crypto:compute_key(PublicKeyType, PubKey, PrivKey, PublicKeyParams).
197
198%% -------------------------------------------------------------------------
199%% Erlang distribution plugin structure explained to myself
200%% -------
201%% These are the processes involved in the distribution:
202%% * net_kernel
203%% * The Acceptor
204%% * The Controller | Handshaker | Ticker
205%% * The DistCtrl process that may be split into:
206%%   + The Output controller
207%%   + The Input controller
208%%   For the regular inet_tcp_dist distribution module, DistCtrl
209%%   is not one or two processes, but one port - a gen_tcp socket
210%%
211%% When the VM is started with the argument "-proto_dist inet_crypto"
212%% net_kernel registers the module inet_crypto_dist acli,oams distribution
213%% module.  net_kernel calls listen/1 to create a listen socket
214%% and then accept/1 with the listen socket as argument to spawn
215%% the Acceptor process, which is linked to net_kernel.  Apparently
216%% the listen socket is owned by net_kernel - I wonder if it could
217%% be owned by the Acceptor process instead...
218%%
219%% The Acceptor process calls blocking accept on the listen socket
220%% and when an incoming socket is returned it spawns the DistCtrl
221%% process a linked to the Acceptor.  The ownership of the accepted
222%% socket is transferred to the DistCtrl process.
223%% A message is sent to net_kernel to inform it that an incoming
224%% connection has appeared and the Acceptor awaits a reply from net_kernel.
225%%
226%% net_kernel then calls accept_connection/5 to spawn the Controller |
227%% Handshaker | Ticker process that is linked to net_kernel.
228%% The Controller then awaits a message from the Acceptor process.
229%%
230%% When net_kernel has spawned the Controller it replies with a message
231%% to the Acceptor that then calls DistCtrl to changes its links
232%% so DistCtrl ends up linked to the Controller and not to the Acceptor.
233%% The Acceptor then sends a message to the Controller.  The Controller
234%% then changes role into the Handshaker creates a #hs_data{} record
235%% and calls dist_util:handshake_other_started/1.  After this
236%% the Acceptor goes back into a blocking accept on the listen socket.
237%%
238%% For the regular distribution inet_tcp_dist DistCtrl is a gen_tcp socket
239%% and when it is a process it also acts as a socket.  The #hs_data{}
240%% record used by dist_util presents a set of funs that are used
241%% by dist_util to perform the distribution handshake.  These funs
242%% make sure to transfer the handshake messages through the DistCtrl
243%% "socket".
244%%
245%% When the handshake is finished a fun for this purpose in #hs_data{}
246%% is called, which tells DistCtrl that it does not need to be prepared
247%% for any more #hs_data{} handshake calls.  The DistCtrl process in this
248%% module then spawns the Input controller process that gets ownership
249%% of the connection's gen_tcp socket and changes into {active, N} mode
250%% so now it gets all incoming traffic and delivers that to the VM.
251%% The original DistCtrl process changes role into the Output controller
252%% process and starts asking the VM for outbound messages and transfers
253%% them on the connection socket.
254%%
255%% The Handshaker now changes into the Ticker role, and uses only two
256%% functions in the #hs_data{} record; one to get socket statistics
257%% and one to send a tick.  None of these may block for any reason
258%% in particular not for a congested socket since that would destroy
259%% connection supervision.
260%%
261%%
262%% For an connection net_kernel calls setup/5 which spawns the
263%% Controller process as linked to net_kernel.  This Controller process
264%% connects to the other node's listen socket and when that is succesful
265%% spawns the DistCtrl process as linked to the controller and transfers
266%% socket ownership to it.
267%%
268%% Then the Controller creates the #hs_data{} record and calls
269%% dist_util:handshake_we_started/1 which changes the process role
270%% into Handshaker.
271%%
272%% When the distribution handshake is finished the procedure is just
273%% as for an incoming connection above.
274%%
275%%
276%% To sum it up.
277%%
278%% There is an Acceptor process that is linked to net_kernel and
279%% informs it when new connections arrive.
280%%
281%% net_kernel spawns Controllers for incoming and for outgoing connections.
282%% these Controllers use the DistCtrl processes to do distribution
283%% handshake and after that becomes Tickers that supervise the connection.
284%%
285%% The Controller | Handshaker | Ticker is linked to net_kernel, and to
286%% DistCtrl, one or both.  If any of these connection processes would die
287%% all others should be killed by the links.  Therefore none of them may
288%% terminate with reason 'normal'.
289%% -------------------------------------------------------------------------
290
291-compile({inline, [socket_options/0]}).
292socket_options() ->
293    [binary, {active, false}, {packet, 2}, {nodelay, true},
294     {sndbuf, ?BUFFER_SIZE}, {recbuf, ?BUFFER_SIZE},
295     {buffer, ?BUFFER_SIZE}].
296
297%% -------------------------------------------------------------------------
298%% select/1 is called by net_kernel to ask if this distribution protocol
299%% is willing to handle Node
300%%
301
302select(Node) ->
303    gen_select(Node, ?DRIVER).
304
305gen_select(Node, Driver) ->
306    case dist_util:split_node(Node) of
307        {node, _, Host} ->
308	    case Driver:getaddr(Host) of
309		{ok, _} -> true;
310		_ -> false
311	    end;
312        _ ->
313            false
314    end.
315
316%% -------------------------------------------------------------------------
317
318is_node_name(Node) ->
319    dist_util:is_node_name(Node).
320
321%% -------------------------------------------------------------------------
322%% Called by net_kernel to create a listen socket for this
323%% distribution protocol.  This listen socket is used by
324%% the Acceptor process.
325%%
326
327listen(Name) ->
328    gen_listen(Name, ?DRIVER).
329
330gen_listen(Name, Driver) ->
331    {ok, Host} = inet:gethostname(),
332    case inet_tcp_dist:gen_listen(Driver, Name, Host) of
333        {ok, {Socket, Address, Creation}} ->
334            inet:setopts(Socket, socket_options()),
335            {ok,
336             {Socket, Address#net_address{protocol = ?DIST_PROTO}, Creation}};
337        Other ->
338            Other
339    end.
340
341%% -------------------------------------------------------------------------
342%% Called by net_kernel to spawn the Acceptor process that awaits
343%% new connection in a blocking accept and informs net_kernel
344%% when a new connection has appeared, and starts the DistCtrl
345%% "socket" process for the connection.
346%%
347
348accept(Listen) ->
349    gen_accept(Listen, ?DRIVER).
350
351gen_accept(Listen, Driver) ->
352    NetKernel = self(),
353    %%
354    %% Spawn Acceptor process
355    %%
356    monitor_dist_proc(
357      spawn_opt(
358        fun () ->
359                start_key_pair_server(),
360                accept_loop(Listen, Driver, NetKernel)
361        end,
362        [link, {priority, max}])).
363
364accept_loop(Listen, Driver, NetKernel) ->
365    case Driver:accept(trace(Listen)) of
366        {ok, Socket} ->
367            wait_for_code_server(),
368            Timeout = net_kernel:connecttime(),
369            DistCtrl = start_dist_ctrl(trace(Socket), Timeout),
370            %% DistCtrl is a "socket"
371            NetKernel !
372                trace({accept,
373                       self(), DistCtrl, Driver:family(), ?DIST_PROTO}),
374            receive
375                {NetKernel, controller, Controller} ->
376                    call_dist_ctrl(DistCtrl, {controller, Controller, self()}),
377                    Controller ! {self(), controller, Socket};
378                {NetKernel, unsupported_protocol} ->
379                    exit(unsupported_protocol)
380            end,
381            accept_loop(Listen, Driver, NetKernel);
382        AcceptError ->
383            exit({accept, AcceptError})
384    end.
385
386wait_for_code_server() ->
387    %% This is an ugly hack.  Starting encryption on a connection
388    %% requires the crypto module to be loaded.  Loading the crypto
389    %% module triggers its on_load function, which calls
390    %% code:priv_dir/1 to find the directory where its NIF library is.
391    %% However, distribution is started earlier than the code server,
392    %% so the code server is not necessarily started yet, and
393    %% code:priv_dir/1 might fail because of that, if we receive
394    %% an incoming connection on the distribution port early enough.
395    %%
396    %% If the on_load function of a module fails, the module is
397    %% unloaded, and the function call that triggered loading it fails
398    %% with 'undef', which is rather confusing.
399    %%
400    %% So let's avoid that by waiting for the code server to start.
401    %%
402    case whereis(code_server) of
403	undefined ->
404	    timer:sleep(10),
405	    wait_for_code_server();
406	Pid when is_pid(Pid) ->
407	    ok
408    end.
409
410%% -------------------------------------------------------------------------
411%% Called by net_kernel when a new connection has appeared, to spawn
412%% a Controller process that performs the handshake with the new node,
413%% and then becomes the Ticker connection supervisor.
414%% -------------------------------------------------------------------------
415
416accept_connection(Acceptor, DistCtrl, MyNode, Allowed, SetupTime) ->
417    gen_accept_connection(
418      Acceptor, DistCtrl, MyNode, Allowed, SetupTime, ?DRIVER).
419
420gen_accept_connection(
421  Acceptor, DistCtrl, MyNode, Allowed, SetupTime, Driver) ->
422    NetKernel = self(),
423    %%
424    %% Spawn Controller/handshaker/ticker process
425    %%
426    monitor_dist_proc(
427      spawn_opt(
428        fun() ->
429                do_accept(
430                  Acceptor, DistCtrl,
431                  trace(MyNode), Allowed, SetupTime, Driver, NetKernel)
432        end,
433        [link, {priority, max}])).
434
435do_accept(
436  Acceptor, DistCtrl, MyNode, Allowed, SetupTime, Driver, NetKernel) ->
437    %%
438    receive
439	{Acceptor, controller, Socket} ->
440	    Timer = dist_util:start_timer(SetupTime),
441            HSData =
442                hs_data_common(
443                  NetKernel, MyNode, DistCtrl, Timer,
444                  Socket, Driver:family()),
445            HSData_1 =
446                HSData#hs_data{
447                  this_node = MyNode,
448                  this_flags = 0,
449                  allowed = Allowed},
450            dist_util:handshake_other_started(trace(HSData_1))
451    end.
452
453%% -------------------------------------------------------------------------
454%% Called by net_kernel to spawn a Controller process that sets up
455%% a new connection to another Erlang node, performs the handshake
456%% with the other it, and then becomes the Ticker process
457%% that supervises the connection.
458%% -------------------------------------------------------------------------
459
460setup(Node, Type, MyNode, LongOrShortNames, SetupTime) ->
461    gen_setup(Node, Type, MyNode, LongOrShortNames, SetupTime, ?DRIVER).
462
463gen_setup(Node, Type, MyNode, LongOrShortNames, SetupTime, Driver) ->
464    NetKernel = self(),
465    %%
466    %% Spawn Controller/handshaker/ticker process
467    %%
468    monitor_dist_proc(
469      spawn_opt(
470        setup_fun(
471          Node, Type, MyNode, LongOrShortNames, SetupTime, Driver, NetKernel),
472        [link, {priority, max}])).
473
474-spec setup_fun(_,_,_,_,_,_,_) -> fun(() -> no_return()).
475setup_fun(
476  Node, Type, MyNode, LongOrShortNames, SetupTime, Driver, NetKernel) ->
477    %%
478    fun() ->
479            do_setup(
480              trace(Node), Type, MyNode, LongOrShortNames, SetupTime,
481              Driver, NetKernel)
482    end.
483
484-spec do_setup(_,_,_,_,_,_,_) -> no_return().
485do_setup(
486  Node, Type, MyNode, LongOrShortNames, SetupTime, Driver, NetKernel) ->
487    %%
488    {Name, Address} = split_node(Driver, Node, LongOrShortNames),
489    ErlEpmd = net_kernel:epmd_module(),
490    {ARMod, ARFun} = get_address_resolver(ErlEpmd, Driver),
491    Timer = trace(dist_util:start_timer(SetupTime)),
492    case ARMod:ARFun(Name, Address, Driver:family()) of
493        {ok, Ip, TcpPort, Version} ->
494            do_setup_connect(
495              Node, Type, MyNode, Timer, Driver, NetKernel,
496              Ip, TcpPort, Version);
497	{ok, Ip} ->
498	    case ErlEpmd:port_please(Name, Ip) of
499		{port, TcpPort, Version} ->
500                do_setup_connect(
501                  Node, Type, MyNode, Timer, Driver, NetKernel,
502                  Ip, TcpPort, trace(Version));
503		Other ->
504                    _ = trace(
505                          {ErlEpmd, port_please, [Name, Ip], Other}),
506                    ?shutdown(Node)
507	    end;
508	Other ->
509            _ = trace(
510                  {ARMod, ARFun, [Name, Address, Driver:family()],
511                   Other}),
512            ?shutdown(Node)
513    end.
514
515-spec do_setup_connect(_,_,_,_,_,_,_,_,_) -> no_return().
516
517do_setup_connect(
518  Node, Type, MyNode, Timer, Driver, NetKernel,
519  Ip, TcpPort, Version) ->
520    dist_util:reset_timer(Timer),
521    ConnectOpts = trace(connect_options(socket_options())),
522    case Driver:connect(Ip, TcpPort, ConnectOpts) of
523        {ok, Socket} ->
524            DistCtrl =
525                try start_dist_ctrl(Socket, net_kernel:connecttime())
526                catch error : {dist_ctrl, _} = DistCtrlError ->
527                        _ = trace(DistCtrlError),
528                        ?shutdown(Node)
529                end,
530            %% DistCtrl is a "socket"
531            HSData =
532                hs_data_common(
533                  NetKernel, MyNode, DistCtrl, Timer,
534                  Socket, Driver:family()),
535            HSData_1 =
536                HSData#hs_data{
537                  other_node = Node,
538                  this_flags = 0,
539                  other_version = Version,
540                  request_type = Type},
541            dist_util:handshake_we_started(trace(HSData_1));
542        ConnectError ->
543            _ = trace(
544                  {Driver, connect, [Ip, TcpPort, ConnectOpts],
545                   ConnectError}),
546            ?shutdown(Node)
547    end.
548
549%% -------------------------------------------------------------------------
550%% close/1 is only called by net_kernel on the socket returned by listen/1.
551
552close(Socket) ->
553    gen_close(Socket, ?DRIVER).
554
555gen_close(Socket, Driver) ->
556    Driver:close(trace(Socket)).
557
558%% -------------------------------------------------------------------------
559
560
561hs_data_common(NetKernel, MyNode, DistCtrl, Timer, Socket, Family) ->
562    %% Field 'socket' below is set to DistCtrl, which makes
563    %% the distribution handshake process (ticker) call
564    %% the funs below with DistCtrl as the S argument.
565    %% So, S =:= DistCtrl below...
566    #hs_data{
567       kernel_pid = NetKernel,
568       this_node = MyNode,
569       socket = DistCtrl,
570       timer = Timer,
571       %%
572       f_send = % -> ok | {error, closed}=>?shutdown()
573           fun (S, Packet) when S =:= DistCtrl ->
574                   try call_dist_ctrl(S, {send, Packet})
575                   catch error : {dist_ctrl, Reason} ->
576                           _ = trace(Reason),
577                           {error, closed}
578                   end
579           end,
580       f_recv = % -> {ok, List} | Other=>?shutdown()
581           fun (S, 0, infinity) when S =:= DistCtrl ->
582                   try call_dist_ctrl(S, recv) of
583                       {ok, Bin} when is_binary(Bin) ->
584                           {ok, binary_to_list(Bin)};
585                       Error ->
586                           Error
587                   catch error : {dist_ctrl, Reason} ->
588                           {error, trace(Reason)}
589                   end
590           end,
591       f_setopts_pre_nodeup =
592           fun (S) when S =:= DistCtrl ->
593                   ok
594           end,
595       f_setopts_post_nodeup =
596           fun (S) when S =:= DistCtrl ->
597                   ok
598           end,
599       f_getll =
600           fun (S) when S =:= DistCtrl ->
601                   {ok, S} %% DistCtrl is the distribution port
602           end,
603       f_address = % -> #net_address{} | ?shutdown()
604           fun (S, Node) when S =:= DistCtrl ->
605                   try call_dist_ctrl(S, peername) of
606                       {ok, Address} ->
607                           case dist_util:split_node(Node) of
608                               {node, _, Host} ->
609                                   #net_address{
610                                      address = Address,
611                                      host = Host,
612                                      protocol = ?DIST_PROTO,
613                                      family = Family};
614                               _ ->
615                                   ?shutdown(Node)
616                           end;
617                       Error ->
618                           _ = trace(Error),
619                           ?shutdown(Node)
620                   catch error : {dist_ctrl, Reason} ->
621                           _ = trace(Reason),
622                           ?shutdown(Node)
623                   end
624           end,
625       f_handshake_complete = % -> ok | ?shutdown()
626           fun (S, Node, DistHandle) when S =:= DistCtrl ->
627                   try call_dist_ctrl(S, {handshake_complete, DistHandle})
628                   catch error : {dist_ctrl, Reason} ->
629                           _ = trace(Reason),
630                           ?shutdown(Node)
631                   end
632           end,
633       %%
634       %% mf_tick/1, mf_getstat/1, mf_setopts/2 and mf_getopts/2
635       %% are called by the ticker any time after f_handshake_complete/3
636       %% so they may not block the caller even for congested socket
637       mf_tick =
638           fun (S) when S =:= DistCtrl ->
639                   S ! dist_tick
640           end,
641       mf_getstat = % -> {ok, RecvCnt, SendCnt, SendPend} | Other=>ignore_it
642           fun (S) when S =:= DistCtrl ->
643                   case
644                       inet:getstat(Socket, [recv_cnt, send_cnt, send_pend])
645                   of
646                       {ok, Stat} ->
647                           split_stat(Stat, 0, 0, 0);
648                       Error ->
649                           trace(Error)
650                   end
651           end,
652       mf_setopts =
653           fun (S, Opts) when S =:= DistCtrl ->
654                   inet:setopts(Socket, setopts_filter(Opts))
655           end,
656       mf_getopts =
657           fun (S, Opts) when S =:= DistCtrl ->
658                   inet:getopts(Socket, Opts)
659           end}.
660
661setopts_filter(Opts) ->
662    [Opt ||
663        Opt <- Opts,
664        case Opt of
665            {K, _} when K =:= active; K =:= deliver; K =:= packet -> false;
666            K when K =:= list; K =:= binary -> false;
667            K when K =:= inet; K =:= inet6 -> false;
668            _ -> true
669        end].
670
671split_stat([{recv_cnt, R}|Stat], _, W, P) ->
672    split_stat(Stat, R, W, P);
673split_stat([{send_cnt, W}|Stat], R, _, P) ->
674    split_stat(Stat, R, W, P);
675split_stat([{send_pend, P}|Stat], R, W, _) ->
676    split_stat(Stat, R, W, P);
677split_stat([], R, W, P) ->
678    {ok, R, W, P}.
679
680%% ------------------------------------------------------------
681%% Determine if EPMD module supports address resolving. Default
682%% is to use inet_tcp:getaddr/2.
683%% ------------------------------------------------------------
684get_address_resolver(EpmdModule, _Driver) ->
685    case erlang:function_exported(EpmdModule, address_please, 3) of
686        true -> {EpmdModule, address_please};
687        _    -> {erl_epmd, address_please}
688    end.
689
690
691%% If Node is illegal terminate the connection setup!!
692split_node(Driver, Node, LongOrShortNames) ->
693    case dist_util:split_node(Node) of
694        {node, Name, Host} ->
695	    check_node(Driver, Node, Name, Host, LongOrShortNames);
696	{host, _} ->
697	    error_logger:error_msg(
698              "** Nodename ~p illegal, no '@' character **~n",
699              [Node]),
700	    ?shutdown2(Node, trace({illegal_node_n@me, Node}));
701	_ ->
702	    error_logger:error_msg(
703              "** Nodename ~p illegal **~n", [Node]),
704	    ?shutdown2(Node, trace({illegal_node_name, Node}))
705    end.
706
707check_node(Driver, Node, Name, Host, LongOrShortNames) ->
708    case string:split(Host, ".", all) of
709	[_] when LongOrShortNames =:= longnames ->
710	    case Driver:parse_address(Host) of
711		{ok, _} ->
712		    {Name, Host};
713		_ ->
714		    error_logger:error_msg(
715                      "** System running to use "
716                      "fully qualified hostnames **~n"
717                      "** Hostname ~s is illegal **~n",
718                      [Host]),
719		    ?shutdown2(Node, trace({not_longnames, Host}))
720	    end;
721	[_, _|_] when LongOrShortNames =:= shortnames ->
722	    error_logger:error_msg(
723              "** System NOT running to use "
724              "fully qualified hostnames **~n"
725              "** Hostname ~s is illegal **~n",
726              [Host]),
727	    ?shutdown2(Node, trace({not_shortnames, Host}));
728	_ ->
729	    {Name, Host}
730    end.
731
732%% -------------------------------------------------------------------------
733
734connect_options(Opts) ->
735    case application:get_env(kernel, inet_dist_connect_options) of
736	{ok, ConnectOpts} ->
737            Opts ++ setopts_filter(ConnectOpts);
738	_ ->
739	    Opts
740    end.
741
742%% we may not always want the nodelay behaviour
743%% for performance reasons
744nodelay() ->
745    case application:get_env(kernel, dist_nodelay) of
746	undefined ->
747	    {nodelay, true};
748	{ok, true} ->
749	    {nodelay, true};
750	{ok, false} ->
751	    {nodelay, false};
752	_ ->
753	    {nodelay, true}
754    end.
755
756%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
757%%
758%% The DistCtrl process(es).
759%%
760%% At net_kernel handshake_complete spawns off the input controller that
761%% takes over the socket ownership, and itself becomes the output controller
762%%
763%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
764
765%%% XXX Missing to "productified":
766%%% * Cryptoanalysis by experts, this is crypto amateur work.
767%%% * Is it useful over inet_tls_dist; i.e to not have to bother
768%%%   with certificates but instead manage a secret cluster cookie?
769%%% * An application to belong to (kernel)
770%%% * Restart and/or code reload policy (not needed in kernel)
771%%% * Fitting into the epmd/Erlang distro protocol version framework
772%%%   (something needs to be created for multiple protocols, epmd,
773%%%    multiple address families, fallback to previous version, etc)
774
775
776%% Debug client and server
777
778test_server() ->
779    {ok, Listen} = gen_tcp:listen(0, socket_options()),
780    {ok, Port} = inet:port(Listen),
781    io:format(?MODULE_STRING":test_client(~w).~n", [Port]),
782    {ok, Socket} = gen_tcp:accept(Listen),
783    test(Socket).
784
785test_client(Port) ->
786    {ok, Socket} = gen_tcp:connect(localhost, Port, socket_options()),
787    test(Socket).
788
789test(Socket) ->
790    start_dist_ctrl(Socket, 10000).
791
792%% -------------------------------------------------------------------------
793
794start_dist_ctrl(Socket, Timeout) ->
795    Secret = atom_to_binary(auth:get_cookie(), latin1),
796    Controller = self(),
797    Server =
798        monitor_dist_proc(
799          spawn_opt(
800            fun () ->
801                    receive
802                        {?MODULE, From, start} ->
803                            {SendParams, RecvParams} =
804                                init(Socket, Secret),
805                            reply(From, self()),
806                            handshake(SendParams, 1, RecvParams, 1, Controller)
807                    end
808            end,
809            [link,
810             {priority, max},
811             {message_queue_data, off_heap},
812             {fullsweep_after, 0}])),
813    ok = gen_tcp:controlling_process(Socket, Server),
814    call_dist_ctrl(Server, start, Timeout).
815
816
817call_dist_ctrl(Server, Msg) ->
818    call_dist_ctrl(Server, Msg, infinity).
819%%
820call_dist_ctrl(Server, Msg, Timeout) ->
821    Ref = erlang:monitor(process, Server),
822    Server ! {?MODULE, {Ref, self()}, Msg},
823    receive
824        {Ref, Res} ->
825            erlang:demonitor(Ref, [flush]),
826            Res;
827        {'DOWN', Ref, process, Server, Reason} ->
828            error({dist_ctrl, Reason})
829    after Timeout -> % Timeout < infinity is only used by start_dist_ctrl/2
830            receive
831                {'DOWN', Ref, process, Server, _} ->
832                    receive {Ref, _} -> ok after 0 -> ok end,
833                    error({dist_ctrl, timeout})
834                    %% Server will be killed by link
835            end
836    end.
837
838reply({Ref, Pid}, Msg) ->
839    Pid ! {Ref, Msg},
840    ok.
841
842%% -------------------------------------------------------------------------
843
844-define(TCP_ACTIVE, 16).
845-define(CHUNK_SIZE, (?PACKET_SIZE - 512)).
846
847-define(HANDSHAKE_CHUNK, 1).
848-define(DATA_CHUNK, 2).
849-define(TICK_CHUNK, 3).
850-define(REKEY_CHUNK, 4).
851
852%% -------------------------------------------------------------------------
853%% Crypto strategy
854%% -------
855%% The crypto strategy is as simple as possible to get an encrypted
856%% connection as benchmark reference.  It is geared around AEAD
857%% ciphers in particular AES-GCM.
858%%
859%% The init message and the start message must fit in the TCP buffers
860%% since both sides start with sending the init message, waits
861%% for the other end's init message, sends the start message
862%% and waits for the other end's start message.  So if the send
863%% blocks we have a deadlock.
864%%
865%% The init + start sequence tries to implement Password Encrypted
866%% Key Exchange using a node public/private key pair and the
867%% shared secret (the Cookie) to create session encryption keys
868%% that can not be re-created if the shared secret is compromized,
869%% which should create forward secrecy.  You need both nodes'
870%% key pairs and the shared secret to decrypt the traffic
871%% between the nodes.
872%%
873%% All exchanged messages uses {packet, 2} i.e 16 bit size header.
874%%
875%% The init message contains a random number and encrypted: the public key
876%% and two random numbers.  The encryption is done with Key and IV hashed
877%% from the unencrypted random number and the shared secret.
878%%
879%% The other node's public key is used with the own node's private
880%% key to create a shared key that is hashed with one of the encrypted
881%% random numbers from each side to create Key and IV for the session.
882%%
883%% The start message contains the two encrypted random numbers
884%% this time encrypted with the session keys for verification
885%% by the other side, plus the rekey count.  The rekey count
886%% is just there to get an early check for if the other side's
887%% maximum rekey count is acceptable, it is just an embryo
888%% of some better check.  Any side may rekey earlier but if the
889%% rekey count is exceeded the connection fails.  Rekey is also
890%% triggered by a timer.
891%%
892%% Subsequent encrypted messages has the sequence number and the length
893%% of the message as AAD data, and an incrementing IV.  These messages
894%% has got a message type that differentiates data from ticks and rekeys.
895%% Ticks have a random size in an attempt to make them less obvious to spot.
896%%
897%% Rekeying is done by the sender that creates a new key pair and
898%% a new shared secret from the other end's public key and with
899%% this and the current key and iv hashes a new key and iv.
900%% The new public key is sent to the other end that uses it
901%% and its old private key to create the same new shared
902%% secret and from that a new key and iv.
903%% So the receiver keeps its private key, and the sender keeps
904%% the receivers public key for the connection's life time.
905%% While the sender generates a new key pair at every rekey,
906%% which changes the shared secret at every rekey.
907%%
908%% The only reaction to errors is to crash noisily (?) wich will bring
909%% down the connection and hopefully produce something useful
910%% in the local log, but all the other end sees is a closed connection.
911%% -------------------------------------------------------------------------
912
913init(Socket, Secret) ->
914    #key_pair{public = PubKey} = KeyPair = get_key_pair(),
915    Params = params(Socket),
916    {R2, R3, Msg} = init_msg(Params, PubKey, Secret),
917    ok = gen_tcp:send(Socket, Msg),
918    init_recv(Params, Secret, KeyPair, R2, R3).
919
920init_recv(
921  #params{socket = Socket, iv = IVLen} = Params, Secret, KeyPair, R2, R3) ->
922    %%
923    {ok, InitMsg} = gen_tcp:recv(Socket, 0),
924    IVSaltLen = IVLen - 6,
925    try
926        case init_msg(Params, Secret, KeyPair, R2, R3, InitMsg) of
927            {#params{iv = <<IV2ASalt:IVSaltLen/binary, IV2ANo:48>>} =
928                 SendParams,
929             RecvParams, SendStartMsg} ->
930                ok = gen_tcp:send(Socket, SendStartMsg),
931                {ok, RecvStartMsg} = gen_tcp:recv(Socket, 0),
932                #params{
933                   iv = <<IV2BSalt:IVSaltLen/binary, IV2BNo:48>>} =
934                    RecvParams_1 =
935                    start_msg(RecvParams, R2, R3, RecvStartMsg),
936                {SendParams#params{iv = {IV2ASalt, IV2ANo}},
937                 RecvParams_1#params{iv = {IV2BSalt, IV2BNo}}}
938        end
939    catch
940        error : Reason : Stacktrace->
941            _ = trace({Reason, Stacktrace}),
942            exit(connection_closed)
943    end.
944
945
946
947init_msg(
948  #params{
949     hmac_algorithm = HmacAlgo,
950     aead_cipher = AeadCipher,
951     key = KeyLen,
952     iv = IVLen,
953     tag_len = TagLen}, PubKeyA, Secret) ->
954    %%
955    RLen = KeyLen + IVLen,
956    <<R1A:RLen/binary, R2A:RLen/binary, R3A:RLen/binary>> =
957        crypto:strong_rand_bytes(3 * RLen),
958    {Key1A, IV1A} = hmac_key_iv(HmacAlgo, R1A, Secret, KeyLen, IVLen),
959    Plaintext = [R2A, R3A, PubKeyA],
960    MsgLen = byte_size(R1A) + TagLen + iolist_size(Plaintext),
961    AAD = [<<MsgLen:32>>, R1A],
962    {Ciphertext, Tag} =
963        crypto:block_encrypt(AeadCipher, Key1A, IV1A, {AAD, Plaintext, TagLen}),
964    Msg = [R1A, Tag, Ciphertext],
965    {R2A, R3A, Msg}.
966%%
967init_msg(
968  #params{
969     hmac_algorithm = HmacAlgo,
970     aead_cipher = AeadCipher,
971     key = KeyLen,
972     iv = IVLen,
973     tag_len = TagLen,
974     rekey_count = RekeyCount} = Params,
975  Secret, KeyPair, R2A, R3A, Msg) ->
976    %%
977    RLen = KeyLen + IVLen,
978    case Msg of
979        <<R1B:RLen/binary, Tag:TagLen/binary, Ciphertext/binary>> ->
980            {Key1B, IV1B} = hmac_key_iv(HmacAlgo, R1B, Secret, KeyLen, IVLen),
981            MsgLen = byte_size(Msg),
982            AAD = [<<MsgLen:32>>, R1B],
983            case
984                crypto:block_decrypt(
985                  AeadCipher, Key1B, IV1B, {AAD, Ciphertext, Tag})
986            of
987                <<R2B:RLen/binary, R3B:RLen/binary, PubKeyB/binary>> ->
988                    SharedSecret = compute_shared_secret(KeyPair, PubKeyB),
989                    %%
990                    {Key2A, IV2A} =
991                        hmac_key_iv(
992                          HmacAlgo, SharedSecret, [R2A, R3B], KeyLen, IVLen),
993                    SendParams =
994                        Params#params{
995                          rekey_key = PubKeyB,
996                          key = Key2A, iv = IV2A},
997                    %%
998                    StartCleartext = [R2B, R3B, <<RekeyCount:32>>],
999                    StartMsgLen = TagLen + iolist_size(StartCleartext),
1000                    StartAAD = <<StartMsgLen:32>>,
1001                    {StartCiphertext, StartTag} =
1002                        crypto:block_encrypt(
1003                          AeadCipher, Key2A, IV2A,
1004                          {StartAAD, StartCleartext, TagLen}),
1005                    StartMsg = [StartTag, StartCiphertext],
1006                    %%
1007                    {Key2B, IV2B} =
1008                        hmac_key_iv(
1009                          HmacAlgo, SharedSecret, [R2B, R3A], KeyLen, IVLen),
1010                    RecvParams =
1011                        Params#params{
1012                          rekey_key = KeyPair,
1013                          key = Key2B, iv = IV2B},
1014                    %%
1015                    {SendParams, RecvParams, StartMsg}
1016            end
1017    end.
1018
1019start_msg(
1020  #params{
1021     aead_cipher = AeadCipher,
1022     key = Key2B,
1023     iv = IV2B,
1024     tag_len = TagLen,
1025     rekey_count = RekeyCountA} = RecvParams, R2A, R3A, Msg) ->
1026    %%
1027    case Msg of
1028        <<Tag:TagLen/binary, Ciphertext/binary>> ->
1029            KeyLen = byte_size(Key2B),
1030            IVLen = byte_size(IV2B),
1031            RLen = KeyLen + IVLen,
1032            MsgLen = byte_size(Msg),
1033            AAD = <<MsgLen:32>>,
1034            case
1035                crypto:block_decrypt(
1036                  AeadCipher, Key2B, IV2B, {AAD, Ciphertext, Tag})
1037            of
1038                <<R2A:RLen/binary, R3A:RLen/binary, RekeyCountB:32>>
1039                  when RekeyCountA =< (RekeyCountB bsl 2),
1040                       RekeyCountB =< (RekeyCountA bsl 2) ->
1041                    RecvParams#params{rekey_count = RekeyCountB}
1042            end
1043    end.
1044
1045hmac_key_iv(HmacAlgo, MacKey, Data, KeyLen, IVLen) ->
1046    <<Key:KeyLen/binary, IV:IVLen/binary>> =
1047        crypto:hmac(HmacAlgo, MacKey, Data, KeyLen + IVLen),
1048    {Key, IV}.
1049
1050%% -------------------------------------------------------------------------
1051%% net_kernel distribution handshake in progress
1052%%
1053
1054handshake(
1055  SendParams, SendSeq,
1056  #params{socket = Socket} = RecvParams, RecvSeq, Controller) ->
1057    receive
1058        {?MODULE, From, {controller, Controller_1, Parent}} ->
1059            Result = link(Controller_1),
1060            true = unlink(Parent),
1061            reply(From, Result),
1062            handshake(SendParams, SendSeq, RecvParams, RecvSeq, Controller_1);
1063        {?MODULE, From, {handshake_complete, DistHandle}} ->
1064            InputHandler =
1065                monitor_dist_proc(
1066                  spawn_opt(
1067                    fun () ->
1068                            link(Controller),
1069                            receive
1070                                DistHandle ->
1071                                    ok =
1072                                        inet:setopts(
1073                                          Socket,
1074                                          [{active, ?TCP_ACTIVE},
1075                                           nodelay()]),
1076                                    input_handler(
1077                                      RecvParams#params{
1078                                        dist_handle = DistHandle},
1079                                      RecvSeq, empty_q(), infinity)
1080                            end
1081                    end,
1082                    [link,
1083                     {priority, normal},
1084                     {message_queue_data, off_heap},
1085                     {fullsweep_after, 0}])),
1086            _ = monitor(process, InputHandler), % For the benchmark test
1087            ok = gen_tcp:controlling_process(Socket, InputHandler),
1088            ok = erlang:dist_ctrl_input_handler(DistHandle, InputHandler),
1089            InputHandler ! DistHandle,
1090            crypto:rand_seed_alg(crypto_cache),
1091            reply(From, ok),
1092            process_flag(priority, normal),
1093            erlang:dist_ctrl_get_data_notification(DistHandle),
1094            output_handler(
1095              SendParams#params{
1096                dist_handle = DistHandle,
1097                rekey_msg = start_rekey_timer(SendParams#params.rekey_time)},
1098              SendSeq);
1099        %%
1100        {?MODULE, From, {send, Data}} ->
1101            case
1102                encrypt_and_send_chunk(
1103                  SendParams, SendSeq, [?HANDSHAKE_CHUNK, Data])
1104            of
1105                {SendParams_1, SendSeq_1, ok} ->
1106                    reply(From, ok),
1107                    handshake(
1108                      SendParams_1, SendSeq_1, RecvParams, RecvSeq,
1109                      Controller);
1110                {_, _, Error} ->
1111                    reply(From, {error, closed}),
1112                    death_row({send, trace(Error)})
1113            end;
1114        {?MODULE, From, recv} ->
1115            case recv_and_decrypt_chunk(RecvParams, RecvSeq) of
1116                {RecvParams_1, RecvSeq_1, {ok, _} = Reply} ->
1117                    reply(From, Reply),
1118                    handshake(
1119                      SendParams, SendSeq, RecvParams_1, RecvSeq_1,
1120                      Controller);
1121                {_, _, Error} ->
1122                    reply(From, Error),
1123                    death_row({recv, trace(Error)})
1124            end;
1125        {?MODULE, From, peername} ->
1126            reply(From, inet:peername(Socket)),
1127            handshake(SendParams, SendSeq, RecvParams, RecvSeq, Controller);
1128        %%
1129        _Alien ->
1130            handshake(SendParams, SendSeq, RecvParams, RecvSeq, Controller)
1131    end.
1132
1133recv_and_decrypt_chunk(#params{socket = Socket} = RecvParams, RecvSeq) ->
1134    case gen_tcp:recv(Socket, 0) of
1135        {ok, Chunk} ->
1136            case decrypt_chunk(RecvParams, RecvSeq, Chunk) of
1137                <<?HANDSHAKE_CHUNK, Cleartext/binary>> ->
1138                    {RecvParams, RecvSeq + 1, {ok, Cleartext}};
1139                OtherChunk when is_binary(OtherChunk) ->
1140                    {RecvParams, RecvSeq + 1, {error, decrypt_error}};
1141                #params{} = RecvParams_1 ->
1142                    recv_and_decrypt_chunk(RecvParams_1, 0);
1143                error ->
1144                    {RecvParams, RecvSeq, {error, decrypt_error}}
1145            end;
1146        Error ->
1147            {RecvParams, RecvSeq, Error}
1148    end.
1149
1150%% -------------------------------------------------------------------------
1151%% Output handler process
1152%%
1153%% The game here is to flush all dist_data and dist_tick messages,
1154%% prioritize dist_data over dist_tick, and to not use selective receive
1155
1156output_handler(Params, Seq) ->
1157    receive
1158        Msg ->
1159            case Msg of
1160                dist_data ->
1161                    output_handler_data(Params, Seq);
1162                dist_tick ->
1163                    output_handler_tick(Params, Seq);
1164                _ when Msg =:= Params#params.rekey_msg ->
1165                    Params_1 = output_handler_rekey(Params, Seq),
1166                    output_handler(Params_1, 0);
1167                _ ->
1168                    %% Ignore
1169                    _ = trace(Msg),
1170                    output_handler(Params, Seq)
1171            end
1172    end.
1173
1174output_handler_data(Params, Seq) ->
1175    receive
1176        Msg ->
1177            case Msg of
1178                dist_data ->
1179                    output_handler_data(Params, Seq);
1180                dist_tick ->
1181                    output_handler_data(Params, Seq);
1182                _ when Msg =:= Params#params.rekey_msg ->
1183                    Params_1 = output_handler_rekey(Params, Seq),
1184                    output_handler_data(Params_1, 0);
1185                _ ->
1186                    %% Ignore
1187                    _ = trace(Msg),
1188                    output_handler_data(Params, Seq)
1189            end
1190    after 0 ->
1191            DistHandle = Params#params.dist_handle,
1192            Q = get_data(DistHandle, empty_q()),
1193            {Params_1, Seq_1} = output_handler_send(Params, Seq, Q),
1194            erlang:dist_ctrl_get_data_notification(DistHandle),
1195            output_handler(Params_1, Seq_1)
1196    end.
1197
1198output_handler_tick(Params, Seq) ->
1199    receive
1200        Msg ->
1201            case Msg of
1202                dist_data ->
1203                    output_handler_data(Params, Seq);
1204                dist_tick ->
1205                    output_handler_tick(Params, Seq);
1206                _ when Msg =:= Params#params.rekey_msg ->
1207                    Params_1 = output_handler_rekey(Params, Seq),
1208                    output_handler(Params_1, 0);
1209                _ ->
1210                    %% Ignore
1211                    _ = trace(Msg),
1212                    output_handler_tick(Params, Seq)
1213            end
1214    after 0 ->
1215            TickSize = 7 + rand:uniform(56),
1216            TickData = binary:copy(<<0>>, TickSize),
1217            case
1218                encrypt_and_send_chunk(Params, Seq, [?TICK_CHUNK, TickData])
1219            of
1220                {Params_1, Seq_1, ok} ->
1221                    output_handler(Params_1, Seq_1);
1222                {_, _, Error} ->
1223                    _ = trace(Error),
1224                    death_row()
1225            end
1226    end.
1227
1228output_handler_rekey(Params, Seq) ->
1229    case encrypt_and_send_rekey_chunk(Params, Seq) of
1230        #params{} = Params_1 ->
1231            Params_1;
1232        SendError ->
1233            _ = trace(SendError),
1234            death_row()
1235    end.
1236
1237output_handler_send(Params, Seq, {_, Size, _} = Q) ->
1238    if
1239        ?CHUNK_SIZE < Size ->
1240            output_handler_deq_send(Params, Seq, Q, ?CHUNK_SIZE);
1241        true ->
1242            case get_data(Params#params.dist_handle, Q) of
1243                {_, 0, _} ->
1244                    {Params, Seq};
1245                {_, Size, _} = Q_1 -> % Got no more
1246                    output_handler_deq_send(Params, Seq, Q_1, Size);
1247                Q_1 ->
1248                    output_handler_send(Params, Seq, Q_1)
1249            end
1250    end.
1251
1252output_handler_deq_send(Params, Seq, Q, Size) ->
1253    {Cleartext, Q_1} = deq_iovec(Size, Q),
1254    case
1255        encrypt_and_send_chunk(Params, Seq, [?DATA_CHUNK, Cleartext])
1256    of
1257        {Params_1, Seq_1, ok} ->
1258            output_handler_send(Params_1, Seq_1, Q_1);
1259        {_, _, Error} ->
1260            _ = trace(Error),
1261            death_row()
1262    end.
1263
1264%% -------------------------------------------------------------------------
1265%% Input handler process
1266%%
1267%% Here is T = 0|infinity to steer if we should try to receive
1268%% more data or not; start with infinity, and when we get some
1269%% data try with 0 to see if more is waiting
1270
1271input_handler(#params{socket = Socket} = Params, Seq, Q, T) ->
1272    receive
1273        Msg ->
1274            case Msg of
1275                {tcp_passive, Socket} ->
1276                    ok = inet:setopts(Socket, [{active, ?TCP_ACTIVE}]),
1277                    Q_1 =
1278                        case T of
1279                            0 ->
1280                                deliver_data(Params#params.dist_handle, Q);
1281                            infinity ->
1282                                Q
1283                        end,
1284                    input_handler(Params, Seq, Q_1, infinity);
1285                {tcp, Socket, Chunk} ->
1286                    input_chunk(Params, Seq, Q, T, Chunk);
1287                {tcp_closed, Socket} ->
1288                    exit(connection_closed);
1289                Other ->
1290                    %% Ignore...
1291                    _ = trace(Other),
1292                    input_handler(Params, Seq, Q, T)
1293            end
1294    after T ->
1295            Q_1 = deliver_data(Params#params.dist_handle, Q),
1296            input_handler(Params, Seq, Q_1, infinity)
1297    end.
1298
1299input_chunk(Params, Seq, Q, T, Chunk) ->
1300    case decrypt_chunk(Params, Seq, Chunk) of
1301        <<?DATA_CHUNK, Cleartext/binary>> ->
1302            input_handler(Params, Seq + 1, enq_binary(Cleartext, Q), 0);
1303        <<?TICK_CHUNK, _/binary>> ->
1304            input_handler(Params, Seq + 1, Q, T);
1305        OtherChunk when is_binary(OtherChunk) ->
1306            _ = trace(invalid_chunk),
1307            exit(connection_closed);
1308        #params{} = Params_1 ->
1309            input_handler(Params_1, 0, Q, T);
1310        error ->
1311            _ = trace(decrypt_error),
1312            exit(connection_closed)
1313    end.
1314
1315%% -------------------------------------------------------------------------
1316%% erlang:dist_ctrl_* helpers
1317
1318%% Get data for sending from the VM and place it in a queue
1319%%
1320get_data(DistHandle, {Front, Size, Rear}) ->
1321    get_data(DistHandle, Front, Size, Rear).
1322%%
1323get_data(DistHandle, Front, Size, Rear) ->
1324    case erlang:dist_ctrl_get_data(DistHandle) of
1325        none ->
1326            {Front, Size, Rear};
1327        Bin when is_binary(Bin)  ->
1328            Len = byte_size(Bin),
1329            get_data(
1330              DistHandle, Front, Size + 4 + Len,
1331              [Bin, <<Len:32>>|Rear]);
1332        [Bin1, Bin2] ->
1333            Len = byte_size(Bin1) + byte_size(Bin2),
1334            get_data(
1335              DistHandle, Front, Size + 4 + Len,
1336              [Bin2, Bin1, <<Len:32>>|Rear]);
1337        Iovec ->
1338            Len = iolist_size(Iovec),
1339            get_data(
1340              DistHandle, Front, Size + 4 + Len,
1341              lists:reverse(Iovec, [<<Len:32>>|Rear]))
1342    end.
1343
1344%% De-packet and deliver received data to the VM from a queue
1345%%
1346deliver_data(DistHandle, Q) ->
1347    case Q of
1348        {[], Size, []} ->
1349            Size = 0, % Assert
1350            Q;
1351        {[], Size, Rear} ->
1352            [Bin|Front] = lists:reverse(Rear),
1353            deliver_data(DistHandle, Front, Size, [], Bin);
1354        {[Bin|Front], Size, Rear} ->
1355            deliver_data(DistHandle, Front, Size, Rear, Bin)
1356    end.
1357%%
1358deliver_data(DistHandle, Front, Size, Rear, Bin) ->
1359    case Bin of
1360        <<DataSizeA:32, DataA:DataSizeA/binary,
1361          DataSizeB:32, DataB:DataSizeB/binary, Rest/binary>> ->
1362            erlang:dist_ctrl_put_data(DistHandle, DataA),
1363            erlang:dist_ctrl_put_data(DistHandle, DataB),
1364            deliver_data(
1365              DistHandle,
1366              Front, Size - (4 + DataSizeA + 4 + DataSizeB), Rear,
1367              Rest);
1368        <<DataSize:32, Data:DataSize/binary, Rest/binary>> ->
1369            erlang:dist_ctrl_put_data(DistHandle, Data),
1370            deliver_data(DistHandle, Front, Size - (4 + DataSize), Rear, Rest);
1371        <<DataSize:32, FirstData/binary>> ->
1372            TotalSize = 4 + DataSize,
1373            if
1374                TotalSize =< Size ->
1375                    BinSize = byte_size(Bin),
1376                    {MoreData, Q} =
1377                        deq_iovec(
1378                          TotalSize - BinSize,
1379                          Front, Size - BinSize, Rear),
1380                    erlang:dist_ctrl_put_data(DistHandle, [FirstData|MoreData]),
1381                    deliver_data(DistHandle, Q);
1382                true -> % Incomplete data
1383                    {[Bin|Front], Size, Rear}
1384            end;
1385        <<_/binary>> ->
1386            BinSize = byte_size(Bin),
1387            if
1388                4 =< Size -> % Fragmented header - extract a header bin
1389                    {RestHeader, {Front_1, _Size_1, Rear_1}} =
1390                        deq_iovec(4 - BinSize, Front, Size - BinSize, Rear),
1391                    Header = iolist_to_binary([Bin|RestHeader]),
1392                    deliver_data(DistHandle, Front_1, Size, Rear_1, Header);
1393                true -> % Incomplete header
1394                    {[Bin|Front], Size, Rear}
1395            end
1396    end.
1397
1398%% -------------------------------------------------------------------------
1399%% Encryption and decryption helpers
1400
1401encrypt_and_send_chunk(
1402  #params{
1403     socket = Socket, rekey_count = Seq, rekey_msg = RekeyMsg} = Params,
1404  Seq, Cleartext) ->
1405    %%
1406    cancel_rekey_timer(RekeyMsg),
1407    case encrypt_and_send_rekey_chunk(Params, Seq) of
1408        #params{} = Params_1 ->
1409            Result =
1410                gen_tcp:send(Socket, encrypt_chunk(Params, 0, Cleartext)),
1411            {Params_1, 1, Result};
1412        SendError ->
1413            {Params, Seq + 1, SendError}
1414    end;
1415encrypt_and_send_chunk(#params{socket = Socket} = Params, Seq, Cleartext) ->
1416    Result = gen_tcp:send(Socket, encrypt_chunk(Params, Seq, Cleartext)),
1417    {Params, Seq + 1, Result}.
1418
1419encrypt_and_send_rekey_chunk(
1420  #params{
1421     socket = Socket,
1422     rekey_key = PubKeyB,
1423     key = Key,
1424     iv = {IVSalt, IVNo},
1425     hmac_algorithm = HmacAlgo} = Params,
1426  Seq) ->
1427    %%
1428    KeyLen = byte_size(Key),
1429    IVSaltLen = byte_size(IVSalt),
1430    #key_pair{public = PubKeyA} = KeyPair = get_new_key_pair(),
1431    case
1432        gen_tcp:send(
1433          Socket, encrypt_chunk(Params, Seq, [?REKEY_CHUNK, PubKeyA]))
1434    of
1435        ok ->
1436            SharedSecret = compute_shared_secret(KeyPair, PubKeyB),
1437            IV = <<(IVNo + Seq):48>>,
1438            {Key_1, <<IVSalt_1:IVSaltLen/binary, IVNo_1:48>>} =
1439                hmac_key_iv(
1440                  HmacAlgo, SharedSecret, [Key, IVSalt, IV],
1441                  KeyLen, IVSaltLen + 6),
1442            Params#params{
1443              key = Key_1, iv = {IVSalt_1, IVNo_1},
1444              rekey_msg = start_rekey_timer(Params#params.rekey_time)};
1445        SendError ->
1446            SendError
1447    end.
1448
1449encrypt_chunk(
1450  #params{
1451     aead_cipher = AeadCipher,
1452     iv = {IVSalt, IVNo}, key = Key, tag_len = TagLen}, Seq, Cleartext) ->
1453    %%
1454    ChunkLen = iolist_size(Cleartext) + TagLen,
1455    AAD = <<Seq:32, ChunkLen:32>>,
1456    IVBin = <<IVSalt/binary, (IVNo + Seq):48>>,
1457    {Ciphertext, CipherTag} =
1458        crypto:block_encrypt(AeadCipher, Key, IVBin, {AAD, Cleartext, TagLen}),
1459    Chunk = [Ciphertext,CipherTag],
1460    Chunk.
1461
1462decrypt_chunk(
1463  #params{
1464     aead_cipher = AeadCipher,
1465     iv = {IVSalt, IVNo}, key = Key, tag_len = TagLen} = Params, Seq, Chunk) ->
1466    %%
1467    ChunkLen = byte_size(Chunk),
1468    if
1469        ChunkLen < TagLen ->
1470            error;
1471        true ->
1472            AAD = <<Seq:32, ChunkLen:32>>,
1473            IVBin = <<IVSalt/binary, (IVNo + Seq):48>>,
1474            CiphertextLen = ChunkLen - TagLen,
1475            case Chunk of
1476                <<Ciphertext:CiphertextLen/binary,
1477                  CipherTag:TagLen/binary>> ->
1478                    block_decrypt(
1479                      Params, Seq, AeadCipher, Key, IVBin,
1480                      {AAD, Ciphertext, CipherTag});
1481                _ ->
1482                    error
1483            end
1484    end.
1485
1486block_decrypt(
1487  #params{
1488     rekey_key = #key_pair{public = PubKeyA} = KeyPair,
1489     rekey_count = RekeyCount} = Params,
1490  Seq, AeadCipher, Key, IV, Data) ->
1491    %%
1492    case crypto:block_decrypt(AeadCipher, Key, IV, Data) of
1493        <<?REKEY_CHUNK, Rest/binary>> ->
1494            PubKeyLen = byte_size(PubKeyA),
1495            case Rest of
1496                <<PubKeyB:PubKeyLen/binary>> ->
1497                    SharedSecret = compute_shared_secret(KeyPair, PubKeyB),
1498                    KeyLen = byte_size(Key),
1499                    IVLen = byte_size(IV),
1500                    IVSaltLen = IVLen - 6,
1501                    {Key_1, <<IVSalt:IVSaltLen/binary, IVNo:48>>} =
1502                        hmac_key_iv(
1503                          Params#params.hmac_algorithm,
1504                          SharedSecret, [Key, IV], KeyLen, IVLen),
1505                    Params#params{iv = {IVSalt, IVNo}, key = Key_1};
1506                _ ->
1507                    error
1508            end;
1509        Chunk when is_binary(Chunk) ->
1510            case Seq of
1511                RekeyCount ->
1512                    %% This was one chunk too many without rekeying
1513                    error;
1514                _ ->
1515                    Chunk
1516            end;
1517        error ->
1518            error
1519    end.
1520
1521%% -------------------------------------------------------------------------
1522%% Queue of binaries i.e an iovec queue
1523
1524empty_q() ->
1525    {[], 0, []}.
1526
1527enq_binary(Bin, {Front, Size, Rear}) ->
1528    {Front, Size + byte_size(Bin), [Bin|Rear]}.
1529
1530deq_iovec(GetSize, {Front, Size, Rear}) when GetSize =< Size ->
1531    deq_iovec(GetSize, Front, Size, Rear, []).
1532%%
1533deq_iovec(GetSize, Front, Size, Rear) ->
1534    deq_iovec(GetSize, Front, Size, Rear, []).
1535%%
1536deq_iovec(GetSize, [], Size, Rear, Acc) ->
1537    deq_iovec(GetSize, lists:reverse(Rear), Size, [], Acc);
1538deq_iovec(GetSize, [Bin|Front], Size, Rear, Acc) ->
1539    BinSize = byte_size(Bin),
1540    if
1541        BinSize < GetSize ->
1542            deq_iovec(
1543              GetSize - BinSize, Front, Size - BinSize, Rear, [Bin|Acc]);
1544        GetSize < BinSize ->
1545            {Bin1,Bin2} = erlang:split_binary(Bin, GetSize),
1546            {lists:reverse(Acc, [Bin1]), {[Bin2|Front], Size - GetSize, Rear}};
1547        true ->
1548            {lists:reverse(Acc, [Bin]), {Front, Size - BinSize, Rear}}
1549    end.
1550
1551%% -------------------------------------------------------------------------
1552
1553death_row() -> death_row(connection_closed).
1554%%
1555death_row(normal) -> death_row(connection_closed);
1556death_row(Reason) -> receive after 5000 -> exit(Reason) end.
1557
1558%% -------------------------------------------------------------------------
1559
1560%% Trace point
1561trace(Term) -> Term.
1562
1563%% Keep an eye on this Pid (debug)
1564-ifndef(undefined).
1565monitor_dist_proc(Pid) ->
1566    Pid.
1567-else.
1568monitor_dist_proc(Pid) ->
1569    spawn(
1570      fun () ->
1571              MRef = erlang:monitor(process, Pid),
1572              receive
1573                  {'DOWN', MRef, _, _, normal} ->
1574                      error_logger:error_report(
1575                        [dist_proc_died,
1576                         {reason, normal},
1577                         {pid, Pid}]);
1578                  {'DOWN', MRef, _, _, Reason} ->
1579                      error_logger:info_report(
1580                        [dist_proc_died,
1581                         {reason, Reason},
1582                         {pid, Pid}])
1583              end
1584      end),
1585    Pid.
1586-endif.
1587
1588dbg() ->
1589    dbg:stop(),
1590    dbg:tracer(),
1591    dbg:p(all, c),
1592    dbg:tpl(?MODULE, trace, cx),
1593    dbg:tpl(erlang, dist_ctrl_get_data_notification, cx),
1594    dbg:tpl(erlang, dist_ctrl_get_data, cx),
1595    dbg:tpl(erlang, dist_ctrl_put_data, cx),
1596    ok.
1597