1 //===- ir.c - Simple test of C APIs ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9
10 /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
11 */
12
13 #include "mlir-c/IR.h"
14 #include "mlir-c/AffineExpr.h"
15 #include "mlir-c/AffineMap.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/BuiltinTypes.h"
18 #include "mlir-c/Diagnostics.h"
19 #include "mlir-c/Dialect/Standard.h"
20 #include "mlir-c/IntegerSet.h"
21 #include "mlir-c/Registration.h"
22
23 #include <assert.h>
24 #include <inttypes.h>
25 #include <math.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29
populateLoopBody(MlirContext ctx,MlirBlock loopBody,MlirLocation location,MlirBlock funcBody)30 void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
31 MlirLocation location, MlirBlock funcBody) {
32 MlirValue iv = mlirBlockGetArgument(loopBody, 0);
33 MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
34 MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
35 MlirType f32Type =
36 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));
37
38 MlirOperationState loadLHSState = mlirOperationStateGet(
39 mlirStringRefCreateFromCString("memref.load"), location);
40 MlirValue loadLHSOperands[] = {funcArg0, iv};
41 mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
42 mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
43 MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
44 mlirBlockAppendOwnedOperation(loopBody, loadLHS);
45
46 MlirOperationState loadRHSState = mlirOperationStateGet(
47 mlirStringRefCreateFromCString("memref.load"), location);
48 MlirValue loadRHSOperands[] = {funcArg1, iv};
49 mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
50 mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
51 MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
52 mlirBlockAppendOwnedOperation(loopBody, loadRHS);
53
54 MlirOperationState addState = mlirOperationStateGet(
55 mlirStringRefCreateFromCString("std.addf"), location);
56 MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
57 mlirOperationGetResult(loadRHS, 0)};
58 mlirOperationStateAddOperands(&addState, 2, addOperands);
59 mlirOperationStateAddResults(&addState, 1, &f32Type);
60 MlirOperation add = mlirOperationCreate(&addState);
61 mlirBlockAppendOwnedOperation(loopBody, add);
62
63 MlirOperationState storeState = mlirOperationStateGet(
64 mlirStringRefCreateFromCString("memref.store"), location);
65 MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
66 mlirOperationStateAddOperands(&storeState, 3, storeOperands);
67 MlirOperation store = mlirOperationCreate(&storeState);
68 mlirBlockAppendOwnedOperation(loopBody, store);
69
70 MlirOperationState yieldState = mlirOperationStateGet(
71 mlirStringRefCreateFromCString("scf.yield"), location);
72 MlirOperation yield = mlirOperationCreate(&yieldState);
73 mlirBlockAppendOwnedOperation(loopBody, yield);
74 }
75
makeAndDumpAdd(MlirContext ctx,MlirLocation location)76 MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
77 MlirModule moduleOp = mlirModuleCreateEmpty(location);
78 MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
79
80 MlirType memrefType =
81 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
82 MlirType funcBodyArgTypes[] = {memrefType, memrefType};
83 MlirRegion funcBodyRegion = mlirRegionCreate();
84 MlirBlock funcBody = mlirBlockCreate(
85 sizeof(funcBodyArgTypes) / sizeof(MlirType), funcBodyArgTypes);
86 mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);
87
88 MlirAttribute funcTypeAttr = mlirAttributeParseGet(
89 ctx,
90 mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
91 MlirAttribute funcNameAttr =
92 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
93 MlirNamedAttribute funcAttrs[] = {
94 mlirNamedAttributeGet(
95 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("type")),
96 funcTypeAttr),
97 mlirNamedAttributeGet(
98 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sym_name")),
99 funcNameAttr)};
100 MlirOperationState funcState =
101 mlirOperationStateGet(mlirStringRefCreateFromCString("func"), location);
102 mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
103 mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
104 MlirOperation func = mlirOperationCreate(&funcState);
105 mlirBlockInsertOwnedOperation(moduleBody, 0, func);
106
107 MlirType indexType =
108 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
109 MlirAttribute indexZeroLiteral =
110 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
111 MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
112 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
113 indexZeroLiteral);
114 MlirOperationState constZeroState = mlirOperationStateGet(
115 mlirStringRefCreateFromCString("std.constant"), location);
116 mlirOperationStateAddResults(&constZeroState, 1, &indexType);
117 mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
118 MlirOperation constZero = mlirOperationCreate(&constZeroState);
119 mlirBlockAppendOwnedOperation(funcBody, constZero);
120
121 MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
122 MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
123 MlirValue dimOperands[] = {funcArg0, constZeroValue};
124 MlirOperationState dimState = mlirOperationStateGet(
125 mlirStringRefCreateFromCString("memref.dim"), location);
126 mlirOperationStateAddOperands(&dimState, 2, dimOperands);
127 mlirOperationStateAddResults(&dimState, 1, &indexType);
128 MlirOperation dim = mlirOperationCreate(&dimState);
129 mlirBlockAppendOwnedOperation(funcBody, dim);
130
131 MlirRegion loopBodyRegion = mlirRegionCreate();
132 MlirBlock loopBody = mlirBlockCreate(0, NULL);
133 mlirBlockAddArgument(loopBody, indexType);
134 mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);
135
136 MlirAttribute indexOneLiteral =
137 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
138 MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
139 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
140 indexOneLiteral);
141 MlirOperationState constOneState = mlirOperationStateGet(
142 mlirStringRefCreateFromCString("std.constant"), location);
143 mlirOperationStateAddResults(&constOneState, 1, &indexType);
144 mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
145 MlirOperation constOne = mlirOperationCreate(&constOneState);
146 mlirBlockAppendOwnedOperation(funcBody, constOne);
147
148 MlirValue dimValue = mlirOperationGetResult(dim, 0);
149 MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
150 MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
151 MlirOperationState loopState = mlirOperationStateGet(
152 mlirStringRefCreateFromCString("scf.for"), location);
153 mlirOperationStateAddOperands(&loopState, 3, loopOperands);
154 mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
155 MlirOperation loop = mlirOperationCreate(&loopState);
156 mlirBlockAppendOwnedOperation(funcBody, loop);
157
158 populateLoopBody(ctx, loopBody, location, funcBody);
159
160 MlirOperationState retState = mlirOperationStateGet(
161 mlirStringRefCreateFromCString("std.return"), location);
162 MlirOperation ret = mlirOperationCreate(&retState);
163 mlirBlockAppendOwnedOperation(funcBody, ret);
164
165 MlirOperation module = mlirModuleGetOperation(moduleOp);
166 mlirOperationDump(module);
167 // clang-format off
168 // CHECK: module {
169 // CHECK: func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) {
170 // CHECK: %[[C0:.*]] = constant 0 : index
171 // CHECK: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32>
172 // CHECK: %[[C1:.*]] = constant 1 : index
173 // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
174 // CHECK: %[[LHS:.*]] = memref.load %[[ARG0]][%[[I]]] : memref<?xf32>
175 // CHECK: %[[RHS:.*]] = memref.load %[[ARG1]][%[[I]]] : memref<?xf32>
176 // CHECK: %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
177 // CHECK: memref.store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32>
178 // CHECK: }
179 // CHECK: return
180 // CHECK: }
181 // CHECK: }
182 // clang-format on
183
184 return moduleOp;
185 }
186
187 struct OpListNode {
188 MlirOperation op;
189 struct OpListNode *next;
190 };
191 typedef struct OpListNode OpListNode;
192
193 struct ModuleStats {
194 unsigned numOperations;
195 unsigned numAttributes;
196 unsigned numBlocks;
197 unsigned numRegions;
198 unsigned numValues;
199 unsigned numBlockArguments;
200 unsigned numOpResults;
201 };
202 typedef struct ModuleStats ModuleStats;
203
collectStatsSingle(OpListNode * head,ModuleStats * stats)204 int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
205 MlirOperation operation = head->op;
206 stats->numOperations += 1;
207 stats->numValues += mlirOperationGetNumResults(operation);
208 stats->numAttributes += mlirOperationGetNumAttributes(operation);
209
210 unsigned numRegions = mlirOperationGetNumRegions(operation);
211
212 stats->numRegions += numRegions;
213
214 intptr_t numResults = mlirOperationGetNumResults(operation);
215 for (intptr_t i = 0; i < numResults; ++i) {
216 MlirValue result = mlirOperationGetResult(operation, i);
217 if (!mlirValueIsAOpResult(result))
218 return 1;
219 if (mlirValueIsABlockArgument(result))
220 return 2;
221 if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
222 return 3;
223 if (i != mlirOpResultGetResultNumber(result))
224 return 4;
225 ++stats->numOpResults;
226 }
227
228 for (unsigned i = 0; i < numRegions; ++i) {
229 MlirRegion region = mlirOperationGetRegion(operation, i);
230 for (MlirBlock block = mlirRegionGetFirstBlock(region);
231 !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
232 ++stats->numBlocks;
233 intptr_t numArgs = mlirBlockGetNumArguments(block);
234 stats->numValues += numArgs;
235 for (intptr_t j = 0; j < numArgs; ++j) {
236 MlirValue arg = mlirBlockGetArgument(block, j);
237 if (!mlirValueIsABlockArgument(arg))
238 return 5;
239 if (mlirValueIsAOpResult(arg))
240 return 6;
241 if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
242 return 7;
243 if (j != mlirBlockArgumentGetArgNumber(arg))
244 return 8;
245 ++stats->numBlockArguments;
246 }
247
248 for (MlirOperation child = mlirBlockGetFirstOperation(block);
249 !mlirOperationIsNull(child);
250 child = mlirOperationGetNextInBlock(child)) {
251 OpListNode *node = malloc(sizeof(OpListNode));
252 node->op = child;
253 node->next = head->next;
254 head->next = node;
255 }
256 }
257 }
258 return 0;
259 }
260
collectStats(MlirOperation operation)261 int collectStats(MlirOperation operation) {
262 OpListNode *head = malloc(sizeof(OpListNode));
263 head->op = operation;
264 head->next = NULL;
265
266 ModuleStats stats;
267 stats.numOperations = 0;
268 stats.numAttributes = 0;
269 stats.numBlocks = 0;
270 stats.numRegions = 0;
271 stats.numValues = 0;
272 stats.numBlockArguments = 0;
273 stats.numOpResults = 0;
274
275 do {
276 int retval = collectStatsSingle(head, &stats);
277 if (retval)
278 return retval;
279 OpListNode *next = head->next;
280 free(head);
281 head = next;
282 } while (head);
283
284 if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
285 return 100;
286
287 fprintf(stderr, "@stats\n");
288 fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
289 fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
290 fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
291 fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
292 fprintf(stderr, "Number of values: %u\n", stats.numValues);
293 fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
294 fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
295 // clang-format off
296 // CHECK-LABEL: @stats
297 // CHECK: Number of operations: 12
298 // CHECK: Number of attributes: 4
299 // CHECK: Number of blocks: 3
300 // CHECK: Number of regions: 3
301 // CHECK: Number of values: 9
302 // CHECK: Number of block arguments: 3
303 // CHECK: Number of op results: 6
304 // clang-format on
305 return 0;
306 }
307
printToStderr(MlirStringRef str,void * userData)308 static void printToStderr(MlirStringRef str, void *userData) {
309 (void)userData;
310 fwrite(str.data, 1, str.length, stderr);
311 }
312
printFirstOfEach(MlirContext ctx,MlirOperation operation)313 static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
314 // Assuming we are given a module, go to the first operation of the first
315 // function.
316 MlirRegion region = mlirOperationGetRegion(operation, 0);
317 MlirBlock block = mlirRegionGetFirstBlock(region);
318 operation = mlirBlockGetFirstOperation(block);
319 region = mlirOperationGetRegion(operation, 0);
320 MlirOperation parentOperation = operation;
321 block = mlirRegionGetFirstBlock(region);
322 operation = mlirBlockGetFirstOperation(block);
323 assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
324
325 // Verify that parent operation and block report correctly.
326 fprintf(stderr, "Parent operation eq: %d\n",
327 mlirOperationEqual(mlirOperationGetParentOperation(operation),
328 parentOperation));
329 fprintf(stderr, "Block eq: %d\n",
330 mlirBlockEqual(mlirOperationGetBlock(operation), block));
331 // CHECK: Parent operation eq: 1
332 // CHECK: Block eq: 1
333
334 // In the module we created, the first operation of the first function is
335 // an "memref.dim", which has an attribute and a single result that we can
336 // use to test the printing mechanism.
337 mlirBlockPrint(block, printToStderr, NULL);
338 fprintf(stderr, "\n");
339 fprintf(stderr, "First operation: ");
340 mlirOperationPrint(operation, printToStderr, NULL);
341 fprintf(stderr, "\n");
342 // clang-format off
343 // CHECK: %[[C0:.*]] = constant 0 : index
344 // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
345 // CHECK: %[[C1:.*]] = constant 1 : index
346 // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
347 // CHECK: %[[LHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
348 // CHECK: %[[RHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
349 // CHECK: %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
350 // CHECK: memref.store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
351 // CHECK: }
352 // CHECK: return
353 // CHECK: First operation: {{.*}} = constant 0 : index
354 // clang-format on
355
356 // Get the operation name and print it.
357 MlirIdentifier ident = mlirOperationGetName(operation);
358 MlirStringRef identStr = mlirIdentifierStr(ident);
359 fprintf(stderr, "Operation name: '");
360 for (size_t i = 0; i < identStr.length; ++i)
361 fputc(identStr.data[i], stderr);
362 fprintf(stderr, "'\n");
363 // CHECK: Operation name: 'std.constant'
364
365 // Get the identifier again and verify equal.
366 MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr);
367 fprintf(stderr, "Identifier equal: %d\n",
368 mlirIdentifierEqual(ident, identAgain));
369 // CHECK: Identifier equal: 1
370
371 // Get the block terminator and print it.
372 MlirOperation terminator = mlirBlockGetTerminator(block);
373 fprintf(stderr, "Terminator: ");
374 mlirOperationPrint(terminator, printToStderr, NULL);
375 fprintf(stderr, "\n");
376 // CHECK: Terminator: return
377
378 // Get the attribute by index.
379 MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
380 fprintf(stderr, "Get attr 0: ");
381 mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
382 fprintf(stderr, "\n");
383 // CHECK: Get attr 0: 0 : index
384
385 // Now re-get the attribute by name.
386 MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
387 operation, mlirIdentifierStr(namedAttr0.name));
388 fprintf(stderr, "Get attr 0 by name: ");
389 mlirAttributePrint(attr0ByName, printToStderr, NULL);
390 fprintf(stderr, "\n");
391 // CHECK: Get attr 0 by name: 0 : index
392
393 // Get a non-existing attribute and assert that it is null (sanity).
394 fprintf(stderr, "does_not_exist is null: %d\n",
395 mlirAttributeIsNull(mlirOperationGetAttributeByName(
396 operation, mlirStringRefCreateFromCString("does_not_exist"))));
397 // CHECK: does_not_exist is null: 1
398
399 // Get result 0 and its type.
400 MlirValue value = mlirOperationGetResult(operation, 0);
401 fprintf(stderr, "Result 0: ");
402 mlirValuePrint(value, printToStderr, NULL);
403 fprintf(stderr, "\n");
404 fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
405 // CHECK: Result 0: {{.*}} = constant 0 : index
406 // CHECK: Value is null: 0
407
408 MlirType type = mlirValueGetType(value);
409 fprintf(stderr, "Result 0 type: ");
410 mlirTypePrint(type, printToStderr, NULL);
411 fprintf(stderr, "\n");
412 // CHECK: Result 0 type: index
413
414 // Set a custom attribute.
415 mlirOperationSetAttributeByName(operation,
416 mlirStringRefCreateFromCString("custom_attr"),
417 mlirBoolAttrGet(ctx, 1));
418 fprintf(stderr, "Op with set attr: ");
419 mlirOperationPrint(operation, printToStderr, NULL);
420 fprintf(stderr, "\n");
421 // CHECK: Op with set attr: {{.*}} {custom_attr = true}
422
423 // Remove the attribute.
424 fprintf(stderr, "Remove attr: %d\n",
425 mlirOperationRemoveAttributeByName(
426 operation, mlirStringRefCreateFromCString("custom_attr")));
427 fprintf(stderr, "Remove attr again: %d\n",
428 mlirOperationRemoveAttributeByName(
429 operation, mlirStringRefCreateFromCString("custom_attr")));
430 fprintf(stderr, "Removed attr is null: %d\n",
431 mlirAttributeIsNull(mlirOperationGetAttributeByName(
432 operation, mlirStringRefCreateFromCString("custom_attr"))));
433 // CHECK: Remove attr: 1
434 // CHECK: Remove attr again: 0
435 // CHECK: Removed attr is null: 1
436
437 // Add a large attribute to verify printing flags.
438 int64_t eltsShape[] = {4};
439 int32_t eltsData[] = {1, 2, 3, 4};
440 mlirOperationSetAttributeByName(
441 operation, mlirStringRefCreateFromCString("elts"),
442 mlirDenseElementsAttrInt32Get(
443 mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
444 mlirAttributeGetNull()), 4, eltsData));
445 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
446 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
447 mlirOpPrintingFlagsPrintGenericOpForm(flags);
448 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
449 mlirOpPrintingFlagsUseLocalScope(flags);
450 fprintf(stderr, "Op print with all flags: ");
451 mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
452 fprintf(stderr, "\n");
453 // clang-format off
454 // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
455 // clang-format on
456
457 mlirOpPrintingFlagsDestroy(flags);
458 }
459
constructAndTraverseIr(MlirContext ctx)460 static int constructAndTraverseIr(MlirContext ctx) {
461 MlirLocation location = mlirLocationUnknownGet(ctx);
462
463 MlirModule moduleOp = makeAndDumpAdd(ctx, location);
464 MlirOperation module = mlirModuleGetOperation(moduleOp);
465 assert(!mlirModuleIsNull(mlirModuleFromOperation(module)));
466
467 int errcode = collectStats(module);
468 if (errcode)
469 return errcode;
470
471 printFirstOfEach(ctx, module);
472
473 mlirModuleDestroy(moduleOp);
474 return 0;
475 }
476
477 /// Creates an operation with a region containing multiple blocks with
478 /// operations and dumps it. The blocks and operations are inserted using
479 /// block/operation-relative API and their final order is checked.
buildWithInsertionsAndPrint(MlirContext ctx)480 static void buildWithInsertionsAndPrint(MlirContext ctx) {
481 MlirLocation loc = mlirLocationUnknownGet(ctx);
482 mlirContextSetAllowUnregisteredDialects(ctx, true);
483
484 MlirRegion owningRegion = mlirRegionCreate();
485 MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
486 MlirOperationState state = mlirOperationStateGet(
487 mlirStringRefCreateFromCString("insertion.order.test"), loc);
488 mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
489 MlirOperation op = mlirOperationCreate(&state);
490 MlirRegion region = mlirOperationGetRegion(op, 0);
491
492 // Use integer types of different bitwidth as block arguments in order to
493 // differentiate blocks.
494 MlirType i1 = mlirIntegerTypeGet(ctx, 1);
495 MlirType i2 = mlirIntegerTypeGet(ctx, 2);
496 MlirType i3 = mlirIntegerTypeGet(ctx, 3);
497 MlirType i4 = mlirIntegerTypeGet(ctx, 4);
498 MlirBlock block1 = mlirBlockCreate(1, &i1);
499 MlirBlock block2 = mlirBlockCreate(1, &i2);
500 MlirBlock block3 = mlirBlockCreate(1, &i3);
501 MlirBlock block4 = mlirBlockCreate(1, &i4);
502 // Insert blocks so as to obtain the 1-2-3-4 order,
503 mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
504 mlirRegionInsertOwnedBlockBefore(region, block3, block2);
505 mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
506 mlirRegionInsertOwnedBlockAfter(region, block3, block4);
507
508 MlirOperationState op1State =
509 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
510 MlirOperationState op2State =
511 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
512 MlirOperationState op3State =
513 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
514 MlirOperationState op4State =
515 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
516 MlirOperationState op5State =
517 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
518 MlirOperationState op6State =
519 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
520 MlirOperationState op7State =
521 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
522 MlirOperation op1 = mlirOperationCreate(&op1State);
523 MlirOperation op2 = mlirOperationCreate(&op2State);
524 MlirOperation op3 = mlirOperationCreate(&op3State);
525 MlirOperation op4 = mlirOperationCreate(&op4State);
526 MlirOperation op5 = mlirOperationCreate(&op5State);
527 MlirOperation op6 = mlirOperationCreate(&op6State);
528 MlirOperation op7 = mlirOperationCreate(&op7State);
529
530 // Insert operations in the first block so as to obtain the 1-2-3-4 order.
531 MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
532 assert(mlirOperationIsNull(nullOperation));
533 mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
534 mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
535 mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
536 mlirBlockInsertOwnedOperationAfter(block1, op3, op4);
537
538 // Append operations to the rest of blocks to make them non-empty and thus
539 // printable.
540 mlirBlockAppendOwnedOperation(block2, op5);
541 mlirBlockAppendOwnedOperation(block3, op6);
542 mlirBlockAppendOwnedOperation(block4, op7);
543
544 mlirOperationDump(op);
545 mlirOperationDestroy(op);
546 mlirContextSetAllowUnregisteredDialects(ctx, false);
547 // clang-format off
548 // CHECK-LABEL: "insertion.order.test"
549 // CHECK: ^{{.*}}(%{{.*}}: i1
550 // CHECK: "dummy.op1"
551 // CHECK-NEXT: "dummy.op2"
552 // CHECK-NEXT: "dummy.op3"
553 // CHECK-NEXT: "dummy.op4"
554 // CHECK: ^{{.*}}(%{{.*}}: i2
555 // CHECK: "dummy.op5"
556 // CHECK: ^{{.*}}(%{{.*}}: i3
557 // CHECK: "dummy.op6"
558 // CHECK: ^{{.*}}(%{{.*}}: i4
559 // CHECK: "dummy.op7"
560 // clang-format on
561 }
562
563 /// Creates operations with type inference and tests various failure modes.
createOperationWithTypeInference(MlirContext ctx)564 static int createOperationWithTypeInference(MlirContext ctx) {
565 MlirLocation loc = mlirLocationUnknownGet(ctx);
566 MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4);
567
568 // The shape.const_size op implements result type inference and is only used
569 // for that reason.
570 MlirOperationState state = mlirOperationStateGet(
571 mlirStringRefCreateFromCString("shape.const_size"), loc);
572 MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
573 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr);
574 mlirOperationStateAddAttributes(&state, 1, &valueAttr);
575 mlirOperationStateEnableResultTypeInference(&state);
576
577 // Expect result type inference to succeed.
578 MlirOperation op = mlirOperationCreate(&state);
579 if (mlirOperationIsNull(op)) {
580 fprintf(stderr, "ERROR: Result type inference unexpectedly failed");
581 return 1;
582 }
583
584 // CHECK: RESULT_TYPE_INFERENCE: !shape.size
585 fprintf(stderr, "RESULT_TYPE_INFERENCE: ");
586 mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0)));
587 fprintf(stderr, "\n");
588 mlirOperationDestroy(op);
589 return 0;
590 }
591
592 /// Dumps instances of all builtin types to check that C API works correctly.
593 /// Additionally, performs simple identity checks that a builtin type
594 /// constructed with C API can be inspected and has the expected type. The
595 /// latter achieves full coverage of C API for builtin types. Returns 0 on
596 /// success and a non-zero error code on failure.
printBuiltinTypes(MlirContext ctx)597 static int printBuiltinTypes(MlirContext ctx) {
598 // Integer types.
599 MlirType i32 = mlirIntegerTypeGet(ctx, 32);
600 MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
601 MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
602 if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
603 return 1;
604 if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
605 return 2;
606 if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
607 return 3;
608 if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
609 return 4;
610 if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
611 return 5;
612 fprintf(stderr, "@types\n");
613 mlirTypeDump(i32);
614 fprintf(stderr, "\n");
615 mlirTypeDump(si32);
616 fprintf(stderr, "\n");
617 mlirTypeDump(ui32);
618 fprintf(stderr, "\n");
619 // CHECK-LABEL: @types
620 // CHECK: i32
621 // CHECK: si32
622 // CHECK: ui32
623
624 // Index type.
625 MlirType index = mlirIndexTypeGet(ctx);
626 if (!mlirTypeIsAIndex(index))
627 return 6;
628 mlirTypeDump(index);
629 fprintf(stderr, "\n");
630 // CHECK: index
631
632 // Floating-point types.
633 MlirType bf16 = mlirBF16TypeGet(ctx);
634 MlirType f16 = mlirF16TypeGet(ctx);
635 MlirType f32 = mlirF32TypeGet(ctx);
636 MlirType f64 = mlirF64TypeGet(ctx);
637 if (!mlirTypeIsABF16(bf16))
638 return 7;
639 if (!mlirTypeIsAF16(f16))
640 return 9;
641 if (!mlirTypeIsAF32(f32))
642 return 10;
643 if (!mlirTypeIsAF64(f64))
644 return 11;
645 mlirTypeDump(bf16);
646 fprintf(stderr, "\n");
647 mlirTypeDump(f16);
648 fprintf(stderr, "\n");
649 mlirTypeDump(f32);
650 fprintf(stderr, "\n");
651 mlirTypeDump(f64);
652 fprintf(stderr, "\n");
653 // CHECK: bf16
654 // CHECK: f16
655 // CHECK: f32
656 // CHECK: f64
657
658 // None type.
659 MlirType none = mlirNoneTypeGet(ctx);
660 if (!mlirTypeIsANone(none))
661 return 12;
662 mlirTypeDump(none);
663 fprintf(stderr, "\n");
664 // CHECK: none
665
666 // Complex type.
667 MlirType cplx = mlirComplexTypeGet(f32);
668 if (!mlirTypeIsAComplex(cplx) ||
669 !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
670 return 13;
671 mlirTypeDump(cplx);
672 fprintf(stderr, "\n");
673 // CHECK: complex<f32>
674
675 // Vector (and Shaped) type. ShapedType is a common base class for vectors,
676 // memrefs and tensors, one cannot create instances of this class so it is
677 // tested on an instance of vector type.
678 int64_t shape[] = {2, 3};
679 MlirType vector =
680 mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
681 if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
682 return 14;
683 if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
684 !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
685 mlirShapedTypeGetDimSize(vector, 0) != 2 ||
686 mlirShapedTypeIsDynamicDim(vector, 0) ||
687 mlirShapedTypeGetDimSize(vector, 1) != 3 ||
688 !mlirShapedTypeHasStaticShape(vector))
689 return 15;
690 mlirTypeDump(vector);
691 fprintf(stderr, "\n");
692 // CHECK: vector<2x3xf32>
693
694 // Ranked tensor type.
695 MlirType rankedTensor = mlirRankedTensorTypeGet(
696 sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
697 if (!mlirTypeIsATensor(rankedTensor) ||
698 !mlirTypeIsARankedTensor(rankedTensor) ||
699 !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
700 return 16;
701 mlirTypeDump(rankedTensor);
702 fprintf(stderr, "\n");
703 // CHECK: tensor<2x3xf32>
704
705 // Unranked tensor type.
706 MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
707 if (!mlirTypeIsATensor(unrankedTensor) ||
708 !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
709 mlirShapedTypeHasRank(unrankedTensor))
710 return 17;
711 mlirTypeDump(unrankedTensor);
712 fprintf(stderr, "\n");
713 // CHECK: tensor<*xf32>
714
715 // MemRef type.
716 MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2);
717 MlirType memRef = mlirMemRefTypeContiguousGet(
718 f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
719 if (!mlirTypeIsAMemRef(memRef) ||
720 mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
721 !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
722 return 18;
723 mlirTypeDump(memRef);
724 fprintf(stderr, "\n");
725 // CHECK: memref<2x3xf32, 2>
726
727 // Unranked MemRef type.
728 MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4);
729 MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4);
730 if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
731 mlirTypeIsAMemRef(unrankedMemRef) ||
732 !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
733 memSpace4))
734 return 19;
735 mlirTypeDump(unrankedMemRef);
736 fprintf(stderr, "\n");
737 // CHECK: memref<*xf32, 4>
738
739 // Tuple type.
740 MlirType types[] = {unrankedMemRef, f32};
741 MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
742 if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
743 !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
744 !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
745 return 20;
746 mlirTypeDump(tuple);
747 fprintf(stderr, "\n");
748 // CHECK: tuple<memref<*xf32, 4>, f32>
749
750 // Function type.
751 MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
752 MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
753 mlirIntegerTypeGet(ctx, 32),
754 mlirIntegerTypeGet(ctx, 64)};
755 MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
756 if (mlirFunctionTypeGetNumInputs(funcType) != 2)
757 return 21;
758 if (mlirFunctionTypeGetNumResults(funcType) != 3)
759 return 22;
760 if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
761 !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
762 return 23;
763 if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
764 !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
765 !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
766 return 24;
767 mlirTypeDump(funcType);
768 fprintf(stderr, "\n");
769 // CHECK: (index, i1) -> (i16, i32, i64)
770
771 return 0;
772 }
773
callbackSetFixedLengthString(const char * data,intptr_t len,void * userData)774 void callbackSetFixedLengthString(const char *data, intptr_t len,
775 void *userData) {
776 strncpy(userData, data, len);
777 }
778
stringIsEqual(const char * lhs,MlirStringRef rhs)779 bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
780 if (strlen(lhs) != rhs.length) {
781 return false;
782 }
783 return !strncmp(lhs, rhs.data, rhs.length);
784 }
785
printBuiltinAttributes(MlirContext ctx)786 int printBuiltinAttributes(MlirContext ctx) {
787 MlirAttribute floating =
788 mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
789 if (!mlirAttributeIsAFloat(floating) ||
790 fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
791 return 1;
792 fprintf(stderr, "@attrs\n");
793 mlirAttributeDump(floating);
794 // CHECK-LABEL: @attrs
795 // CHECK: 2.000000e+00 : f64
796
797 // Exercise mlirAttributeGetType() just for the first one.
798 MlirType floatingType = mlirAttributeGetType(floating);
799 mlirTypeDump(floatingType);
800 // CHECK: f64
801
802 MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
803 if (!mlirAttributeIsAInteger(integer) ||
804 mlirIntegerAttrGetValueInt(integer) != 42)
805 return 2;
806 mlirAttributeDump(integer);
807 // CHECK: 42 : i32
808
809 MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
810 if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
811 return 3;
812 mlirAttributeDump(boolean);
813 // CHECK: true
814
815 const char data[] = "abcdefghijklmnopqestuvwxyz";
816 MlirAttribute opaque =
817 mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("std"), 3, data,
818 mlirNoneTypeGet(ctx));
819 if (!mlirAttributeIsAOpaque(opaque) ||
820 !stringIsEqual("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
821 return 4;
822
823 MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
824 if (opaqueData.length != 3 ||
825 strncmp(data, opaqueData.data, opaqueData.length))
826 return 5;
827 mlirAttributeDump(opaque);
828 // CHECK: #std.abc
829
830 MlirAttribute string =
831 mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
832 if (!mlirAttributeIsAString(string))
833 return 6;
834
835 MlirStringRef stringValue = mlirStringAttrGetValue(string);
836 if (stringValue.length != 2 ||
837 strncmp(data + 3, stringValue.data, stringValue.length))
838 return 7;
839 mlirAttributeDump(string);
840 // CHECK: "de"
841
842 MlirAttribute flatSymbolRef =
843 mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
844 if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
845 return 8;
846
847 MlirStringRef flatSymbolRefValue =
848 mlirFlatSymbolRefAttrGetValue(flatSymbolRef);
849 if (flatSymbolRefValue.length != 3 ||
850 strncmp(data + 5, flatSymbolRefValue.data, flatSymbolRefValue.length))
851 return 9;
852 mlirAttributeDump(flatSymbolRef);
853 // CHECK: @fgh
854
855 MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
856 MlirAttribute symbolRef =
857 mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
858 if (!mlirAttributeIsASymbolRef(symbolRef) ||
859 mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
860 !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
861 flatSymbolRef) ||
862 !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
863 flatSymbolRef))
864 return 10;
865
866 MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(symbolRef);
867 MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(symbolRef);
868 if (symbolRefLeaf.length != 3 ||
869 strncmp(data + 5, symbolRefLeaf.data, symbolRefLeaf.length) ||
870 symbolRefRoot.length != 2 ||
871 strncmp(data + 8, symbolRefRoot.data, symbolRefRoot.length))
872 return 11;
873 mlirAttributeDump(symbolRef);
874 // CHECK: @ij::@fgh::@fgh
875
876 MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
877 if (!mlirAttributeIsAType(type) ||
878 !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
879 return 12;
880 mlirAttributeDump(type);
881 // CHECK: f32
882
883 MlirAttribute unit = mlirUnitAttrGet(ctx);
884 if (!mlirAttributeIsAUnit(unit))
885 return 13;
886 mlirAttributeDump(unit);
887 // CHECK: unit
888
889 int64_t shape[] = {1, 2};
890
891 int bools[] = {0, 1};
892 uint8_t uints8[] = {0u, 1u};
893 int8_t ints8[] = {0, 1};
894 uint32_t uints32[] = {0u, 1u};
895 int32_t ints32[] = {0, 1};
896 uint64_t uints64[] = {0u, 1u};
897 int64_t ints64[] = {0, 1};
898 float floats[] = {0.0f, 1.0f};
899 double doubles[] = {0.0, 1.0};
900 MlirAttribute encoding = mlirAttributeGetNull();
901 MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
902 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
903 2, bools);
904 MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get(
905 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
906 encoding),
907 2, uints8);
908 MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
909 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
910 2, ints8);
911 MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
912 mlirRankedTensorTypeGet(2, shape,
913 mlirIntegerTypeUnsignedGet(ctx, 32), encoding),
914 2, uints32);
915 MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
916 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
917 2, ints32);
918 MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
919 mlirRankedTensorTypeGet(2, shape,
920 mlirIntegerTypeUnsignedGet(ctx, 64), encoding),
921 2, uints64);
922 MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
923 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
924 2, ints64);
925 MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
926 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
927 2, floats);
928 MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
929 mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding),
930 2, doubles);
931
932 if (!mlirAttributeIsADenseElements(boolElements) ||
933 !mlirAttributeIsADenseElements(uint8Elements) ||
934 !mlirAttributeIsADenseElements(int8Elements) ||
935 !mlirAttributeIsADenseElements(uint32Elements) ||
936 !mlirAttributeIsADenseElements(int32Elements) ||
937 !mlirAttributeIsADenseElements(uint64Elements) ||
938 !mlirAttributeIsADenseElements(int64Elements) ||
939 !mlirAttributeIsADenseElements(floatElements) ||
940 !mlirAttributeIsADenseElements(doubleElements))
941 return 14;
942
943 if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
944 mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
945 mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
946 mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
947 mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
948 mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
949 mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
950 fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
951 1E-6f ||
952 fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
953 return 15;
954
955 mlirAttributeDump(boolElements);
956 mlirAttributeDump(uint8Elements);
957 mlirAttributeDump(int8Elements);
958 mlirAttributeDump(uint32Elements);
959 mlirAttributeDump(int32Elements);
960 mlirAttributeDump(uint64Elements);
961 mlirAttributeDump(int64Elements);
962 mlirAttributeDump(floatElements);
963 mlirAttributeDump(doubleElements);
964 // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
965 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
966 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
967 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
968 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
969 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
970 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
971 // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
972 // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
973
974 MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
975 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
976 1);
977 MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet(
978 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
979 encoding),
980 1);
981 MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet(
982 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
983 1);
984 MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
985 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
986 encoding),
987 1);
988 MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
989 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
990 1);
991 MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
992 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
993 encoding),
994 1);
995 MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
996 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
997 1);
998 MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
999 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
1000 MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
1001 mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 1.0);
1002
1003 if (!mlirAttributeIsADenseElements(splatBool) ||
1004 !mlirDenseElementsAttrIsSplat(splatBool) ||
1005 !mlirAttributeIsADenseElements(splatUInt8) ||
1006 !mlirDenseElementsAttrIsSplat(splatUInt8) ||
1007 !mlirAttributeIsADenseElements(splatInt8) ||
1008 !mlirDenseElementsAttrIsSplat(splatInt8) ||
1009 !mlirAttributeIsADenseElements(splatUInt32) ||
1010 !mlirDenseElementsAttrIsSplat(splatUInt32) ||
1011 !mlirAttributeIsADenseElements(splatInt32) ||
1012 !mlirDenseElementsAttrIsSplat(splatInt32) ||
1013 !mlirAttributeIsADenseElements(splatUInt64) ||
1014 !mlirDenseElementsAttrIsSplat(splatUInt64) ||
1015 !mlirAttributeIsADenseElements(splatInt64) ||
1016 !mlirDenseElementsAttrIsSplat(splatInt64) ||
1017 !mlirAttributeIsADenseElements(splatFloat) ||
1018 !mlirDenseElementsAttrIsSplat(splatFloat) ||
1019 !mlirAttributeIsADenseElements(splatDouble) ||
1020 !mlirDenseElementsAttrIsSplat(splatDouble))
1021 return 16;
1022
1023 if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
1024 mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 ||
1025 mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 ||
1026 mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
1027 mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
1028 mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
1029 mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
1030 fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
1031 1E-6f ||
1032 fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
1033 return 17;
1034
1035 uint8_t *uint8RawData =
1036 (uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements);
1037 int8_t *int8RawData = (int8_t *)mlirDenseElementsAttrGetRawData(int8Elements);
1038 uint32_t *uint32RawData =
1039 (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
1040 int32_t *int32RawData =
1041 (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
1042 uint64_t *uint64RawData =
1043 (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
1044 int64_t *int64RawData =
1045 (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
1046 float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
1047 double *doubleRawData =
1048 (double *)mlirDenseElementsAttrGetRawData(doubleElements);
1049 if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
1050 int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
1051 int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
1052 uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
1053 floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
1054 doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
1055 return 18;
1056
1057 mlirAttributeDump(splatBool);
1058 mlirAttributeDump(splatUInt8);
1059 mlirAttributeDump(splatInt8);
1060 mlirAttributeDump(splatUInt32);
1061 mlirAttributeDump(splatInt32);
1062 mlirAttributeDump(splatUInt64);
1063 mlirAttributeDump(splatInt64);
1064 mlirAttributeDump(splatFloat);
1065 mlirAttributeDump(splatDouble);
1066 // CHECK: dense<true> : tensor<1x2xi1>
1067 // CHECK: dense<1> : tensor<1x2xui8>
1068 // CHECK: dense<1> : tensor<1x2xi8>
1069 // CHECK: dense<1> : tensor<1x2xui32>
1070 // CHECK: dense<1> : tensor<1x2xi32>
1071 // CHECK: dense<1> : tensor<1x2xui64>
1072 // CHECK: dense<1> : tensor<1x2xi64>
1073 // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
1074 // CHECK: dense<1.000000e+00> : tensor<1x2xf64>
1075
1076 mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
1077 mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
1078 // CHECK: 1.000000e+00 : f32
1079 // CHECK: 1.000000e+00 : f64
1080
1081 int64_t indices[] = {4, 7};
1082 int64_t two = 2;
1083 MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
1084 mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding),
1085 2, indices);
1086 MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
1087 mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding),
1088 2, floats);
1089 MlirAttribute sparseAttr = mlirSparseElementsAttribute(
1090 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
1091 indicesAttr, valuesAttr);
1092 mlirAttributeDump(sparseAttr);
1093 // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32>
1094
1095 return 0;
1096 }
1097
printAffineMap(MlirContext ctx)1098 int printAffineMap(MlirContext ctx) {
1099 MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
1100 MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2);
1101 MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
1102 MlirAffineMap multiDimIdentityAffineMap =
1103 mlirAffineMapMultiDimIdentityGet(ctx, 3);
1104 MlirAffineMap minorIdentityAffineMap =
1105 mlirAffineMapMinorIdentityGet(ctx, 3, 2);
1106 unsigned permutation[] = {1, 2, 0};
1107 MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet(
1108 ctx, sizeof(permutation) / sizeof(unsigned), permutation);
1109
1110 fprintf(stderr, "@affineMap\n");
1111 mlirAffineMapDump(emptyAffineMap);
1112 mlirAffineMapDump(affineMap);
1113 mlirAffineMapDump(constAffineMap);
1114 mlirAffineMapDump(multiDimIdentityAffineMap);
1115 mlirAffineMapDump(minorIdentityAffineMap);
1116 mlirAffineMapDump(permutationAffineMap);
1117 // CHECK-LABEL: @affineMap
1118 // CHECK: () -> ()
1119 // CHECK: (d0, d1, d2)[s0, s1] -> ()
1120 // CHECK: () -> (2)
1121 // CHECK: (d0, d1, d2) -> (d0, d1, d2)
1122 // CHECK: (d0, d1, d2) -> (d1, d2)
1123 // CHECK: (d0, d1, d2) -> (d1, d2, d0)
1124
1125 if (!mlirAffineMapIsIdentity(emptyAffineMap) ||
1126 mlirAffineMapIsIdentity(affineMap) ||
1127 mlirAffineMapIsIdentity(constAffineMap) ||
1128 !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) ||
1129 mlirAffineMapIsIdentity(minorIdentityAffineMap) ||
1130 mlirAffineMapIsIdentity(permutationAffineMap))
1131 return 1;
1132
1133 if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) ||
1134 mlirAffineMapIsMinorIdentity(affineMap) ||
1135 !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) ||
1136 !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) ||
1137 mlirAffineMapIsMinorIdentity(permutationAffineMap))
1138 return 2;
1139
1140 if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
1141 mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
1142 mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
1143 mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
1144 mlirAffineMapIsEmpty(permutationAffineMap))
1145 return 3;
1146
1147 if (mlirAffineMapIsSingleConstant(emptyAffineMap) ||
1148 mlirAffineMapIsSingleConstant(affineMap) ||
1149 !mlirAffineMapIsSingleConstant(constAffineMap) ||
1150 mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) ||
1151 mlirAffineMapIsSingleConstant(minorIdentityAffineMap) ||
1152 mlirAffineMapIsSingleConstant(permutationAffineMap))
1153 return 4;
1154
1155 if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2)
1156 return 5;
1157
1158 if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 ||
1159 mlirAffineMapGetNumDims(affineMap) != 3 ||
1160 mlirAffineMapGetNumDims(constAffineMap) != 0 ||
1161 mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 ||
1162 mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 ||
1163 mlirAffineMapGetNumDims(permutationAffineMap) != 3)
1164 return 6;
1165
1166 if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 ||
1167 mlirAffineMapGetNumSymbols(affineMap) != 2 ||
1168 mlirAffineMapGetNumSymbols(constAffineMap) != 0 ||
1169 mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 ||
1170 mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 ||
1171 mlirAffineMapGetNumSymbols(permutationAffineMap) != 0)
1172 return 7;
1173
1174 if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 ||
1175 mlirAffineMapGetNumResults(affineMap) != 0 ||
1176 mlirAffineMapGetNumResults(constAffineMap) != 1 ||
1177 mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 ||
1178 mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 ||
1179 mlirAffineMapGetNumResults(permutationAffineMap) != 3)
1180 return 8;
1181
1182 if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 ||
1183 mlirAffineMapGetNumInputs(affineMap) != 5 ||
1184 mlirAffineMapGetNumInputs(constAffineMap) != 0 ||
1185 mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 ||
1186 mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 ||
1187 mlirAffineMapGetNumInputs(permutationAffineMap) != 3)
1188 return 9;
1189
1190 if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) ||
1191 !mlirAffineMapIsPermutation(emptyAffineMap) ||
1192 mlirAffineMapIsProjectedPermutation(affineMap) ||
1193 mlirAffineMapIsPermutation(affineMap) ||
1194 mlirAffineMapIsProjectedPermutation(constAffineMap) ||
1195 mlirAffineMapIsPermutation(constAffineMap) ||
1196 !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) ||
1197 !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) ||
1198 !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) ||
1199 mlirAffineMapIsPermutation(minorIdentityAffineMap) ||
1200 !mlirAffineMapIsProjectedPermutation(permutationAffineMap) ||
1201 !mlirAffineMapIsPermutation(permutationAffineMap))
1202 return 10;
1203
1204 intptr_t sub[] = {1};
1205
1206 MlirAffineMap subMap = mlirAffineMapGetSubMap(
1207 multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub);
1208 MlirAffineMap majorSubMap =
1209 mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1);
1210 MlirAffineMap minorSubMap =
1211 mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1);
1212
1213 mlirAffineMapDump(subMap);
1214 mlirAffineMapDump(majorSubMap);
1215 mlirAffineMapDump(minorSubMap);
1216 // CHECK: (d0, d1, d2) -> (d1)
1217 // CHECK: (d0, d1, d2) -> (d0)
1218 // CHECK: (d0, d1, d2) -> (d2)
1219
1220 return 0;
1221 }
1222
printAffineExpr(MlirContext ctx)1223 int printAffineExpr(MlirContext ctx) {
1224 MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 5);
1225 MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 5);
1226 MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, 5);
1227 MlirAffineExpr affineAddExpr =
1228 mlirAffineAddExprGet(affineDimExpr, affineSymbolExpr);
1229 MlirAffineExpr affineMulExpr =
1230 mlirAffineMulExprGet(affineDimExpr, affineSymbolExpr);
1231 MlirAffineExpr affineModExpr =
1232 mlirAffineModExprGet(affineDimExpr, affineSymbolExpr);
1233 MlirAffineExpr affineFloorDivExpr =
1234 mlirAffineFloorDivExprGet(affineDimExpr, affineSymbolExpr);
1235 MlirAffineExpr affineCeilDivExpr =
1236 mlirAffineCeilDivExprGet(affineDimExpr, affineSymbolExpr);
1237
1238 // Tests mlirAffineExprDump.
1239 fprintf(stderr, "@affineExpr\n");
1240 mlirAffineExprDump(affineDimExpr);
1241 mlirAffineExprDump(affineSymbolExpr);
1242 mlirAffineExprDump(affineConstantExpr);
1243 mlirAffineExprDump(affineAddExpr);
1244 mlirAffineExprDump(affineMulExpr);
1245 mlirAffineExprDump(affineModExpr);
1246 mlirAffineExprDump(affineFloorDivExpr);
1247 mlirAffineExprDump(affineCeilDivExpr);
1248 // CHECK-LABEL: @affineExpr
1249 // CHECK: d5
1250 // CHECK: s5
1251 // CHECK: 5
1252 // CHECK: d5 + s5
1253 // CHECK: d5 * s5
1254 // CHECK: d5 mod s5
1255 // CHECK: d5 floordiv s5
1256 // CHECK: d5 ceildiv s5
1257
1258 // Tests methods of affine binary operation expression, takes add expression
1259 // as an example.
1260 mlirAffineExprDump(mlirAffineBinaryOpExprGetLHS(affineAddExpr));
1261 mlirAffineExprDump(mlirAffineBinaryOpExprGetRHS(affineAddExpr));
1262 // CHECK: d5
1263 // CHECK: s5
1264
1265 // Tests methods of affine dimension expression.
1266 if (mlirAffineDimExprGetPosition(affineDimExpr) != 5)
1267 return 1;
1268
1269 // Tests methods of affine symbol expression.
1270 if (mlirAffineSymbolExprGetPosition(affineSymbolExpr) != 5)
1271 return 2;
1272
1273 // Tests methods of affine constant expression.
1274 if (mlirAffineConstantExprGetValue(affineConstantExpr) != 5)
1275 return 3;
1276
1277 // Tests methods of affine expression.
1278 if (mlirAffineExprIsSymbolicOrConstant(affineDimExpr) ||
1279 !mlirAffineExprIsSymbolicOrConstant(affineSymbolExpr) ||
1280 !mlirAffineExprIsSymbolicOrConstant(affineConstantExpr) ||
1281 mlirAffineExprIsSymbolicOrConstant(affineAddExpr) ||
1282 mlirAffineExprIsSymbolicOrConstant(affineMulExpr) ||
1283 mlirAffineExprIsSymbolicOrConstant(affineModExpr) ||
1284 mlirAffineExprIsSymbolicOrConstant(affineFloorDivExpr) ||
1285 mlirAffineExprIsSymbolicOrConstant(affineCeilDivExpr))
1286 return 4;
1287
1288 if (!mlirAffineExprIsPureAffine(affineDimExpr) ||
1289 !mlirAffineExprIsPureAffine(affineSymbolExpr) ||
1290 !mlirAffineExprIsPureAffine(affineConstantExpr) ||
1291 !mlirAffineExprIsPureAffine(affineAddExpr) ||
1292 mlirAffineExprIsPureAffine(affineMulExpr) ||
1293 mlirAffineExprIsPureAffine(affineModExpr) ||
1294 mlirAffineExprIsPureAffine(affineFloorDivExpr) ||
1295 mlirAffineExprIsPureAffine(affineCeilDivExpr))
1296 return 5;
1297
1298 if (mlirAffineExprGetLargestKnownDivisor(affineDimExpr) != 1 ||
1299 mlirAffineExprGetLargestKnownDivisor(affineSymbolExpr) != 1 ||
1300 mlirAffineExprGetLargestKnownDivisor(affineConstantExpr) != 5 ||
1301 mlirAffineExprGetLargestKnownDivisor(affineAddExpr) != 1 ||
1302 mlirAffineExprGetLargestKnownDivisor(affineMulExpr) != 1 ||
1303 mlirAffineExprGetLargestKnownDivisor(affineModExpr) != 1 ||
1304 mlirAffineExprGetLargestKnownDivisor(affineFloorDivExpr) != 1 ||
1305 mlirAffineExprGetLargestKnownDivisor(affineCeilDivExpr) != 1)
1306 return 6;
1307
1308 if (!mlirAffineExprIsMultipleOf(affineDimExpr, 1) ||
1309 !mlirAffineExprIsMultipleOf(affineSymbolExpr, 1) ||
1310 !mlirAffineExprIsMultipleOf(affineConstantExpr, 5) ||
1311 !mlirAffineExprIsMultipleOf(affineAddExpr, 1) ||
1312 !mlirAffineExprIsMultipleOf(affineMulExpr, 1) ||
1313 !mlirAffineExprIsMultipleOf(affineModExpr, 1) ||
1314 !mlirAffineExprIsMultipleOf(affineFloorDivExpr, 1) ||
1315 !mlirAffineExprIsMultipleOf(affineCeilDivExpr, 1))
1316 return 7;
1317
1318 if (!mlirAffineExprIsFunctionOfDim(affineDimExpr, 5) ||
1319 mlirAffineExprIsFunctionOfDim(affineSymbolExpr, 5) ||
1320 mlirAffineExprIsFunctionOfDim(affineConstantExpr, 5) ||
1321 !mlirAffineExprIsFunctionOfDim(affineAddExpr, 5) ||
1322 !mlirAffineExprIsFunctionOfDim(affineMulExpr, 5) ||
1323 !mlirAffineExprIsFunctionOfDim(affineModExpr, 5) ||
1324 !mlirAffineExprIsFunctionOfDim(affineFloorDivExpr, 5) ||
1325 !mlirAffineExprIsFunctionOfDim(affineCeilDivExpr, 5))
1326 return 8;
1327
1328 // Tests 'IsA' methods of affine binary operation expression.
1329 if (!mlirAffineExprIsAAdd(affineAddExpr))
1330 return 9;
1331
1332 if (!mlirAffineExprIsAMul(affineMulExpr))
1333 return 10;
1334
1335 if (!mlirAffineExprIsAMod(affineModExpr))
1336 return 11;
1337
1338 if (!mlirAffineExprIsAFloorDiv(affineFloorDivExpr))
1339 return 12;
1340
1341 if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
1342 return 13;
1343
1344 if (!mlirAffineExprIsABinary(affineAddExpr))
1345 return 14;
1346
1347 // Test other 'IsA' method on affine expressions.
1348 if (!mlirAffineExprIsAConstant(affineConstantExpr))
1349 return 15;
1350
1351 if (!mlirAffineExprIsADim(affineDimExpr))
1352 return 16;
1353
1354 if (!mlirAffineExprIsASymbol(affineSymbolExpr))
1355 return 17;
1356
1357 // Test equality and nullity.
1358 MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
1359 if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
1360 return 18;
1361
1362 if (mlirAffineExprIsNull(affineDimExpr))
1363 return 19;
1364
1365 return 0;
1366 }
1367
affineMapFromExprs(MlirContext ctx)1368 int affineMapFromExprs(MlirContext ctx) {
1369 MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0);
1370 MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1);
1371 MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr};
1372 MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs);
1373
1374 // CHECK-LABEL: @affineMapFromExprs
1375 fprintf(stderr, "@affineMapFromExprs");
1376 // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1)
1377 mlirAffineMapDump(map);
1378
1379 if (mlirAffineMapGetNumResults(map) != 2)
1380 return 1;
1381
1382 if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr))
1383 return 2;
1384
1385 if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
1386 return 3;
1387
1388 return 0;
1389 }
1390
printIntegerSet(MlirContext ctx)1391 int printIntegerSet(MlirContext ctx) {
1392 MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1393
1394 // CHECK-LABEL: @printIntegerSet
1395 fprintf(stderr, "@printIntegerSet");
1396
1397 // CHECK: (d0, d1)[s0] : (1 == 0)
1398 mlirIntegerSetDump(emptySet);
1399
1400 if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
1401 return 1;
1402
1403 MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1404 if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
1405 return 2;
1406
1407 // Construct a set constrained by:
1408 // d0 - s0 == 0,
1409 // d1 - 42 >= 0.
1410 MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
1411 MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
1412 MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
1413 MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
1414 MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
1415 MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
1416 MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
1417 MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
1418 MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
1419 bool flags[] = {true, false};
1420
1421 MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
1422 // CHECK: (d0, d1)[s0] : (
1423 // CHECK-DAG: d0 - s0 == 0
1424 // CHECK-DAG: d1 - 42 >= 0
1425 mlirIntegerSetDump(set);
1426
1427 // Transform d1 into s0.
1428 MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
1429 MlirAffineExpr repl[] = {d0, s1};
1430 MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
1431 // CHECK: (d0)[s0, s1] : (
1432 // CHECK-DAG: d0 - s0 == 0
1433 // CHECK-DAG: s1 - 42 >= 0
1434 mlirIntegerSetDump(replaced);
1435
1436 if (mlirIntegerSetGetNumDims(set) != 2)
1437 return 3;
1438 if (mlirIntegerSetGetNumDims(replaced) != 1)
1439 return 4;
1440
1441 if (mlirIntegerSetGetNumSymbols(set) != 1)
1442 return 5;
1443 if (mlirIntegerSetGetNumSymbols(replaced) != 2)
1444 return 6;
1445
1446 if (mlirIntegerSetGetNumInputs(set) != 3)
1447 return 7;
1448
1449 if (mlirIntegerSetGetNumConstraints(set) != 2)
1450 return 8;
1451
1452 if (mlirIntegerSetGetNumEqualities(set) != 1)
1453 return 9;
1454
1455 if (mlirIntegerSetGetNumInequalities(set) != 1)
1456 return 10;
1457
1458 MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
1459 MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
1460 bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
1461 bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
1462 if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
1463 return 11;
1464 if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
1465 return 12;
1466
1467 return 0;
1468 }
1469
registerOnlyStd()1470 int registerOnlyStd() {
1471 MlirContext ctx = mlirContextCreate();
1472 // The built-in dialect is always loaded.
1473 if (mlirContextGetNumLoadedDialects(ctx) != 1)
1474 return 1;
1475
1476 MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
1477
1478 MlirDialect std = mlirContextGetOrLoadDialect(
1479 ctx, mlirDialectHandleGetNamespace(stdHandle));
1480 if (!mlirDialectIsNull(std))
1481 return 2;
1482
1483 mlirDialectHandleRegisterDialect(stdHandle, ctx);
1484
1485 std = mlirContextGetOrLoadDialect(ctx,
1486 mlirDialectHandleGetNamespace(stdHandle));
1487 if (mlirDialectIsNull(std))
1488 return 3;
1489
1490 MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx);
1491 if (!mlirDialectEqual(std, alsoStd))
1492 return 4;
1493
1494 MlirStringRef stdNs = mlirDialectGetNamespace(std);
1495 MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle);
1496 if (stdNs.length != alsoStdNs.length ||
1497 strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
1498 return 5;
1499
1500 fprintf(stderr, "@registration\n");
1501 // CHECK-LABEL: @registration
1502
1503 // CHECK: std.cond_br is_registered: 1
1504 fprintf(stderr, "std.cond_br is_registered: %d\n",
1505 mlirContextIsRegisteredOperation(
1506 ctx, mlirStringRefCreateFromCString("std.cond_br")));
1507
1508 // CHECK: std.not_existing_op is_registered: 0
1509 fprintf(stderr, "std.not_existing_op is_registered: %d\n",
1510 mlirContextIsRegisteredOperation(
1511 ctx, mlirStringRefCreateFromCString("std.not_existing_op")));
1512
1513 // CHECK: not_existing_dialect.not_existing_op is_registered: 0
1514 fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
1515 mlirContextIsRegisteredOperation(
1516 ctx, mlirStringRefCreateFromCString(
1517 "not_existing_dialect.not_existing_op")));
1518
1519 return 0;
1520 }
1521
1522 /// Tests backreference APIs
testBackreferences()1523 static int testBackreferences() {
1524 fprintf(stderr, "@test_backreferences\n");
1525
1526 MlirContext ctx = mlirContextCreate();
1527 mlirContextSetAllowUnregisteredDialects(ctx, true);
1528 MlirLocation loc = mlirLocationUnknownGet(ctx);
1529
1530 MlirOperationState opState =
1531 mlirOperationStateGet(mlirStringRefCreateFromCString("invalid.op"), loc);
1532 MlirRegion region = mlirRegionCreate();
1533 MlirBlock block = mlirBlockCreate(0, NULL);
1534 mlirRegionAppendOwnedBlock(region, block);
1535 mlirOperationStateAddOwnedRegions(&opState, 1, ®ion);
1536 MlirOperation op = mlirOperationCreate(&opState);
1537 MlirIdentifier ident =
1538 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("identifier"));
1539
1540 if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) {
1541 fprintf(stderr, "ERROR: Getting context from operation failed\n");
1542 return 1;
1543 }
1544 if (!mlirOperationEqual(op, mlirBlockGetParentOperation(block))) {
1545 fprintf(stderr, "ERROR: Getting parent operation from block failed\n");
1546 return 2;
1547 }
1548 if (!mlirContextEqual(ctx, mlirIdentifierGetContext(ident))) {
1549 fprintf(stderr, "ERROR: Getting context from identifier failed\n");
1550 return 3;
1551 }
1552
1553 mlirOperationDestroy(op);
1554 mlirContextDestroy(ctx);
1555
1556 // CHECK-LABEL: @test_backreferences
1557 return 0;
1558 }
1559
1560 /// Tests operand APIs.
testOperands()1561 int testOperands() {
1562 fprintf(stderr, "@testOperands\n");
1563 // CHECK-LABEL: @testOperands
1564
1565 MlirContext ctx = mlirContextCreate();
1566 mlirRegisterAllDialects(ctx);
1567 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
1568 MlirLocation loc = mlirLocationUnknownGet(ctx);
1569 MlirType indexType = mlirIndexTypeGet(ctx);
1570
1571 // Create some constants to use as operands.
1572 MlirAttribute indexZeroLiteral =
1573 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1574 MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1575 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1576 indexZeroLiteral);
1577 MlirOperationState constZeroState = mlirOperationStateGet(
1578 mlirStringRefCreateFromCString("std.constant"), loc);
1579 mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1580 mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1581 MlirOperation constZero = mlirOperationCreate(&constZeroState);
1582 MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
1583
1584 MlirAttribute indexOneLiteral =
1585 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1586 MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
1587 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1588 indexOneLiteral);
1589 MlirOperationState constOneState = mlirOperationStateGet(
1590 mlirStringRefCreateFromCString("std.constant"), loc);
1591 mlirOperationStateAddResults(&constOneState, 1, &indexType);
1592 mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
1593 MlirOperation constOne = mlirOperationCreate(&constOneState);
1594 MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
1595
1596 // Create the operation under test.
1597 mlirContextSetAllowUnregisteredDialects(ctx, true);
1598 MlirOperationState opState =
1599 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
1600 MlirValue initialOperands[] = {constZeroValue};
1601 mlirOperationStateAddOperands(&opState, 1, initialOperands);
1602 MlirOperation op = mlirOperationCreate(&opState);
1603
1604 // Test operand APIs.
1605 intptr_t numOperands = mlirOperationGetNumOperands(op);
1606 fprintf(stderr, "Num Operands: %" PRIdPTR "\n", numOperands);
1607 // CHECK: Num Operands: 1
1608
1609 MlirValue opOperand = mlirOperationGetOperand(op, 0);
1610 fprintf(stderr, "Original operand: ");
1611 mlirValuePrint(opOperand, printToStderr, NULL);
1612 // CHECK: Original operand: {{.+}} constant 0 : index
1613
1614 mlirOperationSetOperand(op, 0, constOneValue);
1615 opOperand = mlirOperationGetOperand(op, 0);
1616 fprintf(stderr, "Updated operand: ");
1617 mlirValuePrint(opOperand, printToStderr, NULL);
1618 // CHECK: Updated operand: {{.+}} constant 1 : index
1619
1620 mlirOperationDestroy(op);
1621 mlirOperationDestroy(constZero);
1622 mlirOperationDestroy(constOne);
1623 mlirContextDestroy(ctx);
1624
1625 return 0;
1626 }
1627
1628 /// Tests clone APIs.
testClone()1629 int testClone() {
1630 fprintf(stderr, "@testClone\n");
1631 // CHECK-LABEL: @testClone
1632
1633 MlirContext ctx = mlirContextCreate();
1634 mlirRegisterAllDialects(ctx);
1635 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("std"));
1636 MlirLocation loc = mlirLocationUnknownGet(ctx);
1637 MlirType indexType = mlirIndexTypeGet(ctx);
1638 MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
1639
1640 MlirAttribute indexZeroLiteral =
1641 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1642 MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
1643 MlirOperationState constZeroState = mlirOperationStateGet(
1644 mlirStringRefCreateFromCString("std.constant"), loc);
1645 mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1646 mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1647 MlirOperation constZero = mlirOperationCreate(&constZeroState);
1648
1649 MlirAttribute indexOneLiteral =
1650 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1651 MlirOperation constOne = mlirOperationClone(constZero);
1652 mlirOperationSetAttributeByName(constOne, valueStringRef, indexOneLiteral);
1653
1654 mlirOperationPrint(constZero, printToStderr, NULL);
1655 mlirOperationPrint(constOne, printToStderr, NULL);
1656 // CHECK: constant 0 : index
1657 // CHECK: constant 1 : index
1658
1659 return 0;
1660 }
1661
1662 // Wraps a diagnostic into additional text we can match against.
errorHandler(MlirDiagnostic diagnostic,void * userData)1663 MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
1664 fprintf(stderr, "processing diagnostic (userData: %" PRIdPTR ") <<\n",
1665 (intptr_t)userData);
1666 mlirDiagnosticPrint(diagnostic, printToStderr, NULL);
1667 fprintf(stderr, "\n");
1668 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1669 mlirLocationPrint(loc, printToStderr, NULL);
1670 assert(mlirDiagnosticGetNumNotes(diagnostic) == 0);
1671 fprintf(stderr, "\n>> end of diagnostic (userData: %" PRIdPTR ")\n",
1672 (intptr_t)userData);
1673 return mlirLogicalResultSuccess();
1674 }
1675
1676 // Logs when the delete user data callback is called
deleteUserData(void * userData)1677 static void deleteUserData(void *userData) {
1678 fprintf(stderr, "deleting user data (userData: %" PRIdPTR ")\n",
1679 (intptr_t)userData);
1680 }
1681
testDiagnostics()1682 void testDiagnostics() {
1683 MlirContext ctx = mlirContextCreate();
1684 MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
1685 ctx, errorHandler, (void *)42, deleteUserData);
1686 fprintf(stderr, "@test_diagnostics\n");
1687 MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
1688 mlirEmitError(unknownLoc, "test diagnostics");
1689 MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
1690 ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
1691 mlirEmitError(fileLineColLoc, "test diagnostics");
1692 MlirLocation callSiteLoc = mlirLocationCallSiteGet(
1693 mlirLocationFileLineColGet(
1694 ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
1695 fileLineColLoc);
1696 mlirEmitError(callSiteLoc, "test diagnostics");
1697 mlirContextDetachDiagnosticHandler(ctx, id);
1698 mlirEmitError(unknownLoc, "more test diagnostics");
1699 // CHECK-LABEL: @test_diagnostics
1700 // CHECK: processing diagnostic (userData: 42) <<
1701 // CHECK: test diagnostics
1702 // CHECK: loc(unknown)
1703 // CHECK: >> end of diagnostic (userData: 42)
1704 // CHECK: processing diagnostic (userData: 42) <<
1705 // CHECK: test diagnostics
1706 // CHECK: loc("file.c":1:2)
1707 // CHECK: >> end of diagnostic (userData: 42)
1708 // CHECK: processing diagnostic (userData: 42) <<
1709 // CHECK: test diagnostics
1710 // CHECK: loc(callsite("other-file.c":2:3 at "file.c":1:2))
1711 // CHECK: >> end of diagnostic (userData: 42)
1712 // CHECK: deleting user data (userData: 42)
1713 // CHECK-NOT: processing diagnostic
1714 // CHECK: more test diagnostics
1715 }
1716
main()1717 int main() {
1718 MlirContext ctx = mlirContextCreate();
1719 mlirRegisterAllDialects(ctx);
1720 if (constructAndTraverseIr(ctx))
1721 return 1;
1722 buildWithInsertionsAndPrint(ctx);
1723 if (createOperationWithTypeInference(ctx))
1724 return 2;
1725
1726 if (printBuiltinTypes(ctx))
1727 return 3;
1728 if (printBuiltinAttributes(ctx))
1729 return 4;
1730 if (printAffineMap(ctx))
1731 return 5;
1732 if (printAffineExpr(ctx))
1733 return 6;
1734 if (affineMapFromExprs(ctx))
1735 return 7;
1736 if (printIntegerSet(ctx))
1737 return 8;
1738 if (registerOnlyStd())
1739 return 9;
1740 if (testBackreferences())
1741 return 10;
1742 if (testOperands())
1743 return 11;
1744 if (testClone())
1745 return 12;
1746
1747 mlirContextDestroy(ctx);
1748
1749 testDiagnostics();
1750 return 0;
1751 }
1752