1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/threading_strategy.h"
16 
17 #include <memory>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "gtest/gtest.h"
23 #include "src/frame_scratch_buffer.h"
24 #include "src/obu_parser.h"
25 #include "src/utils/constants.h"
26 #include "src/utils/threadpool.h"
27 #include "src/utils/types.h"
28 
29 namespace libgav1 {
30 namespace {
31 
32 class ThreadingStrategyTest : public testing::Test {
33  protected:
34   ThreadingStrategy strategy_;
35   ObuFrameHeader frame_header_ = {};
36 };
37 
TEST_F(ThreadingStrategyTest,MaxThreadEnforced)38 TEST_F(ThreadingStrategyTest, MaxThreadEnforced) {
39   frame_header_.tile_info.tile_count = 32;
40   ASSERT_TRUE(strategy_.Reset(frame_header_, 32));
41   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
42   for (int i = 0; i < 32; ++i) {
43     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
44   }
45   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
46 }
47 
TEST_F(ThreadingStrategyTest,UseAllThreadsForTiles)48 TEST_F(ThreadingStrategyTest, UseAllThreadsForTiles) {
49   frame_header_.tile_info.tile_count = 8;
50   ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
51   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
52   for (int i = 0; i < 8; ++i) {
53     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
54   }
55   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
56 }
57 
TEST_F(ThreadingStrategyTest,RowThreads)58 TEST_F(ThreadingStrategyTest, RowThreads) {
59   frame_header_.tile_info.tile_count = 2;
60   ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
61   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
62   // Each tile should get 3 threads each.
63   for (int i = 0; i < 2; ++i) {
64     EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
65   }
66   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
67 }
68 
TEST_F(ThreadingStrategyTest,RowThreadsUnequal)69 TEST_F(ThreadingStrategyTest, RowThreadsUnequal) {
70   frame_header_.tile_info.tile_count = 2;
71 
72   ASSERT_TRUE(strategy_.Reset(frame_header_, 9));
73   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
74   EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
75   EXPECT_NE(strategy_.row_thread_pool(1), nullptr);
76   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
77 }
78 
79 // Test a random combination of tile_count and thread_count.
TEST_F(ThreadingStrategyTest,MultipleCalls)80 TEST_F(ThreadingStrategyTest, MultipleCalls) {
81   frame_header_.tile_info.tile_count = 2;
82   ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
83   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
84   for (int i = 0; i < 2; ++i) {
85     EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
86   }
87   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
88 
89   frame_header_.tile_info.tile_count = 8;
90   ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
91   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
92   // Row threads must have been reset.
93   for (int i = 0; i < 8; ++i) {
94     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
95   }
96   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
97 
98   frame_header_.tile_info.tile_count = 8;
99   ASSERT_TRUE(strategy_.Reset(frame_header_, 16));
100   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
101   for (int i = 0; i < 8; ++i) {
102     EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
103   }
104   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
105 
106   frame_header_.tile_info.tile_count = 4;
107   ASSERT_TRUE(strategy_.Reset(frame_header_, 16));
108   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
109   for (int i = 0; i < 4; ++i) {
110     EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
111   }
112   // All the other row threads must be reset.
113   for (int i = 4; i < 8; ++i) {
114     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
115   }
116   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
117 
118   frame_header_.tile_info.tile_count = 4;
119   ASSERT_TRUE(strategy_.Reset(frame_header_, 6));
120   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
121   // First two tiles will get 1 thread each.
122   for (int i = 0; i < 2; ++i) {
123     EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
124   }
125   // All the other row threads must be reset.
126   for (int i = 2; i < 8; ++i) {
127     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
128   }
129   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
130 
131   ASSERT_TRUE(strategy_.Reset(frame_header_, 1));
132   EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
133   for (int i = 0; i < 8; ++i) {
134     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
135   }
136   EXPECT_EQ(strategy_.post_filter_thread_pool(), nullptr);
137 }
138 
139 // Tests the following order of calls (with thread count fixed at 4):
140 //  * 1 Tile - 2 Tiles - 1 Tile.
TEST_F(ThreadingStrategyTest,MultipleCalls2)141 TEST_F(ThreadingStrategyTest, MultipleCalls2) {
142   frame_header_.tile_info.tile_count = 1;
143   ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
144   // When there is only one tile, tile thread pool must be nullptr.
145   EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
146   EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
147   for (int i = 1; i < 8; ++i) {
148     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
149   }
150   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
151 
152   frame_header_.tile_info.tile_count = 2;
153   ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
154   EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
155   for (int i = 0; i < 2; ++i) {
156     EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
157   }
158   for (int i = 2; i < 8; ++i) {
159     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
160   }
161   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
162 
163   frame_header_.tile_info.tile_count = 1;
164   ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
165   EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
166   EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
167   for (int i = 1; i < 8; ++i) {
168     EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
169   }
170   EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
171 }
172 
VerifyFrameParallel(int thread_count,int tile_count,int tile_columns,int expected_frame_threads,const std::vector<int> & expected_tile_threads)173 void VerifyFrameParallel(int thread_count, int tile_count, int tile_columns,
174                          int expected_frame_threads,
175                          const std::vector<int>& expected_tile_threads) {
176   ASSERT_EQ(expected_frame_threads, expected_tile_threads.size());
177   ASSERT_GT(thread_count, 1);
178   std::unique_ptr<ThreadPool> frame_thread_pool;
179   FrameScratchBufferPool frame_scratch_buffer_pool;
180   ASSERT_TRUE(InitializeThreadPoolsForFrameParallel(
181       thread_count, tile_count, tile_columns, &frame_thread_pool,
182       &frame_scratch_buffer_pool));
183   if (expected_frame_threads == 0) {
184     EXPECT_EQ(frame_thread_pool, nullptr);
185     return;
186   }
187   EXPECT_NE(frame_thread_pool.get(), nullptr);
188   EXPECT_EQ(frame_thread_pool->num_threads(), expected_frame_threads);
189   std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
190   int actual_thread_count = frame_thread_pool->num_threads();
191   for (int i = 0; i < expected_frame_threads; ++i) {
192     SCOPED_TRACE(absl::StrCat("i: ", i));
193     frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get());
194     ThreadPool* const thread_pool =
195         frame_scratch_buffers.back()->threading_strategy.thread_pool();
196     if (expected_tile_threads[i] > 0) {
197       EXPECT_NE(thread_pool, nullptr);
198       EXPECT_EQ(thread_pool->num_threads(), expected_tile_threads[i]);
199       actual_thread_count += thread_pool->num_threads();
200     } else {
201       EXPECT_EQ(thread_pool, nullptr);
202     }
203   }
204   EXPECT_EQ(thread_count, actual_thread_count);
205   for (auto& frame_scratch_buffer : frame_scratch_buffers) {
206     frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer));
207   }
208 }
209 
TEST(FrameParallelStrategyTest,FrameParallel)210 TEST(FrameParallelStrategyTest, FrameParallel) {
211   // This loop has thread_count <= 3 * tile count. So there should be no frame
212   // threads irrespective of the number of tile columns.
213   for (int thread_count = 2; thread_count <= 6; ++thread_count) {
214     VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/1,
215                         /*expected_frame_threads=*/0,
216                         /*expected_tile_threads=*/{});
217     VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/2,
218                         /*expected_frame_threads=*/0,
219                         /*expected_tile_threads=*/{});
220   }
221 
222   // Equal number of tile threads for each frame thread.
223   VerifyFrameParallel(
224       /*thread_count=*/8, /*tile_count=*/1, /*tile_columns=*/1,
225       /*expected_frame_threads=*/4, /*expected_tile_threads=*/{1, 1, 1, 1});
226   VerifyFrameParallel(
227       /*thread_count=*/12, /*tile_count=*/2, /*tile_columns=*/2,
228       /*expected_frame_threads=*/4, /*expected_tile_threads=*/{2, 2, 2, 2});
229   VerifyFrameParallel(
230       /*thread_count=*/18, /*tile_count=*/2, /*tile_columns=*/2,
231       /*expected_frame_threads=*/6,
232       /*expected_tile_threads=*/{2, 2, 2, 2, 2, 2});
233   VerifyFrameParallel(
234       /*thread_count=*/16, /*tile_count=*/3, /*tile_columns=*/3,
235       /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 3, 3});
236 
237   // Unequal number of tile threads for each frame thread.
238   VerifyFrameParallel(
239       /*thread_count=*/7, /*tile_count=*/1, /*tile_columns=*/1,
240       /*expected_frame_threads=*/3, /*expected_tile_threads=*/{2, 1, 1});
241   VerifyFrameParallel(
242       /*thread_count=*/14, /*tile_count=*/2, /*tile_columns=*/2,
243       /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 2, 2});
244   VerifyFrameParallel(
245       /*thread_count=*/20, /*tile_count=*/2, /*tile_columns=*/2,
246       /*expected_frame_threads=*/6,
247       /*expected_tile_threads=*/{3, 3, 2, 2, 2, 2});
248   VerifyFrameParallel(
249       /*thread_count=*/17, /*tile_count=*/3, /*tile_columns=*/3,
250       /*expected_frame_threads=*/4, /*expected_tile_threads=*/{4, 3, 3, 3});
251 }
252 
TEST(FrameParallelStrategyTest,ThreadCountDoesNotExceedkMaxThreads)253 TEST(FrameParallelStrategyTest, ThreadCountDoesNotExceedkMaxThreads) {
254   std::unique_ptr<ThreadPool> frame_thread_pool;
255   FrameScratchBufferPool frame_scratch_buffer_pool;
256   ASSERT_TRUE(InitializeThreadPoolsForFrameParallel(
257       /*thread_count=*/kMaxThreads + 10, /*tile_count=*/2, /*tile_columns=*/2,
258       &frame_thread_pool, &frame_scratch_buffer_pool));
259   EXPECT_NE(frame_thread_pool.get(), nullptr);
260   std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
261   int actual_thread_count = frame_thread_pool->num_threads();
262   for (int i = 0; i < frame_thread_pool->num_threads(); ++i) {
263     SCOPED_TRACE(absl::StrCat("i: ", i));
264     frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get());
265     ThreadPool* const thread_pool =
266         frame_scratch_buffers.back()->threading_strategy.thread_pool();
267     if (thread_pool != nullptr) {
268       actual_thread_count += thread_pool->num_threads();
269     }
270   }
271   // In this case, the exact number of frame threads and tile threads depend on
272   // the value of kMaxThreads. So simply ensure that the total number of threads
273   // does not exceed kMaxThreads.
274   EXPECT_LE(actual_thread_count, kMaxThreads);
275   for (auto& frame_scratch_buffer : frame_scratch_buffers) {
276     frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer));
277   }
278 }
279 
280 }  // namespace
281 }  // namespace libgav1
282