1#!/usr/bin/python
2##
3## license:BSD-3-Clause
4## copyright-holders:Vas Crabb
5
6from . import dbaccess
7
8import codecs
9import hashlib
10import os
11import os.path
12import struct
13import sys
14import zlib
15
16
17class _Identifier(object):
18    def __init__(self, dbcurs, **kwargs):
19        super(_Identifier, self).__init__(**kwargs)
20        self.dbcurs = dbcurs
21        self.shortnamewidth = 0
22        self.pathwidth = 0
23        self.labelwidth = 0
24        self.machines = { }
25        self.software = { }
26        self.unmatched = [ ]
27
28    def processPath(self, path, depth=0):
29        try:
30            if not os.path.isdir(path):
31                self.processFile(path)
32            elif depth > 5:
33                sys.stderr.write('Not examining \'%s\' - maximum depth exceeded\n')
34            else:
35                for name in os.listdir(path):
36                    self.processPath(os.path.join(path, name), depth + 1)
37        except BaseException as e:
38            sys.stderr.write('Error identifying \'%s\': %s\n' % (path, e))
39
40    def printResults(self):
41        nw = self.shortnamewidth - (self.shortnamewidth % 4) + 4
42        pw = self.pathwidth - (self.pathwidth % 4) + 4
43        lw = self.labelwidth - (self.labelwidth % 4) + 4
44        first = True
45        for shortname, romset in sorted(self.machines.items()):
46            if first:
47                first = False
48            else:
49                sys.stdout.write('\n')
50            sys.stdout.write('%-*s%s\n' % (nw, shortname, romset[0]))
51            self.printMatches(romset[1], pw, lw)
52        for softwarelist, listinfo in sorted(self.software.items()):
53            for shortname, softwareinfo in sorted(listinfo[1].items()):
54                if first:
55                    first = False
56                else:
57                    sys.stdout.write('\n')
58                sys.stdout.write('%-*s%s\n' % (nw, '%s:%s' % (softwarelist, shortname), softwareinfo[0]))
59                for part, partinfo in sorted(softwareinfo[1].items()):
60                    if partinfo[0] is not None:
61                        sys.stdout.write('%-*s%s\n' % (nw, '  ' + part, partinfo[0]))
62                    else:
63                        sys.stdout.write('  %s\n' % (part, ))
64                    self.printMatches(partinfo[1], pw, lw)
65        if self.unmatched:
66            if first:
67                first = False
68            else:
69                sys.stdout.write('\n')
70            sys.stdout.write('Unmatched\n')
71            for path, crc, sha1 in self.unmatched:
72                if crc is not None:
73                    sys.stdout.write('    %-*sCRC(%08x) SHA1(%s)\n' % (pw, path, crc, sha1))
74                else:
75                    sys.stdout.write('    %-*sSHA1(%s)\n' % (pw, path, sha1))
76
77    def getMachineMatches(self, shortname, description):
78        result = self.machines.get(shortname)
79        if result is None:
80            result = (description, [ ])
81            self.machines[shortname] = result
82        return result[1]
83
84    def getSoftwareMatches(self, softwarelist, softwarelistdescription, shortname, description, part, part_id):
85        listinfo = self.software.get(softwarelist)
86        if listinfo is None:
87            listinfo = (softwarelistdescription, { })
88            self.software[softwarelist] = listinfo
89        softwareinfo = listinfo[1].get(shortname)
90        if softwareinfo is None:
91            softwareinfo = (description, { })
92            listinfo[1][shortname] = softwareinfo
93        partinfo = softwareinfo[1].get(part)
94        if partinfo is None:
95            partinfo = (part_id, [ ])
96            softwareinfo[1][part] = partinfo
97        return partinfo[1]
98
99    def processRomFile(self, path, f):
100        crc, sha1 = self.digestRom(f)
101        matched = False
102        for shortname, description, label, bad in self.dbcurs.get_rom_dumps(crc, sha1):
103            matched = True
104            self.shortnamewidth = max(len(shortname), self.shortnamewidth)
105            self.labelwidth = max(len(label), self.labelwidth)
106            self.getMachineMatches(shortname, description).append((path, label, bad))
107        for softwarelist, softwarelistdescription, shortname, description, part, part_id, label, bad in self.dbcurs.get_software_rom_dumps(crc, sha1):
108            matched = True
109            self.shortnamewidth = max(len(softwarelist) + 1 + len(shortname), 2 + len(part), self.shortnamewidth)
110            self.labelwidth = max(len(label), self.labelwidth)
111            self.getSoftwareMatches(softwarelist, softwarelistdescription, shortname, description, part, part_id).append((path, label, bad))
112        if not matched:
113            self.unmatched.append((path, crc, sha1))
114
115    def processChd(self, path, sha1):
116        matched = False
117        for shortname, description, label, bad in self.dbcurs.get_disk_dumps(sha1):
118            matched = True
119            self.shortnamewidth = max(len(shortname), self.shortnamewidth)
120            self.labelwidth = max(len(label), self.labelwidth)
121            self.getMachineMatches(shortname, description).append((path, label, bad))
122        for softwarelist, softwarelistdescription, shortname, description, part, part_id, label, bad in self.dbcurs.get_software_disk_dumps(sha1):
123            matched = True
124            self.shortnamewidth = max(len(softwarelist) + 1 + len(shortname), 2 + len(part), self.shortnamewidth)
125            self.labelwidth = max(len(label), self.labelwidth)
126            self.getSoftwareMatches(softwarelist, softwarelistdescription, shortname, description, part, part_id).append((path, label, bad))
127        if not matched:
128            self.unmatched.append((path, None, sha1))
129
130    def processFile(self, path):
131        if os.path.splitext(path)[1].lower() != '.chd':
132            with open(path, mode='rb', buffering=0) as f:
133                self.processRomFile(path, f)
134        else:
135            with open(path, mode='rb') as f:
136                sha1 = self.probeChd(f)
137                if sha1 is None:
138                    f.seek(0)
139                    self.processRomFile(path, f)
140                else:
141                    self.processChd(path, sha1)
142        self.pathwidth = max(len(path), self.pathwidth)
143
144    @staticmethod
145    def iterateBlocks(f, s=65536):
146        while True:
147            buf = f.read(s)
148            if buf:
149                yield buf
150            else:
151                break
152
153    @staticmethod
154    def digestRom(f):
155        crc = zlib.crc32(bytes())
156        sha = hashlib.sha1()
157        for block in _Identifier.iterateBlocks(f):
158            crc = zlib.crc32(block, crc)
159            sha.update(block)
160        return crc & 0xffffffff, sha.hexdigest()
161
162    @staticmethod
163    def probeChd(f):
164        buf = f.read(16)
165        if (len(buf) != 16) or (buf[:8] != b'MComprHD'):
166            return None
167        headerlen, version = struct.unpack('>II', buf[8:])
168        if version == 3:
169            if headerlen != 120:
170                return None
171            sha1offs = 80
172        elif version == 4:
173            if headerlen != 108:
174                return None
175            sha1offs = 48
176        elif version == 5:
177            if headerlen != 124:
178                return None
179            sha1offs = 84
180        else:
181            return None
182        f.seek(sha1offs)
183        if f.tell() != sha1offs:
184            return None
185        buf = f.read(20)
186        if len(buf) != 20:
187            return None
188        return codecs.getencoder('hex_codec')(buf)[0].decode('ascii')
189
190    @staticmethod
191    def printMatches(matches, pathwidth, labelwidth):
192        for path, label, bad in matches:
193            if bad:
194                sys.stdout.write('    %-*s= %-*s(BAD)\n' % (pathwidth, path, labelwidth, label))
195            else:
196                sys.stdout.write('    %-*s= %s\n' % (pathwidth, path, label))
197
198
199def do_listfull(options):
200    dbconn = dbaccess.QueryConnection(options.database)
201    dbcurs = dbconn.cursor()
202    first = True
203    for shortname, description in dbcurs.listfull(options.pattern):
204        if first:
205            sys.stdout.write('Name:             Description:\n')
206            first = False
207        sys.stdout.write('%-16s  "%s"\n' % (shortname, description))
208    if first:
209        sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, ))
210    dbcurs.close()
211    dbconn.close()
212
213
214def do_listsource(options):
215    dbconn = dbaccess.QueryConnection(options.database)
216    dbcurs = dbconn.cursor()
217    shortname = None
218    for shortname, sourcefile in dbcurs.listsource(options.pattern):
219        sys.stdout.write('%-16s %s\n' % (shortname, sourcefile))
220    if shortname is None:
221        sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, ))
222    dbcurs.close()
223    dbconn.close()
224
225
226def do_listclones(options):
227    dbconn = dbaccess.QueryConnection(options.database)
228    dbcurs = dbconn.cursor()
229    first = True
230    for shortname, parent in dbcurs.listclones(options.pattern):
231        if first:
232            sys.stdout.write('Name:            Clone of:\n')
233            first = False
234        sys.stdout.write('%-16s %s\n' % (shortname, parent))
235    if first:
236        count = dbcurs.count_systems(options.pattern).fetchone()[0]
237        if count:
238            sys.stderr.write('Found %d match(es) for \'%s\' but none were clones\n' % (count, options.pattern))
239        else:
240            sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, ))
241    dbcurs.close()
242    dbconn.close()
243
244
245def do_listbrothers(options):
246    dbconn = dbaccess.QueryConnection(options.database)
247    dbcurs = dbconn.cursor()
248    first = True
249    for sourcefile, shortname, parent in dbcurs.listbrothers(options.pattern):
250        if first:
251            sys.stdout.write('%-20s %-16s %s\n' % ('Source file:', 'Name:', 'Parent:'))
252            first = False
253        sys.stdout.write('%-20s %-16s %s\n' % (sourcefile, shortname, parent or ''))
254    if first:
255        sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, ))
256    dbcurs.close()
257    dbconn.close()
258
259
260def do_listaffected(options):
261    dbconn = dbaccess.QueryConnection(options.database)
262    dbcurs = dbconn.cursor()
263    first = True
264    for shortname, description in dbcurs.listaffected(*options.pattern):
265        if first:
266            sys.stdout.write('Name:             Description:\n')
267            first = False
268        sys.stdout.write('%-16s  "%s"\n' % (shortname, description))
269    if first:
270        sys.stderr.write('No matching systems found for \'%s\'\n' % (options.pattern, ))
271    dbcurs.close()
272    dbconn.close()
273
274
275def do_romident(options):
276    dbconn = dbaccess.QueryConnection(options.database)
277    dbcurs = dbconn.cursor()
278    ident = _Identifier(dbcurs)
279    for path in options.path:
280        ident.processPath(path)
281    ident.printResults()
282    dbcurs.close()
283    dbconn.close()
284