1/* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20#include "mps_utils.h" 21 22namespace tvm { 23namespace contrib { 24 25using namespace runtime; 26 27TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { 28 DLTensor* A = args[0]; 29 DLTensor* B = args[1]; 30 DLTensor* C = args[2]; 31 bool transa = args[3]; 32 bool transb = args[4]; 33 // call gemm for simple compact code. 34 CHECK_EQ(A->ndim, 2); 35 CHECK_EQ(B->ndim, 2); 36 CHECK_EQ(C->ndim, 2); 37 CHECK(C->strides == nullptr); 38 CHECK(B->strides == nullptr); 39 CHECK(A->strides == nullptr); 40 CHECK(TypeMatch(A->dtype, kDLFloat, 32)); 41 CHECK(TypeMatch(B->dtype, kDLFloat, 32)); 42 CHECK(TypeMatch(C->dtype, kDLFloat, 32)); 43 // Get Metal device API 44 MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); 45 // CHECK_EQ(A->ctx, B->ctx); 46 // CHECK_EQ(A->ctx, C->ctx); 47 id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx); 48 id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx); 49 id<MTLCommandBuffer> cb = [queue commandBuffer]; 50 NSUInteger M = A->shape[0 + (transa ? 1 : 0)]; 51 NSUInteger N = B->shape[1 - (transb ? 1 : 0)]; 52 NSUInteger K = B->shape[0 + (transb ? 1 : 0)]; 53 54 CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); 55 // mps a 56 MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); 57 MPSMatrixDescriptor* descA = 58 [MPSMatrixDescriptor matrixDescriptorWithDimensions:M 59 columns:K 60 rowBytes:K * sizeof(MPSDataTypeFloat32) 61 dataType:MPSDataTypeFloat32]; 62 id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data); 63 MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; 64 // mps b 65 MPSMatrixDescriptor* descB = [MPSMatrixDescriptor matrixDescriptorWithDimensions:K 66 columns:N 67 rowBytes:N * sizeof(dtype) 68 dataType:dtype]; 69 id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data); 70 MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; 71 // mps c 72 MPSMatrixDescriptor* descC = [MPSMatrixDescriptor matrixDescriptorWithDimensions:M 73 columns:N 74 rowBytes:N * sizeof(dtype) 75 dataType:dtype]; 76 id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data); 77 MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; 78 // kernel 79 80 MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init]; 81 MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev 82 transposeLeft:transa 83 transposeRight:transb 84 resultRows:M 85 resultColumns:N 86 interiorColumns:K 87 alpha:1.0f 88 beta:0.0f]; 89 CHECK(sgemm != nil); 90 [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; 91 [cb commit]; 92}); 93 94} // namespace contrib 95} // namespace tvm 96