1 // -*- C++ -*-
2 //===----------------------------------------------------------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 // UNSUPPORTED: c++98, c++03, c++11
11 
12 #include <experimental/coroutine>
13 #include <vector>
14 #include <cassert>
15 
16 #include "test_macros.h"
17 
18 using namespace std::experimental;
19 
20 // This file tests, one shot, movable std::function like thing using coroutine
21 // for compile-time type erasure and unerasure.
22 
23 template <typename R> struct func {
24   struct promise_type {
25     R result;
get_return_objectfunc::promise_type26     func get_return_object() { return {this}; }
initial_suspendfunc::promise_type27     suspend_always initial_suspend() { return {}; }
final_suspendfunc::promise_type28     suspend_always final_suspend() { return {}; }
return_valuefunc::promise_type29     void return_value(R v) { result = v; }
unhandled_exceptionfunc::promise_type30     void unhandled_exception() {}
31   };
32 
operator ()func33   R operator()() {
34     h.resume();
35     R result = h.promise().result;
36     h.destroy();
37     h = nullptr;
38     return result;
39   };
40 
funcfunc41   func() {}
funcfunc42   func(func &&rhs) : h(rhs.h) { rhs.h = nullptr; }
43   func(func const &) = delete;
44 
operator =func45   func &operator=(func &&rhs) {
46     if (this != &rhs) {
47       if (h)
48         h.destroy();
49       h = rhs.h;
50       rhs.h = nullptr;
51     }
52     return *this;
53   }
54 
Createfunc55   template <typename F> static func Create(F f) { co_return f(); }
56 
funcfunc57   template <typename F> func(F f) : func(Create(f)) {}
58 
~funcfunc59   ~func() {
60     if (h)
61       h.destroy();
62   }
63 
64 private:
funcfunc65   func(promise_type *promise)
66       : h(coroutine_handle<promise_type>::from_promise(*promise)) {}
67   coroutine_handle<promise_type> h;
68 };
69 
70 std::vector<int> yielded_values = {};
yield(int x)71 int yield(int x) { yielded_values.push_back(x); return x + 1; }
fyield(int x)72 float fyield(int x) { yielded_values.push_back(x); return static_cast<float>(x + 2); }
73 
Do1(func<int> f)74 void Do1(func<int> f) { yield(f()); }
Do2(func<double> f)75 void Do2(func<double> f) { yield(static_cast<int>(f())); }
76 
main(int,char **)77 int main(int, char**) {
78   Do1([] { return yield(43); });
79   assert((yielded_values == std::vector<int>{43, 44}));
80 
81   yielded_values = {};
82   Do2([] { return fyield(44); });
83   assert((yielded_values == std::vector<int>{44, 46}));
84 
85   return 0;
86 }
87