1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \file partition.cc
23  *
24  * \brief Partition a graph into sections for quantization.
25  */
26 
27 #include <tvm/relay/transform.h>
28 #include "../pattern_util.h"
29 #include "./quantize.h"
30 
31 namespace tvm {
32 namespace relay {
33 namespace quantize {
34 
35 using namespace relay::transform;
36 
37 
38 class QPartitionExpr;
39 class QPartitionExprNode : public TempExprNode {
40  public:
41   /*! \brief The original expression */
42   Expr expr;
43 
VisitAttrs(tvm::AttrVisitor * v)44   void VisitAttrs(tvm::AttrVisitor* v) {
45     v->Visit("expr", &expr);
46   }
47 
48   TVM_DLL static QPartitionExpr make(Expr expr);
49 
50   Expr Realize() const final;
51 
52   static constexpr const char* _type_key = "relay.QPartitionExpr";
53   TVM_DECLARE_NODE_TYPE_INFO(QPartitionExprNode, TempExprNode);
54 };
55 
56 RELAY_DEFINE_NODE_REF(QPartitionExpr, QPartitionExprNode, TempExpr);
57 
58 
Realize() const59 Expr QPartitionExprNode::Realize() const {
60   // insert cast hint and stop fusion
61   const QConfig& cfg = QConfig::Current();
62   Expr ret = CastHint(this->expr, cfg->dtype_input);
63   return StopFusion(ret);
64 }
65 
make(Expr expr)66 QPartitionExpr QPartitionExprNode::make(Expr expr) {
67   auto rnode = make_node<QPartitionExprNode>();
68   rnode->expr = expr;
69   return QPartitionExpr(rnode);
70 }
71 
72 TVM_REGISTER_API("relay._quantize.make_partition_expr")
__anon8ca4605c0102(TVMArgs args, TVMRetValue *ret) 73 .set_body([](TVMArgs args,  TVMRetValue *ret) {
74     *ret = QPartitionExprNode::make(args[0]);
75   });
76 
QuantizePartition()77 Pass QuantizePartition() {
78   runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
79     [=](Function f, Module m, PassContext pc) {
80       auto ret = Downcast<Function>(
81           ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
82       return ret;
83   };
84   return CreateFunctionPass(pass_func, 1, "QuantizePartition", {});
85 }
86 
87 TVM_REGISTER_API("relay._quantize.QuantizePartition")
88 .set_body_typed(QuantizePartition);
89 
90 TVM_REGISTER_NODE_TYPE(QPartitionExprNode);
91 
92 }  // namespace quantize
93 }  // namespace relay
94 }  // namespace tvm
95