1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=unused-variable,invalid-name
18"""
19Decorator and utilities for the integration with TOPI and NNVM
20
21"""
22import threading
23import warnings
24import logging
25
26
27from .task import create
28from .topi_integration import TaskExtractEnv
29
30logger = logging.getLogger('autotvm')
31
32
33def extract_from_graph(graph, shape, dtype, target, symbols, params=None, target_host=None):
34    """ Extract tuning tasks from a nnvm graph.
35
36    This function collects tuning tasks by building the graph
37    and trace all the calls to topi.
38
39    Parameters
40    ----------
41    graph : Graph
42        The graph to tune
43    shape : dict of str to tuple
44        The input shape to the graph
45    dtype : str or dict of str to str
46        The input types to the graph
47    target: tvm.target.Target
48        The compilation target
49    symbols : Array of nnvm.symbol
50        Array of nnvm symbols want to be tuned
51    params : dict of str to NDArray
52        The parameter dictionary.
53    target_host: tvm.target.Target
54        The host compilation target
55
56    Returns
57    -------
58    task: Array of autotvm.task.Task
59        collected tasks
60    """
61    import nnvm.compiler
62    import nnvm
63    import topi
64
65    env = TaskExtractEnv.get()
66
67    # NOTE: To add more symbols, you only need to change the following lists
68    # nnvm symbol -> topi compute
69    SYMBOL2TOPI = {
70        nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
71                          topi.nn.group_conv2d_nchw],
72        nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
73        nnvm.sym.dense: [topi.nn.dense],
74    }
75
76    topi_funcs = []
77    for sym_name in symbols:
78        if sym_name in SYMBOL2TOPI:
79            topi_funcs.extend(SYMBOL2TOPI[sym_name])
80        else:
81            warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
82
83    # run compiler to collect all TOPI calls during compilation
84    env.reset(topi_funcs)
85    with env:
86        # disable logger temporarily
87        old_state = logger.disabled
88        logger.disabled = True
89
90        nnvm.compiler.engine.clear_cache()
91        # wrap build call in thread to avoid multiprocessing problems
92        build_thread = threading.Thread(target=nnvm.compiler.build,
93                                        args=(graph,
94                                              target,
95                                              shape,
96                                              dtype,
97                                              params,
98                                              target_host))
99        build_thread.start()
100        build_thread.join()
101
102        logger.disabled = old_state
103
104    # create tasks for target
105    tasks = []
106    for task_name, args in env.get_tasks():
107        try:
108            tsk = create(task_name, args,
109                         target=target, target_host=target_host,
110                         template_key='direct')
111            tasks.append(tsk)
112        except topi.InvalidShapeError:
113            print("[Warning] Invalid shape during AutoTVM task creation")
114
115    return tasks
116
117
118def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, params, target_host=None):
119    """ Extract tuning tasks from multiple nnvm graphs.
120
121    This function is the multiple graph version of extract_from_graph
122
123    Parameters
124    ----------
125    graphs : List of Graph
126        The list of graphs to tune
127    shapes : List of dict of str to tuple
128        The input shape to the graph
129    dtypes : List of str or dict of str to str
130        The input types to the graph
131    target: tvm.target.Target
132        The compilation target
133    symbols : Array of nnvm.symbol
134        Array of nnvm symbols want to be tuned
135    params : dict of str to NDArray
136        The parameter dictionary.
137    target_host: tvm.target.Target
138        The host compilation target
139
140    Returns
141    -------
142    task: Array of autotvm.task.Task
143        collected tasks
144    """
145    import nnvm.compiler
146    import nnvm
147    import topi
148
149    env = TaskExtractEnv.get()
150
151    #NOTE: To add more symbols, you only need to change the following lists
152    #nnvm symbol -> topi compute
153    SYMBOL2TOPI = {
154        nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
155                          topi.nn.group_conv2d_nchw],
156        nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
157        nnvm.sym.dense: [topi.nn.dense],
158    }
159
160    topi_funcs = []
161    for sym_name in symbols:
162        if sym_name in SYMBOL2TOPI:
163            topi_funcs.extend(SYMBOL2TOPI[sym_name])
164        else:
165            warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
166
167    # run compiler to collect all TOPI calls during compilation
168    env.reset(topi_funcs)
169    with env:
170        # disable logger temporarily
171        old_state = logger.disabled
172        logger.disabled = True
173
174        for graph, shape, dtype in zip(graphs, shapes, dtypes):
175            nnvm.compiler.engine.clear_cache()
176            # wrap build call in thread to avoid multiprocessing problems
177            build_thread = threading.Thread(target=nnvm.compiler.build,
178                                            args=(graph,
179                                                  target,
180                                                  shape,
181                                                  dtype,
182                                                  params,
183                                                  target_host))
184            build_thread.start()
185            build_thread.join()
186
187        logger.disabled = old_state
188
189    # create tasks for target
190    tasks = []
191    for task_name, args in env.get_tasks():
192        try:
193            tsk = create(task_name, args,
194                         target=target, target_host=target_host,
195                         template_key='direct')
196            tasks.append(tsk)
197        except topi.InvalidShapeError:
198            print("[Warning] Invalid shape during AutoTVM task creation")
199
200    return tasks
201