1// Copyright ©2014 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"gonum.org/v1/gonum/graph"
9	"gonum.org/v1/gonum/graph/iterator"
10	"gonum.org/v1/gonum/graph/simple"
11)
12
13type GraphNode struct {
14	id        int64
15	neighbors []graph.Node
16	roots     []*GraphNode
17}
18
19func (g *GraphNode) Node(id int64) graph.Node {
20	if id == g.id {
21		return g
22	}
23
24	visited := map[int64]struct{}{g.id: {}}
25	for _, root := range g.roots {
26		if root.ID() == id {
27			return root
28		}
29
30		if root.has(id, visited) {
31			return root
32		}
33	}
34
35	for _, neigh := range g.neighbors {
36		if neigh.ID() == id {
37			return neigh
38		}
39
40		if gn, ok := neigh.(*GraphNode); ok {
41			if gn.has(id, visited) {
42				return gn
43			}
44		}
45	}
46
47	return nil
48}
49
50func (g *GraphNode) has(id int64, visited map[int64]struct{}) bool {
51	for _, root := range g.roots {
52		if _, ok := visited[root.ID()]; ok {
53			continue
54		}
55
56		visited[root.ID()] = struct{}{}
57		if root.ID() == id {
58			return true
59		}
60
61		if root.has(id, visited) {
62			return true
63		}
64
65	}
66
67	for _, neigh := range g.neighbors {
68		if _, ok := visited[neigh.ID()]; ok {
69			continue
70		}
71
72		visited[neigh.ID()] = struct{}{}
73		if neigh.ID() == id {
74			return true
75		}
76
77		if gn, ok := neigh.(*GraphNode); ok {
78			if gn.has(id, visited) {
79				return true
80			}
81		}
82	}
83
84	return false
85}
86
87func (g *GraphNode) Nodes() graph.Nodes {
88	toReturn := []graph.Node{g}
89	visited := map[int64]struct{}{g.id: {}}
90
91	for _, root := range g.roots {
92		toReturn = append(toReturn, root)
93		visited[root.ID()] = struct{}{}
94
95		toReturn = root.nodes(toReturn, visited)
96	}
97
98	for _, neigh := range g.neighbors {
99		toReturn = append(toReturn, neigh)
100		visited[neigh.ID()] = struct{}{}
101
102		if gn, ok := neigh.(*GraphNode); ok {
103			toReturn = gn.nodes(toReturn, visited)
104		}
105	}
106
107	return iterator.NewOrderedNodes(toReturn)
108}
109
110func (g *GraphNode) nodes(list []graph.Node, visited map[int64]struct{}) []graph.Node {
111	for _, root := range g.roots {
112		if _, ok := visited[root.ID()]; ok {
113			continue
114		}
115		visited[root.ID()] = struct{}{}
116		list = append(list, graph.Node(root))
117
118		list = root.nodes(list, visited)
119	}
120
121	for _, neigh := range g.neighbors {
122		if _, ok := visited[neigh.ID()]; ok {
123			continue
124		}
125
126		list = append(list, neigh)
127		if gn, ok := neigh.(*GraphNode); ok {
128			list = gn.nodes(list, visited)
129		}
130	}
131
132	return list
133}
134
135func (g *GraphNode) From(id int64) graph.Nodes {
136	if id == g.ID() {
137		return iterator.NewOrderedNodes(g.neighbors)
138	}
139
140	visited := map[int64]struct{}{g.id: {}}
141	for _, root := range g.roots {
142		visited[root.ID()] = struct{}{}
143
144		if result := root.findNeighbors(id, visited); result != nil {
145			return iterator.NewOrderedNodes(result)
146		}
147	}
148
149	for _, neigh := range g.neighbors {
150		visited[neigh.ID()] = struct{}{}
151
152		if gn, ok := neigh.(*GraphNode); ok {
153			if result := gn.findNeighbors(id, visited); result != nil {
154				return iterator.NewOrderedNodes(result)
155			}
156		}
157	}
158
159	return nil
160}
161
162func (g *GraphNode) findNeighbors(id int64, visited map[int64]struct{}) []graph.Node {
163	if id == g.ID() {
164		return g.neighbors
165	}
166
167	for _, root := range g.roots {
168		if _, ok := visited[root.ID()]; ok {
169			continue
170		}
171		visited[root.ID()] = struct{}{}
172
173		if result := root.findNeighbors(id, visited); result != nil {
174			return result
175		}
176	}
177
178	for _, neigh := range g.neighbors {
179		if _, ok := visited[neigh.ID()]; ok {
180			continue
181		}
182		visited[neigh.ID()] = struct{}{}
183
184		if gn, ok := neigh.(*GraphNode); ok {
185			if result := gn.findNeighbors(id, visited); result != nil {
186				return result
187			}
188		}
189	}
190
191	return nil
192}
193
194func (g *GraphNode) HasEdgeBetween(uid, vid int64) bool {
195	return g.EdgeBetween(uid, vid) != nil
196}
197
198func (g *GraphNode) Edge(uid, vid int64) graph.Edge {
199	return g.EdgeBetween(uid, vid)
200}
201
202func (g *GraphNode) EdgeBetween(uid, vid int64) graph.Edge {
203	if uid == g.id || vid == g.id {
204		for _, neigh := range g.neighbors {
205			if neigh.ID() == uid || neigh.ID() == vid {
206				return simple.Edge{F: g, T: neigh}
207			}
208		}
209		return nil
210	}
211
212	visited := map[int64]struct{}{g.id: {}}
213	for _, root := range g.roots {
214		visited[root.ID()] = struct{}{}
215		if result := root.edgeBetween(uid, vid, visited); result != nil {
216			return result
217		}
218	}
219
220	for _, neigh := range g.neighbors {
221		visited[neigh.ID()] = struct{}{}
222		if gn, ok := neigh.(*GraphNode); ok {
223			if result := gn.edgeBetween(uid, vid, visited); result != nil {
224				return result
225			}
226		}
227	}
228
229	return nil
230}
231
232func (g *GraphNode) edgeBetween(uid, vid int64, visited map[int64]struct{}) graph.Edge {
233	if uid == g.id || vid == g.id {
234		for _, neigh := range g.neighbors {
235			if neigh.ID() == uid || neigh.ID() == vid {
236				return simple.Edge{F: g, T: neigh}
237			}
238		}
239		return nil
240	}
241
242	for _, root := range g.roots {
243		if _, ok := visited[root.ID()]; ok {
244			continue
245		}
246		visited[root.ID()] = struct{}{}
247		if result := root.edgeBetween(uid, vid, visited); result != nil {
248			return result
249		}
250	}
251
252	for _, neigh := range g.neighbors {
253		if _, ok := visited[neigh.ID()]; ok {
254			continue
255		}
256
257		visited[neigh.ID()] = struct{}{}
258		if gn, ok := neigh.(*GraphNode); ok {
259			if result := gn.edgeBetween(uid, vid, visited); result != nil {
260				return result
261			}
262		}
263	}
264
265	return nil
266}
267
268func (g *GraphNode) ID() int64 {
269	return g.id
270}
271
272func (g *GraphNode) AddNeighbor(n *GraphNode) {
273	g.neighbors = append(g.neighbors, graph.Node(n))
274}
275
276func (g *GraphNode) AddRoot(n *GraphNode) {
277	g.roots = append(g.roots, n)
278}
279
280func NewGraphNode(id int64) *GraphNode {
281	return &GraphNode{id: id}
282}
283