1import uuid 2 3from .plugin import SchedulerPlugin 4 5 6class GraphLayout(SchedulerPlugin): 7 """Dynamic graph layout during computation 8 9 This assigns (x, y) locations to all tasks quickly and dynamically as new 10 tasks are added. This scales to a few thousand nodes. 11 12 It is commonly used with distributed/dashboard/components/scheduler.py::TaskGraph, which 13 is rendered at /graph on the diagnostic dashboard. 14 """ 15 16 def __init__(self, scheduler): 17 self.name = f"graph-layout-{uuid.uuid4()}" 18 self.x = {} 19 self.y = {} 20 self.collision = {} 21 self.scheduler = scheduler 22 self.index = {} 23 self.index_edge = {} 24 self.next_y = 0 25 self.next_index = 0 26 self.next_edge_index = 0 27 self.new = [] 28 self.new_edges = [] 29 self.state_updates = [] 30 self.visible_updates = [] 31 self.visible_edge_updates = [] 32 33 if self.scheduler.tasks: 34 dependencies = { 35 k: [ds.key for ds in ts.dependencies] 36 for k, ts in scheduler.tasks.items() 37 } 38 priority = {k: ts.priority for k, ts in scheduler.tasks.items()} 39 self.update_graph( 40 self.scheduler, 41 tasks=self.scheduler.tasks, 42 dependencies=dependencies, 43 priority=priority, 44 ) 45 46 def update_graph( 47 self, scheduler, dependencies=None, priority=None, tasks=None, **kwargs 48 ): 49 stack = sorted(tasks, key=lambda k: priority.get(k, 0), reverse=True) 50 while stack: 51 key = stack.pop() 52 if key in self.x or key not in scheduler.tasks: 53 continue 54 deps = dependencies.get(key, ()) 55 if deps: 56 if not all(dep in self.y for dep in deps): 57 stack.append(key) 58 stack.extend( 59 sorted(deps, key=lambda k: priority.get(k, 0), reverse=True) 60 ) 61 continue 62 else: 63 total_deps = sum( 64 len(scheduler.tasks[dep].dependents) for dep in deps 65 ) 66 y = sum( 67 self.y[dep] * len(scheduler.tasks[dep].dependents) / total_deps 68 for dep in deps 69 ) 70 x = max(self.x[dep] for dep in deps) + 1 71 else: 72 x = 0 73 y = self.next_y 74 self.next_y += 1 75 76 if (x, y) in self.collision: 77 old_x, old_y = x, y 78 x, y = self.collision[(x, y)] 79 y += 0.1 80 self.collision[old_x, old_y] = (x, y) 81 else: 82 self.collision[(x, y)] = (x, y) 83 84 self.x[key] = x 85 self.y[key] = y 86 self.index[key] = self.next_index 87 self.next_index = self.next_index + 1 88 self.new.append(key) 89 for dep in deps: 90 edge = (dep, key) 91 self.index_edge[edge] = self.next_edge_index 92 self.next_edge_index += 1 93 self.new_edges.append(edge) 94 95 def transition(self, key, start, finish, *args, **kwargs): 96 if finish != "forgotten": 97 self.state_updates.append((self.index[key], finish)) 98 else: 99 self.visible_updates.append((self.index[key], "False")) 100 task = self.scheduler.tasks[key] 101 for dep in task.dependents: 102 edge = (key, dep.key) 103 self.visible_edge_updates.append((self.index_edge.pop(edge), "False")) 104 for dep in task.dependencies: 105 self.visible_edge_updates.append( 106 (self.index_edge.pop((dep.key, key)), "False") 107 ) 108 109 try: 110 del self.collision[(self.x[key], self.y[key])] 111 except KeyError: 112 pass 113 114 for collection in [self.x, self.y, self.index]: 115 del collection[key] 116 117 def reset_index(self): 118 """Reset the index and refill new and new_edges 119 120 From time to time TaskGraph wants to remove invisible nodes and reset 121 all of its indices. This helps. 122 """ 123 self.new = [] 124 self.new_edges = [] 125 self.visible_updates = [] 126 self.state_updates = [] 127 self.visible_edge_updates = [] 128 129 self.index = {} 130 self.next_index = 0 131 self.index_edge = {} 132 self.next_edge_index = 0 133 134 for key in self.x: 135 self.index[key] = self.next_index 136 self.next_index += 1 137 self.new.append(key) 138 for dep in self.scheduler.tasks[key].dependencies: 139 edge = (dep.key, key) 140 self.index_edge[edge] = self.next_edge_index 141 self.next_edge_index += 1 142 self.new_edges.append(edge) 143