1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "marl/debug.h"
16 #include "marl/memory.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #define WIN32_LEAN_AND_MEAN 1
22 #include <Windows.h>
23 
24 namespace marl {
25 
26 class OSFiber {
27  public:
28   inline ~OSFiber();
29 
30   // createFiberFromCurrentThread() returns a fiber created from the current
31   // thread.
32   static inline Allocator::unique_ptr<OSFiber> createFiberFromCurrentThread(
33       Allocator* allocator);
34 
35   // createFiber() returns a new fiber with the given stack size that will
36   // call func when switched to. func() must end by switching back to another
37   // fiber, and must not return.
38   static inline Allocator::unique_ptr<OSFiber> createFiber(
39       Allocator* allocator,
40       size_t stackSize,
41       const std::function<void()>& func);
42 
43   // switchTo() immediately switches execution to the given fiber.
44   // switchTo() must be called on the currently executing fiber.
45   inline void switchTo(OSFiber*);
46 
47  private:
48   static inline void WINAPI run(void* self);
49   LPVOID fiber = nullptr;
50   bool isFiberFromThread = false;
51   std::function<void()> target;
52 };
53 
~OSFiber()54 OSFiber::~OSFiber() {
55   if (fiber != nullptr) {
56     if (isFiberFromThread) {
57       ConvertFiberToThread();
58     } else {
59       DeleteFiber(fiber);
60     }
61   }
62 }
63 
createFiberFromCurrentThread(Allocator * allocator)64 Allocator::unique_ptr<OSFiber> OSFiber::createFiberFromCurrentThread(
65     Allocator* allocator) {
66   auto out = allocator->make_unique<OSFiber>();
67   out->fiber = ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
68   out->isFiberFromThread = true;
69   MARL_ASSERT(out->fiber != nullptr,
70               "ConvertThreadToFiberEx() failed with error 0x%x",
71               int(GetLastError()));
72   return out;
73 }
74 
createFiber(Allocator * allocator,size_t stackSize,const std::function<void ()> & func)75 Allocator::unique_ptr<OSFiber> OSFiber::createFiber(
76     Allocator* allocator,
77     size_t stackSize,
78     const std::function<void()>& func) {
79   auto out = allocator->make_unique<OSFiber>();
80   // stackSize is rounded up to the system's allocation granularity (typically
81   // 64 KB).
82   out->fiber = CreateFiberEx(stackSize - 1, stackSize, FIBER_FLAG_FLOAT_SWITCH,
83                              &OSFiber::run, out.get());
84   out->target = func;
85   MARL_ASSERT(out->fiber != nullptr, "CreateFiberEx() failed with error 0x%x",
86               int(GetLastError()));
87   return out;
88 }
89 
switchTo(OSFiber * to)90 void OSFiber::switchTo(OSFiber* to) {
91   SwitchToFiber(to->fiber);
92 }
93 
run(void * self)94 void WINAPI OSFiber::run(void* self) {
95   std::function<void()> func;
96   std::swap(func, reinterpret_cast<OSFiber*>(self)->target);
97   func();
98 }
99 
100 }  // namespace marl
101