1 // Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
2 //  This source code is licensed under both the GPLv2 (found in the
3 //  COPYING file in the root directory) and Apache 2.0 License
4 //  (found in the LICENSE.Apache file in the root directory).
5 
6 #pragma once
7 
8 #ifndef ROCKSDB_LITE
9 
10 #include <stack>
11 #include <string>
12 #include <vector>
13 
14 #include "db/write_batch_internal.h"
15 #include "rocksdb/db.h"
16 #include "rocksdb/slice.h"
17 #include "rocksdb/snapshot.h"
18 #include "rocksdb/status.h"
19 #include "rocksdb/types.h"
20 #include "rocksdb/utilities/transaction.h"
21 #include "rocksdb/utilities/transaction_db.h"
22 #include "rocksdb/utilities/write_batch_with_index.h"
23 #include "util/autovector.h"
24 #include "utilities/transactions/transaction_util.h"
25 
26 namespace ROCKSDB_NAMESPACE {
27 
28 class TransactionBaseImpl : public Transaction {
29  public:
30   TransactionBaseImpl(DB* db, const WriteOptions& write_options);
31 
32   virtual ~TransactionBaseImpl();
33 
34   // Remove pending operations queued in this transaction.
35   virtual void Clear();
36 
37   void Reinitialize(DB* db, const WriteOptions& write_options);
38 
39   // Called before executing Put, Merge, Delete, and GetForUpdate.  If TryLock
40   // returns non-OK, the Put/Merge/Delete/GetForUpdate will be failed.
41   // do_validate will be false if called from PutUntracked, DeleteUntracked,
42   // MergeUntracked, or GetForUpdate(do_validate=false)
43   virtual Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
44                          bool read_only, bool exclusive,
45                          const bool do_validate = true,
46                          const bool assume_tracked = false) = 0;
47 
48   void SetSavePoint() override;
49 
50   Status RollbackToSavePoint() override;
51 
52   Status PopSavePoint() override;
53 
54   using Transaction::Get;
55   Status Get(const ReadOptions& options, ColumnFamilyHandle* column_family,
56              const Slice& key, std::string* value) override;
57 
58   Status Get(const ReadOptions& options, ColumnFamilyHandle* column_family,
59              const Slice& key, PinnableSlice* value) override;
60 
Get(const ReadOptions & options,const Slice & key,std::string * value)61   Status Get(const ReadOptions& options, const Slice& key,
62              std::string* value) override {
63     return Get(options, db_->DefaultColumnFamily(), key, value);
64   }
65 
66   using Transaction::GetForUpdate;
67   Status GetForUpdate(const ReadOptions& options,
68                       ColumnFamilyHandle* column_family, const Slice& key,
69                       std::string* value, bool exclusive,
70                       const bool do_validate) override;
71 
72   Status GetForUpdate(const ReadOptions& options,
73                       ColumnFamilyHandle* column_family, const Slice& key,
74                       PinnableSlice* pinnable_val, bool exclusive,
75                       const bool do_validate) override;
76 
GetForUpdate(const ReadOptions & options,const Slice & key,std::string * value,bool exclusive,const bool do_validate)77   Status GetForUpdate(const ReadOptions& options, const Slice& key,
78                       std::string* value, bool exclusive,
79                       const bool do_validate) override {
80     return GetForUpdate(options, db_->DefaultColumnFamily(), key, value,
81                         exclusive, do_validate);
82   }
83 
84   using Transaction::MultiGet;
85   std::vector<Status> MultiGet(
86       const ReadOptions& options,
87       const std::vector<ColumnFamilyHandle*>& column_family,
88       const std::vector<Slice>& keys,
89       std::vector<std::string>* values) override;
90 
MultiGet(const ReadOptions & options,const std::vector<Slice> & keys,std::vector<std::string> * values)91   std::vector<Status> MultiGet(const ReadOptions& options,
92                                const std::vector<Slice>& keys,
93                                std::vector<std::string>* values) override {
94     return MultiGet(options, std::vector<ColumnFamilyHandle*>(
95                                  keys.size(), db_->DefaultColumnFamily()),
96                     keys, values);
97   }
98 
99   void MultiGet(const ReadOptions& options, ColumnFamilyHandle* column_family,
100                 const size_t num_keys, const Slice* keys, PinnableSlice* values,
101                 Status* statuses, const bool sorted_input = false) override;
102 
103   using Transaction::MultiGetForUpdate;
104   std::vector<Status> MultiGetForUpdate(
105       const ReadOptions& options,
106       const std::vector<ColumnFamilyHandle*>& column_family,
107       const std::vector<Slice>& keys,
108       std::vector<std::string>* values) override;
109 
MultiGetForUpdate(const ReadOptions & options,const std::vector<Slice> & keys,std::vector<std::string> * values)110   std::vector<Status> MultiGetForUpdate(
111       const ReadOptions& options, const std::vector<Slice>& keys,
112       std::vector<std::string>* values) override {
113     return MultiGetForUpdate(options,
114                              std::vector<ColumnFamilyHandle*>(
115                                  keys.size(), db_->DefaultColumnFamily()),
116                              keys, values);
117   }
118 
119   Iterator* GetIterator(const ReadOptions& read_options) override;
120   Iterator* GetIterator(const ReadOptions& read_options,
121                         ColumnFamilyHandle* column_family) override;
122 
123   Status Put(ColumnFamilyHandle* column_family, const Slice& key,
124              const Slice& value, const bool assume_tracked = false) override;
Put(const Slice & key,const Slice & value)125   Status Put(const Slice& key, const Slice& value) override {
126     return Put(nullptr, key, value);
127   }
128 
129   Status Put(ColumnFamilyHandle* column_family, const SliceParts& key,
130              const SliceParts& value,
131              const bool assume_tracked = false) override;
Put(const SliceParts & key,const SliceParts & value)132   Status Put(const SliceParts& key, const SliceParts& value) override {
133     return Put(nullptr, key, value);
134   }
135 
136   Status Merge(ColumnFamilyHandle* column_family, const Slice& key,
137                const Slice& value, const bool assume_tracked = false) override;
Merge(const Slice & key,const Slice & value)138   Status Merge(const Slice& key, const Slice& value) override {
139     return Merge(nullptr, key, value);
140   }
141 
142   Status Delete(ColumnFamilyHandle* column_family, const Slice& key,
143                 const bool assume_tracked = false) override;
Delete(const Slice & key)144   Status Delete(const Slice& key) override { return Delete(nullptr, key); }
145   Status Delete(ColumnFamilyHandle* column_family, const SliceParts& key,
146                 const bool assume_tracked = false) override;
Delete(const SliceParts & key)147   Status Delete(const SliceParts& key) override { return Delete(nullptr, key); }
148 
149   Status SingleDelete(ColumnFamilyHandle* column_family, const Slice& key,
150                       const bool assume_tracked = false) override;
SingleDelete(const Slice & key)151   Status SingleDelete(const Slice& key) override {
152     return SingleDelete(nullptr, key);
153   }
154   Status SingleDelete(ColumnFamilyHandle* column_family, const SliceParts& key,
155                       const bool assume_tracked = false) override;
SingleDelete(const SliceParts & key)156   Status SingleDelete(const SliceParts& key) override {
157     return SingleDelete(nullptr, key);
158   }
159 
160   Status PutUntracked(ColumnFamilyHandle* column_family, const Slice& key,
161                       const Slice& value) override;
PutUntracked(const Slice & key,const Slice & value)162   Status PutUntracked(const Slice& key, const Slice& value) override {
163     return PutUntracked(nullptr, key, value);
164   }
165 
166   Status PutUntracked(ColumnFamilyHandle* column_family, const SliceParts& key,
167                       const SliceParts& value) override;
PutUntracked(const SliceParts & key,const SliceParts & value)168   Status PutUntracked(const SliceParts& key, const SliceParts& value) override {
169     return PutUntracked(nullptr, key, value);
170   }
171 
172   Status MergeUntracked(ColumnFamilyHandle* column_family, const Slice& key,
173                         const Slice& value) override;
MergeUntracked(const Slice & key,const Slice & value)174   Status MergeUntracked(const Slice& key, const Slice& value) override {
175     return MergeUntracked(nullptr, key, value);
176   }
177 
178   Status DeleteUntracked(ColumnFamilyHandle* column_family,
179                          const Slice& key) override;
DeleteUntracked(const Slice & key)180   Status DeleteUntracked(const Slice& key) override {
181     return DeleteUntracked(nullptr, key);
182   }
183   Status DeleteUntracked(ColumnFamilyHandle* column_family,
184                          const SliceParts& key) override;
DeleteUntracked(const SliceParts & key)185   Status DeleteUntracked(const SliceParts& key) override {
186     return DeleteUntracked(nullptr, key);
187   }
188 
189   Status SingleDeleteUntracked(ColumnFamilyHandle* column_family,
190                                const Slice& key) override;
SingleDeleteUntracked(const Slice & key)191   Status SingleDeleteUntracked(const Slice& key) override {
192     return SingleDeleteUntracked(nullptr, key);
193   }
194 
195   void PutLogData(const Slice& blob) override;
196 
197   WriteBatchWithIndex* GetWriteBatch() override;
198 
SetLockTimeout(int64_t)199   virtual void SetLockTimeout(int64_t /*timeout*/) override { /* Do nothing */
200   }
201 
GetSnapshot()202   const Snapshot* GetSnapshot() const override {
203     return snapshot_ ? snapshot_.get() : nullptr;
204   }
205 
206   virtual void SetSnapshot() override;
207   void SetSnapshotOnNextOperation(
208       std::shared_ptr<TransactionNotifier> notifier = nullptr) override;
209 
ClearSnapshot()210   void ClearSnapshot() override {
211     snapshot_.reset();
212     snapshot_needed_ = false;
213     snapshot_notifier_ = nullptr;
214   }
215 
DisableIndexing()216   void DisableIndexing() override { indexing_enabled_ = false; }
217 
EnableIndexing()218   void EnableIndexing() override { indexing_enabled_ = true; }
219 
220   uint64_t GetElapsedTime() const override;
221 
222   uint64_t GetNumPuts() const override;
223 
224   uint64_t GetNumDeletes() const override;
225 
226   uint64_t GetNumMerges() const override;
227 
228   uint64_t GetNumKeys() const override;
229 
230   void UndoGetForUpdate(ColumnFamilyHandle* column_family,
231                         const Slice& key) override;
UndoGetForUpdate(const Slice & key)232   void UndoGetForUpdate(const Slice& key) override {
233     return UndoGetForUpdate(nullptr, key);
234   };
235 
236   // Get list of keys in this transaction that must not have any conflicts
237   // with writes in other transactions.
GetTrackedKeys()238   const TransactionKeyMap& GetTrackedKeys() const { return tracked_keys_; }
239 
GetWriteOptions()240   WriteOptions* GetWriteOptions() override { return &write_options_; }
241 
SetWriteOptions(const WriteOptions & write_options)242   void SetWriteOptions(const WriteOptions& write_options) override {
243     write_options_ = write_options;
244   }
245 
246   // Used for memory management for snapshot_
247   void ReleaseSnapshot(const Snapshot* snapshot, DB* db);
248 
249   // iterates over the given batch and makes the appropriate inserts.
250   // used for rebuilding prepared transactions after recovery.
251   virtual Status RebuildFromWriteBatch(WriteBatch* src_batch) override;
252 
253   WriteBatch* GetCommitTimeWriteBatch() override;
254 
255  protected:
256   // Add a key to the list of tracked keys.
257   //
258   // seqno is the earliest seqno this key was involved with this transaction.
259   // readonly should be set to true if no data was written for this key
260   void TrackKey(uint32_t cfh_id, const std::string& key, SequenceNumber seqno,
261                 bool readonly, bool exclusive);
262 
263   // Helper function to add a key to the given TransactionKeyMap
264   static void TrackKey(TransactionKeyMap* key_map, uint32_t cfh_id,
265                        const std::string& key, SequenceNumber seqno,
266                        bool readonly, bool exclusive);
267 
268   // Called when UndoGetForUpdate determines that this key can be unlocked.
269   virtual void UnlockGetForUpdate(ColumnFamilyHandle* column_family,
270                                   const Slice& key) = 0;
271 
272   std::unique_ptr<TransactionKeyMap> GetTrackedKeysSinceSavePoint();
273 
274   // Sets a snapshot if SetSnapshotOnNextOperation() has been called.
275   void SetSnapshotIfNeeded();
276 
277   // Initialize write_batch_ for 2PC by inserting Noop.
278   inline void InitWriteBatch(bool clear = false) {
279     if (clear) {
280       write_batch_.Clear();
281     }
282     assert(write_batch_.GetDataSize() == WriteBatchInternal::kHeader);
283     WriteBatchInternal::InsertNoop(write_batch_.GetWriteBatch());
284   }
285 
286   DB* db_;
287   DBImpl* dbimpl_;
288 
289   WriteOptions write_options_;
290 
291   const Comparator* cmp_;
292 
293   // Stores that time the txn was constructed, in microseconds.
294   uint64_t start_time_;
295 
296   // Stores the current snapshot that was set by SetSnapshot or null if
297   // no snapshot is currently set.
298   std::shared_ptr<const Snapshot> snapshot_;
299 
300   // Count of various operations pending in this transaction
301   uint64_t num_puts_ = 0;
302   uint64_t num_deletes_ = 0;
303   uint64_t num_merges_ = 0;
304 
305   struct SavePoint {
306     std::shared_ptr<const Snapshot> snapshot_;
307     bool snapshot_needed_ = false;
308     std::shared_ptr<TransactionNotifier> snapshot_notifier_;
309     uint64_t num_puts_ = 0;
310     uint64_t num_deletes_ = 0;
311     uint64_t num_merges_ = 0;
312 
313     // Record all keys tracked since the last savepoint
314     TransactionKeyMap new_keys_;
315 
SavePointSavePoint316     SavePoint(std::shared_ptr<const Snapshot> snapshot, bool snapshot_needed,
317               std::shared_ptr<TransactionNotifier> snapshot_notifier,
318               uint64_t num_puts, uint64_t num_deletes, uint64_t num_merges)
319         : snapshot_(snapshot),
320           snapshot_needed_(snapshot_needed),
321           snapshot_notifier_(snapshot_notifier),
322           num_puts_(num_puts),
323           num_deletes_(num_deletes),
324           num_merges_(num_merges) {}
325 
326     SavePoint() = default;
327   };
328 
329   // Records writes pending in this transaction
330   WriteBatchWithIndex write_batch_;
331 
332   // Map from column_family_id to map of keys that are involved in this
333   // transaction.
334   // For Pessimistic Transactions this is the list of locked keys.
335   // Optimistic Transactions will wait till commit time to do conflict checking.
336   TransactionKeyMap tracked_keys_;
337 
338   // Stack of the Snapshot saved at each save point. Saved snapshots may be
339   // nullptr if there was no snapshot at the time SetSavePoint() was called.
340   std::unique_ptr<std::stack<TransactionBaseImpl::SavePoint,
341                              autovector<TransactionBaseImpl::SavePoint>>>
342       save_points_;
343 
344  private:
345   friend class WritePreparedTxn;
346   // Extra data to be persisted with the commit. Note this is only used when
347   // prepare phase is not skipped.
348   WriteBatch commit_time_batch_;
349 
350   // If true, future Put/Merge/Deletes will be indexed in the
351   // WriteBatchWithIndex.
352   // If false, future Put/Merge/Deletes will be inserted directly into the
353   // underlying WriteBatch and not indexed in the WriteBatchWithIndex.
354   bool indexing_enabled_;
355 
356   // SetSnapshotOnNextOperation() has been called and the snapshot has not yet
357   // been reset.
358   bool snapshot_needed_ = false;
359 
360   // SetSnapshotOnNextOperation() has been called and the caller would like
361   // a notification through the TransactionNotifier interface
362   std::shared_ptr<TransactionNotifier> snapshot_notifier_ = nullptr;
363 
364   Status TryLock(ColumnFamilyHandle* column_family, const SliceParts& key,
365                  bool read_only, bool exclusive, const bool do_validate = true,
366                  const bool assume_tracked = false);
367 
368   WriteBatchBase* GetBatchForWrite();
369   void SetSnapshotInternal(const Snapshot* snapshot);
370 };
371 
372 }  // namespace ROCKSDB_NAMESPACE
373 
374 #endif  // ROCKSDB_LITE
375