1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""External function interface to MPS libraries.""" 18import tvm 19from tvm import te 20 21 22# pylint: disable=C0103,W0612 23 24 25def matmul(lhs, rhs, transa=False, transb=False): 26 """Create an extern op that compute matrix mult of A and rhs with CrhsLAS 27 28 This function serves as an example on how to calle external libraries. 29 30 Parameters 31 ---------- 32 lhs : Tensor 33 The left matrix operand 34 rhs : Tensor 35 The right matrix operand 36 transa : bool 37 Whether transpose lhs 38 transb : bool 39 Whether transpose rhs 40 41 Returns 42 ------- 43 C : Tensor 44 The result tensor. 45 """ 46 m = lhs.shape[0] if transa is False else lhs.shape[1] 47 n = rhs.shape[1] if transb is False else rhs.shape[0] 48 if transa: 49 m = b 50 if transb: 51 n = c 52 return te.extern( 53 (m, n), 54 [lhs, rhs], 55 lambda ins, outs: tvm.tir.call_packed( 56 "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb 57 ), 58 name="C", 59 ) 60 61 62def conv2d(data, weight, pad="SAME", stride=1): 63 """ 64 Create an extern op that compute data * weight and return result in output 65 66 Parameters: 67 ---------- 68 data: Tensor 69 The input data, format NHWC 70 weight: Tensor 71 The conv weight, format output_feature * kH * kW * input_feature 72 pad: str 73 Padding method, 'SAME' or 'VALID' 74 stride: int 75 convolution stride 76 77 Returns 78 ------- 79 output: Tensor 80 The result tensor 81 """ 82 n, hi, wi, ci = data.shape 83 co, kh, kw, ciw = weight.shape 84 padding = 0 if pad == "SAME" else 1 85 ho = hi // stride 86 wo = wi // stride 87 88 return te.extern( 89 (n, ho, wo, co), 90 [data, weight], 91 lambda ins, outs: tvm.tir.call_packed( 92 "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride 93 ), 94 name="C", 95 ) 96