1 /*
2  * %CopyrightBegin%
3  *
4  * Copyright Ericsson AB 2006-2016. All Rights Reserved.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  *
18  * %CopyrightEnd%
19  */
20 /*
21  * Purpose: A driver using libpq to connect to Postgres
22  * from erlang, a sample for the driver documentation
23  */
24 
25 #include <erl_driver.h>
26 
27 #include <libpq-fe.h>
28 
29 #include <ei.h>
30 
31 #include <stdlib.h>
32 #include <stdio.h>
33 #include <string.h>
34 
35 #include "pg_encode.h"
36 
37 /* Driver interface declarations */
38 static ErlDrvData start(ErlDrvPort port, char *command);
39 static void stop(ErlDrvData drv_data);
40 static int control(ErlDrvData drv_data, unsigned int command, char *buf,
41 		   int len, char **rbuf, int rlen);
42 static void ready_io(ErlDrvData drv_data, ErlDrvEvent event);
43 
44 static ErlDrvEntry pq_driver_entry = {
45     NULL,			/* init */
46     start,
47     stop,
48     NULL,			/* output */
49     ready_io,			/* ready_input */
50     ready_io,			/* ready_output */
51     "pg_async",                 /* the name of the driver */
52     NULL,			/* finish */
53     NULL,			/* handle */
54     control,
55     NULL,			/* timeout */
56     NULL,			/* outputv */
57     NULL,			/* ready_async */
58     NULL,			/* flush */
59     NULL,			/* call */
60     NULL			/* event */
61 };
62 
63 typedef struct our_data_t {
64     PGconn* conn;
65     ErlDrvPort port;
66     int socket;
67     int connecting;
68 } our_data_t;
69 
70 /* Keep the following definitions in alignment with the FUNC_LIST
71  * in erl_pq_sync.erl
72  */
73 
74 #define DRV_CONNECT             'C'
75 #define DRV_DISCONNECT          'D'
76 #define DRV_SELECT              'S'
77 
78 /* #define L     fprintf(stderr, "%d\r\n", __LINE__) */
79 
80 /* INITIALIZATION AFTER LOADING */
81 
82 /*
83  * This is the init function called after this driver has been loaded.
84  * It must *not* be declared static. Must return the address to
85  * the driver entry.
86  */
DRIVER_INIT(pq_drv)87 DRIVER_INIT(pq_drv)
88 {
89     return &pq_driver_entry;
90 }
91 
92 static char* get_s(const char* buf, int len);
93 static int do_connect(const char *s, our_data_t* data);
94 static int do_disconnect(our_data_t* data);
95 static int do_select(const char* s, our_data_t* data);
96 
97 /* DRIVER INTERFACE */
start(ErlDrvPort port,char * command)98 static ErlDrvData start(ErlDrvPort port, char *command)
99 {
100     our_data_t* data = driver_alloc(sizeof(our_data_t));
101     data->port = port;
102     data->conn = NULL;
103     return (ErlDrvData)data;
104 }
105 
stop(ErlDrvData drv_data)106 static void stop(ErlDrvData drv_data)
107 {
108     do_disconnect((our_data_t*)drv_data);
109 }
110 
control(ErlDrvData drv_data,unsigned int command,char * buf,int len,char ** rbuf,int rlen)111 static int control(ErlDrvData drv_data, unsigned int command, char *buf,
112 		   int len, char **rbuf, int rlen)
113 {
114     int r;
115     char* s = get_s(buf, len);
116     our_data_t* data = (our_data_t*)drv_data;
117     switch (command) {
118     case DRV_CONNECT:     r = do_connect(s, data);  break;
119     case DRV_DISCONNECT:  r = do_disconnect(data);  break;
120     case DRV_SELECT:      r = do_select(s, data);   break;
121     default:              r = -1;	break;
122     }
123     driver_free(s);
124     return r;
125 }
126 
do_connect(const char * s,our_data_t * data)127 static int do_connect(const char *s, our_data_t* data)
128 {
129     PGconn* conn = PQconnectStart(s);
130     if (PQstatus(conn) == CONNECTION_BAD) {
131 	ei_x_buff x;
132 	ei_x_new_with_version(&x);
133         encode_error(&x, conn);
134 	PQfinish(conn);
135 	conn = NULL;
136 	driver_output(data->port, x.buff, x.index);
137 	ei_x_free(&x);
138     }
139     PQconnectPoll(conn);
140     int socket = PQsocket(conn);
141     data->socket = socket;
142     driver_select(data->port, (ErlDrvEvent)socket, DO_READ, 1);
143     driver_select(data->port, (ErlDrvEvent)socket, DO_WRITE, 1);
144     data->conn = conn;
145     data->connecting = 1;
146     return 0;
147 }
148 
do_disconnect(our_data_t * data)149 static int do_disconnect(our_data_t* data)
150 {
151     ei_x_buff x;
152     driver_select(data->port, (ErlDrvEvent)data->socket, DO_READ, 0);
153     driver_select(data->port, (ErlDrvEvent)data->socket, DO_WRITE, 0);
154     PQfinish(data->conn);
155     data->conn = NULL;
156     ei_x_new_with_version(&x);
157     encode_ok(&x);
158     driver_output(data->port, x.buff, x.index);
159     ei_x_free(&x);
160     return 0;
161 }
162 
do_select(const char * s,our_data_t * data)163 static int do_select(const char* s, our_data_t* data)
164 {
165     data->connecting = 0;
166     PGconn* conn = data->conn;
167     /* if there's an error return it now */
168     if (PQsendQuery(conn, s) == 0) {
169 	ei_x_buff x;
170 	ei_x_new_with_version(&x);
171 	encode_error(&x, conn);
172 	driver_output(data->port, x.buff, x.index);
173 	ei_x_free(&x);
174     }
175     /* else wait for ready_output to get results */
176     return 0;
177 }
178 
ready_io(ErlDrvData drv_data,ErlDrvEvent event)179 static void ready_io(ErlDrvData drv_data, ErlDrvEvent event)
180 {
181     PGresult* res = NULL;
182     our_data_t* data = (our_data_t*)drv_data;
183     PGconn* conn = data->conn;
184     ei_x_buff x;
185     ei_x_new_with_version(&x);
186     if (data->connecting) {
187 	ConnStatusType status;
188 	PQconnectPoll(conn);
189 	status = PQstatus(conn);
190 	if (status == CONNECTION_OK)
191 	    encode_ok(&x);
192 	else if (status == CONNECTION_BAD)
193 	    encode_error(&x, conn);
194     } else {
195 	PQconsumeInput(conn);
196 	if (PQisBusy(conn))
197 	    return;
198 	res = PQgetResult(conn);
199 	encode_result(&x, res, conn);
200 	PQclear(res);
201 	for (;;) {
202 	    res = PQgetResult(conn);
203 	    if (res == NULL)
204 		break;
205 	    PQclear(res);
206 	}
207     }
208     if (x.index > 1) {
209 	driver_output(data->port, x.buff, x.index);
210 	if (data->connecting)
211 	    driver_select(data->port, (ErlDrvEvent)data->socket, DO_WRITE, 0);
212     }
213     ei_x_free(&x);
214 }
215 
216 /* utilities */
get_s(const char * buf,int len)217 static char* get_s(const char* buf, int len)
218 {
219     char* result;
220     if (len < 1 || len > 1000) return NULL;
221     result = driver_alloc(len+1);
222     memcpy(result, buf, len);
223     result[len] = '\0';
224     return result;
225 }
226