1import uuid
2
3from .plugin import SchedulerPlugin
4
5
6class GraphLayout(SchedulerPlugin):
7    """Dynamic graph layout during computation
8
9    This assigns (x, y) locations to all tasks quickly and dynamically as new
10    tasks are added.  This scales to a few thousand nodes.
11
12    It is commonly used with distributed/dashboard/components/scheduler.py::TaskGraph, which
13    is rendered at /graph on the diagnostic dashboard.
14    """
15
16    def __init__(self, scheduler):
17        self.name = f"graph-layout-{uuid.uuid4()}"
18        self.x = {}
19        self.y = {}
20        self.collision = {}
21        self.scheduler = scheduler
22        self.index = {}
23        self.index_edge = {}
24        self.next_y = 0
25        self.next_index = 0
26        self.next_edge_index = 0
27        self.new = []
28        self.new_edges = []
29        self.state_updates = []
30        self.visible_updates = []
31        self.visible_edge_updates = []
32
33        if self.scheduler.tasks:
34            dependencies = {
35                k: [ds.key for ds in ts.dependencies]
36                for k, ts in scheduler.tasks.items()
37            }
38            priority = {k: ts.priority for k, ts in scheduler.tasks.items()}
39            self.update_graph(
40                self.scheduler,
41                tasks=self.scheduler.tasks,
42                dependencies=dependencies,
43                priority=priority,
44            )
45
46    def update_graph(
47        self, scheduler, dependencies=None, priority=None, tasks=None, **kwargs
48    ):
49        stack = sorted(tasks, key=lambda k: priority.get(k, 0), reverse=True)
50        while stack:
51            key = stack.pop()
52            if key in self.x or key not in scheduler.tasks:
53                continue
54            deps = dependencies.get(key, ())
55            if deps:
56                if not all(dep in self.y for dep in deps):
57                    stack.append(key)
58                    stack.extend(
59                        sorted(deps, key=lambda k: priority.get(k, 0), reverse=True)
60                    )
61                    continue
62                else:
63                    total_deps = sum(
64                        len(scheduler.tasks[dep].dependents) for dep in deps
65                    )
66                    y = sum(
67                        self.y[dep] * len(scheduler.tasks[dep].dependents) / total_deps
68                        for dep in deps
69                    )
70                    x = max(self.x[dep] for dep in deps) + 1
71            else:
72                x = 0
73                y = self.next_y
74                self.next_y += 1
75
76            if (x, y) in self.collision:
77                old_x, old_y = x, y
78                x, y = self.collision[(x, y)]
79                y += 0.1
80                self.collision[old_x, old_y] = (x, y)
81            else:
82                self.collision[(x, y)] = (x, y)
83
84            self.x[key] = x
85            self.y[key] = y
86            self.index[key] = self.next_index
87            self.next_index = self.next_index + 1
88            self.new.append(key)
89            for dep in deps:
90                edge = (dep, key)
91                self.index_edge[edge] = self.next_edge_index
92                self.next_edge_index += 1
93                self.new_edges.append(edge)
94
95    def transition(self, key, start, finish, *args, **kwargs):
96        if finish != "forgotten":
97            self.state_updates.append((self.index[key], finish))
98        else:
99            self.visible_updates.append((self.index[key], "False"))
100            task = self.scheduler.tasks[key]
101            for dep in task.dependents:
102                edge = (key, dep.key)
103                self.visible_edge_updates.append((self.index_edge.pop(edge), "False"))
104            for dep in task.dependencies:
105                self.visible_edge_updates.append(
106                    (self.index_edge.pop((dep.key, key)), "False")
107                )
108
109            try:
110                del self.collision[(self.x[key], self.y[key])]
111            except KeyError:
112                pass
113
114            for collection in [self.x, self.y, self.index]:
115                del collection[key]
116
117    def reset_index(self):
118        """Reset the index and refill new and new_edges
119
120        From time to time TaskGraph wants to remove invisible nodes and reset
121        all of its indices.  This helps.
122        """
123        self.new = []
124        self.new_edges = []
125        self.visible_updates = []
126        self.state_updates = []
127        self.visible_edge_updates = []
128
129        self.index = {}
130        self.next_index = 0
131        self.index_edge = {}
132        self.next_edge_index = 0
133
134        for key in self.x:
135            self.index[key] = self.next_index
136            self.next_index += 1
137            self.new.append(key)
138            for dep in self.scheduler.tasks[key].dependencies:
139                edge = (dep.key, key)
140                self.index_edge[edge] = self.next_edge_index
141                self.next_edge_index += 1
142                self.new_edges.append(edge)
143