1# -*- coding: utf-8 -*-
2import requests
3from packaging.specifiers import SpecifierSet
4from .errors import DatabaseFetchError, InvalidKeyError, DatabaseFileNotFoundError
5from .constants import OPEN_MIRRORS, API_MIRRORS, REQUEST_TIMEOUT, CACHE_VALID_SECONDS, CACHE_FILE
6from collections import namedtuple
7import os
8import json
9import time
10import errno
11
12
13class Vulnerability(namedtuple("Vulnerability",
14                               ["name", "spec", "version", "advisory", "vuln_id"])):
15    pass
16
17
18def get_from_cache(db_name):
19    if os.path.exists(CACHE_FILE):
20        with open(CACHE_FILE) as f:
21            try:
22                data = json.loads(f.read())
23                if db_name in data:
24                    if "cached_at" in data[db_name]:
25                        if data[db_name]["cached_at"] + CACHE_VALID_SECONDS > time.time():
26                            return data[db_name]["db"]
27            except json.JSONDecodeError:
28                pass
29    return False
30
31
32def write_to_cache(db_name, data):
33    # cache is in: ~/safety/cache.json
34    # and has the following form:
35    # {
36    #   "insecure.json": {
37    #       "cached_at": 12345678
38    #       "db": {}
39    #   },
40    #   "insecure_full.json": {
41    #       "cached_at": 12345678
42    #       "db": {}
43    #   },
44    # }
45    if not os.path.exists(os.path.dirname(CACHE_FILE)):
46        try:
47            os.makedirs(os.path.dirname(CACHE_FILE))
48            with open(CACHE_FILE, "w") as _:
49                _.write(json.dumps({}))
50        except OSError as exc:  # Guard against race condition
51            if exc.errno != errno.EEXIST:
52                raise
53
54    with open(CACHE_FILE, "r") as f:
55        try:
56            cache = json.loads(f.read())
57        except json.JSONDecodeError:
58            cache = {}
59
60    with open(CACHE_FILE, "w") as f:
61        cache[db_name] = {
62            "cached_at": time.time(),
63            "db": data
64        }
65        f.write(json.dumps(cache))
66
67
68def fetch_database_url(mirror, db_name, key, cached, proxy):
69
70    headers = {}
71    if key:
72        headers["X-Api-Key"] = key
73
74    if cached:
75        cached_data = get_from_cache(db_name=db_name)
76        if cached_data:
77            return cached_data
78    url = mirror + db_name
79    r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy)
80    if r.status_code == 200:
81        data = r.json()
82        if cached:
83            write_to_cache(db_name, data)
84        return data
85    elif r.status_code == 403:
86        raise InvalidKeyError()
87
88
89def fetch_database_file(path, db_name):
90    full_path = os.path.join(path, db_name)
91    if not os.path.exists(full_path):
92        raise DatabaseFileNotFoundError()
93    with open(full_path) as f:
94        return json.loads(f.read())
95
96
97def fetch_database(full=False, key=False, db=False, cached=False, proxy={}):
98
99    if db:
100        mirrors = [db]
101    else:
102        mirrors = API_MIRRORS if key else OPEN_MIRRORS
103
104    db_name = "insecure_full.json" if full else "insecure.json"
105    for mirror in mirrors:
106        # mirror can either be a local path or a URL
107        if mirror.startswith("http://") or mirror.startswith("https://"):
108            data = fetch_database_url(mirror, db_name=db_name, key=key, cached=cached, proxy=proxy)
109        else:
110            data = fetch_database_file(mirror, db_name=db_name)
111        if data:
112            return data
113    raise DatabaseFetchError()
114
115
116def get_vulnerabilities(pkg, spec, db):
117    for entry in db[pkg]:
118        for entry_spec in entry["specs"]:
119            if entry_spec == spec:
120                yield entry
121
122
123def check(packages, key, db_mirror, cached, ignore_ids, proxy):
124    key = key if key else os.environ.get("SAFETY_API_KEY", False)
125    db = fetch_database(key=key, db=db_mirror, cached=cached, proxy=proxy)
126    db_full = None
127    vulnerable_packages = frozenset(db.keys())
128    vulnerable = []
129    for pkg in packages:
130        # normalize the package name, the safety-db is converting underscores to dashes and uses
131        # lowercase
132        name = pkg.key.replace("_", "-").lower()
133
134        if name in vulnerable_packages:
135            # we have a candidate here, build the spec set
136            for specifier in db[name]:
137                spec_set = SpecifierSet(specifiers=specifier)
138                if spec_set.contains(pkg.version):
139                    if not db_full:
140                        db_full = fetch_database(full=True, key=key, db=db_mirror, cached=cached, proxy=proxy)
141                    for data in get_vulnerabilities(pkg=name, spec=specifier, db=db_full):
142                        vuln_id = data.get("id").replace("pyup.io-", "")
143                        if vuln_id and vuln_id not in ignore_ids:
144                            vulnerable.append(
145                                Vulnerability(
146                                    name=name,
147                                    spec=specifier,
148                                    version=pkg.version,
149                                    advisory=data.get("advisory"),
150                                    vuln_id=vuln_id
151                                )
152                            )
153    return vulnerable
154
155
156def review(vulnerabilities):
157    vulnerable = []
158    for vuln in vulnerabilities:
159        current_vuln = {
160            "name": vuln[0],
161            "spec": vuln[1],
162            "version": vuln[2],
163            "advisory": vuln[3],
164            "vuln_id": vuln[4],
165        }
166        vulnerable.append(
167            Vulnerability(**current_vuln)
168        )
169    return vulnerable
170