1from __future__ import absolute_import, division, print_function, unicode_literals 2 3import torch 4import torch.nn.functional as F 5 6from tests.utils import jitVsGlow 7import unittest 8 9 10class TestBatchNorm(unittest.TestCase): 11 def test_batchnorm_basic(self): 12 """Basic test of the PyTorch batchnorm Node on Glow.""" 13 14 def test_f(inputs, running_mean, running_var): 15 return F.batch_norm(inputs, running_mean, running_var) 16 17 inputs = torch.randn(1, 4, 5, 5) 18 running_mean = torch.rand(4) 19 running_var = torch.rand(4) 20 21 jitVsGlow( 22 test_f, 23 inputs, 24 running_mean, 25 running_var, 26 expected_fused_ops={"aten::batch_norm"}, 27 ) 28 29 def test_batchnorm_with_weights(self): 30 """Test of the PyTorch batchnorm Node with weights and biases on Glow.""" 31 32 def test_f(inputs, weight, bias, running_mean, running_var): 33 return F.batch_norm( 34 inputs, running_mean, running_var, weight=weight, bias=bias 35 ) 36 37 inputs = torch.randn(1, 4, 5, 5) 38 weight = torch.rand(4) 39 bias = torch.rand(4) 40 running_mean = torch.rand(4) 41 running_var = torch.rand(4) 42 43 jitVsGlow( 44 test_f, 45 inputs, 46 weight, 47 bias, 48 running_mean, 49 running_var, 50 expected_fused_ops={"aten::batch_norm"}, 51 ) 52