1 /*
2 * %CopyrightBegin%
3 *
4 * Copyright Ericsson AB 2006-2016. All Rights Reserved.
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 *
18 * %CopyrightEnd%
19 */
20 /*
21 * Purpose: A driver using libpq to connect to Postgres
22 * from erlang, a sample for the driver documentation
23 */
24
25 #include <erl_driver.h>
26
27 #include <libpq-fe.h>
28
29 #include <ei.h>
30
31 #include <stdlib.h>
32 #include <stdio.h>
33 #include <string.h>
34
35 #include "pg_encode.h"
36
37 /* Driver interface declarations */
38 static ErlDrvData start(ErlDrvPort port, char *command);
39 static void stop(ErlDrvData drv_data);
40 static int control(ErlDrvData drv_data, unsigned int command, char *buf,
41 int len, char **rbuf, int rlen);
42 static void ready_io(ErlDrvData drv_data, ErlDrvEvent event);
43
44 static ErlDrvEntry pq_driver_entry = {
45 NULL, /* init */
46 start,
47 stop,
48 NULL, /* output */
49 ready_io, /* ready_input */
50 ready_io, /* ready_output */
51 "pg_async", /* the name of the driver */
52 NULL, /* finish */
53 NULL, /* handle */
54 control,
55 NULL, /* timeout */
56 NULL, /* outputv */
57 NULL, /* ready_async */
58 NULL, /* flush */
59 NULL, /* call */
60 NULL /* event */
61 };
62
63 typedef struct our_data_t {
64 PGconn* conn;
65 ErlDrvPort port;
66 int socket;
67 int connecting;
68 } our_data_t;
69
70 /* Keep the following definitions in alignment with the FUNC_LIST
71 * in erl_pq_sync.erl
72 */
73
74 #define DRV_CONNECT 'C'
75 #define DRV_DISCONNECT 'D'
76 #define DRV_SELECT 'S'
77
78 /* #define L fprintf(stderr, "%d\r\n", __LINE__) */
79
80 /* INITIALIZATION AFTER LOADING */
81
82 /*
83 * This is the init function called after this driver has been loaded.
84 * It must *not* be declared static. Must return the address to
85 * the driver entry.
86 */
DRIVER_INIT(pq_drv)87 DRIVER_INIT(pq_drv)
88 {
89 return &pq_driver_entry;
90 }
91
92 static char* get_s(const char* buf, int len);
93 static int do_connect(const char *s, our_data_t* data);
94 static int do_disconnect(our_data_t* data);
95 static int do_select(const char* s, our_data_t* data);
96
97 /* DRIVER INTERFACE */
start(ErlDrvPort port,char * command)98 static ErlDrvData start(ErlDrvPort port, char *command)
99 {
100 our_data_t* data = driver_alloc(sizeof(our_data_t));
101 data->port = port;
102 data->conn = NULL;
103 return (ErlDrvData)data;
104 }
105
stop(ErlDrvData drv_data)106 static void stop(ErlDrvData drv_data)
107 {
108 do_disconnect((our_data_t*)drv_data);
109 }
110
control(ErlDrvData drv_data,unsigned int command,char * buf,int len,char ** rbuf,int rlen)111 static int control(ErlDrvData drv_data, unsigned int command, char *buf,
112 int len, char **rbuf, int rlen)
113 {
114 int r;
115 char* s = get_s(buf, len);
116 our_data_t* data = (our_data_t*)drv_data;
117 switch (command) {
118 case DRV_CONNECT: r = do_connect(s, data); break;
119 case DRV_DISCONNECT: r = do_disconnect(data); break;
120 case DRV_SELECT: r = do_select(s, data); break;
121 default: r = -1; break;
122 }
123 driver_free(s);
124 return r;
125 }
126
do_connect(const char * s,our_data_t * data)127 static int do_connect(const char *s, our_data_t* data)
128 {
129 PGconn* conn = PQconnectStart(s);
130 if (PQstatus(conn) == CONNECTION_BAD) {
131 ei_x_buff x;
132 ei_x_new_with_version(&x);
133 encode_error(&x, conn);
134 PQfinish(conn);
135 conn = NULL;
136 driver_output(data->port, x.buff, x.index);
137 ei_x_free(&x);
138 }
139 PQconnectPoll(conn);
140 int socket = PQsocket(conn);
141 data->socket = socket;
142 driver_select(data->port, (ErlDrvEvent)socket, DO_READ, 1);
143 driver_select(data->port, (ErlDrvEvent)socket, DO_WRITE, 1);
144 data->conn = conn;
145 data->connecting = 1;
146 return 0;
147 }
148
do_disconnect(our_data_t * data)149 static int do_disconnect(our_data_t* data)
150 {
151 ei_x_buff x;
152 driver_select(data->port, (ErlDrvEvent)data->socket, DO_READ, 0);
153 driver_select(data->port, (ErlDrvEvent)data->socket, DO_WRITE, 0);
154 PQfinish(data->conn);
155 data->conn = NULL;
156 ei_x_new_with_version(&x);
157 encode_ok(&x);
158 driver_output(data->port, x.buff, x.index);
159 ei_x_free(&x);
160 return 0;
161 }
162
do_select(const char * s,our_data_t * data)163 static int do_select(const char* s, our_data_t* data)
164 {
165 data->connecting = 0;
166 PGconn* conn = data->conn;
167 /* if there's an error return it now */
168 if (PQsendQuery(conn, s) == 0) {
169 ei_x_buff x;
170 ei_x_new_with_version(&x);
171 encode_error(&x, conn);
172 driver_output(data->port, x.buff, x.index);
173 ei_x_free(&x);
174 }
175 /* else wait for ready_output to get results */
176 return 0;
177 }
178
ready_io(ErlDrvData drv_data,ErlDrvEvent event)179 static void ready_io(ErlDrvData drv_data, ErlDrvEvent event)
180 {
181 PGresult* res = NULL;
182 our_data_t* data = (our_data_t*)drv_data;
183 PGconn* conn = data->conn;
184 ei_x_buff x;
185 ei_x_new_with_version(&x);
186 if (data->connecting) {
187 ConnStatusType status;
188 PQconnectPoll(conn);
189 status = PQstatus(conn);
190 if (status == CONNECTION_OK)
191 encode_ok(&x);
192 else if (status == CONNECTION_BAD)
193 encode_error(&x, conn);
194 } else {
195 PQconsumeInput(conn);
196 if (PQisBusy(conn))
197 return;
198 res = PQgetResult(conn);
199 encode_result(&x, res, conn);
200 PQclear(res);
201 for (;;) {
202 res = PQgetResult(conn);
203 if (res == NULL)
204 break;
205 PQclear(res);
206 }
207 }
208 if (x.index > 1) {
209 driver_output(data->port, x.buff, x.index);
210 if (data->connecting)
211 driver_select(data->port, (ErlDrvEvent)data->socket, DO_WRITE, 0);
212 }
213 ei_x_free(&x);
214 }
215
216 /* utilities */
get_s(const char * buf,int len)217 static char* get_s(const char* buf, int len)
218 {
219 char* result;
220 if (len < 1 || len > 1000) return NULL;
221 result = driver_alloc(len+1);
222 memcpy(result, buf, len);
223 result[len] = '\0';
224 return result;
225 }
226