1from __future__ import absolute_import, division, print_function, unicode_literals
2
3import torch
4import torch.nn.functional as F
5
6from tests.utils import jitVsGlow
7import unittest
8
9
10class TestBatchNorm(unittest.TestCase):
11    def test_batchnorm_basic(self):
12        """Basic test of the PyTorch batchnorm Node on Glow."""
13
14        def test_f(inputs, running_mean, running_var):
15            return F.batch_norm(inputs, running_mean, running_var)
16
17        inputs = torch.randn(1, 4, 5, 5)
18        running_mean = torch.rand(4)
19        running_var = torch.rand(4)
20
21        jitVsGlow(
22            test_f,
23            inputs,
24            running_mean,
25            running_var,
26            expected_fused_ops={"aten::batch_norm"},
27        )
28
29    def test_batchnorm_with_weights(self):
30        """Test of the PyTorch batchnorm Node with weights and biases on Glow."""
31
32        def test_f(inputs, weight, bias, running_mean, running_var):
33            return F.batch_norm(
34                inputs, running_mean, running_var, weight=weight, bias=bias
35            )
36
37        inputs = torch.randn(1, 4, 5, 5)
38        weight = torch.rand(4)
39        bias = torch.rand(4)
40        running_mean = torch.rand(4)
41        running_var = torch.rand(4)
42
43        jitVsGlow(
44            test_f,
45            inputs,
46            weight,
47            bias,
48            running_mean,
49            running_var,
50            expected_fused_ops={"aten::batch_norm"},
51        )
52