1# Copyright (C) 2005 Martin v. Löwis
2# Licensed to PSF under a Contributor Agreement.
3from _msi import *
4import fnmatch
5import os
6import re
7import string
8import sys
9
10AMD64 = "AMD64" in sys.version
11# Keep msilib.Win64 around to preserve backwards compatibility.
12Win64 = AMD64
13
14# Partially taken from Wine
15datasizemask=      0x00ff
16type_valid=        0x0100
17type_localizable=  0x0200
18
19typemask=          0x0c00
20type_long=         0x0000
21type_short=        0x0400
22type_string=       0x0c00
23type_binary=       0x0800
24
25type_nullable=     0x1000
26type_key=          0x2000
27# XXX temporary, localizable?
28knownbits = datasizemask | type_valid | type_localizable | \
29            typemask | type_nullable | type_key
30
31class Table:
32    def __init__(self, name):
33        self.name = name
34        self.fields = []
35
36    def add_field(self, index, name, type):
37        self.fields.append((index,name,type))
38
39    def sql(self):
40        fields = []
41        keys = []
42        self.fields.sort()
43        fields = [None]*len(self.fields)
44        for index, name, type in self.fields:
45            index -= 1
46            unk = type & ~knownbits
47            if unk:
48                print("%s.%s unknown bits %x" % (self.name, name, unk))
49            size = type & datasizemask
50            dtype = type & typemask
51            if dtype == type_string:
52                if size:
53                    tname="CHAR(%d)" % size
54                else:
55                    tname="CHAR"
56            elif dtype == type_short:
57                assert size==2
58                tname = "SHORT"
59            elif dtype == type_long:
60                assert size==4
61                tname="LONG"
62            elif dtype == type_binary:
63                assert size==0
64                tname="OBJECT"
65            else:
66                tname="unknown"
67                print("%s.%sunknown integer type %d" % (self.name, name, size))
68            if type & type_nullable:
69                flags = ""
70            else:
71                flags = " NOT NULL"
72            if type & type_localizable:
73                flags += " LOCALIZABLE"
74            fields[index] = "`%s` %s%s" % (name, tname, flags)
75            if type & type_key:
76                keys.append("`%s`" % name)
77        fields = ", ".join(fields)
78        keys = ", ".join(keys)
79        return "CREATE TABLE %s (%s PRIMARY KEY %s)" % (self.name, fields, keys)
80
81    def create(self, db):
82        v = db.OpenView(self.sql())
83        v.Execute(None)
84        v.Close()
85
86class _Unspecified:pass
87def change_sequence(seq, action, seqno=_Unspecified, cond = _Unspecified):
88    "Change the sequence number of an action in a sequence list"
89    for i in range(len(seq)):
90        if seq[i][0] == action:
91            if cond is _Unspecified:
92                cond = seq[i][1]
93            if seqno is _Unspecified:
94                seqno = seq[i][2]
95            seq[i] = (action, cond, seqno)
96            return
97    raise ValueError("Action not found in sequence")
98
99def add_data(db, table, values):
100    v = db.OpenView("SELECT * FROM `%s`" % table)
101    count = v.GetColumnInfo(MSICOLINFO_NAMES).GetFieldCount()
102    r = CreateRecord(count)
103    for value in values:
104        assert len(value) == count, value
105        for i in range(count):
106            field = value[i]
107            if isinstance(field, int):
108                r.SetInteger(i+1,field)
109            elif isinstance(field, str):
110                r.SetString(i+1,field)
111            elif field is None:
112                pass
113            elif isinstance(field, Binary):
114                r.SetStream(i+1, field.name)
115            else:
116                raise TypeError("Unsupported type %s" % field.__class__.__name__)
117        try:
118            v.Modify(MSIMODIFY_INSERT, r)
119        except Exception as e:
120            raise MSIError("Could not insert "+repr(values)+" into "+table)
121
122        r.ClearData()
123    v.Close()
124
125
126def add_stream(db, name, path):
127    v = db.OpenView("INSERT INTO _Streams (Name, Data) VALUES ('%s', ?)" % name)
128    r = CreateRecord(1)
129    r.SetStream(1, path)
130    v.Execute(r)
131    v.Close()
132
133def init_database(name, schema,
134                  ProductName, ProductCode, ProductVersion,
135                  Manufacturer):
136    try:
137        os.unlink(name)
138    except OSError:
139        pass
140    ProductCode = ProductCode.upper()
141    # Create the database
142    db = OpenDatabase(name, MSIDBOPEN_CREATE)
143    # Create the tables
144    for t in schema.tables:
145        t.create(db)
146    # Fill the validation table
147    add_data(db, "_Validation", schema._Validation_records)
148    # Initialize the summary information, allowing atmost 20 properties
149    si = db.GetSummaryInformation(20)
150    si.SetProperty(PID_TITLE, "Installation Database")
151    si.SetProperty(PID_SUBJECT, ProductName)
152    si.SetProperty(PID_AUTHOR, Manufacturer)
153    if AMD64:
154        si.SetProperty(PID_TEMPLATE, "x64;1033")
155    else:
156        si.SetProperty(PID_TEMPLATE, "Intel;1033")
157    si.SetProperty(PID_REVNUMBER, gen_uuid())
158    si.SetProperty(PID_WORDCOUNT, 2) # long file names, compressed, original media
159    si.SetProperty(PID_PAGECOUNT, 200)
160    si.SetProperty(PID_APPNAME, "Python MSI Library")
161    # XXX more properties
162    si.Persist()
163    add_data(db, "Property", [
164        ("ProductName", ProductName),
165        ("ProductCode", ProductCode),
166        ("ProductVersion", ProductVersion),
167        ("Manufacturer", Manufacturer),
168        ("ProductLanguage", "1033")])
169    db.Commit()
170    return db
171
172def add_tables(db, module):
173    for table in module.tables:
174        add_data(db, table, getattr(module, table))
175
176def make_id(str):
177    identifier_chars = string.ascii_letters + string.digits + "._"
178    str = "".join([c if c in identifier_chars else "_" for c in str])
179    if str[0] in (string.digits + "."):
180        str = "_" + str
181    assert re.match("^[A-Za-z_][A-Za-z0-9_.]*$", str), "FILE"+str
182    return str
183
184def gen_uuid():
185    return "{"+UuidCreate().upper()+"}"
186
187class CAB:
188    def __init__(self, name):
189        self.name = name
190        self.files = []
191        self.filenames = set()
192        self.index = 0
193
194    def gen_id(self, file):
195        logical = _logical = make_id(file)
196        pos = 1
197        while logical in self.filenames:
198            logical = "%s.%d" % (_logical, pos)
199            pos += 1
200        self.filenames.add(logical)
201        return logical
202
203    def append(self, full, file, logical):
204        if os.path.isdir(full):
205            return
206        if not logical:
207            logical = self.gen_id(file)
208        self.index += 1
209        self.files.append((full, logical))
210        return self.index, logical
211
212    def commit(self, db):
213        from tempfile import mktemp
214        filename = mktemp()
215        FCICreate(filename, self.files)
216        add_data(db, "Media",
217                [(1, self.index, None, "#"+self.name, None, None)])
218        add_stream(db, self.name, filename)
219        os.unlink(filename)
220        db.Commit()
221
222_directories = set()
223class Directory:
224    def __init__(self, db, cab, basedir, physical, _logical, default, componentflags=None):
225        """Create a new directory in the Directory table. There is a current component
226        at each point in time for the directory, which is either explicitly created
227        through start_component, or implicitly when files are added for the first
228        time. Files are added into the current component, and into the cab file.
229        To create a directory, a base directory object needs to be specified (can be
230        None), the path to the physical directory, and a logical directory name.
231        Default specifies the DefaultDir slot in the directory table. componentflags
232        specifies the default flags that new components get."""
233        index = 1
234        _logical = make_id(_logical)
235        logical = _logical
236        while logical in _directories:
237            logical = "%s%d" % (_logical, index)
238            index += 1
239        _directories.add(logical)
240        self.db = db
241        self.cab = cab
242        self.basedir = basedir
243        self.physical = physical
244        self.logical = logical
245        self.component = None
246        self.short_names = set()
247        self.ids = set()
248        self.keyfiles = {}
249        self.componentflags = componentflags
250        if basedir:
251            self.absolute = os.path.join(basedir.absolute, physical)
252            blogical = basedir.logical
253        else:
254            self.absolute = physical
255            blogical = None
256        add_data(db, "Directory", [(logical, blogical, default)])
257
258    def start_component(self, component = None, feature = None, flags = None, keyfile = None, uuid=None):
259        """Add an entry to the Component table, and make this component the current for this
260        directory. If no component name is given, the directory name is used. If no feature
261        is given, the current feature is used. If no flags are given, the directory's default
262        flags are used. If no keyfile is given, the KeyPath is left null in the Component
263        table."""
264        if flags is None:
265            flags = self.componentflags
266        if uuid is None:
267            uuid = gen_uuid()
268        else:
269            uuid = uuid.upper()
270        if component is None:
271            component = self.logical
272        self.component = component
273        if AMD64:
274            flags |= 256
275        if keyfile:
276            keyid = self.cab.gen_id(keyfile)
277            self.keyfiles[keyfile] = keyid
278        else:
279            keyid = None
280        add_data(self.db, "Component",
281                        [(component, uuid, self.logical, flags, None, keyid)])
282        if feature is None:
283            feature = current_feature
284        add_data(self.db, "FeatureComponents",
285                        [(feature.id, component)])
286
287    def make_short(self, file):
288        oldfile = file
289        file = file.replace('+', '_')
290        file = ''.join(c for c in file if not c in r' "/\[]:;=,')
291        parts = file.split(".")
292        if len(parts) > 1:
293            prefix = "".join(parts[:-1]).upper()
294            suffix = parts[-1].upper()
295            if not prefix:
296                prefix = suffix
297                suffix = None
298        else:
299            prefix = file.upper()
300            suffix = None
301        if len(parts) < 3 and len(prefix) <= 8 and file == oldfile and (
302                                                not suffix or len(suffix) <= 3):
303            if suffix:
304                file = prefix+"."+suffix
305            else:
306                file = prefix
307        else:
308            file = None
309        if file is None or file in self.short_names:
310            prefix = prefix[:6]
311            if suffix:
312                suffix = suffix[:3]
313            pos = 1
314            while 1:
315                if suffix:
316                    file = "%s~%d.%s" % (prefix, pos, suffix)
317                else:
318                    file = "%s~%d" % (prefix, pos)
319                if file not in self.short_names: break
320                pos += 1
321                assert pos < 10000
322                if pos in (10, 100, 1000):
323                    prefix = prefix[:-1]
324        self.short_names.add(file)
325        assert not re.search(r'[\?|><:/*"+,;=\[\]]', file) # restrictions on short names
326        return file
327
328    def add_file(self, file, src=None, version=None, language=None):
329        """Add a file to the current component of the directory, starting a new one
330        if there is no current component. By default, the file name in the source
331        and the file table will be identical. If the src file is specified, it is
332        interpreted relative to the current directory. Optionally, a version and a
333        language can be specified for the entry in the File table."""
334        if not self.component:
335            self.start_component(self.logical, current_feature, 0)
336        if not src:
337            # Allow relative paths for file if src is not specified
338            src = file
339            file = os.path.basename(file)
340        absolute = os.path.join(self.absolute, src)
341        assert not re.search(r'[\?|><:/*]"', file) # restrictions on long names
342        if file in self.keyfiles:
343            logical = self.keyfiles[file]
344        else:
345            logical = None
346        sequence, logical = self.cab.append(absolute, file, logical)
347        assert logical not in self.ids
348        self.ids.add(logical)
349        short = self.make_short(file)
350        full = "%s|%s" % (short, file)
351        filesize = os.stat(absolute).st_size
352        # constants.msidbFileAttributesVital
353        # Compressed omitted, since it is the database default
354        # could add r/o, system, hidden
355        attributes = 512
356        add_data(self.db, "File",
357                        [(logical, self.component, full, filesize, version,
358                         language, attributes, sequence)])
359        #if not version:
360        #    # Add hash if the file is not versioned
361        #    filehash = FileHash(absolute, 0)
362        #    add_data(self.db, "MsiFileHash",
363        #             [(logical, 0, filehash.IntegerData(1),
364        #               filehash.IntegerData(2), filehash.IntegerData(3),
365        #               filehash.IntegerData(4))])
366        # Automatically remove .pyc files on uninstall (2)
367        # XXX: adding so many RemoveFile entries makes installer unbelievably
368        # slow. So instead, we have to use wildcard remove entries
369        if file.endswith(".py"):
370            add_data(self.db, "RemoveFile",
371                      [(logical+"c", self.component, "%sC|%sc" % (short, file),
372                        self.logical, 2),
373                       (logical+"o", self.component, "%sO|%so" % (short, file),
374                        self.logical, 2)])
375        return logical
376
377    def glob(self, pattern, exclude = None):
378        """Add a list of files to the current component as specified in the
379        glob pattern. Individual files can be excluded in the exclude list."""
380        try:
381            files = os.listdir(self.absolute)
382        except OSError:
383            return []
384        if pattern[:1] != '.':
385            files = (f for f in files if f[0] != '.')
386        files = fnmatch.filter(files, pattern)
387        for f in files:
388            if exclude and f in exclude: continue
389            self.add_file(f)
390        return files
391
392    def remove_pyc(self):
393        "Remove .pyc files on uninstall"
394        add_data(self.db, "RemoveFile",
395                 [(self.component+"c", self.component, "*.pyc", self.logical, 2)])
396
397class Binary:
398    def __init__(self, fname):
399        self.name = fname
400    def __repr__(self):
401        return 'msilib.Binary(os.path.join(dirname,"%s"))' % self.name
402
403class Feature:
404    def __init__(self, db, id, title, desc, display, level = 1,
405                 parent=None, directory = None, attributes=0):
406        self.id = id
407        if parent:
408            parent = parent.id
409        add_data(db, "Feature",
410                        [(id, parent, title, desc, display,
411                          level, directory, attributes)])
412    def set_current(self):
413        global current_feature
414        current_feature = self
415
416class Control:
417    def __init__(self, dlg, name):
418        self.dlg = dlg
419        self.name = name
420
421    def event(self, event, argument, condition = "1", ordering = None):
422        add_data(self.dlg.db, "ControlEvent",
423                 [(self.dlg.name, self.name, event, argument,
424                   condition, ordering)])
425
426    def mapping(self, event, attribute):
427        add_data(self.dlg.db, "EventMapping",
428                 [(self.dlg.name, self.name, event, attribute)])
429
430    def condition(self, action, condition):
431        add_data(self.dlg.db, "ControlCondition",
432                 [(self.dlg.name, self.name, action, condition)])
433
434class RadioButtonGroup(Control):
435    def __init__(self, dlg, name, property):
436        self.dlg = dlg
437        self.name = name
438        self.property = property
439        self.index = 1
440
441    def add(self, name, x, y, w, h, text, value = None):
442        if value is None:
443            value = name
444        add_data(self.dlg.db, "RadioButton",
445                 [(self.property, self.index, value,
446                   x, y, w, h, text, None)])
447        self.index += 1
448
449class Dialog:
450    def __init__(self, db, name, x, y, w, h, attr, title, first, default, cancel):
451        self.db = db
452        self.name = name
453        self.x, self.y, self.w, self.h = x,y,w,h
454        add_data(db, "Dialog", [(name, x,y,w,h,attr,title,first,default,cancel)])
455
456    def control(self, name, type, x, y, w, h, attr, prop, text, next, help):
457        add_data(self.db, "Control",
458                 [(self.name, name, type, x, y, w, h, attr, prop, text, next, help)])
459        return Control(self, name)
460
461    def text(self, name, x, y, w, h, attr, text):
462        return self.control(name, "Text", x, y, w, h, attr, None,
463                     text, None, None)
464
465    def bitmap(self, name, x, y, w, h, text):
466        return self.control(name, "Bitmap", x, y, w, h, 1, None, text, None, None)
467
468    def line(self, name, x, y, w, h):
469        return self.control(name, "Line", x, y, w, h, 1, None, None, None, None)
470
471    def pushbutton(self, name, x, y, w, h, attr, text, next):
472        return self.control(name, "PushButton", x, y, w, h, attr, None, text, next, None)
473
474    def radiogroup(self, name, x, y, w, h, attr, prop, text, next):
475        add_data(self.db, "Control",
476                 [(self.name, name, "RadioButtonGroup",
477                   x, y, w, h, attr, prop, text, next, None)])
478        return RadioButtonGroup(self, name, prop)
479
480    def checkbox(self, name, x, y, w, h, attr, prop, text, next):
481        return self.control(name, "CheckBox", x, y, w, h, attr, prop, text, next, None)
482