1package dag
2
3import (
4	"fmt"
5	"sort"
6	"strings"
7
8	"github.com/hashicorp/terraform-plugin-sdk/internal/tfdiags"
9
10	"github.com/hashicorp/go-multierror"
11)
12
13// AcyclicGraph is a specialization of Graph that cannot have cycles. With
14// this property, we get the property of sane graph traversal.
15type AcyclicGraph struct {
16	Graph
17}
18
19// WalkFunc is the callback used for walking the graph.
20type WalkFunc func(Vertex) tfdiags.Diagnostics
21
22// DepthWalkFunc is a walk function that also receives the current depth of the
23// walk as an argument
24type DepthWalkFunc func(Vertex, int) error
25
26func (g *AcyclicGraph) DirectedGraph() Grapher {
27	return g
28}
29
30// Returns a Set that includes every Vertex yielded by walking down from the
31// provided starting Vertex v.
32func (g *AcyclicGraph) Ancestors(v Vertex) (*Set, error) {
33	s := new(Set)
34	start := AsVertexList(g.DownEdges(v))
35	memoFunc := func(v Vertex, d int) error {
36		s.Add(v)
37		return nil
38	}
39
40	if err := g.DepthFirstWalk(start, memoFunc); err != nil {
41		return nil, err
42	}
43
44	return s, nil
45}
46
47// Returns a Set that includes every Vertex yielded by walking up from the
48// provided starting Vertex v.
49func (g *AcyclicGraph) Descendents(v Vertex) (*Set, error) {
50	s := new(Set)
51	start := AsVertexList(g.UpEdges(v))
52	memoFunc := func(v Vertex, d int) error {
53		s.Add(v)
54		return nil
55	}
56
57	if err := g.ReverseDepthFirstWalk(start, memoFunc); err != nil {
58		return nil, err
59	}
60
61	return s, nil
62}
63
64// Root returns the root of the DAG, or an error.
65//
66// Complexity: O(V)
67func (g *AcyclicGraph) Root() (Vertex, error) {
68	roots := make([]Vertex, 0, 1)
69	for _, v := range g.Vertices() {
70		if g.UpEdges(v).Len() == 0 {
71			roots = append(roots, v)
72		}
73	}
74
75	if len(roots) > 1 {
76		// TODO(mitchellh): make this error message a lot better
77		return nil, fmt.Errorf("multiple roots: %#v", roots)
78	}
79
80	if len(roots) == 0 {
81		return nil, fmt.Errorf("no roots found")
82	}
83
84	return roots[0], nil
85}
86
87// TransitiveReduction performs the transitive reduction of graph g in place.
88// The transitive reduction of a graph is a graph with as few edges as
89// possible with the same reachability as the original graph. This means
90// that if there are three nodes A => B => C, and A connects to both
91// B and C, and B connects to C, then the transitive reduction is the
92// same graph with only a single edge between A and B, and a single edge
93// between B and C.
94//
95// The graph must be valid for this operation to behave properly. If
96// Validate() returns an error, the behavior is undefined and the results
97// will likely be unexpected.
98//
99// Complexity: O(V(V+E)), or asymptotically O(VE)
100func (g *AcyclicGraph) TransitiveReduction() {
101	// For each vertex u in graph g, do a DFS starting from each vertex
102	// v such that the edge (u,v) exists (v is a direct descendant of u).
103	//
104	// For each v-prime reachable from v, remove the edge (u, v-prime).
105	defer g.debug.BeginOperation("TransitiveReduction", "").End("")
106
107	for _, u := range g.Vertices() {
108		uTargets := g.DownEdges(u)
109		vs := AsVertexList(g.DownEdges(u))
110
111		g.depthFirstWalk(vs, false, func(v Vertex, d int) error {
112			shared := uTargets.Intersection(g.DownEdges(v))
113			for _, vPrime := range AsVertexList(shared) {
114				g.RemoveEdge(BasicEdge(u, vPrime))
115			}
116
117			return nil
118		})
119	}
120}
121
122// Validate validates the DAG. A DAG is valid if it has a single root
123// with no cycles.
124func (g *AcyclicGraph) Validate() error {
125	if _, err := g.Root(); err != nil {
126		return err
127	}
128
129	// Look for cycles of more than 1 component
130	var err error
131	cycles := g.Cycles()
132	if len(cycles) > 0 {
133		for _, cycle := range cycles {
134			cycleStr := make([]string, len(cycle))
135			for j, vertex := range cycle {
136				cycleStr[j] = VertexName(vertex)
137			}
138
139			err = multierror.Append(err, fmt.Errorf(
140				"Cycle: %s", strings.Join(cycleStr, ", ")))
141		}
142	}
143
144	// Look for cycles to self
145	for _, e := range g.Edges() {
146		if e.Source() == e.Target() {
147			err = multierror.Append(err, fmt.Errorf(
148				"Self reference: %s", VertexName(e.Source())))
149		}
150	}
151
152	return err
153}
154
155func (g *AcyclicGraph) Cycles() [][]Vertex {
156	var cycles [][]Vertex
157	for _, cycle := range StronglyConnected(&g.Graph) {
158		if len(cycle) > 1 {
159			cycles = append(cycles, cycle)
160		}
161	}
162	return cycles
163}
164
165// Walk walks the graph, calling your callback as each node is visited.
166// This will walk nodes in parallel if it can. The resulting diagnostics
167// contains problems from all graphs visited, in no particular order.
168func (g *AcyclicGraph) Walk(cb WalkFunc) tfdiags.Diagnostics {
169	defer g.debug.BeginOperation(typeWalk, "").End("")
170
171	w := &Walker{Callback: cb, Reverse: true}
172	w.Update(g)
173	return w.Wait()
174}
175
176// simple convenience helper for converting a dag.Set to a []Vertex
177func AsVertexList(s *Set) []Vertex {
178	rawList := s.List()
179	vertexList := make([]Vertex, len(rawList))
180	for i, raw := range rawList {
181		vertexList[i] = raw.(Vertex)
182	}
183	return vertexList
184}
185
186type vertexAtDepth struct {
187	Vertex Vertex
188	Depth  int
189}
190
191// depthFirstWalk does a depth-first walk of the graph starting from
192// the vertices in start.
193func (g *AcyclicGraph) DepthFirstWalk(start []Vertex, f DepthWalkFunc) error {
194	return g.depthFirstWalk(start, true, f)
195}
196
197// This internal method provides the option of not sorting the vertices during
198// the walk, which we use for the Transitive reduction.
199// Some configurations can lead to fully-connected subgraphs, which makes our
200// transitive reduction algorithm O(n^3). This is still passable for the size
201// of our graphs, but the additional n^2 sort operations would make this
202// uncomputable in a reasonable amount of time.
203func (g *AcyclicGraph) depthFirstWalk(start []Vertex, sorted bool, f DepthWalkFunc) error {
204	defer g.debug.BeginOperation(typeDepthFirstWalk, "").End("")
205
206	seen := make(map[Vertex]struct{})
207	frontier := make([]*vertexAtDepth, len(start))
208	for i, v := range start {
209		frontier[i] = &vertexAtDepth{
210			Vertex: v,
211			Depth:  0,
212		}
213	}
214	for len(frontier) > 0 {
215		// Pop the current vertex
216		n := len(frontier)
217		current := frontier[n-1]
218		frontier = frontier[:n-1]
219
220		// Check if we've seen this already and return...
221		if _, ok := seen[current.Vertex]; ok {
222			continue
223		}
224		seen[current.Vertex] = struct{}{}
225
226		// Visit the current node
227		if err := f(current.Vertex, current.Depth); err != nil {
228			return err
229		}
230
231		// Visit targets of this in a consistent order.
232		targets := AsVertexList(g.DownEdges(current.Vertex))
233
234		if sorted {
235			sort.Sort(byVertexName(targets))
236		}
237
238		for _, t := range targets {
239			frontier = append(frontier, &vertexAtDepth{
240				Vertex: t,
241				Depth:  current.Depth + 1,
242			})
243		}
244	}
245
246	return nil
247}
248
249// reverseDepthFirstWalk does a depth-first walk _up_ the graph starting from
250// the vertices in start.
251func (g *AcyclicGraph) ReverseDepthFirstWalk(start []Vertex, f DepthWalkFunc) error {
252	defer g.debug.BeginOperation(typeReverseDepthFirstWalk, "").End("")
253
254	seen := make(map[Vertex]struct{})
255	frontier := make([]*vertexAtDepth, len(start))
256	for i, v := range start {
257		frontier[i] = &vertexAtDepth{
258			Vertex: v,
259			Depth:  0,
260		}
261	}
262	for len(frontier) > 0 {
263		// Pop the current vertex
264		n := len(frontier)
265		current := frontier[n-1]
266		frontier = frontier[:n-1]
267
268		// Check if we've seen this already and return...
269		if _, ok := seen[current.Vertex]; ok {
270			continue
271		}
272		seen[current.Vertex] = struct{}{}
273
274		// Add next set of targets in a consistent order.
275		targets := AsVertexList(g.UpEdges(current.Vertex))
276		sort.Sort(byVertexName(targets))
277		for _, t := range targets {
278			frontier = append(frontier, &vertexAtDepth{
279				Vertex: t,
280				Depth:  current.Depth + 1,
281			})
282		}
283
284		// Visit the current node
285		if err := f(current.Vertex, current.Depth); err != nil {
286			return err
287		}
288	}
289
290	return nil
291}
292
293// byVertexName implements sort.Interface so a list of Vertices can be sorted
294// consistently by their VertexName
295type byVertexName []Vertex
296
297func (b byVertexName) Len() int      { return len(b) }
298func (b byVertexName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
299func (b byVertexName) Less(i, j int) bool {
300	return VertexName(b[i]) < VertexName(b[j])
301}
302