1"""This module can be used for finding similar code"""
2import re
3
4import rope.refactor.wildcards
5from rope.base import libutils
6from rope.base import codeanalyze, exceptions, ast, builtins
7from rope.refactor import (patchedast, wildcards)
8
9from rope.refactor.patchedast import MismatchedTokenError
10
11
12class BadNameInCheckError(exceptions.RefactoringError):
13    pass
14
15
16class SimilarFinder(object):
17    """`SimilarFinder` can be used to find similar pieces of code
18
19    See the notes in the `rope.refactor.restructure` module for more
20    info.
21
22    """
23
24    def __init__(self, pymodule, wildcards=None):
25        """Construct a SimilarFinder"""
26        self.source = pymodule.source_code
27        try:
28            self.raw_finder = RawSimilarFinder(
29                pymodule.source_code, pymodule.get_ast(), self._does_match)
30        except MismatchedTokenError:
31            print("in file %s" % pymodule.resource.path)
32            raise
33        self.pymodule = pymodule
34        if wildcards is None:
35            self.wildcards = {}
36            for wildcard in [rope.refactor.wildcards.
37                             DefaultWildcard(pymodule.pycore.project)]:
38                self.wildcards[wildcard.get_name()] = wildcard
39        else:
40            self.wildcards = wildcards
41
42    def get_matches(self, code, args={}, start=0, end=None):
43        self.args = args
44        if end is None:
45            end = len(self.source)
46        skip_region = None
47        if 'skip' in args.get('', {}):
48            resource, region = args['']['skip']
49            if resource == self.pymodule.get_resource():
50                skip_region = region
51        return self.raw_finder.get_matches(code, start=start, end=end,
52                                           skip=skip_region)
53
54    def get_match_regions(self, *args, **kwds):
55        for match in self.get_matches(*args, **kwds):
56            yield match.get_region()
57
58    def _does_match(self, node, name):
59        arg = self.args.get(name, '')
60        kind = 'default'
61        if isinstance(arg, (tuple, list)):
62            kind = arg[0]
63            arg = arg[1]
64        suspect = wildcards.Suspect(self.pymodule, node, name)
65        return self.wildcards[kind].matches(suspect, arg)
66
67
68class RawSimilarFinder(object):
69    """A class for finding similar expressions and statements"""
70
71    def __init__(self, source, node=None, does_match=None):
72        if node is None:
73            node = ast.parse(source)
74        if does_match is None:
75            self.does_match = self._simple_does_match
76        else:
77            self.does_match = does_match
78        self._init_using_ast(node, source)
79
80    def _simple_does_match(self, node, name):
81        return isinstance(node, (ast.expr, ast.Name))
82
83    def _init_using_ast(self, node, source):
84        self.source = source
85        self._matched_asts = {}
86        if not hasattr(node, 'region'):
87            patchedast.patch_ast(node, source)
88        self.ast = node
89
90    def get_matches(self, code, start=0, end=None, skip=None):
91        """Search for `code` in source and return a list of `Match`\es
92
93        `code` can contain wildcards.  ``${name}`` matches normal
94        names and ``${?name} can match any expression.  You can use
95        `Match.get_ast()` for getting the node that has matched a
96        given pattern.
97
98        """
99        if end is None:
100            end = len(self.source)
101        for match in self._get_matched_asts(code):
102            match_start, match_end = match.get_region()
103            if start <= match_start and match_end <= end:
104                if skip is not None and (skip[0] < match_end and
105                                         skip[1] > match_start):
106                    continue
107                yield match
108
109    def _get_matched_asts(self, code):
110        if code not in self._matched_asts:
111            wanted = self._create_pattern(code)
112            matches = _ASTMatcher(self.ast, wanted,
113                                  self.does_match).find_matches()
114            self._matched_asts[code] = matches
115        return self._matched_asts[code]
116
117    def _create_pattern(self, expression):
118        expression = self._replace_wildcards(expression)
119        node = ast.parse(expression)
120        # Getting Module.Stmt.nodes
121        nodes = node.body
122        if len(nodes) == 1 and isinstance(nodes[0], ast.Expr):
123            # Getting Discard.expr
124            wanted = nodes[0].value
125        else:
126            wanted = nodes
127        return wanted
128
129    def _replace_wildcards(self, expression):
130        ropevar = _RopeVariable()
131        template = CodeTemplate(expression)
132        mapping = {}
133        for name in template.get_names():
134            mapping[name] = ropevar.get_var(name)
135        return template.substitute(mapping)
136
137
138class _ASTMatcher(object):
139
140    def __init__(self, body, pattern, does_match):
141        """Searches the given pattern in the body AST.
142
143        body is an AST node and pattern can be either an AST node or
144        a list of ASTs nodes
145        """
146        self.body = body
147        self.pattern = pattern
148        self.matches = None
149        self.ropevar = _RopeVariable()
150        self.matches_callback = does_match
151
152    def find_matches(self):
153        if self.matches is None:
154            self.matches = []
155            ast.call_for_nodes(self.body, self._check_node, recursive=True)
156        return self.matches
157
158    def _check_node(self, node):
159        if isinstance(self.pattern, list):
160            self._check_statements(node)
161        else:
162            self._check_expression(node)
163
164    def _check_expression(self, node):
165        mapping = {}
166        if self._match_nodes(self.pattern, node, mapping):
167            self.matches.append(ExpressionMatch(node, mapping))
168
169    def _check_statements(self, node):
170        for child in ast.get_children(node):
171            if isinstance(child, (list, tuple)):
172                self.__check_stmt_list(child)
173
174    def __check_stmt_list(self, nodes):
175        for index in range(len(nodes)):
176            if len(nodes) - index >= len(self.pattern):
177                current_stmts = nodes[index:index + len(self.pattern)]
178                mapping = {}
179                if self._match_stmts(current_stmts, mapping):
180                    self.matches.append(StatementMatch(current_stmts, mapping))
181
182    def _match_nodes(self, expected, node, mapping):
183        if isinstance(expected, ast.Name):
184            if self.ropevar.is_var(expected.id):
185                return self._match_wildcard(expected, node, mapping)
186        if not isinstance(expected, ast.AST):
187            return expected == node
188        if expected.__class__ != node.__class__:
189            return False
190
191        children1 = self._get_children(expected)
192        children2 = self._get_children(node)
193        if len(children1) != len(children2):
194            return False
195        for child1, child2 in zip(children1, children2):
196            if isinstance(child1, ast.AST):
197                if not self._match_nodes(child1, child2, mapping):
198                    return False
199            elif isinstance(child1, (list, tuple)):
200                if not isinstance(child2, (list, tuple)) or \
201                   len(child1) != len(child2):
202                    return False
203                for c1, c2 in zip(child1, child2):
204                    if not self._match_nodes(c1, c2, mapping):
205                        return False
206            else:
207                if type(child1) is not type(child2) or child1 != child2:
208                    return False
209        return True
210
211    def _get_children(self, node):
212        """Return not `ast.expr_context` children of `node`"""
213        children = ast.get_children(node)
214        return [child for child in children
215                if not isinstance(child, ast.expr_context)]
216
217    def _match_stmts(self, current_stmts, mapping):
218        if len(current_stmts) != len(self.pattern):
219            return False
220        for stmt, expected in zip(current_stmts, self.pattern):
221            if not self._match_nodes(expected, stmt, mapping):
222                return False
223        return True
224
225    def _match_wildcard(self, node1, node2, mapping):
226        name = self.ropevar.get_base(node1.id)
227        if name not in mapping:
228            if self.matches_callback(node2, name):
229                mapping[name] = node2
230                return True
231            return False
232        else:
233            return self._match_nodes(mapping[name], node2, {})
234
235
236class Match(object):
237
238    def __init__(self, mapping):
239        self.mapping = mapping
240
241    def get_region(self):
242        """Returns match region"""
243
244    def get_ast(self, name):
245        """Return the ast node that has matched rope variables"""
246        return self.mapping.get(name, None)
247
248
249class ExpressionMatch(Match):
250
251    def __init__(self, ast, mapping):
252        super(ExpressionMatch, self).__init__(mapping)
253        self.ast = ast
254
255    def get_region(self):
256        return self.ast.region
257
258
259class StatementMatch(Match):
260
261    def __init__(self, ast_list, mapping):
262        super(StatementMatch, self).__init__(mapping)
263        self.ast_list = ast_list
264
265    def get_region(self):
266        return self.ast_list[0].region[0], self.ast_list[-1].region[1]
267
268
269class CodeTemplate(object):
270
271    def __init__(self, template):
272        self.template = template
273        self._find_names()
274
275    def _find_names(self):
276        self.names = {}
277        for match in CodeTemplate._get_pattern().finditer(self.template):
278            if 'name' in match.groupdict() and \
279               match.group('name') is not None:
280                start, end = match.span('name')
281                name = self.template[start + 2:end - 1]
282                if name not in self.names:
283                    self.names[name] = []
284                self.names[name].append((start, end))
285
286    def get_names(self):
287        return self.names.keys()
288
289    def substitute(self, mapping):
290        collector = codeanalyze.ChangeCollector(self.template)
291        for name, occurrences in self.names.items():
292            for region in occurrences:
293                collector.add_change(region[0], region[1], mapping[name])
294        result = collector.get_changed()
295        if result is None:
296            return self.template
297        return result
298
299    _match_pattern = None
300
301    @classmethod
302    def _get_pattern(cls):
303        if cls._match_pattern is None:
304            pattern = codeanalyze.get_comment_pattern() + '|' + \
305                codeanalyze.get_string_pattern() + '|' + \
306                r'(?P<name>\$\{[^\s\$\}]*\})'
307            cls._match_pattern = re.compile(pattern)
308        return cls._match_pattern
309
310
311class _RopeVariable(object):
312    """Transform and identify rope inserted wildcards"""
313
314    _normal_prefix = '__rope__variable_normal_'
315    _any_prefix = '__rope__variable_any_'
316
317    def get_var(self, name):
318        if name.startswith('?'):
319            return self._get_any(name)
320        else:
321            return self._get_normal(name)
322
323    def is_var(self, name):
324        return self._is_normal(name) or self._is_var(name)
325
326    def get_base(self, name):
327        if self._is_normal(name):
328            return name[len(self._normal_prefix):]
329        if self._is_var(name):
330            return '?' + name[len(self._any_prefix):]
331
332    def _get_normal(self, name):
333        return self._normal_prefix + name
334
335    def _get_any(self, name):
336        return self._any_prefix + name[1:]
337
338    def _is_normal(self, name):
339        return name.startswith(self._normal_prefix)
340
341    def _is_var(self, name):
342        return name.startswith(self._any_prefix)
343
344
345def make_pattern(code, variables):
346    variables = set(variables)
347    collector = codeanalyze.ChangeCollector(code)
348
349    def does_match(node, name):
350        return isinstance(node, ast.Name) and node.id == name
351    finder = RawSimilarFinder(code, does_match=does_match)
352    for variable in variables:
353        for match in finder.get_matches('${%s}' % variable):
354            start, end = match.get_region()
355            collector.add_change(start, end, '${%s}' % variable)
356    result = collector.get_changed()
357    return result if result is not None else code
358
359
360def _pydefined_to_str(pydefined):
361    address = []
362    if isinstance(pydefined,
363                  (builtins.BuiltinClass, builtins.BuiltinFunction)):
364        return '__builtins__.' + pydefined.get_name()
365    else:
366        while pydefined.parent is not None:
367            address.insert(0, pydefined.get_name())
368            pydefined = pydefined.parent
369        module_name = libutils.modname(pydefined.resource)
370    return '.'.join(module_name.split('.') + address)
371