1%%
2%% %CopyrightBegin%
3%%
4%% Copyright Ericsson AB 2007-2020. All Rights Reserved.
5%%
6%% Licensed under the Apache License, Version 2.0 (the "License");
7%% you may not use this file except in compliance with the License.
8%% You may obtain a copy of the License at
9%%
10%%     http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing, software
13%% distributed under the License is distributed on an "AS IS" BASIS,
14%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15%% See the License for the specific language governing permissions and
16%% limitations under the License.
17%%
18%% %CopyrightEnd%
19%%
20%%
21%%----------------------------------------------------------------------
22%% Purpose: Handles an ssl connection, e.i. both the setup
23%% e.i. SSL-Handshake, SSL-Alert and SSL-Cipher protocols and delivering
24%% data to the application. All data on the connectinon is received and
25%% sent according to the SSL-record protocol.
26%%----------------------------------------------------------------------
27
28-module(tls_connection).
29
30-behaviour(gen_statem).
31
32-include("tls_connection.hrl").
33-include("tls_handshake.hrl").
34-include("tls_handshake_1_3.hrl").
35-include("ssl_alert.hrl").
36-include("tls_record.hrl").
37-include("ssl_cipher.hrl").
38-include("ssl_api.hrl").
39-include("ssl_internal.hrl").
40-include("ssl_srp.hrl").
41-include_lib("public_key/include/public_key.hrl").
42-include_lib("kernel/include/logger.hrl").
43
44%% Internal application API
45
46%% Setup
47-export([start_fsm/8, start_link/8, init/1, pids/1]).
48
49%% State transition handling
50-export([next_event/3, next_event/4,
51         handle_protocol_record/3]).
52
53%% Handshake handling
54-export([renegotiation/2, renegotiate/2, send_handshake/2,
55         send_handshake_flight/1,
56	 queue_handshake/2, queue_change_cipher/2,
57	 reinit/1, reinit_handshake_data/1, select_sni_extension/1,
58         empty_connection_state/2]).
59
60%% Alert and close handling
61-export([send_alert/2, send_alert_in_connection/2,
62         send_sync_alert/2,
63         close/5, protocol_name/0]).
64
65%% Data handling
66-export([socket/4, setopts/3, getopts/3]).
67
68%% gen_statem state functions
69-export([init/3, error/3, downgrade/3, %% Initiation and take down states
70	 hello/3, user_hello/3, certify/3, cipher/3, abbreviated/3, %% Handshake states
71	 connection/3]).
72%% TLS 1.3 state functions (server)
73-export([start/3,         %% common state with client
74         negotiated/3,
75         recvd_ch/3,
76         wait_cert/3,     %% common state with client
77         wait_cv/3,       %% common state with client
78         wait_eoed/3,
79         wait_finished/3, %% common state with client
80         wait_flight2/3,
81         connected/3      %% common state with client
82        ]).
83%% TLS 1.3 state functions (client)
84-export([wait_cert_cr/3,
85         wait_ee/3,
86         wait_sh/3
87        ]).
88%% gen_statem callbacks
89-export([callback_mode/0, terminate/3, code_change/4, format_status/2]).
90
91-export([encode_handshake/4, send_key_update/2, update_cipher_key/2]).
92
93-define(DIST_CNTRL_SPAWN_OPTS, [{priority, max}]).
94
95%%====================================================================
96%% Internal application API
97%%====================================================================
98%%====================================================================
99%% Setup
100%%====================================================================
101start_fsm(Role, Host, Port, Socket, {#{erl_dist := false},_, Trackers} = Opts,
102	  User, {CbModule, _,_, _, _} = CbInfo,
103	  Timeout) ->
104    try
105        {ok, Sender} = tls_sender:start(),
106	{ok, Pid} = tls_connection_sup:start_child([Role, Sender, Host, Port, Socket,
107						    Opts, User, CbInfo]),
108	{ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, [Pid, Sender], CbModule, Trackers),
109        ssl_connection:handshake(SslSocket, Timeout)
110    catch
111	error:{badmatch, {error, _} = Error} ->
112	    Error
113    end;
114
115start_fsm(Role, Host, Port, Socket, {#{erl_dist := true},_, Trackers} = Opts,
116	  User, {CbModule, _,_, _, _} = CbInfo,
117	  Timeout) ->
118    try
119        {ok, Sender} = tls_sender:start([{spawn_opt, ?DIST_CNTRL_SPAWN_OPTS}]),
120	{ok, Pid} = tls_connection_sup:start_child_dist([Role, Sender, Host, Port, Socket,
121							 Opts, User, CbInfo]),
122	{ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, [Pid, Sender], CbModule, Trackers),
123        ssl_connection:handshake(SslSocket, Timeout)
124    catch
125	error:{badmatch, {error, _} = Error} ->
126	    Error
127    end.
128
129%%--------------------------------------------------------------------
130-spec start_link(atom(), pid(), ssl:host(), inet:port_number(), port(), list(), pid(), tuple()) ->
131    {ok, pid()} | ignore |  {error, reason()}.
132%%
133%% Description: Creates a gen_statem process which calls Module:init/1 to
134%% initialize.
135%%--------------------------------------------------------------------
136start_link(Role, Sender, Host, Port, Socket, Options, User, CbInfo) ->
137    {ok, proc_lib:spawn_link(?MODULE, init, [[Role, Sender, Host, Port, Socket, Options, User, CbInfo]])}.
138
139init([Role, Sender, Host, Port, Socket, {#{erl_dist := ErlDist}, _, _} = Options,  User, CbInfo]) ->
140    process_flag(trap_exit, true),
141    link(Sender),
142    case ErlDist of
143        true ->
144            process_flag(priority, max);
145        _ ->
146            ok
147    end,
148    State0 = #state{protocol_specific = Map} = initial_state(Role, Sender,
149                                                             Host, Port, Socket, Options, User, CbInfo),
150    try
151	State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0),
152        initialize_tls_sender(State),
153        gen_statem:enter_loop(?MODULE, [], init, State)
154    catch throw:Error ->
155            EState = State0#state{protocol_specific = Map#{error => Error}},
156            gen_statem:enter_loop(?MODULE, [], error, EState)
157    end.
158
159pids(#state{protocol_specific = #{sender := Sender}}) ->
160    [self(), Sender].
161
162%%====================================================================
163%% State transition handling
164%%====================================================================
165next_record(_, #state{handshake_env =
166                       #handshake_env{unprocessed_handshake_events = N} = HsEnv}
167            = State) when N > 0 ->
168    {no_record, State#state{handshake_env =
169                                HsEnv#handshake_env{unprocessed_handshake_events = N-1}}};
170next_record(_, #state{protocol_buffers =
171                          #protocol_buffers{tls_cipher_texts = [_|_] = CipherTexts},
172                      connection_states = ConnectionStates,
173                      ssl_options = #{padding_check := Check}} = State) ->
174    next_record(State, CipherTexts, ConnectionStates, Check);
175next_record(connection, #state{protocol_buffers = #protocol_buffers{tls_cipher_texts = []},
176                               protocol_specific = #{active_n_toggle := true}
177                              } = State) ->
178    %% If ssl application user is not reading data wait to activate socket
179    flow_ctrl(State);
180
181next_record(_, #state{protocol_buffers = #protocol_buffers{tls_cipher_texts = []},
182                      protocol_specific = #{active_n_toggle := true}
183                     } = State) ->
184    activate_socket(State);
185next_record(_, State) ->
186    {no_record, State}.
187
188%%% bytes_to_read equals the integer Length arg of ssl:recv
189%%% the actual value is only relevant for packet = raw | 0
190%%% bytes_to_read = undefined means no recv call is ongoing
191flow_ctrl(#state{user_data_buffer = {_,Size,_},
192                 socket_options = #socket_options{active = false},
193                 bytes_to_read = undefined} = State)  when Size =/= 0 ->
194    %% Passive mode wait for new recv request or socket activation
195    %% that is preserv some tcp back pressure by waiting to activate
196    %% socket
197    {no_record, State};
198%%%%%%%%%% A packet mode is set and socket is passive %%%%%%%%%%
199flow_ctrl(#state{socket_options = #socket_options{active = false,
200                                                  packet = Packet}} = State)
201  when ((Packet =/= 0) andalso (Packet =/= raw)) ->
202    %% We need more data to complete the packet.
203    activate_socket(State);
204%%%%%%%%% No packet mode set and socket is passive %%%%%%%%%%%%
205flow_ctrl(#state{user_data_buffer = {_,Size,_},
206                 socket_options = #socket_options{active = false},
207                 bytes_to_read = 0} = State)  when Size == 0 ->
208    %% Passive mode no available bytes, get some
209    activate_socket(State);
210flow_ctrl(#state{user_data_buffer = {_,Size,_},
211                 socket_options = #socket_options{active = false},
212                 bytes_to_read = 0} = State)  when Size =/= 0 ->
213    %% There is data in the buffer to deliver
214    {no_record, State};
215flow_ctrl(#state{user_data_buffer = {_,Size,_},
216                 socket_options = #socket_options{active = false},
217                 bytes_to_read = BytesToRead} = State) when (BytesToRead > 0) ->
218    case (Size >= BytesToRead) of
219        true -> %% There is enough data bufferd
220            {no_record, State};
221        false -> %% We need more data to complete the delivery of <BytesToRead> size
222            activate_socket(State)
223    end;
224%%%%%%%%%%% Active mode or more data needed %%%%%%%%%%
225flow_ctrl(State) ->
226    activate_socket(State).
227
228
229activate_socket(#state{protocol_specific = #{active_n_toggle := true, active_n := N} = ProtocolSpec,
230                       static_env = #static_env{socket = Socket,
231                                                close_tag = CloseTag,
232                                                transport_cb = Transport}
233                      } = State) ->
234    case tls_socket:setopts(Transport, Socket, [{active, N}]) of
235        ok ->
236            {no_record, State#state{protocol_specific = ProtocolSpec#{active_n_toggle => false}}};
237        _ ->
238            self() ! {CloseTag, Socket},
239            {no_record, State}
240    end.
241
242%% Decipher next record and concatenate consecutive ?APPLICATION_DATA records into one
243%%
244next_record(State, CipherTexts, ConnectionStates, Check) ->
245    next_record(State, CipherTexts, ConnectionStates, Check, []).
246%%
247next_record(#state{connection_env = #connection_env{negotiated_version = {3,4} = Version}} = State,
248            [CT|CipherTexts], ConnectionStates0, Check, Acc) ->
249    case tls_record:decode_cipher_text(Version, CT, ConnectionStates0, Check) of
250        {#ssl_tls{type = ?APPLICATION_DATA, fragment = Fragment}, ConnectionStates} ->
251            case CipherTexts of
252                [] ->
253                    %% End of cipher texts - build and deliver an ?APPLICATION_DATA record
254                    %% from the accumulated fragments
255                    next_record_done(State, [], ConnectionStates,
256                                     #ssl_tls{type = ?APPLICATION_DATA,
257                                              fragment = iolist_to_binary(lists:reverse(Acc, [Fragment]))});
258                [_|_] ->
259                    next_record(State, CipherTexts, ConnectionStates, Check, [Fragment|Acc])
260            end;
261        {Record, ConnectionStates} when Acc =:= [] ->
262            %% Singelton non-?APPLICATION_DATA record - deliver
263            next_record_done(State, CipherTexts, ConnectionStates, Record);
264        {_Record, _ConnectionStates_to_forget} ->
265            %% Not ?APPLICATION_DATA but we have accumulated fragments
266            %% -> build an ?APPLICATION_DATA record with concatenated fragments
267            %%    and forget about decrypting this record - we'll decrypt it again next time
268            %% Will not work for stream ciphers
269            next_record_done(State, [CT|CipherTexts], ConnectionStates0,
270                             #ssl_tls{type = ?APPLICATION_DATA, fragment = iolist_to_binary(lists:reverse(Acc))});
271        #alert{} = Alert ->
272            Alert
273    end;
274next_record(#state{connection_env = #connection_env{negotiated_version = Version}} = State,
275            [#ssl_tls{type = ?APPLICATION_DATA} = CT |CipherTexts], ConnectionStates0, Check, Acc) ->
276    case tls_record:decode_cipher_text(Version, CT, ConnectionStates0, Check) of
277        {#ssl_tls{type = ?APPLICATION_DATA, fragment = Fragment}, ConnectionStates} ->
278            case CipherTexts of
279                [] ->
280                    %% End of cipher texts - build and deliver an ?APPLICATION_DATA record
281                    %% from the accumulated fragments
282                    next_record_done(State, [], ConnectionStates,
283                                     #ssl_tls{type = ?APPLICATION_DATA,
284                                              fragment = iolist_to_binary(lists:reverse(Acc, [Fragment]))});
285                [_|_] ->
286                    next_record(State, CipherTexts, ConnectionStates, Check, [Fragment|Acc])
287            end;
288        #alert{} = Alert ->
289            Alert
290    end;
291next_record(State, CipherTexts, ConnectionStates, _, [_|_] = Acc) ->
292    next_record_done(State, CipherTexts, ConnectionStates,
293                     #ssl_tls{type = ?APPLICATION_DATA,
294                              fragment = iolist_to_binary(lists:reverse(Acc))});
295next_record(#state{connection_env = #connection_env{negotiated_version = Version}} = State,
296            [CT|CipherTexts], ConnectionStates0, Check, []) ->
297    case tls_record:decode_cipher_text(Version, CT, ConnectionStates0, Check) of
298        {Record, ConnectionStates} ->
299            %% Singelton non-?APPLICATION_DATA record - deliver
300            next_record_done(State, CipherTexts, ConnectionStates, Record);
301        #alert{} = Alert ->
302            Alert
303    end.
304
305next_record_done(#state{protocol_buffers = Buffers} = State, CipherTexts, ConnectionStates, Record) ->
306    {Record,
307     State#state{protocol_buffers = Buffers#protocol_buffers{tls_cipher_texts = CipherTexts},
308                 connection_states = ConnectionStates}}.
309
310next_event(StateName, Record, State) ->
311    next_event(StateName, Record, State, []).
312%%
313next_event(StateName, no_record, #state{static_env = #static_env{role = Role}} = State0, Actions) ->
314    case next_record(StateName, State0) of
315 	{no_record, State} ->
316            ssl_connection:hibernate_after(StateName, State, Actions);
317        {Record, State} ->
318            next_event(StateName, Record, State, Actions);
319        #alert{} = Alert ->
320            ssl_connection:handle_normal_shutdown(Alert#alert{role = Role}, StateName, State0),
321	    {stop, {shutdown, own_alert}, State0}
322    end;
323next_event(StateName,  #ssl_tls{} = Record, State, Actions) ->
324    {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]};
325next_event(StateName,  #alert{} = Alert, State, Actions) ->
326    {next_state, StateName, State, [{next_event, internal, Alert} | Actions]}.
327
328%%% TLS record protocol level application data messages
329handle_protocol_record(#ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, StateName,
330                       #state{start_or_recv_from = From,
331                              socket_options = #socket_options{active = false}} = State0) when From =/= undefined ->
332    case ssl_connection:read_application_data(Data, State0) of
333       {stop, _, _} = Stop->
334            Stop;
335        {Record, #state{start_or_recv_from = Caller} = State} ->
336            TimerAction = case Caller of
337                              undefined -> %% Passive recv complete cancel timer
338                                  [{{timeout, recv}, infinity, timeout}];
339                              _ ->
340                                  []
341                          end,
342            next_event(StateName, Record, State, TimerAction)
343    end;
344handle_protocol_record(#ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, StateName, State0) ->
345    case ssl_connection:read_application_data(Data, State0) of
346	{stop, _, _} = Stop->
347            Stop;
348	{Record, State} ->
349            next_event(StateName, Record, State)
350    end;
351%%% TLS record protocol level handshake messages
352handle_protocol_record(#ssl_tls{type = ?HANDSHAKE, fragment = Data},
353		    StateName, #state{protocol_buffers =
354					  #protocol_buffers{tls_handshake_buffer = Buf0} = Buffers,
355                                      connection_env = #connection_env{negotiated_version = Version},
356                                      static_env = #static_env{role = Role},
357				      ssl_options = Options} = State0) ->
358    try
359	%% Calculate the effective version that should be used when decoding an incoming handshake
360	%% message.
361	EffectiveVersion = effective_version(Version, Options, Role),
362	{Packets, Buf} = tls_handshake:get_tls_handshake(EffectiveVersion,Data,Buf0, Options),
363	State =
364	    State0#state{protocol_buffers =
365			     Buffers#protocol_buffers{tls_handshake_buffer = Buf}},
366	case Packets of
367            [] ->
368                assert_buffer_sanity(Buf, Options),
369                next_event(StateName, no_record, State);
370            _ ->
371                Events = tls_handshake_events(Packets),
372                case StateName of
373                    connection ->
374                        ssl_connection:hibernate_after(StateName, State, Events);
375                    _ ->
376                        HsEnv = State#state.handshake_env,
377                        {next_state, StateName,
378                         State#state{handshake_env =
379                                         HsEnv#handshake_env{unprocessed_handshake_events
380                                                             = unprocessed_events(Events)}}, Events}
381                end
382        end
383    catch throw:#alert{} = Alert ->
384            ssl_connection:handle_own_alert(Alert, Version, StateName, State0)
385    end;
386%%% TLS record protocol level change cipher messages
387handle_protocol_record(#ssl_tls{type = ?CHANGE_CIPHER_SPEC, fragment = Data}, StateName, State) ->
388    {next_state, StateName, State, [{next_event, internal, #change_cipher_spec{type = Data}}]};
389%%% TLS record protocol level Alert messages
390handle_protocol_record(#ssl_tls{type = ?ALERT, fragment = EncAlerts}, StateName,
391                       #state{connection_env = #connection_env{negotiated_version = Version}} = State) ->
392    try decode_alerts(EncAlerts) of
393	Alerts = [_|_] ->
394	    handle_alerts(Alerts,  {next_state, StateName, State});
395	[] ->
396	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, empty_alert),
397					    Version, StateName, State);
398        #alert{} = Alert ->
399            ssl_connection:handle_own_alert(Alert, Version, StateName, State)
400    catch
401	_:_ ->
402	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, alert_decode_error),
403					    Version, StateName, State)
404
405    end;
406%% Ignore unknown TLS record level protocol messages
407handle_protocol_record(#ssl_tls{type = _Unknown}, StateName, State) ->
408    {next_state, StateName, State, []}.
409%%====================================================================
410%% Handshake handling
411%%====================================================================
412renegotiation(Pid, WriteState) ->
413    gen_statem:call(Pid, {user_renegotiate, WriteState}).
414
415renegotiate(#state{static_env = #static_env{role = client},
416                   handshake_env = HsEnv} = State, Actions) ->
417    %% Handle same way as if server requested
418    %% the renegotiation
419    Hs0 = ssl_handshake:init_handshake_history(),
420    {next_state, connection, State#state{handshake_env = HsEnv#handshake_env{tls_handshake_history = Hs0}},
421     [{next_event, internal, #hello_request{}} | Actions]};
422renegotiate(#state{static_env = #static_env{role = server,
423                                            socket = Socket,
424                                            transport_cb = Transport},
425                   handshake_env = HsEnv,
426                   connection_env = #connection_env{negotiated_version = Version},
427		   connection_states = ConnectionStates0} = State0, Actions) ->
428    HelloRequest = ssl_handshake:hello_request(),
429    Frag = tls_handshake:encode_handshake(HelloRequest, Version),
430    Hs0 = ssl_handshake:init_handshake_history(),
431    {BinMsg, ConnectionStates} =
432	tls_record:encode_handshake(Frag, Version, ConnectionStates0),
433    tls_socket:send(Transport, Socket, BinMsg),
434    State = State0#state{connection_states =
435			     ConnectionStates,
436			 handshake_env = HsEnv#handshake_env{tls_handshake_history = Hs0}},
437    next_event(hello, no_record, State, Actions).
438
439send_handshake(Handshake, State) ->
440    send_handshake_flight(queue_handshake(Handshake, State)).
441
442queue_handshake(Handshake, #state{handshake_env = #handshake_env{tls_handshake_history = Hist0} = HsEnv,
443				  connection_env = #connection_env{negotiated_version = Version},
444                                  flight_buffer = Flight0,
445                                  ssl_options = #{log_level := LogLevel},
446				  connection_states = ConnectionStates0} = State0) ->
447    {BinHandshake, ConnectionStates, Hist} =
448	encode_handshake(Handshake, Version, ConnectionStates0, Hist0),
449    ssl_logger:debug(LogLevel, outbound, 'handshake', Handshake),
450    ssl_logger:debug(LogLevel, outbound, 'record', BinHandshake),
451
452    State0#state{connection_states = ConnectionStates,
453                 handshake_env = HsEnv#handshake_env{tls_handshake_history = Hist},
454		 flight_buffer = Flight0 ++ [BinHandshake]}.
455
456
457send_handshake_flight(#state{static_env = #static_env{socket = Socket,
458                                                      transport_cb = Transport},
459			     flight_buffer = Flight} = State0) ->
460    tls_socket:send(Transport, Socket, Flight),
461    {State0#state{flight_buffer = []}, []}.
462
463
464queue_change_cipher(Msg, #state{connection_env = #connection_env{negotiated_version = Version},
465                                flight_buffer = Flight0,
466                                ssl_options = #{log_level := LogLevel},
467                                connection_states = ConnectionStates0} = State0) ->
468    {BinChangeCipher, ConnectionStates} =
469	encode_change_cipher(Msg, Version, ConnectionStates0),
470    ssl_logger:debug(LogLevel, outbound, 'record', BinChangeCipher),
471    State0#state{connection_states = ConnectionStates,
472		 flight_buffer = Flight0 ++ [BinChangeCipher]}.
473
474reinit(#state{protocol_specific = #{sender := Sender},
475              connection_env = #connection_env{negotiated_version = Version},
476              connection_states = #{current_write := Write}} = State) ->
477    tls_sender:update_connection_state(Sender, Write, Version),
478    reinit_handshake_data(State).
479
480reinit_handshake_data(#state{handshake_env = HsEnv} =State) ->
481    %% premaster_secret, public_key_info and tls_handshake_info
482    %% are only needed during the handshake phase.
483    %% To reduce memory foot print of a connection reinitialize them.
484     State#state{
485       handshake_env = HsEnv#handshake_env{tls_handshake_history = ssl_handshake:init_handshake_history(),
486                                           public_key_info = undefined,
487                                           premaster_secret = undefined}
488     }.
489
490select_sni_extension(#client_hello{extensions = #{sni := SNI}}) ->
491    SNI;
492select_sni_extension(_) ->
493    undefined.
494
495empty_connection_state(ConnectionEnd, BeastMitigation) ->
496    ssl_record:empty_connection_state(ConnectionEnd, BeastMitigation).
497
498%%====================================================================
499%% Alert and close handling
500%%====================================================================
501
502%%--------------------------------------------------------------------
503-spec encode_alert(#alert{}, ssl_record:ssl_version(), ssl_record:connection_states()) ->
504		    {iolist(), ssl_record:connection_states()}.
505%%
506%% Description: Encodes an alert
507%%--------------------------------------------------------------------
508encode_alert(#alert{} = Alert, Version, ConnectionStates) ->
509    tls_record:encode_alert_record(Alert, Version, ConnectionStates).
510
511send_alert(Alert, #state{static_env = #static_env{socket = Socket,
512                                                  transport_cb = Transport},
513                         connection_env = #connection_env{negotiated_version = Version},
514                         ssl_options = #{log_level := LogLevel},
515                         connection_states = ConnectionStates0} = StateData0) ->
516    {BinMsg, ConnectionStates} =
517        encode_alert(Alert, Version, ConnectionStates0),
518    tls_socket:send(Transport, Socket, BinMsg),
519    ssl_logger:debug(LogLevel, outbound, 'record', BinMsg),
520    StateData0#state{connection_states = ConnectionStates}.
521
522%% If an ALERT sent in the connection state, should cause the TLS
523%% connection to end, we need to synchronize with the tls_sender
524%% process so that the ALERT if possible (that is the tls_sender process is
525%% not blocked) is sent before the connection process terminates and
526%% thereby closes the transport socket.
527send_alert_in_connection(#alert{level = ?FATAL} = Alert, State) ->
528    send_sync_alert(Alert, State);
529send_alert_in_connection(#alert{description = ?CLOSE_NOTIFY} = Alert, State) ->
530    send_sync_alert(Alert, State);
531send_alert_in_connection(Alert,
532                         #state{protocol_specific = #{sender := Sender}}) ->
533    tls_sender:send_alert(Sender, Alert).
534send_sync_alert(
535  Alert, #state{protocol_specific = #{sender := Sender}} = State) ->
536    try tls_sender:send_and_ack_alert(Sender, Alert)
537    catch
538        _:_ ->
539            throw({stop, {shutdown, own_alert}, State})
540    end.
541
542%% User closes or recursive call!
543close({close, Timeout}, Socket, Transport = gen_tcp, _,_) ->
544    tls_socket:setopts(Transport, Socket, [{active, false}]),
545    Transport:shutdown(Socket, write),
546    _ = Transport:recv(Socket, 0, Timeout),
547    ok;
548%% Peer closed socket
549close({shutdown, transport_closed}, Socket, Transport = gen_tcp, ConnectionStates, Check) ->
550    close({close, 0}, Socket, Transport, ConnectionStates, Check);
551%% We generate fatal alert
552close({shutdown, own_alert}, Socket, Transport = gen_tcp, ConnectionStates, Check) ->
553    %% Standard trick to try to make sure all
554    %% data sent to the tcp port is really delivered to the
555    %% peer application before tcp port is closed so that the peer will
556    %% get the correct TLS alert message and not only a transport close.
557    %% Will return when other side has closed or after timout millisec
558    %% e.g. we do not want to hang if something goes wrong
559    %% with the network but we want to maximise the odds that
560    %% peer application gets all data sent on the tcp connection.
561    close({close, ?DEFAULT_TIMEOUT}, Socket, Transport, ConnectionStates, Check);
562close(downgrade, _,_,_,_) ->
563    ok;
564%% Other
565close(_, Socket, Transport, _,_) ->
566    tls_socket:close(Transport, Socket).
567protocol_name() ->
568    "TLS".
569
570%%====================================================================
571%% Data handling
572%%====================================================================
573
574socket(Pids,  Transport, Socket, Trackers) ->
575    tls_socket:socket(Pids, Transport, Socket, ?MODULE, Trackers).
576
577setopts(Transport, Socket, Other) ->
578    tls_socket:setopts(Transport, Socket, Other).
579
580getopts(Transport, Socket, Tag) ->
581    tls_socket:getopts(Transport, Socket, Tag).
582
583%%--------------------------------------------------------------------
584%% State functions
585%%--------------------------------------------------------------------
586%%--------------------------------------------------------------------
587-spec init(gen_statem:event_type(),
588	   {start, timeout()} | term(), #state{}) ->
589		   gen_statem:state_function_result().
590%%--------------------------------------------------------------------
591
592init({call, From}, {start, Timeout},
593     #state{static_env = #static_env{role = client,
594                                     host = Host,
595                                     port = Port,
596                                     transport_cb = Transport,
597                                     socket = Socket,
598                                     session_cache = Cache,
599                                     session_cache_cb = CacheCb},
600            handshake_env = #handshake_env{renegotiation = {Renegotiation, _}} = HsEnv,
601            connection_env = CEnv,
602	    ssl_options = #{log_level := LogLevel,
603                            %% Use highest version in initial ClientHello.
604                            %% Versions is a descending list of supported versions.
605                            versions := [HelloVersion|_] = Versions,
606                            session_tickets := SessionTickets} = SslOpts,
607	    session = NewSession,
608	    connection_states = ConnectionStates0
609	   } = State0) ->
610    KeyShare = maybe_generate_client_shares(SslOpts),
611    Session = ssl_session:client_select_session({Host, Port, SslOpts}, Cache, CacheCb, NewSession),
612    %% Update UseTicket in case of automatic session resumption
613    {UseTicket, State1} = tls_handshake_1_3:maybe_automatic_session_resumption(State0),
614    TicketData = tls_handshake_1_3:get_ticket_data(self(), SessionTickets, UseTicket),
615    Hello = tls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts,
616                                       Session#session.session_id,
617                                       Renegotiation,
618                                       Session#session.own_certificate,
619                                       KeyShare,
620                                       TicketData),
621
622    Handshake0 = ssl_handshake:init_handshake_history(),
623
624    %% Update pre_shared_key extension with binders (TLS 1.3)
625    Hello1 = tls_handshake_1_3:maybe_add_binders(Hello, TicketData, HelloVersion),
626
627    {BinMsg, ConnectionStates, Handshake} =
628        encode_handshake(Hello1,  HelloVersion, ConnectionStates0, Handshake0),
629
630    tls_socket:send(Transport, Socket, BinMsg),
631    ssl_logger:debug(LogLevel, outbound, 'handshake', Hello1),
632    ssl_logger:debug(LogLevel, outbound, 'record', BinMsg),
633
634    %% RequestedVersion is used as the legacy record protocol version and shall be
635    %% {3,3} in case of TLS 1.2 and higher. In all other cases it defaults to the
636    %% lowest supported protocol version.
637    %%
638    %% negotiated_version is also used by the TLS 1.3 state machine and is set after
639    %% ServerHello is processed.
640    RequestedVersion = tls_record:hello_version(Versions),
641    State = State1#state{connection_states = ConnectionStates,
642                         connection_env = CEnv#connection_env{
643                                            negotiated_version = RequestedVersion},
644                         session = Session,
645                         handshake_env = HsEnv#handshake_env{tls_handshake_history = Handshake},
646                         start_or_recv_from = From,
647                         key_share = KeyShare},
648    next_event(hello, no_record, State, [{{timeout, handshake}, Timeout, close}]);
649
650init(Type, Event, State) ->
651    gen_handshake(?FUNCTION_NAME, Type, Event, State).
652
653%%--------------------------------------------------------------------
654-spec error(gen_statem:event_type(),
655	   {start, timeout()} | term(), #state{}) ->
656		   gen_statem:state_function_result().
657%%--------------------------------------------------------------------
658error({call, From}, {start, _Timeout},
659      #state{protocol_specific = #{error := Error}} = State) ->
660    {stop_and_reply, {shutdown, normal},
661     [{reply, From, {error, Error}}], State};
662
663error({call, _} = Call, Msg, State) ->
664    gen_handshake(?FUNCTION_NAME, Call, Msg, State);
665error(_, _, _) ->
666     {keep_state_and_data, [postpone]}.
667
668%%--------------------------------------------------------------------
669-spec hello(gen_statem:event_type(),
670	    #hello_request{} | #client_hello{} | #server_hello{} | term(),
671	    #state{}) ->
672		   gen_statem:state_function_result().
673%%--------------------------------------------------------------------
674hello(internal, #client_hello{extensions = Extensions} = Hello,
675      #state{ssl_options = #{handshake := hello},
676             handshake_env = HsEnv,
677             start_or_recv_from = From} = State) ->
678    {next_state, user_hello, State#state{start_or_recv_from = undefined,
679                                         handshake_env = HsEnv#handshake_env{hello = Hello}},
680     [{reply, From, {ok, Extensions}}]};
681hello(internal, #server_hello{extensions = Extensions} = Hello,
682      #state{ssl_options = #{handshake := hello},
683             handshake_env = HsEnv,
684             start_or_recv_from = From} = State) ->
685    {next_state, user_hello, State#state{start_or_recv_from = undefined,
686                                         handshake_env = HsEnv#handshake_env{hello = Hello}},
687     [{reply, From, {ok, Extensions}}]};
688
689hello(internal, #client_hello{client_version = ClientVersion} = Hello,
690      #state{connection_states = ConnectionStates0,
691             static_env = #static_env{
692                             port = Port,
693                             session_cache = Cache,
694                             session_cache_cb = CacheCb},
695             handshake_env = #handshake_env{kex_algorithm = KeyExAlg,
696                                            renegotiation = {Renegotiation, _},
697                                            negotiated_protocol = CurrentProtocol} = HsEnv,
698             connection_env = CEnv,
699             session = #session{own_certificate = Cert} = Session0,
700	     ssl_options = SslOpts} = State) ->
701
702    case choose_tls_version(SslOpts, Hello) of
703        'tls_v1.3' ->
704            %% Continue in TLS 1.3 'start' state
705            {next_state, start, State, [{next_event, internal, Hello}]};
706        'tls_v1.2' ->
707            case tls_handshake:hello(Hello,
708                                     SslOpts,
709                                     {Port, Session0, Cache, CacheCb,
710                                      ConnectionStates0, Cert, KeyExAlg},
711                                     Renegotiation) of
712                #alert{} = Alert ->
713                    ssl_connection:handle_own_alert(Alert, ClientVersion, hello,
714                                                    State#state{connection_env = CEnv#connection_env{negotiated_version
715                                                                                                     = ClientVersion}});
716                {Version, {Type, Session},
717                 ConnectionStates, Protocol0, ServerHelloExt, HashSign} ->
718                    Protocol = case Protocol0 of
719                                   undefined -> CurrentProtocol;
720                                   _ -> Protocol0
721                               end,
722                    gen_handshake(?FUNCTION_NAME,
723                                  internal,
724                                  {common_client_hello, Type, ServerHelloExt},
725                                  State#state{connection_states  = ConnectionStates,
726                                              connection_env = CEnv#connection_env{negotiated_version = Version},
727                                              handshake_env = HsEnv#handshake_env{
728                                                                hashsign_algorithm = HashSign,
729                                                                client_hello_version = ClientVersion,
730                                                                negotiated_protocol = Protocol},
731                                              session = Session
732                                             })
733            end
734
735    end;
736hello(internal, #server_hello{} = Hello,
737      #state{connection_states = ConnectionStates0,
738             connection_env = #connection_env{negotiated_version = ReqVersion} = CEnv,
739	     static_env = #static_env{role = client},
740             handshake_env = #handshake_env{renegotiation = {Renegotiation, _}},
741	     ssl_options = SslOptions} = State) ->
742    case tls_handshake:hello(Hello, SslOptions, ConnectionStates0, Renegotiation) of
743	#alert{} = Alert -> %%TODO
744	    ssl_connection:handle_own_alert(Alert, ReqVersion, hello,
745                                            State#state{connection_env =
746                                                            CEnv#connection_env{negotiated_version = ReqVersion}});
747        %% Legacy TLS 1.2 and older
748	{Version, NewId, ConnectionStates, ProtoExt, Protocol} ->
749	    ssl_connection:handle_session(Hello,
750					  Version, NewId, ConnectionStates, ProtoExt, Protocol, State);
751        %% TLS 1.3
752        {next_state, wait_sh, SelectedVersion} ->
753            %% Continue in TLS 1.3 'wait_sh' state
754            {next_state, wait_sh,
755             State#state{
756               connection_env = CEnv#connection_env{negotiated_version = SelectedVersion}},
757             [{next_event, internal, Hello}]}
758    end;
759hello(info, Event, State) ->
760    handle_info(Event, ?FUNCTION_NAME, State);
761hello(Type, Event, State) ->
762    gen_handshake(?FUNCTION_NAME, Type, Event, State).
763
764user_hello(Type, Event, State) ->
765    gen_handshake(?FUNCTION_NAME, Type, Event, State).
766
767%%--------------------------------------------------------------------
768-spec abbreviated(gen_statem:event_type(), term(), #state{}) ->
769			 gen_statem:state_function_result().
770%%--------------------------------------------------------------------
771abbreviated(info, Event, State) ->
772    gen_info(Event, ?FUNCTION_NAME, State);
773abbreviated(Type, Event, State) ->
774    gen_handshake(?FUNCTION_NAME, Type, Event, State).
775
776%%--------------------------------------------------------------------
777-spec certify(gen_statem:event_type(), term(), #state{}) ->
778		     gen_statem:state_function_result().
779%%--------------------------------------------------------------------
780certify(info, Event, State) ->
781    gen_info(Event, ?FUNCTION_NAME, State);
782certify(Type, Event, State) ->
783    gen_handshake(?FUNCTION_NAME, Type, Event, State).
784
785%%--------------------------------------------------------------------
786-spec cipher(gen_statem:event_type(), term(), #state{}) ->
787		    gen_statem:state_function_result().
788%%--------------------------------------------------------------------
789cipher(info, Event, State) ->
790    gen_info(Event, ?FUNCTION_NAME, State);
791cipher(Type, Event, State) ->
792     gen_handshake(?FUNCTION_NAME, Type, Event, State).
793
794%%--------------------------------------------------------------------
795-spec connection(gen_statem:event_type(),
796		 #hello_request{} | #client_hello{}| term(), #state{}) ->
797			gen_statem:state_function_result().
798%%--------------------------------------------------------------------
799connection(info, Event, State) ->
800    gen_info(Event, ?FUNCTION_NAME, State);
801connection({call, From}, {user_renegotiate, WriteState},
802           #state{connection_states = ConnectionStates} = State) ->
803    {next_state,  ?FUNCTION_NAME, State#state{connection_states = ConnectionStates#{current_write => WriteState}},
804     [{next_event,{call, From}, renegotiate}]};
805connection({call, From},
806           {close, {Pid, _Timeout}},
807           #state{connection_env = #connection_env{terminated = closed} = CEnv,
808                 protocol_specific = PS} = State) ->
809    {next_state, downgrade, State#state{connection_env =
810                                            CEnv#connection_env{terminated = true,
811                                                                downgrade = {Pid, From}},
812                                        protocol_specific = PS#{active_n_toggle => true,
813                                                                active_n => 1}
814                                       },
815     [{next_event, internal, ?ALERT_REC(?WARNING, ?CLOSE_NOTIFY)}]};
816connection({call, From},
817           {close,{Pid, Timeout}},
818           #state{connection_states = ConnectionStates,
819                  protocol_specific = #{sender := Sender} = PS,
820                  connection_env = CEnv
821                 } = State0) ->
822    case tls_sender:downgrade(Sender, Timeout) of
823        {ok, Write} ->
824            %% User downgrades connection
825            %% When downgrading an TLS connection to a transport connection
826            %% we must recive the close alert from the peer before releasing the
827            %% transport socket.
828            State = send_alert(?ALERT_REC(?WARNING, ?CLOSE_NOTIFY),
829                               State0#state{connection_states =
830                                                ConnectionStates#{current_write => Write}}),
831            {next_state, downgrade, State#state{connection_env =
832                                                    CEnv#connection_env{downgrade = {Pid, From},
833                                                                        terminated = true},
834                                                protocol_specific = PS#{active_n_toggle => true,
835                                                                        active_n => 1}
836                                               },
837             [{timeout, Timeout, downgrade}]};
838        {error, timeout} ->
839            {stop_and_reply, {shutdown, downgrade_fail}, [{reply, From, {error, timeout}}]}
840    end;
841connection(internal, #hello_request{},
842	   #state{static_env = #static_env{role = client,
843                                           host = Host,
844                                           port = Port,
845                                           session_cache = Cache,
846                                           session_cache_cb = CacheCb},
847                  handshake_env = #handshake_env{renegotiation = {Renegotiation, peer}},
848		  session = #session{own_certificate = Cert} = Session0,
849		  ssl_options = SslOpts,
850                  protocol_specific = #{sender := Pid},
851		  connection_states = ConnectionStates} = State0) ->
852    try tls_sender:peer_renegotiate(Pid) of
853        {ok, Write} ->
854            Session = ssl_session:client_select_session({Host, Port, SslOpts}, Cache, CacheCb, Session0),
855            Hello = tls_handshake:client_hello(Host, Port, ConnectionStates, SslOpts,
856                                               Session#session.session_id,
857                                               Renegotiation, Cert, undefined,
858                                               undefined),
859            {State, Actions} = send_handshake(Hello, State0#state{connection_states = ConnectionStates#{current_write => Write},
860                                                                  session = Session}),
861            next_event(hello, no_record, State, Actions)
862        catch
863            _:_ ->
864                {stop, {shutdown, sender_blocked}, State0}
865        end;
866connection(internal, #hello_request{},
867	   #state{static_env = #static_env{role = client,
868                                           host = Host,
869                                           port = Port},
870                  handshake_env = #handshake_env{renegotiation = {Renegotiation, _}},
871		  session = #session{own_certificate = Cert},
872		  ssl_options = SslOpts,
873		  connection_states = ConnectionStates} = State0) ->
874    Hello = tls_handshake:client_hello(Host, Port, ConnectionStates, SslOpts,
875                                       <<>>, Renegotiation, Cert, undefined,
876                                       undefined),
877
878    {State, Actions} = send_handshake(Hello, State0),
879    next_event(hello, no_record, State, Actions);
880connection(internal, #client_hello{} = Hello,
881	   #state{static_env = #static_env{role = server},
882                  handshake_env = #handshake_env{allow_renegotiate = true}= HsEnv,
883                  connection_states = CS,
884                  protocol_specific = #{sender := Sender}
885                 } = State) ->
886    %% Mitigate Computational DoS attack
887    %% http://www.educatedguesswork.org/2011/10/ssltls_and_computational_dos.html
888    %% http://www.thc.org/thc-ssl-dos/ Rather than disabling client
889    %% initiated renegotiation we will disallow many client initiated
890    %% renegotiations immediately after each other.
891    erlang:send_after(?WAIT_TO_ALLOW_RENEGOTIATION, self(), allow_renegotiate),
892    {ok, Write} = tls_sender:renegotiate(Sender),
893    next_event(hello, no_record, State#state{connection_states = CS#{current_write => Write},
894                                             handshake_env = HsEnv#handshake_env{renegotiation = {true, peer},
895                                                                                 allow_renegotiate = false}
896                                            },
897               [{next_event, internal, Hello}]);
898connection(internal, #client_hello{},
899	   #state{static_env = #static_env{role = server},
900                  handshake_env = #handshake_env{allow_renegotiate = false}} = State0) ->
901    Alert = ?ALERT_REC(?WARNING, ?NO_RENEGOTIATION),
902    send_alert_in_connection(Alert, State0),
903    State = reinit_handshake_data(State0),
904    next_event(?FUNCTION_NAME, no_record, State);
905
906connection(internal, #new_session_ticket{} = NewSessionTicket, State) ->
907    %% TLS 1.3
908    handle_new_session_ticket(NewSessionTicket, State),
909    next_event(?FUNCTION_NAME, no_record, State);
910
911connection(internal, #key_update{} = KeyUpdate, State0) ->
912    %% TLS 1.3
913    case handle_key_update(KeyUpdate, State0) of
914        {ok, State} ->
915            next_event(?FUNCTION_NAME, no_record, State);
916        {error, State, Alert} ->
917            ssl_connection:handle_own_alert(Alert, {3,4}, connection, State),
918            next_event(?FUNCTION_NAME, no_record, State)
919    end;
920
921connection(Type, Event, State) ->
922    ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE).
923
924%%--------------------------------------------------------------------
925-spec downgrade(gen_statem:event_type(), term(), #state{}) ->
926		       gen_statem:state_function_result().
927%%--------------------------------------------------------------------
928downgrade(internal, #alert{description = ?CLOSE_NOTIFY},
929	  #state{static_env = #static_env{transport_cb = Transport,
930                                          socket = Socket},
931		 connection_env = #connection_env{downgrade = {Pid, From}}} = State) ->
932    tls_socket:setopts(Transport, Socket, [{active, false}, {packet, 0}, {mode, binary}]),
933    Transport:controlling_process(Socket, Pid),
934    {stop_and_reply, {shutdown, downgrade},[{reply, From, {ok, Socket}}], State};
935downgrade(timeout, downgrade, #state{ connection_env = #connection_env{downgrade = {_, From}}} = State) ->
936    {stop_and_reply, {shutdown, normal},[{reply, From, {error, timeout}}], State};
937downgrade(info, {CloseTag, Socket},
938          #state{static_env = #static_env{socket = Socket,
939                                          close_tag = CloseTag},
940                 connection_env = #connection_env{downgrade = {_, From}}} =
941              State) ->
942    {stop_and_reply, {shutdown, normal},[{reply, From, {error, CloseTag}}], State};
943downgrade(info, Info, State) ->
944    handle_info(Info, ?FUNCTION_NAME, State);
945downgrade(Type, Event, State) ->
946     ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE).
947
948%%--------------------------------------------------------------------
949%% TLS 1.3 state functions
950%%--------------------------------------------------------------------
951%%--------------------------------------------------------------------
952-spec start(gen_statem:event_type(), term(), #state{}) ->
953			 gen_statem:state_function_result().
954%%--------------------------------------------------------------------
955start(info, Event, State) ->
956    gen_info_1_3(Event, ?FUNCTION_NAME, State);
957start(Type, Event, State) ->
958    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
959
960%%--------------------------------------------------------------------
961-spec negotiated(gen_statem:event_type(), term(), #state{}) ->
962			 gen_statem:state_function_result().
963%%--------------------------------------------------------------------
964negotiated(info, Event, State) ->
965    gen_info_1_3(Event, ?FUNCTION_NAME, State);
966negotiated(Type, Event, State) ->
967    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
968
969%%--------------------------------------------------------------------
970-spec recvd_ch(gen_statem:event_type(), term(), #state{}) ->
971			 gen_statem:state_function_result().
972%%--------------------------------------------------------------------
973recvd_ch(info, Event, State) ->
974    gen_info_1_3(Event, ?FUNCTION_NAME, State);
975recvd_ch(Type, Event, State) ->
976    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
977
978%%--------------------------------------------------------------------
979-spec wait_cert(gen_statem:event_type(), term(), #state{}) ->
980			 gen_statem:state_function_result().
981%%--------------------------------------------------------------------
982wait_cert(info, Event, State) ->
983    gen_info_1_3(Event, ?FUNCTION_NAME, State);
984wait_cert(Type, Event, State) ->
985    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
986
987%%--------------------------------------------------------------------
988-spec wait_cv(gen_statem:event_type(), term(), #state{}) ->
989			 gen_statem:state_function_result().
990%%--------------------------------------------------------------------
991wait_cv(info, Event, State) ->
992    gen_info_1_3(Event, ?FUNCTION_NAME, State);
993wait_cv(Type, Event, State) ->
994    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
995
996%%--------------------------------------------------------------------
997-spec wait_eoed(gen_statem:event_type(), term(), #state{}) ->
998			 gen_statem:state_function_result().
999%%--------------------------------------------------------------------
1000wait_eoed(info, Event, State) ->
1001    gen_info_1_3(Event, ?FUNCTION_NAME, State);
1002wait_eoed(Type, Event, State) ->
1003    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
1004
1005%%--------------------------------------------------------------------
1006-spec wait_finished(gen_statem:event_type(), term(), #state{}) ->
1007			 gen_statem:state_function_result().
1008%%--------------------------------------------------------------------
1009wait_finished(info, Event, State) ->
1010    gen_info_1_3(Event, ?FUNCTION_NAME, State);
1011wait_finished(Type, Event, State) ->
1012    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
1013
1014%%--------------------------------------------------------------------
1015-spec wait_flight2(gen_statem:event_type(), term(), #state{}) ->
1016			 gen_statem:state_function_result().
1017%%--------------------------------------------------------------------
1018wait_flight2(info, Event, State) ->
1019    gen_info_1_3(Event, ?FUNCTION_NAME, State);
1020wait_flight2(Type, Event, State) ->
1021    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
1022
1023%%--------------------------------------------------------------------
1024-spec connected(gen_statem:event_type(), term(), #state{}) ->
1025			 gen_statem:state_function_result().
1026%%--------------------------------------------------------------------
1027connected(info, Event, State) ->
1028    gen_info_1_3(Event, ?FUNCTION_NAME, State);
1029connected(Type, Event, State) ->
1030    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
1031
1032%%--------------------------------------------------------------------
1033-spec wait_cert_cr(gen_statem:event_type(), term(), #state{}) ->
1034			 gen_statem:state_function_result().
1035%%--------------------------------------------------------------------
1036wait_cert_cr(info, Event, State) ->
1037    gen_info_1_3(Event, ?FUNCTION_NAME, State);
1038wait_cert_cr(Type, Event, State) ->
1039    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
1040
1041%%--------------------------------------------------------------------
1042-spec wait_ee(gen_statem:event_type(), term(), #state{}) ->
1043			 gen_statem:state_function_result().
1044%%--------------------------------------------------------------------
1045wait_ee(info, Event, State) ->
1046    gen_info_1_3(Event, ?FUNCTION_NAME, State);
1047wait_ee(Type, Event, State) ->
1048    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
1049
1050%%--------------------------------------------------------------------
1051-spec wait_sh(gen_statem:event_type(), term(), #state{}) ->
1052			 gen_statem:state_function_result().
1053%%--------------------------------------------------------------------
1054wait_sh(info, Event, State) ->
1055    gen_info_1_3(Event, ?FUNCTION_NAME, State);
1056wait_sh(Type, Event, State) ->
1057    gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State).
1058
1059%--------------------------------------------------------------------
1060%% gen_statem callbacks
1061%%--------------------------------------------------------------------
1062callback_mode() ->
1063    state_functions.
1064
1065terminate({shutdown, {sender_died, Reason}}, _StateName,
1066          #state{static_env = #static_env{socket = Socket,
1067                                          transport_cb = Transport}}
1068          = State) ->
1069    ssl_connection:handle_trusted_certs_db(State),
1070    close(Reason, Socket, Transport, undefined, undefined);
1071terminate(Reason, StateName, State) ->
1072    catch ssl_connection:terminate(Reason, StateName, State),
1073    ensure_sender_terminate(Reason, State).
1074
1075format_status(Type, Data) ->
1076    ssl_connection:format_status(Type, Data).
1077
1078code_change(_OldVsn, StateName, State, _) ->
1079    {ok, StateName, State}.
1080
1081%%--------------------------------------------------------------------
1082%%% Internal functions
1083%%--------------------------------------------------------------------
1084initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Trackers}, User,
1085	      {CbModule, DataTag, CloseTag, ErrorTag, PassiveTag}) ->
1086    #{beast_mitigation := BeastMitigation,
1087      erl_dist := IsErlDist,
1088      client_renegotiation := ClientRenegotiation} = SSLOptions,
1089    ConnectionStates = tls_record:init_connection_states(Role, BeastMitigation),
1090    SessionCacheCb = case application:get_env(ssl, session_cb) of
1091			 {ok, Cb} when is_atom(Cb) ->
1092			    Cb;
1093			 _  ->
1094			     ssl_session_cache
1095		     end,
1096    InternalActiveN =  case application:get_env(ssl, internal_active_n) of
1097                           {ok, N} when is_integer(N) andalso (not IsErlDist) ->
1098                               N;
1099                           _  ->
1100                               ?INTERNAL_ACTIVE_N
1101                       end,
1102    UserMonitor = erlang:monitor(process, User),
1103    InitStatEnv = #static_env{
1104                     role = Role,
1105                     transport_cb = CbModule,
1106                     protocol_cb = ?MODULE,
1107                     data_tag = DataTag,
1108                     close_tag = CloseTag,
1109                     error_tag = ErrorTag,
1110                     passive_tag = PassiveTag,
1111                     host = Host,
1112                     port = Port,
1113                     socket = Socket,
1114                     session_cache_cb = SessionCacheCb,
1115                     trackers = Trackers
1116                    },
1117    #state{
1118       static_env = InitStatEnv,
1119       handshake_env = #handshake_env{
1120                          tls_handshake_history = ssl_handshake:init_handshake_history(),
1121                          renegotiation = {false, first},
1122                          allow_renegotiate = ClientRenegotiation
1123                         },
1124       connection_env = #connection_env{user_application = {UserMonitor, User}},
1125       socket_options = SocketOptions,
1126       ssl_options = SSLOptions,
1127       session = #session{is_resumable = new},
1128       connection_states = ConnectionStates,
1129       protocol_buffers = #protocol_buffers{},
1130       user_data_buffer = {[],0,[]},
1131       start_or_recv_from = undefined,
1132       flight_buffer = [],
1133       protocol_specific = #{sender => Sender,
1134                             active_n => InternalActiveN,
1135                             active_n_toggle => true
1136                            }
1137      }.
1138
1139initialize_tls_sender(#state{static_env = #static_env{
1140                                             role = Role,
1141                                             transport_cb = Transport,
1142                                             socket = Socket,
1143                                             trackers = Trackers
1144                                            },
1145                             connection_env = #connection_env{negotiated_version = Version},
1146                             socket_options = SockOpts,
1147                             ssl_options = #{renegotiate_at := RenegotiateAt,
1148                                             key_update_at := KeyUpdateAt,
1149                                             log_level := LogLevel},
1150                             connection_states = #{current_write := ConnectionWriteState},
1151                             protocol_specific = #{sender := Sender}}) ->
1152    Init = #{current_write => ConnectionWriteState,
1153             role => Role,
1154             socket => Socket,
1155             socket_options => SockOpts,
1156             trackers => Trackers,
1157             transport_cb => Transport,
1158             negotiated_version => Version,
1159             renegotiate_at => RenegotiateAt,
1160             key_update_at => KeyUpdateAt,
1161             log_level => LogLevel},
1162    tls_sender:initialize(Sender, Init).
1163
1164next_tls_record(Data, StateName,
1165                         #state{protocol_buffers =
1166                                    #protocol_buffers{tls_record_buffer = Buf0,
1167                                                      tls_cipher_texts = CT0} = Buffers,
1168                                ssl_options = SslOpts} = State0) ->
1169    Versions =
1170        %% TLSPlaintext.legacy_record_version is ignored in TLS 1.3 and thus all
1171        %% record version are accepted when receiving initial ClientHello and
1172        %% ServerHello. This can happen in state 'hello' in case of all TLS
1173        %% versions and also in state 'start' when TLS 1.3 is negotiated.
1174        %% After the version is negotiated all subsequent TLS records shall have
1175        %% the proper legacy_record_version (= negotiated_version).
1176        %% Note: TLS record version {3,4} is used internally in TLS 1.3 and at this
1177        %% point it is the same as the negotiated protocol version.
1178        %% TODO: Refactor state machine and introduce a record_protocol_version beside
1179        %% the negotiated_version.
1180        case StateName of
1181            State when State =:= hello orelse
1182                       State =:= start ->
1183                [tls_record:protocol_version(Vsn) || Vsn <- ?ALL_AVAILABLE_VERSIONS];
1184            _ ->
1185                State0#state.connection_env#connection_env.negotiated_version
1186        end,
1187    case tls_record:get_tls_records(Data, Versions, Buf0, SslOpts) of
1188	{Records, Buf1} ->
1189	    CT1 = CT0 ++ Records,
1190	    next_record(StateName, State0#state{protocol_buffers =
1191					 Buffers#protocol_buffers{tls_record_buffer = Buf1,
1192								  tls_cipher_texts = CT1}});
1193	#alert{} = Alert ->
1194	    handle_record_alert(Alert, State0)
1195    end.
1196
1197
1198handle_record_alert(Alert, _) ->
1199    Alert.
1200
1201tls_handshake_events(Packets) ->
1202    lists:map(fun(Packet) ->
1203		      {next_event, internal, {handshake, Packet}}
1204	      end, Packets).
1205
1206%% raw data from socket, upack records
1207handle_info({Protocol, _, Data}, StateName,
1208            #state{static_env = #static_env{data_tag = Protocol},
1209                   connection_env = #connection_env{negotiated_version = Version}} = State0) ->
1210    case next_tls_record(Data, StateName, State0) of
1211	{Record, State} ->
1212	    next_event(StateName, Record, State);
1213	#alert{} = Alert ->
1214	    ssl_connection:handle_own_alert(Alert, Version, StateName, State0)
1215    end;
1216handle_info({PassiveTag, Socket},  StateName,
1217            #state{static_env = #static_env{socket = Socket,
1218                                            passive_tag = PassiveTag},
1219                   start_or_recv_from = From,
1220                   protocol_buffers = #protocol_buffers{tls_cipher_texts = CTs},
1221                   protocol_specific = PS
1222                  } = State0) ->
1223    case (From =/= undefined) andalso (CTs == []) of
1224        true ->
1225            {Record, State} = activate_socket(State0#state{protocol_specific = PS#{active_n_toggle => true}}),
1226            next_event(StateName, Record, State);
1227        false ->
1228            next_event(StateName, no_record,
1229                       State0#state{protocol_specific = PS#{active_n_toggle => true}})
1230    end;
1231handle_info({CloseTag, Socket}, StateName,
1232            #state{static_env = #static_env{
1233                                   role = Role,
1234                                   host = Host,
1235                                   port = Port,
1236                                   socket = Socket,
1237                                   close_tag = CloseTag},
1238                   handshake_env = #handshake_env{renegotiation = Type},
1239                   connection_env = #connection_env{negotiated_version = Version},
1240                   session = Session} = State) when  StateName =/= connection ->
1241    ssl_connection:maybe_invalidate_session(Version, Type, Role, Host, Port, Session),
1242    Alert = ?ALERT_REC(?FATAL, ?CLOSE_NOTIFY, transport_closed),
1243    ssl_connection:handle_normal_shutdown(Alert#alert{role = Role}, StateName, State),
1244    {stop, {shutdown, transport_closed}, State};
1245handle_info({CloseTag, Socket}, StateName,
1246            #state{static_env = #static_env{
1247                                   role = Role,
1248                                   socket = Socket,
1249                                   close_tag = CloseTag},
1250                   socket_options = #socket_options{active = Active},
1251                   protocol_buffers = #protocol_buffers{tls_cipher_texts = CTs},
1252                   user_data_buffer = {_,BufferSize,_},
1253                   protocol_specific = PS} = State) ->
1254
1255    %% Note that as of TLS 1.1,
1256    %% failure to properly close a connection no longer requires that a
1257    %% session not be resumed.  This is a change from TLS 1.0 to conform
1258    %% with widespread implementation practice.
1259
1260    case (Active == false) andalso ((CTs =/= []) or (BufferSize =/= 0)) of
1261        false ->
1262            %% As invalidate_sessions here causes performance issues,
1263            %% we will conform to the widespread implementation
1264            %% practice and go aginst the spec
1265            %% case Version of
1266            %%     {3, N} when N >= 1 ->
1267            %%         ok;
1268            %%     _ ->
1269            %%         invalidate_session(Role, Host, Port, Session)
1270            %%         ok
1271            %% end,
1272            Alert = ?ALERT_REC(?FATAL, ?CLOSE_NOTIFY, transport_closed),
1273            ssl_connection:handle_normal_shutdown(Alert#alert{role = Role}, StateName, State),
1274            {stop, {shutdown, transport_closed}, State};
1275        true ->
1276            %% Fixes non-delivery of final TLS record in {active, once}.
1277            %% Basically allows the application the opportunity to set {active, once} again
1278            %% and then receive the final message. Set internal active_n to zero
1279            %% to ensure socket close message is sent if there is not enough data to deliver.
1280            next_event(StateName, no_record, State#state{protocol_specific = PS#{active_n_toggle => true}})
1281    end;
1282handle_info({'EXIT', Sender, Reason}, _,
1283            #state{protocol_specific = #{sender := Sender}} = State) ->
1284    {stop, {shutdown, {sender_died, Reason}}, State};
1285handle_info(Msg, StateName, State) ->
1286    ssl_connection:StateName(info, Msg, State, ?MODULE).
1287
1288handle_alerts([], Result) ->
1289    Result;
1290handle_alerts(_, {stop, _, _} = Stop) ->
1291    Stop;
1292handle_alerts([#alert{level = ?WARNING, description = ?CLOSE_NOTIFY} | _Alerts],
1293              {next_state, connection = StateName, #state{connection_env = CEnv,
1294                                                          socket_options = #socket_options{active = false},
1295                                                          user_data_buffer = {_,BufferSize,_},
1296                                                          protocol_buffers = #protocol_buffers{tls_cipher_texts = CTs}} =
1297                   State}) when (BufferSize =/= 0) orelse
1298                                (CTs =/= []) ->
1299    {next_state, StateName, State#state{connection_env = CEnv#connection_env{terminated = true}}};
1300handle_alerts([Alert | Alerts], {next_state, StateName, State}) ->
1301     handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State));
1302handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) ->
1303     handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State)).
1304
1305encode_handshake(Handshake, Version, ConnectionStates0, Hist0) ->
1306    Frag = tls_handshake:encode_handshake(Handshake, Version),
1307    Hist = ssl_handshake:update_handshake_history(Hist0, Frag),
1308    {Encoded, ConnectionStates} =
1309        tls_record:encode_handshake(Frag, Version, ConnectionStates0),
1310    {Encoded, ConnectionStates, Hist}.
1311
1312encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) ->
1313    tls_record:encode_change_cipher_spec(Version, ConnectionStates).
1314
1315decode_alerts(Bin) ->
1316    ssl_alert:decode(Bin).
1317
1318gen_handshake(StateName, Type, Event,
1319	      #state{connection_env = #connection_env{negotiated_version = Version}} = State) ->
1320    try ssl_connection:StateName(Type, Event, State, ?MODULE) of
1321	Result ->
1322	    Result
1323    catch
1324	_:_ ->
1325 	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
1326						       malformed_handshake_data),
1327					    Version, StateName, State)
1328    end.
1329
1330
1331gen_handshake_1_3(StateName, Type, Event,
1332	      #state{connection_env = #connection_env{negotiated_version = Version}} = State) ->
1333    try tls_connection_1_3:StateName(Type, Event, State, ?MODULE) of
1334	Result ->
1335	    Result
1336    catch
1337	_:_ ->
1338            ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
1339						       malformed_handshake_data),
1340					    Version, StateName, State)
1341    end.
1342
1343
1344gen_info(Event, connection = StateName,  #state{connection_env = #connection_env{negotiated_version = Version}} = State) ->
1345    try handle_info(Event, StateName, State) of
1346	Result ->
1347	    Result
1348    catch
1349        _:_ ->
1350	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?INTERNAL_ERROR,
1351						       malformed_data),
1352					    Version, StateName, State)
1353    end;
1354
1355gen_info(Event, StateName, #state{connection_env = #connection_env{negotiated_version = Version}} = State) ->
1356    try handle_info(Event, StateName, State) of
1357	Result ->
1358	    Result
1359    catch
1360        _:_ ->
1361	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
1362						       malformed_handshake_data),
1363					    Version, StateName, State)
1364    end.
1365
1366gen_info_1_3(Event, connected = StateName,  #state{connection_env = #connection_env{negotiated_version = Version}} = State) ->
1367    try handle_info(Event, StateName, State) of
1368	Result ->
1369	    Result
1370    catch
1371        _:_ ->
1372	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?INTERNAL_ERROR,
1373						       malformed_data),
1374					    Version, StateName, State)
1375    end;
1376
1377gen_info_1_3(Event, StateName, #state{connection_env = #connection_env{negotiated_version = Version}} = State) ->
1378    try handle_info(Event, StateName, State) of
1379	Result ->
1380	    Result
1381    catch
1382        _:_ ->
1383	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
1384						       malformed_handshake_data),
1385					    Version, StateName, State)
1386    end.
1387
1388
1389unprocessed_events(Events) ->
1390    %% The first handshake event will be processed immediately
1391    %% as it is entered first in the event queue and
1392    %% when it is processed there will be length(Events)-1
1393    %% handshake events left to process before we should
1394    %% process more TLS-records received on the socket.
1395    erlang:length(Events)-1.
1396
1397
1398assert_buffer_sanity(<<?BYTE(_Type), ?UINT24(Length), Rest/binary>>,
1399                     #{max_handshake_size := Max}) when
1400      Length =< Max ->
1401    case size(Rest) of
1402        N when N < Length ->
1403            true;
1404        N when N > Length ->
1405            throw(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
1406                             too_big_handshake_data));
1407        _ ->
1408            throw(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
1409                             malformed_handshake_data))
1410    end;
1411assert_buffer_sanity(Bin, _) ->
1412    case size(Bin) of
1413        N when N < 3 ->
1414            true;
1415        _ ->
1416            throw(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
1417                             malformed_handshake_data))
1418    end.
1419
1420ensure_sender_terminate(downgrade, _) ->
1421    ok; %% Do not terminate sender during downgrade phase
1422ensure_sender_terminate(_,  #state{protocol_specific = #{sender := Sender}}) ->
1423    %% Make sure TLS sender dies when connection process is terminated normally
1424    %% This is needed if the tls_sender is blocked in prim_inet:send
1425    Kill = fun() ->
1426                   receive
1427                   after 5000 ->
1428                           catch (exit(Sender, kill))
1429                   end
1430           end,
1431    spawn(Kill).
1432
1433maybe_generate_client_shares(#{versions := [Version|_],
1434                               supported_groups :=
1435                                   #supported_groups{
1436                                      supported_groups = [Group|_]}})
1437  when Version =:= {3,4} ->
1438    %% Generate only key_share entry for the most preferred group
1439    ssl_cipher:generate_client_shares([Group]);
1440maybe_generate_client_shares(_) ->
1441    undefined.
1442
1443choose_tls_version(#{versions := Versions},
1444                   #client_hello{
1445                      extensions = #{client_hello_versions :=
1446                                         #client_hello_versions{versions = ClientVersions}
1447                                    }
1448                     }) ->
1449    case ssl_handshake:select_supported_version(ClientVersions, Versions) of
1450        {3,4} ->
1451            'tls_v1.3';
1452        _Else ->
1453            'tls_v1.2'
1454    end;
1455choose_tls_version(_, _) ->
1456    'tls_v1.2'.
1457
1458
1459%% Special version handling for TLS 1.3 clients:
1460%% In the shared state 'init' negotiated_version is set to requested version and
1461%% that is expected by the legacy part of the state machine. However, in order to
1462%% be able to process new TLS 1.3 extensions, the effective version shall be set
1463%% {3,4}.
1464%% When highest supported version is {3,4} the negotiated version is set to {3,3}.
1465effective_version({3,3} , #{versions := [Version|_]}, client) when Version >= {3,4} ->
1466    Version;
1467%% Use highest supported version during startup (TLS server, all versions).
1468effective_version(undefined, #{versions := [Version|_]}, _) ->
1469    Version;
1470%% Use negotiated version in all other cases.
1471effective_version(Version, _, _) ->
1472    Version.
1473
1474
1475handle_new_session_ticket(_, #state{ssl_options = #{session_tickets := disabled}}) ->
1476    ok;
1477handle_new_session_ticket(#new_session_ticket{ticket_nonce = Nonce} = NewSessionTicket,
1478                          #state{connection_states = ConnectionStates,
1479                                 ssl_options = #{session_tickets := SessionTickets,
1480                                                 server_name_indication := SNI},
1481                                 connection_env = #connection_env{user_application = {_, User}}})
1482  when SessionTickets =:= manual ->
1483    #{security_parameters := SecParams} =
1484	ssl_record:current_connection_state(ConnectionStates, read),
1485    HKDF = SecParams#security_parameters.prf_algorithm,
1486    RMS = SecParams#security_parameters.resumption_master_secret,
1487    PSK = tls_v1:pre_shared_key(RMS, Nonce, HKDF),
1488    send_ticket_data(User, NewSessionTicket, HKDF, SNI, PSK);
1489handle_new_session_ticket(#new_session_ticket{ticket_nonce = Nonce} = NewSessionTicket,
1490                          #state{connection_states = ConnectionStates,
1491                                 ssl_options = #{session_tickets := SessionTickets,
1492                                                 server_name_indication := SNI}})
1493  when SessionTickets =:= auto ->
1494    #{security_parameters := SecParams} =
1495	ssl_record:current_connection_state(ConnectionStates, read),
1496    HKDF = SecParams#security_parameters.prf_algorithm,
1497    RMS = SecParams#security_parameters.resumption_master_secret,
1498    PSK = tls_v1:pre_shared_key(RMS, Nonce, HKDF),
1499    tls_client_ticket_store:store_ticket(NewSessionTicket, HKDF, SNI, PSK).
1500
1501
1502handle_key_update(#key_update{request_update = update_not_requested}, State0) ->
1503    %% Update read key in connection
1504    {ok, update_cipher_key(current_read, State0)};
1505handle_key_update(#key_update{request_update = update_requested},
1506                  #state{protocol_specific = #{sender := Sender}} = State0) ->
1507    %% Update read key in connection
1508    State1 = update_cipher_key(current_read, State0),
1509    %% Send key_update and update sender's write key
1510    case send_key_update(Sender, update_not_requested) of
1511        ok ->
1512            {ok, State1};
1513        {error, Reason} ->
1514            {error, State1, ?ALERT_REC(?FATAL, ?INTERNAL_ERROR, Reason)}
1515    end.
1516
1517
1518update_cipher_key(ConnStateName, #state{connection_states = CS0} = State0) ->
1519    CS = update_cipher_key(ConnStateName, CS0),
1520    State0#state{connection_states = CS};
1521update_cipher_key(ConnStateName, CS0) ->
1522    #{security_parameters := SecParams0,
1523      cipher_state := CipherState0} = ConnState0 = maps:get(ConnStateName, CS0),
1524    HKDF = SecParams0#security_parameters.prf_algorithm,
1525    CipherSuite = SecParams0#security_parameters.cipher_suite,
1526    ApplicationTrafficSecret0 = SecParams0#security_parameters.application_traffic_secret,
1527    ApplicationTrafficSecret = tls_v1:update_traffic_secret(HKDF, ApplicationTrafficSecret0),
1528
1529    %% Calculate traffic keys
1530    #{cipher := Cipher} = ssl_cipher_format:suite_bin_to_map(CipherSuite),
1531    {Key, IV} = tls_v1:calculate_traffic_keys(HKDF, Cipher, ApplicationTrafficSecret),
1532
1533    SecParams = SecParams0#security_parameters{application_traffic_secret = ApplicationTrafficSecret},
1534    CipherState = CipherState0#cipher_state{key = Key, iv = IV},
1535    ConnState = ConnState0#{security_parameters => SecParams,
1536                            cipher_state => CipherState,
1537                            sequence_number => 0},
1538    CS0#{ConnStateName => ConnState}.
1539
1540
1541send_key_update(Sender, Type) ->
1542    KeyUpdate = tls_handshake_1_3:key_update(Type),
1543    tls_sender:send_post_handshake(Sender, KeyUpdate).
1544
1545
1546%% Send ticket data to user as opaque binary
1547send_ticket_data(User, NewSessionTicket, HKDF, SNI, PSK) ->
1548    Timestamp = erlang:system_time(seconds),
1549    TicketData = #{hkdf => HKDF,
1550                   sni => SNI,
1551                   psk => PSK,
1552                   timestamp => Timestamp,
1553                   ticket => NewSessionTicket},
1554    User ! {ssl, session_ticket, {SNI, erlang:term_to_binary(TicketData)}}.
1555