1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=unused-variable,invalid-name 18""" 19Decorator and utilities for the integration with TOPI and NNVM 20 21""" 22import threading 23import warnings 24import logging 25 26 27from .task import create 28from .topi_integration import TaskExtractEnv 29 30logger = logging.getLogger('autotvm') 31 32 33def extract_from_graph(graph, shape, dtype, target, symbols, params=None, target_host=None): 34 """ Extract tuning tasks from a nnvm graph. 35 36 This function collects tuning tasks by building the graph 37 and trace all the calls to topi. 38 39 Parameters 40 ---------- 41 graph : Graph 42 The graph to tune 43 shape : dict of str to tuple 44 The input shape to the graph 45 dtype : str or dict of str to str 46 The input types to the graph 47 target: tvm.target.Target 48 The compilation target 49 symbols : Array of nnvm.symbol 50 Array of nnvm symbols want to be tuned 51 params : dict of str to NDArray 52 The parameter dictionary. 53 target_host: tvm.target.Target 54 The host compilation target 55 56 Returns 57 ------- 58 task: Array of autotvm.task.Task 59 collected tasks 60 """ 61 import nnvm.compiler 62 import nnvm 63 import topi 64 65 env = TaskExtractEnv.get() 66 67 # NOTE: To add more symbols, you only need to change the following lists 68 # nnvm symbol -> topi compute 69 SYMBOL2TOPI = { 70 nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, 71 topi.nn.group_conv2d_nchw], 72 nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], 73 nnvm.sym.dense: [topi.nn.dense], 74 } 75 76 topi_funcs = [] 77 for sym_name in symbols: 78 if sym_name in SYMBOL2TOPI: 79 topi_funcs.extend(SYMBOL2TOPI[sym_name]) 80 else: 81 warnings.warn("Symbol %s is not tunable, ignored" % sym_name) 82 83 # run compiler to collect all TOPI calls during compilation 84 env.reset(topi_funcs) 85 with env: 86 # disable logger temporarily 87 old_state = logger.disabled 88 logger.disabled = True 89 90 nnvm.compiler.engine.clear_cache() 91 # wrap build call in thread to avoid multiprocessing problems 92 build_thread = threading.Thread(target=nnvm.compiler.build, 93 args=(graph, 94 target, 95 shape, 96 dtype, 97 params, 98 target_host)) 99 build_thread.start() 100 build_thread.join() 101 102 logger.disabled = old_state 103 104 # create tasks for target 105 tasks = [] 106 for task_name, args in env.get_tasks(): 107 try: 108 tsk = create(task_name, args, 109 target=target, target_host=target_host, 110 template_key='direct') 111 tasks.append(tsk) 112 except topi.InvalidShapeError: 113 print("[Warning] Invalid shape during AutoTVM task creation") 114 115 return tasks 116 117 118def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, params, target_host=None): 119 """ Extract tuning tasks from multiple nnvm graphs. 120 121 This function is the multiple graph version of extract_from_graph 122 123 Parameters 124 ---------- 125 graphs : List of Graph 126 The list of graphs to tune 127 shapes : List of dict of str to tuple 128 The input shape to the graph 129 dtypes : List of str or dict of str to str 130 The input types to the graph 131 target: tvm.target.Target 132 The compilation target 133 symbols : Array of nnvm.symbol 134 Array of nnvm symbols want to be tuned 135 params : dict of str to NDArray 136 The parameter dictionary. 137 target_host: tvm.target.Target 138 The host compilation target 139 140 Returns 141 ------- 142 task: Array of autotvm.task.Task 143 collected tasks 144 """ 145 import nnvm.compiler 146 import nnvm 147 import topi 148 149 env = TaskExtractEnv.get() 150 151 #NOTE: To add more symbols, you only need to change the following lists 152 #nnvm symbol -> topi compute 153 SYMBOL2TOPI = { 154 nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, 155 topi.nn.group_conv2d_nchw], 156 nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], 157 nnvm.sym.dense: [topi.nn.dense], 158 } 159 160 topi_funcs = [] 161 for sym_name in symbols: 162 if sym_name in SYMBOL2TOPI: 163 topi_funcs.extend(SYMBOL2TOPI[sym_name]) 164 else: 165 warnings.warn("Symbol %s is not tunable, ignored" % sym_name) 166 167 # run compiler to collect all TOPI calls during compilation 168 env.reset(topi_funcs) 169 with env: 170 # disable logger temporarily 171 old_state = logger.disabled 172 logger.disabled = True 173 174 for graph, shape, dtype in zip(graphs, shapes, dtypes): 175 nnvm.compiler.engine.clear_cache() 176 # wrap build call in thread to avoid multiprocessing problems 177 build_thread = threading.Thread(target=nnvm.compiler.build, 178 args=(graph, 179 target, 180 shape, 181 dtype, 182 params, 183 target_host)) 184 build_thread.start() 185 build_thread.join() 186 187 logger.disabled = old_state 188 189 # create tasks for target 190 tasks = [] 191 for task_name, args in env.get_tasks(): 192 try: 193 tsk = create(task_name, args, 194 target=target, target_host=target_host, 195 template_key='direct') 196 tasks.append(tsk) 197 except topi.InvalidShapeError: 198 print("[Warning] Invalid shape during AutoTVM task creation") 199 200 return tasks 201