1 // -*- C++ -*- 2 //===----------------------------------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 // UNSUPPORTED: c++98, c++03, c++11 11 12 #include <experimental/coroutine> 13 #include <vector> 14 #include <cassert> 15 16 #include "test_macros.h" 17 18 using namespace std::experimental; 19 20 // This file tests, one shot, movable std::function like thing using coroutine 21 // for compile-time type erasure and unerasure. 22 23 template <typename R> struct func { 24 struct promise_type { 25 R result; get_return_objectfunc::promise_type26 func get_return_object() { return {this}; } initial_suspendfunc::promise_type27 suspend_always initial_suspend() { return {}; } final_suspendfunc::promise_type28 suspend_always final_suspend() { return {}; } return_valuefunc::promise_type29 void return_value(R v) { result = v; } unhandled_exceptionfunc::promise_type30 void unhandled_exception() {} 31 }; 32 operator ()func33 R operator()() { 34 h.resume(); 35 R result = h.promise().result; 36 h.destroy(); 37 h = nullptr; 38 return result; 39 }; 40 funcfunc41 func() {} funcfunc42 func(func &&rhs) : h(rhs.h) { rhs.h = nullptr; } 43 func(func const &) = delete; 44 operator =func45 func &operator=(func &&rhs) { 46 if (this != &rhs) { 47 if (h) 48 h.destroy(); 49 h = rhs.h; 50 rhs.h = nullptr; 51 } 52 return *this; 53 } 54 Createfunc55 template <typename F> static func Create(F f) { co_return f(); } 56 funcfunc57 template <typename F> func(F f) : func(Create(f)) {} 58 ~funcfunc59 ~func() { 60 if (h) 61 h.destroy(); 62 } 63 64 private: funcfunc65 func(promise_type *promise) 66 : h(coroutine_handle<promise_type>::from_promise(*promise)) {} 67 coroutine_handle<promise_type> h; 68 }; 69 70 std::vector<int> yielded_values = {}; yield(int x)71int yield(int x) { yielded_values.push_back(x); return x + 1; } fyield(int x)72float fyield(int x) { yielded_values.push_back(x); return static_cast<float>(x + 2); } 73 Do1(func<int> f)74void Do1(func<int> f) { yield(f()); } Do2(func<double> f)75void Do2(func<double> f) { yield(static_cast<int>(f())); } 76 main(int,char **)77int main(int, char**) { 78 Do1([] { return yield(43); }); 79 assert((yielded_values == std::vector<int>{43, 44})); 80 81 yielded_values = {}; 82 Do2([] { return fyield(44); }); 83 assert((yielded_values == std::vector<int>{44, 46})); 84 85 return 0; 86 } 87