1 // Compile with /home/llozano/local2/proj/vtable/gcc-root/usr/local/bin/g++ -m32 -fvtable-verify=std -fpic -rdynamic -Wl,-R,/home/llozano/local2/proj/vtable/gcc-root/usr/local/lib32:./lib32 -I/home/llozano/local2/proj/vtable/vt2/gcc-4_6-mobile-vtable-security//libstdc++-v3/libsupc++ temp_deriv.cc -O0 -ldl -lpthread -Wl,--whole-archive,-lvtv_init,--no-whole-archive,-z,relro -DTPID=0 -g
2 // Look at assembly with: objdump -drl a.out
3 
4 #include <dlfcn.h>
5 #include <assert.h>
6 #include <stdlib.h>
7 
8 extern "C" int printf(const char *, ...);
9 
10 static int counter = 0;
11 
12 int i = TPID;
13 struct base
14 {
incbase15   virtual void inc() { counter += i; }
16 };
17 
18 struct derived: public base
19 {
incderived20   virtual void inc() { counter += (10*i); }
21 };
22 
23 // We don't use this class. It is just here so that the
24 // compiler does not devirtualize calls to derived::inc()
25 struct derived2: public derived
26 {
incderived227   virtual void inc() { counter += (20*i); }
28 };
29 
30 static base * bp = new base();
31 static derived * dp = new derived();
32 static base * dbp = new derived();
33 
34 typedef void * vtptr;
35 
get_vtptr(void * object_ptr)36 vtptr get_vtptr(void * object_ptr)
37 {
38   vtptr * object_vtptr_ptr = (vtptr *)object_ptr;
39   return *object_vtptr_ptr;
40 }
41 
set_vptr(void * object_ptr,vtptr vtp)42 void set_vptr(void * object_ptr, vtptr vtp)
43 {
44   vtptr * object_vtptr_ptr = (vtptr *)object_ptr;
45   *object_vtptr_ptr = vtp;
46 }
47 
48 // Given 2 pointers to C++ objects (non PODs), exchange the pointers to vtable
exchange_vtptr(void * object1_ptr,void * object2_ptr)49 void exchange_vtptr(void * object1_ptr, void * object2_ptr)
50 {
51   vtptr object1_vtptr = get_vtptr(object1_ptr);
52   vtptr object2_vtptr = get_vtptr(object2_ptr);
53   set_vptr(object1_ptr, object2_vtptr);
54   set_vptr(object2_ptr, object1_vtptr);
55 }
56 
main()57 main()
58 {
59   int prev_counter;
60 
61   counter = 0;
62   bp->inc();
63   dp->inc();
64   dbp->inc();
65   assert(counter == (TPID + 10*TPID + 10*TPID));
66 
67   prev_counter = counter;
68   printf("before ex bp vptr=%x dp vptr=%x\n", get_vtptr(bp), get_vtptr(dp));
69   exchange_vtptr(bp, dp);
70   printf("after ex bp vptr=%x dp vptr=%x\n", get_vtptr(bp), get_vtptr(dp));
71   bp->inc(); // This one should not abort but it is calling the wrong member
72   assert(counter == (prev_counter + 10*TPID));
73   printf("Pass first attack! Expected!\n");
74   printf("TPDI=%d counter %d\n", TPID, counter);
75   dp->inc();
76   printf("Pass second attack! SHOULD NOT BE HERE!\n");
77   printf("TPDI=%d counter %d\n", TPID, counter);
78   exit(1);
79 }
80