1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "marl/debug.h"
16 #include "marl/memory.h"
17
18 #include <functional>
19 #include <memory>
20
21 #define WIN32_LEAN_AND_MEAN 1
22 #include <Windows.h>
23
24 namespace marl {
25
26 class OSFiber {
27 public:
28 inline ~OSFiber();
29
30 // createFiberFromCurrentThread() returns a fiber created from the current
31 // thread.
32 static inline Allocator::unique_ptr<OSFiber> createFiberFromCurrentThread(
33 Allocator* allocator);
34
35 // createFiber() returns a new fiber with the given stack size that will
36 // call func when switched to. func() must end by switching back to another
37 // fiber, and must not return.
38 static inline Allocator::unique_ptr<OSFiber> createFiber(
39 Allocator* allocator,
40 size_t stackSize,
41 const std::function<void()>& func);
42
43 // switchTo() immediately switches execution to the given fiber.
44 // switchTo() must be called on the currently executing fiber.
45 inline void switchTo(OSFiber*);
46
47 private:
48 static inline void WINAPI run(void* self);
49 LPVOID fiber = nullptr;
50 bool isFiberFromThread = false;
51 std::function<void()> target;
52 };
53
~OSFiber()54 OSFiber::~OSFiber() {
55 if (fiber != nullptr) {
56 if (isFiberFromThread) {
57 ConvertFiberToThread();
58 } else {
59 DeleteFiber(fiber);
60 }
61 }
62 }
63
createFiberFromCurrentThread(Allocator * allocator)64 Allocator::unique_ptr<OSFiber> OSFiber::createFiberFromCurrentThread(
65 Allocator* allocator) {
66 auto out = allocator->make_unique<OSFiber>();
67 out->fiber = ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
68 out->isFiberFromThread = true;
69 MARL_ASSERT(out->fiber != nullptr,
70 "ConvertThreadToFiberEx() failed with error 0x%x",
71 int(GetLastError()));
72 return out;
73 }
74
createFiber(Allocator * allocator,size_t stackSize,const std::function<void ()> & func)75 Allocator::unique_ptr<OSFiber> OSFiber::createFiber(
76 Allocator* allocator,
77 size_t stackSize,
78 const std::function<void()>& func) {
79 auto out = allocator->make_unique<OSFiber>();
80 // stackSize is rounded up to the system's allocation granularity (typically
81 // 64 KB).
82 out->fiber = CreateFiberEx(stackSize - 1, stackSize, FIBER_FLAG_FLOAT_SWITCH,
83 &OSFiber::run, out.get());
84 out->target = func;
85 MARL_ASSERT(out->fiber != nullptr, "CreateFiberEx() failed with error 0x%x",
86 int(GetLastError()));
87 return out;
88 }
89
switchTo(OSFiber * to)90 void OSFiber::switchTo(OSFiber* to) {
91 SwitchToFiber(to->fiber);
92 }
93
run(void * self)94 void WINAPI OSFiber::run(void* self) {
95 std::function<void()> func;
96 std::swap(func, reinterpret_cast<OSFiber*>(self)->target);
97 func();
98 }
99
100 } // namespace marl
101