1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=unused-argument
18"""A Relay implementation of graph packing."""
19
20from tvm import relay
21from tvm.relay import op, transform
22from tvm.relay import ExprMutator
23
24def run_opt_pass(expr, opt_pass):
25    """Exectue a relay pass."""
26    assert isinstance(opt_pass, transform.Pass)
27    mod = relay.Module.from_expr(expr)
28    mod = opt_pass(mod)
29    entry = mod["main"]
30    return entry if isinstance(expr, relay.Function) else entry.body
31
32def _to_shape(shape):
33    return tuple(int(sh) for sh in shape)
34
35def _pack_batch_channel(data, dshape, bfactor, cfactor):
36    """Pack the data channel dimension.
37    """
38    assert int(dshape[0]) % bfactor == 0
39    assert int(dshape[1]) % cfactor == 0
40    data = op.reshape(data,
41                      newshape=(int(dshape[0]) // bfactor, bfactor,
42                                int(dshape[1]) // cfactor, cfactor,
43                                int(dshape[2]), int(dshape[3])))
44    data = op.transpose(
45        data, axes=(0, 2, 4, 5, 1, 3))
46    return data
47
48
49def _unpack_batch_channel(data, old_shape):
50    """Unpack the data channel dimension.
51    """
52    data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
53    data = op.reshape(data, newshape=old_shape)
54    return data
55
56
57def _pack_weight(data, dshape, cfactor):
58    """Pack the weight into packed format.
59    """
60    assert len(dshape) == 4
61    assert int(dshape[0]) % cfactor == 0
62    assert int(dshape[1]) % cfactor == 0
63    data = op.reshape(data,
64                      newshape=(int(dshape[0]) // cfactor, cfactor,
65                                int(dshape[1]) // cfactor, cfactor,
66                                int(dshape[2]), int(dshape[3])))
67    data = op.transpose(
68        data, axes=(0, 2, 4, 5, 1, 3))
69    return data
70
71
72def _pack_weight_conv2d_transpose(data, dshape, cfactor):
73    """Pack the weight into packed format.
74    """
75    dshape = _to_shape(dshape)
76    assert len(dshape) == 4
77    assert dshape[0] % cfactor == 0
78    assert dshape[1] % cfactor == 0
79    data = op.reshape(data,
80                      newshape=(dshape[0] // cfactor, cfactor,
81                                dshape[1] // cfactor, cfactor,
82                                dshape[2], dshape[3]))
83    data = op.transpose(
84        data, axes=(2, 0, 4, 5, 3, 1))
85    return data
86
87
88def _pack_const(data, dshape, dtype, bfactor, cfactor):
89    """Pack a constant parameter.
90    """
91    dshape = _to_shape(dshape)
92    assert len(dshape) == 3
93    assert dshape[0] % cfactor == 0
94    data = op.reshape(data,
95                      newshape=(dshape[0] // cfactor,
96                                cfactor, dshape[1],
97                                dshape[2], 1))
98    data = op.transpose(
99        data, axes=(0, 2, 3, 4, 1))
100
101    # broadcast batch dimension to bfactor
102    data = op.broadcast_to(
103        data,
104        shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
105    return data
106
107
108def _get_shape(node):
109    """Get the shape of a node.
110    """
111    return _to_shape(node.checked_type.shape)
112
113class ExprPack(ExprMutator):
114    """Visitor to perform graph packing on an AST.
115    """
116    def __init__(self, bfactor, cfactor, weight_bits):
117        self.bfactor = bfactor
118        self.cfactor = cfactor
119        self.weight_bits = weight_bits
120        self.start_pack = False
121        # Cache Operator the algorithm matches against.
122        self.bitpack_start = op.op.get('annotation.bitpack_start')
123        self.bitpack_end = op.op.get('annotation.bitpack_end')
124        self.conv2d = op.op.get("nn.conv2d")
125        self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
126        self.add = op.op.get("add")
127        self.multiply = op.op.get("multiply")
128        self.bias_add = op.op.get("nn.bias_add")
129        self.number_of_conv2d = 0
130        super().__init__()
131
132    def visit_call(self, call):
133        """ Visit the children. """
134        # First visit the children.
135        oshape = _get_shape(call)
136        odtype = call.checked_type.dtype
137        input_types = [arg.checked_type for arg in call.args]
138        args = [self.visit(arg) for arg in call.args]
139
140        # Start and stop cases.
141        if call.op == self.bitpack_start:
142            assert not self.start_pack
143            self.start_pack = True
144            return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor)
145        elif call.op == self.bitpack_end:
146            if self.start_pack:
147                self.start_pack = False
148                data = args[0]
149                data_shape = _get_shape(call.args[0])
150                return _unpack_batch_channel(data, data_shape)
151            else:
152                pass
153        if self.start_pack:
154            # Operator cases
155            if call.op == self.conv2d and odtype == 'int32':
156                self.number_of_conv2d += 1
157                assert 8 % self.weight_bits == 0
158                w_lanes = 8 // self.weight_bits
159                data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
160                kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor)
161                data, weight = args
162                data_shape = _to_shape(input_types[0].shape)
163                kernel_shape = _to_shape(input_types[1].shape)
164                kernel = _pack_weight(weight, kernel_shape, self.cfactor)
165                # insert bit packing when necessary
166                if w_lanes != 1:
167                    assert 8 % w_lanes == 0
168                    kernel = op.bitpack(kernel, lanes=w_lanes)
169                conv2d = op.nn.conv2d(
170                    data,
171                    kernel,
172                    strides=call.attrs.strides,
173                    padding=call.attrs.padding,
174                    dilation=call.attrs.dilation,
175                    groups=call.attrs.groups,
176                    channels=call.attrs.channels,
177                    kernel_size=call.attrs.kernel_size,
178                    data_layout=data_layout,
179                    kernel_layout=kernel_layout,
180                    out_dtype=call.attrs.out_dtype)
181                return conv2d
182            elif call.op == self.conv2d_transpose and odtype == 'int32':
183                self.number_of_conv2d += 1
184                assert 8 % self.weight_bits == 0
185                w_lanes = 8 // self.weight_bits
186                if self.start_pack:
187                    data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
188                    kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor)
189                    data, weight = args
190                    data_shape = _to_shape(input_types[0].shape)
191                    kernel_shape = _to_shape(input_types[1].shape)
192                    kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor)
193                    conv2d = op.nn.conv2d_transpose(
194                        data,
195                        kernel,
196                        strides=call.attrs.strides,
197                        padding=call.attrs.padding,
198                        dilation=call.attrs.dilation,
199                        groups=call.attrs.groups,
200                        channels=call.attrs.channels,
201                        kernel_size=call.attrs.kernel_size,
202                        data_layout=data_layout,
203                        kernel_layout=kernel_layout,
204                        output_padding=call.attrs.output_padding,
205                        out_dtype=call.attrs.out_dtype)
206                return conv2d
207            elif call.op == self.add and \
208                    tuple(input_types[0].shape) == tuple(input_types[1].shape):
209                pass
210            elif call.op == self.add and len(input_types[1].shape) == 3:
211                data, const = args
212                const = _pack_const(const,
213                                    _to_shape(input_types[1].shape),
214                                    input_types[1].dtype,
215                                    self.bfactor,
216                                    self.cfactor)
217                return relay.Call(self.add, [data, const])
218            elif call.op == self.multiply and \
219                    tuple(input_types[0].shape) == tuple(input_types[1].shape):
220                pass
221            elif call.op == self.multiply and len(input_types[1].shape) == 3:
222                data, const = args
223                const = _pack_const(const,
224                                    _to_shape(input_types[1].shape),
225                                    input_types[1].dtype,
226                                    self.bfactor,
227                                    self.cfactor)
228                return relay.Call(self.multiply, [data, const])
229            elif self.start_pack and call.op == self.bias_add:
230                data, bias = args
231                bias = _pack_const(bias,
232                                   _to_shape(input_types[1].shape),
233                                   input_types[1].dtype,
234                                   self.bfactor,
235                                   self.cfactor)
236                return relay.Call(self.add, [data, bias])
237            elif self.start_pack and call.op == op.op.get('cast') and \
238                    input_types[0].dtype == 'int32':
239                cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs)
240                return relay.Call(op.op.get('copy'), [cast])
241
242        return relay.Call(
243            self.visit(call.op),
244            args,
245            call.attrs)
246
247class BT(Exception):
248    pass
249def get_subgraph(expr, start_name, stop_name):
250    """ We assume stop_name only appears once for simplicity.
251        This constraint will be lifted in the future.
252        bitpack_start and bitpack_end are both inclusive.
253    """
254    bitpack_start = op.op.get('annotation.bitpack_start')
255    bitpack_end = op.op.get('annotation.bitpack_end')
256    anf = run_opt_pass(expr, transform.ToANormalForm())
257    def _recursion(anf, start_found, stop_found):
258        """ Helper to obtain the subgraph.
259        """
260        if isinstance(anf, relay.expr.Function):
261            return relay.expr.Function(anf.params,
262                                       _recursion(anf.body, start_found, stop_found),
263                                       anf.ret_type, anf.type_params, anf.attrs)
264        elif isinstance(anf, relay.expr.Let):
265            value = anf.value
266            if isinstance(value, relay.expr.Call):
267                if isinstance(value.op, relay.op.Op):
268                    if value.op.name == start_name and not start_found:
269                        value = relay.expr.Call(bitpack_start, [value])
270                        start_found = True
271                    elif value.op.name == stop_name:
272                        raise BT()
273            try:
274                return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found))
275            except BT:
276                assert start_found
277                assert not stop_found
278                stop_found = True
279                value = relay.expr.Call(bitpack_end, [value])
280                # todo: check anf.body has no more stop_name beside that one
281                return relay.expr.Let(anf.var, value, anf.body)
282        else:
283            assert start_found
284            assert stop_found
285            return anf
286    annotated = _recursion(anf, False, False)
287    return run_opt_pass(annotated, transform.ToGraphNormalForm())
288
289def graph_pack(expr,
290               bfactor,
291               cfactor,
292               weight_bits,
293               start_name="nn.max_pool2d",
294               stop_name="nn.global_avg_pool2d"):
295    """Pack the graph into batch&channel packed format.
296
297    Parameters
298    ----------
299    expr : relay.Expr
300       The input program.
301
302    bfactor : int
303       The packing factor in batch
304
305    cfactor : int
306       The packing factor in channel
307
308    weight_bits: int
309        The bit-width of the weights.
310
311    start_name: str, optional
312       Start packing from certain known node.
313
314    stop_name: str, optional
315       Stop packing from certain known node.
316
317    Returns
318    -------
319    expr : Expr
320        The transformed expression.
321    """
322    assert isinstance(expr, relay.Function)
323    expr = get_subgraph(expr, start_name, stop_name)
324    expr = run_opt_pass(expr, transform.InferType())
325    packer = ExprPack(
326        bfactor, cfactor,
327        weight_bits)
328    expr = packer.visit(expr)
329    assert not packer.start_pack
330    return run_opt_pass(expr, transform.InferType())
331