1 //  Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
2 //  This source code is licensed under both the GPLv2 (found in the
3 //  COPYING file in the root directory) and Apache 2.0 License
4 //  (found in the LICENSE.Apache file in the root directory).
5 //
6 #include <assert.h>
7 #include <memory>
8 #include <iostream>
9 
10 #include "db/db_impl/db_impl.h"
11 #include "db/dbformat.h"
12 #include "db/write_batch_internal.h"
13 #include "port/stack_trace.h"
14 #include "rocksdb/cache.h"
15 #include "rocksdb/comparator.h"
16 #include "rocksdb/db.h"
17 #include "rocksdb/env.h"
18 #include "rocksdb/merge_operator.h"
19 #include "rocksdb/utilities/db_ttl.h"
20 #include "test_util/testharness.h"
21 #include "utilities/merge_operators.h"
22 
23 namespace ROCKSDB_NAMESPACE {
24 
25 bool use_compression;
26 
27 class MergeTest : public testing::Test {};
28 
29 size_t num_merge_operator_calls;
resetNumMergeOperatorCalls()30 void resetNumMergeOperatorCalls() { num_merge_operator_calls = 0; }
31 
32 size_t num_partial_merge_calls;
resetNumPartialMergeCalls()33 void resetNumPartialMergeCalls() { num_partial_merge_calls = 0; }
34 
35 class CountMergeOperator : public AssociativeMergeOperator {
36  public:
CountMergeOperator()37   CountMergeOperator() {
38     mergeOperator_ = MergeOperators::CreateUInt64AddOperator();
39   }
40 
Merge(const Slice & key,const Slice * existing_value,const Slice & value,std::string * new_value,Logger * logger) const41   bool Merge(const Slice& key, const Slice* existing_value, const Slice& value,
42              std::string* new_value, Logger* logger) const override {
43     assert(new_value->empty());
44     ++num_merge_operator_calls;
45     if (existing_value == nullptr) {
46       new_value->assign(value.data(), value.size());
47       return true;
48     }
49 
50     return mergeOperator_->PartialMerge(
51         key,
52         *existing_value,
53         value,
54         new_value,
55         logger);
56   }
57 
PartialMergeMulti(const Slice & key,const std::deque<Slice> & operand_list,std::string * new_value,Logger * logger) const58   bool PartialMergeMulti(const Slice& key,
59                          const std::deque<Slice>& operand_list,
60                          std::string* new_value,
61                          Logger* logger) const override {
62     assert(new_value->empty());
63     ++num_partial_merge_calls;
64     return mergeOperator_->PartialMergeMulti(key, operand_list, new_value,
65                                              logger);
66   }
67 
Name() const68   const char* Name() const override { return "UInt64AddOperator"; }
69 
70  private:
71   std::shared_ptr<MergeOperator> mergeOperator_;
72 };
73 
OpenDb(const std::string & dbname,const bool ttl=false,const size_t max_successive_merges=0)74 std::shared_ptr<DB> OpenDb(const std::string& dbname, const bool ttl = false,
75                            const size_t max_successive_merges = 0) {
76   DB* db;
77   Options options;
78   options.create_if_missing = true;
79   options.merge_operator = std::make_shared<CountMergeOperator>();
80   options.max_successive_merges = max_successive_merges;
81   Status s;
82   DestroyDB(dbname, Options());
83 // DBWithTTL is not supported in ROCKSDB_LITE
84 #ifndef ROCKSDB_LITE
85   if (ttl) {
86     DBWithTTL* db_with_ttl;
87     s = DBWithTTL::Open(options, dbname, &db_with_ttl);
88     db = db_with_ttl;
89   } else {
90     s = DB::Open(options, dbname, &db);
91   }
92 #else
93   assert(!ttl);
94   s = DB::Open(options, dbname, &db);
95 #endif  // !ROCKSDB_LITE
96   if (!s.ok()) {
97     std::cerr << s.ToString() << std::endl;
98     assert(false);
99   }
100   return std::shared_ptr<DB>(db);
101 }
102 
103 // Imagine we are maintaining a set of uint64 counters.
104 // Each counter has a distinct name. And we would like
105 // to support four high level operations:
106 // set, add, get and remove
107 // This is a quick implementation without a Merge operation.
108 class Counters {
109 
110  protected:
111   std::shared_ptr<DB> db_;
112 
113   WriteOptions put_option_;
114   ReadOptions get_option_;
115   WriteOptions delete_option_;
116 
117   uint64_t default_;
118 
119  public:
Counters(std::shared_ptr<DB> db,uint64_t defaultCount=0)120   explicit Counters(std::shared_ptr<DB> db, uint64_t defaultCount = 0)
121       : db_(db),
122         put_option_(),
123         get_option_(),
124         delete_option_(),
125         default_(defaultCount) {
126     assert(db_);
127   }
128 
~Counters()129   virtual ~Counters() {}
130 
131   // public interface of Counters.
132   // All four functions return false
133   // if the underlying level db operation failed.
134 
135   // mapped to a levedb Put
set(const std::string & key,uint64_t value)136   bool set(const std::string& key, uint64_t value) {
137     // just treat the internal rep of int64 as the string
138     char buf[sizeof(value)];
139     EncodeFixed64(buf, value);
140     Slice slice(buf, sizeof(value));
141     auto s = db_->Put(put_option_, key, slice);
142 
143     if (s.ok()) {
144       return true;
145     } else {
146       std::cerr << s.ToString() << std::endl;
147       return false;
148     }
149   }
150 
151   // mapped to a rocksdb Delete
remove(const std::string & key)152   bool remove(const std::string& key) {
153     auto s = db_->Delete(delete_option_, key);
154 
155     if (s.ok()) {
156       return true;
157     } else {
158       std::cerr << s.ToString() << std::endl;
159       return false;
160     }
161   }
162 
163   // mapped to a rocksdb Get
get(const std::string & key,uint64_t * value)164   bool get(const std::string& key, uint64_t* value) {
165     std::string str;
166     auto s = db_->Get(get_option_, key, &str);
167 
168     if (s.IsNotFound()) {
169       // return default value if not found;
170       *value = default_;
171       return true;
172     } else if (s.ok()) {
173       // deserialization
174       if (str.size() != sizeof(uint64_t)) {
175         std::cerr << "value corruption\n";
176         return false;
177       }
178       *value = DecodeFixed64(&str[0]);
179       return true;
180     } else {
181       std::cerr << s.ToString() << std::endl;
182       return false;
183     }
184   }
185 
186   // 'add' is implemented as get -> modify -> set
187   // An alternative is a single merge operation, see MergeBasedCounters
add(const std::string & key,uint64_t value)188   virtual bool add(const std::string& key, uint64_t value) {
189     uint64_t base = default_;
190     return get(key, &base) && set(key, base + value);
191   }
192 
193 
194   // convenience functions for testing
assert_set(const std::string & key,uint64_t value)195   void assert_set(const std::string& key, uint64_t value) {
196     assert(set(key, value));
197   }
198 
assert_remove(const std::string & key)199   void assert_remove(const std::string& key) { assert(remove(key)); }
200 
assert_get(const std::string & key)201   uint64_t assert_get(const std::string& key) {
202     uint64_t value = default_;
203     int result = get(key, &value);
204     assert(result);
205     if (result == 0) exit(1); // Disable unused variable warning.
206     return value;
207   }
208 
assert_add(const std::string & key,uint64_t value)209   void assert_add(const std::string& key, uint64_t value) {
210     int result = add(key, value);
211     assert(result);
212     if (result == 0) exit(1); // Disable unused variable warning.
213   }
214 };
215 
216 // Implement 'add' directly with the new Merge operation
217 class MergeBasedCounters : public Counters {
218  private:
219   WriteOptions merge_option_; // for merge
220 
221  public:
MergeBasedCounters(std::shared_ptr<DB> db,uint64_t defaultCount=0)222   explicit MergeBasedCounters(std::shared_ptr<DB> db, uint64_t defaultCount = 0)
223       : Counters(db, defaultCount),
224         merge_option_() {
225   }
226 
227   // mapped to a rocksdb Merge operation
add(const std::string & key,uint64_t value)228   bool add(const std::string& key, uint64_t value) override {
229     char encoded[sizeof(uint64_t)];
230     EncodeFixed64(encoded, value);
231     Slice slice(encoded, sizeof(uint64_t));
232     auto s = db_->Merge(merge_option_, key, slice);
233 
234     if (s.ok()) {
235       return true;
236     } else {
237       std::cerr << s.ToString() << std::endl;
238       return false;
239     }
240   }
241 };
242 
dumpDb(DB * db)243 void dumpDb(DB* db) {
244   auto it = std::unique_ptr<Iterator>(db->NewIterator(ReadOptions()));
245   for (it->SeekToFirst(); it->Valid(); it->Next()) {
246     //uint64_t value = DecodeFixed64(it->value().data());
247     //std::cout << it->key().ToString() << ": " << value << std::endl;
248   }
249   assert(it->status().ok());  // Check for any errors found during the scan
250 }
251 
testCounters(Counters & counters,DB * db,bool test_compaction)252 void testCounters(Counters& counters, DB* db, bool test_compaction) {
253 
254   FlushOptions o;
255   o.wait = true;
256 
257   counters.assert_set("a", 1);
258 
259   if (test_compaction) db->Flush(o);
260 
261   assert(counters.assert_get("a") == 1);
262 
263   counters.assert_remove("b");
264 
265   // defaut value is 0 if non-existent
266   assert(counters.assert_get("b") == 0);
267 
268   counters.assert_add("a", 2);
269 
270   if (test_compaction) db->Flush(o);
271 
272   // 1+2 = 3
273   assert(counters.assert_get("a")== 3);
274 
275   dumpDb(db);
276 
277   // 1+...+49 = ?
278   uint64_t sum = 0;
279   for (int i = 1; i < 50; i++) {
280     counters.assert_add("b", i);
281     sum += i;
282   }
283   assert(counters.assert_get("b") == sum);
284 
285   dumpDb(db);
286 
287   if (test_compaction) {
288     db->Flush(o);
289 
290     db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
291 
292     dumpDb(db);
293 
294     assert(counters.assert_get("a")== 3);
295     assert(counters.assert_get("b") == sum);
296   }
297 }
298 
testSuccessiveMerge(Counters & counters,size_t max_num_merges,size_t num_merges)299 void testSuccessiveMerge(Counters& counters, size_t max_num_merges,
300                          size_t num_merges) {
301 
302   counters.assert_remove("z");
303   uint64_t sum = 0;
304 
305   for (size_t i = 1; i <= num_merges; ++i) {
306     resetNumMergeOperatorCalls();
307     counters.assert_add("z", i);
308     sum += i;
309 
310     if (i % (max_num_merges + 1) == 0) {
311       assert(num_merge_operator_calls == max_num_merges + 1);
312     } else {
313       assert(num_merge_operator_calls == 0);
314     }
315 
316     resetNumMergeOperatorCalls();
317     assert(counters.assert_get("z") == sum);
318     assert(num_merge_operator_calls == i % (max_num_merges + 1));
319   }
320 }
321 
testPartialMerge(Counters * counters,DB * db,size_t max_merge,size_t min_merge,size_t count)322 void testPartialMerge(Counters* counters, DB* db, size_t max_merge,
323                       size_t min_merge, size_t count) {
324   FlushOptions o;
325   o.wait = true;
326 
327   // Test case 1: partial merge should be called when the number of merge
328   //              operands exceeds the threshold.
329   uint64_t tmp_sum = 0;
330   resetNumPartialMergeCalls();
331   for (size_t i = 1; i <= count; i++) {
332     counters->assert_add("b", i);
333     tmp_sum += i;
334   }
335   db->Flush(o);
336   db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
337   ASSERT_EQ(tmp_sum, counters->assert_get("b"));
338   if (count > max_merge) {
339     // in this case, FullMerge should be called instead.
340     ASSERT_EQ(num_partial_merge_calls, 0U);
341   } else {
342     // if count >= min_merge, then partial merge should be called once.
343     ASSERT_EQ((count >= min_merge), (num_partial_merge_calls == 1));
344   }
345 
346   // Test case 2: partial merge should not be called when a put is found.
347   resetNumPartialMergeCalls();
348   tmp_sum = 0;
349   db->Put(ROCKSDB_NAMESPACE::WriteOptions(), "c", "10");
350   for (size_t i = 1; i <= count; i++) {
351     counters->assert_add("c", i);
352     tmp_sum += i;
353   }
354   db->Flush(o);
355   db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
356   ASSERT_EQ(tmp_sum, counters->assert_get("c"));
357   ASSERT_EQ(num_partial_merge_calls, 0U);
358 }
359 
testSingleBatchSuccessiveMerge(DB * db,size_t max_num_merges,size_t num_merges)360 void testSingleBatchSuccessiveMerge(DB* db, size_t max_num_merges,
361                                     size_t num_merges) {
362   assert(num_merges > max_num_merges);
363 
364   Slice key("BatchSuccessiveMerge");
365   uint64_t merge_value = 1;
366   char buf[sizeof(merge_value)];
367   EncodeFixed64(buf, merge_value);
368   Slice merge_value_slice(buf, sizeof(merge_value));
369 
370   // Create the batch
371   WriteBatch batch;
372   for (size_t i = 0; i < num_merges; ++i) {
373     batch.Merge(key, merge_value_slice);
374   }
375 
376   // Apply to memtable and count the number of merges
377   resetNumMergeOperatorCalls();
378   {
379     Status s = db->Write(WriteOptions(), &batch);
380     assert(s.ok());
381   }
382   ASSERT_EQ(
383       num_merge_operator_calls,
384       static_cast<size_t>(num_merges - (num_merges % (max_num_merges + 1))));
385 
386   // Get the value
387   resetNumMergeOperatorCalls();
388   std::string get_value_str;
389   {
390     Status s = db->Get(ReadOptions(), key, &get_value_str);
391     assert(s.ok());
392   }
393   assert(get_value_str.size() == sizeof(uint64_t));
394   uint64_t get_value = DecodeFixed64(&get_value_str[0]);
395   ASSERT_EQ(get_value, num_merges * merge_value);
396   ASSERT_EQ(num_merge_operator_calls,
397             static_cast<size_t>((num_merges % (max_num_merges + 1))));
398 }
399 
runTest(const std::string & dbname,const bool use_ttl=false)400 void runTest(const std::string& dbname, const bool use_ttl = false) {
401 
402   {
403     auto db = OpenDb(dbname, use_ttl);
404 
405     {
406       Counters counters(db, 0);
407       testCounters(counters, db.get(), true);
408     }
409 
410     {
411       MergeBasedCounters counters(db, 0);
412       testCounters(counters, db.get(), use_compression);
413     }
414   }
415 
416   DestroyDB(dbname, Options());
417 
418   {
419     size_t max_merge = 5;
420     auto db = OpenDb(dbname, use_ttl, max_merge);
421     MergeBasedCounters counters(db, 0);
422     testCounters(counters, db.get(), use_compression);
423     testSuccessiveMerge(counters, max_merge, max_merge * 2);
424     testSingleBatchSuccessiveMerge(db.get(), 5, 7);
425     DestroyDB(dbname, Options());
426   }
427 
428   {
429     size_t max_merge = 100;
430     // Min merge is hard-coded to 2.
431     uint32_t min_merge = 2;
432     for (uint32_t count = min_merge - 1; count <= min_merge + 1; count++) {
433       auto db = OpenDb(dbname, use_ttl, max_merge);
434       MergeBasedCounters counters(db, 0);
435       testPartialMerge(&counters, db.get(), max_merge, min_merge, count);
436       DestroyDB(dbname, Options());
437     }
438     {
439       auto db = OpenDb(dbname, use_ttl, max_merge);
440       MergeBasedCounters counters(db, 0);
441       testPartialMerge(&counters, db.get(), max_merge, min_merge,
442                        min_merge * 10);
443       DestroyDB(dbname, Options());
444     }
445   }
446 
447   {
448     {
449       auto db = OpenDb(dbname);
450       MergeBasedCounters counters(db, 0);
451       counters.add("test-key", 1);
452       counters.add("test-key", 1);
453       counters.add("test-key", 1);
454       db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
455     }
456 
457     DB* reopen_db;
458     ASSERT_OK(DB::Open(Options(), dbname, &reopen_db));
459     std::string value;
460     ASSERT_TRUE(!(reopen_db->Get(ReadOptions(), "test-key", &value).ok()));
461     delete reopen_db;
462     DestroyDB(dbname, Options());
463   }
464 
465   /* Temporary remove this test
466   {
467     std::cout << "Test merge-operator not set after reopen (recovery case)\n";
468     {
469       auto db = OpenDb(dbname);
470       MergeBasedCounters counters(db, 0);
471       counters.add("test-key", 1);
472       counters.add("test-key", 1);
473       counters.add("test-key", 1);
474     }
475 
476     DB* reopen_db;
477     ASSERT_TRUE(DB::Open(Options(), dbname, &reopen_db).IsInvalidArgument());
478   }
479   */
480 }
481 
TEST_F(MergeTest,MergeDbTest)482 TEST_F(MergeTest, MergeDbTest) {
483   runTest(test::PerThreadDBPath("merge_testdb"));
484 }
485 
486 #ifndef ROCKSDB_LITE
TEST_F(MergeTest,MergeDbTtlTest)487 TEST_F(MergeTest, MergeDbTtlTest) {
488   runTest(test::PerThreadDBPath("merge_testdbttl"),
489           true);  // Run test on TTL database
490 }
491 #endif  // !ROCKSDB_LITE
492 
493 }  // namespace ROCKSDB_NAMESPACE
494 
main(int argc,char ** argv)495 int main(int argc, char** argv) {
496   ROCKSDB_NAMESPACE::use_compression = false;
497   if (argc > 1) {
498     ROCKSDB_NAMESPACE::use_compression = true;
499   }
500 
501   ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
502   ::testing::InitGoogleTest(&argc, argv);
503   return RUN_ALL_TESTS();
504 }
505