1 // Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
2 //  This source code is licensed under both the GPLv2 (found in the
3 //  COPYING file in the root directory) and Apache 2.0 License
4 //  (found in the LICENSE.Apache file in the root directory).
5 
6 #ifndef ROCKSDB_LITE
7 
8 #include "utilities/transactions/transaction_base.h"
9 
10 #include <cinttypes>
11 
12 #include "db/column_family.h"
13 #include "db/db_impl/db_impl.h"
14 #include "rocksdb/comparator.h"
15 #include "rocksdb/db.h"
16 #include "rocksdb/status.h"
17 #include "util/cast_util.h"
18 #include "util/string_util.h"
19 
20 namespace ROCKSDB_NAMESPACE {
21 
TransactionBaseImpl(DB * db,const WriteOptions & write_options)22 TransactionBaseImpl::TransactionBaseImpl(DB* db,
23                                          const WriteOptions& write_options)
24     : db_(db),
25       dbimpl_(static_cast_with_check<DBImpl, DB>(db)),
26       write_options_(write_options),
27       cmp_(GetColumnFamilyUserComparator(db->DefaultColumnFamily())),
28       start_time_(db_->GetEnv()->NowMicros()),
29       write_batch_(cmp_, 0, true, 0),
30       indexing_enabled_(true) {
31   assert(dynamic_cast<DBImpl*>(db_) != nullptr);
32   log_number_ = 0;
33   if (dbimpl_->allow_2pc()) {
34     InitWriteBatch();
35   }
36 }
37 
~TransactionBaseImpl()38 TransactionBaseImpl::~TransactionBaseImpl() {
39   // Release snapshot if snapshot is set
40   SetSnapshotInternal(nullptr);
41 }
42 
Clear()43 void TransactionBaseImpl::Clear() {
44   save_points_.reset(nullptr);
45   write_batch_.Clear();
46   commit_time_batch_.Clear();
47   tracked_keys_.clear();
48   num_puts_ = 0;
49   num_deletes_ = 0;
50   num_merges_ = 0;
51 
52   if (dbimpl_->allow_2pc()) {
53     InitWriteBatch();
54   }
55 }
56 
Reinitialize(DB * db,const WriteOptions & write_options)57 void TransactionBaseImpl::Reinitialize(DB* db,
58                                        const WriteOptions& write_options) {
59   Clear();
60   ClearSnapshot();
61   id_ = 0;
62   db_ = db;
63   name_.clear();
64   log_number_ = 0;
65   write_options_ = write_options;
66   start_time_ = db_->GetEnv()->NowMicros();
67   indexing_enabled_ = true;
68   cmp_ = GetColumnFamilyUserComparator(db_->DefaultColumnFamily());
69 }
70 
SetSnapshot()71 void TransactionBaseImpl::SetSnapshot() {
72   const Snapshot* snapshot = dbimpl_->GetSnapshotForWriteConflictBoundary();
73   SetSnapshotInternal(snapshot);
74 }
75 
SetSnapshotInternal(const Snapshot * snapshot)76 void TransactionBaseImpl::SetSnapshotInternal(const Snapshot* snapshot) {
77   // Set a custom deleter for the snapshot_ SharedPtr as the snapshot needs to
78   // be released, not deleted when it is no longer referenced.
79   snapshot_.reset(snapshot, std::bind(&TransactionBaseImpl::ReleaseSnapshot,
80                                       this, std::placeholders::_1, db_));
81   snapshot_needed_ = false;
82   snapshot_notifier_ = nullptr;
83 }
84 
SetSnapshotOnNextOperation(std::shared_ptr<TransactionNotifier> notifier)85 void TransactionBaseImpl::SetSnapshotOnNextOperation(
86     std::shared_ptr<TransactionNotifier> notifier) {
87   snapshot_needed_ = true;
88   snapshot_notifier_ = notifier;
89 }
90 
SetSnapshotIfNeeded()91 void TransactionBaseImpl::SetSnapshotIfNeeded() {
92   if (snapshot_needed_) {
93     std::shared_ptr<TransactionNotifier> notifier = snapshot_notifier_;
94     SetSnapshot();
95     if (notifier != nullptr) {
96       notifier->SnapshotCreated(GetSnapshot());
97     }
98   }
99 }
100 
TryLock(ColumnFamilyHandle * column_family,const SliceParts & key,bool read_only,bool exclusive,const bool do_validate,const bool assume_tracked)101 Status TransactionBaseImpl::TryLock(ColumnFamilyHandle* column_family,
102                                     const SliceParts& key, bool read_only,
103                                     bool exclusive, const bool do_validate,
104                                     const bool assume_tracked) {
105   size_t key_size = 0;
106   for (int i = 0; i < key.num_parts; ++i) {
107     key_size += key.parts[i].size();
108   }
109 
110   std::string str;
111   str.reserve(key_size);
112 
113   for (int i = 0; i < key.num_parts; ++i) {
114     str.append(key.parts[i].data(), key.parts[i].size());
115   }
116 
117   return TryLock(column_family, str, read_only, exclusive, do_validate,
118                  assume_tracked);
119 }
120 
SetSavePoint()121 void TransactionBaseImpl::SetSavePoint() {
122   if (save_points_ == nullptr) {
123     save_points_.reset(new std::stack<TransactionBaseImpl::SavePoint, autovector<TransactionBaseImpl::SavePoint>>());
124   }
125   save_points_->emplace(snapshot_, snapshot_needed_, snapshot_notifier_,
126                         num_puts_, num_deletes_, num_merges_);
127   write_batch_.SetSavePoint();
128 }
129 
RollbackToSavePoint()130 Status TransactionBaseImpl::RollbackToSavePoint() {
131   if (save_points_ != nullptr && save_points_->size() > 0) {
132     // Restore saved SavePoint
133     TransactionBaseImpl::SavePoint& save_point = save_points_->top();
134     snapshot_ = save_point.snapshot_;
135     snapshot_needed_ = save_point.snapshot_needed_;
136     snapshot_notifier_ = save_point.snapshot_notifier_;
137     num_puts_ = save_point.num_puts_;
138     num_deletes_ = save_point.num_deletes_;
139     num_merges_ = save_point.num_merges_;
140 
141     // Rollback batch
142     Status s = write_batch_.RollbackToSavePoint();
143     assert(s.ok());
144 
145     // Rollback any keys that were tracked since the last savepoint
146     const TransactionKeyMap& key_map = save_point.new_keys_;
147     for (const auto& key_map_iter : key_map) {
148       uint32_t column_family_id = key_map_iter.first;
149       auto& keys = key_map_iter.second;
150 
151       auto& cf_tracked_keys = tracked_keys_[column_family_id];
152 
153       for (const auto& key_iter : keys) {
154         const std::string& key = key_iter.first;
155         uint32_t num_reads = key_iter.second.num_reads;
156         uint32_t num_writes = key_iter.second.num_writes;
157 
158         auto tracked_keys_iter = cf_tracked_keys.find(key);
159         assert(tracked_keys_iter != cf_tracked_keys.end());
160 
161         // Decrement the total reads/writes of this key by the number of
162         // reads/writes done since the last SavePoint.
163         if (num_reads > 0) {
164           assert(tracked_keys_iter->second.num_reads >= num_reads);
165           tracked_keys_iter->second.num_reads -= num_reads;
166         }
167         if (num_writes > 0) {
168           assert(tracked_keys_iter->second.num_writes >= num_writes);
169           tracked_keys_iter->second.num_writes -= num_writes;
170         }
171         if (tracked_keys_iter->second.num_reads == 0 &&
172             tracked_keys_iter->second.num_writes == 0) {
173           cf_tracked_keys.erase(tracked_keys_iter);
174         }
175       }
176     }
177 
178     save_points_->pop();
179 
180     return s;
181   } else {
182     assert(write_batch_.RollbackToSavePoint().IsNotFound());
183     return Status::NotFound();
184   }
185 }
186 
PopSavePoint()187 Status TransactionBaseImpl::PopSavePoint() {
188   if (save_points_ == nullptr ||
189       save_points_->empty()) {
190     // No SavePoint yet.
191     assert(write_batch_.PopSavePoint().IsNotFound());
192     return Status::NotFound();
193   }
194 
195   assert(!save_points_->empty());
196   // If there is another savepoint A below the current savepoint B, then A needs
197   // to inherit tracked_keys in B so that if we rollback to savepoint A, we
198   // remember to unlock keys in B. If there is no other savepoint below, then we
199   // can safely discard savepoint info.
200   if (save_points_->size() == 1) {
201     save_points_->pop();
202   } else {
203     TransactionBaseImpl::SavePoint top;
204     std::swap(top, save_points_->top());
205     save_points_->pop();
206 
207     const TransactionKeyMap& curr_cf_key_map = top.new_keys_;
208     TransactionKeyMap& prev_cf_key_map = save_points_->top().new_keys_;
209 
210     for (const auto& curr_cf_key_iter : curr_cf_key_map) {
211       uint32_t column_family_id = curr_cf_key_iter.first;
212       const std::unordered_map<std::string, TransactionKeyMapInfo>& curr_keys =
213           curr_cf_key_iter.second;
214 
215       // If cfid was not previously tracked, just copy everything over.
216       auto prev_keys_iter = prev_cf_key_map.find(column_family_id);
217       if (prev_keys_iter == prev_cf_key_map.end()) {
218         prev_cf_key_map.emplace(curr_cf_key_iter);
219       } else {
220         std::unordered_map<std::string, TransactionKeyMapInfo>& prev_keys =
221             prev_keys_iter->second;
222         for (const auto& key_iter : curr_keys) {
223           const std::string& key = key_iter.first;
224           const TransactionKeyMapInfo& info = key_iter.second;
225           // If key was not previously tracked, just copy the whole struct over.
226           // Otherwise, some merging needs to occur.
227           auto prev_info = prev_keys.find(key);
228           if (prev_info == prev_keys.end()) {
229             prev_keys.emplace(key_iter);
230           } else {
231             prev_info->second.Merge(info);
232           }
233         }
234       }
235     }
236   }
237 
238   return write_batch_.PopSavePoint();
239 }
240 
Get(const ReadOptions & read_options,ColumnFamilyHandle * column_family,const Slice & key,std::string * value)241 Status TransactionBaseImpl::Get(const ReadOptions& read_options,
242                                 ColumnFamilyHandle* column_family,
243                                 const Slice& key, std::string* value) {
244   assert(value != nullptr);
245   PinnableSlice pinnable_val(value);
246   assert(!pinnable_val.IsPinned());
247   auto s = Get(read_options, column_family, key, &pinnable_val);
248   if (s.ok() && pinnable_val.IsPinned()) {
249     value->assign(pinnable_val.data(), pinnable_val.size());
250   }  // else value is already assigned
251   return s;
252 }
253 
Get(const ReadOptions & read_options,ColumnFamilyHandle * column_family,const Slice & key,PinnableSlice * pinnable_val)254 Status TransactionBaseImpl::Get(const ReadOptions& read_options,
255                                 ColumnFamilyHandle* column_family,
256                                 const Slice& key, PinnableSlice* pinnable_val) {
257   return write_batch_.GetFromBatchAndDB(db_, read_options, column_family, key,
258                                         pinnable_val);
259 }
260 
GetForUpdate(const ReadOptions & read_options,ColumnFamilyHandle * column_family,const Slice & key,std::string * value,bool exclusive,const bool do_validate)261 Status TransactionBaseImpl::GetForUpdate(const ReadOptions& read_options,
262                                          ColumnFamilyHandle* column_family,
263                                          const Slice& key, std::string* value,
264                                          bool exclusive,
265                                          const bool do_validate) {
266   if (!do_validate && read_options.snapshot != nullptr) {
267     return Status::InvalidArgument(
268         "If do_validate is false then GetForUpdate with snapshot is not "
269         "defined.");
270   }
271   Status s =
272       TryLock(column_family, key, true /* read_only */, exclusive, do_validate);
273 
274   if (s.ok() && value != nullptr) {
275     assert(value != nullptr);
276     PinnableSlice pinnable_val(value);
277     assert(!pinnable_val.IsPinned());
278     s = Get(read_options, column_family, key, &pinnable_val);
279     if (s.ok() && pinnable_val.IsPinned()) {
280       value->assign(pinnable_val.data(), pinnable_val.size());
281     }  // else value is already assigned
282   }
283   return s;
284 }
285 
GetForUpdate(const ReadOptions & read_options,ColumnFamilyHandle * column_family,const Slice & key,PinnableSlice * pinnable_val,bool exclusive,const bool do_validate)286 Status TransactionBaseImpl::GetForUpdate(const ReadOptions& read_options,
287                                          ColumnFamilyHandle* column_family,
288                                          const Slice& key,
289                                          PinnableSlice* pinnable_val,
290                                          bool exclusive,
291                                          const bool do_validate) {
292   if (!do_validate && read_options.snapshot != nullptr) {
293     return Status::InvalidArgument(
294         "If do_validate is false then GetForUpdate with snapshot is not "
295         "defined.");
296   }
297   Status s =
298       TryLock(column_family, key, true /* read_only */, exclusive, do_validate);
299 
300   if (s.ok() && pinnable_val != nullptr) {
301     s = Get(read_options, column_family, key, pinnable_val);
302   }
303   return s;
304 }
305 
MultiGet(const ReadOptions & read_options,const std::vector<ColumnFamilyHandle * > & column_family,const std::vector<Slice> & keys,std::vector<std::string> * values)306 std::vector<Status> TransactionBaseImpl::MultiGet(
307     const ReadOptions& read_options,
308     const std::vector<ColumnFamilyHandle*>& column_family,
309     const std::vector<Slice>& keys, std::vector<std::string>* values) {
310   size_t num_keys = keys.size();
311   values->resize(num_keys);
312 
313   std::vector<Status> stat_list(num_keys);
314   for (size_t i = 0; i < num_keys; ++i) {
315     std::string* value = values ? &(*values)[i] : nullptr;
316     stat_list[i] = Get(read_options, column_family[i], keys[i], value);
317   }
318 
319   return stat_list;
320 }
321 
MultiGet(const ReadOptions & read_options,ColumnFamilyHandle * column_family,const size_t num_keys,const Slice * keys,PinnableSlice * values,Status * statuses,const bool sorted_input)322 void TransactionBaseImpl::MultiGet(const ReadOptions& read_options,
323                                    ColumnFamilyHandle* column_family,
324                                    const size_t num_keys, const Slice* keys,
325                                    PinnableSlice* values, Status* statuses,
326                                    const bool sorted_input) {
327   write_batch_.MultiGetFromBatchAndDB(db_, read_options, column_family,
328                                       num_keys, keys, values, statuses,
329                                       sorted_input);
330 }
331 
MultiGetForUpdate(const ReadOptions & read_options,const std::vector<ColumnFamilyHandle * > & column_family,const std::vector<Slice> & keys,std::vector<std::string> * values)332 std::vector<Status> TransactionBaseImpl::MultiGetForUpdate(
333     const ReadOptions& read_options,
334     const std::vector<ColumnFamilyHandle*>& column_family,
335     const std::vector<Slice>& keys, std::vector<std::string>* values) {
336   // Regardless of whether the MultiGet succeeded, track these keys.
337   size_t num_keys = keys.size();
338   values->resize(num_keys);
339 
340   // Lock all keys
341   for (size_t i = 0; i < num_keys; ++i) {
342     Status s = TryLock(column_family[i], keys[i], true /* read_only */,
343                        true /* exclusive */);
344     if (!s.ok()) {
345       // Fail entire multiget if we cannot lock all keys
346       return std::vector<Status>(num_keys, s);
347     }
348   }
349 
350   // TODO(agiardullo): optimize multiget?
351   std::vector<Status> stat_list(num_keys);
352   for (size_t i = 0; i < num_keys; ++i) {
353     std::string* value = values ? &(*values)[i] : nullptr;
354     stat_list[i] = Get(read_options, column_family[i], keys[i], value);
355   }
356 
357   return stat_list;
358 }
359 
GetIterator(const ReadOptions & read_options)360 Iterator* TransactionBaseImpl::GetIterator(const ReadOptions& read_options) {
361   Iterator* db_iter = db_->NewIterator(read_options);
362   assert(db_iter);
363 
364   return write_batch_.NewIteratorWithBase(db_iter);
365 }
366 
GetIterator(const ReadOptions & read_options,ColumnFamilyHandle * column_family)367 Iterator* TransactionBaseImpl::GetIterator(const ReadOptions& read_options,
368                                            ColumnFamilyHandle* column_family) {
369   Iterator* db_iter = db_->NewIterator(read_options, column_family);
370   assert(db_iter);
371 
372   return write_batch_.NewIteratorWithBase(column_family, db_iter,
373                                           &read_options);
374 }
375 
Put(ColumnFamilyHandle * column_family,const Slice & key,const Slice & value,const bool assume_tracked)376 Status TransactionBaseImpl::Put(ColumnFamilyHandle* column_family,
377                                 const Slice& key, const Slice& value,
378                                 const bool assume_tracked) {
379   const bool do_validate = !assume_tracked;
380   Status s = TryLock(column_family, key, false /* read_only */,
381                      true /* exclusive */, do_validate, assume_tracked);
382 
383   if (s.ok()) {
384     s = GetBatchForWrite()->Put(column_family, key, value);
385     if (s.ok()) {
386       num_puts_++;
387     }
388   }
389 
390   return s;
391 }
392 
Put(ColumnFamilyHandle * column_family,const SliceParts & key,const SliceParts & value,const bool assume_tracked)393 Status TransactionBaseImpl::Put(ColumnFamilyHandle* column_family,
394                                 const SliceParts& key, const SliceParts& value,
395                                 const bool assume_tracked) {
396   const bool do_validate = !assume_tracked;
397   Status s = TryLock(column_family, key, false /* read_only */,
398                      true /* exclusive */, do_validate, assume_tracked);
399 
400   if (s.ok()) {
401     s = GetBatchForWrite()->Put(column_family, key, value);
402     if (s.ok()) {
403       num_puts_++;
404     }
405   }
406 
407   return s;
408 }
409 
Merge(ColumnFamilyHandle * column_family,const Slice & key,const Slice & value,const bool assume_tracked)410 Status TransactionBaseImpl::Merge(ColumnFamilyHandle* column_family,
411                                   const Slice& key, const Slice& value,
412                                   const bool assume_tracked) {
413   const bool do_validate = !assume_tracked;
414   Status s = TryLock(column_family, key, false /* read_only */,
415                      true /* exclusive */, do_validate, assume_tracked);
416 
417   if (s.ok()) {
418     s = GetBatchForWrite()->Merge(column_family, key, value);
419     if (s.ok()) {
420       num_merges_++;
421     }
422   }
423 
424   return s;
425 }
426 
Delete(ColumnFamilyHandle * column_family,const Slice & key,const bool assume_tracked)427 Status TransactionBaseImpl::Delete(ColumnFamilyHandle* column_family,
428                                    const Slice& key,
429                                    const bool assume_tracked) {
430   const bool do_validate = !assume_tracked;
431   Status s = TryLock(column_family, key, false /* read_only */,
432                      true /* exclusive */, do_validate, assume_tracked);
433 
434   if (s.ok()) {
435     s = GetBatchForWrite()->Delete(column_family, key);
436     if (s.ok()) {
437       num_deletes_++;
438     }
439   }
440 
441   return s;
442 }
443 
Delete(ColumnFamilyHandle * column_family,const SliceParts & key,const bool assume_tracked)444 Status TransactionBaseImpl::Delete(ColumnFamilyHandle* column_family,
445                                    const SliceParts& key,
446                                    const bool assume_tracked) {
447   const bool do_validate = !assume_tracked;
448   Status s = TryLock(column_family, key, false /* read_only */,
449                      true /* exclusive */, do_validate, assume_tracked);
450 
451   if (s.ok()) {
452     s = GetBatchForWrite()->Delete(column_family, key);
453     if (s.ok()) {
454       num_deletes_++;
455     }
456   }
457 
458   return s;
459 }
460 
SingleDelete(ColumnFamilyHandle * column_family,const Slice & key,const bool assume_tracked)461 Status TransactionBaseImpl::SingleDelete(ColumnFamilyHandle* column_family,
462                                          const Slice& key,
463                                          const bool assume_tracked) {
464   const bool do_validate = !assume_tracked;
465   Status s = TryLock(column_family, key, false /* read_only */,
466                      true /* exclusive */, do_validate, assume_tracked);
467 
468   if (s.ok()) {
469     s = GetBatchForWrite()->SingleDelete(column_family, key);
470     if (s.ok()) {
471       num_deletes_++;
472     }
473   }
474 
475   return s;
476 }
477 
SingleDelete(ColumnFamilyHandle * column_family,const SliceParts & key,const bool assume_tracked)478 Status TransactionBaseImpl::SingleDelete(ColumnFamilyHandle* column_family,
479                                          const SliceParts& key,
480                                          const bool assume_tracked) {
481   const bool do_validate = !assume_tracked;
482   Status s = TryLock(column_family, key, false /* read_only */,
483                      true /* exclusive */, do_validate, assume_tracked);
484 
485   if (s.ok()) {
486     s = GetBatchForWrite()->SingleDelete(column_family, key);
487     if (s.ok()) {
488       num_deletes_++;
489     }
490   }
491 
492   return s;
493 }
494 
PutUntracked(ColumnFamilyHandle * column_family,const Slice & key,const Slice & value)495 Status TransactionBaseImpl::PutUntracked(ColumnFamilyHandle* column_family,
496                                          const Slice& key, const Slice& value) {
497   Status s = TryLock(column_family, key, false /* read_only */,
498                      true /* exclusive */, false /* do_validate */);
499 
500   if (s.ok()) {
501     s = GetBatchForWrite()->Put(column_family, key, value);
502     if (s.ok()) {
503       num_puts_++;
504     }
505   }
506 
507   return s;
508 }
509 
PutUntracked(ColumnFamilyHandle * column_family,const SliceParts & key,const SliceParts & value)510 Status TransactionBaseImpl::PutUntracked(ColumnFamilyHandle* column_family,
511                                          const SliceParts& key,
512                                          const SliceParts& value) {
513   Status s = TryLock(column_family, key, false /* read_only */,
514                      true /* exclusive */, false /* do_validate */);
515 
516   if (s.ok()) {
517     s = GetBatchForWrite()->Put(column_family, key, value);
518     if (s.ok()) {
519       num_puts_++;
520     }
521   }
522 
523   return s;
524 }
525 
MergeUntracked(ColumnFamilyHandle * column_family,const Slice & key,const Slice & value)526 Status TransactionBaseImpl::MergeUntracked(ColumnFamilyHandle* column_family,
527                                            const Slice& key,
528                                            const Slice& value) {
529   Status s = TryLock(column_family, key, false /* read_only */,
530                      true /* exclusive */, false /* do_validate */);
531 
532   if (s.ok()) {
533     s = GetBatchForWrite()->Merge(column_family, key, value);
534     if (s.ok()) {
535       num_merges_++;
536     }
537   }
538 
539   return s;
540 }
541 
DeleteUntracked(ColumnFamilyHandle * column_family,const Slice & key)542 Status TransactionBaseImpl::DeleteUntracked(ColumnFamilyHandle* column_family,
543                                             const Slice& key) {
544   Status s = TryLock(column_family, key, false /* read_only */,
545                      true /* exclusive */, false /* do_validate */);
546 
547   if (s.ok()) {
548     s = GetBatchForWrite()->Delete(column_family, key);
549     if (s.ok()) {
550       num_deletes_++;
551     }
552   }
553 
554   return s;
555 }
556 
DeleteUntracked(ColumnFamilyHandle * column_family,const SliceParts & key)557 Status TransactionBaseImpl::DeleteUntracked(ColumnFamilyHandle* column_family,
558                                             const SliceParts& key) {
559   Status s = TryLock(column_family, key, false /* read_only */,
560                      true /* exclusive */, false /* do_validate */);
561 
562   if (s.ok()) {
563     s = GetBatchForWrite()->Delete(column_family, key);
564     if (s.ok()) {
565       num_deletes_++;
566     }
567   }
568 
569   return s;
570 }
571 
SingleDeleteUntracked(ColumnFamilyHandle * column_family,const Slice & key)572 Status TransactionBaseImpl::SingleDeleteUntracked(
573     ColumnFamilyHandle* column_family, const Slice& key) {
574   Status s = TryLock(column_family, key, false /* read_only */,
575                      true /* exclusive */, false /* do_validate */);
576 
577   if (s.ok()) {
578     s = GetBatchForWrite()->SingleDelete(column_family, key);
579     if (s.ok()) {
580       num_deletes_++;
581     }
582   }
583 
584   return s;
585 }
586 
PutLogData(const Slice & blob)587 void TransactionBaseImpl::PutLogData(const Slice& blob) {
588   write_batch_.PutLogData(blob);
589 }
590 
GetWriteBatch()591 WriteBatchWithIndex* TransactionBaseImpl::GetWriteBatch() {
592   return &write_batch_;
593 }
594 
GetElapsedTime() const595 uint64_t TransactionBaseImpl::GetElapsedTime() const {
596   return (db_->GetEnv()->NowMicros() - start_time_) / 1000;
597 }
598 
GetNumPuts() const599 uint64_t TransactionBaseImpl::GetNumPuts() const { return num_puts_; }
600 
GetNumDeletes() const601 uint64_t TransactionBaseImpl::GetNumDeletes() const { return num_deletes_; }
602 
GetNumMerges() const603 uint64_t TransactionBaseImpl::GetNumMerges() const { return num_merges_; }
604 
GetNumKeys() const605 uint64_t TransactionBaseImpl::GetNumKeys() const {
606   uint64_t count = 0;
607 
608   // sum up locked keys in all column families
609   for (const auto& key_map_iter : tracked_keys_) {
610     const auto& keys = key_map_iter.second;
611     count += keys.size();
612   }
613 
614   return count;
615 }
616 
TrackKey(uint32_t cfh_id,const std::string & key,SequenceNumber seq,bool read_only,bool exclusive)617 void TransactionBaseImpl::TrackKey(uint32_t cfh_id, const std::string& key,
618                                    SequenceNumber seq, bool read_only,
619                                    bool exclusive) {
620   // Update map of all tracked keys for this transaction
621   TrackKey(&tracked_keys_, cfh_id, key, seq, read_only, exclusive);
622 
623   if (save_points_ != nullptr && !save_points_->empty()) {
624     // Update map of tracked keys in this SavePoint
625     TrackKey(&save_points_->top().new_keys_, cfh_id, key, seq, read_only,
626              exclusive);
627   }
628 }
629 
630 // Add a key to the given TransactionKeyMap
631 // seq for pessimistic transactions is the sequence number from which we know
632 // there has not been a concurrent update to the key.
TrackKey(TransactionKeyMap * key_map,uint32_t cfh_id,const std::string & key,SequenceNumber seq,bool read_only,bool exclusive)633 void TransactionBaseImpl::TrackKey(TransactionKeyMap* key_map, uint32_t cfh_id,
634                                    const std::string& key, SequenceNumber seq,
635                                    bool read_only, bool exclusive) {
636   auto& cf_key_map = (*key_map)[cfh_id];
637 #ifdef __cpp_lib_unordered_map_try_emplace
638   // use c++17's try_emplace if available, to avoid rehashing the key
639   // in case it is not already in the map
640   auto result = cf_key_map.try_emplace(key, seq);
641   auto iter = result.first;
642   if (!result.second && seq < iter->second.seq) {
643     // Now tracking this key with an earlier sequence number
644     iter->second.seq = seq;
645   }
646 #else
647   auto iter = cf_key_map.find(key);
648   if (iter == cf_key_map.end()) {
649     auto result = cf_key_map.emplace(key, TransactionKeyMapInfo(seq));
650     iter = result.first;
651   } else if (seq < iter->second.seq) {
652     // Now tracking this key with an earlier sequence number
653     iter->second.seq = seq;
654   }
655 #endif
656   // else we do not update the seq. The smaller the tracked seq, the stronger it
657   // the guarantee since it implies from the seq onward there has not been a
658   // concurrent update to the key. So we update the seq if it implies stronger
659   // guarantees, i.e., if it is smaller than the existing tracked seq.
660 
661   if (read_only) {
662     iter->second.num_reads++;
663   } else {
664     iter->second.num_writes++;
665   }
666   iter->second.exclusive |= exclusive;
667 }
668 
669 std::unique_ptr<TransactionKeyMap>
GetTrackedKeysSinceSavePoint()670 TransactionBaseImpl::GetTrackedKeysSinceSavePoint() {
671   if (save_points_ != nullptr && !save_points_->empty()) {
672     // Examine the number of reads/writes performed on all keys written
673     // since the last SavePoint and compare to the total number of reads/writes
674     // for each key.
675     TransactionKeyMap* result = new TransactionKeyMap();
676     for (const auto& key_map_iter : save_points_->top().new_keys_) {
677       uint32_t column_family_id = key_map_iter.first;
678       auto& keys = key_map_iter.second;
679 
680       auto& cf_tracked_keys = tracked_keys_[column_family_id];
681 
682       for (const auto& key_iter : keys) {
683         const std::string& key = key_iter.first;
684         uint32_t num_reads = key_iter.second.num_reads;
685         uint32_t num_writes = key_iter.second.num_writes;
686 
687         auto total_key_info = cf_tracked_keys.find(key);
688         assert(total_key_info != cf_tracked_keys.end());
689         assert(total_key_info->second.num_reads >= num_reads);
690         assert(total_key_info->second.num_writes >= num_writes);
691 
692         if (total_key_info->second.num_reads == num_reads &&
693             total_key_info->second.num_writes == num_writes) {
694           // All the reads/writes to this key were done in the last savepoint.
695           bool read_only = (num_writes == 0);
696           TrackKey(result, column_family_id, key, key_iter.second.seq,
697                    read_only, key_iter.second.exclusive);
698         }
699       }
700     }
701     return std::unique_ptr<TransactionKeyMap>(result);
702   }
703 
704   // No SavePoint
705   return nullptr;
706 }
707 
708 // Gets the write batch that should be used for Put/Merge/Deletes.
709 //
710 // Returns either a WriteBatch or WriteBatchWithIndex depending on whether
711 // DisableIndexing() has been called.
GetBatchForWrite()712 WriteBatchBase* TransactionBaseImpl::GetBatchForWrite() {
713   if (indexing_enabled_) {
714     // Use WriteBatchWithIndex
715     return &write_batch_;
716   } else {
717     // Don't use WriteBatchWithIndex. Return base WriteBatch.
718     return write_batch_.GetWriteBatch();
719   }
720 }
721 
ReleaseSnapshot(const Snapshot * snapshot,DB * db)722 void TransactionBaseImpl::ReleaseSnapshot(const Snapshot* snapshot, DB* db) {
723   if (snapshot != nullptr) {
724     ROCKS_LOG_DETAILS(dbimpl_->immutable_db_options().info_log,
725                       "ReleaseSnapshot %" PRIu64 " Set",
726                       snapshot->GetSequenceNumber());
727     db->ReleaseSnapshot(snapshot);
728   }
729 }
730 
UndoGetForUpdate(ColumnFamilyHandle * column_family,const Slice & key)731 void TransactionBaseImpl::UndoGetForUpdate(ColumnFamilyHandle* column_family,
732                                            const Slice& key) {
733   uint32_t column_family_id = GetColumnFamilyID(column_family);
734   auto& cf_tracked_keys = tracked_keys_[column_family_id];
735   std::string key_str = key.ToString();
736   bool can_decrement = false;
737   bool can_unlock __attribute__((__unused__)) = false;
738 
739   if (save_points_ != nullptr && !save_points_->empty()) {
740     // Check if this key was fetched ForUpdate in this SavePoint
741     auto& cf_savepoint_keys = save_points_->top().new_keys_[column_family_id];
742 
743     auto savepoint_iter = cf_savepoint_keys.find(key_str);
744     if (savepoint_iter != cf_savepoint_keys.end()) {
745       if (savepoint_iter->second.num_reads > 0) {
746         savepoint_iter->second.num_reads--;
747         can_decrement = true;
748 
749         if (savepoint_iter->second.num_reads == 0 &&
750             savepoint_iter->second.num_writes == 0) {
751           // No other GetForUpdates or write on this key in this SavePoint
752           cf_savepoint_keys.erase(savepoint_iter);
753           can_unlock = true;
754         }
755       }
756     }
757   } else {
758     // No SavePoint set
759     can_decrement = true;
760     can_unlock = true;
761   }
762 
763   // We can only decrement the read count for this key if we were able to
764   // decrement the read count in the current SavePoint, OR if there is no
765   // SavePoint set.
766   if (can_decrement) {
767     auto key_iter = cf_tracked_keys.find(key_str);
768 
769     if (key_iter != cf_tracked_keys.end()) {
770       if (key_iter->second.num_reads > 0) {
771         key_iter->second.num_reads--;
772 
773         if (key_iter->second.num_reads == 0 &&
774             key_iter->second.num_writes == 0) {
775           // No other GetForUpdates or writes on this key
776           assert(can_unlock);
777           cf_tracked_keys.erase(key_iter);
778           UnlockGetForUpdate(column_family, key);
779         }
780       }
781     }
782   }
783 }
784 
RebuildFromWriteBatch(WriteBatch * src_batch)785 Status TransactionBaseImpl::RebuildFromWriteBatch(WriteBatch* src_batch) {
786   struct IndexedWriteBatchBuilder : public WriteBatch::Handler {
787     Transaction* txn_;
788     DBImpl* db_;
789     IndexedWriteBatchBuilder(Transaction* txn, DBImpl* db)
790         : txn_(txn), db_(db) {
791       assert(dynamic_cast<TransactionBaseImpl*>(txn_) != nullptr);
792     }
793 
794     Status PutCF(uint32_t cf, const Slice& key, const Slice& val) override {
795       return txn_->Put(db_->GetColumnFamilyHandle(cf), key, val);
796     }
797 
798     Status DeleteCF(uint32_t cf, const Slice& key) override {
799       return txn_->Delete(db_->GetColumnFamilyHandle(cf), key);
800     }
801 
802     Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
803       return txn_->SingleDelete(db_->GetColumnFamilyHandle(cf), key);
804     }
805 
806     Status MergeCF(uint32_t cf, const Slice& key, const Slice& val) override {
807       return txn_->Merge(db_->GetColumnFamilyHandle(cf), key, val);
808     }
809 
810     // this is used for reconstructing prepared transactions upon
811     // recovery. there should not be any meta markers in the batches
812     // we are processing.
813     Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
814 
815     Status MarkEndPrepare(const Slice&) override {
816       return Status::InvalidArgument();
817     }
818 
819     Status MarkCommit(const Slice&) override {
820       return Status::InvalidArgument();
821     }
822 
823     Status MarkRollback(const Slice&) override {
824       return Status::InvalidArgument();
825     }
826   };
827 
828   IndexedWriteBatchBuilder copycat(this, dbimpl_);
829   return src_batch->Iterate(&copycat);
830 }
831 
GetCommitTimeWriteBatch()832 WriteBatch* TransactionBaseImpl::GetCommitTimeWriteBatch() {
833   return &commit_time_batch_;
834 }
835 }  // namespace ROCKSDB_NAMESPACE
836 
837 #endif  // ROCKSDB_LITE
838