1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef PTHREADMOCKS_H_
20 #define PTHREADMOCKS_H_
21 
22 #include <pthread.h>
23 #include <string.h>
24 #include <errno.h>
25 
26 #include "src/zk_adaptor.h"
27 
28 #include "Util.h"
29 #include "MocksBase.h"
30 #include "LibCSymTable.h"
31 #include "ThreadingUtil.h"
32 
33 // an ABC for pthreads
34 class MockPthreadsBase: public Mock
35 {
36 public:
MockPthreadsBase()37     MockPthreadsBase(){mock_=this;}
~MockPthreadsBase()38     virtual ~MockPthreadsBase(){mock_=0;}
39 
40     virtual int pthread_create(pthread_t * t, const pthread_attr_t *a,
41             void *(*f)(void *), void *d) =0;
42     virtual int pthread_join(pthread_t t, void ** r) =0;
43     virtual int pthread_detach(pthread_t t) =0;
44     virtual int pthread_cond_broadcast(pthread_cond_t *c) =0;
45     virtual int pthread_cond_destroy(pthread_cond_t *c) =0;
46     virtual int pthread_cond_init(pthread_cond_t *c, const pthread_condattr_t *a) =0;
47     virtual int pthread_cond_signal(pthread_cond_t *c) =0;
48     virtual int pthread_cond_timedwait(pthread_cond_t *c,
49             pthread_mutex_t *m, const struct timespec *t) =0;
50     virtual int pthread_cond_wait(pthread_cond_t *c, pthread_mutex_t *m) =0;
51     virtual int pthread_mutex_destroy(pthread_mutex_t *m) =0;
52     virtual int pthread_mutex_init(pthread_mutex_t *m, const pthread_mutexattr_t *a) =0;
53     virtual int pthread_mutex_lock(pthread_mutex_t *m) =0;
54     virtual int pthread_mutex_trylock(pthread_mutex_t *m) =0;
55     virtual int pthread_mutex_unlock(pthread_mutex_t *m) =0;
56 
57     static MockPthreadsBase* mock_;
58 };
59 
60 // all pthread functions simply return an error code
61 // and increment their invocation counter. No actual threads are spawned.
62 class MockPthreadsNull: public MockPthreadsBase
63 {
64 public:
MockPthreadsNull()65     MockPthreadsNull():
66     pthread_createReturns(0),pthread_createCounter(0),
67     pthread_joinReturns(0),pthread_joinCounter(0),pthread_joinResultReturn(0),
68     pthread_detachReturns(0),pthread_detachCounter(0),
69     pthread_cond_broadcastReturns(0),pthread_cond_broadcastCounter(0),
70     pthread_cond_destroyReturns(0),pthread_cond_destroyCounter(0),
71     pthread_cond_initReturns(0),pthread_cond_initCounter(0),
72     pthread_cond_signalReturns(0),pthread_cond_signalCounter(0),
73     pthread_cond_timedwaitReturns(0),pthread_cond_timedwaitCounter(0),
74     pthread_cond_waitReturns(0),pthread_cond_waitCounter(0),
75     pthread_mutex_destroyReturns(0),pthread_mutex_destroyCounter(0),
76     pthread_mutex_initReturns(0),pthread_mutex_initCounter(0),
77     pthread_mutex_lockReturns(0),pthread_mutex_lockCounter(0),
78     pthread_mutex_trylockReturns(0),pthread_mutex_trylockCounter(0),
79     pthread_mutex_unlockReturns(0),pthread_mutex_unlockCounter(0)
80     {
81         memset(threads,0,sizeof(threads));
82     }
83 
84     short threads[512];
85 
86     int pthread_createReturns;
87     int pthread_createCounter;
pthread_create(pthread_t * t,const pthread_attr_t * a,void * (* f)(void *),void * d)88     virtual int pthread_create(pthread_t * t, const pthread_attr_t *a,
89             void *(*f)(void *), void *d){
90         char* p=(char*)&threads[pthread_createCounter++];
91         p[0]='i'; // mark as created
92         *t=(pthread_t)p;
93         return pthread_createReturns;
94     }
95     int pthread_joinReturns;
96     int pthread_joinCounter;
97     void* pthread_joinResultReturn;
pthread_join(pthread_t t,void ** r)98     virtual int pthread_join(pthread_t t, void ** r){
99         pthread_joinCounter++;
100         if(r!=0)
101             *r=pthread_joinResultReturn;
102         char* p=(char*)t;
103         p[0]='x';p[1]+=1;
104         return pthread_joinReturns;
105     }
106     int pthread_detachReturns;
107     int pthread_detachCounter;
pthread_detach(pthread_t t)108     virtual int pthread_detach(pthread_t t){
109         pthread_detachCounter++;
110         char* p=(char*)t;
111         p[0]='x';p[1]+=1;
112         return pthread_detachReturns;
113     }
114 
115     template<class T>
isInitialized(const T & t)116     static bool isInitialized(const T& t){
117         return ((char*)t)[0]=='i';
118     }
119     template<class T>
isDestroyed(const T & t)120     static bool isDestroyed(const T& t){
121         return ((char*)t)[0]=='x';
122     }
123     template<class T>
getDestroyCounter(const T & t)124     static int getDestroyCounter(const T& t){
125         return ((char*)t)[1];
126     }
127     template<class T>
getInvalidAccessCounter(const T & t)128     static int getInvalidAccessCounter(const T& t){
129         return ((char*)t)[2];
130     }
131     int pthread_cond_broadcastReturns;
132     int pthread_cond_broadcastCounter;
pthread_cond_broadcast(pthread_cond_t * c)133     virtual int pthread_cond_broadcast(pthread_cond_t *c){
134         pthread_cond_broadcastCounter++;
135         if(isDestroyed(c))((char*)c)[2]++;
136         return pthread_cond_broadcastReturns;
137     }
138     int pthread_cond_destroyReturns;
139     int pthread_cond_destroyCounter;
pthread_cond_destroy(pthread_cond_t * c)140     virtual int pthread_cond_destroy(pthread_cond_t *c){
141         pthread_cond_destroyCounter++;
142         char* p=(char*)c;
143         p[0]='x';p[1]+=1;
144         return pthread_cond_destroyReturns;
145     }
146     int pthread_cond_initReturns;
147     int pthread_cond_initCounter;
pthread_cond_init(pthread_cond_t * c,const pthread_condattr_t * a)148     virtual int pthread_cond_init(pthread_cond_t *c, const pthread_condattr_t *a){
149         pthread_cond_initCounter++;
150         char* p=(char*)c;
151         p[0]='i'; // mark as created
152         p[1]=0;   // destruction counter
153         p[2]=0;   // access after destruction counter
154         return pthread_cond_initReturns;
155     }
156     int pthread_cond_signalReturns;
157     int pthread_cond_signalCounter;
pthread_cond_signal(pthread_cond_t * c)158     virtual int pthread_cond_signal(pthread_cond_t *c){
159         pthread_cond_signalCounter++;
160         if(isDestroyed(c))((char*)c)[2]++;
161         return pthread_cond_signalReturns;
162     }
163     int pthread_cond_timedwaitReturns;
164     int pthread_cond_timedwaitCounter;
pthread_cond_timedwait(pthread_cond_t * c,pthread_mutex_t * m,const struct timespec * t)165     virtual int pthread_cond_timedwait(pthread_cond_t *c,
166             pthread_mutex_t *m, const struct timespec *t){
167         pthread_cond_timedwaitCounter++;
168         if(isDestroyed(c))((char*)c)[2]++;
169         return pthread_cond_timedwaitReturns;
170     }
171     int pthread_cond_waitReturns;
172     int pthread_cond_waitCounter;
pthread_cond_wait(pthread_cond_t * c,pthread_mutex_t * m)173     virtual int pthread_cond_wait(pthread_cond_t *c, pthread_mutex_t *m){
174         pthread_cond_waitCounter++;
175         if(isDestroyed(c))((char*)c)[2]++;
176         return pthread_cond_waitReturns;
177     }
178     int pthread_mutex_destroyReturns;
179     int pthread_mutex_destroyCounter;
pthread_mutex_destroy(pthread_mutex_t * m)180     virtual int pthread_mutex_destroy(pthread_mutex_t *m){
181         pthread_mutex_destroyCounter++;
182         char* p=(char*)m;
183         p[0]='x';p[1]+=1;
184         return pthread_mutex_destroyReturns;
185     }
186     int pthread_mutex_initReturns;
187     int pthread_mutex_initCounter;
pthread_mutex_init(pthread_mutex_t * m,const pthread_mutexattr_t * a)188     virtual int pthread_mutex_init(pthread_mutex_t *m, const pthread_mutexattr_t *a){
189         pthread_mutex_initCounter++;
190         char* p=(char*)m;
191         p[0]='i'; // mark as created
192         p[1]=0;   // destruction counter
193         p[2]=0;   // access after destruction counter
194         return pthread_mutex_initReturns;
195     }
196     int pthread_mutex_lockReturns;
197     int pthread_mutex_lockCounter;
pthread_mutex_lock(pthread_mutex_t * m)198     virtual int pthread_mutex_lock(pthread_mutex_t *m){
199         pthread_mutex_lockCounter++;
200         if(isDestroyed(m))((char*)m)[2]++;
201         return pthread_mutex_lockReturns;
202     }
203     int pthread_mutex_trylockReturns;
204     int pthread_mutex_trylockCounter;
pthread_mutex_trylock(pthread_mutex_t * m)205     virtual int pthread_mutex_trylock(pthread_mutex_t *m){
206         pthread_mutex_trylockCounter++;
207         if(isDestroyed(m))((char*)m)[2]++;
208         return pthread_mutex_trylockReturns;
209     }
210     int pthread_mutex_unlockReturns;
211     int pthread_mutex_unlockCounter;
pthread_mutex_unlock(pthread_mutex_t * m)212     virtual int pthread_mutex_unlock(pthread_mutex_t *m){
213         pthread_mutex_unlockCounter++;
214         if(isDestroyed(m))((char*)m)[2]++;
215         return pthread_mutex_unlockReturns;
216     }
217 };
218 
219 // simulates the way zookeeper threads make use of api_prolog/epilog and
220 //
221 class MockPthreadZKNull: public MockPthreadsNull
222 {
223     typedef std::map<pthread_t,zhandle_t*> Map;
224     Map map_;
225 public:
pthread_create(pthread_t * t,const pthread_attr_t * a,void * (* f)(void *),void * d)226     virtual int pthread_create(pthread_t * t, const pthread_attr_t *a,
227             void *(*f)(void *), void *d){
228         int ret=MockPthreadsNull::pthread_create(t,a,f,d);
229         zhandle_t* zh=(zhandle_t*)d;
230         adaptor_threads* ad=(adaptor_threads*)zh->adaptor_priv;
231         api_prolog(zh);
232         ad->threadsToWait--;
233         putValue(map_,*t,zh);
234         return ret;
235     }
pthread_join(pthread_t t,void ** r)236     virtual int pthread_join(pthread_t t, void ** r){
237         zhandle_t* zh=0;
238         if(getValue(map_,t,zh))
239             api_epilog(zh,0);
240         return MockPthreadsNull::pthread_join(t,r);
241     }
242 };
243 
244 struct ThreadInfo{
245     typedef enum {RUNNING,TERMINATED} ThreadState;
246 
ThreadInfoThreadInfo247     ThreadInfo():
248         destructionCounter_(0),invalidAccessCounter_(0),state_(RUNNING)
249     {
250     }
251 
incDestroyedThreadInfo252     ThreadInfo& incDestroyed() {
253         destructionCounter_++;
254         return *this;
255     }
incInvalidAccessThreadInfo256     ThreadInfo& incInvalidAccess(){
257         invalidAccessCounter_++;
258         return *this;
259     }
setTerminatedThreadInfo260     ThreadInfo& setTerminated(){
261         state_=TERMINATED;
262         return *this;
263     }
264     int destructionCounter_;
265     int invalidAccessCounter_;
266     ThreadState state_;
267 };
268 
269 class CheckedPthread: public MockPthreadsBase
270 {
271     // first => destruction counter
272     // second => invalid access counter
273     //typedef std::pair<int,int> Entry;
274     typedef ThreadInfo Entry;
275     typedef std::map<pthread_t,Entry> ThreadMap;
276     static ThreadMap tmap_;
getMap(const TypeOp<pthread_t>::BareT &)277     static ThreadMap& getMap(const TypeOp<pthread_t>::BareT&){return tmap_;}
278     typedef std::map<pthread_mutex_t*,Entry> MutexMap;
279     static MutexMap mmap_;
getMap(const TypeOp<pthread_mutex_t>::BareT &)280     static MutexMap& getMap(const TypeOp<pthread_mutex_t>::BareT&){return mmap_;}
281     typedef std::map<pthread_cond_t*,Entry> CVMap;
282     static CVMap cvmap_;
getMap(const TypeOp<pthread_cond_t>::BareT &)283     static CVMap& getMap(const TypeOp<pthread_cond_t>::BareT&){return cvmap_;}
284 
285     static Mutex mx;
286 
287     template<class T>
markDestroyed(T & t)288     static void markDestroyed(T& t){
289         typedef typename TypeOp<T>::BareT Type;
290         Entry e;
291         synchronized(mx);
292         if(getValue(getMap(Type()),t,e)){
293             putValue(getMap(Type()),t,Entry(e).incDestroyed());
294         }else{
295             putValue(getMap(Type()),t,Entry().incDestroyed());
296         }
297     }
298     template<class T>
markCreated(T & t)299     static void markCreated(T& t){
300         typedef typename TypeOp<T>::BareT Type;
301         Entry e;
302         synchronized(mx);
303         if(!getValue(getMap(Type()),t,e))
304             putValue(getMap(Type()),t,Entry());
305     }
306     template<class T>
checkAccessed(T & t)307     static void checkAccessed(T& t){
308         typedef typename TypeOp<T>::BareT Type;
309         Entry e;
310         synchronized(mx);
311         if(getValue(getMap(Type()),t,e) && e.destructionCounter_>0)
312             putValue(getMap(Type()),t,Entry(e).incInvalidAccess());
313     }
setTerminated(pthread_t t)314     static void setTerminated(pthread_t t){
315         Entry e;
316         synchronized(mx);
317         if(getValue(tmap_,t,e))
318             putValue(tmap_,t,Entry(e).setTerminated());
319     }
320 public:
321     bool verbose;
CheckedPthread()322     CheckedPthread():verbose(false){
323         tmap_.clear();
324         mmap_.clear();
325         cvmap_.clear();
326         mx.release();
327     }
328     template <class T>
isInitialized(const T & t)329     static bool isInitialized(const T& t){
330         typedef typename TypeOp<T>::BareT Type;
331         Entry e;
332         synchronized(mx);
333         return getValue(getMap(Type()),t,e) && e.destructionCounter_==0;
334     }
335     template <class T>
isDestroyed(const T & t)336     static bool isDestroyed(const T& t){
337         typedef typename TypeOp<T>::BareT Type;
338         Entry e;
339         synchronized(mx);
340         return getValue(getMap(Type()),t,e) && e.destructionCounter_>0;
341     }
isTerminated(pthread_t t)342     static bool isTerminated(pthread_t t){
343         Entry e;
344         synchronized(mx);
345         return getValue(tmap_,t,e) && e.state_==ThreadInfo::TERMINATED;
346     }
347     template <class T>
getDestroyCounter(const T & t)348     static int getDestroyCounter(const T& t){
349         typedef typename TypeOp<T>::BareT Type;
350         Entry e;
351         synchronized(mx);
352         return getValue(getMap(Type()),t,e)?e.destructionCounter_:-1;
353     }
354     template<class T>
getInvalidAccessCounter(const T & t)355     static int getInvalidAccessCounter(const T& t){
356         typedef typename TypeOp<T>::BareT Type;
357         Entry e;
358         synchronized(mx);
359         return getValue(getMap(Type()),t,e)?e.invalidAccessCounter_:-1;
360     }
361 
362     struct ThreadContext{
363         typedef void *(*ThreadFunc)(void *);
364 
ThreadContextThreadContext365         ThreadContext(ThreadFunc func,void* param):func_(func),param_(param){}
366         ThreadFunc func_;
367         void* param_;
368     };
threadFuncWrapper(void * v)369     static void* threadFuncWrapper(void* v){
370         ThreadContext* ctx=(ThreadContext*)v;
371         pthread_t t=pthread_self();
372         markCreated(t);
373         void* res=ctx->func_(ctx->param_);
374         setTerminated(pthread_self());
375         delete ctx;
376         return res;
377     }
pthread_create(pthread_t * t,const pthread_attr_t * a,void * (* f)(void *),void * d)378     virtual int pthread_create(pthread_t * t, const pthread_attr_t *a,
379             void *(*f)(void *), void *d)
380     {
381         int ret=LIBC_SYMBOLS.pthread_create(t,a,threadFuncWrapper,
382                 new ThreadContext(f,d));
383         if(verbose)
384             TEST_TRACE("thread created %p",*t);
385         return ret;
386     }
pthread_join(pthread_t t,void ** r)387     virtual int pthread_join(pthread_t t, void ** r){
388         if(verbose) TEST_TRACE("thread joined %p",t);
389         int ret=LIBC_SYMBOLS.pthread_join(t,r);
390         if(ret==0)
391             markDestroyed(t);
392         return ret;
393     }
pthread_detach(pthread_t t)394     virtual int pthread_detach(pthread_t t){
395         if(verbose) TEST_TRACE("thread detached %p",t);
396         int ret=LIBC_SYMBOLS.pthread_detach(t);
397         if(ret==0)
398             markDestroyed(t);
399         return ret;
400     }
pthread_cond_broadcast(pthread_cond_t * c)401     virtual int pthread_cond_broadcast(pthread_cond_t *c){
402         checkAccessed(c);
403         return LIBC_SYMBOLS.pthread_cond_broadcast(c);
404     }
pthread_cond_destroy(pthread_cond_t * c)405     virtual int pthread_cond_destroy(pthread_cond_t *c){
406         markDestroyed(c);
407         return LIBC_SYMBOLS.pthread_cond_destroy(c);
408     }
pthread_cond_init(pthread_cond_t * c,const pthread_condattr_t * a)409     virtual int pthread_cond_init(pthread_cond_t *c, const pthread_condattr_t *a){
410         markCreated(c);
411         return LIBC_SYMBOLS.pthread_cond_init(c,a);
412     }
pthread_cond_signal(pthread_cond_t * c)413     virtual int pthread_cond_signal(pthread_cond_t *c){
414         checkAccessed(c);
415         return LIBC_SYMBOLS.pthread_cond_signal(c);
416     }
pthread_cond_timedwait(pthread_cond_t * c,pthread_mutex_t * m,const struct timespec * t)417     virtual int pthread_cond_timedwait(pthread_cond_t *c,
418             pthread_mutex_t *m, const struct timespec *t){
419         checkAccessed(c);
420         return LIBC_SYMBOLS.pthread_cond_timedwait(c,m,t);
421     }
pthread_cond_wait(pthread_cond_t * c,pthread_mutex_t * m)422     virtual int pthread_cond_wait(pthread_cond_t *c, pthread_mutex_t *m){
423         checkAccessed(c);
424         return LIBC_SYMBOLS.pthread_cond_wait(c,m);
425     }
pthread_mutex_destroy(pthread_mutex_t * m)426     virtual int pthread_mutex_destroy(pthread_mutex_t *m){
427         markDestroyed(m);
428         return LIBC_SYMBOLS.pthread_mutex_destroy(m);
429     }
pthread_mutex_init(pthread_mutex_t * m,const pthread_mutexattr_t * a)430     virtual int pthread_mutex_init(pthread_mutex_t *m, const pthread_mutexattr_t *a){
431         markCreated(m);
432         return LIBC_SYMBOLS.pthread_mutex_init(m,a);
433     }
pthread_mutex_lock(pthread_mutex_t * m)434     virtual int pthread_mutex_lock(pthread_mutex_t *m){
435         checkAccessed(m);
436         return LIBC_SYMBOLS.pthread_mutex_lock(m);
437     }
pthread_mutex_trylock(pthread_mutex_t * m)438     virtual int pthread_mutex_trylock(pthread_mutex_t *m){
439         checkAccessed(m);
440         return LIBC_SYMBOLS.pthread_mutex_trylock(m);
441     }
pthread_mutex_unlock(pthread_mutex_t * m)442     virtual int pthread_mutex_unlock(pthread_mutex_t *m){
443         checkAccessed(m);
444         return LIBC_SYMBOLS.pthread_mutex_unlock(m);
445     }
446 };
447 
448 #endif /*PTHREADMOCKS_H_*/
449 
450