1%%% Copyright (C) 2009 - Will Glozer.  All rights reserved.
2%%% Copyright (C) 2011 - Anton Lebedevich.  All rights reserved.
3
4%%% @doc GenServer holding all connection state (including socket).
5%%%
6%%% See https://www.postgresql.org/docs/current/static/protocol-flow.html
7%%% Commands in PostgreSQL are pipelined: you don't need to wait for reply to
8%%% be able to send next command.
9%%% Commands are processed (and responses to them are generated) in FIFO order.
10%%% eg, if you execute 2 SimpleQuery: #1 and #2, first you get all response
11%%% packets for #1 and then all for #2:
12%%% > SQuery #1
13%%% > SQuery #2
14%%% < RowDescription #1
15%%% < DataRow #1
16%%% < CommandComplete #1
17%%% < RowDescription #2
18%%% < DataRow #2
19%%% < CommandComplete #2
20%%%
21%%% See epgsql_cmd_connect for network connection and authentication setup
22
23
24-module(epgsql_sock).
25
26-behavior(gen_server).
27
28-export([start_link/0,
29         close/1,
30         sync_command/3,
31         async_command/4,
32         get_parameter/2,
33         set_notice_receiver/2,
34         get_cmd_status/1,
35         cancel/1]).
36
37-export([handle_call/3, handle_cast/2, handle_info/2]).
38-export([init/1, code_change/3, terminate/2]).
39
40%% loop callback
41-export([on_message/3, on_replication/3]).
42
43%% Comand's APIs
44-export([set_net_socket/3, init_replication_state/1, set_attr/3, get_codec/1,
45         get_rows/1, get_results/1, notify/2, send/2, send/3, send_multi/2,
46         get_parameter_internal/2,
47         get_replication_state/1, set_packet_handler/2]).
48
49-export_type([transport/0, pg_sock/0]).
50
51-include("epgsql.hrl").
52-include("protocol.hrl").
53-include("epgsql_replication.hrl").
54
55-type transport() :: {call, any()}
56                   | {cast, pid(), reference()}
57                   | {incremental, pid(), reference()}.
58
59-type tcp_socket() :: port(). %gen_tcp:socket() isn't exported prior to erl 18
60-type repl_state() :: #repl{}.
61
62-record(state, {mod :: gen_tcp | ssl | undefined,
63                sock :: tcp_socket() | ssl:sslsocket() | undefined,
64                data = <<>>,
65                backend :: {Pid :: integer(), Key :: integer()} | undefined,
66                handler = on_message :: on_message | on_replication | undefined,
67                codec :: epgsql_binary:codec() | undefined,
68                queue = queue:new() :: queue:queue({epgsql_command:command(), any(), transport()}),
69                current_cmd :: epgsql_command:command() | undefined,
70                current_cmd_state :: any() | undefined,
71                current_cmd_transport :: transport() | undefined,
72                async :: undefined | atom() | pid(),
73                parameters = [] :: [{Key :: binary(), Value :: binary()}],
74                rows = [] :: [tuple()],
75                results = [],
76                sync_required :: boolean() | undefined,
77                txstatus :: byte() | undefined,  % $I | $T | $E,
78                complete_status :: atom() | {atom(), integer()} | undefined,
79                repl :: repl_state() | undefined}).
80
81-opaque pg_sock() :: #state{}.
82
83%% -- client interface --
84
85start_link() ->
86    gen_server:start_link(?MODULE, [], []).
87
88close(C) when is_pid(C) ->
89    catch gen_server:cast(C, stop),
90    ok.
91
92-spec sync_command(epgsql:connection(), epgsql_command:command(), any()) -> any().
93sync_command(C, Command, Args) ->
94    gen_server:call(C, {command, Command, Args}, infinity).
95
96-spec async_command(epgsql:connection(), cast | incremental,
97                    epgsql_command:command(), any()) -> reference().
98async_command(C, Transport, Command, Args) ->
99    Ref = make_ref(),
100    Pid = self(),
101    ok = gen_server:cast(C, {{Transport, Pid, Ref}, Command, Args}),
102    Ref.
103
104get_parameter(C, Name) ->
105    gen_server:call(C, {get_parameter, to_binary(Name)}, infinity).
106
107set_notice_receiver(C, PidOrName) when is_pid(PidOrName);
108                                       is_atom(PidOrName) ->
109    gen_server:call(C, {set_async_receiver, PidOrName}, infinity).
110
111get_cmd_status(C) ->
112    gen_server:call(C, get_cmd_status, infinity).
113
114cancel(S) ->
115    gen_server:cast(S, cancel).
116
117
118%% -- command APIs --
119
120%% send()
121%% send_many()
122
123-spec set_net_socket(gen_tcp | ssl, tcp_socket() | ssl:sslsocket(), pg_sock()) -> pg_sock().
124set_net_socket(Mod, Socket, State) ->
125    State1 = State#state{mod = Mod, sock = Socket},
126    setopts(State1, [{active, true}]),
127    State1.
128
129-spec init_replication_state(pg_sock()) -> pg_sock().
130init_replication_state(State) ->
131    State#state{repl = #repl{}}.
132
133-spec set_attr(atom(), any(), pg_sock()) -> pg_sock().
134set_attr(backend, {_Pid, _Key} = Backend, State) ->
135    State#state{backend = Backend};
136set_attr(async, Async, State) ->
137    State#state{async = Async};
138set_attr(txstatus, Status, State) ->
139    State#state{txstatus = Status};
140set_attr(codec, Codec, State) ->
141    State#state{codec = Codec};
142set_attr(sync_required, Value, State) ->
143    State#state{sync_required = Value};
144set_attr(replication_state, Value, State) ->
145    State#state{repl = Value}.
146
147%% XXX: be careful!
148-spec set_packet_handler(atom(), pg_sock()) -> pg_sock().
149set_packet_handler(Handler, State) ->
150    State#state{handler = Handler}.
151
152-spec get_codec(pg_sock()) -> epgsql_binary:codec().
153get_codec(#state{codec = Codec}) ->
154    Codec.
155
156-spec get_replication_state(pg_sock()) -> repl_state().
157get_replication_state(#state{repl = Repl}) ->
158    Repl.
159
160-spec get_rows(pg_sock()) -> [tuple()].
161get_rows(#state{rows = Rows}) ->
162    lists:reverse(Rows).
163
164-spec get_results(pg_sock()) -> [any()].
165get_results(#state{results = Results}) ->
166    lists:reverse(Results).
167
168-spec get_parameter_internal(binary(), pg_sock()) -> binary() | undefined.
169get_parameter_internal(Name, #state{parameters = Parameters}) ->
170    case lists:keysearch(Name, 1, Parameters) of
171        {value, {Name, Value}} -> Value;
172        false                  -> undefined
173    end.
174
175
176%% -- gen_server implementation --
177
178init([]) ->
179    {ok, #state{}}.
180
181handle_call({get_parameter, Name}, _From, State) ->
182    {reply, {ok, get_parameter_internal(Name, State)}, State};
183
184handle_call({set_async_receiver, PidOrName}, _From, #state{async = Previous} = State) ->
185    {reply, {ok, Previous}, State#state{async = PidOrName}};
186
187handle_call(get_cmd_status, _From, #state{complete_status = Status} = State) ->
188    {reply, {ok, Status}, State};
189
190handle_call({standby_status_update, FlushedLSN, AppliedLSN}, _From,
191            #state{handler = on_replication,
192                   repl = #repl{last_received_lsn = ReceivedLSN} = Repl} = State) ->
193    send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN)),
194    Repl1 = Repl#repl{last_flushed_lsn = FlushedLSN,
195                      last_applied_lsn = AppliedLSN},
196    {reply, ok, State#state{repl = Repl1}};
197handle_call({command, Command, Args}, From, State) ->
198    Transport = {call, From},
199    command_new(Transport, Command, Args, State).
200
201handle_cast({{Method, From, Ref} = Transport, Command, Args}, State)
202  when ((Method == cast) or (Method == incremental)),
203       is_pid(From),
204       is_reference(Ref)  ->
205    command_new(Transport, Command, Args, State);
206
207handle_cast(stop, State) ->
208    {stop, normal, flush_queue(State, {error, closed})};
209
210handle_cast(cancel, State = #state{backend = {Pid, Key},
211                                   sock = TimedOutSock}) ->
212    {ok, {Addr, Port}} = case State#state.mod of
213                             gen_tcp -> inet:peername(TimedOutSock);
214                             ssl -> ssl:peername(TimedOutSock)
215                         end,
216    SockOpts = [{active, false}, {packet, raw}, binary],
217    %% TODO timeout
218    {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
219    Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
220    ok = gen_tcp:send(Sock, Msg),
221    gen_tcp:close(Sock),
222    {noreply, State}.
223
224handle_info({Closed, Sock}, #state{sock = Sock} = State)
225  when Closed == tcp_closed; Closed == ssl_closed ->
226    {stop, sock_closed, flush_queue(State#state{sock = undefined}, {error, sock_closed})};
227
228handle_info({Error, Sock, Reason}, #state{sock = Sock} = State)
229  when Error == tcp_error; Error == ssl_error ->
230    Why = {sock_error, Reason},
231    {stop, Why, flush_queue(State, {error, Why})};
232
233handle_info({inet_reply, _, ok}, State) ->
234    {noreply, State};
235
236handle_info({inet_reply, _, Status}, State) ->
237    {stop, Status, flush_queue(State, {error, Status})};
238
239handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
240    loop(State#state{data = <<Data/binary, Data2/binary>>}).
241
242terminate(_Reason, #state{sock = undefined}) -> ok;
243terminate(_Reason, #state{mod = gen_tcp, sock = Sock}) -> gen_tcp:close(Sock);
244terminate(_Reason, #state{mod = ssl, sock = Sock}) -> ssl:close(Sock).
245
246code_change(_OldVsn, State, _Extra) ->
247    {ok, State}.
248
249%% -- internal functions --
250
251-spec command_new(transport(), epgsql_command:command(), any(), pg_sock()) ->
252                         Result when
253      Result :: {noreply, pg_sock()}
254              | {stop, Reason :: any(), pg_sock()}.
255command_new(Transport, Command, Args, State) ->
256    CmdState = epgsql_command:init(Command, Args),
257    command_exec(Transport, Command, CmdState, State).
258
259-spec command_exec(transport(), epgsql_command:command(), any(), pg_sock()) ->
260                          Result when
261      Result :: {noreply, pg_sock()}
262              | {stop, Reason :: any(), pg_sock()}.
263command_exec(Transport, Command, _, State = #state{sync_required = true})
264  when Command /= epgsql_cmd_sync ->
265    {noreply,
266     finish(State#state{current_cmd = Command,
267                        current_cmd_transport = Transport},
268            {error, sync_required})};
269command_exec(Transport, Command, CmdState, State) ->
270    case epgsql_command:execute(Command, State, CmdState) of
271        {ok, State1, CmdState1} ->
272            {noreply, command_enqueue(Transport, Command, CmdState1, State1)};
273        {stop, StopReason, Response, State1} ->
274            reply(Transport, Response, Response),
275            {stop, StopReason, State1}
276    end.
277
278-spec command_enqueue(transport(), epgsql_command:command(), epgsql_command:state(), pg_sock()) -> pg_sock().
279command_enqueue(Transport, Command, CmdState, #state{current_cmd = undefined} = State) ->
280    State#state{current_cmd = Command,
281                current_cmd_state = CmdState,
282                current_cmd_transport = Transport,
283                complete_status = undefined};
284command_enqueue(Transport, Command, CmdState, #state{queue = Q} = State) ->
285    State#state{queue = queue:in({Command, CmdState, Transport}, Q),
286                complete_status = undefined}.
287
288-spec command_handle_message(byte(), binary() | epgsql:query_error(), pg_sock()) ->
289                                    {noreply, pg_sock()}
290                                  | {stop, any(), pg_sock()}.
291command_handle_message(Msg, Payload,
292                       #state{current_cmd = Command,
293                              current_cmd_state = CmdState} = State) ->
294    case epgsql_command:handle_message(Command, Msg, Payload, State, CmdState) of
295        {add_row, Row, State1, CmdState1} ->
296            {noreply, add_row(State1#state{current_cmd_state = CmdState1}, Row)};
297        {add_result, Result, Notice, State1, CmdState1} ->
298            {noreply,
299             add_result(State1#state{current_cmd_state = CmdState1},
300                        Notice, Result)};
301        {finish, Result, Notice, State1} ->
302            {noreply, finish(State1, Notice, Result)};
303        {noaction, State1} ->
304            {noreply, State1};
305        {noaction, State1, CmdState1} ->
306            {noreply, State1#state{current_cmd_state = CmdState1}};
307        {requeue, State1, CmdState1} ->
308            Transport = State1#state.current_cmd_transport,
309            command_exec(Transport, Command, CmdState1,
310                         State1#state{current_cmd = undefined});
311        {stop, Reason, Response, State1} ->
312            {stop, Reason, finish(State1, Response)};
313        {sync_required, Why} ->
314            %% Protocol error. Finish and flush all pending commands.
315            {noreply, sync_required(finish(State#state{sync_required = true}, Why))};
316        unknown ->
317            {stop, {error, {unexpected_message, Msg, Command, CmdState}}, State}
318    end.
319
320command_next(#state{current_cmd = PrevCmd,
321                    queue = Q} = State) when PrevCmd =/= undefined ->
322    case queue:out(Q) of
323        {empty, _} ->
324            State#state{current_cmd = undefined,
325                        current_cmd_state = undefined,
326                        current_cmd_transport = undefined,
327                        rows = [],
328                        results = []};
329        {{value, {Command, CmdState, Transport}}, Q1} ->
330            State#state{current_cmd = Command,
331                        current_cmd_state = CmdState,
332                        current_cmd_transport = Transport,
333                        queue = Q1,
334                        rows = [],
335                        results = []}
336    end.
337
338
339setopts(#state{mod = Mod, sock = Sock}, Opts) ->
340    case Mod of
341        gen_tcp -> inet:setopts(Sock, Opts);
342        ssl     -> ssl:setopts(Sock, Opts)
343    end.
344
345%% This one only used in connection initiation to send client's
346%% `StartupMessage' and `SSLRequest' packets
347-spec send(pg_sock(), iodata()) -> ok | {error, any()}.
348send(#state{mod = Mod, sock = Sock}, Data) ->
349    do_send(Mod, Sock, epgsql_wire:encode_command(Data)).
350
351-spec send(pg_sock(), byte(), iodata()) -> ok | {error, any()}.
352send(#state{mod = Mod, sock = Sock}, Type, Data) ->
353    do_send(Mod, Sock, epgsql_wire:encode_command(Type, Data)).
354
355-spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}.
356send_multi(#state{mod = Mod, sock = Sock}, List) ->
357    do_send(Mod, Sock, lists:map(fun({Type, Data}) ->
358        epgsql_wire:encode_command(Type, Data)
359    end, List)).
360
361do_send(gen_tcp, Sock, Bin) ->
362    %% Why not gen_tcp:send/2?
363    %% See https://github.com/rabbitmq/rabbitmq-common/blob/v3.7.4/src/rabbit_writer.erl#L367-L384
364    %% Because of that we also have `handle_info({inet_reply, ...`
365    try erlang:port_command(Sock, Bin) of
366        true ->
367            ok
368    catch
369        error:_Error ->
370            {error, einval}
371    end;
372do_send(ssl, Sock, Bin) ->
373    ssl:send(Sock, Bin).
374
375loop(#state{data = Data, handler = Handler, repl = Repl} = State) ->
376    case epgsql_wire:decode_message(Data) of
377        {Type, Payload, Tail} ->
378            case ?MODULE:Handler(Type, Payload, State#state{data = Tail}) of
379                {noreply, State2} ->
380                    loop(State2);
381                R = {stop, _Reason2, _State2} ->
382                    R
383            end;
384        _ ->
385            %% in replication mode send feedback after each batch of messages
386            case (Repl =/= undefined) andalso (Repl#repl.feedback_required) of
387                true ->
388                    #repl{last_received_lsn = LastReceivedLSN,
389                          last_flushed_lsn = LastFlushedLSN,
390                          last_applied_lsn = LastAppliedLSN} = Repl,
391                    send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(
392                        LastReceivedLSN, LastFlushedLSN, LastAppliedLSN)),
393                    {noreply, State#state{repl = Repl#repl{feedback_required = false}}};
394                _ ->
395                    {noreply, State}
396            end
397    end.
398
399finish(State, Result) ->
400    finish(State, Result, Result).
401
402finish(State = #state{current_cmd_transport = Transport}, Notice, Result) ->
403    reply(Transport, Notice, Result),
404    command_next(State).
405
406reply({cast, From, Ref}, _, Result) ->
407    From ! {self(), Ref, Result};
408reply({incremental, From, Ref}, Notice, _) ->
409    From ! {self(), Ref, Notice};
410reply({call, From}, _, Result) ->
411    gen_server:reply(From, Result).
412
413add_result(#state{results = Results, current_cmd_transport = Transport} = State, Notice, Result) ->
414    Results2 = case Transport of
415                   {incremental, From, Ref} ->
416                       From ! {self(), Ref, Notice},
417                       Results;
418                   _ ->
419                       [Result | Results]
420               end,
421    State#state{rows = [],
422                results = Results2}.
423
424add_row(#state{rows = Rows, current_cmd_transport = Transport} = State, Data) ->
425    Rows2 = case Transport of
426                {incremental, From, Ref} ->
427                    From ! {self(), Ref, {data, Data}},
428                    Rows;
429                _ ->
430                    [Data | Rows]
431            end,
432    State#state{rows = Rows2}.
433
434notify(#state{current_cmd_transport = {incremental, From, Ref}} = State, Notice) ->
435    From ! {self(), Ref, Notice},
436    State;
437notify(State, _) ->
438    State.
439
440%% Send asynchronous messages (notice / notification)
441notify_async(#state{async = undefined}, _) ->
442    false;
443notify_async(#state{async = PidOrName}, Msg) ->
444    try PidOrName ! {epgsql, self(), Msg} of
445        _ -> true
446    catch error:badarg ->
447            %% no process registered under this name
448            false
449    end.
450
451sync_required(#state{current_cmd = epgsql_cmd_sync} = State) ->
452    State;
453sync_required(#state{current_cmd = undefined} = State) ->
454    State#state{sync_required = true};
455sync_required(State) ->
456    sync_required(finish(State, {error, sync_required})).
457
458flush_queue(#state{current_cmd = undefined} = State, _) ->
459    State;
460flush_queue(State, Error) ->
461    flush_queue(finish(State, Error), Error).
462
463to_binary(B) when is_binary(B) -> B;
464to_binary(L) when is_list(L)   -> list_to_binary(L).
465
466
467%% -- backend message handling --
468
469%% CommandComplete
470on_message(?COMMAND_COMPLETE = Msg, Bin, State) ->
471    Complete = epgsql_wire:decode_complete(Bin),
472    command_handle_message(Msg, Bin, State#state{complete_status = Complete});
473
474%% ReadyForQuery
475on_message(?READY_FOR_QUERY = Msg, <<Status:8>> = Bin, State) ->
476    command_handle_message(Msg, Bin, State#state{txstatus = Status});
477
478%% Error
479on_message(?ERROR = Msg, Err, #state{current_cmd = CurrentCmd} = State) ->
480    Reason = epgsql_wire:decode_error(Err),
481    case CurrentCmd of
482        undefined ->
483            %% Message generated by server asynchronously
484            {stop, {shutdown, Reason}, State};
485        _ ->
486            command_handle_message(Msg, Reason, State)
487    end;
488
489%% NoticeResponse
490on_message(?NOTICE, Data, State) ->
491    notify_async(State, {notice, epgsql_wire:decode_error(Data)}),
492    {noreply, State};
493
494%% ParameterStatus
495on_message(?PARAMETER_STATUS, Data, State) ->
496    [Name, Value] = epgsql_wire:decode_strings(Data),
497    Parameters2 = lists:keystore(Name, 1, State#state.parameters,
498                                 {Name, Value}),
499    {noreply, State#state{parameters = Parameters2}};
500
501%% NotificationResponse
502on_message(?NOTIFICATION, <<Pid:?int32, Strings/binary>>, State) ->
503    {Channel1, Payload1} = case epgsql_wire:decode_strings(Strings) of
504        [Channel, Payload] -> {Channel, Payload};
505        [Channel]          -> {Channel, <<>>}
506    end,
507    notify_async(State, {notification, Channel1, Pid, Payload1}),
508    {noreply, State};
509
510%% ParseComplete
511%% ParameterDescription
512%% RowDescription
513%% NoData
514%% BindComplete
515%% CloseComplete
516%% DataRow
517%% PortalSuspended
518%% EmptyQueryResponse
519%% CopyData
520%% CopyBothResponse
521on_message(Msg, Payload, State) ->
522    command_handle_message(Msg, Payload, State).
523
524
525%% CopyData for Replication mode
526on_replication(?COPY_DATA, <<?PRIMARY_KEEPALIVE_MESSAGE:8, LSN:?int64, _Timestamp:?int64, ReplyRequired:8>>,
527               #state{repl = #repl{last_flushed_lsn = LastFlushedLSN,
528                                   last_applied_lsn = LastAppliedLSN} = Repl} = State) ->
529    Repl1 =
530        case ReplyRequired of
531            1 ->
532                send(State, ?COPY_DATA,
533                     epgsql_wire:encode_standby_status_update(LSN, LastFlushedLSN, LastAppliedLSN)),
534                Repl#repl{feedback_required = false,
535                          last_received_lsn = LSN};
536            _ ->
537                Repl#repl{feedback_required = true,
538                          last_received_lsn = LSN}
539        end,
540    {noreply, State#state{repl = Repl1}};
541
542%% CopyData for Replication mode
543on_replication(?COPY_DATA, <<?X_LOG_DATA, StartLSN:?int64, EndLSN:?int64,
544                             _Timestamp:?int64, WALRecord/binary>>,
545               #state{repl = Repl} = State) ->
546    Repl1 = handle_xlog_data(StartLSN, EndLSN, WALRecord, Repl),
547    {noreply, State#state{repl = Repl1}};
548on_replication(?ERROR, Err, State) ->
549    Reason = epgsql_wire:decode_error(Err),
550    {stop, {error, Reason}, State};
551on_replication(M, Data, Sock) when M == ?NOTICE;
552                                   M == ?NOTIFICATION;
553                                   M == ?PARAMETER_STATUS ->
554    on_message(M, Data, Sock).
555
556
557handle_xlog_data(StartLSN, EndLSN, WALRecord, #repl{cbmodule = undefined,
558                                                    receiver = Receiver} = Repl) ->
559    %% with async messages
560    Receiver ! {epgsql, self(), {x_log_data, StartLSN, EndLSN, WALRecord}},
561    Repl#repl{feedback_required = true,
562              last_received_lsn = EndLSN};
563handle_xlog_data(StartLSN, EndLSN, WALRecord,
564                 #repl{cbmodule = CbModule, cbstate = CbState, receiver = undefined} = Repl) ->
565    %% with callback method
566    {ok, LastFlushedLSN, LastAppliedLSN, NewCbState} =
567        epgsql:handle_x_log_data(CbModule, StartLSN, EndLSN, WALRecord, CbState),
568    Repl#repl{feedback_required = true,
569              last_received_lsn = EndLSN,
570              last_flushed_lsn = LastFlushedLSN,
571              last_applied_lsn = LastAppliedLSN,
572              cbstate = NewCbState}.
573