1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/threading_strategy.h"
16
17 #include <memory>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "gtest/gtest.h"
23 #include "src/frame_scratch_buffer.h"
24 #include "src/obu_parser.h"
25 #include "src/utils/constants.h"
26 #include "src/utils/threadpool.h"
27 #include "src/utils/types.h"
28
29 namespace libgav1 {
30 namespace {
31
32 class ThreadingStrategyTest : public testing::Test {
33 protected:
34 ThreadingStrategy strategy_;
35 ObuFrameHeader frame_header_ = {};
36 };
37
TEST_F(ThreadingStrategyTest,MaxThreadEnforced)38 TEST_F(ThreadingStrategyTest, MaxThreadEnforced) {
39 frame_header_.tile_info.tile_count = 32;
40 ASSERT_TRUE(strategy_.Reset(frame_header_, 32));
41 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
42 for (int i = 0; i < 32; ++i) {
43 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
44 }
45 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
46 }
47
TEST_F(ThreadingStrategyTest,UseAllThreadsForTiles)48 TEST_F(ThreadingStrategyTest, UseAllThreadsForTiles) {
49 frame_header_.tile_info.tile_count = 8;
50 ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
51 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
52 for (int i = 0; i < 8; ++i) {
53 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
54 }
55 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
56 }
57
TEST_F(ThreadingStrategyTest,RowThreads)58 TEST_F(ThreadingStrategyTest, RowThreads) {
59 frame_header_.tile_info.tile_count = 2;
60 ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
61 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
62 // Each tile should get 3 threads each.
63 for (int i = 0; i < 2; ++i) {
64 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
65 }
66 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
67 }
68
TEST_F(ThreadingStrategyTest,RowThreadsUnequal)69 TEST_F(ThreadingStrategyTest, RowThreadsUnequal) {
70 frame_header_.tile_info.tile_count = 2;
71
72 ASSERT_TRUE(strategy_.Reset(frame_header_, 9));
73 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
74 EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
75 EXPECT_NE(strategy_.row_thread_pool(1), nullptr);
76 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
77 }
78
79 // Test a random combination of tile_count and thread_count.
TEST_F(ThreadingStrategyTest,MultipleCalls)80 TEST_F(ThreadingStrategyTest, MultipleCalls) {
81 frame_header_.tile_info.tile_count = 2;
82 ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
83 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
84 for (int i = 0; i < 2; ++i) {
85 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
86 }
87 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
88
89 frame_header_.tile_info.tile_count = 8;
90 ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
91 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
92 // Row threads must have been reset.
93 for (int i = 0; i < 8; ++i) {
94 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
95 }
96 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
97
98 frame_header_.tile_info.tile_count = 8;
99 ASSERT_TRUE(strategy_.Reset(frame_header_, 16));
100 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
101 for (int i = 0; i < 8; ++i) {
102 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
103 }
104 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
105
106 frame_header_.tile_info.tile_count = 4;
107 ASSERT_TRUE(strategy_.Reset(frame_header_, 16));
108 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
109 for (int i = 0; i < 4; ++i) {
110 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
111 }
112 // All the other row threads must be reset.
113 for (int i = 4; i < 8; ++i) {
114 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
115 }
116 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
117
118 frame_header_.tile_info.tile_count = 4;
119 ASSERT_TRUE(strategy_.Reset(frame_header_, 6));
120 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
121 // First two tiles will get 1 thread each.
122 for (int i = 0; i < 2; ++i) {
123 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
124 }
125 // All the other row threads must be reset.
126 for (int i = 2; i < 8; ++i) {
127 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
128 }
129 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
130
131 ASSERT_TRUE(strategy_.Reset(frame_header_, 1));
132 EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
133 for (int i = 0; i < 8; ++i) {
134 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
135 }
136 EXPECT_EQ(strategy_.post_filter_thread_pool(), nullptr);
137 }
138
139 // Tests the following order of calls (with thread count fixed at 4):
140 // * 1 Tile - 2 Tiles - 1 Tile.
TEST_F(ThreadingStrategyTest,MultipleCalls2)141 TEST_F(ThreadingStrategyTest, MultipleCalls2) {
142 frame_header_.tile_info.tile_count = 1;
143 ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
144 // When there is only one tile, tile thread pool must be nullptr.
145 EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
146 EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
147 for (int i = 1; i < 8; ++i) {
148 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
149 }
150 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
151
152 frame_header_.tile_info.tile_count = 2;
153 ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
154 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
155 for (int i = 0; i < 2; ++i) {
156 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
157 }
158 for (int i = 2; i < 8; ++i) {
159 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
160 }
161 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
162
163 frame_header_.tile_info.tile_count = 1;
164 ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
165 EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
166 EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
167 for (int i = 1; i < 8; ++i) {
168 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
169 }
170 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
171 }
172
VerifyFrameParallel(int thread_count,int tile_count,int tile_columns,int expected_frame_threads,const std::vector<int> & expected_tile_threads)173 void VerifyFrameParallel(int thread_count, int tile_count, int tile_columns,
174 int expected_frame_threads,
175 const std::vector<int>& expected_tile_threads) {
176 ASSERT_EQ(expected_frame_threads, expected_tile_threads.size());
177 ASSERT_GT(thread_count, 1);
178 std::unique_ptr<ThreadPool> frame_thread_pool;
179 FrameScratchBufferPool frame_scratch_buffer_pool;
180 ASSERT_TRUE(InitializeThreadPoolsForFrameParallel(
181 thread_count, tile_count, tile_columns, &frame_thread_pool,
182 &frame_scratch_buffer_pool));
183 if (expected_frame_threads == 0) {
184 EXPECT_EQ(frame_thread_pool, nullptr);
185 return;
186 }
187 EXPECT_NE(frame_thread_pool.get(), nullptr);
188 EXPECT_EQ(frame_thread_pool->num_threads(), expected_frame_threads);
189 std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
190 int actual_thread_count = frame_thread_pool->num_threads();
191 for (int i = 0; i < expected_frame_threads; ++i) {
192 SCOPED_TRACE(absl::StrCat("i: ", i));
193 frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get());
194 ThreadPool* const thread_pool =
195 frame_scratch_buffers.back()->threading_strategy.thread_pool();
196 if (expected_tile_threads[i] > 0) {
197 EXPECT_NE(thread_pool, nullptr);
198 EXPECT_EQ(thread_pool->num_threads(), expected_tile_threads[i]);
199 actual_thread_count += thread_pool->num_threads();
200 } else {
201 EXPECT_EQ(thread_pool, nullptr);
202 }
203 }
204 EXPECT_EQ(thread_count, actual_thread_count);
205 for (auto& frame_scratch_buffer : frame_scratch_buffers) {
206 frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer));
207 }
208 }
209
TEST(FrameParallelStrategyTest,FrameParallel)210 TEST(FrameParallelStrategyTest, FrameParallel) {
211 // This loop has thread_count <= 3 * tile count. So there should be no frame
212 // threads irrespective of the number of tile columns.
213 for (int thread_count = 2; thread_count <= 6; ++thread_count) {
214 VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/1,
215 /*expected_frame_threads=*/0,
216 /*expected_tile_threads=*/{});
217 VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/2,
218 /*expected_frame_threads=*/0,
219 /*expected_tile_threads=*/{});
220 }
221
222 // Equal number of tile threads for each frame thread.
223 VerifyFrameParallel(
224 /*thread_count=*/8, /*tile_count=*/1, /*tile_columns=*/1,
225 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{1, 1, 1, 1});
226 VerifyFrameParallel(
227 /*thread_count=*/12, /*tile_count=*/2, /*tile_columns=*/2,
228 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{2, 2, 2, 2});
229 VerifyFrameParallel(
230 /*thread_count=*/18, /*tile_count=*/2, /*tile_columns=*/2,
231 /*expected_frame_threads=*/6,
232 /*expected_tile_threads=*/{2, 2, 2, 2, 2, 2});
233 VerifyFrameParallel(
234 /*thread_count=*/16, /*tile_count=*/3, /*tile_columns=*/3,
235 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 3, 3});
236
237 // Unequal number of tile threads for each frame thread.
238 VerifyFrameParallel(
239 /*thread_count=*/7, /*tile_count=*/1, /*tile_columns=*/1,
240 /*expected_frame_threads=*/3, /*expected_tile_threads=*/{2, 1, 1});
241 VerifyFrameParallel(
242 /*thread_count=*/14, /*tile_count=*/2, /*tile_columns=*/2,
243 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 2, 2});
244 VerifyFrameParallel(
245 /*thread_count=*/20, /*tile_count=*/2, /*tile_columns=*/2,
246 /*expected_frame_threads=*/6,
247 /*expected_tile_threads=*/{3, 3, 2, 2, 2, 2});
248 VerifyFrameParallel(
249 /*thread_count=*/17, /*tile_count=*/3, /*tile_columns=*/3,
250 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{4, 3, 3, 3});
251 }
252
TEST(FrameParallelStrategyTest,ThreadCountDoesNotExceedkMaxThreads)253 TEST(FrameParallelStrategyTest, ThreadCountDoesNotExceedkMaxThreads) {
254 std::unique_ptr<ThreadPool> frame_thread_pool;
255 FrameScratchBufferPool frame_scratch_buffer_pool;
256 ASSERT_TRUE(InitializeThreadPoolsForFrameParallel(
257 /*thread_count=*/kMaxThreads + 10, /*tile_count=*/2, /*tile_columns=*/2,
258 &frame_thread_pool, &frame_scratch_buffer_pool));
259 EXPECT_NE(frame_thread_pool.get(), nullptr);
260 std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
261 int actual_thread_count = frame_thread_pool->num_threads();
262 for (int i = 0; i < frame_thread_pool->num_threads(); ++i) {
263 SCOPED_TRACE(absl::StrCat("i: ", i));
264 frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get());
265 ThreadPool* const thread_pool =
266 frame_scratch_buffers.back()->threading_strategy.thread_pool();
267 if (thread_pool != nullptr) {
268 actual_thread_count += thread_pool->num_threads();
269 }
270 }
271 // In this case, the exact number of frame threads and tile threads depend on
272 // the value of kMaxThreads. So simply ensure that the total number of threads
273 // does not exceed kMaxThreads.
274 EXPECT_LE(actual_thread_count, kMaxThreads);
275 for (auto& frame_scratch_buffer : frame_scratch_buffers) {
276 frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer));
277 }
278 }
279
280 } // namespace
281 } // namespace libgav1
282