1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Tensor intrinsics on CUDA.""" 18#pylint: disable=invalid-name 19import tvm 20 21 22def dp4a(x_scope='local', y_scope='local', z_scope='local'): 23 """ 24 Int8 dot product reduced by every 4 elements using __dp4a 25 26 Parameters 27 ---------- 28 x_scope : str, optional 29 The storage scope of buffer for lhs 30 y_scope : str, optional 31 The storage scope of buffer for rhs 32 z_scope : str, optional 33 The storage scope of buffer for result 34 35 Returns 36 ------- 37 intrin : TensorIntrin 38 The dp4a TensorIntrin that can be used in tensorizing schedule. 39 """ 40 41 n = 4 # dp4a requires operands packed by 4 42 x = tvm.placeholder((n,), name='x', dtype='int8') 43 y = tvm.placeholder((n,), name='y', dtype='int8') 44 45 k = tvm.reduce_axis((0, n), name='rc') 46 47 z = tvm.compute((1,), lambda i: tvm.sum( 48 x[k].astype('int32') * y[k].astype('int32'), axis=[k])) 49 50 def _intrin_func(ins, outs): 51 def _instr(index): 52 xx, yy = ins 53 zz = outs[0] 54 55 if index == 1: 56 return zz.vstore(0, 0) 57 58 ib = tvm.ir_builder.create() 59 60 vec_x = xx.vload(0, dtype='int8x4') 61 vec_y = yy.vload(0, dtype='int8x4') 62 prev_z = 0 if index == 0 else zz.vload(0) 63 64 new_z = tvm.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z) 65 ib.emit(zz.vstore(0, new_z)) 66 67 return ib.get() 68 69 return _instr(0), _instr(1), _instr(2) # body, reset, update 70 71 with tvm.build_config(data_alignment=4, offset_factor=1) as cfg: 72 scopes = {x: x_scope, y: y_scope, z: z_scope} 73 binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name, 74 data_alignment=cfg.data_alignment, 75 offset_factor=cfg.offset_factor, 76 scope=scopes[t]) for t in [x, y, z]} 77 78 return tvm.decl_tensor_intrin(z.op, _intrin_func, binds=binds) 79