1 //===- Configuration.h - Configuration object base classes ------*- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The quantizer is relatively agnostic to source and target dialects, with 10 // the specific represented by configuration policy objects derived from 11 // classes in this file. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H 16 #define MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H 17 18 #include <functional> 19 20 #include "mlir/Dialect/QuantOps/QuantTypes.h" 21 #include "mlir/IR/Identifier.h" 22 #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" 23 #include "mlir/Quantizer/Support/Metadata.h" 24 #include "mlir/Quantizer/Support/Rules.h" 25 #include "llvm/ADT/DenseMap.h" 26 #include "llvm/ADT/SmallBitVector.h" 27 #include "llvm/ADT/StringSet.h" 28 29 namespace mlir { 30 class Operation; 31 32 namespace quantizer { 33 34 class CAGSlice; 35 36 /// Defines quantization configuration for the target. 37 /// The settings here depend on a variety of details about the deployment 38 /// environment, although, where we have control over such things, we do 39 /// try to standardize as possible. 40 /// 41 /// Non-const methods are used to setup the configuration. It is expected that 42 /// const instances/references are used post-build. 43 class TargetConfiguration { 44 public: 45 static constexpr size_t MaxSchemeIndex = 31; 46 using OpHandlerFn = std::function<void(Operation *op, CAGSlice &cag)>; 47 48 TargetConfiguration(SolverContext &context); 49 virtual ~TargetConfiguration() = default; 50 51 /// Adds a candidate type, returning its ordinal. addCandidateType(quant::AnyQuantizedType quantizedType,CandidateQuantizedType::Scheme scheme)52 unsigned addCandidateType(quant::AnyQuantizedType quantizedType, 53 CandidateQuantizedType::Scheme scheme) { 54 unsigned ordinal = candidateTypes.size(); 55 assert(allCandidateTypesMask.size() == ordinal); 56 CandidateQuantizedType ct{ordinal, quantizedType, scheme}; 57 candidateTypes.push_back(ct); 58 allCandidateTypesMask.push_back(true); 59 return ordinal; 60 } 61 62 /// Gets a prototype scheme by index. getCandidateType(unsigned index)63 const CandidateQuantizedType &getCandidateType(unsigned index) const { 64 assert(index < candidateTypes.size()); 65 return candidateTypes[index]; 66 } 67 getCandidateTypes()68 ArrayRef<CandidateQuantizedType> getCandidateTypes() const { 69 return candidateTypes; 70 } 71 72 /// Gets a mask of all enabled candidate types by ordinal. getAllCandidateTypesMask()73 llvm::SmallBitVector getAllCandidateTypesMask() const { 74 return allCandidateTypesMask; 75 } 76 77 /// Gets a mask with every candidate type except those in the given mask. 78 llvm::SmallBitVector getCandidateTypeDisabledExceptMask(ArrayRef<unsigned> exceptOrdinals)79 getCandidateTypeDisabledExceptMask(ArrayRef<unsigned> exceptOrdinals) const { 80 llvm::SmallBitVector disabled(allCandidateTypesMask); 81 for (unsigned ordinal : exceptOrdinals) { 82 disabled.reset(ordinal); 83 } 84 return disabled; 85 } 86 87 /// Adds an op handler. 88 template <typename OpTy> addOpHandler(OpHandlerFn fn)89 void addOpHandler(OpHandlerFn fn) { 90 addOpHandlerByName(OpTy::getOperationName(), fn); 91 } 92 93 /// Adds an operation which requires statistics at its result nodes for 94 /// best quantization performance. Note that the opName StringRef is 95 /// expected to come from getOperationName() and be static. 96 template <typename OpTy> addRequireStatsOp()97 void addRequireStatsOp() { 98 addRequireStatsOpByName(OpTy::getOperationName()); 99 } 100 101 /// Returns whether opName is a RequireStatsOp. 102 bool isRequireStatsOp(Operation *op) const; 103 104 /// Adds an op which does not mutate its values but may mutate its shape 105 /// or combine its operands in an arbitrary way. 106 /// Such ops are expected to have the same types for operands and results 107 /// and must be capable of operating on storage types. 108 template <typename OpTy> addValueIdentityOp()109 void addValueIdentityOp() { 110 addValueIdentityOpByName(OpTy::getOperationName()); 111 } 112 113 /// Handles the operation if a handler is defined for it. 114 void handleOp(Operation *op, CAGSlice &cag) const; 115 116 /// Finalizes the CAG after all anchors have been added. finalizeAnchors(CAGSlice & cag)117 virtual void finalizeAnchors(CAGSlice &cag) const {} 118 119 /// Whether an operand or result type is subject to analysis by this config. 120 virtual bool isHandledType(Type t) const = 0; 121 122 protected: 123 virtual void addValueIdentityOpByName(StringRef opName) = 0; 124 void addOpHandlerByName(StringRef name, OpHandlerFn fn); 125 126 private: 127 void addRequireStatsOpByName(StringRef opName); 128 129 /// Vector of all candidate type constraints, indexed by ordinal. 130 std::vector<CandidateQuantizedType> candidateTypes; 131 132 // A SmallBoolVector with bits set for all known candidate types. 133 llvm::SmallBitVector allCandidateTypesMask; 134 135 /// Map of all op handlers. 136 llvm::StringMap<OpHandlerFn> opHandlers; 137 138 /// Names of operations which should have their results annotated with 139 /// statistics. 140 llvm::StringSet<> requireStatsOpNames; 141 }; 142 143 } // namespace quantizer 144 } // namespace mlir 145 146 #endif // MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H 147