1 //  Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
2 //  This source code is licensed under both the GPLv2 (found in the
3 //  COPYING file in the root directory) and Apache 2.0 License
4 //  (found in the LICENSE.Apache file in the root directory).
5 
6 #ifndef ROCKSDB_LITE
7 
8 #include <atomic>
9 #include <functional>
10 #include <string>
11 #include <utility>
12 #include <vector>
13 
14 #include "db/db_impl/db_impl.h"
15 #include "db/write_callback.h"
16 #include "port/port.h"
17 #include "rocksdb/db.h"
18 #include "rocksdb/write_batch.h"
19 #include "test_util/sync_point.h"
20 #include "test_util/testharness.h"
21 #include "util/random.h"
22 
23 using std::string;
24 
25 namespace ROCKSDB_NAMESPACE {
26 
27 class WriteCallbackTest : public testing::Test {
28  public:
29   string dbname;
30 
WriteCallbackTest()31   WriteCallbackTest() {
32     dbname = test::PerThreadDBPath("write_callback_testdb");
33   }
34 };
35 
36 class WriteCallbackTestWriteCallback1 : public WriteCallback {
37  public:
38   bool was_called = false;
39 
Callback(DB * db)40   Status Callback(DB *db) override {
41     was_called = true;
42 
43     // Make sure db is a DBImpl
44     DBImpl* db_impl = dynamic_cast<DBImpl*> (db);
45     if (db_impl == nullptr) {
46       return Status::InvalidArgument("");
47     }
48 
49     return Status::OK();
50   }
51 
AllowWriteBatching()52   bool AllowWriteBatching() override { return true; }
53 };
54 
55 class WriteCallbackTestWriteCallback2 : public WriteCallback {
56  public:
Callback(DB *)57   Status Callback(DB* /*db*/) override { return Status::Busy(); }
AllowWriteBatching()58   bool AllowWriteBatching() override { return true; }
59 };
60 
61 class MockWriteCallback : public WriteCallback {
62  public:
63   bool should_fail_ = false;
64   bool allow_batching_ = false;
65   std::atomic<bool> was_called_{false};
66 
MockWriteCallback()67   MockWriteCallback() {}
68 
MockWriteCallback(const MockWriteCallback & other)69   MockWriteCallback(const MockWriteCallback& other) {
70     should_fail_ = other.should_fail_;
71     allow_batching_ = other.allow_batching_;
72     was_called_.store(other.was_called_.load());
73   }
74 
Callback(DB *)75   Status Callback(DB* /*db*/) override {
76     was_called_.store(true);
77     if (should_fail_) {
78       return Status::Busy();
79     } else {
80       return Status::OK();
81     }
82   }
83 
AllowWriteBatching()84   bool AllowWriteBatching() override { return allow_batching_; }
85 };
86 
87 #if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
88 class WriteCallbackPTest
89     : public WriteCallbackTest,
90       public ::testing::WithParamInterface<
91           std::tuple<bool, bool, bool, bool, bool, bool, bool>> {
92  public:
WriteCallbackPTest()93   WriteCallbackPTest() {
94     std::tie(unordered_write_, seq_per_batch_, two_queues_, allow_parallel_,
95              allow_batching_, enable_WAL_, enable_pipelined_write_) =
96         GetParam();
97   }
98 
99  protected:
100   bool unordered_write_;
101   bool seq_per_batch_;
102   bool two_queues_;
103   bool allow_parallel_;
104   bool allow_batching_;
105   bool enable_WAL_;
106   bool enable_pipelined_write_;
107 };
108 
TEST_P(WriteCallbackPTest,WriteWithCallbackTest)109 TEST_P(WriteCallbackPTest, WriteWithCallbackTest) {
110   struct WriteOP {
111     WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; }
112 
113     void Put(const string& key, const string& val) {
114       kvs_.push_back(std::make_pair(key, val));
115       ASSERT_OK(write_batch_.Put(key, val));
116     }
117 
118     void Clear() {
119       kvs_.clear();
120       write_batch_.Clear();
121       callback_.was_called_.store(false);
122     }
123 
124     MockWriteCallback callback_;
125     WriteBatch write_batch_;
126     std::vector<std::pair<string, string>> kvs_;
127   };
128 
129   // In each scenario we'll launch multiple threads to write.
130   // The size of each array equals to number of threads, and
131   // each boolean in it denote whether callback of corresponding
132   // thread should succeed or fail.
133   std::vector<std::vector<WriteOP>> write_scenarios = {
134       {true},
135       {false},
136       {false, false},
137       {true, true},
138       {true, false},
139       {false, true},
140       {false, false, false},
141       {true, true, true},
142       {false, true, false},
143       {true, false, true},
144       {true, false, false, false, false},
145       {false, false, false, false, true},
146       {false, false, true, false, true},
147   };
148 
149   for (auto& write_group : write_scenarios) {
150     Options options;
151     options.create_if_missing = true;
152     options.unordered_write = unordered_write_;
153     options.allow_concurrent_memtable_write = allow_parallel_;
154     options.enable_pipelined_write = enable_pipelined_write_;
155     options.two_write_queues = two_queues_;
156     // Skip unsupported combinations
157     if (options.enable_pipelined_write && seq_per_batch_) {
158       continue;
159     }
160     if (options.enable_pipelined_write && options.two_write_queues) {
161       continue;
162     }
163     if (options.unordered_write && !options.allow_concurrent_memtable_write) {
164       continue;
165     }
166     if (options.unordered_write && options.enable_pipelined_write) {
167       continue;
168     }
169 
170     ReadOptions read_options;
171     DB* db;
172     DBImpl* db_impl;
173 
174     DestroyDB(dbname, options);
175 
176     DBOptions db_options(options);
177     ColumnFamilyOptions cf_options(options);
178     std::vector<ColumnFamilyDescriptor> column_families;
179     column_families.push_back(
180         ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
181     std::vector<ColumnFamilyHandle*> handles;
182     auto open_s = DBImpl::Open(db_options, dbname, column_families, &handles,
183                                &db, seq_per_batch_, true /* batch_per_txn */);
184     ASSERT_OK(open_s);
185     assert(handles.size() == 1);
186     delete handles[0];
187 
188     db_impl = dynamic_cast<DBImpl*>(db);
189     ASSERT_TRUE(db_impl);
190 
191     // Writers that have called JoinBatchGroup.
192     std::atomic<uint64_t> threads_joining(0);
193     // Writers that have linked to the queue
194     std::atomic<uint64_t> threads_linked(0);
195     // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
196     std::atomic<uint64_t> threads_verified(0);
197 
198     std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber());
199     ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0);
200 
201     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
202         "WriteThread::JoinBatchGroup:Start", [&](void*) {
203           uint64_t cur_threads_joining = threads_joining.fetch_add(1);
204           // Wait for the last joined writer to link to the queue.
205           // In this way the writers link to the queue one by one.
206           // This allows us to confidently detect the first writer
207           // who increases threads_linked as the leader.
208           while (threads_linked.load() < cur_threads_joining) {
209           }
210         });
211 
212     // Verification once writers call JoinBatchGroup.
213     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
214         "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
215           uint64_t cur_threads_linked = threads_linked.fetch_add(1);
216           bool is_leader = false;
217           bool is_last = false;
218 
219           // who am i
220           is_leader = (cur_threads_linked == 0);
221           is_last = (cur_threads_linked == write_group.size() - 1);
222 
223           // check my state
224           auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
225 
226           if (is_leader) {
227             ASSERT_TRUE(writer->state ==
228                         WriteThread::State::STATE_GROUP_LEADER);
229           } else {
230             ASSERT_TRUE(writer->state == WriteThread::State::STATE_INIT);
231           }
232 
233           // (meta test) the first WriteOP should indeed be the first
234           // and the last should be the last (all others can be out of
235           // order)
236           if (is_leader) {
237             ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
238                         !write_group.front().callback_.should_fail_);
239           } else if (is_last) {
240             ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
241                         !write_group.back().callback_.should_fail_);
242           }
243 
244           threads_verified.fetch_add(1);
245           // Wait here until all verification in this sync-point
246           // callback finish for all writers.
247           while (threads_verified.load() < write_group.size()) {
248           }
249         });
250 
251     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
252         "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) {
253           // check my state
254           auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
255 
256           if (!allow_batching_) {
257             // no batching so everyone should be a leader
258             ASSERT_TRUE(writer->state ==
259                         WriteThread::State::STATE_GROUP_LEADER);
260           } else if (!allow_parallel_) {
261             ASSERT_TRUE(writer->state == WriteThread::State::STATE_COMPLETED ||
262                         (enable_pipelined_write_ &&
263                          writer->state ==
264                              WriteThread::State::STATE_MEMTABLE_WRITER_LEADER));
265           }
266         });
267 
268     std::atomic<uint32_t> thread_num(0);
269     std::atomic<char> dummy_key(0);
270 
271     // Each write thread create a random write batch and write to DB
272     // with a write callback.
273     std::function<void()> write_with_callback_func = [&]() {
274       uint32_t i = thread_num.fetch_add(1);
275       Random rnd(i);
276 
277       // leaders gotta lead
278       while (i > 0 && threads_verified.load() < 1) {
279       }
280 
281       // loser has to lose
282       while (i == write_group.size() - 1 &&
283              threads_verified.load() < write_group.size() - 1) {
284       }
285 
286       auto& write_op = write_group.at(i);
287       write_op.Clear();
288       write_op.callback_.allow_batching_ = allow_batching_;
289 
290       // insert some keys
291       for (uint32_t j = 0; j < rnd.Next() % 50; j++) {
292         // grab unique key
293         char my_key = dummy_key.fetch_add(1);
294 
295         string skey(5, my_key);
296         string sval(10, my_key);
297         write_op.Put(skey, sval);
298 
299         if (!write_op.callback_.should_fail_ && !seq_per_batch_) {
300           seq.fetch_add(1);
301         }
302       }
303       if (!write_op.callback_.should_fail_ && seq_per_batch_) {
304         seq.fetch_add(1);
305       }
306 
307       WriteOptions woptions;
308       woptions.disableWAL = !enable_WAL_;
309       woptions.sync = enable_WAL_;
310       Status s;
311       if (seq_per_batch_) {
312         class PublishSeqCallback : public PreReleaseCallback {
313          public:
314           PublishSeqCallback(DBImpl* db_impl_in) : db_impl_(db_impl_in) {}
315           Status Callback(SequenceNumber last_seq, bool /*not used*/, uint64_t,
316                           size_t /*index*/, size_t /*total*/) override {
317             db_impl_->SetLastPublishedSequence(last_seq);
318             return Status::OK();
319           }
320           DBImpl* db_impl_;
321         } publish_seq_callback(db_impl);
322         // seq_per_batch_ requires a natural batch separator or Noop
323         ASSERT_OK(WriteBatchInternal::InsertNoop(&write_op.write_batch_));
324         const size_t ONE_BATCH = 1;
325         s = db_impl->WriteImpl(woptions, &write_op.write_batch_,
326                                &write_op.callback_, nullptr, 0, false, nullptr,
327                                ONE_BATCH,
328                                two_queues_ ? &publish_seq_callback : nullptr);
329       } else {
330         s = db_impl->WriteWithCallback(woptions, &write_op.write_batch_,
331                                        &write_op.callback_);
332       }
333 
334       if (write_op.callback_.should_fail_) {
335         ASSERT_TRUE(s.IsBusy());
336       } else {
337         ASSERT_OK(s);
338       }
339     };
340 
341     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
342 
343     // do all the writes
344     std::vector<port::Thread> threads;
345     for (uint32_t i = 0; i < write_group.size(); i++) {
346       threads.emplace_back(write_with_callback_func);
347     }
348     for (auto& t : threads) {
349       t.join();
350     }
351 
352     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
353 
354     // check for keys
355     string value;
356     for (auto& w : write_group) {
357       ASSERT_TRUE(w.callback_.was_called_.load());
358       for (auto& kvp : w.kvs_) {
359         if (w.callback_.should_fail_) {
360           ASSERT_TRUE(db->Get(read_options, kvp.first, &value).IsNotFound());
361         } else {
362           ASSERT_OK(db->Get(read_options, kvp.first, &value));
363           ASSERT_EQ(value, kvp.second);
364         }
365       }
366     }
367 
368     ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence());
369 
370     delete db;
371     DestroyDB(dbname, options);
372   }
373 }
374 
375 INSTANTIATE_TEST_CASE_P(WriteCallbackPTest, WriteCallbackPTest,
376                         ::testing::Combine(::testing::Bool(), ::testing::Bool(),
377                                            ::testing::Bool(), ::testing::Bool(),
378                                            ::testing::Bool(), ::testing::Bool(),
379                                            ::testing::Bool()));
380 #endif  // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
381 
TEST_F(WriteCallbackTest,WriteCallBackTest)382 TEST_F(WriteCallbackTest, WriteCallBackTest) {
383   Options options;
384   WriteOptions write_options;
385   ReadOptions read_options;
386   string value;
387   DB* db;
388   DBImpl* db_impl;
389 
390   DestroyDB(dbname, options);
391 
392   options.create_if_missing = true;
393   Status s = DB::Open(options, dbname, &db);
394   ASSERT_OK(s);
395 
396   db_impl = dynamic_cast<DBImpl*> (db);
397   ASSERT_TRUE(db_impl);
398 
399   WriteBatch wb;
400 
401   ASSERT_OK(wb.Put("a", "value.a"));
402   ASSERT_OK(wb.Delete("x"));
403 
404   // Test a simple Write
405   s = db->Write(write_options, &wb);
406   ASSERT_OK(s);
407 
408   s = db->Get(read_options, "a", &value);
409   ASSERT_OK(s);
410   ASSERT_EQ("value.a", value);
411 
412   // Test WriteWithCallback
413   WriteCallbackTestWriteCallback1 callback1;
414   WriteBatch wb2;
415 
416   ASSERT_OK(wb2.Put("a", "value.a2"));
417 
418   s = db_impl->WriteWithCallback(write_options, &wb2, &callback1);
419   ASSERT_OK(s);
420   ASSERT_TRUE(callback1.was_called);
421 
422   s = db->Get(read_options, "a", &value);
423   ASSERT_OK(s);
424   ASSERT_EQ("value.a2", value);
425 
426   // Test WriteWithCallback for a callback that fails
427   WriteCallbackTestWriteCallback2 callback2;
428   WriteBatch wb3;
429 
430   ASSERT_OK(wb3.Put("a", "value.a3"));
431 
432   s = db_impl->WriteWithCallback(write_options, &wb3, &callback2);
433   ASSERT_NOK(s);
434 
435   s = db->Get(read_options, "a", &value);
436   ASSERT_OK(s);
437   ASSERT_EQ("value.a2", value);
438 
439   delete db;
440   DestroyDB(dbname, options);
441 }
442 
443 }  // namespace ROCKSDB_NAMESPACE
444 
main(int argc,char ** argv)445 int main(int argc, char** argv) {
446   ::testing::InitGoogleTest(&argc, argv);
447   return RUN_ALL_TESTS();
448 }
449 
450 #else
451 #include <stdio.h>
452 
main(int,char **)453 int main(int /*argc*/, char** /*argv*/) {
454   fprintf(stderr,
455           "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n");
456   return 0;
457 }
458 
459 #endif  // !ROCKSDB_LITE
460