1 //===-- runtime/matmul.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // Implements all forms of MATMUL (Fortran 2018 16.9.124)
10 //
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
14 //
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16).  A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
19 //
20 // Places where BLAS routines could be called are marked as TODO items.
21 
22 #include "matmul.h"
23 #include "cpp-type.h"
24 #include "descriptor.h"
25 #include "terminator.h"
26 #include "tools.h"
27 
28 namespace Fortran::runtime {
29 
30 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
31 class Accumulator {
32 public:
33   // Accumulate floating-point results in (at least) double precision
34   using Result = CppTypeFor<RCAT,
35       RCAT == TypeCategory::Real || RCAT == TypeCategory::Complex
36           ? std::max(RKIND, static_cast<int>(sizeof(double)))
37           : RKIND>;
Accumulator(const Descriptor & x,const Descriptor & y)38   Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
Accumulate(const SubscriptValue xAt[],const SubscriptValue yAt[])39   void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) {
40     if constexpr (RCAT == TypeCategory::Logical) {
41       sum_ = sum_ ||
42           (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
43     } else {
44       sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
45           static_cast<Result>(*y_.Element<YT>(yAt));
46     }
47   }
GetResult() const48   Result GetResult() const { return sum_; }
49 
50 private:
51   const Descriptor &x_, &y_;
52   Result sum_{};
53 };
54 
55 // Implements an instance of MATMUL for given argument types.
56 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
57     typename YT>
DoMatmul(std::conditional_t<IS_ALLOCATING,Descriptor,const Descriptor> & result,const Descriptor & x,const Descriptor & y,Terminator & terminator)58 static inline void DoMatmul(
59     std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
60     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
61   int xRank{x.rank()};
62   int yRank{y.rank()};
63   int resRank{xRank + yRank - 2};
64   if (xRank * yRank != 2 * resRank) {
65     terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
66   }
67   SubscriptValue extent[2]{
68       xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
69       resRank == 2 ? y.GetDimension(1).Extent() : 0};
70   if constexpr (IS_ALLOCATING) {
71     result.Establish(
72         RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
73     for (int j{0}; j < resRank; ++j) {
74       result.GetDimension(j).SetBounds(1, extent[j]);
75     }
76     if (int stat{result.Allocate()}) {
77       terminator.Crash(
78           "MATMUL: could not allocate memory for result; STAT=%d", stat);
79     }
80   } else {
81     RUNTIME_CHECK(terminator, resRank == result.rank());
82     RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND}));
83     RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
84     RUNTIME_CHECK(terminator,
85         resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
86   }
87   using WriteResult =
88       CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
89           RKIND>;
90   SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
91   if (n != y.GetDimension(0).Extent()) {
92     terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)",
93         static_cast<std::intmax_t>(n),
94         static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
95   }
96   SubscriptValue xAt[2], yAt[2], resAt[2];
97   x.GetLowerBounds(xAt);
98   y.GetLowerBounds(yAt);
99   result.GetLowerBounds(resAt);
100   if (resRank == 2) { // M*M -> M
101     if constexpr (std::is_same_v<XT, YT>) {
102       if constexpr (std::is_same_v<XT, float>) {
103         // TODO: call BLAS-3 SGEMM
104       } else if constexpr (std::is_same_v<XT, double>) {
105         // TODO: call BLAS-3 DGEMM
106       } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
107         // TODO: call BLAS-3 CGEMM
108       } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
109         // TODO: call BLAS-3 ZGEMM
110       }
111     }
112     SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
113     for (SubscriptValue i{0}; i < extent[0]; ++i) {
114       for (SubscriptValue j{0}; j < extent[1]; ++j) {
115         Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
116         yAt[1] = y1 + j;
117         for (SubscriptValue k{0}; k < n; ++k) {
118           xAt[1] = x1 + k;
119           yAt[0] = y0 + k;
120           accumulator.Accumulate(xAt, yAt);
121         }
122         resAt[1] = res1 + j;
123         *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
124       }
125       ++resAt[0];
126       ++xAt[0];
127     }
128   } else {
129     if constexpr (std::is_same_v<XT, YT>) {
130       if constexpr (std::is_same_v<XT, float>) {
131         // TODO: call BLAS-2 SGEMV
132       } else if constexpr (std::is_same_v<XT, double>) {
133         // TODO: call BLAS-2 DGEMV
134       } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
135         // TODO: call BLAS-2 CGEMV
136       } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
137         // TODO: call BLAS-2 ZGEMV
138       }
139     }
140     if (xRank == 2) { // M*V -> V
141       SubscriptValue x1{xAt[1]}, y0{yAt[0]};
142       for (SubscriptValue j{0}; j < extent[0]; ++j) {
143         Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
144         for (SubscriptValue k{0}; k < n; ++k) {
145           xAt[1] = x1 + k;
146           yAt[0] = y0 + k;
147           accumulator.Accumulate(xAt, yAt);
148         }
149         *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
150         ++resAt[0];
151         ++xAt[0];
152       }
153     } else { // V*M -> V
154       SubscriptValue x0{xAt[0]}, y0{yAt[0]};
155       for (SubscriptValue j{0}; j < extent[0]; ++j) {
156         Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
157         for (SubscriptValue k{0}; k < n; ++k) {
158           xAt[0] = x0 + k;
159           yAt[0] = y0 + k;
160           accumulator.Accumulate(xAt, yAt);
161         }
162         *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
163         ++resAt[0];
164         ++yAt[1];
165       }
166     }
167   }
168 }
169 
170 // Maps the dynamic type information from the arguments' descriptors
171 // to the right instantiation of DoMatmul() for valid combinations of
172 // types.
173 template <bool IS_ALLOCATING> struct Matmul {
174   using ResultDescriptor =
175       std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
176   template <TypeCategory XCAT, int XKIND> struct MM1 {
177     template <TypeCategory YCAT, int YKIND> struct MM2 {
operator ()Fortran::runtime::Matmul::MM1::MM2178       void operator()(ResultDescriptor &result, const Descriptor &x,
179           const Descriptor &y, Terminator &terminator) const {
180         if constexpr (constexpr auto resultType{
181                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
182           if constexpr (common::IsNumericTypeCategory(resultType->first) ||
183               resultType->first == TypeCategory::Logical) {
184             return DoMatmul<IS_ALLOCATING, resultType->first,
185                 resultType->second, CppTypeFor<XCAT, XKIND>,
186                 CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
187           }
188         }
189         terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
190             static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
191       }
192     };
operator ()Fortran::runtime::Matmul::MM1193     void operator()(ResultDescriptor &result, const Descriptor &x,
194         const Descriptor &y, Terminator &terminator, TypeCategory yCat,
195         int yKind) const {
196       ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
197     }
198   };
operator ()Fortran::runtime::Matmul199   void operator()(ResultDescriptor &result, const Descriptor &x,
200       const Descriptor &y, const char *sourceFile, int line) const {
201     Terminator terminator{sourceFile, line};
202     auto xCatKind{x.type().GetCategoryAndKind()};
203     auto yCatKind{y.type().GetCategoryAndKind()};
204     RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
205     ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
206         x, y, terminator, yCatKind->first, yCatKind->second);
207   }
208 };
209 
210 extern "C" {
RTNAME(Matmul)211 void RTNAME(Matmul)(Descriptor &result, const Descriptor &x,
212     const Descriptor &y, const char *sourceFile, int line) {
213   Matmul<true>{}(result, x, y, sourceFile, line);
214 }
RTNAME(MatmulDirect)215 void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x,
216     const Descriptor &y, const char *sourceFile, int line) {
217   Matmul<false>{}(result, x, y, sourceFile, line);
218 }
219 } // extern "C"
220 } // namespace Fortran::runtime
221