1"""
2This module is responsible for inferring *args and **kwargs for signatures.
3
4This means for example in this case::
5
6    def foo(a, b, c): ...
7
8    def bar(*args):
9        return foo(1, *args)
10
11The signature here for bar should be `bar(b, c)` instead of bar(*args).
12"""
13from inspect import Parameter
14
15from jedi.inference.utils import to_list
16from jedi.inference.names import ParamNameWrapper
17from jedi.inference.helpers import is_big_annoying_library
18
19
20def _iter_nodes_for_param(param_name):
21    from parso.python.tree import search_ancestor
22    from jedi.inference.arguments import TreeArguments
23
24    execution_context = param_name.parent_context
25    function_node = execution_context.tree_node
26    module_node = function_node.get_root_node()
27    start = function_node.children[-1].start_pos
28    end = function_node.children[-1].end_pos
29    for name in module_node.get_used_names().get(param_name.string_name):
30        if start <= name.start_pos < end:
31            # Is used in the function
32            argument = name.parent
33            if argument.type == 'argument' \
34                    and argument.children[0] == '*' * param_name.star_count:
35                trailer = search_ancestor(argument, 'trailer')
36                if trailer is not None:  # Make sure we're in a function
37                    context = execution_context.create_context(trailer)
38                    if _goes_to_param_name(param_name, context, name):
39                        values = _to_callables(context, trailer)
40
41                        args = TreeArguments.create_cached(
42                            execution_context.inference_state,
43                            context=context,
44                            argument_node=trailer.children[1],
45                            trailer=trailer,
46                        )
47                        for c in values:
48                            yield c, args
49
50
51def _goes_to_param_name(param_name, context, potential_name):
52    if potential_name.type != 'name':
53        return False
54    from jedi.inference.names import TreeNameDefinition
55    found = TreeNameDefinition(context, potential_name).goto()
56    return any(param_name.parent_context == p.parent_context
57               and param_name.start_pos == p.start_pos
58               for p in found)
59
60
61def _to_callables(context, trailer):
62    from jedi.inference.syntax_tree import infer_trailer
63
64    atom_expr = trailer.parent
65    index = atom_expr.children[0] == 'await'
66    # Infer atom first
67    values = context.infer_node(atom_expr.children[index])
68    for trailer2 in atom_expr.children[index + 1:]:
69        if trailer == trailer2:
70            break
71        values = infer_trailer(context, values, trailer2)
72    return values
73
74
75def _remove_given_params(arguments, param_names):
76    count = 0
77    used_keys = set()
78    for key, _ in arguments.unpack():
79        if key is None:
80            count += 1
81        else:
82            used_keys.add(key)
83
84    for p in param_names:
85        if count and p.maybe_positional_argument():
86            count -= 1
87            continue
88        if p.string_name in used_keys and p.maybe_keyword_argument():
89            continue
90        yield p
91
92
93@to_list
94def process_params(param_names, star_count=3):  # default means both * and **
95    if param_names:
96        if is_big_annoying_library(param_names[0].parent_context):
97            # At first this feature can look innocent, but it does a lot of
98            # type inference in some cases, so we just ditch it.
99            yield from param_names
100            return
101
102    used_names = set()
103    arg_callables = []
104    kwarg_callables = []
105
106    kw_only_names = []
107    kwarg_names = []
108    arg_names = []
109    original_arg_name = None
110    original_kwarg_name = None
111    for p in param_names:
112        kind = p.get_kind()
113        if kind == Parameter.VAR_POSITIONAL:
114            if star_count & 1:
115                arg_callables = _iter_nodes_for_param(p)
116                original_arg_name = p
117        elif p.get_kind() == Parameter.VAR_KEYWORD:
118            if star_count & 2:
119                kwarg_callables = list(_iter_nodes_for_param(p))
120                original_kwarg_name = p
121        elif kind == Parameter.KEYWORD_ONLY:
122            if star_count & 2:
123                kw_only_names.append(p)
124        elif kind == Parameter.POSITIONAL_ONLY:
125            if star_count & 1:
126                yield p
127        else:
128            if star_count == 1:
129                yield ParamNameFixedKind(p, Parameter.POSITIONAL_ONLY)
130            elif star_count == 2:
131                kw_only_names.append(ParamNameFixedKind(p, Parameter.KEYWORD_ONLY))
132            else:
133                used_names.add(p.string_name)
134                yield p
135
136    # First process *args
137    longest_param_names = ()
138    found_arg_signature = False
139    found_kwarg_signature = False
140    for func_and_argument in arg_callables:
141        func, arguments = func_and_argument
142        new_star_count = star_count
143        if func_and_argument in kwarg_callables:
144            kwarg_callables.remove(func_and_argument)
145        else:
146            new_star_count = 1
147
148        for signature in func.get_signatures():
149            found_arg_signature = True
150            if new_star_count == 3:
151                found_kwarg_signature = True
152            args_for_this_func = []
153            for p in process_params(
154                    list(_remove_given_params(
155                        arguments,
156                        signature.get_param_names(resolve_stars=False)
157                    )), new_star_count):
158                if p.get_kind() == Parameter.VAR_KEYWORD:
159                    kwarg_names.append(p)
160                elif p.get_kind() == Parameter.VAR_POSITIONAL:
161                    arg_names.append(p)
162                elif p.get_kind() == Parameter.KEYWORD_ONLY:
163                    kw_only_names.append(p)
164                else:
165                    args_for_this_func.append(p)
166            if len(args_for_this_func) > len(longest_param_names):
167                longest_param_names = args_for_this_func
168
169    for p in longest_param_names:
170        if star_count == 1 and p.get_kind() != Parameter.VAR_POSITIONAL:
171            yield ParamNameFixedKind(p, Parameter.POSITIONAL_ONLY)
172        else:
173            if p.get_kind() == Parameter.POSITIONAL_OR_KEYWORD:
174                used_names.add(p.string_name)
175            yield p
176
177    if not found_arg_signature and original_arg_name is not None:
178        yield original_arg_name
179    elif arg_names:
180        yield arg_names[0]
181
182    # Then process **kwargs
183    for func, arguments in kwarg_callables:
184        for signature in func.get_signatures():
185            found_kwarg_signature = True
186            for p in process_params(
187                    list(_remove_given_params(
188                        arguments,
189                        signature.get_param_names(resolve_stars=False)
190                    )), star_count=2):
191                if p.get_kind() == Parameter.VAR_KEYWORD:
192                    kwarg_names.append(p)
193                elif p.get_kind() == Parameter.KEYWORD_ONLY:
194                    kw_only_names.append(p)
195
196    for p in kw_only_names:
197        if p.string_name in used_names:
198            continue
199        yield p
200        used_names.add(p.string_name)
201
202    if not found_kwarg_signature and original_kwarg_name is not None:
203        yield original_kwarg_name
204    elif kwarg_names:
205        yield kwarg_names[0]
206
207
208class ParamNameFixedKind(ParamNameWrapper):
209    def __init__(self, param_name, new_kind):
210        super().__init__(param_name)
211        self._new_kind = new_kind
212
213    def get_kind(self):
214        return self._new_kind
215