1#!/usr/bin/python 2## 3## license:BSD-3-Clause 4## copyright-holders:Vas Crabb 5 6from . import dbaccess 7 8import codecs 9import hashlib 10import os 11import os.path 12import struct 13import sys 14import zlib 15 16 17class _Identifier(object): 18 def __init__(self, dbcurs, **kwargs): 19 super(_Identifier, self).__init__(**kwargs) 20 self.dbcurs = dbcurs 21 self.shortnamewidth = 0 22 self.pathwidth = 0 23 self.labelwidth = 0 24 self.machines = { } 25 self.software = { } 26 self.unmatched = [ ] 27 28 def processPath(self, path, depth=0): 29 try: 30 if not os.path.isdir(path): 31 self.processFile(path) 32 elif depth > 5: 33 sys.stderr.write('Not examining \'%s\' - maximum depth exceeded\n') 34 else: 35 for name in os.listdir(path): 36 self.processPath(os.path.join(path, name), depth + 1) 37 except BaseException as e: 38 sys.stderr.write('Error identifying \'%s\': %s\n' % (path, e)) 39 40 def printResults(self): 41 nw = self.shortnamewidth - (self.shortnamewidth % 4) + 4 42 pw = self.pathwidth - (self.pathwidth % 4) + 4 43 lw = self.labelwidth - (self.labelwidth % 4) + 4 44 first = True 45 for shortname, romset in sorted(self.machines.items()): 46 if first: 47 first = False 48 else: 49 sys.stdout.write('\n') 50 sys.stdout.write('%-*s%s\n' % (nw, shortname, romset[0])) 51 self.printMatches(romset[1], pw, lw) 52 for softwarelist, listinfo in sorted(self.software.items()): 53 for shortname, softwareinfo in sorted(listinfo[1].items()): 54 if first: 55 first = False 56 else: 57 sys.stdout.write('\n') 58 sys.stdout.write('%-*s%s\n' % (nw, '%s:%s' % (softwarelist, shortname), softwareinfo[0])) 59 for part, partinfo in sorted(softwareinfo[1].items()): 60 if partinfo[0] is not None: 61 sys.stdout.write('%-*s%s\n' % (nw, ' ' + part, partinfo[0])) 62 else: 63 sys.stdout.write(' %s\n' % (part, )) 64 self.printMatches(partinfo[1], pw, lw) 65 if self.unmatched: 66 if first: 67 first = False 68 else: 69 sys.stdout.write('\n') 70 sys.stdout.write('Unmatched\n') 71 for path, crc, sha1 in self.unmatched: 72 if crc is not None: 73 sys.stdout.write(' %-*sCRC(%08x) SHA1(%s)\n' % (pw, path, crc, sha1)) 74 else: 75 sys.stdout.write(' %-*sSHA1(%s)\n' % (pw, path, sha1)) 76 77 def getMachineMatches(self, shortname, description): 78 result = self.machines.get(shortname) 79 if result is None: 80 result = (description, [ ]) 81 self.machines[shortname] = result 82 return result[1] 83 84 def getSoftwareMatches(self, softwarelist, softwarelistdescription, shortname, description, part, part_id): 85 listinfo = self.software.get(softwarelist) 86 if listinfo is None: 87 listinfo = (softwarelistdescription, { }) 88 self.software[softwarelist] = listinfo 89 softwareinfo = listinfo[1].get(shortname) 90 if softwareinfo is None: 91 softwareinfo = (description, { }) 92 listinfo[1][shortname] = softwareinfo 93 partinfo = softwareinfo[1].get(part) 94 if partinfo is None: 95 partinfo = (part_id, [ ]) 96 softwareinfo[1][part] = partinfo 97 return partinfo[1] 98 99 def processRomFile(self, path, f): 100 crc, sha1 = self.digestRom(f) 101 matched = False 102 for shortname, description, label, bad in self.dbcurs.get_rom_dumps(crc, sha1): 103 matched = True 104 self.shortnamewidth = max(len(shortname), self.shortnamewidth) 105 self.labelwidth = max(len(label), self.labelwidth) 106 self.getMachineMatches(shortname, description).append((path, label, bad)) 107 for softwarelist, softwarelistdescription, shortname, description, part, part_id, label, bad in self.dbcurs.get_software_rom_dumps(crc, sha1): 108 matched = True 109 self.shortnamewidth = max(len(softwarelist) + 1 + len(shortname), 2 + len(part), self.shortnamewidth) 110 self.labelwidth = max(len(label), self.labelwidth) 111 self.getSoftwareMatches(softwarelist, softwarelistdescription, shortname, description, part, part_id).append((path, label, bad)) 112 if not matched: 113 self.unmatched.append((path, crc, sha1)) 114 115 def processChd(self, path, sha1): 116 matched = False 117 for shortname, description, label, bad in self.dbcurs.get_disk_dumps(sha1): 118 matched = True 119 self.shortnamewidth = max(len(shortname), self.shortnamewidth) 120 self.labelwidth = max(len(label), self.labelwidth) 121 self.getMachineMatches(shortname, description).append((path, label, bad)) 122 for softwarelist, softwarelistdescription, shortname, description, part, part_id, label, bad in self.dbcurs.get_software_disk_dumps(sha1): 123 matched = True 124 self.shortnamewidth = max(len(softwarelist) + 1 + len(shortname), 2 + len(part), self.shortnamewidth) 125 self.labelwidth = max(len(label), self.labelwidth) 126 self.getSoftwareMatches(softwarelist, softwarelistdescription, shortname, description, part, part_id).append((path, label, bad)) 127 if not matched: 128 self.unmatched.append((path, None, sha1)) 129 130 def processFile(self, path): 131 if os.path.splitext(path)[1].lower() != '.chd': 132 with open(path, mode='rb', buffering=0) as f: 133 self.processRomFile(path, f) 134 else: 135 with open(path, mode='rb') as f: 136 sha1 = self.probeChd(f) 137 if sha1 is None: 138 f.seek(0) 139 self.processRomFile(path, f) 140 else: 141 self.processChd(path, sha1) 142 self.pathwidth = max(len(path), self.pathwidth) 143 144 @staticmethod 145 def iterateBlocks(f, s=65536): 146 while True: 147 buf = f.read(s) 148 if buf: 149 yield buf 150 else: 151 break 152 153 @staticmethod 154 def digestRom(f): 155 crc = zlib.crc32(bytes()) 156 sha = hashlib.sha1() 157 for block in _Identifier.iterateBlocks(f): 158 crc = zlib.crc32(block, crc) 159 sha.update(block) 160 return crc & 0xffffffff, sha.hexdigest() 161 162 @staticmethod 163 def probeChd(f): 164 buf = f.read(16) 165 if (len(buf) != 16) or (buf[:8] != b'MComprHD'): 166 return None 167 headerlen, version = struct.unpack('>II', buf[8:]) 168 if version == 3: 169 if headerlen != 120: 170 return None 171 sha1offs = 80 172 elif version == 4: 173 if headerlen != 108: 174 return None 175 sha1offs = 48 176 elif version == 5: 177 if headerlen != 124: 178 return None 179 sha1offs = 84 180 else: 181 return None 182 f.seek(sha1offs) 183 if f.tell() != sha1offs: 184 return None 185 buf = f.read(20) 186 if len(buf) != 20: 187 return None 188 return codecs.getencoder('hex_codec')(buf)[0].decode('ascii') 189 190 @staticmethod 191 def printMatches(matches, pathwidth, labelwidth): 192 for path, label, bad in matches: 193 if bad: 194 sys.stdout.write(' %-*s= %-*s(BAD)\n' % (pathwidth, path, labelwidth, label)) 195 else: 196 sys.stdout.write(' %-*s= %s\n' % (pathwidth, path, label)) 197 198 199def do_listfull(options): 200 dbconn = dbaccess.QueryConnection(options.database) 201 dbcurs = dbconn.cursor() 202 first = True 203 for shortname, description in dbcurs.listfull(options.pattern): 204 if first: 205 sys.stdout.write('Name: Description:\n') 206 first = False 207 sys.stdout.write('%-16s "%s"\n' % (shortname, description)) 208 if first: 209 sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, )) 210 dbcurs.close() 211 dbconn.close() 212 213 214def do_listsource(options): 215 dbconn = dbaccess.QueryConnection(options.database) 216 dbcurs = dbconn.cursor() 217 shortname = None 218 for shortname, sourcefile in dbcurs.listsource(options.pattern): 219 sys.stdout.write('%-16s %s\n' % (shortname, sourcefile)) 220 if shortname is None: 221 sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, )) 222 dbcurs.close() 223 dbconn.close() 224 225 226def do_listclones(options): 227 dbconn = dbaccess.QueryConnection(options.database) 228 dbcurs = dbconn.cursor() 229 first = True 230 for shortname, parent in dbcurs.listclones(options.pattern): 231 if first: 232 sys.stdout.write('Name: Clone of:\n') 233 first = False 234 sys.stdout.write('%-16s %s\n' % (shortname, parent)) 235 if first: 236 count = dbcurs.count_systems(options.pattern).fetchone()[0] 237 if count: 238 sys.stderr.write('Found %d match(es) for \'%s\' but none were clones\n' % (count, options.pattern)) 239 else: 240 sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, )) 241 dbcurs.close() 242 dbconn.close() 243 244 245def do_listbrothers(options): 246 dbconn = dbaccess.QueryConnection(options.database) 247 dbcurs = dbconn.cursor() 248 first = True 249 for sourcefile, shortname, parent in dbcurs.listbrothers(options.pattern): 250 if first: 251 sys.stdout.write('%-20s %-16s %s\n' % ('Source file:', 'Name:', 'Parent:')) 252 first = False 253 sys.stdout.write('%-20s %-16s %s\n' % (sourcefile, shortname, parent or '')) 254 if first: 255 sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, )) 256 dbcurs.close() 257 dbconn.close() 258 259 260def do_listaffected(options): 261 dbconn = dbaccess.QueryConnection(options.database) 262 dbcurs = dbconn.cursor() 263 first = True 264 for shortname, description in dbcurs.listaffected(*options.pattern): 265 if first: 266 sys.stdout.write('Name: Description:\n') 267 first = False 268 sys.stdout.write('%-16s "%s"\n' % (shortname, description)) 269 if first: 270 sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, )) 271 dbcurs.close() 272 dbconn.close() 273 274 275def do_romident(options): 276 dbconn = dbaccess.QueryConnection(options.database) 277 dbcurs = dbconn.cursor() 278 ident = _Identifier(dbcurs) 279 for path in options.path: 280 ident.processPath(path) 281 ident.printResults() 282 dbcurs.close() 283 dbconn.close() 284