1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import tvm
19import numpy as np
20from tvm import relay
21from tvm.contrib import graph_runtime
22
23roundings = ["UPWARD", "TONEAREST"]
24
25def verify(mod, goldens):
26    with relay.build_config(opt_level=3):
27        graph, lib, params = relay.build(mod, "llvm", params=None)
28        golden_data, golden_output = goldens
29        rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
30        rt_mod.set_input("quantized_data",golden_data)
31        rt_mod.set_input(**params)
32        rt_mod.run()
33        res = rt_mod.get_output(0).asnumpy()
34        np.testing.assert_equal(res, golden_output)
35
36def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
37        input_zero_point=0, output_zero_point=0, rounding="TONEAREST"):
38    quantized_data = relay.var("quantized_data", shape=data_shape,
39            dtype=data_dtype)
40    mod = relay.qnn.op.requantize(
41            quantized_data,
42            input_scale=input_scale,
43            input_zero_point=input_zero_point,
44            output_scale=output_scale,
45            output_zero_point=output_zero_point,
46            rounding=rounding,
47            out_dtype=out_dtype)
48
49    mod = relay.Function(relay.analysis.free_vars(mod), mod)
50    mod = relay.Module.from_expr(mod)
51    return mod
52
53def test_same_scale():
54    # Have same scales, everything within range
55    golden_data = np.arange(-100, 100, 1).astype('int32')
56    golden_output = golden_data
57
58    for rounding in roundings:
59        mod = get_mod(data_shape=(200, ),
60                      data_dtype='int32',
61                      out_dtype="int8",
62                      input_scale=0.5,
63                      output_scale=0.5,
64                      rounding=rounding)
65        assert 'right_shift' not in mod.astext()
66        verify(mod, (golden_data, golden_output))
67
68def test_downscale():
69    for rounding in roundings:
70        mod = get_mod(data_shape=(32, ),
71                      data_dtype='int32',
72                      out_dtype='int8',
73                      input_scale=1,
74                      output_scale=16,
75                      rounding=rounding)
76
77        # Try positive values
78        # 8 corresponds to 0.5, resulting in 1
79        golden_data = np.arange(0, 32, 1).astype('int32')
80        golden_output = np.repeat([0, 1, 2], [8, 16, 8])
81        verify(mod, (golden_data, golden_output))
82
83        # Try negative values
84        # -8 corresponds to -0.5. For UPWARD, this is 0
85        golden_data = np.arange(0, -32, -1).astype('int32')
86        if rounding == "UPWARD":
87            golden_output = np.repeat([0, -1, -2], [9, 16, 7])
88        else:
89            golden_output = np.repeat([0, -1, -2], [8, 16, 8])
90        verify(mod, (golden_data, golden_output))
91
92        # Try a different scale
93        mod = get_mod(data_shape=(32, ),
94                      data_dtype='int32',
95                      out_dtype="int8",
96                      input_scale=1,
97                      output_scale=4,
98                      rounding=rounding)
99
100        # Try positive values
101        # 2I corresponds to 0.5, resulting in 1
102        golden_data = np.arange(0, 32, 1).astype('int32')
103        golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
104                                  [2, 4, 4, 4, 4, 4, 4, 4, 2])
105        verify(mod, (golden_data, golden_output))
106
107        # Try negative values
108        # -8 corresponds to -0.5. For UPWARD, this is 0
109        golden_data = np.arange(0, -32, -1).astype('int32')
110        if rounding == "UPWARD":
111            golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
112                                      [3, 4, 4, 4, 4, 4, 4, 4, 1])
113        else:
114            golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
115                                      [2, 4, 4, 4, 4, 4, 4, 4, 2])
116        verify(mod, (golden_data, golden_output))
117
118        # Try uint8 out_dtype
119        mod = get_mod(data_shape=(32, ),
120                      data_dtype='int32',
121                      out_dtype='uint8',
122                      input_scale=1,
123                      output_scale=16,
124                      rounding=rounding)
125
126        # Try positive values
127        # 8 corresponds to 0.5, resulting in 1
128        golden_data = np.arange(0, 32, 1).astype('int32')
129        golden_output = np.repeat([0, 1, 2], [8, 16, 8])
130        verify(mod, (golden_data, golden_output))
131
132        # Try uint8 in_dtyope and uint8 out_dtype
133        mod = get_mod(data_shape=(32, ),
134                      data_dtype='uint8',
135                      out_dtype='uint8',
136                      input_scale=1,
137                      output_scale=16,
138                      rounding=rounding)
139
140        # Try positive values
141        # 8 corresponds to 0.5, resulting in 1
142        golden_data = np.arange(0, 32, 1).astype('int32')
143        golden_output = np.repeat([0, 1, 2], [8, 16, 8])
144        verify(mod, (golden_data, golden_output))
145
146def test_upscale():
147    for rounding in roundings:
148        mod = get_mod(data_shape=(32, ),
149                      data_dtype='int32',
150                      out_dtype="int8",
151                      input_scale=2,
152                      output_scale=1,
153                      rounding=rounding)
154
155        # Try positive values
156        # 8 corresponds to 0.5, resulting in 1
157        golden_data = np.arange(0, 32, 1).astype('int32')
158        golden_output = np.multiply(2, golden_data)
159        verify(mod, (golden_data, golden_output))
160
161        # Try negative values
162        # -8 corresponds to -0.5. For UPWARD, this is 0
163        golden_data = np.arange(0, -32, -1).astype('int32')
164        golden_output = np.multiply(2, golden_data)
165        verify(mod, (golden_data, golden_output))
166
167def test_saturation():
168    for rounding in roundings:
169        mod = get_mod(data_shape=(16, ),
170                      data_dtype='int32',
171                      out_dtype="int8",
172                      input_scale=0.5,
173                      output_scale=0.5,
174                      rounding=rounding)
175        golden_data = np.arange(0, 16, 1).astype('int32')
176        golden_data = np.add(120, golden_data)
177        output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
178                           127, 127, 127, 127, 127, 127, 127, 127])
179        golden_output = output
180        verify(mod, (golden_data, golden_output))
181
182        # Try negative numbers
183        golden_data = np.arange(0, -16, -1).astype('int32')
184        golden_data = np.add(-120, golden_data)
185        output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
186                           -128, -128, -128, -128, -128, -128, -128, -128])
187        golden_output = output
188        verify(mod, (golden_data, golden_output))
189
190def test_zero_point():
191    # Output zero point
192    for rounding in roundings:
193        mod = get_mod(data_shape=(32, ),
194                      data_dtype='int32',
195                      out_dtype='int8',
196                      input_scale=1,
197                      output_scale=16,
198                      output_zero_point=1,
199                      rounding=rounding)
200
201        # Try positive values
202        # 8 corresponds to 0.5, resulting in 1
203        golden_data = np.arange(0, 32, 1).astype('int32')
204        golden_output = np.repeat([0, 1, 2], [8, 16, 8])
205        golden_output = np.add(1, golden_output)
206        verify(mod, (golden_data, golden_output))
207
208        # Try negative values
209        # -8 corresponds to -0.5. For UPWARD, this is 0
210        golden_data = np.arange(-32, -64, -1).astype('int32')
211        if rounding == "UPWARD":
212            golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
213        else:
214            golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
215        golden_output = np.add(1, golden_output)
216        verify(mod, (golden_data, golden_output))
217
218    # Input zero point
219    for rounding in roundings:
220        mod = get_mod(data_shape=(32, ),
221                      data_dtype='int32',
222                      out_dtype='int8',
223                      input_scale=1,
224                      output_scale=16,
225                      input_zero_point=16,
226                      rounding=rounding)
227
228        # Try positive values
229        golden_data = np.arange(32, 64, 1).astype('int32')
230        golden_output = np.repeat([2, 3, 4], [8, 16, 8])
231        golden_output = np.subtract(golden_output, 1)
232        verify(mod, (golden_data, golden_output))
233
234        # Try negative values
235        golden_data = np.arange(-32, -64, -1).astype('int32')
236        if rounding == "UPWARD":
237            golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
238        else:
239            golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
240        golden_output = np.subtract(golden_output, 1)
241        verify(mod, (golden_data, golden_output))
242
243if __name__ == "__main__":
244    test_same_scale()
245    test_downscale()
246    test_upscale()
247    test_saturation()
248    test_zero_point()
249