1 //===- ir.c - Simple test of C APIs ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
11  */
12 
13 #include "mlir-c/IR.h"
14 #include "mlir-c/AffineExpr.h"
15 #include "mlir-c/AffineMap.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/BuiltinTypes.h"
18 #include "mlir-c/Diagnostics.h"
19 #include "mlir-c/Dialect/Standard.h"
20 #include "mlir-c/IntegerSet.h"
21 #include "mlir-c/Registration.h"
22 
23 #include <assert.h>
24 #include <inttypes.h>
25 #include <math.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 
populateLoopBody(MlirContext ctx,MlirBlock loopBody,MlirLocation location,MlirBlock funcBody)30 void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
31                       MlirLocation location, MlirBlock funcBody) {
32   MlirValue iv = mlirBlockGetArgument(loopBody, 0);
33   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
34   MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
35   MlirType f32Type =
36       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));
37 
38   MlirOperationState loadLHSState = mlirOperationStateGet(
39       mlirStringRefCreateFromCString("memref.load"), location);
40   MlirValue loadLHSOperands[] = {funcArg0, iv};
41   mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
42   mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
43   MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
44   mlirBlockAppendOwnedOperation(loopBody, loadLHS);
45 
46   MlirOperationState loadRHSState = mlirOperationStateGet(
47       mlirStringRefCreateFromCString("memref.load"), location);
48   MlirValue loadRHSOperands[] = {funcArg1, iv};
49   mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
50   mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
51   MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
52   mlirBlockAppendOwnedOperation(loopBody, loadRHS);
53 
54   MlirOperationState addState = mlirOperationStateGet(
55       mlirStringRefCreateFromCString("std.addf"), location);
56   MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
57                              mlirOperationGetResult(loadRHS, 0)};
58   mlirOperationStateAddOperands(&addState, 2, addOperands);
59   mlirOperationStateAddResults(&addState, 1, &f32Type);
60   MlirOperation add = mlirOperationCreate(&addState);
61   mlirBlockAppendOwnedOperation(loopBody, add);
62 
63   MlirOperationState storeState = mlirOperationStateGet(
64       mlirStringRefCreateFromCString("memref.store"), location);
65   MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
66   mlirOperationStateAddOperands(&storeState, 3, storeOperands);
67   MlirOperation store = mlirOperationCreate(&storeState);
68   mlirBlockAppendOwnedOperation(loopBody, store);
69 
70   MlirOperationState yieldState = mlirOperationStateGet(
71       mlirStringRefCreateFromCString("scf.yield"), location);
72   MlirOperation yield = mlirOperationCreate(&yieldState);
73   mlirBlockAppendOwnedOperation(loopBody, yield);
74 }
75 
makeAndDumpAdd(MlirContext ctx,MlirLocation location)76 MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
77   MlirModule moduleOp = mlirModuleCreateEmpty(location);
78   MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
79 
80   MlirType memrefType =
81       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
82   MlirType funcBodyArgTypes[] = {memrefType, memrefType};
83   MlirRegion funcBodyRegion = mlirRegionCreate();
84   MlirBlock funcBody = mlirBlockCreate(
85       sizeof(funcBodyArgTypes) / sizeof(MlirType), funcBodyArgTypes);
86   mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);
87 
88   MlirAttribute funcTypeAttr = mlirAttributeParseGet(
89       ctx,
90       mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
91   MlirAttribute funcNameAttr =
92       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
93   MlirNamedAttribute funcAttrs[] = {
94       mlirNamedAttributeGet(
95           mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("type")),
96           funcTypeAttr),
97       mlirNamedAttributeGet(
98           mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sym_name")),
99           funcNameAttr)};
100   MlirOperationState funcState =
101       mlirOperationStateGet(mlirStringRefCreateFromCString("func"), location);
102   mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
103   mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
104   MlirOperation func = mlirOperationCreate(&funcState);
105   mlirBlockInsertOwnedOperation(moduleBody, 0, func);
106 
107   MlirType indexType =
108       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
109   MlirAttribute indexZeroLiteral =
110       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
111   MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
112       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
113       indexZeroLiteral);
114   MlirOperationState constZeroState = mlirOperationStateGet(
115       mlirStringRefCreateFromCString("std.constant"), location);
116   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
117   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
118   MlirOperation constZero = mlirOperationCreate(&constZeroState);
119   mlirBlockAppendOwnedOperation(funcBody, constZero);
120 
121   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
122   MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
123   MlirValue dimOperands[] = {funcArg0, constZeroValue};
124   MlirOperationState dimState = mlirOperationStateGet(
125       mlirStringRefCreateFromCString("memref.dim"), location);
126   mlirOperationStateAddOperands(&dimState, 2, dimOperands);
127   mlirOperationStateAddResults(&dimState, 1, &indexType);
128   MlirOperation dim = mlirOperationCreate(&dimState);
129   mlirBlockAppendOwnedOperation(funcBody, dim);
130 
131   MlirRegion loopBodyRegion = mlirRegionCreate();
132   MlirBlock loopBody = mlirBlockCreate(0, NULL);
133   mlirBlockAddArgument(loopBody, indexType);
134   mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);
135 
136   MlirAttribute indexOneLiteral =
137       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
138   MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
139       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
140       indexOneLiteral);
141   MlirOperationState constOneState = mlirOperationStateGet(
142       mlirStringRefCreateFromCString("std.constant"), location);
143   mlirOperationStateAddResults(&constOneState, 1, &indexType);
144   mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
145   MlirOperation constOne = mlirOperationCreate(&constOneState);
146   mlirBlockAppendOwnedOperation(funcBody, constOne);
147 
148   MlirValue dimValue = mlirOperationGetResult(dim, 0);
149   MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
150   MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
151   MlirOperationState loopState = mlirOperationStateGet(
152       mlirStringRefCreateFromCString("scf.for"), location);
153   mlirOperationStateAddOperands(&loopState, 3, loopOperands);
154   mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
155   MlirOperation loop = mlirOperationCreate(&loopState);
156   mlirBlockAppendOwnedOperation(funcBody, loop);
157 
158   populateLoopBody(ctx, loopBody, location, funcBody);
159 
160   MlirOperationState retState = mlirOperationStateGet(
161       mlirStringRefCreateFromCString("std.return"), location);
162   MlirOperation ret = mlirOperationCreate(&retState);
163   mlirBlockAppendOwnedOperation(funcBody, ret);
164 
165   MlirOperation module = mlirModuleGetOperation(moduleOp);
166   mlirOperationDump(module);
167   // clang-format off
168   // CHECK: module {
169   // CHECK:   func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) {
170   // CHECK:     %[[C0:.*]] = constant 0 : index
171   // CHECK:     %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32>
172   // CHECK:     %[[C1:.*]] = constant 1 : index
173   // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
174   // CHECK:       %[[LHS:.*]] = memref.load %[[ARG0]][%[[I]]] : memref<?xf32>
175   // CHECK:       %[[RHS:.*]] = memref.load %[[ARG1]][%[[I]]] : memref<?xf32>
176   // CHECK:       %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
177   // CHECK:       memref.store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32>
178   // CHECK:     }
179   // CHECK:     return
180   // CHECK:   }
181   // CHECK: }
182   // clang-format on
183 
184   return moduleOp;
185 }
186 
187 struct OpListNode {
188   MlirOperation op;
189   struct OpListNode *next;
190 };
191 typedef struct OpListNode OpListNode;
192 
193 struct ModuleStats {
194   unsigned numOperations;
195   unsigned numAttributes;
196   unsigned numBlocks;
197   unsigned numRegions;
198   unsigned numValues;
199   unsigned numBlockArguments;
200   unsigned numOpResults;
201 };
202 typedef struct ModuleStats ModuleStats;
203 
collectStatsSingle(OpListNode * head,ModuleStats * stats)204 int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
205   MlirOperation operation = head->op;
206   stats->numOperations += 1;
207   stats->numValues += mlirOperationGetNumResults(operation);
208   stats->numAttributes += mlirOperationGetNumAttributes(operation);
209 
210   unsigned numRegions = mlirOperationGetNumRegions(operation);
211 
212   stats->numRegions += numRegions;
213 
214   intptr_t numResults = mlirOperationGetNumResults(operation);
215   for (intptr_t i = 0; i < numResults; ++i) {
216     MlirValue result = mlirOperationGetResult(operation, i);
217     if (!mlirValueIsAOpResult(result))
218       return 1;
219     if (mlirValueIsABlockArgument(result))
220       return 2;
221     if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
222       return 3;
223     if (i != mlirOpResultGetResultNumber(result))
224       return 4;
225     ++stats->numOpResults;
226   }
227 
228   for (unsigned i = 0; i < numRegions; ++i) {
229     MlirRegion region = mlirOperationGetRegion(operation, i);
230     for (MlirBlock block = mlirRegionGetFirstBlock(region);
231          !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
232       ++stats->numBlocks;
233       intptr_t numArgs = mlirBlockGetNumArguments(block);
234       stats->numValues += numArgs;
235       for (intptr_t j = 0; j < numArgs; ++j) {
236         MlirValue arg = mlirBlockGetArgument(block, j);
237         if (!mlirValueIsABlockArgument(arg))
238           return 5;
239         if (mlirValueIsAOpResult(arg))
240           return 6;
241         if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
242           return 7;
243         if (j != mlirBlockArgumentGetArgNumber(arg))
244           return 8;
245         ++stats->numBlockArguments;
246       }
247 
248       for (MlirOperation child = mlirBlockGetFirstOperation(block);
249            !mlirOperationIsNull(child);
250            child = mlirOperationGetNextInBlock(child)) {
251         OpListNode *node = malloc(sizeof(OpListNode));
252         node->op = child;
253         node->next = head->next;
254         head->next = node;
255       }
256     }
257   }
258   return 0;
259 }
260 
collectStats(MlirOperation operation)261 int collectStats(MlirOperation operation) {
262   OpListNode *head = malloc(sizeof(OpListNode));
263   head->op = operation;
264   head->next = NULL;
265 
266   ModuleStats stats;
267   stats.numOperations = 0;
268   stats.numAttributes = 0;
269   stats.numBlocks = 0;
270   stats.numRegions = 0;
271   stats.numValues = 0;
272   stats.numBlockArguments = 0;
273   stats.numOpResults = 0;
274 
275   do {
276     int retval = collectStatsSingle(head, &stats);
277     if (retval)
278       return retval;
279     OpListNode *next = head->next;
280     free(head);
281     head = next;
282   } while (head);
283 
284   if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
285     return 100;
286 
287   fprintf(stderr, "@stats\n");
288   fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
289   fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
290   fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
291   fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
292   fprintf(stderr, "Number of values: %u\n", stats.numValues);
293   fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
294   fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
295   // clang-format off
296   // CHECK-LABEL: @stats
297   // CHECK: Number of operations: 12
298   // CHECK: Number of attributes: 4
299   // CHECK: Number of blocks: 3
300   // CHECK: Number of regions: 3
301   // CHECK: Number of values: 9
302   // CHECK: Number of block arguments: 3
303   // CHECK: Number of op results: 6
304   // clang-format on
305   return 0;
306 }
307 
printToStderr(MlirStringRef str,void * userData)308 static void printToStderr(MlirStringRef str, void *userData) {
309   (void)userData;
310   fwrite(str.data, 1, str.length, stderr);
311 }
312 
printFirstOfEach(MlirContext ctx,MlirOperation operation)313 static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
314   // Assuming we are given a module, go to the first operation of the first
315   // function.
316   MlirRegion region = mlirOperationGetRegion(operation, 0);
317   MlirBlock block = mlirRegionGetFirstBlock(region);
318   operation = mlirBlockGetFirstOperation(block);
319   region = mlirOperationGetRegion(operation, 0);
320   MlirOperation parentOperation = operation;
321   block = mlirRegionGetFirstBlock(region);
322   operation = mlirBlockGetFirstOperation(block);
323   assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
324 
325   // Verify that parent operation and block report correctly.
326   fprintf(stderr, "Parent operation eq: %d\n",
327           mlirOperationEqual(mlirOperationGetParentOperation(operation),
328                              parentOperation));
329   fprintf(stderr, "Block eq: %d\n",
330           mlirBlockEqual(mlirOperationGetBlock(operation), block));
331   // CHECK: Parent operation eq: 1
332   // CHECK: Block eq: 1
333 
334   // In the module we created, the first operation of the first function is
335   // an "memref.dim", which has an attribute and a single result that we can
336   // use to test the printing mechanism.
337   mlirBlockPrint(block, printToStderr, NULL);
338   fprintf(stderr, "\n");
339   fprintf(stderr, "First operation: ");
340   mlirOperationPrint(operation, printToStderr, NULL);
341   fprintf(stderr, "\n");
342   // clang-format off
343   // CHECK:   %[[C0:.*]] = constant 0 : index
344   // CHECK:   %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
345   // CHECK:   %[[C1:.*]] = constant 1 : index
346   // CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
347   // CHECK:     %[[LHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
348   // CHECK:     %[[RHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
349   // CHECK:     %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
350   // CHECK:     memref.store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
351   // CHECK:   }
352   // CHECK: return
353   // CHECK: First operation: {{.*}} = constant 0 : index
354   // clang-format on
355 
356   // Get the operation name and print it.
357   MlirIdentifier ident = mlirOperationGetName(operation);
358   MlirStringRef identStr = mlirIdentifierStr(ident);
359   fprintf(stderr, "Operation name: '");
360   for (size_t i = 0; i < identStr.length; ++i)
361     fputc(identStr.data[i], stderr);
362   fprintf(stderr, "'\n");
363   // CHECK: Operation name: 'std.constant'
364 
365   // Get the identifier again and verify equal.
366   MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr);
367   fprintf(stderr, "Identifier equal: %d\n",
368           mlirIdentifierEqual(ident, identAgain));
369   // CHECK: Identifier equal: 1
370 
371   // Get the block terminator and print it.
372   MlirOperation terminator = mlirBlockGetTerminator(block);
373   fprintf(stderr, "Terminator: ");
374   mlirOperationPrint(terminator, printToStderr, NULL);
375   fprintf(stderr, "\n");
376   // CHECK: Terminator: return
377 
378   // Get the attribute by index.
379   MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
380   fprintf(stderr, "Get attr 0: ");
381   mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
382   fprintf(stderr, "\n");
383   // CHECK: Get attr 0: 0 : index
384 
385   // Now re-get the attribute by name.
386   MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
387       operation, mlirIdentifierStr(namedAttr0.name));
388   fprintf(stderr, "Get attr 0 by name: ");
389   mlirAttributePrint(attr0ByName, printToStderr, NULL);
390   fprintf(stderr, "\n");
391   // CHECK: Get attr 0 by name: 0 : index
392 
393   // Get a non-existing attribute and assert that it is null (sanity).
394   fprintf(stderr, "does_not_exist is null: %d\n",
395           mlirAttributeIsNull(mlirOperationGetAttributeByName(
396               operation, mlirStringRefCreateFromCString("does_not_exist"))));
397   // CHECK: does_not_exist is null: 1
398 
399   // Get result 0 and its type.
400   MlirValue value = mlirOperationGetResult(operation, 0);
401   fprintf(stderr, "Result 0: ");
402   mlirValuePrint(value, printToStderr, NULL);
403   fprintf(stderr, "\n");
404   fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
405   // CHECK: Result 0: {{.*}} = constant 0 : index
406   // CHECK: Value is null: 0
407 
408   MlirType type = mlirValueGetType(value);
409   fprintf(stderr, "Result 0 type: ");
410   mlirTypePrint(type, printToStderr, NULL);
411   fprintf(stderr, "\n");
412   // CHECK: Result 0 type: index
413 
414   // Set a custom attribute.
415   mlirOperationSetAttributeByName(operation,
416                                   mlirStringRefCreateFromCString("custom_attr"),
417                                   mlirBoolAttrGet(ctx, 1));
418   fprintf(stderr, "Op with set attr: ");
419   mlirOperationPrint(operation, printToStderr, NULL);
420   fprintf(stderr, "\n");
421   // CHECK: Op with set attr: {{.*}} {custom_attr = true}
422 
423   // Remove the attribute.
424   fprintf(stderr, "Remove attr: %d\n",
425           mlirOperationRemoveAttributeByName(
426               operation, mlirStringRefCreateFromCString("custom_attr")));
427   fprintf(stderr, "Remove attr again: %d\n",
428           mlirOperationRemoveAttributeByName(
429               operation, mlirStringRefCreateFromCString("custom_attr")));
430   fprintf(stderr, "Removed attr is null: %d\n",
431           mlirAttributeIsNull(mlirOperationGetAttributeByName(
432               operation, mlirStringRefCreateFromCString("custom_attr"))));
433   // CHECK: Remove attr: 1
434   // CHECK: Remove attr again: 0
435   // CHECK: Removed attr is null: 1
436 
437   // Add a large attribute to verify printing flags.
438   int64_t eltsShape[] = {4};
439   int32_t eltsData[] = {1, 2, 3, 4};
440   mlirOperationSetAttributeByName(
441       operation, mlirStringRefCreateFromCString("elts"),
442       mlirDenseElementsAttrInt32Get(
443           mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
444                                   mlirAttributeGetNull()), 4, eltsData));
445   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
446   mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
447   mlirOpPrintingFlagsPrintGenericOpForm(flags);
448   mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
449   mlirOpPrintingFlagsUseLocalScope(flags);
450   fprintf(stderr, "Op print with all flags: ");
451   mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
452   fprintf(stderr, "\n");
453   // clang-format off
454   // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
455   // clang-format on
456 
457   mlirOpPrintingFlagsDestroy(flags);
458 }
459 
constructAndTraverseIr(MlirContext ctx)460 static int constructAndTraverseIr(MlirContext ctx) {
461   MlirLocation location = mlirLocationUnknownGet(ctx);
462 
463   MlirModule moduleOp = makeAndDumpAdd(ctx, location);
464   MlirOperation module = mlirModuleGetOperation(moduleOp);
465   assert(!mlirModuleIsNull(mlirModuleFromOperation(module)));
466 
467   int errcode = collectStats(module);
468   if (errcode)
469     return errcode;
470 
471   printFirstOfEach(ctx, module);
472 
473   mlirModuleDestroy(moduleOp);
474   return 0;
475 }
476 
477 /// Creates an operation with a region containing multiple blocks with
478 /// operations and dumps it. The blocks and operations are inserted using
479 /// block/operation-relative API and their final order is checked.
buildWithInsertionsAndPrint(MlirContext ctx)480 static void buildWithInsertionsAndPrint(MlirContext ctx) {
481   MlirLocation loc = mlirLocationUnknownGet(ctx);
482   mlirContextSetAllowUnregisteredDialects(ctx, true);
483 
484   MlirRegion owningRegion = mlirRegionCreate();
485   MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
486   MlirOperationState state = mlirOperationStateGet(
487       mlirStringRefCreateFromCString("insertion.order.test"), loc);
488   mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
489   MlirOperation op = mlirOperationCreate(&state);
490   MlirRegion region = mlirOperationGetRegion(op, 0);
491 
492   // Use integer types of different bitwidth as block arguments in order to
493   // differentiate blocks.
494   MlirType i1 = mlirIntegerTypeGet(ctx, 1);
495   MlirType i2 = mlirIntegerTypeGet(ctx, 2);
496   MlirType i3 = mlirIntegerTypeGet(ctx, 3);
497   MlirType i4 = mlirIntegerTypeGet(ctx, 4);
498   MlirBlock block1 = mlirBlockCreate(1, &i1);
499   MlirBlock block2 = mlirBlockCreate(1, &i2);
500   MlirBlock block3 = mlirBlockCreate(1, &i3);
501   MlirBlock block4 = mlirBlockCreate(1, &i4);
502   // Insert blocks so as to obtain the 1-2-3-4 order,
503   mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
504   mlirRegionInsertOwnedBlockBefore(region, block3, block2);
505   mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
506   mlirRegionInsertOwnedBlockAfter(region, block3, block4);
507 
508   MlirOperationState op1State =
509       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
510   MlirOperationState op2State =
511       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
512   MlirOperationState op3State =
513       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
514   MlirOperationState op4State =
515       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
516   MlirOperationState op5State =
517       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
518   MlirOperationState op6State =
519       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
520   MlirOperationState op7State =
521       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
522   MlirOperation op1 = mlirOperationCreate(&op1State);
523   MlirOperation op2 = mlirOperationCreate(&op2State);
524   MlirOperation op3 = mlirOperationCreate(&op3State);
525   MlirOperation op4 = mlirOperationCreate(&op4State);
526   MlirOperation op5 = mlirOperationCreate(&op5State);
527   MlirOperation op6 = mlirOperationCreate(&op6State);
528   MlirOperation op7 = mlirOperationCreate(&op7State);
529 
530   // Insert operations in the first block so as to obtain the 1-2-3-4 order.
531   MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
532   assert(mlirOperationIsNull(nullOperation));
533   mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
534   mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
535   mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
536   mlirBlockInsertOwnedOperationAfter(block1, op3, op4);
537 
538   // Append operations to the rest of blocks to make them non-empty and thus
539   // printable.
540   mlirBlockAppendOwnedOperation(block2, op5);
541   mlirBlockAppendOwnedOperation(block3, op6);
542   mlirBlockAppendOwnedOperation(block4, op7);
543 
544   mlirOperationDump(op);
545   mlirOperationDestroy(op);
546   mlirContextSetAllowUnregisteredDialects(ctx, false);
547   // clang-format off
548   // CHECK-LABEL:  "insertion.order.test"
549   // CHECK:      ^{{.*}}(%{{.*}}: i1
550   // CHECK:        "dummy.op1"
551   // CHECK-NEXT:   "dummy.op2"
552   // CHECK-NEXT:   "dummy.op3"
553   // CHECK-NEXT:   "dummy.op4"
554   // CHECK:      ^{{.*}}(%{{.*}}: i2
555   // CHECK:        "dummy.op5"
556   // CHECK:      ^{{.*}}(%{{.*}}: i3
557   // CHECK:        "dummy.op6"
558   // CHECK:      ^{{.*}}(%{{.*}}: i4
559   // CHECK:        "dummy.op7"
560   // clang-format on
561 }
562 
563 /// Creates operations with type inference and tests various failure modes.
createOperationWithTypeInference(MlirContext ctx)564 static int createOperationWithTypeInference(MlirContext ctx) {
565   MlirLocation loc = mlirLocationUnknownGet(ctx);
566   MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4);
567 
568   // The shape.const_size op implements result type inference and is only used
569   // for that reason.
570   MlirOperationState state = mlirOperationStateGet(
571       mlirStringRefCreateFromCString("shape.const_size"), loc);
572   MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
573       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr);
574   mlirOperationStateAddAttributes(&state, 1, &valueAttr);
575   mlirOperationStateEnableResultTypeInference(&state);
576 
577   // Expect result type inference to succeed.
578   MlirOperation op = mlirOperationCreate(&state);
579   if (mlirOperationIsNull(op)) {
580     fprintf(stderr, "ERROR: Result type inference unexpectedly failed");
581     return 1;
582   }
583 
584   // CHECK: RESULT_TYPE_INFERENCE: !shape.size
585   fprintf(stderr, "RESULT_TYPE_INFERENCE: ");
586   mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0)));
587   fprintf(stderr, "\n");
588   mlirOperationDestroy(op);
589   return 0;
590 }
591 
592 /// Dumps instances of all builtin types to check that C API works correctly.
593 /// Additionally, performs simple identity checks that a builtin type
594 /// constructed with C API can be inspected and has the expected type. The
595 /// latter achieves full coverage of C API for builtin types. Returns 0 on
596 /// success and a non-zero error code on failure.
printBuiltinTypes(MlirContext ctx)597 static int printBuiltinTypes(MlirContext ctx) {
598   // Integer types.
599   MlirType i32 = mlirIntegerTypeGet(ctx, 32);
600   MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
601   MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
602   if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
603     return 1;
604   if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
605     return 2;
606   if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
607     return 3;
608   if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
609     return 4;
610   if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
611     return 5;
612   fprintf(stderr, "@types\n");
613   mlirTypeDump(i32);
614   fprintf(stderr, "\n");
615   mlirTypeDump(si32);
616   fprintf(stderr, "\n");
617   mlirTypeDump(ui32);
618   fprintf(stderr, "\n");
619   // CHECK-LABEL: @types
620   // CHECK: i32
621   // CHECK: si32
622   // CHECK: ui32
623 
624   // Index type.
625   MlirType index = mlirIndexTypeGet(ctx);
626   if (!mlirTypeIsAIndex(index))
627     return 6;
628   mlirTypeDump(index);
629   fprintf(stderr, "\n");
630   // CHECK: index
631 
632   // Floating-point types.
633   MlirType bf16 = mlirBF16TypeGet(ctx);
634   MlirType f16 = mlirF16TypeGet(ctx);
635   MlirType f32 = mlirF32TypeGet(ctx);
636   MlirType f64 = mlirF64TypeGet(ctx);
637   if (!mlirTypeIsABF16(bf16))
638     return 7;
639   if (!mlirTypeIsAF16(f16))
640     return 9;
641   if (!mlirTypeIsAF32(f32))
642     return 10;
643   if (!mlirTypeIsAF64(f64))
644     return 11;
645   mlirTypeDump(bf16);
646   fprintf(stderr, "\n");
647   mlirTypeDump(f16);
648   fprintf(stderr, "\n");
649   mlirTypeDump(f32);
650   fprintf(stderr, "\n");
651   mlirTypeDump(f64);
652   fprintf(stderr, "\n");
653   // CHECK: bf16
654   // CHECK: f16
655   // CHECK: f32
656   // CHECK: f64
657 
658   // None type.
659   MlirType none = mlirNoneTypeGet(ctx);
660   if (!mlirTypeIsANone(none))
661     return 12;
662   mlirTypeDump(none);
663   fprintf(stderr, "\n");
664   // CHECK: none
665 
666   // Complex type.
667   MlirType cplx = mlirComplexTypeGet(f32);
668   if (!mlirTypeIsAComplex(cplx) ||
669       !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
670     return 13;
671   mlirTypeDump(cplx);
672   fprintf(stderr, "\n");
673   // CHECK: complex<f32>
674 
675   // Vector (and Shaped) type. ShapedType is a common base class for vectors,
676   // memrefs and tensors, one cannot create instances of this class so it is
677   // tested on an instance of vector type.
678   int64_t shape[] = {2, 3};
679   MlirType vector =
680       mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
681   if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
682     return 14;
683   if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
684       !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
685       mlirShapedTypeGetDimSize(vector, 0) != 2 ||
686       mlirShapedTypeIsDynamicDim(vector, 0) ||
687       mlirShapedTypeGetDimSize(vector, 1) != 3 ||
688       !mlirShapedTypeHasStaticShape(vector))
689     return 15;
690   mlirTypeDump(vector);
691   fprintf(stderr, "\n");
692   // CHECK: vector<2x3xf32>
693 
694   // Ranked tensor type.
695   MlirType rankedTensor = mlirRankedTensorTypeGet(
696       sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
697   if (!mlirTypeIsATensor(rankedTensor) ||
698       !mlirTypeIsARankedTensor(rankedTensor) ||
699       !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
700     return 16;
701   mlirTypeDump(rankedTensor);
702   fprintf(stderr, "\n");
703   // CHECK: tensor<2x3xf32>
704 
705   // Unranked tensor type.
706   MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
707   if (!mlirTypeIsATensor(unrankedTensor) ||
708       !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
709       mlirShapedTypeHasRank(unrankedTensor))
710     return 17;
711   mlirTypeDump(unrankedTensor);
712   fprintf(stderr, "\n");
713   // CHECK: tensor<*xf32>
714 
715   // MemRef type.
716   MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2);
717   MlirType memRef = mlirMemRefTypeContiguousGet(
718       f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
719   if (!mlirTypeIsAMemRef(memRef) ||
720       mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
721       !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
722     return 18;
723   mlirTypeDump(memRef);
724   fprintf(stderr, "\n");
725   // CHECK: memref<2x3xf32, 2>
726 
727   // Unranked MemRef type.
728   MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4);
729   MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4);
730   if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
731       mlirTypeIsAMemRef(unrankedMemRef) ||
732       !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
733                           memSpace4))
734     return 19;
735   mlirTypeDump(unrankedMemRef);
736   fprintf(stderr, "\n");
737   // CHECK: memref<*xf32, 4>
738 
739   // Tuple type.
740   MlirType types[] = {unrankedMemRef, f32};
741   MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
742   if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
743       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
744       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
745     return 20;
746   mlirTypeDump(tuple);
747   fprintf(stderr, "\n");
748   // CHECK: tuple<memref<*xf32, 4>, f32>
749 
750   // Function type.
751   MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
752   MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
753                              mlirIntegerTypeGet(ctx, 32),
754                              mlirIntegerTypeGet(ctx, 64)};
755   MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
756   if (mlirFunctionTypeGetNumInputs(funcType) != 2)
757     return 21;
758   if (mlirFunctionTypeGetNumResults(funcType) != 3)
759     return 22;
760   if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
761       !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
762     return 23;
763   if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
764       !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
765       !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
766     return 24;
767   mlirTypeDump(funcType);
768   fprintf(stderr, "\n");
769   // CHECK: (index, i1) -> (i16, i32, i64)
770 
771   return 0;
772 }
773 
callbackSetFixedLengthString(const char * data,intptr_t len,void * userData)774 void callbackSetFixedLengthString(const char *data, intptr_t len,
775                                   void *userData) {
776   strncpy(userData, data, len);
777 }
778 
stringIsEqual(const char * lhs,MlirStringRef rhs)779 bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
780   if (strlen(lhs) != rhs.length) {
781     return false;
782   }
783   return !strncmp(lhs, rhs.data, rhs.length);
784 }
785 
printBuiltinAttributes(MlirContext ctx)786 int printBuiltinAttributes(MlirContext ctx) {
787   MlirAttribute floating =
788       mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
789   if (!mlirAttributeIsAFloat(floating) ||
790       fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
791     return 1;
792   fprintf(stderr, "@attrs\n");
793   mlirAttributeDump(floating);
794   // CHECK-LABEL: @attrs
795   // CHECK: 2.000000e+00 : f64
796 
797   // Exercise mlirAttributeGetType() just for the first one.
798   MlirType floatingType = mlirAttributeGetType(floating);
799   mlirTypeDump(floatingType);
800   // CHECK: f64
801 
802   MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
803   if (!mlirAttributeIsAInteger(integer) ||
804       mlirIntegerAttrGetValueInt(integer) != 42)
805     return 2;
806   mlirAttributeDump(integer);
807   // CHECK: 42 : i32
808 
809   MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
810   if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
811     return 3;
812   mlirAttributeDump(boolean);
813   // CHECK: true
814 
815   const char data[] = "abcdefghijklmnopqestuvwxyz";
816   MlirAttribute opaque =
817       mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("std"), 3, data,
818                         mlirNoneTypeGet(ctx));
819   if (!mlirAttributeIsAOpaque(opaque) ||
820       !stringIsEqual("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
821     return 4;
822 
823   MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
824   if (opaqueData.length != 3 ||
825       strncmp(data, opaqueData.data, opaqueData.length))
826     return 5;
827   mlirAttributeDump(opaque);
828   // CHECK: #std.abc
829 
830   MlirAttribute string =
831       mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
832   if (!mlirAttributeIsAString(string))
833     return 6;
834 
835   MlirStringRef stringValue = mlirStringAttrGetValue(string);
836   if (stringValue.length != 2 ||
837       strncmp(data + 3, stringValue.data, stringValue.length))
838     return 7;
839   mlirAttributeDump(string);
840   // CHECK: "de"
841 
842   MlirAttribute flatSymbolRef =
843       mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
844   if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
845     return 8;
846 
847   MlirStringRef flatSymbolRefValue =
848       mlirFlatSymbolRefAttrGetValue(flatSymbolRef);
849   if (flatSymbolRefValue.length != 3 ||
850       strncmp(data + 5, flatSymbolRefValue.data, flatSymbolRefValue.length))
851     return 9;
852   mlirAttributeDump(flatSymbolRef);
853   // CHECK: @fgh
854 
855   MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
856   MlirAttribute symbolRef =
857       mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
858   if (!mlirAttributeIsASymbolRef(symbolRef) ||
859       mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
860       !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
861                           flatSymbolRef) ||
862       !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
863                           flatSymbolRef))
864     return 10;
865 
866   MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(symbolRef);
867   MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(symbolRef);
868   if (symbolRefLeaf.length != 3 ||
869       strncmp(data + 5, symbolRefLeaf.data, symbolRefLeaf.length) ||
870       symbolRefRoot.length != 2 ||
871       strncmp(data + 8, symbolRefRoot.data, symbolRefRoot.length))
872     return 11;
873   mlirAttributeDump(symbolRef);
874   // CHECK: @ij::@fgh::@fgh
875 
876   MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
877   if (!mlirAttributeIsAType(type) ||
878       !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
879     return 12;
880   mlirAttributeDump(type);
881   // CHECK: f32
882 
883   MlirAttribute unit = mlirUnitAttrGet(ctx);
884   if (!mlirAttributeIsAUnit(unit))
885     return 13;
886   mlirAttributeDump(unit);
887   // CHECK: unit
888 
889   int64_t shape[] = {1, 2};
890 
891   int bools[] = {0, 1};
892   uint8_t uints8[] = {0u, 1u};
893   int8_t ints8[] = {0, 1};
894   uint32_t uints32[] = {0u, 1u};
895   int32_t ints32[] = {0, 1};
896   uint64_t uints64[] = {0u, 1u};
897   int64_t ints64[] = {0, 1};
898   float floats[] = {0.0f, 1.0f};
899   double doubles[] = {0.0, 1.0};
900   MlirAttribute encoding = mlirAttributeGetNull();
901   MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
902       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
903       2, bools);
904   MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get(
905       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
906                               encoding),
907       2, uints8);
908   MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
909       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
910       2, ints8);
911   MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
912       mlirRankedTensorTypeGet(2, shape,
913                               mlirIntegerTypeUnsignedGet(ctx, 32), encoding),
914       2, uints32);
915   MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
916       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
917       2, ints32);
918   MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
919       mlirRankedTensorTypeGet(2, shape,
920                               mlirIntegerTypeUnsignedGet(ctx, 64), encoding),
921       2, uints64);
922   MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
923       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
924       2, ints64);
925   MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
926       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
927       2, floats);
928   MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
929       mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding),
930       2, doubles);
931 
932   if (!mlirAttributeIsADenseElements(boolElements) ||
933       !mlirAttributeIsADenseElements(uint8Elements) ||
934       !mlirAttributeIsADenseElements(int8Elements) ||
935       !mlirAttributeIsADenseElements(uint32Elements) ||
936       !mlirAttributeIsADenseElements(int32Elements) ||
937       !mlirAttributeIsADenseElements(uint64Elements) ||
938       !mlirAttributeIsADenseElements(int64Elements) ||
939       !mlirAttributeIsADenseElements(floatElements) ||
940       !mlirAttributeIsADenseElements(doubleElements))
941     return 14;
942 
943   if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
944       mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
945       mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
946       mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
947       mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
948       mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
949       mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
950       fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
951           1E-6f ||
952       fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
953     return 15;
954 
955   mlirAttributeDump(boolElements);
956   mlirAttributeDump(uint8Elements);
957   mlirAttributeDump(int8Elements);
958   mlirAttributeDump(uint32Elements);
959   mlirAttributeDump(int32Elements);
960   mlirAttributeDump(uint64Elements);
961   mlirAttributeDump(int64Elements);
962   mlirAttributeDump(floatElements);
963   mlirAttributeDump(doubleElements);
964   // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
965   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
966   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
967   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
968   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
969   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
970   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
971   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
972   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
973 
974   MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
975       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
976       1);
977   MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet(
978       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
979                               encoding),
980       1);
981   MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet(
982       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
983       1);
984   MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
985       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
986                               encoding),
987       1);
988   MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
989       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
990       1);
991   MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
992       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
993                               encoding),
994       1);
995   MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
996       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
997       1);
998   MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
999       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
1000   MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
1001       mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 1.0);
1002 
1003   if (!mlirAttributeIsADenseElements(splatBool) ||
1004       !mlirDenseElementsAttrIsSplat(splatBool) ||
1005       !mlirAttributeIsADenseElements(splatUInt8) ||
1006       !mlirDenseElementsAttrIsSplat(splatUInt8) ||
1007       !mlirAttributeIsADenseElements(splatInt8) ||
1008       !mlirDenseElementsAttrIsSplat(splatInt8) ||
1009       !mlirAttributeIsADenseElements(splatUInt32) ||
1010       !mlirDenseElementsAttrIsSplat(splatUInt32) ||
1011       !mlirAttributeIsADenseElements(splatInt32) ||
1012       !mlirDenseElementsAttrIsSplat(splatInt32) ||
1013       !mlirAttributeIsADenseElements(splatUInt64) ||
1014       !mlirDenseElementsAttrIsSplat(splatUInt64) ||
1015       !mlirAttributeIsADenseElements(splatInt64) ||
1016       !mlirDenseElementsAttrIsSplat(splatInt64) ||
1017       !mlirAttributeIsADenseElements(splatFloat) ||
1018       !mlirDenseElementsAttrIsSplat(splatFloat) ||
1019       !mlirAttributeIsADenseElements(splatDouble) ||
1020       !mlirDenseElementsAttrIsSplat(splatDouble))
1021     return 16;
1022 
1023   if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
1024       mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 ||
1025       mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 ||
1026       mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
1027       mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
1028       mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
1029       mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
1030       fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
1031           1E-6f ||
1032       fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
1033     return 17;
1034 
1035   uint8_t *uint8RawData =
1036       (uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements);
1037   int8_t *int8RawData = (int8_t *)mlirDenseElementsAttrGetRawData(int8Elements);
1038   uint32_t *uint32RawData =
1039       (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
1040   int32_t *int32RawData =
1041       (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
1042   uint64_t *uint64RawData =
1043       (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
1044   int64_t *int64RawData =
1045       (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
1046   float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
1047   double *doubleRawData =
1048       (double *)mlirDenseElementsAttrGetRawData(doubleElements);
1049   if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
1050       int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
1051       int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
1052       uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
1053       floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
1054       doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
1055     return 18;
1056 
1057   mlirAttributeDump(splatBool);
1058   mlirAttributeDump(splatUInt8);
1059   mlirAttributeDump(splatInt8);
1060   mlirAttributeDump(splatUInt32);
1061   mlirAttributeDump(splatInt32);
1062   mlirAttributeDump(splatUInt64);
1063   mlirAttributeDump(splatInt64);
1064   mlirAttributeDump(splatFloat);
1065   mlirAttributeDump(splatDouble);
1066   // CHECK: dense<true> : tensor<1x2xi1>
1067   // CHECK: dense<1> : tensor<1x2xui8>
1068   // CHECK: dense<1> : tensor<1x2xi8>
1069   // CHECK: dense<1> : tensor<1x2xui32>
1070   // CHECK: dense<1> : tensor<1x2xi32>
1071   // CHECK: dense<1> : tensor<1x2xui64>
1072   // CHECK: dense<1> : tensor<1x2xi64>
1073   // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
1074   // CHECK: dense<1.000000e+00> : tensor<1x2xf64>
1075 
1076   mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
1077   mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
1078   // CHECK: 1.000000e+00 : f32
1079   // CHECK: 1.000000e+00 : f64
1080 
1081   int64_t indices[] = {4, 7};
1082   int64_t two = 2;
1083   MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
1084       mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding),
1085       2, indices);
1086   MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
1087       mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding),
1088       2, floats);
1089   MlirAttribute sparseAttr = mlirSparseElementsAttribute(
1090       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
1091       indicesAttr, valuesAttr);
1092   mlirAttributeDump(sparseAttr);
1093   // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32>
1094 
1095   return 0;
1096 }
1097 
printAffineMap(MlirContext ctx)1098 int printAffineMap(MlirContext ctx) {
1099   MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
1100   MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2);
1101   MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
1102   MlirAffineMap multiDimIdentityAffineMap =
1103       mlirAffineMapMultiDimIdentityGet(ctx, 3);
1104   MlirAffineMap minorIdentityAffineMap =
1105       mlirAffineMapMinorIdentityGet(ctx, 3, 2);
1106   unsigned permutation[] = {1, 2, 0};
1107   MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet(
1108       ctx, sizeof(permutation) / sizeof(unsigned), permutation);
1109 
1110   fprintf(stderr, "@affineMap\n");
1111   mlirAffineMapDump(emptyAffineMap);
1112   mlirAffineMapDump(affineMap);
1113   mlirAffineMapDump(constAffineMap);
1114   mlirAffineMapDump(multiDimIdentityAffineMap);
1115   mlirAffineMapDump(minorIdentityAffineMap);
1116   mlirAffineMapDump(permutationAffineMap);
1117   // CHECK-LABEL: @affineMap
1118   // CHECK: () -> ()
1119   // CHECK: (d0, d1, d2)[s0, s1] -> ()
1120   // CHECK: () -> (2)
1121   // CHECK: (d0, d1, d2) -> (d0, d1, d2)
1122   // CHECK: (d0, d1, d2) -> (d1, d2)
1123   // CHECK: (d0, d1, d2) -> (d1, d2, d0)
1124 
1125   if (!mlirAffineMapIsIdentity(emptyAffineMap) ||
1126       mlirAffineMapIsIdentity(affineMap) ||
1127       mlirAffineMapIsIdentity(constAffineMap) ||
1128       !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) ||
1129       mlirAffineMapIsIdentity(minorIdentityAffineMap) ||
1130       mlirAffineMapIsIdentity(permutationAffineMap))
1131     return 1;
1132 
1133   if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) ||
1134       mlirAffineMapIsMinorIdentity(affineMap) ||
1135       !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) ||
1136       !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) ||
1137       mlirAffineMapIsMinorIdentity(permutationAffineMap))
1138     return 2;
1139 
1140   if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
1141       mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
1142       mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
1143       mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
1144       mlirAffineMapIsEmpty(permutationAffineMap))
1145     return 3;
1146 
1147   if (mlirAffineMapIsSingleConstant(emptyAffineMap) ||
1148       mlirAffineMapIsSingleConstant(affineMap) ||
1149       !mlirAffineMapIsSingleConstant(constAffineMap) ||
1150       mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) ||
1151       mlirAffineMapIsSingleConstant(minorIdentityAffineMap) ||
1152       mlirAffineMapIsSingleConstant(permutationAffineMap))
1153     return 4;
1154 
1155   if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2)
1156     return 5;
1157 
1158   if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 ||
1159       mlirAffineMapGetNumDims(affineMap) != 3 ||
1160       mlirAffineMapGetNumDims(constAffineMap) != 0 ||
1161       mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 ||
1162       mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 ||
1163       mlirAffineMapGetNumDims(permutationAffineMap) != 3)
1164     return 6;
1165 
1166   if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 ||
1167       mlirAffineMapGetNumSymbols(affineMap) != 2 ||
1168       mlirAffineMapGetNumSymbols(constAffineMap) != 0 ||
1169       mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 ||
1170       mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 ||
1171       mlirAffineMapGetNumSymbols(permutationAffineMap) != 0)
1172     return 7;
1173 
1174   if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 ||
1175       mlirAffineMapGetNumResults(affineMap) != 0 ||
1176       mlirAffineMapGetNumResults(constAffineMap) != 1 ||
1177       mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 ||
1178       mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 ||
1179       mlirAffineMapGetNumResults(permutationAffineMap) != 3)
1180     return 8;
1181 
1182   if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 ||
1183       mlirAffineMapGetNumInputs(affineMap) != 5 ||
1184       mlirAffineMapGetNumInputs(constAffineMap) != 0 ||
1185       mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 ||
1186       mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 ||
1187       mlirAffineMapGetNumInputs(permutationAffineMap) != 3)
1188     return 9;
1189 
1190   if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) ||
1191       !mlirAffineMapIsPermutation(emptyAffineMap) ||
1192       mlirAffineMapIsProjectedPermutation(affineMap) ||
1193       mlirAffineMapIsPermutation(affineMap) ||
1194       mlirAffineMapIsProjectedPermutation(constAffineMap) ||
1195       mlirAffineMapIsPermutation(constAffineMap) ||
1196       !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) ||
1197       !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) ||
1198       !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) ||
1199       mlirAffineMapIsPermutation(minorIdentityAffineMap) ||
1200       !mlirAffineMapIsProjectedPermutation(permutationAffineMap) ||
1201       !mlirAffineMapIsPermutation(permutationAffineMap))
1202     return 10;
1203 
1204   intptr_t sub[] = {1};
1205 
1206   MlirAffineMap subMap = mlirAffineMapGetSubMap(
1207       multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub);
1208   MlirAffineMap majorSubMap =
1209       mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1);
1210   MlirAffineMap minorSubMap =
1211       mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1);
1212 
1213   mlirAffineMapDump(subMap);
1214   mlirAffineMapDump(majorSubMap);
1215   mlirAffineMapDump(minorSubMap);
1216   // CHECK: (d0, d1, d2) -> (d1)
1217   // CHECK: (d0, d1, d2) -> (d0)
1218   // CHECK: (d0, d1, d2) -> (d2)
1219 
1220   return 0;
1221 }
1222 
printAffineExpr(MlirContext ctx)1223 int printAffineExpr(MlirContext ctx) {
1224   MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 5);
1225   MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 5);
1226   MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, 5);
1227   MlirAffineExpr affineAddExpr =
1228       mlirAffineAddExprGet(affineDimExpr, affineSymbolExpr);
1229   MlirAffineExpr affineMulExpr =
1230       mlirAffineMulExprGet(affineDimExpr, affineSymbolExpr);
1231   MlirAffineExpr affineModExpr =
1232       mlirAffineModExprGet(affineDimExpr, affineSymbolExpr);
1233   MlirAffineExpr affineFloorDivExpr =
1234       mlirAffineFloorDivExprGet(affineDimExpr, affineSymbolExpr);
1235   MlirAffineExpr affineCeilDivExpr =
1236       mlirAffineCeilDivExprGet(affineDimExpr, affineSymbolExpr);
1237 
1238   // Tests mlirAffineExprDump.
1239   fprintf(stderr, "@affineExpr\n");
1240   mlirAffineExprDump(affineDimExpr);
1241   mlirAffineExprDump(affineSymbolExpr);
1242   mlirAffineExprDump(affineConstantExpr);
1243   mlirAffineExprDump(affineAddExpr);
1244   mlirAffineExprDump(affineMulExpr);
1245   mlirAffineExprDump(affineModExpr);
1246   mlirAffineExprDump(affineFloorDivExpr);
1247   mlirAffineExprDump(affineCeilDivExpr);
1248   // CHECK-LABEL: @affineExpr
1249   // CHECK: d5
1250   // CHECK: s5
1251   // CHECK: 5
1252   // CHECK: d5 + s5
1253   // CHECK: d5 * s5
1254   // CHECK: d5 mod s5
1255   // CHECK: d5 floordiv s5
1256   // CHECK: d5 ceildiv s5
1257 
1258   // Tests methods of affine binary operation expression, takes add expression
1259   // as an example.
1260   mlirAffineExprDump(mlirAffineBinaryOpExprGetLHS(affineAddExpr));
1261   mlirAffineExprDump(mlirAffineBinaryOpExprGetRHS(affineAddExpr));
1262   // CHECK: d5
1263   // CHECK: s5
1264 
1265   // Tests methods of affine dimension expression.
1266   if (mlirAffineDimExprGetPosition(affineDimExpr) != 5)
1267     return 1;
1268 
1269   // Tests methods of affine symbol expression.
1270   if (mlirAffineSymbolExprGetPosition(affineSymbolExpr) != 5)
1271     return 2;
1272 
1273   // Tests methods of affine constant expression.
1274   if (mlirAffineConstantExprGetValue(affineConstantExpr) != 5)
1275     return 3;
1276 
1277   // Tests methods of affine expression.
1278   if (mlirAffineExprIsSymbolicOrConstant(affineDimExpr) ||
1279       !mlirAffineExprIsSymbolicOrConstant(affineSymbolExpr) ||
1280       !mlirAffineExprIsSymbolicOrConstant(affineConstantExpr) ||
1281       mlirAffineExprIsSymbolicOrConstant(affineAddExpr) ||
1282       mlirAffineExprIsSymbolicOrConstant(affineMulExpr) ||
1283       mlirAffineExprIsSymbolicOrConstant(affineModExpr) ||
1284       mlirAffineExprIsSymbolicOrConstant(affineFloorDivExpr) ||
1285       mlirAffineExprIsSymbolicOrConstant(affineCeilDivExpr))
1286     return 4;
1287 
1288   if (!mlirAffineExprIsPureAffine(affineDimExpr) ||
1289       !mlirAffineExprIsPureAffine(affineSymbolExpr) ||
1290       !mlirAffineExprIsPureAffine(affineConstantExpr) ||
1291       !mlirAffineExprIsPureAffine(affineAddExpr) ||
1292       mlirAffineExprIsPureAffine(affineMulExpr) ||
1293       mlirAffineExprIsPureAffine(affineModExpr) ||
1294       mlirAffineExprIsPureAffine(affineFloorDivExpr) ||
1295       mlirAffineExprIsPureAffine(affineCeilDivExpr))
1296     return 5;
1297 
1298   if (mlirAffineExprGetLargestKnownDivisor(affineDimExpr) != 1 ||
1299       mlirAffineExprGetLargestKnownDivisor(affineSymbolExpr) != 1 ||
1300       mlirAffineExprGetLargestKnownDivisor(affineConstantExpr) != 5 ||
1301       mlirAffineExprGetLargestKnownDivisor(affineAddExpr) != 1 ||
1302       mlirAffineExprGetLargestKnownDivisor(affineMulExpr) != 1 ||
1303       mlirAffineExprGetLargestKnownDivisor(affineModExpr) != 1 ||
1304       mlirAffineExprGetLargestKnownDivisor(affineFloorDivExpr) != 1 ||
1305       mlirAffineExprGetLargestKnownDivisor(affineCeilDivExpr) != 1)
1306     return 6;
1307 
1308   if (!mlirAffineExprIsMultipleOf(affineDimExpr, 1) ||
1309       !mlirAffineExprIsMultipleOf(affineSymbolExpr, 1) ||
1310       !mlirAffineExprIsMultipleOf(affineConstantExpr, 5) ||
1311       !mlirAffineExprIsMultipleOf(affineAddExpr, 1) ||
1312       !mlirAffineExprIsMultipleOf(affineMulExpr, 1) ||
1313       !mlirAffineExprIsMultipleOf(affineModExpr, 1) ||
1314       !mlirAffineExprIsMultipleOf(affineFloorDivExpr, 1) ||
1315       !mlirAffineExprIsMultipleOf(affineCeilDivExpr, 1))
1316     return 7;
1317 
1318   if (!mlirAffineExprIsFunctionOfDim(affineDimExpr, 5) ||
1319       mlirAffineExprIsFunctionOfDim(affineSymbolExpr, 5) ||
1320       mlirAffineExprIsFunctionOfDim(affineConstantExpr, 5) ||
1321       !mlirAffineExprIsFunctionOfDim(affineAddExpr, 5) ||
1322       !mlirAffineExprIsFunctionOfDim(affineMulExpr, 5) ||
1323       !mlirAffineExprIsFunctionOfDim(affineModExpr, 5) ||
1324       !mlirAffineExprIsFunctionOfDim(affineFloorDivExpr, 5) ||
1325       !mlirAffineExprIsFunctionOfDim(affineCeilDivExpr, 5))
1326     return 8;
1327 
1328   // Tests 'IsA' methods of affine binary operation expression.
1329   if (!mlirAffineExprIsAAdd(affineAddExpr))
1330     return 9;
1331 
1332   if (!mlirAffineExprIsAMul(affineMulExpr))
1333     return 10;
1334 
1335   if (!mlirAffineExprIsAMod(affineModExpr))
1336     return 11;
1337 
1338   if (!mlirAffineExprIsAFloorDiv(affineFloorDivExpr))
1339     return 12;
1340 
1341   if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
1342     return 13;
1343 
1344   if (!mlirAffineExprIsABinary(affineAddExpr))
1345     return 14;
1346 
1347   // Test other 'IsA' method on affine expressions.
1348   if (!mlirAffineExprIsAConstant(affineConstantExpr))
1349     return 15;
1350 
1351   if (!mlirAffineExprIsADim(affineDimExpr))
1352     return 16;
1353 
1354   if (!mlirAffineExprIsASymbol(affineSymbolExpr))
1355     return 17;
1356 
1357   // Test equality and nullity.
1358   MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
1359   if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
1360     return 18;
1361 
1362   if (mlirAffineExprIsNull(affineDimExpr))
1363     return 19;
1364 
1365   return 0;
1366 }
1367 
affineMapFromExprs(MlirContext ctx)1368 int affineMapFromExprs(MlirContext ctx) {
1369   MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0);
1370   MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1);
1371   MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr};
1372   MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs);
1373 
1374   // CHECK-LABEL: @affineMapFromExprs
1375   fprintf(stderr, "@affineMapFromExprs");
1376   // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1)
1377   mlirAffineMapDump(map);
1378 
1379   if (mlirAffineMapGetNumResults(map) != 2)
1380     return 1;
1381 
1382   if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr))
1383     return 2;
1384 
1385   if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
1386     return 3;
1387 
1388   return 0;
1389 }
1390 
printIntegerSet(MlirContext ctx)1391 int printIntegerSet(MlirContext ctx) {
1392   MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1393 
1394   // CHECK-LABEL: @printIntegerSet
1395   fprintf(stderr, "@printIntegerSet");
1396 
1397   // CHECK: (d0, d1)[s0] : (1 == 0)
1398   mlirIntegerSetDump(emptySet);
1399 
1400   if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
1401     return 1;
1402 
1403   MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1404   if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
1405     return 2;
1406 
1407   // Construct a set constrained by:
1408   //   d0 - s0 == 0,
1409   //   d1 - 42 >= 0.
1410   MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
1411   MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
1412   MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
1413   MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
1414   MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
1415   MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
1416   MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
1417   MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
1418   MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
1419   bool flags[] = {true, false};
1420 
1421   MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
1422   // CHECK: (d0, d1)[s0] : (
1423   // CHECK-DAG: d0 - s0 == 0
1424   // CHECK-DAG: d1 - 42 >= 0
1425   mlirIntegerSetDump(set);
1426 
1427   // Transform d1 into s0.
1428   MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
1429   MlirAffineExpr repl[] = {d0, s1};
1430   MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
1431   // CHECK: (d0)[s0, s1] : (
1432   // CHECK-DAG: d0 - s0 == 0
1433   // CHECK-DAG: s1 - 42 >= 0
1434   mlirIntegerSetDump(replaced);
1435 
1436   if (mlirIntegerSetGetNumDims(set) != 2)
1437     return 3;
1438   if (mlirIntegerSetGetNumDims(replaced) != 1)
1439     return 4;
1440 
1441   if (mlirIntegerSetGetNumSymbols(set) != 1)
1442     return 5;
1443   if (mlirIntegerSetGetNumSymbols(replaced) != 2)
1444     return 6;
1445 
1446   if (mlirIntegerSetGetNumInputs(set) != 3)
1447     return 7;
1448 
1449   if (mlirIntegerSetGetNumConstraints(set) != 2)
1450     return 8;
1451 
1452   if (mlirIntegerSetGetNumEqualities(set) != 1)
1453     return 9;
1454 
1455   if (mlirIntegerSetGetNumInequalities(set) != 1)
1456     return 10;
1457 
1458   MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
1459   MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
1460   bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
1461   bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
1462   if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
1463     return 11;
1464   if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
1465     return 12;
1466 
1467   return 0;
1468 }
1469 
registerOnlyStd()1470 int registerOnlyStd() {
1471   MlirContext ctx = mlirContextCreate();
1472   // The built-in dialect is always loaded.
1473   if (mlirContextGetNumLoadedDialects(ctx) != 1)
1474     return 1;
1475 
1476   MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
1477 
1478   MlirDialect std = mlirContextGetOrLoadDialect(
1479       ctx, mlirDialectHandleGetNamespace(stdHandle));
1480   if (!mlirDialectIsNull(std))
1481     return 2;
1482 
1483   mlirDialectHandleRegisterDialect(stdHandle, ctx);
1484 
1485   std = mlirContextGetOrLoadDialect(ctx,
1486                                     mlirDialectHandleGetNamespace(stdHandle));
1487   if (mlirDialectIsNull(std))
1488     return 3;
1489 
1490   MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx);
1491   if (!mlirDialectEqual(std, alsoStd))
1492     return 4;
1493 
1494   MlirStringRef stdNs = mlirDialectGetNamespace(std);
1495   MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle);
1496   if (stdNs.length != alsoStdNs.length ||
1497       strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
1498     return 5;
1499 
1500   fprintf(stderr, "@registration\n");
1501   // CHECK-LABEL: @registration
1502 
1503   // CHECK: std.cond_br is_registered: 1
1504   fprintf(stderr, "std.cond_br is_registered: %d\n",
1505           mlirContextIsRegisteredOperation(
1506               ctx, mlirStringRefCreateFromCString("std.cond_br")));
1507 
1508   // CHECK: std.not_existing_op is_registered: 0
1509   fprintf(stderr, "std.not_existing_op is_registered: %d\n",
1510           mlirContextIsRegisteredOperation(
1511               ctx, mlirStringRefCreateFromCString("std.not_existing_op")));
1512 
1513   // CHECK: not_existing_dialect.not_existing_op is_registered: 0
1514   fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
1515           mlirContextIsRegisteredOperation(
1516               ctx, mlirStringRefCreateFromCString(
1517                        "not_existing_dialect.not_existing_op")));
1518 
1519   return 0;
1520 }
1521 
1522 /// Tests backreference APIs
testBackreferences()1523 static int testBackreferences() {
1524   fprintf(stderr, "@test_backreferences\n");
1525 
1526   MlirContext ctx = mlirContextCreate();
1527   mlirContextSetAllowUnregisteredDialects(ctx, true);
1528   MlirLocation loc = mlirLocationUnknownGet(ctx);
1529 
1530   MlirOperationState opState =
1531       mlirOperationStateGet(mlirStringRefCreateFromCString("invalid.op"), loc);
1532   MlirRegion region = mlirRegionCreate();
1533   MlirBlock block = mlirBlockCreate(0, NULL);
1534   mlirRegionAppendOwnedBlock(region, block);
1535   mlirOperationStateAddOwnedRegions(&opState, 1, &region);
1536   MlirOperation op = mlirOperationCreate(&opState);
1537   MlirIdentifier ident =
1538       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("identifier"));
1539 
1540   if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) {
1541     fprintf(stderr, "ERROR: Getting context from operation failed\n");
1542     return 1;
1543   }
1544   if (!mlirOperationEqual(op, mlirBlockGetParentOperation(block))) {
1545     fprintf(stderr, "ERROR: Getting parent operation from block failed\n");
1546     return 2;
1547   }
1548   if (!mlirContextEqual(ctx, mlirIdentifierGetContext(ident))) {
1549     fprintf(stderr, "ERROR: Getting context from identifier failed\n");
1550     return 3;
1551   }
1552 
1553   mlirOperationDestroy(op);
1554   mlirContextDestroy(ctx);
1555 
1556   // CHECK-LABEL: @test_backreferences
1557   return 0;
1558 }
1559 
1560 /// Tests operand APIs.
testOperands()1561 int testOperands() {
1562   fprintf(stderr, "@testOperands\n");
1563   // CHECK-LABEL: @testOperands
1564 
1565   MlirContext ctx = mlirContextCreate();
1566   mlirRegisterAllDialects(ctx);
1567   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
1568   MlirLocation loc = mlirLocationUnknownGet(ctx);
1569   MlirType indexType = mlirIndexTypeGet(ctx);
1570 
1571   // Create some constants to use as operands.
1572   MlirAttribute indexZeroLiteral =
1573       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1574   MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1575       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1576       indexZeroLiteral);
1577   MlirOperationState constZeroState = mlirOperationStateGet(
1578       mlirStringRefCreateFromCString("std.constant"), loc);
1579   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1580   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1581   MlirOperation constZero = mlirOperationCreate(&constZeroState);
1582   MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
1583 
1584   MlirAttribute indexOneLiteral =
1585       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1586   MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
1587       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1588       indexOneLiteral);
1589   MlirOperationState constOneState = mlirOperationStateGet(
1590       mlirStringRefCreateFromCString("std.constant"), loc);
1591   mlirOperationStateAddResults(&constOneState, 1, &indexType);
1592   mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
1593   MlirOperation constOne = mlirOperationCreate(&constOneState);
1594   MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
1595 
1596   // Create the operation under test.
1597   mlirContextSetAllowUnregisteredDialects(ctx, true);
1598   MlirOperationState opState =
1599       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
1600   MlirValue initialOperands[] = {constZeroValue};
1601   mlirOperationStateAddOperands(&opState, 1, initialOperands);
1602   MlirOperation op = mlirOperationCreate(&opState);
1603 
1604   // Test operand APIs.
1605   intptr_t numOperands = mlirOperationGetNumOperands(op);
1606   fprintf(stderr, "Num Operands: %" PRIdPTR "\n", numOperands);
1607   // CHECK: Num Operands: 1
1608 
1609   MlirValue opOperand = mlirOperationGetOperand(op, 0);
1610   fprintf(stderr, "Original operand: ");
1611   mlirValuePrint(opOperand, printToStderr, NULL);
1612   // CHECK: Original operand: {{.+}} constant 0 : index
1613 
1614   mlirOperationSetOperand(op, 0, constOneValue);
1615   opOperand = mlirOperationGetOperand(op, 0);
1616   fprintf(stderr, "Updated operand: ");
1617   mlirValuePrint(opOperand, printToStderr, NULL);
1618   // CHECK: Updated operand: {{.+}} constant 1 : index
1619 
1620   mlirOperationDestroy(op);
1621   mlirOperationDestroy(constZero);
1622   mlirOperationDestroy(constOne);
1623   mlirContextDestroy(ctx);
1624 
1625   return 0;
1626 }
1627 
1628 /// Tests clone APIs.
testClone()1629 int testClone() {
1630   fprintf(stderr, "@testClone\n");
1631   // CHECK-LABEL: @testClone
1632 
1633   MlirContext ctx = mlirContextCreate();
1634   mlirRegisterAllDialects(ctx);
1635   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("std"));
1636   MlirLocation loc = mlirLocationUnknownGet(ctx);
1637   MlirType indexType = mlirIndexTypeGet(ctx);
1638   MlirStringRef valueStringRef =  mlirStringRefCreateFromCString("value");
1639 
1640   MlirAttribute indexZeroLiteral =
1641       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1642   MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
1643   MlirOperationState constZeroState = mlirOperationStateGet(
1644       mlirStringRefCreateFromCString("std.constant"), loc);
1645   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1646   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1647   MlirOperation constZero = mlirOperationCreate(&constZeroState);
1648 
1649   MlirAttribute indexOneLiteral =
1650       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1651   MlirOperation constOne = mlirOperationClone(constZero);
1652   mlirOperationSetAttributeByName(constOne, valueStringRef, indexOneLiteral);
1653 
1654   mlirOperationPrint(constZero, printToStderr, NULL);
1655   mlirOperationPrint(constOne, printToStderr, NULL);
1656   // CHECK: constant 0 : index
1657   // CHECK: constant 1 : index
1658 
1659   return 0;
1660 }
1661 
1662 // Wraps a diagnostic into additional text we can match against.
errorHandler(MlirDiagnostic diagnostic,void * userData)1663 MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
1664   fprintf(stderr, "processing diagnostic (userData: %" PRIdPTR ") <<\n",
1665           (intptr_t)userData);
1666   mlirDiagnosticPrint(diagnostic, printToStderr, NULL);
1667   fprintf(stderr, "\n");
1668   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1669   mlirLocationPrint(loc, printToStderr, NULL);
1670   assert(mlirDiagnosticGetNumNotes(diagnostic) == 0);
1671   fprintf(stderr, "\n>> end of diagnostic (userData: %" PRIdPTR ")\n",
1672           (intptr_t)userData);
1673   return mlirLogicalResultSuccess();
1674 }
1675 
1676 // Logs when the delete user data callback is called
deleteUserData(void * userData)1677 static void deleteUserData(void *userData) {
1678   fprintf(stderr, "deleting user data (userData: %" PRIdPTR ")\n",
1679           (intptr_t)userData);
1680 }
1681 
testDiagnostics()1682 void testDiagnostics() {
1683   MlirContext ctx = mlirContextCreate();
1684   MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
1685       ctx, errorHandler, (void *)42, deleteUserData);
1686   fprintf(stderr, "@test_diagnostics\n");
1687   MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
1688   mlirEmitError(unknownLoc, "test diagnostics");
1689   MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
1690       ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
1691   mlirEmitError(fileLineColLoc, "test diagnostics");
1692   MlirLocation callSiteLoc = mlirLocationCallSiteGet(
1693       mlirLocationFileLineColGet(
1694           ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
1695       fileLineColLoc);
1696   mlirEmitError(callSiteLoc, "test diagnostics");
1697   mlirContextDetachDiagnosticHandler(ctx, id);
1698   mlirEmitError(unknownLoc, "more test diagnostics");
1699   // CHECK-LABEL: @test_diagnostics
1700   // CHECK: processing diagnostic (userData: 42) <<
1701   // CHECK:   test diagnostics
1702   // CHECK:   loc(unknown)
1703   // CHECK: >> end of diagnostic (userData: 42)
1704   // CHECK: processing diagnostic (userData: 42) <<
1705   // CHECK:   test diagnostics
1706   // CHECK:   loc("file.c":1:2)
1707   // CHECK: >> end of diagnostic (userData: 42)
1708   // CHECK: processing diagnostic (userData: 42) <<
1709   // CHECK:   test diagnostics
1710   // CHECK:   loc(callsite("other-file.c":2:3 at "file.c":1:2))
1711   // CHECK: >> end of diagnostic (userData: 42)
1712   // CHECK: deleting user data (userData: 42)
1713   // CHECK-NOT: processing diagnostic
1714   // CHECK:     more test diagnostics
1715 }
1716 
main()1717 int main() {
1718   MlirContext ctx = mlirContextCreate();
1719   mlirRegisterAllDialects(ctx);
1720   if (constructAndTraverseIr(ctx))
1721     return 1;
1722   buildWithInsertionsAndPrint(ctx);
1723   if (createOperationWithTypeInference(ctx))
1724     return 2;
1725 
1726   if (printBuiltinTypes(ctx))
1727     return 3;
1728   if (printBuiltinAttributes(ctx))
1729     return 4;
1730   if (printAffineMap(ctx))
1731     return 5;
1732   if (printAffineExpr(ctx))
1733     return 6;
1734   if (affineMapFromExprs(ctx))
1735     return 7;
1736   if (printIntegerSet(ctx))
1737     return 8;
1738   if (registerOnlyStd())
1739     return 9;
1740   if (testBackreferences())
1741     return 10;
1742   if (testOperands())
1743     return 11;
1744   if (testClone())
1745     return 12;
1746 
1747   mlirContextDestroy(ctx);
1748 
1749   testDiagnostics();
1750   return 0;
1751 }
1752