1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,too-many-locals
18"""Partitioned Boolean Quadratic Programming Tuner"""
19from ._base import INVALID_LAYOUT_TIME
20from .base_graph_tuner import BaseGraphTuner
21from .utils import is_boundary_node, has_multiple_inputs
22
23
24class PBQPTuner(BaseGraphTuner):
25    """An approximation method to deal with intractably
26    large size of graph tuning problem.
27
28    This graph coloring algorithm mainly comes from:
29
30    Lang Hames and Bernhard Scholz.
31    Nearly optimal register allocation with pbqp.JMLC 2006.
32    LNCS, vol.4228,pp. 346-361, 2016
33    """
34
35    def __init__(self, *args, **kwargs):
36        """Create a partitioned boolean quadratic programming tuner."""
37        super(PBQPTuner, self).__init__(*args, **kwargs)
38
39        # Remove input and ruled_out nodes
40        input_names = self._input_shapes.keys()
41        for node_idx in self._out_nodes_dict:
42            node = self._node_list[node_idx]
43            if is_boundary_node(node, input_names):
44                for out_node_idx in self._out_nodes_dict[node_idx]:
45                    self._in_nodes_dict[out_node_idx].remove(node_idx)
46
47        self._adj_dict = {}
48        for node_idx in self._in_nodes_dict:
49            self._adj_dict[node_idx] = list(self._in_nodes_dict[node_idx]) + list(
50                self._out_nodes_dict[node_idx]
51            )
52
53        self._record_cost_dict = {}
54        for key in self._in_nodes_dict:
55            self._record_cost_dict[key] = []
56            for record in self._node_list[key]["record_candidates"]:
57                self._record_cost_dict[key].append(record[1].costs[0])
58
59        self._max_degree = -1
60        self._node_degree_dict = {}
61        for node_idx in self._in_nodes_dict:
62            node_degree = self._get_degree(node_idx)
63            self._node_degree_dict[node_idx] = node_degree
64            self._max_degree = max(self._max_degree, node_degree)
65
66        self._stack = []
67        self._buckets = [[] for _ in range(self._max_degree + 2)]
68        for node_idx in sorted(self._in_nodes_dict):
69            node_degree = self._get_degree(node_idx)
70            self._buckets[node_degree].append(node_idx)
71
72        self._is_optimal = True
73
74    def _get_degree(self, node_idx):
75        """Get node degree."""
76        return len(self._adj_dict[node_idx])
77
78    def _reorder_adj_nodes(self, node_idx):
79        """Update buckets list with current adjacency list."""
80        for adj_node in self._adj_dict[node_idx]:
81            current_degree = self._get_degree(adj_node)
82            prev_degree = self._node_degree_dict[adj_node]
83            if prev_degree != current_degree:
84                self._buckets[prev_degree].remove(adj_node)
85                self._buckets[current_degree].insert(0, adj_node)
86                self._node_degree_dict[adj_node] = current_degree
87
88    def _remove_node(self, node_idx):
89        """Remove node from graph. Update adjacency list accordingly."""
90        node_degree = self._get_degree(node_idx)
91        self._buckets[node_degree].remove(node_idx)
92        for adj_node in self._adj_dict[node_idx]:
93            self._adj_dict[adj_node].remove(node_idx)
94
95    def _insert_edge(self, node_x, node_y, adj_cost_matrix):
96        """Insert an edge between two nodes."""
97        self._layout_transform_interlayer_cost[(node_x, node_y)] = adj_cost_matrix
98        self._layout_transform_interlayer_cost[(node_y, node_x)] = []
99        for i in range(len(adj_cost_matrix[0])):
100            self._layout_transform_interlayer_cost[(node_y, node_x)].append([])
101            for cost_vec in adj_cost_matrix:
102                self._layout_transform_interlayer_cost[(node_y, node_x)][i].append(cost_vec[i])
103
104        self._adj_dict[node_x].append(node_y)
105        self._adj_dict[node_y].append(node_x)
106
107    def _backward_insert_node(self, node_idx):
108        """Reinsert node in backward pass."""
109        for adj_node in self._adj_dict[node_idx]:
110            self._adj_dict[adj_node].append(node_idx)
111
112    def _RI_reduction(self, node_idx):
113        """Reduce nodes with degree 1."""
114        adj_node = self._adj_dict[node_idx][0]
115        ltf_matrix = self._layout_transform_interlayer_cost[(adj_node, node_idx)]
116        for i, cost_vec in enumerate(ltf_matrix):
117            min_cost = INVALID_LAYOUT_TIME
118            for j, cost in enumerate(cost_vec):
119                min_cost = min(min_cost, cost + self._record_cost_dict[node_idx][j])
120            self._record_cost_dict[adj_node][i] += min_cost
121        self._remove_node(node_idx)
122        self._reorder_adj_nodes(node_idx)
123        self._stack.append(node_idx)
124
125    def _RII_reduction(self, node_idx):
126        """Reduce nodes with degree 2."""
127        adj_node_x, adj_node_y = self._adj_dict[node_idx]
128        ltf_matrix_x = self._layout_transform_interlayer_cost[(adj_node_x, node_idx)]
129        ltf_matrix_y = self._layout_transform_interlayer_cost[(adj_node_y, node_idx)]
130        delta_matrix = [[] for _ in range(len(ltf_matrix_x))]
131        for i, cost_vec_x in enumerate(ltf_matrix_x):
132            for j, cost_vec_y in enumerate(ltf_matrix_y):
133                min_cost = INVALID_LAYOUT_TIME
134                for k in range(len(self._record_cost_dict[node_idx])):
135                    min_cost = min(
136                        min_cost,
137                        cost_vec_x[k] + cost_vec_y[k] + self._record_cost_dict[node_idx][k],
138                    )
139                delta_matrix[i].append(min_cost)
140
141        if adj_node_x == adj_node_y:
142            for i, delta_row in enumerate(delta_matrix):
143                self._record_cost_dict[adj_node_x][i] += delta_row[i]
144        elif adj_node_x in self._adj_dict[adj_node_y]:
145            for i, _ in enumerate(delta_matrix):
146                for j, delta in enumerate(delta_matrix[i]):
147                    self._layout_transform_interlayer_cost[(adj_node_x, adj_node_y)][i][j] += delta
148                    self._layout_transform_interlayer_cost[(adj_node_y, adj_node_x)][j][i] += delta
149        else:
150            self._insert_edge(adj_node_x, adj_node_y, delta_matrix)
151
152        self._remove_node(node_idx)
153        self._reorder_adj_nodes(node_idx)
154        self._stack.append(node_idx)
155
156    def _RN_reduction(self, node_idx):
157        """Reduce nodes with degree greater than 2."""
158        min_cost = INVALID_LAYOUT_TIME
159        record_idx = -1
160
161        for i, record_cost in enumerate(self._record_cost_dict[node_idx]):
162            current_cost = record_cost
163            for adj_node in self._adj_dict[node_idx]:
164                ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)]
165                adj_record_cost = list(self._record_cost_dict[adj_node])
166                for j, ltf_cost in enumerate(ltf_matrix[i]):
167                    adj_record_cost[j] += ltf_cost
168                current_cost += min(adj_record_cost)
169            if current_cost < min_cost:
170                min_cost = current_cost
171                record_idx = i
172
173        if record_idx < 0:
174            raise RuntimeError(
175                "Can't find a soltuion for node %d when " "applying RN reduction" % node_idx
176            )
177        self._optimal_record_dict[node_idx] = record_idx
178        self._is_optimal = False
179
180        for adj_node in self._adj_dict[node_idx]:
181            ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)]
182            for i, ltf_cost in enumerate(ltf_matrix[record_idx]):
183                self._record_cost_dict[adj_node][i] += ltf_cost
184
185        self._remove_node(node_idx)
186        self._reorder_adj_nodes(node_idx)
187        self._stack.append(node_idx)
188
189    def _forward(self):
190        """Forward pass in PBQP to reduce nodes."""
191        while True:
192            if self._buckets[1]:
193                node_idx = self._buckets[1][0]
194                self._RI_reduction(node_idx)
195            elif self._max_degree >= 2 and self._buckets[2]:
196                node_idx = self._buckets[2][0]
197                self._RII_reduction(node_idx)
198            elif self._max_degree >= 3:
199                max_degree_node = -1
200                for i in range(self._max_degree, 2, -1):
201                    if self._buckets[i]:
202                        max_degree_node = self._buckets[i][0]
203                        self._RN_reduction(max_degree_node)
204                        break
205                if max_degree_node < 0:
206                    break
207            else:
208                break
209
210    def _backward(self):
211        """Backward pass in PBQP to generate optimal solution."""
212        # Solve nodes left in the forward graph
213        for node_idx in self._buckets[0]:
214            record_costs = self._record_cost_dict[node_idx]
215            min_cost = min(record_costs)
216            self._optimal_record_dict[node_idx] = record_costs.index(min_cost)
217
218        # Solve nodes with one or two degrees
219        for node_idx in reversed(self._stack):
220            self._backward_insert_node(node_idx)
221            if node_idx not in self._optimal_record_dict:
222                record_costs = list(self._record_cost_dict[node_idx])
223                for adj_node in self._adj_dict[node_idx]:
224                    adj_optimal_idx = self._optimal_record_dict[adj_node]
225                    for i, _ in enumerate(record_costs):
226                        record_costs[i] += self._layout_transform_interlayer_cost[
227                            (node_idx, adj_node)
228                        ][i][adj_optimal_idx]
229                min_cost = min(record_costs)
230                self._optimal_record_dict[node_idx] = record_costs.index(min_cost)
231
232    def run(self, **kwargs):
233        """Run partitioned boolean quadratic programming tuner."""
234        self._logger.info("Start to run PBQP algorithm...")
235        # Define virtual record lists and layout transformaton matrices
236        # for multi-input nodes.
237        input_names = self._input_shapes.keys()
238        temp = {}
239        for key, val in self._in_nodes_dict.items():
240            target_input_idx = -1
241            target_input_pos = -1
242            if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
243                for i, item in enumerate(val):
244                    node = self._node_list[item]
245                    if not is_boundary_node(node, input_names):
246                        target_input_idx = item
247                        target_input_pos = i
248                        break
249
250                # Skip boundary operator
251                if target_input_idx < 0:
252                    continue
253
254                temp[(target_input_idx, key)] = []
255                record_candidates = self._node_list[target_input_idx]["record_candidates"]
256                for j in range(len(record_candidates)):
257                    temp[(target_input_idx, key)].append([])
258                    for k in range(len(record_candidates)):
259                        temp[(target_input_idx, key)][j].append(
260                            0 if j == k else INVALID_LAYOUT_TIME
261                        )
262
263                for j in range(target_input_pos + 1, len(val)):
264                    input_idx = val[j]
265                    input_node = self._node_list[input_idx]
266                    if is_boundary_node(input_node, input_names):
267                        continue
268                    temp[(input_idx, key)] = self._layout_transform_interlayer_cost[
269                        (input_idx, target_input_idx)
270                    ]
271        self._layout_transform_interlayer_cost.update(temp)
272
273        # Create reverse layout transformation matrices
274        temp = {}
275        for idx_pair, ltf_matrix in self._layout_transform_interlayer_cost.items():
276            reverse_key = (idx_pair[1], idx_pair[0])
277            reverse_matrix = [[] for _ in range(len(ltf_matrix[0]))]
278            for i, _ in enumerate(ltf_matrix):
279                for j, ltf in enumerate(ltf_matrix[i]):
280                    reverse_matrix[j].append(ltf)
281            temp[reverse_key] = reverse_matrix
282        self._layout_transform_interlayer_cost.update(temp)
283
284        self._forward()
285        self._backward()
286        is_optimal = "optimal" if self._is_optimal else "sub-optimal"
287        msg = "Finished PBQPExecutor run. Got %s solution." % is_optimal
288        self._logger.info(msg)
289