1from __future__ import absolute_import, division, print_function, unicode_literals
2
3import torch
4
5from tests.utils import jitVsGlow
6import unittest
7
8
9class TestQuantizedLinear(unittest.TestCase):
10    def test_quantized_linear_packed(self):
11        """Basic test of the PyTorch quantized::linear Node on Glow."""
12
13        q = torch.nn.quantized.Quantize(
14            scale=1 / 25, zero_point=17, dtype=torch.quint8)
15        dq = torch.nn.quantized.DeQuantize()
16
17        linear = torch.nn.Linear(5, 5)
18
19        linear.weight.data.fill_(1.2)
20        linear.bias.data.fill_(3.0)
21
22        model = torch.nn.Sequential(q, linear, dq)
23        model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
24        torch.quantization.prepare(model, inplace=True)
25        torch.quantization.convert(model, inplace=True)
26
27        x = torch.tensor(range(5), dtype=torch.float)
28        x = torch.cat((x, x, x, x, x))
29        x = torch.reshape(x, [5, 5])
30
31        jitVsGlow(
32            model,
33            x,
34            expected_fused_ops={
35                "aten::quantize_per_tensor",
36                "quantized::linear",
37                "aten::dequantize",
38            },
39        )
40
41    def test_quantized_linear_packed_dq_cut(self):
42        """Basic test of the PyTorch quantized::linear Node on Glow, with dequantize excluded. """
43
44        q = torch.nn.quantized.Quantize(
45            scale=1 / 25, zero_point=17, dtype=torch.quint8)
46        dq = torch.nn.quantized.DeQuantize()
47
48        linear = torch.nn.Linear(5, 5)
49
50        linear.weight.data.fill_(1.2)
51        linear.bias.data.fill_(3.0)
52
53        model = torch.nn.Sequential(q, linear, dq)
54        model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
55        torch.quantization.prepare(model, inplace=True)
56        torch.quantization.convert(model, inplace=True)
57
58        x = torch.tensor(range(5), dtype=torch.float)
59        x = torch.cat((x, x, x, x, x))
60        x = torch.reshape(x, [5, 5])
61
62        jitVsGlow(
63            model,
64            x,
65            expected_fused_ops={
66                "aten::quantize_per_tensor",
67                "quantized::linear",
68            },
69            black_list=[
70                "aten::dequantize",
71            ]
72        )
73
74    @unittest.skip(reason="random input could cause flaky")
75    def test_quantized_linear_random_input(self):
76        """Basic test of the PyTorch quantized::linear Node on Glow."""
77
78        def test_f(inputs, weights, bias=None):
79            q_int = torch.nn.quantized.Quantize(
80                scale=1 / 13, zero_point=0, dtype=torch.qint8
81            )
82            q_uint = torch.nn.quantized.Quantize(
83                scale=1 / 13, zero_point=10, dtype=torch.quint8
84            )
85
86            dq = torch.nn.quantized.DeQuantize()
87
88            q_inputs = q_uint(inputs)
89            q_weights = q_int(weights)
90
91            return dq(torch.nn.quantized.functional.linear(q_inputs, q_weights, bias))
92
93        for _ in range(100):
94            inputs = torch.randn(7, 7)
95            weights = torch.randn(7, 7)
96
97            bias = torch.tensor([1, 1, 1, 1, 1, 1, 1], dtype=torch.float) * 0.1
98
99            jitVsGlow(
100                test_f,
101                inputs,
102                weights,
103                bias,
104                expected_fused_ops={
105                    "glow::unpacked_quantized_linear",
106                    "aten::quantize_per_tensor",
107                    "aten::dequantize",
108                },
109            )
110
111    def test_quantized_linear_packed_rowwise(self):
112        """Basic test of the PyTorch quantized::linear Node with rowwise quantized
113        packed weights on Glow."""
114
115        linear = torch.nn.Linear(6, 5)
116        linear.weight.data.random_(0, 100)
117        linear.bias.data.random_(0, 10)
118
119        x = torch.tensor(range(30), dtype=torch.float)
120        x = torch.reshape(x, [5, 6])
121
122        model = torch.quantization.QuantWrapper(linear)
123        model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
124        torch.quantization.prepare(model, inplace=True)
125        torch.quantization.convert(model, inplace=True)
126
127        jitVsGlow(model, x, expected_fused_ops={"aten::quantize_per_tensor",
128                                                "quantized::linear",
129                                                "aten::dequantize"})
130