1 //===- Rules.h - Helpers for declaring facts and rules ----------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines helper classes and functions for managing state (facts),
10 // merging and tracking modification for various data types important for
11 // quantization.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_QUANTIZER_SUPPORT_RULES_H
16 #define MLIR_QUANTIZER_SUPPORT_RULES_H
17 
18 #include "llvm/ADT/Optional.h"
19 
20 #include <algorithm>
21 #include <limits>
22 #include <utility>
23 
24 namespace mlir {
25 namespace quantizer {
26 
27 /// Typed indicator of whether a mutator produces a modification.
28 struct ModificationResult {
29   enum ModificationEnum { Retained, Modified } value;
ModificationResultModificationResult30   ModificationResult(ModificationEnum v) : value(v) {}
31 
32   ModificationResult operator|(ModificationResult other) {
33     if (value == Modified || other.value == Modified) {
34       return ModificationResult(Modified);
35     } else {
36       return ModificationResult(Retained);
37     }
38   }
39 
40   ModificationResult operator|=(ModificationResult other) {
41     value =
42         (value == Modified || other.value == Modified) ? Modified : Retained;
43     return *this;
44   }
45 };
46 
47 inline ModificationResult modify(bool isModified = true) {
48   return ModificationResult{isModified ? ModificationResult::Modified
49                                        : ModificationResult::Retained};
50 }
51 
modified(ModificationResult m)52 inline bool modified(ModificationResult m) {
53   return m.value == ModificationResult::Modified;
54 }
55 
56 /// A fact that can converge through forward propagation alone without the
57 /// need to track ownership or individual assertions. In practice, this works
58 /// for static assertions that are either minimized or maximized and do not
59 /// vary dynamically.
60 ///
61 /// It is expected that ValueTy is appropriate to pass by value and has an
62 /// operator==. The BinaryReducer type should have two static methods:
63 ///   using ValueTy : Type of the value.
64 ///   ValueTy initialValue() : Returns the initial value of the fact.
65 ///   ValueTy reduce(ValueTy lhs, ValueTy rhs) : Reduces two values.
66 template <typename BinaryReducer>
67 class BasePropagatedFact {
68 public:
69   using ValueTy = typename BinaryReducer::ValueTy;
70   using ThisTy = BasePropagatedFact<BinaryReducer>;
BasePropagatedFact()71   BasePropagatedFact()
72       : value(BinaryReducer::initialValue()),
73         salience(std::numeric_limits<int>::min()) {}
74 
getSalience()75   int getSalience() const { return salience; }
hasValue()76   bool hasValue() const { return salience != std::numeric_limits<int>::min(); }
getValue()77   ValueTy getValue() const { return value; }
assertValue(int assertSalience,ValueTy assertValue)78   ModificationResult assertValue(int assertSalience, ValueTy assertValue) {
79     if (assertSalience > salience) {
80       // New salience band.
81       value = assertValue;
82       salience = assertSalience;
83       return modify(true);
84     } else if (assertSalience < salience) {
85       // Lower salience - ignore.
86       return modify(false);
87     }
88     // Merge within same salience band.
89     ValueTy updatedValue = BinaryReducer::reduce(value, assertValue);
90     auto mod = modify(value != updatedValue);
91     value = updatedValue;
92     return mod;
93   }
mergeFrom(const ThisTy & other)94   ModificationResult mergeFrom(const ThisTy &other) {
95     if (other.hasValue()) {
96       return assertValue(other.getSalience(), other.getValue());
97     }
98     return modify(false);
99   }
100 
101 private:
102   ValueTy value;
103   int salience;
104 };
105 
106 /// A binary reducer that expands a min/max range represented by a pair
107 /// of doubles such that it represents the largest of all inputs.
108 /// The initial value is (Inf, -Inf).
109 struct ExpandingMinMaxReducer {
110   using ValueTy = std::pair<double, double>;
initialValueExpandingMinMaxReducer111   static ValueTy initialValue() {
112     return std::make_pair(std::numeric_limits<double>::infinity(),
113                           -std::numeric_limits<double>::infinity());
114   }
reduceExpandingMinMaxReducer115   static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
116     return std::make_pair(std::min(lhs.first, rhs.first),
117                           std::max(lhs.second, rhs.second));
118   }
119 };
120 using ExpandingMinMaxFact = BasePropagatedFact<ExpandingMinMaxReducer>;
121 
122 /// A binary reducer that minimizing a numeric type.
123 template <typename T>
124 struct MinimizingNumericReducer {
125   using ValueTy = T;
initialValueMinimizingNumericReducer126   static ValueTy initialValue() {
127     if (std::numeric_limits<T>::has_infinity()) {
128       return std::numeric_limits<T>::infinity();
129     } else {
130       return std::numeric_limits<T>::max();
131     }
132   }
reduceMinimizingNumericReducer133   static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::min(lhs, rhs); }
134 };
135 using MinimizingDoubleFact =
136     BasePropagatedFact<MinimizingNumericReducer<double>>;
137 using MinimizingIntFact = BasePropagatedFact<MinimizingNumericReducer<int>>;
138 
139 /// A binary reducer that maximizes a numeric type.
140 template <typename T>
141 struct MaximizingNumericReducer {
142   using ValueTy = T;
initialValueMaximizingNumericReducer143   static ValueTy initialValue() {
144     if (std::numeric_limits<T>::has_infinity()) {
145       return -std::numeric_limits<T>::infinity();
146     } else {
147       return std::numeric_limits<T>::min();
148     }
149   }
reduceMaximizingNumericReducer150   static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::max(lhs, rhs); }
151 };
152 using MaximizingDoubleFact =
153     BasePropagatedFact<MaximizingNumericReducer<double>>;
154 using MaximizingIntFact = BasePropagatedFact<MaximizingNumericReducer<int>>;
155 
156 /// A fact and reducer for tracking agreement of discrete values. The value
157 /// type consists of a |T| value and a flag indicating whether there is a
158 /// conflict (in which case, the preserved value is arbitrary).
159 template <typename T>
160 struct DiscreteReducer {
161   struct ValueTy {
ValueTyDiscreteReducer::ValueTy162     ValueTy() : conflict(false) {}
ValueTyDiscreteReducer::ValueTy163     ValueTy(T value) : value(value), conflict(false) {}
ValueTyDiscreteReducer::ValueTy164     ValueTy(T value, bool conflict) : value(value), conflict(conflict) {}
165     llvm::Optional<T> value;
166     bool conflict;
167     bool operator==(const ValueTy &other) const {
168       if (conflict != other.conflict)
169         return false;
170       if (value && other.value) {
171         return *value == *other.value;
172       } else {
173         return !value && !other.value;
174       }
175     }
176     bool operator!=(const ValueTy &other) const { return !(*this == other); }
177   };
initialValueDiscreteReducer178   static ValueTy initialValue() { return ValueTy(); }
reduceDiscreteReducer179   static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
180     if (!lhs.value && !rhs.value)
181       return lhs;
182     else if (!lhs.value)
183       return rhs;
184     else if (!rhs.value)
185       return lhs;
186     else
187       return ValueTy(*lhs.value, *lhs.value != *rhs.value);
188   }
189 };
190 
191 template <typename T>
192 using DiscreteFact = BasePropagatedFact<DiscreteReducer<T>>;
193 
194 /// Discrete scale/zeroPoint fact.
195 using DiscreteScaleZeroPointFact = DiscreteFact<std::pair<double, int64_t>>;
196 
197 } // end namespace quantizer
198 } // end namespace mlir
199 
200 #endif // MLIR_QUANTIZER_SUPPORT_RULES_H
201