1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file graph_fuse.h
22 * \brief Definition of structs used by graph fusion
23 */
24 #ifndef NNVM_COMPILER_GRAPH_FUSE_H_
25 #define NNVM_COMPILER_GRAPH_FUSE_H_
26
27 #include <nnvm/graph.h>
28 #include <vector>
29 #include <unordered_map>
30
31 #include "compile_engine.h"
32
33 namespace nnvm {
34 namespace compiler {
35
36 // The single fuse rule.
37 enum class FuseRule {
38 kUknown,
39 kFuseToMaster,
40 kRealize
41 };
42
43 /*!
44 * \brief Get DLDataType from dtype flag.
45 *
46 * \param type_flag The data type flag
47 * \return corresponding DLDataType
48 */
GetDLType(int type_flag)49 inline DLDataType GetDLType(int type_flag) {
50 return tvm::Type2TVMType(GetTVMType(type_flag));
51 }
52
53 struct INodeEntryHash {
operatorINodeEntryHash54 size_t operator()(const IndexedGraph::NodeEntry& e) const {
55 return e.node_id;
56 }
57 };
58
59 struct INodeEntryEqual {
operatorINodeEntryEqual60 size_t operator()(const IndexedGraph::NodeEntry &a,
61 const IndexedGraph::NodeEntry &b) const {
62 return a.node_id == b.node_id && a.index == b.index;
63 }
64 };
65
66 // Auxiliary data structure for representing fused op.
67 struct FuseEntry {
68 // Subgraph of the fragment
69 Graph subgraph;
70 // The input map
71 std::unordered_map<IndexedGraph::NodeEntry, nnvm::NodeEntry, INodeEntryHash,
72 INodeEntryEqual>
73 imap;
74 // Reverse map to the old input entry
75 std::unordered_map<const Node *, IndexedGraph::NodeEntry> reverse_imap;
76 // TVM Placeholder for inputs
77 std::unordered_map<const Node *, Tensor> input_info;
78 // Whether we can flatten data
79 bool flatten_data;
80 // The corresponding function.
81 GraphFunc compiled_func;
82 };
83
84 // GroupVec stores the root node ids of the fused nodes.
85 using GroupVec = std::vector<int>;
86
87 // MasterVec stores master node ids of fused groups.
88 using MasterVec = std::vector<int>;
89
90 // FuseVec stores fused entries.
91 using FuseEntryVec = std::vector<FuseEntry>;
92
93 // PatternVec stores operator patterns.
94 using PatternVec = std::vector<TOpPattern>;
95
96 } // namespace compiler
97 } // namespace nnvm
98
99 #endif // NNVM_COMPILER_GRAPH_FUSE_H_
100