1 use std::collections::hash_map::Entry::{Occupied, Vacant};
2 use std::collections::{BinaryHeap, HashMap};
3 
4 use std::hash::Hash;
5 
6 use super::visit::{EdgeRef, GraphBase, IntoEdges, VisitMap, Visitable};
7 use crate::scored::MinScored;
8 
9 use crate::algo::Measure;
10 
11 /// \[Generic\] A* shortest path algorithm.
12 ///
13 /// Computes the shortest path from `start` to `finish`, including the total path cost.
14 ///
15 /// `finish` is implicitly given via the `is_goal` callback, which should return `true` if the
16 /// given node is the finish node.
17 ///
18 /// The function `edge_cost` should return the cost for a particular edge. Edge costs must be
19 /// non-negative.
20 ///
21 /// The function `estimate_cost` should return the estimated cost to the finish for a particular
22 /// node. For the algorithm to find the actual shortest path, it should be admissible, meaning that
23 /// it should never overestimate the actual cost to get to the nearest goal node. Estimate costs
24 /// must also be non-negative.
25 ///
26 /// The graph should be `Visitable` and implement `IntoEdges`.
27 ///
28 /// # Example
29 /// ```
30 /// use petgraph::Graph;
31 /// use petgraph::algo::astar;
32 ///
33 /// let mut g = Graph::new();
34 /// let a = g.add_node((0., 0.));
35 /// let b = g.add_node((2., 0.));
36 /// let c = g.add_node((1., 1.));
37 /// let d = g.add_node((0., 2.));
38 /// let e = g.add_node((3., 3.));
39 /// let f = g.add_node((4., 2.));
40 /// g.extend_with_edges(&[
41 ///     (a, b, 2),
42 ///     (a, d, 4),
43 ///     (b, c, 1),
44 ///     (b, f, 7),
45 ///     (c, e, 5),
46 ///     (e, f, 1),
47 ///     (d, e, 1),
48 /// ]);
49 ///
50 /// // Graph represented with the weight of each edge
51 /// // Edges with '*' are part of the optimal path.
52 /// //
53 /// //     2       1
54 /// // a ----- b ----- c
55 /// // | 4*    | 7     |
56 /// // d       f       | 5
57 /// // | 1*    | 1*    |
58 /// // \------ e ------/
59 ///
60 /// let path = astar(&g, a, |finish| finish == f, |e| *e.weight(), |_| 0);
61 /// assert_eq!(path, Some((6, vec![a, d, e, f])));
62 /// ```
63 ///
64 /// Returns the total cost + the path of subsequent `NodeId` from start to finish, if one was
65 /// found.
astar<G, F, H, K, IsGoal>( graph: G, start: G::NodeId, mut is_goal: IsGoal, mut edge_cost: F, mut estimate_cost: H, ) -> Option<(K, Vec<G::NodeId>)> where G: IntoEdges + Visitable, IsGoal: FnMut(G::NodeId) -> bool, G::NodeId: Eq + Hash, F: FnMut(G::EdgeRef) -> K, H: FnMut(G::NodeId) -> K, K: Measure + Copy,66 pub fn astar<G, F, H, K, IsGoal>(
67     graph: G,
68     start: G::NodeId,
69     mut is_goal: IsGoal,
70     mut edge_cost: F,
71     mut estimate_cost: H,
72 ) -> Option<(K, Vec<G::NodeId>)>
73 where
74     G: IntoEdges + Visitable,
75     IsGoal: FnMut(G::NodeId) -> bool,
76     G::NodeId: Eq + Hash,
77     F: FnMut(G::EdgeRef) -> K,
78     H: FnMut(G::NodeId) -> K,
79     K: Measure + Copy,
80 {
81     let mut visited = graph.visit_map();
82     let mut visit_next = BinaryHeap::new();
83     let mut scores = HashMap::new();
84     let mut path_tracker = PathTracker::<G>::new();
85 
86     let zero_score = K::default();
87     scores.insert(start, zero_score);
88     visit_next.push(MinScored(estimate_cost(start), start));
89 
90     while let Some(MinScored(_, node)) = visit_next.pop() {
91         if is_goal(node) {
92             let path = path_tracker.reconstruct_path_to(node);
93             let cost = scores[&node];
94             return Some((cost, path));
95         }
96 
97         // Don't visit the same node several times, as the first time it was visited it was using
98         // the shortest available path.
99         if !visited.visit(node) {
100             continue;
101         }
102 
103         // This lookup can be unwrapped without fear of panic since the node was necessarily scored
104         // before adding him to `visit_next`.
105         let node_score = scores[&node];
106 
107         for edge in graph.edges(node) {
108             let next = edge.target();
109             if visited.is_visited(&next) {
110                 continue;
111             }
112 
113             let mut next_score = node_score + edge_cost(edge);
114 
115             match scores.entry(next) {
116                 Occupied(ent) => {
117                     let old_score = *ent.get();
118                     if next_score < old_score {
119                         *ent.into_mut() = next_score;
120                         path_tracker.set_predecessor(next, node);
121                     } else {
122                         next_score = old_score;
123                     }
124                 }
125                 Vacant(ent) => {
126                     ent.insert(next_score);
127                     path_tracker.set_predecessor(next, node);
128                 }
129             }
130 
131             let next_estimate_score = next_score + estimate_cost(next);
132             visit_next.push(MinScored(next_estimate_score, next));
133         }
134     }
135 
136     None
137 }
138 
139 struct PathTracker<G>
140 where
141     G: GraphBase,
142     G::NodeId: Eq + Hash,
143 {
144     came_from: HashMap<G::NodeId, G::NodeId>,
145 }
146 
147 impl<G> PathTracker<G>
148 where
149     G: GraphBase,
150     G::NodeId: Eq + Hash,
151 {
new() -> PathTracker<G>152     fn new() -> PathTracker<G> {
153         PathTracker {
154             came_from: HashMap::new(),
155         }
156     }
157 
set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId)158     fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
159         self.came_from.insert(node, previous);
160     }
161 
reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId>162     fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
163         let mut path = vec![last];
164 
165         let mut current = last;
166         while let Some(&previous) = self.came_from.get(&current) {
167             path.push(previous);
168             current = previous;
169         }
170 
171         path.reverse();
172 
173         path
174     }
175 }
176