1 // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
2 // This source code is licensed under both the GPLv2 (found in the
3 // COPYING file in the root directory) and Apache 2.0 License
4 // (found in the LICENSE.Apache file in the root directory).
5 //
6 #include <assert.h>
7 #include <memory>
8 #include <iostream>
9
10 #include "db/db_impl/db_impl.h"
11 #include "db/dbformat.h"
12 #include "db/write_batch_internal.h"
13 #include "port/stack_trace.h"
14 #include "rocksdb/cache.h"
15 #include "rocksdb/comparator.h"
16 #include "rocksdb/db.h"
17 #include "rocksdb/env.h"
18 #include "rocksdb/merge_operator.h"
19 #include "rocksdb/utilities/db_ttl.h"
20 #include "test_util/testharness.h"
21 #include "utilities/merge_operators.h"
22
23 namespace ROCKSDB_NAMESPACE {
24
25 bool use_compression;
26
27 class MergeTest : public testing::Test {};
28
29 size_t num_merge_operator_calls;
resetNumMergeOperatorCalls()30 void resetNumMergeOperatorCalls() { num_merge_operator_calls = 0; }
31
32 size_t num_partial_merge_calls;
resetNumPartialMergeCalls()33 void resetNumPartialMergeCalls() { num_partial_merge_calls = 0; }
34
35 class CountMergeOperator : public AssociativeMergeOperator {
36 public:
CountMergeOperator()37 CountMergeOperator() {
38 mergeOperator_ = MergeOperators::CreateUInt64AddOperator();
39 }
40
Merge(const Slice & key,const Slice * existing_value,const Slice & value,std::string * new_value,Logger * logger) const41 bool Merge(const Slice& key, const Slice* existing_value, const Slice& value,
42 std::string* new_value, Logger* logger) const override {
43 assert(new_value->empty());
44 ++num_merge_operator_calls;
45 if (existing_value == nullptr) {
46 new_value->assign(value.data(), value.size());
47 return true;
48 }
49
50 return mergeOperator_->PartialMerge(
51 key,
52 *existing_value,
53 value,
54 new_value,
55 logger);
56 }
57
PartialMergeMulti(const Slice & key,const std::deque<Slice> & operand_list,std::string * new_value,Logger * logger) const58 bool PartialMergeMulti(const Slice& key,
59 const std::deque<Slice>& operand_list,
60 std::string* new_value,
61 Logger* logger) const override {
62 assert(new_value->empty());
63 ++num_partial_merge_calls;
64 return mergeOperator_->PartialMergeMulti(key, operand_list, new_value,
65 logger);
66 }
67
Name() const68 const char* Name() const override { return "UInt64AddOperator"; }
69
70 private:
71 std::shared_ptr<MergeOperator> mergeOperator_;
72 };
73
OpenDb(const std::string & dbname,const bool ttl=false,const size_t max_successive_merges=0)74 std::shared_ptr<DB> OpenDb(const std::string& dbname, const bool ttl = false,
75 const size_t max_successive_merges = 0) {
76 DB* db;
77 Options options;
78 options.create_if_missing = true;
79 options.merge_operator = std::make_shared<CountMergeOperator>();
80 options.max_successive_merges = max_successive_merges;
81 Status s;
82 DestroyDB(dbname, Options());
83 // DBWithTTL is not supported in ROCKSDB_LITE
84 #ifndef ROCKSDB_LITE
85 if (ttl) {
86 DBWithTTL* db_with_ttl;
87 s = DBWithTTL::Open(options, dbname, &db_with_ttl);
88 db = db_with_ttl;
89 } else {
90 s = DB::Open(options, dbname, &db);
91 }
92 #else
93 assert(!ttl);
94 s = DB::Open(options, dbname, &db);
95 #endif // !ROCKSDB_LITE
96 if (!s.ok()) {
97 std::cerr << s.ToString() << std::endl;
98 assert(false);
99 }
100 return std::shared_ptr<DB>(db);
101 }
102
103 // Imagine we are maintaining a set of uint64 counters.
104 // Each counter has a distinct name. And we would like
105 // to support four high level operations:
106 // set, add, get and remove
107 // This is a quick implementation without a Merge operation.
108 class Counters {
109
110 protected:
111 std::shared_ptr<DB> db_;
112
113 WriteOptions put_option_;
114 ReadOptions get_option_;
115 WriteOptions delete_option_;
116
117 uint64_t default_;
118
119 public:
Counters(std::shared_ptr<DB> db,uint64_t defaultCount=0)120 explicit Counters(std::shared_ptr<DB> db, uint64_t defaultCount = 0)
121 : db_(db),
122 put_option_(),
123 get_option_(),
124 delete_option_(),
125 default_(defaultCount) {
126 assert(db_);
127 }
128
~Counters()129 virtual ~Counters() {}
130
131 // public interface of Counters.
132 // All four functions return false
133 // if the underlying level db operation failed.
134
135 // mapped to a levedb Put
set(const std::string & key,uint64_t value)136 bool set(const std::string& key, uint64_t value) {
137 // just treat the internal rep of int64 as the string
138 char buf[sizeof(value)];
139 EncodeFixed64(buf, value);
140 Slice slice(buf, sizeof(value));
141 auto s = db_->Put(put_option_, key, slice);
142
143 if (s.ok()) {
144 return true;
145 } else {
146 std::cerr << s.ToString() << std::endl;
147 return false;
148 }
149 }
150
151 // mapped to a rocksdb Delete
remove(const std::string & key)152 bool remove(const std::string& key) {
153 auto s = db_->Delete(delete_option_, key);
154
155 if (s.ok()) {
156 return true;
157 } else {
158 std::cerr << s.ToString() << std::endl;
159 return false;
160 }
161 }
162
163 // mapped to a rocksdb Get
get(const std::string & key,uint64_t * value)164 bool get(const std::string& key, uint64_t* value) {
165 std::string str;
166 auto s = db_->Get(get_option_, key, &str);
167
168 if (s.IsNotFound()) {
169 // return default value if not found;
170 *value = default_;
171 return true;
172 } else if (s.ok()) {
173 // deserialization
174 if (str.size() != sizeof(uint64_t)) {
175 std::cerr << "value corruption\n";
176 return false;
177 }
178 *value = DecodeFixed64(&str[0]);
179 return true;
180 } else {
181 std::cerr << s.ToString() << std::endl;
182 return false;
183 }
184 }
185
186 // 'add' is implemented as get -> modify -> set
187 // An alternative is a single merge operation, see MergeBasedCounters
add(const std::string & key,uint64_t value)188 virtual bool add(const std::string& key, uint64_t value) {
189 uint64_t base = default_;
190 return get(key, &base) && set(key, base + value);
191 }
192
193
194 // convenience functions for testing
assert_set(const std::string & key,uint64_t value)195 void assert_set(const std::string& key, uint64_t value) {
196 assert(set(key, value));
197 }
198
assert_remove(const std::string & key)199 void assert_remove(const std::string& key) { assert(remove(key)); }
200
assert_get(const std::string & key)201 uint64_t assert_get(const std::string& key) {
202 uint64_t value = default_;
203 int result = get(key, &value);
204 assert(result);
205 if (result == 0) exit(1); // Disable unused variable warning.
206 return value;
207 }
208
assert_add(const std::string & key,uint64_t value)209 void assert_add(const std::string& key, uint64_t value) {
210 int result = add(key, value);
211 assert(result);
212 if (result == 0) exit(1); // Disable unused variable warning.
213 }
214 };
215
216 // Implement 'add' directly with the new Merge operation
217 class MergeBasedCounters : public Counters {
218 private:
219 WriteOptions merge_option_; // for merge
220
221 public:
MergeBasedCounters(std::shared_ptr<DB> db,uint64_t defaultCount=0)222 explicit MergeBasedCounters(std::shared_ptr<DB> db, uint64_t defaultCount = 0)
223 : Counters(db, defaultCount),
224 merge_option_() {
225 }
226
227 // mapped to a rocksdb Merge operation
add(const std::string & key,uint64_t value)228 bool add(const std::string& key, uint64_t value) override {
229 char encoded[sizeof(uint64_t)];
230 EncodeFixed64(encoded, value);
231 Slice slice(encoded, sizeof(uint64_t));
232 auto s = db_->Merge(merge_option_, key, slice);
233
234 if (s.ok()) {
235 return true;
236 } else {
237 std::cerr << s.ToString() << std::endl;
238 return false;
239 }
240 }
241 };
242
dumpDb(DB * db)243 void dumpDb(DB* db) {
244 auto it = std::unique_ptr<Iterator>(db->NewIterator(ReadOptions()));
245 for (it->SeekToFirst(); it->Valid(); it->Next()) {
246 //uint64_t value = DecodeFixed64(it->value().data());
247 //std::cout << it->key().ToString() << ": " << value << std::endl;
248 }
249 assert(it->status().ok()); // Check for any errors found during the scan
250 }
251
testCounters(Counters & counters,DB * db,bool test_compaction)252 void testCounters(Counters& counters, DB* db, bool test_compaction) {
253
254 FlushOptions o;
255 o.wait = true;
256
257 counters.assert_set("a", 1);
258
259 if (test_compaction) db->Flush(o);
260
261 assert(counters.assert_get("a") == 1);
262
263 counters.assert_remove("b");
264
265 // defaut value is 0 if non-existent
266 assert(counters.assert_get("b") == 0);
267
268 counters.assert_add("a", 2);
269
270 if (test_compaction) db->Flush(o);
271
272 // 1+2 = 3
273 assert(counters.assert_get("a")== 3);
274
275 dumpDb(db);
276
277 // 1+...+49 = ?
278 uint64_t sum = 0;
279 for (int i = 1; i < 50; i++) {
280 counters.assert_add("b", i);
281 sum += i;
282 }
283 assert(counters.assert_get("b") == sum);
284
285 dumpDb(db);
286
287 if (test_compaction) {
288 db->Flush(o);
289
290 db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
291
292 dumpDb(db);
293
294 assert(counters.assert_get("a")== 3);
295 assert(counters.assert_get("b") == sum);
296 }
297 }
298
testSuccessiveMerge(Counters & counters,size_t max_num_merges,size_t num_merges)299 void testSuccessiveMerge(Counters& counters, size_t max_num_merges,
300 size_t num_merges) {
301
302 counters.assert_remove("z");
303 uint64_t sum = 0;
304
305 for (size_t i = 1; i <= num_merges; ++i) {
306 resetNumMergeOperatorCalls();
307 counters.assert_add("z", i);
308 sum += i;
309
310 if (i % (max_num_merges + 1) == 0) {
311 assert(num_merge_operator_calls == max_num_merges + 1);
312 } else {
313 assert(num_merge_operator_calls == 0);
314 }
315
316 resetNumMergeOperatorCalls();
317 assert(counters.assert_get("z") == sum);
318 assert(num_merge_operator_calls == i % (max_num_merges + 1));
319 }
320 }
321
testPartialMerge(Counters * counters,DB * db,size_t max_merge,size_t min_merge,size_t count)322 void testPartialMerge(Counters* counters, DB* db, size_t max_merge,
323 size_t min_merge, size_t count) {
324 FlushOptions o;
325 o.wait = true;
326
327 // Test case 1: partial merge should be called when the number of merge
328 // operands exceeds the threshold.
329 uint64_t tmp_sum = 0;
330 resetNumPartialMergeCalls();
331 for (size_t i = 1; i <= count; i++) {
332 counters->assert_add("b", i);
333 tmp_sum += i;
334 }
335 db->Flush(o);
336 db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
337 ASSERT_EQ(tmp_sum, counters->assert_get("b"));
338 if (count > max_merge) {
339 // in this case, FullMerge should be called instead.
340 ASSERT_EQ(num_partial_merge_calls, 0U);
341 } else {
342 // if count >= min_merge, then partial merge should be called once.
343 ASSERT_EQ((count >= min_merge), (num_partial_merge_calls == 1));
344 }
345
346 // Test case 2: partial merge should not be called when a put is found.
347 resetNumPartialMergeCalls();
348 tmp_sum = 0;
349 db->Put(ROCKSDB_NAMESPACE::WriteOptions(), "c", "10");
350 for (size_t i = 1; i <= count; i++) {
351 counters->assert_add("c", i);
352 tmp_sum += i;
353 }
354 db->Flush(o);
355 db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
356 ASSERT_EQ(tmp_sum, counters->assert_get("c"));
357 ASSERT_EQ(num_partial_merge_calls, 0U);
358 }
359
testSingleBatchSuccessiveMerge(DB * db,size_t max_num_merges,size_t num_merges)360 void testSingleBatchSuccessiveMerge(DB* db, size_t max_num_merges,
361 size_t num_merges) {
362 assert(num_merges > max_num_merges);
363
364 Slice key("BatchSuccessiveMerge");
365 uint64_t merge_value = 1;
366 char buf[sizeof(merge_value)];
367 EncodeFixed64(buf, merge_value);
368 Slice merge_value_slice(buf, sizeof(merge_value));
369
370 // Create the batch
371 WriteBatch batch;
372 for (size_t i = 0; i < num_merges; ++i) {
373 batch.Merge(key, merge_value_slice);
374 }
375
376 // Apply to memtable and count the number of merges
377 resetNumMergeOperatorCalls();
378 {
379 Status s = db->Write(WriteOptions(), &batch);
380 assert(s.ok());
381 }
382 ASSERT_EQ(
383 num_merge_operator_calls,
384 static_cast<size_t>(num_merges - (num_merges % (max_num_merges + 1))));
385
386 // Get the value
387 resetNumMergeOperatorCalls();
388 std::string get_value_str;
389 {
390 Status s = db->Get(ReadOptions(), key, &get_value_str);
391 assert(s.ok());
392 }
393 assert(get_value_str.size() == sizeof(uint64_t));
394 uint64_t get_value = DecodeFixed64(&get_value_str[0]);
395 ASSERT_EQ(get_value, num_merges * merge_value);
396 ASSERT_EQ(num_merge_operator_calls,
397 static_cast<size_t>((num_merges % (max_num_merges + 1))));
398 }
399
runTest(const std::string & dbname,const bool use_ttl=false)400 void runTest(const std::string& dbname, const bool use_ttl = false) {
401
402 {
403 auto db = OpenDb(dbname, use_ttl);
404
405 {
406 Counters counters(db, 0);
407 testCounters(counters, db.get(), true);
408 }
409
410 {
411 MergeBasedCounters counters(db, 0);
412 testCounters(counters, db.get(), use_compression);
413 }
414 }
415
416 DestroyDB(dbname, Options());
417
418 {
419 size_t max_merge = 5;
420 auto db = OpenDb(dbname, use_ttl, max_merge);
421 MergeBasedCounters counters(db, 0);
422 testCounters(counters, db.get(), use_compression);
423 testSuccessiveMerge(counters, max_merge, max_merge * 2);
424 testSingleBatchSuccessiveMerge(db.get(), 5, 7);
425 DestroyDB(dbname, Options());
426 }
427
428 {
429 size_t max_merge = 100;
430 // Min merge is hard-coded to 2.
431 uint32_t min_merge = 2;
432 for (uint32_t count = min_merge - 1; count <= min_merge + 1; count++) {
433 auto db = OpenDb(dbname, use_ttl, max_merge);
434 MergeBasedCounters counters(db, 0);
435 testPartialMerge(&counters, db.get(), max_merge, min_merge, count);
436 DestroyDB(dbname, Options());
437 }
438 {
439 auto db = OpenDb(dbname, use_ttl, max_merge);
440 MergeBasedCounters counters(db, 0);
441 testPartialMerge(&counters, db.get(), max_merge, min_merge,
442 min_merge * 10);
443 DestroyDB(dbname, Options());
444 }
445 }
446
447 {
448 {
449 auto db = OpenDb(dbname);
450 MergeBasedCounters counters(db, 0);
451 counters.add("test-key", 1);
452 counters.add("test-key", 1);
453 counters.add("test-key", 1);
454 db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
455 }
456
457 DB* reopen_db;
458 ASSERT_OK(DB::Open(Options(), dbname, &reopen_db));
459 std::string value;
460 ASSERT_TRUE(!(reopen_db->Get(ReadOptions(), "test-key", &value).ok()));
461 delete reopen_db;
462 DestroyDB(dbname, Options());
463 }
464
465 /* Temporary remove this test
466 {
467 std::cout << "Test merge-operator not set after reopen (recovery case)\n";
468 {
469 auto db = OpenDb(dbname);
470 MergeBasedCounters counters(db, 0);
471 counters.add("test-key", 1);
472 counters.add("test-key", 1);
473 counters.add("test-key", 1);
474 }
475
476 DB* reopen_db;
477 ASSERT_TRUE(DB::Open(Options(), dbname, &reopen_db).IsInvalidArgument());
478 }
479 */
480 }
481
TEST_F(MergeTest,MergeDbTest)482 TEST_F(MergeTest, MergeDbTest) {
483 runTest(test::PerThreadDBPath("merge_testdb"));
484 }
485
486 #ifndef ROCKSDB_LITE
TEST_F(MergeTest,MergeDbTtlTest)487 TEST_F(MergeTest, MergeDbTtlTest) {
488 runTest(test::PerThreadDBPath("merge_testdbttl"),
489 true); // Run test on TTL database
490 }
491 #endif // !ROCKSDB_LITE
492
493 } // namespace ROCKSDB_NAMESPACE
494
main(int argc,char ** argv)495 int main(int argc, char** argv) {
496 ROCKSDB_NAMESPACE::use_compression = false;
497 if (argc > 1) {
498 ROCKSDB_NAMESPACE::use_compression = true;
499 }
500
501 ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
502 ::testing::InitGoogleTest(&argc, argv);
503 return RUN_ALL_TESTS();
504 }
505