1 //===-- X86InstrFMA3Info.cpp - X86 FMA3 Instruction Information -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the classes providing information
10 // about existing X86 FMA3 opcodes, classifying and grouping them.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "X86InstrFMA3Info.h"
15 #include "X86InstrInfo.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/Threading.h"
18 #include <cassert>
19 #include <cstdint>
20 
21 using namespace llvm;
22 
23 #define FMA3GROUP(Name, Suf, Attrs) \
24   { { X86::Name##132##Suf, X86::Name##213##Suf, X86::Name##231##Suf }, Attrs },
25 
26 #define FMA3GROUP_MASKED(Name, Suf, Attrs) \
27   FMA3GROUP(Name, Suf, Attrs) \
28   FMA3GROUP(Name, Suf##k, Attrs | X86InstrFMA3Group::KMergeMasked) \
29   FMA3GROUP(Name, Suf##kz, Attrs | X86InstrFMA3Group::KZeroMasked)
30 
31 #define FMA3GROUP_PACKED_WIDTHS(Name, Suf, Attrs) \
32   FMA3GROUP(Name, Suf##Ym, Attrs) \
33   FMA3GROUP(Name, Suf##Yr, Attrs) \
34   FMA3GROUP_MASKED(Name, Suf##Z128m, Attrs) \
35   FMA3GROUP_MASKED(Name, Suf##Z128r, Attrs) \
36   FMA3GROUP_MASKED(Name, Suf##Z256m, Attrs) \
37   FMA3GROUP_MASKED(Name, Suf##Z256r, Attrs) \
38   FMA3GROUP_MASKED(Name, Suf##Zm, Attrs) \
39   FMA3GROUP_MASKED(Name, Suf##Zr, Attrs) \
40   FMA3GROUP(Name, Suf##m, Attrs) \
41   FMA3GROUP(Name, Suf##r, Attrs)
42 
43 #define FMA3GROUP_PACKED(Name, Attrs) \
44   FMA3GROUP_PACKED_WIDTHS(Name, PD, Attrs) \
45   FMA3GROUP_PACKED_WIDTHS(Name, PS, Attrs)
46 
47 #define FMA3GROUP_SCALAR_WIDTHS(Name, Suf, Attrs) \
48   FMA3GROUP(Name, Suf##Zm, Attrs) \
49   FMA3GROUP_MASKED(Name, Suf##Zm_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
50   FMA3GROUP(Name, Suf##Zr, Attrs) \
51   FMA3GROUP_MASKED(Name, Suf##Zr_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
52   FMA3GROUP(Name, Suf##m, Attrs) \
53   FMA3GROUP(Name, Suf##m_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
54   FMA3GROUP(Name, Suf##r, Attrs) \
55   FMA3GROUP(Name, Suf##r_Int, Attrs | X86InstrFMA3Group::Intrinsic)
56 
57 #define FMA3GROUP_SCALAR(Name, Attrs) \
58   FMA3GROUP_SCALAR_WIDTHS(Name, SD, Attrs) \
59   FMA3GROUP_SCALAR_WIDTHS(Name, SS, Attrs)
60 
61 #define FMA3GROUP_FULL(Name, Attrs) \
62   FMA3GROUP_PACKED(Name, Attrs) \
63   FMA3GROUP_SCALAR(Name, Attrs)
64 
65 static const X86InstrFMA3Group Groups[] = {
66   FMA3GROUP_FULL(VFMADD, 0)
67   FMA3GROUP_PACKED(VFMADDSUB, 0)
68   FMA3GROUP_FULL(VFMSUB, 0)
69   FMA3GROUP_PACKED(VFMSUBADD, 0)
70   FMA3GROUP_FULL(VFNMADD, 0)
71   FMA3GROUP_FULL(VFNMSUB, 0)
72 };
73 
74 #define FMA3GROUP_PACKED_AVX512_WIDTHS(Name, Type, Suf, Attrs) \
75   FMA3GROUP_MASKED(Name, Type##Z128##Suf, Attrs) \
76   FMA3GROUP_MASKED(Name, Type##Z256##Suf, Attrs) \
77   FMA3GROUP_MASKED(Name, Type##Z##Suf, Attrs)
78 
79 #define FMA3GROUP_PACKED_AVX512(Name, Suf, Attrs) \
80   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PD, Suf, Attrs) \
81   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PS, Suf, Attrs)
82 
83 #define FMA3GROUP_PACKED_AVX512_ROUND(Name, Suf, Attrs) \
84   FMA3GROUP_MASKED(Name, PDZ##Suf, Attrs) \
85   FMA3GROUP_MASKED(Name, PSZ##Suf, Attrs)
86 
87 #define FMA3GROUP_SCALAR_AVX512_ROUND(Name, Suf, Attrs) \
88   FMA3GROUP(Name, SDZ##Suf, Attrs) \
89   FMA3GROUP_MASKED(Name, SDZ##Suf##_Int, Attrs) \
90   FMA3GROUP(Name, SSZ##Suf, Attrs) \
91   FMA3GROUP_MASKED(Name, SSZ##Suf##_Int, Attrs)
92 
93 static const X86InstrFMA3Group BroadcastGroups[] = {
94   FMA3GROUP_PACKED_AVX512(VFMADD, mb, 0)
95   FMA3GROUP_PACKED_AVX512(VFMADDSUB, mb, 0)
96   FMA3GROUP_PACKED_AVX512(VFMSUB, mb, 0)
97   FMA3GROUP_PACKED_AVX512(VFMSUBADD, mb, 0)
98   FMA3GROUP_PACKED_AVX512(VFNMADD, mb, 0)
99   FMA3GROUP_PACKED_AVX512(VFNMSUB, mb, 0)
100 };
101 
102 static const X86InstrFMA3Group RoundGroups[] = {
103   FMA3GROUP_PACKED_AVX512_ROUND(VFMADD, rb, 0)
104   FMA3GROUP_SCALAR_AVX512_ROUND(VFMADD, rb, X86InstrFMA3Group::Intrinsic)
105   FMA3GROUP_PACKED_AVX512_ROUND(VFMADDSUB, rb, 0)
106   FMA3GROUP_PACKED_AVX512_ROUND(VFMSUB, rb, 0)
107   FMA3GROUP_SCALAR_AVX512_ROUND(VFMSUB, rb, X86InstrFMA3Group::Intrinsic)
108   FMA3GROUP_PACKED_AVX512_ROUND(VFMSUBADD, rb, 0)
109   FMA3GROUP_PACKED_AVX512_ROUND(VFNMADD, rb, 0)
110   FMA3GROUP_SCALAR_AVX512_ROUND(VFNMADD, rb, X86InstrFMA3Group::Intrinsic)
111   FMA3GROUP_PACKED_AVX512_ROUND(VFNMSUB, rb, 0)
112   FMA3GROUP_SCALAR_AVX512_ROUND(VFNMSUB, rb, X86InstrFMA3Group::Intrinsic)
113 };
114 
verifyTables()115 static void verifyTables() {
116 #ifndef NDEBUG
117   static std::atomic<bool> TableChecked(false);
118   if (!TableChecked.load(std::memory_order_relaxed)) {
119     assert(std::is_sorted(std::begin(Groups), std::end(Groups)) &&
120            std::is_sorted(std::begin(RoundGroups), std::end(RoundGroups)) &&
121            std::is_sorted(std::begin(BroadcastGroups),
122                           std::end(BroadcastGroups)) &&
123            "FMA3 tables not sorted!");
124     TableChecked.store(true, std::memory_order_relaxed);
125   }
126 #endif
127 }
128 
129 /// Returns a reference to a group of FMA3 opcodes to where the given
130 /// \p Opcode is included. If the given \p Opcode is not recognized as FMA3
131 /// and not included into any FMA3 group, then nullptr is returned.
getFMA3Group(unsigned Opcode,uint64_t TSFlags)132 const X86InstrFMA3Group *llvm::getFMA3Group(unsigned Opcode, uint64_t TSFlags) {
133 
134   // FMA3 instructions have a well defined encoding pattern we can exploit.
135   uint8_t BaseOpcode = X86II::getBaseOpcodeFor(TSFlags);
136   bool IsFMA3 = ((TSFlags & X86II::EncodingMask) == X86II::VEX ||
137                  (TSFlags & X86II::EncodingMask) == X86II::EVEX) &&
138                 (TSFlags & X86II::OpMapMask) == X86II::T8 &&
139                 (TSFlags & X86II::OpPrefixMask) == X86II::PD &&
140                 ((BaseOpcode >= 0x96 && BaseOpcode <= 0x9F) ||
141                  (BaseOpcode >= 0xA6 && BaseOpcode <= 0xAF) ||
142                  (BaseOpcode >= 0xB6 && BaseOpcode <= 0xBF));
143   if (!IsFMA3)
144     return nullptr;
145 
146   verifyTables();
147 
148   ArrayRef<X86InstrFMA3Group> Table;
149   if (TSFlags & X86II::EVEX_RC)
150     Table = makeArrayRef(RoundGroups);
151   else if (TSFlags & X86II::EVEX_B)
152     Table = makeArrayRef(BroadcastGroups);
153   else
154     Table = makeArrayRef(Groups);
155 
156   // FMA 132 instructions have an opcode of 0x96-0x9F
157   // FMA 213 instructions have an opcode of 0xA6-0xAF
158   // FMA 231 instructions have an opcode of 0xB6-0xBF
159   unsigned FormIndex = ((BaseOpcode - 0x90) >> 4) & 0x3;
160 
161   auto I = partition_point(Table, [=](const X86InstrFMA3Group &Group) {
162     return Group.Opcodes[FormIndex] < Opcode;
163   });
164   assert(I != Table.end() && I->Opcodes[FormIndex] == Opcode &&
165          "Couldn't find FMA3 opcode!");
166   return I;
167 }
168