1%%% Copyright (C) 2009 - Will Glozer. All rights reserved. 2%%% Copyright (C) 2011 - Anton Lebedevich. All rights reserved. 3 4%%% @doc GenServer holding all connection state (including socket). 5%%% 6%%% See https://www.postgresql.org/docs/current/static/protocol-flow.html 7%%% Commands in PostgreSQL are pipelined: you don't need to wait for reply to 8%%% be able to send next command. 9%%% Commands are processed (and responses to them are generated) in FIFO order. 10%%% eg, if you execute 2 SimpleQuery: #1 and #2, first you get all response 11%%% packets for #1 and then all for #2: 12%%% > SQuery #1 13%%% > SQuery #2 14%%% < RowDescription #1 15%%% < DataRow #1 16%%% < CommandComplete #1 17%%% < RowDescription #2 18%%% < DataRow #2 19%%% < CommandComplete #2 20%%% 21%%% See epgsql_cmd_connect for network connection and authentication setup 22 23 24-module(epgsql_sock). 25 26-behavior(gen_server). 27 28-export([start_link/0, 29 close/1, 30 sync_command/3, 31 async_command/4, 32 get_parameter/2, 33 set_notice_receiver/2, 34 get_cmd_status/1, 35 cancel/1]). 36 37-export([handle_call/3, handle_cast/2, handle_info/2]). 38-export([init/1, code_change/3, terminate/2]). 39 40%% loop callback 41-export([on_message/3, on_replication/3]). 42 43%% Comand's APIs 44-export([set_net_socket/3, init_replication_state/1, set_attr/3, get_codec/1, 45 get_rows/1, get_results/1, notify/2, send/2, send/3, send_multi/2, 46 get_parameter_internal/2, 47 get_replication_state/1, set_packet_handler/2]). 48 49-export_type([transport/0, pg_sock/0]). 50 51-include("epgsql.hrl"). 52-include("protocol.hrl"). 53-include("epgsql_replication.hrl"). 54 55-type transport() :: {call, any()} 56 | {cast, pid(), reference()} 57 | {incremental, pid(), reference()}. 58 59-type tcp_socket() :: port(). %gen_tcp:socket() isn't exported prior to erl 18 60-type repl_state() :: #repl{}. 61 62-record(state, {mod :: gen_tcp | ssl | undefined, 63 sock :: tcp_socket() | ssl:sslsocket() | undefined, 64 data = <<>>, 65 backend :: {Pid :: integer(), Key :: integer()} | undefined, 66 handler = on_message :: on_message | on_replication | undefined, 67 codec :: epgsql_binary:codec() | undefined, 68 queue = queue:new() :: queue:queue({epgsql_command:command(), any(), transport()}), 69 current_cmd :: epgsql_command:command() | undefined, 70 current_cmd_state :: any() | undefined, 71 current_cmd_transport :: transport() | undefined, 72 async :: undefined | atom() | pid(), 73 parameters = [] :: [{Key :: binary(), Value :: binary()}], 74 rows = [] :: [tuple()], 75 results = [], 76 sync_required :: boolean() | undefined, 77 txstatus :: byte() | undefined, % $I | $T | $E, 78 complete_status :: atom() | {atom(), integer()} | undefined, 79 repl :: repl_state() | undefined}). 80 81-opaque pg_sock() :: #state{}. 82 83%% -- client interface -- 84 85start_link() -> 86 gen_server:start_link(?MODULE, [], []). 87 88close(C) when is_pid(C) -> 89 catch gen_server:cast(C, stop), 90 ok. 91 92-spec sync_command(epgsql:connection(), epgsql_command:command(), any()) -> any(). 93sync_command(C, Command, Args) -> 94 gen_server:call(C, {command, Command, Args}, infinity). 95 96-spec async_command(epgsql:connection(), cast | incremental, 97 epgsql_command:command(), any()) -> reference(). 98async_command(C, Transport, Command, Args) -> 99 Ref = make_ref(), 100 Pid = self(), 101 ok = gen_server:cast(C, {{Transport, Pid, Ref}, Command, Args}), 102 Ref. 103 104get_parameter(C, Name) -> 105 gen_server:call(C, {get_parameter, to_binary(Name)}, infinity). 106 107set_notice_receiver(C, PidOrName) when is_pid(PidOrName); 108 is_atom(PidOrName) -> 109 gen_server:call(C, {set_async_receiver, PidOrName}, infinity). 110 111get_cmd_status(C) -> 112 gen_server:call(C, get_cmd_status, infinity). 113 114cancel(S) -> 115 gen_server:cast(S, cancel). 116 117 118%% -- command APIs -- 119 120%% send() 121%% send_many() 122 123-spec set_net_socket(gen_tcp | ssl, tcp_socket() | ssl:sslsocket(), pg_sock()) -> pg_sock(). 124set_net_socket(Mod, Socket, State) -> 125 State1 = State#state{mod = Mod, sock = Socket}, 126 setopts(State1, [{active, true}]), 127 State1. 128 129-spec init_replication_state(pg_sock()) -> pg_sock(). 130init_replication_state(State) -> 131 State#state{repl = #repl{}}. 132 133-spec set_attr(atom(), any(), pg_sock()) -> pg_sock(). 134set_attr(backend, {_Pid, _Key} = Backend, State) -> 135 State#state{backend = Backend}; 136set_attr(async, Async, State) -> 137 State#state{async = Async}; 138set_attr(txstatus, Status, State) -> 139 State#state{txstatus = Status}; 140set_attr(codec, Codec, State) -> 141 State#state{codec = Codec}; 142set_attr(sync_required, Value, State) -> 143 State#state{sync_required = Value}; 144set_attr(replication_state, Value, State) -> 145 State#state{repl = Value}. 146 147%% XXX: be careful! 148-spec set_packet_handler(atom(), pg_sock()) -> pg_sock(). 149set_packet_handler(Handler, State) -> 150 State#state{handler = Handler}. 151 152-spec get_codec(pg_sock()) -> epgsql_binary:codec(). 153get_codec(#state{codec = Codec}) -> 154 Codec. 155 156-spec get_replication_state(pg_sock()) -> repl_state(). 157get_replication_state(#state{repl = Repl}) -> 158 Repl. 159 160-spec get_rows(pg_sock()) -> [tuple()]. 161get_rows(#state{rows = Rows}) -> 162 lists:reverse(Rows). 163 164-spec get_results(pg_sock()) -> [any()]. 165get_results(#state{results = Results}) -> 166 lists:reverse(Results). 167 168-spec get_parameter_internal(binary(), pg_sock()) -> binary() | undefined. 169get_parameter_internal(Name, #state{parameters = Parameters}) -> 170 case lists:keysearch(Name, 1, Parameters) of 171 {value, {Name, Value}} -> Value; 172 false -> undefined 173 end. 174 175 176%% -- gen_server implementation -- 177 178init([]) -> 179 {ok, #state{}}. 180 181handle_call({get_parameter, Name}, _From, State) -> 182 {reply, {ok, get_parameter_internal(Name, State)}, State}; 183 184handle_call({set_async_receiver, PidOrName}, _From, #state{async = Previous} = State) -> 185 {reply, {ok, Previous}, State#state{async = PidOrName}}; 186 187handle_call(get_cmd_status, _From, #state{complete_status = Status} = State) -> 188 {reply, {ok, Status}, State}; 189 190handle_call({standby_status_update, FlushedLSN, AppliedLSN}, _From, 191 #state{handler = on_replication, 192 repl = #repl{last_received_lsn = ReceivedLSN} = Repl} = State) -> 193 send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN)), 194 Repl1 = Repl#repl{last_flushed_lsn = FlushedLSN, 195 last_applied_lsn = AppliedLSN}, 196 {reply, ok, State#state{repl = Repl1}}; 197handle_call({command, Command, Args}, From, State) -> 198 Transport = {call, From}, 199 command_new(Transport, Command, Args, State). 200 201handle_cast({{Method, From, Ref} = Transport, Command, Args}, State) 202 when ((Method == cast) or (Method == incremental)), 203 is_pid(From), 204 is_reference(Ref) -> 205 command_new(Transport, Command, Args, State); 206 207handle_cast(stop, State) -> 208 {stop, normal, flush_queue(State, {error, closed})}; 209 210handle_cast(cancel, State = #state{backend = {Pid, Key}, 211 sock = TimedOutSock}) -> 212 {ok, {Addr, Port}} = case State#state.mod of 213 gen_tcp -> inet:peername(TimedOutSock); 214 ssl -> ssl:peername(TimedOutSock) 215 end, 216 SockOpts = [{active, false}, {packet, raw}, binary], 217 %% TODO timeout 218 {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts), 219 Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>, 220 ok = gen_tcp:send(Sock, Msg), 221 gen_tcp:close(Sock), 222 {noreply, State}. 223 224handle_info({Closed, Sock}, #state{sock = Sock} = State) 225 when Closed == tcp_closed; Closed == ssl_closed -> 226 {stop, sock_closed, flush_queue(State#state{sock = undefined}, {error, sock_closed})}; 227 228handle_info({Error, Sock, Reason}, #state{sock = Sock} = State) 229 when Error == tcp_error; Error == ssl_error -> 230 Why = {sock_error, Reason}, 231 {stop, Why, flush_queue(State, {error, Why})}; 232 233handle_info({inet_reply, _, ok}, State) -> 234 {noreply, State}; 235 236handle_info({inet_reply, _, Status}, State) -> 237 {stop, Status, flush_queue(State, {error, Status})}; 238 239handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) -> 240 loop(State#state{data = <<Data/binary, Data2/binary>>}). 241 242terminate(_Reason, #state{sock = undefined}) -> ok; 243terminate(_Reason, #state{mod = gen_tcp, sock = Sock}) -> gen_tcp:close(Sock); 244terminate(_Reason, #state{mod = ssl, sock = Sock}) -> ssl:close(Sock). 245 246code_change(_OldVsn, State, _Extra) -> 247 {ok, State}. 248 249%% -- internal functions -- 250 251-spec command_new(transport(), epgsql_command:command(), any(), pg_sock()) -> 252 Result when 253 Result :: {noreply, pg_sock()} 254 | {stop, Reason :: any(), pg_sock()}. 255command_new(Transport, Command, Args, State) -> 256 CmdState = epgsql_command:init(Command, Args), 257 command_exec(Transport, Command, CmdState, State). 258 259-spec command_exec(transport(), epgsql_command:command(), any(), pg_sock()) -> 260 Result when 261 Result :: {noreply, pg_sock()} 262 | {stop, Reason :: any(), pg_sock()}. 263command_exec(Transport, Command, _, State = #state{sync_required = true}) 264 when Command /= epgsql_cmd_sync -> 265 {noreply, 266 finish(State#state{current_cmd = Command, 267 current_cmd_transport = Transport}, 268 {error, sync_required})}; 269command_exec(Transport, Command, CmdState, State) -> 270 case epgsql_command:execute(Command, State, CmdState) of 271 {ok, State1, CmdState1} -> 272 {noreply, command_enqueue(Transport, Command, CmdState1, State1)}; 273 {stop, StopReason, Response, State1} -> 274 reply(Transport, Response, Response), 275 {stop, StopReason, State1} 276 end. 277 278-spec command_enqueue(transport(), epgsql_command:command(), epgsql_command:state(), pg_sock()) -> pg_sock(). 279command_enqueue(Transport, Command, CmdState, #state{current_cmd = undefined} = State) -> 280 State#state{current_cmd = Command, 281 current_cmd_state = CmdState, 282 current_cmd_transport = Transport, 283 complete_status = undefined}; 284command_enqueue(Transport, Command, CmdState, #state{queue = Q} = State) -> 285 State#state{queue = queue:in({Command, CmdState, Transport}, Q), 286 complete_status = undefined}. 287 288-spec command_handle_message(byte(), binary() | epgsql:query_error(), pg_sock()) -> 289 {noreply, pg_sock()} 290 | {stop, any(), pg_sock()}. 291command_handle_message(Msg, Payload, 292 #state{current_cmd = Command, 293 current_cmd_state = CmdState} = State) -> 294 case epgsql_command:handle_message(Command, Msg, Payload, State, CmdState) of 295 {add_row, Row, State1, CmdState1} -> 296 {noreply, add_row(State1#state{current_cmd_state = CmdState1}, Row)}; 297 {add_result, Result, Notice, State1, CmdState1} -> 298 {noreply, 299 add_result(State1#state{current_cmd_state = CmdState1}, 300 Notice, Result)}; 301 {finish, Result, Notice, State1} -> 302 {noreply, finish(State1, Notice, Result)}; 303 {noaction, State1} -> 304 {noreply, State1}; 305 {noaction, State1, CmdState1} -> 306 {noreply, State1#state{current_cmd_state = CmdState1}}; 307 {requeue, State1, CmdState1} -> 308 Transport = State1#state.current_cmd_transport, 309 command_exec(Transport, Command, CmdState1, 310 State1#state{current_cmd = undefined}); 311 {stop, Reason, Response, State1} -> 312 {stop, Reason, finish(State1, Response)}; 313 {sync_required, Why} -> 314 %% Protocol error. Finish and flush all pending commands. 315 {noreply, sync_required(finish(State#state{sync_required = true}, Why))}; 316 unknown -> 317 {stop, {error, {unexpected_message, Msg, Command, CmdState}}, State} 318 end. 319 320command_next(#state{current_cmd = PrevCmd, 321 queue = Q} = State) when PrevCmd =/= undefined -> 322 case queue:out(Q) of 323 {empty, _} -> 324 State#state{current_cmd = undefined, 325 current_cmd_state = undefined, 326 current_cmd_transport = undefined, 327 rows = [], 328 results = []}; 329 {{value, {Command, CmdState, Transport}}, Q1} -> 330 State#state{current_cmd = Command, 331 current_cmd_state = CmdState, 332 current_cmd_transport = Transport, 333 queue = Q1, 334 rows = [], 335 results = []} 336 end. 337 338 339setopts(#state{mod = Mod, sock = Sock}, Opts) -> 340 case Mod of 341 gen_tcp -> inet:setopts(Sock, Opts); 342 ssl -> ssl:setopts(Sock, Opts) 343 end. 344 345%% This one only used in connection initiation to send client's 346%% `StartupMessage' and `SSLRequest' packets 347-spec send(pg_sock(), iodata()) -> ok | {error, any()}. 348send(#state{mod = Mod, sock = Sock}, Data) -> 349 do_send(Mod, Sock, epgsql_wire:encode_command(Data)). 350 351-spec send(pg_sock(), byte(), iodata()) -> ok | {error, any()}. 352send(#state{mod = Mod, sock = Sock}, Type, Data) -> 353 do_send(Mod, Sock, epgsql_wire:encode_command(Type, Data)). 354 355-spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}. 356send_multi(#state{mod = Mod, sock = Sock}, List) -> 357 do_send(Mod, Sock, lists:map(fun({Type, Data}) -> 358 epgsql_wire:encode_command(Type, Data) 359 end, List)). 360 361do_send(gen_tcp, Sock, Bin) -> 362 %% Why not gen_tcp:send/2? 363 %% See https://github.com/rabbitmq/rabbitmq-common/blob/v3.7.4/src/rabbit_writer.erl#L367-L384 364 %% Because of that we also have `handle_info({inet_reply, ...` 365 try erlang:port_command(Sock, Bin) of 366 true -> 367 ok 368 catch 369 error:_Error -> 370 {error, einval} 371 end; 372do_send(ssl, Sock, Bin) -> 373 ssl:send(Sock, Bin). 374 375loop(#state{data = Data, handler = Handler, repl = Repl} = State) -> 376 case epgsql_wire:decode_message(Data) of 377 {Type, Payload, Tail} -> 378 case ?MODULE:Handler(Type, Payload, State#state{data = Tail}) of 379 {noreply, State2} -> 380 loop(State2); 381 R = {stop, _Reason2, _State2} -> 382 R 383 end; 384 _ -> 385 %% in replication mode send feedback after each batch of messages 386 case (Repl =/= undefined) andalso (Repl#repl.feedback_required) of 387 true -> 388 #repl{last_received_lsn = LastReceivedLSN, 389 last_flushed_lsn = LastFlushedLSN, 390 last_applied_lsn = LastAppliedLSN} = Repl, 391 send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update( 392 LastReceivedLSN, LastFlushedLSN, LastAppliedLSN)), 393 {noreply, State#state{repl = Repl#repl{feedback_required = false}}}; 394 _ -> 395 {noreply, State} 396 end 397 end. 398 399finish(State, Result) -> 400 finish(State, Result, Result). 401 402finish(State = #state{current_cmd_transport = Transport}, Notice, Result) -> 403 reply(Transport, Notice, Result), 404 command_next(State). 405 406reply({cast, From, Ref}, _, Result) -> 407 From ! {self(), Ref, Result}; 408reply({incremental, From, Ref}, Notice, _) -> 409 From ! {self(), Ref, Notice}; 410reply({call, From}, _, Result) -> 411 gen_server:reply(From, Result). 412 413add_result(#state{results = Results, current_cmd_transport = Transport} = State, Notice, Result) -> 414 Results2 = case Transport of 415 {incremental, From, Ref} -> 416 From ! {self(), Ref, Notice}, 417 Results; 418 _ -> 419 [Result | Results] 420 end, 421 State#state{rows = [], 422 results = Results2}. 423 424add_row(#state{rows = Rows, current_cmd_transport = Transport} = State, Data) -> 425 Rows2 = case Transport of 426 {incremental, From, Ref} -> 427 From ! {self(), Ref, {data, Data}}, 428 Rows; 429 _ -> 430 [Data | Rows] 431 end, 432 State#state{rows = Rows2}. 433 434notify(#state{current_cmd_transport = {incremental, From, Ref}} = State, Notice) -> 435 From ! {self(), Ref, Notice}, 436 State; 437notify(State, _) -> 438 State. 439 440%% Send asynchronous messages (notice / notification) 441notify_async(#state{async = undefined}, _) -> 442 false; 443notify_async(#state{async = PidOrName}, Msg) -> 444 try PidOrName ! {epgsql, self(), Msg} of 445 _ -> true 446 catch error:badarg -> 447 %% no process registered under this name 448 false 449 end. 450 451sync_required(#state{current_cmd = epgsql_cmd_sync} = State) -> 452 State; 453sync_required(#state{current_cmd = undefined} = State) -> 454 State#state{sync_required = true}; 455sync_required(State) -> 456 sync_required(finish(State, {error, sync_required})). 457 458flush_queue(#state{current_cmd = undefined} = State, _) -> 459 State; 460flush_queue(State, Error) -> 461 flush_queue(finish(State, Error), Error). 462 463to_binary(B) when is_binary(B) -> B; 464to_binary(L) when is_list(L) -> list_to_binary(L). 465 466 467%% -- backend message handling -- 468 469%% CommandComplete 470on_message(?COMMAND_COMPLETE = Msg, Bin, State) -> 471 Complete = epgsql_wire:decode_complete(Bin), 472 command_handle_message(Msg, Bin, State#state{complete_status = Complete}); 473 474%% ReadyForQuery 475on_message(?READY_FOR_QUERY = Msg, <<Status:8>> = Bin, State) -> 476 command_handle_message(Msg, Bin, State#state{txstatus = Status}); 477 478%% Error 479on_message(?ERROR = Msg, Err, #state{current_cmd = CurrentCmd} = State) -> 480 Reason = epgsql_wire:decode_error(Err), 481 case CurrentCmd of 482 undefined -> 483 %% Message generated by server asynchronously 484 {stop, {shutdown, Reason}, State}; 485 _ -> 486 command_handle_message(Msg, Reason, State) 487 end; 488 489%% NoticeResponse 490on_message(?NOTICE, Data, State) -> 491 notify_async(State, {notice, epgsql_wire:decode_error(Data)}), 492 {noreply, State}; 493 494%% ParameterStatus 495on_message(?PARAMETER_STATUS, Data, State) -> 496 [Name, Value] = epgsql_wire:decode_strings(Data), 497 Parameters2 = lists:keystore(Name, 1, State#state.parameters, 498 {Name, Value}), 499 {noreply, State#state{parameters = Parameters2}}; 500 501%% NotificationResponse 502on_message(?NOTIFICATION, <<Pid:?int32, Strings/binary>>, State) -> 503 {Channel1, Payload1} = case epgsql_wire:decode_strings(Strings) of 504 [Channel, Payload] -> {Channel, Payload}; 505 [Channel] -> {Channel, <<>>} 506 end, 507 notify_async(State, {notification, Channel1, Pid, Payload1}), 508 {noreply, State}; 509 510%% ParseComplete 511%% ParameterDescription 512%% RowDescription 513%% NoData 514%% BindComplete 515%% CloseComplete 516%% DataRow 517%% PortalSuspended 518%% EmptyQueryResponse 519%% CopyData 520%% CopyBothResponse 521on_message(Msg, Payload, State) -> 522 command_handle_message(Msg, Payload, State). 523 524 525%% CopyData for Replication mode 526on_replication(?COPY_DATA, <<?PRIMARY_KEEPALIVE_MESSAGE:8, LSN:?int64, _Timestamp:?int64, ReplyRequired:8>>, 527 #state{repl = #repl{last_flushed_lsn = LastFlushedLSN, 528 last_applied_lsn = LastAppliedLSN} = Repl} = State) -> 529 Repl1 = 530 case ReplyRequired of 531 1 -> 532 send(State, ?COPY_DATA, 533 epgsql_wire:encode_standby_status_update(LSN, LastFlushedLSN, LastAppliedLSN)), 534 Repl#repl{feedback_required = false, 535 last_received_lsn = LSN}; 536 _ -> 537 Repl#repl{feedback_required = true, 538 last_received_lsn = LSN} 539 end, 540 {noreply, State#state{repl = Repl1}}; 541 542%% CopyData for Replication mode 543on_replication(?COPY_DATA, <<?X_LOG_DATA, StartLSN:?int64, EndLSN:?int64, 544 _Timestamp:?int64, WALRecord/binary>>, 545 #state{repl = Repl} = State) -> 546 Repl1 = handle_xlog_data(StartLSN, EndLSN, WALRecord, Repl), 547 {noreply, State#state{repl = Repl1}}; 548on_replication(?ERROR, Err, State) -> 549 Reason = epgsql_wire:decode_error(Err), 550 {stop, {error, Reason}, State}; 551on_replication(M, Data, Sock) when M == ?NOTICE; 552 M == ?NOTIFICATION; 553 M == ?PARAMETER_STATUS -> 554 on_message(M, Data, Sock). 555 556 557handle_xlog_data(StartLSN, EndLSN, WALRecord, #repl{cbmodule = undefined, 558 receiver = Receiver} = Repl) -> 559 %% with async messages 560 Receiver ! {epgsql, self(), {x_log_data, StartLSN, EndLSN, WALRecord}}, 561 Repl#repl{feedback_required = true, 562 last_received_lsn = EndLSN}; 563handle_xlog_data(StartLSN, EndLSN, WALRecord, 564 #repl{cbmodule = CbModule, cbstate = CbState, receiver = undefined} = Repl) -> 565 %% with callback method 566 {ok, LastFlushedLSN, LastAppliedLSN, NewCbState} = 567 epgsql:handle_x_log_data(CbModule, StartLSN, EndLSN, WALRecord, CbState), 568 Repl#repl{feedback_required = true, 569 last_received_lsn = EndLSN, 570 last_flushed_lsn = LastFlushedLSN, 571 last_applied_lsn = LastAppliedLSN, 572 cbstate = NewCbState}. 573