1"""Implements the functionality of a directed graph."""
2
3import copy
4
5from ..exceptions import NonexistentNodeError, NonexistentEdgeError
6
7
8class DirectedGraph(object):
9    nodes = None
10    edges = None
11    next_node_id = 1
12    next_edge_id = 1
13
14    def __init__(self):
15        self.nodes = {}
16        self.edges = {}
17        self._num_nodes = 0
18        self._num_edges = 0
19
20    def __deepcopy__(self, memo=None):
21        graph = DirectedGraph()
22        graph.nodes = copy.deepcopy(self.nodes)
23        graph.edges = copy.deepcopy(self.edges)
24        graph.next_node_id = self.next_node_id
25        graph.next_edge_id = self.next_edge_id
26        graph._num_nodes = self._num_nodes
27        graph._num_edges = self._num_edges
28        return graph
29
30    def num_nodes(self):
31        """Returns the current number of nodes in the graph."""
32        return self._num_nodes
33
34    def num_edges(self):
35        """Returns the current number of edges in the graph."""
36        return self._num_edges
37
38    def generate_node_id(self):
39        node_id = self.next_node_id
40        self.next_node_id += 1
41        return node_id
42
43    def generate_edge_id(self):
44        edge_id = self.next_edge_id
45        self.next_edge_id += 1
46        return edge_id
47
48    def new_node(self):
49        """Adds a new, blank node to the graph.
50        Returns the node id of the new node."""
51        node_id = self.generate_node_id()
52
53        node = {'id': node_id,
54                'edges': [],
55                'data': {}
56        }
57
58        self.nodes[node_id] = node
59
60        self._num_nodes += 1
61
62        return node_id
63
64    def new_edge(self, node_a, node_b, cost=1):
65        """Adds a new edge from node_a to node_b that has a cost.
66        Returns the edge id of the new edge."""
67
68        # Verify that both nodes exist in the graph
69        try:
70            self.nodes[node_a]
71        except KeyError:
72            raise NonexistentNodeError(node_a)
73        try:
74            self.nodes[node_b]
75        except KeyError:
76            raise NonexistentNodeError(node_b)
77
78        # Create the new edge
79        edge_id = self.generate_edge_id()
80
81        edge = {'id': edge_id,
82                'vertices': (node_a, node_b),
83                'cost': cost,
84                'data': {}
85        }
86
87        self.edges[edge_id] = edge
88        self.nodes[node_a]['edges'].append(edge_id)
89
90        self._num_edges += 1
91
92        return edge_id
93
94    def neighbors(self, node_id):
95        """Find all the nodes where there is an edge from the specified node to that node.
96        Returns a list of node ids."""
97        node = self.get_node(node_id)
98        return [self.get_edge(edge_id)['vertices'][1] for edge_id in node['edges']]
99
100    def adjacent(self, node_a, node_b):
101        """Determines whether there is an edge from node_a to node_b.
102        Returns True if such an edge exists, otherwise returns False."""
103        neighbors = self.neighbors(node_a)
104        return node_b in neighbors
105
106    def edge_cost(self, node_a, node_b):
107        """Returns the cost of moving between the edge that connects node_a to node_b.
108        Returns +inf if no such edge exists."""
109        cost = float('inf')
110        node_object_a = self.get_node(node_a)
111        for edge_id in node_object_a['edges']:
112            edge = self.get_edge(edge_id)
113            tpl = (node_a, node_b)
114            if edge['vertices'] == tpl:
115                cost = edge['cost']
116                break
117        return cost
118
119    def get_node(self, node_id):
120        """Returns the node object identified by "node_id"."""
121        try:
122            node_object = self.nodes[node_id]
123        except KeyError:
124            raise NonexistentNodeError(node_id)
125        return node_object
126
127    def get_all_node_ids(self):
128        """Returns a list of all the node ids in the graph."""
129        return list(self.nodes.keys())
130
131    def get_all_node_objects(self):
132        """Returns a list of all the node objects in the graph."""
133        return list(self.nodes.values())
134
135    def get_edge(self, edge_id):
136        """Returns the edge object identified by "edge_id"."""
137        try:
138            edge_object = self.edges[edge_id]
139        except KeyError:
140            raise NonexistentEdgeError(edge_id)
141        return edge_object
142
143    def get_all_edge_ids(self):
144        """Returns a list of all the edge ids in the graph"""
145        return list(self.edges.keys())
146
147    def get_all_edge_objects(self):
148        """Returns a list of all the edge objects in the graph."""
149        return list(self.edges.values())
150
151    def delete_edge_by_id(self, edge_id):
152        """Removes the edge identified by "edge_id" from the graph."""
153        edge = self.get_edge(edge_id)
154
155        # Remove the edge from the "from node"
156        # --Determine the from node
157        from_node_id = edge['vertices'][0]
158        from_node = self.get_node(from_node_id)
159
160        # --Remove the edge from it
161        from_node['edges'].remove(edge_id)
162
163        # Remove the edge from the edge list
164        del self.edges[edge_id]
165
166        self._num_edges -= 1
167
168    def delete_edge_by_nodes(self, node_a, node_b):
169        """Removes all the edges from node_a to node_b from the graph."""
170        node = self.get_node(node_a)
171
172        # Determine the edge ids
173        edge_ids = []
174        for e_id in node['edges']:
175            edge = self.get_edge(e_id)
176            if edge['vertices'][1] == node_b:
177                edge_ids.append(e_id)
178
179        # Delete the edges
180        for e in edge_ids:
181            self.delete_edge_by_id(e)
182
183    def delete_node(self, node_id):
184        """Removes the node identified by node_id from the graph."""
185        node = self.get_node(node_id)
186
187        # Remove all edges from the node
188        for e in node['edges']:
189            self.delete_edge_by_id(e)
190
191        # Remove all edges to the node
192        edges = [edge_id for edge_id, edge in list(self.edges.items()) if edge['vertices'][1] == node_id]
193        for e in edges:
194            self.delete_edge_by_id(e)
195
196        # Remove the node from the node list
197        del self.nodes[node_id]
198
199        self._num_nodes -= 1
200
201    def move_edge_source(self, edge_id, node_a, node_b):
202        """Moves an edge originating from node_a so that it originates from node_b."""
203        # Grab the edge
204        edge = self.get_edge(edge_id)
205
206        # Alter the vertices
207        edge['vertices'] = (node_b, edge['vertices'][1])
208
209        # Remove the edge from node_a
210        node = self.get_node(node_a)
211        node['edges'].remove(edge_id)
212
213        # Add the edge to node_b
214        node = self.get_node(node_b)
215        node['edges'].append(edge_id)
216
217    def move_edge_target(self, edge_id, node_a):
218        """Moves an edge so that it targets node_a."""
219        # Grab the edge
220        edge = self.get_edge(edge_id)
221
222        # Alter the vertices
223        edge['vertices'] = (edge['vertices'][0], node_a)
224
225    def get_edge_ids_by_node_ids(self, node_a, node_b):
226        """Returns a list of edge ids connecting node_a to node_b."""
227        # Check if the nodes are adjacent
228        if not self.adjacent(node_a, node_b):
229            return []
230
231        # They're adjacent, so pull the list of edges from node_a and determine which ones point to node_b
232        node = self.get_node(node_a)
233        return [edge_id for edge_id in node['edges'] if self.get_edge(edge_id)['vertices'][1] == node_b]
234
235    def get_first_edge_id_by_node_ids(self, node_a, node_b):
236        """Returns the first (and possibly only) edge connecting node_a and node_b."""
237        ret = self.get_edge_ids_by_node_ids(node_a, node_b)
238        if not ret:
239            return None
240        else:
241            return ret[0]
242