1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/task/post_task.h"
6 
7 #include <utility>
8 
9 #include "base/logging.h"
10 #include "base/task/scoped_set_task_priority_for_current_thread.h"
11 #include "base/task/task_executor.h"
12 #include "base/task/thread_pool/thread_pool_impl.h"
13 #include "base/task/thread_pool/thread_pool_instance.h"
14 #include "base/threading/post_task_and_reply_impl.h"
15 
16 namespace base {
17 
18 namespace {
19 
20 class PostTaskAndReplyWithTraitsTaskRunner
21     : public internal::PostTaskAndReplyImpl {
22  public:
PostTaskAndReplyWithTraitsTaskRunner(const TaskTraits & traits)23   explicit PostTaskAndReplyWithTraitsTaskRunner(const TaskTraits& traits)
24       : traits_(traits) {}
25 
26  private:
PostTask(const Location & from_here,OnceClosure task)27   bool PostTask(const Location& from_here, OnceClosure task) override {
28     ::base::PostTask(from_here, traits_, std::move(task));
29     return true;
30   }
31 
32   const TaskTraits traits_;
33 };
34 
GetTaskExecutorForTraits(const TaskTraits & traits)35 TaskExecutor* GetTaskExecutorForTraits(const TaskTraits& traits) {
36   const bool has_extension =
37       traits.extension_id() != TaskTraitsExtensionStorage::kInvalidExtensionId;
38   DCHECK(has_extension ^ traits.use_thread_pool())
39       << "A destination (e.g. ThreadPool or BrowserThread) must be specified "
40          "to use the post_task.h API. However, you should prefer the direct "
41          "thread_pool.h or browser_thread.h APIs in new code.";
42 
43   if (traits.use_thread_pool()) {
44     DCHECK(ThreadPoolInstance::Get())
45         << "Ref. Prerequisite section of post_task.h for base::ThreadPool "
46            "usage.\n"
47            "Hint: if this is in a unit test, you're likely merely missing a "
48            "base::test::TaskEnvironment member in your fixture (or your "
49            "fixture is using a base::test::SingleThreadTaskEnvironment and now "
50            "needs a full base::test::TaskEnvironment).\n";
51     return static_cast<internal::ThreadPoolImpl*>(ThreadPoolInstance::Get());
52   }
53 
54   // Assume |has_extension| per above invariant.
55   TaskExecutor* executor = GetRegisteredTaskExecutorForTraits(traits);
56   DCHECK(executor)
57       << "A TaskExecutor wasn't yet registered for this extension.\n"
58          "Hint: if this is in a unit test, you're likely missing a "
59          "content::BrowserTaskEnvironment member in your fixture.";
60   return executor;
61 }
62 
63 }  // namespace
64 
PostTask(const Location & from_here,OnceClosure task)65 bool PostTask(const Location& from_here, OnceClosure task) {
66   // TODO(skyostil): Make task traits required here too.
67   return PostDelayedTask(from_here, {ThreadPool()}, std::move(task),
68                          TimeDelta());
69 }
70 
PostTaskAndReply(const Location & from_here,OnceClosure task,OnceClosure reply)71 bool PostTaskAndReply(const Location& from_here,
72                       OnceClosure task,
73                       OnceClosure reply) {
74   return PostTaskAndReply(from_here, {ThreadPool()}, std::move(task),
75                           std::move(reply));
76 }
77 
PostTask(const Location & from_here,const TaskTraits & traits,OnceClosure task)78 bool PostTask(const Location& from_here,
79               const TaskTraits& traits,
80               OnceClosure task) {
81   return PostDelayedTask(from_here, traits, std::move(task), TimeDelta());
82 }
83 
PostDelayedTask(const Location & from_here,const TaskTraits & traits,OnceClosure task,TimeDelta delay)84 bool PostDelayedTask(const Location& from_here,
85                      const TaskTraits& traits,
86                      OnceClosure task,
87                      TimeDelta delay) {
88   return GetTaskExecutorForTraits(traits)->PostDelayedTask(
89       from_here, traits, std::move(task), delay);
90 }
91 
PostTaskAndReply(const Location & from_here,const TaskTraits & traits,OnceClosure task,OnceClosure reply)92 bool PostTaskAndReply(const Location& from_here,
93                       const TaskTraits& traits,
94                       OnceClosure task,
95                       OnceClosure reply) {
96   return PostTaskAndReplyWithTraitsTaskRunner(traits).PostTaskAndReply(
97       from_here, std::move(task), std::move(reply));
98 }
99 
CreateTaskRunner(const TaskTraits & traits)100 scoped_refptr<TaskRunner> CreateTaskRunner(const TaskTraits& traits) {
101   return GetTaskExecutorForTraits(traits)->CreateTaskRunner(traits);
102 }
103 
CreateSequencedTaskRunner(const TaskTraits & traits)104 scoped_refptr<SequencedTaskRunner> CreateSequencedTaskRunner(
105     const TaskTraits& traits) {
106   return GetTaskExecutorForTraits(traits)->CreateSequencedTaskRunner(traits);
107 }
108 
109 scoped_refptr<UpdateableSequencedTaskRunner>
CreateUpdateableSequencedTaskRunner(const TaskTraits & traits)110 CreateUpdateableSequencedTaskRunner(const TaskTraits& traits) {
111   DCHECK(ThreadPoolInstance::Get())
112       << "Ref. Prerequisite section of post_task.h.\n\n"
113          "Hint: if this is in a unit test, you're likely merely missing a "
114          "base::test::TaskEnvironment member in your fixture.\n";
115   DCHECK(traits.use_thread_pool())
116       << "The base::UseThreadPool() trait is mandatory with "
117          "CreateUpdateableSequencedTaskRunner().";
118   CHECK_EQ(traits.extension_id(),
119            TaskTraitsExtensionStorage::kInvalidExtensionId)
120       << "Extension traits cannot be used with "
121          "CreateUpdateableSequencedTaskRunner().";
122   return static_cast<internal::ThreadPoolImpl*>(ThreadPoolInstance::Get())
123       ->CreateUpdateableSequencedTaskRunner(traits);
124 }
125 
CreateSingleThreadTaskRunner(const TaskTraits & traits,SingleThreadTaskRunnerThreadMode thread_mode)126 scoped_refptr<SingleThreadTaskRunner> CreateSingleThreadTaskRunner(
127     const TaskTraits& traits,
128     SingleThreadTaskRunnerThreadMode thread_mode) {
129   return GetTaskExecutorForTraits(traits)->CreateSingleThreadTaskRunner(
130       traits, thread_mode);
131 }
132 
133 #if defined(OS_WIN)
CreateCOMSTATaskRunner(const TaskTraits & traits,SingleThreadTaskRunnerThreadMode thread_mode)134 scoped_refptr<SingleThreadTaskRunner> CreateCOMSTATaskRunner(
135     const TaskTraits& traits,
136     SingleThreadTaskRunnerThreadMode thread_mode) {
137   return GetTaskExecutorForTraits(traits)->CreateCOMSTATaskRunner(traits,
138                                                                   thread_mode);
139 }
140 #endif  // defined(OS_WIN)
141 
142 }  // namespace base
143