1 //===-- flang/unittests/RuntimeGTest/Matmul.cpp---- -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../../runtime/matmul.h"
10 #include "gtest/gtest.h"
11 #include "tools.h"
12 #include "../../runtime/allocatable.h"
13 #include "../../runtime/cpp-type.h"
14 #include "../../runtime/descriptor.h"
15 #include "../../runtime/type-code.h"
16 
17 using namespace Fortran::runtime;
18 using Fortran::common::TypeCategory;
19 
TEST(Matmul,Basic)20 TEST(Matmul, Basic) {
21   // X 0 2 4   Y 6  9   V -1 -2
22   //   1 3 5     7 10
23   //             8 11
24   auto x{MakeArray<TypeCategory::Integer, 4>(
25       std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})};
26   auto y{MakeArray<TypeCategory::Integer, 2>(
27       std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
28   auto v{MakeArray<TypeCategory::Integer, 8>(
29       std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
30   StaticDescriptor<2, true> statDesc;
31   Descriptor &result{statDesc.descriptor()};
32 
33   RTNAME(Matmul)(result, *x, *y, __FILE__, __LINE__);
34   ASSERT_EQ(result.rank(), 2);
35   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
36   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
37   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
38   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
39   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
40   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
41   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
42   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
43   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
44 
45   std::memset(
46       result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
47   result.GetDimension(0).SetLowerBound(0);
48   result.GetDimension(1).SetLowerBound(2);
49   RTNAME(MatmulDirect)(result, *x, *y, __FILE__, __LINE__);
50   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
51   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
52   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
53   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
54   result.Destroy();
55 
56   RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__);
57   ASSERT_EQ(result.rank(), 1);
58   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
59   EXPECT_EQ(result.GetDimension(0).Extent(), 3);
60   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
61   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
62   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
63   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
64   result.Destroy();
65 
66   RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__);
67   ASSERT_EQ(result.rank(), 1);
68   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
69   EXPECT_EQ(result.GetDimension(0).Extent(), 3);
70   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
71   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
72   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
73   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
74   result.Destroy();
75 
76   // X F F T  Y F T
77   //   F T T    F T
78   //            F F
79   auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{2, 3},
80       std::vector<std::uint8_t>{false, false, false, true, true, false})};
81   auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
82       std::vector<std::uint16_t>{false, false, false, true, true, false})};
83   RTNAME(Matmul)(result, *xLog, *yLog, __FILE__, __LINE__);
84   ASSERT_EQ(result.rank(), 2);
85   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
86   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
87   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
88   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
89   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
90   EXPECT_FALSE(
91       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
92   EXPECT_FALSE(
93       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
94   EXPECT_FALSE(
95       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
96   EXPECT_TRUE(
97       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
98 }
99