1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
18""" Conv Int8 functional and performance testing"""
19import sys
20import logging
21import numpy as np
22import tvm
23from tvm import te
24from tvm import topi
25
26logging.basicConfig(stream=sys.stdout, level=logging.INFO)
27LOGGER = logging.getLogger("test_conv_int8_intel")
28LOGGER.disabled = False
29
30# All the WORKLOADS from Resnet except first layer
31# Workload is ['height', 'width', 'in_filter', 'out_filter',
32#              'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
33WORKLOADS = [
34    (56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
35    (56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
36    (56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
37    (56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
38    (28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
39    (28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
40    (28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
41    (14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
42    (14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
43    (14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
44    (7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
45    (56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
46    (56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
47    (56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
48    (28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
49    (56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
50    (28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
51    (28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
52    (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
53    (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
54    (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
55    (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
56    (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
57    (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
58    (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
59]
60
61
62TARGET_NAME = "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod"
63NUM_VEC_LANES = 16
64CTX = tvm.context(TARGET_NAME, 0)
65
66
67def get_shape(
68    im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, hstride, wstride, out_dtype
69):
70    """
71    Finds out the shape of all data structures
72    """
73    data_shape = (1, in_filter // NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
74
75    if out_dtype == "int32" or out_dtype == "uint32":
76        kernel_shape = (
77            out_filter // NUM_VEC_LANES,
78            in_filter // NUM_VEC_LANES,
79            k_h,
80            k_w,
81            NUM_VEC_LANES // 4,
82            NUM_VEC_LANES,
83            4,
84        )
85    elif out_dtype == "float32":
86        kernel_shape = (
87            out_filter // NUM_VEC_LANES,
88            in_filter // NUM_VEC_LANES,
89            k_h,
90            k_w,
91            NUM_VEC_LANES,
92            NUM_VEC_LANES,
93        )
94    out_height = (im_height + 2 * hpad - k_h) // hstride + 1
95    out_width = (im_width + 2 * wpad - k_w) // wstride + 1
96    o_shape = (1, out_filter // NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
97    return (data_shape, kernel_shape, o_shape)
98
99
100def run_inference(
101    data_dtype,
102    kernel_dtype,
103    out_dtype,
104    im_height,
105    im_width,
106    in_filter,
107    out_filter,
108    k_h,
109    k_w,
110    hpad,
111    wpad,
112    hstride,
113    wstride,
114):
115    """
116    Runs the inference and checks the functional correctness between
117    compute and schedule outputs
118    """
119    (data_shape, kernel_shape, o_shape) = get_shape(
120        im_height,
121        im_width,
122        in_filter,
123        out_filter,
124        k_h,
125        k_w,
126        hpad,
127        wpad,
128        hstride,
129        wstride,
130        out_dtype,
131    )
132
133    # Create TVM placeholders
134    data = te.placeholder(data_shape, name="data", dtype=data_dtype)
135    kernel = te.placeholder(kernel_shape, name="kernel", dtype=kernel_dtype)
136
137    # Create the numpy arrays to be used for executing conv models
138    if data_dtype == "float32":
139        data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX)
140        kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX)
141    else:
142        data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype))
143        kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype))
144
145    # c_orig will be used for declaration ouptut
146    # c_sch will be used for scheduled computation output
147    c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
148    c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
149
150    with tvm.target.Target(TARGET_NAME):
151        if out_dtype == "float32":
152            conv = topi.nn.conv2d_NCHWc(
153                data,
154                kernel,
155                stride=hstride,
156                padding=hpad,
157                dilation=(1, 1),
158                layout="NCHWc",
159                out_layout="NCHWc",
160                out_dtype=out_dtype,
161            )
162        else:
163            conv = topi.nn.conv2d_NCHWc_int8(
164                data,
165                kernel,
166                strides=hstride,
167                padding=hpad,
168                dilation=(1, 1),
169                layout="NCHWc",
170                out_layout="NCHWc",
171                out_dtype=out_dtype,
172            )
173        out = topi.nn.relu(conv)
174        sch = te.create_schedule(out.op)
175        func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name="out")
176        func(data_array, kernel_array, c_orig)
177        LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
178
179        # Generate and run the optimized schedule
180        if out_dtype == "float32":
181            sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out])
182        else:
183            sconv = topi.generic.nn.schedule_conv2d_NCHWc_int8(outs=[out])
184        func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name="conv")
185        func(data_array, kernel_array, c_sch)
186
187        # Functional check
188        if data_dtype == "uint8":
189            np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy())
190        else:
191            assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy())
192
193        evaluator = func.time_evaluator(func.entry_name, CTX, number=1000)
194        LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True))
195        return evaluator(data_array, kernel_array, c_sch).mean
196
197
198if __name__ == "__main__":
199    LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
200    SPEEDUP_ARRAY = []
201    for i, wkl in enumerate(WORKLOADS):
202        for dtype in ["uint", "int"]:
203            fp32_time = run_inference("float32", "float32", "float32", *wkl)
204            int8_time = run_inference("%s8" % dtype, "%s8" % dtype, "%s32" % dtype, *wkl)
205            kernel_h = wkl[4]
206            kernel_w = wkl[5]
207            LOGGER.info(
208                "[%s] Workload#" % dtype
209                + str(i)
210                + ", "
211                + str(kernel_h)
212                + "x"
213                + str(kernel_w)
214                + ", "
215                + str(fp32_time)
216                + ", "
217                + str(int8_time)
218                + ", "
219                + str(fp32_time / int8_time)
220            )
221
222            SPEEDUP_ARRAY.append(fp32_time / int8_time)
223    LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY) / float(len(SPEEDUP_ARRAY))))
224