1#!/usr/bin/env python
2###############################################################################
3# $Id$
4#
5# Project:  Adapted from GDAL/OGR Test Suite
6# Purpose:  Fake HTTP server
7# Author:   Even Rouault <even dot rouault at spatialys.com>
8#
9###############################################################################
10# Copyright (c) 2010-2020, Even Rouault <even dot rouault at spatialys.com>
11#
12# Permission is hereby granted, free of charge, to any person obtaining a
13# copy of this software and associated documentation files (the "Software"),
14# to deal in the Software without restriction, including without limitation
15# the rights to use, copy, modify, merge, publish, distribute, sublicense,
16# and/or sell copies of the Software, and to permit persons to whom the
17# Software is furnished to do so, subject to the following conditions:
18#
19# The above copyright notice and this permission notice shall be included
20# in all copies or substantial portions of the Software.
21#
22# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
23# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
25# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
27# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
28# DEALINGS IN THE SOFTWARE.
29###############################################################################
30
31from http.server import BaseHTTPRequestHandler, HTTPServer
32from threading import Thread
33
34import contextlib
35import time
36import sys
37
38do_log = False
39custom_handler = None
40
41
42@contextlib.contextmanager
43def install_http_handler(handler_instance):
44    global custom_handler
45    custom_handler = handler_instance
46    try:
47        yield
48    finally:
49        handler_instance.final_check()
50        custom_handler = None
51
52
53class RequestResponse(object):
54    def __init__(self, method, path, code, headers=None, body=None, custom_method=None, expected_headers=None, expected_body=None, add_content_length_header=True, unexpected_headers=[]):
55        self.method = method
56        self.path = path
57        self.code = code
58        self.headers = {} if headers is None else headers
59        self.body = body
60        self.custom_method = custom_method
61        self.expected_headers = {} if expected_headers is None else expected_headers
62        self.expected_body = expected_body
63        self.add_content_length_header = add_content_length_header
64        self.unexpected_headers = unexpected_headers
65
66
67class SequentialHandler(object):
68    def __init__(self):
69        self.req_count = 0
70        self.req_resp = []
71        self.unexpected = False
72
73    def final_check(self):
74        assert not self.unexpected
75        assert self.req_count == len(self.req_resp), (self.req_count, len(self.req_resp))
76
77    def add(self, method, path, code=None, headers=None, body=None, custom_method=None, expected_headers=None, expected_body=None, add_content_length_header=True, unexpected_headers=[]):
78        hdrs = {} if headers is None else headers
79        expected_hdrs = {} if expected_headers is None else expected_headers
80        req = RequestResponse(method, path, code, hdrs, body, custom_method, expected_hdrs, expected_body, add_content_length_header, unexpected_headers)
81        self.req_resp.append(req)
82        return req
83
84    def _process_req_resp(self, req_resp, request):
85        if req_resp.custom_method:
86            req_resp.custom_method(request)
87        else:
88
89            if req_resp.expected_headers:
90                for k in req_resp.expected_headers:
91                    if k not in request.headers or request.headers[k] != req_resp.expected_headers[k]:
92                        sys.stderr.write('Did not get expected headers: %s\n' % str(request.headers))
93                        request.send_response(400)
94                        request.send_header('Content-Length', 0)
95                        request.end_headers()
96                        self.unexpected = True
97                        return
98
99            for k in req_resp.unexpected_headers:
100                if k in request.headers:
101                    sys.stderr.write('Did not expect header: %s\n' % k)
102                    request.send_response(400)
103                    request.send_header('Content-Length', 0)
104                    request.end_headers()
105                    self.unexpected = True
106                    return
107
108            if req_resp.expected_body:
109                content = request.rfile.read(int(request.headers['Content-Length']))
110                if content != req_resp.expected_body:
111                    sys.stderr.write('Did not get expected content: %s\n' % content)
112                    request.send_response(400)
113                    request.send_header('Content-Length', 0)
114                    request.end_headers()
115                    self.unexpected = True
116                    return
117
118            request.send_response(req_resp.code)
119            for k in req_resp.headers:
120                request.send_header(k, req_resp.headers[k])
121            if req_resp.add_content_length_header:
122                if req_resp.body:
123                    request.send_header('Content-Length', len(req_resp.body))
124                elif 'Content-Length' not in req_resp.headers:
125                    request.send_header('Content-Length', '0')
126            request.end_headers()
127            if req_resp.body:
128                try:
129                    request.wfile.write(req_resp.body)
130                except:
131                    request.wfile.write(req_resp.body.encode('ascii'))
132
133    def process(self, method, request):
134        if self.req_count < len(self.req_resp):
135            req_resp = self.req_resp[self.req_count]
136            if method == req_resp.method and request.path == req_resp.path:
137                self.req_count += 1
138                self._process_req_resp(req_resp, request)
139                return
140
141        request.send_error(500, 'Unexpected %s request for %s, req_count = %d' % (method, request.path, self.req_count))
142        self.unexpected = True
143
144    def do_HEAD(self, request):
145        self.process('HEAD', request)
146
147    def do_GET(self, request):
148        self.process('GET', request)
149
150    def do_POST(self, request):
151        self.process('POST', request)
152
153    def do_PUT(self, request):
154        self.process('PUT', request)
155
156    def do_DELETE(self, request):
157        self.process('DELETE', request)
158
159
160class DispatcherHttpHandler(BaseHTTPRequestHandler):
161
162    # protocol_version = 'HTTP/1.1'
163
164    def log_request(self, code='-', size='-'):
165        # pylint: disable=unused-argument
166        pass
167
168    def do_HEAD(self):
169
170        if do_log:
171            f = open('/tmp/log.txt', 'a')
172            f.write('HEAD %s\n' % self.path)
173            f.close()
174
175        custom_handler.do_HEAD(self)
176
177    def do_DELETE(self):
178
179        if do_log:
180            f = open('/tmp/log.txt', 'a')
181            f.write('DELETE %s\n' % self.path)
182            f.close()
183
184        custom_handler.do_DELETE(self)
185
186    def do_POST(self):
187
188        if do_log:
189            f = open('/tmp/log.txt', 'a')
190            f.write('POST %s\n' % self.path)
191            f.close()
192
193        custom_handler.do_POST(self)
194
195    def do_PUT(self):
196
197        if do_log:
198            f = open('/tmp/log.txt', 'a')
199            f.write('PUT %s\n' % self.path)
200            f.close()
201
202        custom_handler.do_PUT(self)
203
204    def do_GET(self):
205
206        if do_log:
207            f = open('/tmp/log.txt', 'a')
208            f.write('GET %s\n' % self.path)
209            f.close()
210
211        custom_handler.do_GET(self)
212
213
214class ThreadedHttpServer(Thread):
215
216    def __init__(self, handlerClass):
217        Thread.__init__(self)
218        self.server = HTTPServer(('', 0), handlerClass)
219        self.running = False
220
221    def getPort(self):
222        return self.server.server_address[1]
223
224    def run(self):
225        try:
226            self.running = True
227            self.server.serve_forever()
228        except KeyboardInterrupt:
229            print('^C received, shutting down server')
230            self.stop()
231
232    def start_and_wait_ready(self):
233        self.start()
234        while not self.running:
235            time.sleep(1)
236
237    def stop(self):
238        self.server.shutdown()
239        self.server.server_close()
240
241
242def launch(handler=None):
243    if handler is None:
244        handler = DispatcherHttpHandler
245    server = ThreadedHttpServer(handler)
246    server.start_and_wait_ready()
247    return server, server.getPort()
248
249
250@contextlib.contextmanager
251def install_http_server(handler=None):
252    server, port = launch(handler)
253    try:
254        yield port
255    finally:
256        server.stop()
257