1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file feature.cc
22 * \brief Detect features used in Expr/Module
23 */
24 #include <tvm/relay/feature.h>
25 #include <tvm/relay/analysis.h>
26 #include <tvm/relay/expr.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/relay/module.h>
29 #include "pass_util.h"
30
31 namespace tvm {
32 namespace relay {
33
DetectFeature(const Expr & expr)34 FeatureSet DetectFeature(const Expr& expr) {
35 if (!expr.defined()) {
36 return FeatureSet::No();
37 }
38 struct FeatureDetector : ExprVisitor {
39 std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
40 FeatureSet fs = FeatureSet::No();
41
42 void VisitExpr(const Expr& expr) final {
43 if (visited_.count(expr) == 0) {
44 visited_.insert(expr);
45 ExprVisitor::VisitExpr(expr);
46 } else {
47 if (!IsAtomic(expr)) {
48 fs += fGraph;
49 }
50 }
51 }
52 #define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \
53 void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
54 STMT \
55 fs += f##CONSTRUCT_NAME; \
56 }
57 #define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \
58 ExprVisitor::VisitExpr_(op); \
59 })
60 DETECT_DEFAULT_CONSTRUCT(Var)
61 DETECT_DEFAULT_CONSTRUCT(GlobalVar)
62 DETECT_DEFAULT_CONSTRUCT(Constant)
63 DETECT_DEFAULT_CONSTRUCT(Tuple)
64 DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
65 DETECT_CONSTRUCT(Function, {
66 if (!op->IsPrimitive()) {
67 ExprVisitor::VisitExpr_(op);
68 }
69 })
70 DETECT_DEFAULT_CONSTRUCT(Op)
71 DETECT_DEFAULT_CONSTRUCT(Call)
72 DETECT_CONSTRUCT(Let, {
73 for (const Var& v : FreeVars(op->value)) {
74 if (op->var == v) {
75 fs += fLetRec;
76 }
77 }
78 ExprVisitor::VisitExpr_(op);
79 })
80 DETECT_DEFAULT_CONSTRUCT(If)
81 DETECT_DEFAULT_CONSTRUCT(RefCreate)
82 DETECT_DEFAULT_CONSTRUCT(RefRead)
83 DETECT_DEFAULT_CONSTRUCT(RefWrite)
84 DETECT_DEFAULT_CONSTRUCT(Constructor)
85 DETECT_DEFAULT_CONSTRUCT(Match)
86 #undef DETECT_DEFAULT_CONSTRUCT
87 } fd;
88 fd(expr);
89 return fd.fs;
90 }
91
DetectFeature(const Module & mod)92 FeatureSet DetectFeature(const Module& mod) {
93 FeatureSet fs = FeatureSet::No();
94 if (mod.defined()) {
95 for (const auto& f : mod->functions) {
96 fs += DetectFeature(f.second);
97 }
98 }
99 return fs;
100 }
101
PyDetectFeature(const Expr & expr,const Module & mod)102 Array<Integer> PyDetectFeature(const Expr& expr, const Module& mod) {
103 FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
104 return static_cast<Array<Integer>>(fs);
105 }
106
107 TVM_REGISTER_API("relay._analysis.detect_feature")
108 .set_body_typed(PyDetectFeature);
109
110 } // namespace relay
111 } // namespace tvm
112