1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <folly/io/async/AsyncSocketException.h>
20 #include <folly/io/async/AsyncTransport.h>
21 #include <folly/portability/GMock.h>
22 
23 namespace folly {
24 namespace test {
25 
26 class MockAsyncTransportLifecycleObserver
27     : public AsyncTransport::LifecycleObserver {
28  public:
29   using AsyncTransport::LifecycleObserver::LifecycleObserver;
30   MOCK_METHOD1(observerAttachMock, void(AsyncTransport*));
31   MOCK_METHOD1(observerDetachMock, void(AsyncTransport*));
32   MOCK_METHOD1(destroyMock, void(AsyncTransport*));
33   MOCK_METHOD1(closeMock, void(AsyncTransport*));
34   MOCK_METHOD1(connectAttemptMock, void(AsyncTransport*));
35   MOCK_METHOD1(connectSuccessMock, void(AsyncTransport*));
36   MOCK_METHOD2(
37       connectErrorMock, void(AsyncTransport*, const AsyncSocketException&));
38   MOCK_METHOD2(evbAttachMock, void(AsyncTransport*, EventBase*));
39   MOCK_METHOD2(evbDetachMock, void(AsyncTransport*, EventBase*));
40   MOCK_METHOD2(
41       byteEventMock, void(AsyncTransport*, const AsyncTransport::ByteEvent&));
42   MOCK_METHOD1(byteEventsEnabledMock, void(AsyncTransport*));
43   MOCK_METHOD2(
44       byteEventsUnavailableMock,
45       void(AsyncTransport*, const AsyncSocketException&));
46   MOCK_METHOD2(
47       prewriteMock, PrewriteRequest(AsyncTransport*, const PrewriteState&));
48 
49  private:
observerAttach(AsyncTransport * trans)50   void observerAttach(AsyncTransport* trans) noexcept override {
51     observerAttachMock(trans);
52   }
observerDetach(AsyncTransport * trans)53   void observerDetach(AsyncTransport* trans) noexcept override {
54     observerDetachMock(trans);
55   }
destroy(AsyncTransport * trans)56   void destroy(AsyncTransport* trans) noexcept override { destroyMock(trans); }
close(AsyncTransport * trans)57   void close(AsyncTransport* trans) noexcept override { closeMock(trans); }
connectAttempt(AsyncTransport * trans)58   void connectAttempt(AsyncTransport* trans) noexcept override {
59     connectAttemptMock(trans);
60   }
connectSuccess(AsyncTransport * trans)61   void connectSuccess(AsyncTransport* trans) noexcept override {
62     connectSuccessMock(trans);
63   }
connectError(AsyncTransport * trans,const AsyncSocketException & ex)64   void connectError(
65       AsyncTransport* trans, const AsyncSocketException& ex) noexcept override {
66     connectErrorMock(trans, ex);
67   }
evbAttach(AsyncTransport * trans,EventBase * eb)68   void evbAttach(AsyncTransport* trans, EventBase* eb) noexcept override {
69     evbAttachMock(trans, eb);
70   }
evbDetach(AsyncTransport * trans,EventBase * eb)71   void evbDetach(AsyncTransport* trans, EventBase* eb) noexcept override {
72     evbDetachMock(trans, eb);
73   }
byteEvent(AsyncTransport * trans,const AsyncTransport::ByteEvent & ev)74   void byteEvent(
75       AsyncTransport* trans,
76       const AsyncTransport::ByteEvent& ev) noexcept override {
77     byteEventMock(trans, ev);
78   }
byteEventsEnabled(AsyncTransport * trans)79   void byteEventsEnabled(AsyncTransport* trans) noexcept override {
80     byteEventsEnabledMock(trans);
81   }
byteEventsUnavailable(AsyncTransport * trans,const AsyncSocketException & ex)82   void byteEventsUnavailable(
83       AsyncTransport* trans, const AsyncSocketException& ex) noexcept override {
84     byteEventsUnavailableMock(trans, ex);
85   }
prewrite(AsyncTransport * trans,const PrewriteState & state)86   PrewriteRequest prewrite(
87       AsyncTransport* trans, const PrewriteState& state) noexcept override {
88     return prewriteMock(trans, state);
89   }
90 };
91 
92 /**
93  * Extends mock class to simplify ByteEvents tests.
94  */
95 class MockAsyncTransportObserverForByteEvents
96     : public MockAsyncTransportLifecycleObserver {
97  public:
MockAsyncTransportObserverForByteEvents(AsyncTransport * transport,const MockAsyncTransportObserverForByteEvents::Config & observerConfig)98   MockAsyncTransportObserverForByteEvents(
99       AsyncTransport* transport,
100       const MockAsyncTransportObserverForByteEvents::Config& observerConfig)
101       : MockAsyncTransportLifecycleObserver(observerConfig),
102         transport_(transport) {
103     ON_CALL(*this, byteEventMock(testing::_, testing::_))
104         .WillByDefault(
105             testing::Invoke([this](
106                                 AsyncTransport* transport,
107                                 const AsyncTransport::ByteEvent& event) {
108               CHECK_EQ(this->transport_, transport);
109               byteEvents_.emplace_back(event);
110             }));
111     ON_CALL(*this, byteEventsEnabledMock(testing::_))
112         .WillByDefault(testing::Invoke([this](AsyncTransport* transport) {
113           CHECK_EQ(this->transport_, transport);
114           byteEventsEnabledCalled_++;
115         }));
116 
117     ON_CALL(*this, byteEventsUnavailableMock(testing::_, testing::_))
118         .WillByDefault(testing::Invoke(
119             [this](AsyncTransport* transport, const AsyncSocketException& ex) {
120               CHECK_EQ(this->transport_, transport);
121               byteEventsUnavailableCalled_++;
122               byteEventsUnavailableCalledEx_.emplace(ex);
123             }));
124     transport->addLifecycleObserver(this);
125   }
126 
getByteEvents()127   const std::vector<AsyncTransport::ByteEvent>& getByteEvents() {
128     return byteEvents_;
129   }
130 
getByteEventReceivedWithOffset(const uint64_t offset,const AsyncTransport::ByteEvent::Type type)131   folly::Optional<AsyncTransport::ByteEvent> getByteEventReceivedWithOffset(
132       const uint64_t offset, const AsyncTransport::ByteEvent::Type type) {
133     for (const auto& byteEvent : byteEvents_) {
134       if (type == byteEvent.type && offset == byteEvent.offset) {
135         return byteEvent;
136       }
137     }
138     return folly::none;
139   }
140 
maxOffsetForByteEventReceived(const AsyncTransport::ByteEvent::Type type)141   folly::Optional<uint64_t> maxOffsetForByteEventReceived(
142       const AsyncTransport::ByteEvent::Type type) {
143     folly::Optional<uint64_t> maybeMaxOffset;
144     for (const auto& byteEvent : byteEvents_) {
145       if (type == byteEvent.type &&
146           (!maybeMaxOffset.has_value() ||
147            maybeMaxOffset.value() <= byteEvent.offset)) {
148         maybeMaxOffset = byteEvent.offset;
149       }
150     }
151     return maybeMaxOffset;
152   }
153 
checkIfByteEventReceived(const AsyncTransport::ByteEvent::Type type,const uint64_t offset)154   bool checkIfByteEventReceived(
155       const AsyncTransport::ByteEvent::Type type, const uint64_t offset) {
156     for (const auto& byteEvent : byteEvents_) {
157       if (type == byteEvent.type && offset == byteEvent.offset) {
158         return true;
159       }
160     }
161     return false;
162   }
163 
waitForByteEvent(const AsyncTransport::ByteEvent::Type type,const uint64_t offset)164   void waitForByteEvent(
165       const AsyncTransport::ByteEvent::Type type, const uint64_t offset) {
166     while (!checkIfByteEventReceived(type, offset)) {
167       transport_->getEventBase()->loopOnce();
168     }
169   }
170 
171   // Exposed ByteEvent helper fields with const
172   const uint32_t& byteEventsEnabledCalled{byteEventsEnabledCalled_};
173   const uint32_t& byteEventsUnavailableCalled{byteEventsUnavailableCalled_};
174   const folly::Optional<AsyncSocketException>& byteEventsUnavailableCalledEx{
175       byteEventsUnavailableCalledEx_};
176   const std::vector<AsyncTransport::ByteEvent>& byteEvents{byteEvents_};
177 
178  private:
179   const AsyncTransport* transport_;
180 
181   // ByteEvents helpers
182   uint32_t byteEventsEnabledCalled_{0};
183   uint32_t byteEventsUnavailableCalled_{0};
184   folly::Optional<AsyncSocketException> byteEventsUnavailableCalledEx_;
185   std::vector<AsyncTransport::ByteEvent> byteEvents_;
186 };
187 
188 } // namespace test
189 } // namespace folly
190