1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/util/task_group.h"
19 
20 #include <atomic>
21 #include <condition_variable>
22 #include <cstdint>
23 #include <functional>
24 #include <mutex>
25 #include <utility>
26 
27 #include "arrow/util/checked_cast.h"
28 #include "arrow/util/logging.h"
29 #include "arrow/util/thread_pool.h"
30 
31 namespace arrow {
32 namespace internal {
33 
34 ////////////////////////////////////////////////////////////////////////
35 // Serial TaskGroup implementation
36 
37 class SerialTaskGroup : public TaskGroup {
38  public:
AppendReal(std::function<Status ()> task)39   void AppendReal(std::function<Status()> task) override {
40     DCHECK(!finished_);
41     if (status_.ok()) {
42       status_ &= task();
43     }
44   }
45 
current_status()46   Status current_status() override { return status_; }
47 
ok() const48   bool ok() const override { return status_.ok(); }
49 
Finish()50   Status Finish() override {
51     if (!finished_) {
52       finished_ = true;
53     }
54     return status_;
55   }
56 
FinishAsync()57   Future<> FinishAsync() override { return Future<>::MakeFinished(Finish()); }
58 
parallelism()59   int parallelism() override { return 1; }
60 
61   Status status_;
62   bool finished_ = false;
63 };
64 
65 ////////////////////////////////////////////////////////////////////////
66 // Threaded TaskGroup implementation
67 
68 class ThreadedTaskGroup : public TaskGroup {
69  public:
ThreadedTaskGroup(Executor * executor)70   explicit ThreadedTaskGroup(Executor* executor)
71       : executor_(executor), nremaining_(0), ok_(true) {}
72 
~ThreadedTaskGroup()73   ~ThreadedTaskGroup() override {
74     // Make sure all pending tasks are finished, so that dangling references
75     // to this don't persist.
76     ARROW_UNUSED(Finish());
77   }
78 
AppendReal(std::function<Status ()> task)79   void AppendReal(std::function<Status()> task) override {
80     DCHECK(!finished_);
81     // The hot path is unlocked thanks to atomics
82     // Only if an error occurs is the lock taken
83     if (ok_.load(std::memory_order_acquire)) {
84       nremaining_.fetch_add(1, std::memory_order_acquire);
85 
86       auto self = checked_pointer_cast<ThreadedTaskGroup>(shared_from_this());
87       Status st = executor_->Spawn(std::bind(
88           [](const std::shared_ptr<ThreadedTaskGroup>& self,
89              const std::function<Status()>& task) {
90             if (self->ok_.load(std::memory_order_acquire)) {
91               // XXX what about exceptions?
92               Status st = task();
93               self->UpdateStatus(std::move(st));
94             }
95             self->OneTaskDone();
96           },
97           std::move(self), std::move(task)));
98       UpdateStatus(std::move(st));
99     }
100   }
101 
current_status()102   Status current_status() override {
103     std::lock_guard<std::mutex> lock(mutex_);
104     return status_;
105   }
106 
ok() const107   bool ok() const override { return ok_.load(); }
108 
Finish()109   Status Finish() override {
110     std::unique_lock<std::mutex> lock(mutex_);
111     if (!finished_) {
112       cv_.wait(lock, [&]() { return nremaining_.load() == 0; });
113       // Current tasks may start other tasks, so only set this when done
114       finished_ = true;
115     }
116     return status_;
117   }
118 
FinishAsync()119   Future<> FinishAsync() override {
120     std::lock_guard<std::mutex> lock(mutex_);
121     if (!completion_future_.has_value()) {
122       if (nremaining_.load() == 0) {
123         completion_future_ = Future<>::MakeFinished(status_);
124       } else {
125         completion_future_ = Future<>::Make();
126       }
127     }
128     return *completion_future_;
129   }
130 
parallelism()131   int parallelism() override { return executor_->GetCapacity(); }
132 
133  protected:
UpdateStatus(Status && st)134   void UpdateStatus(Status&& st) {
135     // Must be called unlocked, only locks on error
136     if (ARROW_PREDICT_FALSE(!st.ok())) {
137       std::lock_guard<std::mutex> lock(mutex_);
138       ok_.store(false, std::memory_order_release);
139       status_ &= std::move(st);
140     }
141   }
142 
OneTaskDone()143   void OneTaskDone() {
144     // Can be called unlocked thanks to atomics
145     auto nremaining = nremaining_.fetch_sub(1, std::memory_order_release) - 1;
146     DCHECK_GE(nremaining, 0);
147     if (nremaining == 0) {
148       // Take the lock so that ~ThreadedTaskGroup cannot destroy cv
149       // before cv.notify_one() has returned
150       std::unique_lock<std::mutex> lock(mutex_);
151       cv_.notify_one();
152       if (completion_future_.has_value()) {
153         // MarkFinished could be slow.  We don't want to call it while we are holding
154         // the lock.
155         auto& future = *completion_future_;
156         const auto finished = completion_future_->is_finished();
157         const auto& status = status_;
158         // This will be redundant if the user calls Finish and not FinishAsync
159         if (!finished && !finished_) {
160           finished_ = true;
161           lock.unlock();
162           future.MarkFinished(status);
163         } else {
164           lock.unlock();
165         }
166       }
167     }
168   }
169 
170   // These members are usable unlocked
171   Executor* executor_;
172   std::atomic<int32_t> nremaining_;
173   std::atomic<bool> ok_;
174 
175   // These members use locking
176   std::mutex mutex_;
177   std::condition_variable cv_;
178   Status status_;
179   bool finished_ = false;
180   util::optional<Future<>> completion_future_;
181 };
182 
MakeSerial()183 std::shared_ptr<TaskGroup> TaskGroup::MakeSerial() {
184   return std::shared_ptr<TaskGroup>(new SerialTaskGroup);
185 }
186 
MakeThreaded(Executor * thread_pool)187 std::shared_ptr<TaskGroup> TaskGroup::MakeThreaded(Executor* thread_pool) {
188   return std::shared_ptr<TaskGroup>(new ThreadedTaskGroup(thread_pool));
189 }
190 
191 }  // namespace internal
192 }  // namespace arrow
193