1 //  Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
2 //  This source code is licensed under both the GPLv2 (found in the
3 //  COPYING file in the root directory) and Apache 2.0 License
4 //  (found in the LICENSE.Apache file in the root directory).
5 
6 #ifndef ROCKSDB_LITE
7 
8 #include <atomic>
9 #include <functional>
10 #include <string>
11 #include <utility>
12 #include <vector>
13 
14 #include "db/db_impl/db_impl.h"
15 #include "db/write_callback.h"
16 #include "port/port.h"
17 #include "rocksdb/db.h"
18 #include "rocksdb/write_batch.h"
19 #include "test_util/sync_point.h"
20 #include "test_util/testharness.h"
21 #include "util/random.h"
22 
23 using std::string;
24 
25 namespace ROCKSDB_NAMESPACE {
26 
27 class WriteCallbackTest : public testing::Test {
28  public:
29   string dbname;
30 
WriteCallbackTest()31   WriteCallbackTest() {
32     dbname = test::PerThreadDBPath("write_callback_testdb");
33   }
34 };
35 
36 class WriteCallbackTestWriteCallback1 : public WriteCallback {
37  public:
38   bool was_called = false;
39 
Callback(DB * db)40   Status Callback(DB *db) override {
41     was_called = true;
42 
43     // Make sure db is a DBImpl
44     DBImpl* db_impl = dynamic_cast<DBImpl*> (db);
45     if (db_impl == nullptr) {
46       return Status::InvalidArgument("");
47     }
48 
49     return Status::OK();
50   }
51 
AllowWriteBatching()52   bool AllowWriteBatching() override { return true; }
53 };
54 
55 class WriteCallbackTestWriteCallback2 : public WriteCallback {
56  public:
Callback(DB *)57   Status Callback(DB* /*db*/) override { return Status::Busy(); }
AllowWriteBatching()58   bool AllowWriteBatching() override { return true; }
59 };
60 
61 class MockWriteCallback : public WriteCallback {
62  public:
63   bool should_fail_ = false;
64   bool allow_batching_ = false;
65   std::atomic<bool> was_called_{false};
66 
MockWriteCallback()67   MockWriteCallback() {}
68 
MockWriteCallback(const MockWriteCallback & other)69   MockWriteCallback(const MockWriteCallback& other) {
70     should_fail_ = other.should_fail_;
71     allow_batching_ = other.allow_batching_;
72     was_called_.store(other.was_called_.load());
73   }
74 
Callback(DB *)75   Status Callback(DB* /*db*/) override {
76     was_called_.store(true);
77     if (should_fail_) {
78       return Status::Busy();
79     } else {
80       return Status::OK();
81     }
82   }
83 
AllowWriteBatching()84   bool AllowWriteBatching() override { return allow_batching_; }
85 };
86 
87 class WriteCallbackPTest
88     : public WriteCallbackTest,
89       public ::testing::WithParamInterface<
90           std::tuple<bool, bool, bool, bool, bool, bool, bool>> {
91  public:
WriteCallbackPTest()92   WriteCallbackPTest() {
93     std::tie(unordered_write_, seq_per_batch_, two_queues_, allow_parallel_,
94              allow_batching_, enable_WAL_, enable_pipelined_write_) =
95         GetParam();
96   }
97 
98  protected:
99   bool unordered_write_;
100   bool seq_per_batch_;
101   bool two_queues_;
102   bool allow_parallel_;
103   bool allow_batching_;
104   bool enable_WAL_;
105   bool enable_pipelined_write_;
106 };
107 
TEST_P(WriteCallbackPTest,WriteWithCallbackTest)108 TEST_P(WriteCallbackPTest, WriteWithCallbackTest) {
109   struct WriteOP {
110     WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; }
111 
112     void Put(const string& key, const string& val) {
113       kvs_.push_back(std::make_pair(key, val));
114       ASSERT_OK(write_batch_.Put(key, val));
115     }
116 
117     void Clear() {
118       kvs_.clear();
119       write_batch_.Clear();
120       callback_.was_called_.store(false);
121     }
122 
123     MockWriteCallback callback_;
124     WriteBatch write_batch_;
125     std::vector<std::pair<string, string>> kvs_;
126   };
127 
128   // In each scenario we'll launch multiple threads to write.
129   // The size of each array equals to number of threads, and
130   // each boolean in it denote whether callback of corresponding
131   // thread should succeed or fail.
132   std::vector<std::vector<WriteOP>> write_scenarios = {
133       {true},
134       {false},
135       {false, false},
136       {true, true},
137       {true, false},
138       {false, true},
139       {false, false, false},
140       {true, true, true},
141       {false, true, false},
142       {true, false, true},
143       {true, false, false, false, false},
144       {false, false, false, false, true},
145       {false, false, true, false, true},
146   };
147 
148   for (auto& write_group : write_scenarios) {
149     Options options;
150     options.create_if_missing = true;
151     options.unordered_write = unordered_write_;
152     options.allow_concurrent_memtable_write = allow_parallel_;
153     options.enable_pipelined_write = enable_pipelined_write_;
154     options.two_write_queues = two_queues_;
155     // Skip unsupported combinations
156     if (options.enable_pipelined_write && seq_per_batch_) {
157       continue;
158     }
159     if (options.enable_pipelined_write && options.two_write_queues) {
160       continue;
161     }
162     if (options.unordered_write && !options.allow_concurrent_memtable_write) {
163       continue;
164     }
165     if (options.unordered_write && options.enable_pipelined_write) {
166       continue;
167     }
168 
169     ReadOptions read_options;
170     DB* db;
171     DBImpl* db_impl;
172 
173     DestroyDB(dbname, options);
174 
175     DBOptions db_options(options);
176     ColumnFamilyOptions cf_options(options);
177     std::vector<ColumnFamilyDescriptor> column_families;
178     column_families.push_back(
179         ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
180     std::vector<ColumnFamilyHandle*> handles;
181     auto open_s = DBImpl::Open(db_options, dbname, column_families, &handles,
182                                &db, seq_per_batch_, true /* batch_per_txn */);
183     ASSERT_OK(open_s);
184     assert(handles.size() == 1);
185     delete handles[0];
186 
187     db_impl = dynamic_cast<DBImpl*>(db);
188     ASSERT_TRUE(db_impl);
189 
190     // Writers that have called JoinBatchGroup.
191     std::atomic<uint64_t> threads_joining(0);
192     // Writers that have linked to the queue
193     std::atomic<uint64_t> threads_linked(0);
194     // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
195     std::atomic<uint64_t> threads_verified(0);
196 
197     std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber());
198     ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0);
199 
200     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
201         "WriteThread::JoinBatchGroup:Start", [&](void*) {
202           uint64_t cur_threads_joining = threads_joining.fetch_add(1);
203           // Wait for the last joined writer to link to the queue.
204           // In this way the writers link to the queue one by one.
205           // This allows us to confidently detect the first writer
206           // who increases threads_linked as the leader.
207           while (threads_linked.load() < cur_threads_joining) {
208           }
209         });
210 
211     // Verification once writers call JoinBatchGroup.
212     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
213         "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
214           uint64_t cur_threads_linked = threads_linked.fetch_add(1);
215           bool is_leader = false;
216           bool is_last = false;
217 
218           // who am i
219           is_leader = (cur_threads_linked == 0);
220           is_last = (cur_threads_linked == write_group.size() - 1);
221 
222           // check my state
223           auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
224 
225           if (is_leader) {
226             ASSERT_TRUE(writer->state ==
227                         WriteThread::State::STATE_GROUP_LEADER);
228           } else {
229             ASSERT_TRUE(writer->state == WriteThread::State::STATE_INIT);
230           }
231 
232           // (meta test) the first WriteOP should indeed be the first
233           // and the last should be the last (all others can be out of
234           // order)
235           if (is_leader) {
236             ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
237                         !write_group.front().callback_.should_fail_);
238           } else if (is_last) {
239             ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
240                         !write_group.back().callback_.should_fail_);
241           }
242 
243           threads_verified.fetch_add(1);
244           // Wait here until all verification in this sync-point
245           // callback finish for all writers.
246           while (threads_verified.load() < write_group.size()) {
247           }
248         });
249 
250     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
251         "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) {
252           // check my state
253           auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
254 
255           if (!allow_batching_) {
256             // no batching so everyone should be a leader
257             ASSERT_TRUE(writer->state ==
258                         WriteThread::State::STATE_GROUP_LEADER);
259           } else if (!allow_parallel_) {
260             ASSERT_TRUE(writer->state == WriteThread::State::STATE_COMPLETED ||
261                         (enable_pipelined_write_ &&
262                          writer->state ==
263                              WriteThread::State::STATE_MEMTABLE_WRITER_LEADER));
264           }
265         });
266 
267     std::atomic<uint32_t> thread_num(0);
268     std::atomic<char> dummy_key(0);
269 
270     // Each write thread create a random write batch and write to DB
271     // with a write callback.
272     std::function<void()> write_with_callback_func = [&]() {
273       uint32_t i = thread_num.fetch_add(1);
274       Random rnd(i);
275 
276       // leaders gotta lead
277       while (i > 0 && threads_verified.load() < 1) {
278       }
279 
280       // loser has to lose
281       while (i == write_group.size() - 1 &&
282              threads_verified.load() < write_group.size() - 1) {
283       }
284 
285       auto& write_op = write_group.at(i);
286       write_op.Clear();
287       write_op.callback_.allow_batching_ = allow_batching_;
288 
289       // insert some keys
290       for (uint32_t j = 0; j < rnd.Next() % 50; j++) {
291         // grab unique key
292         char my_key = dummy_key.fetch_add(1);
293 
294         string skey(5, my_key);
295         string sval(10, my_key);
296         write_op.Put(skey, sval);
297 
298         if (!write_op.callback_.should_fail_ && !seq_per_batch_) {
299           seq.fetch_add(1);
300         }
301       }
302       if (!write_op.callback_.should_fail_ && seq_per_batch_) {
303         seq.fetch_add(1);
304       }
305 
306       WriteOptions woptions;
307       woptions.disableWAL = !enable_WAL_;
308       woptions.sync = enable_WAL_;
309       Status s;
310       if (seq_per_batch_) {
311         class PublishSeqCallback : public PreReleaseCallback {
312          public:
313           PublishSeqCallback(DBImpl* db_impl_in) : db_impl_(db_impl_in) {}
314           Status Callback(SequenceNumber last_seq, bool /*not used*/, uint64_t,
315                           size_t /*index*/, size_t /*total*/) override {
316             db_impl_->SetLastPublishedSequence(last_seq);
317             return Status::OK();
318           }
319           DBImpl* db_impl_;
320         } publish_seq_callback(db_impl);
321         // seq_per_batch_ requires a natural batch separator or Noop
322         ASSERT_OK(WriteBatchInternal::InsertNoop(&write_op.write_batch_));
323         const size_t ONE_BATCH = 1;
324         s = db_impl->WriteImpl(woptions, &write_op.write_batch_,
325                                &write_op.callback_, nullptr, 0, false, nullptr,
326                                ONE_BATCH,
327                                two_queues_ ? &publish_seq_callback : nullptr);
328       } else {
329         s = db_impl->WriteWithCallback(woptions, &write_op.write_batch_,
330                                        &write_op.callback_);
331       }
332 
333       if (write_op.callback_.should_fail_) {
334         ASSERT_TRUE(s.IsBusy());
335       } else {
336         ASSERT_OK(s);
337       }
338     };
339 
340     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
341 
342     // do all the writes
343     std::vector<port::Thread> threads;
344     for (uint32_t i = 0; i < write_group.size(); i++) {
345       threads.emplace_back(write_with_callback_func);
346     }
347     for (auto& t : threads) {
348       t.join();
349     }
350 
351     ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
352 
353     // check for keys
354     string value;
355     for (auto& w : write_group) {
356       ASSERT_TRUE(w.callback_.was_called_.load());
357       for (auto& kvp : w.kvs_) {
358         if (w.callback_.should_fail_) {
359           ASSERT_TRUE(db->Get(read_options, kvp.first, &value).IsNotFound());
360         } else {
361           ASSERT_OK(db->Get(read_options, kvp.first, &value));
362           ASSERT_EQ(value, kvp.second);
363         }
364       }
365     }
366 
367     ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence());
368 
369     delete db;
370     DestroyDB(dbname, options);
371   }
372 }
373 
374 INSTANTIATE_TEST_CASE_P(WriteCallbackPTest, WriteCallbackPTest,
375                         ::testing::Combine(::testing::Bool(), ::testing::Bool(),
376                                            ::testing::Bool(), ::testing::Bool(),
377                                            ::testing::Bool(), ::testing::Bool(),
378                                            ::testing::Bool()));
379 
TEST_F(WriteCallbackTest,WriteCallBackTest)380 TEST_F(WriteCallbackTest, WriteCallBackTest) {
381   Options options;
382   WriteOptions write_options;
383   ReadOptions read_options;
384   string value;
385   DB* db;
386   DBImpl* db_impl;
387 
388   DestroyDB(dbname, options);
389 
390   options.create_if_missing = true;
391   Status s = DB::Open(options, dbname, &db);
392   ASSERT_OK(s);
393 
394   db_impl = dynamic_cast<DBImpl*> (db);
395   ASSERT_TRUE(db_impl);
396 
397   WriteBatch wb;
398 
399   ASSERT_OK(wb.Put("a", "value.a"));
400   ASSERT_OK(wb.Delete("x"));
401 
402   // Test a simple Write
403   s = db->Write(write_options, &wb);
404   ASSERT_OK(s);
405 
406   s = db->Get(read_options, "a", &value);
407   ASSERT_OK(s);
408   ASSERT_EQ("value.a", value);
409 
410   // Test WriteWithCallback
411   WriteCallbackTestWriteCallback1 callback1;
412   WriteBatch wb2;
413 
414   ASSERT_OK(wb2.Put("a", "value.a2"));
415 
416   s = db_impl->WriteWithCallback(write_options, &wb2, &callback1);
417   ASSERT_OK(s);
418   ASSERT_TRUE(callback1.was_called);
419 
420   s = db->Get(read_options, "a", &value);
421   ASSERT_OK(s);
422   ASSERT_EQ("value.a2", value);
423 
424   // Test WriteWithCallback for a callback that fails
425   WriteCallbackTestWriteCallback2 callback2;
426   WriteBatch wb3;
427 
428   ASSERT_OK(wb3.Put("a", "value.a3"));
429 
430   s = db_impl->WriteWithCallback(write_options, &wb3, &callback2);
431   ASSERT_NOK(s);
432 
433   s = db->Get(read_options, "a", &value);
434   ASSERT_OK(s);
435   ASSERT_EQ("value.a2", value);
436 
437   delete db;
438   DestroyDB(dbname, options);
439 }
440 
441 }  // namespace ROCKSDB_NAMESPACE
442 
main(int argc,char ** argv)443 int main(int argc, char** argv) {
444   ::testing::InitGoogleTest(&argc, argv);
445   return RUN_ALL_TESTS();
446 }
447 
448 #else
449 #include <stdio.h>
450 
main(int,char **)451 int main(int /*argc*/, char** /*argv*/) {
452   fprintf(stderr,
453           "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n");
454   return 0;
455 }
456 
457 #endif  // !ROCKSDB_LITE
458