1 //===- Rules.h - Helpers for declaring facts and rules ----------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines helper classes and functions for managing state (facts),
10 // merging and tracking modification for various data types important for
11 // quantization.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef MLIR_QUANTIZER_SUPPORT_RULES_H
16 #define MLIR_QUANTIZER_SUPPORT_RULES_H
17
18 #include "llvm/ADT/Optional.h"
19
20 #include <algorithm>
21 #include <limits>
22 #include <utility>
23
24 namespace mlir {
25 namespace quantizer {
26
27 /// Typed indicator of whether a mutator produces a modification.
28 struct ModificationResult {
29 enum ModificationEnum { Retained, Modified } value;
ModificationResultModificationResult30 ModificationResult(ModificationEnum v) : value(v) {}
31
32 ModificationResult operator|(ModificationResult other) {
33 if (value == Modified || other.value == Modified) {
34 return ModificationResult(Modified);
35 } else {
36 return ModificationResult(Retained);
37 }
38 }
39
40 ModificationResult operator|=(ModificationResult other) {
41 value =
42 (value == Modified || other.value == Modified) ? Modified : Retained;
43 return *this;
44 }
45 };
46
47 inline ModificationResult modify(bool isModified = true) {
48 return ModificationResult{isModified ? ModificationResult::Modified
49 : ModificationResult::Retained};
50 }
51
modified(ModificationResult m)52 inline bool modified(ModificationResult m) {
53 return m.value == ModificationResult::Modified;
54 }
55
56 /// A fact that can converge through forward propagation alone without the
57 /// need to track ownership or individual assertions. In practice, this works
58 /// for static assertions that are either minimized or maximized and do not
59 /// vary dynamically.
60 ///
61 /// It is expected that ValueTy is appropriate to pass by value and has an
62 /// operator==. The BinaryReducer type should have two static methods:
63 /// using ValueTy : Type of the value.
64 /// ValueTy initialValue() : Returns the initial value of the fact.
65 /// ValueTy reduce(ValueTy lhs, ValueTy rhs) : Reduces two values.
66 template <typename BinaryReducer>
67 class BasePropagatedFact {
68 public:
69 using ValueTy = typename BinaryReducer::ValueTy;
70 using ThisTy = BasePropagatedFact<BinaryReducer>;
BasePropagatedFact()71 BasePropagatedFact()
72 : value(BinaryReducer::initialValue()),
73 salience(std::numeric_limits<int>::min()) {}
74
getSalience()75 int getSalience() const { return salience; }
hasValue()76 bool hasValue() const { return salience != std::numeric_limits<int>::min(); }
getValue()77 ValueTy getValue() const { return value; }
assertValue(int assertSalience,ValueTy assertValue)78 ModificationResult assertValue(int assertSalience, ValueTy assertValue) {
79 if (assertSalience > salience) {
80 // New salience band.
81 value = assertValue;
82 salience = assertSalience;
83 return modify(true);
84 } else if (assertSalience < salience) {
85 // Lower salience - ignore.
86 return modify(false);
87 }
88 // Merge within same salience band.
89 ValueTy updatedValue = BinaryReducer::reduce(value, assertValue);
90 auto mod = modify(value != updatedValue);
91 value = updatedValue;
92 return mod;
93 }
mergeFrom(const ThisTy & other)94 ModificationResult mergeFrom(const ThisTy &other) {
95 if (other.hasValue()) {
96 return assertValue(other.getSalience(), other.getValue());
97 }
98 return modify(false);
99 }
100
101 private:
102 ValueTy value;
103 int salience;
104 };
105
106 /// A binary reducer that expands a min/max range represented by a pair
107 /// of doubles such that it represents the largest of all inputs.
108 /// The initial value is (Inf, -Inf).
109 struct ExpandingMinMaxReducer {
110 using ValueTy = std::pair<double, double>;
initialValueExpandingMinMaxReducer111 static ValueTy initialValue() {
112 return std::make_pair(std::numeric_limits<double>::infinity(),
113 -std::numeric_limits<double>::infinity());
114 }
reduceExpandingMinMaxReducer115 static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
116 return std::make_pair(std::min(lhs.first, rhs.first),
117 std::max(lhs.second, rhs.second));
118 }
119 };
120 using ExpandingMinMaxFact = BasePropagatedFact<ExpandingMinMaxReducer>;
121
122 /// A binary reducer that minimizing a numeric type.
123 template <typename T>
124 struct MinimizingNumericReducer {
125 using ValueTy = T;
initialValueMinimizingNumericReducer126 static ValueTy initialValue() {
127 if (std::numeric_limits<T>::has_infinity()) {
128 return std::numeric_limits<T>::infinity();
129 } else {
130 return std::numeric_limits<T>::max();
131 }
132 }
reduceMinimizingNumericReducer133 static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::min(lhs, rhs); }
134 };
135 using MinimizingDoubleFact =
136 BasePropagatedFact<MinimizingNumericReducer<double>>;
137 using MinimizingIntFact = BasePropagatedFact<MinimizingNumericReducer<int>>;
138
139 /// A binary reducer that maximizes a numeric type.
140 template <typename T>
141 struct MaximizingNumericReducer {
142 using ValueTy = T;
initialValueMaximizingNumericReducer143 static ValueTy initialValue() {
144 if (std::numeric_limits<T>::has_infinity()) {
145 return -std::numeric_limits<T>::infinity();
146 } else {
147 return std::numeric_limits<T>::min();
148 }
149 }
reduceMaximizingNumericReducer150 static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::max(lhs, rhs); }
151 };
152 using MaximizingDoubleFact =
153 BasePropagatedFact<MaximizingNumericReducer<double>>;
154 using MaximizingIntFact = BasePropagatedFact<MaximizingNumericReducer<int>>;
155
156 /// A fact and reducer for tracking agreement of discrete values. The value
157 /// type consists of a |T| value and a flag indicating whether there is a
158 /// conflict (in which case, the preserved value is arbitrary).
159 template <typename T>
160 struct DiscreteReducer {
161 struct ValueTy {
ValueTyDiscreteReducer::ValueTy162 ValueTy() : conflict(false) {}
ValueTyDiscreteReducer::ValueTy163 ValueTy(T value) : value(value), conflict(false) {}
ValueTyDiscreteReducer::ValueTy164 ValueTy(T value, bool conflict) : value(value), conflict(conflict) {}
165 llvm::Optional<T> value;
166 bool conflict;
167 bool operator==(const ValueTy &other) const {
168 if (conflict != other.conflict)
169 return false;
170 if (value && other.value) {
171 return *value == *other.value;
172 } else {
173 return !value && !other.value;
174 }
175 }
176 bool operator!=(const ValueTy &other) const { return !(*this == other); }
177 };
initialValueDiscreteReducer178 static ValueTy initialValue() { return ValueTy(); }
reduceDiscreteReducer179 static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
180 if (!lhs.value && !rhs.value)
181 return lhs;
182 else if (!lhs.value)
183 return rhs;
184 else if (!rhs.value)
185 return lhs;
186 else
187 return ValueTy(*lhs.value, *lhs.value != *rhs.value);
188 }
189 };
190
191 template <typename T>
192 using DiscreteFact = BasePropagatedFact<DiscreteReducer<T>>;
193
194 /// Discrete scale/zeroPoint fact.
195 using DiscreteScaleZeroPointFact = DiscreteFact<std::pair<double, int64_t>>;
196
197 } // end namespace quantizer
198 } // end namespace mlir
199
200 #endif // MLIR_QUANTIZER_SUPPORT_RULES_H
201