1 //===-- flang/unittests/RuntimeGTest/Matmul.cpp---- -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "../../runtime/matmul.h"
10 #include "gtest/gtest.h"
11 #include "tools.h"
12 #include "../../runtime/allocatable.h"
13 #include "../../runtime/cpp-type.h"
14 #include "../../runtime/descriptor.h"
15 #include "../../runtime/type-code.h"
16
17 using namespace Fortran::runtime;
18 using Fortran::common::TypeCategory;
19
TEST(Matmul,Basic)20 TEST(Matmul, Basic) {
21 // X 0 2 4 Y 6 9 V -1 -2
22 // 1 3 5 7 10
23 // 8 11
24 auto x{MakeArray<TypeCategory::Integer, 4>(
25 std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})};
26 auto y{MakeArray<TypeCategory::Integer, 2>(
27 std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
28 auto v{MakeArray<TypeCategory::Integer, 8>(
29 std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
30 StaticDescriptor<2, true> statDesc;
31 Descriptor &result{statDesc.descriptor()};
32
33 RTNAME(Matmul)(result, *x, *y, __FILE__, __LINE__);
34 ASSERT_EQ(result.rank(), 2);
35 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
36 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
37 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
38 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
39 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
40 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
41 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
42 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
43 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
44
45 std::memset(
46 result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
47 result.GetDimension(0).SetLowerBound(0);
48 result.GetDimension(1).SetLowerBound(2);
49 RTNAME(MatmulDirect)(result, *x, *y, __FILE__, __LINE__);
50 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
51 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
52 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
53 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
54 result.Destroy();
55
56 RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__);
57 ASSERT_EQ(result.rank(), 1);
58 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
59 EXPECT_EQ(result.GetDimension(0).Extent(), 3);
60 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
61 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
62 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
63 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
64 result.Destroy();
65
66 RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__);
67 ASSERT_EQ(result.rank(), 1);
68 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
69 EXPECT_EQ(result.GetDimension(0).Extent(), 3);
70 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
71 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
72 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
73 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
74 result.Destroy();
75
76 // X F F T Y F T
77 // F T T F T
78 // F F
79 auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{2, 3},
80 std::vector<std::uint8_t>{false, false, false, true, true, false})};
81 auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
82 std::vector<std::uint16_t>{false, false, false, true, true, false})};
83 RTNAME(Matmul)(result, *xLog, *yLog, __FILE__, __LINE__);
84 ASSERT_EQ(result.rank(), 2);
85 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
86 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
87 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
88 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
89 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
90 EXPECT_FALSE(
91 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
92 EXPECT_FALSE(
93 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
94 EXPECT_FALSE(
95 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
96 EXPECT_TRUE(
97 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
98 }
99