1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file graph_fuse.h
22  * \brief Definition of structs used by graph fusion
23 */
24 #ifndef NNVM_COMPILER_GRAPH_FUSE_H_
25 #define NNVM_COMPILER_GRAPH_FUSE_H_
26 
27 #include <nnvm/graph.h>
28 #include <vector>
29 #include <unordered_map>
30 
31 #include "compile_engine.h"
32 
33 namespace nnvm {
34 namespace compiler {
35 
36 // The single fuse rule.
37 enum class FuseRule {
38   kUknown,
39   kFuseToMaster,
40   kRealize
41 };
42 
43 /*!
44  * \brief Get DLDataType from dtype flag.
45  *
46  * \param type_flag The data type flag
47  * \return corresponding DLDataType
48  */
GetDLType(int type_flag)49 inline DLDataType GetDLType(int type_flag) {
50   return tvm::Type2TVMType(GetTVMType(type_flag));
51 }
52 
53 struct INodeEntryHash {
operatorINodeEntryHash54   size_t operator()(const IndexedGraph::NodeEntry& e) const {
55     return e.node_id;
56   }
57 };
58 
59 struct INodeEntryEqual {
operatorINodeEntryEqual60   size_t operator()(const IndexedGraph::NodeEntry &a,
61                     const IndexedGraph::NodeEntry &b) const {
62     return a.node_id == b.node_id && a.index == b.index;
63   }
64 };
65 
66 // Auxiliary data structure for representing fused op.
67 struct FuseEntry {
68   // Subgraph of the fragment
69   Graph subgraph;
70   // The input map
71   std::unordered_map<IndexedGraph::NodeEntry, nnvm::NodeEntry, INodeEntryHash,
72                      INodeEntryEqual>
73       imap;
74   // Reverse map to the old input entry
75   std::unordered_map<const Node *, IndexedGraph::NodeEntry> reverse_imap;
76   // TVM Placeholder for inputs
77   std::unordered_map<const Node *, Tensor> input_info;
78   // Whether we can flatten data
79   bool flatten_data;
80   // The corresponding function.
81   GraphFunc compiled_func;
82 };
83 
84 // GroupVec stores the root node ids of the fused nodes.
85 using GroupVec = std::vector<int>;
86 
87 // MasterVec stores master node ids of fused groups.
88 using MasterVec = std::vector<int>;
89 
90 // FuseVec stores fused entries.
91 using FuseEntryVec = std::vector<FuseEntry>;
92 
93 // PatternVec stores operator patterns.
94 using PatternVec = std::vector<TOpPattern>;
95 
96 }  // namespace compiler
97 }  // namespace nnvm
98 
99 #endif  // NNVM_COMPILER_GRAPH_FUSE_H_
100