1%%
2%% %CopyrightBegin%
3%%
4%% Copyright Ericsson AB 2010-2020. All Rights Reserved.
5%%
6%% Licensed under the Apache License, Version 2.0 (the "License");
7%% you may not use this file except in compliance with the License.
8%% You may obtain a copy of the License at
9%%
10%%     http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing, software
13%% distributed under the License is distributed on an "AS IS" BASIS,
14%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15%% See the License for the specific language governing permissions and
16%% limitations under the License.
17%%
18%% %CopyrightEnd%
19%%
20
21-module(diameter_tcp).
22
23-dialyzer(no_improper_lists).
24
25-behaviour(gen_server).
26
27%% interface
28-export([start/3]).
29
30%% child start from supervisor
31-export([start_link/1]).
32
33%% child start from here
34-export([init/1]).
35
36%% gen_server callbacks
37-export([handle_call/3,
38         handle_cast/2,
39         handle_info/2,
40         code_change/3,
41         terminate/2]).
42
43-export([listener/1,%% diameter_sync callback
44         info/1]).  %% service_info callback
45
46-export([ports/0,
47         ports/1]).
48
49-export_type([connect_option/0,
50              listen_option/0]).
51
52-include_lib("diameter/include/diameter.hrl").
53
54%% Keys into process dictionary.
55-define(INFO_KEY, info).
56-define(REF_KEY,  ref).
57-define(TRANSPORT_KEY, transport).
58
59-define(ERROR(T), erlang:error({T, ?MODULE, ?LINE})).
60
61-define(DEFAULT_PORT, 3868).  %% RFC 3588, ch 2.1
62-define(DEFAULT_FRAGMENT_TIMEOUT, 1000).
63
64-define(IS_UINT32(N), (is_integer(N) andalso 0 =< N andalso 0 == N bsr 32)).
65-define(IS_TIMEOUT(N), (infinity == N orelse ?IS_UINT32(N))).
66
67%% cb_info passed to ssl.
68-define(TCP_CB(Mod), {Mod, tcp, tcp_closed, tcp_error}).
69
70%% The same gen_server implementation supports three different kinds
71%% of processes: an actual transport process, one that will club it to
72%% death should the parent die before a connection is established, and
73%% a process owning the listening port. The monitor process
74%% historically died after connection establishment, but can now live
75%% on as the sender of outgoing messages, so that a blocking send
76%% doesn't prevent messages from being received.
77
78%% Listener process state.
79-record(listener, {socket :: inet:socket(),
80                   module :: module(),
81                   service = false :: false | pid()}). %% service process
82
83%% Monitor process state.
84-record(monitor,
85        {parent :: reference() | false | pid(),
86         transport = self() :: pid(),
87         ack = false :: boolean(),
88         socket :: inet:socket() | ssl:sslsocket() | undefined,
89         module :: module() | undefined}).
90
91-type length() :: 0..16#FFFFFF. %% message length from Diameter header
92-type frag()   :: maybe_improper_list(length(), binary())
93                | binary().
94
95-type connect_option() :: {raddr, inet:ip_address()}
96                        | {rport, pos_integer()}
97                        | {ssl_options, true | [ssl:tls_client_option()]}
98                        | option()
99                        | ssl:tls_client_option()
100                        | gen_tcp:connect_option().
101
102-type match() :: inet:ip_address()
103               | string()
104               | [match()].
105
106-type listen_option() :: {accept, match()}
107                       | {ssl_options, true | [ssl:tls_server_option()]}
108                       | option()
109                       | ssl:tls_server_option()
110                       | gen_tcp:listen_option().
111
112-type option() :: {port, non_neg_integer()}
113                | {sender, boolean()}
114                | sender
115                | {message_cb, false | diameter:eval()}
116                | {fragment_timer, 0..16#FFFFFFFF}.
117
118%% Accepting/connecting transport process state.
119-record(transport,
120        {socket  :: inet:socket() | ssl:sslsocket(), %% accept/connect socket
121         active = false :: boolean(),           %% is socket active?
122         recv   = true  :: boolean(),           %% should it be active?
123         parent  :: pid(),          %% of process that started us
124         module  :: module(),       %% gen_tcp-like module
125         ssl     :: [term()] | boolean(),       %% ssl options, ssl or not
126         frag = <<>> :: frag(),                 %% message fragment
127         timeout :: infinity | 0..16#FFFFFFFF,  %% fragment timeout
128         tref = false  :: false | reference(),  %% fragment timer reference
129         flush = false :: boolean(),            %% flush fragment at timeout?
130         message_cb  :: false | diameter:eval(),
131         send        :: pid() | false}).         %% sending process
132
133%% The usual transport using gen_tcp can be replaced by anything
134%% sufficiently gen_tcp-like by passing a 'module' option as the first
135%% (for simplicity) transport option. The transport_module diameter_etcp
136%% uses this to set itself as the module to call, its start/3 just
137%% calling start/3 here with the option set.
138
139%% ---------------------------------------------------------------------------
140%% # start/3
141%% ---------------------------------------------------------------------------
142
143-spec start({accept, Ref}, #diameter_service{}, [listen_option()])
144   -> {ok, pid(), [inet:ip_address()]}
145 when Ref :: diameter:transport_ref();
146           ({connect, Ref}, #diameter_service{}, [connect_option()])
147   -> {ok, pid()}
148 when Ref :: diameter:transport_ref().
149
150start({T, Ref}, Svc, Opts) ->
151    #diameter_service{capabilities = Caps,
152                      pid = SvcPid}
153        = Svc,
154
155    diameter_tcp_sup:start(),  %% start tcp supervisors on demand
156    {Mod, Rest} = split(Opts),
157    Addrs = Caps#diameter_caps.host_ip_address,
158    Arg = {T, Ref, Mod, self(), Rest, Addrs, SvcPid},
159    diameter_tcp_sup:start_child(Arg).
160
161split([{module, M} | Opts]) ->
162    {M, Opts};
163split(Opts) ->
164    {gen_tcp, Opts}.
165
166%% start_link/1
167
168start_link(T) ->
169    proc_lib:start_link(?MODULE,
170                        init,
171                        [T],
172                        infinity,
173                        diameter_lib:spawn_opts(server, [])).
174
175%% ---------------------------------------------------------------------------
176%% # info/1
177%% ---------------------------------------------------------------------------
178
179info({Mod, Sock}) ->
180    lists:flatmap(fun(K) -> info(Mod, K, Sock) end,
181                  [{socket, fun sockname/2},
182                   {peer, fun peername/2},
183                   {statistics, fun getstat/2}
184                   | ssl_info(Mod, Sock)]).
185
186info(Mod, {K,F}, Sock) ->
187    case F(Mod, Sock) of
188        {ok, V} ->
189            [{K,V}];
190        _ ->
191            []
192    end.
193
194ssl_info(ssl = M, Sock) ->
195    [{M, ssl_info(Sock)}];
196ssl_info(_, _) ->
197    [].
198
199ssl_info(Sock) ->
200    [{peercert, C} || {ok, C} <- [ssl:peercert(Sock)]].
201
202%% ---------------------------------------------------------------------------
203%% # init/1
204%% ---------------------------------------------------------------------------
205
206init(T) ->
207    gen_server:enter_loop(?MODULE, [], i(T)).
208
209%% i/1
210
211%% A transport process.
212i({T, Ref, Mod, Pid, Opts, Addrs, SvcPid})
213  when T == accept;
214       T == connect ->
215    monitor(process, Pid),
216    %% Since accept/connect might block indefinitely, spawn a process
217    %% that kills us with the parent until call returns, and then
218    %% sends outgoing messages.
219    {[SO|TO], Rest} = proplists:split(Opts, [ssl_options,
220                                             sender,
221                                             message_cb,
222                                             fragment_timer]),
223    SslOpts = ssl_opts(SO),
224    OwnOpts = lists:append(TO),
225    Tmo = proplists:get_value(fragment_timer,
226                              OwnOpts,
227                              ?DEFAULT_FRAGMENT_TIMEOUT),
228    [CB, Sender] = [proplists:get_value(K, OwnOpts, false)
229                    || K <- [message_cb, sender]],
230    ?IS_TIMEOUT(Tmo) orelse ?ERROR({fragment_timer, Tmo}),
231    {ok, MPid} = diameter_tcp_sup:start_child(#monitor{parent = Pid}),
232    Sock = init(T, Ref, Mod, Pid, SslOpts, Rest, Addrs, SvcPid),
233    M = if SslOpts -> ssl; true -> Mod end,
234    Sender andalso monitor(process, MPid),
235    false == CB orelse (Pid ! {diameter, ack}),
236    MPid ! {start, self(), Sender andalso {Sock, M}, false /= CB},
237    putr(?REF_KEY, Ref),
238    setopts(#transport{parent = Pid,
239                       module = M,
240                       socket = Sock,
241                       ssl = SslOpts,
242                       message_cb = CB,
243                       timeout = Tmo,
244                       send = Sender andalso MPid});
245%% Put the reference in the process dictionary since we now use it
246%% advertise the ssl socket after TLS upgrade.
247
248%% A monitor process to kill the transport if the parent dies.
249i(#monitor{parent = Pid, transport = TPid} = S) ->
250    putr(?TRANSPORT_KEY, TPid),
251    proc_lib:init_ack({ok, self()}),
252    monitor(process, TPid),
253    S#monitor{parent = monitor(process, Pid)};
254%% In principle a link between the transport and killer processes
255%% could do the same thing: have the accepting/connecting process be
256%% killed when the killer process dies as a consequence of parent
257%% death. However, a link can be unlinked and this is exactly what
258%% gen_tcp seems to do. Links should be left to supervisors.
259
260i({listen, Ref, {Mod, Opts, Addrs}}) ->
261    [_] = diameter_config:subscribe(Ref, transport), %% assert existence
262    {[LP], Rest} = proplists:split(Opts, [port]),
263    {ok, LSock} = Mod:listen(get_port(LP), gen_opts(Addrs, Rest)),
264    {ok, {LAddr, _}} = sockname(Mod, LSock),
265    true = diameter_reg:add_new({?MODULE, listener, {Ref, {LAddr, LSock}}}),
266    proc_lib:init_ack({ok, self(), {LAddr, LSock}}),
267    #listener{socket = LSock,
268              module = Mod}.
269
270ssl_opts([]) ->
271    false;
272ssl_opts([{ssl_options, true}]) ->
273    true;
274ssl_opts([{ssl_options, Opts}])
275  when is_list(Opts) ->
276    Opts;
277ssl_opts(T) ->
278    ?ERROR({ssl_options, T}).
279
280%% init/8
281
282%% Establish a TLS connection before capabilities exchange ...
283init(Type, Ref, Mod, Pid, true, Opts, Addrs, SvcPid) ->
284    init(Type, Ref, ssl, Pid, [{cb_info, ?TCP_CB(Mod)} | Opts], Addrs, SvcPid);
285
286%% ... or not.
287init(Type, Ref, Mod, Pid, _, Opts, Addrs, SvcPid) ->
288    init(Type, Ref, Mod, Pid, Opts, Addrs, SvcPid).
289
290%% init/7
291
292init(accept = T, Ref, Mod, Pid, Opts, Addrs, SvcPid) ->
293    {[Matches], Rest} = proplists:split(Opts, [accept]),
294    {ok, LPid, {LAddr, LSock}} = listener(Ref, {Mod, Rest, Addrs}),
295    ok = gen_server:call(LPid, {accept, SvcPid}, infinity),
296    proc_lib:init_ack({ok, self(), [LAddr]}),
297    Sock = ok(accept(Mod, LSock)),
298    ok = accept_peer(Mod, Sock, accept(Matches)),
299    publish(Mod, T, Ref, Sock),
300    diameter_peer:up(Pid),
301    Sock;
302
303init(connect = T, Ref, Mod, Pid, Opts, Addrs, _SvcPid) ->
304    {[RA, RP], Rest} = proplists:split(Opts, [raddr, rport]),
305    RAddr = get_addr(RA),
306    RPort = get_port(RP),
307    proc_lib:init_ack({ok, self()}),
308    Sock = ok(connect(Mod, RAddr, RPort, gen_opts(Addrs, Rest))),
309    publish(Mod, T, Ref, Sock),
310    up(Pid, {RAddr, RPort}, Mod, Sock),
311    Sock.
312
313up(Pid, Remote, Mod, Sock) ->
314    {Addr, _Port} = ok(sockname(Mod, Sock)),
315    diameter_peer:up(Pid, Remote, [Addr]).
316
317publish(Mod, T, Ref, Sock) ->
318    true = diameter_reg:add_new({?MODULE, T, {Ref, Sock}}),
319    putr(?INFO_KEY, {Mod, Sock}).  %% for info/1
320
321ok({ok, T}) ->
322    T;
323ok(No) ->
324    x(No).
325
326x(Reason) ->
327    exit({shutdown, Reason}).
328
329%% accept_peer/3
330
331accept_peer(_Mod, _Sock, []) ->
332    ok;
333
334accept_peer(Mod, Sock, Matches) ->
335    {RAddr, _} = ok(peername(Mod, Sock)),
336    diameter_peer:match([RAddr], Matches)
337        orelse x({accept, RAddr, Matches}),
338    ok.
339
340%% accept/1
341
342accept(Opts) ->
343    [[M] || {accept, M} <- Opts].
344
345%% listener/2
346
347%% Accepting processes can be started concurrently: ensure only one
348%% listener is started.
349listener(Ref, T) ->
350    diameter_sync:call({?MODULE, listener, Ref},
351                       {?MODULE, listener, [{Ref, T, self()}]},
352                       infinity,
353                       infinity).
354
355%% listener/1
356
357listener({Ref, T, _TPid}) ->
358    l(diameter_reg:match({?MODULE, listener, {Ref, '_'}}), Ref, T).
359
360%% l/3
361
362%% Existing listening process ...
363l([{{?MODULE, listener, {_, AS}}, LPid}], _, _) ->
364    {ok, LPid, AS};
365
366%% ... or not.
367l([], Ref, T) ->
368    diameter_tcp_sup:start_child({listen, Ref, T}).
369
370%% addrs/2
371%%
372%% Take the first address from the service if several are specified
373%% and not address is configured.
374
375addrs(Addrs, Opts) ->
376    case lists:mapfoldr(fun ipaddr/2, [], Opts) of
377        {Os, [_]} ->
378            Os;
379        {_, []} ->
380            Opts ++ [{ip, A} || [A|_] <- [Addrs]];
381        {_, As} ->
382            ?ERROR({invalid_addrs, As, Addrs})
383    end.
384
385ipaddr({K,A}, As)
386  when K == ifaddr;
387       K == ip ->
388    {{ip, ipaddr(A)}, [A | As]};
389ipaddr(T, B) ->
390    {T, B}.
391
392ipaddr(A)
393  when A == loopback;
394       A == any ->
395    A;
396ipaddr(A) ->
397    diameter_lib:ipaddr(A).
398
399%% get_addr/1
400
401get_addr([{_, Addr}]) ->
402    diameter_lib:ipaddr(Addr);
403get_addr(Addrs) ->
404    ?ERROR({invalid_addrs, Addrs}).
405
406%% get_port/1
407
408get_port([{_, Port}]) ->
409    Port;
410get_port([]) ->
411    ?DEFAULT_PORT;
412get_port(Ps) ->
413    ?ERROR({invalid_ports, Ps}).
414
415%% gen_opts/2
416
417gen_opts(Addrs, Opts) ->
418    gen_opts(addrs(Addrs, Opts)).
419
420%% gen_opts/1
421
422gen_opts(Opts) ->
423    {L,_} = proplists:split(Opts, [binary, packet, active]),
424    [[],[],[]] == L orelse ?ERROR({reserved_options, Opts}),
425    [binary, {packet, 0}, {active, false} | Opts].
426
427%% ---------------------------------------------------------------------------
428%% # ports/1
429%% ---------------------------------------------------------------------------
430
431ports() ->
432    Ts = diameter_reg:match({?MODULE, '_', '_'}),
433    [{type(T), resolve(T,S), Pid} || {{?MODULE, T, {_,S}}, Pid} <- Ts].
434
435ports(Ref) ->
436    Ts = diameter_reg:match({?MODULE, '_', {Ref, '_'}}),
437    [{type(T), resolve(T,S), Pid} || {{?MODULE, T, {R,S}}, Pid} <- Ts,
438                                     R == Ref].
439
440type(listener) ->
441    listen;
442type(T) ->
443    T.
444
445sock(listener, {_LAddr, Sock}) ->
446    Sock;
447sock(_, Sock) ->
448    Sock.
449
450resolve(Type, S) ->
451    Sock = sock(Type, S),
452    try
453        ok(portnr(Sock))
454    catch
455        _:_ -> Sock
456    end.
457
458portnr(Sock)
459  when is_port(Sock) ->
460    portnr(gen_tcp, Sock);
461portnr(Sock) ->
462    portnr(ssl, Sock).
463
464%% ---------------------------------------------------------------------------
465%% # handle_call/3
466%% ---------------------------------------------------------------------------
467
468handle_call({accept, SvcPid}, _From, #listener{service = P} = S) ->
469    {reply, ok, if not is_pid(P), is_pid(SvcPid) ->
470                        monitor(process, SvcPid),
471                        S#listener{service = SvcPid};
472                   true ->
473                        S
474                end};
475
476%% Transport is telling us of parent death.
477handle_call({stop, _Pid} = Reason, _From, #monitor{} = S) ->
478    {stop, {shutdown, Reason}, ok, S};
479
480handle_call(_, _, State) ->
481    {reply, nok, State}.
482
483%% ---------------------------------------------------------------------------
484%% # handle_cast/2
485%% ---------------------------------------------------------------------------
486
487handle_cast(_, State) ->
488    {noreply, State}.
489
490%% ---------------------------------------------------------------------------
491%% # handle_info/2
492%% ---------------------------------------------------------------------------
493
494handle_info(T, #transport{} = S) ->
495    {noreply, #transport{} = t(T,S)};
496
497handle_info(T, #listener{} = S) ->
498    {noreply, #listener{} = l(T,S)};
499
500handle_info(T, #monitor{} = S) ->
501    {noreply, #monitor{} = m(T,S)}.
502
503%% ---------------------------------------------------------------------------
504%% # code_change/3
505%% ---------------------------------------------------------------------------
506
507code_change(_, State, _) ->
508    {ok, State}.
509
510%% ---------------------------------------------------------------------------
511%% # terminate/2
512%% ---------------------------------------------------------------------------
513
514terminate(_, _) ->
515    ok.
516
517
518%% ---------------------------------------------------------------------------
519
520putr(Key, Val) ->
521    put({?MODULE, Key}, Val).
522
523getr(Key) ->
524    get({?MODULE, Key}).
525
526%% m/2
527%%
528%% Transition monitor state.
529
530%% Outgoing message.
531m(Msg, S)
532  when is_record(Msg, diameter_packet);
533       is_binary(Msg) ->
534    send(Msg, S),
535    S;
536
537%% Transport has established a connection. Stop monitoring on the
538%% parent so as not to die before a send from the transport.
539m({start, TPid, T, Ack} = M, #monitor{transport = TPid} = S) ->
540    case T of
541        {Sock, Mod} ->
542            demonitor(S#monitor.parent, [flush]),
543            S#monitor{parent = false,
544                      socket = Sock,
545                      module = Mod,
546                      ack = Ack};
547        false ->  %% monitor not sending
548            x(M)
549    end;
550
551%% Transport is telling us to die.
552m({stop, TPid} = T, #monitor{transport = TPid}) ->
553    x(T);
554
555%% Transport is telling us to die.
556m({stop, TPid} = T, #monitor{transport = TPid}) ->
557    x(T);
558
559%% Transport is telling us that TLS has been negotiated after
560%% capabilities exchange.
561m({tls, SSock}, S) ->
562    S#monitor{socket = SSock,
563              module = ssl};
564
565%% Transport or parent has died.
566m({'DOWN', M, process, P, _} = T, #monitor{parent = MRef,
567                                           transport = TPid})
568  when M == MRef;
569       P == TPid ->
570    x(T).
571
572%% l/2
573%%
574%% Transition listener state. Or not anymore since any message causes
575%% the process to exit.
576
577-spec l(tuple(), #listener{})
578   -> no_return().
579
580%% Service process has died.
581l({'DOWN', _, process, Pid, _} = T, #listener{service = Pid,
582                                              socket = Sock,
583                                              module = M}) ->
584    M:close(Sock),
585    x(T);
586
587%% Transport has been removed.
588l({transport, remove, _} = T, #listener{socket = Sock,
589                                        module = M}) ->
590    M:close(Sock),
591    x(T).
592
593%% t/2
594%%
595%% Transition transport state.
596
597t(T,S) ->
598    case transition(T,S) of
599        ok ->
600            S;
601        #transport{} = NS ->
602            NS;
603        stop ->
604            x(T)
605    end.
606
607%% transition/2
608
609%% Incoming packets.
610transition({P, Sock, Bin}, #transport{socket = Sock,
611                                      ssl = B,
612                                      frag = Frag}
613                           = S)
614  when P == ssl, true == B;
615       P == tcp ->
616    recv(acc(Frag, Bin), S);
617
618%% Capabilties exchange has decided on whether or not to run over TLS.
619transition({diameter, {tls, Ref, Type, B}}, #transport{parent = Pid}
620                                            = S) ->
621    true = is_boolean(B),  %% assert
622    #transport{}
623        = NS
624        = tls_handshake(Type, B, S),
625    Pid ! {diameter, {tls, Ref}},
626    NS#transport{ssl = B};
627
628transition({C, Sock}, #transport{socket = Sock,
629                                 ssl = B})
630  when C == tcp_closed, not B;
631       C == ssl_closed, B ->
632    stop;
633
634transition({E, Sock, _Reason} = T, #transport{socket = Sock,
635                                              ssl = B}
636                                   = S)
637  when E == tcp_error, not B;
638       E == ssl_error, B ->
639    ?ERROR({T,S});
640
641%% Outgoing message.
642transition({diameter, {send, Msg}}, #transport{} = S) ->
643    message(send, Msg, S);
644
645%% Monitor has sent an outgoing message.
646transition(Msg, S)
647  when is_record(Msg, diameter_packet);
648       is_binary(Msg) ->
649    message(ack, Msg, S);
650
651%% Deferred actions from a message_cb.
652transition({actions, Dir, Acts}, S) ->
653    setopts(actions(Acts, Dir, S));
654
655%% Request to close the transport connection.
656transition({diameter, {close, Pid}}, #transport{parent = Pid,
657                                                socket = Sock,
658                                                module = M}) ->
659    M:close(Sock),
660    stop;
661
662%% Timeout for reception of outstanding packets.
663transition({timeout, TRef, flush}, #transport{tref = TRef} = S) ->
664    flush(S#transport{tref = false});
665
666%% Request for the local port number.
667transition({resolve_port, Pid}, #transport{socket = Sock,
668                                           module = M})
669  when is_pid(Pid) ->
670    Pid ! portnr(M, Sock),
671    ok;
672
673%% Parent process has died: call the monitor to not close the socket
674%% during an ongoing send, but don't let it take forever.
675transition({'DOWN', _, process, Pid, _}, #transport{parent = Pid,
676                                                    send = MPid}) ->
677    false == MPid
678        orelse (ok == gen_server:call(MPid, {stop, self()}, 1000))
679        orelse exit(MPid, {shutdown, parent}),
680    stop;
681
682%% Monitor process has died.
683transition({'DOWN', _, process, MPid, _}, #transport{send = MPid})
684  when is_pid(MPid) ->
685    stop.
686
687%% Crash on anything unexpected.
688
689%% tls_handshake/3
690%%
691%% In the case that no tls message is received (eg. the service hasn't
692%% been configured to advertise TLS support) we will simply never ask
693%% for another TCP message, which will force the watchdog to
694%% eventually take us down.
695
696%% TLS has already been established with the connection.
697tls_handshake(_, _, #transport{ssl = true} = S) ->
698    S;
699
700%% Capabilities exchange negotiated TLS but transport was not
701%% configured with an options list.
702tls_handshake(_, true, #transport{ssl = false}) ->
703    ?ERROR(no_ssl_options);
704
705%% Capabilities exchange negotiated TLS: upgrade the connection.
706tls_handshake(Type, true, #transport{socket = Sock,
707                                     module = M,
708                                     ssl = Opts,
709                                     send = MPid}
710                          = S) ->
711    {ok, SSock} = tls(Type, Sock, [{cb_info, ?TCP_CB(M)} | Opts]),
712    Ref = getr(?REF_KEY),
713    true = diameter_reg:add_new({?MODULE, Type, {Ref, SSock}}),
714    false == MPid orelse (MPid ! {tls, SSock}), %% tell the sender process
715    S#transport{socket = SSock,
716                module = ssl};
717
718%% Capabilities exchange has not negotiated TLS.
719tls_handshake(_, false, S) ->
720    S.
721
722tls(connect, Sock, Opts) ->
723    ssl:connect(Sock, Opts);
724tls(accept, Sock, Opts) ->
725    ssl:handshake(Sock, Opts).  %% assume no handshake option
726
727%% recv/2
728%%
729%% Reassemble fragmented messages and extract multiple message sent
730%% using Nagle.
731
732%% Receive packets until a full message is received,
733
734recv({Msg, Rest}, S) ->  %% have a complete message ...
735    recv(acc(Rest), message(recv, Msg, S));
736
737recv(Frag, #transport{recv = B,
738                      socket = Sock,
739                      module = M}
740           = S) ->       %% or not
741    B andalso setopts(M, Sock),
742    start_fragment_timer(S#transport{frag = Frag,
743                                     flush = false,
744                                     active = B}).
745
746%% acc/2
747
748%% Know how many bytes to extract.
749acc([Len | Acc], Bin) ->
750    acc1(Len, <<Acc/binary, Bin/binary>>);
751
752%% Or not.
753acc(Head, Bin) ->
754    acc(<<Head/binary, Bin/binary>>).
755
756%% acc1/3
757
758%% Extract a message for which we have all bytes.
759acc1(Len, Bin)
760  when Len =< byte_size(Bin) ->
761    split_binary(Bin, Len);
762
763%% Wait for more packets.
764acc1(Len, Bin) ->
765    [Len | Bin].
766
767%% acc/1
768
769%% Don't match on Bin since this results in it being copied at the
770%% next append according to the Efficiency Guide. This is also the
771%% reason that the Len is extracted and maintained when accumulating
772%% messages. The simplest implementation is just to accumulate a
773%% binary and match <<_, Len:24, _/binary>> each time the length is
774%% required, but the performance of this decays quadratically with the
775%% message length, since the binary is then copied with each append of
776%% additional bytes from gen_tcp.
777
778acc(Bin)
779  when 3 < byte_size(Bin) ->
780    {Head, _} = split_binary(Bin, 4),
781    [_,A,B,C] = binary_to_list(Head),
782    Len = (A bsl 16) bor (B bsl 8) bor C,
783    if Len < 20 ->
784            %% Message length isn't sufficient for a Diameter Header.
785            %% Chances are things will go south from here but if we're
786            %% lucky then the bytes we have extend to an intended
787            %% message boundary and we can recover by simply receiving
788            %% them. Make it so.
789            {Bin, <<>>};
790       true ->
791            acc1(Len, Bin)
792    end;
793
794%% Not even 4 bytes yet.
795acc(Bin) ->
796    Bin.
797
798%% bin/1
799
800bin([_ | Bin]) ->
801    Bin;
802
803bin(Bin) ->
804    Bin.
805
806%% flush/1
807
808%% An erroneously large message length may leave us with a fragment
809%% that lingers if the peer doesn't have anything more to send. Start
810%% a timer to force reception if an incoming message doesn't arrive
811%% first. This won't stop a peer from sending a large bogus value and
812%% following it up however but such a state of affairs can only go on
813%% for so long since an unanswered DWR will eventually be the result.
814%%
815%% An erroneously small message length causes problems as well but
816%% since all messages with length problems are discarded this should
817%% also eventually lead to watchdog failover.
818
819%% No fragment to flush or not receiving messages.
820flush(#transport{frag = <<>>} = S) ->
821    S;
822
823%% Messages have been received since last timer expiry.
824flush(#transport{flush = false} = S) ->
825    start_fragment_timer(S#transport{flush = true});
826
827%% No messages since last expiry.
828flush(#transport{frag = Frag} = S) ->
829    message(recv, bin(Frag), S#transport{frag = <<>>}).
830
831%% start_fragment_timer/1
832%%
833%% Start a timer only if there's none running and a message to flush.
834
835start_fragment_timer(#transport{frag = B, tref = TRef} = S)
836  when B == <<>>;
837       TRef /= false ->
838    S;
839
840start_fragment_timer(#transport{timeout = Tmo} = S) ->
841    S#transport{tref = erlang:start_timer(Tmo, self(), flush)}.
842
843%% accept/2
844
845accept(ssl, LSock) ->
846    case ssl:transport_accept(LSock) of
847        {ok, Sock} ->
848            ssl:handshake(Sock);
849        {error, _} = No ->
850            No
851    end;
852accept(Mod, LSock) ->
853    Mod:accept(LSock).
854
855%% connect/4
856
857connect(Mod, Host, Port, Opts) ->
858    Mod:connect(Host, Port, Opts).
859
860%% send/2
861
862send(Msg, #monitor{socket = Sock, module = M, transport = TPid, ack = B}) ->
863    send1(M, Sock, Msg),
864    B andalso (TPid ! Msg);
865
866send(Msg, #transport{socket = Sock, module = M, send = false} = S) ->
867    send1(M, Sock, Msg),
868    message(ack, Msg, S);
869
870%% Send from the monitor process to avoid deadlock if both the
871%% receiver and the peer were to block in send.
872send(Msg, #transport{send = Pid} = S) ->
873    Pid ! Msg,
874    S.
875
876%% send1/3
877
878send1(Mod, Sock, #diameter_packet{bin = Bin}) ->
879    send1(Mod, Sock, Bin);
880
881send1(Mod, Sock, Bin) ->
882    case send(Mod, Sock, Bin) of
883        ok ->
884            ok;
885        {error, Reason} ->
886            x({send, Reason})
887    end.
888
889%% send/3
890
891send(gen_tcp, Sock, Bin) ->
892    gen_tcp:send(Sock, Bin);
893send(ssl, Sock, Bin) ->
894    ssl:send(Sock, Bin);
895send(M, Sock, Bin) ->
896    M:send(Sock, Bin).
897
898%% setopts/3
899
900setopts(gen_tcp, Sock, Opts) ->
901    inet:setopts(Sock, Opts);
902setopts(ssl, Sock, Opts) ->
903    ssl:setopts(Sock, Opts);
904setopts(M, Sock, Opts) ->
905    M:setopts(Sock, Opts).
906
907%% setopts/1
908
909setopts(#transport{socket = Sock,
910                   active = A,
911                   recv = B,
912                   module = M}
913        = S)
914  when B, not A ->
915    setopts(M, Sock),
916    S#transport{active = true};
917
918setopts(S) ->
919    S.
920
921%% setopts/2
922
923setopts(M, Sock) ->
924    case setopts(M, Sock, [{active, once}]) of
925        ok -> ok;
926        X  -> x({setopts, Sock, M, X})  %% possibly on peer disconnect
927    end.
928
929%% portnr/2
930
931portnr(gen_tcp, Sock) ->
932    inet:port(Sock);
933portnr(M, Sock) ->
934    case M:sockname(Sock) of
935        {ok, {_Addr, PortNr}} ->
936            {ok, PortNr};
937        {error, _} = No ->
938            No
939    end.
940
941%% sockname/2
942
943sockname(gen_tcp, Sock) ->
944    inet:sockname(Sock);
945sockname(M, Sock) ->
946    M:sockname(Sock).
947
948%% peername/2
949
950peername(gen_tcp, Sock) ->
951    inet:peername(Sock);
952peername(M, Sock) ->
953    M:peername(Sock).
954
955%% getstat/2
956
957getstat(gen_tcp, Sock) ->
958    inet:getstat(Sock);
959getstat(M, Sock) ->
960    M:getstat(Sock).
961%% Note that ssl:getstat/1 doesn't yet exist in R15B01.
962
963%% A message_cb is invoked whenever a message is sent or received, or
964%% to provide acknowledgement of a completed send or discarded
965%% request. Ignoring possible extra arguments, calls are of the
966%% following form.
967%%
968%% cb(recv, Msg)          Receive a message into diameter?
969%% cb(send, Msg)          Send a message on the socket?
970%% cb(ack,  Msg)          Acknowledgement of a completed send.
971%% cb(ack,  false)        Acknowledgement of a discarded request.
972%%
973%% Msg will be binary() in a recv callback, but can be a
974%% diameter_packet record in a send/ack callback if a recv/send
975%% callback returns a record. Callbacks return a list of the following
976%% form.
977%%
978%%   [boolean() | send | recv | binary() | #diameter_packet{}]
979%%
980%% The atoms are meaningless by themselves, but say whether subsequent
981%% messages are to be sent or received. A boolean says whether or not
982%% to continue reading on the socket. Messages can be received even
983%% after false is returned if these arrived in the same packet. A
984%% leading recv or send is implicit on the corresponding callbacks. A
985%% new callback can be returned as the tail of a returned list: any
986%% value not of the aforementioned list type is interpreted as a
987%% callback.
988
989%% message/3
990
991message(send, false = M, S) ->
992    message(ack, M, S);
993
994message(ack, _, #transport{message_cb = false} = S) ->
995    S;
996
997message(Dir, Msg, #transport{message_cb = CB} = S) ->
998    setopts(actions(cb(CB, Dir, Msg), Dir, S)).
999
1000%% actions/3
1001
1002actions([], _, S) ->
1003    S;
1004
1005actions([B | As], Dir, S)
1006  when is_boolean(B) ->
1007    actions(As, Dir, S#transport{recv = B});
1008
1009actions([Dir | As], _, S)
1010  when Dir == send;
1011       Dir == recv ->
1012    actions(As, Dir, S);
1013
1014actions([Msg | As], send = Dir, S)
1015  when is_binary(Msg);
1016       is_record(Msg, diameter_packet) ->
1017    actions(As, Dir, send(Msg, S));
1018
1019actions([Msg | As], recv = Dir, #transport{parent = Pid} = S)
1020  when is_binary(Msg);
1021       is_record(Msg, diameter_packet) ->
1022    diameter_peer:recv(Pid, Msg),
1023    actions(As, Dir, S);
1024
1025actions([{defer, Tmo, Acts} | As], Dir, S) ->
1026    erlang:send_after(Tmo, self(), {actions, Dir, Acts}),
1027    actions(As, Dir, S);
1028
1029actions(CB, _, S) ->
1030    S#transport{message_cb = CB}.
1031
1032%% cb/3
1033
1034cb(false, _, Msg) ->
1035    [Msg];
1036
1037cb(CB, Dir, Msg) ->
1038    diameter_lib:eval([CB, Dir, Msg]).
1039