1from __future__ import absolute_import, division, print_function, unicode_literals
2
3import torch
4
5from tests.utils import jitVsGlow
6import unittest
7
8
9class TestExp(unittest.TestCase):
10    def test_exp_basic(self):
11        """Test of the PyTorch exp Node on Glow."""
12
13        def test_f(a):
14            b = torch.exp(a)
15            return torch.exp(b)
16
17        x = torch.randn(4)
18
19        jitVsGlow(test_f, x, expected_fused_ops={"aten::exp"})
20