1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-argument 18"""Arm Compute Library supported operators.""" 19import tvm 20from tvm.relay.expr import const 21from tvm.relay import transform 22from tvm.relay.build_module import bind_params_by_name 23 24from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr 25from .register import register_pattern_table 26 27 28def is_arm_compute_runtime_enabled(): 29 """Check if the ACL graph runtime is present. 30 31 Returns 32 ------- 33 ret: bool 34 True if present, False if not. 35 """ 36 check_enabled = tvm.get_global_func("relay.op.is_arm_compute_runtime_enabled", True) 37 if check_enabled: 38 return check_enabled() 39 return False 40 41 42def partition_for_arm_compute_lib(mod, params=None): 43 """Partition the graph greedily offloading supported 44 operators to Arm Compute Library. 45 46 Parameters 47 ---------- 48 mod : Module 49 The module to run passes on. 50 params : Optional[Dict[str, NDArray]] 51 Constant input parameters. 52 53 Returns 54 ------- 55 ret : annotated and partitioned module. 56 """ 57 if params: 58 mod["main"] = bind_params_by_name(mod["main"], params) 59 60 seq = tvm.transform.Sequential( 61 [ 62 transform.MergeComposite(arm_compute_lib_pattern_table()), 63 transform.AnnotateTarget("arm_compute_lib"), 64 transform.PartitionGraph(), 65 ] 66 ) 67 68 return seq(mod) 69 70 71@register_pattern_table("arm_compute_lib") 72def arm_compute_lib_pattern_table(): 73 """Get the ACL pattern table.""" 74 75 def conv_pattern(): 76 """Create a convolution pattern. 77 78 Returns 79 ------- 80 pattern : dataflow_pattern.AltPattern 81 Denotes the convolution pattern. 82 """ 83 pattern = is_op("nn.pad")(wildcard()) | wildcard() 84 pattern = is_op("nn.conv2d")(pattern, is_constant()) 85 pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) 86 pattern = pattern.optional(is_op("nn.relu")) 87 return pattern 88 89 def qnn_conv_pattern(): 90 """Create a quantized convolution pattern. 91 92 Returns 93 ------- 94 pattern : dataflow_pattern.AltPattern 95 Denotes the convolution pattern. 96 """ 97 pattern = is_op("nn.pad")(wildcard()) | wildcard() 98 pattern = is_op("qnn.conv2d")( 99 pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant() 100 ) 101 pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) 102 pattern = pattern.optional(is_op("nn.relu")) 103 pattern = is_op("qnn.requantize")( 104 pattern, wildcard(), wildcard(), is_constant(), is_constant() 105 ) 106 return pattern 107 108 def dense_pattern(): 109 """Create a dense (fully-connected) pattern. 110 111 Returns 112 ------- 113 pattern : dataflow_pattern.AltPattern 114 Denotes the convolution pattern. 115 """ 116 pattern = is_op("nn.dense")(wildcard(), is_constant()) 117 pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) 118 return pattern 119 120 def qnn_dense_pattern(): 121 """Create a quantized dense (fully-connected) pattern. 122 123 Returns 124 ------- 125 pattern : dataflow_pattern.AltPattern 126 Denotes the convolution pattern. 127 """ 128 pattern = is_op("qnn.dense")( 129 wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() 130 ) 131 pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) 132 pattern = is_op("qnn.requantize")( 133 pattern, wildcard(), wildcard(), is_constant(), is_constant() 134 ) 135 return pattern 136 137 def avg_pool2d_pattern(): 138 """Creates a pattern that matches either quantized 139 avg_pool2d or quantized global_avg_pool2d. 140 141 Returns 142 ------- 143 pattern : dataflow_pattern.AltPattern 144 Denotes the convolution pattern. 145 """ 146 pattern = is_op("cast")(wildcard()) 147 pattern = is_op("nn.avg_pool2d")(pattern) | is_op("nn.global_avg_pool2d")(pattern) 148 pattern = is_op("cast")(pattern) 149 return pattern 150 151 def l2_pool2d_pattern(): 152 """Create an l2 pooling pattern from equivalent relay operators. 153 154 Returns 155 ------- 156 pattern : dataflow_pattern.AltPattern 157 Denotes the convolution pattern. 158 """ 159 pattern = is_op("power")(wildcard(), is_expr(const(2.0))) 160 pattern = is_op("nn.avg_pool2d")(pattern) 161 pattern = is_op("sqrt")(pattern) 162 return pattern 163 164 def check_conv(extract): 165 """Check conv pattern is supported by ACL.""" 166 call = extract 167 while call.op.name != "nn.conv2d": 168 call = call.args[0] 169 return conv2d(call.attrs, call.args) 170 171 def check_qnn_conv(extract): 172 """Check qnn conv pattern is supported by ACL.""" 173 if extract.attrs.out_dtype != "uint8": 174 return False 175 call = extract 176 while call.op.name != "qnn.conv2d": 177 call = call.args[0] 178 return qnn_conv2d(call.attrs, call.args) 179 180 def check_dense(extract): 181 """Check conv pattern is supported by ACL.""" 182 call = extract 183 while call.op.name != "nn.dense": 184 call = call.args[0] 185 return dense(call.attrs, call.args) 186 187 def check_qnn_dense(extract): 188 """Check qnn conv pattern is supported by ACL.""" 189 if extract.attrs.out_dtype != "uint8": 190 return False 191 call = extract 192 while call.op.name != "qnn.dense": 193 call = call.args[0] 194 return qnn_dense(call.attrs, call.args) 195 196 def check_avg_pool2d(extract): 197 """Check average pool2d pattern is supported by ACL.""" 198 if extract.attrs.dtype != "uint8": 199 return False 200 pool = extract.args[0] 201 if pool.args[0].attrs.dtype != "int32": 202 return False 203 return avg_pool2d(pool.attrs, pool.args, from_quantized_composite=True) 204 205 def check_l2_pool2d(extract): 206 """Check l2 pool2d pattern is supported by ACL.""" 207 pool = extract.args[0] 208 return avg_pool2d(pool.attrs, pool.args) 209 210 return [ 211 ("arm_compute_lib.conv2d", conv_pattern(), check_conv), 212 ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv), 213 ("arm_compute_lib.dense", dense_pattern(), check_dense), 214 ("arm_compute_lib.qnn_dense", qnn_dense_pattern(), check_qnn_dense), 215 ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv), 216 ("arm_compute_lib.avg_pool2d", avg_pool2d_pattern(), check_avg_pool2d), 217 ("arm_compute_lib.l2_pool2d", l2_pool2d_pattern(), check_l2_pool2d), 218 ] 219 220 221def _register_external_op_helper(op_name, supported=True): 222 @tvm.ir.register_op_attr(op_name, "target.arm_compute_lib") 223 def _func_wrapper(attrs, args): 224 return supported 225 226 return _func_wrapper 227 228 229_register_external_op_helper("reshape") 230 231 232@tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib") 233def conv2d(attrs, args): 234 """Check if the external ACL codegen for conv2d should be used.""" 235 if attrs.groups != 1: 236 return False 237 if attrs.data_layout != "NHWC": 238 return False 239 if attrs.out_dtype != "float32" and attrs.out_dtype != "": 240 return False 241 data_typ = args[0].checked_type 242 if len(data_typ.shape) != 4 or data_typ.shape[0] != 1 or data_typ.dtype != "float32": 243 return False 244 kernel_typ = args[1].checked_type 245 if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "float32": 246 return False 247 return True 248 249 250def qnn_conv2d(attrs, args): 251 """Check if the external ACL codegen for qnn.conv2d should be used.""" 252 if attrs.groups != 1: 253 return False 254 if attrs.data_layout != "NHWC": 255 return False 256 if attrs.out_dtype != "int32" and attrs.out_dtype != "": 257 return False 258 data_typ = args[0].checked_type 259 if len(data_typ.shape) != 4 or data_typ.shape[0] != 1 or data_typ.dtype != "uint8": 260 return False 261 kernel_typ = args[1].checked_type 262 if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "uint8": 263 return False 264 return True 265 266 267@tvm.ir.register_op_attr("nn.dense", "target.arm_compute_lib") 268def dense(attrs, args): 269 """Check if the external ACL codegen for dense should be used.""" 270 data_typ = args[0].checked_type 271 if data_typ.dtype != "float32": 272 return False 273 kernel_typ = args[1].checked_type 274 if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "float32": 275 return False 276 if attrs.out_dtype != "float32" and attrs.out_dtype != "": 277 return False 278 return True 279 280 281def qnn_dense(attrs, args): 282 """Check if the external ACL codegen for qnn.dense should be used.""" 283 data_typ = args[0].checked_type 284 if data_typ.dtype != "uint8": 285 return False 286 kernel_typ = args[1].checked_type 287 if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "uint8": 288 return False 289 if attrs.out_dtype != "int32": 290 return False 291 return True 292 293 294@tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib") 295def max_pool2d(attrs, args): 296 """Check if the external ACL codegen for maxpool2d should be used.""" 297 if attrs.layout != "NHWC": 298 return False 299 typ = args[0].checked_type 300 if typ.dtype not in ["float32", "uint8"]: 301 return False 302 return True 303 304 305@tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib") 306def avg_pool2d(attrs, args, from_quantized_composite=False): 307 """Check if the external ACL codegen for avgpool2d should be used.""" 308 typ = args[0].checked_type 309 if from_quantized_composite: 310 if typ.dtype != "int32": 311 return False 312 else: 313 if typ.dtype not in ["float32"]: 314 return False 315 if attrs.layout != "NHWC": 316 return False 317 return True 318 319 320@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib") 321def global_max_pool2d(attrs, args): 322 """Check if the external ACL codegen for gloval_maxpool2d should be used.""" 323 typ = args[0].checked_type 324 if typ.dtype not in ["float32", "uint8"]: 325 return False 326 if attrs.layout != "NHWC": 327 return False 328 return True 329 330 331@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib") 332def global_avg_pool2d(attrs, args): 333 """Check if the external ACL codegen for global_avgpool2d should be used.""" 334 typ = args[0].checked_type 335 if typ.dtype not in ["float32"]: 336 return False 337 if attrs.layout != "NHWC": 338 return False 339 return True 340 341 342@tvm.ir.register_op_attr("maximum", "target.arm_compute_lib") 343def maximum(attrs, args): 344 """Check if the external ACL codegen for maximum should be used.""" 345 type_a = args[0].checked_type 346 type_b = args[0].checked_type 347 return (type_a.dtype == "float32") and (type_b.dtype == "float32") 348