1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""External function interface to MPS libraries."""
18import tvm
19from tvm import te
20
21
22# pylint: disable=C0103,W0612
23
24
25def matmul(lhs, rhs, transa=False, transb=False):
26    """Create an extern op that compute matrix mult of A and rhs with CrhsLAS
27
28    This function serves as an example on how to calle external libraries.
29
30    Parameters
31    ----------
32    lhs : Tensor
33        The left matrix operand
34    rhs : Tensor
35        The right matrix operand
36    transa : bool
37        Whether transpose lhs
38    transb : bool
39        Whether transpose rhs
40
41    Returns
42    -------
43    C : Tensor
44        The result tensor.
45    """
46    m = lhs.shape[0] if transa is False else lhs.shape[1]
47    n = rhs.shape[1] if transb is False else rhs.shape[0]
48    if transa:
49        m = b
50    if transb:
51        n = c
52    return te.extern(
53        (m, n),
54        [lhs, rhs],
55        lambda ins, outs: tvm.tir.call_packed(
56            "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb
57        ),
58        name="C",
59    )
60
61
62def conv2d(data, weight, pad="SAME", stride=1):
63    """
64    Create an extern op that compute data * weight and return result in output
65
66    Parameters:
67    ----------
68    data: Tensor
69        The input data, format NHWC
70    weight: Tensor
71        The conv weight, format output_feature * kH * kW * input_feature
72    pad: str
73        Padding method, 'SAME' or 'VALID'
74    stride: int
75        convolution stride
76
77    Returns
78    -------
79    output: Tensor
80        The result tensor
81    """
82    n, hi, wi, ci = data.shape
83    co, kh, kw, ciw = weight.shape
84    padding = 0 if pad == "SAME" else 1
85    ho = hi // stride
86    wo = wi // stride
87
88    return te.extern(
89        (n, ho, wo, co),
90        [data, weight],
91        lambda ins, outs: tvm.tir.call_packed(
92            "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride
93        ),
94        name="C",
95    )
96