1# -*- coding: utf-8 -*- 2import requests 3from packaging.specifiers import SpecifierSet 4from .errors import DatabaseFetchError, InvalidKeyError, DatabaseFileNotFoundError 5from .constants import OPEN_MIRRORS, API_MIRRORS, REQUEST_TIMEOUT, CACHE_VALID_SECONDS, CACHE_FILE 6from collections import namedtuple 7import os 8import json 9import time 10import errno 11 12 13class Vulnerability(namedtuple("Vulnerability", 14 ["name", "spec", "version", "advisory", "vuln_id"])): 15 pass 16 17 18def get_from_cache(db_name): 19 if os.path.exists(CACHE_FILE): 20 with open(CACHE_FILE) as f: 21 try: 22 data = json.loads(f.read()) 23 if db_name in data: 24 if "cached_at" in data[db_name]: 25 if data[db_name]["cached_at"] + CACHE_VALID_SECONDS > time.time(): 26 return data[db_name]["db"] 27 except json.JSONDecodeError: 28 pass 29 return False 30 31 32def write_to_cache(db_name, data): 33 # cache is in: ~/safety/cache.json 34 # and has the following form: 35 # { 36 # "insecure.json": { 37 # "cached_at": 12345678 38 # "db": {} 39 # }, 40 # "insecure_full.json": { 41 # "cached_at": 12345678 42 # "db": {} 43 # }, 44 # } 45 if not os.path.exists(os.path.dirname(CACHE_FILE)): 46 try: 47 os.makedirs(os.path.dirname(CACHE_FILE)) 48 with open(CACHE_FILE, "w") as _: 49 _.write(json.dumps({})) 50 except OSError as exc: # Guard against race condition 51 if exc.errno != errno.EEXIST: 52 raise 53 54 with open(CACHE_FILE, "r") as f: 55 try: 56 cache = json.loads(f.read()) 57 except json.JSONDecodeError: 58 cache = {} 59 60 with open(CACHE_FILE, "w") as f: 61 cache[db_name] = { 62 "cached_at": time.time(), 63 "db": data 64 } 65 f.write(json.dumps(cache)) 66 67 68def fetch_database_url(mirror, db_name, key, cached, proxy): 69 70 headers = {} 71 if key: 72 headers["X-Api-Key"] = key 73 74 if cached: 75 cached_data = get_from_cache(db_name=db_name) 76 if cached_data: 77 return cached_data 78 url = mirror + db_name 79 r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy) 80 if r.status_code == 200: 81 data = r.json() 82 if cached: 83 write_to_cache(db_name, data) 84 return data 85 elif r.status_code == 403: 86 raise InvalidKeyError() 87 88 89def fetch_database_file(path, db_name): 90 full_path = os.path.join(path, db_name) 91 if not os.path.exists(full_path): 92 raise DatabaseFileNotFoundError() 93 with open(full_path) as f: 94 return json.loads(f.read()) 95 96 97def fetch_database(full=False, key=False, db=False, cached=False, proxy={}): 98 99 if db: 100 mirrors = [db] 101 else: 102 mirrors = API_MIRRORS if key else OPEN_MIRRORS 103 104 db_name = "insecure_full.json" if full else "insecure.json" 105 for mirror in mirrors: 106 # mirror can either be a local path or a URL 107 if mirror.startswith("http://") or mirror.startswith("https://"): 108 data = fetch_database_url(mirror, db_name=db_name, key=key, cached=cached, proxy=proxy) 109 else: 110 data = fetch_database_file(mirror, db_name=db_name) 111 if data: 112 return data 113 raise DatabaseFetchError() 114 115 116def get_vulnerabilities(pkg, spec, db): 117 for entry in db[pkg]: 118 for entry_spec in entry["specs"]: 119 if entry_spec == spec: 120 yield entry 121 122 123def check(packages, key, db_mirror, cached, ignore_ids, proxy): 124 key = key if key else os.environ.get("SAFETY_API_KEY", False) 125 db = fetch_database(key=key, db=db_mirror, cached=cached, proxy=proxy) 126 db_full = None 127 vulnerable_packages = frozenset(db.keys()) 128 vulnerable = [] 129 for pkg in packages: 130 # normalize the package name, the safety-db is converting underscores to dashes and uses 131 # lowercase 132 name = pkg.key.replace("_", "-").lower() 133 134 if name in vulnerable_packages: 135 # we have a candidate here, build the spec set 136 for specifier in db[name]: 137 spec_set = SpecifierSet(specifiers=specifier) 138 if spec_set.contains(pkg.version): 139 if not db_full: 140 db_full = fetch_database(full=True, key=key, db=db_mirror, cached=cached, proxy=proxy) 141 for data in get_vulnerabilities(pkg=name, spec=specifier, db=db_full): 142 vuln_id = data.get("id").replace("pyup.io-", "") 143 if vuln_id and vuln_id not in ignore_ids: 144 vulnerable.append( 145 Vulnerability( 146 name=name, 147 spec=specifier, 148 version=pkg.version, 149 advisory=data.get("advisory"), 150 vuln_id=vuln_id 151 ) 152 ) 153 return vulnerable 154 155 156def review(vulnerabilities): 157 vulnerable = [] 158 for vuln in vulnerabilities: 159 current_vuln = { 160 "name": vuln[0], 161 "spec": vuln[1], 162 "version": vuln[2], 163 "advisory": vuln[3], 164 "vuln_id": vuln[4], 165 } 166 vulnerable.append( 167 Vulnerability(**current_vuln) 168 ) 169 return vulnerable 170