1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Module.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/Parser.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Context API.
26 //===----------------------------------------------------------------------===//
27 
mlirContextCreate()28 MlirContext mlirContextCreate() {
29   auto *context = new MLIRContext;
30   return wrap(context);
31 }
32 
mlirContextEqual(MlirContext ctx1,MlirContext ctx2)33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
34   return unwrap(ctx1) == unwrap(ctx2);
35 }
36 
mlirContextDestroy(MlirContext context)37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
38 
mlirContextSetAllowUnregisteredDialects(MlirContext context,int allow)39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
40   unwrap(context)->allowUnregisteredDialects(allow);
41 }
42 
mlirContextGetAllowUnregisteredDialects(MlirContext context)43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
44   return unwrap(context)->allowsUnregisteredDialects();
45 }
mlirContextGetNumRegisteredDialects(MlirContext context)46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
47   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
48 }
49 
50 // TODO: expose a cheaper way than constructing + sorting a vector only to take
51 // its size.
mlirContextGetNumLoadedDialects(MlirContext context)52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
53   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
54 }
55 
mlirContextGetOrLoadDialect(MlirContext context,MlirStringRef name)56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
57                                         MlirStringRef name) {
58   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // Dialect API.
63 //===----------------------------------------------------------------------===//
64 
mlirDialectGetContext(MlirDialect dialect)65 MlirContext mlirDialectGetContext(MlirDialect dialect) {
66   return wrap(unwrap(dialect)->getContext());
67 }
68 
mlirDialectEqual(MlirDialect dialect1,MlirDialect dialect2)69 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
70   return unwrap(dialect1) == unwrap(dialect2);
71 }
72 
mlirDialectGetNamespace(MlirDialect dialect)73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
74   return wrap(unwrap(dialect)->getNamespace());
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Printing flags API.
79 //===----------------------------------------------------------------------===//
80 
mlirOpPrintingFlagsCreate()81 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
82   return wrap(new OpPrintingFlags());
83 }
84 
mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)85 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
86   delete unwrap(flags);
87 }
88 
mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,intptr_t largeElementLimit)89 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
90                                                 intptr_t largeElementLimit) {
91   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
92 }
93 
mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,int prettyForm)94 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
95                                         int prettyForm) {
96   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
97 }
98 
mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)99 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
100   unwrap(flags)->printGenericOpForm();
101 }
102 
mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)103 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
104   unwrap(flags)->useLocalScope();
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Location API.
109 //===----------------------------------------------------------------------===//
110 
mlirLocationFileLineColGet(MlirContext context,const char * filename,unsigned line,unsigned col)111 MlirLocation mlirLocationFileLineColGet(MlirContext context,
112                                         const char *filename, unsigned line,
113                                         unsigned col) {
114   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
115 }
116 
mlirLocationUnknownGet(MlirContext context)117 MlirLocation mlirLocationUnknownGet(MlirContext context) {
118   return wrap(UnknownLoc::get(unwrap(context)));
119 }
120 
mlirLocationGetContext(MlirLocation location)121 MlirContext mlirLocationGetContext(MlirLocation location) {
122   return wrap(unwrap(location).getContext());
123 }
124 
mlirLocationPrint(MlirLocation location,MlirStringCallback callback,void * userData)125 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
126                        void *userData) {
127   detail::CallbackOstream stream(callback, userData);
128   unwrap(location).print(stream);
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // Module API.
133 //===----------------------------------------------------------------------===//
134 
mlirModuleCreateEmpty(MlirLocation location)135 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
136   return wrap(ModuleOp::create(unwrap(location)));
137 }
138 
mlirModuleCreateParse(MlirContext context,const char * module)139 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
140   OwningModuleRef owning = parseSourceString(module, unwrap(context));
141   if (!owning)
142     return MlirModule{nullptr};
143   return MlirModule{owning.release().getOperation()};
144 }
145 
mlirModuleGetContext(MlirModule module)146 MlirContext mlirModuleGetContext(MlirModule module) {
147   return wrap(unwrap(module).getContext());
148 }
149 
mlirModuleGetBody(MlirModule module)150 MlirBlock mlirModuleGetBody(MlirModule module) {
151   return wrap(unwrap(module).getBody());
152 }
153 
mlirModuleDestroy(MlirModule module)154 void mlirModuleDestroy(MlirModule module) {
155   // Transfer ownership to an OwningModuleRef so that its destructor is called.
156   OwningModuleRef(unwrap(module));
157 }
158 
mlirModuleGetOperation(MlirModule module)159 MlirOperation mlirModuleGetOperation(MlirModule module) {
160   return wrap(unwrap(module).getOperation());
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // Operation state API.
165 //===----------------------------------------------------------------------===//
166 
mlirOperationStateGet(const char * name,MlirLocation loc)167 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
168   MlirOperationState state;
169   state.name = name;
170   state.location = loc;
171   state.nResults = 0;
172   state.results = nullptr;
173   state.nOperands = 0;
174   state.operands = nullptr;
175   state.nRegions = 0;
176   state.regions = nullptr;
177   state.nSuccessors = 0;
178   state.successors = nullptr;
179   state.nAttributes = 0;
180   state.attributes = nullptr;
181   return state;
182 }
183 
184 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
185   state->elemName =                                                            \
186       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
187   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
188   state->sizeName += n;
189 
mlirOperationStateAddResults(MlirOperationState * state,intptr_t n,MlirType * results)190 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
191                                   MlirType *results) {
192   APPEND_ELEMS(MlirType, nResults, results);
193 }
194 
mlirOperationStateAddOperands(MlirOperationState * state,intptr_t n,MlirValue * operands)195 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
196                                    MlirValue *operands) {
197   APPEND_ELEMS(MlirValue, nOperands, operands);
198 }
mlirOperationStateAddOwnedRegions(MlirOperationState * state,intptr_t n,MlirRegion * regions)199 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
200                                        MlirRegion *regions) {
201   APPEND_ELEMS(MlirRegion, nRegions, regions);
202 }
mlirOperationStateAddSuccessors(MlirOperationState * state,intptr_t n,MlirBlock * successors)203 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
204                                      MlirBlock *successors) {
205   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
206 }
mlirOperationStateAddAttributes(MlirOperationState * state,intptr_t n,MlirNamedAttribute * attributes)207 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
208                                      MlirNamedAttribute *attributes) {
209   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // Operation API.
214 //===----------------------------------------------------------------------===//
215 
mlirOperationCreate(const MlirOperationState * state)216 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
217   assert(state);
218   OperationState cppState(unwrap(state->location), state->name);
219   SmallVector<Type, 4> resultStorage;
220   SmallVector<Value, 8> operandStorage;
221   SmallVector<Block *, 2> successorStorage;
222   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
223   cppState.addOperands(
224       unwrapList(state->nOperands, state->operands, operandStorage));
225   cppState.addSuccessors(
226       unwrapList(state->nSuccessors, state->successors, successorStorage));
227 
228   cppState.attributes.reserve(state->nAttributes);
229   for (intptr_t i = 0; i < state->nAttributes; ++i)
230     cppState.addAttribute(state->attributes[i].name,
231                           unwrap(state->attributes[i].attribute));
232 
233   for (intptr_t i = 0; i < state->nRegions; ++i)
234     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
235 
236   MlirOperation result = wrap(Operation::create(cppState));
237   free(state->results);
238   free(state->operands);
239   free(state->successors);
240   free(state->regions);
241   free(state->attributes);
242   return result;
243 }
244 
mlirOperationDestroy(MlirOperation op)245 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
246 
mlirOperationEqual(MlirOperation op,MlirOperation other)247 int mlirOperationEqual(MlirOperation op, MlirOperation other) {
248   return unwrap(op) == unwrap(other);
249 }
250 
mlirOperationGetName(MlirOperation op)251 MlirIdentifier mlirOperationGetName(MlirOperation op) {
252   return wrap(unwrap(op)->getName().getIdentifier());
253 }
254 
mlirOperationGetBlock(MlirOperation op)255 MlirBlock mlirOperationGetBlock(MlirOperation op) {
256   return wrap(unwrap(op)->getBlock());
257 }
258 
mlirOperationGetParentOperation(MlirOperation op)259 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
260   return wrap(unwrap(op)->getParentOp());
261 }
262 
mlirOperationGetNumRegions(MlirOperation op)263 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
264   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
265 }
266 
mlirOperationGetRegion(MlirOperation op,intptr_t pos)267 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
268   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
269 }
270 
mlirOperationGetNextInBlock(MlirOperation op)271 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
272   return wrap(unwrap(op)->getNextNode());
273 }
274 
mlirOperationGetNumOperands(MlirOperation op)275 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
276   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
277 }
278 
mlirOperationGetOperand(MlirOperation op,intptr_t pos)279 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
280   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
281 }
282 
mlirOperationGetNumResults(MlirOperation op)283 intptr_t mlirOperationGetNumResults(MlirOperation op) {
284   return static_cast<intptr_t>(unwrap(op)->getNumResults());
285 }
286 
mlirOperationGetResult(MlirOperation op,intptr_t pos)287 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
288   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
289 }
290 
mlirOperationGetNumSuccessors(MlirOperation op)291 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
292   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
293 }
294 
mlirOperationGetSuccessor(MlirOperation op,intptr_t pos)295 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
296   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
297 }
298 
mlirOperationGetNumAttributes(MlirOperation op)299 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
300   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
301 }
302 
mlirOperationGetAttribute(MlirOperation op,intptr_t pos)303 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
304   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
305   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
306 }
307 
mlirOperationGetAttributeByName(MlirOperation op,const char * name)308 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
309                                               const char *name) {
310   return wrap(unwrap(op)->getAttr(name));
311 }
312 
mlirOperationSetAttributeByName(MlirOperation op,const char * name,MlirAttribute attr)313 void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
314                                      MlirAttribute attr) {
315   unwrap(op)->setAttr(name, unwrap(attr));
316 }
317 
mlirOperationRemoveAttributeByName(MlirOperation op,const char * name)318 int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) {
319   auto removeResult = unwrap(op)->removeAttr(name);
320   return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
321 }
322 
mlirOperationPrint(MlirOperation op,MlirStringCallback callback,void * userData)323 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
324                         void *userData) {
325   detail::CallbackOstream stream(callback, userData);
326   unwrap(op)->print(stream);
327 }
328 
mlirOperationPrintWithFlags(MlirOperation op,MlirOpPrintingFlags flags,MlirStringCallback callback,void * userData)329 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
330                                  MlirStringCallback callback, void *userData) {
331   detail::CallbackOstream stream(callback, userData);
332   unwrap(op)->print(stream, *unwrap(flags));
333 }
334 
mlirOperationDump(MlirOperation op)335 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
336 
337 //===----------------------------------------------------------------------===//
338 // Region API.
339 //===----------------------------------------------------------------------===//
340 
mlirRegionCreate()341 MlirRegion mlirRegionCreate() { return wrap(new Region); }
342 
mlirRegionGetFirstBlock(MlirRegion region)343 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
344   Region *cppRegion = unwrap(region);
345   if (cppRegion->empty())
346     return wrap(static_cast<Block *>(nullptr));
347   return wrap(&cppRegion->front());
348 }
349 
mlirRegionAppendOwnedBlock(MlirRegion region,MlirBlock block)350 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
351   unwrap(region)->push_back(unwrap(block));
352 }
353 
mlirRegionInsertOwnedBlock(MlirRegion region,intptr_t pos,MlirBlock block)354 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
355                                 MlirBlock block) {
356   auto &blockList = unwrap(region)->getBlocks();
357   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
358 }
359 
mlirRegionInsertOwnedBlockAfter(MlirRegion region,MlirBlock reference,MlirBlock block)360 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
361                                      MlirBlock block) {
362   Region *cppRegion = unwrap(region);
363   if (mlirBlockIsNull(reference)) {
364     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
365     return;
366   }
367 
368   assert(unwrap(reference)->getParent() == unwrap(region) &&
369          "expected reference block to belong to the region");
370   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
371                                      unwrap(block));
372 }
373 
mlirRegionInsertOwnedBlockBefore(MlirRegion region,MlirBlock reference,MlirBlock block)374 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
375                                       MlirBlock block) {
376   if (mlirBlockIsNull(reference))
377     return mlirRegionAppendOwnedBlock(region, block);
378 
379   assert(unwrap(reference)->getParent() == unwrap(region) &&
380          "expected reference block to belong to the region");
381   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
382                                      unwrap(block));
383 }
384 
mlirRegionDestroy(MlirRegion region)385 void mlirRegionDestroy(MlirRegion region) {
386   delete static_cast<Region *>(region.ptr);
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // Block API.
391 //===----------------------------------------------------------------------===//
392 
mlirBlockCreate(intptr_t nArgs,MlirType * args)393 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
394   Block *b = new Block;
395   for (intptr_t i = 0; i < nArgs; ++i)
396     b->addArgument(unwrap(args[i]));
397   return wrap(b);
398 }
399 
mlirBlockEqual(MlirBlock block,MlirBlock other)400 int mlirBlockEqual(MlirBlock block, MlirBlock other) {
401   return unwrap(block) == unwrap(other);
402 }
403 
mlirBlockGetNextInRegion(MlirBlock block)404 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
405   return wrap(unwrap(block)->getNextNode());
406 }
407 
mlirBlockGetFirstOperation(MlirBlock block)408 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
409   Block *cppBlock = unwrap(block);
410   if (cppBlock->empty())
411     return wrap(static_cast<Operation *>(nullptr));
412   return wrap(&cppBlock->front());
413 }
414 
mlirBlockGetTerminator(MlirBlock block)415 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
416   Block *cppBlock = unwrap(block);
417   if (cppBlock->empty())
418     return wrap(static_cast<Operation *>(nullptr));
419   Operation &back = cppBlock->back();
420   if (!back.isKnownTerminator())
421     return wrap(static_cast<Operation *>(nullptr));
422   return wrap(&back);
423 }
424 
mlirBlockAppendOwnedOperation(MlirBlock block,MlirOperation operation)425 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
426   unwrap(block)->push_back(unwrap(operation));
427 }
428 
mlirBlockInsertOwnedOperation(MlirBlock block,intptr_t pos,MlirOperation operation)429 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
430                                    MlirOperation operation) {
431   auto &opList = unwrap(block)->getOperations();
432   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
433 }
434 
mlirBlockInsertOwnedOperationAfter(MlirBlock block,MlirOperation reference,MlirOperation operation)435 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
436                                         MlirOperation reference,
437                                         MlirOperation operation) {
438   Block *cppBlock = unwrap(block);
439   if (mlirOperationIsNull(reference)) {
440     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
441     return;
442   }
443 
444   assert(unwrap(reference)->getBlock() == unwrap(block) &&
445          "expected reference operation to belong to the block");
446   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
447                                         unwrap(operation));
448 }
449 
mlirBlockInsertOwnedOperationBefore(MlirBlock block,MlirOperation reference,MlirOperation operation)450 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
451                                          MlirOperation reference,
452                                          MlirOperation operation) {
453   if (mlirOperationIsNull(reference))
454     return mlirBlockAppendOwnedOperation(block, operation);
455 
456   assert(unwrap(reference)->getBlock() == unwrap(block) &&
457          "expected reference operation to belong to the block");
458   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
459                                         unwrap(operation));
460 }
461 
mlirBlockDestroy(MlirBlock block)462 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
463 
mlirBlockGetNumArguments(MlirBlock block)464 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
465   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
466 }
467 
mlirBlockGetArgument(MlirBlock block,intptr_t pos)468 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
469   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
470 }
471 
mlirBlockPrint(MlirBlock block,MlirStringCallback callback,void * userData)472 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
473                     void *userData) {
474   detail::CallbackOstream stream(callback, userData);
475   unwrap(block)->print(stream);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // Value API.
480 //===----------------------------------------------------------------------===//
481 
mlirValueEqual(MlirValue value1,MlirValue value2)482 int mlirValueEqual(MlirValue value1, MlirValue value2) {
483   return unwrap(value1) == unwrap(value2);
484 }
485 
mlirValueIsABlockArgument(MlirValue value)486 int mlirValueIsABlockArgument(MlirValue value) {
487   return unwrap(value).isa<BlockArgument>();
488 }
489 
mlirValueIsAOpResult(MlirValue value)490 int mlirValueIsAOpResult(MlirValue value) {
491   return unwrap(value).isa<OpResult>();
492 }
493 
mlirBlockArgumentGetOwner(MlirValue value)494 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
495   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
496 }
497 
mlirBlockArgumentGetArgNumber(MlirValue value)498 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
499   return static_cast<intptr_t>(
500       unwrap(value).cast<BlockArgument>().getArgNumber());
501 }
502 
mlirBlockArgumentSetType(MlirValue value,MlirType type)503 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
504   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
505 }
506 
mlirOpResultGetOwner(MlirValue value)507 MlirOperation mlirOpResultGetOwner(MlirValue value) {
508   return wrap(unwrap(value).cast<OpResult>().getOwner());
509 }
510 
mlirOpResultGetResultNumber(MlirValue value)511 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
512   return static_cast<intptr_t>(
513       unwrap(value).cast<OpResult>().getResultNumber());
514 }
515 
mlirValueGetType(MlirValue value)516 MlirType mlirValueGetType(MlirValue value) {
517   return wrap(unwrap(value).getType());
518 }
519 
mlirValueDump(MlirValue value)520 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
521 
mlirValuePrint(MlirValue value,MlirStringCallback callback,void * userData)522 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
523                     void *userData) {
524   detail::CallbackOstream stream(callback, userData);
525   unwrap(value).print(stream);
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // Type API.
530 //===----------------------------------------------------------------------===//
531 
mlirTypeParseGet(MlirContext context,const char * type)532 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
533   return wrap(mlir::parseType(type, unwrap(context)));
534 }
535 
mlirTypeGetContext(MlirType type)536 MlirContext mlirTypeGetContext(MlirType type) {
537   return wrap(unwrap(type).getContext());
538 }
539 
mlirTypeEqual(MlirType t1,MlirType t2)540 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
541 
mlirTypePrint(MlirType type,MlirStringCallback callback,void * userData)542 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
543   detail::CallbackOstream stream(callback, userData);
544   unwrap(type).print(stream);
545 }
546 
mlirTypeDump(MlirType type)547 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
548 
549 //===----------------------------------------------------------------------===//
550 // Attribute API.
551 //===----------------------------------------------------------------------===//
552 
mlirAttributeParseGet(MlirContext context,const char * attr)553 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
554   return wrap(mlir::parseAttribute(attr, unwrap(context)));
555 }
556 
mlirAttributeGetContext(MlirAttribute attribute)557 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
558   return wrap(unwrap(attribute).getContext());
559 }
560 
mlirAttributeGetType(MlirAttribute attribute)561 MlirType mlirAttributeGetType(MlirAttribute attribute) {
562   return wrap(unwrap(attribute).getType());
563 }
564 
mlirAttributeEqual(MlirAttribute a1,MlirAttribute a2)565 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
566   return unwrap(a1) == unwrap(a2);
567 }
568 
mlirAttributePrint(MlirAttribute attr,MlirStringCallback callback,void * userData)569 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
570                         void *userData) {
571   detail::CallbackOstream stream(callback, userData);
572   unwrap(attr).print(stream);
573 }
574 
mlirAttributeDump(MlirAttribute attr)575 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
576 
mlirNamedAttributeGet(const char * name,MlirAttribute attr)577 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
578   return MlirNamedAttribute{name, attr};
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // Identifier API.
583 //===----------------------------------------------------------------------===//
584 
mlirIdentifierGet(MlirContext context,MlirStringRef str)585 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
586   return wrap(Identifier::get(unwrap(str), unwrap(context)));
587 }
588 
mlirIdentifierEqual(MlirIdentifier ident,MlirIdentifier other)589 int mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
590   return unwrap(ident) == unwrap(other);
591 }
592 
mlirIdentifierStr(MlirIdentifier ident)593 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
594   return wrap(unwrap(ident).strref());
595 }
596