1 //  Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
2 //  This source code is licensed under both the GPLv2 (found in the
3 //  COPYING file in the root directory) and Apache 2.0 License
4 //  (found in the LICENSE.Apache file in the root directory).
5 
6 #ifndef ROCKSDB_LITE
7 
8 #include <atomic>
9 #include <functional>
10 #include <string>
11 #include <utility>
12 #include <vector>
13 
14 #include "db/db_impl/db_impl.h"
15 #include "db/write_callback.h"
16 #include "port/port.h"
17 #include "rocksdb/db.h"
18 #include "rocksdb/write_batch.h"
19 #include "test_util/sync_point.h"
20 #include "test_util/testharness.h"
21 #include "util/random.h"
22 
23 using std::string;
24 
25 namespace ROCKSDB_NAMESPACE {
26 
27 class WriteCallbackTest : public testing::Test {
28  public:
29   string dbname;
30 
WriteCallbackTest()31   WriteCallbackTest() {
32     dbname = test::PerThreadDBPath("write_callback_testdb");
33   }
34 };
35 
36 class WriteCallbackTestWriteCallback1 : public WriteCallback {
37  public:
38   bool was_called = false;
39 
Callback(DB * db)40   Status Callback(DB *db) override {
41     was_called = true;
42 
43     // Make sure db is a DBImpl
44     DBImpl* db_impl = dynamic_cast<DBImpl*> (db);
45     if (db_impl == nullptr) {
46       return Status::InvalidArgument("");
47     }
48 
49     return Status::OK();
50   }
51 
AllowWriteBatching()52   bool AllowWriteBatching() override { return true; }
53 };
54 
55 class WriteCallbackTestWriteCallback2 : public WriteCallback {
56  public:
Callback(DB *)57   Status Callback(DB* /*db*/) override { return Status::Busy(); }
AllowWriteBatching()58   bool AllowWriteBatching() override { return true; }
59 };
60 
61 class MockWriteCallback : public WriteCallback {
62  public:
63   bool should_fail_ = false;
64   bool allow_batching_ = false;
65   std::atomic<bool> was_called_{false};
66 
MockWriteCallback()67   MockWriteCallback() {}
68 
MockWriteCallback(const MockWriteCallback & other)69   MockWriteCallback(const MockWriteCallback& other) {
70     should_fail_ = other.should_fail_;
71     allow_batching_ = other.allow_batching_;
72     was_called_.store(other.was_called_.load());
73   }
74 
Callback(DB *)75   Status Callback(DB* /*db*/) override {
76     was_called_.store(true);
77     if (should_fail_) {
78       return Status::Busy();
79     } else {
80       return Status::OK();
81     }
82   }
83 
AllowWriteBatching()84   bool AllowWriteBatching() override { return allow_batching_; }
85 };
86 
TEST_F(WriteCallbackTest,WriteWithCallbackTest)87 TEST_F(WriteCallbackTest, WriteWithCallbackTest) {
88   struct WriteOP {
89     WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; }
90 
91     void Put(const string& key, const string& val) {
92       kvs_.push_back(std::make_pair(key, val));
93       write_batch_.Put(key, val);
94     }
95 
96     void Clear() {
97       kvs_.clear();
98       write_batch_.Clear();
99       callback_.was_called_.store(false);
100     }
101 
102     MockWriteCallback callback_;
103     WriteBatch write_batch_;
104     std::vector<std::pair<string, string>> kvs_;
105   };
106 
107   // In each scenario we'll launch multiple threads to write.
108   // The size of each array equals to number of threads, and
109   // each boolean in it denote whether callback of corresponding
110   // thread should succeed or fail.
111   std::vector<std::vector<WriteOP>> write_scenarios = {
112       {true},
113       {false},
114       {false, false},
115       {true, true},
116       {true, false},
117       {false, true},
118       {false, false, false},
119       {true, true, true},
120       {false, true, false},
121       {true, false, true},
122       {true, false, false, false, false},
123       {false, false, false, false, true},
124       {false, false, true, false, true},
125   };
126 
127   for (auto& unordered_write : {true, false}) {
128   for (auto& seq_per_batch : {true, false}) {
129   for (auto& two_queues : {true, false}) {
130     for (auto& allow_parallel : {true, false}) {
131       for (auto& allow_batching : {true, false}) {
132         for (auto& enable_WAL : {true, false}) {
133           for (auto& enable_pipelined_write : {true, false}) {
134             for (auto& write_group : write_scenarios) {
135               Options options;
136               options.create_if_missing = true;
137               options.unordered_write = unordered_write;
138               options.allow_concurrent_memtable_write = allow_parallel;
139               options.enable_pipelined_write = enable_pipelined_write;
140               options.two_write_queues = two_queues;
141               // Skip unsupported combinations
142               if (options.enable_pipelined_write && seq_per_batch) {
143                 continue;
144               }
145               if (options.enable_pipelined_write && options.two_write_queues) {
146                 continue;
147               }
148               if (options.unordered_write &&
149                   !options.allow_concurrent_memtable_write) {
150                 continue;
151               }
152               if (options.unordered_write && options.enable_pipelined_write) {
153                 continue;
154               }
155 
156               ReadOptions read_options;
157               DB* db;
158               DBImpl* db_impl;
159 
160               DestroyDB(dbname, options);
161 
162               DBOptions db_options(options);
163               ColumnFamilyOptions cf_options(options);
164               std::vector<ColumnFamilyDescriptor> column_families;
165               column_families.push_back(
166                   ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
167               std::vector<ColumnFamilyHandle*> handles;
168               auto open_s =
169                   DBImpl::Open(db_options, dbname, column_families, &handles,
170                                &db, seq_per_batch, true /* batch_per_txn */);
171               ASSERT_OK(open_s);
172               assert(handles.size() == 1);
173               delete handles[0];
174 
175               db_impl = dynamic_cast<DBImpl*>(db);
176               ASSERT_TRUE(db_impl);
177 
178               // Writers that have called JoinBatchGroup.
179               std::atomic<uint64_t> threads_joining(0);
180               // Writers that have linked to the queue
181               std::atomic<uint64_t> threads_linked(0);
182               // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
183               std::atomic<uint64_t> threads_verified(0);
184 
185               std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber());
186               ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0);
187 
188               ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
189                   "WriteThread::JoinBatchGroup:Start", [&](void*) {
190                     uint64_t cur_threads_joining = threads_joining.fetch_add(1);
191                     // Wait for the last joined writer to link to the queue.
192                     // In this way the writers link to the queue one by one.
193                     // This allows us to confidently detect the first writer
194                     // who increases threads_linked as the leader.
195                     while (threads_linked.load() < cur_threads_joining) {
196                     }
197                   });
198 
199               // Verification once writers call JoinBatchGroup.
200               ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
201                   "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
202                     uint64_t cur_threads_linked = threads_linked.fetch_add(1);
203                     bool is_leader = false;
204                     bool is_last = false;
205 
206                     // who am i
207                     is_leader = (cur_threads_linked == 0);
208                     is_last = (cur_threads_linked == write_group.size() - 1);
209 
210                     // check my state
211                     auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
212 
213                     if (is_leader) {
214                       ASSERT_TRUE(writer->state ==
215                                   WriteThread::State::STATE_GROUP_LEADER);
216                     } else {
217                       ASSERT_TRUE(writer->state ==
218                                   WriteThread::State::STATE_INIT);
219                     }
220 
221                     // (meta test) the first WriteOP should indeed be the first
222                     // and the last should be the last (all others can be out of
223                     // order)
224                     if (is_leader) {
225                       ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
226                                   !write_group.front().callback_.should_fail_);
227                     } else if (is_last) {
228                       ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
229                                   !write_group.back().callback_.should_fail_);
230                     }
231 
232                     threads_verified.fetch_add(1);
233                     // Wait here until all verification in this sync-point
234                     // callback finish for all writers.
235                     while (threads_verified.load() < write_group.size()) {
236                     }
237                   });
238 
239               ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
240                   "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) {
241                     // check my state
242                     auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
243 
244                     if (!allow_batching) {
245                       // no batching so everyone should be a leader
246                       ASSERT_TRUE(writer->state ==
247                                   WriteThread::State::STATE_GROUP_LEADER);
248                     } else if (!allow_parallel) {
249                       ASSERT_TRUE(writer->state ==
250                                       WriteThread::State::STATE_COMPLETED ||
251                                   (enable_pipelined_write &&
252                                    writer->state ==
253                                        WriteThread::State::
254                                            STATE_MEMTABLE_WRITER_LEADER));
255                     }
256                   });
257 
258               std::atomic<uint32_t> thread_num(0);
259               std::atomic<char> dummy_key(0);
260 
261               // Each write thread create a random write batch and write to DB
262               // with a write callback.
263               std::function<void()> write_with_callback_func = [&]() {
264                 uint32_t i = thread_num.fetch_add(1);
265                 Random rnd(i);
266 
267                 // leaders gotta lead
268                 while (i > 0 && threads_verified.load() < 1) {
269                 }
270 
271                 // loser has to lose
272                 while (i == write_group.size() - 1 &&
273                        threads_verified.load() < write_group.size() - 1) {
274                 }
275 
276                 auto& write_op = write_group.at(i);
277                 write_op.Clear();
278                 write_op.callback_.allow_batching_ = allow_batching;
279 
280                 // insert some keys
281                 for (uint32_t j = 0; j < rnd.Next() % 50; j++) {
282                   // grab unique key
283                   char my_key = dummy_key.fetch_add(1);
284 
285                   string skey(5, my_key);
286                   string sval(10, my_key);
287                   write_op.Put(skey, sval);
288 
289                   if (!write_op.callback_.should_fail_ && !seq_per_batch) {
290                     seq.fetch_add(1);
291                   }
292                 }
293                 if (!write_op.callback_.should_fail_ && seq_per_batch) {
294                   seq.fetch_add(1);
295                 }
296 
297                 WriteOptions woptions;
298                 woptions.disableWAL = !enable_WAL;
299                 woptions.sync = enable_WAL;
300                 Status s;
301                 if (seq_per_batch) {
302                   class PublishSeqCallback : public PreReleaseCallback {
303                    public:
304                     PublishSeqCallback(DBImpl* db_impl_in)
305                         : db_impl_(db_impl_in) {}
306                     Status Callback(SequenceNumber last_seq, bool /*not used*/,
307                                     uint64_t, size_t /*index*/,
308                                     size_t /*total*/) override {
309                       db_impl_->SetLastPublishedSequence(last_seq);
310                       return Status::OK();
311                     }
312                     DBImpl* db_impl_;
313                   } publish_seq_callback(db_impl);
314                   // seq_per_batch requires a natural batch separator or Noop
315                   WriteBatchInternal::InsertNoop(&write_op.write_batch_);
316                   const size_t ONE_BATCH = 1;
317                   s = db_impl->WriteImpl(
318                       woptions, &write_op.write_batch_, &write_op.callback_,
319                       nullptr, 0, false, nullptr, ONE_BATCH,
320                       two_queues ? &publish_seq_callback : nullptr);
321                 } else {
322                   s = db_impl->WriteWithCallback(
323                       woptions, &write_op.write_batch_, &write_op.callback_);
324                 }
325 
326                 if (write_op.callback_.should_fail_) {
327                   ASSERT_TRUE(s.IsBusy());
328                 } else {
329                   ASSERT_OK(s);
330                 }
331               };
332 
333               ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
334 
335               // do all the writes
336               std::vector<port::Thread> threads;
337               for (uint32_t i = 0; i < write_group.size(); i++) {
338                 threads.emplace_back(write_with_callback_func);
339               }
340               for (auto& t : threads) {
341                 t.join();
342               }
343 
344               ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
345 
346               // check for keys
347               string value;
348               for (auto& w : write_group) {
349                 ASSERT_TRUE(w.callback_.was_called_.load());
350                 for (auto& kvp : w.kvs_) {
351                   if (w.callback_.should_fail_) {
352                     ASSERT_TRUE(
353                         db->Get(read_options, kvp.first, &value).IsNotFound());
354                   } else {
355                     ASSERT_OK(db->Get(read_options, kvp.first, &value));
356                     ASSERT_EQ(value, kvp.second);
357                   }
358                 }
359               }
360 
361               ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence());
362 
363               delete db;
364               DestroyDB(dbname, options);
365             }
366           }
367         }
368       }
369     }
370   }
371   }
372   }
373 }
374 
TEST_F(WriteCallbackTest,WriteCallBackTest)375 TEST_F(WriteCallbackTest, WriteCallBackTest) {
376   Options options;
377   WriteOptions write_options;
378   ReadOptions read_options;
379   string value;
380   DB* db;
381   DBImpl* db_impl;
382 
383   DestroyDB(dbname, options);
384 
385   options.create_if_missing = true;
386   Status s = DB::Open(options, dbname, &db);
387   ASSERT_OK(s);
388 
389   db_impl = dynamic_cast<DBImpl*> (db);
390   ASSERT_TRUE(db_impl);
391 
392   WriteBatch wb;
393 
394   wb.Put("a", "value.a");
395   wb.Delete("x");
396 
397   // Test a simple Write
398   s = db->Write(write_options, &wb);
399   ASSERT_OK(s);
400 
401   s = db->Get(read_options, "a", &value);
402   ASSERT_OK(s);
403   ASSERT_EQ("value.a", value);
404 
405   // Test WriteWithCallback
406   WriteCallbackTestWriteCallback1 callback1;
407   WriteBatch wb2;
408 
409   wb2.Put("a", "value.a2");
410 
411   s = db_impl->WriteWithCallback(write_options, &wb2, &callback1);
412   ASSERT_OK(s);
413   ASSERT_TRUE(callback1.was_called);
414 
415   s = db->Get(read_options, "a", &value);
416   ASSERT_OK(s);
417   ASSERT_EQ("value.a2", value);
418 
419   // Test WriteWithCallback for a callback that fails
420   WriteCallbackTestWriteCallback2 callback2;
421   WriteBatch wb3;
422 
423   wb3.Put("a", "value.a3");
424 
425   s = db_impl->WriteWithCallback(write_options, &wb3, &callback2);
426   ASSERT_NOK(s);
427 
428   s = db->Get(read_options, "a", &value);
429   ASSERT_OK(s);
430   ASSERT_EQ("value.a2", value);
431 
432   delete db;
433   DestroyDB(dbname, options);
434 }
435 
436 }  // namespace ROCKSDB_NAMESPACE
437 
main(int argc,char ** argv)438 int main(int argc, char** argv) {
439   ::testing::InitGoogleTest(&argc, argv);
440   return RUN_ALL_TESTS();
441 }
442 
443 #else
444 #include <stdio.h>
445 
main(int,char **)446 int main(int /*argc*/, char** /*argv*/) {
447   fprintf(stderr,
448           "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n");
449   return 0;
450 }
451 
452 #endif  // !ROCKSDB_LITE
453