1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""QNN dialect operators."""
19
20from __future__ import absolute_import as _abs
21from tvm.relay.expr import Tuple, TupleWrapper
22from tvm.relay.op.nn.util import get_pad_tuple2d
23from . import _make
24
25
26def requantize(
27    data,
28    input_scale,
29    input_zero_point,
30    output_scale,
31    output_zero_point,
32    axis=-1,
33    rounding="UPWARD",
34    out_dtype="int8",
35):
36    r"""Requantized operator.
37
38    The requantize operator converts one quantized tensor representation to
39    another quantized tensor representation. For the output tensor, we are
40    provided with output scale and zero point. The computation is as follows
41
42    Q_output = zp_output +  (scale_input)/(scale_output) * (Q_input - zp_input)
43
44    Parameters
45    ----------
46    data : tvm.relay.Expr
47        The input data to the operator.
48
49    input_scale: tvm.relay.Expr
50        The quantization scale for the input tensor.
51
52    input_zero_point: tvm.relay.Expr
53        The zero point of the input tensor.
54
55    output_scale: tvm.relay.Expr
56        The quantization scale for the output tensor.
57
58    output_zero_point: tvm.relay.Expr
59        The zero point of the output tensor.
60
61    axis : int
62        The channel axis for quantization. Default value is -1 which corresponds to the last axis.
63
64    rounding : string, optional
65        Defines the rounding direction when the value is midway between two
66        representable values.
67
68    out_dtype : str, optional
69        Specifies the output data type.
70
71    Returns
72    -------
73    result : tvm.relay.Expr
74        The computed result.
75    """
76
77    return _make.requantize(
78        data,
79        input_scale,
80        input_zero_point,
81        output_scale,
82        output_zero_point,
83        axis,
84        rounding,
85        out_dtype,
86    )
87
88
89def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
90    r"""Quantize op
91    This operator takes float32 as input and produces quantized int8 or unit8 as output.
92    The input tensor can be of any shape. The output shape is the same as input shape.
93
94    Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
95                     out_dtype::min,
96                     out_dtype::max)
97
98    Parameters
99    ----------
100    data : tvm.relay.Expr
101        The input tensor to be quantized. Can be of type float32.
102    output_zero_point : tvm.relay.Expr
103        The output zero_point.
104    output_scale : tvm.relay.Expr
105        The output scale.
106    axis : int
107        The channel axis for quantization. Default value is -1 which corresponds to the last axis.
108    out_dtype : str, optional
109        The data type of the input tensor. Can be [int8, uint8, int32]
110    Returns
111    -------
112    result : tvm.relay.Expr
113        The computed result.
114    """
115
116    return _make.quantize(data, output_scale, output_zero_point, axis, out_dtype)
117
118
119def dequantize(data, input_scale, input_zero_point, axis=-1):
120    r"""Dequantize op
121    This operator takes quantized int8 and unit8 as input and produces
122    dequantized float32 as output. The output shape is the same as input shape. The input
123    tensor can be of any shape.
124
125    Parameters
126    ----------
127    data : tvm.relay.Expr
128        The input tensor to be dequantized. Can be of type [int8, uint8].
129    input_zero_point : tvm.relay.Expr
130        The input zero_point.
131    input_scale : tvm.relay.Expr
132        The input scale.
133    axis : int
134        The channel axis for quantization. Default value is -1 which corresponds to the last axis.
135    Returns
136    -------
137    result : tvm.relay.Expr
138        The computed result.
139    """
140
141    return _make.dequantize(data, input_scale, input_zero_point, axis)
142
143
144def concatenate(data, input_scales, input_zero_points, output_scale, output_zero_point, axis):
145    """Concatenate the quantized input tensors along the given axis.
146
147    Parameters
148    ----------
149    data : Union(List[relay.Expr], Tuple[relay.Expr], TupleWrapper[relay.Expr])
150        The list of quantized tensors.
151
152    input_scales : List[relay.Expr]
153        The list of scales of input quantized tensors.
154
155    input_zero_points : List[relay.Expr]
156        The list of zero points of input quantized tensors.
157
158    output_scale : relay.Expr
159        The scale of the output quantized tensor.
160
161    output_zero_point : relay.Expr
162        The zero point of the output quantized tensor.
163
164    axis : int
165        The axis along which the tensors are concatenated.
166
167    Returns
168    -------
169    result: relay.Expr
170        The concatenated quantized tensor.
171    """
172
173    if isinstance(data, (list, tuple)):
174        data = Tuple(data)
175    elif isinstance(data, TupleWrapper):
176        data = data.tuple_value
177    if not isinstance(axis, int):
178        raise ValueError("For now, we only support integer axis")
179    input_scales = list(input_scales)
180    input_zero_points = list(input_zero_points)
181
182    return _make.concatenate(
183        data, Tuple(input_scales), Tuple(input_zero_points), output_scale, output_zero_point, axis
184    )
185
186
187def conv2d(
188    data,
189    kernel,
190    input_zero_point,
191    kernel_zero_point,
192    input_scale,
193    kernel_scale,
194    kernel_size,
195    channels,
196    strides=(1, 1),
197    padding=(0, 0),
198    dilation=(1, 1),
199    groups=1,
200    data_layout="NCHW",
201    kernel_layout="OIHW",
202    out_layout="",
203    out_dtype="int32",
204):
205    r"""Quantized 2D convolution.
206
207    This operator convolves quantized data with quantized kernel. The scale of
208    the output quantized tensor is the product of the kernel_scale and
209    input_scale of the input quantized tensors. The zero point of the output
210    quantized tensor is 0. By default, the dtype of output is int32. Please also
211    refer to Requantize operator to understand how to scale back the int32
212    output to (u)int8.
213
214    Parameters
215    ----------
216    data : tvm.relay.Expr
217        The input data to the operator.
218
219    kernel : tvm.relay.Expr
220        The kernel expressions.
221
222    input_zero_point: tvm.relay.Expr
223           The zero point of the data distribution.
224
225    kernel_zero_point: tvm.relay.Expr
226           The zero point of the quantized_kernel distribution.
227
228    input_scale: tvm.relay.Expr
229           The scale for the input tensor. The scale for the input tensor is
230           stored purely for convenience here. See more commentary below.
231
232    kernel_scale: tvm.relay.Expr
233           The scale for the weight tensor. The scale for the weight tensor is
234           stored for access to this during relay. This information is not
235           needed in the pass pipeline after qnn.conv2d is lowered to the
236           sequence of steps as in nn.conv2d. See also input_scale in Requantize.
237
238    kernel_size : tuple of int
239        The spatial width and height of the convolution kernel.
240
241    channels : int
242        Number of output channels of this convolution.
243
244    strides : tuple of int, optional
245        The strides of convolution.
246
247    padding : tuple of int, optional
248        The padding of convolution on both sides of inputs before convolution.
249
250    dilation : tuple of int, optional
251        Specifies the dilation rate to be used for dilated convolution.
252
253    groups : int, optional
254        Number of groups for grouped convolution.
255
256    data_layout : str, optional
257        Layout of the input.
258
259    kernel_layout : str, optional
260        Layout of the kernel.
261
262    out_layout : str, optional
263        Layout of the output, by default, out_layout is the same as data_layout
264
265    out_dtype : str, optional
266        Specifies the output data type for mixed precision conv2d.
267
268    Returns
269    -------
270    result : tvm.relay.Expr
271        The computed result.
272    """
273
274    # TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
275    # convert 2-way padding to 4-way padding
276    padding = get_pad_tuple2d(padding)
277    return _make.conv2d(
278        data,
279        kernel,
280        input_zero_point,
281        kernel_zero_point,
282        input_scale,
283        kernel_scale,
284        strides,
285        padding,
286        dilation,
287        groups,
288        channels,
289        kernel_size,
290        data_layout,
291        kernel_layout,
292        out_layout,
293        out_dtype,
294    )
295
296
297def add(
298    lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
299):
300    """Quantized addition with numpy-style broadcasting.
301
302    Parameters
303    ----------
304    lhs : relay.Expr
305        The left hand side quantized input data.
306
307    rhs : relay.Expr
308        The right hand side quantized input data.
309
310    lhs_scale: relay.Expr
311        The scale of the lhs quantized expr.
312
313    lhs_zero_point: relay.Expr
314       The zero point of lhs quantized expr.
315
316    rhs_scale: relay.Expr
317        The scale of the rhs quantized expr.
318
319    rhs_zero_point: relay.Expr
320       The zero point of rhs quantized expr.
321
322    output_scale: relay.Expr
323        The scale of the output quantized expr.
324
325    output_zero_point: relay.Expr
326       The zero point of output quantized expr.
327
328    Returns
329    -------
330    result : relay.Expr
331        The computed result.
332
333    """
334    return _make.add(
335        lhs,
336        rhs,
337        lhs_scale,
338        lhs_zero_point,
339        rhs_scale,
340        rhs_zero_point,
341        output_scale,
342        output_zero_point,
343    )
344
345
346def dense(
347    data,
348    weight,
349    input_zero_point,
350    kernel_zero_point,
351    input_scale,
352    kernel_scale,
353    units,
354    out_dtype="int32",
355):
356    """Qnn Dense operator.
357    Applies a quantized linear transformation
358
359     .. math::
360
361     `Y = X * W`
362
363    Parameters
364    ----------
365    data : tvm.relay.Expr
366        The quantized input data to the operator.
367    weight : tvm.relay.Expr
368        The quantized weight expressions.
369    input_zero_point: tvm.relay.Expr
370        The input zero point.
371    kernel_zero_point: tvm.relay.Expr
372        The kernel zero point.
373    input_scale: tvm.relay.Expr
374        The scale for the input tensor.
375    kernel_scale: tvm.relay.Expr
376        The scale for the weight tensor. The scale for the weight tensor is
377        stored for access to this during relay. This information is not
378        needed in the pass pipeline after qnn.conv2d is lowered to the
379        sequence of steps as in nn.conv2d. See also input_scale in Requantize.
380    units : int
381        Number of hidden units of the dense transformation.
382    out_dtype : str, optional
383        Specifies the output data type for mixed precision dense can be int32 or int16.
384
385    Returns
386    -------
387    result : tvm.relay.Expr
388        The computed result.
389    """
390
391    return _make.dense(
392        data,
393        weight,
394        input_zero_point,
395        kernel_zero_point,
396        input_scale,
397        kernel_scale,
398        units,
399        out_dtype,
400    )
401
402
403def mul(
404    lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
405):
406    """Quantized multiplication with numpy-style broadcasting.
407
408    Parameters
409    ----------
410    lhs : relay.Expr
411        The left hand side quantized input data.
412
413    rhs : relay.Expr
414        The right hand side quantized input data.
415
416    lhs_scale: relay.Expr
417        The scale of the lhs quantized expr.
418
419    lhs_zero_point: relay.Expr
420       The zero point of lhs quantized expr.
421
422    rhs_scale: relay.Expr
423        The scale of the rhs quantized expr.
424
425    rhs_zero_point: relay.Expr
426       The zero point of rhs quantized expr.
427
428    output_scale: relay.Expr
429        The scale of the output quantized expr.
430
431    output_zero_point: relay.Expr
432       The zero point of output quantized expr.
433
434    Returns
435    -------
436    result : relay.Expr
437        The computed result.
438
439    """
440    return _make.mul(
441        lhs,
442        rhs,
443        lhs_scale,
444        lhs_zero_point,
445        rhs_scale,
446        rhs_zero_point,
447        output_scale,
448        output_zero_point,
449    )
450
451
452def subtract(
453    lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
454):
455    """Quantized subtraction with numpy-style broadcasting.
456
457    Parameters
458    ----------
459    lhs : relay.Expr
460        The left hand side quantized input data.
461
462    rhs : relay.Expr
463        The right hand side quantized input data.
464
465    lhs_scale: relay.Expr
466        The scale of the lhs quantized expr.
467
468    lhs_zero_point: relay.Expr
469       The zero point of lhs quantized expr.
470
471    rhs_scale: relay.Expr
472        The scale of the rhs quantized expr.
473
474    rhs_zero_point: relay.Expr
475       The zero point of rhs quantized expr.
476
477    output_scale: relay.Expr
478        The scale of the output quantized expr.
479
480    output_zero_point: relay.Expr
481       The zero point of output quantized expr.
482
483    Returns
484    -------
485    result : relay.Expr
486        The computed result.
487
488    """
489    return _make.subtract(
490        lhs,
491        rhs,
492        lhs_scale,
493        lhs_zero_point,
494        rhs_scale,
495        rhs_zero_point,
496        output_scale,
497        output_zero_point,
498    )
499