1<!--- Licensed to the Apache Software Foundation (ASF) under one --> 2<!--- or more contributor license agreements. See the NOTICE file --> 3<!--- distributed with this work for additional information --> 4<!--- regarding copyright ownership. The ASF licenses this file --> 5<!--- to you under the Apache License, Version 2.0 (the --> 6<!--- "License"); you may not use this file except in compliance --> 7<!--- with the License. You may obtain a copy of the License at --> 8 9<!--- http://www.apache.org/licenses/LICENSE-2.0 --> 10 11<!--- Unless required by applicable law or agreed to in writing, --> 12<!--- software distributed under the License is distributed on an --> 13<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> 14<!--- KIND, either express or implied. See the License for the --> 15<!--- specific language governing permissions and limitations --> 16<!--- under the License. --> 17 18Custom Graph Pass Example and Tutorial 19======================================= 20 21## Introduction 22 23Adding custom graph passes in MXNet used to require deep understanding of the MXNet backend, including nnvm pass registration and other internal classes, followed by recompiling MXNet from source. This feature allows adding custom graph passes by dynamically loading external libraries at runtime. 24 25This custom graph pass feature enables users to write custom model modification strategies without compiling against all of MXNet header files and dependencies. When a library containing custom passes is loaded dynamically, the components found in the library will be registered in MXNet so that users can use those natively just like other built-in components. 26 27## Getting Started 28 29### Have MXNet Ready 30 31To run the following example, the build type of MXNet doesn’t matter since the custom pass doesn’t interact with the execution of other native MXNet features. Note that if you want to use your custom pass with models running on GPU, you still need an MXNet CUDA build. 32 33### Run An Example 34 35You can start getting familiar with custom passes by running an example provided in the **example/extensions/lib_pass** directory. The `myPass` example just prints out the graph. Go to the **lib_pass** directory and follow these steps: 36 371. Run `make`. The Makefile will generate the dynamic library **libpass_lib.so** which is compiled from the `pass_lib.cc` file. This is the library you are going to load that contains everything for the custom pass. 382. Run `python test_pass.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then execute the pass on the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_pass.py` command. Notice that it loads 1 pass: `myPass`. 39 40``` 41[10:38:03] src/c_api/c_api.cc:286: Found 0 operators in library 42[10:38:03] src/c_api/c_api.cc:785: Found 0 partitioners in library 43[07:14:00] src/c_api/c_api.cc:887: Found 1 graph passes in library 44[07:14:00] src/c_api/c_api.cc:902: Graph Pass [0] myPass 45``` 46 47### Basic Files For Custom Pass Library 48* **lib_pass/pass_lib.cc**: This file has a source code implementation of all required components to make a custom pass, it also shows registration of them so that they can be loaded by MXNet. 49* **lib_pass/Makefile**: This file compiles the source code to a dynamic shared library, with a header file `include/mxnet/lib_api.h` from MXNet source code. Currently the custom pass is compatible with C++11 and above. 50* **lib_pass/test_pass.py**: This file calls `mx.library.load(‘libpass_lib.so’)` to load the library containing the custom components, executes the pass on the model using the `optimize_for` API, and prints outputs of the forward passes. The outputs should be the same as the regular MXNet forward pass without running the pass. 51* **include/mxnet/lib_api.h**: This file from MXNet source code is the single header file needed to include all necessary data types and function prototypes for writing a custom library. You can either specify the include path in the `Makefile`, or copy the header file over to `example/extensions/lib_pass` folder. Note that apart from this header, the custom library is independent of MXNet source. 52## Writing Custom Pass Library 53To build your own library containing a custom pass, compose a C++ source file like `mypass_lib.cc`, include `lib_api.h` header file, and write your custom pass with these essential functions: 54- `initialize` - Library Initialization Function 55- `REGISTER_PASS` - Pass Registration Macro 56- `graphPass` - Pass Implementation 57Then compile it to the `mypass_lib.so` dynamic library using the following command: 58```bash 59g++ -shared -fPIC -std=c++11 mypass_lib.cc -o libmypass_lib.so -I ../../../include/mxnet 60``` 61 62Finally, you can write a Python script to load the library and execute your pass on a model: 63 64```python 65import mxnet as mx 66mx.library.load(‘libmypass_lib.so’) 67sym, _, _ = mx.model.load_checkpoint('mymodel', 0) 68# Symbol/Module flow 69sym2 = sym.optimize_for("myPass") 70# Gluon flow 1 71sym_block = nn.SymbolBlock(sym, inputs) 72sym_block.hybridize(backend='myPass') 73# Gluon flow 2 74sym_block = nn.SymbolBlock(sym, inputs) 75sym_block.optimize_for(x, backend='myPass') 76``` 77 78### Using a Custom Pass Library 79 80APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, `optimize_for` can be called on Symbol objects to run the graph pass and return a new Symbol. 81 82```python 83sym.optimize_for(backend, args=None, aux=None, ctx=None, **kwargs) 84``` 85 86The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to use to optimize the model. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before executing the graph pass. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will be passed to the backend APIs. 87 88For the Gluon API, `hybridize` can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol. 89 90```python 91block.hybridize(backend=None, **kwargs) 92``` 93 94The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. `**kwargs` might contain other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass. 95 96If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass. 97 98```python 99block.optimize_for(x, backend=None, **kwargs) 100``` 101 102When the `optimize_for` API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass. 103 104```python 105block.optimize_for(x, backend='myPass') 106block.export('optimized') 107``` 108 109But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too. 110 111```python 112block.optimize_for(x, backend='myPass') 113block(x) 114``` 115 116### Writing A Custom Graph Pass 117 118There are several essential building blocks for making a custom pass: 119 120* [initialize](./pass_lib.cc#44): 121 * This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded. 122```c++ 123 MXReturnValue initialize(int version) 124``` 125* [graphPass](./pass_lib.cc#31): 126 * This function provides a copy of the model graph, and any specific options from the user. 127```c++ 128 MXReturnValue graphPass( 129 mxnet::ext::Graph *g, 130 const std::unordered_map<std::string, std::string>& options) 131``` 132* [REGISTER_PASS(my_pass_name)](./pass_lib.cc#L41): 133 * This macro registers the custom pass and its properties to MXNet by its name. The argument to `setBody` is the `graphPass` function. 134```c++ 135 REGISTER_PASS(my_pass_name) 136 .setBody(graphPass); 137``` 138Let’s take a closer look at those registry functions: 139 140* **graphPass**: This function takes two arguments. The first argument is the Graph of the model architecture, where nodes are inputs/params/weights and edges are data dependencies. The second argument is the map of options specified by the user. Users can pass custom options to the pass and they are passed to this function in the `options` map. 141 142### Graph representation 143 144The `Graph` class represents the model's architecture. Each `Node` in the graph represents an operator or weight (ie. args/aux param). Since an operator in MXNet can take multiple inputs and produce multiple outputs, each input/output is represented by a `NodeEntry`. A `Node` contains the following: 145- `op` - [string] operator name 146- `name` - [string] unique node name 147- `inputs` - [vector of NodeEntry] set of inputs to the node 148- `outputs` - [vector of NodeEntry] set of outputs from the node 149- `subgraph` - [vector of Graph] set of subgraphs in the node 150- `attrs` - [map of string to string] set of attributes for the node 151 152The `inputs` are a set of `NodeEntry` where each contains a pointer to a `Node` that produces the data, and an `entry` that is the index of the output on the other `Node`. Conversely, the `outputs` are a set of `NodeEntry` where each contains a pointer to a`Node` that consumes the data, and and `entry` that is the index of the input on the other `Node`. This bidirectional dependency will enable you to easily traverse the graph. 153 154A `Graph` contains the following: 155- `nodes` - [vector of Node] set of nodes in the graph 156- `inputs` - [vector of Node] set of inputs to the graph 157- `outputs` - [vector of NodeEntry] set of outputs from the graph 158- `attrs` - [map of string to JSON object] set of attributes for the graph 159 160The `nodes` are all the nodes in the graph (superset). The `inputs` are only those nodes that are model inputs (ie. input image) or weights (ie. arg/aux params). The `outputs` are the outputs from the operators in the model that are true outputs of the model (ie. prediction results). 161 162Heres an example creating a new node and adding it to the graph: 163```c++ 164g->addNode("myConv","Convolution"); 165``` 166Heres an example creating an edge between two nodes: 167```c++ 168n1->outputs.push_back({n2,1}); 169n2->inputs.push_back({n1,0}); 170``` 171Here node `n1` produces an output at index 0 that is consumed by node `n2` on the input at index 1. 172 173![example connection](example_connection.png) 174 175Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enable allocating new NDArrays and integrate them with the model args and aux params. Both APIs have the following signature: 176 177```c++ 178 MXTensor* alloc_xxx(const std::vector<int64_t>& shapes, 179 const MXContext &ctx, 180 MXDType dtype) 181``` 182 183This function can be called on a node in the graph to allocate a tensor for that node like: 184 185```c++ 186node->alloc_arg({1},MXContext::CPU(0),kFloat32); 187``` 188It adds a new param to the appropriate arg/aux set when the graph pass returns. If you wish to remove an existing param, just remove the node in the graph corresponding to that param. It will be deleted after the pass completes and removed from the dictionary of args or aux (whichever it is a member of). 189 190### Parsing a JSON string 191 192To simplify custom libraries, basic JSON parsing utility functions have been implemented in the `lib_api.h` header file. You create a `JsonParser` object and parse the string by calling the `parse_to_json` API like: 193 194```c++ 195JsonVal json_val = JsonVal::parse(json); 196``` 197 198A `JsonVal` is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the `JsonVal.type` to `STR`, `NUM`, `LIST`, or `MAP`. Then you can get that value from the node like: 199 200```c++ 201switch(json_val.type) { 202 case STR: 203 std::string str = json_val.str; 204 break; 205 case NUM: 206 int num = json_val.num; 207 break; 208 case LIST: 209 std::vector<JsonVal> list = json_val.list; 210 break; 211 case MAP: 212 std::map<JsonVal, JsonVal> map = json_val.map; 213 break; 214 default: 215 // error 216} 217``` 218 219You call the `dump` function on a `JsonVal` object like `json_val.dump()` to get a JSON-compatible string. There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`. 220