1%%
2%% %CopyrightBegin%
3%%
4%% Copyright Ericsson AB 2007-2020. All Rights Reserved.
5%%
6%% Licensed under the Apache License, Version 2.0 (the "License");
7%% you may not use this file except in compliance with the License.
8%% You may obtain a copy of the License at
9%%
10%%     http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing, software
13%% distributed under the License is distributed on an "AS IS" BASIS,
14%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15%% See the License for the specific language governing permissions and
16%% limitations under the License.
17%%
18%% %CopyrightEnd%
19%%
20
21%%----------------------------------------------------------------------
22%% Purpose: Handle client side TLS-1.3 session ticket storage
23%%----------------------------------------------------------------------
24
25-module(tls_client_ticket_store).
26-behaviour(gen_server).
27
28-include("ssl_internal.hrl").
29-include("tls_handshake_1_3.hrl").
30
31%% API
32-export([find_ticket/5,
33         get_tickets/2,
34         lock_tickets/2,
35         remove_tickets/1,
36         start_link/2,
37         store_ticket/4,
38         unlock_tickets/2,
39         update_ticket/2]).
40
41%% gen_server callbacks
42-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
43         terminate/2, code_change/3, format_status/2]).
44
45-define(SERVER, ?MODULE).
46
47-record(state, {
48                db,
49                lifetime,
50                max
51               }).
52
53-record(data, {
54               pos = undefined,
55               cipher_suite,
56               sni,
57               psk,
58               timestamp,
59               ticket,
60               lock = undefined
61              }).
62
63%%%===================================================================
64%%% API
65%%%===================================================================
66-spec start_link(integer(), integer()) -> {ok, Pid :: pid()} |
67                      {error, Error :: {already_started, pid()}} |
68                      {error, Error :: term()} |
69                      ignore.
70start_link(Max, Lifetime) ->
71    gen_server:start_link({local, ?MODULE}, ?MODULE, [Max, Lifetime], []).
72
73find_ticket(Pid, Ciphers, HashAlgos, SNI, EarlyDataSize) ->
74    gen_server:call(?MODULE, {find_ticket, Pid, Ciphers, HashAlgos, SNI, EarlyDataSize}, infinity).
75
76get_tickets(Pid, Keys) ->
77    gen_server:call(?MODULE, {get_tickets, Pid, Keys}, infinity).
78
79lock_tickets(_, undefined) ->
80    ok;
81lock_tickets(Pid, Keys) ->
82    gen_server:call(?MODULE, {lock, Pid, Keys}, infinity).
83
84remove_tickets([]) ->
85    ok;
86remove_tickets(Keys) ->
87    gen_server:cast(?MODULE, {remove_tickets, Keys}).
88
89store_ticket(Ticket, CipherSuite, SNI, PSK) ->
90    gen_server:call(?MODULE, {store_ticket, Ticket, CipherSuite, SNI, PSK}, infinity).
91
92unlock_tickets(Pid, Keys) ->
93    gen_server:call(?MODULE, {unlock, Pid, Keys}, infinity).
94
95update_ticket(Key, Pos) ->
96    gen_server:call(?MODULE, {update_ticket, Key, Pos}, infinity).
97
98%%%===================================================================
99%%% gen_server callbacks
100%%%===================================================================
101
102-spec init(Args :: term()) -> {ok, State :: term()}.
103
104init(Args) ->
105    process_flag(trap_exit, true),
106    State = inital_state(Args),
107    {ok, State}.
108
109-spec handle_call(Request :: term(), From :: {pid(), term()}, State :: term()) ->
110                         {reply, Reply :: term(), NewState :: term()} .
111handle_call({find_ticket, Pid, Ciphers, HashAlgos, SNI, EarlyDataSize}, _From, State) ->
112    Key = do_find_ticket(State, Pid, Ciphers, HashAlgos, SNI, EarlyDataSize),
113    {reply, Key, State};
114handle_call({get_tickets, Pid, Keys}, _From, State) ->
115    Data = get_tickets(State, Pid, Keys),
116    {reply, Data, State};
117handle_call({lock, Pid, Keys}, _From, State0) ->
118    State = lock_tickets(State0, Pid, Keys),
119    {reply, ok, State};
120handle_call({store_ticket, Ticket, CipherSuite, SNI, PSK}, _From, State0) ->
121    State = store_ticket(State0, Ticket, CipherSuite, SNI, PSK),
122    {reply, ok, State};
123handle_call({unlock, Pid, Keys}, _From, State0) ->
124    State = unlock_tickets(State0, Pid, Keys),
125    {reply, ok, State};
126handle_call({update_ticket, Key, Pos}, _From, State0) ->
127    State = update_ticket(State0, Key, Pos),
128    {reply, ok, State}.
129
130-spec handle_cast(Request :: term(), State :: term()) ->
131                         {noreply, NewState :: term()}.
132handle_cast({remove_tickets, Key}, State0) ->
133    State = remove_tickets(State0, Key),
134    {noreply, State};
135handle_cast(_Request, State) ->
136    {noreply, State}.
137
138-spec handle_info(Info :: timeout() | term(), State :: term()) ->
139                         {noreply, NewState :: term()}.
140handle_info(remove_invalid_tickets, State0) ->
141    State = remove_invalid_tickets(State0),
142    {noreply, State};
143handle_info(_Info, State) ->
144    {noreply, State}.
145
146-spec terminate(Reason :: normal | shutdown | {shutdown, term()} | term(),
147                State :: term()) -> any().
148terminate(_Reason, _State) ->
149    ok.
150
151-spec code_change(OldVsn :: term() | {down, term()},
152                  State :: term(),
153                  Extra :: term()) -> {ok, NewState :: term()} |
154                                      {error, Reason :: term()}.
155code_change(_OldVsn, State, _Extra) ->
156    {ok, State}.
157
158
159-spec format_status(Opt :: normal | terminate,
160                    Status :: list()) -> Status :: term().
161format_status(_Opt, Status) ->
162    Status.
163%%%===================================================================
164%%% Internal functions
165%%%===================================================================
166
167inital_state([Max, Lifetime]) ->
168    erlang:send_after(Lifetime * 1000, self(), remove_invalid_tickets),
169    #state{db = gb_trees:empty(),
170           lifetime = Lifetime,
171           max = Max
172          }.
173
174do_find_ticket(Iter, Pid, Ciphers, HashAlgos, SNI, EarlyDataSize) ->
175    do_find_ticket(Iter, Pid, Ciphers, HashAlgos, SNI, EarlyDataSize, []).
176%%
177do_find_ticket(_, _, _, [], _, _, []) ->
178    {undefined, undefined};
179do_find_ticket(_, _, _, [], _, _, Acc) ->
180    {undefined, last_elem(Acc)};
181do_find_ticket(#state{db = Db,
182                      lifetime = Lifetime} = State, Pid, Ciphers, [Hash|T], SNI, EarlyDataSize, Acc) ->
183    case iterate_tickets(gb_trees:iterator(Db), Pid, Ciphers, Hash, SNI, Lifetime, EarlyDataSize) of
184        {undefined, undefined} ->
185            do_find_ticket(State, Pid, Ciphers, T, SNI, EarlyDataSize, Acc);
186        {undefined, Key} ->
187            do_find_ticket(State, Pid, Ciphers, T, SNI, EarlyDataSize, [Key|Acc]);
188        Key ->
189            Key
190    end.
191
192iterate_tickets(Iter0, Pid, Ciphers, Hash, SNI, Lifetime, EarlyDataSize) ->
193    iterate_tickets(Iter0, Pid, Ciphers, Hash, SNI, Lifetime, EarlyDataSize, []).
194%%
195iterate_tickets(Iter0, Pid, Ciphers, Hash, SNI, Lifetime, EarlyDataSize, Acc) ->
196    case gb_trees:next(Iter0) of
197        {Key, #data{cipher_suite = {Cipher, Hash},
198                    sni = TicketSNI,
199                    ticket = #new_session_ticket{
200                                extensions = Extensions},
201                    timestamp = Timestamp,
202                    lock = Lock}, Iter} when Lock =:= undefined orelse
203                                             Lock =:= Pid ->
204            MaxEarlyData = tls_handshake_1_3:get_max_early_data(Extensions),
205            Age = erlang:system_time(seconds) - Timestamp,
206            if Age < Lifetime ->
207                    case verify_ticket_sni(SNI, TicketSNI) of
208                        match ->
209                            case lists:member(Cipher, Ciphers) of
210                                true ->
211                                    Front = last_elem(Acc),
212                                    %% 'Key' can be used with early_data as both
213                                    %% block cipher and hash algorithm matches.
214                                    %% 'Front' can only be used for session
215                                    %% resumption.
216                                    case EarlyDataSize =:= undefined orelse
217                                        EarlyDataSize =< MaxEarlyData of
218                                        true ->
219                                            {Key, Front};
220                                        false ->
221                                            %% 'Key' cannot be used for early_data as the data
222                                            %% to be sent exceeds the max limit for this ticket.
223                                            iterate_tickets(Iter, Pid, Ciphers, Hash, SNI,
224                                                            Lifetime, EarlyDataSize,[Key|Acc])
225                                    end;
226                                false ->
227                                    iterate_tickets(Iter, Pid, Ciphers, Hash, SNI, Lifetime, EarlyDataSize, [Key|Acc])
228                            end;
229                        nomatch ->
230                            iterate_tickets(Iter, Pid, Ciphers, Hash, SNI, Lifetime, EarlyDataSize, Acc)
231                    end;
232               true ->
233                    iterate_tickets(Iter, Pid, Ciphers, Hash, SNI, Lifetime, EarlyDataSize, Acc)
234            end;
235        {_, _, Iter} ->
236            iterate_tickets(Iter, Pid, Ciphers, Hash, SNI, Lifetime, EarlyDataSize, Acc);
237        none ->
238            {undefined, last_elem(Acc)}
239    end.
240
241last_elem([_|_] = L) ->
242    lists:last(L);
243last_elem([]) ->
244    undefined.
245
246verify_ticket_sni(undefined, _) ->
247    match;
248verify_ticket_sni(SNI, SNI) ->
249    match;
250verify_ticket_sni(_, _) ->
251    nomatch.
252
253%% Get tickets that are not locked by another process
254get_tickets(State, Pid, Keys) ->
255    get_tickets(State, Pid, Keys, []).
256%%
257get_tickets(_, _, [], []) ->
258    undefined; %% No tickets found
259get_tickets(_, _, [], Acc) ->
260    Acc;
261get_tickets(#state{db = Db} = State, Pid, [Key|T], Acc) ->
262    try gb_trees:get(Key, Db) of
263        #data{pos = Pos,
264              cipher_suite = CipherSuite,
265              psk = PSK,
266              timestamp = Timestamp,
267              ticket = NewSessionTicket,
268              lock = Lock} when Lock =:= undefined orelse
269                                Lock =:= Pid ->
270            #new_session_ticket{
271               ticket_lifetime = _LifeTime,
272               ticket_age_add = AgeAdd,
273               ticket_nonce = Nonce,
274               ticket = Ticket,
275               extensions = Extensions
276              } = NewSessionTicket,
277            TicketAge =  erlang:system_time(seconds) - Timestamp,
278            ObfuscatedTicketAge = obfuscate_ticket_age(TicketAge, AgeAdd),
279            Identity = #psk_identity{
280                          identity = Ticket,
281                          obfuscated_ticket_age = ObfuscatedTicketAge},
282            MaxEarlyData = tls_handshake_1_3:get_max_early_data(Extensions),
283            TicketData = #ticket_data{
284                           key = Key,
285                           pos = Pos,
286                           identity = Identity,
287                           psk = PSK,
288                           nonce = Nonce,
289                           cipher_suite = CipherSuite,
290                           max_size = MaxEarlyData},
291            get_tickets(State, Pid, T, [TicketData|Acc])
292    catch
293        _:_ ->
294            get_tickets(State, Pid, T, Acc)
295    end.
296
297%% The "obfuscated_ticket_age"
298%% field of each PskIdentity contains an obfuscated version of the
299%% ticket age formed by taking the age in milliseconds and adding the
300%% "ticket_age_add" value that was included with the ticket
301%% (see Section 4.6.1), modulo 2^32.
302obfuscate_ticket_age(TicketAge, AgeAdd) ->
303    (TicketAge + AgeAdd) rem round(math:pow(2,32)).
304
305
306remove_tickets(State, []) ->
307    State;
308remove_tickets(State0, [Key|T]) ->
309    remove_tickets(remove_ticket(State0, Key), T).
310
311
312remove_ticket(#state{db = Db0} = State, Key) ->
313    Db = gb_trees:delete_any(Key, Db0),
314    State#state{db = Db}.
315
316
317remove_invalid_tickets(#state{db = Db,
318                              lifetime = Lifetime} = State0) ->
319    Keys = collect_invalid_tickets(gb_trees:iterator(Db), Lifetime),
320    State = remove_tickets(State0, Keys),
321    erlang:send_after(Lifetime * 1000, self(), remove_invalid_tickets),
322    State.
323
324
325collect_invalid_tickets(Iter, Lifetime) ->
326    collect_invalid_tickets(Iter, Lifetime, []).
327%%
328collect_invalid_tickets(Iter0, Lifetime, Acc) ->
329    case gb_trees:next(Iter0) of
330        {Key, #data{timestamp = Timestamp,
331                    lock = undefined}, Iter} ->
332            Age = erlang:system_time(seconds) - Timestamp,
333            if Age < Lifetime ->
334                    collect_invalid_tickets(Iter, Lifetime, Acc);
335               true ->
336                    collect_invalid_tickets(Iter, Lifetime, [Key|Acc])
337            end;
338        {_, _, Iter} ->  %% Skip locked tickets
339            collect_invalid_tickets(Iter, Lifetime, Acc);
340        none ->
341            Acc
342    end.
343
344
345store_ticket(#state{db = Db0, max = Max} = State, Ticket, CipherSuite, SNI, PSK) ->
346    Timestamp = erlang:system_time(seconds),
347    Size = gb_trees:size(Db0),
348    Db1 = if Size =:= Max ->
349                  delete_oldest(Db0);
350             true ->
351                  Db0
352          end,
353    Key =  {erlang:monotonic_time(), erlang:unique_integer([monotonic])},
354    Db = gb_trees:insert(Key,
355                         #data{cipher_suite = CipherSuite,
356                               sni = SNI,
357                               psk = PSK,
358                               timestamp = Timestamp,
359                               ticket = Ticket},
360                         Db1),
361    State#state{db = Db}.
362
363
364update_ticket(#state{db = Db0} = State, Key, Pos) ->
365    try gb_trees:get(Key, Db0) of
366        Value ->
367            Db = gb_trees:update(Key, Value#data{pos = Pos}, Db0),
368            State#state{db = Db}
369    catch
370        _:_ ->
371            State
372    end.
373
374
375delete_oldest(Db0) ->
376    try gb_trees:take_smallest(Db0) of
377        {_, _, Db} ->
378            Db
379    catch
380        _:_ ->
381            Db0
382    end.
383
384
385lock_tickets(State, Pid, Keys) ->
386    set_lock(State, Pid, Keys, lock).
387
388
389unlock_tickets(State, Pid, Keys) ->
390    set_lock(State, Pid, Keys, unlock).
391
392
393set_lock(State, _, [], _) ->
394    State;
395set_lock(#state{db = Db0} = State, Pid, [Key|T], Cmd) ->
396    try gb_trees:get(Key, Db0) of
397        Value ->
398            Db = gb_trees:update(Key, update_data_lock(Value, Pid, Cmd), Db0),
399            set_lock(State#state{db = Db}, Pid, T, Cmd)
400    catch
401        _:_ ->
402            set_lock(State, Pid, T, Cmd)
403    end.
404
405
406update_data_lock(Value, Pid, lock) ->
407    Value#data{lock = Pid};
408update_data_lock(#data{lock = Pid} = Value, Pid, unlock) ->
409    Value#data{lock = undefined};
410update_data_lock(Value, _, _) ->
411    Value.
412