1 // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
2 // This source code is licensed under both the GPLv2 (found in the
3 // COPYING file in the root directory) and Apache 2.0 License
4 // (found in the LICENSE.Apache file in the root directory).
5
6 #ifndef ROCKSDB_LITE
7
8 #include <atomic>
9 #include <functional>
10 #include <string>
11 #include <utility>
12 #include <vector>
13
14 #include "db/db_impl/db_impl.h"
15 #include "db/write_callback.h"
16 #include "port/port.h"
17 #include "rocksdb/db.h"
18 #include "rocksdb/write_batch.h"
19 #include "test_util/sync_point.h"
20 #include "test_util/testharness.h"
21 #include "util/random.h"
22
23 using std::string;
24
25 namespace rocksdb {
26
27 class WriteCallbackTest : public testing::Test {
28 public:
29 string dbname;
30
WriteCallbackTest()31 WriteCallbackTest() {
32 dbname = test::PerThreadDBPath("write_callback_testdb");
33 }
34 };
35
36 class WriteCallbackTestWriteCallback1 : public WriteCallback {
37 public:
38 bool was_called = false;
39
Callback(DB * db)40 Status Callback(DB *db) override {
41 was_called = true;
42
43 // Make sure db is a DBImpl
44 DBImpl* db_impl = dynamic_cast<DBImpl*> (db);
45 if (db_impl == nullptr) {
46 return Status::InvalidArgument("");
47 }
48
49 return Status::OK();
50 }
51
AllowWriteBatching()52 bool AllowWriteBatching() override { return true; }
53 };
54
55 class WriteCallbackTestWriteCallback2 : public WriteCallback {
56 public:
Callback(DB *)57 Status Callback(DB* /*db*/) override { return Status::Busy(); }
AllowWriteBatching()58 bool AllowWriteBatching() override { return true; }
59 };
60
61 class MockWriteCallback : public WriteCallback {
62 public:
63 bool should_fail_ = false;
64 bool allow_batching_ = false;
65 std::atomic<bool> was_called_{false};
66
MockWriteCallback()67 MockWriteCallback() {}
68
MockWriteCallback(const MockWriteCallback & other)69 MockWriteCallback(const MockWriteCallback& other) {
70 should_fail_ = other.should_fail_;
71 allow_batching_ = other.allow_batching_;
72 was_called_.store(other.was_called_.load());
73 }
74
Callback(DB *)75 Status Callback(DB* /*db*/) override {
76 was_called_.store(true);
77 if (should_fail_) {
78 return Status::Busy();
79 } else {
80 return Status::OK();
81 }
82 }
83
AllowWriteBatching()84 bool AllowWriteBatching() override { return allow_batching_; }
85 };
86
TEST_F(WriteCallbackTest,WriteWithCallbackTest)87 TEST_F(WriteCallbackTest, WriteWithCallbackTest) {
88 struct WriteOP {
89 WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; }
90
91 void Put(const string& key, const string& val) {
92 kvs_.push_back(std::make_pair(key, val));
93 write_batch_.Put(key, val);
94 }
95
96 void Clear() {
97 kvs_.clear();
98 write_batch_.Clear();
99 callback_.was_called_.store(false);
100 }
101
102 MockWriteCallback callback_;
103 WriteBatch write_batch_;
104 std::vector<std::pair<string, string>> kvs_;
105 };
106
107 // In each scenario we'll launch multiple threads to write.
108 // The size of each array equals to number of threads, and
109 // each boolean in it denote whether callback of corresponding
110 // thread should succeed or fail.
111 std::vector<std::vector<WriteOP>> write_scenarios = {
112 {true},
113 {false},
114 {false, false},
115 {true, true},
116 {true, false},
117 {false, true},
118 {false, false, false},
119 {true, true, true},
120 {false, true, false},
121 {true, false, true},
122 {true, false, false, false, false},
123 {false, false, false, false, true},
124 {false, false, true, false, true},
125 };
126
127 for (auto& unordered_write : {true, false}) {
128 for (auto& seq_per_batch : {true, false}) {
129 for (auto& two_queues : {true, false}) {
130 for (auto& allow_parallel : {true, false}) {
131 for (auto& allow_batching : {true, false}) {
132 for (auto& enable_WAL : {true, false}) {
133 for (auto& enable_pipelined_write : {true, false}) {
134 for (auto& write_group : write_scenarios) {
135 Options options;
136 options.create_if_missing = true;
137 options.unordered_write = unordered_write;
138 options.allow_concurrent_memtable_write = allow_parallel;
139 options.enable_pipelined_write = enable_pipelined_write;
140 options.two_write_queues = two_queues;
141 // Skip unsupported combinations
142 if (options.enable_pipelined_write && seq_per_batch) {
143 continue;
144 }
145 if (options.enable_pipelined_write && options.two_write_queues) {
146 continue;
147 }
148 if (options.unordered_write &&
149 !options.allow_concurrent_memtable_write) {
150 continue;
151 }
152 if (options.unordered_write && options.enable_pipelined_write) {
153 continue;
154 }
155
156 ReadOptions read_options;
157 DB* db;
158 DBImpl* db_impl;
159
160 DestroyDB(dbname, options);
161
162 DBOptions db_options(options);
163 ColumnFamilyOptions cf_options(options);
164 std::vector<ColumnFamilyDescriptor> column_families;
165 column_families.push_back(
166 ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
167 std::vector<ColumnFamilyHandle*> handles;
168 auto open_s =
169 DBImpl::Open(db_options, dbname, column_families, &handles,
170 &db, seq_per_batch, true /* batch_per_txn */);
171 ASSERT_OK(open_s);
172 assert(handles.size() == 1);
173 delete handles[0];
174
175 db_impl = dynamic_cast<DBImpl*>(db);
176 ASSERT_TRUE(db_impl);
177
178 // Writers that have called JoinBatchGroup.
179 std::atomic<uint64_t> threads_joining(0);
180 // Writers that have linked to the queue
181 std::atomic<uint64_t> threads_linked(0);
182 // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
183 std::atomic<uint64_t> threads_verified(0);
184
185 std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber());
186 ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0);
187
188 rocksdb::SyncPoint::GetInstance()->SetCallBack(
189 "WriteThread::JoinBatchGroup:Start", [&](void*) {
190 uint64_t cur_threads_joining = threads_joining.fetch_add(1);
191 // Wait for the last joined writer to link to the queue.
192 // In this way the writers link to the queue one by one.
193 // This allows us to confidently detect the first writer
194 // who increases threads_linked as the leader.
195 while (threads_linked.load() < cur_threads_joining) {
196 }
197 });
198
199 // Verification once writers call JoinBatchGroup.
200 rocksdb::SyncPoint::GetInstance()->SetCallBack(
201 "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
202 uint64_t cur_threads_linked = threads_linked.fetch_add(1);
203 bool is_leader = false;
204 bool is_last = false;
205
206 // who am i
207 is_leader = (cur_threads_linked == 0);
208 is_last = (cur_threads_linked == write_group.size() - 1);
209
210 // check my state
211 auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
212
213 if (is_leader) {
214 ASSERT_TRUE(writer->state ==
215 WriteThread::State::STATE_GROUP_LEADER);
216 } else {
217 ASSERT_TRUE(writer->state ==
218 WriteThread::State::STATE_INIT);
219 }
220
221 // (meta test) the first WriteOP should indeed be the first
222 // and the last should be the last (all others can be out of
223 // order)
224 if (is_leader) {
225 ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
226 !write_group.front().callback_.should_fail_);
227 } else if (is_last) {
228 ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
229 !write_group.back().callback_.should_fail_);
230 }
231
232 threads_verified.fetch_add(1);
233 // Wait here until all verification in this sync-point
234 // callback finish for all writers.
235 while (threads_verified.load() < write_group.size()) {
236 }
237 });
238
239 rocksdb::SyncPoint::GetInstance()->SetCallBack(
240 "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) {
241 // check my state
242 auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
243
244 if (!allow_batching) {
245 // no batching so everyone should be a leader
246 ASSERT_TRUE(writer->state ==
247 WriteThread::State::STATE_GROUP_LEADER);
248 } else if (!allow_parallel) {
249 ASSERT_TRUE(writer->state ==
250 WriteThread::State::STATE_COMPLETED ||
251 (enable_pipelined_write &&
252 writer->state ==
253 WriteThread::State::
254 STATE_MEMTABLE_WRITER_LEADER));
255 }
256 });
257
258 std::atomic<uint32_t> thread_num(0);
259 std::atomic<char> dummy_key(0);
260
261 // Each write thread create a random write batch and write to DB
262 // with a write callback.
263 std::function<void()> write_with_callback_func = [&]() {
264 uint32_t i = thread_num.fetch_add(1);
265 Random rnd(i);
266
267 // leaders gotta lead
268 while (i > 0 && threads_verified.load() < 1) {
269 }
270
271 // loser has to lose
272 while (i == write_group.size() - 1 &&
273 threads_verified.load() < write_group.size() - 1) {
274 }
275
276 auto& write_op = write_group.at(i);
277 write_op.Clear();
278 write_op.callback_.allow_batching_ = allow_batching;
279
280 // insert some keys
281 for (uint32_t j = 0; j < rnd.Next() % 50; j++) {
282 // grab unique key
283 char my_key = dummy_key.fetch_add(1);
284
285 string skey(5, my_key);
286 string sval(10, my_key);
287 write_op.Put(skey, sval);
288
289 if (!write_op.callback_.should_fail_ && !seq_per_batch) {
290 seq.fetch_add(1);
291 }
292 }
293 if (!write_op.callback_.should_fail_ && seq_per_batch) {
294 seq.fetch_add(1);
295 }
296
297 WriteOptions woptions;
298 woptions.disableWAL = !enable_WAL;
299 woptions.sync = enable_WAL;
300 Status s;
301 if (seq_per_batch) {
302 class PublishSeqCallback : public PreReleaseCallback {
303 public:
304 PublishSeqCallback(DBImpl* db_impl_in)
305 : db_impl_(db_impl_in) {}
306 Status Callback(SequenceNumber last_seq, bool /*not used*/,
307 uint64_t, size_t /*index*/,
308 size_t /*total*/) override {
309 db_impl_->SetLastPublishedSequence(last_seq);
310 return Status::OK();
311 }
312 DBImpl* db_impl_;
313 } publish_seq_callback(db_impl);
314 // seq_per_batch requires a natural batch separator or Noop
315 WriteBatchInternal::InsertNoop(&write_op.write_batch_);
316 const size_t ONE_BATCH = 1;
317 s = db_impl->WriteImpl(
318 woptions, &write_op.write_batch_, &write_op.callback_,
319 nullptr, 0, false, nullptr, ONE_BATCH,
320 two_queues ? &publish_seq_callback : nullptr);
321 } else {
322 s = db_impl->WriteWithCallback(
323 woptions, &write_op.write_batch_, &write_op.callback_);
324 }
325
326 if (write_op.callback_.should_fail_) {
327 ASSERT_TRUE(s.IsBusy());
328 } else {
329 ASSERT_OK(s);
330 }
331 };
332
333 rocksdb::SyncPoint::GetInstance()->EnableProcessing();
334
335 // do all the writes
336 std::vector<port::Thread> threads;
337 for (uint32_t i = 0; i < write_group.size(); i++) {
338 threads.emplace_back(write_with_callback_func);
339 }
340 for (auto& t : threads) {
341 t.join();
342 }
343
344 rocksdb::SyncPoint::GetInstance()->DisableProcessing();
345
346 // check for keys
347 string value;
348 for (auto& w : write_group) {
349 ASSERT_TRUE(w.callback_.was_called_.load());
350 for (auto& kvp : w.kvs_) {
351 if (w.callback_.should_fail_) {
352 ASSERT_TRUE(
353 db->Get(read_options, kvp.first, &value).IsNotFound());
354 } else {
355 ASSERT_OK(db->Get(read_options, kvp.first, &value));
356 ASSERT_EQ(value, kvp.second);
357 }
358 }
359 }
360
361 ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence());
362
363 delete db;
364 DestroyDB(dbname, options);
365 }
366 }
367 }
368 }
369 }
370 }
371 }
372 }
373 }
374
TEST_F(WriteCallbackTest,WriteCallBackTest)375 TEST_F(WriteCallbackTest, WriteCallBackTest) {
376 Options options;
377 WriteOptions write_options;
378 ReadOptions read_options;
379 string value;
380 DB* db;
381 DBImpl* db_impl;
382
383 DestroyDB(dbname, options);
384
385 options.create_if_missing = true;
386 Status s = DB::Open(options, dbname, &db);
387 ASSERT_OK(s);
388
389 db_impl = dynamic_cast<DBImpl*> (db);
390 ASSERT_TRUE(db_impl);
391
392 WriteBatch wb;
393
394 wb.Put("a", "value.a");
395 wb.Delete("x");
396
397 // Test a simple Write
398 s = db->Write(write_options, &wb);
399 ASSERT_OK(s);
400
401 s = db->Get(read_options, "a", &value);
402 ASSERT_OK(s);
403 ASSERT_EQ("value.a", value);
404
405 // Test WriteWithCallback
406 WriteCallbackTestWriteCallback1 callback1;
407 WriteBatch wb2;
408
409 wb2.Put("a", "value.a2");
410
411 s = db_impl->WriteWithCallback(write_options, &wb2, &callback1);
412 ASSERT_OK(s);
413 ASSERT_TRUE(callback1.was_called);
414
415 s = db->Get(read_options, "a", &value);
416 ASSERT_OK(s);
417 ASSERT_EQ("value.a2", value);
418
419 // Test WriteWithCallback for a callback that fails
420 WriteCallbackTestWriteCallback2 callback2;
421 WriteBatch wb3;
422
423 wb3.Put("a", "value.a3");
424
425 s = db_impl->WriteWithCallback(write_options, &wb3, &callback2);
426 ASSERT_NOK(s);
427
428 s = db->Get(read_options, "a", &value);
429 ASSERT_OK(s);
430 ASSERT_EQ("value.a2", value);
431
432 delete db;
433 DestroyDB(dbname, options);
434 }
435
436 } // namespace rocksdb
437
main(int argc,char ** argv)438 int main(int argc, char** argv) {
439 ::testing::InitGoogleTest(&argc, argv);
440 return RUN_ALL_TESTS();
441 }
442
443 #else
444 #include <stdio.h>
445
main(int,char **)446 int main(int /*argc*/, char** /*argv*/) {
447 fprintf(stderr,
448 "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n");
449 return 0;
450 }
451
452 #endif // !ROCKSDB_LITE
453