1# -*- coding: utf-8 -*-
2
3############################ Copyrights and license ############################
4#                                                                              #
5# Copyright 2012 Andrew Bettison <andrewb@zip.com.au>                          #
6# Copyright 2012 Dima Kukushkin <dima@kukushkin.me>                            #
7# Copyright 2012 Michael Woodworth <mwoodworth@upverter.com>                   #
8# Copyright 2012 Petteri Muilu <pmuilu@xena.(none)>                            #
9# Copyright 2012 Steve English <steve.english@navetas.com>                     #
10# Copyright 2012 Vincent Jacques <vincent@vincent-jacques.net>                 #
11# Copyright 2012 Zearin <zearin@gonk.net>                                      #
12# Copyright 2013 AKFish <akfish@gmail.com>                                     #
13# Copyright 2013 Cameron White <cawhite@pdx.edu>                               #
14# Copyright 2013 Ed Jackson <ed.jackson@gmail.com>                             #
15# Copyright 2013 Jonathan J Hunt <hunt@braincorporation.com>                   #
16# Copyright 2013 Mark Roddy <markroddy@gmail.com>                              #
17# Copyright 2013 Vincent Jacques <vincent@vincent-jacques.net>                 #
18# Copyright 2014 Jimmy Zelinskie <jimmyzelinskie@gmail.com>                    #
19# Copyright 2014 Vincent Jacques <vincent@vincent-jacques.net>                 #
20# Copyright 2015 Brian Eugley <Brian.Eugley@capitalone.com>                    #
21# Copyright 2015 Daniel Pocock <daniel@pocock.pro>                             #
22# Copyright 2015 Jimmy Zelinskie <jimmyzelinskie@gmail.com>                    #
23# Copyright 2016 Denis K <f1nal@cgaming.org>                                   #
24# Copyright 2016 Jared K. Smith <jaredsmith@jaredsmith.net>                    #
25# Copyright 2016 Jimmy Zelinskie <jimmy.zelinskie+git@gmail.com>               #
26# Copyright 2016 Mathieu Mitchell <mmitchell@iweb.com>                         #
27# Copyright 2016 Peter Buckley <dx-pbuckley@users.noreply.github.com>          #
28# Copyright 2017 Chris McBride <thehighlander@users.noreply.github.com>        #
29# Copyright 2017 Hugo <hugovk@users.noreply.github.com>                        #
30# Copyright 2017 Simon <spam@esemi.ru>                                         #
31# Copyright 2018 Dylan <djstein@ncsu.edu>                                      #
32# Copyright 2018 Maarten Fonville <mfonville@users.noreply.github.com>         #
33# Copyright 2018 Mike Miller <github@mikeage.net>                              #
34# Copyright 2018 R1kk3r <R1kk3r@users.noreply.github.com>                      #
35# Copyright 2018 sfdye <tsfdye@gmail.com>                                      #
36#                                                                              #
37# This file is part of PyGithub.                                               #
38# http://pygithub.readthedocs.io/                                              #
39#                                                                              #
40# PyGithub is free software: you can redistribute it and/or modify it under    #
41# the terms of the GNU Lesser General Public License as published by the Free  #
42# Software Foundation, either version 3 of the License, or (at your option)    #
43# any later version.                                                           #
44#                                                                              #
45# PyGithub is distributed in the hope that it will be useful, but WITHOUT ANY  #
46# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS    #
47# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more #
48# details.                                                                     #
49#                                                                              #
50# You should have received a copy of the GNU Lesser General Public License     #
51# along with PyGithub. If not, see <http://www.gnu.org/licenses/>.             #
52#                                                                              #
53################################################################################
54
55import base64
56import json
57import logging
58import mimetypes
59import os
60import re
61import time
62import urllib
63from io import IOBase
64
65import requests
66
67from . import Consts, GithubException
68
69
70class RequestsResponse:
71    # mimic the httplib response object
72    def __init__(self, r):
73        self.status = r.status_code
74        self.headers = r.headers
75        self.text = r.text
76
77    def getheaders(self):
78        return self.headers.items()
79
80    def read(self):
81        return self.text
82
83
84class HTTPSRequestsConnectionClass(object):
85    # mimic the httplib connection object
86    def __init__(
87        self, host, port=None, strict=False, timeout=None, retry=None, **kwargs
88    ):
89        self.port = port if port else 443
90        self.host = host
91        self.protocol = "https"
92        self.timeout = timeout
93        self.verify = kwargs.get("verify", True)
94        self.session = requests.Session()
95        # Code to support retries
96        if retry:
97            self.retry = retry
98            self.adapter = requests.adapters.HTTPAdapter(max_retries=self.retry)
99            self.session.mount("https://", self.adapter)
100
101    def request(self, verb, url, input, headers):
102        self.verb = verb
103        self.url = url
104        self.input = input
105        self.headers = headers
106
107    def getresponse(self):
108        verb = getattr(self.session, self.verb.lower())
109        url = "%s://%s:%s%s" % (self.protocol, self.host, self.port, self.url)
110        r = verb(
111            url,
112            headers=self.headers,
113            data=self.input,
114            timeout=self.timeout,
115            verify=self.verify,
116            allow_redirects=False,
117        )
118        return RequestsResponse(r)
119
120    def close(self):
121        return
122
123
124class HTTPRequestsConnectionClass(object):
125    # mimic the httplib connection object
126    def __init__(
127        self, host, port=None, strict=False, timeout=None, retry=None, **kwargs
128    ):
129        self.port = port if port else 80
130        self.host = host
131        self.protocol = "http"
132        self.timeout = timeout
133        self.verify = kwargs.get("verify", True)
134        self.session = requests.Session()
135        # Code to support retries
136        if retry:
137            self.retry = retry
138            self.adapter = requests.adapters.HTTPAdapter(max_retries=self.retry)
139            self.session.mount("http://", self.adapter)
140
141    def request(self, verb, url, input, headers):
142        self.verb = verb
143        self.url = url
144        self.input = input
145        self.headers = headers
146
147    def getresponse(self):
148        verb = getattr(self.session, self.verb.lower())
149        url = "%s://%s:%s%s" % (self.protocol, self.host, self.port, self.url)
150        r = verb(
151            url,
152            headers=self.headers,
153            data=self.input,
154            timeout=self.timeout,
155            verify=self.verify,
156            allow_redirects=False,
157        )
158        return RequestsResponse(r)
159
160    def close(self):
161        return
162
163
164class Requester:
165    __httpConnectionClass = HTTPRequestsConnectionClass
166    __httpsConnectionClass = HTTPSRequestsConnectionClass
167    __connection = None
168    __persist = True
169    __logger = None
170
171    @classmethod
172    def injectConnectionClasses(cls, httpConnectionClass, httpsConnectionClass):
173        cls.__persist = False
174        cls.__httpConnectionClass = httpConnectionClass
175        cls.__httpsConnectionClass = httpsConnectionClass
176
177    @classmethod
178    def resetConnectionClasses(cls):
179        cls.__persist = True
180        cls.__httpConnectionClass = HTTPRequestsConnectionClass
181        cls.__httpsConnectionClass = HTTPSRequestsConnectionClass
182
183    @classmethod
184    def injectLogger(cls, logger):
185        cls.__logger = logger
186
187    @classmethod
188    def resetLogger(cls):
189        cls.__logger = None
190
191    #############################################################
192    # For Debug
193    @classmethod
194    def setDebugFlag(cls, flag):
195        cls.DEBUG_FLAG = flag
196
197    @classmethod
198    def setOnCheckMe(cls, onCheckMe):
199        cls.ON_CHECK_ME = onCheckMe
200
201    DEBUG_FLAG = False
202
203    DEBUG_FRAME_BUFFER_SIZE = 1024
204
205    DEBUG_HEADER_KEY = "DEBUG_FRAME"
206
207    ON_CHECK_ME = None
208
209    def NEW_DEBUG_FRAME(self, requestHeader):
210        """
211        Initialize a debug frame with requestHeader
212        Frame count is updated and will be attached to respond header
213        The structure of a frame: [requestHeader, statusCode, responseHeader, raw_data]
214        Some of them may be None
215        """
216        if self.DEBUG_FLAG:  # pragma no branch (Flag always set in tests)
217            new_frame = [requestHeader, None, None, None]
218            if (
219                self._frameCount < self.DEBUG_FRAME_BUFFER_SIZE - 1
220            ):  # pragma no branch (Should be covered)
221                self._frameBuffer.append(new_frame)
222            else:
223                self._frameBuffer[0] = new_frame  # pragma no cover (Should be covered)
224
225            self._frameCount = len(self._frameBuffer) - 1
226
227    def DEBUG_ON_RESPONSE(self, statusCode, responseHeader, data):
228        """
229        Update current frame with response
230        Current frame index will be attached to responseHeader
231        """
232        if self.DEBUG_FLAG:  # pragma no branch (Flag always set in tests)
233            self._frameBuffer[self._frameCount][1:4] = [
234                statusCode,
235                responseHeader,
236                data,
237            ]
238            responseHeader[self.DEBUG_HEADER_KEY] = self._frameCount
239
240    def check_me(self, obj):
241        if (
242            self.DEBUG_FLAG and self.ON_CHECK_ME is not None
243        ):  # pragma no branch (Flag always set in tests)
244            frame = None
245            if self.DEBUG_HEADER_KEY in obj._headers:
246                frame_index = obj._headers[self.DEBUG_HEADER_KEY]
247                frame = self._frameBuffer[frame_index]
248            self.ON_CHECK_ME(obj, frame)
249
250    def _initializeDebugFeature(self):
251        self._frameCount = 0
252        self._frameBuffer = []
253
254    #############################################################
255
256    def __init__(
257        self,
258        login_or_token,
259        password,
260        jwt,
261        base_url,
262        timeout,
263        client_id,
264        client_secret,
265        user_agent,
266        per_page,
267        verify,
268        retry,
269    ):
270        self._initializeDebugFeature()
271
272        if password is not None:
273            login = login_or_token
274            self.__authorizationHeader = "Basic " + base64.b64encode(
275                (login + ":" + password).encode("utf-8")
276            ).decode("utf-8").replace("\n", "")
277        elif login_or_token is not None:
278            token = login_or_token
279            self.__authorizationHeader = "token " + token
280        elif jwt is not None:
281            self.__authorizationHeader = "Bearer " + jwt
282        else:
283            self.__authorizationHeader = None
284
285        self.__base_url = base_url
286        o = urllib.parse.urlparse(base_url)
287        self.__hostname = o.hostname
288        self.__port = o.port
289        self.__prefix = o.path
290        self.__timeout = timeout
291        self.__retry = retry  # NOTE: retry can be either int or an urllib3 Retry object
292        self.__scheme = o.scheme
293        if o.scheme == "https":
294            self.__connectionClass = self.__httpsConnectionClass
295        elif o.scheme == "http":
296            self.__connectionClass = self.__httpConnectionClass
297        else:
298            assert False, "Unknown URL scheme"
299        self.rate_limiting = (-1, -1)
300        self.rate_limiting_resettime = 0
301        self.FIX_REPO_GET_GIT_REF = True
302        self.per_page = per_page
303
304        self.oauth_scopes = None
305
306        self.__clientId = client_id
307        self.__clientSecret = client_secret
308
309        assert user_agent is not None, (
310            "github now requires a user-agent. "
311            "See http://developer.github.com/v3/#user-agent-required"
312        )
313        self.__userAgent = user_agent
314        self.__verify = verify
315
316    def requestJsonAndCheck(self, verb, url, parameters=None, headers=None, input=None):
317        return self.__check(
318            *self.requestJson(
319                verb, url, parameters, headers, input, self.__customConnection(url)
320            )
321        )
322
323    def requestMultipartAndCheck(
324        self, verb, url, parameters=None, headers=None, input=None
325    ):
326        return self.__check(
327            *self.requestMultipart(
328                verb, url, parameters, headers, input, self.__customConnection(url)
329            )
330        )
331
332    def requestBlobAndCheck(self, verb, url, parameters=None, headers=None, input=None):
333        return self.__check(
334            *self.requestBlob(
335                verb, url, parameters, headers, input, self.__customConnection(url)
336            )
337        )
338
339    def __check(self, status, responseHeaders, output):
340        output = self.__structuredFromJson(output)
341        if status >= 400:
342            raise self.__createException(status, responseHeaders, output)
343        return responseHeaders, output
344
345    def __customConnection(self, url):
346        cnx = None
347        if not url.startswith("/"):
348            o = urllib.parse.urlparse(url)
349            if (
350                o.hostname != self.__hostname
351                or (o.port and o.port != self.__port)
352                or (
353                    o.scheme != self.__scheme
354                    and not (o.scheme == "https" and self.__scheme == "http")
355                )
356            ):  # issue80
357                if o.scheme == "http":
358                    cnx = self.__httpConnectionClass(
359                        o.hostname, o.port, retry=self.__retry
360                    )
361                elif o.scheme == "https":
362                    cnx = self.__httpsConnectionClass(
363                        o.hostname, o.port, retry=self.__retry
364                    )
365        return cnx
366
367    def __createException(self, status, headers, output):
368        if status == 401 and output.get("message") == "Bad credentials":
369            cls = GithubException.BadCredentialsException
370        elif (
371            status == 401
372            and Consts.headerOTP in headers
373            and re.match(r".*required.*", headers[Consts.headerOTP])
374        ):
375            cls = GithubException.TwoFactorException
376        elif status == 403 and output.get("message").startswith(
377            "Missing or invalid User Agent string"
378        ):
379            cls = GithubException.BadUserAgentException
380        elif status == 403 and (
381            output.get("message").lower().startswith("api rate limit exceeded")
382            or output.get("message")
383            .lower()
384            .endswith("please wait a few minutes before you try again.")
385        ):
386            cls = GithubException.RateLimitExceededException
387        elif status == 404 and output.get("message") == "Not Found":
388            cls = GithubException.UnknownObjectException
389        else:
390            cls = GithubException.GithubException
391        return cls(status, output)
392
393    def __structuredFromJson(self, data):
394        if len(data) == 0:
395            return None
396        else:
397            if isinstance(data, bytes):
398                data = data.decode("utf-8")
399            try:
400                return json.loads(data)
401            except ValueError:
402                return {"data": data}
403
404    def requestJson(
405        self, verb, url, parameters=None, headers=None, input=None, cnx=None
406    ):
407        def encode(input):
408            return "application/json", json.dumps(input)
409
410        return self.__requestEncode(cnx, verb, url, parameters, headers, input, encode)
411
412    def requestMultipart(
413        self, verb, url, parameters=None, headers=None, input=None, cnx=None
414    ):
415        def encode(input):
416            boundary = "----------------------------3c3ba8b523b2"
417            eol = "\r\n"
418
419            encoded_input = ""
420            for name, value in input.items():
421                encoded_input += "--" + boundary + eol
422                encoded_input += (
423                    'Content-Disposition: form-data; name="' + name + '"' + eol
424                )
425                encoded_input += eol
426                encoded_input += value + eol
427            encoded_input += "--" + boundary + "--" + eol
428            return "multipart/form-data; boundary=" + boundary, encoded_input
429
430        return self.__requestEncode(cnx, verb, url, parameters, headers, input, encode)
431
432    def requestBlob(self, verb, url, parameters={}, headers={}, input=None, cnx=None):
433        def encode(local_path):
434            if "Content-Type" in headers:
435                mime_type = headers["Content-Type"]
436            else:
437                guessed_type = mimetypes.guess_type(input)
438                mime_type = (
439                    guessed_type[0]
440                    if guessed_type[0] is not None
441                    else Consts.defaultMediaType
442                )
443            f = open(local_path, "rb")
444            return mime_type, f
445
446        if input:
447            headers["Content-Length"] = str(os.path.getsize(input))
448        return self.__requestEncode(cnx, verb, url, parameters, headers, input, encode)
449
450    def requestMemoryBlobAndCheck(
451        self, verb, url, parameters, headers, file_like, cnx=None
452    ):
453        # The expected signature of encode means that the argument is ignored.
454        def encode(_):
455            return headers["Content-Type"], file_like
456
457        if not cnx:
458            cnx = self.__customConnection(url)
459        return self.__check(
460            *self.__requestEncode(
461                cnx, verb, url, parameters, headers, file_like, encode
462            )
463        )
464
465    def __requestEncode(
466        self, cnx, verb, url, parameters, requestHeaders, input, encode
467    ):
468        assert verb in ["HEAD", "GET", "POST", "PATCH", "PUT", "DELETE"]
469        if parameters is None:
470            parameters = dict()
471        if requestHeaders is None:
472            requestHeaders = dict()
473
474        self.__authenticate(url, requestHeaders, parameters)
475        requestHeaders["User-Agent"] = self.__userAgent
476
477        url = self.__makeAbsoluteUrl(url)
478        url = self.__addParametersToUrl(url, parameters)
479
480        encoded_input = None
481        if input is not None:
482            requestHeaders["Content-Type"], encoded_input = encode(input)
483
484        self.NEW_DEBUG_FRAME(requestHeaders)
485
486        status, responseHeaders, output = self.__requestRaw(
487            cnx, verb, url, requestHeaders, encoded_input
488        )
489
490        if (
491            Consts.headerRateRemaining in responseHeaders
492            and Consts.headerRateLimit in responseHeaders
493        ):
494            self.rate_limiting = (
495                int(responseHeaders[Consts.headerRateRemaining]),
496                int(responseHeaders[Consts.headerRateLimit]),
497            )
498        if Consts.headerRateReset in responseHeaders:
499            self.rate_limiting_resettime = int(responseHeaders[Consts.headerRateReset])
500
501        if Consts.headerOAuthScopes in responseHeaders:
502            self.oauth_scopes = responseHeaders[Consts.headerOAuthScopes].split(", ")
503
504        self.DEBUG_ON_RESPONSE(status, responseHeaders, output)
505
506        return status, responseHeaders, output
507
508    def __requestRaw(self, cnx, verb, url, requestHeaders, input):
509        original_cnx = cnx
510        if cnx is None:
511            cnx = self.__createConnection()
512        cnx.request(verb, url, input, requestHeaders)
513        response = cnx.getresponse()
514
515        status = response.status
516        responseHeaders = dict((k.lower(), v) for k, v in response.getheaders())
517        output = response.read()
518
519        cnx.close()
520        if input:
521            if isinstance(input, IOBase):
522                input.close()
523
524        self.__log(verb, url, requestHeaders, input, status, responseHeaders, output)
525
526        if status == 202 and (
527            verb == "GET" or verb == "HEAD"
528        ):  # only for requests that are considered 'safe' in RFC 2616
529            time.sleep(Consts.PROCESSING_202_WAIT_TIME)
530            return self.__requestRaw(original_cnx, verb, url, requestHeaders, input)
531
532        if status == 301 and "location" in responseHeaders:
533            o = urllib.parse.urlparse(responseHeaders["location"])
534            return self.__requestRaw(original_cnx, verb, o.path, requestHeaders, input)
535
536        return status, responseHeaders, output
537
538    def __authenticate(self, url, requestHeaders, parameters):
539        if self.__clientId and self.__clientSecret and "client_id=" not in url:
540            parameters["client_id"] = self.__clientId
541            parameters["client_secret"] = self.__clientSecret
542        if self.__authorizationHeader is not None:
543            requestHeaders["Authorization"] = self.__authorizationHeader
544
545    def __makeAbsoluteUrl(self, url):
546        # URLs generated locally will be relative to __base_url
547        # URLs returned from the server will start with __base_url
548        if url.startswith("/"):
549            url = self.__prefix + url
550        else:
551            o = urllib.parse.urlparse(url)
552            assert o.hostname in [
553                self.__hostname,
554                "uploads.github.com",
555                "status.github.com",
556                "github.com",
557            ], o.hostname
558            assert o.path.startswith((self.__prefix, "/api/"))
559            assert o.port == self.__port
560            url = o.path
561            if o.query != "":
562                url += "?" + o.query
563        return url
564
565    def __addParametersToUrl(self, url, parameters):
566        if len(parameters) == 0:
567            return url
568        else:
569            return url + "?" + urllib.parse.urlencode(parameters)
570
571    def __createConnection(self):
572        kwds = {}
573        kwds["timeout"] = self.__timeout
574        kwds["verify"] = self.__verify
575
576        if self.__persist and self.__connection is not None:
577            return self.__connection
578
579        self.__connection = self.__connectionClass(
580            self.__hostname, self.__port, retry=self.__retry, **kwds
581        )
582
583        return self.__connection
584
585    def __log(self, verb, url, requestHeaders, input, status, responseHeaders, output):
586        if self.__logger is None:
587            self.__logger = logging.getLogger(__name__)
588        if self.__logger.isEnabledFor(logging.DEBUG):
589            if "Authorization" in requestHeaders:
590                if requestHeaders["Authorization"].startswith("Basic"):
591                    requestHeaders[
592                        "Authorization"
593                    ] = "Basic (login and password removed)"
594                elif requestHeaders["Authorization"].startswith("token"):
595                    requestHeaders["Authorization"] = "token (oauth token removed)"
596                elif requestHeaders["Authorization"].startswith("Bearer"):
597                    requestHeaders["Authorization"] = "Bearer (jwt removed)"
598                else:  # pragma no cover (Cannot happen, but could if we add an authentication method => be prepared)
599                    requestHeaders[
600                        "Authorization"
601                    ] = "(unknown auth removed)"  # pragma no cover (Cannot happen, but could if we add an authentication method => be prepared)
602            self.__logger.debug(
603                "%s %s://%s%s %s %s ==> %i %s %s",
604                verb,
605                self.__scheme,
606                self.__hostname,
607                url,
608                requestHeaders,
609                input,
610                status,
611                responseHeaders,
612                output,
613            )
614