1 #include <sys/types.h>
2 #include <sys/stat.h>
3 #include <sys/socket.h>
4 #include <sys/un.h>
5 #include <stdlib.h>
6 #include <stdio.h>
7 #include <unistd.h>
8 #include <string.h>
9 #include <errno.h>
10 
11 #include "common.h"
12 #include "sqlcached_client.h"
13 
14 struct sc_line {
15     char *buf;
16     unsigned int buf_len;
17     unsigned int buf_size;
18 };
19 
20 static void sc_line_init(struct sc_line *line);
21 static void sc_line_read(struct sc_line *line, FILE *f);
22 static void sc_line_done(struct sc_line *line);
23 static void sc_line_forget(struct sc_line *line);
24 
25 #define BUF_SIZE 200
26 
27 
sc_connect_unix(char * unix_socket_name,struct sc_client_conn * sc,char ** errormsg)28 int sc_connect_unix(char *unix_socket_name, struct sc_client_conn *sc, char **errormsg) {
29     struct sockaddr_un sun;
30     char sig[20];
31 
32     bzero(sc, sizeof(*sc));
33 
34     sun.sun_family = AF_LOCAL;
35 #ifdef __BSD_VISIBLE
36     sun.sun_len = strlen(unix_socket_name);
37 #endif
38     strcpy(sun.sun_path, unix_socket_name);
39 
40     if ((sc->sock_server = socket(AF_LOCAL, SOCK_STREAM, 0)) < 0) {
41         asprintf(errormsg, "Cannot create socket: %s", strerror(errno));
42         return SC_NET_ERROR;
43     }
44 
45     if (connect(sc->sock_server, (struct sockaddr*)&sun, SUN_LEN(&sun)) < 0) {
46         asprintf(errormsg, "Cannot connect to %s: %s", unix_socket_name, strerror(errno));
47         return SC_NET_ERROR;
48     }
49 
50     sc->server = fdopen(sc->sock_server, "r+");
51     fprintf(sc->server, "VER %s\r\n", SC_VER_SIG);
52     if (fscanf(sc->server, "+VER %15[^\r\n]%*2[\r\n]", sig) != 1) {
53         asprintf(errormsg, "Handshake error");
54         return SC_HANDSHAKE_ERROR;
55     }
56     if (strncmp(sig, SC_VER_SIG, 15) != 0) {
57         asprintf(errormsg, "Expecting version signature %s, got %s\n", SC_VER_SIG, sig);
58         return SC_SIG_ERROR;
59     }
60 
61     return SC_OK;
62 }
63 
64 
sc_get_ver(struct sc_client_conn * sc,char * ver)65 int sc_get_ver(struct sc_client_conn *sc, char *ver) {
66     fprintf(sc->server, "VER %s\r\n", SC_VER_SIG);
67     if (fscanf(sc->server, "+VER %15[^\r\n]%*2[\r\n]", ver) != 1)
68         return SC_HANDSHAKE_ERROR;
69     return SC_OK;
70 }
71 
72 
sc_exec_query(struct sc_client_conn * sc,char * sql,struct sc_result_set * rs,char ** errormsg)73 int sc_exec_query(struct sc_client_conn *sc, char *sql, struct sc_result_set *rs, char **errormsg) {
74     char resp[104], tmpmsg[104];;
75     int n_rows = -1, n_cols = -1;
76     fprintf(sc->server, "SQL %s\r\n", sql);
77     if (fscanf(sc->server, "%100[^\r\n]%*2[\r\n]", resp) != 1) {
78         if (errormsg != NULL)
79             asprintf(errormsg, "Error reading response line\n");
80         return SC_NET_ERROR;
81     }
82 
83     /*fprintf(stderr, "resp: %s\n", resp);*/
84     if (rs != NULL)
85         bzero(rs, sizeof *rs);
86 
87     if (sscanf(resp, "+REC %d, %d", &n_rows, &n_cols) == 2) {
88         int i, j;
89         struct sc_line line;
90 
91         /*printf("n_rows: %d, n_cols: %d\n", n_rows, n_cols);*/
92         if (rs != NULL) {
93             rs->n_rows = n_rows;
94             rs->n_cols = n_cols;
95             rs->names = malloc(n_cols * sizeof(char*));
96         }
97         bzero(&line, sizeof line);
98 
99         for (i = 0; i < n_cols; i++) { /* read header line (column names) */
100             sc_line_init(&line);
101             sc_line_read(&line, sc->server);
102             if (rs != NULL) {
103                 rs->names[i] = line.buf;
104                 sc_line_forget(&line);
105             } else
106                 sc_line_done(&line);
107         }
108         if (rs != NULL)
109             rs->rows = malloc(n_rows * sizeof(char**));
110         for (i = 0; i < n_rows; i++) { /* read rows */
111             if (rs != NULL)
112                 rs->rows[i] = malloc(n_cols * sizeof(char*));
113             for (j = 0; j < n_cols; j++) {
114                 sc_line_init(&line);
115                 sc_line_read(&line, sc->server);
116                 if (rs != NULL) {
117                     rs->rows[i][j] = line.buf;
118                     sc_line_forget(&line);
119                 } else
120                     sc_line_done(&line);
121             }
122         }
123         if (errormsg != NULL)
124             *errormsg = NULL;
125         return SC_OK;
126     } else if (sscanf(resp, "+NREC %100[^\r\n]", tmpmsg) == 1) {
127         if (errormsg != NULL)
128             *errormsg = NULL;
129         if (rs != NULL)
130             rs->n_rows = atoi(tmpmsg);
131         return SC_OK;
132     }
133     if (errormsg != NULL)
134         *errormsg = strdup(resp);
135     return SC_SQL_ERROR;
136 }
137 
138 
sc_free_result(struct sc_client_conn * sc,struct sc_result_set * rs)139 void sc_free_result(struct sc_client_conn *sc, struct sc_result_set *rs) {
140     int i, j;
141 
142     for (i = 0; i < rs->n_cols; i++)
143         free(rs->names[i]);
144 
145     for (i = 0; i < rs->n_rows; i++)
146         for (j = 0; j < rs->n_cols; j++)
147             free(rs->rows[i][j]);
148 }
149 
150 
sc_line_init(struct sc_line * line)151 static void sc_line_init(struct sc_line *line) {
152     if (line->buf == NULL) {
153         line->buf = malloc(BUF_SIZE);
154         line->buf_size = BUF_SIZE;
155     }
156     line->buf_len = 0;
157     bzero(line->buf, line->buf_size);
158 }
159 
160 
161 /* TODO: replace bzero() calls with on-the-fly \0-s */
sc_line_read(struct sc_line * line,FILE * f)162 static void sc_line_read(struct sc_line *line, FILE *f) {
163     int ch;
164     while (1) {
165         ch = getc_unlocked(f);
166         if (ch < 0)
167             return;
168         if (ch == '\n')
169             return;
170         if (ch == '\r') {
171             ch = getc_unlocked(f);
172             if (ch != '\n')
173                 ungetc(ch, f);
174             return;
175         }
176         if (line->buf_len >= line->buf_size) { /* resize buffer */
177             line->buf = realloc(line->buf, line->buf_size + BUF_SIZE);
178             if (line->buf == NULL) {
179                 line->buf_len = 0;
180                 return;
181             }
182             bzero(line->buf + line->buf_size, BUF_SIZE);
183             line->buf_size += BUF_SIZE;
184         }
185         line->buf[line->buf_len++] = (char)ch;
186     }
187 }
188 
189 
sc_line_done(struct sc_line * line)190 static void sc_line_done(struct sc_line *line) {
191     free(line->buf);
192     bzero(line, sizeof *line);
193 }
194 
195 
sc_line_forget(struct sc_line * line)196 static void sc_line_forget(struct sc_line *line) {
197     bzero(line, sizeof *line);
198 }
199