1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 *
22 * \file calibrate.cc
23 *
24 * \brief Create profile graph and calibrate on dataset
25 */
26 #include <tvm/relay/analysis.h>
27 #include <tvm/relay/expr_functor.h>
28 #include "./quantize.h"
29
30
31 namespace tvm {
32 namespace relay {
33 namespace quantize {
34
35 class StatsCollector : private ExprMutator {
36 public:
Collect(const Expr & expr)37 Expr Collect(const Expr& expr) {
38 auto new_e = this->Mutate(expr);
39 const FunctionNode* func = new_e.as<FunctionNode>();
40 CHECK(func) << "Input shoule be Function";
41 Expr new_body = TupleNode::make(std::move(profile_data_));
42 return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
43 func->attrs);
44 }
45
46 private:
47 Array<Expr> profile_data_;
48
VisitExpr_(const CallNode * call)49 Expr VisitExpr_(const CallNode* call) {
50 static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
51 Expr new_e = ExprMutator::VisitExpr_(call);
52 const CallNode* new_call = new_e.as<CallNode>();
53 CHECK(new_call);
54 if (new_call->op.same_as(simulated_quantize)) {
55 auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
56 // rewrite the annotation
57 auto new_attrs = make_node<SimulatedQuantizeAttrs>();
58 const Expr& quantize_input = new_call->args[0]; // expression being quantized
59 auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument
60 Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};
61 new_attrs->kind = QAnnotateKind::kQIdentity;
62 new_attrs->sign = attrs->sign;
63 new_attrs->rounding = attrs->rounding;
64 Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {});
65
66 // add non-const expressions to profile data
67 if (attrs->kind != QAnnotateKind::kQWeight) {
68 CHECK(!quantize_input.as<ConstantNode>());
69 profile_data_.push_back(identity_quantize);
70 }
71 return identity_quantize;
72 } else {
73 return new_e;
74 }
75 }
76 };
77
78 /*
79 * \brief Given an annotated graph, create a profile graph to collect profile data from the
80 * calibration dataset.
81 *
82 * This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to
83 * identity mode. The tuple is the output of the profile graph. Both input and output of this pass
84 * are relay::Function.
85 *
86 * \param expr The simulation graph after annotation.
87 * \return The profile graph.
88 */
CreateStatsCollector(const Expr & expr)89 Expr CreateStatsCollector(const Expr& expr) {
90 return StatsCollector().Collect(expr);
91 }
92
93 TVM_REGISTER_API("relay._quantize.CreateStatsCollector")
94 .set_body_typed(CreateStatsCollector);
95
96 } // namespace quantize
97 } // namespace relay
98 } // namespace tvm
99