1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/IR/Verifier.h"
21 #include "mlir/Interfaces/InferTypeOpInterface.h"
22 #include "mlir/Parser.h"
23
24 #include "llvm/Support/Debug.h"
25
26 using namespace mlir;
27
28 //===----------------------------------------------------------------------===//
29 // Context API.
30 //===----------------------------------------------------------------------===//
31
mlirContextCreate()32 MlirContext mlirContextCreate() {
33 auto *context = new MLIRContext;
34 return wrap(context);
35 }
36
mlirContextEqual(MlirContext ctx1,MlirContext ctx2)37 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
38 return unwrap(ctx1) == unwrap(ctx2);
39 }
40
mlirContextDestroy(MlirContext context)41 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
42
mlirContextSetAllowUnregisteredDialects(MlirContext context,bool allow)43 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
44 unwrap(context)->allowUnregisteredDialects(allow);
45 }
46
mlirContextGetAllowUnregisteredDialects(MlirContext context)47 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
48 return unwrap(context)->allowsUnregisteredDialects();
49 }
mlirContextGetNumRegisteredDialects(MlirContext context)50 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
51 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
52 }
53
54 // TODO: expose a cheaper way than constructing + sorting a vector only to take
55 // its size.
mlirContextGetNumLoadedDialects(MlirContext context)56 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
57 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
58 }
59
mlirContextGetOrLoadDialect(MlirContext context,MlirStringRef name)60 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
61 MlirStringRef name) {
62 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
63 }
64
mlirContextIsRegisteredOperation(MlirContext context,MlirStringRef name)65 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
66 return unwrap(context)->isOperationRegistered(unwrap(name));
67 }
68
mlirContextEnableMultithreading(MlirContext context,bool enable)69 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
70 return unwrap(context)->enableMultithreading(enable);
71 }
72
73 //===----------------------------------------------------------------------===//
74 // Dialect API.
75 //===----------------------------------------------------------------------===//
76
mlirDialectGetContext(MlirDialect dialect)77 MlirContext mlirDialectGetContext(MlirDialect dialect) {
78 return wrap(unwrap(dialect)->getContext());
79 }
80
mlirDialectEqual(MlirDialect dialect1,MlirDialect dialect2)81 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
82 return unwrap(dialect1) == unwrap(dialect2);
83 }
84
mlirDialectGetNamespace(MlirDialect dialect)85 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
86 return wrap(unwrap(dialect)->getNamespace());
87 }
88
89 //===----------------------------------------------------------------------===//
90 // Printing flags API.
91 //===----------------------------------------------------------------------===//
92
mlirOpPrintingFlagsCreate()93 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
94 return wrap(new OpPrintingFlags());
95 }
96
mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)97 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
98 delete unwrap(flags);
99 }
100
mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,intptr_t largeElementLimit)101 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
102 intptr_t largeElementLimit) {
103 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
104 }
105
mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,bool prettyForm)106 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
107 bool prettyForm) {
108 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
109 }
110
mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)111 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
112 unwrap(flags)->printGenericOpForm();
113 }
114
mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)115 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
116 unwrap(flags)->useLocalScope();
117 }
118
119 //===----------------------------------------------------------------------===//
120 // Location API.
121 //===----------------------------------------------------------------------===//
122
mlirLocationFileLineColGet(MlirContext context,MlirStringRef filename,unsigned line,unsigned col)123 MlirLocation mlirLocationFileLineColGet(MlirContext context,
124 MlirStringRef filename, unsigned line,
125 unsigned col) {
126 return wrap(Location(
127 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
128 }
129
mlirLocationCallSiteGet(MlirLocation callee,MlirLocation caller)130 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
131 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
132 }
133
mlirLocationUnknownGet(MlirContext context)134 MlirLocation mlirLocationUnknownGet(MlirContext context) {
135 return wrap(Location(UnknownLoc::get(unwrap(context))));
136 }
137
mlirLocationEqual(MlirLocation l1,MlirLocation l2)138 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
139 return unwrap(l1) == unwrap(l2);
140 }
141
mlirLocationGetContext(MlirLocation location)142 MlirContext mlirLocationGetContext(MlirLocation location) {
143 return wrap(unwrap(location).getContext());
144 }
145
mlirLocationPrint(MlirLocation location,MlirStringCallback callback,void * userData)146 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
147 void *userData) {
148 detail::CallbackOstream stream(callback, userData);
149 unwrap(location).print(stream);
150 }
151
152 //===----------------------------------------------------------------------===//
153 // Module API.
154 //===----------------------------------------------------------------------===//
155
mlirModuleCreateEmpty(MlirLocation location)156 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
157 return wrap(ModuleOp::create(unwrap(location)));
158 }
159
mlirModuleCreateParse(MlirContext context,MlirStringRef module)160 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
161 OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context));
162 if (!owning)
163 return MlirModule{nullptr};
164 return MlirModule{owning.release().getOperation()};
165 }
166
mlirModuleGetContext(MlirModule module)167 MlirContext mlirModuleGetContext(MlirModule module) {
168 return wrap(unwrap(module).getContext());
169 }
170
mlirModuleGetBody(MlirModule module)171 MlirBlock mlirModuleGetBody(MlirModule module) {
172 return wrap(unwrap(module).getBody());
173 }
174
mlirModuleDestroy(MlirModule module)175 void mlirModuleDestroy(MlirModule module) {
176 // Transfer ownership to an OwningModuleRef so that its destructor is called.
177 OwningModuleRef(unwrap(module));
178 }
179
mlirModuleGetOperation(MlirModule module)180 MlirOperation mlirModuleGetOperation(MlirModule module) {
181 return wrap(unwrap(module).getOperation());
182 }
183
mlirModuleFromOperation(MlirOperation op)184 MlirModule mlirModuleFromOperation(MlirOperation op) {
185 return wrap(dyn_cast<ModuleOp>(unwrap(op)));
186 }
187
188 //===----------------------------------------------------------------------===//
189 // Operation state API.
190 //===----------------------------------------------------------------------===//
191
mlirOperationStateGet(MlirStringRef name,MlirLocation loc)192 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
193 MlirOperationState state;
194 state.name = name;
195 state.location = loc;
196 state.nResults = 0;
197 state.results = nullptr;
198 state.nOperands = 0;
199 state.operands = nullptr;
200 state.nRegions = 0;
201 state.regions = nullptr;
202 state.nSuccessors = 0;
203 state.successors = nullptr;
204 state.nAttributes = 0;
205 state.attributes = nullptr;
206 state.enableResultTypeInference = false;
207 return state;
208 }
209
210 #define APPEND_ELEMS(type, sizeName, elemName) \
211 state->elemName = \
212 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \
213 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \
214 state->sizeName += n;
215
mlirOperationStateAddResults(MlirOperationState * state,intptr_t n,MlirType const * results)216 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
217 MlirType const *results) {
218 APPEND_ELEMS(MlirType, nResults, results);
219 }
220
mlirOperationStateAddOperands(MlirOperationState * state,intptr_t n,MlirValue const * operands)221 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
222 MlirValue const *operands) {
223 APPEND_ELEMS(MlirValue, nOperands, operands);
224 }
mlirOperationStateAddOwnedRegions(MlirOperationState * state,intptr_t n,MlirRegion const * regions)225 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
226 MlirRegion const *regions) {
227 APPEND_ELEMS(MlirRegion, nRegions, regions);
228 }
mlirOperationStateAddSuccessors(MlirOperationState * state,intptr_t n,MlirBlock const * successors)229 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
230 MlirBlock const *successors) {
231 APPEND_ELEMS(MlirBlock, nSuccessors, successors);
232 }
mlirOperationStateAddAttributes(MlirOperationState * state,intptr_t n,MlirNamedAttribute const * attributes)233 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
234 MlirNamedAttribute const *attributes) {
235 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
236 }
237
mlirOperationStateEnableResultTypeInference(MlirOperationState * state)238 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
239 state->enableResultTypeInference = true;
240 }
241
242 //===----------------------------------------------------------------------===//
243 // Operation API.
244 //===----------------------------------------------------------------------===//
245
inferOperationTypes(OperationState & state)246 static LogicalResult inferOperationTypes(OperationState &state) {
247 MLIRContext *context = state.getContext();
248 const AbstractOperation *abstractOp =
249 AbstractOperation::lookup(state.name.getStringRef(), context);
250 if (!abstractOp) {
251 emitError(state.location)
252 << "type inference was requested for the operation " << state.name
253 << ", but the operation was not registered. Ensure that the dialect "
254 "containing the operation is linked into MLIR and registered with "
255 "the context";
256 return failure();
257 }
258
259 // Fallback to inference via an op interface.
260 auto *inferInterface = abstractOp->getInterface<InferTypeOpInterface>();
261 if (!inferInterface) {
262 emitError(state.location)
263 << "type inference was requested for the operation " << state.name
264 << ", but the operation does not support type inference. Result "
265 "types must be specified explicitly.";
266 return failure();
267 }
268
269 if (succeeded(inferInterface->inferReturnTypes(
270 context, state.location, state.operands,
271 state.attributes.getDictionary(context), state.regions, state.types)))
272 return success();
273
274 // Diagnostic emitted by interface.
275 return failure();
276 }
277
mlirOperationCreate(MlirOperationState * state)278 MlirOperation mlirOperationCreate(MlirOperationState *state) {
279 assert(state);
280 OperationState cppState(unwrap(state->location), unwrap(state->name));
281 SmallVector<Type, 4> resultStorage;
282 SmallVector<Value, 8> operandStorage;
283 SmallVector<Block *, 2> successorStorage;
284 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
285 cppState.addOperands(
286 unwrapList(state->nOperands, state->operands, operandStorage));
287 cppState.addSuccessors(
288 unwrapList(state->nSuccessors, state->successors, successorStorage));
289
290 cppState.attributes.reserve(state->nAttributes);
291 for (intptr_t i = 0; i < state->nAttributes; ++i)
292 cppState.addAttribute(unwrap(state->attributes[i].name),
293 unwrap(state->attributes[i].attribute));
294
295 for (intptr_t i = 0; i < state->nRegions; ++i)
296 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
297
298 free(state->results);
299 free(state->operands);
300 free(state->successors);
301 free(state->regions);
302 free(state->attributes);
303
304 // Infer result types.
305 if (state->enableResultTypeInference) {
306 assert(cppState.types.empty() &&
307 "result type inference enabled and result types provided");
308 if (failed(inferOperationTypes(cppState)))
309 return {nullptr};
310 }
311
312 MlirOperation result = wrap(Operation::create(cppState));
313 return result;
314 }
315
mlirOperationClone(MlirOperation op)316 MlirOperation mlirOperationClone(MlirOperation op) {
317 return wrap(unwrap(op)->clone());
318 }
319
mlirOperationDestroy(MlirOperation op)320 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
321
mlirOperationEqual(MlirOperation op,MlirOperation other)322 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
323 return unwrap(op) == unwrap(other);
324 }
325
mlirOperationGetContext(MlirOperation op)326 MlirContext mlirOperationGetContext(MlirOperation op) {
327 return wrap(unwrap(op)->getContext());
328 }
329
mlirOperationGetName(MlirOperation op)330 MlirIdentifier mlirOperationGetName(MlirOperation op) {
331 return wrap(unwrap(op)->getName().getIdentifier());
332 }
333
mlirOperationGetBlock(MlirOperation op)334 MlirBlock mlirOperationGetBlock(MlirOperation op) {
335 return wrap(unwrap(op)->getBlock());
336 }
337
mlirOperationGetParentOperation(MlirOperation op)338 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
339 return wrap(unwrap(op)->getParentOp());
340 }
341
mlirOperationGetNumRegions(MlirOperation op)342 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
343 return static_cast<intptr_t>(unwrap(op)->getNumRegions());
344 }
345
mlirOperationGetRegion(MlirOperation op,intptr_t pos)346 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
347 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
348 }
349
mlirOperationGetNextInBlock(MlirOperation op)350 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
351 return wrap(unwrap(op)->getNextNode());
352 }
353
mlirOperationGetNumOperands(MlirOperation op)354 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
355 return static_cast<intptr_t>(unwrap(op)->getNumOperands());
356 }
357
mlirOperationGetOperand(MlirOperation op,intptr_t pos)358 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
359 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
360 }
361
mlirOperationSetOperand(MlirOperation op,intptr_t pos,MlirValue newValue)362 void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
363 MlirValue newValue) {
364 unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
365 }
366
mlirOperationGetNumResults(MlirOperation op)367 intptr_t mlirOperationGetNumResults(MlirOperation op) {
368 return static_cast<intptr_t>(unwrap(op)->getNumResults());
369 }
370
mlirOperationGetResult(MlirOperation op,intptr_t pos)371 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
372 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
373 }
374
mlirOperationGetNumSuccessors(MlirOperation op)375 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
376 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
377 }
378
mlirOperationGetSuccessor(MlirOperation op,intptr_t pos)379 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
380 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
381 }
382
mlirOperationGetNumAttributes(MlirOperation op)383 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
384 return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
385 }
386
mlirOperationGetAttribute(MlirOperation op,intptr_t pos)387 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
388 NamedAttribute attr = unwrap(op)->getAttrs()[pos];
389 return MlirNamedAttribute{wrap(attr.first), wrap(attr.second)};
390 }
391
mlirOperationGetAttributeByName(MlirOperation op,MlirStringRef name)392 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
393 MlirStringRef name) {
394 return wrap(unwrap(op)->getAttr(unwrap(name)));
395 }
396
mlirOperationSetAttributeByName(MlirOperation op,MlirStringRef name,MlirAttribute attr)397 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
398 MlirAttribute attr) {
399 unwrap(op)->setAttr(unwrap(name), unwrap(attr));
400 }
401
mlirOperationRemoveAttributeByName(MlirOperation op,MlirStringRef name)402 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
403 return !!unwrap(op)->removeAttr(unwrap(name));
404 }
405
mlirOperationPrint(MlirOperation op,MlirStringCallback callback,void * userData)406 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
407 void *userData) {
408 detail::CallbackOstream stream(callback, userData);
409 unwrap(op)->print(stream);
410 }
411
mlirOperationPrintWithFlags(MlirOperation op,MlirOpPrintingFlags flags,MlirStringCallback callback,void * userData)412 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
413 MlirStringCallback callback, void *userData) {
414 detail::CallbackOstream stream(callback, userData);
415 unwrap(op)->print(stream, *unwrap(flags));
416 }
417
mlirOperationDump(MlirOperation op)418 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
419
mlirOperationVerify(MlirOperation op)420 bool mlirOperationVerify(MlirOperation op) {
421 return succeeded(verify(unwrap(op)));
422 }
423
424 //===----------------------------------------------------------------------===//
425 // Region API.
426 //===----------------------------------------------------------------------===//
427
mlirRegionCreate()428 MlirRegion mlirRegionCreate() { return wrap(new Region); }
429
mlirRegionGetFirstBlock(MlirRegion region)430 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
431 Region *cppRegion = unwrap(region);
432 if (cppRegion->empty())
433 return wrap(static_cast<Block *>(nullptr));
434 return wrap(&cppRegion->front());
435 }
436
mlirRegionAppendOwnedBlock(MlirRegion region,MlirBlock block)437 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
438 unwrap(region)->push_back(unwrap(block));
439 }
440
mlirRegionInsertOwnedBlock(MlirRegion region,intptr_t pos,MlirBlock block)441 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
442 MlirBlock block) {
443 auto &blockList = unwrap(region)->getBlocks();
444 blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
445 }
446
mlirRegionInsertOwnedBlockAfter(MlirRegion region,MlirBlock reference,MlirBlock block)447 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
448 MlirBlock block) {
449 Region *cppRegion = unwrap(region);
450 if (mlirBlockIsNull(reference)) {
451 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
452 return;
453 }
454
455 assert(unwrap(reference)->getParent() == unwrap(region) &&
456 "expected reference block to belong to the region");
457 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
458 unwrap(block));
459 }
460
mlirRegionInsertOwnedBlockBefore(MlirRegion region,MlirBlock reference,MlirBlock block)461 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
462 MlirBlock block) {
463 if (mlirBlockIsNull(reference))
464 return mlirRegionAppendOwnedBlock(region, block);
465
466 assert(unwrap(reference)->getParent() == unwrap(region) &&
467 "expected reference block to belong to the region");
468 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
469 unwrap(block));
470 }
471
mlirRegionDestroy(MlirRegion region)472 void mlirRegionDestroy(MlirRegion region) {
473 delete static_cast<Region *>(region.ptr);
474 }
475
476 //===----------------------------------------------------------------------===//
477 // Block API.
478 //===----------------------------------------------------------------------===//
479
mlirBlockCreate(intptr_t nArgs,MlirType const * args)480 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) {
481 Block *b = new Block;
482 for (intptr_t i = 0; i < nArgs; ++i)
483 b->addArgument(unwrap(args[i]));
484 return wrap(b);
485 }
486
mlirBlockEqual(MlirBlock block,MlirBlock other)487 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
488 return unwrap(block) == unwrap(other);
489 }
490
mlirBlockGetParentOperation(MlirBlock block)491 MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
492 return wrap(unwrap(block)->getParentOp());
493 }
494
mlirBlockGetNextInRegion(MlirBlock block)495 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
496 return wrap(unwrap(block)->getNextNode());
497 }
498
mlirBlockGetFirstOperation(MlirBlock block)499 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
500 Block *cppBlock = unwrap(block);
501 if (cppBlock->empty())
502 return wrap(static_cast<Operation *>(nullptr));
503 return wrap(&cppBlock->front());
504 }
505
mlirBlockGetTerminator(MlirBlock block)506 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
507 Block *cppBlock = unwrap(block);
508 if (cppBlock->empty())
509 return wrap(static_cast<Operation *>(nullptr));
510 Operation &back = cppBlock->back();
511 if (!back.hasTrait<OpTrait::IsTerminator>())
512 return wrap(static_cast<Operation *>(nullptr));
513 return wrap(&back);
514 }
515
mlirBlockAppendOwnedOperation(MlirBlock block,MlirOperation operation)516 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
517 unwrap(block)->push_back(unwrap(operation));
518 }
519
mlirBlockInsertOwnedOperation(MlirBlock block,intptr_t pos,MlirOperation operation)520 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
521 MlirOperation operation) {
522 auto &opList = unwrap(block)->getOperations();
523 opList.insert(std::next(opList.begin(), pos), unwrap(operation));
524 }
525
mlirBlockInsertOwnedOperationAfter(MlirBlock block,MlirOperation reference,MlirOperation operation)526 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
527 MlirOperation reference,
528 MlirOperation operation) {
529 Block *cppBlock = unwrap(block);
530 if (mlirOperationIsNull(reference)) {
531 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
532 return;
533 }
534
535 assert(unwrap(reference)->getBlock() == unwrap(block) &&
536 "expected reference operation to belong to the block");
537 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
538 unwrap(operation));
539 }
540
mlirBlockInsertOwnedOperationBefore(MlirBlock block,MlirOperation reference,MlirOperation operation)541 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
542 MlirOperation reference,
543 MlirOperation operation) {
544 if (mlirOperationIsNull(reference))
545 return mlirBlockAppendOwnedOperation(block, operation);
546
547 assert(unwrap(reference)->getBlock() == unwrap(block) &&
548 "expected reference operation to belong to the block");
549 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
550 unwrap(operation));
551 }
552
mlirBlockDestroy(MlirBlock block)553 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
554
mlirBlockGetNumArguments(MlirBlock block)555 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
556 return static_cast<intptr_t>(unwrap(block)->getNumArguments());
557 }
558
mlirBlockAddArgument(MlirBlock block,MlirType type)559 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type) {
560 return wrap(unwrap(block)->addArgument(unwrap(type)));
561 }
562
mlirBlockGetArgument(MlirBlock block,intptr_t pos)563 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
564 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
565 }
566
mlirBlockPrint(MlirBlock block,MlirStringCallback callback,void * userData)567 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
568 void *userData) {
569 detail::CallbackOstream stream(callback, userData);
570 unwrap(block)->print(stream);
571 }
572
573 //===----------------------------------------------------------------------===//
574 // Value API.
575 //===----------------------------------------------------------------------===//
576
mlirValueEqual(MlirValue value1,MlirValue value2)577 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
578 return unwrap(value1) == unwrap(value2);
579 }
580
mlirValueIsABlockArgument(MlirValue value)581 bool mlirValueIsABlockArgument(MlirValue value) {
582 return unwrap(value).isa<BlockArgument>();
583 }
584
mlirValueIsAOpResult(MlirValue value)585 bool mlirValueIsAOpResult(MlirValue value) {
586 return unwrap(value).isa<OpResult>();
587 }
588
mlirBlockArgumentGetOwner(MlirValue value)589 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
590 return wrap(unwrap(value).cast<BlockArgument>().getOwner());
591 }
592
mlirBlockArgumentGetArgNumber(MlirValue value)593 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
594 return static_cast<intptr_t>(
595 unwrap(value).cast<BlockArgument>().getArgNumber());
596 }
597
mlirBlockArgumentSetType(MlirValue value,MlirType type)598 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
599 unwrap(value).cast<BlockArgument>().setType(unwrap(type));
600 }
601
mlirOpResultGetOwner(MlirValue value)602 MlirOperation mlirOpResultGetOwner(MlirValue value) {
603 return wrap(unwrap(value).cast<OpResult>().getOwner());
604 }
605
mlirOpResultGetResultNumber(MlirValue value)606 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
607 return static_cast<intptr_t>(
608 unwrap(value).cast<OpResult>().getResultNumber());
609 }
610
mlirValueGetType(MlirValue value)611 MlirType mlirValueGetType(MlirValue value) {
612 return wrap(unwrap(value).getType());
613 }
614
mlirValueDump(MlirValue value)615 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
616
mlirValuePrint(MlirValue value,MlirStringCallback callback,void * userData)617 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
618 void *userData) {
619 detail::CallbackOstream stream(callback, userData);
620 unwrap(value).print(stream);
621 }
622
623 //===----------------------------------------------------------------------===//
624 // Type API.
625 //===----------------------------------------------------------------------===//
626
mlirTypeParseGet(MlirContext context,MlirStringRef type)627 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
628 return wrap(mlir::parseType(unwrap(type), unwrap(context)));
629 }
630
mlirTypeGetContext(MlirType type)631 MlirContext mlirTypeGetContext(MlirType type) {
632 return wrap(unwrap(type).getContext());
633 }
634
mlirTypeEqual(MlirType t1,MlirType t2)635 bool mlirTypeEqual(MlirType t1, MlirType t2) {
636 return unwrap(t1) == unwrap(t2);
637 }
638
mlirTypePrint(MlirType type,MlirStringCallback callback,void * userData)639 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
640 detail::CallbackOstream stream(callback, userData);
641 unwrap(type).print(stream);
642 }
643
mlirTypeDump(MlirType type)644 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
645
646 //===----------------------------------------------------------------------===//
647 // Attribute API.
648 //===----------------------------------------------------------------------===//
649
mlirAttributeParseGet(MlirContext context,MlirStringRef attr)650 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
651 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
652 }
653
mlirAttributeGetContext(MlirAttribute attribute)654 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
655 return wrap(unwrap(attribute).getContext());
656 }
657
mlirAttributeGetType(MlirAttribute attribute)658 MlirType mlirAttributeGetType(MlirAttribute attribute) {
659 return wrap(unwrap(attribute).getType());
660 }
661
mlirAttributeEqual(MlirAttribute a1,MlirAttribute a2)662 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
663 return unwrap(a1) == unwrap(a2);
664 }
665
mlirAttributePrint(MlirAttribute attr,MlirStringCallback callback,void * userData)666 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
667 void *userData) {
668 detail::CallbackOstream stream(callback, userData);
669 unwrap(attr).print(stream);
670 }
671
mlirAttributeDump(MlirAttribute attr)672 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
673
mlirNamedAttributeGet(MlirIdentifier name,MlirAttribute attr)674 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
675 MlirAttribute attr) {
676 return MlirNamedAttribute{name, attr};
677 }
678
679 //===----------------------------------------------------------------------===//
680 // Identifier API.
681 //===----------------------------------------------------------------------===//
682
mlirIdentifierGet(MlirContext context,MlirStringRef str)683 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
684 return wrap(Identifier::get(unwrap(str), unwrap(context)));
685 }
686
mlirIdentifierGetContext(MlirIdentifier ident)687 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
688 return wrap(unwrap(ident).getContext());
689 }
690
mlirIdentifierEqual(MlirIdentifier ident,MlirIdentifier other)691 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
692 return unwrap(ident) == unwrap(other);
693 }
694
mlirIdentifierStr(MlirIdentifier ident)695 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
696 return wrap(unwrap(ident).strref());
697 }
698