1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include <arpa/inet.h>  // for htonl
20 #include <memory>
21 
22 #include <zookeeper.h>
23 #include <proto.h>
24 
25 #ifdef THREADED
26 #include "PthreadMocks.h"
27 #endif
28 #include "ZKMocks.h"
29 
30 using namespace std;
31 
32 TestClientId testClientId;
33 const char* TestClientId::PASSWD="1234567890123456";
34 
parse(const std::string & buf)35 HandshakeRequest* HandshakeRequest::parse(const std::string& buf) {
36     unique_ptr<HandshakeRequest> req(new HandshakeRequest);
37 
38     memcpy(&req->protocolVersion,buf.data(), sizeof(req->protocolVersion));
39     req->protocolVersion = htonl(req->protocolVersion);
40 
41     int offset=sizeof(req->protocolVersion);
42 
43     memcpy(&req->lastZxidSeen,buf.data()+offset,sizeof(req->lastZxidSeen));
44     req->lastZxidSeen = zoo_htonll(req->lastZxidSeen);
45     offset+=sizeof(req->lastZxidSeen);
46 
47     memcpy(&req->timeOut,buf.data()+offset,sizeof(req->timeOut));
48     req->timeOut = htonl(req->timeOut);
49     offset+=sizeof(req->timeOut);
50 
51     memcpy(&req->sessionId,buf.data()+offset,sizeof(req->sessionId));
52     req->sessionId = zoo_htonll(req->sessionId);
53     offset+=sizeof(req->sessionId);
54 
55     memcpy(&req->passwd_len,buf.data()+offset,sizeof(req->passwd_len));
56     req->passwd_len = htonl(req->passwd_len);
57     offset+=sizeof(req->passwd_len);
58 
59     memcpy(req->passwd,buf.data()+offset,sizeof(req->passwd));
60     offset+=sizeof(req->passwd);
61 
62     memcpy(&req->readOnly,buf.data()+offset,sizeof(req->readOnly));
63 
64     if(testClientId.client_id==req->sessionId &&
65             !memcmp(testClientId.passwd,req->passwd,sizeof(req->passwd)))
66         return req.release();
67     // the request didn't match -- may not be a handshake request after all
68 
69     return 0;
70 }
71 
72 // *****************************************************************************
73 // watcher action implementation
activeWatcher(zhandle_t * zh,int type,int state,const char * path,void * ctx)74 void activeWatcher(zhandle_t *zh,
75                    int type, int state, const char *path,void* ctx) {
76 
77     if (zh == 0 || ctx == 0)
78       return;
79 
80     WatcherAction* action = (WatcherAction *)ctx;
81 
82     if (type == ZOO_SESSION_EVENT) {
83         if (state == ZOO_EXPIRED_SESSION_STATE)
84             action->onSessionExpired(zh);
85         else if(state == ZOO_CONNECTING_STATE)
86             action->onConnectionLost(zh);
87         else if(state == ZOO_CONNECTED_STATE)
88             action->onConnectionEstablished(zh);
89     } else if (type == ZOO_CHANGED_EVENT)
90         action->onNodeValueChanged(zh,path);
91     else if (type == ZOO_DELETED_EVENT)
92         action->onNodeDeleted(zh,path);
93     else if (type == ZOO_CHILD_EVENT)
94         action->onChildChanged(zh,path);
95 
96     // TODO: implement for the rest of the event types
97 
98     action->setWatcherTriggered();
99 }
100 
isWatcherTriggered() const101 SyncedBoolCondition WatcherAction::isWatcherTriggered() const {
102     return SyncedBoolCondition(triggered_,mx_);
103 }
104 
105 // a set of async completion signatures
106 
asyncCompletion(int rc,ACL_vector * acl,Stat * stat,const void * data)107 void asyncCompletion(int rc, ACL_vector *acl,Stat *stat, const void *data){
108     assert("Completion data is NULL"&&data);
109     static_cast<AsyncCompletion*>((void*)data)->aclCompl(rc,acl,stat);
110 }
111 
asyncCompletion(int rc,const char * value,int len,const Stat * stat,const void * data)112 void asyncCompletion(int rc, const char *value, int len, const Stat *stat,
113         const void *data) {
114     assert("Completion data is NULL"&&data);
115     static_cast<AsyncCompletion*>((void*)data)->dataCompl(rc,value,len,stat);
116 }
117 
asyncCompletion(int rc,const Stat * stat,const void * data)118 void asyncCompletion(int rc, const Stat *stat, const void *data) {
119     assert("Completion data is NULL"&&data);
120     static_cast<AsyncCompletion*>((void*)data)->statCompl(rc,stat);
121 }
122 
asyncCompletion(int rc,const char * value,const void * data)123 void asyncCompletion(int rc, const char *value, const void *data) {
124     assert("Completion data is NULL"&&data);
125     static_cast<AsyncCompletion*>((void*)data)->stringCompl(rc,value);
126 }
127 
asyncCompletion(int rc,const String_vector * strings,const void * data)128 void asyncCompletion(int rc,const String_vector *strings, const void *data) {
129     assert("Completion data is NULL"&&data);
130     static_cast<AsyncCompletion*>((void*)data)->stringsCompl(rc,strings);
131 }
132 
asyncCompletion(int rc,const void * data)133 void asyncCompletion(int rc, const void *data) {
134     assert("Completion data is NULL"&&data);
135     static_cast<AsyncCompletion*>((void*)data)->voidCompl(rc);
136 }
137 
138 // a predicate implementation
operator ()() const139 bool IOThreadStopped::operator()() const{
140 #ifdef THREADED
141     adaptor_threads* adaptor=(adaptor_threads*)zh_->adaptor_priv;
142     return CheckedPthread::isTerminated(adaptor->io);
143 #else
144     assert("IOThreadStopped predicate is only for use with THREADED client" &&
145            false);
146     return false;
147 #endif
148 }
149 
150 //******************************************************************************
151 //
152 DECLARE_WRAPPER(int,flush_send_queue,(zhandle_t*zh, int timeout))
153 {
154     if(!Mock_flush_send_queue::mock_)
155         return CALL_REAL(flush_send_queue,(zh,timeout));
156     return Mock_flush_send_queue::mock_->call(zh,timeout);
157 }
158 
159 Mock_flush_send_queue* Mock_flush_send_queue::mock_=0;
160 
161 //******************************************************************************
162 //
163 DECLARE_WRAPPER(int32_t,get_xid,())
164 {
165     if(!Mock_get_xid::mock_)
166         return CALL_REAL(get_xid,());
167     return Mock_get_xid::mock_->call();
168 }
169 
170 Mock_get_xid* Mock_get_xid::mock_=0;
171 
172 //******************************************************************************
173 // activateWatcher mock
174 
175 DECLARE_WRAPPER(void,activateWatcher,(zhandle_t *zh, watcher_registration_t* reg, int rc))
176 {
177     if(!Mock_activateWatcher::mock_){
178         CALL_REAL(activateWatcher,(zh, reg,rc));
179     }else{
180         Mock_activateWatcher::mock_->call(zh, reg,rc);
181     }
182 }
183 Mock_activateWatcher* Mock_activateWatcher::mock_=0;
184 
185 class ActivateWatcherWrapper: public Mock_activateWatcher{
186 public:
ActivateWatcherWrapper()187     ActivateWatcherWrapper():ctx_(0),activated_(false){}
188 
call(zhandle_t * zh,watcher_registration_t * reg,int rc)189     virtual void call(zhandle_t *zh, watcher_registration_t* reg, int rc){
190         CALL_REAL(activateWatcher,(zh, reg,rc));
191         synchronized(mx_);
192         if(reg->context==ctx_){
193             activated_=true;
194             ctx_=0;
195         }
196     }
197 
setContext(void * ctx)198     void setContext(void* ctx){
199         synchronized(mx_);
200         ctx_=ctx;
201         activated_=false;
202     }
203 
isActivated() const204     SyncedBoolCondition isActivated() const{
205         return SyncedBoolCondition(activated_,mx_);
206     }
207     mutable Mutex mx_;
208     void* ctx_;
209     bool activated_;
210 };
211 
WatcherActivationTracker()212 WatcherActivationTracker::WatcherActivationTracker():
213     wrapper_(new ActivateWatcherWrapper)
214 {
215 }
216 
~WatcherActivationTracker()217 WatcherActivationTracker::~WatcherActivationTracker(){
218     delete wrapper_;
219 }
220 
track(void * ctx)221 void WatcherActivationTracker::track(void* ctx){
222     wrapper_->setContext(ctx);
223 }
224 
isWatcherActivated() const225 SyncedBoolCondition WatcherActivationTracker::isWatcherActivated() const{
226     return wrapper_->isActivated();
227 }
228 
229 //******************************************************************************
230 //
231 DECLARE_WRAPPER(void,deliverWatchers,(zhandle_t* zh,int type,int state, const char* path, watcher_object_list_t **list))
232 {
233     if(!Mock_deliverWatchers::mock_){
234         CALL_REAL(deliverWatchers,(zh,type,state,path, list));
235     }else{
236         Mock_deliverWatchers::mock_->call(zh,type,state,path, list);
237     }
238 }
239 
240 Mock_deliverWatchers* Mock_deliverWatchers::mock_=0;
241 
242 struct RefCounterValue{
RefCounterValueRefCounterValue243     RefCounterValue(zhandle_t* const& zh,int32_t expectedCounter,Mutex& mx):
244         zh_(zh),expectedCounter_(expectedCounter),mx_(mx){}
operator ()RefCounterValue245     bool operator()() const{
246         {
247             synchronized(mx_);
248             if(zh_==0)
249                 return false;
250         }
251         return inc_ref_counter(zh_,0)==expectedCounter_;
252     }
253     zhandle_t* const& zh_;
254     int32_t expectedCounter_;
255     Mutex& mx_;
256 };
257 
258 
259 class DeliverWatchersWrapper: public Mock_deliverWatchers{
260 public:
DeliverWatchersWrapper(int type,int state,bool terminate)261     DeliverWatchersWrapper(int type,int state,bool terminate):
262         type_(type),state_(state),
263         allDelivered_(false),terminate_(terminate),zh_(0),deliveryCounter_(0){}
call(zhandle_t * zh,int type,int state,const char * path,watcher_object_list ** list)264     virtual void call(zhandle_t* zh, int type, int state,
265                       const char* path, watcher_object_list **list) {
266         {
267             synchronized(mx_);
268             zh_=zh;
269             allDelivered_=false;
270         }
271         CALL_REAL(deliverWatchers,(zh,type,state,path, list));
272         if(type_==type && state_==state){
273             if(terminate_){
274                 // prevent zhandle_t from being prematurely distroyed;
275                 // this will also ensure that zookeeper_close() cleanups the
276                 //  thread resources by calling finish_adaptor()
277                 inc_ref_counter(zh,1);
278                 terminateZookeeperThreads(zh);
279             }
280             synchronized(mx_);
281             allDelivered_=true;
282             deliveryCounter_++;
283         }
284     }
isDelivered() const285     SyncedBoolCondition isDelivered() const{
286         if(terminate_){
287             int i=ensureCondition(RefCounterValue(zh_,1,mx_),1000);
288             assert(i<1000);
289         }
290         return SyncedBoolCondition(allDelivered_,mx_);
291     }
resetDeliveryCounter()292     void resetDeliveryCounter(){
293         synchronized(mx_);
294         deliveryCounter_=0;
295     }
deliveryCounterEquals(int expected) const296     SyncedIntegerEqual deliveryCounterEquals(int expected) const{
297         if(terminate_){
298             int i=ensureCondition(RefCounterValue(zh_,1,mx_),1000);
299             assert(i<1000);
300         }
301         return SyncedIntegerEqual(deliveryCounter_,expected,mx_);
302     }
303     int type_;
304     int state_;
305     mutable Mutex mx_;
306     bool allDelivered_;
307     bool terminate_;
308     zhandle_t* zh_;
309     int deliveryCounter_;
310 };
311 
WatcherDeliveryTracker(int type,int state,bool terminateCompletionThread)312 WatcherDeliveryTracker::WatcherDeliveryTracker(
313         int type,int state,bool terminateCompletionThread):
314     deliveryWrapper_(new DeliverWatchersWrapper(
315             type,state,terminateCompletionThread)){
316 }
317 
~WatcherDeliveryTracker()318 WatcherDeliveryTracker::~WatcherDeliveryTracker(){
319     delete deliveryWrapper_;
320 }
321 
isWatcherProcessingCompleted() const322 SyncedBoolCondition WatcherDeliveryTracker::isWatcherProcessingCompleted() const {
323     return deliveryWrapper_->isDelivered();
324 }
325 
resetDeliveryCounter()326 void WatcherDeliveryTracker::resetDeliveryCounter(){
327     deliveryWrapper_->resetDeliveryCounter();
328 }
329 
deliveryCounterEquals(int expected) const330 SyncedIntegerEqual WatcherDeliveryTracker::deliveryCounterEquals(int expected) const {
331     return deliveryWrapper_->deliveryCounterEquals(expected);
332 }
333 
334 //******************************************************************************
335 //
toString() const336 string HandshakeResponse::toString() const {
337     string buf;
338     int32_t tmp=htonl(protocolVersion);
339     buf.append((char*)&tmp,sizeof(tmp));
340     tmp=htonl(timeOut);
341     buf.append((char*)&tmp,sizeof(tmp));
342     int64_t tmp64=zoo_htonll(sessionId);
343     buf.append((char*)&tmp64,sizeof(sessionId));
344     tmp=htonl(passwd_len);
345     buf.append((char*)&tmp,sizeof(tmp));
346     buf.append(passwd,sizeof(passwd));
347     if (!omitReadOnly) {
348         buf.append(&readOnly,sizeof(readOnly));
349     }
350     // finally set the buffer length
351     tmp=htonl(buf.size());
352     buf.insert(0,(char*)&tmp, sizeof(tmp));
353     return buf;
354 }
355 
toString() const356 string ZooGetResponse::toString() const{
357     oarchive* oa=create_buffer_oarchive();
358 
359     ReplyHeader h = {xid_,1,ZOK};
360     serialize_ReplyHeader(oa, "hdr", &h);
361 
362     GetDataResponse resp;
363     char buf[1024];
364     assert("GetDataResponse is too long"&&data_.size()<=sizeof(buf));
365     resp.data.len=data_.size();
366     resp.data.buff=buf;
367     data_.copy(resp.data.buff, data_.size());
368     resp.stat=stat_;
369     serialize_GetDataResponse(oa, "reply", &resp);
370     int32_t len=htonl(get_buffer_len(oa));
371     string res((char*)&len,sizeof(len));
372     res.append(get_buffer(oa),get_buffer_len(oa));
373 
374     close_buffer_oarchive(&oa,1);
375     return res;
376 }
377 
toString() const378 string ZooStatResponse::toString() const{
379     oarchive* oa=create_buffer_oarchive();
380 
381     ReplyHeader h = {xid_,1,rc_};
382     serialize_ReplyHeader(oa, "hdr", &h);
383 
384     SetDataResponse resp;
385     resp.stat=stat_;
386     serialize_SetDataResponse(oa, "reply", &resp);
387     int32_t len=htonl(get_buffer_len(oa));
388     string res((char*)&len,sizeof(len));
389     res.append(get_buffer(oa),get_buffer_len(oa));
390 
391     close_buffer_oarchive(&oa,1);
392     return res;
393 }
394 
toString() const395 string ZooGetChildrenResponse::toString() const{
396     oarchive* oa=create_buffer_oarchive();
397 
398     ReplyHeader h = {xid_,1,rc_};
399     serialize_ReplyHeader(oa, "hdr", &h);
400 
401     GetChildrenResponse resp;
402     // populate the string vector
403     allocate_String_vector(&resp.children,strings_.size());
404     for(int i=0;i<(int)strings_.size();++i)
405         resp.children.data[i]=strdup(strings_[i].c_str());
406     serialize_GetChildrenResponse(oa, "reply", &resp);
407     deallocate_GetChildrenResponse(&resp);
408 
409     int32_t len=htonl(get_buffer_len(oa));
410     string res((char*)&len,sizeof(len));
411     res.append(get_buffer(oa),get_buffer_len(oa));
412 
413     close_buffer_oarchive(&oa,1);
414     return res;
415 }
416 
toString() const417 string ZNodeEvent::toString() const{
418     oarchive* oa=create_buffer_oarchive();
419     struct WatcherEvent evt = {type_,0,(char*)path_.c_str()};
420     struct ReplyHeader h = {WATCHER_EVENT_XID,0,ZOK };
421 
422     serialize_ReplyHeader(oa, "hdr", &h);
423     serialize_WatcherEvent(oa, "event", &evt);
424 
425     int32_t len=htonl(get_buffer_len(oa));
426     string res((char*)&len,sizeof(len));
427     res.append(get_buffer(oa),get_buffer_len(oa));
428 
429     close_buffer_oarchive(&oa,1);
430     return res;
431 }
432 
toString() const433 string PingResponse::toString() const{
434     oarchive* oa=create_buffer_oarchive();
435 
436     ReplyHeader h = {PING_XID,1,ZOK};
437     serialize_ReplyHeader(oa, "hdr", &h);
438 
439     int32_t len=htonl(get_buffer_len(oa));
440     string res((char*)&len,sizeof(len));
441     res.append(get_buffer(oa),get_buffer_len(oa));
442 
443     close_buffer_oarchive(&oa,1);
444     return res;
445 }
446 
447 //******************************************************************************
448 // Zookeeper server simulator
449 //
hasMoreRecv() const450 bool ZookeeperServer::hasMoreRecv() const{
451   return recvHasMore.get()!=0  || connectionLost;
452 }
453 
callRecv(int s,void * buf,size_t len,int flags)454 ssize_t ZookeeperServer::callRecv(int s,void *buf,size_t len,int flags){
455     if(connectionLost){
456         recvReturnBuffer.erase();
457         return 0;
458     }
459     // done transmitting the current buffer?
460     if(recvReturnBuffer.size()==0){
461         synchronized(recvQMx);
462         if(recvQueue.empty()){
463             recvErrno=EAGAIN;
464             return Mock_socket::callRecv(s,buf,len,flags);
465         }
466         --recvHasMore;
467         Element& el=recvQueue.front();
468         if(el.first!=0){
469             recvReturnBuffer=el.first->toString();
470             delete el.first;
471         }
472         recvErrno=el.second;
473         recvQueue.pop_front();
474     }
475     return Mock_socket::callRecv(s,buf,len,flags);
476 }
477 
onMessageReceived(const RequestHeader & rh,iarchive * ia)478 void ZookeeperServer::onMessageReceived(const RequestHeader& rh, iarchive* ia){
479     // no-op by default
480 }
481 
notifyBufferSent(const std::string & buffer)482 void ZookeeperServer::notifyBufferSent(const std::string& buffer){
483     if(HandshakeRequest::isValid(buffer)){
484         // could be a connect request
485         unique_ptr<HandshakeRequest> req(HandshakeRequest::parse(buffer));
486         if(req.get()!=0){
487             // handle the handshake
488             int64_t sessId=sessionExpired?req->sessionId+1:req->sessionId;
489             sessionExpired=false;
490             addRecvResponse(new HandshakeResponse(sessId));
491             return;
492         }
493         // not a connect request -- fall thru
494     }
495     // parse the buffer to extract the request type and its xid
496     iarchive *ia=create_buffer_iarchive((char*)buffer.data(), buffer.size());
497     RequestHeader rh;
498     deserialize_RequestHeader(ia,"hdr",&rh);
499     // notify the "server" a client request has arrived
500     if (rh.xid == -8) {
501         Element e = Element(new ZooStatResponse,0);
502         e.first->setXID(-8);
503         addRecvResponse(e);
504         close_buffer_iarchive(&ia);
505         return;
506     } else {
507         onMessageReceived(rh,ia);
508     }
509     close_buffer_iarchive(&ia);
510     if(rh.type==ZOO_CLOSE_OP){
511         ++closeSent;
512         return; // no reply for close requests
513     }
514     // get the next response from the response queue and append it to the
515     // receive list
516     Element e;
517     {
518         synchronized(respQMx);
519         if(respQueue.empty())
520             return;
521         e=respQueue.front();
522         respQueue.pop_front();
523     }
524     e.first->setXID(rh.xid);
525     addRecvResponse(e);
526 }
527 
forceConnected(zhandle_t * zh)528 void forceConnected(zhandle_t* zh){
529     // simulate connected state
530     zh->state=ZOO_CONNECTED_STATE;
531 
532     // Simulate we're connected to the first host in our host list
533     zh->fd->sock=ZookeeperServer::FD;
534     assert(zh->addrs.count > 0);
535     zh->addr_cur = zh->addrs.data[0];
536     zh->addrs.next++;
537 
538     zh->input_buffer=0;
539     gettimeofday(&zh->last_recv,0);
540     gettimeofday(&zh->last_send,0);
541 }
542 
terminateZookeeperThreads(zhandle_t * zh)543 void terminateZookeeperThreads(zhandle_t* zh){
544     // this will cause the zookeeper threads to terminate
545     zh->close_requested=1;
546 }
547