1 //===- Configuration.h - Configuration object base classes ------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The quantizer is relatively agnostic to source and target dialects, with
10 // the specific represented by configuration policy objects derived from
11 // classes in this file.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
16 #define MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
17 
18 #include <functional>
19 
20 #include "mlir/Dialect/QuantOps/QuantTypes.h"
21 #include "mlir/IR/Identifier.h"
22 #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
23 #include "mlir/Quantizer/Support/Metadata.h"
24 #include "mlir/Quantizer/Support/Rules.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/StringSet.h"
28 
29 namespace mlir {
30 class Operation;
31 
32 namespace quantizer {
33 
34 class CAGSlice;
35 
36 /// Defines quantization configuration for the target.
37 /// The settings here depend on a variety of details about the deployment
38 /// environment, although, where we have control over such things, we do
39 /// try to standardize as possible.
40 ///
41 /// Non-const methods are used to setup the configuration. It is expected that
42 /// const instances/references are used post-build.
43 class TargetConfiguration {
44 public:
45   static constexpr size_t MaxSchemeIndex = 31;
46   using OpHandlerFn = std::function<void(Operation *op, CAGSlice &cag)>;
47 
48   TargetConfiguration(SolverContext &context);
49   virtual ~TargetConfiguration() = default;
50 
51   /// Adds a candidate type, returning its ordinal.
addCandidateType(quant::AnyQuantizedType quantizedType,CandidateQuantizedType::Scheme scheme)52   unsigned addCandidateType(quant::AnyQuantizedType quantizedType,
53                             CandidateQuantizedType::Scheme scheme) {
54     unsigned ordinal = candidateTypes.size();
55     assert(allCandidateTypesMask.size() == ordinal);
56     CandidateQuantizedType ct{ordinal, quantizedType, scheme};
57     candidateTypes.push_back(ct);
58     allCandidateTypesMask.push_back(true);
59     return ordinal;
60   }
61 
62   /// Gets a prototype scheme by index.
getCandidateType(unsigned index)63   const CandidateQuantizedType &getCandidateType(unsigned index) const {
64     assert(index < candidateTypes.size());
65     return candidateTypes[index];
66   }
67 
getCandidateTypes()68   ArrayRef<CandidateQuantizedType> getCandidateTypes() const {
69     return candidateTypes;
70   }
71 
72   /// Gets a mask of all enabled candidate types by ordinal.
getAllCandidateTypesMask()73   llvm::SmallBitVector getAllCandidateTypesMask() const {
74     return allCandidateTypesMask;
75   }
76 
77   /// Gets a mask with every candidate type except those in the given mask.
78   llvm::SmallBitVector
getCandidateTypeDisabledExceptMask(ArrayRef<unsigned> exceptOrdinals)79   getCandidateTypeDisabledExceptMask(ArrayRef<unsigned> exceptOrdinals) const {
80     llvm::SmallBitVector disabled(allCandidateTypesMask);
81     for (unsigned ordinal : exceptOrdinals) {
82       disabled.reset(ordinal);
83     }
84     return disabled;
85   }
86 
87   /// Adds an op handler.
88   template <typename OpTy>
addOpHandler(OpHandlerFn fn)89   void addOpHandler(OpHandlerFn fn) {
90     addOpHandlerByName(OpTy::getOperationName(), fn);
91   }
92 
93   /// Adds an operation which requires statistics at its result nodes for
94   /// best quantization performance. Note that the opName StringRef is
95   /// expected to come from getOperationName() and be static.
96   template <typename OpTy>
addRequireStatsOp()97   void addRequireStatsOp() {
98     addRequireStatsOpByName(OpTy::getOperationName());
99   }
100 
101   /// Returns whether opName is a RequireStatsOp.
102   bool isRequireStatsOp(Operation *op) const;
103 
104   /// Adds an op which does not mutate its values but may mutate its shape
105   /// or combine its operands in an arbitrary way.
106   /// Such ops are expected to have the same types for operands and results
107   /// and must be capable of operating on storage types.
108   template <typename OpTy>
addValueIdentityOp()109   void addValueIdentityOp() {
110     addValueIdentityOpByName(OpTy::getOperationName());
111   }
112 
113   /// Handles the operation if a handler is defined for it.
114   void handleOp(Operation *op, CAGSlice &cag) const;
115 
116   /// Finalizes the CAG after all anchors have been added.
finalizeAnchors(CAGSlice & cag)117   virtual void finalizeAnchors(CAGSlice &cag) const {}
118 
119   /// Whether an operand or result type is subject to analysis by this config.
120   virtual bool isHandledType(Type t) const = 0;
121 
122 protected:
123   virtual void addValueIdentityOpByName(StringRef opName) = 0;
124   void addOpHandlerByName(StringRef name, OpHandlerFn fn);
125 
126 private:
127   void addRequireStatsOpByName(StringRef opName);
128 
129   /// Vector of all candidate type constraints, indexed by ordinal.
130   std::vector<CandidateQuantizedType> candidateTypes;
131 
132   // A SmallBoolVector with bits set for all known candidate types.
133   llvm::SmallBitVector allCandidateTypesMask;
134 
135   /// Map of all op handlers.
136   llvm::StringMap<OpHandlerFn> opHandlers;
137 
138   /// Names of operations which should have their results annotated with
139   /// statistics.
140   llvm::StringSet<> requireStatsOpNames;
141 };
142 
143 } // namespace quantizer
144 } // namespace mlir
145 
146 #endif // MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
147