1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "mps_utils.h"
21
22namespace tvm {
23namespace contrib {
24
25using namespace runtime;
26
27TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
28  DLTensor* A = args[0];
29  DLTensor* B = args[1];
30  DLTensor* C = args[2];
31  bool transa = args[3];
32  bool transb = args[4];
33  // call gemm for simple compact code.
34  CHECK_EQ(A->ndim, 2);
35  CHECK_EQ(B->ndim, 2);
36  CHECK_EQ(C->ndim, 2);
37  CHECK(C->strides == nullptr);
38  CHECK(B->strides == nullptr);
39  CHECK(A->strides == nullptr);
40  CHECK(TypeMatch(A->dtype, kDLFloat, 32));
41  CHECK(TypeMatch(B->dtype, kDLFloat, 32));
42  CHECK(TypeMatch(C->dtype, kDLFloat, 32));
43  // Get Metal device API
44  MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
45  // CHECK_EQ(A->ctx, B->ctx);
46  // CHECK_EQ(A->ctx, C->ctx);
47  id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
48  id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
49  id<MTLCommandBuffer> cb = [queue commandBuffer];
50  NSUInteger M = A->shape[0 + (transa ? 1 : 0)];
51  NSUInteger N = B->shape[1 - (transb ? 1 : 0)];
52  NSUInteger K = B->shape[0 + (transb ? 1 : 0)];
53
54  CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K);
55  // mps a
56  MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
57  MPSMatrixDescriptor* descA =
58      [MPSMatrixDescriptor matrixDescriptorWithDimensions:M
59                                                  columns:K
60                                                 rowBytes:K * sizeof(MPSDataTypeFloat32)
61                                                 dataType:MPSDataTypeFloat32];
62  id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
63  MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
64  // mps b
65  MPSMatrixDescriptor* descB = [MPSMatrixDescriptor matrixDescriptorWithDimensions:K
66                                                                           columns:N
67                                                                          rowBytes:N * sizeof(dtype)
68                                                                          dataType:dtype];
69  id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
70  MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
71  // mps c
72  MPSMatrixDescriptor* descC = [MPSMatrixDescriptor matrixDescriptorWithDimensions:M
73                                                                           columns:N
74                                                                          rowBytes:N * sizeof(dtype)
75                                                                          dataType:dtype];
76  id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
77  MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
78  // kernel
79
80  MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init];
81  MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev
82                                             transposeLeft:transa
83                                            transposeRight:transb
84                                                resultRows:M
85                                             resultColumns:N
86                                           interiorColumns:K
87                                                     alpha:1.0f
88                                                      beta:0.0f];
89  CHECK(sgemm != nil);
90  [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC];
91  [cb commit];
92});
93
94}  // namespace contrib
95}  // namespace tvm
96