1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/util/task_group.h"
19
20 #include <atomic>
21 #include <condition_variable>
22 #include <cstdint>
23 #include <functional>
24 #include <mutex>
25 #include <utility>
26
27 #include "arrow/util/checked_cast.h"
28 #include "arrow/util/logging.h"
29 #include "arrow/util/thread_pool.h"
30
31 namespace arrow {
32 namespace internal {
33
34 ////////////////////////////////////////////////////////////////////////
35 // Serial TaskGroup implementation
36
37 class SerialTaskGroup : public TaskGroup {
38 public:
AppendReal(std::function<Status ()> task)39 void AppendReal(std::function<Status()> task) override {
40 DCHECK(!finished_);
41 if (status_.ok()) {
42 status_ &= task();
43 }
44 }
45
current_status()46 Status current_status() override { return status_; }
47
ok() const48 bool ok() const override { return status_.ok(); }
49
Finish()50 Status Finish() override {
51 if (!finished_) {
52 finished_ = true;
53 }
54 return status_;
55 }
56
FinishAsync()57 Future<> FinishAsync() override { return Future<>::MakeFinished(Finish()); }
58
parallelism()59 int parallelism() override { return 1; }
60
61 Status status_;
62 bool finished_ = false;
63 };
64
65 ////////////////////////////////////////////////////////////////////////
66 // Threaded TaskGroup implementation
67
68 class ThreadedTaskGroup : public TaskGroup {
69 public:
ThreadedTaskGroup(Executor * executor)70 explicit ThreadedTaskGroup(Executor* executor)
71 : executor_(executor), nremaining_(0), ok_(true) {}
72
~ThreadedTaskGroup()73 ~ThreadedTaskGroup() override {
74 // Make sure all pending tasks are finished, so that dangling references
75 // to this don't persist.
76 ARROW_UNUSED(Finish());
77 }
78
AppendReal(std::function<Status ()> task)79 void AppendReal(std::function<Status()> task) override {
80 DCHECK(!finished_);
81 // The hot path is unlocked thanks to atomics
82 // Only if an error occurs is the lock taken
83 if (ok_.load(std::memory_order_acquire)) {
84 nremaining_.fetch_add(1, std::memory_order_acquire);
85
86 auto self = checked_pointer_cast<ThreadedTaskGroup>(shared_from_this());
87 Status st = executor_->Spawn(std::bind(
88 [](const std::shared_ptr<ThreadedTaskGroup>& self,
89 const std::function<Status()>& task) {
90 if (self->ok_.load(std::memory_order_acquire)) {
91 // XXX what about exceptions?
92 Status st = task();
93 self->UpdateStatus(std::move(st));
94 }
95 self->OneTaskDone();
96 },
97 std::move(self), std::move(task)));
98 UpdateStatus(std::move(st));
99 }
100 }
101
current_status()102 Status current_status() override {
103 std::lock_guard<std::mutex> lock(mutex_);
104 return status_;
105 }
106
ok() const107 bool ok() const override { return ok_.load(); }
108
Finish()109 Status Finish() override {
110 std::unique_lock<std::mutex> lock(mutex_);
111 if (!finished_) {
112 cv_.wait(lock, [&]() { return nremaining_.load() == 0; });
113 // Current tasks may start other tasks, so only set this when done
114 finished_ = true;
115 }
116 return status_;
117 }
118
FinishAsync()119 Future<> FinishAsync() override {
120 std::lock_guard<std::mutex> lock(mutex_);
121 if (!completion_future_.has_value()) {
122 if (nremaining_.load() == 0) {
123 completion_future_ = Future<>::MakeFinished(status_);
124 } else {
125 completion_future_ = Future<>::Make();
126 }
127 }
128 return *completion_future_;
129 }
130
parallelism()131 int parallelism() override { return executor_->GetCapacity(); }
132
133 protected:
UpdateStatus(Status && st)134 void UpdateStatus(Status&& st) {
135 // Must be called unlocked, only locks on error
136 if (ARROW_PREDICT_FALSE(!st.ok())) {
137 std::lock_guard<std::mutex> lock(mutex_);
138 ok_.store(false, std::memory_order_release);
139 status_ &= std::move(st);
140 }
141 }
142
OneTaskDone()143 void OneTaskDone() {
144 // Can be called unlocked thanks to atomics
145 auto nremaining = nremaining_.fetch_sub(1, std::memory_order_release) - 1;
146 DCHECK_GE(nremaining, 0);
147 if (nremaining == 0) {
148 // Take the lock so that ~ThreadedTaskGroup cannot destroy cv
149 // before cv.notify_one() has returned
150 std::unique_lock<std::mutex> lock(mutex_);
151 cv_.notify_one();
152 if (completion_future_.has_value()) {
153 // MarkFinished could be slow. We don't want to call it while we are holding
154 // the lock.
155 auto& future = *completion_future_;
156 const auto finished = completion_future_->is_finished();
157 const auto& status = status_;
158 // This will be redundant if the user calls Finish and not FinishAsync
159 if (!finished && !finished_) {
160 finished_ = true;
161 lock.unlock();
162 future.MarkFinished(status);
163 } else {
164 lock.unlock();
165 }
166 }
167 }
168 }
169
170 // These members are usable unlocked
171 Executor* executor_;
172 std::atomic<int32_t> nremaining_;
173 std::atomic<bool> ok_;
174
175 // These members use locking
176 std::mutex mutex_;
177 std::condition_variable cv_;
178 Status status_;
179 bool finished_ = false;
180 util::optional<Future<>> completion_future_;
181 };
182
MakeSerial()183 std::shared_ptr<TaskGroup> TaskGroup::MakeSerial() {
184 return std::shared_ptr<TaskGroup>(new SerialTaskGroup);
185 }
186
MakeThreaded(Executor * thread_pool)187 std::shared_ptr<TaskGroup> TaskGroup::MakeThreaded(Executor* thread_pool) {
188 return std::shared_ptr<TaskGroup>(new ThreadedTaskGroup(thread_pool));
189 }
190
191 } // namespace internal
192 } // namespace arrow
193