1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "llvm/AsmParser/Parser.h"
26 #include "llvm/IR/Attributes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/Support/SourceMgr.h"
30
31 using namespace mlir;
32 using namespace NVVM;
33
34 //===----------------------------------------------------------------------===//
35 // Printing/parsing for NVVM ops
36 //===----------------------------------------------------------------------===//
37
printNVVMIntrinsicOp(OpAsmPrinter & p,Operation * op)38 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
39 p << op->getName() << " " << op->getOperands();
40 if (op->getNumResults() > 0)
41 p << " : " << op->getResultTypes();
42 }
43
44 // <operation> ::= `llvm.nvvm.XYZ` : type
parseNVVMSpecialRegisterOp(OpAsmParser & parser,OperationState & result)45 static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser &parser,
46 OperationState &result) {
47 Type type;
48 if (parser.parseOptionalAttrDict(result.attributes) ||
49 parser.parseColonType(type))
50 return failure();
51
52 result.addTypes(type);
53 return success();
54 }
55
getLlvmDialect(OpAsmParser & parser)56 static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
57 return parser.getBuilder()
58 .getContext()
59 ->getRegisteredDialect<LLVM::LLVMDialect>();
60 }
61
62 // <operation> ::=
63 // `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
64 // ({return_value_and_is_valid})? : result_type
parseNVVMShflSyncBflyOp(OpAsmParser & parser,OperationState & result)65 static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
66 OperationState &result) {
67 SmallVector<OpAsmParser::OperandType, 8> ops;
68 Type resultType;
69 if (parser.parseOperandList(ops) ||
70 parser.parseOptionalAttrDict(result.attributes) ||
71 parser.parseColonType(resultType) ||
72 parser.addTypeToList(resultType, result.types))
73 return failure();
74
75 auto type = resultType.cast<LLVM::LLVMType>();
76 for (auto &attr : result.attributes) {
77 if (attr.first != "return_value_and_is_valid")
78 continue;
79 if (type.isStructTy() && type.getStructNumElements() > 0)
80 type = type.getStructElementType(0);
81 break;
82 }
83
84 auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
85 return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
86 parser.getNameLoc(), result.operands);
87 }
88
89 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
parseNVVMVoteBallotOp(OpAsmParser & parser,OperationState & result)90 static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
91 OperationState &result) {
92 auto llvmDialect = getLlvmDialect(parser);
93 auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
94 auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
95
96 SmallVector<OpAsmParser::OperandType, 8> ops;
97 Type type;
98 return failure(parser.parseOperandList(ops) ||
99 parser.parseOptionalAttrDict(result.attributes) ||
100 parser.parseColonType(type) ||
101 parser.addTypeToList(type, result.types) ||
102 parser.resolveOperands(ops, {int32Ty, int1Ty},
103 parser.getNameLoc(), result.operands));
104 }
105
106 // <operation> ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...`
107 // : signature_type
parseNVVMMmaOp(OpAsmParser & parser,OperationState & result)108 static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
109 SmallVector<OpAsmParser::OperandType, 12> ops;
110 Type type;
111 llvm::SMLoc typeLoc;
112 if (parser.parseOperandList(ops) ||
113 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
114 parser.getCurrentLocation(&typeLoc) || parser.parseType(type)) {
115 return failure();
116 }
117
118 auto signature = type.dyn_cast<FunctionType>();
119 if (!signature) {
120 return parser.emitError(
121 typeLoc, "expected the type to be the full list of input and output");
122 }
123
124 if (signature.getNumResults() != 1) {
125 return parser.emitError(typeLoc, "expected single result");
126 }
127
128 return failure(parser.addTypeToList(signature.getResult(0), result.types) ||
129 parser.resolveOperands(ops, signature.getInputs(),
130 parser.getNameLoc(), result.operands));
131 }
132
printNVVMMmaOp(OpAsmPrinter & p,MmaOp & op)133 static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
134 p << op.getOperationName() << " " << op.getOperands();
135 p.printOptionalAttrDict(op.getAttrs());
136 p << " : "
137 << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
138 op.getType(), op.getContext());
139 }
140
verify(MmaOp op)141 static LogicalResult verify(MmaOp op) {
142 auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
143 auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
144 auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
145 auto f32Ty = LLVM::LLVMType::getFloatTy(dialect);
146 auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
147 dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
148 auto f32x8StructTy = LLVM::LLVMType::getStructTy(
149 dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
150
151 SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
152 op.getOperandTypes().end());
153 if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
154 operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
155 f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
156 f32Ty, f32Ty, f32Ty}) {
157 return op.emitOpError(
158 "expected operands to be 4 <halfx2>s followed by either "
159 "4 <halfx2>s or 8 floats");
160 }
161 if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
162 return op.emitOpError("expected result type to be a struct of either 4 "
163 "<halfx2>s or 8 floats");
164 }
165
166 auto alayout = op.getAttrOfType<StringAttr>("alayout");
167 auto blayout = op.getAttrOfType<StringAttr>("blayout");
168
169 if (!(alayout && blayout) ||
170 !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
171 !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
172 return op.emitOpError(
173 "alayout and blayout attributes must be set to either "
174 "\"row\" or \"col\"");
175 }
176
177 if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
178 f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
179 f32Ty, f32Ty, f32Ty} &&
180 op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
181 blayout.getValue() == "row") {
182 return success();
183 }
184 return op.emitOpError("unimplemented mma.sync variant");
185 }
186
187 //===----------------------------------------------------------------------===//
188 // NVVMDialect initialization, type parsing, and registration.
189 //===----------------------------------------------------------------------===//
190
191 // TODO(herhut): This should be the llvm.nvvm dialect once this is supported.
NVVMDialect(MLIRContext * context)192 NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
193 addOperations<
194 #define GET_OP_LIST
195 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
196 >();
197
198 // Support unknown operations because not all NVVM operations are registered.
199 allowUnknownOperations();
200 }
201
202 namespace mlir {
203 namespace NVVM {
204 #define GET_OP_CLASSES
205 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
206 } // namespace NVVM
207 } // namespace mlir
208
209 static DialectRegistration<NVVMDialect> nvvmDialect;
210