1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Unit tests for the Bring Your Own Datatype framework.
18
19TODO(@gussmith23 @hypercubestart) link to documentation"""
20import tvm
21import tvm.topi.testing
22import numpy as np
23import pytest
24from numpy.random import MT19937, RandomState, SeedSequence
25from tvm import relay
26from tvm.relay.testing.layers import batch_norm_infer
27from tvm.target.datatype import (
28    register,
29    register_min_func,
30    register_op,
31    create_lower_func,
32    lower_ite,
33    lower_call_pure_extern,
34    create_min_lower_func,
35)
36from tvm.tir.op import call_pure_extern
37
38# note: we can't use relay.testing models because params are randomly initialized,
39# which lead the output to have the same values
40# get mobilenet model from Gluon CV
41# because: https://discuss.tvm.apache.org/t/mobilenet-intermediate-values-are-0/7812
42def get_mobilenet():
43    dshape = (1, 3, 224, 224)
44    from mxnet.gluon.model_zoo.vision import get_model
45
46    block = get_model("mobilenet0.25", pretrained=True)
47    shape_dict = {"data": dshape}
48    return relay.frontend.from_mxnet(block, shape_dict)
49
50
51# use real image instead of random data for end-to-end model training
52# or else output would all be around the same value
53def get_cat_image(dimensions):
54    from tvm.contrib.download import download_testdata
55    from PIL import Image
56
57    url = "https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png"
58    dst = "cat.png"
59    real_dst = download_testdata(url, dst, module="data")
60    img = Image.open(real_dst).resize(dimensions)
61    # CoreML's standard model image format is BGR
62    img_bgr = np.array(img)[:, :, ::-1]
63    img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]
64    return np.asarray(img, dtype="float32")
65
66
67# we use a random seed to generate input_data
68# to guarantee stable tests
69rs = RandomState(MT19937(SeedSequence(123456789)))
70
71
72def convert_ndarray(dst_dtype, array):
73    """Converts NDArray(s) into the specified datatype"""
74    x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
75    cast = relay.Function([x], x.astype(dst_dtype))
76    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
77        return relay.create_executor("graph").evaluate(cast)(array)
78
79
80def change_dtype(src, dst, module, params):
81    """Convert constants and functions in module from src type to dst type.
82    Returns changed module and converted params of type dst_type.
83    """
84    module = relay.frontend.ChangeDatatype(src, dst)(module)
85    module = relay.transform.InferType()(module)
86    params = {k: convert_ndarray(dst, v) for k, v in params.items()}
87    return module, params
88
89
90def compare(module, input, src_dtype, dst_dtype, rtol, atol, params={}, target="llvm"):
91    module = relay.transform.SimplifyInference()(module)
92    ex = relay.create_executor("graph", mod=module)
93
94    correct = ex.evaluate()(*input, **params)
95    module, converted_params = change_dtype(src_dtype, dst_dtype, module, params)
96    ex = relay.create_executor("graph", mod=module, target=target)
97    # converts all inputs to dst_dtype
98    x_converted = [convert_ndarray(dst_dtype, arr) for arr in input]
99
100    # Vectorization is not implemented with custom datatypes
101    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
102        maybe_correct = ex.evaluate()(*x_converted, **converted_params)
103        # currently this only works for comparing single output
104        maybe_correct_converted = convert_ndarray(src_dtype, maybe_correct)
105    np.testing.assert_allclose(
106        maybe_correct_converted.asnumpy(), correct.asnumpy(), rtol=rtol, atol=atol
107    )
108
109
110def setup_myfloat():
111    """Set up tests for myfloat (a custom datatype that under the hood is float)
112
113    Currently, this registers some custom datatypes using the Bring Your
114    Own Datatypes framework.
115    """
116
117    # To use datatype operations in an external library, you should first load
118    # the library containing the datatype implementation:
119    # CDLL("libposit.so", RTLD_GLOBAL)
120    # In this case, the datatype library we are using is built right into TVM,
121    # so we do not need to explicitly load any library.
122
123    # You can pick a code for your datatype arbitrarily, as long as it is
124    # greater than 128 and has not already been chosen.
125    register("myfloat", 131)
126
127    register_op(
128        create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", "float", "myfloat"
129    )
130    register_op(
131        create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", "myfloat", "float"
132    )
133    register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", "myfloat")
134    register_op(
135        create_lower_func(
136            {
137                32: "Custom32Sub",
138            }
139        ),
140        "Sub",
141        "llvm",
142        "myfloat",
143    )
144    register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", "myfloat")
145    register_op(
146        create_lower_func(
147            {
148                32: "FloatToCustom32",
149            }
150        ),
151        "FloatImm",
152        "llvm",
153        "myfloat",
154    )
155    register_op(
156        create_lower_func(
157            {
158                32: "Custom32Div",
159            }
160        ),
161        "Div",
162        "llvm",
163        "myfloat",
164    )
165    register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", "myfloat")
166    register_op(
167        create_lower_func({32: "Custom32Sqrt"}),
168        "Call",
169        "llvm",
170        "myfloat",
171        intrinsic_name="tir.sqrt",
172    )
173    register_op(
174        create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat", intrinsic_name="tir.exp"
175    )
176    register_op(
177        create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat", intrinsic_name="tir.log"
178    )
179    register_op(
180        create_lower_func({32: "Custom32Sigmoid"}),
181        "Call",
182        "llvm",
183        "myfloat",
184        intrinsic_name="tir.sigmoid",
185    )
186    register_op(
187        create_lower_func({32: "Custom32Tanh"}),
188        "Call",
189        "llvm",
190        "myfloat",
191        intrinsic_name="tir.tanh",
192    )
193    register_op(lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else")
194    register_op(
195        lower_call_pure_extern, "Call", "llvm", "myfloat", intrinsic_name="tir.call_pure_extern"
196    )
197
198    register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), "myfloat")
199
200
201def setup_posites2():
202    """Set up tests for posites2
203    Currently, this registers some custom datatypes using the Bring Your
204    Own Datatypes framework.
205    """
206
207    # To use datatype operations in an external library, you should first load
208    # the library containing the datatype implementation:
209    # CDLL("libposit.so", RTLD_GLOBAL)
210    # In this case, the datatype library we are using is built right into TVM,
211    # so we do not need to explicitly load any library.
212
213    # You can pick a code for your datatype arbitrarily, as long as it is
214    # greater than 128 and has not already been chosen.
215
216    register("posites2", 132)
217
218    register_op(
219        create_lower_func(
220            {
221                (32, 32): "FloatToPosit32es2",
222                (32, 16): "FloatToPosit16es2",
223                (32, 8): "FloatToPosit8es2",
224            }
225        ),
226        "Cast",
227        "llvm",
228        "float",
229        "posites2",
230    )
231    register_op(
232        create_lower_func(
233            {
234                (32, 32): "Posit32es2ToFloat",
235                (16, 32): "Posit16es2ToFloat",
236                (8, 32): "Posit8es2ToFloat",
237            }
238        ),
239        "Cast",
240        "llvm",
241        "posites2",
242        "float",
243    )
244    register_op(
245        create_lower_func({32: "Posit32es2Add", 16: "Posit16es2Add", 8: "Posit8es2Add"}),
246        "Add",
247        "llvm",
248        "posites2",
249    )
250    register_op(
251        create_lower_func({32: "Posit32es2Sub", 16: "Posit16es2Sub", 8: "Posit8es2Sub"}),
252        "Sub",
253        "llvm",
254        "posites2",
255    )
256    register_op(
257        create_lower_func(
258            {32: "FloatToPosit32es2", 16: "FloatToPosit16es2", 8: "FloatToPosit8es2"}
259        ),
260        "FloatImm",
261        "llvm",
262        "posites2",
263    )
264    register_op(
265        create_lower_func({32: "Posit32es2Mul", 16: "Posit16es2Mul", 8: "Posit8es2Mul"}),
266        "Mul",
267        "llvm",
268        "posites2",
269    )
270    register_op(
271        create_lower_func({32: "Posit32es2Div", 16: "Posit16es2Div", 8: "Posit8es2Div"}),
272        "Div",
273        "llvm",
274        "posites2",
275    )
276    register_op(
277        create_lower_func({32: "Posit32es2Max", 16: "Posit16es2Max", 8: "Posit8es2Max"}),
278        "Max",
279        "llvm",
280        "posites2",
281    )
282    register_op(
283        create_lower_func({32: "Posit32es2Sqrt", 16: "Posit16es2Sqrt", 8: "Posit8es2Sqrt"}),
284        "Call",
285        "llvm",
286        "posites2",
287        intrinsic_name="tir.sqrt",
288    )
289    register_op(lower_ite, "Call", "llvm", "posites2", intrinsic_name="tir.if_then_else")
290    register_op(
291        lower_call_pure_extern, "Call", "llvm", "posites2", intrinsic_name="tir.call_pure_extern"
292    )
293    register_op(
294        create_lower_func({32: "Posit32es2Exp", 16: "Posit16es2Exp", 8: "Posit8es2Exp"}),
295        "Call",
296        "llvm",
297        "posites2",
298        intrinsic_name="tir.exp",
299    )
300    register_op(
301        create_lower_func({32: "Posit32es2Log", 16: "Posit16es2Log", 8: "Posit8es2Log"}),
302        "Call",
303        "llvm",
304        "posites2",
305        intrinsic_name="tir.log",
306    )
307    register_op(
308        create_lower_func(
309            {32: "Posit32es2Sigmoid", 16: "Posit16es2Sigmoid", 8: "Posit8es2Sigmoid"}
310        ),
311        "Call",
312        "llvm",
313        "posites2",
314        intrinsic_name="tir.sigmoid",
315    )
316    register_op(
317        create_lower_func({32: "Posit32es2Tanh", 16: "Posit16es2Tanh", 8: "Posit8es2Tanh"}),
318        "Call",
319        "llvm",
320        "posites2",
321        intrinsic_name="tir.tanh",
322    )
323
324    register_min_func(
325        create_min_lower_func(
326            {32: "MinPosit32es2", 16: "MinPosit16es2", 8: "MinPosit8es2"}, "posites2"
327        ),
328        "posites2",
329    )
330
331
332def run_ops(src_dtype, dst_dtype, rtol=1e-7, atol=1e-7):
333    """Run the same op, but with two different datatypes"""
334    # used for unary ops, first shape in binary ops
335    shape1 = (5, 10, 5)
336    # second shape for binary ops
337    shape2 = (5,)
338
339    def check_unary_op(op, src_dtype, dst_dtype, shape):
340        t1 = relay.TensorType(shape, src_dtype)
341        x = relay.var("x", t1)
342        z = op(x)
343        x_data = rs.rand(*shape).astype(t1.dtype)
344
345        module = tvm.IRModule.from_expr(relay.Function([x], z))
346
347        compare(module, (x_data,), src_dtype, dst_dtype, rtol, atol)
348
349    # test unary ops
350    for op in [
351        relay.nn.softmax,
352        tvm.relay.log,
353        tvm.relay.exp,
354        tvm.relay.sqrt,
355        tvm.relay.rsqrt,
356        tvm.relay.sigmoid,
357        tvm.relay.tanh,
358        relay.nn.relu,
359        relay.nn.batch_flatten,
360    ]:
361        check_unary_op(op, src_dtype, dst_dtype, shape1)
362
363    # test unary ops over 4d data
364    for op in [relay.nn.max_pool2d, relay.nn.avg_pool2d, relay.nn.global_avg_pool2d]:
365        shape_2d = (3, 32, 32, 32)
366        check_unary_op(op, src_dtype, dst_dtype, shape_2d)
367
368    def check_binary_op(opfunc, src_dtype, dst_dtype):
369        t1 = relay.TensorType(shape1, src_dtype)
370        t2 = relay.TensorType(shape2, src_dtype)
371        x = relay.var("x", t1)
372        y = relay.var("y", t2)
373        z = opfunc(x, y)
374        x_data = rs.rand(*shape1).astype(t1.dtype)
375        y_data = rs.rand(*shape2).astype(t2.dtype)
376        module = tvm.IRModule.from_expr(relay.Function([x, y], z))
377
378        compare(module, (x_data, y_data), src_dtype, dst_dtype, rtol, atol)
379
380    for op in [
381        relay.add,
382        relay.subtract,
383        relay.divide,
384        relay.multiply,
385    ]:
386        check_binary_op(op, src_dtype, dst_dtype)
387
388    # we would like to test tvm_if_then_else
389    # but Relay.IfNode is not lowered to this intrinsic,
390    # so to keep our tests consistent with relay, we decide to not unit test
391    # Note: tvm_if_then_else is tested as part of the mobile_net model
392
393
394def run_model(get_workload, input, src_dtype, dst_dtype, rtol=1e-4, atol=1e-4):
395    module, params = get_workload()
396
397    # we don't generate random data here
398    # because then the output data would all be around the same value
399    compare(module, input, src_dtype, dst_dtype, rtol, atol, params)
400
401
402def run_conv2d(src_dtype, dst_dtype, rtol=1e-7, atol=1e-4):
403    def run_test_conv2d(
404        src_dtype,
405        dst_dtype,
406        scale,
407        dshape,
408        kshape,
409        padding=(1, 1),
410        groups=1,
411        dilation=(1, 1),
412        **attrs,
413    ):
414        x = relay.var("x", shape=dshape, dtype=src_dtype)
415        w = relay.var("w", shape=kshape, dtype=src_dtype)
416        y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
417        module = tvm.IRModule.from_expr(relay.Function([x, w], y))
418        data = rs.uniform(-scale, scale, size=dshape).astype(src_dtype)
419        kernel = rs.uniform(-scale, scale, size=kshape).astype(src_dtype)
420
421        compare(module, (data, kernel), src_dtype, dst_dtype, rtol, atol)
422
423    # depthwise conv2d
424    dshape = (1, 32, 18, 18)
425    kshape = (32, 1, 3, 3)
426    run_test_conv2d(
427        src_dtype,
428        dst_dtype,
429        1,
430        dshape,
431        kshape,
432        padding=(1, 1),
433        channels=32,
434        groups=32,
435        kernel_size=(3, 3),
436    )
437
438    # CUDA is disabled for 'direct' schedule:
439    # https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553
440    # group conv2d
441    dshape = (1, 32, 18, 18)
442    kshape = (32, 4, 3, 3)
443    run_test_conv2d(
444        src_dtype,
445        dst_dtype,
446        1,
447        dshape,
448        kshape,
449        padding=(1, 1),
450        channels=32,
451        groups=8,
452        kernel_size=(3, 3),
453    )
454    # also group conv2d
455    dshape = (1, 32, 18, 18)
456    kshape = (64, 1, 3, 3)
457    run_test_conv2d(
458        src_dtype,
459        dst_dtype,
460        1,
461        dshape,
462        kshape,
463        padding=(1, 1),
464        channels=64,
465        groups=32,
466        kernel_size=(3, 3),
467    )
468
469    # normal conv2d
470    dshape = (1, 3, 224, 224)
471    kshape = (10, 3, 3, 3)
472    run_test_conv2d(
473        src_dtype, dst_dtype, 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3, 3)
474    )
475
476    # dilated conv2d
477    dshape = (1, 3, 18, 18)
478    kshape = (10, 3, 3, 3)
479    run_test_conv2d(
480        src_dtype,
481        dst_dtype,
482        1,
483        dshape,
484        kshape,
485        padding=(1, 1),
486        channels=10,
487        kernel_size=(3, 3),
488        dilation=(3, 3),
489    )
490
491
492def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, atol=1e-6):
493    shape = (3, 32, 32)
494    t = relay.TensorType(shape, src_dtype)
495    x = relay.var("x", t)
496    bn = batch_norm_infer(data=x, epsilon=2e-5, scale=False, name="bn_x")
497    f = relay.Function(relay.analysis.free_vars(bn), bn)
498
499    x_data = rs.rand(*shape).astype(t.dtype)
500    module = tvm.IRModule.from_expr(f)
501
502    zero_data = np.zeros((32), "float32")
503    compare(
504        module,
505        (x_data, zero_data, zero_data, zero_data, zero_data),
506        src_dtype,
507        dst_dtype,
508        rtol,
509        atol,
510    )
511
512
513def test_myfloat():
514    setup_myfloat()
515    run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
516    run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
517    run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
518
519    # mxnet python package not available
520    # run_model(get_mobilenet, (get_cat_image((224, 224)), ),
521    #           'float32',
522    #           'custom[myfloat]32')
523
524
525def _has_posit():
526    return tvm.support.libinfo()["USE_BYODT_POSIT"] == "ON"
527
528
529@pytest.mark.skipif(not _has_posit(), reason="compiled with USE_BYODT_POSIT flag OFF")
530def test_posites2():
531    setup_posites2()
532    run_ops("float32", "custom[posites2]8", rtol=1, atol=1)
533    run_ops("float32", "custom[posites2]16", rtol=0.01, atol=1)
534    run_ops("float32", "custom[posites2]32", rtol=1e-6, atol=1e-6)
535
536    run_conv2d("float32", "custom[posites2]8", rtol=1, atol=1)
537    run_conv2d("float32", "custom[posites2]16", rtol=0.01, atol=1)
538    run_conv2d("float32", "custom[posites2]32")
539
540    run_batchnorm("float32", "custom[posites2]8", rtol=1, atol=1)
541    run_batchnorm("float32", "custom[posites2]16", rtol=0.01, atol=1)
542    run_batchnorm("float32", "custom[posites2]32", rtol=1e-4, atol=1e-4)
543    # Expected posit8 might be faster, but it's not.
544    # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 'float32', 'custom[posit8]8')
545    # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 'float32', 'custom[posit32]32')
546    # run_model(get_inception, (get_cat_image((229, 229)), ), 'float32', 'custom[posit32]32')
547    # run_model(get_resnet, (get_cat_image((224, 224)), ), 'float32', 'custom[posit32]32')
548
549    # can't run cifar-10 sizes because dimensions
550    # don't match pretrained weights
551
552    # runs on the order of minutes...
553    # run_model(get_inception, (get_cat_image((229, 229)), ),
554    #           'float32',
555    #           'custom[posites2]32')
556    # run_model(get_resnet, (get_cat_image((224, 224)), ),
557    #           'float32',
558    #           'custom[posites2]32')
559
560
561if __name__ == "__main__":
562    pytest.main([__file__])
563