1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file feature.cc
22  * \brief Detect features used in Expr/Module
23  */
24 #include <tvm/relay/feature.h>
25 #include <tvm/relay/analysis.h>
26 #include <tvm/relay/expr.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/relay/module.h>
29 #include "pass_util.h"
30 
31 namespace tvm {
32 namespace relay {
33 
DetectFeature(const Expr & expr)34 FeatureSet DetectFeature(const Expr& expr) {
35   if (!expr.defined()) {
36     return FeatureSet::No();
37   }
38   struct FeatureDetector : ExprVisitor {
39     std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
40     FeatureSet fs = FeatureSet::No();
41 
42     void VisitExpr(const Expr& expr) final {
43       if (visited_.count(expr) == 0) {
44         visited_.insert(expr);
45         ExprVisitor::VisitExpr(expr);
46       } else {
47         if (!IsAtomic(expr)) {
48           fs += fGraph;
49         }
50       }
51     }
52 #define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT)            \
53   void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
54     STMT                                                  \
55     fs += f##CONSTRUCT_NAME;                              \
56   }
57 #define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \
58     ExprVisitor::VisitExpr_(op);                                                    \
59   })
60     DETECT_DEFAULT_CONSTRUCT(Var)
61     DETECT_DEFAULT_CONSTRUCT(GlobalVar)
62     DETECT_DEFAULT_CONSTRUCT(Constant)
63     DETECT_DEFAULT_CONSTRUCT(Tuple)
64     DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
65     DETECT_CONSTRUCT(Function, {
66         if (!op->IsPrimitive()) {
67           ExprVisitor::VisitExpr_(op);
68         }
69       })
70     DETECT_DEFAULT_CONSTRUCT(Op)
71     DETECT_DEFAULT_CONSTRUCT(Call)
72     DETECT_CONSTRUCT(Let, {
73         for (const Var& v : FreeVars(op->value)) {
74           if (op->var == v) {
75             fs += fLetRec;
76           }
77         }
78         ExprVisitor::VisitExpr_(op);
79       })
80     DETECT_DEFAULT_CONSTRUCT(If)
81     DETECT_DEFAULT_CONSTRUCT(RefCreate)
82     DETECT_DEFAULT_CONSTRUCT(RefRead)
83     DETECT_DEFAULT_CONSTRUCT(RefWrite)
84     DETECT_DEFAULT_CONSTRUCT(Constructor)
85     DETECT_DEFAULT_CONSTRUCT(Match)
86 #undef DETECT_DEFAULT_CONSTRUCT
87   } fd;
88   fd(expr);
89   return fd.fs;
90 }
91 
DetectFeature(const Module & mod)92 FeatureSet DetectFeature(const Module& mod) {
93   FeatureSet fs = FeatureSet::No();
94   if (mod.defined()) {
95     for (const auto& f : mod->functions) {
96       fs += DetectFeature(f.second);
97     }
98   }
99   return fs;
100 }
101 
PyDetectFeature(const Expr & expr,const Module & mod)102 Array<Integer> PyDetectFeature(const Expr& expr, const Module& mod) {
103   FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
104   return static_cast<Array<Integer>>(fs);
105 }
106 
107 TVM_REGISTER_API("relay._analysis.detect_feature")
108 .set_body_typed(PyDetectFeature);
109 
110 }  // namespace relay
111 }  // namespace tvm
112