1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/relay/quantize.h
22  * \brief Header of definitions for quantization
23  */
24 #ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_
25 #define TVM_RELAY_QUANTIZE_QUANTIZE_H_
26 
27 #include <tvm/relay/expr.h>
28 #include <tvm/relay/op.h>
29 
30 #include <string>
31 
32 #include "../transforms/pattern_util.h"
33 
34 namespace tvm {
35 namespace relay {
36 namespace quantize {
37 
38 /*! \brief Kind of annotate field */
39 enum QAnnotateKind : int { kQIdentity = 0, kQInput = 1, kQWeight = 2, kQActivation = 3 };
40 
41 /*! \brief Attribute for simulated quantize operator */
42 struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
43   int kind;
44   bool sign;
45   std::string rounding;
46 
47   TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
48     TVM_ATTR_FIELD(kind).describe("kind of field, hint for nbit/dtype configuration.");
49     TVM_ATTR_FIELD(sign).set_default(true).describe("whether to use signed data type.");
50     TVM_ATTR_FIELD(rounding).set_default("round").describe(
51         "rounding mode. Can be 'floor', 'ceil', 'round'");
52   }
53 };
54 
55 class QConfig;
56 /*!
57  * \brief Container for build configuration options
58  */
59 class QConfigNode : public Object {
60  public:
61   int nbit_input = 8;
62   int nbit_weight = 8;
63   int nbit_activation = 32;
64   DataType dtype_input = DataType::Int(8);
65   DataType dtype_weight = DataType::Int(8);
66   DataType dtype_activation = DataType::Int(32);
67   std::string calibrate_mode = "global_scale";
68   double global_scale = 8.0;
69   std::string weight_scale = "power2";
70   bool skip_dense_layer = true;
71   Array<Expr> skip_conv_layers = Array<Expr>(ObjectPtr<Object>(nullptr));
72   bool do_simulation = false;
73   bool round_for_shift = true;
74   Array<Expr> debug_enabled_ops = Array<Expr>(ObjectPtr<Object>(nullptr));
75   std::string rounding = "UPWARD";
76   int calibrate_chunk_by = -1;
77   std::string partition_conversions = "disabled";
78 
VisitAttrs(AttrVisitor * v)79   void VisitAttrs(AttrVisitor* v) {
80     v->Visit("nbit_input", &nbit_input);
81     v->Visit("nbit_weight", &nbit_weight);
82     v->Visit("nbit_activation", &nbit_activation);
83     v->Visit("dtype_input", &dtype_input);
84     v->Visit("dtype_weight", &dtype_weight);
85     v->Visit("dtype_activation", &dtype_activation);
86     v->Visit("calibrate_mode", &calibrate_mode);
87     v->Visit("global_scale", &global_scale);
88     v->Visit("weight_scale", &weight_scale);
89     v->Visit("skip_dense_layer", &skip_dense_layer);
90     v->Visit("skip_conv_layers", &skip_conv_layers);
91     v->Visit("do_simulation", &do_simulation);
92     v->Visit("round_for_shift", &round_for_shift);
93     v->Visit("debug_enabled_ops", &debug_enabled_ops);
94     v->Visit("rounding", &rounding);
95     v->Visit("calibrate_chunk_by", &calibrate_chunk_by);
96     v->Visit("partition_conversions", &partition_conversions);
97   }
98 
99   static constexpr const char* _type_key = "relay.quantize.QConfig";
100   TVM_DECLARE_FINAL_OBJECT_INFO(QConfigNode, Object);
101 };
102 
103 /*!
104  * \brief Container for build configuration options
105  */
106 class QConfig : public ObjectRef {
107  public:
QConfig()108   QConfig() {}
QConfig(ObjectPtr<Object> n)109   explicit QConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
110 
111   const QConfigNode* operator->() const { return static_cast<const QConfigNode*>(get()); }
112 
113   QConfigNode* operator->() { return static_cast<QConfigNode*>(get_mutable()); }
114 
115   /*!
116    * \brief Push a new BuildConfig context onto the thread local stack.
117    * \param build_config The configuration to set as the current context.
118    */
119   static void EnterQConfigScope(const QConfig& qconfig);
120 
121   /*!
122    * \brief Pop a build config off the thread local context stack, restoring the previous
123    * configuration as the current context.
124    */
125   static void ExitQConfigScope();
126 
127   /*!
128    * \brief Get the current BuildConfig context from thread local storage, or a default
129    * configuration if a BuildConfig scope has not been entered.
130    * \return The configuration that is the current context.
131    */
132   static QConfig& Current();
133 
134   using ContainerType = QConfigNode;
135 };
136 
137 /*!
138  * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
139  * context stack when constructed, and pops it when destructed.
140  */
141 struct QConfigContext {
142   /*!
143    * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
144    * context. When the BuildConfigContext is destructed, the previous context is restored.
145    * \param build_config The BuildConfig to set as the new current context.
146    */
QConfigContextQConfigContext147   explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); }
148 
149   /*! \brief Destructor. Pops the context off the thread local stack. */
~QConfigContextQConfigContext150   ~QConfigContext() { QConfig::ExitQConfigScope(); }
151 };
152 
153 }  // namespace quantize
154 }  // namespace relay
155 }  // namespace tvm
156 #endif  // TVM_RELAY_QUANTIZE_QUANTIZE_H_
157