1 #include <dlfcn.h>
2 #include <assert.h>
3 #include <unistd.h>
4 #include <vtv_fail.h>
5 
6 extern "C" int printf(const char *, ...);
7 extern "C" int sprintf(char *, const char*, ...);
8 
9 static int counter = 0;
10 extern int failures;
11 
12 template <int i> struct base
13 {
whoamibase14   virtual char * whoami() {
15     static char sl[100];
16     sprintf(sl, "I am base %d", i);
17     return sl;
18   }
incbase19   virtual void inc() { counter += i; }
20 };
21 
22 template <int i> struct derived: base<i>
23 {
whoamiderived24   virtual char * whoami() {
25     static char sl[100];
26     sprintf(sl, "I am derived %d", i);
27     return sl;
28   }
incderived29   virtual void inc() { counter += (10*i); }
30 };
31 
32 // We don't use this class. It is just here so that the
33 // compiler does not devirtualize calls to derived::inc()
34 template <int i> struct derived2: derived<i>
35 {
incderived236   virtual void inc() { counter += (20*i); }
37 };
38 
39 static base<TPID> * bp = new base<TPID>();
40 static derived<TPID> * dp = new derived<TPID>();
41 static base<TPID> * dbp = new derived<TPID>();
42 
43 
44 // Given 2 pointers to C++ objects (non PODs), exchange the pointers to vtable
exchange_vtptr(void * object1_ptr,void * object2_ptr)45 static void exchange_vtptr(void * object1_ptr, void * object2_ptr)
46 {
47   void ** object1_vtptr_ptr = (void **)object1_ptr;
48   void ** object2_vtptr_ptr = (void **)object2_ptr;
49   void * object1_vtptr = *object1_vtptr_ptr;
50   void * object2_vtptr = *object2_vtptr_ptr;
51   *object1_vtptr_ptr = object2_vtptr;
52   *object2_vtptr_ptr = object1_vtptr;
53 }
54 
55 #define BUILD_NAME(NAME,ID) NAME##ID
56 #define EXPAND(NAME,X) BUILD_NAME(NAME,X)
EXPAND(so_entry_,TPID)57 extern "C" void EXPAND(so_entry_,TPID)(void)
58 {
59   int prev_counter;
60   int prev_failures;
61 
62   counter = 0;
63   bp->inc();
64   dp->inc();
65   dbp->inc();
66   assert(counter == (TPID + 10*TPID + 10*TPID));
67 
68   prev_counter = counter;
69   exchange_vtptr(bp, dp);
70   bp->inc(); // This one should succeed but it is calling the wrong member
71   if (counter != (prev_counter + 10*TPID))
72   {
73     printf("TPID=%d whoami=%s wrong counter value prev_counter=%d counter=%d\n", TPID, bp->whoami(), prev_counter, counter);
74     sleep(2);
75   }
76   assert(counter == (prev_counter + 10*TPID));
77   //  printf("Pass first attack!\n");
78 
79  // This one should fail verification!. So it should jump to __vtv_verify_fail above.
80   prev_failures = failures;
81   dp->inc();
82   // this code may be executed by multiple threads at the same time. So, just verify the number of failures has
83   // increased as opposed to check for increase by 1.
84   assert(failures > prev_failures);
85   assert(counter == (prev_counter + 10*TPID + TPID));
86   //  printf("TPDI=%d counter %d\n", TPID, counter);
87   //  printf("Pass second attack!\n");
88 
89   // restore the vtable pointers to the original state.
90   // This is very important. For some reason the dlclose is not "really" closing the library so when we reopen it we are
91   // getting the old memory state.
92   exchange_vtptr(bp, dp);
93 }
94