1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=unused-argument 18"""A Relay implementation of graph packing.""" 19 20from tvm import relay 21from tvm.relay import op, transform 22from tvm.relay import ExprMutator 23 24def run_opt_pass(expr, opt_pass): 25 """Exectue a relay pass.""" 26 assert isinstance(opt_pass, transform.Pass) 27 mod = relay.Module.from_expr(expr) 28 mod = opt_pass(mod) 29 entry = mod["main"] 30 return entry if isinstance(expr, relay.Function) else entry.body 31 32def _to_shape(shape): 33 return tuple(int(sh) for sh in shape) 34 35def _pack_batch_channel(data, dshape, bfactor, cfactor): 36 """Pack the data channel dimension. 37 """ 38 assert int(dshape[0]) % bfactor == 0 39 assert int(dshape[1]) % cfactor == 0 40 data = op.reshape(data, 41 newshape=(int(dshape[0]) // bfactor, bfactor, 42 int(dshape[1]) // cfactor, cfactor, 43 int(dshape[2]), int(dshape[3]))) 44 data = op.transpose( 45 data, axes=(0, 2, 4, 5, 1, 3)) 46 return data 47 48 49def _unpack_batch_channel(data, old_shape): 50 """Unpack the data channel dimension. 51 """ 52 data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) 53 data = op.reshape(data, newshape=old_shape) 54 return data 55 56 57def _pack_weight(data, dshape, cfactor): 58 """Pack the weight into packed format. 59 """ 60 assert len(dshape) == 4 61 assert int(dshape[0]) % cfactor == 0 62 assert int(dshape[1]) % cfactor == 0 63 data = op.reshape(data, 64 newshape=(int(dshape[0]) // cfactor, cfactor, 65 int(dshape[1]) // cfactor, cfactor, 66 int(dshape[2]), int(dshape[3]))) 67 data = op.transpose( 68 data, axes=(0, 2, 4, 5, 1, 3)) 69 return data 70 71 72def _pack_weight_conv2d_transpose(data, dshape, cfactor): 73 """Pack the weight into packed format. 74 """ 75 dshape = _to_shape(dshape) 76 assert len(dshape) == 4 77 assert dshape[0] % cfactor == 0 78 assert dshape[1] % cfactor == 0 79 data = op.reshape(data, 80 newshape=(dshape[0] // cfactor, cfactor, 81 dshape[1] // cfactor, cfactor, 82 dshape[2], dshape[3])) 83 data = op.transpose( 84 data, axes=(2, 0, 4, 5, 3, 1)) 85 return data 86 87 88def _pack_const(data, dshape, dtype, bfactor, cfactor): 89 """Pack a constant parameter. 90 """ 91 dshape = _to_shape(dshape) 92 assert len(dshape) == 3 93 assert dshape[0] % cfactor == 0 94 data = op.reshape(data, 95 newshape=(dshape[0] // cfactor, 96 cfactor, dshape[1], 97 dshape[2], 1)) 98 data = op.transpose( 99 data, axes=(0, 2, 3, 4, 1)) 100 101 # broadcast batch dimension to bfactor 102 data = op.broadcast_to( 103 data, 104 shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor)) 105 return data 106 107 108def _get_shape(node): 109 """Get the shape of a node. 110 """ 111 return _to_shape(node.checked_type.shape) 112 113class ExprPack(ExprMutator): 114 """Visitor to perform graph packing on an AST. 115 """ 116 def __init__(self, bfactor, cfactor, weight_bits): 117 self.bfactor = bfactor 118 self.cfactor = cfactor 119 self.weight_bits = weight_bits 120 self.start_pack = False 121 # Cache Operator the algorithm matches against. 122 self.bitpack_start = op.op.get('annotation.bitpack_start') 123 self.bitpack_end = op.op.get('annotation.bitpack_end') 124 self.conv2d = op.op.get("nn.conv2d") 125 self.conv2d_transpose = op.op.get("nn.conv2d_transpose") 126 self.add = op.op.get("add") 127 self.multiply = op.op.get("multiply") 128 self.bias_add = op.op.get("nn.bias_add") 129 self.number_of_conv2d = 0 130 super().__init__() 131 132 def visit_call(self, call): 133 """ Visit the children. """ 134 # First visit the children. 135 oshape = _get_shape(call) 136 odtype = call.checked_type.dtype 137 input_types = [arg.checked_type for arg in call.args] 138 args = [self.visit(arg) for arg in call.args] 139 140 # Start and stop cases. 141 if call.op == self.bitpack_start: 142 assert not self.start_pack 143 self.start_pack = True 144 return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor) 145 elif call.op == self.bitpack_end: 146 if self.start_pack: 147 self.start_pack = False 148 data = args[0] 149 data_shape = _get_shape(call.args[0]) 150 return _unpack_batch_channel(data, data_shape) 151 else: 152 pass 153 if self.start_pack: 154 # Operator cases 155 if call.op == self.conv2d and odtype == 'int32': 156 self.number_of_conv2d += 1 157 assert 8 % self.weight_bits == 0 158 w_lanes = 8 // self.weight_bits 159 data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) 160 kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor) 161 data, weight = args 162 data_shape = _to_shape(input_types[0].shape) 163 kernel_shape = _to_shape(input_types[1].shape) 164 kernel = _pack_weight(weight, kernel_shape, self.cfactor) 165 # insert bit packing when necessary 166 if w_lanes != 1: 167 assert 8 % w_lanes == 0 168 kernel = op.bitpack(kernel, lanes=w_lanes) 169 conv2d = op.nn.conv2d( 170 data, 171 kernel, 172 strides=call.attrs.strides, 173 padding=call.attrs.padding, 174 dilation=call.attrs.dilation, 175 groups=call.attrs.groups, 176 channels=call.attrs.channels, 177 kernel_size=call.attrs.kernel_size, 178 data_layout=data_layout, 179 kernel_layout=kernel_layout, 180 out_dtype=call.attrs.out_dtype) 181 return conv2d 182 elif call.op == self.conv2d_transpose and odtype == 'int32': 183 self.number_of_conv2d += 1 184 assert 8 % self.weight_bits == 0 185 w_lanes = 8 // self.weight_bits 186 if self.start_pack: 187 data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) 188 kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor) 189 data, weight = args 190 data_shape = _to_shape(input_types[0].shape) 191 kernel_shape = _to_shape(input_types[1].shape) 192 kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor) 193 conv2d = op.nn.conv2d_transpose( 194 data, 195 kernel, 196 strides=call.attrs.strides, 197 padding=call.attrs.padding, 198 dilation=call.attrs.dilation, 199 groups=call.attrs.groups, 200 channels=call.attrs.channels, 201 kernel_size=call.attrs.kernel_size, 202 data_layout=data_layout, 203 kernel_layout=kernel_layout, 204 output_padding=call.attrs.output_padding, 205 out_dtype=call.attrs.out_dtype) 206 return conv2d 207 elif call.op == self.add and \ 208 tuple(input_types[0].shape) == tuple(input_types[1].shape): 209 pass 210 elif call.op == self.add and len(input_types[1].shape) == 3: 211 data, const = args 212 const = _pack_const(const, 213 _to_shape(input_types[1].shape), 214 input_types[1].dtype, 215 self.bfactor, 216 self.cfactor) 217 return relay.Call(self.add, [data, const]) 218 elif call.op == self.multiply and \ 219 tuple(input_types[0].shape) == tuple(input_types[1].shape): 220 pass 221 elif call.op == self.multiply and len(input_types[1].shape) == 3: 222 data, const = args 223 const = _pack_const(const, 224 _to_shape(input_types[1].shape), 225 input_types[1].dtype, 226 self.bfactor, 227 self.cfactor) 228 return relay.Call(self.multiply, [data, const]) 229 elif self.start_pack and call.op == self.bias_add: 230 data, bias = args 231 bias = _pack_const(bias, 232 _to_shape(input_types[1].shape), 233 input_types[1].dtype, 234 self.bfactor, 235 self.cfactor) 236 return relay.Call(self.add, [data, bias]) 237 elif self.start_pack and call.op == op.op.get('cast') and \ 238 input_types[0].dtype == 'int32': 239 cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs) 240 return relay.Call(op.op.get('copy'), [cast]) 241 242 return relay.Call( 243 self.visit(call.op), 244 args, 245 call.attrs) 246 247class BT(Exception): 248 pass 249def get_subgraph(expr, start_name, stop_name): 250 """ We assume stop_name only appears once for simplicity. 251 This constraint will be lifted in the future. 252 bitpack_start and bitpack_end are both inclusive. 253 """ 254 bitpack_start = op.op.get('annotation.bitpack_start') 255 bitpack_end = op.op.get('annotation.bitpack_end') 256 anf = run_opt_pass(expr, transform.ToANormalForm()) 257 def _recursion(anf, start_found, stop_found): 258 """ Helper to obtain the subgraph. 259 """ 260 if isinstance(anf, relay.expr.Function): 261 return relay.expr.Function(anf.params, 262 _recursion(anf.body, start_found, stop_found), 263 anf.ret_type, anf.type_params, anf.attrs) 264 elif isinstance(anf, relay.expr.Let): 265 value = anf.value 266 if isinstance(value, relay.expr.Call): 267 if isinstance(value.op, relay.op.Op): 268 if value.op.name == start_name and not start_found: 269 value = relay.expr.Call(bitpack_start, [value]) 270 start_found = True 271 elif value.op.name == stop_name: 272 raise BT() 273 try: 274 return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found)) 275 except BT: 276 assert start_found 277 assert not stop_found 278 stop_found = True 279 value = relay.expr.Call(bitpack_end, [value]) 280 # todo: check anf.body has no more stop_name beside that one 281 return relay.expr.Let(anf.var, value, anf.body) 282 else: 283 assert start_found 284 assert stop_found 285 return anf 286 annotated = _recursion(anf, False, False) 287 return run_opt_pass(annotated, transform.ToGraphNormalForm()) 288 289def graph_pack(expr, 290 bfactor, 291 cfactor, 292 weight_bits, 293 start_name="nn.max_pool2d", 294 stop_name="nn.global_avg_pool2d"): 295 """Pack the graph into batch&channel packed format. 296 297 Parameters 298 ---------- 299 expr : relay.Expr 300 The input program. 301 302 bfactor : int 303 The packing factor in batch 304 305 cfactor : int 306 The packing factor in channel 307 308 weight_bits: int 309 The bit-width of the weights. 310 311 start_name: str, optional 312 Start packing from certain known node. 313 314 stop_name: str, optional 315 Stop packing from certain known node. 316 317 Returns 318 ------- 319 expr : Expr 320 The transformed expression. 321 """ 322 assert isinstance(expr, relay.Function) 323 expr = get_subgraph(expr, start_name, stop_name) 324 expr = run_opt_pass(expr, transform.InferType()) 325 packer = ExprPack( 326 bfactor, cfactor, 327 weight_bits) 328 expr = packer.visit(expr) 329 assert not packer.start_pack 330 return run_opt_pass(expr, transform.InferType()) 331