1 //===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/InferTypeOpInterface.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/ImplicitLocOpBuilder.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/Parser.h"
19 
20 #include <gtest/gtest.h>
21 
22 using namespace mlir;
23 
24 class ValueShapeRangeTest : public testing::Test {
25 protected:
SetUp()26   void SetUp() override {
27     const char *ir = R"MLIR(
28       func @map(%arg : tensor<1xi64>) {
29         %0 = constant dense<[10]> : tensor<1xi64>
30         %1 = addi %arg, %0 : tensor<1xi64>
31         return
32       }
33     )MLIR";
34 
35     registry.insert<StandardOpsDialect>();
36     ctx.appendDialectRegistry(registry);
37     module = parseSourceString(ir, &ctx);
38     mapFn = cast<FuncOp>(module->front());
39   }
40 
41   // Create ValueShapeRange on the addi operation.
addiRange()42   ValueShapeRange addiRange() {
43     auto &fnBody = mapFn.body();
44     return std::next(fnBody.front().begin())->getOperands();
45   }
46 
47   DialectRegistry registry;
48   MLIRContext ctx;
49   OwningModuleRef module;
50   FuncOp mapFn;
51 };
52 
TEST_F(ValueShapeRangeTest,ShapesFromValues)53 TEST_F(ValueShapeRangeTest, ShapesFromValues) {
54   ValueShapeRange range = addiRange();
55 
56   EXPECT_FALSE(range.getValueAsShape(0));
57   ASSERT_TRUE(range.getValueAsShape(1));
58   EXPECT_TRUE(range.getValueAsShape(1).hasRank());
59   EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
60   EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
61   EXPECT_EQ(range.getShape(1).getRank(), 1);
62   EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
63 }
64 
TEST_F(ValueShapeRangeTest,MapValuesToShapes)65 TEST_F(ValueShapeRangeTest, MapValuesToShapes) {
66   ValueShapeRange range = addiRange();
67   ShapedTypeComponents fixed(SmallVector<int64_t>{30});
68   auto mapping = [&](Value val) -> ShapeAdaptor {
69     if (val == mapFn.getArgument(0))
70       return &fixed;
71     return nullptr;
72   };
73   range.setValueToShapeMapping(mapping);
74 
75   ASSERT_TRUE(range.getValueAsShape(0));
76   EXPECT_TRUE(range.getValueAsShape(0).hasRank());
77   EXPECT_EQ(range.getValueAsShape(0).getRank(), 1);
78   EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30);
79   ASSERT_TRUE(range.getValueAsShape(1));
80   EXPECT_TRUE(range.getValueAsShape(1).hasRank());
81   EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
82   EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
83 }
84 
TEST_F(ValueShapeRangeTest,SettingShapes)85 TEST_F(ValueShapeRangeTest, SettingShapes) {
86   ShapedTypeComponents shape(SmallVector<int64_t>{10, 20});
87   ValueShapeRange range = addiRange();
88   auto mapping = [&](Value val) -> ShapeAdaptor {
89     if (val == mapFn.getArgument(0))
90       return &shape;
91     return nullptr;
92   };
93   range.setOperandShapeMapping(mapping);
94 
95   ASSERT_TRUE(range.getShape(0));
96   EXPECT_EQ(range.getShape(0).getRank(), 2);
97   EXPECT_EQ(range.getShape(0).getDimSize(0), 10);
98   EXPECT_EQ(range.getShape(0).getDimSize(1), 20);
99   ASSERT_TRUE(range.getShape(1));
100   EXPECT_EQ(range.getShape(1).getRank(), 1);
101   EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
102   EXPECT_FALSE(range.getShape(2));
103 }
104