1 //===-- runtime/matmul.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 // Implements all forms of MATMUL (Fortran 2018 16.9.124)
10 //
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
14 //
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
19 //
20 // Places where BLAS routines could be called are marked as TODO items.
21
22 #include "matmul.h"
23 #include "cpp-type.h"
24 #include "descriptor.h"
25 #include "terminator.h"
26 #include "tools.h"
27
28 namespace Fortran::runtime {
29
30 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
31 class Accumulator {
32 public:
33 // Accumulate floating-point results in (at least) double precision
34 using Result = CppTypeFor<RCAT,
35 RCAT == TypeCategory::Real || RCAT == TypeCategory::Complex
36 ? std::max(RKIND, static_cast<int>(sizeof(double)))
37 : RKIND>;
Accumulator(const Descriptor & x,const Descriptor & y)38 Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
Accumulate(const SubscriptValue xAt[],const SubscriptValue yAt[])39 void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) {
40 if constexpr (RCAT == TypeCategory::Logical) {
41 sum_ = sum_ ||
42 (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
43 } else {
44 sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
45 static_cast<Result>(*y_.Element<YT>(yAt));
46 }
47 }
GetResult() const48 Result GetResult() const { return sum_; }
49
50 private:
51 const Descriptor &x_, &y_;
52 Result sum_{};
53 };
54
55 // Implements an instance of MATMUL for given argument types.
56 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
57 typename YT>
DoMatmul(std::conditional_t<IS_ALLOCATING,Descriptor,const Descriptor> & result,const Descriptor & x,const Descriptor & y,Terminator & terminator)58 static inline void DoMatmul(
59 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
60 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
61 int xRank{x.rank()};
62 int yRank{y.rank()};
63 int resRank{xRank + yRank - 2};
64 if (xRank * yRank != 2 * resRank) {
65 terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
66 }
67 SubscriptValue extent[2]{
68 xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
69 resRank == 2 ? y.GetDimension(1).Extent() : 0};
70 if constexpr (IS_ALLOCATING) {
71 result.Establish(
72 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
73 for (int j{0}; j < resRank; ++j) {
74 result.GetDimension(j).SetBounds(1, extent[j]);
75 }
76 if (int stat{result.Allocate()}) {
77 terminator.Crash(
78 "MATMUL: could not allocate memory for result; STAT=%d", stat);
79 }
80 } else {
81 RUNTIME_CHECK(terminator, resRank == result.rank());
82 RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND}));
83 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
84 RUNTIME_CHECK(terminator,
85 resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
86 }
87 using WriteResult =
88 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
89 RKIND>;
90 SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
91 if (n != y.GetDimension(0).Extent()) {
92 terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)",
93 static_cast<std::intmax_t>(n),
94 static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
95 }
96 SubscriptValue xAt[2], yAt[2], resAt[2];
97 x.GetLowerBounds(xAt);
98 y.GetLowerBounds(yAt);
99 result.GetLowerBounds(resAt);
100 if (resRank == 2) { // M*M -> M
101 if constexpr (std::is_same_v<XT, YT>) {
102 if constexpr (std::is_same_v<XT, float>) {
103 // TODO: call BLAS-3 SGEMM
104 } else if constexpr (std::is_same_v<XT, double>) {
105 // TODO: call BLAS-3 DGEMM
106 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
107 // TODO: call BLAS-3 CGEMM
108 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
109 // TODO: call BLAS-3 ZGEMM
110 }
111 }
112 SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
113 for (SubscriptValue i{0}; i < extent[0]; ++i) {
114 for (SubscriptValue j{0}; j < extent[1]; ++j) {
115 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
116 yAt[1] = y1 + j;
117 for (SubscriptValue k{0}; k < n; ++k) {
118 xAt[1] = x1 + k;
119 yAt[0] = y0 + k;
120 accumulator.Accumulate(xAt, yAt);
121 }
122 resAt[1] = res1 + j;
123 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
124 }
125 ++resAt[0];
126 ++xAt[0];
127 }
128 } else {
129 if constexpr (std::is_same_v<XT, YT>) {
130 if constexpr (std::is_same_v<XT, float>) {
131 // TODO: call BLAS-2 SGEMV
132 } else if constexpr (std::is_same_v<XT, double>) {
133 // TODO: call BLAS-2 DGEMV
134 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
135 // TODO: call BLAS-2 CGEMV
136 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
137 // TODO: call BLAS-2 ZGEMV
138 }
139 }
140 if (xRank == 2) { // M*V -> V
141 SubscriptValue x1{xAt[1]}, y0{yAt[0]};
142 for (SubscriptValue j{0}; j < extent[0]; ++j) {
143 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
144 for (SubscriptValue k{0}; k < n; ++k) {
145 xAt[1] = x1 + k;
146 yAt[0] = y0 + k;
147 accumulator.Accumulate(xAt, yAt);
148 }
149 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
150 ++resAt[0];
151 ++xAt[0];
152 }
153 } else { // V*M -> V
154 SubscriptValue x0{xAt[0]}, y0{yAt[0]};
155 for (SubscriptValue j{0}; j < extent[0]; ++j) {
156 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
157 for (SubscriptValue k{0}; k < n; ++k) {
158 xAt[0] = x0 + k;
159 yAt[0] = y0 + k;
160 accumulator.Accumulate(xAt, yAt);
161 }
162 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
163 ++resAt[0];
164 ++yAt[1];
165 }
166 }
167 }
168 }
169
170 // Maps the dynamic type information from the arguments' descriptors
171 // to the right instantiation of DoMatmul() for valid combinations of
172 // types.
173 template <bool IS_ALLOCATING> struct Matmul {
174 using ResultDescriptor =
175 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
176 template <TypeCategory XCAT, int XKIND> struct MM1 {
177 template <TypeCategory YCAT, int YKIND> struct MM2 {
operator ()Fortran::runtime::Matmul::MM1::MM2178 void operator()(ResultDescriptor &result, const Descriptor &x,
179 const Descriptor &y, Terminator &terminator) const {
180 if constexpr (constexpr auto resultType{
181 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
182 if constexpr (common::IsNumericTypeCategory(resultType->first) ||
183 resultType->first == TypeCategory::Logical) {
184 return DoMatmul<IS_ALLOCATING, resultType->first,
185 resultType->second, CppTypeFor<XCAT, XKIND>,
186 CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
187 }
188 }
189 terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
190 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
191 }
192 };
operator ()Fortran::runtime::Matmul::MM1193 void operator()(ResultDescriptor &result, const Descriptor &x,
194 const Descriptor &y, Terminator &terminator, TypeCategory yCat,
195 int yKind) const {
196 ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
197 }
198 };
operator ()Fortran::runtime::Matmul199 void operator()(ResultDescriptor &result, const Descriptor &x,
200 const Descriptor &y, const char *sourceFile, int line) const {
201 Terminator terminator{sourceFile, line};
202 auto xCatKind{x.type().GetCategoryAndKind()};
203 auto yCatKind{y.type().GetCategoryAndKind()};
204 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
205 ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
206 x, y, terminator, yCatKind->first, yCatKind->second);
207 }
208 };
209
210 extern "C" {
RTNAME(Matmul)211 void RTNAME(Matmul)(Descriptor &result, const Descriptor &x,
212 const Descriptor &y, const char *sourceFile, int line) {
213 Matmul<true>{}(result, x, y, sourceFile, line);
214 }
RTNAME(MatmulDirect)215 void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x,
216 const Descriptor &y, const char *sourceFile, int line) {
217 Matmul<false>{}(result, x, y, sourceFile, line);
218 }
219 } // extern "C"
220 } // namespace Fortran::runtime
221