1 /*
2  * Copyright (C) 2009 Dan Carpenter.
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU General Public License
6  * as published by the Free Software Foundation; either version 2
7  * of the License, or (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
16  */
17 
18 /*
19  * The idea here is that you have an expression and you
20  * want to know what the type is for that.
21  */
22 
23 #include "smatch.h"
24 #include "smatch_slist.h"
25 
26 struct symbol *get_real_base_type(struct symbol *sym)
27 {
28 	struct symbol *ret;
29 
30 	if (!sym)
31 		return NULL;
32 	ret = get_base_type(sym);
33 	if (!ret)
34 		return NULL;
35 	if (ret->type == SYM_RESTRICT || ret->type == SYM_NODE)
36 		return get_real_base_type(ret);
37 	return ret;
38 }
39 
40 int type_bytes(struct symbol *type)
41 {
42 	int bits;
43 
44 	if (type && type->type == SYM_ARRAY)
45 		return array_bytes(type);
46 
47 	bits = type_bits(type);
48 	if (bits < 0)
49 		return 0;
50 	return bits_to_bytes(bits);
51 }
52 
53 int array_bytes(struct symbol *type)
54 {
55 	if (!type || type->type != SYM_ARRAY)
56 		return 0;
57 	return bits_to_bytes(type->bit_size);
58 }
59 
60 static struct symbol *get_binop_type(struct expression *expr)
61 {
62 	struct symbol *left, *right;
63 
64 	left = get_type(expr->left);
65 	if (!left)
66 		return NULL;
67 
68 	if (expr->op == SPECIAL_LEFTSHIFT ||
69 	    expr->op == SPECIAL_RIGHTSHIFT) {
70 		if (type_positive_bits(left) < 31)
71 			return &int_ctype;
72 		return left;
73 	}
74 	right = get_type(expr->right);
75 	if (!right)
76 		return NULL;
77 
78 	if (left->type == SYM_PTR || left->type == SYM_ARRAY)
79 		return left;
80 	if (right->type == SYM_PTR || right->type == SYM_ARRAY)
81 		return right;
82 
83 	if (type_positive_bits(left) < 31 && type_positive_bits(right) < 31)
84 		return &int_ctype;
85 
86 	if (type_positive_bits(left) > type_positive_bits(right))
87 		return left;
88 	return right;
89 }
90 
91 static struct symbol *get_type_symbol(struct expression *expr)
92 {
93 	if (!expr || expr->type != EXPR_SYMBOL || !expr->symbol)
94 		return NULL;
95 
96 	return get_real_base_type(expr->symbol);
97 }
98 
99 static struct symbol *get_member_symbol(struct symbol_list *symbol_list, struct ident *member)
100 {
101 	struct symbol *tmp, *sub;
102 
103 	FOR_EACH_PTR(symbol_list, tmp) {
104 		if (!tmp->ident) {
105 			sub = get_real_base_type(tmp);
106 			sub = get_member_symbol(sub->symbol_list, member);
107 			if (sub)
108 				return sub;
109 			continue;
110 		}
111 		if (tmp->ident == member)
112 			return tmp;
113 	} END_FOR_EACH_PTR(tmp);
114 
115 	return NULL;
116 }
117 
118 static struct symbol *get_symbol_from_deref(struct expression *expr)
119 {
120 	struct ident *member;
121 	struct symbol *sym;
122 
123 	if (!expr || expr->type != EXPR_DEREF)
124 		return NULL;
125 
126 	member = expr->member;
127 	sym = get_type(expr->deref);
128 	if (!sym) {
129 		// sm_msg("could not find struct type");
130 		return NULL;
131 	}
132 	if (sym->type == SYM_PTR)
133 		sym = get_real_base_type(sym);
134 	sym = get_member_symbol(sym->symbol_list, member);
135 	if (!sym)
136 		return NULL;
137 	return get_real_base_type(sym);
138 }
139 
140 static struct symbol *get_return_type(struct expression *expr)
141 {
142 	struct symbol *tmp;
143 
144 	tmp = get_type(expr->fn);
145 	if (!tmp)
146 		return NULL;
147 	/* this is to handle __builtin_constant_p() */
148 	if (tmp->type != SYM_FN)
149 		tmp = get_base_type(tmp);
150 	return get_real_base_type(tmp);
151 }
152 
153 static struct symbol *get_expr_stmt_type(struct statement *stmt)
154 {
155 	if (stmt->type != STMT_COMPOUND)
156 		return NULL;
157 	stmt = last_ptr_list((struct ptr_list *)stmt->stmts);
158 	if (stmt->type == STMT_LABEL)
159 		stmt = stmt->label_statement;
160 	if (stmt->type != STMT_EXPRESSION)
161 		return NULL;
162 	return get_type(stmt->expression);
163 }
164 
165 static struct symbol *get_select_type(struct expression *expr)
166 {
167 	struct symbol *one, *two;
168 
169 	one = get_type(expr->cond_true);
170 	two = get_type(expr->cond_false);
171 	if (!one || !two)
172 		return NULL;
173 	/*
174 	 * This is a hack.  If the types are not equiv then we
175 	 * really don't know the type.  But I think guessing is
176 	 *  probably Ok here.
177 	 */
178 	if (type_positive_bits(one) > type_positive_bits(two))
179 		return one;
180 	return two;
181 }
182 
183 struct symbol *get_pointer_type(struct expression *expr)
184 {
185 	struct symbol *sym;
186 
187 	sym = get_type(expr);
188 	if (!sym)
189 		return NULL;
190 	if (sym->type == SYM_NODE) {
191 		sym = get_real_base_type(sym);
192 		if (!sym)
193 			return NULL;
194 	}
195 	if (sym->type != SYM_PTR && sym->type != SYM_ARRAY)
196 		return NULL;
197 	return get_real_base_type(sym);
198 }
199 
200 static struct symbol *fake_pointer_sym(struct expression *expr)
201 {
202 	struct symbol *sym;
203 	struct symbol *base;
204 
205 	sym = alloc_symbol(expr->pos, SYM_PTR);
206 	expr = expr->unop;
207 	base = get_type(expr);
208 	if (!base)
209 		return NULL;
210 	sym->ctype.base_type = base;
211 	return sym;
212 }
213 
214 static struct symbol *get_type_helper(struct expression *expr)
215 {
216 	struct symbol *ret;
217 
218 	expr = strip_parens(expr);
219 	if (!expr)
220 		return NULL;
221 
222 	if (expr->ctype)
223 		return expr->ctype;
224 
225 	switch (expr->type) {
226 	case EXPR_STRING:
227 		ret = &string_ctype;
228 		break;
229 	case EXPR_SYMBOL:
230 		ret = get_type_symbol(expr);
231 		break;
232 	case EXPR_DEREF:
233 		ret = get_symbol_from_deref(expr);
234 		break;
235 	case EXPR_PREOP:
236 	case EXPR_POSTOP:
237 		if (expr->op == '&')
238 			ret = fake_pointer_sym(expr);
239 		else if (expr->op == '*')
240 			ret = get_pointer_type(expr->unop);
241 		else
242 			ret = get_type(expr->unop);
243 		break;
244 	case EXPR_ASSIGNMENT:
245 		ret = get_type(expr->left);
246 		break;
247 	case EXPR_CAST:
248 	case EXPR_FORCE_CAST:
249 	case EXPR_IMPLIED_CAST:
250 		ret = get_real_base_type(expr->cast_type);
251 		break;
252 	case EXPR_COMPARE:
253 	case EXPR_BINOP:
254 		ret = get_binop_type(expr);
255 		break;
256 	case EXPR_CALL:
257 		ret = get_return_type(expr);
258 		break;
259 	case EXPR_STATEMENT:
260 		ret = get_expr_stmt_type(expr->statement);
261 		break;
262 	case EXPR_CONDITIONAL:
263 	case EXPR_SELECT:
264 		ret = get_select_type(expr);
265 		break;
266 	case EXPR_SIZEOF:
267 		ret = &ulong_ctype;
268 		break;
269 	case EXPR_LOGICAL:
270 		ret = &int_ctype;
271 		break;
272 	default:
273 		return NULL;
274 	}
275 
276 	if (ret && ret->type == SYM_TYPEOF)
277 		ret = get_type(ret->initializer);
278 
279 	expr->ctype = ret;
280 	return ret;
281 }
282 
283 static struct symbol *get_final_type_helper(struct expression *expr)
284 {
285 	/*
286 	 * I'm not totally positive I understand types...
287 	 *
288 	 * So, when you're doing pointer math, and you do a subtraction, then
289 	 * the sval_binop() and whatever need to know the type of the pointer
290 	 * so they can figure out the alignment.  But the result is going to be
291 	 * and ssize_t.  So get_operation_type() gives you the pointer type
292 	 * and get_type() gives you ssize_t.
293 	 *
294 	 * Most of the time the operation type and the final type are the same
295 	 * but this just handles the few places where they are different.
296 	 *
297 	 */
298 
299 	expr = strip_parens(expr);
300 	if (!expr)
301 		return NULL;
302 
303 	switch (expr->type) {
304 	case EXPR_COMPARE:
305 		return &int_ctype;
306 	case EXPR_BINOP: {
307 		struct symbol *left, *right;
308 
309 		if (expr->op != '-')
310 			return NULL;
311 
312 		left = get_type(expr->left);
313 		right = get_type(expr->right);
314 		if (type_is_ptr(left) || type_is_ptr(right))
315 			return ssize_t_ctype;
316 		}
317 	}
318 
319 	return NULL;
320 }
321 
322 struct symbol *get_type(struct expression *expr)
323 {
324 	return get_type_helper(expr);
325 }
326 
327 struct symbol *get_final_type(struct expression *expr)
328 {
329 	struct symbol *ret;
330 
331 	ret = get_final_type_helper(expr);
332 	if (ret)
333 		return ret;
334 	return get_type_helper(expr);
335 }
336 
337 struct symbol *get_promoted_type(struct symbol *left, struct symbol *right)
338 {
339 	struct symbol *ret = &int_ctype;
340 
341 	if (type_positive_bits(left) > type_positive_bits(ret))
342 		ret = left;
343 	if (type_positive_bits(right) > type_positive_bits(ret))
344 		ret = right;
345 
346 	if (type_is_ptr(left))
347 		ret = left;
348 	if (type_is_ptr(right))
349 		ret = right;
350 
351 	return ret;
352 }
353 
354 int type_signed(struct symbol *base_type)
355 {
356 	if (!base_type)
357 		return 0;
358 	if (base_type->ctype.modifiers & MOD_SIGNED)
359 		return 1;
360 	return 0;
361 }
362 
363 int expr_unsigned(struct expression *expr)
364 {
365 	struct symbol *sym;
366 
367 	sym = get_type(expr);
368 	if (!sym)
369 		return 0;
370 	if (type_unsigned(sym))
371 		return 1;
372 	return 0;
373 }
374 
375 int expr_signed(struct expression *expr)
376 {
377 	struct symbol *sym;
378 
379 	sym = get_type(expr);
380 	if (!sym)
381 		return 0;
382 	if (type_signed(sym))
383 		return 1;
384 	return 0;
385 }
386 
387 int returns_unsigned(struct symbol *sym)
388 {
389 	if (!sym)
390 		return 0;
391 	sym = get_base_type(sym);
392 	if (!sym || sym->type != SYM_FN)
393 		return 0;
394 	sym = get_base_type(sym);
395 	return type_unsigned(sym);
396 }
397 
398 int is_pointer(struct expression *expr)
399 {
400 	struct symbol *sym;
401 
402 	sym = get_type(expr);
403 	if (!sym)
404 		return 0;
405 	if (sym == &string_ctype)
406 		return 0;
407 	if (sym->type == SYM_PTR)
408 		return 1;
409 	return 0;
410 }
411 
412 int returns_pointer(struct symbol *sym)
413 {
414 	if (!sym)
415 		return 0;
416 	sym = get_base_type(sym);
417 	if (!sym || sym->type != SYM_FN)
418 		return 0;
419 	sym = get_base_type(sym);
420 	if (sym->type == SYM_PTR)
421 		return 1;
422 	return 0;
423 }
424 
425 sval_t sval_type_max(struct symbol *base_type)
426 {
427 	sval_t ret;
428 
429 	if (!base_type || !type_bits(base_type))
430 		base_type = &llong_ctype;
431 	ret.type = base_type;
432 
433 	ret.value = (~0ULL) >> (64 - type_positive_bits(base_type));
434 	return ret;
435 }
436 
437 sval_t sval_type_min(struct symbol *base_type)
438 {
439 	sval_t ret;
440 
441 	if (!base_type || !type_bits(base_type))
442 		base_type = &llong_ctype;
443 	ret.type = base_type;
444 
445 	if (type_unsigned(base_type)) {
446 		ret.value = 0;
447 		return ret;
448 	}
449 
450 	ret.value = (~0ULL) << type_positive_bits(base_type);
451 
452 	return ret;
453 }
454 
455 int nr_bits(struct expression *expr)
456 {
457 	struct symbol *type;
458 
459 	type = get_type(expr);
460 	if (!type)
461 		return 0;
462 	return type_bits(type);
463 }
464 
465 int is_void_pointer(struct expression *expr)
466 {
467 	struct symbol *type;
468 
469 	type = get_type(expr);
470 	if (!type || type->type != SYM_PTR)
471 		return 0;
472 	type = get_real_base_type(type);
473 	if (type == &void_ctype)
474 		return 1;
475 	return 0;
476 }
477 
478 int is_char_pointer(struct expression *expr)
479 {
480 	struct symbol *type;
481 
482 	type = get_type(expr);
483 	if (!type || type->type != SYM_PTR)
484 		return 0;
485 	type = get_real_base_type(type);
486 	if (type == &char_ctype)
487 		return 1;
488 	return 0;
489 }
490 
491 int is_string(struct expression *expr)
492 {
493 	expr = strip_expr(expr);
494 	if (!expr || expr->type != EXPR_STRING)
495 		return 0;
496 	if (expr->string)
497 		return 1;
498 	return 0;
499 }
500 
501 int is_static(struct expression *expr)
502 {
503 	char *name;
504 	struct symbol *sym;
505 	int ret = 0;
506 
507 	name = expr_to_str_sym(expr, &sym);
508 	if (!name || !sym)
509 		goto free;
510 
511 	if (sym->ctype.modifiers & MOD_STATIC)
512 		ret = 1;
513 free:
514 	free_string(name);
515 	return ret;
516 }
517 
518 int is_local_variable(struct expression *expr)
519 {
520 	struct symbol *sym;
521 	char *name;
522 
523 	name = expr_to_var_sym(expr, &sym);
524 	free_string(name);
525 	if (!sym || !sym->scope || !sym->scope->token || !cur_func_sym)
526 		return 0;
527 	if (cmp_pos(sym->scope->token->pos, cur_func_sym->pos) < 0)
528 		return 0;
529 	if (is_static(expr))
530 		return 0;
531 	return 1;
532 }
533 
534 int types_equiv(struct symbol *one, struct symbol *two)
535 {
536 	if (!one && !two)
537 		return 1;
538 	if (!one || !two)
539 		return 0;
540 	if (one->type != two->type)
541 		return 0;
542 	if (one->type == SYM_PTR)
543 		return types_equiv(get_real_base_type(one), get_real_base_type(two));
544 	if (type_positive_bits(one) != type_positive_bits(two))
545 		return 0;
546 	return 1;
547 }
548 
549 int fn_static(void)
550 {
551 	return !!(cur_func_sym->ctype.modifiers & MOD_STATIC);
552 }
553 
554 const char *global_static(void)
555 {
556 	if (cur_func_sym->ctype.modifiers & MOD_STATIC)
557 		return "static";
558 	else
559 		return "global";
560 }
561 
562 struct symbol *cur_func_return_type(void)
563 {
564 	struct symbol *sym;
565 
566 	sym = get_real_base_type(cur_func_sym);
567 	if (!sym || sym->type != SYM_FN)
568 		return NULL;
569 	sym = get_real_base_type(sym);
570 	return sym;
571 }
572 
573 struct symbol *get_arg_type(struct expression *fn, int arg)
574 {
575 	struct symbol *fn_type;
576 	struct symbol *tmp;
577 	struct symbol *arg_type;
578 	int i;
579 
580 	fn_type = get_type(fn);
581 	if (!fn_type)
582 		return NULL;
583 	if (fn_type->type == SYM_PTR)
584 		fn_type = get_real_base_type(fn_type);
585 	if (fn_type->type != SYM_FN)
586 		return NULL;
587 
588 	i = 0;
589 	FOR_EACH_PTR(fn_type->arguments, tmp) {
590 		arg_type = get_real_base_type(tmp);
591 		if (i == arg) {
592 			return arg_type;
593 		}
594 		i++;
595 	} END_FOR_EACH_PTR(tmp);
596 
597 	return NULL;
598 }
599 
600 static struct symbol *get_member_from_string(struct symbol_list *symbol_list, const char *name)
601 {
602 	struct symbol *tmp, *sub;
603 	int chunk_len;
604 
605 	if (strncmp(name, ".", 1) == 0)
606 		name += 1;
607 	if (strncmp(name, "->", 2) == 0)
608 		name += 2;
609 
610 	FOR_EACH_PTR(symbol_list, tmp) {
611 		if (!tmp->ident) {
612 			sub = get_real_base_type(tmp);
613 			sub = get_member_from_string(sub->symbol_list, name);
614 			if (sub)
615 				return sub;
616 			continue;
617 		}
618 
619 		if (strcmp(tmp->ident->name, name) == 0)
620 			return tmp;
621 
622 		chunk_len = strlen(tmp->ident->name);
623 		if (strncmp(tmp->ident->name, name, chunk_len) == 0 &&
624 		    (name[chunk_len] == '.' || name[chunk_len] == '-')) {
625 			sub = get_real_base_type(tmp);
626 			return get_member_from_string(sub->symbol_list, name + chunk_len);
627 		}
628 
629 	} END_FOR_EACH_PTR(tmp);
630 
631 	return NULL;
632 }
633 
634 struct symbol *get_member_type_from_key(struct expression *expr, const char *key)
635 {
636 	struct symbol *sym;
637 
638 	if (strcmp(key, "$") == 0)
639 		return get_type(expr);
640 
641 	if (strcmp(key, "*$") == 0) {
642 		sym = get_type(expr);
643 		if (!sym || sym->type != SYM_PTR)
644 			return NULL;
645 		return get_real_base_type(sym);
646 	}
647 
648 	sym = get_type(expr);
649 	if (!sym)
650 		return NULL;
651 	if (sym->type == SYM_PTR)
652 		sym = get_real_base_type(sym);
653 
654 	key = key + 1;
655 	sym = get_member_from_string(sym->symbol_list, key);
656 	if (!sym)
657 		return NULL;
658 	return get_real_base_type(sym);
659 }
660 
661 struct symbol *get_arg_type_from_key(struct expression *fn, int param, struct expression *arg, const char *key)
662 {
663 	struct symbol *type;
664 
665 	if (!key)
666 		return NULL;
667 	if (strcmp(key, "$") == 0)
668 		return get_arg_type(fn, param);
669 	if (strcmp(key, "*$") == 0) {
670 		type = get_arg_type(fn, param);
671 		if (!type || type->type != SYM_PTR)
672 			return NULL;
673 		return get_real_base_type(type);
674 	}
675 	return get_member_type_from_key(arg, key);
676 }
677 
678 int is_struct(struct expression *expr)
679 {
680 	struct symbol *type;
681 
682 	type = get_type(expr);
683 	if (type && type->type == SYM_STRUCT)
684 		return 1;
685 	return 0;
686 }
687 
688 static struct {
689 	struct symbol *sym;
690 	const char *name;
691 } base_types[] = {
692 	{&bool_ctype, "bool"},
693 	{&void_ctype, "void"},
694 	{&type_ctype, "type"},
695 	{&char_ctype, "char"},
696 	{&schar_ctype, "schar"},
697 	{&uchar_ctype, "uchar"},
698 	{&short_ctype, "short"},
699 	{&sshort_ctype, "sshort"},
700 	{&ushort_ctype, "ushort"},
701 	{&int_ctype, "int"},
702 	{&sint_ctype, "sint"},
703 	{&uint_ctype, "uint"},
704 	{&long_ctype, "long"},
705 	{&slong_ctype, "slong"},
706 	{&ulong_ctype, "ulong"},
707 	{&llong_ctype, "llong"},
708 	{&sllong_ctype, "sllong"},
709 	{&ullong_ctype, "ullong"},
710 	{&lllong_ctype, "lllong"},
711 	{&slllong_ctype, "slllong"},
712 	{&ulllong_ctype, "ulllong"},
713 	{&float_ctype, "float"},
714 	{&double_ctype, "double"},
715 	{&ldouble_ctype, "ldouble"},
716 	{&string_ctype, "string"},
717 	{&ptr_ctype, "ptr"},
718 	{&lazy_ptr_ctype, "lazy_ptr"},
719 	{&incomplete_ctype, "incomplete"},
720 	{&label_ctype, "label"},
721 	{&bad_ctype, "bad"},
722 	{&null_ctype, "null"},
723 };
724 
725 static const char *base_type_str(struct symbol *sym)
726 {
727 	int i;
728 
729 	for (i = 0; i < ARRAY_SIZE(base_types); i++) {
730 		if (sym == base_types[i].sym)
731 			return base_types[i].name;
732 	}
733 	return "<unknown>";
734 }
735 
736 static int type_str_helper(char *buf, int size, struct symbol *type)
737 {
738 	int n;
739 
740 	if (!type)
741 		return snprintf(buf, size, "<unknown>");
742 
743 	if (type->type == SYM_BASETYPE) {
744 		return snprintf(buf, size, base_type_str(type));
745 	} else if (type->type == SYM_PTR) {
746 		type = get_real_base_type(type);
747 		n = type_str_helper(buf, size, type);
748 		if (n > size)
749 			return n;
750 		return n + snprintf(buf + n, size - n, "*");
751 	} else if (type->type == SYM_ARRAY) {
752 		type = get_real_base_type(type);
753 		n = type_str_helper(buf, size, type);
754 		if (n > size)
755 			return n;
756 		return n + snprintf(buf + n, size - n, "[]");
757 	} else if (type->type == SYM_STRUCT) {
758 		return snprintf(buf, size, "struct %s", type->ident ? type->ident->name : "");
759 	} else if (type->type == SYM_UNION) {
760 		if (type->ident)
761 			return snprintf(buf, size, "union %s", type->ident->name);
762 		else
763 			return snprintf(buf, size, "anonymous union");
764 	} else if (type->type == SYM_FN) {
765 		struct symbol *arg, *return_type, *arg_type;
766 		int i;
767 
768 		return_type = get_real_base_type(type);
769 		n = type_str_helper(buf, size, return_type);
770 		if (n > size)
771 			return n;
772 		n += snprintf(buf + n, size - n, "(*)(");
773 		if (n > size)
774 			return n;
775 
776 		i = 0;
777 		FOR_EACH_PTR(type->arguments, arg) {
778 			if (i++)
779 				n += snprintf(buf + n, size - n, ", ");
780 			if (n > size)
781 				return n;
782 			arg_type = get_real_base_type(arg);
783 			n += type_str_helper(buf + n, size - n, arg_type);
784 			if (n > size)
785 				return n;
786 		} END_FOR_EACH_PTR(arg);
787 
788 		return n + snprintf(buf + n, size - n, ")");
789 	} else if (type->type == SYM_NODE) {
790 		n = snprintf(buf, size, "node {");
791 		if (n > size)
792 			return n;
793 		type = get_real_base_type(type);
794 		n += type_str_helper(buf + n, size - n, type);
795 		if (n > size)
796 			return n;
797 		return n + snprintf(buf + n, size - n, "}");
798 	} else {
799 		return snprintf(buf, size, "<type %d>", type->type);
800 	}
801 }
802 
803 char *type_to_str(struct symbol *type)
804 {
805 	static char buf[256];
806 
807 	buf[0] = '\0';
808 	type_str_helper(buf, sizeof(buf), type);
809 	return buf;
810 }
811