1 // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
2 // This source code is licensed under both the GPLv2 (found in the
3 // COPYING file in the root directory) and Apache 2.0 License
4 // (found in the LICENSE.Apache file in the root directory).
5
6 #ifndef ROCKSDB_LITE
7
8 #include <atomic>
9 #include <functional>
10 #include <string>
11 #include <utility>
12 #include <vector>
13
14 #include "db/db_impl/db_impl.h"
15 #include "db/write_callback.h"
16 #include "port/port.h"
17 #include "rocksdb/db.h"
18 #include "rocksdb/write_batch.h"
19 #include "test_util/sync_point.h"
20 #include "test_util/testharness.h"
21 #include "util/random.h"
22
23 using std::string;
24
25 namespace ROCKSDB_NAMESPACE {
26
27 class WriteCallbackTest : public testing::Test {
28 public:
29 string dbname;
30
WriteCallbackTest()31 WriteCallbackTest() {
32 dbname = test::PerThreadDBPath("write_callback_testdb");
33 }
34 };
35
36 class WriteCallbackTestWriteCallback1 : public WriteCallback {
37 public:
38 bool was_called = false;
39
Callback(DB * db)40 Status Callback(DB *db) override {
41 was_called = true;
42
43 // Make sure db is a DBImpl
44 DBImpl* db_impl = dynamic_cast<DBImpl*> (db);
45 if (db_impl == nullptr) {
46 return Status::InvalidArgument("");
47 }
48
49 return Status::OK();
50 }
51
AllowWriteBatching()52 bool AllowWriteBatching() override { return true; }
53 };
54
55 class WriteCallbackTestWriteCallback2 : public WriteCallback {
56 public:
Callback(DB *)57 Status Callback(DB* /*db*/) override { return Status::Busy(); }
AllowWriteBatching()58 bool AllowWriteBatching() override { return true; }
59 };
60
61 class MockWriteCallback : public WriteCallback {
62 public:
63 bool should_fail_ = false;
64 bool allow_batching_ = false;
65 std::atomic<bool> was_called_{false};
66
MockWriteCallback()67 MockWriteCallback() {}
68
MockWriteCallback(const MockWriteCallback & other)69 MockWriteCallback(const MockWriteCallback& other) {
70 should_fail_ = other.should_fail_;
71 allow_batching_ = other.allow_batching_;
72 was_called_.store(other.was_called_.load());
73 }
74
Callback(DB *)75 Status Callback(DB* /*db*/) override {
76 was_called_.store(true);
77 if (should_fail_) {
78 return Status::Busy();
79 } else {
80 return Status::OK();
81 }
82 }
83
AllowWriteBatching()84 bool AllowWriteBatching() override { return allow_batching_; }
85 };
86
87 #if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
88 class WriteCallbackPTest
89 : public WriteCallbackTest,
90 public ::testing::WithParamInterface<
91 std::tuple<bool, bool, bool, bool, bool, bool, bool>> {
92 public:
WriteCallbackPTest()93 WriteCallbackPTest() {
94 std::tie(unordered_write_, seq_per_batch_, two_queues_, allow_parallel_,
95 allow_batching_, enable_WAL_, enable_pipelined_write_) =
96 GetParam();
97 }
98
99 protected:
100 bool unordered_write_;
101 bool seq_per_batch_;
102 bool two_queues_;
103 bool allow_parallel_;
104 bool allow_batching_;
105 bool enable_WAL_;
106 bool enable_pipelined_write_;
107 };
108
TEST_P(WriteCallbackPTest,WriteWithCallbackTest)109 TEST_P(WriteCallbackPTest, WriteWithCallbackTest) {
110 struct WriteOP {
111 WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; }
112
113 void Put(const string& key, const string& val) {
114 kvs_.push_back(std::make_pair(key, val));
115 ASSERT_OK(write_batch_.Put(key, val));
116 }
117
118 void Clear() {
119 kvs_.clear();
120 write_batch_.Clear();
121 callback_.was_called_.store(false);
122 }
123
124 MockWriteCallback callback_;
125 WriteBatch write_batch_;
126 std::vector<std::pair<string, string>> kvs_;
127 };
128
129 // In each scenario we'll launch multiple threads to write.
130 // The size of each array equals to number of threads, and
131 // each boolean in it denote whether callback of corresponding
132 // thread should succeed or fail.
133 std::vector<std::vector<WriteOP>> write_scenarios = {
134 {true},
135 {false},
136 {false, false},
137 {true, true},
138 {true, false},
139 {false, true},
140 {false, false, false},
141 {true, true, true},
142 {false, true, false},
143 {true, false, true},
144 {true, false, false, false, false},
145 {false, false, false, false, true},
146 {false, false, true, false, true},
147 };
148
149 for (auto& write_group : write_scenarios) {
150 Options options;
151 options.create_if_missing = true;
152 options.unordered_write = unordered_write_;
153 options.allow_concurrent_memtable_write = allow_parallel_;
154 options.enable_pipelined_write = enable_pipelined_write_;
155 options.two_write_queues = two_queues_;
156 // Skip unsupported combinations
157 if (options.enable_pipelined_write && seq_per_batch_) {
158 continue;
159 }
160 if (options.enable_pipelined_write && options.two_write_queues) {
161 continue;
162 }
163 if (options.unordered_write && !options.allow_concurrent_memtable_write) {
164 continue;
165 }
166 if (options.unordered_write && options.enable_pipelined_write) {
167 continue;
168 }
169
170 ReadOptions read_options;
171 DB* db;
172 DBImpl* db_impl;
173
174 DestroyDB(dbname, options);
175
176 DBOptions db_options(options);
177 ColumnFamilyOptions cf_options(options);
178 std::vector<ColumnFamilyDescriptor> column_families;
179 column_families.push_back(
180 ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
181 std::vector<ColumnFamilyHandle*> handles;
182 auto open_s = DBImpl::Open(db_options, dbname, column_families, &handles,
183 &db, seq_per_batch_, true /* batch_per_txn */);
184 ASSERT_OK(open_s);
185 assert(handles.size() == 1);
186 delete handles[0];
187
188 db_impl = dynamic_cast<DBImpl*>(db);
189 ASSERT_TRUE(db_impl);
190
191 // Writers that have called JoinBatchGroup.
192 std::atomic<uint64_t> threads_joining(0);
193 // Writers that have linked to the queue
194 std::atomic<uint64_t> threads_linked(0);
195 // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
196 std::atomic<uint64_t> threads_verified(0);
197
198 std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber());
199 ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0);
200
201 ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
202 "WriteThread::JoinBatchGroup:Start", [&](void*) {
203 uint64_t cur_threads_joining = threads_joining.fetch_add(1);
204 // Wait for the last joined writer to link to the queue.
205 // In this way the writers link to the queue one by one.
206 // This allows us to confidently detect the first writer
207 // who increases threads_linked as the leader.
208 while (threads_linked.load() < cur_threads_joining) {
209 }
210 });
211
212 // Verification once writers call JoinBatchGroup.
213 ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
214 "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
215 uint64_t cur_threads_linked = threads_linked.fetch_add(1);
216 bool is_leader = false;
217 bool is_last = false;
218
219 // who am i
220 is_leader = (cur_threads_linked == 0);
221 is_last = (cur_threads_linked == write_group.size() - 1);
222
223 // check my state
224 auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
225
226 if (is_leader) {
227 ASSERT_TRUE(writer->state ==
228 WriteThread::State::STATE_GROUP_LEADER);
229 } else {
230 ASSERT_TRUE(writer->state == WriteThread::State::STATE_INIT);
231 }
232
233 // (meta test) the first WriteOP should indeed be the first
234 // and the last should be the last (all others can be out of
235 // order)
236 if (is_leader) {
237 ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
238 !write_group.front().callback_.should_fail_);
239 } else if (is_last) {
240 ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
241 !write_group.back().callback_.should_fail_);
242 }
243
244 threads_verified.fetch_add(1);
245 // Wait here until all verification in this sync-point
246 // callback finish for all writers.
247 while (threads_verified.load() < write_group.size()) {
248 }
249 });
250
251 ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
252 "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) {
253 // check my state
254 auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
255
256 if (!allow_batching_) {
257 // no batching so everyone should be a leader
258 ASSERT_TRUE(writer->state ==
259 WriteThread::State::STATE_GROUP_LEADER);
260 } else if (!allow_parallel_) {
261 ASSERT_TRUE(writer->state == WriteThread::State::STATE_COMPLETED ||
262 (enable_pipelined_write_ &&
263 writer->state ==
264 WriteThread::State::STATE_MEMTABLE_WRITER_LEADER));
265 }
266 });
267
268 std::atomic<uint32_t> thread_num(0);
269 std::atomic<char> dummy_key(0);
270
271 // Each write thread create a random write batch and write to DB
272 // with a write callback.
273 std::function<void()> write_with_callback_func = [&]() {
274 uint32_t i = thread_num.fetch_add(1);
275 Random rnd(i);
276
277 // leaders gotta lead
278 while (i > 0 && threads_verified.load() < 1) {
279 }
280
281 // loser has to lose
282 while (i == write_group.size() - 1 &&
283 threads_verified.load() < write_group.size() - 1) {
284 }
285
286 auto& write_op = write_group.at(i);
287 write_op.Clear();
288 write_op.callback_.allow_batching_ = allow_batching_;
289
290 // insert some keys
291 for (uint32_t j = 0; j < rnd.Next() % 50; j++) {
292 // grab unique key
293 char my_key = dummy_key.fetch_add(1);
294
295 string skey(5, my_key);
296 string sval(10, my_key);
297 write_op.Put(skey, sval);
298
299 if (!write_op.callback_.should_fail_ && !seq_per_batch_) {
300 seq.fetch_add(1);
301 }
302 }
303 if (!write_op.callback_.should_fail_ && seq_per_batch_) {
304 seq.fetch_add(1);
305 }
306
307 WriteOptions woptions;
308 woptions.disableWAL = !enable_WAL_;
309 woptions.sync = enable_WAL_;
310 Status s;
311 if (seq_per_batch_) {
312 class PublishSeqCallback : public PreReleaseCallback {
313 public:
314 PublishSeqCallback(DBImpl* db_impl_in) : db_impl_(db_impl_in) {}
315 Status Callback(SequenceNumber last_seq, bool /*not used*/, uint64_t,
316 size_t /*index*/, size_t /*total*/) override {
317 db_impl_->SetLastPublishedSequence(last_seq);
318 return Status::OK();
319 }
320 DBImpl* db_impl_;
321 } publish_seq_callback(db_impl);
322 // seq_per_batch_ requires a natural batch separator or Noop
323 ASSERT_OK(WriteBatchInternal::InsertNoop(&write_op.write_batch_));
324 const size_t ONE_BATCH = 1;
325 s = db_impl->WriteImpl(woptions, &write_op.write_batch_,
326 &write_op.callback_, nullptr, 0, false, nullptr,
327 ONE_BATCH,
328 two_queues_ ? &publish_seq_callback : nullptr);
329 } else {
330 s = db_impl->WriteWithCallback(woptions, &write_op.write_batch_,
331 &write_op.callback_);
332 }
333
334 if (write_op.callback_.should_fail_) {
335 ASSERT_TRUE(s.IsBusy());
336 } else {
337 ASSERT_OK(s);
338 }
339 };
340
341 ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
342
343 // do all the writes
344 std::vector<port::Thread> threads;
345 for (uint32_t i = 0; i < write_group.size(); i++) {
346 threads.emplace_back(write_with_callback_func);
347 }
348 for (auto& t : threads) {
349 t.join();
350 }
351
352 ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
353
354 // check for keys
355 string value;
356 for (auto& w : write_group) {
357 ASSERT_TRUE(w.callback_.was_called_.load());
358 for (auto& kvp : w.kvs_) {
359 if (w.callback_.should_fail_) {
360 ASSERT_TRUE(db->Get(read_options, kvp.first, &value).IsNotFound());
361 } else {
362 ASSERT_OK(db->Get(read_options, kvp.first, &value));
363 ASSERT_EQ(value, kvp.second);
364 }
365 }
366 }
367
368 ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence());
369
370 delete db;
371 DestroyDB(dbname, options);
372 }
373 }
374
375 INSTANTIATE_TEST_CASE_P(WriteCallbackPTest, WriteCallbackPTest,
376 ::testing::Combine(::testing::Bool(), ::testing::Bool(),
377 ::testing::Bool(), ::testing::Bool(),
378 ::testing::Bool(), ::testing::Bool(),
379 ::testing::Bool()));
380 #endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
381
TEST_F(WriteCallbackTest,WriteCallBackTest)382 TEST_F(WriteCallbackTest, WriteCallBackTest) {
383 Options options;
384 WriteOptions write_options;
385 ReadOptions read_options;
386 string value;
387 DB* db;
388 DBImpl* db_impl;
389
390 DestroyDB(dbname, options);
391
392 options.create_if_missing = true;
393 Status s = DB::Open(options, dbname, &db);
394 ASSERT_OK(s);
395
396 db_impl = dynamic_cast<DBImpl*> (db);
397 ASSERT_TRUE(db_impl);
398
399 WriteBatch wb;
400
401 ASSERT_OK(wb.Put("a", "value.a"));
402 ASSERT_OK(wb.Delete("x"));
403
404 // Test a simple Write
405 s = db->Write(write_options, &wb);
406 ASSERT_OK(s);
407
408 s = db->Get(read_options, "a", &value);
409 ASSERT_OK(s);
410 ASSERT_EQ("value.a", value);
411
412 // Test WriteWithCallback
413 WriteCallbackTestWriteCallback1 callback1;
414 WriteBatch wb2;
415
416 ASSERT_OK(wb2.Put("a", "value.a2"));
417
418 s = db_impl->WriteWithCallback(write_options, &wb2, &callback1);
419 ASSERT_OK(s);
420 ASSERT_TRUE(callback1.was_called);
421
422 s = db->Get(read_options, "a", &value);
423 ASSERT_OK(s);
424 ASSERT_EQ("value.a2", value);
425
426 // Test WriteWithCallback for a callback that fails
427 WriteCallbackTestWriteCallback2 callback2;
428 WriteBatch wb3;
429
430 ASSERT_OK(wb3.Put("a", "value.a3"));
431
432 s = db_impl->WriteWithCallback(write_options, &wb3, &callback2);
433 ASSERT_NOK(s);
434
435 s = db->Get(read_options, "a", &value);
436 ASSERT_OK(s);
437 ASSERT_EQ("value.a2", value);
438
439 delete db;
440 DestroyDB(dbname, options);
441 }
442
443 } // namespace ROCKSDB_NAMESPACE
444
main(int argc,char ** argv)445 int main(int argc, char** argv) {
446 ::testing::InitGoogleTest(&argc, argv);
447 return RUN_ALL_TESTS();
448 }
449
450 #else
451 #include <stdio.h>
452
main(int,char **)453 int main(int /*argc*/, char** /*argv*/) {
454 fprintf(stderr,
455 "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n");
456 return 0;
457 }
458
459 #endif // !ROCKSDB_LITE
460