1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc
16import contextlib
17import importlib
18import os
19from os import path
20import pkgutil
21import platform
22import shutil
23import sys
24import tempfile
25import unittest
26
27import six
28
29import grpc
30from grpc_tools import protoc
31from tests.unit import test_common
32
33_MESSAGES_IMPORT = b'import "messages.proto";'
34_SPLIT_NAMESPACE = b'package grpc_protoc_plugin.invocation_testing.split;'
35_COMMON_NAMESPACE = b'package grpc_protoc_plugin.invocation_testing;'
36
37_RELATIVE_PROTO_PATH = 'relative_proto_path'
38_RELATIVE_PYTHON_OUT = 'relative_python_out'
39
40
41@contextlib.contextmanager
42def _system_path(path_insertion):
43    old_system_path = sys.path[:]
44    sys.path = sys.path[0:1] + path_insertion + sys.path[1:]
45    yield
46    sys.path = old_system_path
47
48
49# NOTE(nathaniel): https://twitter.com/exoplaneteer/status/677259364256747520
50# Life lesson "just always default to idempotence" reinforced.
51def _create_directory_tree(root, path_components_sequence):
52    created = set()
53    for path_components in path_components_sequence:
54        thus_far = ''
55        for path_component in path_components:
56            relative_path = path.join(thus_far, path_component)
57            if relative_path not in created:
58                os.makedirs(path.join(root, relative_path))
59                created.add(relative_path)
60            thus_far = path.join(thus_far, path_component)
61
62
63def _massage_proto_content(proto_content, test_name_bytes,
64                           messages_proto_relative_file_name_bytes):
65    package_substitution = (b'package grpc_protoc_plugin.invocation_testing.' +
66                            test_name_bytes + b';')
67    common_namespace_substituted = proto_content.replace(
68        _COMMON_NAMESPACE, package_substitution)
69    split_namespace_substituted = common_namespace_substituted.replace(
70        _SPLIT_NAMESPACE, package_substitution)
71    message_import_replaced = split_namespace_substituted.replace(
72        _MESSAGES_IMPORT,
73        b'import "' + messages_proto_relative_file_name_bytes + b'";')
74    return message_import_replaced
75
76
77def _packagify(directory):
78    for subdirectory, _, _ in os.walk(directory):
79        init_file_name = path.join(subdirectory, '__init__.py')
80        with open(init_file_name, 'wb') as init_file:
81            init_file.write(b'')
82
83
84class _Servicer(object):
85
86    def __init__(self, response_class):
87        self._response_class = response_class
88
89    def Call(self, request, context):
90        return self._response_class()
91
92
93def _protoc(proto_path, python_out, grpc_python_out_flag, grpc_python_out,
94            absolute_proto_file_names):
95    args = [
96        '',
97        '--proto_path={}'.format(proto_path),
98    ]
99    if python_out is not None:
100        args.append('--python_out={}'.format(python_out))
101    if grpc_python_out is not None:
102        args.append('--grpc_python_out={}:{}'.format(grpc_python_out_flag,
103                                                     grpc_python_out))
104    args.extend(absolute_proto_file_names)
105    return protoc.main(args)
106
107
108class _Mid2016ProtocStyle(object):
109
110    def name(self):
111        return 'Mid2016ProtocStyle'
112
113    def grpc_in_pb2_expected(self):
114        return True
115
116    def protoc(self, proto_path, python_out, absolute_proto_file_names):
117        return (_protoc(proto_path, python_out, 'grpc_1_0', python_out,
118                        absolute_proto_file_names),)
119
120
121class _SingleProtocExecutionProtocStyle(object):
122
123    def name(self):
124        return 'SingleProtocExecutionProtocStyle'
125
126    def grpc_in_pb2_expected(self):
127        return False
128
129    def protoc(self, proto_path, python_out, absolute_proto_file_names):
130        return (_protoc(proto_path, python_out, 'grpc_2_0', python_out,
131                        absolute_proto_file_names),)
132
133
134class _ProtoBeforeGrpcProtocStyle(object):
135
136    def name(self):
137        return 'ProtoBeforeGrpcProtocStyle'
138
139    def grpc_in_pb2_expected(self):
140        return False
141
142    def protoc(self, proto_path, python_out, absolute_proto_file_names):
143        pb2_protoc_exit_code = _protoc(proto_path, python_out, None, None,
144                                       absolute_proto_file_names)
145        pb2_grpc_protoc_exit_code = _protoc(
146            proto_path, None, 'grpc_2_0', python_out, absolute_proto_file_names)
147        return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code
148
149
150class _GrpcBeforeProtoProtocStyle(object):
151
152    def name(self):
153        return 'GrpcBeforeProtoProtocStyle'
154
155    def grpc_in_pb2_expected(self):
156        return False
157
158    def protoc(self, proto_path, python_out, absolute_proto_file_names):
159        pb2_grpc_protoc_exit_code = _protoc(
160            proto_path, None, 'grpc_2_0', python_out, absolute_proto_file_names)
161        pb2_protoc_exit_code = _protoc(proto_path, python_out, None, None,
162                                       absolute_proto_file_names)
163        return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code
164
165
166_PROTOC_STYLES = (
167    _Mid2016ProtocStyle(),
168    _SingleProtocExecutionProtocStyle(),
169    _ProtoBeforeGrpcProtocStyle(),
170    _GrpcBeforeProtoProtocStyle(),
171)
172
173
174@unittest.skipIf(platform.python_implementation() == 'PyPy',
175                 'Skip test if run with PyPy!')
176class _Test(six.with_metaclass(abc.ABCMeta, unittest.TestCase)):
177
178    def setUp(self):
179        self._directory = tempfile.mkdtemp(suffix=self.NAME, dir='.')
180        self._proto_path = path.join(self._directory, _RELATIVE_PROTO_PATH)
181        self._python_out = path.join(self._directory, _RELATIVE_PYTHON_OUT)
182
183        os.makedirs(self._proto_path)
184        os.makedirs(self._python_out)
185
186        proto_directories_and_names = {
187            (
188                self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES,
189                self.MESSAGES_PROTO_FILE_NAME,
190            ),
191            (
192                self.SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES,
193                self.SERVICES_PROTO_FILE_NAME,
194            ),
195        }
196        messages_proto_relative_file_name_forward_slashes = '/'.join(
197            self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES +
198            (self.MESSAGES_PROTO_FILE_NAME,))
199        _create_directory_tree(self._proto_path,
200                               (relative_proto_directory_names
201                                for relative_proto_directory_names, _ in
202                                proto_directories_and_names))
203        self._absolute_proto_file_names = set()
204        for relative_directory_names, file_name in proto_directories_and_names:
205            absolute_proto_file_name = path.join(
206                self._proto_path, *relative_directory_names + (file_name,))
207            raw_proto_content = pkgutil.get_data(
208                'tests.protoc_plugin.protos.invocation_testing',
209                path.join(*relative_directory_names + (file_name,)))
210            massaged_proto_content = _massage_proto_content(
211                raw_proto_content, self.NAME.encode(),
212                messages_proto_relative_file_name_forward_slashes.encode())
213            with open(absolute_proto_file_name, 'wb') as proto_file:
214                proto_file.write(massaged_proto_content)
215            self._absolute_proto_file_names.add(absolute_proto_file_name)
216
217    def tearDown(self):
218        shutil.rmtree(self._directory)
219
220    def _protoc(self):
221        protoc_exit_codes = self.PROTOC_STYLE.protoc(
222            self._proto_path, self._python_out, self._absolute_proto_file_names)
223        for protoc_exit_code in protoc_exit_codes:
224            self.assertEqual(0, protoc_exit_code)
225
226        _packagify(self._python_out)
227
228        generated_modules = {}
229        expected_generated_full_module_names = {
230            self.EXPECTED_MESSAGES_PB2,
231            self.EXPECTED_SERVICES_PB2,
232            self.EXPECTED_SERVICES_PB2_GRPC,
233        }
234        with _system_path([self._python_out]):
235            for full_module_name in expected_generated_full_module_names:
236                module = importlib.import_module(full_module_name)
237                generated_modules[full_module_name] = module
238
239        self._messages_pb2 = generated_modules[self.EXPECTED_MESSAGES_PB2]
240        self._services_pb2 = generated_modules[self.EXPECTED_SERVICES_PB2]
241        self._services_pb2_grpc = generated_modules[
242            self.EXPECTED_SERVICES_PB2_GRPC]
243
244    def _services_modules(self):
245        if self.PROTOC_STYLE.grpc_in_pb2_expected():
246            return self._services_pb2, self._services_pb2_grpc
247        else:
248            return (self._services_pb2_grpc,)
249
250    def test_imported_attributes(self):
251        self._protoc()
252
253        self._messages_pb2.Request
254        self._messages_pb2.Response
255        self._services_pb2.DESCRIPTOR.services_by_name['TestService']
256        for services_module in self._services_modules():
257            services_module.TestServiceStub
258            services_module.TestServiceServicer
259            services_module.add_TestServiceServicer_to_server
260
261    def test_call(self):
262        self._protoc()
263
264        for services_module in self._services_modules():
265            server = test_common.test_server()
266            services_module.add_TestServiceServicer_to_server(
267                _Servicer(self._messages_pb2.Response), server)
268            port = server.add_insecure_port('[::]:0')
269            server.start()
270            channel = grpc.insecure_channel('localhost:{}'.format(port))
271            stub = services_module.TestServiceStub(channel)
272            response = stub.Call(self._messages_pb2.Request())
273            self.assertEqual(self._messages_pb2.Response(), response)
274            server.stop(None)
275
276
277def _create_test_case_class(split_proto, protoc_style):
278    attributes = {}
279
280    name = '{}{}'.format('SplitProto' if split_proto else 'SameProto',
281                         protoc_style.name())
282    attributes['NAME'] = name
283
284    if split_proto:
285        attributes['MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES'] = (
286            'split_messages',
287            'sub',
288        )
289        attributes['MESSAGES_PROTO_FILE_NAME'] = 'messages.proto'
290        attributes['SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES'] = (
291            'split_services',)
292        attributes['SERVICES_PROTO_FILE_NAME'] = 'services.proto'
293        attributes['EXPECTED_MESSAGES_PB2'] = 'split_messages.sub.messages_pb2'
294        attributes['EXPECTED_SERVICES_PB2'] = 'split_services.services_pb2'
295        attributes['EXPECTED_SERVICES_PB2_GRPC'] = (
296            'split_services.services_pb2_grpc')
297    else:
298        attributes['MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES'] = ()
299        attributes['MESSAGES_PROTO_FILE_NAME'] = 'same.proto'
300        attributes['SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES'] = ()
301        attributes['SERVICES_PROTO_FILE_NAME'] = 'same.proto'
302        attributes['EXPECTED_MESSAGES_PB2'] = 'same_pb2'
303        attributes['EXPECTED_SERVICES_PB2'] = 'same_pb2'
304        attributes['EXPECTED_SERVICES_PB2_GRPC'] = 'same_pb2_grpc'
305
306    attributes['PROTOC_STYLE'] = protoc_style
307
308    attributes['__module__'] = _Test.__module__
309
310    return type('{}Test'.format(name), (_Test,), attributes)
311
312
313def _create_test_case_classes():
314    for split_proto in (
315            False,
316            True,
317    ):
318        for protoc_style in _PROTOC_STYLES:
319            yield _create_test_case_class(split_proto, protoc_style)
320
321
322def load_tests(loader, tests, pattern):
323    tests = tuple(
324        loader.loadTestsFromTestCase(test_case_class)
325        for test_case_class in _create_test_case_classes())
326    return unittest.TestSuite(tests=tests)
327
328
329if __name__ == '__main__':
330    unittest.main(verbosity=2)
331