1"""
2Test the nthderiv module.
3
4Note that some of the numdifftools documentation is out of date.
5To figure out the right way to use numdifftools I had to directly inspect
6http://code.google.com/p/numdifftools/source/browse/trunk/numdifftools/core.py
7Also note that comparisons that use relative tolerances,
8such as are used by assert_allclose by default,
9do not work well when you are comparing to zero.
10"""
11
12import functools
13import warnings
14
15import numpy as np
16np.random.seed(0)
17
18import numpy.testing
19from numpy.testing import assert_allclose, assert_equal
20
21from algopy import nthderiv
22
23try:
24    import mpmath
25except ImportError as e:
26    warnings.warn('some tests require the mpmath package')
27    mpmath = None
28
29try:
30    import sympy
31except ImportError as e:
32    warnings.warn('some tests require the sympy package')
33    sympy = None
34
35try:
36    import numdifftools
37except ImportError as e:
38    warnings.warn('some tests require the numdifftools package')
39    numdifftools = None
40
41# an example list of simple x values to be used for testing
42g_simple_xs = [
43        0.135,
44        -0.567,
45        1.1234,
46        -1.23,
47        ]
48
49# an example list of more complicated x values to be used for testing
50g_complicated_xs = g_simple_xs + [
51        np.array([[0.123, 0.2], [0.93, 0.44]]),
52        ]
53
54def assert_allclose_or_small(a, b, rtol=1e-7, zerotol=1e-7):
55    if np.amax(np.abs(a)) > zerotol or np.amax(np.abs(b)) > zerotol:
56        numpy.testing.assert_allclose(a, b, rtol=rtol)
57
58def gen_named_functions():
59    for name, f in list(nthderiv.__dict__.items()):
60        domain = getattr(f, 'domain', None)
61        extras = getattr(f, 'extras', None)
62        if domain is not None and extras is not None:
63            yield name, f
64
65
66class TestAuto(numpy.testing.TestCase):
67
68    def _test_syntax_helper(self, f, x):
69        args = [1] * f.extras + [x]
70        for n in range(4):
71            #print 'n:', n
72            ya = f(*args, n=n)
73            #print ya
74            # the output shape should match the input shape
75            assert_equal(np.shape(x), np.shape(ya))
76            yb = np.empty_like(x)
77            f(*args, out=yb, n=n)
78            # the inplace and out-of-place modes should give the same output
79            assert_equal(ya, yb)
80
81    def test_syntax(self):
82        with warnings.catch_warnings():
83            warnings.simplefilter('ignore', np.ComplexWarning)
84            #with np.errstate(divide='ignore'):
85            with np.errstate(divide='ignore', invalid='ignore'):
86                for name, f in gen_named_functions():
87                    #print
88                    #print name
89                    for x in g_complicated_xs:
90                        if np.all(f.domain(x)):
91                            #print 'x:', x
92                            self._test_syntax_helper(f, x)
93
94    def _test_numdifftools_helper(self, f, x):
95        extra_args = [1] * f.extras
96        args = extra_args + [x]
97        for n in range(1, 5):
98            #print 'n:', n
99            ya = f(*args, n=n)
100            #print 'ya:', ya
101            f_part = functools.partial(f, *extra_args)
102            yb = numdifftools.Derivative(f_part, n=n)(x)
103            #print 'yb:', yb
104            # detect only gross errors
105            assert_allclose_or_small(ya, yb, rtol=1e-2, zerotol=1e-2)
106
107    @numpy.testing.decorators.skipif(numdifftools is None)
108    def test_numdifftools(self):
109        with warnings.catch_warnings():
110            warnings.simplefilter('ignore', np.ComplexWarning)
111            with np.errstate(divide='ignore', invalid='ignore'):
112                for name, f in gen_named_functions():
113                    #print
114                    #print name
115                    for x in g_simple_xs:
116                        if f.domain(x):
117                            #print 'x:', x
118                            self._test_numdifftools_helper(f, x)
119
120
121class TestExtras(numpy.testing.TestCase):
122    """
123    Test nth derivatives of scalar functions that take auxiliary arguments.
124    """
125
126    def test_clip_n0(self):
127        a_min = -2
128        a_max = 4
129        x = [-3.3, -1, 6, 3.9]
130        a = nthderiv.clip(a_min, a_max, x, n=0)
131        b = np.clip(x, a_min, a_max)
132        assert_allclose(a, b)
133
134    def test_clip_n1(self):
135        a_min = -2
136        a_max = 4
137        x = [-3.3, -1, 6, 3.9]
138        a = nthderiv.clip(a_min, a_max, x, n=1)
139        b = [0, 1, 0, 1]
140        assert_allclose(a, b)
141
142    def test_clip_n2(self):
143        a_min = -2
144        a_max = 4
145        x = [-3.3, -1, 6, 3.9]
146        a = nthderiv.clip(a_min, a_max, x, n=2)
147        b = [0, 0, 0, 0]
148        assert_allclose(a, b)
149
150
151class TestLog(numpy.testing.TestCase):
152
153    def test_log_n0(self):
154        x = 0.123
155        a = nthderiv.log(x, n=0)
156        b = np.log(x)
157        assert_allclose(a, b)
158
159    def test_log_n1(self):
160        x = 0.123
161        a = nthderiv.log(x, n=1)
162        b = 1 / x
163        assert_allclose(a, b)
164
165    def test_log_n2(self):
166        x = 0.123
167        a = nthderiv.log(x, n=2)
168        b = -1 / np.square(x)
169        assert_allclose(a, b)
170
171    def test_log_n3(self):
172        x = 0.123
173        a = nthderiv.log(x, n=3)
174        b = 2 / np.power(x, 3)
175        assert_allclose(a, b)
176
177    def test_log1p_n0(self):
178        x = 0.123
179        a = nthderiv.log1p(x, n=0)
180        b = np.log1p(x)
181        assert_allclose(a, b)
182
183    def test_log1p_n1(self):
184        x = 0.123
185        a = nthderiv.log1p(x, n=1)
186        b = 1 / (x + 1)
187        assert_allclose(a, b)
188
189    def test_log1p_n2(self):
190        x = 0.123
191        a = nthderiv.log1p(x, n=2)
192        b = -1 / np.square(x + 1)
193        assert_allclose(a, b)
194
195    def test_log1p_n3(self):
196        x = 0.123
197        a = nthderiv.log1p(x, n=3)
198        b = 2 / np.power(x + 1, 3)
199        assert_allclose(a, b)
200
201
202
203class TestMisc(numpy.testing.TestCase):
204
205    def test_reciprocal(self):
206        x = 0.123
207        a = nthderiv.reciprocal(x, n=0)
208        b = np.reciprocal(x)
209        assert_allclose(a, b)
210
211    def test_cos(self):
212        x = 0.123
213        a = nthderiv.cos(x, n=0)
214        b = np.cos(x)
215        assert_allclose(a, b)
216
217    @numpy.testing.decorators.skipif(mpmath is None)
218    def test_tan(self):
219        x = 0.123
220        a = nthderiv.tan(x, n=0)
221        b = np.tan(x)
222        assert_allclose(a, b)
223
224    def test_arctan_n0(self):
225        x = 0.123
226        a = nthderiv.arctan(x, n=0)
227        b = np.arctan(x)
228        assert_allclose(a, b)
229
230    def test_arctan_n1(self):
231        x = 0.123
232        a = nthderiv.arctan(x, n=1)
233        b = 1 / (x*x + 1)
234        assert_allclose(a, b)
235
236    def test_cosh(self):
237        x = 0.123
238        a = nthderiv.cosh(x, n=0)
239        b = np.cosh(x)
240        assert_allclose(a, b)
241
242    def test_sqrt_n0(self):
243        x = 0.123
244        a = nthderiv.sqrt(x, n=0)
245        b = np.sqrt(x)
246        assert_allclose(a, b)
247
248    def test_sqrt_n1(self):
249        x = 0.123
250        a = nthderiv.sqrt(x, n=1)
251        b = 1 / (2 * np.sqrt(x))
252        assert_allclose(a, b)
253
254    def test_arccosh_n0(self):
255        x = 2.345
256        a = nthderiv.arccosh(x, n=0)
257        b = np.arccosh(x)
258        assert_allclose(a, b)
259
260    def test_arccosh_n1(self):
261        x = 2.345
262        a = nthderiv.arccosh(x, n=1)
263        b = 1 / np.sqrt(x*x - 1)
264        assert_allclose(a, b)
265
266
267
268if __name__ == '__main__':
269    numpy.testing.run_module_suite()
270