1import sys
2import marshal
3import contextlib
4import dis
5from distutils.version import StrictVersion
6
7from ._imp import find_module, PY_COMPILED, PY_FROZEN, PY_SOURCE
8from . import _imp
9
10
11__all__ = [
12    'Require', 'find_module', 'get_module_constant', 'extract_constant'
13]
14
15
16class Require:
17    """A prerequisite to building or installing a distribution"""
18
19    def __init__(
20            self, name, requested_version, module, homepage='',
21            attribute=None, format=None):
22
23        if format is None and requested_version is not None:
24            format = StrictVersion
25
26        if format is not None:
27            requested_version = format(requested_version)
28            if attribute is None:
29                attribute = '__version__'
30
31        self.__dict__.update(locals())
32        del self.self
33
34    def full_name(self):
35        """Return full package/distribution name, w/version"""
36        if self.requested_version is not None:
37            return '%s-%s' % (self.name, self.requested_version)
38        return self.name
39
40    def version_ok(self, version):
41        """Is 'version' sufficiently up-to-date?"""
42        return self.attribute is None or self.format is None or \
43            str(version) != "unknown" and version >= self.requested_version
44
45    def get_version(self, paths=None, default="unknown"):
46        """Get version number of installed module, 'None', or 'default'
47
48        Search 'paths' for module.  If not found, return 'None'.  If found,
49        return the extracted version attribute, or 'default' if no version
50        attribute was specified, or the value cannot be determined without
51        importing the module.  The version is formatted according to the
52        requirement's version format (if any), unless it is 'None' or the
53        supplied 'default'.
54        """
55
56        if self.attribute is None:
57            try:
58                f, p, i = find_module(self.module, paths)
59                if f:
60                    f.close()
61                return default
62            except ImportError:
63                return None
64
65        v = get_module_constant(self.module, self.attribute, default, paths)
66
67        if v is not None and v is not default and self.format is not None:
68            return self.format(v)
69
70        return v
71
72    def is_present(self, paths=None):
73        """Return true if dependency is present on 'paths'"""
74        return self.get_version(paths) is not None
75
76    def is_current(self, paths=None):
77        """Return true if dependency is present and up-to-date on 'paths'"""
78        version = self.get_version(paths)
79        if version is None:
80            return False
81        return self.version_ok(version)
82
83
84def maybe_close(f):
85    @contextlib.contextmanager
86    def empty():
87        yield
88        return
89    if not f:
90        return empty()
91
92    return contextlib.closing(f)
93
94
95def get_module_constant(module, symbol, default=-1, paths=None):
96    """Find 'module' by searching 'paths', and extract 'symbol'
97
98    Return 'None' if 'module' does not exist on 'paths', or it does not define
99    'symbol'.  If the module defines 'symbol' as a constant, return the
100    constant.  Otherwise, return 'default'."""
101
102    try:
103        f, path, (suffix, mode, kind) = info = find_module(module, paths)
104    except ImportError:
105        # Module doesn't exist
106        return None
107
108    with maybe_close(f):
109        if kind == PY_COMPILED:
110            f.read(8)  # skip magic & date
111            code = marshal.load(f)
112        elif kind == PY_FROZEN:
113            code = _imp.get_frozen_object(module, paths)
114        elif kind == PY_SOURCE:
115            code = compile(f.read(), path, 'exec')
116        else:
117            # Not something we can parse; we'll have to import it.  :(
118            imported = _imp.get_module(module, paths, info)
119            return getattr(imported, symbol, None)
120
121    return extract_constant(code, symbol, default)
122
123
124def extract_constant(code, symbol, default=-1):
125    """Extract the constant value of 'symbol' from 'code'
126
127    If the name 'symbol' is bound to a constant value by the Python code
128    object 'code', return that value.  If 'symbol' is bound to an expression,
129    return 'default'.  Otherwise, return 'None'.
130
131    Return value is based on the first assignment to 'symbol'.  'symbol' must
132    be a global, or at least a non-"fast" local in the code block.  That is,
133    only 'STORE_NAME' and 'STORE_GLOBAL' opcodes are checked, and 'symbol'
134    must be present in 'code.co_names'.
135    """
136    if symbol not in code.co_names:
137        # name's not there, can't possibly be an assignment
138        return None
139
140    name_idx = list(code.co_names).index(symbol)
141
142    STORE_NAME = 90
143    STORE_GLOBAL = 97
144    LOAD_CONST = 100
145
146    const = default
147
148    for byte_code in dis.Bytecode(code):
149        op = byte_code.opcode
150        arg = byte_code.arg
151
152        if op == LOAD_CONST:
153            const = code.co_consts[arg]
154        elif arg == name_idx and (op == STORE_NAME or op == STORE_GLOBAL):
155            return const
156        else:
157            const = default
158
159
160def _update_globals():
161    """
162    Patch the globals to remove the objects not available on some platforms.
163
164    XXX it'd be better to test assertions about bytecode instead.
165    """
166
167    if not sys.platform.startswith('java') and sys.platform != 'cli':
168        return
169    incompatible = 'extract_constant', 'get_module_constant'
170    for name in incompatible:
171        del globals()[name]
172        __all__.remove(name)
173
174
175_update_globals()
176