1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/threading/thread_task_runner_handle.h"
6 
7 #include <memory>
8 #include <utility>
9 
10 #include "base/bind.h"
11 #include "base/callback_helpers.h"
12 #include "base/lazy_instance.h"
13 #include "base/logging.h"
14 #include "base/run_loop.h"
15 #include "base/threading/thread_local.h"
16 
17 namespace base {
18 
19 namespace {
20 
21 base::LazyInstance<base::ThreadLocalPointer<ThreadTaskRunnerHandle>>::Leaky
22     thread_task_runner_tls = LAZY_INSTANCE_INITIALIZER;
23 
24 }  // namespace
25 
26 // static
Get()27 const scoped_refptr<SingleThreadTaskRunner>& ThreadTaskRunnerHandle::Get() {
28   const ThreadTaskRunnerHandle* current =
29       thread_task_runner_tls.Pointer()->Get();
30   CHECK(current)
31       << "Error: This caller requires a single-threaded context (i.e. the "
32          "current task needs to run from a SingleThreadTaskRunner). If you're "
33          "in a test refer to //docs/threading_and_tasks_testing.md.";
34   return current->task_runner_;
35 }
36 
37 // static
IsSet()38 bool ThreadTaskRunnerHandle::IsSet() {
39   return !!thread_task_runner_tls.Pointer()->Get();
40 }
41 
42 #if defined(OS_BSD)
43 // static
OverrideForTesting(scoped_refptr<SingleThreadTaskRunner> overriding_task_runner)44 ScopedClosureRunner ThreadTaskRunnerHandle::OverrideForTesting(
45     scoped_refptr<SingleThreadTaskRunner> overriding_task_runner) {
46   // OverrideForTesting() is not compatible with a SequencedTaskRunnerHandle
47   // already being set on this thread (except when it's set by the current
48   // ThreadTaskRunnerHandle).
49   DCHECK(!SequencedTaskRunnerHandle::IsSet() || IsSet());
50 
51   if (!IsSet()) {
52     auto top_level_ttrh = std::make_unique<ThreadTaskRunnerHandle>(
53         std::move(overriding_task_runner));
54     return ScopedClosureRunner(base::BindOnce(
55         [](std::unique_ptr<ThreadTaskRunnerHandle> ttrh_to_release) {},
56         std::move(top_level_ttrh)));
57   }
58 
59   ThreadTaskRunnerHandle* ttrh = thread_task_runner_tls.Pointer()->Get();
60   // Swap the two (and below bind |overriding_task_runner|, which is now the
61   // previous one, as the |task_runner_to_restore|).
62   ttrh->sequenced_task_runner_handle_.task_runner_ = overriding_task_runner;
63   ttrh->task_runner_.swap(overriding_task_runner);
64 
65   auto no_running_during_override =
66       std::make_unique<RunLoop::ScopedDisallowRunningForTesting>();
67 
68   return ScopedClosureRunner(base::BindOnce(
69       [](scoped_refptr<SingleThreadTaskRunner> task_runner_to_restore,
70          SingleThreadTaskRunner* expected_task_runner_before_restore,
71          std::unique_ptr<RunLoop::ScopedDisallowRunningForTesting>
72              no_running_during_override) {
73         ThreadTaskRunnerHandle* ttrh = thread_task_runner_tls.Pointer()->Get();
74 
75         DCHECK_EQ(expected_task_runner_before_restore, ttrh->task_runner_.get())
76             << "Nested overrides must expire their ScopedClosureRunners "
77                "in LIFO order.";
78 
79         ttrh->sequenced_task_runner_handle_.task_runner_ =
80             task_runner_to_restore;
81         ttrh->task_runner_.swap(task_runner_to_restore);
82       },
83       std::move(overriding_task_runner),
84       base::Unretained(ttrh->task_runner_.get()),
85       std::move(no_running_during_override)));
86 }
87 #endif
88 
ThreadTaskRunnerHandle(scoped_refptr<SingleThreadTaskRunner> task_runner)89 ThreadTaskRunnerHandle::ThreadTaskRunnerHandle(
90     scoped_refptr<SingleThreadTaskRunner> task_runner)
91     : task_runner_(std::move(task_runner)),
92       sequenced_task_runner_handle_(task_runner_) {
93   DCHECK(task_runner_->BelongsToCurrentThread());
94   DCHECK(!thread_task_runner_tls.Pointer()->Get());
95   thread_task_runner_tls.Pointer()->Set(this);
96 }
97 
~ThreadTaskRunnerHandle()98 ThreadTaskRunnerHandle::~ThreadTaskRunnerHandle() {
99   DCHECK(task_runner_->BelongsToCurrentThread());
100   DCHECK_EQ(thread_task_runner_tls.Pointer()->Get(), this);
101   thread_task_runner_tls.Pointer()->Set(nullptr);
102 }
103 
104 }  // namespace base
105