1 /*
2 * Secure sscanf - sscanf with an additional size argument for string
3 * arguments. All format specifiers should work as in the standard
4 * scanf - except for those writing to a string buffer provided by the
5 * caller. These specifiers take an additional argument of type size_t
6 * that specifies the size of the buffer.
7 *
8 ** This program is free software; you can redistribute it and/or modify
9 ** it under the terms of the GNU General Public License as published by
10 ** the Free Software Foundation; either version 3 of the License, or
11 ** (at your option) any later version.
12 **
13 ** This program is distributed in the hope that it will be useful,
14 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
15 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 ** GNU General Public License for more details.
17 **
18 ** You should have received a copy of the GNU General Public License
19 ** along with GNU gv; see the file COPYING. If not, write to
20 ** the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
21 ** Boston, MA 02111-1307, USA.
22 *
23 * Copyright (C) 2002, Olaf Kirch <okir@suse.de>
24 */
25
26 #define _GNU_SOURCE
27 #include <ac_config.h>
28
29 #include <sys/param.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <stdarg.h>
33 #include <string.h>
34 #include <ctype.h>
35 #include <inttypes.h>
36 #include "secscanf.h"
37
GNU_strnlen(const char * s,size_t len)38 static size_t GNU_strnlen(const char *s, size_t len)
39 {
40 size_t i;
41 for(i=0; i<len && *(s+i); i++);
42 return i;
43 }
44
GNU_strndup(char const * s,size_t n)45 static char* GNU_strndup (char const *s, size_t n)
46 {
47 size_t len = GNU_strnlen (s, n);
48 char *new = malloc (len + 1);
49
50 if (new == NULL)
51 return NULL;
52
53 new[len] = '\0';
54 return memcpy (new, s, len);
55 }
56
57 enum {
58 CONV_ANY,
59 CONV_STR,
60 CONV_NUM,
61 CONV_INTEGER,
62 CONV_FLOAT,
63 CONV_POINTER,
64 };
65
66 enum {
67 SIZE_ANY,
68 SIZE_SHORT,
69 SIZE_LONG,
70 SIZE_QUAD,
71 };
72
73 union scan_value {
74 const char * v_string;
75 long long v_signed;
76 unsigned long long v_integer;
77 long double v_double;
78 void * v_pointer;
79 };
80
81
82 static int process_number(union scan_value *vp, const char **sp, char fmt);
83 static int process_char_class(const char **, const char **, int);
84
85 static inline int
set_conv_type(int * type,int new_type)86 set_conv_type(int *type, int new_type)
87 {
88 switch (*type) {
89 case CONV_ANY:
90 break;
91 case CONV_NUM:
92 if (new_type == CONV_INTEGER
93 || new_type == CONV_FLOAT
94 || new_type == CONV_POINTER)
95 break;
96 /* fallthru */
97 default:
98 if (*type != new_type)
99 return 0;
100 break;
101 }
102
103 *type = new_type;
104 return 1;
105 }
106
107 int
sec_sscanf(const char * s,const char * fmt,...)108 sec_sscanf(const char *s, const char *fmt, ...)
109 {
110 const char *begin = s;
111 int num_fields = 0, fmt_empty = 1;
112 va_list ap;
113
114 va_start(ap, fmt);
115 while (*fmt) {
116 union scan_value value;
117 const char *pre_space_skip,
118 *value_begin;
119 int assign = 1, allocate = 0,
120 conv_type = CONV_ANY,
121 conv_size = SIZE_ANY,
122 field_width = -1,
123 nul_terminated = 1;
124 char c;
125
126 c = *fmt++;
127 if (isspace(c)) {
128 while (isspace(*s))
129 s++;
130 continue;
131 }
132
133 fmt_empty = 0;
134 if (c != '%') {
135 if (c != *s)
136 goto stop;
137 s++;
138 continue;
139 }
140
141 /* Each % directive implicitly skips white space
142 * except for the %c case */
143 pre_space_skip = s;
144 while (isspace(*s))
145 s++;
146
147 while (1) {
148 int type = CONV_ANY, size = SIZE_ANY;
149
150 switch (*fmt) {
151 case '*':
152 assign = 0;
153 break;
154 case 'a':
155 type = CONV_STR;
156 allocate = 1;
157 break;
158 case 'h':
159 type = CONV_INTEGER;
160 size = SIZE_SHORT;
161 break;
162 case 'l':
163 type = CONV_NUM;
164 size = SIZE_LONG;
165 break;
166 case 'L':
167 case 'q':
168 type = CONV_NUM;
169 size = SIZE_QUAD;
170 break;
171 case '0': case '1': case '2': case '3': case '4':
172 case '5': case '6': case '7': case '8': case '9':
173 field_width = strtol(fmt, (char **) &fmt, 10);
174 fmt--;
175 break;
176 default:
177 goto flags_done;
178 }
179
180 if (!set_conv_type(&conv_type, type))
181 goto stop;
182
183 if (size != SIZE_ANY) {
184 if (size == SIZE_LONG && conv_size == SIZE_LONG)
185 conv_size = SIZE_QUAD;
186 else
187 conv_size = size;
188 }
189
190 fmt++;
191 }
192
193 flags_done:
194 value_begin = s;
195
196 switch (*fmt++) {
197 case '%':
198 if (*s == '\0')
199 goto eof;
200 if (*s != '%')
201 goto stop;
202 continue;
203 case '[':
204 value.v_string = s;
205 if (!set_conv_type(&conv_type, CONV_STR)
206 || !process_char_class(&fmt, &s, field_width))
207 goto stop;
208 break;
209 case 's':
210 value.v_string = s;
211 if (!set_conv_type(&conv_type, CONV_STR))
212 goto stop;
213 while (*s && !isspace(*s) && field_width-- != 0)
214 s++;
215 break;
216 case 'c':
217 if (!set_conv_type(&conv_type, CONV_STR))
218 goto stop;
219 value.v_string = s = value_begin = pre_space_skip;
220
221 if (field_width < 0)
222 s++;
223 else while (*s && field_width--)
224 s++;
225 nul_terminated = 0;
226 break;
227 case 'd':
228 case 'i':
229 case 'o':
230 case 'u':
231 case 'x':
232 case 'X':
233 if (!set_conv_type(&conv_type, CONV_INTEGER)
234 || !process_number(&value, &s, fmt[-1]))
235 goto stop;
236 break;
237 case 'p':
238 if (!set_conv_type(&conv_type, CONV_POINTER)
239 || !process_number(&value, &s, fmt[-1]))
240 goto stop;
241 break;
242 case 'f':
243 case 'g':
244 case 'e':
245 case 'E':
246 if (!set_conv_type(&conv_type, CONV_FLOAT)
247 || !process_number(&value, &s, fmt[-1]))
248 goto stop;
249 break;
250 case 'n':
251 if (!set_conv_type(&conv_type, CONV_INTEGER))
252 goto stop;
253 value.v_signed = (s - begin);
254 break;
255 default:
256 goto stop;
257 }
258
259 /* We've consumed what we need to consume. Now copy */
260 if (!assign)
261 continue;
262
263 /* Make sure we've consumed at least *something* */
264 if (s == value_begin)
265 goto eof;
266
267 /* Deal with a conversion flag */
268 if (conv_type == CONV_STR && allocate) {
269 value.v_pointer = GNU_strndup(value.v_string, s - value.v_string);
270 conv_type = CONV_POINTER;
271 allocate = 0;
272 }
273
274 switch (conv_type) {
275 case CONV_STR:
276 {
277 const char *string = value.v_string;
278 char *buf;
279 size_t size;
280
281 if (string == NULL)
282 goto stop;
283 buf = va_arg(ap, char *);
284 size = va_arg(ap, size_t) - nul_terminated;
285 if (size > (size_t)(s - string))
286 size = s - string;
287 strncpy(buf, string, size);
288 if (nul_terminated)
289 buf[size] = '\0';
290 }
291 break;
292
293 case CONV_POINTER:
294 {
295 void **ptr;
296
297 ptr = va_arg(ap, void **);
298 *ptr = value.v_pointer;
299 }
300 break;
301 case CONV_INTEGER:
302 {
303 void *ptr;
304
305 ptr = va_arg(ap, void *);
306 switch (conv_size) {
307 case SIZE_SHORT:
308 *(short *) ptr = value.v_integer;
309 break;
310 case SIZE_ANY:
311 *(int *) ptr = value.v_integer;
312 break;
313 case SIZE_LONG:
314 *(long *) ptr = value.v_integer;
315 break;
316 case SIZE_QUAD:
317 *(long long *) ptr = value.v_integer;
318 break;
319 default:
320 goto stop;
321 }
322 }
323 break;
324 case CONV_FLOAT:
325 {
326 void *ptr;
327
328 ptr = va_arg(ap, void *);
329 switch (conv_size) {
330 case SIZE_ANY:
331 *(float *) ptr = value.v_double;
332 break;
333 case SIZE_LONG:
334 *(double *) ptr = value.v_double;
335 break;
336 case SIZE_QUAD:
337 *(long double *) ptr = value.v_double;
338 break;
339 default:
340 goto stop;
341 }
342 }
343 break;
344 default:
345 goto stop;
346 }
347
348 num_fields++;
349 }
350
351 stop: return num_fields;
352
353 eof: if (num_fields)
354 return num_fields;
355 return EOF;
356 }
357
358 static int
process_number(union scan_value * vp,const char ** sp,char fmt)359 process_number(union scan_value *vp, const char **sp, char fmt)
360 {
361 const char *s = *sp;
362
363 switch (fmt) {
364 case 'd':
365 vp->v_signed = strtoll(s, (char **) sp, 10);
366 break;
367 case 'i':
368 vp->v_signed = strtoll(s, (char **) sp, 0);
369 break;
370 case 'o':
371 vp->v_integer = strtoull(s, (char **) sp, 8);
372 break;
373 case 'u':
374 vp->v_integer = strtoull(s, (char **) sp, 10);
375 break;
376 case 'x':
377 case 'X':
378 vp->v_integer = strtoull(s, (char **) sp, 16);
379 break;
380 case 'p':
381 vp->v_pointer = (void *)(intptr_t) strtoull(s, (char **) sp, 0);
382 break;
383 case 'f':
384 case 'g':
385 case 'e':
386 case 'E':
387 vp->v_double = strtold(s, (char **) sp);
388 break;
389 default:
390 return 0;
391 }
392
393 return 1;
394 }
395
396 static int
process_char_class(const char ** fmt,const char ** sp,int width)397 process_char_class(const char **fmt, const char **sp, int width)
398 {
399 unsigned char *s, c, prev_char = 0;
400 unsigned char table[255];
401 int val = 1;
402
403 s = (unsigned char *) *fmt;
404 if (*s == '^') {
405 memset(table, 1, sizeof(table));
406 val = 0;
407 s++;
408 } else {
409 memset(table, 0, sizeof(table));
410 val = 1;
411 }
412 /* First character in set is closing bracket means add it to the
413 * set of characters */
414 if ((c = *s) == ']') {
415 table[c] = val;
416 prev_char = c;
417 s++;
418 }
419
420 /* Any other closing bracket finishes off the set */
421 while ((c = *s++) != ']') {
422 if (prev_char) {
423 if (c == '-' && *s != '\0' && *s != ']') {
424 c = *s++;
425 } else {
426 //table[prev_char] = val;
427 prev_char = '\0';
428 }
429 }
430
431 if (c == '\0')
432 return 0;
433
434 if (prev_char) {
435 while (prev_char < c)
436 table[prev_char++] = val;
437 }
438 table[c] = val;
439 prev_char = c;
440 }
441 *fmt = (char *) s;
442
443 #if 0
444 {
445 int n;
446
447 printf("char class=");
448 for (n = 0; n < 255; n++)
449 if (table[n])
450 printf(isprint(n)? "%c" : "\\%03o", n);
451 printf("\n");
452 }
453 #endif
454
455 s = (unsigned char *) *sp;
456 while ((c = *s) != '\0' && table[c] && width--)
457 s++;
458
459 *sp = (char *) s;
460 return 1;
461 }
462
463 #ifdef TEST
464 static int verify(const char *fmt, const char *s);
465 static int verify_s(const char *fmt, const char *s);
466
467 enum { S, I, L, F, D, P };
468
469 int
main(int argc,char ** argv)470 main(int argc, char **argv)
471 {
472 verify("%d %d", "12 13");
473 verify("%d-%d", "12 13");
474 verify("%d-%d", "12-13");
475 verify("%u %u", "12 13");
476 verify("%o %o", "12 13");
477 verify("%x %x", "12 13");
478 verify("%X %X", "12 13");
479 verify("%hd %hd", "12 13");
480 verify("%ld %ld", "12 13");
481 verify("%lld %lld", "12 13");
482 verify("%Ld %Ld", "12 13");
483 verify("%qd %qd", "12 13");
484 verify("%f %f", "12 13");
485 verify("%lf %lf", "12 13");
486 verify("%Lf %Lf", "12 13");
487 verify("%qf %qf", "12 13");
488 verify("%*d-%d", "12-13");
489 verify("%*s %d", "12 13");
490 verify("%p", "0xdeadbeef");
491 verify("%*[a-e] %x", "deadbeef feeb");
492 verify("%*[a-f] %x", "deadbeef feeb");
493 verify("%*[^g-z] %x", "deadbeef feeb");
494 verify("%*[^ g-z] %x", "deadbeef feeb");
495 verify("%*[^ g-z-] %x", "dead-beef feeb");
496 verify("%*5s %d", "toast123 456");
497 verify("", "lalla");
498 verify("%u", "");
499
500 verify_s("%s", "aa bb");
501 verify_s("%s %s", "aa bb");
502 verify_s("%[a-z] %s", "aa bb");
503 verify_s("%c %s", "aa bb");
504 verify_s("%2c %s", " aa bb");
505 verify_s("%20c %s", " aa bb");
506
507 return 0;
508 }
509
510 static int
verify(const char * fmt,const char * s)511 verify(const char *fmt, const char *s)
512 {
513 union scan_value vals[5], vals_ref[5], *v;
514 int n, m;
515
516 memset(vals, 0xe5, sizeof(vals));
517 memset(vals_ref, 0xe5, sizeof(vals_ref));
518
519 v = vals;
520 n = sec_sscanf(s, fmt, v + 0, v + 1, v + 2, v + 3, v + 4);
521
522 v = vals_ref;
523 m = sscanf(s, fmt, v + 0, v + 1, v + 2, v + 3, v + 4);
524
525 if (m != n) {
526 printf("FAILED: fmt=\"%s\"\n"
527 " str=\"%s\"\n"
528 " sec_scanf returns %d, sscanf returns %d\n",
529 fmt, s, n, m);
530 return 0;
531 }
532
533 if (memcmp(vals, vals_ref, sizeof(vals))) {
534 printf("FAILED: fmt=\"%s\"\n"
535 " str=\"%s\"\n"
536 " data differs!\n",
537 fmt, s);
538 printf("0x%Lx != 0x%Lx\n", vals[0].v_integer, vals_ref[0].v_integer);
539 return 0;
540 }
541
542 return 1;
543 }
544
545 static int
verify_s(const char * fmt,const char * s)546 verify_s(const char *fmt, const char *s)
547 {
548 char buf[3][256], buf_ref[3][256];
549 int n, m;
550
551 memset(buf, 0xe5, sizeof(buf));
552 memset(buf_ref, 0xe5, sizeof(buf_ref));
553
554 n = sec_sscanf(s, fmt, buf, sizeof(buf[0]), buf + 1, sizeof(buf[1]), buf + 2, sizeof(buf[2]));
555
556 m = sscanf(s, fmt, buf_ref, buf_ref + 1, buf_ref + 2);
557
558 if (m != n) {
559 printf("FAILED: fmt=\"%s\"\n"
560 " str=\"%s\"\n"
561 " sec_scanf returns %d, sscanf returns %d\n",
562 fmt, s, n, m);
563 return 0;
564 }
565
566 if (memcmp(buf, buf_ref, sizeof(buf))) {
567 printf("FAILED: fmt=\"%s\"\n"
568 " str=\"%s\"\n"
569 " data differs!\n",
570 fmt, s);
571 printf("%s != %s\n", buf[0], buf_ref[0]);
572 return 0;
573 }
574
575 return 1;
576 }
577 #endif
578