1 //===-- KindMapping.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Support/KindMapping.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "llvm/Support/CommandLine.h"
16 
17 /// Allow the user to set the FIR intrinsic type kind value to LLVM type
18 /// mappings.  Note that these are not mappings from kind values to any
19 /// other MLIR dialect, only to LLVM IR. The default values follow the f18
20 /// front-end kind mappings.
21 
22 using Bitsize = fir::KindMapping::Bitsize;
23 using KindTy = fir::KindMapping::KindTy;
24 using LLVMTypeID = fir::KindMapping::LLVMTypeID;
25 using MatchResult = fir::KindMapping::MatchResult;
26 
27 static llvm::cl::opt<std::string>
28     clKindMapping("kind-mapping",
29                   llvm::cl::desc("kind mapping string to set kind precision"),
30                   llvm::cl::value_desc("kind-mapping-string"),
31                   llvm::cl::init(fir::KindMapping::getDefaultMap()));
32 
33 static llvm::cl::opt<std::string>
34     clDefaultKinds("default-kinds",
35                    llvm::cl::desc("string to set default kind values"),
36                    llvm::cl::value_desc("default-kind-string"),
37                    llvm::cl::init(fir::KindMapping::getDefaultKinds()));
38 
39 // Keywords for the floating point types.
40 
41 static constexpr const char *kwHalf = "Half";
42 static constexpr const char *kwBFloat = "BFloat";
43 static constexpr const char *kwFloat = "Float";
44 static constexpr const char *kwDouble = "Double";
45 static constexpr const char *kwX86FP80 = "X86_FP80";
46 static constexpr const char *kwFP128 = "FP128";
47 static constexpr const char *kwPPCFP128 = "PPC_FP128";
48 
49 /// Integral types default to the kind value being the size of the value in
50 /// bytes. The default is to scale from bytes to bits.
defaultScalingKind(KindTy kind)51 static Bitsize defaultScalingKind(KindTy kind) {
52   const unsigned bitsInByte = 8;
53   return kind * bitsInByte;
54 }
55 
56 /// Floating-point types default to the kind value being the size of the value
57 /// in bytes. The default is to translate kinds of 2, 3, 4, 8, 10, and 16 to a
58 /// valid llvm::Type::TypeID value. Otherwise, the default is FloatTyID.
defaultRealKind(KindTy kind)59 static LLVMTypeID defaultRealKind(KindTy kind) {
60   switch (kind) {
61   case 2:
62     return LLVMTypeID::HalfTyID;
63   case 3:
64     return LLVMTypeID::BFloatTyID;
65   case 4:
66     return LLVMTypeID::FloatTyID;
67   case 8:
68     return LLVMTypeID::DoubleTyID;
69   case 10:
70     return LLVMTypeID::X86_FP80TyID;
71   case 16:
72     return LLVMTypeID::FP128TyID;
73   default:
74     return LLVMTypeID::FloatTyID;
75   }
76 }
77 
78 // lookup the kind-value given the defaults, the mappings, and a KIND key
79 template <typename RT, char KEY>
doLookup(std::function<RT (KindTy)> def,const llvm::DenseMap<std::pair<char,KindTy>,RT> & map,KindTy kind)80 static RT doLookup(std::function<RT(KindTy)> def,
81                    const llvm::DenseMap<std::pair<char, KindTy>, RT> &map,
82                    KindTy kind) {
83   std::pair<char, KindTy> key{KEY, kind};
84   auto iter = map.find(key);
85   if (iter != map.end())
86     return iter->second;
87   return def(kind);
88 }
89 
90 // do a lookup for INTERGER, LOGICAL, or CHARACTER
91 template <char KEY, typename MAP>
getIntegerLikeBitsize(KindTy kind,const MAP & map)92 static Bitsize getIntegerLikeBitsize(KindTy kind, const MAP &map) {
93   return doLookup<Bitsize, KEY>(defaultScalingKind, map, kind);
94 }
95 
96 // do a lookup for REAL or COMPLEX
97 template <char KEY, typename MAP>
getFloatLikeTypeID(KindTy kind,const MAP & map)98 static LLVMTypeID getFloatLikeTypeID(KindTy kind, const MAP &map) {
99   return doLookup<LLVMTypeID, KEY>(defaultRealKind, map, kind);
100 }
101 
102 template <char KEY, typename MAP>
getFloatSemanticsOfKind(KindTy kind,const MAP & map)103 static const llvm::fltSemantics &getFloatSemanticsOfKind(KindTy kind,
104                                                          const MAP &map) {
105   switch (doLookup<LLVMTypeID, KEY>(defaultRealKind, map, kind)) {
106   case LLVMTypeID::HalfTyID:
107     return llvm::APFloat::IEEEhalf();
108   case LLVMTypeID::BFloatTyID:
109     return llvm::APFloat::BFloat();
110   case LLVMTypeID::FloatTyID:
111     return llvm::APFloat::IEEEsingle();
112   case LLVMTypeID::DoubleTyID:
113     return llvm::APFloat::IEEEdouble();
114   case LLVMTypeID::X86_FP80TyID:
115     return llvm::APFloat::x87DoubleExtended();
116   case LLVMTypeID::FP128TyID:
117     return llvm::APFloat::IEEEquad();
118   case LLVMTypeID::PPC_FP128TyID:
119     return llvm::APFloat::PPCDoubleDouble();
120   default:
121     llvm_unreachable("Invalid floating type");
122   }
123 }
124 
125 /// Parse an intrinsic type code. The codes are ('a', CHARACTER), ('c',
126 /// COMPLEX), ('i', INTEGER), ('l', LOGICAL), and ('r', REAL).
parseCode(char & code,const char * & ptr,const char * endPtr)127 static MatchResult parseCode(char &code, const char *&ptr, const char *endPtr) {
128   if (ptr >= endPtr)
129     return mlir::failure();
130   if (*ptr != 'a' && *ptr != 'c' && *ptr != 'i' && *ptr != 'l' && *ptr != 'r')
131     return mlir::failure();
132   code = *ptr++;
133   return mlir::success();
134 }
135 
136 /// Same as `parseCode` but adds the ('d', DOUBLE PRECISION) code.
parseDefCode(char & code,const char * & ptr,const char * endPtr)137 static MatchResult parseDefCode(char &code, const char *&ptr,
138                                 const char *endPtr) {
139   if (ptr >= endPtr)
140     return mlir::failure();
141   if (*ptr == 'd') {
142     code = *ptr++;
143     return mlir::success();
144   }
145   return parseCode(code, ptr, endPtr);
146 }
147 
148 template <char ch>
parseSingleChar(const char * & ptr,const char * endPtr)149 static MatchResult parseSingleChar(const char *&ptr, const char *endPtr) {
150   if (ptr >= endPtr || *ptr != ch)
151     return mlir::failure();
152   ++ptr;
153   return mlir::success();
154 }
155 
parseColon(const char * & ptr,const char * endPtr)156 static MatchResult parseColon(const char *&ptr, const char *endPtr) {
157   return parseSingleChar<':'>(ptr, endPtr);
158 }
159 
parseComma(const char * & ptr,const char * endPtr)160 static MatchResult parseComma(const char *&ptr, const char *endPtr) {
161   return parseSingleChar<','>(ptr, endPtr);
162 }
163 
164 /// Recognize and parse an unsigned integer.
parseInt(unsigned & result,const char * & ptr,const char * endPtr)165 static MatchResult parseInt(unsigned &result, const char *&ptr,
166                             const char *endPtr) {
167   const char *beg = ptr;
168   while (ptr < endPtr && *ptr >= '0' && *ptr <= '9')
169     ptr++;
170   if (beg == ptr)
171     return mlir::failure();
172   llvm::StringRef ref(beg, ptr - beg);
173   int temp;
174   if (ref.consumeInteger(10, temp))
175     return mlir::failure();
176   result = temp;
177   return mlir::success();
178 }
179 
matchString(const char * & ptr,const char * endPtr,llvm::StringRef literal)180 static mlir::LogicalResult matchString(const char *&ptr, const char *endPtr,
181                                        llvm::StringRef literal) {
182   llvm::StringRef s(ptr, endPtr - ptr);
183   if (s.startswith(literal)) {
184     ptr += literal.size();
185     return mlir::success();
186   }
187   return mlir::failure();
188 }
189 
190 /// Recognize and parse the various floating-point keywords. These follow the
191 /// LLVM naming convention.
parseTypeID(LLVMTypeID & result,const char * & ptr,const char * endPtr)192 static MatchResult parseTypeID(LLVMTypeID &result, const char *&ptr,
193                                const char *endPtr) {
194   if (mlir::succeeded(matchString(ptr, endPtr, kwHalf))) {
195     result = LLVMTypeID::HalfTyID;
196     return mlir::success();
197   }
198   if (mlir::succeeded(matchString(ptr, endPtr, kwBFloat))) {
199     result = LLVMTypeID::BFloatTyID;
200     return mlir::success();
201   }
202   if (mlir::succeeded(matchString(ptr, endPtr, kwFloat))) {
203     result = LLVMTypeID::FloatTyID;
204     return mlir::success();
205   }
206   if (mlir::succeeded(matchString(ptr, endPtr, kwDouble))) {
207     result = LLVMTypeID::DoubleTyID;
208     return mlir::success();
209   }
210   if (mlir::succeeded(matchString(ptr, endPtr, kwX86FP80))) {
211     result = LLVMTypeID::X86_FP80TyID;
212     return mlir::success();
213   }
214   if (mlir::succeeded(matchString(ptr, endPtr, kwFP128))) {
215     result = LLVMTypeID::FP128TyID;
216     return mlir::success();
217   }
218   if (mlir::succeeded(matchString(ptr, endPtr, kwPPCFP128))) {
219     result = LLVMTypeID::PPC_FP128TyID;
220     return mlir::success();
221   }
222   return mlir::failure();
223 }
224 
KindMapping(mlir::MLIRContext * context,llvm::StringRef map,llvm::ArrayRef<KindTy> defs)225 fir::KindMapping::KindMapping(mlir::MLIRContext *context, llvm::StringRef map,
226                               llvm::ArrayRef<KindTy> defs)
227     : context{context} {
228   if (mlir::failed(setDefaultKinds(defs)))
229     llvm::report_fatal_error("bad default kinds");
230   if (mlir::failed(parse(map)))
231     llvm::report_fatal_error("could not parse kind map");
232 }
233 
KindMapping(mlir::MLIRContext * context,llvm::ArrayRef<KindTy> defs)234 fir::KindMapping::KindMapping(mlir::MLIRContext *context,
235                               llvm::ArrayRef<KindTy> defs)
236     : KindMapping{context, clKindMapping, defs} {}
237 
KindMapping(mlir::MLIRContext * context)238 fir::KindMapping::KindMapping(mlir::MLIRContext *context)
239     : KindMapping{context, clKindMapping, clDefaultKinds} {}
240 
badMapString(const llvm::Twine & ptr)241 MatchResult fir::KindMapping::badMapString(const llvm::Twine &ptr) {
242   auto unknown = mlir::UnknownLoc::get(context);
243   mlir::emitError(unknown, ptr);
244   return mlir::failure();
245 }
246 
parse(llvm::StringRef kindMap)247 MatchResult fir::KindMapping::parse(llvm::StringRef kindMap) {
248   if (kindMap.empty())
249     return mlir::success();
250   const char *srcPtr = kindMap.begin();
251   const char *endPtr = kindMap.end();
252   while (true) {
253     char code = '\0';
254     KindTy kind = 0;
255     if (parseCode(code, srcPtr, endPtr) || parseInt(kind, srcPtr, endPtr))
256       return badMapString(srcPtr);
257     if (code == 'a' || code == 'i' || code == 'l') {
258       Bitsize bits = 0;
259       if (parseColon(srcPtr, endPtr) || parseInt(bits, srcPtr, endPtr))
260         return badMapString(srcPtr);
261       intMap[std::pair<char, KindTy>{code, kind}] = bits;
262     } else if (code == 'r' || code == 'c') {
263       LLVMTypeID id{};
264       if (parseColon(srcPtr, endPtr) || parseTypeID(id, srcPtr, endPtr))
265         return badMapString(srcPtr);
266       floatMap[std::pair<char, KindTy>{code, kind}] = id;
267     } else {
268       return badMapString(srcPtr);
269     }
270     if (parseComma(srcPtr, endPtr))
271       break;
272   }
273   if (srcPtr > endPtr)
274     return badMapString(srcPtr);
275   return mlir::success();
276 }
277 
getCharacterBitsize(KindTy kind) const278 Bitsize fir::KindMapping::getCharacterBitsize(KindTy kind) const {
279   return getIntegerLikeBitsize<'a'>(kind, intMap);
280 }
281 
getIntegerBitsize(KindTy kind) const282 Bitsize fir::KindMapping::getIntegerBitsize(KindTy kind) const {
283   return getIntegerLikeBitsize<'i'>(kind, intMap);
284 }
285 
getLogicalBitsize(KindTy kind) const286 Bitsize fir::KindMapping::getLogicalBitsize(KindTy kind) const {
287   return getIntegerLikeBitsize<'l'>(kind, intMap);
288 }
289 
getRealTypeID(KindTy kind) const290 LLVMTypeID fir::KindMapping::getRealTypeID(KindTy kind) const {
291   return getFloatLikeTypeID<'r'>(kind, floatMap);
292 }
293 
getComplexTypeID(KindTy kind) const294 LLVMTypeID fir::KindMapping::getComplexTypeID(KindTy kind) const {
295   return getFloatLikeTypeID<'c'>(kind, floatMap);
296 }
297 
getRealBitsize(KindTy kind) const298 Bitsize fir::KindMapping::getRealBitsize(KindTy kind) const {
299   auto typeId = getFloatLikeTypeID<'r'>(kind, floatMap);
300   llvm::LLVMContext llCtxt; // FIXME
301   return llvm::Type::getPrimitiveType(llCtxt, typeId)->getPrimitiveSizeInBits();
302 }
303 
304 const llvm::fltSemantics &
getFloatSemantics(KindTy kind) const305 fir::KindMapping::getFloatSemantics(KindTy kind) const {
306   return getFloatSemanticsOfKind<'r'>(kind, floatMap);
307 }
308 
mapToString() const309 std::string fir::KindMapping::mapToString() const {
310   std::string result;
311   bool addComma = false;
312   for (auto [k, v] : intMap) {
313     if (addComma)
314       result.append(",");
315     else
316       addComma = true;
317     result += k.first + std::to_string(k.second) + ":" + std::to_string(v);
318   }
319   for (auto [k, v] : floatMap) {
320     if (addComma)
321       result.append(",");
322     else
323       addComma = true;
324     result.append(k.first + std::to_string(k.second) + ":");
325     switch (v) {
326     default:
327       llvm_unreachable("unhandled type-id");
328     case LLVMTypeID::HalfTyID:
329       result.append(kwHalf);
330       break;
331     case LLVMTypeID::BFloatTyID:
332       result.append(kwBFloat);
333       break;
334     case LLVMTypeID::FloatTyID:
335       result.append(kwFloat);
336       break;
337     case LLVMTypeID::DoubleTyID:
338       result.append(kwDouble);
339       break;
340     case LLVMTypeID::X86_FP80TyID:
341       result.append(kwX86FP80);
342       break;
343     case LLVMTypeID::FP128TyID:
344       result.append(kwFP128);
345       break;
346     case LLVMTypeID::PPC_FP128TyID:
347       result.append(kwPPCFP128);
348       break;
349     }
350   }
351   return result;
352 }
353 
354 mlir::LogicalResult
setDefaultKinds(llvm::ArrayRef<KindTy> defs)355 fir::KindMapping::setDefaultKinds(llvm::ArrayRef<KindTy> defs) {
356   if (defs.empty()) {
357     // generic front-end defaults
358     const KindTy genericKind = 4;
359     defaultMap.insert({'a', 1});
360     defaultMap.insert({'c', genericKind});
361     defaultMap.insert({'d', 2 * genericKind});
362     defaultMap.insert({'i', genericKind});
363     defaultMap.insert({'l', genericKind});
364     defaultMap.insert({'r', genericKind});
365     return mlir::success();
366   }
367   if (defs.size() != 6)
368     return mlir::failure();
369 
370   // defaults determined after command-line processing
371   defaultMap.insert({'a', defs[0]});
372   defaultMap.insert({'c', defs[1]});
373   defaultMap.insert({'d', defs[2]});
374   defaultMap.insert({'i', defs[3]});
375   defaultMap.insert({'l', defs[4]});
376   defaultMap.insert({'r', defs[5]});
377   return mlir::success();
378 }
379 
defaultsToString() const380 std::string fir::KindMapping::defaultsToString() const {
381   return std::string("a") + std::to_string(defaultMap.find('a')->second) +
382          std::string("c") + std::to_string(defaultMap.find('c')->second) +
383          std::string("d") + std::to_string(defaultMap.find('d')->second) +
384          std::string("i") + std::to_string(defaultMap.find('i')->second) +
385          std::string("l") + std::to_string(defaultMap.find('l')->second) +
386          std::string("r") + std::to_string(defaultMap.find('r')->second);
387 }
388 
389 /// Convert a default intrinsic code into the proper position in the array. The
390 /// default kinds have a precise ordering.
codeToIndex(char code)391 static int codeToIndex(char code) {
392   switch (code) {
393   case 'a':
394     return 0;
395   case 'c':
396     return 1;
397   case 'd':
398     return 2;
399   case 'i':
400     return 3;
401   case 'l':
402     return 4;
403   case 'r':
404     return 5;
405   }
406   llvm_unreachable("invalid default kind intrinsic code");
407 }
408 
toDefaultKinds(llvm::StringRef defs)409 std::vector<KindTy> fir::KindMapping::toDefaultKinds(llvm::StringRef defs) {
410   std::vector<KindTy> result(6);
411   char code;
412   KindTy kind;
413   if (defs.empty())
414     defs = clDefaultKinds;
415   const char *srcPtr = defs.begin();
416   const char *endPtr = defs.end();
417   while (srcPtr < endPtr) {
418     if (parseDefCode(code, srcPtr, endPtr) || parseInt(kind, srcPtr, endPtr))
419       llvm::report_fatal_error("invalid default kind code");
420     result[codeToIndex(code)] = kind;
421   }
422   assert(srcPtr == endPtr);
423   return result;
424 }
425 
defaultCharacterKind() const426 KindTy fir::KindMapping::defaultCharacterKind() const {
427   auto iter = defaultMap.find('a');
428   assert(iter != defaultMap.end());
429   return iter->second;
430 }
431 
defaultComplexKind() const432 KindTy fir::KindMapping::defaultComplexKind() const {
433   auto iter = defaultMap.find('c');
434   assert(iter != defaultMap.end());
435   return iter->second;
436 }
437 
defaultDoubleKind() const438 KindTy fir::KindMapping::defaultDoubleKind() const {
439   auto iter = defaultMap.find('d');
440   assert(iter != defaultMap.end());
441   return iter->second;
442 }
443 
defaultIntegerKind() const444 KindTy fir::KindMapping::defaultIntegerKind() const {
445   auto iter = defaultMap.find('i');
446   assert(iter != defaultMap.end());
447   return iter->second;
448 }
449 
defaultLogicalKind() const450 KindTy fir::KindMapping::defaultLogicalKind() const {
451   auto iter = defaultMap.find('l');
452   assert(iter != defaultMap.end());
453   return iter->second;
454 }
455 
defaultRealKind() const456 KindTy fir::KindMapping::defaultRealKind() const {
457   auto iter = defaultMap.find('r');
458   assert(iter != defaultMap.end());
459   return iter->second;
460 }
461