1%%
2%% %CopyrightBegin%
3%%
4%% Copyright Ericsson AB 2010-2020. All Rights Reserved.
5%%
6%% Licensed under the Apache License, Version 2.0 (the "License");
7%% you may not use this file except in compliance with the License.
8%% You may obtain a copy of the License at
9%%
10%%     http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing, software
13%% distributed under the License is distributed on an "AS IS" BASIS,
14%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15%% See the License for the specific language governing permissions and
16%% limitations under the License.
17%%
18%% %CopyrightEnd%
19%%
20
21%%
22%% This module implements (as a process) the RFC 3588/6733 Peer State
23%% Machine modulo the necessity of adapting the peer election to the
24%% fact that we don't know the identity of a peer until we've received
25%% a CER/CEA from it.
26%%
27
28-module(diameter_peer_fsm).
29-behaviour(gen_server).
30
31%% Interface towards diameter_watchdog.
32-export([start/3,
33         result_code/2]).
34
35%% Interface towards diameter.
36-export([find/1]).
37
38%% gen_server callbacks
39-export([init/1,
40         handle_call/3,
41         handle_cast/2,
42         handle_info/2,
43         terminate/2,
44         code_change/3]).
45
46%% diameter_peer_fsm_sup callback
47-export([start_link/1]).
48
49%% internal callbacks
50-export([match/1]).
51
52-include_lib("diameter/include/diameter.hrl").
53-include("diameter_internal.hrl").
54
55%% Values of Disconnect-Cause in DPR.
56-define(GOAWAY, 2).  %% DO_NOT_WANT_TO_TALK_TO_YOU
57-define(BUSY,   1).  %% BUSY
58-define(REBOOT, 0).  %% REBOOTING
59
60%% Values of Inband-Security-Id.
61-define(NO_INBAND_SECURITY, 0).
62-define(TLS, 1).
63
64%% Note that the a common dictionary hrl is purposely not included
65%% since the common dictionary is an argument to start/3.
66
67%% Keys in process dictionary.
68-define(CB_KEY, cb).         %% capabilities callback
69-define(DPR_KEY, dpr).       %% disconnect callback
70-define(DPA_KEY, dpa).       %% timeout for incoming DPA, or shutdown after
71                             %% outgoing DPA
72-define(REF_KEY, ref).       %% transport_ref()
73-define(Q_KEY, q).           %% transport start queue
74-define(START_KEY, start).   %% start of connected transport
75-define(SEQUENCE_KEY, mask). %% mask for sequence numbers
76-define(RESTRICT_KEY, restrict). %% nodes for connection check
77
78%% The default sequence mask.
79-define(NOMASK, {0,32}).
80
81%% A 2xxx series Result-Code. Not necessarily 2001.
82-define(IS_SUCCESS(N), 2 == (N) div 1000).
83
84%% Guards.
85-define(IS_UINT32(N), (is_integer(N) andalso 0 =< N andalso 0 == N bsr 32)).
86-define(IS_TIMEOUT(N), ?IS_UINT32(N)).
87-define(IS_CAUSE(N), N == ?REBOOT; N == rebooting;
88                     N == ?GOAWAY; N == goaway;
89                     N == ?BUSY;   N == busy).
90
91%% RFC 6733:
92%%
93%%   Timeout        An application-defined timer has expired while waiting
94%%                  for some event.
95%%
96
97%% Default timeout for reception of CER/CEA.
98-define(CAPX_TIMEOUT, 10000).
99
100%% Default timeout for DPA to be received in response to an outgoing
101%% DPR. A bit short but the timeout used to be hardcoded. (So it could
102%% be worse.)
103-define(DPA_TIMEOUT, 1000).
104
105%% Default timeout for the connection to be closed by the peer
106%% following an outgoing DPA in response to an incoming DPR. It's the
107%% recipient of DPA that should close the connection according to the
108%% RFC.
109-define(DPR_TIMEOUT, 5000).
110
111-type uint32() :: diameter:'Unsigned32'().
112
113-record(state,
114        {state %% of RFC 3588 Peer State Machine
115              :: {'Wait-Conn-Ack', uint32()}
116               | recv_CER
117               | {'Wait-CEA', uint32(), uint32()}
118               | 'Open',
119         mode :: accept | connect | {connect, reference()},
120         parent       :: pid(),     %% watchdog process
121         transport    :: pid(),     %% transport process
122         dictionary   :: module(),  %% common dictionary
123         service      :: #diameter_service{} | undefined,
124         dpr = false  :: false
125                       | true  %% DPR received, DPA sent
126                       | {boolean(), uint32(), uint32()},
127                       %% hop by hop and end to end identifiers in
128                       %% outgoing DPR; boolean says whether or not
129                       %% the request was sent explicitly with
130                       %% diameter:call/4.
131         codec :: #{decode_format := diameter:decode_format(),
132                    string_decode := boolean(),
133                    strict_mbit := boolean(),
134                    rfc := 3588 | 6733,
135                    ordered_encode := false},
136         strict :: boolean(),
137         ack = false :: boolean(),
138         length_errors :: exit | handle | discard,
139         incoming_maxlen :: integer() | infinity}).
140
141%% There are non-3588 states possible as a consequence of 5.6.1 of the
142%% standard and the corresponding problem for incoming CEA's: we don't
143%% know who we're talking to until either a CER or CEA has been
144%% received. The CEA problem in particular makes it impossible to
145%% follow the state machine exactly as documented in 3588: there can
146%% be no election until the CEA arrives and we have an Origin-Host to
147%% elect.
148
149%%
150%% Once upon a time start/2 started a process akin to that started by
151%% start/3 below, which in turn started a watchdog/transport process
152%% with the result that the watchdog could send DWR/DWA regardless of
153%% whether or not the corresponding Peer State Machine was in its open
154%% state; that is, before capabilities exchange had taken place. This
155%% is not what RFC's 3588 and 3539 say (albeit not very clearly).
156%% Watchdog messages are only exchanged on *open* connections, so the
157%% 3539 state machine is more naturally placed on top of the 3588 Peer
158%% State Machine rather than closer to the transport. This is what we
159%% now do below: connect/accept call diameter_watchdog and return the
160%% pid of the watchdog process, and the watchdog in turn calls start/3
161%% below to start the process implementing the Peer State Machine.
162%%
163
164%% ---------------------------------------------------------------------------
165%% # start/3
166%% ---------------------------------------------------------------------------
167
168-spec start(T, [Opt], {map(), [node()], module(), #diameter_service{}})
169   -> {reference(), pid()}
170 when T   :: {connect|accept, diameter:transport_ref()},
171      Opt :: diameter:transport_opt().
172
173%% diameter_config requires a non-empty list of applications on the
174%% service but diameter_service then constrains the list to any
175%% specified on the transport in question. Check here that the list is
176%% still non-empty.
177
178start({_,_} = Type, Opts, S) ->
179    Ack = make_ref(),
180    T = {Ack, self(), Type, Opts, S},
181    {ok, Pid} = diameter_peer_fsm_sup:start_child(T),
182    try
183        {erlang:monitor(process, Pid), Pid}
184    after
185        Pid ! Ack
186    end.
187
188start_link(T) ->
189    {ok, _} = proc_lib:start_link(?MODULE,
190                                  init,
191                                  [T],
192                                  infinity,
193                                  diameter_lib:spawn_opts(server, [])).
194
195%% find/1
196%%
197%% Identify both pids of a peer_fsm/transport pair.
198
199find(Pid) ->
200    findl([{?MODULE, '_', Pid}, {?MODULE, Pid, '_'}]).
201
202findl([]) ->
203    false;
204
205findl([Pat | Rest]) ->
206    try
207        [{{_, Pid, TPid}, Pid}] = diameter_reg:match(Pat),
208        {Pid, TPid}
209    catch
210        error:_ ->
211            findl(Rest)
212    end.
213
214%% ---------------------------------------------------------------------------
215%% ---------------------------------------------------------------------------
216
217%% init/1
218
219init(T) ->
220    proc_lib:init_ack({ok, self()}),
221    gen_server:enter_loop(?MODULE, [], i(T)).
222
223i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) ->
224    erlang:monitor(process, WPid),
225    wait(Ack, WPid),
226    diameter_stats:reg(Ref),
227
228    #{sequence := Mask, incoming_maxlen := Maxlen}
229        = SvcOpts,
230
231    {[Cs,Ds], Rest} = proplists:split(Opts, [capabilities_cb, disconnect_cb]),
232    putr(?CB_KEY, {Ref, [F || {_,F} <- Cs]}),
233    putr(?DPR_KEY, [F || {_, F} <- Ds]),
234    putr(?REF_KEY, Ref),
235    putr(?SEQUENCE_KEY, Mask),
236    putr(?RESTRICT_KEY, Nodes),
237    putr(?DPA_KEY, {proplists:get_value(dpr_timeout, Opts, ?DPR_TIMEOUT),
238                    proplists:get_value(dpa_timeout, Opts, ?DPA_TIMEOUT)}),
239
240    Tmo = proplists:get_value(capx_timeout, Opts, ?CAPX_TIMEOUT),
241    Strict = proplists:get_value(strict_capx, Opts, true),
242    LengthErr = proplists:get_value(length_errors, Opts, exit),
243
244    {TPid, Addrs} = start_transport(T, Rest, Svc),
245
246    diameter_reg:add({?MODULE, self(), TPid}),  %% lets pairs be discovered
247
248    #state{state = {'Wait-Conn-Ack', Tmo},
249           parent = WPid,
250           transport = TPid,
251           dictionary = Dict0,
252           mode = M,
253           service = svc(Svc, Addrs),
254           length_errors = LengthErr,
255           strict = Strict,
256           incoming_maxlen = Maxlen,
257           codec = maps:with([decode_format,
258                              string_decode,
259                              strict_mbit,
260                              rfc,
261                              ordered_encode],
262                             SvcOpts#{ordered_encode => false})}.
263%% The transport returns its local ip addresses so that different
264%% transports on the same service can use different local addresses.
265%% The local addresses are put into Host-IP-Address avps here when
266%% sending capabilities exchange messages.
267
268%% Wait for the caller to have a monitor to avoid a race with our
269%% death. (Since the exit reason is used in diameter_service.)
270wait(Ref, Pid) ->
271    receive
272        Ref ->
273            ok;
274        {'DOWN', _, process, Pid, _} = D ->
275            x(D)
276    end.
277
278x(T) ->
279    exit({shutdown, T}).
280
281start_transport(T, Opts, #diameter_service{capabilities = LCaps} = Svc) ->
282    Addrs0 = LCaps#diameter_caps.host_ip_address,
283    start_transport(Addrs0, {T, Opts, Svc}).
284
285start_transport(Addrs0, T) ->
286    case diameter_peer:start(T) of
287        {TPid, Addrs, Tmo, Data} ->
288            erlang:monitor(process, TPid),
289            q_next(TPid, Addrs0, Tmo, Data),
290            {TPid, Addrs};
291        {error, No} ->
292            x({no_connection, No})
293    end.
294
295svc(#diameter_service{capabilities = LCaps0} = Svc, Addrs) ->
296    #diameter_caps{host_ip_address = Addrs0}
297        = LCaps0,
298    case Addrs0 of
299        [] ->
300            LCaps = LCaps0#diameter_caps{host_ip_address = Addrs},
301            Svc#diameter_service{capabilities = LCaps};
302        [_|_] ->
303            Svc
304    end.
305
306readdr(#diameter_service{capabilities = LCaps0} = Svc, Addrs) ->
307    LCaps = LCaps0#diameter_caps{host_ip_address = Addrs},
308    Svc#diameter_service{capabilities = LCaps}.
309
310%% The 4-tuple Data returned from diameter_peer:start/1 identifies the
311%% transport module/config use to start the transport process in
312%% question as well as any alternates to try if a connection isn't
313%% established within Tmo.
314q_next(TPid, Addrs0, Tmo, {_,_,_,_} = Data) ->
315    send_after(Tmo, {connection_timeout, TPid}),
316    putr(?Q_KEY, {Addrs0, Tmo, Data}).
317
318%% Connection has been established: retain the started
319%% pid/module/config in the process dictionary. This is a part of the
320%% interface defined by this module, so that the transport pid can be
321%% found when constructing service_info (in order to extract further
322%% information from it).
323keep_transport(TPid) ->
324    {_, _, {{_,_,_} = T, _, _, _}} = eraser(?Q_KEY),
325    putr(?START_KEY, {TPid, T}).
326
327send_after(infinity, _) ->
328    ok;
329send_after(Tmo, T) ->
330    erlang:send_after(Tmo, self(), T).
331
332%% handle_call/3
333
334handle_call(_, _, State) ->
335    {reply, nok, State}.
336
337%% handle_cast/2
338
339handle_cast(_, State) ->
340    {noreply, State}.
341
342%% handle_info/1
343
344handle_info(T, #state{} = State) ->
345    try transition(T, State) of
346        ok ->
347            {noreply, State};
348        #state{state = X} = S ->
349            ?LOGC(X /= State#state.state, transition, X),
350            {noreply, S};
351        {stop, Reason} ->
352            ?LOG(stop, Reason),
353            {stop, {shutdown, Reason}, State};
354        stop ->
355            ?LOG(stop, truncate(T)),
356            {stop, {shutdown, T}, State}
357    catch
358        exit: {diameter_codec, encode, T} = Reason ->
359            incr_error(send, T, State#state.dictionary),
360            ?LOG(stop, Reason),
361            {stop, {shutdown, Reason}, State};
362        {?MODULE, Tag, Reason}  ->
363            ?LOG(stop, Tag),
364            {stop, {shutdown, Reason}, State}
365    end.
366%% The form of the throw caught here is historical. It's
367%% significant that it's not a 2-tuple, as in ?FAILURE(Reason),
368%% since these are caught elsewhere.
369
370%% Note that there's no guarantee that the service and transport
371%% capabilities are good enough to build a CER/CEA that can be
372%% successfully encoded. It's not checked at diameter:add_transport/2
373%% since this can be called before creating the service.
374
375%% terminate/2
376
377terminate(_, _) ->
378    ok.
379
380%% code_change/3
381
382code_change(_, State, _) ->
383    {ok, State}.
384
385%% ---------------------------------------------------------------------------
386%% ---------------------------------------------------------------------------
387
388truncate({'DOWN' = T, _, process, Pid, _}) ->
389    {T, Pid};
390truncate(T) ->
391    T.
392
393putr(Key, Val) ->
394    put({?MODULE, Key}, Val).
395
396getr(Key) ->
397    get({?MODULE, Key}).
398
399eraser(Key) ->
400    erase({?MODULE, Key}).
401
402%% transition/2
403
404%% Connection to peer.
405transition({diameter, {TPid, connected, Remote}},
406           #state{transport = TPid,
407                  state = PS,
408                  mode = M}
409           = S) ->
410    {'Wait-Conn-Ack', _} = PS,  %% assert
411    connect = M,                %%
412    keep_transport(TPid),
413    send_CER(S#state{mode = {M, Remote}});
414
415transition({diameter, {TPid, connected, Remote, LAddrs}},
416           #state{transport = TPid,
417                  service = Svc}
418           = S) ->
419    transition({diameter, {TPid, connected, Remote}},
420               S#state{service = svc(Svc, LAddrs)});
421
422%% Connection from peer.
423transition({diameter, {TPid, connected}},
424           #state{transport = TPid,
425                  state = PS,
426                  mode = M,
427                  parent = Pid}
428           = S) ->
429    {'Wait-Conn-Ack', Tmo} = PS,  %% assert
430    accept = M,                   %%
431    keep_transport(TPid),
432    Pid ! {accepted, self()},
433    start_timer(Tmo, S#state{state = recv_CER});
434
435%% Connection established after receiving a connection_timeout
436%% message. This may be followed by an incoming message which arrived
437%% before the transport was killed and this can't be distinguished
438%% from one from the transport that's been started to replace it.
439transition({diameter, T}, _)
440  when tuple_size(T) < 5, connected == element(2,T) ->
441    {stop, connection_timeout};
442
443%% Connection has timed out: start an alternate.
444transition({connection_timeout = T, TPid},
445           #state{transport = TPid,
446                  state = {'Wait-Conn-Ack', _}}
447           = S) ->
448    exit(TPid, {shutdown, T}),
449    start_next(S);
450
451%% Connect timeout after connection or alternate start: ignore.
452transition({connection_timeout, _}, _) ->
453    ok;
454
455%% Requests for acknowledgements to the transport.
456transition({diameter, ack}, S) ->
457    S#state{ack = true};
458
459%% Incoming message from the transport.
460transition({diameter, {recv, Msg}}, S) ->
461    incoming(recv(Msg, S), S);
462
463%% Handler of an incoming request is telling of its existence.
464transition({handler, Pid}, _) ->
465    put_route(Pid),
466    ok;
467
468%% Timeout when still in the same state ...
469transition({timeout = T, PS}, #state{state = PS}) ->
470    {stop, {capx(PS), T}};
471
472%% ... or not.
473transition({timeout, _}, _) ->
474    ok;
475
476%% Outgoing message.
477transition({send, Msg}, S) ->
478    outgoing(Msg, S);
479transition({send, Msg, Route}, S) ->
480    route_outgoing(Route),
481    outgoing(Msg, S);
482
483%% Request for graceful shutdown at remove_transport, stop_service of
484%% application shutdown.
485transition({shutdown, Pid, Reason}, #state{parent = Pid, dpr = false} = S) ->
486    dpr(Reason, S);
487transition({shutdown, Pid, _}, #state{parent = Pid}) ->
488    ok;
489
490%% DPA reception has timed out, or peer has not closed the connection
491%% as a result of outgoing DPA.
492transition(dpa_timeout, _) ->
493    stop;
494
495%% Someone wants to know a resolved port: forward to the transport process.
496transition({resolve_port, _Pid} = T, #state{transport = TPid}) ->
497    TPid ! T,
498    ok;
499
500%% Parent has died.
501transition({'DOWN', _, process, WPid, _},
502           #state{parent = WPid}) ->
503    stop;
504
505%% Transport has died before connection timeout.
506transition({'DOWN', _, process, TPid, _},
507           #state{transport = TPid}
508           = S) ->
509    start_next(S#state{ack = false});
510
511%% Transport has died after connection timeout, or handler process has
512%% died.
513transition({'DOWN', _, process, Pid, _}, #state{transport = TPid}) ->
514    is_reference(erase_route(Pid))
515        andalso send(TPid, false),  %% answer not forthcoming
516    ok;
517
518%% State query.
519transition({state, Pid}, #state{state = S, transport = TPid}) ->
520    Pid ! {self(), [S, TPid]},
521    ok.
522
523%% Crash on anything unexpected.
524
525%% route_outgoing/1
526
527%% Map identifiers in an outgoing request to be able to lookup the
528%% handler process when the answer is received.
529route_outgoing({Pid, Ref, Seqs}) ->  %% request
530    MRef = monitor(process, Pid),
531    put(Pid, Seqs),
532    put(Seqs, {Pid, Ref, MRef});
533
534%% Remove a mapping made for an incoming request.
535route_outgoing(Pid)
536  when is_pid(Pid) ->  %% answer
537    MRef = erase_route(Pid),
538    undefined == MRef orelse demonitor(MRef).
539
540%% put_route/1
541
542%% Monitor on a handler process for an incoming request.
543put_route(Pid) ->
544    MRef = monitor(process, Pid),
545    put(Pid, MRef).
546
547%% get_route/3
548
549%% Incoming answer.
550get_route(_, _, #diameter_packet{header = #diameter_header{is_request = false}}
551                = Pkt) ->
552    Seqs = diameter_codec:sequence_numbers(Pkt),
553    case erase(Seqs) of
554        {Pid, Ref, MRef} ->
555            demonitor(MRef),
556            erase(Pid),
557            {Pid, Ref, self()};
558        undefined ->  %% request unknown
559            false
560    end;
561
562%% Requests answered here ...
563get_route(_, N, _)
564  when N == 'CER';
565       N == 'DPR' ->
566    false;
567
568%% ... or not.
569get_route(Ack, _, _) ->
570    Ack.
571
572%% erase_route/1
573
574erase_route(Pid) ->
575    case erase(Pid) of
576        {_,_} = Seqs ->
577            erase(Seqs);
578        T ->
579            T
580    end.
581
582%% capx/1
583
584capx(recv_CER) ->
585    'CER';
586capx({'Wait-CEA', _, _}) ->
587    'CEA'.
588
589%% start_next/1
590
591start_next(#state{service = Svc0} = S) ->
592    case getr(?Q_KEY) of
593        {Addrs0, Tmo, Data} ->
594            Svc = readdr(Svc0, Addrs0),
595            {TPid, Addrs} = start_transport(Addrs0, {Svc, Tmo, Data}),
596            S#state{transport = TPid,
597                    service = svc(Svc, Addrs)};
598        undefined ->
599            stop
600    end.
601
602%% send_CER/1
603
604send_CER(#state{state = {'Wait-Conn-Ack', Tmo},
605                mode = {connect, Remote},
606                service = #diameter_service{capabilities = LCaps},
607                transport = TPid,
608                dictionary = Dict,
609                codec = Opts}
610         = S) ->
611    OH = LCaps#diameter_caps.origin_host,
612    req_send_CER(OH, Remote)
613        orelse
614        close({already_connected, Remote, LCaps}),
615    CER = build_CER(S),
616    #diameter_packet{header = #diameter_header{end_to_end_id = Eid,
617                                               hop_by_hop_id = Hid}}
618        = Pkt
619        = encode(CER, Opts, Dict),
620    incr(send, Pkt, Dict),
621    send(TPid, Pkt),
622    ?LOG(send, 'CER'),
623    start_timer(Tmo, S#state{state = {'Wait-CEA', Hid, Eid}}).
624
625%% Register ourselves as connecting to the remote endpoint in
626%% question. This isn't strictly necessary since a peer implementing
627%% the 3588 Peer State Machine should reject duplicate connection's
628%% from the same peer but there's little point in us setting up a
629%% duplicate connection in the first place. This could also include
630%% the transport protocol being used but since we're blind to
631%% transport just avoid duplicate connections to the same host/port.
632req_send_CER(OriginHost, Remote) ->
633    register_everywhere({?MODULE, connection, OriginHost, {remote, Remote}}).
634
635%% start_timer/2
636
637start_timer(Tmo, #state{state = PS} = S) ->
638    erlang:send_after(Tmo, self(), {timeout, PS}),
639    S.
640
641%% build_CER/1
642
643build_CER(#state{service = #diameter_service{capabilities = LCaps},
644                 dictionary = Dict}) ->
645    {ok, CER} = diameter_capx:build_CER(LCaps, Dict),
646    CER.
647
648%% encode/3
649
650encode(Rec, Opts, Dict) ->
651    Seq = diameter_session:sequence({_,_} = getr(?SEQUENCE_KEY)),
652    Hdr = #diameter_header{version = ?DIAMETER_VERSION,
653                           end_to_end_id = Seq,
654                           hop_by_hop_id = Seq},
655    diameter_codec:encode(Dict, Opts, #diameter_packet{header = Hdr,
656                                                       msg = Rec}).
657
658%% incoming/2
659
660incoming(#diameter_header{is_request = R}, #state{transport = TPid,
661                                                  ack = Ack}) ->
662    R andalso Ack andalso send(TPid, false),
663    ok;
664
665incoming(<<_:32, 1:1, _/bits>>, #state{ack = true} = S) ->
666    send(S#state.transport, false),
667    ok;
668
669incoming(<<_/bits>>, _) ->
670    ok;
671
672incoming(T, _) ->
673    T.
674
675%% recv/2
676
677recv(#diameter_packet{bin = Bin} = Pkt, S) ->
678    recv(Bin, Pkt, S);
679
680recv(Bin, S) ->
681    recv(Bin, Bin, S).
682
683%% recv/3
684
685recv(Bin, Msg, S) ->
686    recv(diameter_codec:decode_header(Bin), Bin, Msg, S).
687
688%% recv/4
689
690recv(false, Bin, _, #state{length_errors = E}) ->
691    invalid(E, truncated_header, Bin),
692    Bin;
693
694recv(#diameter_header{length = Len} = H, Bin, Msg, #state{length_errors = E,
695                                                          incoming_maxlen = M,
696                                                          dictionary = Dict0}
697                                                   = S)
698  when E == handle;
699       0 == Len rem 4, bit_size(Bin) == 8*Len, size(Bin) =< M ->
700    recv1(diameter_codec:msg_name(Dict0, H), H, Msg, S);
701
702recv(H, Bin, _, #state{incoming_maxlen = M})
703  when M < size(Bin) ->
704    invalid(false, incoming_maxlen_exceeded, {size(Bin), H}),
705    H;
706
707recv(H, Bin, _, #state{length_errors = E}) ->
708    T = {size(Bin), bit_size(Bin) rem 8, H},
709    invalid(E, message_length_mismatch, T),
710    H.
711
712%% recv1/4
713
714%% Ignore anything but an expected CER/CEA if so configured. This is
715%% non-standard behaviour.
716recv1(Name, H, _, #state{state = {'Wait-CEA', _, _},
717                         strict = false})
718  when Name /= 'CEA' ->
719    H;
720recv1(Name, H, _, #state{state = recv_CER,
721                         strict = false})
722  when Name /= 'CER' ->
723    H;
724
725%% Incoming request after outgoing DPR: discard. Don't discard DPR, so
726%% both ends don't do so when sending simultaneously.
727recv1(Name, #diameter_header{is_request = true} = H, _, #state{dpr = {_,_,_}})
728  when Name /= 'DPR' ->
729    invalid(false, recv_after_outgoing_dpr, H),
730    H;
731
732%% Incoming request after incoming DPR: discard.
733recv1(_, #diameter_header{is_request = true} = H, _, #state{dpr = true}) ->
734    invalid(false, recv_after_incoming_dpr, H),
735    H;
736
737%% DPA with identifier mismatch, or in response to a DPR initiated by
738%% the service.
739recv1('DPA' = Name,
740      #diameter_header{hop_by_hop_id = Hid, end_to_end_id = Eid}
741      = H,
742      Msg,
743      #state{dpr = {X,HI,EI}}
744      = S)
745  when HI /= Hid;
746       EI /= Eid;
747       not X ->
748    Pkt = pkt(H, Msg),
749    handle(Name, Pkt, S);
750
751%% Any other message with a header and no length errors.
752recv1(Name, H, Msg, #state{parent = Pid, ack = Ack} = S) ->
753    Pkt = pkt(H, Msg),
754    Pid ! {recv, self(), get_route(Ack, Name, Pkt), Name, Pkt},
755    handle(Name, Pkt, S).
756
757%% pkt/2
758
759pkt(H, Bin)
760  when is_binary(Bin) ->
761    #diameter_packet{header = H,
762                     bin = Bin};
763
764pkt(H, Pkt) ->
765    Pkt#diameter_packet{header = H}.
766
767%% invalid/3
768
769%% Note that counters here only count discarded messages.
770invalid(E, Reason, T) ->
771    diameter_stats:incr(Reason),
772    E == exit andalso close({Reason, T}),
773    ?LOG(Reason, T),
774    ok.
775
776%% handle/3
777
778%% Incoming CEA.
779handle('CEA' = N,
780       #diameter_packet{header = #diameter_header{end_to_end_id = Eid,
781                                                  hop_by_hop_id = Hid}}
782       = Pkt,
783       #state{state = {'Wait-CEA', Hid, Eid}}
784       = S) ->
785    ?LOG(recv, N),
786    handle_CEA(Pkt, S);
787
788%% Incoming CER
789handle('CER' = N, Pkt, #state{state = recv_CER} = S) ->
790    handle_request(N, Pkt, S);
791
792%% Anything but CER/CEA in a non-Open state is an error, as is
793%% CER/CEA in anything but recv_CER/Wait-CEA.
794handle(Name, _, #state{state = PS})
795  when PS /= 'Open';
796       Name == 'CER';
797       Name == 'CEA' ->
798    {stop, {Name, PS}};
799
800handle('DPR' = N, Pkt, S) ->
801    handle_request(N, Pkt, S);
802
803%% DPA in response to DPR, with the expected identifiers.
804handle('DPA' = N,
805       #diameter_packet{header = #diameter_header{end_to_end_id = Eid,
806                                                  hop_by_hop_id = Hid}
807                               = H}
808       = Pkt,
809    #state{dictionary = Dict0,
810           transport = TPid,
811           dpr = {X, Hid, Eid},
812           codec = Opts}) ->
813    ?LOG(recv, N),
814    X orelse begin
815                 %% Only count DPA in response to a DPR sent by the
816                 %% service: explicit DPR is counted in the same way
817                 %% as other explicitly sent requests.
818                 incr(recv, H, Dict0),
819                 {_, RecPkt} = decode(Dict0, Opts, Pkt),
820                 incr_rc(recv, RecPkt, Dict0)
821             end,
822    diameter_peer:close(TPid),
823    {stop, N};
824
825%% Ignore an unsolicited DPA in particular. Note that dpa_timeout
826%% deals with the case in which the peer sends the wrong identifiers
827%% in DPA.
828handle('DPA' = N, #diameter_packet{header = H}, _) ->
829    ?LOG(ignored, N),
830    %% Note that these aren't counted in the normal recv counter.
831    diameter_stats:incr({diameter_codec:msg_id(H), recv, ignored}),
832    ok;
833
834handle(_, _, _) ->
835    ok.
836
837%% incr/3
838
839incr(Dir, Hdr, Dict0) ->
840    diameter_traffic:incr(Dir, Hdr, self(), Dict0).
841
842%% incr_rc/3
843
844incr_rc(Dir, Pkt, Dict0) ->
845    diameter_traffic:incr_rc(Dir, Pkt, self(), Dict0).
846
847%% incr_error/3
848
849incr_error(Dir, Pkt, Dict0) ->
850    diameter_traffic:incr_error(Dir, Pkt, self(), Dict0).
851
852%% send/2
853
854%% Msg here could be a #diameter_packet or a binary depending on who's
855%% sending. In particular, the watchdog will send DWR as a binary
856%% while messages coming from clients will be in a #diameter_packet.
857
858send(Pid, Msg) ->
859    diameter_peer:send(Pid, Msg).
860
861%% outgoing/2
862
863%% Explicit DPR.
864outgoing(#diameter_packet{header = #diameter_header{application_id = 0,
865                                                    cmd_code = 282,
866                                                    is_request = true}
867                                 = H}
868         = Pkt,
869         #state{dpr = T,
870                parent = Pid}
871         = S) ->
872    if T == false ->
873            inform_dpr(Pid),
874            send_dpr(true, Pkt, dpa_timeout(), S);
875       T == true ->
876            invalid(false, dpr_after_dpa, H);  %% DPA sent: discard
877       true ->
878            invalid(false, dpr_after_dpr, H)   %% DPR sent: discard
879    end;
880
881%% Explicit CER or DWR: discard. These are sent by us.
882outgoing(#diameter_packet{header = #diameter_header{application_id = 0,
883                                                    cmd_code = C,
884                                                    is_request = true}
885                                 = H},
886         _)
887  when 257 == C;    %% CER
888       280 == C ->  %% DWR
889    invalid(false, invalid_request, H);
890
891%% DPR not sent: send.
892outgoing(Msg, #state{transport = TPid, dpr = false}) ->
893    send(TPid, Msg),
894    ok;
895
896%% Outgoing answer: send.
897outgoing(#diameter_packet{header = #diameter_header{is_request = false}}
898         = Pkt,
899         #state{transport = TPid}) ->
900    send(TPid, Pkt),
901    ok;
902
903%% Outgoing request: discard.
904outgoing(Msg, #state{}) ->
905    invalid(false, send_after_dpr, header(Msg)).
906
907header(#diameter_packet{header = H}) ->
908    H;
909header(Bin) ->  %% DWR
910    diameter_codec:decode_header(Bin).
911
912%% handle_request/3
913%%
914%% Incoming CER or DPR.
915
916handle_request(Name,
917               #diameter_packet{header = H}
918               = Pkt,
919               #state{dictionary = Dict0,
920                      codec = Opts}
921               = S) ->
922    ?LOG(recv, Name),
923    incr(recv, H, Dict0),
924    send_answer(Name, decode(Dict0, Opts, Pkt), S).
925
926%% decode/3
927%%
928%% Decode the message as record for diameter_capx, and in the
929%% configured format for events.
930
931decode(Dict0, Opts, Pkt) ->
932    {diameter_codec:decode(Dict0, Opts, Pkt),
933     diameter_codec:decode(Dict0, Opts#{decode_format := record}, Pkt)}.
934
935%% send_answer/3
936
937send_answer(Type, {DecPkt, RecPkt}, #state{transport = TPid,
938                                           dictionary = Dict,
939                                           codec = Opts}
940                                    = S) ->
941    incr_error(recv, RecPkt, Dict),
942
943    #diameter_packet{header = H,
944                     transport_data = TD}
945        = RecPkt,
946
947    {Msg, PostF} = build_answer(Type, DecPkt, RecPkt, S),
948
949    %% An answer message clears the R and T flags and retains the P
950    %% flag. The E flag is set at encode.
951    Pkt = #diameter_packet{header
952                           = H#diameter_header{version = ?DIAMETER_VERSION,
953                                               is_request = false,
954                                               is_error = undefined,
955                                               is_retransmitted = false},
956                           msg = Msg,
957                           transport_data = TD},
958
959    AnsPkt = diameter_codec:encode(Dict, Opts, Pkt),
960
961    incr(send, AnsPkt, Dict),
962    incr_rc(send, AnsPkt, Dict),
963    send(TPid, AnsPkt),
964    ?LOG(send, ans(Type)),
965    eval(PostF, S).
966
967ans('CER') -> 'CEA';
968ans('DPR') -> 'DPA'.
969
970eval([F|A], S) ->
971    apply(F, A ++ [S]);
972eval(T, _) ->
973    close(T).
974
975%% build_answer/4
976
977build_answer('CER',
978             DecPkt,
979             #diameter_packet{msg = CER,
980                              header = #diameter_header{version
981                                                        = ?DIAMETER_VERSION,
982                                                        is_error = false},
983                              errors = []},
984             #state{dictionary = Dict0}
985             = S) ->
986    {SupportedApps, RCaps, CEA} = recv_CER(CER, S),
987
988    [RC, IS] = Dict0:'#get-'(['Result-Code', 'Inband-Security-Id'], CEA),
989
990    #diameter_caps{origin_host = {OH, DH}}
991        = Caps
992        = capz(caps(S), RCaps),
993
994    try
995        2001 == RC  %% DIAMETER_SUCCESS
996            orelse ?THROW(RC),
997        register_everywhere({?MODULE, connection, OH, DH})
998            orelse ?THROW(4003),  %% DIAMETER_ELECTION_LOST
999        caps_cb(Caps)
1000    of
1001        N -> {cea(CEA, N, Dict0), [fun open/5, DecPkt,
1002                                               SupportedApps,
1003                                               Caps,
1004                                               {accept, inband_security(IS)}]}
1005    catch
1006        ?FAILURE(Reason) ->
1007            rejected(Reason, {'CER', Reason, Caps, DecPkt}, S)
1008    end;
1009
1010%% The error checks below are similar to those in diameter_traffic for
1011%% other messages. Should factor out the commonality.
1012
1013build_answer(Type,
1014             DecPkt,
1015             #diameter_packet{header = H,
1016                              errors = Es},
1017             S) ->
1018    {RC, FailedAVP} = result_code(Type, H, Es),
1019    {answer(Type, RC, FailedAVP, S), post(Type, RC, DecPkt, S)}.
1020
1021inband_security([]) ->
1022    ?NO_INBAND_SECURITY;
1023inband_security([IS]) ->
1024    IS.
1025
1026cea(CEA, ok, _) ->
1027    CEA;
1028cea(CEA, 2001, _) ->
1029    CEA;
1030cea(CEA, RC, Dict0) ->
1031    Dict0:'#set-'({'Result-Code', RC}, CEA).
1032
1033post('CER' = T, RC, Pkt, S) ->
1034    {T, caps(S), {RC, Pkt}};
1035post('DPR', _, _, #state{parent = Pid}) ->
1036    [fun(S) -> dpr_timer(), inform_dpr(Pid), dpr(S) end].
1037
1038dpr(#state{dpr = false} = S) ->  %% not awaiting DPA
1039    S#state{dpr = true};  %% DPR received
1040dpr(S) ->  %% DPR already sent or received
1041    S.
1042
1043inform_dpr(Pid) ->
1044    Pid ! {'DPR', self()}.  %% tell watchdog to die with us
1045
1046rejected({capabilities_cb, _F, Reason}, T, S) ->
1047    rejected(Reason, T, S);
1048
1049rejected(discard, T, _) ->
1050    close(T);
1051rejected({N, Es}, T, S) ->
1052    {answer('CER', N, failed_avp(N, Es), S), T};
1053rejected(N, T, S) ->
1054    {answer('CER', N, [], S), T}.
1055
1056failed_avp(RC, [{RC, Avp} | _]) ->
1057    [{'Failed-AVP', [[{'AVP', [Avp]}]]}];
1058failed_avp(RC, [_ | Es]) ->
1059    failed_avp(RC, Es);
1060failed_avp(_, [] = No) ->
1061    No.
1062
1063answer(Type, RC, FailedAVP, S) ->
1064    set(answer(Type, RC, S), FailedAVP).
1065
1066answer(Type, RC, S) ->
1067    answer_message(answer(Type, S), RC).
1068
1069%% answer_message/2
1070
1071answer_message([_ | Avps], RC)
1072  when 3000 =< RC, RC < 4000 ->
1073    ['answer-message', {'Result-Code', RC}
1074                     | lists:filter(fun is_origin/1, Avps)];
1075
1076answer_message(Msg, RC) ->
1077    Msg ++ [{'Result-Code', RC}].
1078
1079is_origin({N, _}) ->
1080    N == 'Origin-Host'
1081        orelse N == 'Origin-Realm'
1082        orelse N == 'Origin-State-Id'.
1083
1084%% set/2
1085
1086set(Ans, []) ->
1087    Ans;
1088set(['answer-message' | _] = Ans, FailedAvp) ->
1089    Ans ++ [{'AVP', [FailedAvp]}];
1090set([_|_] = Ans, FailedAvp) ->
1091    Ans ++ FailedAvp.
1092
1093%% result_code/3
1094
1095%% Be lenient with errors in DPR since there's no reason to be
1096%% otherwise. Rejecting may cause the peer to missinterpret the error
1097%% as meaning that the connection should not be closed, which may well
1098%% lead to more problems than any errors in the DPR.
1099
1100result_code('DPR', _, _) ->
1101    {2001, []};
1102
1103result_code('CER', H, Es) ->
1104    result_code(H, Es).
1105
1106%% result_code/2
1107
1108result_code(#diameter_header{is_error = true}, _) ->
1109    {3008, []};  %% DIAMETER_INVALID_HDR_BITS
1110
1111result_code(#diameter_header{version = ?DIAMETER_VERSION}, Es) ->
1112    rc(Es);
1113
1114result_code(_, _) ->
1115    {5011, []}.  %% DIAMETER_UNSUPPORTED_VERSION
1116
1117%% rc/1
1118
1119rc([]) ->
1120    {2001, []};  %% DIAMETER_SUCCESS
1121rc([{RC, _} | _] = Es) ->
1122    {RC, failed_avp(RC, Es)};
1123rc([RC|_]) ->
1124    {RC, []}.
1125
1126%%   DIAMETER_INVALID_HDR_BITS          3008
1127%%      A request was received whose bits in the Diameter header were
1128%%      either set to an invalid combination, or to a value that is
1129%%      inconsistent with the command code's definition.
1130
1131%%   DIAMETER_INVALID_AVP_BITS          3009
1132%%      A request was received that included an AVP whose flag bits are
1133%%      set to an unrecognized value, or that is inconsistent with the
1134%%      AVP's definition.
1135
1136%%   ELECTION_LOST                      4003
1137%%      The peer has determined that it has lost the election process and
1138%%      has therefore disconnected the transport connection.
1139
1140%%   DIAMETER_NO_COMMON_APPLICATION     5010
1141%%      This error is returned when a CER message is received, and there
1142%%      are no common applications supported between the peers.
1143
1144%%   DIAMETER_UNSUPPORTED_VERSION       5011
1145%%      This error is returned when a request was received, whose version
1146%%      number is unsupported.
1147
1148%% answer/2
1149
1150answer(Name, #state{service = #diameter_service{capabilities = Caps}}) ->
1151    a(Name, Caps).
1152
1153a('CER', #diameter_caps{vendor_id = Vid,
1154                        origin_host = Host,
1155                        origin_realm = Realm,
1156                        host_ip_address = Addrs,
1157                        product_name = Name,
1158                        origin_state_id = OSI}) ->
1159    ['CEA', {'Origin-Host', Host},
1160            {'Origin-Realm', Realm},
1161            {'Host-IP-Address', Addrs},
1162            {'Vendor-Id', Vid},
1163            {'Product-Name', Name},
1164            {'Origin-State-Id', OSI}];
1165
1166a('DPR', #diameter_caps{origin_host = {Host, _},
1167                        origin_realm = {Realm, _}}) ->
1168    ['DPA', {'Origin-Host', Host},
1169            {'Origin-Realm', Realm}].
1170
1171%% recv_CER/2
1172
1173recv_CER(CER, #state{service = Svc, dictionary = Dict}) ->
1174    case diameter_capx:recv_CER(CER, Svc, Dict) of
1175        {ok, T} ->
1176            T;
1177        {error, Reason} ->
1178            close({'CER', CER, Svc, Dict, Reason})
1179    end.
1180
1181%% handle_CEA/2
1182
1183handle_CEA(#diameter_packet{header = H}
1184           = Pkt,
1185           #state{dictionary = Dict0,
1186                  service = #diameter_service{capabilities = LCaps},
1187                  codec = Opts}
1188           = S) ->
1189    incr(recv, H, Dict0),
1190
1191    {DecPkt, RecPkt} = decode(Dict0, Opts, Pkt),
1192
1193    RC = result_code(incr_rc(recv, RecPkt, Dict0)),
1194    {SApps, IS, RCaps} = recv_CEA(RecPkt, S),
1195
1196    #diameter_caps{origin_host = {OH, DH}}
1197        = Caps
1198        = capz(LCaps, RCaps),
1199
1200    %% Ensure that we don't already have a connection to the peer in
1201    %% question. This isn't the peer election of 3588 except in the
1202    %% sense that, since we don't know who we're talking to until we
1203    %% receive a CER/CEA, the first that arrives wins the right to a
1204    %% connection with the peer.
1205
1206    try
1207        is_integer(RC) andalso ?IS_SUCCESS(RC)
1208            orelse ?THROW(RC),
1209        [] == SApps
1210            andalso ?THROW(no_common_application),
1211        [] == IS
1212            andalso ?THROW(no_common_security),
1213        register_everywhere({?MODULE, connection, OH, DH})
1214            orelse ?THROW(election_lost),
1215        caps_cb(Caps)
1216    of
1217        _ -> open(DecPkt, SApps, Caps, {connect, hd([_] = IS)}, S)
1218    catch
1219        ?FAILURE(Reason) -> close({'CEA', Reason, Caps, DecPkt})
1220    end.
1221%% Check more than the result code since the peer could send success
1222%% regardless. If not 2001 then a peer_up callback could do anything
1223%% required. It's not unimaginable that a peer agreeing to TLS after
1224%% capabilities exchange could send DIAMETER_LIMITED_SUCCESS = 2002,
1225%% even if this isn't required by RFC 3588.
1226
1227result_code({'Result-Code', N}) ->
1228    N;
1229result_code(_) ->
1230    undefined.
1231
1232%% recv_CEA/2
1233
1234recv_CEA(#diameter_packet{header = #diameter_header{version
1235                                                    = ?DIAMETER_VERSION,
1236                                                    is_error = false},
1237                          msg = CEA,
1238                          errors = []},
1239         #state{service = Svc,
1240                dictionary = Dict}) ->
1241    case diameter_capx:recv_CEA(CEA, Svc, Dict) of
1242        {ok, T} ->
1243            T;
1244        {error, Reason} ->
1245            close({'CEA', CEA, Svc, Dict, Reason})
1246    end;
1247
1248recv_CEA(Pkt, S) ->
1249    close({'CEA', caps(S), Pkt}).
1250
1251caps(#diameter_service{capabilities = Caps}) ->
1252    Caps;
1253caps(#state{service = Svc}) ->
1254    caps(Svc).
1255
1256%% caps_cb/1
1257
1258caps_cb(Caps) ->
1259    {Ref, Ts} = eraser(?CB_KEY),
1260    caps_cb(Ts, [Ref, Caps]).
1261
1262caps_cb([], _) ->
1263    ok;
1264caps_cb([F | Rest], T) ->
1265    case diameter_lib:eval([F|T]) of
1266        ok ->
1267            caps_cb(Rest, T);
1268        N when ?IS_SUCCESS(N) ->  %% 2xxx result code: accept immediately
1269            N;
1270        Res ->
1271            ?THROW({capabilities_cb, F, rejected(Res)})
1272    end.
1273%% Note that returning 2xxx causes the capabilities exchange to be
1274%% accepted directly, without further callbacks.
1275
1276rejected(discard = T) ->
1277    T;
1278rejected(unknown) ->
1279    3010;  %% DIAMETER_UNKNOWN_PEER
1280rejected(N)
1281  when is_integer(N) ->
1282    N.
1283
1284%% open/5
1285
1286open(Pkt, SupportedApps, Caps, {Type, IS}, #state{parent = Pid,
1287                                                  service = Svc}
1288                                           = S) ->
1289    #diameter_caps{origin_host = {_,_} = H,
1290                   inband_security_id = {LS,_}}
1291        = Caps,
1292
1293    tls_ack(lists:member(?TLS, LS), Caps, Type, IS, S),
1294    Pid ! {open, self(), H, {Caps, SupportedApps, Pkt}},
1295
1296    %% Replace capabilities record with local/remote pairs.
1297    S#state{state = 'Open',
1298            service = Svc#diameter_service{capabilities = Caps}}.
1299
1300%% We've advertised TLS support: tell the transport the result
1301%% and expect a reply when the handshake is complete.
1302tls_ack(true, Caps, Type, IS, #state{transport = TPid}) ->
1303    Ref = make_ref(),
1304    TPid ! {diameter, {tls, Ref, Type, IS == ?TLS}},
1305    receive
1306        {diameter, {tls, Ref}} ->
1307            ok;
1308        {'DOWN', _, process, TPid, Reason} ->
1309            close({tls_ack, Reason, Caps})
1310    end;
1311
1312%% Or not. Don't send anything to the transport so that transports
1313%% not supporting TLS work as before without modification.
1314tls_ack(false, _, _, _, _) ->
1315    ok.
1316
1317capz(#diameter_caps{} = L, #diameter_caps{} = R) ->
1318    #diameter_caps{}
1319        = list_to_tuple([diameter_caps | lists:zip(tl(tuple_to_list(L)),
1320                                                   tl(tuple_to_list(R)))]).
1321
1322%% close/1
1323%%
1324%% A good function to trace on in case of problems with capabilities
1325%% exchange.
1326
1327close(Reason) ->
1328    throw({?MODULE, close, Reason}).
1329
1330%% dpr/2
1331%%
1332%% The RFC isn't clear on whether DPR should be sent in a non-Open
1333%% state. The Peer State Machine transitions it documents aren't
1334%% exhaustive (no Stop in Wait-I-CEA for example) so assume it's up to
1335%% the implementation and transition to Closed (ie. die) if we haven't
1336%% yet reached Open.
1337
1338%% Connection is open, DPR has not been sent.
1339dpr(Reason, #state{state = 'Open',
1340                   dpr = false,
1341                   service = #diameter_service{capabilities = Caps}}
1342            = S) ->
1343    CBs = getr(?DPR_KEY),
1344    Ref = getr(?REF_KEY),
1345    Peer = {self(), Caps},
1346    dpr(CBs, [Reason, Ref, Peer], S);
1347
1348%% Connection is open, DPR already sent or received.
1349dpr(_, #state{state = 'Open'}) ->
1350    ok;
1351
1352%% Connection not open.
1353dpr(_Reason, _S) ->
1354    stop.
1355
1356%% dpr/3
1357%%
1358%% Note that an implementation that wants to do something
1359%% transport_module-specific can lookup the pid of the transport
1360%% process and contact it. (eg. diameter:service_info/2)
1361
1362dpr([CB|Rest], [Reason | _] = Args, S) ->
1363    case diameter_lib:eval([CB | Args]) of
1364        {dpr, Opts} when is_list(Opts) ->
1365            send_dpr(Reason, Opts, S);
1366        dpr ->
1367            send_dpr(Reason, [], S);
1368        close = T ->
1369            {stop, {disconnect_cb, T}};
1370        ignore ->
1371            dpr(Rest, Args, S);
1372        T ->
1373            ?ERROR({disconnect_cb, CB, Args, T})
1374    end;
1375
1376dpr([], [Reason | _], S) ->
1377    send_dpr(Reason, [], S).
1378
1379-record(opts, {cause, timeout}).
1380
1381send_dpr(Reason, DprOpts, #state{dictionary = Dict,
1382                                 service = #diameter_service{capabilities = Caps},
1383                                 codec = Opts}
1384                       = S) ->
1385    #opts{cause = Cause, timeout = Tmo}
1386        = lists:foldl(fun opt/2,
1387                      #opts{cause = case Reason of
1388                                        transport -> ?GOAWAY;
1389                                        _         -> ?REBOOT
1390                                    end,
1391                            timeout = dpa_timeout()},
1392                      DprOpts),
1393    #diameter_caps{origin_host = {OH, _},
1394                   origin_realm = {OR, _}}
1395        = Caps,
1396
1397    Pkt = encode(['DPR', {'Origin-Host', OH},
1398                         {'Origin-Realm', OR},
1399                         {'Disconnect-Cause', Cause}],
1400                 Opts,
1401                 Dict),
1402    send_dpr(false, Pkt, Tmo, S).
1403
1404%% send_dpr/4
1405
1406send_dpr(X,
1407         #diameter_packet{header = #diameter_header{end_to_end_id = Eid,
1408                                                    hop_by_hop_id = Hid}}
1409         = Pkt,
1410         Tmo,
1411         #state{transport = TPid,
1412                dictionary = Dict}
1413         = S) ->
1414    %% Only count DPR sent by the service: explicit DPR is counted in
1415    %% the same way as other explicitly sent requests.
1416    X orelse incr(send, Pkt, Dict),
1417    send(TPid, Pkt),
1418    dpa_timer(Tmo),
1419    ?LOG(send, 'DPR'),
1420    S#state{dpr = {X, Hid, Eid}}.
1421
1422%% opt/2
1423
1424opt({timeout, Tmo}, Rec)
1425  when ?IS_TIMEOUT(Tmo) ->
1426    Rec#opts{timeout = Tmo};
1427opt({cause, Cause}, Rec)
1428  when ?IS_CAUSE(Cause) ->
1429    Rec#opts{cause = cause(Cause)};
1430opt(T, _) ->
1431    ?ERROR({invalid_option, T}).
1432
1433cause(rebooting) -> ?REBOOT;
1434cause(goaway)    -> ?GOAWAY;
1435cause(busy)      -> ?BUSY;
1436cause(N)
1437  when ?IS_CAUSE(N) ->
1438    N;
1439cause(N) ->
1440    ?ERROR({invalid_cause, N}).
1441
1442dpa_timer(Tmo) ->
1443    erlang:send_after(Tmo, self(), dpa_timeout).
1444
1445dpa_timeout() ->
1446    {_, Tmo} = getr(?DPA_KEY),
1447    Tmo.
1448
1449dpr_timer() ->
1450    dpa_timer(dpr_timeout()).
1451
1452dpr_timeout() ->
1453    {Tmo, _} = getr(?DPA_KEY),
1454    Tmo.
1455
1456%% register_everywhere/1
1457%%
1458%% Register a term and ensure it's not registered elsewhere. Note that
1459%% two process that simultaneously register the same term may well
1460%% both fail to do so this isn't foolproof.
1461%%
1462%% Everywhere is no longer everywhere, it's where a
1463%% restrict_connections service_opt() specifies.
1464
1465register_everywhere(T) ->
1466    reg(getr(?RESTRICT_KEY), T).
1467
1468reg(Nodes, T) ->
1469    add(lists:member(node(), Nodes), T) andalso unregistered(Nodes, T).
1470
1471add(true, T) ->
1472    diameter_reg:add_new(T);
1473add(false, T) ->
1474    diameter_reg:add(T).
1475
1476%% unregistered
1477%%
1478%% Ensure that the term in question isn't registered on other nodes.
1479
1480unregistered(Nodes, T) ->
1481    {ResL, _} = rpc:multicall(Nodes, ?MODULE, match, [{node(), T}]),
1482    lists:all(fun nomatch/1, ResL).
1483
1484nomatch({badrpc, {'EXIT', {undef, _}}}) ->  %% no diameter on remote node
1485    true;
1486nomatch(L) ->
1487    [] == L.
1488
1489%% match/1
1490
1491match({Node, _})
1492  when Node == node() ->
1493    [];
1494match({_, T}) ->
1495    try
1496        diameter_reg:match(T)
1497    catch
1498        _:_ -> []
1499    end.
1500