1import json
2import os
3from abc import ABC, abstractmethod
4from unittest.mock import ANY, patch
5
6import pytest
7from async_generator import async_generator, asynccontextmanager, yield_
8
9from geopy import exc
10from geopy.adapters import BaseAsyncAdapter
11from geopy.location import Location
12
13_env = {}
14try:
15    with open(".test_keys") as fp:
16        _env.update(json.loads(fp.read()))
17except IOError:
18    _env.update(os.environ)
19
20
21class SkipIfMissingEnv(dict):
22    def __init__(self, env):
23        super().__init__(env)
24        self.is_internet_access_allowed = None
25
26    def __getitem__(self, key):
27        assert self.is_internet_access_allowed is not None
28        if key not in self:
29            if self.is_internet_access_allowed:
30                pytest.skip("Missing geocoder credential: %s" % (key,))
31            else:
32                # Generate some dummy token. We won't perform a networking
33                # request anyways.
34                return "dummy"
35        return super().__getitem__(key)
36
37
38env = SkipIfMissingEnv(_env)
39
40
41class BaseTestGeocoder(ABC):
42    """
43    Base for geocoder-specific test cases.
44    """
45
46    geocoder = None
47    delta = 0.5
48
49    @pytest.fixture(scope='class', autouse=True)
50    @async_generator
51    async def class_geocoder(_, request, patch_adapter, is_internet_access_allowed):
52        """Prepare a class-level Geocoder instance."""
53        cls = request.cls
54        env.is_internet_access_allowed = is_internet_access_allowed
55
56        geocoder = cls.make_geocoder()
57        cls.geocoder = geocoder
58
59        run_async = isinstance(geocoder.adapter, BaseAsyncAdapter)
60        if run_async:
61            async with geocoder:
62                await yield_(geocoder)
63        else:
64            await yield_(geocoder)
65
66    @classmethod
67    @asynccontextmanager
68    @async_generator
69    async def inject_geocoder(cls, geocoder):
70        """An async context manager allowing to inject a custom
71        geocoder instance in a single test method which will
72        be used by the `geocode_run`/`reverse_run` methods.
73        """
74        with patch.object(cls, 'geocoder', geocoder):
75            run_async = isinstance(geocoder.adapter, BaseAsyncAdapter)
76            if run_async:
77                async with geocoder:
78                    await yield_(geocoder)
79            else:
80                await yield_(geocoder)
81
82    @pytest.fixture(autouse=True)
83    def ensure_no_geocoder_assignment(self):
84        yield
85        assert self.geocoder is type(self).geocoder, (
86            "Detected `self.geocoder` assignment. "
87            "Please use `async with inject_geocoder(my_geocoder):` "
88            "instead, which supports async adapters."
89        )
90
91    @classmethod
92    @abstractmethod
93    def make_geocoder(cls, **kwargs):  # pragma: no cover
94        pass
95
96    async def geocode_run(
97        self, payload, expected,
98        *,
99        skiptest_on_errors=True,
100        expect_failure=False,
101        skiptest_on_failure=False
102    ):
103        """
104        Calls geocoder.geocode(**payload), then checks against `expected`.
105        """
106        cls = type(self)
107        result = await self._make_request(
108            self.geocoder, 'geocode',
109            skiptest_on_errors=skiptest_on_errors,
110            **payload,
111        )
112        if expect_failure:
113            assert result is None
114            return
115        if result is None:
116            if skiptest_on_failure:
117                pytest.skip('%s: Skipping test due to empty result' % cls.__name__)
118            else:
119                pytest.fail('%s: No result found' % cls.__name__)
120        if result == []:
121            pytest.fail('%s returned an empty list instead of None' % cls.__name__)
122        self._verify_request(result, exactly_one=payload.get('exactly_one', True),
123                             **expected)
124        return result
125
126    async def reverse_run(
127        self, payload, expected,
128        *,
129        skiptest_on_errors=True,
130        expect_failure=False,
131        skiptest_on_failure=False
132    ):
133        """
134        Calls geocoder.reverse(**payload), then checks against `expected`.
135        """
136        cls = type(self)
137        result = await self._make_request(
138            self.geocoder, 'reverse',
139            skiptest_on_errors=skiptest_on_errors,
140            **payload,
141        )
142        if expect_failure:
143            assert result is None
144            return
145        if result is None:
146            if skiptest_on_failure:
147                pytest.skip('%s: Skipping test due to empty result' % cls.__name__)
148            else:
149                pytest.fail('%s: No result found' % cls.__name__)
150        if result == []:
151            pytest.fail('%s returned an empty list instead of None' % cls.__name__)
152        self._verify_request(result, exactly_one=payload.get('exactly_one', True),
153                             **expected)
154        return result
155
156    async def reverse_timezone_run(self, payload, expected, *, skiptest_on_errors=True):
157        timezone = await self._make_request(
158            self.geocoder, 'reverse_timezone',
159            skiptest_on_errors=skiptest_on_errors,
160            **payload,
161        )
162        if expected is None:
163            assert timezone is None
164        else:
165            assert timezone.pytz_timezone == expected
166
167        return timezone
168
169    async def _make_request(self, geocoder, method, *, skiptest_on_errors, **kwargs):
170        cls = type(self)
171        call = getattr(geocoder, method)
172        run_async = isinstance(geocoder.adapter, BaseAsyncAdapter)
173        try:
174            if run_async:
175                result = await call(**kwargs)
176            else:
177                result = call(**kwargs)
178        except exc.GeocoderRateLimited as e:
179            if not skiptest_on_errors:
180                raise
181            pytest.skip(
182                "%s: Rate-limited, retry-after %s" % (cls.__name__, e.retry_after)
183            )
184        except exc.GeocoderQuotaExceeded:
185            if not skiptest_on_errors:
186                raise
187            pytest.skip("%s: Quota exceeded" % cls.__name__)
188        except exc.GeocoderTimedOut:
189            if not skiptest_on_errors:
190                raise
191            pytest.skip("%s: Service timed out" % cls.__name__)
192        except exc.GeocoderUnavailable:
193            if not skiptest_on_errors:
194                raise
195            pytest.skip("%s: Service unavailable" % cls.__name__)
196        return result
197
198    def _verify_request(
199            self,
200            result,
201            latitude=ANY,
202            longitude=ANY,
203            address=ANY,
204            exactly_one=True,
205            delta=None,
206    ):
207        if exactly_one:
208            assert isinstance(result, Location)
209        else:
210            assert isinstance(result, list)
211
212        item = result if exactly_one else result[0]
213        delta = delta or self.delta
214
215        expected = (
216            pytest.approx(latitude, abs=delta) if latitude is not ANY else ANY,
217            pytest.approx(longitude, abs=delta) if longitude is not ANY else ANY,
218            address,
219        )
220        received = (
221            item.latitude,
222            item.longitude,
223            item.address,
224        )
225        assert received == expected
226