1%% Copyright (c) 2011-2021, Loïc Hoguin <essen@ninenines.eu>
2%% Copyright (c) 2021, Maria Scott <maria-12648430@hnc-agency.org>
3%%
4%% Permission to use, copy, modify, and/or distribute this software for any
5%% purpose with or without fee is hereby granted, provided that the above
6%% copyright notice and this permission notice appear in all copies.
7%%
8%% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9%% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10%% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11%% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12%% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13%% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14%% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16%% Make sure to never reload this module outside a release upgrade,
17%% as calling l(ranch_conns_sup) twice will kill the process and all
18%% the currently open connections.
19-module(ranch_conns_sup).
20
21%% API.
22-export([start_link/6]).
23-export([start_protocol/3]).
24-export([active_connections/1]).
25
26%% Supervisor internals.
27-export([init/7]).
28-export([system_continue/3]).
29-export([system_terminate/4]).
30-export([system_code_change/4]).
31
32-type conn_type() :: worker | supervisor.
33-type shutdown() :: brutal_kill | timeout().
34
35-record(state, {
36	parent = undefined :: pid(),
37	ref :: ranch:ref(),
38	id :: pos_integer(),
39	conn_type :: conn_type(),
40	shutdown :: shutdown(),
41	transport = undefined :: module(),
42	protocol = undefined :: module(),
43	opts :: any(),
44	handshake_timeout :: timeout(),
45	max_conns = undefined :: ranch:max_conns(),
46	stats_counters_ref :: counters:counters_ref(),
47	alarms = #{} :: #{term() => {map(), undefined | reference()}},
48	logger = undefined :: module()
49}).
50
51%% API.
52
53-spec start_link(ranch:ref(), pos_integer(), module(), any(), module(), module()) -> {ok, pid()}.
54start_link(Ref, Id, Transport, TransOpts, Protocol, Logger) ->
55	proc_lib:start_link(?MODULE, init,
56		[self(), Ref, Id, Transport, TransOpts, Protocol, Logger]).
57
58%% We can safely assume we are on the same node as the supervisor.
59%%
60%% We can also safely avoid having a monitor and a timeout here
61%% because only three things can happen:
62%%  *  The supervisor died; rest_for_one strategy killed all acceptors
63%%     so this very calling process is going to di--
64%%  *  There's too many connections, the supervisor will resume the
65%%     acceptor only when we get below the limit again.
66%%  *  The supervisor is overloaded, there's either too many acceptors
67%%     or the max_connections limit is too large. It's better if we
68%%     don't keep accepting connections because this leaves
69%%     more room for the situation to be resolved.
70%%
71%% We do not need the reply, we only need the ok from the supervisor
72%% to continue. The supervisor sends its own pid when the acceptor can
73%% continue.
74-spec start_protocol(pid(), reference(), inet:socket()) -> ok.
75start_protocol(SupPid, MonitorRef, Socket) ->
76	SupPid ! {?MODULE, start_protocol, self(), Socket},
77	receive
78		SupPid ->
79			ok;
80		{'DOWN', MonitorRef, process, SupPid, Reason} ->
81			error(Reason)
82	end.
83
84%% We can't make the above assumptions here. This function might be
85%% called from anywhere.
86-spec active_connections(pid()) -> non_neg_integer().
87active_connections(SupPid) ->
88	Tag = erlang:monitor(process, SupPid),
89	catch erlang:send(SupPid, {?MODULE, active_connections, self(), Tag},
90		[noconnect]),
91	receive
92		{Tag, Ret} ->
93			erlang:demonitor(Tag, [flush]),
94			Ret;
95		{'DOWN', Tag, _, _, noconnection} ->
96			exit({nodedown, node(SupPid)});
97		{'DOWN', Tag, _, _, Reason} ->
98			exit(Reason)
99	after 5000 ->
100		erlang:demonitor(Tag, [flush]),
101		exit(timeout)
102	end.
103
104%% Supervisor internals.
105
106-spec init(pid(), ranch:ref(), pos_integer(), module(), any(), module(), module()) -> no_return().
107init(Parent, Ref, Id, Transport, TransOpts, Protocol, Logger) ->
108	process_flag(trap_exit, true),
109	ok = ranch_server:set_connections_sup(Ref, Id, self()),
110	MaxConns = ranch_server:get_max_connections(Ref),
111	Alarms = get_alarms(TransOpts),
112	ConnType = maps:get(connection_type, TransOpts, worker),
113	Shutdown = maps:get(shutdown, TransOpts, 5000),
114	HandshakeTimeout = maps:get(handshake_timeout, TransOpts, 5000),
115	ProtoOpts = ranch_server:get_protocol_options(Ref),
116	StatsCounters = ranch_server:get_stats_counters(Ref),
117	ok = proc_lib:init_ack(Parent, {ok, self()}),
118	loop(#state{parent=Parent, ref=Ref, id=Id, conn_type=ConnType,
119		shutdown=Shutdown, transport=Transport, protocol=Protocol,
120		opts=ProtoOpts, stats_counters_ref=StatsCounters,
121		handshake_timeout=HandshakeTimeout,
122		max_conns=MaxConns, alarms=Alarms,
123		logger=Logger}, 0, 0, []).
124
125loop(State=#state{parent=Parent, ref=Ref, id=Id, conn_type=ConnType,
126		transport=Transport, protocol=Protocol, opts=Opts, stats_counters_ref=StatsCounters,
127		alarms=Alarms, max_conns=MaxConns, logger=Logger}, CurConns, NbChildren, Sleepers) ->
128	receive
129		{?MODULE, start_protocol, To, Socket} ->
130			try Protocol:start_link(Ref, Transport, Opts) of
131				{ok, Pid} ->
132					inc_accept(StatsCounters, Id, 1),
133					handshake(State, CurConns, NbChildren, Sleepers, To, Socket, Pid, Pid);
134				{ok, SupPid, ProtocolPid} when ConnType =:= supervisor ->
135					inc_accept(StatsCounters, Id, 1),
136					handshake(State, CurConns, NbChildren, Sleepers, To, Socket, SupPid, ProtocolPid);
137				Ret ->
138					To ! self(),
139					ranch:log(error,
140						"Ranch listener ~p connection process start failure; "
141						"~p:start_link/3 returned: ~999999p~n",
142						[Ref, Protocol, Ret], Logger),
143					Transport:close(Socket),
144					loop(State, CurConns, NbChildren, Sleepers)
145			catch Class:Reason ->
146				To ! self(),
147				ranch:log(error,
148					"Ranch listener ~p connection process start failure; "
149					"~p:start_link/3 crashed with reason: ~p:~999999p~n",
150					[Ref, Protocol, Class, Reason], Logger),
151				Transport:close(Socket),
152				loop(State, CurConns, NbChildren, Sleepers)
153			end;
154		{?MODULE, active_connections, To, Tag} ->
155			To ! {Tag, CurConns},
156			loop(State, CurConns, NbChildren, Sleepers);
157		%% Remove a connection from the count of connections.
158		{remove_connection, Ref, Pid} ->
159			case put(Pid, removed) of
160				active when Sleepers =:= [] ->
161					loop(State, CurConns - 1, NbChildren, Sleepers);
162				active ->
163					[To|Sleepers2] = Sleepers,
164					To ! self(),
165					loop(State, CurConns - 1, NbChildren, Sleepers2);
166				removed ->
167					loop(State, CurConns, NbChildren, Sleepers);
168				undefined ->
169					_ = erase(Pid),
170					loop(State, CurConns, NbChildren, Sleepers)
171			end;
172		%% Upgrade the max number of connections allowed concurrently.
173		%% We resume all sleeping acceptors if this number increases.
174		{set_max_conns, MaxConns2} when MaxConns2 > MaxConns ->
175			_ = [To ! self() || To <- Sleepers],
176			loop(State#state{max_conns=MaxConns2},
177				CurConns, NbChildren, []);
178		{set_max_conns, MaxConns2} ->
179			loop(State#state{max_conns=MaxConns2},
180				CurConns, NbChildren, Sleepers);
181		%% Upgrade the transport options.
182		{set_transport_options, TransOpts} ->
183			set_transport_options(State, CurConns, NbChildren, Sleepers, TransOpts);
184		%% Upgrade the protocol options.
185		{set_protocol_options, Opts2} ->
186			loop(State#state{opts=Opts2},
187				CurConns, NbChildren, Sleepers);
188		{timeout, _, {activate_alarm, AlarmName}} when is_map_key(AlarmName, Alarms) ->
189			{AlarmOpts, _} = maps:get(AlarmName, Alarms),
190			NewAlarm = trigger_alarm(Ref, AlarmName, {AlarmOpts, undefined}, CurConns),
191			loop(State#state{alarms=Alarms#{AlarmName => NewAlarm}}, CurConns, NbChildren, Sleepers);
192		{timeout, _, {activate_alarm, _}} ->
193			loop(State, CurConns, NbChildren, Sleepers);
194		{'EXIT', Parent, Reason} ->
195			terminate(State, Reason, NbChildren);
196		{'EXIT', Pid, Reason} when Sleepers =:= [] ->
197			case erase(Pid) of
198				active ->
199					inc_terminate(StatsCounters, Id, 1),
200					report_error(Logger, Ref, Protocol, Pid, Reason),
201					loop(State, CurConns - 1, NbChildren - 1, Sleepers);
202				removed ->
203					inc_terminate(StatsCounters, Id, 1),
204					report_error(Logger, Ref, Protocol, Pid, Reason),
205					loop(State, CurConns, NbChildren - 1, Sleepers);
206				undefined ->
207					loop(State, CurConns, NbChildren, Sleepers)
208			end;
209		%% Resume a sleeping acceptor if needed.
210		{'EXIT', Pid, Reason} ->
211			case erase(Pid) of
212				active when CurConns > MaxConns ->
213					inc_terminate(StatsCounters, Id, 1),
214					report_error(Logger, Ref, Protocol, Pid, Reason),
215					loop(State, CurConns - 1, NbChildren - 1, Sleepers);
216				active ->
217					inc_terminate(StatsCounters, Id, 1),
218					report_error(Logger, Ref, Protocol, Pid, Reason),
219					[To|Sleepers2] = Sleepers,
220					To ! self(),
221					loop(State, CurConns - 1, NbChildren - 1, Sleepers2);
222				removed ->
223					inc_terminate(StatsCounters, Id, 1),
224					report_error(Logger, Ref, Protocol, Pid, Reason),
225					loop(State, CurConns, NbChildren - 1, Sleepers);
226				undefined ->
227					loop(State, CurConns, NbChildren, Sleepers)
228			end;
229		{system, From, Request} ->
230			sys:handle_system_msg(Request, From, Parent, ?MODULE, [],
231				{State, CurConns, NbChildren, Sleepers});
232		%% Calls from the supervisor module.
233		{'$gen_call', {To, Tag}, which_children} ->
234			Children = [{Protocol, Pid, ConnType, [Protocol]}
235				|| {Pid, Type} <- get(),
236				Type =:= active orelse Type =:= removed],
237			To ! {Tag, Children},
238			loop(State, CurConns, NbChildren, Sleepers);
239		{'$gen_call', {To, Tag}, count_children} ->
240			Counts = case ConnType of
241				worker -> [{supervisors, 0}, {workers, NbChildren}];
242				supervisor -> [{supervisors, NbChildren}, {workers, 0}]
243			end,
244			Counts2 = [{specs, 1}, {active, NbChildren}|Counts],
245			To ! {Tag, Counts2},
246			loop(State, CurConns, NbChildren, Sleepers);
247		{'$gen_call', {To, Tag}, _} ->
248			To ! {Tag, {error, ?MODULE}},
249			loop(State, CurConns, NbChildren, Sleepers);
250		Msg ->
251			ranch:log(error,
252				"Ranch listener ~p received unexpected message ~p~n",
253				[Ref, Msg], Logger),
254			loop(State, CurConns, NbChildren, Sleepers)
255	end.
256
257handshake(State=#state{ref=Ref, transport=Transport, handshake_timeout=HandshakeTimeout,
258		max_conns=MaxConns, alarms=Alarms0}, CurConns, NbChildren, Sleepers, To, Socket, SupPid, ProtocolPid) ->
259	case Transport:controlling_process(Socket, ProtocolPid) of
260		ok ->
261			ProtocolPid ! {handshake, Ref, Transport, Socket, HandshakeTimeout},
262			put(SupPid, active),
263			CurConns2 = CurConns + 1,
264			Sleepers2 = if CurConns2 < MaxConns ->
265					To ! self(),
266					Sleepers;
267				true ->
268					[To|Sleepers]
269			end,
270			Alarms1 = trigger_alarms(Ref, Alarms0, CurConns2),
271			loop(State#state{alarms=Alarms1}, CurConns2, NbChildren + 1, Sleepers2);
272		{error, _} ->
273			Transport:close(Socket),
274			%% Only kill the supervised pid, because the connection's pid,
275			%% when different, is supposed to be sitting under it and linked.
276			exit(SupPid, kill),
277			To ! self(),
278			loop(State, CurConns, NbChildren, Sleepers)
279	end.
280
281trigger_alarms(Ref, Alarms, CurConns) ->
282	maps:map(
283		fun
284			(AlarmName, Alarm) ->
285				trigger_alarm(Ref, AlarmName, Alarm, CurConns)
286		end,
287		Alarms
288	).
289
290trigger_alarm(Ref, AlarmName, {Opts=#{treshold := Treshold, callback := Callback}, undefined}, CurConns) when CurConns >= Treshold ->
291	ActiveConns = [Pid || {Pid, active} <- get()],
292	case Callback of
293		{Mod, Fun} ->
294			spawn(Mod, Fun, [Ref, AlarmName, self(), ActiveConns]);
295		_ ->
296			Self = self(),
297			spawn(fun () -> Callback(Ref, AlarmName, Self, ActiveConns) end)
298	end,
299	{Opts, schedule_activate_alarm(AlarmName, Opts)};
300trigger_alarm(_, _, Alarm, _) ->
301	Alarm.
302
303schedule_activate_alarm(AlarmName, #{cooldown := Cooldown}) when Cooldown > 0 ->
304	erlang:start_timer(Cooldown, self(), {activate_alarm, AlarmName});
305schedule_activate_alarm(_, _) ->
306	undefined.
307
308get_alarms(#{alarms := Alarms}) when is_map(Alarms) ->
309	maps:fold(
310		fun
311			(Name, Opts = #{type := num_connections, cooldown := _}, Acc) ->
312				Acc#{Name => {Opts, undefined}};
313			(Name, Opts = #{type := num_connections}, Acc) ->
314				Acc#{Name => {Opts#{cooldown => 5000}, undefined}};
315			(_, _, Acc) -> Acc
316		end,
317		#{},
318		Alarms
319	);
320get_alarms(_) ->
321	#{}.
322
323set_transport_options(State=#state{max_conns=MaxConns0}, CurConns, NbChildren, Sleepers0, TransOpts) ->
324	MaxConns1 = maps:get(max_connections, TransOpts, 1024),
325	HandshakeTimeout = maps:get(handshake_timeout, TransOpts, 5000),
326	Shutdown = maps:get(shutdown, TransOpts, 5000),
327	Sleepers1 = case MaxConns1 > MaxConns0 of
328		true ->
329			_ = [To ! self() || To <- Sleepers0],
330			[];
331		false ->
332			Sleepers0
333	end,
334	State1=set_alarm_option(State, TransOpts, CurConns),
335	loop(State1#state{max_conns=MaxConns1, handshake_timeout=HandshakeTimeout, shutdown=Shutdown},
336		CurConns, NbChildren, Sleepers1).
337
338set_alarm_option(State=#state{ref=Ref, alarms=OldAlarms}, TransOpts, CurConns) ->
339	NewAlarms0 = get_alarms(TransOpts),
340	NewAlarms1 = merge_alarms(OldAlarms, NewAlarms0),
341	NewAlarms2 = trigger_alarms(Ref, NewAlarms1, CurConns),
342	State#state{alarms=NewAlarms2}.
343
344merge_alarms(Old, New) ->
345	OldList = lists:sort(maps:to_list(Old)),
346	NewList = lists:sort(maps:to_list(New)),
347	Merged = merge_alarms(OldList, NewList, []),
348	maps:from_list(Merged).
349
350merge_alarms([], News, Acc) ->
351	News ++ Acc;
352merge_alarms([{_, {_, undefined}}|Olds], [], Acc) ->
353	merge_alarms(Olds, [], Acc);
354merge_alarms([{_, {_, Timer}}|Olds], [], Acc) ->
355	_ = cancel_alarm_reactivation_timer(Timer),
356	merge_alarms(Olds, [], Acc);
357merge_alarms([{Name, {OldOpts, Timer}}|Olds], [{Name, {NewOpts, _}}|News], Acc) ->
358	merge_alarms(Olds, News, [{Name, {NewOpts, adapt_alarm_timer(Name, Timer, OldOpts, NewOpts)}}|Acc]);
359merge_alarms([{OldName, {_, Timer}}|Olds], News=[{NewName, _}|_], Acc) when OldName < NewName ->
360	_ = cancel_alarm_reactivation_timer(Timer),
361	merge_alarms(Olds, News, Acc);
362merge_alarms(Olds, [New|News], Acc) ->
363	merge_alarms(Olds, News, [New|Acc]).
364
365%% Not in cooldown.
366adapt_alarm_timer(_, undefined, _, _) ->
367	undefined;
368%% Cooldown unchanged.
369adapt_alarm_timer(_, Timer, #{cooldown := Cooldown}, #{cooldown := Cooldown}) ->
370	Timer;
371%% Cooldown changed to no cooldown, cancel cooldown timer.
372adapt_alarm_timer(_, Timer, _, #{cooldown := 0}) ->
373	_ = cancel_alarm_reactivation_timer(Timer),
374	undefined;
375%% Cooldown changed, cancel current and start new timer taking the already elapsed time into account.
376adapt_alarm_timer(Name, Timer, #{cooldown := OldCooldown}, #{cooldown := NewCooldown}) ->
377	OldTimeLeft = cancel_alarm_reactivation_timer(Timer),
378	case NewCooldown-OldCooldown+OldTimeLeft of
379		NewTimeLeft when NewTimeLeft>0 ->
380			erlang:start_timer(NewTimeLeft, self(), {activate_alarm, Name});
381		_ ->
382			undefined
383	end.
384
385cancel_alarm_reactivation_timer(Timer) ->
386	case erlang:cancel_timer(Timer) of
387		%% Timer had already expired when we tried to cancel it, so we flush the
388		%% reactivation message it sent and return 0 as remaining time.
389		false ->
390			ok = receive {timeout, Timer, {activate_alarm, _}} -> ok after 0 -> ok end,
391			0;
392		%% Timer has not yet expired, we return the amount of time that was remaining.
393		TimeLeft ->
394			TimeLeft
395	end.
396
397-spec terminate(#state{}, any(), non_neg_integer()) -> no_return().
398terminate(#state{shutdown=brutal_kill, id=Id,
399		stats_counters_ref=StatsCounters}, Reason, NbChildren) ->
400	kill_children(get_keys(active)),
401	kill_children(get_keys(removed)),
402	inc_terminate(StatsCounters, Id, NbChildren),
403	exit(Reason);
404%% Attempt to gracefully shutdown all children.
405terminate(#state{shutdown=Shutdown, id=Id,
406		stats_counters_ref=StatsCounters}, Reason, NbChildren) ->
407	shutdown_children(get_keys(active)),
408	shutdown_children(get_keys(removed)),
409	_ = if
410		Shutdown =:= infinity ->
411			ok;
412		true ->
413			erlang:send_after(Shutdown, self(), kill)
414	end,
415	wait_children(NbChildren),
416	inc_terminate(StatsCounters, Id, NbChildren),
417	exit(Reason).
418
419inc_accept(StatsCounters, Id, N) ->
420	%% Accepts are counted in the odd indexes.
421	counters:add(StatsCounters, 2*Id-1, N).
422
423inc_terminate(StatsCounters, Id, N) ->
424	%% Terminates are counted in the even indexes.
425	counters:add(StatsCounters, 2*Id, N).
426
427%% Kill all children and then exit. We unlink first to avoid
428%% getting a message for each child getting killed.
429kill_children(Pids) ->
430	_ = [begin
431		unlink(P),
432		exit(P, kill)
433	end || P <- Pids],
434	ok.
435
436%% Monitor processes so we can know which ones have shutdown
437%% before the timeout. Unlink so we avoid receiving an extra
438%% message. Then send a shutdown exit signal.
439shutdown_children(Pids) ->
440	_ = [begin
441		monitor(process, P),
442		unlink(P),
443		exit(P, shutdown)
444	end || P <- Pids],
445	ok.
446
447wait_children(0) ->
448	ok;
449wait_children(NbChildren) ->
450	receive
451		{'DOWN', _, process, Pid, _} ->
452			case erase(Pid) of
453				active -> wait_children(NbChildren - 1);
454				removed -> wait_children(NbChildren - 1);
455				_ -> wait_children(NbChildren)
456			end;
457		kill ->
458			Active = get_keys(active),
459			_ = [exit(P, kill) || P <- Active],
460			Removed = get_keys(removed),
461			_ = [exit(P, kill) || P <- Removed],
462			ok
463	end.
464
465-spec system_continue(_, _, any()) -> no_return().
466system_continue(_, _, {State, CurConns, NbChildren, Sleepers}) ->
467	loop(State, CurConns, NbChildren, Sleepers).
468
469-spec system_terminate(any(), _, _, _) -> no_return().
470system_terminate(Reason, _, _, {State, _, NbChildren, _}) ->
471	terminate(State, Reason, NbChildren).
472
473-spec system_code_change(any(), _, _, _) -> {ok, any()}.
474system_code_change({#state{parent=Parent, ref=Ref, conn_type=ConnType,
475		shutdown=Shutdown, transport=Transport, protocol=Protocol,
476		opts=Opts, handshake_timeout=HandshakeTimeout,
477		max_conns=MaxConns, logger=Logger}, CurConns, NbChildren,
478		Sleepers}, _, {down, _}, _) ->
479	{ok, {{state, Parent, Ref, ConnType, Shutdown, Transport, Protocol,
480		Opts, HandshakeTimeout, MaxConns, Logger}, CurConns, NbChildren,
481		Sleepers}};
482system_code_change({{state, Parent, Ref, ConnType, Shutdown, Transport, Protocol,
483		Opts, HandshakeTimeout, MaxConns, Logger}, CurConns, NbChildren,
484		Sleepers}, _, _, _) ->
485	Self = self(),
486	[Id] = [Id || {Id, Pid} <- ranch_server:get_connections_sups(Ref), Pid=:=Self],
487	StatsCounters = ranch_server:get_stats_counters(Ref),
488	{ok, {#state{parent=Parent, ref=Ref, id=Id, conn_type=ConnType, shutdown=Shutdown,
489		transport=Transport, protocol=Protocol, opts=Opts,
490		handshake_timeout=HandshakeTimeout, max_conns=MaxConns,
491		stats_counters_ref=StatsCounters,
492		logger=Logger}, CurConns, NbChildren, Sleepers}};
493system_code_change(Misc, _, _, _) ->
494	{ok, Misc}.
495
496%% We use ~999999p here instead of ~w because the latter doesn't
497%% support printable strings.
498report_error(_, _, _, _, normal) ->
499	ok;
500report_error(_, _, _, _, shutdown) ->
501	ok;
502report_error(_, _, _, _, {shutdown, _}) ->
503	ok;
504report_error(Logger, Ref, Protocol, Pid, Reason) ->
505	ranch:log(error,
506		"Ranch listener ~p had connection process started with "
507		"~p:start_link/3 at ~p exit with reason: ~999999p~n",
508		[Ref, Protocol, Pid, Reason], Logger).
509