1%%% File    : pgsql_util.erl
2%%% Author  : Christian Sunesson
3%%% Description : utility functions used in implementation of
4%%%               postgresql driver.
5%%% Created : 11 May 2005 by Blah <cos@local>
6
7-module(pgsql_util).
8
9%% Key-Value handling
10-export([option/2]).
11
12%% Networking
13-export([socket/1]).
14-export([send/2, send_int/2, send_msg/3]).
15-export([recv_msg/2, recv_msg/1, recv_byte/2, recv_byte/1]).
16
17%% Protocol packing
18-export([string/1, make_pair/2, split_pair/1]).
19-export([split_pair_rec/1]).
20-export([count_string/1, to_string/1]).
21-export([oids/2, coldescs/2, datacoldescs/3, int16/2]).
22-export([decode_row/2, decode_descs/1]).
23-export([errordesc/1]).
24
25-export([zip/2]).
26
27%% Constructing authentication messages.
28-export([pass_plain/1, pass_md5/3]).
29-import(erlang, [md5/1]).
30-export([hexlist/2]).
31
32%% Lookup key in a plist stored in process dictionary under 'options'.
33%% Default is returned if there is no value for Key in the plist.
34option(Key, Default) ->
35    Plist = get(options),
36    case proplists:get_value(Key, Plist, Default) of
37	Default ->
38	    Default;
39	Value ->
40	    Value
41    end.
42
43
44%% Open a TCP connection
45socket({tcp, Host, Port}) ->
46    gen_tcp:connect(Host, Port, [{active, false}, binary, {packet, raw}], 5000).
47
48send(Sock, Packet) ->
49    gen_tcp:send(Sock, Packet).
50send_int(Sock, Int) ->
51    Packet = <<Int:32/integer>>,
52    gen_tcp:send(Sock, Packet).
53
54send_msg(Sock, Code, Packet) when binary(Packet) ->
55    Len = size(Packet) + 4,
56    Msg = <<Code:8/integer, Len:4/integer-unit:8, Packet/binary>>,
57    gen_tcp:send(Sock, Msg).
58
59recv_msg(Sock, Timeout) ->
60    {ok, Head} = gen_tcp:recv(Sock, 5, Timeout),
61    <<Code:8/integer, Size:4/integer-unit:8>> = Head,
62    %%io:format("Code: ~p, Size: ~p~n", [Code, Size]),
63    if
64	Size > 4 ->
65	    {ok, Packet} = gen_tcp:recv(Sock, Size-4, Timeout),
66	    {ok, Code, Packet};
67	true ->
68	    {ok, Code, <<>>}
69    end.
70recv_msg(Sock) ->
71    recv_msg(Sock, infinity).
72
73
74recv_byte(Sock) ->
75    recv_byte(Sock, infinity).
76recv_byte(Sock, Timeout) ->
77    case gen_tcp:recv(Sock, 1, Timeout) of
78	{ok, <<Byte:1/integer-unit:8>>} ->
79	    {ok, Byte};
80	E={error, _Reason} ->
81	    throw(E)
82    end.
83
84%% Convert String to binary
85string(String) when list(String) ->
86    Bin = list_to_binary(String),
87    <<Bin/binary, 0/integer>>;
88string(Bin) when binary(Bin) ->
89    <<Bin/binary, 0/integer>>.
90
91%%% Two zero terminated strings.
92make_pair(Key, Value) when atom(Key) ->
93    make_pair(atom_to_list(Key), Value);
94make_pair(Key, Value) when atom(Value) ->
95    make_pair(Key, atom_to_list(Value));
96make_pair(Key, Value) when list(Key), list(Value) ->
97    BinKey = list_to_binary(Key),
98    BinValue = list_to_binary(Value),
99    make_pair(BinKey, BinValue);
100make_pair(Key, Value) when binary(Key), binary(Value) ->
101    <<Key/binary, 0/integer,
102     Value/binary, 0/integer>>.
103
104split_pair(Bin) when binary(Bin) ->
105    split_pair(binary_to_list(Bin));
106split_pair(Str)  ->
107    split_pair_rec(Str, norec).
108
109split_pair_rec(Bin) when binary(Bin) ->
110    split_pair_rec(binary_to_list(Bin));
111split_pair_rec(Arg)  ->
112    split_pair_rec(Arg,[]).
113
114split_pair_rec([], Acc) ->
115    lists:reverse(Acc);
116split_pair_rec([0], Acc) ->
117    lists:reverse(Acc);
118split_pair_rec(S, Acc) ->
119    Fun = fun(C) -> C /= 0 end,
120    {Key, [0|S1]} = lists:splitwith(Fun, S),
121    {Value, [0|Tail]} = lists:splitwith(Fun, S1),
122    case Acc of
123        norec -> {Key, Value};
124        _ ->
125            split_pair_rec(Tail, [{Key, Value}| Acc])
126    end.
127
128
129count_string(Bin) when binary(Bin) ->
130    count_string(Bin, 0).
131
132count_string(<<>>, N) ->
133    {N, <<>>};
134count_string(<<0/integer, Rest/binary>>, N) ->
135    {N, Rest};
136count_string(<<_C/integer, Rest/binary>>, N) ->
137    count_string(Rest, N+1).
138
139to_string(Bin) when binary(Bin) ->
140    {Count, _} = count_string(Bin, 0),
141    <<String:Count/binary, _/binary>> = Bin,
142    {binary_to_list(String), Count}.
143
144oids(<<>>, Oids) ->
145    lists:reverse(Oids);
146oids(<<Oid:32/integer, Rest/binary>>, Oids) ->
147    oids(Rest, [Oid|Oids]).
148
149int16(<<>>, Vals) ->
150    lists:reverse(Vals);
151int16(<<Val:16/integer, Rest/binary>>, Vals) ->
152    int16(Rest, [Val|Vals]).
153
154coldescs(<<>>, Descs) ->
155    lists:reverse(Descs);
156coldescs(Bin, Descs) ->
157    {Name, Count} = to_string(Bin),
158    <<_:Count/binary, 0/integer,
159     TableOID:32/integer,
160     ColumnNumber:16/integer,
161     TypeId:32/integer,
162     TypeSize:16/integer-signed,
163     TypeMod:32/integer-signed,
164     FormatCode:16/integer,
165     Rest/binary>> = Bin,
166    Format = case FormatCode of
167		 0 -> text;
168		 1 -> binary
169	     end,
170    Desc = {Name, Format, ColumnNumber,
171	    TypeId, TypeSize, TypeMod,
172	    TableOID},
173    coldescs(Rest, [Desc|Descs]).
174
175datacoldescs(N,
176	     <<Len:32/integer, Data:Len/binary, Rest/binary>>,
177	     Descs) when N >= 0 ->
178    datacoldescs(N-1, Rest, [Data|Descs]);
179datacoldescs(_N, _, Descs) ->
180    lists:reverse(Descs).
181
182decode_descs(Cols) ->
183    decode_descs(Cols, []).
184decode_descs([], Descs) ->
185    {ok, lists:reverse(Descs)};
186decode_descs([Col|ColTail], Descs) ->
187    OidMap = get(oidmap),
188    {Name, Format, ColNumber, Oid, _, _, _} = Col,
189    OidName = dict:fetch(Oid, OidMap),
190    decode_descs(ColTail, [{Name, Format, ColNumber, OidName, [], [], []}|Descs]).
191
192decode_row(Types, Values) ->
193    decode_row(Types, Values, []).
194decode_row([], [], Out) ->
195    {ok, lists:reverse(Out)};
196decode_row([Type|TypeTail], [Value|ValueTail], Out0) ->
197    Out1 = decode_col(Type, Value),
198    decode_row(TypeTail, ValueTail, [Out1|Out0]).
199
200decode_col({_, text, _, _, _, _, _}, Value) ->
201    binary_to_list(Value);
202decode_col({_Name, _Format, _ColNumber, varchar, _Size, _Modifier, _TableOID}, Value) ->
203    binary_to_list(Value);
204decode_col({_Name, _Format, _ColNumber, int4, _Size, _Modifier, _TableOID}, Value) ->
205    <<Int4:32/integer>> = Value,
206    Int4;
207decode_col({_Name, _Format, _ColNumber, Oid, _Size, _Modifier, _TableOID}, Value) ->
208    {Oid, Value}.
209
210errordesc(Bin) ->
211    errordesc(Bin, []).
212
213errordesc(<<0/integer, _Rest/binary>>, Lines) ->
214    lists:reverse(Lines);
215errordesc(<<Code/integer, Rest/binary>>, Lines) ->
216    {String, Count} = to_string(Rest),
217    <<_:Count/binary, 0, Rest1/binary>> = Rest,
218    Msg = case Code of
219	      $S ->
220		  {severity, list_to_atom(String)};
221	      $C ->
222		  {code, String};
223	      $M ->
224		  {message, String};
225	      $D ->
226		  {detail, String};
227	      $H ->
228		  {hint, String};
229	      $P ->
230		  {position, list_to_integer(String)};
231	      $p ->
232		  {internal_position, list_to_integer(String)};
233	      $W ->
234		  {where, String};
235	      $F ->
236		  {file, String};
237	      $L ->
238		  {line, list_to_integer(String)};
239	      $R ->
240		  {routine, String};
241	      Unknown ->
242		  {Unknown, String}
243	  end,
244    errordesc(Rest1, [Msg|Lines]).
245
246%%% Zip two lists together
247zip(List1, List2) ->
248    zip(List1, List2, []).
249zip(List1, List2, Result) when List1 =:= [];
250			       List2 =:= [] ->
251    lists:reverse(Result);
252zip([H1|List1], [H2|List2], Result) ->
253    zip(List1, List2, [{H1, H2}|Result]).
254
255%%% Authentication utils
256
257pass_plain(Password) ->
258	Pass = [Password, 0],
259	list_to_binary(Pass).
260
261%% MD5 authentication patch from
262%%    Juhani Rankimies <juhani@juranki.com>
263%% (patch slightly rewritten, new bugs are mine :] /Christian Sunesson)
264
265%%
266%% MD5(MD5(password + user) + salt)
267%%
268
269pass_md5(User, Password, Salt) ->
270    Digest = hex(md5([Password, User])),
271    Encrypt = hex(md5([Digest, Salt])),
272    Pass = ["md5", Encrypt, 0],
273    list_to_binary(Pass).
274
275hex(B) when binary(B) ->
276    hexlist(binary_to_list(B), []).
277
278hexlist([], Acc) ->
279    lists:reverse(Acc);
280hexlist([N|Rest], Acc) ->
281    HighNibble = (N band 16#f0) bsr 4,
282    LowNibble = (N band 16#0f),
283    hexlist(Rest, [hexdigit(LowNibble), hexdigit(HighNibble)|Acc]).
284
285hexdigit(0) -> $0;
286hexdigit(1) -> $1;
287hexdigit(2) -> $2;
288hexdigit(3) -> $3;
289hexdigit(4) -> $4;
290hexdigit(5) -> $5;
291hexdigit(6) -> $6;
292hexdigit(7) -> $7;
293hexdigit(8) -> $8;
294hexdigit(9) -> $9;
295hexdigit(10) -> $a;
296hexdigit(11) -> $b;
297hexdigit(12) -> $c;
298hexdigit(13) -> $d;
299hexdigit(14) -> $e;
300hexdigit(15) -> $f.
301