1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-argument
18"""Arm Compute Library supported operators."""
19import tvm
20from tvm.relay.expr import const
21from tvm.relay import transform
22from tvm.relay.build_module import bind_params_by_name
23
24from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr
25from .register import register_pattern_table
26
27
28def is_arm_compute_runtime_enabled():
29    """Check if the ACL graph runtime is present.
30
31    Returns
32    -------
33    ret: bool
34        True if present, False if not.
35    """
36    check_enabled = tvm.get_global_func("relay.op.is_arm_compute_runtime_enabled", True)
37    if check_enabled:
38        return check_enabled()
39    return False
40
41
42def partition_for_arm_compute_lib(mod, params=None):
43    """Partition the graph greedily offloading supported
44    operators to Arm Compute Library.
45
46    Parameters
47    ----------
48    mod : Module
49        The module to run passes on.
50    params : Optional[Dict[str, NDArray]]
51        Constant input parameters.
52
53    Returns
54    -------
55    ret : annotated and partitioned module.
56    """
57    if params:
58        mod["main"] = bind_params_by_name(mod["main"], params)
59
60    seq = tvm.transform.Sequential(
61        [
62            transform.MergeComposite(arm_compute_lib_pattern_table()),
63            transform.AnnotateTarget("arm_compute_lib"),
64            transform.PartitionGraph(),
65        ]
66    )
67
68    return seq(mod)
69
70
71@register_pattern_table("arm_compute_lib")
72def arm_compute_lib_pattern_table():
73    """Get the ACL pattern table."""
74
75    def conv_pattern():
76        """Create a convolution pattern.
77
78        Returns
79        -------
80        pattern : dataflow_pattern.AltPattern
81            Denotes the convolution pattern.
82        """
83        pattern = is_op("nn.pad")(wildcard()) | wildcard()
84        pattern = is_op("nn.conv2d")(pattern, is_constant())
85        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
86        pattern = pattern.optional(is_op("nn.relu"))
87        return pattern
88
89    def qnn_conv_pattern():
90        """Create a quantized convolution pattern.
91
92        Returns
93        -------
94        pattern : dataflow_pattern.AltPattern
95            Denotes the convolution pattern.
96        """
97        pattern = is_op("nn.pad")(wildcard()) | wildcard()
98        pattern = is_op("qnn.conv2d")(
99            pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
100        )
101        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
102        pattern = pattern.optional(is_op("nn.relu"))
103        pattern = is_op("qnn.requantize")(
104            pattern, wildcard(), wildcard(), is_constant(), is_constant()
105        )
106        return pattern
107
108    def dense_pattern():
109        """Create a dense (fully-connected) pattern.
110
111        Returns
112        -------
113        pattern : dataflow_pattern.AltPattern
114            Denotes the convolution pattern.
115        """
116        pattern = is_op("nn.dense")(wildcard(), is_constant())
117        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
118        return pattern
119
120    def qnn_dense_pattern():
121        """Create a quantized dense (fully-connected) pattern.
122
123        Returns
124        -------
125        pattern : dataflow_pattern.AltPattern
126            Denotes the convolution pattern.
127        """
128        pattern = is_op("qnn.dense")(
129            wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
130        )
131        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
132        pattern = is_op("qnn.requantize")(
133            pattern, wildcard(), wildcard(), is_constant(), is_constant()
134        )
135        return pattern
136
137    def avg_pool2d_pattern():
138        """Creates a pattern that matches either quantized
139        avg_pool2d or quantized global_avg_pool2d.
140
141        Returns
142        -------
143        pattern : dataflow_pattern.AltPattern
144            Denotes the convolution pattern.
145        """
146        pattern = is_op("cast")(wildcard())
147        pattern = is_op("nn.avg_pool2d")(pattern) | is_op("nn.global_avg_pool2d")(pattern)
148        pattern = is_op("cast")(pattern)
149        return pattern
150
151    def l2_pool2d_pattern():
152        """Create an l2 pooling pattern from equivalent relay operators.
153
154        Returns
155        -------
156        pattern : dataflow_pattern.AltPattern
157            Denotes the convolution pattern.
158        """
159        pattern = is_op("power")(wildcard(), is_expr(const(2.0)))
160        pattern = is_op("nn.avg_pool2d")(pattern)
161        pattern = is_op("sqrt")(pattern)
162        return pattern
163
164    def check_conv(extract):
165        """Check conv pattern is supported by ACL."""
166        call = extract
167        while call.op.name != "nn.conv2d":
168            call = call.args[0]
169        return conv2d(call.attrs, call.args)
170
171    def check_qnn_conv(extract):
172        """Check qnn conv pattern is supported by ACL."""
173        if extract.attrs.out_dtype != "uint8":
174            return False
175        call = extract
176        while call.op.name != "qnn.conv2d":
177            call = call.args[0]
178        return qnn_conv2d(call.attrs, call.args)
179
180    def check_dense(extract):
181        """Check conv pattern is supported by ACL."""
182        call = extract
183        while call.op.name != "nn.dense":
184            call = call.args[0]
185        return dense(call.attrs, call.args)
186
187    def check_qnn_dense(extract):
188        """Check qnn conv pattern is supported by ACL."""
189        if extract.attrs.out_dtype != "uint8":
190            return False
191        call = extract
192        while call.op.name != "qnn.dense":
193            call = call.args[0]
194        return qnn_dense(call.attrs, call.args)
195
196    def check_avg_pool2d(extract):
197        """Check average pool2d pattern is supported by ACL."""
198        if extract.attrs.dtype != "uint8":
199            return False
200        pool = extract.args[0]
201        if pool.args[0].attrs.dtype != "int32":
202            return False
203        return avg_pool2d(pool.attrs, pool.args, from_quantized_composite=True)
204
205    def check_l2_pool2d(extract):
206        """Check l2 pool2d pattern is supported by ACL."""
207        pool = extract.args[0]
208        return avg_pool2d(pool.attrs, pool.args)
209
210    return [
211        ("arm_compute_lib.conv2d", conv_pattern(), check_conv),
212        ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv),
213        ("arm_compute_lib.dense", dense_pattern(), check_dense),
214        ("arm_compute_lib.qnn_dense", qnn_dense_pattern(), check_qnn_dense),
215        ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv),
216        ("arm_compute_lib.avg_pool2d", avg_pool2d_pattern(), check_avg_pool2d),
217        ("arm_compute_lib.l2_pool2d", l2_pool2d_pattern(), check_l2_pool2d),
218    ]
219
220
221def _register_external_op_helper(op_name, supported=True):
222    @tvm.ir.register_op_attr(op_name, "target.arm_compute_lib")
223    def _func_wrapper(attrs, args):
224        return supported
225
226    return _func_wrapper
227
228
229_register_external_op_helper("reshape")
230
231
232@tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib")
233def conv2d(attrs, args):
234    """Check if the external ACL codegen for conv2d should be used."""
235    if attrs.groups != 1:
236        return False
237    if attrs.data_layout != "NHWC":
238        return False
239    if attrs.out_dtype != "float32" and attrs.out_dtype != "":
240        return False
241    data_typ = args[0].checked_type
242    if len(data_typ.shape) != 4 or data_typ.shape[0] != 1 or data_typ.dtype != "float32":
243        return False
244    kernel_typ = args[1].checked_type
245    if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "float32":
246        return False
247    return True
248
249
250def qnn_conv2d(attrs, args):
251    """Check if the external ACL codegen for qnn.conv2d should be used."""
252    if attrs.groups != 1:
253        return False
254    if attrs.data_layout != "NHWC":
255        return False
256    if attrs.out_dtype != "int32" and attrs.out_dtype != "":
257        return False
258    data_typ = args[0].checked_type
259    if len(data_typ.shape) != 4 or data_typ.shape[0] != 1 or data_typ.dtype != "uint8":
260        return False
261    kernel_typ = args[1].checked_type
262    if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "uint8":
263        return False
264    return True
265
266
267@tvm.ir.register_op_attr("nn.dense", "target.arm_compute_lib")
268def dense(attrs, args):
269    """Check if the external ACL codegen for dense should be used."""
270    data_typ = args[0].checked_type
271    if data_typ.dtype != "float32":
272        return False
273    kernel_typ = args[1].checked_type
274    if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "float32":
275        return False
276    if attrs.out_dtype != "float32" and attrs.out_dtype != "":
277        return False
278    return True
279
280
281def qnn_dense(attrs, args):
282    """Check if the external ACL codegen for qnn.dense should be used."""
283    data_typ = args[0].checked_type
284    if data_typ.dtype != "uint8":
285        return False
286    kernel_typ = args[1].checked_type
287    if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "uint8":
288        return False
289    if attrs.out_dtype != "int32":
290        return False
291    return True
292
293
294@tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib")
295def max_pool2d(attrs, args):
296    """Check if the external ACL codegen for maxpool2d should be used."""
297    if attrs.layout != "NHWC":
298        return False
299    typ = args[0].checked_type
300    if typ.dtype not in ["float32", "uint8"]:
301        return False
302    return True
303
304
305@tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib")
306def avg_pool2d(attrs, args, from_quantized_composite=False):
307    """Check if the external ACL codegen for avgpool2d should be used."""
308    typ = args[0].checked_type
309    if from_quantized_composite:
310        if typ.dtype != "int32":
311            return False
312    else:
313        if typ.dtype not in ["float32"]:
314            return False
315    if attrs.layout != "NHWC":
316        return False
317    return True
318
319
320@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib")
321def global_max_pool2d(attrs, args):
322    """Check if the external ACL codegen for gloval_maxpool2d should be used."""
323    typ = args[0].checked_type
324    if typ.dtype not in ["float32", "uint8"]:
325        return False
326    if attrs.layout != "NHWC":
327        return False
328    return True
329
330
331@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib")
332def global_avg_pool2d(attrs, args):
333    """Check if the external ACL codegen for global_avgpool2d should be used."""
334    typ = args[0].checked_type
335    if typ.dtype not in ["float32"]:
336        return False
337    if attrs.layout != "NHWC":
338        return False
339    return True
340
341
342@tvm.ir.register_op_attr("maximum", "target.arm_compute_lib")
343def maximum(attrs, args):
344    """Check if the external ACL codegen for maximum should be used."""
345    type_a = args[0].checked_type
346    type_b = args[0].checked_type
347    return (type_a.dtype == "float32") and (type_b.dtype == "float32")
348