1 //===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/BuiltinTypes.h"
10 #include "mlir-c/AffineMap.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/CAPI/AffineMap.h"
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Types.h"
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // Integer types.
22 //===----------------------------------------------------------------------===//
23 
mlirTypeIsAInteger(MlirType type)24 bool mlirTypeIsAInteger(MlirType type) {
25   return unwrap(type).isa<IntegerType>();
26 }
27 
mlirIntegerTypeGet(MlirContext ctx,unsigned bitwidth)28 MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
29   return wrap(IntegerType::get(unwrap(ctx), bitwidth));
30 }
31 
mlirIntegerTypeSignedGet(MlirContext ctx,unsigned bitwidth)32 MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
33   return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
34 }
35 
mlirIntegerTypeUnsignedGet(MlirContext ctx,unsigned bitwidth)36 MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
37   return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
38 }
39 
mlirIntegerTypeGetWidth(MlirType type)40 unsigned mlirIntegerTypeGetWidth(MlirType type) {
41   return unwrap(type).cast<IntegerType>().getWidth();
42 }
43 
mlirIntegerTypeIsSignless(MlirType type)44 bool mlirIntegerTypeIsSignless(MlirType type) {
45   return unwrap(type).cast<IntegerType>().isSignless();
46 }
47 
mlirIntegerTypeIsSigned(MlirType type)48 bool mlirIntegerTypeIsSigned(MlirType type) {
49   return unwrap(type).cast<IntegerType>().isSigned();
50 }
51 
mlirIntegerTypeIsUnsigned(MlirType type)52 bool mlirIntegerTypeIsUnsigned(MlirType type) {
53   return unwrap(type).cast<IntegerType>().isUnsigned();
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Index type.
58 //===----------------------------------------------------------------------===//
59 
mlirTypeIsAIndex(MlirType type)60 bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
61 
mlirIndexTypeGet(MlirContext ctx)62 MlirType mlirIndexTypeGet(MlirContext ctx) {
63   return wrap(IndexType::get(unwrap(ctx)));
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // Floating-point types.
68 //===----------------------------------------------------------------------===//
69 
mlirTypeIsABF16(MlirType type)70 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
71 
mlirBF16TypeGet(MlirContext ctx)72 MlirType mlirBF16TypeGet(MlirContext ctx) {
73   return wrap(FloatType::getBF16(unwrap(ctx)));
74 }
75 
mlirTypeIsAF16(MlirType type)76 bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
77 
mlirF16TypeGet(MlirContext ctx)78 MlirType mlirF16TypeGet(MlirContext ctx) {
79   return wrap(FloatType::getF16(unwrap(ctx)));
80 }
81 
mlirTypeIsAF32(MlirType type)82 bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
83 
mlirF32TypeGet(MlirContext ctx)84 MlirType mlirF32TypeGet(MlirContext ctx) {
85   return wrap(FloatType::getF32(unwrap(ctx)));
86 }
87 
mlirTypeIsAF64(MlirType type)88 bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
89 
mlirF64TypeGet(MlirContext ctx)90 MlirType mlirF64TypeGet(MlirContext ctx) {
91   return wrap(FloatType::getF64(unwrap(ctx)));
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // None type.
96 //===----------------------------------------------------------------------===//
97 
mlirTypeIsANone(MlirType type)98 bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
99 
mlirNoneTypeGet(MlirContext ctx)100 MlirType mlirNoneTypeGet(MlirContext ctx) {
101   return wrap(NoneType::get(unwrap(ctx)));
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // Complex type.
106 //===----------------------------------------------------------------------===//
107 
mlirTypeIsAComplex(MlirType type)108 bool mlirTypeIsAComplex(MlirType type) {
109   return unwrap(type).isa<ComplexType>();
110 }
111 
mlirComplexTypeGet(MlirType elementType)112 MlirType mlirComplexTypeGet(MlirType elementType) {
113   return wrap(ComplexType::get(unwrap(elementType)));
114 }
115 
mlirComplexTypeGetElementType(MlirType type)116 MlirType mlirComplexTypeGetElementType(MlirType type) {
117   return wrap(unwrap(type).cast<ComplexType>().getElementType());
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Shaped type.
122 //===----------------------------------------------------------------------===//
123 
mlirTypeIsAShaped(MlirType type)124 bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
125 
mlirShapedTypeGetElementType(MlirType type)126 MlirType mlirShapedTypeGetElementType(MlirType type) {
127   return wrap(unwrap(type).cast<ShapedType>().getElementType());
128 }
129 
mlirShapedTypeHasRank(MlirType type)130 bool mlirShapedTypeHasRank(MlirType type) {
131   return unwrap(type).cast<ShapedType>().hasRank();
132 }
133 
mlirShapedTypeGetRank(MlirType type)134 int64_t mlirShapedTypeGetRank(MlirType type) {
135   return unwrap(type).cast<ShapedType>().getRank();
136 }
137 
mlirShapedTypeHasStaticShape(MlirType type)138 bool mlirShapedTypeHasStaticShape(MlirType type) {
139   return unwrap(type).cast<ShapedType>().hasStaticShape();
140 }
141 
mlirShapedTypeIsDynamicDim(MlirType type,intptr_t dim)142 bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
143   return unwrap(type).cast<ShapedType>().isDynamicDim(
144       static_cast<unsigned>(dim));
145 }
146 
mlirShapedTypeGetDimSize(MlirType type,intptr_t dim)147 int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
148   return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
149 }
150 
mlirShapedTypeIsDynamicSize(int64_t size)151 bool mlirShapedTypeIsDynamicSize(int64_t size) {
152   return ShapedType::isDynamic(size);
153 }
154 
mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)155 bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
156   return ShapedType::isDynamicStrideOrOffset(val);
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Vector type.
161 //===----------------------------------------------------------------------===//
162 
mlirTypeIsAVector(MlirType type)163 bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
164 
mlirVectorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType)165 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
166                            MlirType elementType) {
167   return wrap(
168       VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
169                       unwrap(elementType)));
170 }
171 
mlirVectorTypeGetChecked(intptr_t rank,const int64_t * shape,MlirType elementType,MlirLocation loc)172 MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
173                                   MlirType elementType, MlirLocation loc) {
174   return wrap(VectorType::getChecked(
175       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
176       unwrap(elementType)));
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Ranked / Unranked tensor type.
181 //===----------------------------------------------------------------------===//
182 
mlirTypeIsATensor(MlirType type)183 bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
184 
mlirTypeIsARankedTensor(MlirType type)185 bool mlirTypeIsARankedTensor(MlirType type) {
186   return unwrap(type).isa<RankedTensorType>();
187 }
188 
mlirTypeIsAUnrankedTensor(MlirType type)189 bool mlirTypeIsAUnrankedTensor(MlirType type) {
190   return unwrap(type).isa<UnrankedTensorType>();
191 }
192 
mlirRankedTensorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType)193 MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
194                                  MlirType elementType) {
195   return wrap(RankedTensorType::get(
196       llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
197       unwrap(elementType)));
198 }
199 
mlirRankedTensorTypeGetChecked(intptr_t rank,const int64_t * shape,MlirType elementType,MlirLocation loc)200 MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
201                                         MlirType elementType,
202                                         MlirLocation loc) {
203   return wrap(RankedTensorType::getChecked(
204       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
205       unwrap(elementType)));
206 }
207 
mlirUnrankedTensorTypeGet(MlirType elementType)208 MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
209   return wrap(UnrankedTensorType::get(unwrap(elementType)));
210 }
211 
mlirUnrankedTensorTypeGetChecked(MlirType elementType,MlirLocation loc)212 MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
213                                           MlirLocation loc) {
214   return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // Ranked / Unranked MemRef type.
219 //===----------------------------------------------------------------------===//
220 
mlirTypeIsAMemRef(MlirType type)221 bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
222 
mlirMemRefTypeGet(MlirType elementType,intptr_t rank,const int64_t * shape,intptr_t numMaps,MlirAffineMap const * affineMaps,unsigned memorySpace)223 MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
224                            const int64_t *shape, intptr_t numMaps,
225                            MlirAffineMap const *affineMaps,
226                            unsigned memorySpace) {
227   SmallVector<AffineMap, 1> maps;
228   (void)unwrapList(numMaps, affineMaps, maps);
229   return wrap(
230       MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
231                       unwrap(elementType), maps, memorySpace));
232 }
233 
mlirMemRefTypeGetChecked(MlirType elementType,intptr_t rank,const int64_t * shape,intptr_t numMaps,MlirAffineMap const * affineMaps,unsigned memorySpace,MlirLocation loc)234 MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank,
235                                   const int64_t *shape, intptr_t numMaps,
236                                   MlirAffineMap const *affineMaps,
237                                   unsigned memorySpace, MlirLocation loc) {
238   SmallVector<AffineMap, 1> maps;
239   (void)unwrapList(numMaps, affineMaps, maps);
240   return wrap(MemRefType::getChecked(
241       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
242       unwrap(elementType), maps, memorySpace));
243 }
244 
mlirMemRefTypeContiguousGet(MlirType elementType,intptr_t rank,const int64_t * shape,unsigned memorySpace)245 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
246                                      const int64_t *shape,
247                                      unsigned memorySpace) {
248   return wrap(
249       MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
250                       unwrap(elementType), llvm::None, memorySpace));
251 }
252 
mlirMemRefTypeContiguousGetChecked(MlirType elementType,intptr_t rank,const int64_t * shape,unsigned memorySpace,MlirLocation loc)253 MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
254                                             const int64_t *shape,
255                                             unsigned memorySpace,
256                                             MlirLocation loc) {
257   return wrap(MemRefType::getChecked(
258       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
259       unwrap(elementType), llvm::None, memorySpace));
260 }
261 
mlirMemRefTypeGetNumAffineMaps(MlirType type)262 intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
263   return static_cast<intptr_t>(
264       unwrap(type).cast<MemRefType>().getAffineMaps().size());
265 }
266 
mlirMemRefTypeGetAffineMap(MlirType type,intptr_t pos)267 MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
268   return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
269 }
270 
mlirMemRefTypeGetMemorySpace(MlirType type)271 unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
272   return unwrap(type).cast<MemRefType>().getMemorySpace();
273 }
274 
mlirTypeIsAUnrankedMemRef(MlirType type)275 bool mlirTypeIsAUnrankedMemRef(MlirType type) {
276   return unwrap(type).isa<UnrankedMemRefType>();
277 }
278 
mlirUnrankedMemRefTypeGet(MlirType elementType,unsigned memorySpace)279 MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
280   return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
281 }
282 
mlirUnrankedMemRefTypeGetChecked(MlirType elementType,unsigned memorySpace,MlirLocation loc)283 MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
284                                           unsigned memorySpace,
285                                           MlirLocation loc) {
286   return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
287                                              memorySpace));
288 }
289 
mlirUnrankedMemrefGetMemorySpace(MlirType type)290 unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
291   return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // Tuple type.
296 //===----------------------------------------------------------------------===//
297 
mlirTypeIsATuple(MlirType type)298 bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
299 
mlirTupleTypeGet(MlirContext ctx,intptr_t numElements,MlirType const * elements)300 MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
301                           MlirType const *elements) {
302   SmallVector<Type, 4> types;
303   ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
304   return wrap(TupleType::get(unwrap(ctx), typeRef));
305 }
306 
mlirTupleTypeGetNumTypes(MlirType type)307 intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
308   return unwrap(type).cast<TupleType>().size();
309 }
310 
mlirTupleTypeGetType(MlirType type,intptr_t pos)311 MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
312   return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Function type.
317 //===----------------------------------------------------------------------===//
318 
mlirTypeIsAFunction(MlirType type)319 bool mlirTypeIsAFunction(MlirType type) {
320   return unwrap(type).isa<FunctionType>();
321 }
322 
mlirFunctionTypeGet(MlirContext ctx,intptr_t numInputs,MlirType const * inputs,intptr_t numResults,MlirType const * results)323 MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
324                              MlirType const *inputs, intptr_t numResults,
325                              MlirType const *results) {
326   SmallVector<Type, 4> inputsList;
327   SmallVector<Type, 4> resultsList;
328   (void)unwrapList(numInputs, inputs, inputsList);
329   (void)unwrapList(numResults, results, resultsList);
330   return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
331 }
332 
mlirFunctionTypeGetNumInputs(MlirType type)333 intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
334   return unwrap(type).cast<FunctionType>().getNumInputs();
335 }
336 
mlirFunctionTypeGetNumResults(MlirType type)337 intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
338   return unwrap(type).cast<FunctionType>().getNumResults();
339 }
340 
mlirFunctionTypeGetInput(MlirType type,intptr_t pos)341 MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
342   assert(pos >= 0 && "pos in array must be positive");
343   return wrap(
344       unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
345 }
346 
mlirFunctionTypeGetResult(MlirType type,intptr_t pos)347 MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
348   assert(pos >= 0 && "pos in array must be positive");
349   return wrap(
350       unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
351 }
352