1 // Compile with /home/llozano/local2/proj/vtable/gcc-root/usr/local/bin/g++ -m32 -fvtable-verify=std -fpic -rdynamic -Wl,-R,/home/llozano/local2/proj/vtable/gcc-root/usr/local/lib32:./lib32 -I/home/llozano/local2/proj/vtable/vt2/gcc-4_6-mobile-vtable-security//libstdc++-v3/libsupc++ temp_deriv.cc -O0 -ldl -lpthread -Wl,--whole-archive,-lvtv_init,--no-whole-archive,-z,relro -DTPID=0 -g
2 // Look at assembly with: objdump -drl a.out
3 
4 #include <dlfcn.h>
5 #include <assert.h>
6 
7 extern "C" int printf(const char *, ...);
8 
9 static int counter = 0;
10 
11 int i = TPID;
12 struct base
13 {
incbase14   virtual void inc() { counter += i; }
15 };
16 
17 struct derived: public base
18 {
incderived19   virtual void inc() { counter += (10*i); }
20 };
21 
22 // We don't use this class. It is just here so that the
23 // compiler does not devirtualize calls to derived::inc()
24 struct derived2: public derived
25 {
incderived226   virtual void inc() { counter += (20*i); }
27 };
28 
29 static base * bp = new base();
30 static derived * dp = new derived();
31 static base * dbp = new derived();
32 
33 // Given 2 pointers to C++ objects (non PODs), exchange the pointers to vtable
exchange_vtptr(void * object1_ptr,void * object2_ptr)34 void exchange_vtptr(void * object1_ptr, void * object2_ptr)
35 {
36   typedef void * vtptr;
37   vtptr * object1_vtptr_ptr = (vtptr *)object1_ptr;
38   vtptr * object2_vtptr_ptr = (vtptr *)object2_ptr;
39   vtptr object1_vtptr = *object1_vtptr_ptr;
40   vtptr object2_vtptr = *object2_vtptr_ptr;
41   *object1_vtptr_ptr = object2_vtptr;
42   *object2_vtptr_ptr = object1_vtptr;
43 }
44 
main()45 main()
46 {
47   int prev_counter;
48 
49   exchange_vtptr(bp, dp);
50   exchange_vtptr(bp, dp);
51   exchange_vtptr(bp, dbp);
52   exchange_vtptr(bp, dbp);
53 
54   counter = 0;
55   bp->inc();
56   dp->inc();
57   dbp->inc();
58   assert(counter == (TPID + 10*TPID + 10*TPID));
59 
60   prev_counter = counter;
61   exchange_vtptr(bp, dp);
62   bp->inc(); // This one should succeed but it is calling the wrong member
63   assert(counter == (prev_counter + 10*TPID));
64   printf("Pass first attack!\n");
65   dp->inc();
66   printf("TPDI=%d counter %d\n", TPID, counter);
67   printf("Pass second attack!\n");
68 
69 }
70