1from __future__ import absolute_import, division, print_function, unicode_literals 2 3import torch 4 5from tests.utils import jitVsGlow 6import unittest 7 8 9class TestQuantizedLinear(unittest.TestCase): 10 def test_quantized_linear_packed(self): 11 """Basic test of the PyTorch quantized::linear Node on Glow.""" 12 13 q = torch.nn.quantized.Quantize( 14 scale=1 / 25, zero_point=17, dtype=torch.quint8) 15 dq = torch.nn.quantized.DeQuantize() 16 17 linear = torch.nn.Linear(5, 5) 18 19 linear.weight.data.fill_(1.2) 20 linear.bias.data.fill_(3.0) 21 22 model = torch.nn.Sequential(q, linear, dq) 23 model.qconfig = torch.quantization.get_default_qconfig("fbgemm") 24 torch.quantization.prepare(model, inplace=True) 25 torch.quantization.convert(model, inplace=True) 26 27 x = torch.tensor(range(5), dtype=torch.float) 28 x = torch.cat((x, x, x, x, x)) 29 x = torch.reshape(x, [5, 5]) 30 31 jitVsGlow( 32 model, 33 x, 34 expected_fused_ops={ 35 "aten::quantize_per_tensor", 36 "quantized::linear", 37 "aten::dequantize", 38 }, 39 ) 40 41 def test_quantized_linear_packed_dq_cut(self): 42 """Basic test of the PyTorch quantized::linear Node on Glow, with dequantize excluded. """ 43 44 q = torch.nn.quantized.Quantize( 45 scale=1 / 25, zero_point=17, dtype=torch.quint8) 46 dq = torch.nn.quantized.DeQuantize() 47 48 linear = torch.nn.Linear(5, 5) 49 50 linear.weight.data.fill_(1.2) 51 linear.bias.data.fill_(3.0) 52 53 model = torch.nn.Sequential(q, linear, dq) 54 model.qconfig = torch.quantization.get_default_qconfig("fbgemm") 55 torch.quantization.prepare(model, inplace=True) 56 torch.quantization.convert(model, inplace=True) 57 58 x = torch.tensor(range(5), dtype=torch.float) 59 x = torch.cat((x, x, x, x, x)) 60 x = torch.reshape(x, [5, 5]) 61 62 jitVsGlow( 63 model, 64 x, 65 expected_fused_ops={ 66 "aten::quantize_per_tensor", 67 "quantized::linear", 68 }, 69 black_list=[ 70 "aten::dequantize", 71 ] 72 ) 73 74 @unittest.skip(reason="random input could cause flaky") 75 def test_quantized_linear_random_input(self): 76 """Basic test of the PyTorch quantized::linear Node on Glow.""" 77 78 def test_f(inputs, weights, bias=None): 79 q_int = torch.nn.quantized.Quantize( 80 scale=1 / 13, zero_point=0, dtype=torch.qint8 81 ) 82 q_uint = torch.nn.quantized.Quantize( 83 scale=1 / 13, zero_point=10, dtype=torch.quint8 84 ) 85 86 dq = torch.nn.quantized.DeQuantize() 87 88 q_inputs = q_uint(inputs) 89 q_weights = q_int(weights) 90 91 return dq(torch.nn.quantized.functional.linear(q_inputs, q_weights, bias)) 92 93 for _ in range(100): 94 inputs = torch.randn(7, 7) 95 weights = torch.randn(7, 7) 96 97 bias = torch.tensor([1, 1, 1, 1, 1, 1, 1], dtype=torch.float) * 0.1 98 99 jitVsGlow( 100 test_f, 101 inputs, 102 weights, 103 bias, 104 expected_fused_ops={ 105 "glow::unpacked_quantized_linear", 106 "aten::quantize_per_tensor", 107 "aten::dequantize", 108 }, 109 ) 110 111 def test_quantized_linear_packed_rowwise(self): 112 """Basic test of the PyTorch quantized::linear Node with rowwise quantized 113 packed weights on Glow.""" 114 115 linear = torch.nn.Linear(6, 5) 116 linear.weight.data.random_(0, 100) 117 linear.bias.data.random_(0, 10) 118 119 x = torch.tensor(range(30), dtype=torch.float) 120 x = torch.reshape(x, [5, 6]) 121 122 model = torch.quantization.QuantWrapper(linear) 123 model.qconfig = torch.quantization.get_default_qconfig('fbgemm') 124 torch.quantization.prepare(model, inplace=True) 125 torch.quantization.convert(model, inplace=True) 126 127 jitVsGlow(model, x, expected_fused_ops={"aten::quantize_per_tensor", 128 "quantized::linear", 129 "aten::dequantize"}) 130