1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=import-error,too-many-locals,too-many-statements,too-many-branches,unused-variable 18"""Dynamic programming tuner.""" 19import sys 20import numpy as np 21 22from ._base import MAX_OUTPUT_NODES 23from .base_graph_tuner import BaseGraphTuner 24from .dynamic_programming_stage import DPStage 25from .utils import has_multiple_inputs, is_boundary_node 26 27if sys.version_info[0] == 3: 28 import queue 29else: 30 import Queue as queue 31 32class DPTuner(BaseGraphTuner): 33 """Tuner which uses dynamic programming to solve MDP problem. 34 35 Note: currently dynamic programming is used to solve this MDP problem. However, 36 this problem is intrinsically non-polynomial. DP can't apply for more complicated 37 models, such as networks with many element-wise sum operators. In this case, switch 38 to heuristic algorithm such as PBQP tuner. 39 """ 40 def __init__(self, *args, **kwargs): 41 """Create a dynamic programming tuner. 42 """ 43 super(DPTuner, self).__init__(*args, **kwargs) 44 self._num_states = self._max_num_states = None 45 self._stage_dict = {} 46 self._dep_dict = {} 47 self._counted_nodes_set = set() 48 49 self._global_data_dict = { 50 "dtype": self._dtype, 51 "counted_nodes_set": self._counted_nodes_set, 52 "stage_dict": self._stage_dict, 53 "in_nodes_dict": self._in_nodes_dict, 54 "out_nodes_dict": self._out_nodes_dict, 55 "dep_dict": self._dep_dict, 56 "node_list": self._node_list, 57 "input_shapes": self._input_shapes, 58 "layout_transform_interlayer_cost": self._layout_transform_interlayer_cost 59 } 60 61 def _check_num_states(self, num_states): 62 """Track the number of states.""" 63 self._num_states += num_states 64 if self._max_num_states is not None: 65 if self._num_states > self._max_num_states: 66 raise RuntimeError("Too many states detected while running dynamic " 67 "programming: got %d states but upper limit is %d." % 68 (self._num_states, self._max_num_states)) 69 70 def _forward(self): 71 """Forward pass in DP to generate states for all stages. 72 """ 73 self._logger.info("Start forward pass...") 74 for node_idx in sorted(self._in_nodes_dict.keys()): 75 stage = DPStage(idx=node_idx, target_ops=self._target_ops, 76 **self._global_data_dict) 77 self._check_num_states(stage.full_states.size) 78 self._stage_dict[node_idx] = stage 79 self._logger.info("Finished forward pass.") 80 81 def _backward(self): 82 """Backward pass in DP to generate optimal solution. 83 """ 84 self._logger.info("Start backward pass...") 85 input_names = self._input_shapes.keys() 86 optimal_record_dict = {} 87 # Pick optimal schedule for output nodes 88 output_idx_list = [] 89 for key, val in self._out_nodes_dict.items(): 90 if not val: 91 output_idx_list.append(key) 92 93 # Restrict number of output nodes to avoid numpy reshape error 94 if len(output_idx_list) > MAX_OUTPUT_NODES: 95 msg = "The number of outputs in graph is larger than upper " \ 96 "limit: %s vs %s. Usually this is caused by too many " \ 97 "LAYOUT_FIXED_OP in graph. Switch to greedily select schedule." \ 98 "No action required at this moment. We will continuously improve graph tuner" \ 99 % (len(output_idx_list), MAX_OUTPUT_NODES) 100 self._logger.warning(msg) 101 self._optimal_record_dict = {key : 0 for key in self._in_nodes_dict} 102 return 103 104 states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict, 105 self._node_list) 106 num_states = states_list[0][3].size 107 self._check_num_states(num_states * len(output_idx_list)) 108 aligned_node_shape = states_list[0][3].shape 109 min_time = 0 110 min_pos = -1 111 for states in states_list: 112 min_time += np.amax(states[3]) 113 flatten_states_list = [current_states[3].flatten() for current_states in states_list] 114 for i in range(num_states): 115 current_time = 0 116 for j, current_states in enumerate(states_list): 117 current_time += flatten_states_list[j][i] 118 if min_time > current_time: 119 min_time = current_time 120 min_pos = i 121 for i, states in enumerate(states_list): 122 current_major_axis = states[1] 123 current_sch_idx = (min_pos % (states[2] * 124 aligned_node_shape[current_major_axis])) // states[2] 125 optimal_record_dict[aligned_node_list[i]] = current_sch_idx 126 # Pick optimal schedule for dependencies of output nodes 127 for i in range(len(states_list), len(aligned_node_list)): 128 multiplier = 1 129 for j in range(i + 1, len(aligned_node_list)): 130 multiplier *= aligned_node_shape[j] 131 optimal_record_dict[aligned_node_list[i]] = \ 132 min_pos // multiplier % aligned_node_shape[i] 133 134 # Backward pass to get optimal schedules for other nodes 135 bfs_q = queue.Queue() 136 visited = set() 137 for out_idx in output_idx_list: 138 bfs_q.put(out_idx) 139 while not bfs_q.empty(): 140 node_idx = bfs_q.get() 141 visited.add(node_idx) 142 node = self._node_list[node_idx] 143 if is_boundary_node(node, input_names): 144 continue 145 optimal_sch_idx = optimal_record_dict[node_idx] 146 full_states = self._stage_dict[node_idx].full_states 147 if not has_multiple_inputs(self._node_list, node_idx, input_names): 148 input_idx = self._in_nodes_dict[node_idx][0] 149 input_node = self._node_list[input_idx] 150 if is_boundary_node(input_node, input_names): 151 continue 152 if input_idx not in visited: 153 bfs_q.put(input_idx) 154 if input_idx not in optimal_record_dict: 155 dep_list = self._stage_dict[node_idx].dep 156 dep_idx = tuple([optimal_record_dict[item] for item in dep_list]) 157 tmp = np.argmin(full_states, axis=1) 158 optimal_input_sch_idx = tmp[(optimal_sch_idx,) + dep_idx] 159 optimal_record_dict[input_idx] = optimal_input_sch_idx 160 else: 161 input_idx_list = self._in_nodes_dict[node_idx] 162 optimal_record_dict[input_idx_list[0]] = optimal_sch_idx 163 full_states_idx = self._stage_dict[node_idx].full_states_idx 164 tmp = full_states[optimal_sch_idx] 165 new_states_idx, new_states_pos = [], [] 166 visited_states_idx, visited_states_pos = [], [] 167 for i in range(1, len(full_states_idx)): 168 if full_states_idx[i] in optimal_record_dict: 169 visited_states_idx.append(full_states_idx[i]) 170 visited_states_pos.append(i - 1) 171 else: 172 new_states_idx.append(full_states_idx[i]) 173 new_states_pos.append(i - 1) 174 if visited_states_idx: 175 tmp = np.transpose(tmp, tuple(visited_states_pos + new_states_pos)) 176 tmp = tmp[tuple([optimal_record_dict[idx] for idx in visited_states_idx])] 177 min_pos = np.argmin(tmp) 178 multiplier = 1 179 for i in range(len(new_states_idx)): 180 multiplier *= full_states.shape[new_states_pos[i] + 1] 181 for pos, idx in zip(new_states_pos, new_states_idx): 182 multiplier //= full_states.shape[pos + 1] 183 optimal_record_dict[idx] = min_pos // multiplier 184 min_pos %= multiplier 185 for input_idx in input_idx_list: 186 if input_idx not in visited: 187 bfs_q.put(input_idx) 188 189 self._optimal_record_dict = optimal_record_dict 190 for node_idx, _ in self._in_nodes_dict.items(): 191 if self._node_list[node_idx]["op"] not in self._target_ops: 192 continue 193 self._logger.info("Finished backward pass...") 194 195 def run(self, **kwargs): 196 """Run dynamic programming solver. 197 """ 198 max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"] 199 self._num_states = 0 200 self._max_num_states = max_num_states 201 self._logger.info("Start to run dynamic programming algorithm...") 202 self._forward() 203 self._backward() 204 self._logger.info("Finished DPExecutor run.") 205