1import pytest
2import numpy as np
3from gpaw.test import equal
4from gpaw.grid_descriptor import GridDescriptor
5from gpaw.wavefunctions.pw import PWDescriptor
6from gpaw.mpi import world
7
8
9@pytest.mark.ci
10def test_pw_interpol():
11    def test(gd1, gd2, pd1, pd2, R1, R2):
12        a1 = gd1.zeros(dtype=pd1.dtype)
13        a1[R1] = 1
14        a2 = pd1.interpolate(a1, pd2)[0]
15        x = a2[R2]
16
17        a2 = gd2.zeros(dtype=pd2.dtype)
18        a2[R2] = 1
19        y = pd2.restrict(a2, pd1)[0][R1] * a2.size / a1.size
20
21        equal(x, y, 1e-9)
22        return x
23
24    if world.size == 1:
25        for size1, size2 in [
26            [(3, 3, 3), (8, 8, 8)],
27            [(4, 4, 4), (9, 9, 9)],
28            [(2, 4, 4), (5, 9, 9)],
29            [(2, 3, 4), (5, 6, 9)],
30            [(2, 3, 4), (5, 6, 8)],
31            [(4, 4, 4), (8, 8, 8)],
32            [(2, 4, 4), (4, 8, 8)],
33            [(2, 4, 2), (4, 8, 4)]]:
34            print(size1, size2)
35            gd1 = GridDescriptor(size1, size1)
36            gd2 = GridDescriptor(size2, size1)
37            pd1 = PWDescriptor(1, gd1, complex)
38            pd2 = PWDescriptor(1, gd2, complex)
39            pd1r = PWDescriptor(1, gd1)
40            pd2r = PWDescriptor(1, gd2)
41            for R1, R2 in [[(0, 0, 0), (0, 0, 0)],
42                           [(0, 0, 0), (0, 0, 1)]]:
43                x = test(gd1, gd2, pd1, pd2, R1, R2)
44                y = test(gd1, gd2, pd1r, pd2r, R1, R2)
45                equal(x, y, 1e-9)
46
47            a1 = np.random.random(size1)
48            a2 = pd1r.interpolate(a1, pd2r)[0]
49            c2 = pd1.interpolate(a1 + 0.0j, pd2)[0]
50            d2 = pd1.interpolate(a1 * 1.0j, pd2)[0]
51            equal(abs(c2.imag).max(), 0, 1e-14)
52            equal(abs(d2.real).max(), 0, 1e-14)
53            equal(gd1.integrate(a1), gd2.integrate(a2), 1e-13)
54            equal(abs(c2 - a2).max(), 0, 1e-14)
55            equal(abs(d2 - a2 * 1.0j).max(), 0, 1e-14)
56
57            a1 = pd2r.restrict(a2, pd1r)[0]
58            c1 = pd2.restrict(a2 + 0.0j, pd1)[0]
59            d1 = pd2.restrict(a2 * 1.0j, pd1)[0]
60            equal(gd1.integrate(a1), gd2.integrate(a2), 1e-13)
61            equal(abs(c1 - a1).max(), 0, 1e-14)
62            equal(abs(d1 - a1 * 1.0j).max(), 0, 1e-14)
63