1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \file calibrate.cc
23  *
24  * \brief Create profile graph and calibrate on dataset
25  */
26 #include <tvm/relay/analysis.h>
27 #include <tvm/relay/expr_functor.h>
28 #include "./quantize.h"
29 
30 
31 namespace tvm {
32 namespace relay {
33 namespace quantize {
34 
35 class StatsCollector : private ExprMutator {
36  public:
Collect(const Expr & expr)37   Expr Collect(const Expr& expr) {
38     auto new_e = this->Mutate(expr);
39     const FunctionNode* func = new_e.as<FunctionNode>();
40     CHECK(func) << "Input shoule be Function";
41     Expr new_body = TupleNode::make(std::move(profile_data_));
42     return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
43             func->attrs);
44   }
45 
46  private:
47   Array<Expr> profile_data_;
48 
VisitExpr_(const CallNode * call)49   Expr VisitExpr_(const CallNode* call) {
50     static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
51     Expr new_e = ExprMutator::VisitExpr_(call);
52     const CallNode* new_call = new_e.as<CallNode>();
53     CHECK(new_call);
54     if (new_call->op.same_as(simulated_quantize)) {
55       auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
56       // rewrite the annotation
57       auto new_attrs = make_node<SimulatedQuantizeAttrs>();
58       const Expr& quantize_input = new_call->args[0];  // expression being quantized
59       auto placeholder = MakeConstantScalar(Float(32), 0.);  // unused argument
60       Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};
61       new_attrs->kind = QAnnotateKind::kQIdentity;
62       new_attrs->sign = attrs->sign;
63       new_attrs->rounding = attrs->rounding;
64       Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {});
65 
66       // add non-const expressions to profile data
67       if (attrs->kind != QAnnotateKind::kQWeight) {
68         CHECK(!quantize_input.as<ConstantNode>());
69         profile_data_.push_back(identity_quantize);
70       }
71       return identity_quantize;
72     } else {
73       return new_e;
74     }
75   }
76 };
77 
78 /*
79  * \brief Given an annotated graph, create a profile graph to collect profile data from the
80  * calibration dataset.
81  *
82  * This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to
83  * identity mode. The tuple is the output of the profile graph. Both input and output of this pass
84  * are relay::Function.
85  *
86  * \param expr The simulation graph after annotation.
87  * \return The profile graph.
88  */
CreateStatsCollector(const Expr & expr)89 Expr CreateStatsCollector(const Expr& expr) {
90   return StatsCollector().Collect(expr);
91 }
92 
93 TVM_REGISTER_API("relay._quantize.CreateStatsCollector")
94 .set_body_typed(CreateStatsCollector);
95 
96 }  // namespace quantize
97 }  // namespace relay
98 }  // namespace tvm
99