1 // Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
2 // This source code is licensed under both the GPLv2 (found in the
3 // COPYING file in the root directory) and Apache 2.0 License
4 // (found in the LICENSE.Apache file in the root directory).
5 
6 #pragma once
7 
8 #ifndef ROCKSDB_LITE
9 
10 #include <stack>
11 #include <string>
12 #include <vector>
13 
14 #include "db/write_batch_internal.h"
15 #include "rocksdb/db.h"
16 #include "rocksdb/slice.h"
17 #include "rocksdb/snapshot.h"
18 #include "rocksdb/status.h"
19 #include "rocksdb/types.h"
20 #include "rocksdb/utilities/transaction.h"
21 #include "rocksdb/utilities/transaction_db.h"
22 #include "rocksdb/utilities/write_batch_with_index.h"
23 #include "util/autovector.h"
24 #include "utilities/transactions/lock/lock_tracker.h"
25 #include "utilities/transactions/transaction_util.h"
26 
27 namespace ROCKSDB_NAMESPACE {
28 
29 class TransactionBaseImpl : public Transaction {
30  public:
31   TransactionBaseImpl(DB* db, const WriteOptions& write_options,
32                       const LockTrackerFactory& lock_tracker_factory);
33 
34   virtual ~TransactionBaseImpl();
35 
36   // Remove pending operations queued in this transaction.
37   virtual void Clear();
38 
39   void Reinitialize(DB* db, const WriteOptions& write_options);
40 
41   // Called before executing Put, Merge, Delete, and GetForUpdate.  If TryLock
42   // returns non-OK, the Put/Merge/Delete/GetForUpdate will be failed.
43   // do_validate will be false if called from PutUntracked, DeleteUntracked,
44   // MergeUntracked, or GetForUpdate(do_validate=false)
45   virtual Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
46                          bool read_only, bool exclusive,
47                          const bool do_validate = true,
48                          const bool assume_tracked = false) = 0;
49 
50   void SetSavePoint() override;
51 
52   Status RollbackToSavePoint() override;
53 
54   Status PopSavePoint() override;
55 
56   using Transaction::Get;
57   Status Get(const ReadOptions& options, ColumnFamilyHandle* column_family,
58              const Slice& key, std::string* value) override;
59 
60   Status Get(const ReadOptions& options, ColumnFamilyHandle* column_family,
61              const Slice& key, PinnableSlice* value) override;
62 
Get(const ReadOptions & options,const Slice & key,std::string * value)63   Status Get(const ReadOptions& options, const Slice& key,
64              std::string* value) override {
65     return Get(options, db_->DefaultColumnFamily(), key, value);
66   }
67 
68   using Transaction::GetForUpdate;
69   Status GetForUpdate(const ReadOptions& options,
70                       ColumnFamilyHandle* column_family, const Slice& key,
71                       std::string* value, bool exclusive,
72                       const bool do_validate) override;
73 
74   Status GetForUpdate(const ReadOptions& options,
75                       ColumnFamilyHandle* column_family, const Slice& key,
76                       PinnableSlice* pinnable_val, bool exclusive,
77                       const bool do_validate) override;
78 
GetForUpdate(const ReadOptions & options,const Slice & key,std::string * value,bool exclusive,const bool do_validate)79   Status GetForUpdate(const ReadOptions& options, const Slice& key,
80                       std::string* value, bool exclusive,
81                       const bool do_validate) override {
82     return GetForUpdate(options, db_->DefaultColumnFamily(), key, value,
83                         exclusive, do_validate);
84   }
85 
86   using Transaction::MultiGet;
87   std::vector<Status> MultiGet(
88       const ReadOptions& options,
89       const std::vector<ColumnFamilyHandle*>& column_family,
90       const std::vector<Slice>& keys,
91       std::vector<std::string>* values) override;
92 
MultiGet(const ReadOptions & options,const std::vector<Slice> & keys,std::vector<std::string> * values)93   std::vector<Status> MultiGet(const ReadOptions& options,
94                                const std::vector<Slice>& keys,
95                                std::vector<std::string>* values) override {
96     return MultiGet(options, std::vector<ColumnFamilyHandle*>(
97                                  keys.size(), db_->DefaultColumnFamily()),
98                     keys, values);
99   }
100 
101   void MultiGet(const ReadOptions& options, ColumnFamilyHandle* column_family,
102                 const size_t num_keys, const Slice* keys, PinnableSlice* values,
103                 Status* statuses, const bool sorted_input = false) override;
104 
105   using Transaction::MultiGetForUpdate;
106   std::vector<Status> MultiGetForUpdate(
107       const ReadOptions& options,
108       const std::vector<ColumnFamilyHandle*>& column_family,
109       const std::vector<Slice>& keys,
110       std::vector<std::string>* values) override;
111 
MultiGetForUpdate(const ReadOptions & options,const std::vector<Slice> & keys,std::vector<std::string> * values)112   std::vector<Status> MultiGetForUpdate(
113       const ReadOptions& options, const std::vector<Slice>& keys,
114       std::vector<std::string>* values) override {
115     return MultiGetForUpdate(options,
116                              std::vector<ColumnFamilyHandle*>(
117                                  keys.size(), db_->DefaultColumnFamily()),
118                              keys, values);
119   }
120 
121   Iterator* GetIterator(const ReadOptions& read_options) override;
122   Iterator* GetIterator(const ReadOptions& read_options,
123                         ColumnFamilyHandle* column_family) override;
124 
125   Status Put(ColumnFamilyHandle* column_family, const Slice& key,
126              const Slice& value, const bool assume_tracked = false) override;
Put(const Slice & key,const Slice & value)127   Status Put(const Slice& key, const Slice& value) override {
128     return Put(nullptr, key, value);
129   }
130 
131   Status Put(ColumnFamilyHandle* column_family, const SliceParts& key,
132              const SliceParts& value,
133              const bool assume_tracked = false) override;
Put(const SliceParts & key,const SliceParts & value)134   Status Put(const SliceParts& key, const SliceParts& value) override {
135     return Put(nullptr, key, value);
136   }
137 
138   Status Merge(ColumnFamilyHandle* column_family, const Slice& key,
139                const Slice& value, const bool assume_tracked = false) override;
Merge(const Slice & key,const Slice & value)140   Status Merge(const Slice& key, const Slice& value) override {
141     return Merge(nullptr, key, value);
142   }
143 
144   Status Delete(ColumnFamilyHandle* column_family, const Slice& key,
145                 const bool assume_tracked = false) override;
Delete(const Slice & key)146   Status Delete(const Slice& key) override { return Delete(nullptr, key); }
147   Status Delete(ColumnFamilyHandle* column_family, const SliceParts& key,
148                 const bool assume_tracked = false) override;
Delete(const SliceParts & key)149   Status Delete(const SliceParts& key) override { return Delete(nullptr, key); }
150 
151   Status SingleDelete(ColumnFamilyHandle* column_family, const Slice& key,
152                       const bool assume_tracked = false) override;
SingleDelete(const Slice & key)153   Status SingleDelete(const Slice& key) override {
154     return SingleDelete(nullptr, key);
155   }
156   Status SingleDelete(ColumnFamilyHandle* column_family, const SliceParts& key,
157                       const bool assume_tracked = false) override;
SingleDelete(const SliceParts & key)158   Status SingleDelete(const SliceParts& key) override {
159     return SingleDelete(nullptr, key);
160   }
161 
162   Status PutUntracked(ColumnFamilyHandle* column_family, const Slice& key,
163                       const Slice& value) override;
PutUntracked(const Slice & key,const Slice & value)164   Status PutUntracked(const Slice& key, const Slice& value) override {
165     return PutUntracked(nullptr, key, value);
166   }
167 
168   Status PutUntracked(ColumnFamilyHandle* column_family, const SliceParts& key,
169                       const SliceParts& value) override;
PutUntracked(const SliceParts & key,const SliceParts & value)170   Status PutUntracked(const SliceParts& key, const SliceParts& value) override {
171     return PutUntracked(nullptr, key, value);
172   }
173 
174   Status MergeUntracked(ColumnFamilyHandle* column_family, const Slice& key,
175                         const Slice& value) override;
MergeUntracked(const Slice & key,const Slice & value)176   Status MergeUntracked(const Slice& key, const Slice& value) override {
177     return MergeUntracked(nullptr, key, value);
178   }
179 
180   Status DeleteUntracked(ColumnFamilyHandle* column_family,
181                          const Slice& key) override;
DeleteUntracked(const Slice & key)182   Status DeleteUntracked(const Slice& key) override {
183     return DeleteUntracked(nullptr, key);
184   }
185   Status DeleteUntracked(ColumnFamilyHandle* column_family,
186                          const SliceParts& key) override;
DeleteUntracked(const SliceParts & key)187   Status DeleteUntracked(const SliceParts& key) override {
188     return DeleteUntracked(nullptr, key);
189   }
190 
191   Status SingleDeleteUntracked(ColumnFamilyHandle* column_family,
192                                const Slice& key) override;
SingleDeleteUntracked(const Slice & key)193   Status SingleDeleteUntracked(const Slice& key) override {
194     return SingleDeleteUntracked(nullptr, key);
195   }
196 
197   void PutLogData(const Slice& blob) override;
198 
199   WriteBatchWithIndex* GetWriteBatch() override;
200 
SetLockTimeout(int64_t)201   virtual void SetLockTimeout(int64_t /*timeout*/) override { /* Do nothing */
202   }
203 
GetSnapshot()204   const Snapshot* GetSnapshot() const override {
205     return snapshot_ ? snapshot_.get() : nullptr;
206   }
207 
208   virtual void SetSnapshot() override;
209   void SetSnapshotOnNextOperation(
210       std::shared_ptr<TransactionNotifier> notifier = nullptr) override;
211 
ClearSnapshot()212   void ClearSnapshot() override {
213     snapshot_.reset();
214     snapshot_needed_ = false;
215     snapshot_notifier_ = nullptr;
216   }
217 
DisableIndexing()218   void DisableIndexing() override { indexing_enabled_ = false; }
219 
EnableIndexing()220   void EnableIndexing() override { indexing_enabled_ = true; }
221 
222   uint64_t GetElapsedTime() const override;
223 
224   uint64_t GetNumPuts() const override;
225 
226   uint64_t GetNumDeletes() const override;
227 
228   uint64_t GetNumMerges() const override;
229 
230   uint64_t GetNumKeys() const override;
231 
232   void UndoGetForUpdate(ColumnFamilyHandle* column_family,
233                         const Slice& key) override;
UndoGetForUpdate(const Slice & key)234   void UndoGetForUpdate(const Slice& key) override {
235     return UndoGetForUpdate(nullptr, key);
236   };
237 
GetWriteOptions()238   WriteOptions* GetWriteOptions() override { return &write_options_; }
239 
SetWriteOptions(const WriteOptions & write_options)240   void SetWriteOptions(const WriteOptions& write_options) override {
241     write_options_ = write_options;
242   }
243 
244   // Used for memory management for snapshot_
245   void ReleaseSnapshot(const Snapshot* snapshot, DB* db);
246 
247   // iterates over the given batch and makes the appropriate inserts.
248   // used for rebuilding prepared transactions after recovery.
249   virtual Status RebuildFromWriteBatch(WriteBatch* src_batch) override;
250 
251   WriteBatch* GetCommitTimeWriteBatch() override;
252 
GetTrackedLocks()253   LockTracker& GetTrackedLocks() { return *tracked_locks_; }
254 
255  protected:
256   // Add a key to the list of tracked keys.
257   //
258   // seqno is the earliest seqno this key was involved with this transaction.
259   // readonly should be set to true if no data was written for this key
260   void TrackKey(uint32_t cfh_id, const std::string& key, SequenceNumber seqno,
261                 bool readonly, bool exclusive);
262 
263   // Called when UndoGetForUpdate determines that this key can be unlocked.
264   virtual void UnlockGetForUpdate(ColumnFamilyHandle* column_family,
265                                   const Slice& key) = 0;
266 
267   // Sets a snapshot if SetSnapshotOnNextOperation() has been called.
268   void SetSnapshotIfNeeded();
269 
270   // Initialize write_batch_ for 2PC by inserting Noop.
271   inline void InitWriteBatch(bool clear = false) {
272     if (clear) {
273       write_batch_.Clear();
274     }
275     assert(write_batch_.GetDataSize() == WriteBatchInternal::kHeader);
276     auto s = WriteBatchInternal::InsertNoop(write_batch_.GetWriteBatch());
277     assert(s.ok());
278   }
279 
280   DB* db_;
281   DBImpl* dbimpl_;
282 
283   WriteOptions write_options_;
284 
285   const Comparator* cmp_;
286 
287   const LockTrackerFactory& lock_tracker_factory_;
288 
289   // Stores that time the txn was constructed, in microseconds.
290   uint64_t start_time_;
291 
292   // Stores the current snapshot that was set by SetSnapshot or null if
293   // no snapshot is currently set.
294   std::shared_ptr<const Snapshot> snapshot_;
295 
296   // Count of various operations pending in this transaction
297   uint64_t num_puts_ = 0;
298   uint64_t num_deletes_ = 0;
299   uint64_t num_merges_ = 0;
300 
301   struct SavePoint {
302     std::shared_ptr<const Snapshot> snapshot_;
303     bool snapshot_needed_ = false;
304     std::shared_ptr<TransactionNotifier> snapshot_notifier_;
305     uint64_t num_puts_ = 0;
306     uint64_t num_deletes_ = 0;
307     uint64_t num_merges_ = 0;
308 
309     // Record all locks tracked since the last savepoint
310     std::shared_ptr<LockTracker> new_locks_;
311 
SavePointSavePoint312     SavePoint(std::shared_ptr<const Snapshot> snapshot, bool snapshot_needed,
313               std::shared_ptr<TransactionNotifier> snapshot_notifier,
314               uint64_t num_puts, uint64_t num_deletes, uint64_t num_merges,
315               const LockTrackerFactory& lock_tracker_factory)
316         : snapshot_(snapshot),
317           snapshot_needed_(snapshot_needed),
318           snapshot_notifier_(snapshot_notifier),
319           num_puts_(num_puts),
320           num_deletes_(num_deletes),
321           num_merges_(num_merges),
322           new_locks_(lock_tracker_factory.Create()) {}
323 
SavePointSavePoint324     explicit SavePoint(const LockTrackerFactory& lock_tracker_factory)
325         : new_locks_(lock_tracker_factory.Create()) {}
326   };
327 
328   // Records writes pending in this transaction
329   WriteBatchWithIndex write_batch_;
330 
331   // For Pessimistic Transactions this is the set of acquired locks.
332   // Optimistic Transactions will keep note the requested locks (not actually
333   // locked), and do conflict checking until commit time based on the tracked
334   // lock requests.
335   std::unique_ptr<LockTracker> tracked_locks_;
336 
337   // Stack of the Snapshot saved at each save point. Saved snapshots may be
338   // nullptr if there was no snapshot at the time SetSavePoint() was called.
339   std::unique_ptr<std::stack<TransactionBaseImpl::SavePoint,
340                              autovector<TransactionBaseImpl::SavePoint>>>
341       save_points_;
342 
343  private:
344   friend class WritePreparedTxn;
345   // Extra data to be persisted with the commit. Note this is only used when
346   // prepare phase is not skipped.
347   WriteBatch commit_time_batch_;
348 
349   // If true, future Put/Merge/Deletes will be indexed in the
350   // WriteBatchWithIndex.
351   // If false, future Put/Merge/Deletes will be inserted directly into the
352   // underlying WriteBatch and not indexed in the WriteBatchWithIndex.
353   bool indexing_enabled_;
354 
355   // SetSnapshotOnNextOperation() has been called and the snapshot has not yet
356   // been reset.
357   bool snapshot_needed_ = false;
358 
359   // SetSnapshotOnNextOperation() has been called and the caller would like
360   // a notification through the TransactionNotifier interface
361   std::shared_ptr<TransactionNotifier> snapshot_notifier_ = nullptr;
362 
363   Status TryLock(ColumnFamilyHandle* column_family, const SliceParts& key,
364                  bool read_only, bool exclusive, const bool do_validate = true,
365                  const bool assume_tracked = false);
366 
367   WriteBatchBase* GetBatchForWrite();
368   void SetSnapshotInternal(const Snapshot* snapshot);
369 };
370 
371 }  // namespace ROCKSDB_NAMESPACE
372 
373 #endif  // ROCKSDB_LITE
374