1%%
2%%
3%% Copyright WhatsApp Inc. and its affiliates. All rights reserved.
4%%
5%% Licensed under the Apache License, Version 2.0 (the "License");
6%% you may not use this file except in compliance with the License.
7%% You may obtain a copy of the License at
8%%
9%%     http://www.apache.org/licenses/LICENSE-2.0
10%%
11%% Unless required by applicable law or agreed to in writing, software
12%% distributed under the License is distributed on an "AS IS" BASIS,
13%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14%% See the License for the specific language governing permissions and
15%% limitations under the License.
16%%
17%%
18%%-------------------------------------------------------------------
19%%
20%% @author Maxim Fedorov <maximfca@gmail.com>
21%% Process Groups with eventually consistent membership.
22%%
23%% Differences (compared to pg2):
24%%  * non-existent and empty group treated the same (empty list of pids),
25%%     thus create/1 and delete/1 have no effect (and not implemented).
26%%     which_groups() return only non-empty groups
27%%  * no cluster lock required, and no dependency on global
28%%  * all join/leave operations require local process (it's not possible to join
29%%     a process from a different node)
30%%  * multi-join: join/leave several processes with a single call
31%%
32%% Why empty groups are not supported:
33%%  Unlike a process, group does not have originating node. So it's possible
34%% that during net split one node deletes the group, that still exists for
35%% another partition. pg2 will recover the group, as soon as net
36%% split converges, which is quite unexpected.
37%%
38%% Exchange protocol:
39%%  * when pg process starts, it broadcasts
40%%     'discover' message to all nodes in the cluster
41%%  * when pg server receives 'discover', it responds with 'sync' message
42%%     containing list of groups with all local processes, and starts to
43%%     monitor process that sent 'discover' message (assuming it is a part
44%%     of an overlay network)
45%%  * every pg process monitors 'nodeup' messages to attempt discovery for
46%%     nodes that are (re)joining the cluster
47%%
48%% Leave/join operations:
49%%  * processes joining the group are monitored on the local node
50%%  * when process exits (without leaving groups prior to exit), local
51%%     instance of pg scoped process detects this and sends 'leave' to
52%%     all nodes in an overlay network (no remote monitoring done)
53%%  * all leave/join operations are serialised through pg server process
54%%
55-module(pg).
56
57%% API: default scope
58-export([
59    start_link/0,
60
61    join/2,
62    leave/2,
63    get_members/1,
64    get_local_members/1,
65    which_groups/0,
66    which_local_groups/0
67]).
68
69%% Scoped API: overlay networks
70-export([
71    start/1,
72    start_link/1,
73
74    join/3,
75    leave/3,
76    get_members/2,
77    get_local_members/2,
78    which_groups/1,
79    which_local_groups/1
80]).
81
82%% gen_server callbacks
83-export([
84    init/1,
85    handle_call/3,
86    handle_cast/2,
87    handle_info/2,
88    terminate/2
89]).
90
91%% Types
92-type group() :: any().
93
94%% Default scope started by kernel app
95-define(DEFAULT_SCOPE, ?MODULE).
96
97%%--------------------------------------------------------------------
98%% @doc
99%% Starts the server and links it to calling process.
100%% Uses default scope, which is the same as as the module name.
101-spec start_link() -> {ok, pid()} | {error, any()}.
102start_link() ->
103    start_link(?DEFAULT_SCOPE).
104
105%% @doc
106%% Starts the server outside of supervision hierarchy.
107-spec start(Scope :: atom()) -> {ok, pid()} | {error, any()}.
108start(Scope) when is_atom(Scope) ->
109    gen_server:start({local, Scope}, ?MODULE, [Scope], []).
110
111%% @doc
112%% Starts the server and links it to calling process.
113%% Scope name is supplied.
114-spec start_link(Scope :: atom()) -> {ok, pid()} | {error, any()}.
115start_link(Scope) when is_atom(Scope) ->
116    gen_server:start_link({local, Scope}, ?MODULE, [Scope], []).
117
118%%--------------------------------------------------------------------
119%% @doc
120%% Joins a single or a list of processes.
121%% Group is created automatically.
122%% Processes must be local to this node.
123-spec join(Group :: group(), PidOrPids :: pid() | [pid()]) -> ok.
124join(Group, PidOrPids) ->
125    join(?DEFAULT_SCOPE, Group, PidOrPids).
126
127-spec join(Scope :: atom(), Group :: group(), PidOrPids :: pid() | [pid()]) -> ok.
128join(Scope, Group, PidOrPids) when is_pid(PidOrPids); is_list(PidOrPids) ->
129    ok = ensure_local(PidOrPids),
130    gen_server:call(Scope, {join_local, Group, PidOrPids}, infinity).
131
132%%--------------------------------------------------------------------
133%% @doc
134%% Single or list of processes leaving the group.
135%% Processes must be local to this node.
136-spec leave(Group :: group(), PidOrPids :: pid() | [pid()]) -> ok.
137leave(Group, PidOrPids) ->
138    leave(?DEFAULT_SCOPE, Group, PidOrPids).
139
140-spec leave(Scope :: atom(), Group :: group(), PidOrPids :: pid() | [pid()]) -> ok | not_joined.
141leave(Scope, Group, PidOrPids) when is_pid(PidOrPids); is_list(PidOrPids) ->
142    ok = ensure_local(PidOrPids),
143    gen_server:call(Scope, {leave_local, Group, PidOrPids}, infinity).
144
145%%--------------------------------------------------------------------
146%% @doc
147%% Returns all processes in a group
148-spec get_members(Group :: group()) -> [pid()].
149get_members(Group) ->
150    get_members(?DEFAULT_SCOPE, Group).
151
152-spec get_members(Scope :: atom(), Group :: group()) -> [pid()].
153get_members(Scope, Group) ->
154    try
155        ets:lookup_element(Scope, Group, 2)
156    catch
157        error:badarg ->
158            []
159    end.
160
161%%--------------------------------------------------------------------
162%% @doc
163%% Returns processes in a group, running on local node.
164-spec get_local_members(Group :: group()) -> [pid()].
165get_local_members(Group) ->
166    get_local_members(?DEFAULT_SCOPE, Group).
167
168-spec get_local_members(Scope :: atom(), Group :: group()) -> [pid()].
169get_local_members(Scope, Group) ->
170    try
171        ets:lookup_element(Scope, Group, 3)
172    catch
173        error:badarg ->
174            []
175    end.
176
177%%--------------------------------------------------------------------
178%% @doc
179%% Returns a list of all known groups.
180-spec which_groups() -> [Group :: group()].
181which_groups() ->
182    which_groups(?DEFAULT_SCOPE).
183
184-spec which_groups(Scope :: atom()) -> [Group :: group()].
185which_groups(Scope) when is_atom(Scope) ->
186    [G || [G] <- ets:match(Scope, {'$1', '_', '_'})].
187
188%%--------------------------------------------------------------------
189%% @private
190%% Returns a list of groups that have any local processes joined.
191-spec which_local_groups() -> [Group :: group()].
192which_local_groups() ->
193    which_local_groups(?DEFAULT_SCOPE).
194
195-spec which_local_groups(Scope :: atom()) -> [Group :: group()].
196which_local_groups(Scope) when is_atom(Scope) ->
197    ets:select(Scope, [{{'$1', '_', '$2'}, [{'=/=', '$2', []}], ['$1']}]).
198
199%%--------------------------------------------------------------------
200%% Internal implementation
201
202%% gen_server implementation
203-record(state, {
204    %% ETS table name, and also the registered process name (self())
205    scope :: atom(),
206    %% monitored local processes and groups they joined
207    monitors = #{} :: #{pid() => {MRef :: reference(), Groups :: [group()]}},
208    %% remote node: scope process monitor and map of groups to pids for fast sync routine
209    nodes = #{} :: #{pid() => {reference(), #{group() => [pid()]}}}
210}).
211
212-type state() :: #state{}.
213
214-spec init([Scope :: atom()]) -> {ok, state()}.
215init([Scope]) ->
216    ok = net_kernel:monitor_nodes(true),
217    %% discover all nodes in the cluster
218    broadcast([{Scope, Node} || Node <- nodes()], {discover, self()}),
219    Scope = ets:new(Scope, [set, protected, named_table, {read_concurrency, true}]),
220    {ok, #state{scope = Scope}}.
221
222-spec handle_call(Call :: {join_local, Group :: group(), Pid :: pid()}
223                        | {leave_local, Group :: group(), Pid :: pid()},
224                  From :: {pid(),Tag :: any()},
225                  State :: state()) -> {reply, ok | not_joined, state()}.
226
227handle_call({join_local, Group, PidOrPids}, _From, #state{scope = Scope, monitors = Monitors, nodes = Nodes} = State) ->
228    NewMons = join_monitors(PidOrPids, Group, Monitors),
229    join_local_group(Scope, Group, PidOrPids),
230    broadcast(maps:keys(Nodes), {join, self(), Group, PidOrPids}),
231    {reply, ok, State#state{monitors = NewMons}};
232
233handle_call({leave_local, Group, PidOrPids}, _From, #state{scope = Scope, monitors = Monitors, nodes = Nodes} = State) ->
234    case leave_monitors(PidOrPids, Group, Monitors) of
235        Monitors ->
236            {reply, not_joined, State};
237        NewMons ->
238            leave_local_group(Scope, Group, PidOrPids),
239            broadcast(maps:keys(Nodes), {leave, self(), PidOrPids, [Group]}),
240            {reply, ok, State#state{monitors = NewMons}}
241    end;
242
243handle_call(_Request, _From, _S) ->
244    error(badarg).
245
246-spec handle_cast(
247    {sync, Peer :: pid(), Groups :: [{group(), [pid()]}]},
248    State :: state()) -> {noreply, state()}.
249
250handle_cast({sync, Peer, Groups}, #state{scope = Scope, nodes = Nodes} = State) ->
251    {noreply, State#state{nodes = handle_sync(Scope, Peer, Nodes, Groups)}};
252
253handle_cast(_, _State) ->
254    error(badarg).
255
256-spec handle_info(
257    {discover, Peer :: pid()} |
258    {join, Peer :: pid(), group(), pid() | [pid()]} |
259    {leave, Peer :: pid(), pid() | [pid()], [group()]} |
260    {'DOWN', reference(), process, pid(), term()} |
261    {nodedown, node()} | {nodeup, node()}, State :: state()) -> {noreply, state()}.
262
263%% remote pid or several pids joining the group
264handle_info({join, Peer, Group, PidOrPids}, #state{scope = Scope, nodes = Nodes} = State) ->
265    case maps:get(Peer, Nodes, []) of
266        {MRef, RemoteGroups} ->
267            join_remote(Scope, Group, PidOrPids),
268            %% store remote group => pids map for fast sync operation
269            NewRemoteGroups = join_remote_map(Group, PidOrPids, RemoteGroups),
270            {noreply, State#state{nodes = Nodes#{Peer => {MRef, NewRemoteGroups}}}};
271        [] ->
272            %% handle possible race condition, when remote node is flickering up/down,
273            %%  and remote join can happen after the node left overlay network
274            %% It also handles the case when node outside of overlay network sends
275            %%  unexpected join request.
276            {noreply, State}
277    end;
278
279%% remote pid leaving (multiple groups at once)
280handle_info({leave, Peer, PidOrPids, Groups}, #state{scope = Scope, nodes = Nodes} = State) ->
281    case maps:get(Peer, Nodes, []) of
282        {MRef, RemoteMap} ->
283            _ = leave_remote(Scope, PidOrPids, Groups),
284            NewRemoteMap = leave_update_remote_map(PidOrPids, RemoteMap, Groups),
285            {noreply, State#state{nodes = Nodes#{Peer => {MRef, NewRemoteMap}}}};
286        [] ->
287            %% Handle race condition: remote node disconnected, but scope process
288            %%  of the remote node was just about to send 'leave' message. In this
289            %%  case, local node handles 'DOWN' first, but then connection is
290            %%  restored, and 'leave' message gets delivered when it's not expected.
291            %% It also handles the case when node outside of overlay network sends
292            %%  unexpected leave request.
293            {noreply, State}
294    end;
295
296%% we're being discovered, let's exchange!
297handle_info({discover, Peer}, #state{nodes = Nodes, monitors = Monitors} = State) ->
298    gen_server:cast(Peer, {sync, self(), all_local_pids(Monitors)}),
299    %% do we know who is looking for us?
300    case maps:is_key(Peer, Nodes) of
301        true ->
302            {noreply, State};
303        false ->
304            MRef = monitor(process, Peer),
305            erlang:send(Peer, {discover, self()}, [noconnect]),
306            {noreply, State#state{nodes = Nodes#{Peer => {MRef, #{}}}}}
307    end;
308
309%% handle local process exit
310handle_info({'DOWN', MRef, process, Pid, _Info}, #state{scope = Scope, monitors = Monitors, nodes = Nodes} = State) when node(Pid) =:= node() ->
311    case maps:take(Pid, Monitors) of
312        error ->
313            %% this can only happen when leave request and 'DOWN' are in pg queue
314            {noreply, State};
315        {{MRef, Groups}, NewMons} ->
316            [leave_local_group(Scope, Group, Pid) || Group <- Groups],
317            %% send update to all nodes
318            broadcast(maps:keys(Nodes), {leave, self(), Pid, Groups}),
319            {noreply, State#state{monitors = NewMons}}
320    end;
321
322%% handle remote node down or leaving overlay network
323handle_info({'DOWN', MRef, process, Pid, _Info}, #state{scope = Scope, nodes = Nodes} = State)  ->
324    {{MRef, RemoteMap}, NewNodes} = maps:take(Pid, Nodes),
325    maps:foreach(fun (Group, Pids) -> leave_remote(Scope, Pids, [Group]) end, RemoteMap),
326    {noreply, State#state{nodes = NewNodes}};
327
328%% nodedown: ignore, and wait for 'DOWN' signal for monitored process
329handle_info({nodedown, _Node}, State) ->
330    {noreply, State};
331
332%% nodeup: discover if remote node participates in the overlay network
333handle_info({nodeup, Node}, State) when Node =:= node() ->
334    {noreply, State};
335handle_info({nodeup, Node}, #state{scope = Scope} = State) ->
336    {Scope, Node} ! {discover, self()},
337    {noreply, State};
338
339handle_info(_Info, _State) ->
340    error(badarg).
341
342-spec terminate(Reason :: any(), State :: state()) -> true.
343terminate(_Reason, #state{scope = Scope}) ->
344    true = ets:delete(Scope).
345
346%%--------------------------------------------------------------------
347%% Internal implementation
348
349%% Ensures argument is either a node-local pid or a list of such, or it throws an error
350ensure_local(Pid) when is_pid(Pid), node(Pid) =:= node() ->
351    ok;
352ensure_local(Pids) when is_list(Pids) ->
353    lists:foreach(
354        fun
355            (Pid) when is_pid(Pid), node(Pid) =:= node() ->
356                ok;
357            (Bad) ->
358                error({nolocal, Bad})
359        end, Pids);
360ensure_local(Bad) ->
361    error({nolocal, Bad}).
362
363%% Override all knowledge of the remote node with information it sends
364%%  to local node. Current implementation must do the full table scan
365%%  to remove stale pids (just as for 'nodedown').
366handle_sync(Scope, Peer, Nodes, Groups) ->
367    %% can't use maps:get() because it evaluates 'default' value first,
368    %%   and in this case monitor() call has side effect.
369    {MRef, RemoteGroups} =
370        case maps:find(Peer, Nodes) of
371            error ->
372                {monitor(process, Peer), #{}};
373            {ok, MRef0} ->
374                MRef0
375        end,
376    %% sync RemoteMap and transform ETS table
377    _ = sync_groups(Scope, RemoteGroups, Groups),
378    Nodes#{Peer => {MRef, maps:from_list(Groups)}}.
379
380sync_groups(Scope, RemoteGroups, []) ->
381    %% leave all missing groups
382    [leave_remote(Scope, Pids, [Group]) || {Group, Pids} <- maps:to_list(RemoteGroups)];
383sync_groups(Scope, RemoteGroups, [{Group, Pids} | Tail]) ->
384    case maps:take(Group, RemoteGroups) of
385        {Pids, NewRemoteGroups} ->
386            sync_groups(Scope, NewRemoteGroups, Tail);
387        {OldPids, NewRemoteGroups} ->
388            [{Group, AllOldPids, LocalPids}] = ets:lookup(Scope, Group),
389            %% should be really rare...
390            AllNewPids = Pids ++ AllOldPids -- OldPids,
391            true = ets:insert(Scope, {Group, AllNewPids, LocalPids}),
392            sync_groups(Scope, NewRemoteGroups, Tail);
393        error ->
394            join_remote(Scope, Group, Pids),
395            sync_groups(Scope, RemoteGroups, Tail)
396    end.
397
398join_monitors(Pid, Group, Monitors) when is_pid(Pid) ->
399    case maps:find(Pid, Monitors) of
400        {ok, {MRef, Groups}} ->
401            maps:put(Pid, {MRef, [Group | Groups]}, Monitors);
402        error ->
403            MRef = erlang:monitor(process, Pid),
404            Monitors#{Pid => {MRef, [Group]}}
405    end;
406join_monitors([], _Group, Monitors) ->
407    Monitors;
408join_monitors([Pid | Tail], Group, Monitors) ->
409    join_monitors(Tail, Group, join_monitors(Pid, Group, Monitors)).
410
411join_local_group(Scope, Group, Pid) when is_pid(Pid) ->
412    case ets:lookup(Scope, Group) of
413        [{Group, All, Local}] ->
414            ets:insert(Scope, {Group, [Pid | All], [Pid | Local]});
415        [] ->
416            ets:insert(Scope, {Group, [Pid], [Pid]})
417    end;
418join_local_group(Scope, Group, Pids) ->
419    case ets:lookup(Scope, Group) of
420        [{Group, All, Local}] ->
421            ets:insert(Scope, {Group, Pids ++ All, Pids ++ Local});
422        [] ->
423            ets:insert(Scope, {Group, Pids, Pids})
424    end.
425
426join_remote(Scope, Group, Pid) when is_pid(Pid) ->
427    case ets:lookup(Scope, Group) of
428        [{Group, All, Local}] ->
429            ets:insert(Scope, {Group, [Pid | All], Local});
430        [] ->
431            ets:insert(Scope, {Group, [Pid], []})
432    end;
433join_remote(Scope, Group, Pids) ->
434    case ets:lookup(Scope, Group) of
435        [{Group, All, Local}] ->
436            ets:insert(Scope, {Group, Pids ++ All, Local});
437        [] ->
438            ets:insert(Scope, {Group, Pids, []})
439    end.
440
441join_remote_map(Group, Pid, RemoteGroups) when is_pid(Pid) ->
442    maps:update_with(Group, fun (List) -> [Pid | List] end, [Pid], RemoteGroups);
443join_remote_map(Group, Pids, RemoteGroups) ->
444    maps:update_with(Group, fun (List) -> Pids ++ List end, Pids, RemoteGroups).
445
446leave_monitors(Pid, Group, Monitors) when is_pid(Pid) ->
447    case maps:find(Pid, Monitors) of
448        {ok, {MRef, [Group]}} ->
449            erlang:demonitor(MRef),
450            maps:remove(Pid, Monitors);
451        {ok, {MRef, Groups}} ->
452            case lists:member(Group, Groups) of
453                true ->
454                    maps:put(Pid, {MRef, lists:delete(Group, Groups)}, Monitors);
455                false ->
456                    Monitors
457            end;
458        _ ->
459            Monitors
460    end;
461leave_monitors([], _Group, Monitors) ->
462    Monitors;
463leave_monitors([Pid | Tail], Group, Monitors) ->
464    leave_monitors(Tail, Group, leave_monitors(Pid, Group, Monitors)).
465
466leave_local_group(Scope, Group, Pid) when is_pid(Pid) ->
467    case ets:lookup(Scope, Group) of
468        [{Group, [Pid], [Pid]}] ->
469            ets:delete(Scope, Group);
470        [{Group, All, Local}] ->
471            ets:insert(Scope, {Group, lists:delete(Pid, All), lists:delete(Pid, Local)});
472        [] ->
473            %% rare race condition when 'DOWN' from monitor stays in msg queue while process is leave-ing.
474            true
475    end;
476leave_local_group(Scope, Group, Pids) ->
477    case ets:lookup(Scope, Group) of
478        [{Group, All, Local}] ->
479            case All -- Pids of
480                [] ->
481                    ets:delete(Scope, Group);
482                NewAll ->
483                    ets:insert(Scope, {Group, NewAll, Local -- Pids})
484            end;
485        [] ->
486            true
487    end.
488
489leave_remote(Scope, Pid, Groups) when is_pid(Pid) ->
490    _ = [
491        case ets:lookup(Scope, Group) of
492            [{Group, [Pid], []}] ->
493                ets:delete(Scope, Group);
494            [{Group, All, Local}] ->
495                ets:insert(Scope, {Group, lists:delete(Pid, All), Local});
496            [] ->
497                true
498        end ||
499        Group <- Groups];
500leave_remote(Scope, Pids, Groups) ->
501    _ = [
502        case ets:lookup(Scope, Group) of
503            [{Group, All, Local}] ->
504                case All -- Pids of
505                    [] when Local =:= [] ->
506                        ets:delete(Scope, Group);
507                    NewAll ->
508                        ets:insert(Scope, {Group, NewAll, Local})
509                end;
510            [] ->
511                true
512        end ||
513        Group <- Groups].
514
515leave_update_remote_map(Pid, RemoteMap, Groups) when is_pid(Pid) ->
516    leave_update_remote_map([Pid], RemoteMap, Groups);
517leave_update_remote_map(Pids, RemoteMap, Groups) ->
518    lists:foldl(
519        fun (Group, Acc) ->
520            case maps:get(Group, Acc) -- Pids of
521                [] ->
522                    maps:remove(Group, Acc);
523                Remaining ->
524                    Acc#{Group => Remaining}
525            end
526        end, RemoteMap, Groups).
527
528all_local_pids(Monitors) ->
529    maps:to_list(maps:fold(
530        fun(Pid, {_Ref, Groups}, Acc) ->
531            lists:foldl(
532                fun(Group, Acc1) ->
533                    Acc1#{Group => [Pid | maps:get(Group, Acc1, [])]}
534                end,
535                Acc,
536                Groups
537            )
538        end,
539        #{},
540        Monitors
541    )).
542
543%% Works as gen_server:abcast(), but accepts a list of processes
544%%   instead of nodes list.
545broadcast([], _Msg) ->
546    ok;
547broadcast([Dest | Tail], Msg) ->
548    %% do not use 'nosuspend', as it will lead to missing
549    %%   join/leave messages when dist buffer is full
550    erlang:send(Dest, Msg, [noconnect]),
551    broadcast(Tail, Msg).
552