1<!--- Licensed to the Apache Software Foundation (ASF) under one -->
2<!--- or more contributor license agreements.  See the NOTICE file -->
3<!--- distributed with this work for additional information -->
4<!--- regarding copyright ownership.  The ASF licenses this file -->
5<!--- to you under the Apache License, Version 2.0 (the -->
6<!--- "License"); you may not use this file except in compliance -->
7<!--- with the License.  You may obtain a copy of the License at -->
8
9<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
10
11<!--- Unless required by applicable law or agreed to in writing, -->
12<!--- software distributed under the License is distributed on an -->
13<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
14<!--- KIND, either express or implied.  See the License for the -->
15<!--- specific language governing permissions and limitations -->
16<!--- under the License. -->
17
18Custom Graph Pass Example and Tutorial
19=======================================
20
21## Introduction
22
23Adding custom graph passes in MXNet used to require deep understanding of the MXNet backend, including nnvm pass registration and other internal classes, followed by recompiling MXNet from source. This feature allows adding custom graph passes by dynamically loading external libraries at runtime.
24
25This custom graph pass feature enables users to write custom model modification strategies without compiling against all of MXNet header files and dependencies. When a library containing custom passes is loaded dynamically, the components found in the library will be registered in MXNet so that users can use those natively just like other built-in components.
26
27## Getting Started
28
29### Have MXNet Ready
30
31To run the following example, the build type of MXNet doesn’t matter since the custom pass doesn’t interact with the execution of other native MXNet features. Note that if you want to use your custom pass with models running on GPU, you still need an MXNet CUDA build.
32
33### Run An Example
34
35You can start getting familiar with custom passes by running an example provided in the **example/extensions/lib_pass** directory. The `myPass` example just prints out the graph. Go to the **lib_pass** directory and follow these steps:
36
371. Run `make`. The Makefile will generate the dynamic library **libpass_lib.so** which is compiled from the `pass_lib.cc` file. This is the library you are going to load that contains everything for the custom pass.
382. Run `python test_pass.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then execute the pass on the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_pass.py` command. Notice that it loads 1 pass: `myPass`.
39
40```
41[10:38:03] src/c_api/c_api.cc:286: Found 0 operators in library
42[10:38:03] src/c_api/c_api.cc:785: Found 0 partitioners in library
43[07:14:00] src/c_api/c_api.cc:887: Found 1 graph passes in library
44[07:14:00] src/c_api/c_api.cc:902:       Graph Pass [0] myPass
45```
46
47### Basic Files For Custom Pass Library
48* **lib_pass/pass_lib.cc**: This file has a source code implementation of all required components to make a custom pass, it also shows registration of them so that they can be loaded by MXNet.
49* **lib_pass/Makefile**: This file compiles the source code to a dynamic shared library, with a header file `include/mxnet/lib_api.h` from MXNet source code. Currently the custom pass is compatible with C++11 and above.
50* **lib_pass/test_pass.py**: This file calls `mx.library.load(‘libpass_lib.so’)` to load the library containing the custom components, executes the pass on the model using the `optimize_for` API, and prints outputs of the forward passes. The outputs should be the same as the regular MXNet forward pass without running the pass.
51* **include/mxnet/lib_api.h**: This file from MXNet source code is the single header file needed to include all necessary data types and function prototypes for writing a custom library. You can either specify the include path in the `Makefile`, or copy the header file over to `example/extensions/lib_pass` folder. Note that apart from this header, the custom library is independent of MXNet source.
52## Writing Custom Pass Library
53To build your own library containing a custom pass, compose a C++ source file like `mypass_lib.cc`, include `lib_api.h` header file, and write your custom pass with these essential functions:
54- `initialize` - Library Initialization Function
55- `REGISTER_PASS` - Pass Registration Macro
56- `graphPass` - Pass Implementation
57Then compile it to the `mypass_lib.so` dynamic library using the following command:
58```bash
59g++ -shared -fPIC -std=c++11 mypass_lib.cc -o libmypass_lib.so -I ../../../include/mxnet
60```
61
62Finally, you can write a Python script to load the library and execute your pass on a model:
63
64```python
65import mxnet as mx
66mx.library.load(‘libmypass_lib.so’)
67sym, _, _ = mx.model.load_checkpoint('mymodel', 0)
68# Symbol/Module flow
69sym2 = sym.optimize_for("myPass")
70# Gluon flow 1
71sym_block = nn.SymbolBlock(sym, inputs)
72sym_block.hybridize(backend='myPass')
73# Gluon flow 2
74sym_block = nn.SymbolBlock(sym, inputs)
75sym_block.optimize_for(x, backend='myPass')
76```
77
78### Using a Custom Pass Library
79
80APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, `optimize_for` can be called on Symbol objects to run the graph pass and return a new Symbol.
81
82```python
83sym.optimize_for(backend, args=None, aux=None, ctx=None, **kwargs)
84```
85
86The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to use to optimize the model. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before executing the graph pass. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will be passed to the backend APIs.
87
88For the Gluon API, `hybridize` can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol.
89
90```python
91block.hybridize(backend=None, **kwargs)
92```
93
94The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. `**kwargs` might contain other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass.
95
96If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.
97
98```python
99block.optimize_for(x, backend=None, **kwargs)
100```
101
102When the `optimize_for` API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass.
103
104```python
105block.optimize_for(x, backend='myPass')
106block.export('optimized')
107```
108
109But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too.
110
111```python
112block.optimize_for(x, backend='myPass')
113block(x)
114```
115
116### Writing A Custom Graph Pass
117
118There are several essential building blocks for making a custom pass:
119
120* [initialize](./pass_lib.cc#44):
121    * This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded.
122```c++
123            MXReturnValue initialize(int version)
124```
125* [graphPass](./pass_lib.cc#31):
126    * This function provides a copy of the model graph, and any specific options from the user.
127```c++
128            MXReturnValue graphPass(
129                mxnet::ext::Graph *g,
130                const std::unordered_map<std::string, std::string>& options)
131```
132* [REGISTER_PASS(my_pass_name)](./pass_lib.cc#L41):
133    * This macro registers the custom pass and its properties to MXNet by its name. The argument to `setBody` is the `graphPass` function.
134```c++
135            REGISTER_PASS(my_pass_name)
136            .setBody(graphPass);
137```
138Let’s take a closer look at those registry functions:
139
140* **graphPass**: This function takes two arguments. The first argument is the Graph of the model architecture, where nodes are inputs/params/weights and edges are data dependencies. The second argument is the map of options specified by the user. Users can pass custom options to the pass and they are passed to this function in the `options` map.
141
142### Graph representation
143
144The `Graph` class represents the model's architecture. Each `Node` in the graph represents an operator or weight (ie. args/aux param). Since an operator in MXNet can take multiple inputs and produce multiple outputs, each input/output is represented by a `NodeEntry`. A `Node` contains the following:
145- `op` - [string] operator name
146- `name` - [string] unique node name
147- `inputs` - [vector of NodeEntry] set of inputs to the node
148- `outputs` - [vector of NodeEntry] set of outputs from the node
149- `subgraph` - [vector of Graph] set of subgraphs in the node
150- `attrs` - [map of string to string] set of attributes for the node
151
152The `inputs` are a set of `NodeEntry` where each contains a pointer to a `Node` that produces the data, and an `entry` that is the index of the output on the other `Node`. Conversely, the `outputs` are a set of `NodeEntry` where each contains a pointer to a`Node` that consumes the data, and and `entry` that is the index of the input on the other `Node`. This bidirectional dependency will enable you to easily traverse the graph.
153
154A `Graph` contains the following:
155- `nodes` - [vector of Node] set of nodes in the graph
156- `inputs` - [vector of Node] set of inputs to the graph
157- `outputs` - [vector of NodeEntry] set of outputs from the graph
158- `attrs` - [map of string to JSON object] set of attributes for the graph
159
160The `nodes` are all the nodes in the graph (superset). The `inputs` are only those nodes that are model inputs (ie. input image) or weights (ie. arg/aux params). The `outputs` are the outputs from the operators in the model that are true outputs of the model (ie. prediction results).
161
162Heres an example creating a new node and adding it to the graph:
163```c++
164g->addNode("myConv","Convolution");
165```
166Heres an example creating an edge between two nodes:
167```c++
168n1->outputs.push_back({n2,1});
169n2->inputs.push_back({n1,0});
170```
171Here node `n1` produces an output at index 0 that is consumed by node `n2` on the input at index 1.
172
173![example connection](example_connection.png)
174
175Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enable allocating new NDArrays and integrate them with the model args and aux params. Both APIs have the following signature:
176
177```c++
178    MXTensor* alloc_xxx(const std::vector<int64_t>& shapes,
179                        const MXContext &ctx,
180                        MXDType dtype)
181```
182
183This function can be called on a node in the graph to allocate a tensor for that node like:
184
185```c++
186node->alloc_arg({1},MXContext::CPU(0),kFloat32);
187```
188It adds a new param to the appropriate arg/aux set when the graph pass returns. If you wish to remove an existing param, just remove the node in the graph corresponding to that param. It will be deleted after the pass completes and removed from the dictionary of args or aux (whichever it is a member of).
189
190### Parsing a JSON string
191
192To simplify custom libraries, basic JSON parsing utility functions have been implemented in the `lib_api.h` header file. You create a `JsonParser` object and parse the string by calling the `parse_to_json` API like:
193
194```c++
195JsonVal json_val = JsonVal::parse(json);
196```
197
198A `JsonVal` is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the `JsonVal.type` to `STR`, `NUM`, `LIST`, or `MAP`. Then you can get that value from the node like:
199
200```c++
201switch(json_val.type) {
202  case STR:
203    std::string str = json_val.str;
204    break;
205  case NUM:
206    int num = json_val.num;
207    break;
208  case LIST:
209    std::vector<JsonVal> list = json_val.list;
210    break;
211  case MAP:
212    std::map<JsonVal, JsonVal> map = json_val.map;
213    break;
214  default:
215    // error
216}
217```
218
219You call the `dump` function on a `JsonVal` object like `json_val.dump()` to get a JSON-compatible string. There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`.
220