1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "llvm/AsmParser/Parser.h"
26 #include "llvm/IR/Attributes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/Support/SourceMgr.h"
30 
31 using namespace mlir;
32 using namespace NVVM;
33 
34 //===----------------------------------------------------------------------===//
35 // Printing/parsing for NVVM ops
36 //===----------------------------------------------------------------------===//
37 
printNVVMIntrinsicOp(OpAsmPrinter & p,Operation * op)38 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
39   p << op->getName() << " " << op->getOperands();
40   if (op->getNumResults() > 0)
41     p << " : " << op->getResultTypes();
42 }
43 
44 // <operation> ::= `llvm.nvvm.XYZ` : type
parseNVVMSpecialRegisterOp(OpAsmParser & parser,OperationState & result)45 static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser &parser,
46                                               OperationState &result) {
47   Type type;
48   if (parser.parseOptionalAttrDict(result.attributes) ||
49       parser.parseColonType(type))
50     return failure();
51 
52   result.addTypes(type);
53   return success();
54 }
55 
getLlvmDialect(OpAsmParser & parser)56 static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
57   return parser.getBuilder()
58       .getContext()
59       ->getRegisteredDialect<LLVM::LLVMDialect>();
60 }
61 
62 // <operation> ::=
63 //     `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
64 //      ({return_value_and_is_valid})? : result_type
parseNVVMShflSyncBflyOp(OpAsmParser & parser,OperationState & result)65 static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
66                                            OperationState &result) {
67   SmallVector<OpAsmParser::OperandType, 8> ops;
68   Type resultType;
69   if (parser.parseOperandList(ops) ||
70       parser.parseOptionalAttrDict(result.attributes) ||
71       parser.parseColonType(resultType) ||
72       parser.addTypeToList(resultType, result.types))
73     return failure();
74 
75   auto type = resultType.cast<LLVM::LLVMType>();
76   for (auto &attr : result.attributes) {
77     if (attr.first != "return_value_and_is_valid")
78       continue;
79     if (type.isStructTy() && type.getStructNumElements() > 0)
80       type = type.getStructElementType(0);
81     break;
82   }
83 
84   auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
85   return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
86                                 parser.getNameLoc(), result.operands);
87 }
88 
89 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
parseNVVMVoteBallotOp(OpAsmParser & parser,OperationState & result)90 static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
91                                          OperationState &result) {
92   auto llvmDialect = getLlvmDialect(parser);
93   auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
94   auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
95 
96   SmallVector<OpAsmParser::OperandType, 8> ops;
97   Type type;
98   return failure(parser.parseOperandList(ops) ||
99                  parser.parseOptionalAttrDict(result.attributes) ||
100                  parser.parseColonType(type) ||
101                  parser.addTypeToList(type, result.types) ||
102                  parser.resolveOperands(ops, {int32Ty, int1Ty},
103                                         parser.getNameLoc(), result.operands));
104 }
105 
106 // <operation> ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...`
107 //                 : signature_type
parseNVVMMmaOp(OpAsmParser & parser,OperationState & result)108 static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
109   SmallVector<OpAsmParser::OperandType, 12> ops;
110   Type type;
111   llvm::SMLoc typeLoc;
112   if (parser.parseOperandList(ops) ||
113       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
114       parser.getCurrentLocation(&typeLoc) || parser.parseType(type)) {
115     return failure();
116   }
117 
118   auto signature = type.dyn_cast<FunctionType>();
119   if (!signature) {
120     return parser.emitError(
121         typeLoc, "expected the type to be the full list of input and output");
122   }
123 
124   if (signature.getNumResults() != 1) {
125     return parser.emitError(typeLoc, "expected single result");
126   }
127 
128   return failure(parser.addTypeToList(signature.getResult(0), result.types) ||
129                  parser.resolveOperands(ops, signature.getInputs(),
130                                         parser.getNameLoc(), result.operands));
131 }
132 
printNVVMMmaOp(OpAsmPrinter & p,MmaOp & op)133 static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
134   p << op.getOperationName() << " " << op.getOperands();
135   p.printOptionalAttrDict(op.getAttrs());
136   p << " : "
137     << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
138                          op.getType(), op.getContext());
139 }
140 
verify(MmaOp op)141 static LogicalResult verify(MmaOp op) {
142   auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
143   auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
144   auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
145   auto f32Ty = LLVM::LLVMType::getFloatTy(dialect);
146   auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
147       dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
148   auto f32x8StructTy = LLVM::LLVMType::getStructTy(
149       dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
150 
151   SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
152                                       op.getOperandTypes().end());
153   if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
154       operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
155                                              f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
156                                              f32Ty, f32Ty, f32Ty}) {
157     return op.emitOpError(
158         "expected operands to be 4 <halfx2>s followed by either "
159         "4 <halfx2>s or 8 floats");
160   }
161   if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
162     return op.emitOpError("expected result type to be a struct of either 4 "
163                           "<halfx2>s or 8 floats");
164   }
165 
166   auto alayout = op.getAttrOfType<StringAttr>("alayout");
167   auto blayout = op.getAttrOfType<StringAttr>("blayout");
168 
169   if (!(alayout && blayout) ||
170       !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
171       !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
172     return op.emitOpError(
173         "alayout and blayout attributes must be set to either "
174         "\"row\" or \"col\"");
175   }
176 
177   if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
178                                              f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
179                                              f32Ty, f32Ty, f32Ty} &&
180       op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
181       blayout.getValue() == "row") {
182     return success();
183   }
184   return op.emitOpError("unimplemented mma.sync variant");
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // NVVMDialect initialization, type parsing, and registration.
189 //===----------------------------------------------------------------------===//
190 
191 // TODO(herhut): This should be the llvm.nvvm dialect once this is supported.
NVVMDialect(MLIRContext * context)192 NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
193   addOperations<
194 #define GET_OP_LIST
195 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
196       >();
197 
198   // Support unknown operations because not all NVVM operations are registered.
199   allowUnknownOperations();
200 }
201 
202 namespace mlir {
203 namespace NVVM {
204 #define GET_OP_CLASSES
205 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
206 } // namespace NVVM
207 } // namespace mlir
208 
209 static DialectRegistration<NVVMDialect> nvvmDialect;
210