1 use crate::durability::Durability;
2 use crate::plumbing::CycleDetected;
3 use crate::revision::{AtomicRevision, Revision};
4 use crate::{CycleError, Database, DatabaseKeyIndex, Event, EventKind};
5 use log::debug;
6 use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
7 use parking_lot::{Mutex, RwLock};
8 use rustc_hash::{FxHashMap, FxHasher};
9 use smallvec::SmallVec;
10 use std::hash::{BuildHasherDefault, Hash};
11 use std::sync::atomic::{AtomicUsize, Ordering};
12 use std::sync::Arc;
13 
14 pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
15 pub(crate) type FxIndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>;
16 
17 mod local_state;
18 use local_state::LocalState;
19 
20 /// The salsa runtime stores the storage for all queries as well as
21 /// tracking the query stack and dependencies between cycles.
22 ///
23 /// Each new runtime you create (e.g., via `Runtime::new` or
24 /// `Runtime::default`) will have an independent set of query storage
25 /// associated with it. Normally, therefore, you only do this once, at
26 /// the start of your application.
27 pub struct Runtime {
28     /// Our unique runtime id.
29     id: RuntimeId,
30 
31     /// If this is a "forked" runtime, then the `revision_guard` will
32     /// be `Some`; this guard holds a read-lock on the global query
33     /// lock.
34     revision_guard: Option<RevisionGuard>,
35 
36     /// Local state that is specific to this runtime (thread).
37     local_state: LocalState,
38 
39     /// Shared state that is accessible via all runtimes.
40     shared_state: Arc<SharedState>,
41 }
42 
43 impl Default for Runtime {
default() -> Self44     fn default() -> Self {
45         Runtime {
46             id: RuntimeId { counter: 0 },
47             revision_guard: None,
48             shared_state: Default::default(),
49             local_state: Default::default(),
50         }
51     }
52 }
53 
54 impl std::fmt::Debug for Runtime {
fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result55     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56         fmt.debug_struct("Runtime")
57             .field("id", &self.id())
58             .field("forked", &self.revision_guard.is_some())
59             .field("shared_state", &self.shared_state)
60             .finish()
61     }
62 }
63 
64 impl Runtime {
65     /// Create a new runtime; equivalent to `Self::default`. This is
66     /// used when creating a new database.
new() -> Self67     pub fn new() -> Self {
68         Self::default()
69     }
70 
71     /// See [`crate::storage::Storage::snapshot`].
snapshot(&self) -> Self72     pub(crate) fn snapshot(&self) -> Self {
73         if self.local_state.query_in_progress() {
74             panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
75         }
76 
77         let revision_guard = RevisionGuard::new(&self.shared_state);
78 
79         let id = RuntimeId {
80             counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst),
81         };
82 
83         Runtime {
84             id,
85             revision_guard: Some(revision_guard),
86             shared_state: self.shared_state.clone(),
87             local_state: Default::default(),
88         }
89     }
90 
91     /// A "synthetic write" causes the system to act *as though* some
92     /// input of durability `durability` has changed. This is mostly
93     /// useful for profiling scenarios, but it also has interactions
94     /// with garbage collection. In general, a synthetic write to
95     /// durability level D will cause the system to fully trace all
96     /// queries of durability level D and below. When running a GC, then:
97     ///
98     /// - Synthetic writes will cause more derived values to be
99     ///   *retained*.  This is because derived values are only
100     ///   retained if they are traced, and a synthetic write can cause
101     ///   more things to be traced.
102     /// - Synthetic writes can cause more interned values to be
103     ///   *collected*. This is because interned values can only be
104     ///   collected if they were not yet traced in the current
105     ///   revision. Therefore, if you issue a synthetic write, execute
106     ///   some query Q, and then start collecting interned values, you
107     ///   will be able to recycle interned values not used in Q.
108     ///
109     /// In general, then, one can do a "full GC" that retains only
110     /// those things that are used by some query Q by (a) doing a
111     /// synthetic write at `Durability::HIGH`, (b) executing the query
112     /// Q and then (c) doing a sweep.
113     ///
114     /// **WARNING:** Just like an ordinary write, this method triggers
115     /// cancellation. If you invoke it while a snapshot exists, it
116     /// will block until that snapshot is dropped -- if that snapshot
117     /// is owned by the current thread, this could trigger deadlock.
synthetic_write(&mut self, durability: Durability)118     pub fn synthetic_write(&mut self, durability: Durability) {
119         self.with_incremented_revision(&mut |_next_revision| Some(durability));
120     }
121 
122     /// The unique identifier attached to this `SalsaRuntime`. Each
123     /// snapshotted runtime has a distinct identifier.
124     #[inline]
id(&self) -> RuntimeId125     pub fn id(&self) -> RuntimeId {
126         self.id
127     }
128 
129     /// Returns the database-key for the query that this thread is
130     /// actively executing (if any).
active_query(&self) -> Option<DatabaseKeyIndex>131     pub fn active_query(&self) -> Option<DatabaseKeyIndex> {
132         self.local_state.active_query()
133     }
134 
135     /// Read current value of the revision counter.
136     #[inline]
current_revision(&self) -> Revision137     pub(crate) fn current_revision(&self) -> Revision {
138         self.shared_state.revisions[0].load()
139     }
140 
141     /// The revision in which values with durability `d` may have last
142     /// changed.  For D0, this is just the current revision. But for
143     /// higher levels of durability, this value may lag behind the
144     /// current revision. If we encounter a value of durability Di,
145     /// then, we can check this function to get a "bound" on when the
146     /// value may have changed, which allows us to skip walking its
147     /// dependencies.
148     #[inline]
last_changed_revision(&self, d: Durability) -> Revision149     pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision {
150         self.shared_state.revisions[d.index()].load()
151     }
152 
153     /// Read current value of the revision counter.
154     #[inline]
pending_revision(&self) -> Revision155     fn pending_revision(&self) -> Revision {
156         self.shared_state.pending_revision.load()
157     }
158 
159     /// Check if the current revision is canceled. If this method ever
160     /// returns true, the currently executing query is also marked as
161     /// having an *untracked read* -- this means that, in the next
162     /// revision, we will always recompute its value "as if" some
163     /// input had changed. This means that, if your revision is
164     /// canceled (which indicates that current query results will be
165     /// ignored) your query is free to shortcircuit and return
166     /// whatever it likes.
167     ///
168     /// This method is useful for implementing cancellation of queries.
169     /// You can do it in one of two ways, via `Result`s or via unwinding.
170     ///
171     /// The `Result` approach looks like this:
172     ///
173     ///   * Some queries invoke `is_current_revision_canceled` and
174     ///     return a special value, like `Err(Canceled)`, if it returns
175     ///     `true`.
176     ///   * Other queries propagate the special value using `?` operator.
177     ///   * API around top-level queries checks if the result is `Ok` or
178     ///     `Err(Canceled)`.
179     ///
180     /// The `panic` approach works in a similar way:
181     ///
182     ///   * Some queries invoke `is_current_revision_canceled` and
183     ///     panic with a special value, like `Canceled`, if it returns
184     ///     true.
185     ///   * The implementation of `Database` trait overrides
186     ///     `on_propagated_panic` to throw this special value as well.
187     ///     This way, panic gets propagated naturally through dependant
188     ///     queries, even across the threads.
189     ///   * API around top-level queries converts a `panic` into `Result` by
190     ///     catching the panic (using either `std::panic::catch_unwind` or
191     ///     threads) and downcasting the payload to `Canceled` (re-raising
192     ///     panic if downcast fails).
193     ///
194     /// Note that salsa is explicitly designed to be panic-safe, so cancellation
195     /// via unwinding is 100% valid approach to cancellation.
196     #[inline]
is_current_revision_canceled(&self) -> bool197     pub fn is_current_revision_canceled(&self) -> bool {
198         let current_revision = self.current_revision();
199         let pending_revision = self.pending_revision();
200         debug!(
201             "is_current_revision_canceled: current_revision={:?}, pending_revision={:?}",
202             current_revision, pending_revision
203         );
204         if pending_revision > current_revision {
205             self.report_untracked_read();
206             true
207         } else {
208             // Subtle: If the current revision is not canceled, we
209             // still report an **anonymous** read, which will bump up
210             // the revision number to be at least the last
211             // non-canceled revision. This is needed to ensure
212             // deterministic reads and avoid salsa-rs/salsa#66. The
213             // specific scenario we are trying to avoid is tested by
214             // `no_back_dating_in_cancellation`; it works like
215             // this. Imagine we have 3 queries, where Query3 invokes
216             // Query2 which invokes Query1. Then:
217             //
218             // - In Revision R1:
219             //   - Query1: Observes cancelation and returns sentinel S.
220             //     - Recorded inputs: Untracked, because we observed cancelation.
221             //   - Query2: Reads Query1 and propagates sentinel S.
222             //     - Recorded inputs: Query1, changed-at=R1
223             //   - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1)
224             //     - Recorded inputs: Query2, changed-at=R1
225             // - In Revision R2:
226             //   - Query1: Observes no cancelation. All of its inputs last changed in R0,
227             //     so it returns a valid value with "changed at" of R0.
228             //     - Recorded inputs: ..., changed-at=R0
229             //   - Query2: Recomputes its value and returns correct result.
230             //     - Recorded inputs: Query1, changed-at=R0 <-- key problem!
231             //   - Query3: sees that Query2's result last changed in R0, so it thinks it
232             //     can re-use its value from R1 (which is the sentinel value).
233             //
234             // The anonymous read here prevents that scenario: Query1
235             // winds up with a changed-at setting of R2, which is the
236             // "pending revision", and hence Query2 and Query3
237             // are recomputed.
238             assert_eq!(pending_revision, current_revision);
239             self.report_anon_read(pending_revision);
240             false
241         }
242     }
243 
244     /// Acquires the **global query write lock** (ensuring that no queries are
245     /// executing) and then increments the current revision counter; invokes
246     /// `op` with the global query write lock still held.
247     ///
248     /// While we wait to acquire the global query write lock, this method will
249     /// also increment `pending_revision_increments`, thus signalling to queries
250     /// that their results are "canceled" and they should abort as expeditiously
251     /// as possible.
252     ///
253     /// The `op` closure should actually perform the writes needed. It is given
254     /// the new revision as an argument, and its return value indicates whether
255     /// any pre-existing value was modified:
256     ///
257     /// - returning `None` means that no pre-existing value was modified (this
258     ///   could occur e.g. when setting some key on an input that was never set
259     ///   before)
260     /// - returning `Some(d)` indicates that a pre-existing value was modified
261     ///   and it had the durability `d`. This will update the records for when
262     ///   values with each durability were modified.
263     ///
264     /// Note that, given our writer model, we can assume that only one thread is
265     /// attempting to increment the global revision at a time.
with_incremented_revision( &mut self, op: &mut dyn FnMut(Revision) -> Option<Durability>, )266     pub(crate) fn with_incremented_revision(
267         &mut self,
268         op: &mut dyn FnMut(Revision) -> Option<Durability>,
269     ) {
270         log::debug!("increment_revision()");
271 
272         if !self.permits_increment() {
273             panic!("increment_revision invoked during a query computation");
274         }
275 
276         // Set the `pending_revision` field so that people
277         // know current revision is canceled.
278         let current_revision = self.shared_state.pending_revision.fetch_then_increment();
279 
280         // To modify the revision, we need the lock.
281         let shared_state = self.shared_state.clone();
282         let _lock = shared_state.query_lock.write();
283 
284         let old_revision = self.shared_state.revisions[0].fetch_then_increment();
285         assert_eq!(current_revision, old_revision);
286 
287         let new_revision = current_revision.next();
288 
289         debug!("increment_revision: incremented to {:?}", new_revision);
290 
291         if let Some(d) = op(new_revision) {
292             for rev in &self.shared_state.revisions[1..=d.index()] {
293                 rev.store(new_revision);
294             }
295         }
296     }
297 
permits_increment(&self) -> bool298     pub(crate) fn permits_increment(&self) -> bool {
299         self.revision_guard.is_none() && !self.local_state.query_in_progress()
300     }
301 
execute_query_implementation<DB, V>( &self, db: &DB, database_key_index: DatabaseKeyIndex, execute: impl FnOnce() -> V, ) -> ComputedQueryResult<V> where DB: ?Sized + Database,302     pub(crate) fn execute_query_implementation<DB, V>(
303         &self,
304         db: &DB,
305         database_key_index: DatabaseKeyIndex,
306         execute: impl FnOnce() -> V,
307     ) -> ComputedQueryResult<V>
308     where
309         DB: ?Sized + Database,
310     {
311         debug!(
312             "{:?}: execute_query_implementation invoked",
313             database_key_index
314         );
315 
316         db.salsa_event(Event {
317             runtime_id: self.id(),
318             kind: EventKind::WillExecute {
319                 database_key: database_key_index,
320             },
321         });
322 
323         // Push the active query onto the stack.
324         let max_durability = Durability::MAX;
325         let active_query = self
326             .local_state
327             .push_query(database_key_index, max_durability);
328 
329         // Execute user's code, accumulating inputs etc.
330         let value = execute();
331 
332         // Extract accumulated inputs.
333         let ActiveQuery {
334             dependencies,
335             changed_at,
336             durability,
337             cycle,
338             ..
339         } = active_query.complete();
340 
341         ComputedQueryResult {
342             value,
343             durability,
344             changed_at,
345             dependencies,
346             cycle,
347         }
348     }
349 
350     /// Reports that the currently active query read the result from
351     /// another query.
352     ///
353     /// # Parameters
354     ///
355     /// - `database_key`: the query whose result was read
356     /// - `changed_revision`: the last revision in which the result of that
357     ///   query had changed
report_query_read<'hack>( &self, input: DatabaseKeyIndex, durability: Durability, changed_at: Revision, )358     pub(crate) fn report_query_read<'hack>(
359         &self,
360         input: DatabaseKeyIndex,
361         durability: Durability,
362         changed_at: Revision,
363     ) {
364         self.local_state
365             .report_query_read(input, durability, changed_at);
366     }
367 
368     /// Reports that the query depends on some state unknown to salsa.
369     ///
370     /// Queries which report untracked reads will be re-executed in the next
371     /// revision.
report_untracked_read(&self)372     pub fn report_untracked_read(&self) {
373         self.local_state
374             .report_untracked_read(self.current_revision());
375     }
376 
377     /// Acts as though the current query had read an input with the given durability; this will force the current query's durability to be at most `durability`.
378     ///
379     /// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html).
report_synthetic_read(&self, durability: Durability)380     pub fn report_synthetic_read(&self, durability: Durability) {
381         self.local_state.report_synthetic_read(durability);
382     }
383 
384     /// An "anonymous" read is a read that doesn't come from executing
385     /// a query, but from some other internal operation. It just
386     /// modifies the "changed at" to be at least the given revision.
387     /// (It also does not disqualify a query from being considered
388     /// constant, since it is used for queries that don't give back
389     /// actual *data*.)
390     ///
391     /// This is used when queries check if they have been canceled.
report_anon_read(&self, revision: Revision)392     fn report_anon_read(&self, revision: Revision) {
393         self.local_state.report_anon_read(revision)
394     }
395 
396     /// Obviously, this should be user configurable at some point.
report_unexpected_cycle( &self, database_key_index: DatabaseKeyIndex, error: CycleDetected, changed_at: Revision, ) -> crate::CycleError<DatabaseKeyIndex>397     pub(crate) fn report_unexpected_cycle(
398         &self,
399         database_key_index: DatabaseKeyIndex,
400         error: CycleDetected,
401         changed_at: Revision,
402     ) -> crate::CycleError<DatabaseKeyIndex> {
403         debug!(
404             "report_unexpected_cycle(database_key={:?})",
405             database_key_index
406         );
407 
408         let mut query_stack = self.local_state.borrow_query_stack_mut();
409 
410         if error.from == error.to {
411             // All queries in the cycle is local
412             let start_index = query_stack
413                 .iter()
414                 .rposition(|active_query| active_query.database_key_index == database_key_index)
415                 .unwrap();
416             let mut cycle = Vec::new();
417             let cycle_participants = &mut query_stack[start_index..];
418             for active_query in &mut *cycle_participants {
419                 cycle.push(active_query.database_key_index);
420             }
421 
422             assert!(!cycle.is_empty());
423 
424             for active_query in cycle_participants {
425                 active_query.cycle = cycle.clone();
426             }
427 
428             crate::CycleError {
429                 cycle,
430                 changed_at,
431                 durability: Durability::MAX,
432             }
433         } else {
434             // Part of the cycle is on another thread so we need to lock and inspect the shared
435             // state
436             let dependency_graph = self.shared_state.dependency_graph.lock();
437 
438             let mut cycle = Vec::new();
439             dependency_graph.push_cycle_path(
440                 database_key_index,
441                 error.to,
442                 query_stack.iter().map(|query| query.database_key_index),
443                 &mut cycle,
444             );
445             cycle.push(database_key_index);
446 
447             assert!(!cycle.is_empty());
448 
449             for active_query in query_stack
450                 .iter_mut()
451                 .filter(|query| cycle.iter().any(|key| *key == query.database_key_index))
452             {
453                 active_query.cycle = cycle.clone();
454             }
455 
456             crate::CycleError {
457                 cycle,
458                 changed_at,
459                 durability: Durability::MAX,
460             }
461         }
462     }
463 
mark_cycle_participants(&self, err: &CycleError<DatabaseKeyIndex>)464     pub(crate) fn mark_cycle_participants(&self, err: &CycleError<DatabaseKeyIndex>) {
465         for active_query in self
466             .local_state
467             .borrow_query_stack_mut()
468             .iter_mut()
469             .rev()
470             .take_while(|active_query| {
471                 err.cycle
472                     .iter()
473                     .any(|e| *e == active_query.database_key_index)
474             })
475         {
476             active_query.cycle = err.cycle.clone();
477         }
478     }
479 
480     /// Try to make this runtime blocked on `other_id`. Returns true
481     /// upon success or false if `other_id` is already blocked on us.
try_block_on(&self, database_key: DatabaseKeyIndex, other_id: RuntimeId) -> bool482     pub(crate) fn try_block_on(&self, database_key: DatabaseKeyIndex, other_id: RuntimeId) -> bool {
483         self.shared_state.dependency_graph.lock().add_edge(
484             self.id(),
485             database_key,
486             other_id,
487             self.local_state
488                 .borrow_query_stack()
489                 .iter()
490                 .map(|query| query.database_key_index),
491         )
492     }
493 
unblock_queries_blocked_on_self(&self, database_key_index: DatabaseKeyIndex)494     pub(crate) fn unblock_queries_blocked_on_self(&self, database_key_index: DatabaseKeyIndex) {
495         self.shared_state
496             .dependency_graph
497             .lock()
498             .remove_edge(database_key_index, self.id())
499     }
500 }
501 
502 /// State that will be common to all threads (when we support multiple threads)
503 struct SharedState {
504     /// Stores the next id to use for a snapshotted runtime (starts at 1).
505     next_id: AtomicUsize,
506 
507     /// Whenever derived queries are executing, they acquire this lock
508     /// in read mode. Mutating inputs (and thus creating a new
509     /// revision) requires a write lock (thus guaranteeing that no
510     /// derived queries are in progress). Note that this is not needed
511     /// to prevent **race conditions** -- the revision counter itself
512     /// is stored in an `AtomicUsize` so it can be cheaply read
513     /// without acquiring the lock.  Rather, the `query_lock` is used
514     /// to ensure a higher-level consistency property.
515     query_lock: RwLock<()>,
516 
517     /// This is typically equal to `revision` -- set to `revision+1`
518     /// when a new revision is pending (which implies that the current
519     /// revision is canceled).
520     pending_revision: AtomicRevision,
521 
522     /// Stores the "last change" revision for values of each duration.
523     /// This vector is always of length at least 1 (for Durability 0)
524     /// but its total length depends on the number of durations. The
525     /// element at index 0 is special as it represents the "current
526     /// revision".  In general, we have the invariant that revisions
527     /// in here are *declining* -- that is, `revisions[i] >=
528     /// revisions[i + 1]`, for all `i`. This is because when you
529     /// modify a value with durability D, that implies that values
530     /// with durability less than D may have changed too.
531     revisions: Vec<AtomicRevision>,
532 
533     /// The dependency graph tracks which runtimes are blocked on one
534     /// another, waiting for queries to terminate.
535     dependency_graph: Mutex<DependencyGraph<DatabaseKeyIndex>>,
536 }
537 
538 impl SharedState {
with_durabilities(durabilities: usize) -> Self539     fn with_durabilities(durabilities: usize) -> Self {
540         SharedState {
541             next_id: AtomicUsize::new(1),
542             query_lock: Default::default(),
543             revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(),
544             pending_revision: AtomicRevision::start(),
545             dependency_graph: Default::default(),
546         }
547     }
548 }
549 
550 impl std::panic::RefUnwindSafe for SharedState {}
551 
552 impl Default for SharedState {
default() -> Self553     fn default() -> Self {
554         Self::with_durabilities(Durability::LEN)
555     }
556 }
557 
558 impl std::fmt::Debug for SharedState {
fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result559     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560         let query_lock = if self.query_lock.try_write().is_some() {
561             "<unlocked>"
562         } else if self.query_lock.try_read().is_some() {
563             "<rlocked>"
564         } else {
565             "<wlocked>"
566         };
567         fmt.debug_struct("SharedState")
568             .field("query_lock", &query_lock)
569             .field("revisions", &self.revisions)
570             .field("pending_revision", &self.pending_revision)
571             .finish()
572     }
573 }
574 
575 struct ActiveQuery {
576     /// What query is executing
577     database_key_index: DatabaseKeyIndex,
578 
579     /// Minimum durability of inputs observed so far.
580     durability: Durability,
581 
582     /// Maximum revision of all inputs observed. If we observe an
583     /// untracked read, this will be set to the most recent revision.
584     changed_at: Revision,
585 
586     /// Set of subqueries that were accessed thus far, or `None` if
587     /// there was an untracked the read.
588     dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
589 
590     /// Stores the entire cycle, if one is found and this query is part of it.
591     cycle: Vec<DatabaseKeyIndex>,
592 }
593 
594 pub(crate) struct ComputedQueryResult<V> {
595     /// Final value produced
596     pub(crate) value: V,
597 
598     /// Minimum durability of inputs observed so far.
599     pub(crate) durability: Durability,
600 
601     /// Maximum revision of all inputs observed. If we observe an
602     /// untracked read, this will be set to the most recent revision.
603     pub(crate) changed_at: Revision,
604 
605     /// Complete set of subqueries that were accessed, or `None` if
606     /// there was an untracked the read.
607     pub(crate) dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
608 
609     /// The cycle if one occured while computing this value
610     pub(crate) cycle: Vec<DatabaseKeyIndex>,
611 }
612 
613 impl ActiveQuery {
new(database_key_index: DatabaseKeyIndex, max_durability: Durability) -> Self614     fn new(database_key_index: DatabaseKeyIndex, max_durability: Durability) -> Self {
615         ActiveQuery {
616             database_key_index,
617             durability: max_durability,
618             changed_at: Revision::start(),
619             dependencies: Some(FxIndexSet::default()),
620             cycle: Vec::new(),
621         }
622     }
623 
add_read(&mut self, input: DatabaseKeyIndex, durability: Durability, revision: Revision)624     fn add_read(&mut self, input: DatabaseKeyIndex, durability: Durability, revision: Revision) {
625         if let Some(set) = &mut self.dependencies {
626             set.insert(input);
627         }
628 
629         self.durability = self.durability.min(durability);
630         self.changed_at = self.changed_at.max(revision);
631     }
632 
add_untracked_read(&mut self, changed_at: Revision)633     fn add_untracked_read(&mut self, changed_at: Revision) {
634         self.dependencies = None;
635         self.durability = Durability::LOW;
636         self.changed_at = changed_at;
637     }
638 
add_synthetic_read(&mut self, durability: Durability)639     fn add_synthetic_read(&mut self, durability: Durability) {
640         self.durability = self.durability.min(durability);
641     }
642 
add_anon_read(&mut self, changed_at: Revision)643     fn add_anon_read(&mut self, changed_at: Revision) {
644         self.changed_at = self.changed_at.max(changed_at);
645     }
646 }
647 
648 /// A unique identifier for a particular runtime. Each time you create
649 /// a snapshot, a fresh `RuntimeId` is generated. Once a snapshot is
650 /// complete, its `RuntimeId` may potentially be re-used.
651 #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
652 pub struct RuntimeId {
653     counter: usize,
654 }
655 
656 #[derive(Clone, Debug)]
657 pub(crate) struct StampedValue<V> {
658     pub(crate) value: V,
659     pub(crate) durability: Durability,
660     pub(crate) changed_at: Revision,
661 }
662 
663 #[derive(Debug)]
664 struct Edge<K> {
665     id: RuntimeId,
666     path: Vec<K>,
667 }
668 
669 #[derive(Debug)]
670 struct DependencyGraph<K: Hash + Eq> {
671     /// A `(K -> V)` pair in this map indicates that the the runtime
672     /// `K` is blocked on some query executing in the runtime `V`.
673     /// This encodes a graph that must be acyclic (or else deadlock
674     /// will result).
675     edges: FxHashMap<RuntimeId, Edge<K>>,
676     labels: FxHashMap<K, SmallVec<[RuntimeId; 4]>>,
677 }
678 
679 impl<K> Default for DependencyGraph<K>
680 where
681     K: Hash + Eq,
682 {
default() -> Self683     fn default() -> Self {
684         DependencyGraph {
685             edges: Default::default(),
686             labels: Default::default(),
687         }
688     }
689 }
690 
691 impl<K> DependencyGraph<K>
692 where
693     K: Hash + Eq + Clone,
694 {
695     /// Attempt to add an edge `from_id -> to_id` into the result graph.
add_edge( &mut self, from_id: RuntimeId, database_key: K, to_id: RuntimeId, path: impl IntoIterator<Item = K>, ) -> bool696     fn add_edge(
697         &mut self,
698         from_id: RuntimeId,
699         database_key: K,
700         to_id: RuntimeId,
701         path: impl IntoIterator<Item = K>,
702     ) -> bool {
703         assert_ne!(from_id, to_id);
704         debug_assert!(!self.edges.contains_key(&from_id));
705 
706         // First: walk the chain of things that `to_id` depends on,
707         // looking for us.
708         let mut p = to_id;
709         while let Some(q) = self.edges.get(&p).map(|edge| edge.id) {
710             if q == from_id {
711                 return false;
712             }
713 
714             p = q;
715         }
716 
717         self.edges.insert(
718             from_id,
719             Edge {
720                 id: to_id,
721                 path: path.into_iter().chain(Some(database_key.clone())).collect(),
722             },
723         );
724         self.labels
725             .entry(database_key.clone())
726             .or_default()
727             .push(from_id);
728         true
729     }
730 
remove_edge(&mut self, database_key: K, to_id: RuntimeId)731     fn remove_edge(&mut self, database_key: K, to_id: RuntimeId) {
732         let vec = self.labels.remove(&database_key).unwrap_or_default();
733 
734         for from_id in &vec {
735             let to_id1 = self.edges.remove(from_id).map(|edge| edge.id);
736             assert_eq!(Some(to_id), to_id1);
737         }
738     }
739 
push_cycle_path<'a>( &'a self, database_key: K, to: RuntimeId, local_path: impl IntoIterator<Item = K>, output: &mut Vec<K>, ) where K: std::fmt::Debug,740     fn push_cycle_path<'a>(
741         &'a self,
742         database_key: K,
743         to: RuntimeId,
744         local_path: impl IntoIterator<Item = K>,
745         output: &mut Vec<K>,
746     ) where
747         K: std::fmt::Debug,
748     {
749         let mut current = Some((to, std::slice::from_ref(&database_key)));
750         let mut last = None;
751         let mut local_path = Some(local_path);
752 
753         loop {
754             match current.take() {
755                 Some((id, path)) => {
756                     let link_key = path.last().unwrap();
757 
758                     output.extend(path.iter().cloned());
759 
760                     current = self.edges.get(&id).map(|edge| {
761                         let i = edge.path.iter().rposition(|p| p == link_key).unwrap();
762                         (edge.id, &edge.path[i + 1..])
763                     });
764 
765                     if current.is_none() {
766                         last = local_path.take().map(|local_path| {
767                             local_path
768                                 .into_iter()
769                                 .skip_while(move |p| *p != *link_key)
770                                 .skip(1)
771                         });
772                     }
773                 }
774                 None => break,
775             }
776         }
777 
778         if let Some(iter) = &mut last {
779             output.extend(iter);
780         }
781     }
782 }
783 
784 struct RevisionGuard {
785     shared_state: Arc<SharedState>,
786 }
787 
788 impl RevisionGuard {
new(shared_state: &Arc<SharedState>) -> Self789     fn new(shared_state: &Arc<SharedState>) -> Self {
790         // Subtle: we use a "recursive" lock here so that it is not an
791         // error to acquire a read-lock when one is already held (this
792         // happens when a query uses `snapshot` to spawn off parallel
793         // workers, for example).
794         //
795         // This has the side-effect that we are responsible to ensure
796         // that people contending for the write lock do not starve,
797         // but this is what we achieve via the cancellation mechanism.
798         //
799         // (In particular, since we only ever have one "mutating
800         // handle" to the database, the only contention for the global
801         // query lock occurs when there are "futures" evaluating
802         // queries in parallel, and those futures hold a read-lock
803         // already, so the starvation problem is more about them bring
804         // themselves to a close, versus preventing other people from
805         // *starting* work).
806         unsafe {
807             shared_state.query_lock.raw().lock_shared_recursive();
808         }
809 
810         Self {
811             shared_state: shared_state.clone(),
812         }
813     }
814 }
815 
816 impl Drop for RevisionGuard {
drop(&mut self)817     fn drop(&mut self) {
818         // Release our read-lock without using RAII. As documented in
819         // `Snapshot::new` above, this requires the unsafe keyword.
820         unsafe {
821             self.shared_state.query_lock.raw().unlock_shared();
822         }
823     }
824 }
825 
826 #[cfg(test)]
827 mod tests {
828     use super::*;
829 
830     #[test]
dependency_graph_path1()831     fn dependency_graph_path1() {
832         let mut graph = DependencyGraph::default();
833         let a = RuntimeId { counter: 0 };
834         let b = RuntimeId { counter: 1 };
835         assert!(graph.add_edge(a, 2, b, vec![1]));
836         // assert!(graph.add_edge(b, &1, a, vec![3, 2]));
837         let mut v = vec![];
838         graph.push_cycle_path(1, a, vec![3, 2], &mut v);
839         assert_eq!(v, vec![1, 2]);
840     }
841 
842     #[test]
dependency_graph_path2()843     fn dependency_graph_path2() {
844         let mut graph = DependencyGraph::default();
845         let a = RuntimeId { counter: 0 };
846         let b = RuntimeId { counter: 1 };
847         let c = RuntimeId { counter: 2 };
848         assert!(graph.add_edge(a, 3, b, vec![1]));
849         assert!(graph.add_edge(b, 4, c, vec![2, 3]));
850         // assert!(graph.add_edge(c, &1, a, vec![5, 6, 4, 7]));
851         let mut v = vec![];
852         graph.push_cycle_path(1, a, vec![5, 6, 4, 7], &mut v);
853         assert_eq!(v, vec![1, 3, 4, 7]);
854     }
855 }
856