1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/relay/quantize.h 22 * \brief Header of definitions for quantization 23 */ 24 #ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_ 25 #define TVM_RELAY_QUANTIZE_QUANTIZE_H_ 26 27 #include <tvm/relay/expr.h> 28 #include <tvm/relay/op.h> 29 30 #include <string> 31 32 #include "../transforms/pattern_util.h" 33 34 namespace tvm { 35 namespace relay { 36 namespace quantize { 37 38 /*! \brief Kind of annotate field */ 39 enum QAnnotateKind : int { kQIdentity = 0, kQInput = 1, kQWeight = 2, kQActivation = 3 }; 40 41 /*! \brief Attribute for simulated quantize operator */ 42 struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> { 43 int kind; 44 bool sign; 45 std::string rounding; 46 47 TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { 48 TVM_ATTR_FIELD(kind).describe("kind of field, hint for nbit/dtype configuration."); 49 TVM_ATTR_FIELD(sign).set_default(true).describe("whether to use signed data type."); 50 TVM_ATTR_FIELD(rounding).set_default("round").describe( 51 "rounding mode. Can be 'floor', 'ceil', 'round'"); 52 } 53 }; 54 55 class QConfig; 56 /*! 57 * \brief Container for build configuration options 58 */ 59 class QConfigNode : public Object { 60 public: 61 int nbit_input = 8; 62 int nbit_weight = 8; 63 int nbit_activation = 32; 64 DataType dtype_input = DataType::Int(8); 65 DataType dtype_weight = DataType::Int(8); 66 DataType dtype_activation = DataType::Int(32); 67 std::string calibrate_mode = "global_scale"; 68 double global_scale = 8.0; 69 std::string weight_scale = "power2"; 70 bool skip_dense_layer = true; 71 Array<Expr> skip_conv_layers = Array<Expr>(ObjectPtr<Object>(nullptr)); 72 bool do_simulation = false; 73 bool round_for_shift = true; 74 Array<Expr> debug_enabled_ops = Array<Expr>(ObjectPtr<Object>(nullptr)); 75 std::string rounding = "UPWARD"; 76 int calibrate_chunk_by = -1; 77 std::string partition_conversions = "disabled"; 78 VisitAttrs(AttrVisitor * v)79 void VisitAttrs(AttrVisitor* v) { 80 v->Visit("nbit_input", &nbit_input); 81 v->Visit("nbit_weight", &nbit_weight); 82 v->Visit("nbit_activation", &nbit_activation); 83 v->Visit("dtype_input", &dtype_input); 84 v->Visit("dtype_weight", &dtype_weight); 85 v->Visit("dtype_activation", &dtype_activation); 86 v->Visit("calibrate_mode", &calibrate_mode); 87 v->Visit("global_scale", &global_scale); 88 v->Visit("weight_scale", &weight_scale); 89 v->Visit("skip_dense_layer", &skip_dense_layer); 90 v->Visit("skip_conv_layers", &skip_conv_layers); 91 v->Visit("do_simulation", &do_simulation); 92 v->Visit("round_for_shift", &round_for_shift); 93 v->Visit("debug_enabled_ops", &debug_enabled_ops); 94 v->Visit("rounding", &rounding); 95 v->Visit("calibrate_chunk_by", &calibrate_chunk_by); 96 v->Visit("partition_conversions", &partition_conversions); 97 } 98 99 static constexpr const char* _type_key = "relay.quantize.QConfig"; 100 TVM_DECLARE_FINAL_OBJECT_INFO(QConfigNode, Object); 101 }; 102 103 /*! 104 * \brief Container for build configuration options 105 */ 106 class QConfig : public ObjectRef { 107 public: QConfig()108 QConfig() {} QConfig(ObjectPtr<Object> n)109 explicit QConfig(ObjectPtr<Object> n) : ObjectRef(n) {} 110 111 const QConfigNode* operator->() const { return static_cast<const QConfigNode*>(get()); } 112 113 QConfigNode* operator->() { return static_cast<QConfigNode*>(get_mutable()); } 114 115 /*! 116 * \brief Push a new BuildConfig context onto the thread local stack. 117 * \param build_config The configuration to set as the current context. 118 */ 119 static void EnterQConfigScope(const QConfig& qconfig); 120 121 /*! 122 * \brief Pop a build config off the thread local context stack, restoring the previous 123 * configuration as the current context. 124 */ 125 static void ExitQConfigScope(); 126 127 /*! 128 * \brief Get the current BuildConfig context from thread local storage, or a default 129 * configuration if a BuildConfig scope has not been entered. 130 * \return The configuration that is the current context. 131 */ 132 static QConfig& Current(); 133 134 using ContainerType = QConfigNode; 135 }; 136 137 /*! 138 * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the 139 * context stack when constructed, and pops it when destructed. 140 */ 141 struct QConfigContext { 142 /*! 143 * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current 144 * context. When the BuildConfigContext is destructed, the previous context is restored. 145 * \param build_config The BuildConfig to set as the new current context. 146 */ QConfigContextQConfigContext147 explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); } 148 149 /*! \brief Destructor. Pops the context off the thread local stack. */ ~QConfigContextQConfigContext150 ~QConfigContext() { QConfig::ExitQConfigScope(); } 151 }; 152 153 } // namespace quantize 154 } // namespace relay 155 } // namespace tvm 156 #endif // TVM_RELAY_QUANTIZE_QUANTIZE_H_ 157