1"""
2My own variation on function-specific inspect-like features.
3"""
4
5# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
6# Copyright (c) 2009 Gael Varoquaux
7# License: BSD Style, 3 clauses.
8
9import inspect
10import warnings
11import re
12import os
13import collections
14
15from itertools import islice
16from tokenize import open as open_py_source
17
18from .logger import pformat
19
20full_argspec_fields = ('args varargs varkw defaults kwonlyargs '
21                       'kwonlydefaults annotations')
22full_argspec_type = collections.namedtuple('FullArgSpec', full_argspec_fields)
23
24
25def get_func_code(func):
26    """ Attempts to retrieve a reliable function code hash.
27
28        The reason we don't use inspect.getsource is that it caches the
29        source, whereas we want this to be modified on the fly when the
30        function is modified.
31
32        Returns
33        -------
34        func_code: string
35            The function code
36        source_file: string
37            The path to the file in which the function is defined.
38        first_line: int
39            The first line of the code in the source file.
40
41        Notes
42        ------
43        This function does a bit more magic than inspect, and is thus
44        more robust.
45    """
46    source_file = None
47    try:
48        code = func.__code__
49        source_file = code.co_filename
50        if not os.path.exists(source_file):
51            # Use inspect for lambda functions and functions defined in an
52            # interactive shell, or in doctests
53            source_code = ''.join(inspect.getsourcelines(func)[0])
54            line_no = 1
55            if source_file.startswith('<doctest '):
56                source_file, line_no = re.match(
57                    r'\<doctest (.*\.rst)\[(.*)\]\>', source_file).groups()
58                line_no = int(line_no)
59                source_file = '<doctest %s>' % source_file
60            return source_code, source_file, line_no
61        # Try to retrieve the source code.
62        with open_py_source(source_file) as source_file_obj:
63            first_line = code.co_firstlineno
64            # All the lines after the function definition:
65            source_lines = list(islice(source_file_obj, first_line - 1, None))
66        return ''.join(inspect.getblock(source_lines)), source_file, first_line
67    except:
68        # If the source code fails, we use the hash. This is fragile and
69        # might change from one session to another.
70        if hasattr(func, '__code__'):
71            # Python 3.X
72            return str(func.__code__.__hash__()), source_file, -1
73        else:
74            # Weird objects like numpy ufunc don't have __code__
75            # This is fragile, as quite often the id of the object is
76            # in the repr, so it might not persist across sessions,
77            # however it will work for ufuncs.
78            return repr(func), source_file, -1
79
80
81def _clean_win_chars(string):
82    """Windows cannot encode some characters in filename."""
83    import urllib
84    if hasattr(urllib, 'quote'):
85        quote = urllib.quote
86    else:
87        # In Python 3, quote is elsewhere
88        import urllib.parse
89        quote = urllib.parse.quote
90    for char in ('<', '>', '!', ':', '\\'):
91        string = string.replace(char, quote(char))
92    return string
93
94
95def get_func_name(func, resolv_alias=True, win_characters=True):
96    """ Return the function import path (as a list of module names), and
97        a name for the function.
98
99        Parameters
100        ----------
101        func: callable
102            The func to inspect
103        resolv_alias: boolean, optional
104            If true, possible local aliases are indicated.
105        win_characters: boolean, optional
106            If true, substitute special characters using urllib.quote
107            This is useful in Windows, as it cannot encode some filenames
108    """
109    if hasattr(func, '__module__'):
110        module = func.__module__
111    else:
112        try:
113            module = inspect.getmodule(func)
114        except TypeError:
115            if hasattr(func, '__class__'):
116                module = func.__class__.__module__
117            else:
118                module = 'unknown'
119    if module is None:
120        # Happens in doctests, eg
121        module = ''
122    if module == '__main__':
123        try:
124            filename = os.path.abspath(inspect.getsourcefile(func))
125        except:
126            filename = None
127        if filename is not None:
128            # mangling of full path to filename
129            parts = filename.split(os.sep)
130            if parts[-1].startswith('<ipython-input'):
131                # We're in a IPython (or notebook) session. parts[-1] comes
132                # from func.__code__.co_filename and is of the form
133                # <ipython-input-N-XYZ>, where:
134                # - N is the cell number where the function was defined
135                # - XYZ is a hash representing the function's code (and name).
136                #   It will be consistent across sessions and kernel restarts,
137                #   and will change if the function's code/name changes
138                # We remove N so that cache is properly hit if the cell where
139                # the func is defined is re-exectuted.
140                # The XYZ hash should avoid collisions between functions with
141                # the same name, both within the same notebook but also across
142                # notebooks
143                splitted = parts[-1].split('-')
144                parts[-1] = '-'.join(splitted[:2] + splitted[3:])
145            elif len(parts) > 2 and parts[-2].startswith('ipykernel_'):
146                # In a notebook session (ipykernel). Filename seems to be 'xyz'
147                # of above. parts[-2] has the structure ipykernel_XXXXXX where
148                # XXXXXX is a six-digit number identifying the current run (?).
149                # If we split it off, the function again has the same
150                # identifier across runs.
151                parts[-2] = 'ipykernel'
152            filename = '-'.join(parts)
153            if filename.endswith('.py'):
154                filename = filename[:-3]
155            module = module + '-' + filename
156    module = module.split('.')
157    if hasattr(func, 'func_name'):
158        name = func.func_name
159    elif hasattr(func, '__name__'):
160        name = func.__name__
161    else:
162        name = 'unknown'
163    # Hack to detect functions not defined at the module-level
164    if resolv_alias:
165        # TODO: Maybe add a warning here?
166        if hasattr(func, 'func_globals') and name in func.func_globals:
167            if not func.func_globals[name] is func:
168                name = '%s-alias' % name
169    if inspect.ismethod(func):
170        # We need to add the name of the class
171        if hasattr(func, 'im_class'):
172            klass = func.im_class
173            module.append(klass.__name__)
174    if os.name == 'nt' and win_characters:
175        # Windows can't encode certain characters in filenames
176        name = _clean_win_chars(name)
177        module = [_clean_win_chars(s) for s in module]
178    return module, name
179
180
181def _signature_str(function_name, arg_sig):
182    """Helper function to output a function signature"""
183    return '{}{}'.format(function_name, arg_sig)
184
185
186def _function_called_str(function_name, args, kwargs):
187    """Helper function to output a function call"""
188    template_str = '{0}({1}, {2})'
189
190    args_str = repr(args)[1:-1]
191    kwargs_str = ', '.join('%s=%s' % (k, v)
192                           for k, v in kwargs.items())
193    return template_str.format(function_name, args_str,
194                               kwargs_str)
195
196
197def filter_args(func, ignore_lst, args=(), kwargs=dict()):
198    """ Filters the given args and kwargs using a list of arguments to
199        ignore, and a function specification.
200
201        Parameters
202        ----------
203        func: callable
204            Function giving the argument specification
205        ignore_lst: list of strings
206            List of arguments to ignore (either a name of an argument
207            in the function spec, or '*', or '**')
208        *args: list
209            Positional arguments passed to the function.
210        **kwargs: dict
211            Keyword arguments passed to the function
212
213        Returns
214        -------
215        filtered_args: list
216            List of filtered positional and keyword arguments.
217    """
218    args = list(args)
219    if isinstance(ignore_lst, str):
220        # Catch a common mistake
221        raise ValueError(
222            'ignore_lst must be a list of parameters to ignore '
223            '%s (type %s) was given' % (ignore_lst, type(ignore_lst)))
224    # Special case for functools.partial objects
225    if (not inspect.ismethod(func) and not inspect.isfunction(func)):
226        if ignore_lst:
227            warnings.warn('Cannot inspect object %s, ignore list will '
228                          'not work.' % func, stacklevel=2)
229        return {'*': args, '**': kwargs}
230    arg_sig = inspect.signature(func)
231    arg_names = []
232    arg_defaults = []
233    arg_kwonlyargs = []
234    arg_varargs = None
235    arg_varkw = None
236    for param in arg_sig.parameters.values():
237        if param.kind is param.POSITIONAL_OR_KEYWORD:
238            arg_names.append(param.name)
239        elif param.kind is param.KEYWORD_ONLY:
240            arg_names.append(param.name)
241            arg_kwonlyargs.append(param.name)
242        elif param.kind is param.VAR_POSITIONAL:
243            arg_varargs = param.name
244        elif param.kind is param.VAR_KEYWORD:
245            arg_varkw = param.name
246        if param.default is not param.empty:
247            arg_defaults.append(param.default)
248    if inspect.ismethod(func):
249        # First argument is 'self', it has been removed by Python
250        # we need to add it back:
251        args = [func.__self__, ] + args
252        # func is an instance method, inspect.signature(func) does not
253        # include self, we need to fetch it from the class method, i.e
254        # func.__func__
255        class_method_sig = inspect.signature(func.__func__)
256        self_name = next(iter(class_method_sig.parameters))
257        arg_names = [self_name] + arg_names
258    # XXX: Maybe I need an inspect.isbuiltin to detect C-level methods, such
259    # as on ndarrays.
260
261    _, name = get_func_name(func, resolv_alias=False)
262    arg_dict = dict()
263    arg_position = -1
264    for arg_position, arg_name in enumerate(arg_names):
265        if arg_position < len(args):
266            # Positional argument or keyword argument given as positional
267            if arg_name not in arg_kwonlyargs:
268                arg_dict[arg_name] = args[arg_position]
269            else:
270                raise ValueError(
271                    "Keyword-only parameter '%s' was passed as "
272                    'positional parameter for %s:\n'
273                    '     %s was called.'
274                    % (arg_name,
275                       _signature_str(name, arg_sig),
276                       _function_called_str(name, args, kwargs))
277                )
278
279        else:
280            position = arg_position - len(arg_names)
281            if arg_name in kwargs:
282                arg_dict[arg_name] = kwargs[arg_name]
283            else:
284                try:
285                    arg_dict[arg_name] = arg_defaults[position]
286                except (IndexError, KeyError) as e:
287                    # Missing argument
288                    raise ValueError(
289                        'Wrong number of arguments for %s:\n'
290                        '     %s was called.'
291                        % (_signature_str(name, arg_sig),
292                           _function_called_str(name, args, kwargs))
293                    ) from e
294
295    varkwargs = dict()
296    for arg_name, arg_value in sorted(kwargs.items()):
297        if arg_name in arg_dict:
298            arg_dict[arg_name] = arg_value
299        elif arg_varkw is not None:
300            varkwargs[arg_name] = arg_value
301        else:
302            raise TypeError("Ignore list for %s() contains an unexpected "
303                            "keyword argument '%s'" % (name, arg_name))
304
305    if arg_varkw is not None:
306        arg_dict['**'] = varkwargs
307    if arg_varargs is not None:
308        varargs = args[arg_position + 1:]
309        arg_dict['*'] = varargs
310
311    # Now remove the arguments to be ignored
312    for item in ignore_lst:
313        if item in arg_dict:
314            arg_dict.pop(item)
315        else:
316            raise ValueError("Ignore list: argument '%s' is not defined for "
317                             "function %s"
318                             % (item,
319                                _signature_str(name, arg_sig))
320                             )
321    # XXX: Return a sorted list of pairs?
322    return arg_dict
323
324
325def _format_arg(arg):
326    formatted_arg = pformat(arg, indent=2)
327    if len(formatted_arg) > 1500:
328        formatted_arg = '%s...' % formatted_arg[:700]
329    return formatted_arg
330
331
332def format_signature(func, *args, **kwargs):
333    # XXX: Should this use inspect.formatargvalues/formatargspec?
334    module, name = get_func_name(func)
335    module = [m for m in module if m]
336    if module:
337        module.append(name)
338        module_path = '.'.join(module)
339    else:
340        module_path = name
341    arg_str = list()
342    previous_length = 0
343    for arg in args:
344        formatted_arg = _format_arg(arg)
345        if previous_length > 80:
346            formatted_arg = '\n%s' % formatted_arg
347        previous_length = len(formatted_arg)
348        arg_str.append(formatted_arg)
349    arg_str.extend(['%s=%s' % (v, _format_arg(i)) for v, i in kwargs.items()])
350    arg_str = ', '.join(arg_str)
351
352    signature = '%s(%s)' % (name, arg_str)
353    return module_path, signature
354
355
356def format_call(func, args, kwargs, object_name="Memory"):
357    """ Returns a nicely formatted statement displaying the function
358        call with the given arguments.
359    """
360    path, signature = format_signature(func, *args, **kwargs)
361    msg = '%s\n[%s] Calling %s...\n%s' % (80 * '_', object_name,
362                                          path, signature)
363    return msg
364    # XXX: Not using logging framework
365    # self.debug(msg)
366