1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <atomic>
20 #include <type_traits>
21 
22 #include <folly/experimental/coro/Coroutine.h>
23 #include <folly/experimental/coro/GtestHelpers.h>
24 #include <folly/experimental/coro/Result.h>
25 #include <folly/experimental/coro/Task.h>
26 #include <folly/portability/GMock.h>
27 
28 #if FOLLY_HAS_COROUTINES
29 
30 namespace folly {
31 namespace coro {
32 namespace gmock_helpers {
33 
34 // This helper function is intended for use in GMock implementations where the
35 // implementation of the method is a coroutine lambda.
36 //
37 // The GMock framework internally always takes a copy of an action/lambda
38 // before invoking it to prevent cases where invoking the method might end
39 // up destroying itself.
40 //
41 // However, this is problematic for coroutine-lambdas-with-captures as the
42 // return-value from invoking a coroutine lambda will typically capture a
43 // reference to the copy of the lambda which will immediately become a dangling
44 // reference as soon as the mocking framework returns that value to the caller.
45 //
46 // Use this action-factory instead of Invoke() when passing coroutine-lambdas
47 // to mock definitions to ensure that a copy of the lambda is kept alive until
48 // the coroutine completes. It does this by invoking the lambda using the
49 // folly::coro::co_invoke() helper instead of directly invoking the lambda.
50 //
51 //
52 // Example:
53 //   using namespace ::testing
54 //   using namespace folly::coro::gmock_helpers;
55 //
56 //   MockFoo mock;
57 //   int fooCallCount = 0;
58 //
59 //   EXPECT_CALL(mock, foo(_))
60 //     .WillRepeatedly(CoInvoke(
61 //         [&](int x) -> folly::coro::Task<int> {
62 //           ++fooCallCount;
63 //           co_return x + 1;
64 //         }));
65 //
66 template <typename F>
CoInvoke(F && f)67 auto CoInvoke(F&& f) {
68   return ::testing::Invoke([f = static_cast<F&&>(f)](auto&&... a) {
69     return co_invoke(f, static_cast<decltype(a)>(a)...);
70   });
71 }
72 
73 // Member function overload
74 template <class Class, typename MethodPtr>
CoInvoke(Class * obj_ptr,MethodPtr method_ptr)75 auto CoInvoke(Class* obj_ptr, MethodPtr method_ptr) {
76   return ::testing::Invoke([=](auto&&... a) {
77     return co_invoke(method_ptr, obj_ptr, static_cast<decltype(a)>(a)...);
78   });
79 }
80 
81 // CoInvoke variant that does not pass arguments to callback function.
82 //
83 // Example:
84 //   using namespace ::testing
85 //   using namespace folly::coro::gmock_helpers;
86 //
87 //   MockFoo mock;
88 //   int fooCallCount = 0;
89 //
90 //   EXPECT_CALL(mock, foo(_))
91 //     .WillRepeatedly(CoInvokeWithoutArgs(
92 //         [&]() -> folly::coro::Task<int> {
93 //           ++fooCallCount;
94 //           co_return 42;
95 //         }));
96 template <typename F>
CoInvokeWithoutArgs(F && f)97 auto CoInvokeWithoutArgs(F&& f) {
98   return ::testing::InvokeWithoutArgs(
99       [f = static_cast<F&&>(f)]() { return co_invoke(f); });
100 }
101 
102 // Member function overload
103 template <class Class, typename MethodPtr>
CoInvokeWithoutArgs(Class * obj_ptr,MethodPtr method_ptr)104 auto CoInvokeWithoutArgs(Class* obj_ptr, MethodPtr method_ptr) {
105   return ::testing::InvokeWithoutArgs(
106       [=]() { return co_invoke(method_ptr, obj_ptr); });
107 }
108 
109 namespace detail {
110 template <typename Fn>
makeCoAction(Fn && fn)111 auto makeCoAction(Fn&& fn) {
112   static_assert(
113       std::is_copy_constructible_v<remove_cvref_t<Fn>>,
114       "Fn should be copyable to allow calling mocked call multiple times.");
115 
116   using Ret = std::invoke_result_t<remove_cvref_t<Fn>&&>;
117   return ::testing::InvokeWithoutArgs(
118       [fn = std::forward<Fn>(fn)]() mutable -> Ret { return co_invoke(fn); });
119 }
120 
121 // Helper class to capture a ByMove return value for mocked coroutine function.
122 // Adds a test failure if it is moved twice like:
123 //    .WillRepeatedly(CoReturnByMove...)
124 template <typename R>
125 struct OnceForwarder {
126   static_assert(std::is_reference_v<R>);
127   using V = remove_cvref_t<R>;
128 
OnceForwarderOnceForwarder129   explicit OnceForwarder(R r) noexcept(std::is_nothrow_constructible_v<V>)
130       : val_(static_cast<R>(r)) {}
131 
operatorOnceForwarder132   R operator()() noexcept {
133     auto performedPreviously =
134         performed_.exchange(true, std::memory_order_relaxed);
135     if (performedPreviously) {
136       terminate_with<std::runtime_error>(
137           "a CoReturnByMove action must be performed only once");
138     }
139     return static_cast<R>(val_);
140   }
141 
142  private:
143   V val_;
144   std::atomic<bool> performed_ = false;
145 };
146 
147 // Allow to return a value by providing a convertible value.
148 // This works similarly to Return(x):
149 // MOCK_METHOD1(Method, T(U));
150 // EXPECT_CALL(mock, Method(_)).WillOnce(Return(F()));
151 // should work as long as F is convertible to T.
152 template <typename T>
153 class CoReturnImpl {
154  public:
CoReturnImpl(T && value)155   explicit CoReturnImpl(T&& value) : value_(std::move(value)) {}
156 
157   template <typename Result, typename ArgumentTuple>
Perform(const ArgumentTuple &)158   Result Perform(const ArgumentTuple& /* unused */) const {
159     return [](T value) -> Result { co_return value; }(value_);
160   }
161 
162  private:
163   T value_;
164 };
165 
166 template <typename T>
167 class CoReturnByMoveImpl {
168  public:
CoReturnByMoveImpl(std::shared_ptr<OnceForwarder<T &&>> forwarder)169   explicit CoReturnByMoveImpl(std::shared_ptr<OnceForwarder<T&&>> forwarder)
170       : forwarder_(std::move(forwarder)) {}
171 
172   template <typename Result, typename ArgumentTuple>
Perform(const ArgumentTuple &)173   Result Perform(const ArgumentTuple& /* unused */) const {
174     return [](std::shared_ptr<OnceForwarder<T&&>> forwarder) -> Result {
175       co_return (*forwarder)();
176     }(forwarder_);
177   }
178 
179  private:
180   std::shared_ptr<OnceForwarder<T&&>> forwarder_;
181 };
182 
183 } // namespace detail
184 
185 // Helper functions to adapt CoRoutines enabled functions to be mocked using
186 // gMock. CoReturn and CoThrows are gMock Action types that mirror the Return
187 // and Throws Action types used in EXPECT_CALL|ON_CALL invocations.
188 //
189 // Example:
190 //   using namespace ::testing
191 //   using namespace folly::coro::gmock_helpers;
192 //
193 //   MockFoo mock;
194 //   std::string result = "abc";
195 //
196 //   EXPECT_CALL(mock, co_foo(_))
197 //     .WillRepeatedly(CoReturn(result));
198 //
199 //   // For Task<void> return types.
200 //   EXPECT_CALL(mock, co_bar(_))
201 //     .WillRepeatedly(CoReturn());
202 //
203 //   // For returning by move.
204 //   EXPECT_CALL(mock, co_bar(_))
205 //     .WillRepeatedly(CoReturnByMove(std::move(result)));
206 //
207 //   // For returning by move.
208 //   EXPECT_CALL(mock, co_bar(_))
209 //     .WillRepeatedly(CoReturnByMove(std::make_unique(result)));
210 //
211 //
212 //  EXPECT_CALL(mock, co_foo(_))
213 //     .WillRepeatedly(CoThrow<std::string>(std::runtime_error("error")));
214 template <typename T>
CoReturn(T ret)215 auto CoReturn(T ret) {
216   return ::testing::MakePolymorphicAction(
217       detail::CoReturnImpl<T>(std::move(ret)));
218 }
219 
CoReturn()220 inline auto CoReturn() {
221   return ::testing::InvokeWithoutArgs([]() -> Task<> { co_return; });
222 }
223 
224 template <typename T>
CoReturnByMove(T && ret)225 auto CoReturnByMove(T&& ret) {
226   static_assert(
227       !std::is_lvalue_reference_v<decltype(ret)>,
228       "the argument must be passed as non-const rvalue-ref");
229   static_assert(
230       !std::is_const_v<T>,
231       "the argument must be passed as non-const rvalue-ref");
232 
233   auto ptr = std::make_shared<detail::OnceForwarder<T&&>>(std::move(ret));
234 
235   return ::testing::MakePolymorphicAction(
236       detail::CoReturnByMoveImpl<T>(std::move(ptr)));
237 }
238 
239 template <typename T, typename Ex>
CoThrow(Ex && e)240 auto CoThrow(Ex&& e) {
241   return detail::makeCoAction(
242       [ex = std::forward<Ex>(e)]() -> Task<T> { co_yield co_error(ex); });
243 }
244 
245 } // namespace gmock_helpers
246 } // namespace coro
247 } // namespace folly
248 
249 #define CO_ASSERT_THAT(value, matcher) \
250   CO_ASSERT_PRED_FORMAT1(              \
251       ::testing::internal::MakePredicateFormatterFromMatcher(matcher), value)
252 
253 #endif // FOLLY_HAS_COROUTINES
254