1 //===- InstructionCost.h ----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file defines an InstructionCost class that is used when calculating
10 /// the cost of an instruction, or a group of instructions. In addition to a
11 /// numeric value representing the cost the class also contains a state that
12 /// can be used to encode particular properties, such as a cost being invalid.
13 /// Operations on InstructionCost implement saturation arithmetic, so that
14 /// accumulating costs on large cost-values don't overflow.
15 ///
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef LLVM_SUPPORT_INSTRUCTIONCOST_H
19 #define LLVM_SUPPORT_INSTRUCTIONCOST_H
20 
21 #include "llvm/Support/MathExtras.h"
22 #include <limits>
23 #include <optional>
24 
25 namespace llvm {
26 
27 class raw_ostream;
28 
29 class InstructionCost {
30 public:
31   using CostType = int64_t;
32 
33   /// CostState describes the state of a cost.
34   enum CostState {
35     Valid,  /// < The cost value represents a valid cost, even when the
36             /// cost-value is large.
37     Invalid /// < Invalid indicates there is no way to represent the cost as a
38             /// numeric value. This state exists to represent a possible issue,
39             /// e.g. if the cost-model knows the operation cannot be expanded
40             /// into a valid code-sequence by the code-generator.  While some
41             /// passes may assert that the calculated cost must be valid, it is
42             /// up to individual passes how to interpret an Invalid cost. For
43             /// example, a transformation pass could choose not to perform a
44             /// transformation if the resulting cost would end up Invalid.
45             /// Because some passes may assert a cost is Valid, it is not
46             /// recommended to use Invalid costs to model 'Unknown'.
47             /// Note that Invalid is semantically different from a (very) high,
48             /// but valid cost, which intentionally indicates no issue, but
49             /// rather a strong preference not to select a certain operation.
50   };
51 
52 private:
53   CostType Value = 0;
54   CostState State = Valid;
55 
propagateState(const InstructionCost & RHS)56   void propagateState(const InstructionCost &RHS) {
57     if (RHS.State == Invalid)
58       State = Invalid;
59   }
60 
getMaxValue()61   static CostType getMaxValue() { return std::numeric_limits<CostType>::max(); }
getMinValue()62   static CostType getMinValue() { return std::numeric_limits<CostType>::min(); }
63 
64 public:
65   // A default constructed InstructionCost is a valid zero cost
66   InstructionCost() = default;
67 
68   InstructionCost(CostState) = delete;
InstructionCost(CostType Val)69   InstructionCost(CostType Val) : Value(Val), State(Valid) {}
70 
getMax()71   static InstructionCost getMax() { return getMaxValue(); }
getMin()72   static InstructionCost getMin() { return getMinValue(); }
73   static InstructionCost getInvalid(CostType Val = 0) {
74     InstructionCost Tmp(Val);
75     Tmp.setInvalid();
76     return Tmp;
77   }
78 
isValid()79   bool isValid() const { return State == Valid; }
setValid()80   void setValid() { State = Valid; }
setInvalid()81   void setInvalid() { State = Invalid; }
getState()82   CostState getState() const { return State; }
83 
84   /// This function is intended to be used as sparingly as possible, since the
85   /// class provides the full range of operator support required for arithmetic
86   /// and comparisons.
getValue()87   std::optional<CostType> getValue() const {
88     if (isValid())
89       return Value;
90     return std::nullopt;
91   }
92 
93   /// For all of the arithmetic operators provided here any invalid state is
94   /// perpetuated and cannot be removed. Once a cost becomes invalid it stays
95   /// invalid, and it also inherits any invalid state from the RHS.
96   /// Arithmetic work on the actual values is implemented with saturation,
97   /// to avoid overflow when using more extreme cost values.
98 
99   InstructionCost &operator+=(const InstructionCost &RHS) {
100     propagateState(RHS);
101 
102     // Saturating addition.
103     InstructionCost::CostType Result;
104     if (AddOverflow(Value, RHS.Value, Result))
105       Result = RHS.Value > 0 ? getMaxValue() : getMinValue();
106 
107     Value = Result;
108     return *this;
109   }
110 
111   InstructionCost &operator+=(const CostType RHS) {
112     InstructionCost RHS2(RHS);
113     *this += RHS2;
114     return *this;
115   }
116 
117   InstructionCost &operator-=(const InstructionCost &RHS) {
118     propagateState(RHS);
119 
120     // Saturating subtract.
121     InstructionCost::CostType Result;
122     if (SubOverflow(Value, RHS.Value, Result))
123       Result = RHS.Value > 0 ? getMinValue() : getMaxValue();
124     Value = Result;
125     return *this;
126   }
127 
128   InstructionCost &operator-=(const CostType RHS) {
129     InstructionCost RHS2(RHS);
130     *this -= RHS2;
131     return *this;
132   }
133 
134   InstructionCost &operator*=(const InstructionCost &RHS) {
135     propagateState(RHS);
136 
137     // Saturating multiply.
138     InstructionCost::CostType Result;
139     if (MulOverflow(Value, RHS.Value, Result)) {
140       if ((Value > 0 && RHS.Value > 0) || (Value < 0 && RHS.Value < 0))
141         Result = getMaxValue();
142       else
143         Result = getMinValue();
144     }
145 
146     Value = Result;
147     return *this;
148   }
149 
150   InstructionCost &operator*=(const CostType RHS) {
151     InstructionCost RHS2(RHS);
152     *this *= RHS2;
153     return *this;
154   }
155 
156   InstructionCost &operator/=(const InstructionCost &RHS) {
157     propagateState(RHS);
158     Value /= RHS.Value;
159     return *this;
160   }
161 
162   InstructionCost &operator/=(const CostType RHS) {
163     InstructionCost RHS2(RHS);
164     *this /= RHS2;
165     return *this;
166   }
167 
168   InstructionCost &operator++() {
169     *this += 1;
170     return *this;
171   }
172 
173   InstructionCost operator++(int) {
174     InstructionCost Copy = *this;
175     ++*this;
176     return Copy;
177   }
178 
179   InstructionCost &operator--() {
180     *this -= 1;
181     return *this;
182   }
183 
184   InstructionCost operator--(int) {
185     InstructionCost Copy = *this;
186     --*this;
187     return Copy;
188   }
189 
190   /// For the comparison operators we have chosen to use lexicographical
191   /// ordering where valid costs are always considered to be less than invalid
192   /// costs. This avoids having to add asserts to the comparison operators that
193   /// the states are valid and users can test for validity of the cost
194   /// explicitly.
195   bool operator<(const InstructionCost &RHS) const {
196     if (State != RHS.State)
197       return State < RHS.State;
198     return Value < RHS.Value;
199   }
200 
201   // Implement in terms of operator< to ensure that the two comparisons stay in
202   // sync
203   bool operator==(const InstructionCost &RHS) const {
204     return !(*this < RHS) && !(RHS < *this);
205   }
206 
207   bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); }
208 
209   bool operator==(const CostType RHS) const {
210     InstructionCost RHS2(RHS);
211     return *this == RHS2;
212   }
213 
214   bool operator!=(const CostType RHS) const { return !(*this == RHS); }
215 
216   bool operator>(const InstructionCost &RHS) const { return RHS < *this; }
217 
218   bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); }
219 
220   bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); }
221 
222   bool operator<(const CostType RHS) const {
223     InstructionCost RHS2(RHS);
224     return *this < RHS2;
225   }
226 
227   bool operator>(const CostType RHS) const {
228     InstructionCost RHS2(RHS);
229     return *this > RHS2;
230   }
231 
232   bool operator<=(const CostType RHS) const {
233     InstructionCost RHS2(RHS);
234     return *this <= RHS2;
235   }
236 
237   bool operator>=(const CostType RHS) const {
238     InstructionCost RHS2(RHS);
239     return *this >= RHS2;
240   }
241 
242   void print(raw_ostream &OS) const;
243 
244   template <class Function>
245   auto map(const Function &F) const -> InstructionCost {
246     if (isValid())
247       return F(Value);
248     return getInvalid();
249   }
250 };
251 
252 inline InstructionCost operator+(const InstructionCost &LHS,
253                                  const InstructionCost &RHS) {
254   InstructionCost LHS2(LHS);
255   LHS2 += RHS;
256   return LHS2;
257 }
258 
259 inline InstructionCost operator-(const InstructionCost &LHS,
260                                  const InstructionCost &RHS) {
261   InstructionCost LHS2(LHS);
262   LHS2 -= RHS;
263   return LHS2;
264 }
265 
266 inline InstructionCost operator*(const InstructionCost &LHS,
267                                  const InstructionCost &RHS) {
268   InstructionCost LHS2(LHS);
269   LHS2 *= RHS;
270   return LHS2;
271 }
272 
273 inline InstructionCost operator/(const InstructionCost &LHS,
274                                  const InstructionCost &RHS) {
275   InstructionCost LHS2(LHS);
276   LHS2 /= RHS;
277   return LHS2;
278 }
279 
280 inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) {
281   V.print(OS);
282   return OS;
283 }
284 
285 } // namespace llvm
286 
287 #endif
288