1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=import-error,too-many-locals,too-many-statements,too-many-branches,unused-variable
18"""Dynamic programming tuner."""
19import sys
20import numpy as np
21
22from ._base import MAX_OUTPUT_NODES
23from .base_graph_tuner import BaseGraphTuner
24from .dynamic_programming_stage import DPStage
25from .utils import has_multiple_inputs, is_boundary_node
26
27if sys.version_info[0] == 3:
28    import queue
29else:
30    import Queue as queue
31
32class DPTuner(BaseGraphTuner):
33    """Tuner which uses dynamic programming to solve MDP problem.
34
35    Note: currently dynamic programming is used to solve this MDP problem. However,
36    this problem is intrinsically non-polynomial. DP can't apply for more complicated
37    models, such as networks with many element-wise sum operators. In this case, switch
38    to heuristic algorithm such as PBQP tuner.
39    """
40    def __init__(self, *args, **kwargs):
41        """Create a dynamic programming tuner.
42        """
43        super(DPTuner, self).__init__(*args, **kwargs)
44        self._num_states = self._max_num_states = None
45        self._stage_dict = {}
46        self._dep_dict = {}
47        self._counted_nodes_set = set()
48
49        self._global_data_dict = {
50            "dtype": self._dtype,
51            "counted_nodes_set": self._counted_nodes_set,
52            "stage_dict": self._stage_dict,
53            "in_nodes_dict": self._in_nodes_dict,
54            "out_nodes_dict": self._out_nodes_dict,
55            "dep_dict": self._dep_dict,
56            "node_list": self._node_list,
57            "input_shapes": self._input_shapes,
58            "layout_transform_interlayer_cost": self._layout_transform_interlayer_cost
59        }
60
61    def _check_num_states(self, num_states):
62        """Track the number of states."""
63        self._num_states += num_states
64        if self._max_num_states is not None:
65            if self._num_states > self._max_num_states:
66                raise RuntimeError("Too many states detected while running dynamic "
67                                   "programming: got %d states but upper limit is %d." %
68                                   (self._num_states, self._max_num_states))
69
70    def _forward(self):
71        """Forward pass in DP to generate states for all stages.
72        """
73        self._logger.info("Start forward pass...")
74        for node_idx in sorted(self._in_nodes_dict.keys()):
75            stage = DPStage(idx=node_idx, target_ops=self._target_ops,
76                            **self._global_data_dict)
77            self._check_num_states(stage.full_states.size)
78            self._stage_dict[node_idx] = stage
79        self._logger.info("Finished forward pass.")
80
81    def _backward(self):
82        """Backward pass in DP to generate optimal solution.
83        """
84        self._logger.info("Start backward pass...")
85        input_names = self._input_shapes.keys()
86        optimal_record_dict = {}
87        # Pick optimal schedule for output nodes
88        output_idx_list = []
89        for key, val in self._out_nodes_dict.items():
90            if not val:
91                output_idx_list.append(key)
92
93        # Restrict number of output nodes to avoid numpy reshape error
94        if len(output_idx_list) > MAX_OUTPUT_NODES:
95            msg = "The number of outputs in graph is larger than upper " \
96                  "limit: %s vs %s. Usually this is caused by too many " \
97                  "LAYOUT_FIXED_OP in graph. Switch to greedily select schedule." \
98                  "No action required at this moment. We will continuously improve graph tuner" \
99                  % (len(output_idx_list), MAX_OUTPUT_NODES)
100            self._logger.warning(msg)
101            self._optimal_record_dict = {key : 0 for key in self._in_nodes_dict}
102            return
103
104        states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict,
105                                                              self._node_list)
106        num_states = states_list[0][3].size
107        self._check_num_states(num_states * len(output_idx_list))
108        aligned_node_shape = states_list[0][3].shape
109        min_time = 0
110        min_pos = -1
111        for states in states_list:
112            min_time += np.amax(states[3])
113        flatten_states_list = [current_states[3].flatten() for current_states in states_list]
114        for i in range(num_states):
115            current_time = 0
116            for j, current_states in enumerate(states_list):
117                current_time += flatten_states_list[j][i]
118            if min_time > current_time:
119                min_time = current_time
120                min_pos = i
121        for i, states in enumerate(states_list):
122            current_major_axis = states[1]
123            current_sch_idx = (min_pos % (states[2] *
124                                          aligned_node_shape[current_major_axis])) // states[2]
125            optimal_record_dict[aligned_node_list[i]] = current_sch_idx
126        # Pick optimal schedule for dependencies of output nodes
127        for i in range(len(states_list), len(aligned_node_list)):
128            multiplier = 1
129            for j in range(i + 1, len(aligned_node_list)):
130                multiplier *= aligned_node_shape[j]
131            optimal_record_dict[aligned_node_list[i]] = \
132                min_pos // multiplier % aligned_node_shape[i]
133
134        # Backward pass to get optimal schedules for other nodes
135        bfs_q = queue.Queue()
136        visited = set()
137        for out_idx in output_idx_list:
138            bfs_q.put(out_idx)
139        while not bfs_q.empty():
140            node_idx = bfs_q.get()
141            visited.add(node_idx)
142            node = self._node_list[node_idx]
143            if is_boundary_node(node, input_names):
144                continue
145            optimal_sch_idx = optimal_record_dict[node_idx]
146            full_states = self._stage_dict[node_idx].full_states
147            if not has_multiple_inputs(self._node_list, node_idx, input_names):
148                input_idx = self._in_nodes_dict[node_idx][0]
149                input_node = self._node_list[input_idx]
150                if is_boundary_node(input_node, input_names):
151                    continue
152                if input_idx not in visited:
153                    bfs_q.put(input_idx)
154                    if input_idx not in optimal_record_dict:
155                        dep_list = self._stage_dict[node_idx].dep
156                        dep_idx = tuple([optimal_record_dict[item] for item in dep_list])
157                        tmp = np.argmin(full_states, axis=1)
158                        optimal_input_sch_idx = tmp[(optimal_sch_idx,) + dep_idx]
159                        optimal_record_dict[input_idx] = optimal_input_sch_idx
160            else:
161                input_idx_list = self._in_nodes_dict[node_idx]
162                optimal_record_dict[input_idx_list[0]] = optimal_sch_idx
163                full_states_idx = self._stage_dict[node_idx].full_states_idx
164                tmp = full_states[optimal_sch_idx]
165                new_states_idx, new_states_pos = [], []
166                visited_states_idx, visited_states_pos = [], []
167                for i in range(1, len(full_states_idx)):
168                    if full_states_idx[i] in optimal_record_dict:
169                        visited_states_idx.append(full_states_idx[i])
170                        visited_states_pos.append(i - 1)
171                    else:
172                        new_states_idx.append(full_states_idx[i])
173                        new_states_pos.append(i - 1)
174                if visited_states_idx:
175                    tmp = np.transpose(tmp, tuple(visited_states_pos + new_states_pos))
176                    tmp = tmp[tuple([optimal_record_dict[idx] for idx in visited_states_idx])]
177                min_pos = np.argmin(tmp)
178                multiplier = 1
179                for i in range(len(new_states_idx)):
180                    multiplier *= full_states.shape[new_states_pos[i] + 1]
181                for pos, idx in zip(new_states_pos, new_states_idx):
182                    multiplier //= full_states.shape[pos + 1]
183                    optimal_record_dict[idx] = min_pos // multiplier
184                    min_pos %= multiplier
185                for input_idx in input_idx_list:
186                    if input_idx not in visited:
187                        bfs_q.put(input_idx)
188
189        self._optimal_record_dict = optimal_record_dict
190        for node_idx, _ in self._in_nodes_dict.items():
191            if self._node_list[node_idx]["op"] not in self._target_ops:
192                continue
193        self._logger.info("Finished backward pass...")
194
195    def run(self, **kwargs):
196        """Run dynamic programming solver.
197        """
198        max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"]
199        self._num_states = 0
200        self._max_num_states = max_num_states
201        self._logger.info("Start to run dynamic programming algorithm...")
202        self._forward()
203        self._backward()
204        self._logger.info("Finished DPExecutor run.")
205