1from __future__ import absolute_import
2
3import re
4import sys
5import copy
6import codecs
7import itertools
8
9from . import TypeSlots
10from .ExprNodes import not_a_constant
11import cython
12cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object, encoded_string=object,
13               Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
14               UtilNodes=object, _py_int_types=object)
15
16if sys.version_info[0] >= 3:
17    _py_int_types = int
18    _py_string_types = (bytes, str)
19else:
20    _py_int_types = (int, long)
21    _py_string_types = (bytes, unicode)
22
23from . import Nodes
24from . import ExprNodes
25from . import PyrexTypes
26from . import Visitor
27from . import Builtin
28from . import UtilNodes
29from . import Options
30
31from .Code import UtilityCode, TempitaUtilityCode
32from .StringEncoding import EncodedString, bytes_literal, encoded_string
33from .Errors import error, warning
34from .ParseTreeTransforms import SkipDeclarations
35
36try:
37    from __builtin__ import reduce
38except ImportError:
39    from functools import reduce
40
41try:
42    from __builtin__ import basestring
43except ImportError:
44    basestring = str # Python 3
45
46
47def load_c_utility(name):
48    return UtilityCode.load_cached(name, "Optimize.c")
49
50
51def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
52    if isinstance(node, coercion_nodes):
53        return node.arg
54    return node
55
56
57def unwrap_node(node):
58    while isinstance(node, UtilNodes.ResultRefNode):
59        node = node.expression
60    return node
61
62
63def is_common_value(a, b):
64    a = unwrap_node(a)
65    b = unwrap_node(b)
66    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
67        return a.name == b.name
68    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
69        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
70    return False
71
72
73def filter_none_node(node):
74    if node is not None and node.constant_result is None:
75        return None
76    return node
77
78
79class _YieldNodeCollector(Visitor.TreeVisitor):
80    """
81    YieldExprNode finder for generator expressions.
82    """
83    def __init__(self):
84        Visitor.TreeVisitor.__init__(self)
85        self.yield_stat_nodes = {}
86        self.yield_nodes = []
87
88    visit_Node = Visitor.TreeVisitor.visitchildren
89
90    def visit_YieldExprNode(self, node):
91        self.yield_nodes.append(node)
92        self.visitchildren(node)
93
94    def visit_ExprStatNode(self, node):
95        self.visitchildren(node)
96        if node.expr in self.yield_nodes:
97            self.yield_stat_nodes[node.expr] = node
98
99    # everything below these nodes is out of scope:
100
101    def visit_GeneratorExpressionNode(self, node):
102        pass
103
104    def visit_LambdaNode(self, node):
105        pass
106
107    def visit_FuncDefNode(self, node):
108        pass
109
110
111def _find_single_yield_expression(node):
112    yield_statements = _find_yield_statements(node)
113    if len(yield_statements) != 1:
114        return None, None
115    return yield_statements[0]
116
117
118def _find_yield_statements(node):
119    collector = _YieldNodeCollector()
120    collector.visitchildren(node)
121    try:
122        yield_statements = [
123            (yield_node.arg, collector.yield_stat_nodes[yield_node])
124            for yield_node in collector.yield_nodes
125        ]
126    except KeyError:
127        # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield')
128        yield_statements = []
129    return yield_statements
130
131
132class IterationTransform(Visitor.EnvTransform):
133    """Transform some common for-in loop patterns into efficient C loops:
134
135    - for-in-dict loop becomes a while loop calling PyDict_Next()
136    - for-in-enumerate is replaced by an external counter variable
137    - for-in-range loop becomes a plain C for loop
138    """
139    def visit_PrimaryCmpNode(self, node):
140        if node.is_ptr_contains():
141
142            # for t in operand2:
143            #     if operand1 == t:
144            #         res = True
145            #         break
146            # else:
147            #     res = False
148
149            pos = node.pos
150            result_ref = UtilNodes.ResultRefNode(node)
151            if node.operand2.is_subscript:
152                base_type = node.operand2.base.type.base_type
153            else:
154                base_type = node.operand2.type.base_type
155            target_handle = UtilNodes.TempHandle(base_type)
156            target = target_handle.ref(pos)
157            cmp_node = ExprNodes.PrimaryCmpNode(
158                pos, operator=u'==', operand1=node.operand1, operand2=target)
159            if_body = Nodes.StatListNode(
160                pos,
161                stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
162                         Nodes.BreakStatNode(pos)])
163            if_node = Nodes.IfStatNode(
164                pos,
165                if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
166                else_clause=None)
167            for_loop = UtilNodes.TempsBlockNode(
168                pos,
169                temps = [target_handle],
170                body = Nodes.ForInStatNode(
171                    pos,
172                    target=target,
173                    iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
174                    body=if_node,
175                    else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
176            for_loop = for_loop.analyse_expressions(self.current_env())
177            for_loop = self.visit(for_loop)
178            new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
179
180            if node.operator == 'not_in':
181                new_node = ExprNodes.NotNode(pos, operand=new_node)
182            return new_node
183
184        else:
185            self.visitchildren(node)
186            return node
187
188    def visit_ForInStatNode(self, node):
189        self.visitchildren(node)
190        return self._optimise_for_loop(node, node.iterator.sequence)
191
192    def _optimise_for_loop(self, node, iterable, reversed=False):
193        annotation_type = None
194        if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation:
195            annotation = iterable.entry.annotation
196            if annotation.is_subscript:
197                annotation = annotation.base  # container base type
198            # FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names?
199            if annotation.is_name:
200                if annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
201                    annotation_type = Builtin.dict_type
202                elif annotation.name == 'Dict':
203                    annotation_type = Builtin.dict_type
204                if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'):
205                    annotation_type = Builtin.set_type
206                elif annotation.name in ('Set', 'FrozenSet'):
207                    annotation_type = Builtin.set_type
208
209        if Builtin.dict_type in (iterable.type, annotation_type):
210            # like iterating over dict.keys()
211            if reversed:
212                # CPython raises an error here: not a sequence
213                return node
214            return self._transform_dict_iteration(
215                node, dict_obj=iterable, method=None, keys=True, values=False)
216
217        if (Builtin.set_type in (iterable.type, annotation_type) or
218                Builtin.frozenset_type in (iterable.type, annotation_type)):
219            if reversed:
220                # CPython raises an error here: not a sequence
221                return node
222            return self._transform_set_iteration(node, iterable)
223
224        # C array (slice) iteration?
225        if iterable.type.is_ptr or iterable.type.is_array:
226            return self._transform_carray_iteration(node, iterable, reversed=reversed)
227        if iterable.type is Builtin.bytes_type:
228            return self._transform_bytes_iteration(node, iterable, reversed=reversed)
229        if iterable.type is Builtin.unicode_type:
230            return self._transform_unicode_iteration(node, iterable, reversed=reversed)
231
232        # the rest is based on function calls
233        if not isinstance(iterable, ExprNodes.SimpleCallNode):
234            return node
235
236        if iterable.args is None:
237            arg_count = iterable.arg_tuple and len(iterable.arg_tuple.args) or 0
238        else:
239            arg_count = len(iterable.args)
240            if arg_count and iterable.self is not None:
241                arg_count -= 1
242
243        function = iterable.function
244        # dict iteration?
245        if function.is_attribute and not reversed and not arg_count:
246            base_obj = iterable.self or function.obj
247            method = function.attribute
248            # in Py3, items() is equivalent to Py2's iteritems()
249            is_safe_iter = self.global_scope().context.language_level >= 3
250
251            if not is_safe_iter and method in ('keys', 'values', 'items'):
252                # try to reduce this to the corresponding .iter*() methods
253                if isinstance(base_obj, ExprNodes.CallNode):
254                    inner_function = base_obj.function
255                    if (inner_function.is_name and inner_function.name == 'dict'
256                            and inner_function.entry
257                            and inner_function.entry.is_builtin):
258                        # e.g. dict(something).items() => safe to use .iter*()
259                        is_safe_iter = True
260
261            keys = values = False
262            if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
263                keys = True
264            elif method == 'itervalues' or (is_safe_iter and method == 'values'):
265                values = True
266            elif method == 'iteritems' or (is_safe_iter and method == 'items'):
267                keys = values = True
268
269            if keys or values:
270                return self._transform_dict_iteration(
271                    node, base_obj, method, keys, values)
272
273        # enumerate/reversed ?
274        if iterable.self is None and function.is_name and \
275               function.entry and function.entry.is_builtin:
276            if function.name == 'enumerate':
277                if reversed:
278                    # CPython raises an error here: not a sequence
279                    return node
280                return self._transform_enumerate_iteration(node, iterable)
281            elif function.name == 'reversed':
282                if reversed:
283                    # CPython raises an error here: not a sequence
284                    return node
285                return self._transform_reversed_iteration(node, iterable)
286
287        # range() iteration?
288        if Options.convert_range and arg_count >= 1 and (
289                iterable.self is None and
290                function.is_name and function.name in ('range', 'xrange') and
291                function.entry and function.entry.is_builtin):
292            if node.target.type.is_int or node.target.type.is_enum:
293                return self._transform_range_iteration(node, iterable, reversed=reversed)
294            if node.target.type.is_pyobject:
295                # Assume that small integer ranges (C long >= 32bit) are best handled in C as well.
296                for arg in (iterable.arg_tuple.args if iterable.args is None else iterable.args):
297                    if isinstance(arg, ExprNodes.IntNode):
298                        if arg.has_constant_result() and -2**30 <= arg.constant_result < 2**30:
299                            continue
300                    break
301                else:
302                    return self._transform_range_iteration(node, iterable, reversed=reversed)
303
304        return node
305
306    def _transform_reversed_iteration(self, node, reversed_function):
307        args = reversed_function.arg_tuple.args
308        if len(args) == 0:
309            error(reversed_function.pos,
310                  "reversed() requires an iterable argument")
311            return node
312        elif len(args) > 1:
313            error(reversed_function.pos,
314                  "reversed() takes exactly 1 argument")
315            return node
316        arg = args[0]
317
318        # reversed(list/tuple) ?
319        if arg.type in (Builtin.tuple_type, Builtin.list_type):
320            node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
321            node.iterator.reversed = True
322            return node
323
324        return self._optimise_for_loop(node, arg, reversed=True)
325
326    PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
327        PyrexTypes.c_char_ptr_type, [
328            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
329            ])
330
331    PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
332        PyrexTypes.c_py_ssize_t_type, [
333            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
334            ])
335
336    def _transform_bytes_iteration(self, node, slice_node, reversed=False):
337        target_type = node.target.type
338        if not target_type.is_int and target_type is not Builtin.bytes_type:
339            # bytes iteration returns bytes objects in Py2, but
340            # integers in Py3
341            return node
342
343        unpack_temp_node = UtilNodes.LetRefNode(
344            slice_node.as_none_safe_node("'NoneType' is not iterable"))
345
346        slice_base_node = ExprNodes.PythonCapiCallNode(
347            slice_node.pos, "PyBytes_AS_STRING",
348            self.PyBytes_AS_STRING_func_type,
349            args = [unpack_temp_node],
350            is_temp = 0,
351            )
352        len_node = ExprNodes.PythonCapiCallNode(
353            slice_node.pos, "PyBytes_GET_SIZE",
354            self.PyBytes_GET_SIZE_func_type,
355            args = [unpack_temp_node],
356            is_temp = 0,
357            )
358
359        return UtilNodes.LetNode(
360            unpack_temp_node,
361            self._transform_carray_iteration(
362                node,
363                ExprNodes.SliceIndexNode(
364                    slice_node.pos,
365                    base = slice_base_node,
366                    start = None,
367                    step = None,
368                    stop = len_node,
369                    type = slice_base_node.type,
370                    is_temp = 1,
371                    ),
372                reversed = reversed))
373
374    PyUnicode_READ_func_type = PyrexTypes.CFuncType(
375        PyrexTypes.c_py_ucs4_type, [
376            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
377            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
378            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
379        ])
380
381    init_unicode_iteration_func_type = PyrexTypes.CFuncType(
382        PyrexTypes.c_int_type, [
383            PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
384            PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
385            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
386            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
387        ],
388        exception_value = '-1')
389
390    def _transform_unicode_iteration(self, node, slice_node, reversed=False):
391        if slice_node.is_literal:
392            # try to reduce to byte iteration for plain Latin-1 strings
393            try:
394                bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1')
395            except UnicodeEncodeError:
396                pass
397            else:
398                bytes_slice = ExprNodes.SliceIndexNode(
399                    slice_node.pos,
400                    base=ExprNodes.BytesNode(
401                        slice_node.pos, value=bytes_value,
402                        constant_result=bytes_value,
403                        type=PyrexTypes.c_const_char_ptr_type).coerce_to(
404                            PyrexTypes.c_const_uchar_ptr_type, self.current_env()),
405                    start=None,
406                    stop=ExprNodes.IntNode(
407                        slice_node.pos, value=str(len(bytes_value)),
408                        constant_result=len(bytes_value),
409                        type=PyrexTypes.c_py_ssize_t_type),
410                    type=Builtin.unicode_type,  # hint for Python conversion
411                )
412                return self._transform_carray_iteration(node, bytes_slice, reversed)
413
414        unpack_temp_node = UtilNodes.LetRefNode(
415            slice_node.as_none_safe_node("'NoneType' is not iterable"))
416
417        start_node = ExprNodes.IntNode(
418            node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
419        length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
420        end_node = length_temp.ref(node.pos)
421        if reversed:
422            relation1, relation2 = '>', '>='
423            start_node, end_node = end_node, start_node
424        else:
425            relation1, relation2 = '<=', '<'
426
427        kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
428        data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
429        counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
430
431        target_value = ExprNodes.PythonCapiCallNode(
432            slice_node.pos, "__Pyx_PyUnicode_READ",
433            self.PyUnicode_READ_func_type,
434            args = [kind_temp.ref(slice_node.pos),
435                    data_temp.ref(slice_node.pos),
436                    counter_temp.ref(node.target.pos)],
437            is_temp = False,
438            )
439        if target_value.type != node.target.type:
440            target_value = target_value.coerce_to(node.target.type,
441                                                  self.current_env())
442        target_assign = Nodes.SingleAssignmentNode(
443            pos = node.target.pos,
444            lhs = node.target,
445            rhs = target_value)
446        body = Nodes.StatListNode(
447            node.pos,
448            stats = [target_assign, node.body])
449
450        loop_node = Nodes.ForFromStatNode(
451            node.pos,
452            bound1=start_node, relation1=relation1,
453            target=counter_temp.ref(node.target.pos),
454            relation2=relation2, bound2=end_node,
455            step=None, body=body,
456            else_clause=node.else_clause,
457            from_range=True)
458
459        setup_node = Nodes.ExprStatNode(
460            node.pos,
461            expr = ExprNodes.PythonCapiCallNode(
462                slice_node.pos, "__Pyx_init_unicode_iteration",
463                self.init_unicode_iteration_func_type,
464                args = [unpack_temp_node,
465                        ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
466                                                type=PyrexTypes.c_py_ssize_t_ptr_type),
467                        ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
468                                                type=PyrexTypes.c_void_ptr_ptr_type),
469                        ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
470                                                type=PyrexTypes.c_int_ptr_type),
471                        ],
472                is_temp = True,
473                result_is_used = False,
474                utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
475                ))
476        return UtilNodes.LetNode(
477            unpack_temp_node,
478            UtilNodes.TempsBlockNode(
479                node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
480                body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
481
482    def _transform_carray_iteration(self, node, slice_node, reversed=False):
483        neg_step = False
484        if isinstance(slice_node, ExprNodes.SliceIndexNode):
485            slice_base = slice_node.base
486            start = filter_none_node(slice_node.start)
487            stop = filter_none_node(slice_node.stop)
488            step = None
489            if not stop:
490                if not slice_base.type.is_pyobject:
491                    error(slice_node.pos, "C array iteration requires known end index")
492                return node
493
494        elif slice_node.is_subscript:
495            assert isinstance(slice_node.index, ExprNodes.SliceNode)
496            slice_base = slice_node.base
497            index = slice_node.index
498            start = filter_none_node(index.start)
499            stop = filter_none_node(index.stop)
500            step = filter_none_node(index.step)
501            if step:
502                if not isinstance(step.constant_result, _py_int_types) \
503                       or step.constant_result == 0 \
504                       or step.constant_result > 0 and not stop \
505                       or step.constant_result < 0 and not start:
506                    if not slice_base.type.is_pyobject:
507                        error(step.pos, "C array iteration requires known step size and end index")
508                    return node
509                else:
510                    # step sign is handled internally by ForFromStatNode
511                    step_value = step.constant_result
512                    if reversed:
513                        step_value = -step_value
514                    neg_step = step_value < 0
515                    step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
516                                             value=str(abs(step_value)),
517                                             constant_result=abs(step_value))
518
519        elif slice_node.type.is_array:
520            if slice_node.type.size is None:
521                error(slice_node.pos, "C array iteration requires known end index")
522                return node
523            slice_base = slice_node
524            start = None
525            stop = ExprNodes.IntNode(
526                slice_node.pos, value=str(slice_node.type.size),
527                type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
528            step = None
529
530        else:
531            if not slice_node.type.is_pyobject:
532                error(slice_node.pos, "C array iteration requires known end index")
533            return node
534
535        if start:
536            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
537        if stop:
538            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
539        if stop is None:
540            if neg_step:
541                stop = ExprNodes.IntNode(
542                    slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
543            else:
544                error(slice_node.pos, "C array iteration requires known step size and end index")
545                return node
546
547        if reversed:
548            if not start:
549                start = ExprNodes.IntNode(slice_node.pos, value="0",  constant_result=0,
550                                          type=PyrexTypes.c_py_ssize_t_type)
551            # if step was provided, it was already negated above
552            start, stop = stop, start
553
554        ptr_type = slice_base.type
555        if ptr_type.is_array:
556            ptr_type = ptr_type.element_ptr_type()
557        carray_ptr = slice_base.coerce_to_simple(self.current_env())
558
559        if start and start.constant_result != 0:
560            start_ptr_node = ExprNodes.AddNode(
561                start.pos,
562                operand1=carray_ptr,
563                operator='+',
564                operand2=start,
565                type=ptr_type)
566        else:
567            start_ptr_node = carray_ptr
568
569        if stop and stop.constant_result != 0:
570            stop_ptr_node = ExprNodes.AddNode(
571                stop.pos,
572                operand1=ExprNodes.CloneNode(carray_ptr),
573                operator='+',
574                operand2=stop,
575                type=ptr_type
576                ).coerce_to_simple(self.current_env())
577        else:
578            stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
579
580        counter = UtilNodes.TempHandle(ptr_type)
581        counter_temp = counter.ref(node.target.pos)
582
583        if slice_base.type.is_string and node.target.type.is_pyobject:
584            # special case: char* -> bytes/unicode
585            if slice_node.type is Builtin.unicode_type:
586                target_value = ExprNodes.CastNode(
587                    ExprNodes.DereferenceNode(
588                        node.target.pos, operand=counter_temp,
589                        type=ptr_type.base_type),
590                    PyrexTypes.c_py_ucs4_type).coerce_to(
591                        node.target.type, self.current_env())
592            else:
593                # char* -> bytes coercion requires slicing, not indexing
594                target_value = ExprNodes.SliceIndexNode(
595                    node.target.pos,
596                    start=ExprNodes.IntNode(node.target.pos, value='0',
597                                            constant_result=0,
598                                            type=PyrexTypes.c_int_type),
599                    stop=ExprNodes.IntNode(node.target.pos, value='1',
600                                           constant_result=1,
601                                           type=PyrexTypes.c_int_type),
602                    base=counter_temp,
603                    type=Builtin.bytes_type,
604                    is_temp=1)
605        elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
606            # Allow iteration with pointer target to avoid copy.
607            target_value = counter_temp
608        else:
609            # TODO: can this safely be replaced with DereferenceNode() as above?
610            target_value = ExprNodes.IndexNode(
611                node.target.pos,
612                index=ExprNodes.IntNode(node.target.pos, value='0',
613                                        constant_result=0,
614                                        type=PyrexTypes.c_int_type),
615                base=counter_temp,
616                type=ptr_type.base_type)
617
618        if target_value.type != node.target.type:
619            target_value = target_value.coerce_to(node.target.type,
620                                                  self.current_env())
621
622        target_assign = Nodes.SingleAssignmentNode(
623            pos = node.target.pos,
624            lhs = node.target,
625            rhs = target_value)
626
627        body = Nodes.StatListNode(
628            node.pos,
629            stats = [target_assign, node.body])
630
631        relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
632
633        for_node = Nodes.ForFromStatNode(
634            node.pos,
635            bound1=start_ptr_node, relation1=relation1,
636            target=counter_temp,
637            relation2=relation2, bound2=stop_ptr_node,
638            step=step, body=body,
639            else_clause=node.else_clause,
640            from_range=True)
641
642        return UtilNodes.TempsBlockNode(
643            node.pos, temps=[counter],
644            body=for_node)
645
646    def _transform_enumerate_iteration(self, node, enumerate_function):
647        args = enumerate_function.arg_tuple.args
648        if len(args) == 0:
649            error(enumerate_function.pos,
650                  "enumerate() requires an iterable argument")
651            return node
652        elif len(args) > 2:
653            error(enumerate_function.pos,
654                  "enumerate() takes at most 2 arguments")
655            return node
656
657        if not node.target.is_sequence_constructor:
658            # leave this untouched for now
659            return node
660        targets = node.target.args
661        if len(targets) != 2:
662            # leave this untouched for now
663            return node
664
665        enumerate_target, iterable_target = targets
666        counter_type = enumerate_target.type
667
668        if not counter_type.is_pyobject and not counter_type.is_int:
669            # nothing we can do here, I guess
670            return node
671
672        if len(args) == 2:
673            start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
674        else:
675            start = ExprNodes.IntNode(enumerate_function.pos,
676                                      value='0',
677                                      type=counter_type,
678                                      constant_result=0)
679        temp = UtilNodes.LetRefNode(start)
680
681        inc_expression = ExprNodes.AddNode(
682            enumerate_function.pos,
683            operand1 = temp,
684            operand2 = ExprNodes.IntNode(node.pos, value='1',
685                                         type=counter_type,
686                                         constant_result=1),
687            operator = '+',
688            type = counter_type,
689            #inplace = True,   # not worth using in-place operation for Py ints
690            is_temp = counter_type.is_pyobject
691            )
692
693        loop_body = [
694            Nodes.SingleAssignmentNode(
695                pos = enumerate_target.pos,
696                lhs = enumerate_target,
697                rhs = temp),
698            Nodes.SingleAssignmentNode(
699                pos = enumerate_target.pos,
700                lhs = temp,
701                rhs = inc_expression)
702            ]
703
704        if isinstance(node.body, Nodes.StatListNode):
705            node.body.stats = loop_body + node.body.stats
706        else:
707            loop_body.append(node.body)
708            node.body = Nodes.StatListNode(
709                node.body.pos,
710                stats = loop_body)
711
712        node.target = iterable_target
713        node.item = node.item.coerce_to(iterable_target.type, self.current_env())
714        node.iterator.sequence = args[0]
715
716        # recurse into loop to check for further optimisations
717        return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
718
719    def _find_for_from_node_relations(self, neg_step_value, reversed):
720        if reversed:
721            if neg_step_value:
722                return '<', '<='
723            else:
724                return '>', '>='
725        else:
726            if neg_step_value:
727                return '>=', '>'
728            else:
729                return '<=', '<'
730
731    def _transform_range_iteration(self, node, range_function, reversed=False):
732        args = range_function.arg_tuple.args
733        if len(args) < 3:
734            step_pos = range_function.pos
735            step_value = 1
736            step = ExprNodes.IntNode(step_pos, value='1', constant_result=1)
737        else:
738            step = args[2]
739            step_pos = step.pos
740            if not isinstance(step.constant_result, _py_int_types):
741                # cannot determine step direction
742                return node
743            step_value = step.constant_result
744            if step_value == 0:
745                # will lead to an error elsewhere
746                return node
747            step = ExprNodes.IntNode(step_pos, value=str(step_value),
748                                     constant_result=step_value)
749
750        if len(args) == 1:
751            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
752                                       constant_result=0)
753            bound2 = args[0].coerce_to_integer(self.current_env())
754        else:
755            bound1 = args[0].coerce_to_integer(self.current_env())
756            bound2 = args[1].coerce_to_integer(self.current_env())
757
758        relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
759
760        bound2_ref_node = None
761        if reversed:
762            bound1, bound2 = bound2, bound1
763            abs_step = abs(step_value)
764            if abs_step != 1:
765                if (isinstance(bound1.constant_result, _py_int_types) and
766                        isinstance(bound2.constant_result, _py_int_types)):
767                    # calculate final bounds now
768                    if step_value < 0:
769                        begin_value = bound2.constant_result
770                        end_value = bound1.constant_result
771                        bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1
772                    else:
773                        begin_value = bound1.constant_result
774                        end_value = bound2.constant_result
775                        bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1
776
777                    bound1 = ExprNodes.IntNode(
778                        bound1.pos, value=str(bound1_value), constant_result=bound1_value,
779                        type=PyrexTypes.spanning_type(bound1.type, bound2.type))
780                else:
781                    # evaluate the same expression as above at runtime
782                    bound2_ref_node = UtilNodes.LetRefNode(bound2)
783                    bound1 = self._build_range_step_calculation(
784                        bound1, bound2_ref_node, step, step_value)
785
786        if step_value < 0:
787            step_value = -step_value
788        step.value = str(step_value)
789        step.constant_result = step_value
790        step = step.coerce_to_integer(self.current_env())
791
792        if not bound2.is_literal:
793            # stop bound must be immutable => keep it in a temp var
794            bound2_is_temp = True
795            bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2)
796        else:
797            bound2_is_temp = False
798
799        for_node = Nodes.ForFromStatNode(
800            node.pos,
801            target=node.target,
802            bound1=bound1, relation1=relation1,
803            relation2=relation2, bound2=bound2,
804            step=step, body=node.body,
805            else_clause=node.else_clause,
806            from_range=True)
807        for_node.set_up_loop(self.current_env())
808
809        if bound2_is_temp:
810            for_node = UtilNodes.LetNode(bound2, for_node)
811
812        return for_node
813
814    def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value):
815        abs_step = abs(step_value)
816        spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type)
817        if step.type.is_int and abs_step < 0x7FFF:
818            # Avoid loss of integer precision warnings.
819            spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type)
820        else:
821            spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type)
822        if step_value < 0:
823            begin_value = bound2_ref_node
824            end_value = bound1
825            final_op = '-'
826        else:
827            begin_value = bound1
828            end_value = bound2_ref_node
829            final_op = '+'
830
831        step_calculation_node = ExprNodes.binop_node(
832            bound1.pos,
833            operand1=ExprNodes.binop_node(
834                bound1.pos,
835                operand1=bound2_ref_node,
836                operator=final_op,  # +/-
837                operand2=ExprNodes.MulNode(
838                    bound1.pos,
839                    operand1=ExprNodes.IntNode(
840                        bound1.pos,
841                        value=str(abs_step),
842                        constant_result=abs_step,
843                        type=spanning_step_type),
844                    operator='*',
845                    operand2=ExprNodes.DivNode(
846                        bound1.pos,
847                        operand1=ExprNodes.SubNode(
848                            bound1.pos,
849                            operand1=ExprNodes.SubNode(
850                                bound1.pos,
851                                operand1=begin_value,
852                                operator='-',
853                                operand2=end_value,
854                                type=spanning_type),
855                            operator='-',
856                            operand2=ExprNodes.IntNode(
857                                bound1.pos,
858                                value='1',
859                                constant_result=1),
860                            type=spanning_step_type),
861                        operator='//',
862                        operand2=ExprNodes.IntNode(
863                            bound1.pos,
864                            value=str(abs_step),
865                            constant_result=abs_step,
866                            type=spanning_step_type),
867                        type=spanning_step_type),
868                    type=spanning_step_type),
869                type=spanning_step_type),
870            operator=final_op,  # +/-
871            operand2=ExprNodes.IntNode(
872                bound1.pos,
873                value='1',
874                constant_result=1),
875            type=spanning_type)
876        return step_calculation_node
877
878    def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
879        temps = []
880        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
881        temps.append(temp)
882        dict_temp = temp.ref(dict_obj.pos)
883        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
884        temps.append(temp)
885        pos_temp = temp.ref(node.pos)
886
887        key_target = value_target = tuple_target = None
888        if keys and values:
889            if node.target.is_sequence_constructor:
890                if len(node.target.args) == 2:
891                    key_target, value_target = node.target.args
892                else:
893                    # unusual case that may or may not lead to an error
894                    return node
895            else:
896                tuple_target = node.target
897        elif keys:
898            key_target = node.target
899        else:
900            value_target = node.target
901
902        if isinstance(node.body, Nodes.StatListNode):
903            body = node.body
904        else:
905            body = Nodes.StatListNode(pos = node.body.pos,
906                                      stats = [node.body])
907
908        # keep original length to guard against dict modification
909        dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
910        temps.append(dict_len_temp)
911        dict_len_temp_addr = ExprNodes.AmpersandNode(
912            node.pos, operand=dict_len_temp.ref(dict_obj.pos),
913            type=PyrexTypes.c_ptr_type(dict_len_temp.type))
914        temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
915        temps.append(temp)
916        is_dict_temp = temp.ref(node.pos)
917        is_dict_temp_addr = ExprNodes.AmpersandNode(
918            node.pos, operand=is_dict_temp,
919            type=PyrexTypes.c_ptr_type(temp.type))
920
921        iter_next_node = Nodes.DictIterationNextNode(
922            dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
923            key_target, value_target, tuple_target,
924            is_dict_temp)
925        iter_next_node = iter_next_node.analyse_expressions(self.current_env())
926        body.stats[0:0] = [iter_next_node]
927
928        if method:
929            method_node = ExprNodes.StringNode(
930                dict_obj.pos, is_identifier=True, value=method)
931            dict_obj = dict_obj.as_none_safe_node(
932                "'NoneType' object has no attribute '%{0}s'".format('.30' if len(method) <= 30 else ''),
933                error = "PyExc_AttributeError",
934                format_args = [method])
935        else:
936            method_node = ExprNodes.NullNode(dict_obj.pos)
937            dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
938
939        def flag_node(value):
940            value = value and 1 or 0
941            return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
942
943        result_code = [
944            Nodes.SingleAssignmentNode(
945                node.pos,
946                lhs = pos_temp,
947                rhs = ExprNodes.IntNode(node.pos, value='0',
948                                        constant_result=0)),
949            Nodes.SingleAssignmentNode(
950                dict_obj.pos,
951                lhs = dict_temp,
952                rhs = ExprNodes.PythonCapiCallNode(
953                    dict_obj.pos,
954                    "__Pyx_dict_iterator",
955                    self.PyDict_Iterator_func_type,
956                    utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
957                    args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
958                            method_node, dict_len_temp_addr, is_dict_temp_addr,
959                            ],
960                    is_temp=True,
961                )),
962            Nodes.WhileStatNode(
963                node.pos,
964                condition = None,
965                body = body,
966                else_clause = node.else_clause
967                )
968            ]
969
970        return UtilNodes.TempsBlockNode(
971            node.pos, temps=temps,
972            body=Nodes.StatListNode(
973                node.pos,
974                stats = result_code
975                ))
976
977    PyDict_Iterator_func_type = PyrexTypes.CFuncType(
978        PyrexTypes.py_object_type, [
979            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
980            PyrexTypes.CFuncTypeArg("is_dict",  PyrexTypes.c_int_type, None),
981            PyrexTypes.CFuncTypeArg("method_name",  PyrexTypes.py_object_type, None),
982            PyrexTypes.CFuncTypeArg("p_orig_length",  PyrexTypes.c_py_ssize_t_ptr_type, None),
983            PyrexTypes.CFuncTypeArg("p_is_dict",  PyrexTypes.c_int_ptr_type, None),
984            ])
985
986    PySet_Iterator_func_type = PyrexTypes.CFuncType(
987        PyrexTypes.py_object_type, [
988            PyrexTypes.CFuncTypeArg("set",  PyrexTypes.py_object_type, None),
989            PyrexTypes.CFuncTypeArg("is_set",  PyrexTypes.c_int_type, None),
990            PyrexTypes.CFuncTypeArg("p_orig_length",  PyrexTypes.c_py_ssize_t_ptr_type, None),
991            PyrexTypes.CFuncTypeArg("p_is_set",  PyrexTypes.c_int_ptr_type, None),
992            ])
993
994    def _transform_set_iteration(self, node, set_obj):
995        temps = []
996        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
997        temps.append(temp)
998        set_temp = temp.ref(set_obj.pos)
999        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
1000        temps.append(temp)
1001        pos_temp = temp.ref(node.pos)
1002
1003        if isinstance(node.body, Nodes.StatListNode):
1004            body = node.body
1005        else:
1006            body = Nodes.StatListNode(pos = node.body.pos,
1007                                      stats = [node.body])
1008
1009        # keep original length to guard against set modification
1010        set_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
1011        temps.append(set_len_temp)
1012        set_len_temp_addr = ExprNodes.AmpersandNode(
1013            node.pos, operand=set_len_temp.ref(set_obj.pos),
1014            type=PyrexTypes.c_ptr_type(set_len_temp.type))
1015        temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
1016        temps.append(temp)
1017        is_set_temp = temp.ref(node.pos)
1018        is_set_temp_addr = ExprNodes.AmpersandNode(
1019            node.pos, operand=is_set_temp,
1020            type=PyrexTypes.c_ptr_type(temp.type))
1021
1022        value_target = node.target
1023        iter_next_node = Nodes.SetIterationNextNode(
1024            set_temp, set_len_temp.ref(set_obj.pos), pos_temp, value_target, is_set_temp)
1025        iter_next_node = iter_next_node.analyse_expressions(self.current_env())
1026        body.stats[0:0] = [iter_next_node]
1027
1028        def flag_node(value):
1029            value = value and 1 or 0
1030            return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
1031
1032        result_code = [
1033            Nodes.SingleAssignmentNode(
1034                node.pos,
1035                lhs=pos_temp,
1036                rhs=ExprNodes.IntNode(node.pos, value='0', constant_result=0)),
1037            Nodes.SingleAssignmentNode(
1038                set_obj.pos,
1039                lhs=set_temp,
1040                rhs=ExprNodes.PythonCapiCallNode(
1041                    set_obj.pos,
1042                    "__Pyx_set_iterator",
1043                    self.PySet_Iterator_func_type,
1044                    utility_code=UtilityCode.load_cached("set_iter", "Optimize.c"),
1045                    args=[set_obj, flag_node(set_obj.type is Builtin.set_type),
1046                          set_len_temp_addr, is_set_temp_addr,
1047                          ],
1048                    is_temp=True,
1049                )),
1050            Nodes.WhileStatNode(
1051                node.pos,
1052                condition=None,
1053                body=body,
1054                else_clause=node.else_clause,
1055                )
1056            ]
1057
1058        return UtilNodes.TempsBlockNode(
1059            node.pos, temps=temps,
1060            body=Nodes.StatListNode(
1061                node.pos,
1062                stats = result_code
1063                ))
1064
1065
1066class SwitchTransform(Visitor.EnvTransform):
1067    """
1068    This transformation tries to turn long if statements into C switch statements.
1069    The requirement is that every clause be an (or of) var == value, where the var
1070    is common among all clauses and both var and value are ints.
1071    """
1072    NO_MATCH = (None, None, None)
1073
1074    def extract_conditions(self, cond, allow_not_in):
1075        while True:
1076            if isinstance(cond, (ExprNodes.CoerceToTempNode,
1077                                 ExprNodes.CoerceToBooleanNode)):
1078                cond = cond.arg
1079            elif isinstance(cond, ExprNodes.BoolBinopResultNode):
1080                cond = cond.arg.arg
1081            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
1082                # this is what we get from the FlattenInListTransform
1083                cond = cond.subexpression
1084            elif isinstance(cond, ExprNodes.TypecastNode):
1085                cond = cond.operand
1086            else:
1087                break
1088
1089        if isinstance(cond, ExprNodes.PrimaryCmpNode):
1090            if cond.cascade is not None:
1091                return self.NO_MATCH
1092            elif cond.is_c_string_contains() and \
1093                   isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
1094                not_in = cond.operator == 'not_in'
1095                if not_in and not allow_not_in:
1096                    return self.NO_MATCH
1097                if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
1098                       cond.operand2.contains_surrogates():
1099                    # dealing with surrogates leads to different
1100                    # behaviour on wide and narrow Unicode
1101                    # platforms => refuse to optimise this case
1102                    return self.NO_MATCH
1103                return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
1104            elif not cond.is_python_comparison():
1105                if cond.operator == '==':
1106                    not_in = False
1107                elif allow_not_in and cond.operator == '!=':
1108                    not_in = True
1109                else:
1110                    return self.NO_MATCH
1111                # this looks somewhat silly, but it does the right
1112                # checks for NameNode and AttributeNode
1113                if is_common_value(cond.operand1, cond.operand1):
1114                    if cond.operand2.is_literal:
1115                        return not_in, cond.operand1, [cond.operand2]
1116                    elif getattr(cond.operand2, 'entry', None) \
1117                             and cond.operand2.entry.is_const:
1118                        return not_in, cond.operand1, [cond.operand2]
1119                if is_common_value(cond.operand2, cond.operand2):
1120                    if cond.operand1.is_literal:
1121                        return not_in, cond.operand2, [cond.operand1]
1122                    elif getattr(cond.operand1, 'entry', None) \
1123                             and cond.operand1.entry.is_const:
1124                        return not_in, cond.operand2, [cond.operand1]
1125        elif isinstance(cond, ExprNodes.BoolBinopNode):
1126            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
1127                allow_not_in = (cond.operator == 'and')
1128                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
1129                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
1130                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
1131                    if (not not_in_1) or allow_not_in:
1132                        return not_in_1, t1, c1+c2
1133        return self.NO_MATCH
1134
1135    def extract_in_string_conditions(self, string_literal):
1136        if isinstance(string_literal, ExprNodes.UnicodeNode):
1137            charvals = list(map(ord, set(string_literal.value)))
1138            charvals.sort()
1139            return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
1140                                       constant_result=charval)
1141                     for charval in charvals ]
1142        else:
1143            # this is a bit tricky as Py3's bytes type returns
1144            # integers on iteration, whereas Py2 returns 1-char byte
1145            # strings
1146            characters = string_literal.value
1147            characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
1148            characters.sort()
1149            return [ ExprNodes.CharNode(string_literal.pos, value=charval,
1150                                        constant_result=charval)
1151                     for charval in characters ]
1152
1153    def extract_common_conditions(self, common_var, condition, allow_not_in):
1154        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
1155        if var is None:
1156            return self.NO_MATCH
1157        elif common_var is not None and not is_common_value(var, common_var):
1158            return self.NO_MATCH
1159        elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
1160            return self.NO_MATCH
1161        return not_in, var, conditions
1162
1163    def has_duplicate_values(self, condition_values):
1164        # duplicated values don't work in a switch statement
1165        seen = set()
1166        for value in condition_values:
1167            if value.has_constant_result():
1168                if value.constant_result in seen:
1169                    return True
1170                seen.add(value.constant_result)
1171            else:
1172                # this isn't completely safe as we don't know the
1173                # final C value, but this is about the best we can do
1174                try:
1175                    if value.entry.cname in seen:
1176                        return True
1177                except AttributeError:
1178                    return True  # play safe
1179                seen.add(value.entry.cname)
1180        return False
1181
1182    def visit_IfStatNode(self, node):
1183        if not self.current_directives.get('optimize.use_switch'):
1184            self.visitchildren(node)
1185            return node
1186
1187        common_var = None
1188        cases = []
1189        for if_clause in node.if_clauses:
1190            _, common_var, conditions = self.extract_common_conditions(
1191                common_var, if_clause.condition, False)
1192            if common_var is None:
1193                self.visitchildren(node)
1194                return node
1195            cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos,
1196                                              conditions=conditions,
1197                                              body=if_clause.body))
1198
1199        condition_values = [
1200            cond for case in cases for cond in case.conditions]
1201        if len(condition_values) < 2:
1202            self.visitchildren(node)
1203            return node
1204        if self.has_duplicate_values(condition_values):
1205            self.visitchildren(node)
1206            return node
1207
1208        # Recurse into body subtrees that we left untouched so far.
1209        self.visitchildren(node, 'else_clause')
1210        for case in cases:
1211            self.visitchildren(case, 'body')
1212
1213        common_var = unwrap_node(common_var)
1214        switch_node = Nodes.SwitchStatNode(pos=node.pos,
1215                                           test=common_var,
1216                                           cases=cases,
1217                                           else_clause=node.else_clause)
1218        return switch_node
1219
1220    def visit_CondExprNode(self, node):
1221        if not self.current_directives.get('optimize.use_switch'):
1222            self.visitchildren(node)
1223            return node
1224
1225        not_in, common_var, conditions = self.extract_common_conditions(
1226            None, node.test, True)
1227        if common_var is None \
1228                or len(conditions) < 2 \
1229                or self.has_duplicate_values(conditions):
1230            self.visitchildren(node)
1231            return node
1232
1233        return self.build_simple_switch_statement(
1234            node, common_var, conditions, not_in,
1235            node.true_val, node.false_val)
1236
1237    def visit_BoolBinopNode(self, node):
1238        if not self.current_directives.get('optimize.use_switch'):
1239            self.visitchildren(node)
1240            return node
1241
1242        not_in, common_var, conditions = self.extract_common_conditions(
1243            None, node, True)
1244        if common_var is None \
1245                or len(conditions) < 2 \
1246                or self.has_duplicate_values(conditions):
1247            self.visitchildren(node)
1248            node.wrap_operands(self.current_env())  # in case we changed the operands
1249            return node
1250
1251        return self.build_simple_switch_statement(
1252            node, common_var, conditions, not_in,
1253            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
1254            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
1255
1256    def visit_PrimaryCmpNode(self, node):
1257        if not self.current_directives.get('optimize.use_switch'):
1258            self.visitchildren(node)
1259            return node
1260
1261        not_in, common_var, conditions = self.extract_common_conditions(
1262            None, node, True)
1263        if common_var is None \
1264                or len(conditions) < 2 \
1265                or self.has_duplicate_values(conditions):
1266            self.visitchildren(node)
1267            return node
1268
1269        return self.build_simple_switch_statement(
1270            node, common_var, conditions, not_in,
1271            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
1272            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
1273
1274    def build_simple_switch_statement(self, node, common_var, conditions,
1275                                      not_in, true_val, false_val):
1276        result_ref = UtilNodes.ResultRefNode(node)
1277        true_body = Nodes.SingleAssignmentNode(
1278            node.pos,
1279            lhs=result_ref,
1280            rhs=true_val.coerce_to(node.type, self.current_env()),
1281            first=True)
1282        false_body = Nodes.SingleAssignmentNode(
1283            node.pos,
1284            lhs=result_ref,
1285            rhs=false_val.coerce_to(node.type, self.current_env()),
1286            first=True)
1287
1288        if not_in:
1289            true_body, false_body = false_body, true_body
1290
1291        cases = [Nodes.SwitchCaseNode(pos = node.pos,
1292                                      conditions = conditions,
1293                                      body = true_body)]
1294
1295        common_var = unwrap_node(common_var)
1296        switch_node = Nodes.SwitchStatNode(pos = node.pos,
1297                                           test = common_var,
1298                                           cases = cases,
1299                                           else_clause = false_body)
1300        replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
1301        return replacement
1302
1303    def visit_EvalWithTempExprNode(self, node):
1304        if not self.current_directives.get('optimize.use_switch'):
1305            self.visitchildren(node)
1306            return node
1307
1308        # drop unused expression temp from FlattenInListTransform
1309        orig_expr = node.subexpression
1310        temp_ref = node.lazy_temp
1311        self.visitchildren(node)
1312        if node.subexpression is not orig_expr:
1313            # node was restructured => check if temp is still used
1314            if not Visitor.tree_contains(node.subexpression, temp_ref):
1315                return node.subexpression
1316        return node
1317
1318    visit_Node = Visitor.VisitorTransform.recurse_to_children
1319
1320
1321class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
1322    """
1323    This transformation flattens "x in [val1, ..., valn]" into a sequential list
1324    of comparisons.
1325    """
1326
1327    def visit_PrimaryCmpNode(self, node):
1328        self.visitchildren(node)
1329        if node.cascade is not None:
1330            return node
1331        elif node.operator == 'in':
1332            conjunction = 'or'
1333            eq_or_neq = '=='
1334        elif node.operator == 'not_in':
1335            conjunction = 'and'
1336            eq_or_neq = '!='
1337        else:
1338            return node
1339
1340        if not isinstance(node.operand2, (ExprNodes.TupleNode,
1341                                          ExprNodes.ListNode,
1342                                          ExprNodes.SetNode)):
1343            return node
1344
1345        args = node.operand2.args
1346        if len(args) == 0:
1347            # note: lhs may have side effects
1348            return node
1349
1350        lhs = UtilNodes.ResultRefNode(node.operand1)
1351
1352        conds = []
1353        temps = []
1354        for arg in args:
1355            try:
1356                # Trial optimisation to avoid redundant temp
1357                # assignments.  However, since is_simple() is meant to
1358                # be called after type analysis, we ignore any errors
1359                # and just play safe in that case.
1360                is_simple_arg = arg.is_simple()
1361            except Exception:
1362                is_simple_arg = False
1363            if not is_simple_arg:
1364                # must evaluate all non-simple RHS before doing the comparisons
1365                arg = UtilNodes.LetRefNode(arg)
1366                temps.append(arg)
1367            cond = ExprNodes.PrimaryCmpNode(
1368                                pos = node.pos,
1369                                operand1 = lhs,
1370                                operator = eq_or_neq,
1371                                operand2 = arg,
1372                                cascade = None)
1373            conds.append(ExprNodes.TypecastNode(
1374                                pos = node.pos,
1375                                operand = cond,
1376                                type = PyrexTypes.c_bint_type))
1377        def concat(left, right):
1378            return ExprNodes.BoolBinopNode(
1379                                pos = node.pos,
1380                                operator = conjunction,
1381                                operand1 = left,
1382                                operand2 = right)
1383
1384        condition = reduce(concat, conds)
1385        new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
1386        for temp in temps[::-1]:
1387            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
1388        return new_node
1389
1390    visit_Node = Visitor.VisitorTransform.recurse_to_children
1391
1392
1393class DropRefcountingTransform(Visitor.VisitorTransform):
1394    """Drop ref-counting in safe places.
1395    """
1396    visit_Node = Visitor.VisitorTransform.recurse_to_children
1397
1398    def visit_ParallelAssignmentNode(self, node):
1399        """
1400        Parallel swap assignments like 'a,b = b,a' are safe.
1401        """
1402        left_names, right_names = [], []
1403        left_indices, right_indices = [], []
1404        temps = []
1405
1406        for stat in node.stats:
1407            if isinstance(stat, Nodes.SingleAssignmentNode):
1408                if not self._extract_operand(stat.lhs, left_names,
1409                                             left_indices, temps):
1410                    return node
1411                if not self._extract_operand(stat.rhs, right_names,
1412                                             right_indices, temps):
1413                    return node
1414            elif isinstance(stat, Nodes.CascadedAssignmentNode):
1415                # FIXME
1416                return node
1417            else:
1418                return node
1419
1420        if left_names or right_names:
1421            # lhs/rhs names must be a non-redundant permutation
1422            lnames = [ path for path, n in left_names ]
1423            rnames = [ path for path, n in right_names ]
1424            if set(lnames) != set(rnames):
1425                return node
1426            if len(set(lnames)) != len(right_names):
1427                return node
1428
1429        if left_indices or right_indices:
1430            # base name and index of index nodes must be a
1431            # non-redundant permutation
1432            lindices = []
1433            for lhs_node in left_indices:
1434                index_id = self._extract_index_id(lhs_node)
1435                if not index_id:
1436                    return node
1437                lindices.append(index_id)
1438            rindices = []
1439            for rhs_node in right_indices:
1440                index_id = self._extract_index_id(rhs_node)
1441                if not index_id:
1442                    return node
1443                rindices.append(index_id)
1444
1445            if set(lindices) != set(rindices):
1446                return node
1447            if len(set(lindices)) != len(right_indices):
1448                return node
1449
1450            # really supporting IndexNode requires support in
1451            # __Pyx_GetItemInt(), so let's stop short for now
1452            return node
1453
1454        temp_args = [t.arg for t in temps]
1455        for temp in temps:
1456            temp.use_managed_ref = False
1457
1458        for _, name_node in left_names + right_names:
1459            if name_node not in temp_args:
1460                name_node.use_managed_ref = False
1461
1462        for index_node in left_indices + right_indices:
1463            index_node.use_managed_ref = False
1464
1465        return node
1466
1467    def _extract_operand(self, node, names, indices, temps):
1468        node = unwrap_node(node)
1469        if not node.type.is_pyobject:
1470            return False
1471        if isinstance(node, ExprNodes.CoerceToTempNode):
1472            temps.append(node)
1473            node = node.arg
1474        name_path = []
1475        obj_node = node
1476        while obj_node.is_attribute:
1477            if obj_node.is_py_attr:
1478                return False
1479            name_path.append(obj_node.member)
1480            obj_node = obj_node.obj
1481        if obj_node.is_name:
1482            name_path.append(obj_node.name)
1483            names.append( ('.'.join(name_path[::-1]), node) )
1484        elif node.is_subscript:
1485            if node.base.type != Builtin.list_type:
1486                return False
1487            if not node.index.type.is_int:
1488                return False
1489            if not node.base.is_name:
1490                return False
1491            indices.append(node)
1492        else:
1493            return False
1494        return True
1495
1496    def _extract_index_id(self, index_node):
1497        base = index_node.base
1498        index = index_node.index
1499        if isinstance(index, ExprNodes.NameNode):
1500            index_val = index.name
1501        elif isinstance(index, ExprNodes.ConstNode):
1502            # FIXME:
1503            return None
1504        else:
1505            return None
1506        return (base.name, index_val)
1507
1508
1509class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1510    """Optimize some common calls to builtin types *before* the type
1511    analysis phase and *after* the declarations analysis phase.
1512
1513    This transform cannot make use of any argument types, but it can
1514    restructure the tree in a way that the type analysis phase can
1515    respond to.
1516
1517    Introducing C function calls here may not be a good idea.  Move
1518    them to the OptimizeBuiltinCalls transform instead, which runs
1519    after type analysis.
1520    """
1521    # only intercept on call nodes
1522    visit_Node = Visitor.VisitorTransform.recurse_to_children
1523
1524    def visit_SimpleCallNode(self, node):
1525        self.visitchildren(node)
1526        function = node.function
1527        if not self._function_is_builtin_name(function):
1528            return node
1529        return self._dispatch_to_handler(node, function, node.args)
1530
1531    def visit_GeneralCallNode(self, node):
1532        self.visitchildren(node)
1533        function = node.function
1534        if not self._function_is_builtin_name(function):
1535            return node
1536        arg_tuple = node.positional_args
1537        if not isinstance(arg_tuple, ExprNodes.TupleNode):
1538            return node
1539        args = arg_tuple.args
1540        return self._dispatch_to_handler(
1541            node, function, args, node.keyword_args)
1542
1543    def _function_is_builtin_name(self, function):
1544        if not function.is_name:
1545            return False
1546        env = self.current_env()
1547        entry = env.lookup(function.name)
1548        if entry is not env.builtin_scope().lookup_here(function.name):
1549            return False
1550        # if entry is None, it's at least an undeclared name, so likely builtin
1551        return True
1552
1553    def _dispatch_to_handler(self, node, function, args, kwargs=None):
1554        if kwargs is None:
1555            handler_name = '_handle_simple_function_%s' % function.name
1556        else:
1557            handler_name = '_handle_general_function_%s' % function.name
1558        handle_call = getattr(self, handler_name, None)
1559        if handle_call is not None:
1560            if kwargs is None:
1561                return handle_call(node, args)
1562            else:
1563                return handle_call(node, args, kwargs)
1564        return node
1565
1566    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1567        node.function = ExprNodes.PythonCapiFunctionNode(
1568            node.function.pos, node.function.name, cname, func_type,
1569            utility_code = utility_code)
1570
1571    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1572        if not expected: # None or 0
1573            arg_str = ''
1574        elif isinstance(expected, basestring) or expected > 1:
1575            arg_str = '...'
1576        elif expected == 1:
1577            arg_str = 'x'
1578        else:
1579            arg_str = ''
1580        if expected is not None:
1581            expected_str = 'expected %s, ' % expected
1582        else:
1583            expected_str = ''
1584        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1585            function_name, arg_str, expected_str, len(args)))
1586
1587    # specific handlers for simple call nodes
1588
1589    def _handle_simple_function_float(self, node, pos_args):
1590        if not pos_args:
1591            return ExprNodes.FloatNode(node.pos, value='0.0')
1592        if len(pos_args) > 1:
1593            self._error_wrong_arg_count('float', node, pos_args, 1)
1594        arg_type = getattr(pos_args[0], 'type', None)
1595        if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
1596            return pos_args[0]
1597        return node
1598
1599    def _handle_simple_function_slice(self, node, pos_args):
1600        arg_count = len(pos_args)
1601        start = step = None
1602        if arg_count == 1:
1603            stop, = pos_args
1604        elif arg_count == 2:
1605            start, stop = pos_args
1606        elif arg_count == 3:
1607            start, stop, step = pos_args
1608        else:
1609            self._error_wrong_arg_count('slice', node, pos_args)
1610            return node
1611        return ExprNodes.SliceNode(
1612            node.pos,
1613            start=start or ExprNodes.NoneNode(node.pos),
1614            stop=stop,
1615            step=step or ExprNodes.NoneNode(node.pos))
1616
1617    def _handle_simple_function_ord(self, node, pos_args):
1618        """Unpack ord('X').
1619        """
1620        if len(pos_args) != 1:
1621            return node
1622        arg = pos_args[0]
1623        if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
1624            if len(arg.value) == 1:
1625                return ExprNodes.IntNode(
1626                    arg.pos, type=PyrexTypes.c_long_type,
1627                    value=str(ord(arg.value)),
1628                    constant_result=ord(arg.value)
1629                )
1630        elif isinstance(arg, ExprNodes.StringNode):
1631            if arg.unicode_value and len(arg.unicode_value) == 1 \
1632                    and ord(arg.unicode_value) <= 255:  # Py2/3 portability
1633                return ExprNodes.IntNode(
1634                    arg.pos, type=PyrexTypes.c_int_type,
1635                    value=str(ord(arg.unicode_value)),
1636                    constant_result=ord(arg.unicode_value)
1637                )
1638        return node
1639
1640    # sequence processing
1641
1642    def _handle_simple_function_all(self, node, pos_args):
1643        """Transform
1644
1645        _result = all(p(x) for L in LL for x in L)
1646
1647        into
1648
1649        for L in LL:
1650            for x in L:
1651                if not p(x):
1652                    return False
1653        else:
1654            return True
1655        """
1656        return self._transform_any_all(node, pos_args, False)
1657
1658    def _handle_simple_function_any(self, node, pos_args):
1659        """Transform
1660
1661        _result = any(p(x) for L in LL for x in L)
1662
1663        into
1664
1665        for L in LL:
1666            for x in L:
1667                if p(x):
1668                    return True
1669        else:
1670            return False
1671        """
1672        return self._transform_any_all(node, pos_args, True)
1673
1674    def _transform_any_all(self, node, pos_args, is_any):
1675        if len(pos_args) != 1:
1676            return node
1677        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1678            return node
1679        gen_expr_node = pos_args[0]
1680        generator_body = gen_expr_node.def_node.gbody
1681        loop_node = generator_body.body
1682        yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
1683        if yield_expression is None:
1684            return node
1685
1686        if is_any:
1687            condition = yield_expression
1688        else:
1689            condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression)
1690
1691        test_node = Nodes.IfStatNode(
1692            yield_expression.pos, else_clause=None, if_clauses=[
1693                Nodes.IfClauseNode(
1694                    yield_expression.pos,
1695                    condition=condition,
1696                    body=Nodes.ReturnStatNode(
1697                        node.pos,
1698                        value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any))
1699                )]
1700        )
1701        loop_node.else_clause = Nodes.ReturnStatNode(
1702            node.pos,
1703            value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any))
1704
1705        Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node)
1706
1707        return ExprNodes.InlinedGeneratorExpressionNode(
1708            gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all')
1709
1710    PySequence_List_func_type = PyrexTypes.CFuncType(
1711        Builtin.list_type,
1712        [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
1713
1714    def _handle_simple_function_sorted(self, node, pos_args):
1715        """Transform sorted(genexpr) and sorted([listcomp]) into
1716        [listcomp].sort().  CPython just reads the iterable into a
1717        list and calls .sort() on it.  Expanding the iterable in a
1718        listcomp is still faster and the result can be sorted in
1719        place.
1720        """
1721        if len(pos_args) != 1:
1722            return node
1723
1724        arg = pos_args[0]
1725        if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type:
1726            list_node = pos_args[0]
1727            loop_node = list_node.loop
1728
1729        elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
1730            gen_expr_node = arg
1731            loop_node = gen_expr_node.loop
1732            yield_statements = _find_yield_statements(loop_node)
1733            if not yield_statements:
1734                return node
1735
1736            list_node = ExprNodes.InlinedGeneratorExpressionNode(
1737                node.pos, gen_expr_node, orig_func='sorted',
1738                comprehension_type=Builtin.list_type)
1739
1740            for yield_expression, yield_stat_node in yield_statements:
1741                append_node = ExprNodes.ComprehensionAppendNode(
1742                    yield_expression.pos,
1743                    expr=yield_expression,
1744                    target=list_node.target)
1745                Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
1746
1747        elif arg.is_sequence_constructor:
1748            # sorted([a, b, c]) or sorted((a, b, c)).  The result is always a list,
1749            # so starting off with a fresh one is more efficient.
1750            list_node = loop_node = arg.as_list()
1751
1752        else:
1753            # Interestingly, PySequence_List works on a lot of non-sequence
1754            # things as well.
1755            list_node = loop_node = ExprNodes.PythonCapiCallNode(
1756                node.pos, "PySequence_List", self.PySequence_List_func_type,
1757                args=pos_args, is_temp=True)
1758
1759        result_node = UtilNodes.ResultRefNode(
1760            pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False)
1761        list_assign_node = Nodes.SingleAssignmentNode(
1762            node.pos, lhs=result_node, rhs=list_node, first=True)
1763
1764        sort_method = ExprNodes.AttributeNode(
1765            node.pos, obj=result_node, attribute=EncodedString('sort'),
1766            # entry ? type ?
1767            needs_none_check=False)
1768        sort_node = Nodes.ExprStatNode(
1769            node.pos, expr=ExprNodes.SimpleCallNode(
1770                node.pos, function=sort_method, args=[]))
1771
1772        sort_node.analyse_declarations(self.current_env())
1773
1774        return UtilNodes.TempResultFromStatNode(
1775            result_node,
1776            Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node]))
1777
1778    def __handle_simple_function_sum(self, node, pos_args):
1779        """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1780        """
1781        if len(pos_args) not in (1,2):
1782            return node
1783        if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
1784                                        ExprNodes.ComprehensionNode)):
1785            return node
1786        gen_expr_node = pos_args[0]
1787        loop_node = gen_expr_node.loop
1788
1789        if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1790            yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
1791            # FIXME: currently nonfunctional
1792            yield_expression = None
1793            if yield_expression is None:
1794                return node
1795        else:  # ComprehensionNode
1796            yield_stat_node = gen_expr_node.append
1797            yield_expression = yield_stat_node.expr
1798            try:
1799                if not yield_expression.is_literal or not yield_expression.type.is_int:
1800                    return node
1801            except AttributeError:
1802                return node # in case we don't have a type yet
1803            # special case: old Py2 backwards compatible "sum([int_const for ...])"
1804            # can safely be unpacked into a genexpr
1805
1806        if len(pos_args) == 1:
1807            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1808        else:
1809            start = pos_args[1]
1810
1811        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1812        add_node = Nodes.SingleAssignmentNode(
1813            yield_expression.pos,
1814            lhs = result_ref,
1815            rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1816            )
1817
1818        Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node)
1819
1820        exec_code = Nodes.StatListNode(
1821            node.pos,
1822            stats = [
1823                Nodes.SingleAssignmentNode(
1824                    start.pos,
1825                    lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1826                    rhs = start,
1827                    first = True),
1828                loop_node
1829                ])
1830
1831        return ExprNodes.InlinedGeneratorExpressionNode(
1832            gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1833            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
1834            has_local_scope = gen_expr_node.has_local_scope)
1835
1836    def _handle_simple_function_min(self, node, pos_args):
1837        return self._optimise_min_max(node, pos_args, '<')
1838
1839    def _handle_simple_function_max(self, node, pos_args):
1840        return self._optimise_min_max(node, pos_args, '>')
1841
1842    def _optimise_min_max(self, node, args, operator):
1843        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1844        """
1845        if len(args) <= 1:
1846            if len(args) == 1 and args[0].is_sequence_constructor:
1847                args = args[0].args
1848            if len(args) <= 1:
1849                # leave this to Python
1850                return node
1851
1852        cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1853
1854        last_result = args[0]
1855        for arg_node in cascaded_nodes:
1856            result_ref = UtilNodes.ResultRefNode(last_result)
1857            last_result = ExprNodes.CondExprNode(
1858                arg_node.pos,
1859                true_val = arg_node,
1860                false_val = result_ref,
1861                test = ExprNodes.PrimaryCmpNode(
1862                    arg_node.pos,
1863                    operand1 = arg_node,
1864                    operator = operator,
1865                    operand2 = result_ref,
1866                    )
1867                )
1868            last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1869
1870        for ref_node in cascaded_nodes[::-1]:
1871            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1872
1873        return last_result
1874
1875    # builtin type creation
1876
1877    def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1878        if not pos_args:
1879            return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1880        # This is a bit special - for iterables (including genexps),
1881        # Python actually overallocates and resizes a newly created
1882        # tuple incrementally while reading items, which we can't
1883        # easily do without explicit node support. Instead, we read
1884        # the items into a list and then copy them into a tuple of the
1885        # final size.  This takes up to twice as much memory, but will
1886        # have to do until we have real support for genexps.
1887        result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1888        if result is not node:
1889            return ExprNodes.AsTupleNode(node.pos, arg=result)
1890        return node
1891
1892    def _handle_simple_function_frozenset(self, node, pos_args):
1893        """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
1894        """
1895        if len(pos_args) != 1:
1896            return node
1897        if pos_args[0].is_sequence_constructor and not pos_args[0].args:
1898            del pos_args[0]
1899        elif isinstance(pos_args[0], ExprNodes.ListNode):
1900            pos_args[0] = pos_args[0].as_tuple()
1901        return node
1902
1903    def _handle_simple_function_list(self, node, pos_args):
1904        if not pos_args:
1905            return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1906        return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1907
1908    def _handle_simple_function_set(self, node, pos_args):
1909        if not pos_args:
1910            return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1911        return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
1912
1913    def _transform_list_set_genexpr(self, node, pos_args, target_type):
1914        """Replace set(genexpr) and list(genexpr) by an inlined comprehension.
1915        """
1916        if len(pos_args) > 1:
1917            return node
1918        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1919            return node
1920        gen_expr_node = pos_args[0]
1921        loop_node = gen_expr_node.loop
1922
1923        yield_statements = _find_yield_statements(loop_node)
1924        if not yield_statements:
1925            return node
1926
1927        result_node = ExprNodes.InlinedGeneratorExpressionNode(
1928            node.pos, gen_expr_node,
1929            orig_func='set' if target_type is Builtin.set_type else 'list',
1930            comprehension_type=target_type)
1931
1932        for yield_expression, yield_stat_node in yield_statements:
1933            append_node = ExprNodes.ComprehensionAppendNode(
1934                yield_expression.pos,
1935                expr=yield_expression,
1936                target=result_node.target)
1937            Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
1938
1939        return result_node
1940
1941    def _handle_simple_function_dict(self, node, pos_args):
1942        """Replace dict( (a,b) for ... ) by an inlined { a:b for ... }
1943        """
1944        if len(pos_args) == 0:
1945            return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1946        if len(pos_args) > 1:
1947            return node
1948        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1949            return node
1950        gen_expr_node = pos_args[0]
1951        loop_node = gen_expr_node.loop
1952
1953        yield_statements = _find_yield_statements(loop_node)
1954        if not yield_statements:
1955            return node
1956
1957        for yield_expression, _ in yield_statements:
1958            if not isinstance(yield_expression, ExprNodes.TupleNode):
1959                return node
1960            if len(yield_expression.args) != 2:
1961                return node
1962
1963        result_node = ExprNodes.InlinedGeneratorExpressionNode(
1964            node.pos, gen_expr_node, orig_func='dict',
1965            comprehension_type=Builtin.dict_type)
1966
1967        for yield_expression, yield_stat_node in yield_statements:
1968            append_node = ExprNodes.DictComprehensionAppendNode(
1969                yield_expression.pos,
1970                key_expr=yield_expression.args[0],
1971                value_expr=yield_expression.args[1],
1972                target=result_node.target)
1973            Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
1974
1975        return result_node
1976
1977    # specific handlers for general call nodes
1978
1979    def _handle_general_function_dict(self, node, pos_args, kwargs):
1980        """Replace dict(a=b,c=d,...) by the underlying keyword dict
1981        construction which is done anyway.
1982        """
1983        if len(pos_args) > 0:
1984            return node
1985        if not isinstance(kwargs, ExprNodes.DictNode):
1986            return node
1987        return kwargs
1988
1989
1990class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
1991    visit_Node = Visitor.VisitorTransform.recurse_to_children
1992
1993    def get_constant_value_node(self, name_node):
1994        if name_node.cf_state is None:
1995            return None
1996        if name_node.cf_state.cf_is_null:
1997            return None
1998        entry = self.current_env().lookup(name_node.name)
1999        if not entry or (not entry.cf_assignments
2000                         or len(entry.cf_assignments) != 1):
2001            # not just a single assignment in all closures
2002            return None
2003        return entry.cf_assignments[0].rhs
2004
2005    def visit_SimpleCallNode(self, node):
2006        self.visitchildren(node)
2007        if not self.current_directives.get('optimize.inline_defnode_calls'):
2008            return node
2009        function_name = node.function
2010        if not function_name.is_name:
2011            return node
2012        function = self.get_constant_value_node(function_name)
2013        if not isinstance(function, ExprNodes.PyCFunctionNode):
2014            return node
2015        inlined = ExprNodes.InlinedDefNodeCallNode(
2016            node.pos, function_name=function_name,
2017            function=function, args=node.args)
2018        if inlined.can_be_inlined():
2019            return self.replace(node, inlined)
2020        return node
2021
2022
2023class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
2024                           Visitor.MethodDispatcherTransform):
2025    """Optimize some common methods calls and instantiation patterns
2026    for builtin types *after* the type analysis phase.
2027
2028    Running after type analysis, this transform can only perform
2029    function replacements that do not alter the function return type
2030    in a way that was not anticipated by the type analysis.
2031    """
2032    ### cleanup to avoid redundant coercions to/from Python types
2033
2034    def visit_PyTypeTestNode(self, node):
2035        """Flatten redundant type checks after tree changes.
2036        """
2037        self.visitchildren(node)
2038        return node.reanalyse()
2039
2040    def _visit_TypecastNode(self, node):
2041        # disabled - the user may have had a reason to put a type
2042        # cast, even if it looks redundant to Cython
2043        """
2044        Drop redundant type casts.
2045        """
2046        self.visitchildren(node)
2047        if node.type == node.operand.type:
2048            return node.operand
2049        return node
2050
2051    def visit_ExprStatNode(self, node):
2052        """
2053        Drop dead code and useless coercions.
2054        """
2055        self.visitchildren(node)
2056        if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
2057            node.expr = node.expr.arg
2058        expr = node.expr
2059        if expr is None or expr.is_none or expr.is_literal:
2060            # Expression was removed or is dead code => remove ExprStatNode as well.
2061            return None
2062        if expr.is_name and expr.entry and (expr.entry.is_local or expr.entry.is_arg):
2063            # Ignore dead references to local variables etc.
2064            return None
2065        return node
2066
2067    def visit_CoerceToBooleanNode(self, node):
2068        """Drop redundant conversion nodes after tree changes.
2069        """
2070        self.visitchildren(node)
2071        arg = node.arg
2072        if isinstance(arg, ExprNodes.PyTypeTestNode):
2073            arg = arg.arg
2074        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2075            if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
2076                return arg.arg.coerce_to_boolean(self.current_env())
2077        return node
2078
2079    PyNumber_Float_func_type = PyrexTypes.CFuncType(
2080        PyrexTypes.py_object_type, [
2081            PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
2082            ])
2083
2084    def visit_CoerceToPyTypeNode(self, node):
2085        """Drop redundant conversion nodes after tree changes."""
2086        self.visitchildren(node)
2087        arg = node.arg
2088        if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
2089            arg = arg.arg
2090        if isinstance(arg, ExprNodes.PythonCapiCallNode):
2091            if arg.function.name == 'float' and len(arg.args) == 1:
2092                # undo redundant Py->C->Py coercion
2093                func_arg = arg.args[0]
2094                if func_arg.type is Builtin.float_type:
2095                    return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'")
2096                elif func_arg.type.is_pyobject:
2097                    return ExprNodes.PythonCapiCallNode(
2098                        node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type,
2099                        args=[func_arg],
2100                        py_name='float',
2101                        is_temp=node.is_temp,
2102                        result_is_used=node.result_is_used,
2103                    ).coerce_to(node.type, self.current_env())
2104        return node
2105
2106    def visit_CoerceFromPyTypeNode(self, node):
2107        """Drop redundant conversion nodes after tree changes.
2108
2109        Also, optimise away calls to Python's builtin int() and
2110        float() if the result is going to be coerced back into a C
2111        type anyway.
2112        """
2113        self.visitchildren(node)
2114        arg = node.arg
2115        if not arg.type.is_pyobject:
2116            # no Python conversion left at all, just do a C coercion instead
2117            if node.type != arg.type:
2118                arg = arg.coerce_to(node.type, self.current_env())
2119            return arg
2120        if isinstance(arg, ExprNodes.PyTypeTestNode):
2121            arg = arg.arg
2122        if arg.is_literal:
2123            if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
2124                    node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
2125                    node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
2126                return arg.coerce_to(node.type, self.current_env())
2127        elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2128            if arg.type is PyrexTypes.py_object_type:
2129                if node.type.assignable_from(arg.arg.type):
2130                    # completely redundant C->Py->C coercion
2131                    return arg.arg.coerce_to(node.type, self.current_env())
2132            elif arg.type is Builtin.unicode_type:
2133                if arg.arg.type.is_unicode_char and node.type.is_unicode_char:
2134                    return arg.arg.coerce_to(node.type, self.current_env())
2135        elif isinstance(arg, ExprNodes.SimpleCallNode):
2136            if node.type.is_int or node.type.is_float:
2137                return self._optimise_numeric_cast_call(node, arg)
2138        elif arg.is_subscript:
2139            index_node = arg.index
2140            if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
2141                index_node = index_node.arg
2142            if index_node.type.is_int:
2143                return self._optimise_int_indexing(node, arg, index_node)
2144        return node
2145
2146    PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
2147        PyrexTypes.c_char_type, [
2148            PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
2149            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
2150            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
2151            ],
2152        exception_value = "((char)-1)",
2153        exception_check = True)
2154
2155    def _optimise_int_indexing(self, coerce_node, arg, index_node):
2156        env = self.current_env()
2157        bound_check_bool = env.directives['boundscheck'] and 1 or 0
2158        if arg.base.type is Builtin.bytes_type:
2159            if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
2160                # bytes[index] -> char
2161                bound_check_node = ExprNodes.IntNode(
2162                    coerce_node.pos, value=str(bound_check_bool),
2163                    constant_result=bound_check_bool)
2164                node = ExprNodes.PythonCapiCallNode(
2165                    coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
2166                    self.PyBytes_GetItemInt_func_type,
2167                    args=[
2168                        arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
2169                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
2170                        bound_check_node,
2171                        ],
2172                    is_temp=True,
2173                    utility_code=UtilityCode.load_cached(
2174                        'bytes_index', 'StringTools.c'))
2175                if coerce_node.type is not PyrexTypes.c_char_type:
2176                    node = node.coerce_to(coerce_node.type, env)
2177                return node
2178        return coerce_node
2179
2180    float_float_func_types = dict(
2181        (float_type, PyrexTypes.CFuncType(
2182            float_type, [
2183                PyrexTypes.CFuncTypeArg("arg", float_type, None)
2184            ]))
2185        for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type))
2186
2187    def _optimise_numeric_cast_call(self, node, arg):
2188        function = arg.function
2189        args = None
2190        if isinstance(arg, ExprNodes.PythonCapiCallNode):
2191            args = arg.args
2192        elif isinstance(function, ExprNodes.NameNode):
2193            if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode):
2194                args = arg.arg_tuple.args
2195
2196        if args is None or len(args) != 1:
2197            return node
2198        func_arg = args[0]
2199        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2200            func_arg = func_arg.arg
2201        elif func_arg.type.is_pyobject:
2202            # play it safe: Python conversion might work on all sorts of things
2203            return node
2204
2205        if function.name == 'int':
2206            if func_arg.type.is_int or node.type.is_int:
2207                if func_arg.type == node.type:
2208                    return func_arg
2209                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
2210                    return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type)
2211            elif func_arg.type.is_float and node.type.is_numeric:
2212                if func_arg.type.math_h_modifier == 'l':
2213                    # Work around missing Cygwin definition.
2214                    truncl = '__Pyx_truncl'
2215                else:
2216                    truncl = 'trunc' + func_arg.type.math_h_modifier
2217                return ExprNodes.PythonCapiCallNode(
2218                    node.pos, truncl,
2219                    func_type=self.float_float_func_types[func_arg.type],
2220                    args=[func_arg],
2221                    py_name='int',
2222                    is_temp=node.is_temp,
2223                    result_is_used=node.result_is_used,
2224                ).coerce_to(node.type, self.current_env())
2225        elif function.name == 'float':
2226            if func_arg.type.is_float or node.type.is_float:
2227                if func_arg.type == node.type:
2228                    return func_arg
2229                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
2230                    return ExprNodes.TypecastNode(
2231                        node.pos, operand=func_arg, type=node.type)
2232        return node
2233
2234    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
2235        if not expected: # None or 0
2236            arg_str = ''
2237        elif isinstance(expected, basestring) or expected > 1:
2238            arg_str = '...'
2239        elif expected == 1:
2240            arg_str = 'x'
2241        else:
2242            arg_str = ''
2243        if expected is not None:
2244            expected_str = 'expected %s, ' % expected
2245        else:
2246            expected_str = ''
2247        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
2248            function_name, arg_str, expected_str, len(args)))
2249
2250    ### generic fallbacks
2251
2252    def _handle_function(self, node, function_name, function, arg_list, kwargs):
2253        return node
2254
2255    def _handle_method(self, node, type_name, attr_name, function,
2256                       arg_list, is_unbound_method, kwargs):
2257        """
2258        Try to inject C-API calls for unbound method calls to builtin types.
2259        While the method declarations in Builtin.py already handle this, we
2260        can additionally resolve bound and unbound methods here that were
2261        assigned to variables ahead of time.
2262        """
2263        if kwargs:
2264            return node
2265        if not function or not function.is_attribute or not function.obj.is_name:
2266            # cannot track unbound method calls over more than one indirection as
2267            # the names might have been reassigned in the meantime
2268            return node
2269        type_entry = self.current_env().lookup(type_name)
2270        if not type_entry:
2271            return node
2272        method = ExprNodes.AttributeNode(
2273            node.function.pos,
2274            obj=ExprNodes.NameNode(
2275                function.pos,
2276                name=type_name,
2277                entry=type_entry,
2278                type=type_entry.type),
2279            attribute=attr_name,
2280            is_called=True).analyse_as_type_attribute(self.current_env())
2281        if method is None:
2282            return self._optimise_generic_builtin_method_call(
2283                node, attr_name, function, arg_list, is_unbound_method)
2284        args = node.args
2285        if args is None and node.arg_tuple:
2286            args = node.arg_tuple.args
2287        call_node = ExprNodes.SimpleCallNode(
2288            node.pos,
2289            function=method,
2290            args=args)
2291        if not is_unbound_method:
2292            call_node.self = function.obj
2293        call_node.analyse_c_function_call(self.current_env())
2294        call_node.analysed = True
2295        return call_node.coerce_to(node.type, self.current_env())
2296
2297    ### builtin types
2298
2299    def _optimise_generic_builtin_method_call(self, node, attr_name, function, arg_list, is_unbound_method):
2300        """
2301        Try to inject an unbound method call for a call to a method of a known builtin type.
2302        This enables caching the underlying C function of the method at runtime.
2303        """
2304        arg_count = len(arg_list)
2305        if is_unbound_method or arg_count >= 3 or not (function.is_attribute and function.is_py_attr):
2306            return node
2307        if not function.obj.type.is_builtin_type:
2308            return node
2309        if function.obj.type.name in ('basestring', 'type'):
2310            # these allow different actual types => unsafe
2311            return node
2312        return ExprNodes.CachedBuiltinMethodCallNode(
2313            node, function.obj, attr_name, arg_list)
2314
2315    PyObject_Unicode_func_type = PyrexTypes.CFuncType(
2316        Builtin.unicode_type, [
2317            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
2318            ])
2319
2320    def _handle_simple_function_unicode(self, node, function, pos_args):
2321        """Optimise single argument calls to unicode().
2322        """
2323        if len(pos_args) != 1:
2324            if len(pos_args) == 0:
2325                return ExprNodes.UnicodeNode(node.pos, value=EncodedString(), constant_result=u'')
2326            return node
2327        arg = pos_args[0]
2328        if arg.type is Builtin.unicode_type:
2329            if not arg.may_be_none():
2330                return arg
2331            cname = "__Pyx_PyUnicode_Unicode"
2332            utility_code = UtilityCode.load_cached('PyUnicode_Unicode', 'StringTools.c')
2333        else:
2334            cname = "__Pyx_PyObject_Unicode"
2335            utility_code = UtilityCode.load_cached('PyObject_Unicode', 'StringTools.c')
2336        return ExprNodes.PythonCapiCallNode(
2337            node.pos, cname, self.PyObject_Unicode_func_type,
2338            args=pos_args,
2339            is_temp=node.is_temp,
2340            utility_code=utility_code,
2341            py_name="unicode")
2342
2343    def visit_FormattedValueNode(self, node):
2344        """Simplify or avoid plain string formatting of a unicode value.
2345        This seems misplaced here, but plain unicode formatting is essentially
2346        a call to the unicode() builtin, which is optimised right above.
2347        """
2348        self.visitchildren(node)
2349        if node.value.type is Builtin.unicode_type and not node.c_format_spec and not node.format_spec:
2350            if not node.conversion_char or node.conversion_char == 's':
2351                # value is definitely a unicode string and we don't format it any special
2352                return self._handle_simple_function_unicode(node, None, [node.value])
2353        return node
2354
2355    PyDict_Copy_func_type = PyrexTypes.CFuncType(
2356        Builtin.dict_type, [
2357            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
2358            ])
2359
2360    def _handle_simple_function_dict(self, node, function, pos_args):
2361        """Replace dict(some_dict) by PyDict_Copy(some_dict).
2362        """
2363        if len(pos_args) != 1:
2364            return node
2365        arg = pos_args[0]
2366        if arg.type is Builtin.dict_type:
2367            arg = arg.as_none_safe_node("'NoneType' is not iterable")
2368            return ExprNodes.PythonCapiCallNode(
2369                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
2370                args = [arg],
2371                is_temp = node.is_temp
2372                )
2373        return node
2374
2375    PySequence_List_func_type = PyrexTypes.CFuncType(
2376        Builtin.list_type,
2377        [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
2378
2379    def _handle_simple_function_list(self, node, function, pos_args):
2380        """Turn list(ob) into PySequence_List(ob).
2381        """
2382        if len(pos_args) != 1:
2383            return node
2384        arg = pos_args[0]
2385        return ExprNodes.PythonCapiCallNode(
2386            node.pos, "PySequence_List", self.PySequence_List_func_type,
2387            args=pos_args, is_temp=node.is_temp)
2388
2389    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
2390        Builtin.tuple_type, [
2391            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
2392            ])
2393
2394    def _handle_simple_function_tuple(self, node, function, pos_args):
2395        """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple.
2396        """
2397        if len(pos_args) != 1 or not node.is_temp:
2398            return node
2399        arg = pos_args[0]
2400        if arg.type is Builtin.tuple_type and not arg.may_be_none():
2401            return arg
2402        if arg.type is Builtin.list_type:
2403            pos_args[0] = arg.as_none_safe_node(
2404                "'NoneType' object is not iterable")
2405
2406            return ExprNodes.PythonCapiCallNode(
2407                node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
2408                args=pos_args, is_temp=node.is_temp)
2409        else:
2410            return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type)
2411
2412    PySet_New_func_type = PyrexTypes.CFuncType(
2413        Builtin.set_type, [
2414            PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
2415        ])
2416
2417    def _handle_simple_function_set(self, node, function, pos_args):
2418        if len(pos_args) != 1:
2419            return node
2420        if pos_args[0].is_sequence_constructor:
2421            # We can optimise set([x,y,z]) safely into a set literal,
2422            # but only if we create all items before adding them -
2423            # adding an item may raise an exception if it is not
2424            # hashable, but creating the later items may have
2425            # side-effects.
2426            args = []
2427            temps = []
2428            for arg in pos_args[0].args:
2429                if not arg.is_simple():
2430                    arg = UtilNodes.LetRefNode(arg)
2431                    temps.append(arg)
2432                args.append(arg)
2433            result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
2434            self.replace(node, result)
2435            for temp in temps[::-1]:
2436                result = UtilNodes.EvalWithTempExprNode(temp, result)
2437            return result
2438        else:
2439            # PySet_New(it) is better than a generic Python call to set(it)
2440            return self.replace(node, ExprNodes.PythonCapiCallNode(
2441                node.pos, "PySet_New",
2442                self.PySet_New_func_type,
2443                args=pos_args,
2444                is_temp=node.is_temp,
2445                py_name="set"))
2446
2447    PyFrozenSet_New_func_type = PyrexTypes.CFuncType(
2448        Builtin.frozenset_type, [
2449            PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
2450        ])
2451
2452    def _handle_simple_function_frozenset(self, node, function, pos_args):
2453        if not pos_args:
2454            pos_args = [ExprNodes.NullNode(node.pos)]
2455        elif len(pos_args) > 1:
2456            return node
2457        elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none():
2458            return pos_args[0]
2459        # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
2460        return ExprNodes.PythonCapiCallNode(
2461            node.pos, "__Pyx_PyFrozenSet_New",
2462            self.PyFrozenSet_New_func_type,
2463            args=pos_args,
2464            is_temp=node.is_temp,
2465            utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'),
2466            py_name="frozenset")
2467
2468    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
2469        PyrexTypes.c_double_type, [
2470            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2471            ],
2472        exception_value = "((double)-1)",
2473        exception_check = True)
2474
2475    def _handle_simple_function_float(self, node, function, pos_args):
2476        """Transform float() into either a C type cast or a faster C
2477        function call.
2478        """
2479        # Note: this requires the float() function to be typed as
2480        # returning a C 'double'
2481        if len(pos_args) == 0:
2482            return ExprNodes.FloatNode(
2483                node, value="0.0", constant_result=0.0
2484                ).coerce_to(Builtin.float_type, self.current_env())
2485        elif len(pos_args) != 1:
2486            self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
2487            return node
2488        func_arg = pos_args[0]
2489        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2490            func_arg = func_arg.arg
2491        if func_arg.type is PyrexTypes.c_double_type:
2492            return func_arg
2493        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
2494            return ExprNodes.TypecastNode(
2495                node.pos, operand=func_arg, type=node.type)
2496        return ExprNodes.PythonCapiCallNode(
2497            node.pos, "__Pyx_PyObject_AsDouble",
2498            self.PyObject_AsDouble_func_type,
2499            args = pos_args,
2500            is_temp = node.is_temp,
2501            utility_code = load_c_utility('pyobject_as_double'),
2502            py_name = "float")
2503
2504    PyNumber_Int_func_type = PyrexTypes.CFuncType(
2505        PyrexTypes.py_object_type, [
2506            PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
2507            ])
2508
2509    PyInt_FromDouble_func_type = PyrexTypes.CFuncType(
2510        PyrexTypes.py_object_type, [
2511            PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None)
2512            ])
2513
2514    def _handle_simple_function_int(self, node, function, pos_args):
2515        """Transform int() into a faster C function call.
2516        """
2517        if len(pos_args) == 0:
2518            return ExprNodes.IntNode(node.pos, value="0", constant_result=0,
2519                                     type=PyrexTypes.py_object_type)
2520        elif len(pos_args) != 1:
2521            return node  # int(x, base)
2522        func_arg = pos_args[0]
2523        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2524            if func_arg.arg.type.is_float:
2525                return ExprNodes.PythonCapiCallNode(
2526                    node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type,
2527                    args=[func_arg.arg], is_temp=True, py_name='int',
2528                    utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c"))
2529            else:
2530                return node  # handled in visit_CoerceFromPyTypeNode()
2531        if func_arg.type.is_pyobject and node.type.is_pyobject:
2532            return ExprNodes.PythonCapiCallNode(
2533                node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type,
2534                args=pos_args, is_temp=True, py_name='int')
2535        return node
2536
2537    def _handle_simple_function_bool(self, node, function, pos_args):
2538        """Transform bool(x) into a type coercion to a boolean.
2539        """
2540        if len(pos_args) == 0:
2541            return ExprNodes.BoolNode(
2542                node.pos, value=False, constant_result=False
2543                ).coerce_to(Builtin.bool_type, self.current_env())
2544        elif len(pos_args) != 1:
2545            self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
2546            return node
2547        else:
2548            # => !!<bint>(x)  to make sure it's exactly 0 or 1
2549            operand = pos_args[0].coerce_to_boolean(self.current_env())
2550            operand = ExprNodes.NotNode(node.pos, operand = operand)
2551            operand = ExprNodes.NotNode(node.pos, operand = operand)
2552            # coerce back to Python object as that's the result we are expecting
2553            return operand.coerce_to_pyobject(self.current_env())
2554
2555    ### builtin functions
2556
2557    Pyx_strlen_func_type = PyrexTypes.CFuncType(
2558        PyrexTypes.c_size_t_type, [
2559            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None)
2560        ])
2561
2562    Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
2563        PyrexTypes.c_size_t_type, [
2564            PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None)
2565        ])
2566
2567    PyObject_Size_func_type = PyrexTypes.CFuncType(
2568        PyrexTypes.c_py_ssize_t_type, [
2569            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
2570        ],
2571        exception_value="-1")
2572
2573    _map_to_capi_len_function = {
2574        Builtin.unicode_type:    "__Pyx_PyUnicode_GET_LENGTH",
2575        Builtin.bytes_type:      "PyBytes_GET_SIZE",
2576        Builtin.bytearray_type:  'PyByteArray_GET_SIZE',
2577        Builtin.list_type:       "PyList_GET_SIZE",
2578        Builtin.tuple_type:      "PyTuple_GET_SIZE",
2579        Builtin.set_type:        "PySet_GET_SIZE",
2580        Builtin.frozenset_type:  "PySet_GET_SIZE",
2581        Builtin.dict_type:       "PyDict_Size",
2582    }.get
2583
2584    _ext_types_with_pysize = set(["cpython.array.array"])
2585
2586    def _handle_simple_function_len(self, node, function, pos_args):
2587        """Replace len(char*) by the equivalent call to strlen(),
2588        len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
2589        len(known_builtin_type) by an equivalent C-API call.
2590        """
2591        if len(pos_args) != 1:
2592            self._error_wrong_arg_count('len', node, pos_args, 1)
2593            return node
2594        arg = pos_args[0]
2595        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2596            arg = arg.arg
2597        if arg.type.is_string:
2598            new_node = ExprNodes.PythonCapiCallNode(
2599                node.pos, "strlen", self.Pyx_strlen_func_type,
2600                args = [arg],
2601                is_temp = node.is_temp,
2602                utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
2603        elif arg.type.is_pyunicode_ptr:
2604            new_node = ExprNodes.PythonCapiCallNode(
2605                node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
2606                args = [arg],
2607                is_temp = node.is_temp)
2608        elif arg.type.is_memoryviewslice:
2609            func_type = PyrexTypes.CFuncType(
2610                PyrexTypes.c_size_t_type, [
2611                    PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None)
2612                ], nogil=True)
2613            new_node = ExprNodes.PythonCapiCallNode(
2614                node.pos, "__Pyx_MemoryView_Len", func_type,
2615                args=[arg], is_temp=node.is_temp)
2616        elif arg.type.is_pyobject:
2617            cfunc_name = self._map_to_capi_len_function(arg.type)
2618            if cfunc_name is None:
2619                arg_type = arg.type
2620                if ((arg_type.is_extension_type or arg_type.is_builtin_type)
2621                    and arg_type.entry.qualified_name in self._ext_types_with_pysize):
2622                    cfunc_name = 'Py_SIZE'
2623                else:
2624                    return node
2625            arg = arg.as_none_safe_node(
2626                "object of type 'NoneType' has no len()")
2627            new_node = ExprNodes.PythonCapiCallNode(
2628                node.pos, cfunc_name, self.PyObject_Size_func_type,
2629                args=[arg], is_temp=node.is_temp)
2630        elif arg.type.is_unicode_char:
2631            return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
2632                                     type=node.type)
2633        else:
2634            return node
2635        if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
2636            new_node = new_node.coerce_to(node.type, self.current_env())
2637        return new_node
2638
2639    Pyx_Type_func_type = PyrexTypes.CFuncType(
2640        Builtin.type_type, [
2641            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
2642            ])
2643
2644    def _handle_simple_function_type(self, node, function, pos_args):
2645        """Replace type(o) by a macro call to Py_TYPE(o).
2646        """
2647        if len(pos_args) != 1:
2648            return node
2649        node = ExprNodes.PythonCapiCallNode(
2650            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
2651            args = pos_args,
2652            is_temp = False)
2653        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
2654
2655    Py_type_check_func_type = PyrexTypes.CFuncType(
2656        PyrexTypes.c_bint_type, [
2657            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
2658            ])
2659
2660    def _handle_simple_function_isinstance(self, node, function, pos_args):
2661        """Replace isinstance() checks against builtin types by the
2662        corresponding C-API call.
2663        """
2664        if len(pos_args) != 2:
2665            return node
2666        arg, types = pos_args
2667        temps = []
2668        if isinstance(types, ExprNodes.TupleNode):
2669            types = types.args
2670            if len(types) == 1 and not types[0].type is Builtin.type_type:
2671                return node  # nothing to improve here
2672            if arg.is_attribute or not arg.is_simple():
2673                arg = UtilNodes.ResultRefNode(arg)
2674                temps.append(arg)
2675        elif types.type is Builtin.type_type:
2676            types = [types]
2677        else:
2678            return node
2679
2680        tests = []
2681        test_nodes = []
2682        env = self.current_env()
2683        for test_type_node in types:
2684            builtin_type = None
2685            if test_type_node.is_name:
2686                if test_type_node.entry:
2687                    entry = env.lookup(test_type_node.entry.name)
2688                    if entry and entry.type and entry.type.is_builtin_type:
2689                        builtin_type = entry.type
2690            if builtin_type is Builtin.type_type:
2691                # all types have type "type", but there's only one 'type'
2692                if entry.name != 'type' or not (
2693                        entry.scope and entry.scope.is_builtin_scope):
2694                    builtin_type = None
2695            if builtin_type is not None:
2696                type_check_function = entry.type.type_check_function(exact=False)
2697                if type_check_function in tests:
2698                    continue
2699                tests.append(type_check_function)
2700                type_check_args = [arg]
2701            elif test_type_node.type is Builtin.type_type:
2702                type_check_function = '__Pyx_TypeCheck'
2703                type_check_args = [arg, test_type_node]
2704            else:
2705                if not test_type_node.is_literal:
2706                    test_type_node = UtilNodes.ResultRefNode(test_type_node)
2707                    temps.append(test_type_node)
2708                type_check_function = 'PyObject_IsInstance'
2709                type_check_args = [arg, test_type_node]
2710            test_nodes.append(
2711                ExprNodes.PythonCapiCallNode(
2712                    test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2713                    args=type_check_args,
2714                    is_temp=True,
2715                ))
2716
2717        def join_with_or(a, b, make_binop_node=ExprNodes.binop_node):
2718            or_node = make_binop_node(node.pos, 'or', a, b)
2719            or_node.type = PyrexTypes.c_bint_type
2720            or_node.wrap_operands(env)
2721            return or_node
2722
2723        test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2724        for temp in temps[::-1]:
2725            test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
2726        return test_node
2727
2728    def _handle_simple_function_ord(self, node, function, pos_args):
2729        """Unpack ord(Py_UNICODE) and ord('X').
2730        """
2731        if len(pos_args) != 1:
2732            return node
2733        arg = pos_args[0]
2734        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2735            if arg.arg.type.is_unicode_char:
2736                return ExprNodes.TypecastNode(
2737                    arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type
2738                    ).coerce_to(node.type, self.current_env())
2739        elif isinstance(arg, ExprNodes.UnicodeNode):
2740            if len(arg.value) == 1:
2741                return ExprNodes.IntNode(
2742                    arg.pos, type=PyrexTypes.c_int_type,
2743                    value=str(ord(arg.value)),
2744                    constant_result=ord(arg.value)
2745                    ).coerce_to(node.type, self.current_env())
2746        elif isinstance(arg, ExprNodes.StringNode):
2747            if arg.unicode_value and len(arg.unicode_value) == 1 \
2748                    and ord(arg.unicode_value) <= 255:  # Py2/3 portability
2749                return ExprNodes.IntNode(
2750                    arg.pos, type=PyrexTypes.c_int_type,
2751                    value=str(ord(arg.unicode_value)),
2752                    constant_result=ord(arg.unicode_value)
2753                    ).coerce_to(node.type, self.current_env())
2754        return node
2755
2756    ### special methods
2757
2758    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2759        PyrexTypes.py_object_type, [
2760            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2761            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
2762            ])
2763
2764    Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
2765        PyrexTypes.py_object_type, [
2766            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2767            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
2768            PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
2769        ])
2770
2771    def _handle_any_slot__new__(self, node, function, args,
2772                                is_unbound_method, kwargs=None):
2773        """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
2774        """
2775        obj = function.obj
2776        if not is_unbound_method or len(args) < 1:
2777            return node
2778        type_arg = args[0]
2779        if not obj.is_name or not type_arg.is_name:
2780            # play safe
2781            return node
2782        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2783            # not a known type, play safe
2784            return node
2785        if not type_arg.type_entry or not obj.type_entry:
2786            if obj.name != type_arg.name:
2787                return node
2788            # otherwise, we know it's a type and we know it's the same
2789            # type for both - that should do
2790        elif type_arg.type_entry != obj.type_entry:
2791            # different types - may or may not lead to an error at runtime
2792            return node
2793
2794        args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
2795        args_tuple = args_tuple.analyse_types(
2796            self.current_env(), skip_children=True)
2797
2798        if type_arg.type_entry:
2799            ext_type = type_arg.type_entry.type
2800            if (ext_type.is_extension_type and ext_type.typeobj_cname and
2801                    ext_type.scope.global_scope() == self.current_env().global_scope()):
2802                # known type in current module
2803                tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
2804                slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
2805                if slot_func_cname:
2806                    cython_scope = self.context.cython_scope
2807                    PyTypeObjectPtr = PyrexTypes.CPtrType(
2808                        cython_scope.lookup('PyTypeObject').type)
2809                    pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
2810                        ext_type, [
2811                            PyrexTypes.CFuncTypeArg("type",   PyTypeObjectPtr, None),
2812                            PyrexTypes.CFuncTypeArg("args",   PyrexTypes.py_object_type, None),
2813                            PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
2814                            ])
2815
2816                    type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
2817                    if not kwargs:
2818                        kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type)  # hack?
2819                    return ExprNodes.PythonCapiCallNode(
2820                        node.pos, slot_func_cname,
2821                        pyx_tp_new_kwargs_func_type,
2822                        args=[type_arg, args_tuple, kwargs],
2823                        may_return_none=False,
2824                        is_temp=True)
2825        else:
2826            # arbitrary variable, needs a None check for safety
2827            type_arg = type_arg.as_none_safe_node(
2828                "object.__new__(X): X is not a type object (NoneType)")
2829
2830        utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
2831        if kwargs:
2832            return ExprNodes.PythonCapiCallNode(
2833                node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
2834                args=[type_arg, args_tuple, kwargs],
2835                utility_code=utility_code,
2836                is_temp=node.is_temp
2837                )
2838        else:
2839            return ExprNodes.PythonCapiCallNode(
2840                node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2841                args=[type_arg, args_tuple],
2842                utility_code=utility_code,
2843                is_temp=node.is_temp
2844            )
2845
2846    ### methods of builtin types
2847
2848    PyObject_Append_func_type = PyrexTypes.CFuncType(
2849        PyrexTypes.c_returncode_type, [
2850            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2851            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2852            ],
2853        exception_value="-1")
2854
2855    def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
2856        """Optimistic optimisation as X.append() is almost always
2857        referring to a list.
2858        """
2859        if len(args) != 2 or node.result_is_used:
2860            return node
2861
2862        return ExprNodes.PythonCapiCallNode(
2863            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2864            args=args,
2865            may_return_none=False,
2866            is_temp=node.is_temp,
2867            result_is_used=False,
2868            utility_code=load_c_utility('append')
2869        )
2870
2871    def _handle_simple_method_list_extend(self, node, function, args, is_unbound_method):
2872        """Replace list.extend([...]) for short sequence literals values by sequential appends
2873        to avoid creating an intermediate sequence argument.
2874        """
2875        if len(args) != 2:
2876            return node
2877        obj, value = args
2878        if not value.is_sequence_constructor:
2879            return node
2880        items = list(value.args)
2881        if value.mult_factor is not None or len(items) > 8:
2882            # Appending wins for short sequences but slows down when multiple resize operations are needed.
2883            # This seems to be a good enough limit that avoids repeated resizing.
2884            if False and isinstance(value, ExprNodes.ListNode):
2885                # One would expect that tuples are more efficient here, but benchmarking with
2886                # Py3.5 and Py3.7 suggests that they are not. Probably worth revisiting at some point.
2887                # Might be related to the usage of PySequence_FAST() in CPython's list.extend(),
2888                # which is probably tuned more towards lists than tuples (and rightly so).
2889                tuple_node = args[1].as_tuple().analyse_types(self.current_env(), skip_children=True)
2890                Visitor.recursively_replace_node(node, args[1], tuple_node)
2891            return node
2892        wrapped_obj = self._wrap_self_arg(obj, function, is_unbound_method, 'extend')
2893        if not items:
2894            # Empty sequences are not likely to occur, but why waste a call to list.extend() for them?
2895            wrapped_obj.result_is_used = node.result_is_used
2896            return wrapped_obj
2897        cloned_obj = obj = wrapped_obj
2898        if len(items) > 1 and not obj.is_simple():
2899            cloned_obj = UtilNodes.LetRefNode(obj)
2900        # Use ListComp_Append() for all but the last item and finish with PyList_Append()
2901        # to shrink the list storage size at the very end if necessary.
2902        temps = []
2903        arg = items[-1]
2904        if not arg.is_simple():
2905            arg = UtilNodes.LetRefNode(arg)
2906            temps.append(arg)
2907        new_node = ExprNodes.PythonCapiCallNode(
2908            node.pos, "__Pyx_PyList_Append", self.PyObject_Append_func_type,
2909            args=[cloned_obj, arg],
2910            is_temp=True,
2911            utility_code=load_c_utility("ListAppend"))
2912        for arg in items[-2::-1]:
2913            if not arg.is_simple():
2914                arg = UtilNodes.LetRefNode(arg)
2915                temps.append(arg)
2916            new_node = ExprNodes.binop_node(
2917                node.pos, '|',
2918                ExprNodes.PythonCapiCallNode(
2919                    node.pos, "__Pyx_ListComp_Append", self.PyObject_Append_func_type,
2920                    args=[cloned_obj, arg], py_name="extend",
2921                    is_temp=True,
2922                    utility_code=load_c_utility("ListCompAppend")),
2923                new_node,
2924                type=PyrexTypes.c_returncode_type,
2925            )
2926        new_node.result_is_used = node.result_is_used
2927        if cloned_obj is not obj:
2928            temps.append(cloned_obj)
2929        for temp in temps:
2930            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
2931            new_node.result_is_used = node.result_is_used
2932        return new_node
2933
2934    PyByteArray_Append_func_type = PyrexTypes.CFuncType(
2935        PyrexTypes.c_returncode_type, [
2936            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
2937            PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
2938            ],
2939        exception_value="-1")
2940
2941    PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
2942        PyrexTypes.c_returncode_type, [
2943            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
2944            PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
2945            ],
2946        exception_value="-1")
2947
2948    def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
2949        if len(args) != 2:
2950            return node
2951        func_name = "__Pyx_PyByteArray_Append"
2952        func_type = self.PyByteArray_Append_func_type
2953
2954        value = unwrap_coerced_node(args[1])
2955        if value.type.is_int or isinstance(value, ExprNodes.IntNode):
2956            value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
2957            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
2958        elif value.is_string_literal:
2959            if not value.can_coerce_to_char_literal():
2960                return node
2961            value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
2962            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
2963        elif value.type.is_pyobject:
2964            func_name = "__Pyx_PyByteArray_AppendObject"
2965            func_type = self.PyByteArray_AppendObject_func_type
2966            utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
2967        else:
2968            return node
2969
2970        new_node = ExprNodes.PythonCapiCallNode(
2971            node.pos, func_name, func_type,
2972            args=[args[0], value],
2973            may_return_none=False,
2974            is_temp=node.is_temp,
2975            utility_code=utility_code,
2976        )
2977        if node.result_is_used:
2978            new_node = new_node.coerce_to(node.type, self.current_env())
2979        return new_node
2980
2981    PyObject_Pop_func_type = PyrexTypes.CFuncType(
2982        PyrexTypes.py_object_type, [
2983            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2984            ])
2985
2986    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2987        PyrexTypes.py_object_type, [
2988            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2989            PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None),
2990            PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None),
2991            PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None),
2992        ],
2993        has_varargs=True)  # to fake the additional macro args that lack a proper C type
2994
2995    def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
2996        return self._handle_simple_method_object_pop(
2997            node, function, args, is_unbound_method, is_list=True)
2998
2999    def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
3000        """Optimistic optimisation as X.pop([n]) is almost always
3001        referring to a list.
3002        """
3003        if not args:
3004            return node
3005        obj = args[0]
3006        if is_list:
3007            type_name = 'List'
3008            obj = obj.as_none_safe_node(
3009                "'NoneType' object has no attribute '%.30s'",
3010                error="PyExc_AttributeError",
3011                format_args=['pop'])
3012        else:
3013            type_name = 'Object'
3014        if len(args) == 1:
3015            return ExprNodes.PythonCapiCallNode(
3016                node.pos, "__Pyx_Py%s_Pop" % type_name,
3017                self.PyObject_Pop_func_type,
3018                args=[obj],
3019                may_return_none=True,
3020                is_temp=node.is_temp,
3021                utility_code=load_c_utility('pop'),
3022            )
3023        elif len(args) == 2:
3024            index = unwrap_coerced_node(args[1])
3025            py_index = ExprNodes.NoneNode(index.pos)
3026            orig_index_type = index.type
3027            if not index.type.is_int:
3028                if isinstance(index, ExprNodes.IntNode):
3029                    py_index = index.coerce_to_pyobject(self.current_env())
3030                    index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3031                elif is_list:
3032                    if index.type.is_pyobject:
3033                        py_index = index.coerce_to_simple(self.current_env())
3034                        index = ExprNodes.CloneNode(py_index)
3035                    index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3036                else:
3037                    return node
3038            elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type):
3039                return node
3040            elif isinstance(index, ExprNodes.IntNode):
3041                py_index = index.coerce_to_pyobject(self.current_env())
3042            # real type might still be larger at runtime
3043            if not orig_index_type.is_int:
3044                orig_index_type = index.type
3045            if not orig_index_type.create_to_py_utility_code(self.current_env()):
3046                return node
3047            convert_func = orig_index_type.to_py_function
3048            conversion_type = PyrexTypes.CFuncType(
3049                PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)])
3050            return ExprNodes.PythonCapiCallNode(
3051                node.pos, "__Pyx_Py%s_PopIndex" % type_name,
3052                self.PyObject_PopIndex_func_type,
3053                args=[obj, py_index, index,
3054                      ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0),
3055                                        constant_result=orig_index_type.signed and 1 or 0,
3056                                        type=PyrexTypes.c_int_type),
3057                      ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type,
3058                                                 orig_index_type.empty_declaration_code()),
3059                      ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)],
3060                may_return_none=True,
3061                is_temp=node.is_temp,
3062                utility_code=load_c_utility("pop_index"),
3063            )
3064
3065        return node
3066
3067    single_param_func_type = PyrexTypes.CFuncType(
3068        PyrexTypes.c_returncode_type, [
3069            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
3070            ],
3071        exception_value = "-1")
3072
3073    def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
3074        """Call PyList_Sort() instead of the 0-argument l.sort().
3075        """
3076        if len(args) != 1:
3077            return node
3078        return self._substitute_method_call(
3079            node, function, "PyList_Sort", self.single_param_func_type,
3080            'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
3081
3082    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
3083        PyrexTypes.py_object_type, [
3084            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
3085            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
3086            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3087            ])
3088
3089    def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
3090        """Replace dict.get() by a call to PyDict_GetItem().
3091        """
3092        if len(args) == 2:
3093            args.append(ExprNodes.NoneNode(node.pos))
3094        elif len(args) != 3:
3095            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
3096            return node
3097
3098        return self._substitute_method_call(
3099            node, function,
3100            "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
3101            'get', is_unbound_method, args,
3102            may_return_none = True,
3103            utility_code = load_c_utility("dict_getitem_default"))
3104
3105    Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
3106        PyrexTypes.py_object_type, [
3107            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
3108            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
3109            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3110            PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
3111            ])
3112
3113    def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
3114        """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
3115        """
3116        if len(args) == 2:
3117            args.append(ExprNodes.NoneNode(node.pos))
3118        elif len(args) != 3:
3119            self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
3120            return node
3121        key_type = args[1].type
3122        if key_type.is_builtin_type:
3123            is_safe_type = int(key_type.name in
3124                               'str bytes unicode float int long bool')
3125        elif key_type is PyrexTypes.py_object_type:
3126            is_safe_type = -1  # don't know
3127        else:
3128            is_safe_type = 0   # definitely not
3129        args.append(ExprNodes.IntNode(
3130            node.pos, value=str(is_safe_type), constant_result=is_safe_type))
3131
3132        return self._substitute_method_call(
3133            node, function,
3134            "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
3135            'setdefault', is_unbound_method, args,
3136            may_return_none=True,
3137            utility_code=load_c_utility('dict_setdefault'))
3138
3139    PyDict_Pop_func_type = PyrexTypes.CFuncType(
3140        PyrexTypes.py_object_type, [
3141            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
3142            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
3143            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3144            ])
3145
3146    def _handle_simple_method_dict_pop(self, node, function, args, is_unbound_method):
3147        """Replace dict.pop() by a call to _PyDict_Pop().
3148        """
3149        if len(args) == 2:
3150            args.append(ExprNodes.NullNode(node.pos))
3151        elif len(args) != 3:
3152            self._error_wrong_arg_count('dict.pop', node, args, "2 or 3")
3153            return node
3154
3155        return self._substitute_method_call(
3156            node, function,
3157            "__Pyx_PyDict_Pop", self.PyDict_Pop_func_type,
3158            'pop', is_unbound_method, args,
3159            may_return_none=True,
3160            utility_code=load_c_utility('py_dict_pop'))
3161
3162    Pyx_BinopInt_func_types = dict(
3163        ((ctype, ret_type), PyrexTypes.CFuncType(
3164            ret_type, [
3165                PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
3166                PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
3167                PyrexTypes.CFuncTypeArg("cval", ctype, None),
3168                PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None),
3169                PyrexTypes.CFuncTypeArg("zerodiv_check", PyrexTypes.c_bint_type, None),
3170            ], exception_value=None if ret_type.is_pyobject else ret_type.exception_value))
3171        for ctype in (PyrexTypes.c_long_type, PyrexTypes.c_double_type)
3172        for ret_type in (PyrexTypes.py_object_type, PyrexTypes.c_bint_type)
3173        )
3174
3175    def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method):
3176        return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
3177
3178    def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method):
3179        return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
3180
3181    def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method):
3182        return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
3183
3184    def _handle_simple_method_object___ne__(self, node, function, args, is_unbound_method):
3185        return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
3186
3187    def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method):
3188        return self._optimise_num_binop('And', node, function, args, is_unbound_method)
3189
3190    def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method):
3191        return self._optimise_num_binop('Or', node, function, args, is_unbound_method)
3192
3193    def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method):
3194        return self._optimise_num_binop('Xor', node, function, args, is_unbound_method)
3195
3196    def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method):
3197        if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
3198            return node
3199        if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
3200            return node
3201        return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method)
3202
3203    def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method):
3204        if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
3205            return node
3206        if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
3207            return node
3208        return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method)
3209
3210    def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method):
3211        return self._optimise_num_div('Remainder', node, function, args, is_unbound_method)
3212
3213    def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method):
3214        return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method)
3215
3216    def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method):
3217        return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method)
3218
3219    def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method):
3220        return self._optimise_num_div('Divide', node, function, args, is_unbound_method)
3221
3222    def _optimise_num_div(self, operator, node, function, args, is_unbound_method):
3223        if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0:
3224            return node
3225        if isinstance(args[1], ExprNodes.IntNode):
3226            if not (-2**30 <= args[1].constant_result <= 2**30):
3227                return node
3228        elif isinstance(args[1], ExprNodes.FloatNode):
3229            if not (-2**53 <= args[1].constant_result <= 2**53):
3230                return node
3231        else:
3232            return node
3233        return self._optimise_num_binop(operator, node, function, args, is_unbound_method)
3234
3235    def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method):
3236        return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
3237
3238    def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method):
3239        return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
3240
3241    def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method):
3242        return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method)
3243
3244    def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method):
3245        return self._optimise_num_binop('Divide', node, function, args, is_unbound_method)
3246
3247    def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method):
3248        return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method)
3249
3250    def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method):
3251        return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
3252
3253    def _handle_simple_method_float___ne__(self, node, function, args, is_unbound_method):
3254        return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
3255
3256    def _optimise_num_binop(self, operator, node, function, args, is_unbound_method):
3257        """
3258        Optimise math operators for (likely) float or small integer operations.
3259        """
3260        if len(args) != 2:
3261            return node
3262
3263        if node.type.is_pyobject:
3264            ret_type = PyrexTypes.py_object_type
3265        elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'):
3266            ret_type = PyrexTypes.c_bint_type
3267        else:
3268            return node
3269
3270        # When adding IntNode/FloatNode to something else, assume other operand is also numeric.
3271        # Prefer constants on RHS as they allows better size control for some operators.
3272        num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode)
3273        if isinstance(args[1], num_nodes):
3274            if args[0].type is not PyrexTypes.py_object_type:
3275                return node
3276            numval = args[1]
3277            arg_order = 'ObjC'
3278        elif isinstance(args[0], num_nodes):
3279            if args[1].type is not PyrexTypes.py_object_type:
3280                return node
3281            numval = args[0]
3282            arg_order = 'CObj'
3283        else:
3284            return node
3285
3286        if not numval.has_constant_result():
3287            return node
3288
3289        is_float = isinstance(numval, ExprNodes.FloatNode)
3290        num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type
3291        if is_float:
3292            if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'):
3293                return node
3294        elif operator == 'Divide':
3295            # mixed old-/new-style division is not currently optimised for integers
3296            return node
3297        elif abs(numval.constant_result) > 2**30:
3298            # Cut off at an integer border that is still safe for all operations.
3299            return node
3300
3301        if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'):
3302            if args[1].constant_result == 0:
3303                # Don't optimise division by 0. :)
3304                return node
3305
3306        args = list(args)
3307        args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)(
3308            numval.pos, value=numval.value, constant_result=numval.constant_result,
3309            type=num_type))
3310        inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False
3311        args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
3312        if is_float or operator not in ('Eq', 'Ne'):
3313            # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument.
3314            zerodivision_check = arg_order == 'CObj' and (
3315                not node.cdivision if isinstance(node, ExprNodes.DivNode) else False)
3316            args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check))
3317
3318        utility_code = TempitaUtilityCode.load_cached(
3319            "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop",
3320            "Optimize.c",
3321            context=dict(op=operator, order=arg_order, ret_type=ret_type))
3322
3323        call_node = self._substitute_method_call(
3324            node, function,
3325            "__Pyx_Py%s_%s%s%s" % (
3326                'Float' if is_float else 'Int',
3327                '' if ret_type.is_pyobject else 'Bool',
3328                operator,
3329                arg_order),
3330            self.Pyx_BinopInt_func_types[(num_type, ret_type)],
3331            '__%s__' % operator[:3].lower(), is_unbound_method, args,
3332            may_return_none=True,
3333            with_none_check=False,
3334            utility_code=utility_code)
3335
3336        if node.type.is_pyobject and not ret_type.is_pyobject:
3337            call_node = ExprNodes.CoerceToPyTypeNode(call_node, self.current_env(), node.type)
3338        return call_node
3339
3340    ### unicode type methods
3341
3342    PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
3343        PyrexTypes.c_bint_type, [
3344            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
3345            ])
3346
3347    def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
3348        if is_unbound_method or len(args) != 1:
3349            return node
3350        ustring = args[0]
3351        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
3352               not ustring.arg.type.is_unicode_char:
3353            return node
3354        uchar = ustring.arg
3355        method_name = function.attribute
3356        if method_name == 'istitle':
3357            # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
3358            utility_code = UtilityCode.load_cached(
3359                "py_unicode_istitle", "StringTools.c")
3360            function_name = '__Pyx_Py_UNICODE_ISTITLE'
3361        else:
3362            utility_code = None
3363            function_name = 'Py_UNICODE_%s' % method_name.upper()
3364        func_call = self._substitute_method_call(
3365            node, function,
3366            function_name, self.PyUnicode_uchar_predicate_func_type,
3367            method_name, is_unbound_method, [uchar],
3368            utility_code = utility_code)
3369        if node.type.is_pyobject:
3370            func_call = func_call.coerce_to_pyobject(self.current_env)
3371        return func_call
3372
3373    _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
3374    _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
3375    _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
3376    _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
3377    _handle_simple_method_unicode_islower   = _inject_unicode_predicate
3378    _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
3379    _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
3380    _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
3381    _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
3382
3383    PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
3384        PyrexTypes.c_py_ucs4_type, [
3385            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
3386            ])
3387
3388    def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
3389        if is_unbound_method or len(args) != 1:
3390            return node
3391        ustring = args[0]
3392        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
3393               not ustring.arg.type.is_unicode_char:
3394            return node
3395        uchar = ustring.arg
3396        method_name = function.attribute
3397        function_name = 'Py_UNICODE_TO%s' % method_name.upper()
3398        func_call = self._substitute_method_call(
3399            node, function,
3400            function_name, self.PyUnicode_uchar_conversion_func_type,
3401            method_name, is_unbound_method, [uchar])
3402        if node.type.is_pyobject:
3403            func_call = func_call.coerce_to_pyobject(self.current_env)
3404        return func_call
3405
3406    _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
3407    _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
3408    _handle_simple_method_unicode_title = _inject_unicode_character_conversion
3409
3410    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
3411        Builtin.list_type, [
3412            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3413            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
3414            ])
3415
3416    def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
3417        """Replace unicode.splitlines(...) by a direct call to the
3418        corresponding C-API function.
3419        """
3420        if len(args) not in (1,2):
3421            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
3422            return node
3423        self._inject_bint_default_argument(node, args, 1, False)
3424
3425        return self._substitute_method_call(
3426            node, function,
3427            "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
3428            'splitlines', is_unbound_method, args)
3429
3430    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
3431        Builtin.list_type, [
3432            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3433            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
3434            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
3435            ]
3436        )
3437
3438    def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
3439        """Replace unicode.split(...) by a direct call to the
3440        corresponding C-API function.
3441        """
3442        if len(args) not in (1,2,3):
3443            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
3444            return node
3445        if len(args) < 2:
3446            args.append(ExprNodes.NullNode(node.pos))
3447        self._inject_int_default_argument(
3448            node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
3449
3450        return self._substitute_method_call(
3451            node, function,
3452            "PyUnicode_Split", self.PyUnicode_Split_func_type,
3453            'split', is_unbound_method, args)
3454
3455    PyUnicode_Join_func_type = PyrexTypes.CFuncType(
3456        Builtin.unicode_type, [
3457            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3458            PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None),
3459            ])
3460
3461    def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method):
3462        """
3463        unicode.join() builds a list first => see if we can do this more efficiently
3464        """
3465        if len(args) != 2:
3466            self._error_wrong_arg_count('unicode.join', node, args, "2")
3467            return node
3468        if isinstance(args[1], ExprNodes.GeneratorExpressionNode):
3469            gen_expr_node = args[1]
3470            loop_node = gen_expr_node.loop
3471
3472            yield_statements = _find_yield_statements(loop_node)
3473            if yield_statements:
3474                inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
3475                    node.pos, gen_expr_node, orig_func='list',
3476                    comprehension_type=Builtin.list_type)
3477
3478                for yield_expression, yield_stat_node in yield_statements:
3479                    append_node = ExprNodes.ComprehensionAppendNode(
3480                        yield_expression.pos,
3481                        expr=yield_expression,
3482                        target=inlined_genexpr.target)
3483
3484                    Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
3485
3486                args[1] = inlined_genexpr
3487
3488        return self._substitute_method_call(
3489            node, function,
3490            "PyUnicode_Join", self.PyUnicode_Join_func_type,
3491            'join', is_unbound_method, args)
3492
3493    PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
3494        PyrexTypes.c_bint_type, [
3495            PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None),  # bytes/str/unicode
3496            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
3497            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3498            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
3499            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
3500            ],
3501        exception_value = '-1')
3502
3503    def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
3504        return self._inject_tailmatch(
3505            node, function, args, is_unbound_method, 'unicode', 'endswith',
3506            unicode_tailmatch_utility_code, +1)
3507
3508    def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
3509        return self._inject_tailmatch(
3510            node, function, args, is_unbound_method, 'unicode', 'startswith',
3511            unicode_tailmatch_utility_code, -1)
3512
3513    def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
3514                          method_name, utility_code, direction):
3515        """Replace unicode.startswith(...) and unicode.endswith(...)
3516        by a direct call to the corresponding C-API function.
3517        """
3518        if len(args) not in (2,3,4):
3519            self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
3520            return node
3521        self._inject_int_default_argument(
3522            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
3523        self._inject_int_default_argument(
3524            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
3525        args.append(ExprNodes.IntNode(
3526            node.pos, value=str(direction), type=PyrexTypes.c_int_type))
3527
3528        method_call = self._substitute_method_call(
3529            node, function,
3530            "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
3531            self.PyString_Tailmatch_func_type,
3532            method_name, is_unbound_method, args,
3533            utility_code = utility_code)
3534        return method_call.coerce_to(Builtin.bool_type, self.current_env())
3535
3536    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
3537        PyrexTypes.c_py_ssize_t_type, [
3538            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3539            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
3540            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3541            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
3542            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
3543            ],
3544        exception_value = '-2')
3545
3546    def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
3547        return self._inject_unicode_find(
3548            node, function, args, is_unbound_method, 'find', +1)
3549
3550    def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
3551        return self._inject_unicode_find(
3552            node, function, args, is_unbound_method, 'rfind', -1)
3553
3554    def _inject_unicode_find(self, node, function, args, is_unbound_method,
3555                             method_name, direction):
3556        """Replace unicode.find(...) and unicode.rfind(...) by a
3557        direct call to the corresponding C-API function.
3558        """
3559        if len(args) not in (2,3,4):
3560            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
3561            return node
3562        self._inject_int_default_argument(
3563            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
3564        self._inject_int_default_argument(
3565            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
3566        args.append(ExprNodes.IntNode(
3567            node.pos, value=str(direction), type=PyrexTypes.c_int_type))
3568
3569        method_call = self._substitute_method_call(
3570            node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
3571            method_name, is_unbound_method, args)
3572        return method_call.coerce_to_pyobject(self.current_env())
3573
3574    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
3575        PyrexTypes.c_py_ssize_t_type, [
3576            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3577            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
3578            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3579            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
3580            ],
3581        exception_value = '-1')
3582
3583    def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
3584        """Replace unicode.count(...) by a direct call to the
3585        corresponding C-API function.
3586        """
3587        if len(args) not in (2,3,4):
3588            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
3589            return node
3590        self._inject_int_default_argument(
3591            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
3592        self._inject_int_default_argument(
3593            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
3594
3595        method_call = self._substitute_method_call(
3596            node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
3597            'count', is_unbound_method, args)
3598        return method_call.coerce_to_pyobject(self.current_env())
3599
3600    PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
3601        Builtin.unicode_type, [
3602            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3603            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
3604            PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
3605            PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
3606            ])
3607
3608    def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
3609        """Replace unicode.replace(...) by a direct call to the
3610        corresponding C-API function.
3611        """
3612        if len(args) not in (3,4):
3613            self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
3614            return node
3615        self._inject_int_default_argument(
3616            node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
3617
3618        return self._substitute_method_call(
3619            node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
3620            'replace', is_unbound_method, args)
3621
3622    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
3623        Builtin.bytes_type, [
3624            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
3625            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
3626            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3627            ])
3628
3629    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
3630        Builtin.bytes_type, [
3631            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
3632            ])
3633
3634    _special_encodings = ['UTF8', 'UTF16', 'UTF-16LE', 'UTF-16BE', 'Latin1', 'ASCII',
3635                          'unicode_escape', 'raw_unicode_escape']
3636
3637    _special_codecs = [ (name, codecs.getencoder(name))
3638                        for name in _special_encodings ]
3639
3640    def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
3641        """Replace unicode.encode(...) by a direct C-API call to the
3642        corresponding codec.
3643        """
3644        if len(args) < 1 or len(args) > 3:
3645            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
3646            return node
3647
3648        string_node = args[0]
3649
3650        if len(args) == 1:
3651            null_node = ExprNodes.NullNode(node.pos)
3652            return self._substitute_method_call(
3653                node, function, "PyUnicode_AsEncodedString",
3654                self.PyUnicode_AsEncodedString_func_type,
3655                'encode', is_unbound_method, [string_node, null_node, null_node])
3656
3657        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
3658        if parameters is None:
3659            return node
3660        encoding, encoding_node, error_handling, error_handling_node = parameters
3661
3662        if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
3663            # constant, so try to do the encoding at compile time
3664            try:
3665                value = string_node.value.encode(encoding, error_handling)
3666            except:
3667                # well, looks like we can't
3668                pass
3669            else:
3670                value = bytes_literal(value, encoding)
3671                return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)
3672
3673        if encoding and error_handling == 'strict':
3674            # try to find a specific encoder function
3675            codec_name = self._find_special_codec_name(encoding)
3676            if codec_name is not None and '-' not in codec_name:
3677                encode_function = "PyUnicode_As%sString" % codec_name
3678                return self._substitute_method_call(
3679                    node, function, encode_function,
3680                    self.PyUnicode_AsXyzString_func_type,
3681                    'encode', is_unbound_method, [string_node])
3682
3683        return self._substitute_method_call(
3684            node, function, "PyUnicode_AsEncodedString",
3685            self.PyUnicode_AsEncodedString_func_type,
3686            'encode', is_unbound_method,
3687            [string_node, encoding_node, error_handling_node])
3688
3689    PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
3690        Builtin.unicode_type, [
3691            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
3692            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
3693            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3694        ]))
3695
3696    _decode_c_string_func_type = PyrexTypes.CFuncType(
3697        Builtin.unicode_type, [
3698            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
3699            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3700            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3701            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
3702            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3703            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
3704        ])
3705
3706    _decode_bytes_func_type = PyrexTypes.CFuncType(
3707        Builtin.unicode_type, [
3708            PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
3709            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3710            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3711            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
3712            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3713            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
3714        ])
3715
3716    _decode_cpp_string_func_type = None  # lazy init
3717
3718    def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
3719        """Replace char*.decode() by a direct C-API call to the
3720        corresponding codec, possibly resolving a slice on the char*.
3721        """
3722        if not (1 <= len(args) <= 3):
3723            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
3724            return node
3725
3726        # normalise input nodes
3727        string_node = args[0]
3728        start = stop = None
3729        if isinstance(string_node, ExprNodes.SliceIndexNode):
3730            index_node = string_node
3731            string_node = index_node.base
3732            start, stop = index_node.start, index_node.stop
3733            if not start or start.constant_result == 0:
3734                start = None
3735        if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
3736            string_node = string_node.arg
3737
3738        string_type = string_node.type
3739        if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
3740            if is_unbound_method:
3741                string_node = string_node.as_none_safe_node(
3742                    "descriptor '%s' requires a '%s' object but received a 'NoneType'",
3743                    format_args=['decode', string_type.name])
3744            else:
3745                string_node = string_node.as_none_safe_node(
3746                    "'NoneType' object has no attribute '%.30s'",
3747                    error="PyExc_AttributeError",
3748                    format_args=['decode'])
3749        elif not string_type.is_string and not string_type.is_cpp_string:
3750            # nothing to optimise here
3751            return node
3752
3753        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
3754        if parameters is None:
3755            return node
3756        encoding, encoding_node, error_handling, error_handling_node = parameters
3757
3758        if not start:
3759            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
3760        elif not start.type.is_int:
3761            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3762        if stop and not stop.type.is_int:
3763            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3764
3765        # try to find a specific encoder function
3766        codec_name = None
3767        if encoding is not None:
3768            codec_name = self._find_special_codec_name(encoding)
3769        if codec_name is not None:
3770            if codec_name in ('UTF16', 'UTF-16LE', 'UTF-16BE'):
3771                codec_cname = "__Pyx_PyUnicode_Decode%s" % codec_name.replace('-', '')
3772            else:
3773                codec_cname = "PyUnicode_Decode%s" % codec_name
3774            decode_function = ExprNodes.RawCNameExprNode(
3775                node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, cname=codec_cname)
3776            encoding_node = ExprNodes.NullNode(node.pos)
3777        else:
3778            decode_function = ExprNodes.NullNode(node.pos)
3779
3780        # build the helper function call
3781        temps = []
3782        if string_type.is_string:
3783            # C string
3784            if not stop:
3785                # use strlen() to find the string length, just as CPython would
3786                if not string_node.is_name:
3787                    string_node = UtilNodes.LetRefNode(string_node) # used twice
3788                    temps.append(string_node)
3789                stop = ExprNodes.PythonCapiCallNode(
3790                    string_node.pos, "strlen", self.Pyx_strlen_func_type,
3791                    args=[string_node],
3792                    is_temp=False,
3793                    utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
3794                ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3795            helper_func_type = self._decode_c_string_func_type
3796            utility_code_name = 'decode_c_string'
3797        elif string_type.is_cpp_string:
3798            # C++ std::string
3799            if not stop:
3800                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
3801                                         constant_result=ExprNodes.not_a_constant)
3802            if self._decode_cpp_string_func_type is None:
3803                # lazy init to reuse the C++ string type
3804                self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
3805                    Builtin.unicode_type, [
3806                        PyrexTypes.CFuncTypeArg("string", string_type, None),
3807                        PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3808                        PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3809                        PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
3810                        PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3811                        PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
3812                    ])
3813            helper_func_type = self._decode_cpp_string_func_type
3814            utility_code_name = 'decode_cpp_string'
3815        else:
3816            # Python bytes/bytearray object
3817            if not stop:
3818                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
3819                                         constant_result=ExprNodes.not_a_constant)
3820            helper_func_type = self._decode_bytes_func_type
3821            if string_type is Builtin.bytes_type:
3822                utility_code_name = 'decode_bytes'
3823            else:
3824                utility_code_name = 'decode_bytearray'
3825
3826        node = ExprNodes.PythonCapiCallNode(
3827            node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
3828            args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
3829            is_temp=node.is_temp,
3830            utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
3831        )
3832
3833        for temp in temps[::-1]:
3834            node = UtilNodes.EvalWithTempExprNode(temp, node)
3835        return node
3836
3837    _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode
3838
3839    def _find_special_codec_name(self, encoding):
3840        try:
3841            requested_codec = codecs.getencoder(encoding)
3842        except LookupError:
3843            return None
3844        for name, codec in self._special_codecs:
3845            if codec == requested_codec:
3846                if '_' in name:
3847                    name = ''.join([s.capitalize()
3848                                    for s in name.split('_')])
3849                return name
3850        return None
3851
3852    def _unpack_encoding_and_error_mode(self, pos, args):
3853        null_node = ExprNodes.NullNode(pos)
3854
3855        if len(args) >= 2:
3856            encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
3857            if encoding_node is None:
3858                return None
3859        else:
3860            encoding = None
3861            encoding_node = null_node
3862
3863        if len(args) == 3:
3864            error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
3865            if error_handling_node is None:
3866                return None
3867            if error_handling == 'strict':
3868                error_handling_node = null_node
3869        else:
3870            error_handling = 'strict'
3871            error_handling_node = null_node
3872
3873        return (encoding, encoding_node, error_handling, error_handling_node)
3874
3875    def _unpack_string_and_cstring_node(self, node):
3876        if isinstance(node, ExprNodes.CoerceToPyTypeNode):
3877            node = node.arg
3878        if isinstance(node, ExprNodes.UnicodeNode):
3879            encoding = node.value
3880            node = ExprNodes.BytesNode(
3881                node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type)
3882        elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
3883            encoding = node.value.decode('ISO-8859-1')
3884            node = ExprNodes.BytesNode(
3885                node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type)
3886        elif node.type is Builtin.bytes_type:
3887            encoding = None
3888            node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env())
3889        elif node.type.is_string:
3890            encoding = None
3891        else:
3892            encoding = node = None
3893        return encoding, node
3894
3895    def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method):
3896        return self._inject_tailmatch(
3897            node, function, args, is_unbound_method, 'str', 'endswith',
3898            str_tailmatch_utility_code, +1)
3899
3900    def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method):
3901        return self._inject_tailmatch(
3902            node, function, args, is_unbound_method, 'str', 'startswith',
3903            str_tailmatch_utility_code, -1)
3904
3905    def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
3906        return self._inject_tailmatch(
3907            node, function, args, is_unbound_method, 'bytes', 'endswith',
3908            bytes_tailmatch_utility_code, +1)
3909
3910    def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
3911        return self._inject_tailmatch(
3912            node, function, args, is_unbound_method, 'bytes', 'startswith',
3913            bytes_tailmatch_utility_code, -1)
3914
3915    '''   # disabled for now, enable when we consider it worth it (see StringTools.c)
3916    def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
3917        return self._inject_tailmatch(
3918            node, function, args, is_unbound_method, 'bytearray', 'endswith',
3919            bytes_tailmatch_utility_code, +1)
3920
3921    def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
3922        return self._inject_tailmatch(
3923            node, function, args, is_unbound_method, 'bytearray', 'startswith',
3924            bytes_tailmatch_utility_code, -1)
3925    '''
3926
3927    ### helpers
3928
3929    def _substitute_method_call(self, node, function, name, func_type,
3930                                attr_name, is_unbound_method, args=(),
3931                                utility_code=None, is_temp=None,
3932                                may_return_none=ExprNodes.PythonCapiCallNode.may_return_none,
3933                                with_none_check=True):
3934        args = list(args)
3935        if with_none_check and args:
3936            args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name)
3937        if is_temp is None:
3938            is_temp = node.is_temp
3939        return ExprNodes.PythonCapiCallNode(
3940            node.pos, name, func_type,
3941            args = args,
3942            is_temp = is_temp,
3943            utility_code = utility_code,
3944            may_return_none = may_return_none,
3945            result_is_used = node.result_is_used,
3946            )
3947
3948    def _wrap_self_arg(self, self_arg, function, is_unbound_method, attr_name):
3949        if self_arg.is_literal:
3950            return self_arg
3951        if is_unbound_method:
3952            self_arg = self_arg.as_none_safe_node(
3953                "descriptor '%s' requires a '%s' object but received a 'NoneType'",
3954                format_args=[attr_name, self_arg.type.name])
3955        else:
3956            self_arg = self_arg.as_none_safe_node(
3957                "'NoneType' object has no attribute '%{0}s'".format('.30' if len(attr_name) <= 30 else ''),
3958                error="PyExc_AttributeError",
3959                format_args=[attr_name])
3960        return self_arg
3961
3962    def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
3963        assert len(args) >= arg_index
3964        if len(args) == arg_index:
3965            args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
3966                                          type=type, constant_result=default_value))
3967        else:
3968            args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
3969
3970    def _inject_bint_default_argument(self, node, args, arg_index, default_value):
3971        assert len(args) >= arg_index
3972        if len(args) == arg_index:
3973            default_value = bool(default_value)
3974            args.append(ExprNodes.BoolNode(node.pos, value=default_value,
3975                                           constant_result=default_value))
3976        else:
3977            args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
3978
3979
3980unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
3981bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
3982str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
3983
3984
3985class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
3986    """Calculate the result of constant expressions to store it in
3987    ``expr_node.constant_result``, and replace trivial cases by their
3988    constant result.
3989
3990    General rules:
3991
3992    - We calculate float constants to make them available to the
3993      compiler, but we do not aggregate them into a single literal
3994      node to prevent any loss of precision.
3995
3996    - We recursively calculate constants from non-literal nodes to
3997      make them available to the compiler, but we only aggregate
3998      literal nodes at each step.  Non-literal nodes are never merged
3999      into a single node.
4000    """
4001
4002    def __init__(self, reevaluate=False):
4003        """
4004        The reevaluate argument specifies whether constant values that were
4005        previously computed should be recomputed.
4006        """
4007        super(ConstantFolding, self).__init__()
4008        self.reevaluate = reevaluate
4009
4010    def _calculate_const(self, node):
4011        if (not self.reevaluate and
4012                node.constant_result is not ExprNodes.constant_value_not_set):
4013            return
4014
4015        # make sure we always set the value
4016        not_a_constant = ExprNodes.not_a_constant
4017        node.constant_result = not_a_constant
4018
4019        # check if all children are constant
4020        children = self.visitchildren(node)
4021        for child_result in children.values():
4022            if type(child_result) is list:
4023                for child in child_result:
4024                    if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
4025                        return
4026            elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
4027                return
4028
4029        # now try to calculate the real constant value
4030        try:
4031            node.calculate_constant_result()
4032#            if node.constant_result is not ExprNodes.not_a_constant:
4033#                print node.__class__.__name__, node.constant_result
4034        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
4035            # ignore all 'normal' errors here => no constant result
4036            pass
4037        except Exception:
4038            # this looks like a real error
4039            import traceback, sys
4040            traceback.print_exc(file=sys.stdout)
4041
4042    NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
4043                       ExprNodes.IntNode, ExprNodes.FloatNode]
4044
4045    def _widest_node_class(self, *nodes):
4046        try:
4047            return self.NODE_TYPE_ORDER[
4048                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
4049        except ValueError:
4050            return None
4051
4052    def _bool_node(self, node, value):
4053        value = bool(value)
4054        return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)
4055
4056    def visit_ExprNode(self, node):
4057        self._calculate_const(node)
4058        return node
4059
4060    def visit_UnopNode(self, node):
4061        self._calculate_const(node)
4062        if not node.has_constant_result():
4063            if node.operator == '!':
4064                return self._handle_NotNode(node)
4065            return node
4066        if not node.operand.is_literal:
4067            return node
4068        if node.operator == '!':
4069            return self._bool_node(node, node.constant_result)
4070        elif isinstance(node.operand, ExprNodes.BoolNode):
4071            return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
4072                                     type=PyrexTypes.c_int_type,
4073                                     constant_result=int(node.constant_result))
4074        elif node.operator == '+':
4075            return self._handle_UnaryPlusNode(node)
4076        elif node.operator == '-':
4077            return self._handle_UnaryMinusNode(node)
4078        return node
4079
4080    _negate_operator = {
4081        'in': 'not_in',
4082        'not_in': 'in',
4083        'is': 'is_not',
4084        'is_not': 'is'
4085    }.get
4086
4087    def _handle_NotNode(self, node):
4088        operand = node.operand
4089        if isinstance(operand, ExprNodes.PrimaryCmpNode):
4090            operator = self._negate_operator(operand.operator)
4091            if operator:
4092                node = copy.copy(operand)
4093                node.operator = operator
4094                node = self.visit_PrimaryCmpNode(node)
4095        return node
4096
4097    def _handle_UnaryMinusNode(self, node):
4098        def _negate(value):
4099            if value.startswith('-'):
4100                value = value[1:]
4101            else:
4102                value = '-' + value
4103            return value
4104
4105        node_type = node.operand.type
4106        if isinstance(node.operand, ExprNodes.FloatNode):
4107            # this is a safe operation
4108            return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
4109                                       type=node_type,
4110                                       constant_result=node.constant_result)
4111        if node_type.is_int and node_type.signed or \
4112                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
4113            return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
4114                                     type=node_type,
4115                                     longness=node.operand.longness,
4116                                     constant_result=node.constant_result)
4117        return node
4118
4119    def _handle_UnaryPlusNode(self, node):
4120        if (node.operand.has_constant_result() and
4121                    node.constant_result == node.operand.constant_result):
4122            return node.operand
4123        return node
4124
4125    def visit_BoolBinopNode(self, node):
4126        self._calculate_const(node)
4127        if not node.operand1.has_constant_result():
4128            return node
4129        if node.operand1.constant_result:
4130            if node.operator == 'and':
4131                return node.operand2
4132            else:
4133                return node.operand1
4134        else:
4135            if node.operator == 'and':
4136                return node.operand1
4137            else:
4138                return node.operand2
4139
4140    def visit_BinopNode(self, node):
4141        self._calculate_const(node)
4142        if node.constant_result is ExprNodes.not_a_constant:
4143            return node
4144        if isinstance(node.constant_result, float):
4145            return node
4146        operand1, operand2 = node.operand1, node.operand2
4147        if not operand1.is_literal or not operand2.is_literal:
4148            return node
4149
4150        # now inject a new constant node with the calculated value
4151        try:
4152            type1, type2 = operand1.type, operand2.type
4153            if type1 is None or type2 is None:
4154                return node
4155        except AttributeError:
4156            return node
4157
4158        if type1.is_numeric and type2.is_numeric:
4159            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
4160        else:
4161            widest_type = PyrexTypes.py_object_type
4162
4163        target_class = self._widest_node_class(operand1, operand2)
4164        if target_class is None:
4165            return node
4166        elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
4167            # C arithmetic results in at least an int type
4168            target_class = ExprNodes.IntNode
4169        elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
4170            # C arithmetic results in at least an int type
4171            target_class = ExprNodes.IntNode
4172
4173        if target_class is ExprNodes.IntNode:
4174            unsigned = getattr(operand1, 'unsigned', '') and \
4175                       getattr(operand2, 'unsigned', '')
4176            longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
4177                                 len(getattr(operand2, 'longness', '')))]
4178            new_node = ExprNodes.IntNode(pos=node.pos,
4179                                         unsigned=unsigned, longness=longness,
4180                                         value=str(int(node.constant_result)),
4181                                         constant_result=int(node.constant_result))
4182            # IntNode is smart about the type it chooses, so we just
4183            # make sure we were not smarter this time
4184            if widest_type.is_pyobject or new_node.type.is_pyobject:
4185                new_node.type = PyrexTypes.py_object_type
4186            else:
4187                new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
4188        else:
4189            if target_class is ExprNodes.BoolNode:
4190                node_value = node.constant_result
4191            else:
4192                node_value = str(node.constant_result)
4193            new_node = target_class(pos=node.pos, type = widest_type,
4194                                    value = node_value,
4195                                    constant_result = node.constant_result)
4196        return new_node
4197
4198    def visit_AddNode(self, node):
4199        self._calculate_const(node)
4200        if node.constant_result is ExprNodes.not_a_constant:
4201            return node
4202        if node.operand1.is_string_literal and node.operand2.is_string_literal:
4203            # some people combine string literals with a '+'
4204            str1, str2 = node.operand1, node.operand2
4205            if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode):
4206                bytes_value = None
4207                if str1.bytes_value is not None and str2.bytes_value is not None:
4208                    if str1.bytes_value.encoding == str2.bytes_value.encoding:
4209                        bytes_value = bytes_literal(
4210                            str1.bytes_value + str2.bytes_value,
4211                            str1.bytes_value.encoding)
4212                string_value = EncodedString(node.constant_result)
4213                return ExprNodes.UnicodeNode(
4214                    str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value)
4215            elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode):
4216                if str1.value.encoding == str2.value.encoding:
4217                    bytes_value = bytes_literal(node.constant_result, str1.value.encoding)
4218                    return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result)
4219            # all other combinations are rather complicated
4220            # to get right in Py2/3: encodings, unicode escapes, ...
4221        return self.visit_BinopNode(node)
4222
4223    def visit_MulNode(self, node):
4224        self._calculate_const(node)
4225        if node.operand1.is_sequence_constructor:
4226            return self._calculate_constant_seq(node, node.operand1, node.operand2)
4227        if isinstance(node.operand1, ExprNodes.IntNode) and \
4228                node.operand2.is_sequence_constructor:
4229            return self._calculate_constant_seq(node, node.operand2, node.operand1)
4230        if node.operand1.is_string_literal:
4231            return self._multiply_string(node, node.operand1, node.operand2)
4232        elif node.operand2.is_string_literal:
4233            return self._multiply_string(node, node.operand2, node.operand1)
4234        return self.visit_BinopNode(node)
4235
4236    def _multiply_string(self, node, string_node, multiplier_node):
4237        multiplier = multiplier_node.constant_result
4238        if not isinstance(multiplier, _py_int_types):
4239            return node
4240        if not (node.has_constant_result() and isinstance(node.constant_result, _py_string_types)):
4241            return node
4242        if len(node.constant_result) > 256:
4243            # Too long for static creation, leave it to runtime.  (-> arbitrary limit)
4244            return node
4245
4246        build_string = encoded_string
4247        if isinstance(string_node, ExprNodes.BytesNode):
4248            build_string = bytes_literal
4249        elif isinstance(string_node, ExprNodes.StringNode):
4250            if string_node.unicode_value is not None:
4251                string_node.unicode_value = encoded_string(
4252                    string_node.unicode_value * multiplier,
4253                    string_node.unicode_value.encoding)
4254            build_string = encoded_string if string_node.value.is_unicode else bytes_literal
4255        elif isinstance(string_node, ExprNodes.UnicodeNode):
4256            if string_node.bytes_value is not None:
4257                string_node.bytes_value = bytes_literal(
4258                    string_node.bytes_value * multiplier,
4259                    string_node.bytes_value.encoding)
4260        else:
4261            assert False, "unknown string node type: %s" % type(string_node)
4262        string_node.value = build_string(
4263            string_node.value * multiplier,
4264            string_node.value.encoding)
4265        # follow constant-folding and use unicode_value in preference
4266        if isinstance(string_node, ExprNodes.StringNode) and string_node.unicode_value is not None:
4267            string_node.constant_result = string_node.unicode_value
4268        else:
4269            string_node.constant_result = string_node.value
4270        return string_node
4271
4272    def _calculate_constant_seq(self, node, sequence_node, factor):
4273        if factor.constant_result != 1 and sequence_node.args:
4274            if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0:
4275                del sequence_node.args[:]
4276                sequence_node.mult_factor = None
4277            elif sequence_node.mult_factor is not None:
4278                if (isinstance(factor.constant_result, _py_int_types) and
4279                        isinstance(sequence_node.mult_factor.constant_result, _py_int_types)):
4280                    value = sequence_node.mult_factor.constant_result * factor.constant_result
4281                    sequence_node.mult_factor = ExprNodes.IntNode(
4282                        sequence_node.mult_factor.pos,
4283                        value=str(value), constant_result=value)
4284                else:
4285                    # don't know if we can combine the factors, so don't
4286                    return self.visit_BinopNode(node)
4287            else:
4288                sequence_node.mult_factor = factor
4289        return sequence_node
4290
4291    def visit_ModNode(self, node):
4292        self.visitchildren(node)
4293        if isinstance(node.operand1, ExprNodes.UnicodeNode) and isinstance(node.operand2, ExprNodes.TupleNode):
4294            if not node.operand2.mult_factor:
4295                fstring = self._build_fstring(node.operand1.pos, node.operand1.value, node.operand2.args)
4296                if fstring is not None:
4297                    return fstring
4298        return self.visit_BinopNode(node)
4299
4300    _parse_string_format_regex = (
4301        u'(%(?:'              # %...
4302        u'(?:[-0-9]+|[ ])?'   # width (optional) or space prefix fill character (optional)
4303        u'(?:[.][0-9]+)?'     # precision (optional)
4304        u')?.)'               # format type (or something different for unsupported formats)
4305    )
4306
4307    def _build_fstring(self, pos, ustring, format_args):
4308        # Issues formatting warnings instead of errors since we really only catch a few errors by accident.
4309        args = iter(format_args)
4310        substrings = []
4311        can_be_optimised = True
4312        for s in re.split(self._parse_string_format_regex, ustring):
4313            if not s:
4314                continue
4315            if s == u'%%':
4316                substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(u'%'), constant_result=u'%'))
4317                continue
4318            if s[0] != u'%':
4319                if s[-1] == u'%':
4320                    warning(pos, "Incomplete format: '...%s'" % s[-3:], level=1)
4321                    can_be_optimised = False
4322                substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(s), constant_result=s))
4323                continue
4324            format_type = s[-1]
4325            try:
4326                arg = next(args)
4327            except StopIteration:
4328                warning(pos, "Too few arguments for format placeholders", level=1)
4329                can_be_optimised = False
4330                break
4331            if arg.is_starred:
4332                can_be_optimised = False
4333                break
4334            if format_type in u'asrfdoxX':
4335                format_spec = s[1:]
4336                conversion_char = None
4337                if format_type in u'doxX' and u'.' in format_spec:
4338                    # Precision is not allowed for integers in format(), but ok in %-formatting.
4339                    can_be_optimised = False
4340                elif format_type in u'ars':
4341                    format_spec = format_spec[:-1]
4342                    conversion_char = format_type
4343                    if format_spec.startswith('0'):
4344                        format_spec = '>' + format_spec[1:]  # right-alignment '%05s' spells '{:>5}'
4345                elif format_type == u'd':
4346                    # '%d' formatting supports float, but '{obj:d}' does not => convert to int first.
4347                    conversion_char = 'd'
4348
4349                if format_spec.startswith('-'):
4350                    format_spec = '<' + format_spec[1:]  # left-alignment '%-5s' spells '{:<5}'
4351
4352                substrings.append(ExprNodes.FormattedValueNode(
4353                    arg.pos, value=arg,
4354                    conversion_char=conversion_char,
4355                    format_spec=ExprNodes.UnicodeNode(
4356                        pos, value=EncodedString(format_spec), constant_result=format_spec)
4357                        if format_spec else None,
4358                ))
4359            else:
4360                # keep it simple for now ...
4361                can_be_optimised = False
4362                break
4363
4364        if not can_be_optimised:
4365            # Print all warnings we can find before finally giving up here.
4366            return None
4367
4368        try:
4369            next(args)
4370        except StopIteration: pass
4371        else:
4372            warning(pos, "Too many arguments for format placeholders", level=1)
4373            return None
4374
4375        node = ExprNodes.JoinedStrNode(pos, values=substrings)
4376        return self.visit_JoinedStrNode(node)
4377
4378    def visit_FormattedValueNode(self, node):
4379        self.visitchildren(node)
4380        conversion_char = node.conversion_char or 's'
4381        if isinstance(node.format_spec, ExprNodes.UnicodeNode) and not node.format_spec.value:
4382            node.format_spec = None
4383        if node.format_spec is None and isinstance(node.value, ExprNodes.IntNode):
4384            value = EncodedString(node.value.value)
4385            if value.isdigit():
4386                return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
4387        if node.format_spec is None and conversion_char == 's':
4388            value = None
4389            if isinstance(node.value, ExprNodes.UnicodeNode):
4390                value = node.value.value
4391            elif isinstance(node.value, ExprNodes.StringNode):
4392                value = node.value.unicode_value
4393            if value is not None:
4394                return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
4395        return node
4396
4397    def visit_JoinedStrNode(self, node):
4398        """
4399        Clean up after the parser by discarding empty Unicode strings and merging
4400        substring sequences.  Empty or single-value join lists are not uncommon
4401        because f-string format specs are always parsed into JoinedStrNodes.
4402        """
4403        self.visitchildren(node)
4404        unicode_node = ExprNodes.UnicodeNode
4405
4406        values = []
4407        for is_unode_group, substrings in itertools.groupby(node.values, lambda v: isinstance(v, unicode_node)):
4408            if is_unode_group:
4409                substrings = list(substrings)
4410                unode = substrings[0]
4411                if len(substrings) > 1:
4412                    value = EncodedString(u''.join(value.value for value in substrings))
4413                    unode = ExprNodes.UnicodeNode(unode.pos, value=value, constant_result=value)
4414                # ignore empty Unicode strings
4415                if unode.value:
4416                    values.append(unode)
4417            else:
4418                values.extend(substrings)
4419
4420        if not values:
4421            value = EncodedString('')
4422            node = ExprNodes.UnicodeNode(node.pos, value=value, constant_result=value)
4423        elif len(values) == 1:
4424            node = values[0]
4425        elif len(values) == 2:
4426            # reduce to string concatenation
4427            node = ExprNodes.binop_node(node.pos, '+', *values)
4428        else:
4429            node.values = values
4430        return node
4431
4432    def visit_MergedDictNode(self, node):
4433        """Unpack **args in place if we can."""
4434        self.visitchildren(node)
4435        args = []
4436        items = []
4437
4438        def add(arg):
4439            if arg.is_dict_literal:
4440                if items:
4441                    items[0].key_value_pairs.extend(arg.key_value_pairs)
4442                else:
4443                    items.append(arg)
4444            elif isinstance(arg, ExprNodes.MergedDictNode):
4445                for child_arg in arg.keyword_args:
4446                    add(child_arg)
4447            else:
4448                if items:
4449                    args.append(items[0])
4450                    del items[:]
4451                args.append(arg)
4452
4453        for arg in node.keyword_args:
4454            add(arg)
4455        if items:
4456            args.append(items[0])
4457
4458        if len(args) == 1:
4459            arg = args[0]
4460            if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode):
4461                return arg
4462        node.keyword_args[:] = args
4463        self._calculate_const(node)
4464        return node
4465
4466    def visit_MergedSequenceNode(self, node):
4467        """Unpack *args in place if we can."""
4468        self.visitchildren(node)
4469
4470        is_set = node.type is Builtin.set_type
4471        args = []
4472        values = []
4473
4474        def add(arg):
4475            if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor):
4476                if values:
4477                    values[0].args.extend(arg.args)
4478                else:
4479                    values.append(arg)
4480            elif isinstance(arg, ExprNodes.MergedSequenceNode):
4481                for child_arg in arg.args:
4482                    add(child_arg)
4483            else:
4484                if values:
4485                    args.append(values[0])
4486                    del values[:]
4487                args.append(arg)
4488
4489        for arg in node.args:
4490            add(arg)
4491        if values:
4492            args.append(values[0])
4493
4494        if len(args) == 1:
4495            arg = args[0]
4496            if ((is_set and arg.is_set_literal) or
4497                    (arg.is_sequence_constructor and arg.type is node.type) or
4498                    isinstance(arg, ExprNodes.MergedSequenceNode)):
4499                return arg
4500        node.args[:] = args
4501        self._calculate_const(node)
4502        return node
4503
4504    def visit_SequenceNode(self, node):
4505        """Unpack *args in place if we can."""
4506        self.visitchildren(node)
4507        args = []
4508        for arg in node.args:
4509            if not arg.is_starred:
4510                args.append(arg)
4511            elif arg.target.is_sequence_constructor and not arg.target.mult_factor:
4512                args.extend(arg.target.args)
4513            else:
4514                args.append(arg)
4515        node.args[:] = args
4516        self._calculate_const(node)
4517        return node
4518
4519    def visit_PrimaryCmpNode(self, node):
4520        # calculate constant partial results in the comparison cascade
4521        self.visitchildren(node, ['operand1'])
4522        left_node = node.operand1
4523        cmp_node = node
4524        while cmp_node is not None:
4525            self.visitchildren(cmp_node, ['operand2'])
4526            right_node = cmp_node.operand2
4527            cmp_node.constant_result = not_a_constant
4528            if left_node.has_constant_result() and right_node.has_constant_result():
4529                try:
4530                    cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
4531                except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
4532                    pass  # ignore all 'normal' errors here => no constant result
4533            left_node = right_node
4534            cmp_node = cmp_node.cascade
4535
4536        if not node.cascade:
4537            if node.has_constant_result():
4538                return self._bool_node(node, node.constant_result)
4539            return node
4540
4541        # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
4542        cascades = [[node.operand1]]
4543        final_false_result = []
4544
4545        def split_cascades(cmp_node):
4546            if cmp_node.has_constant_result():
4547                if not cmp_node.constant_result:
4548                    # False => short-circuit
4549                    final_false_result.append(self._bool_node(cmp_node, False))
4550                    return
4551                else:
4552                    # True => discard and start new cascade
4553                    cascades.append([cmp_node.operand2])
4554            else:
4555                # not constant => append to current cascade
4556                cascades[-1].append(cmp_node)
4557            if cmp_node.cascade:
4558                split_cascades(cmp_node.cascade)
4559
4560        split_cascades(node)
4561
4562        cmp_nodes = []
4563        for cascade in cascades:
4564            if len(cascade) < 2:
4565                continue
4566            cmp_node = cascade[1]
4567            pcmp_node = ExprNodes.PrimaryCmpNode(
4568                cmp_node.pos,
4569                operand1=cascade[0],
4570                operator=cmp_node.operator,
4571                operand2=cmp_node.operand2,
4572                constant_result=not_a_constant)
4573            cmp_nodes.append(pcmp_node)
4574
4575            last_cmp_node = pcmp_node
4576            for cmp_node in cascade[2:]:
4577                last_cmp_node.cascade = cmp_node
4578                last_cmp_node = cmp_node
4579            last_cmp_node.cascade = None
4580
4581        if final_false_result:
4582            # last cascade was constant False
4583            cmp_nodes.append(final_false_result[0])
4584        elif not cmp_nodes:
4585            # only constants, but no False result
4586            return self._bool_node(node, True)
4587        node = cmp_nodes[0]
4588        if len(cmp_nodes) == 1:
4589            if node.has_constant_result():
4590                return self._bool_node(node, node.constant_result)
4591        else:
4592            for cmp_node in cmp_nodes[1:]:
4593                node = ExprNodes.BoolBinopNode(
4594                    node.pos,
4595                    operand1=node,
4596                    operator='and',
4597                    operand2=cmp_node,
4598                    constant_result=not_a_constant)
4599        return node
4600
4601    def visit_CondExprNode(self, node):
4602        self._calculate_const(node)
4603        if not node.test.has_constant_result():
4604            return node
4605        if node.test.constant_result:
4606            return node.true_val
4607        else:
4608            return node.false_val
4609
4610    def visit_IfStatNode(self, node):
4611        self.visitchildren(node)
4612        # eliminate dead code based on constant condition results
4613        if_clauses = []
4614        for if_clause in node.if_clauses:
4615            condition = if_clause.condition
4616            if condition.has_constant_result():
4617                if condition.constant_result:
4618                    # always true => subsequent clauses can safely be dropped
4619                    node.else_clause = if_clause.body
4620                    break
4621                # else: false => drop clause
4622            else:
4623                # unknown result => normal runtime evaluation
4624                if_clauses.append(if_clause)
4625        if if_clauses:
4626            node.if_clauses = if_clauses
4627            return node
4628        elif node.else_clause:
4629            return node.else_clause
4630        else:
4631            return Nodes.StatListNode(node.pos, stats=[])
4632
4633    def visit_SliceIndexNode(self, node):
4634        self._calculate_const(node)
4635        # normalise start/stop values
4636        if node.start is None or node.start.constant_result is None:
4637            start = node.start = None
4638        else:
4639            start = node.start.constant_result
4640        if node.stop is None or node.stop.constant_result is None:
4641            stop = node.stop = None
4642        else:
4643            stop = node.stop.constant_result
4644        # cut down sliced constant sequences
4645        if node.constant_result is not not_a_constant:
4646            base = node.base
4647            if base.is_sequence_constructor and base.mult_factor is None:
4648                base.args = base.args[start:stop]
4649                return base
4650            elif base.is_string_literal:
4651                base = base.as_sliced_node(start, stop)
4652                if base is not None:
4653                    return base
4654        return node
4655
4656    def visit_ComprehensionNode(self, node):
4657        self.visitchildren(node)
4658        if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
4659            # loop was pruned already => transform into literal
4660            if node.type is Builtin.list_type:
4661                return ExprNodes.ListNode(
4662                    node.pos, args=[], constant_result=[])
4663            elif node.type is Builtin.set_type:
4664                return ExprNodes.SetNode(
4665                    node.pos, args=[], constant_result=set())
4666            elif node.type is Builtin.dict_type:
4667                return ExprNodes.DictNode(
4668                    node.pos, key_value_pairs=[], constant_result={})
4669        return node
4670
4671    def visit_ForInStatNode(self, node):
4672        self.visitchildren(node)
4673        sequence = node.iterator.sequence
4674        if isinstance(sequence, ExprNodes.SequenceNode):
4675            if not sequence.args:
4676                if node.else_clause:
4677                    return node.else_clause
4678                else:
4679                    # don't break list comprehensions
4680                    return Nodes.StatListNode(node.pos, stats=[])
4681            # iterating over a list literal? => tuples are more efficient
4682            if isinstance(sequence, ExprNodes.ListNode):
4683                node.iterator.sequence = sequence.as_tuple()
4684        return node
4685
4686    def visit_WhileStatNode(self, node):
4687        self.visitchildren(node)
4688        if node.condition and node.condition.has_constant_result():
4689            if node.condition.constant_result:
4690                node.condition = None
4691                node.else_clause = None
4692            else:
4693                return node.else_clause
4694        return node
4695
4696    def visit_ExprStatNode(self, node):
4697        self.visitchildren(node)
4698        if not isinstance(node.expr, ExprNodes.ExprNode):
4699            # ParallelRangeTransform does this ...
4700            return node
4701        # drop unused constant expressions
4702        if node.expr.has_constant_result():
4703            return None
4704        return node
4705
4706    # in the future, other nodes can have their own handler method here
4707    # that can replace them with a constant result node
4708
4709    visit_Node = Visitor.VisitorTransform.recurse_to_children
4710
4711
4712class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin):
4713    """
4714    This visitor handles several commuting optimizations, and is run
4715    just before the C code generation phase.
4716
4717    The optimizations currently implemented in this class are:
4718        - eliminate None assignment and refcounting for first assignment.
4719        - isinstance -> typecheck for cdef types
4720        - eliminate checks for None and/or types that became redundant after tree changes
4721        - eliminate useless string formatting steps
4722        - replace Python function calls that look like method calls by a faster PyMethodCallNode
4723    """
4724    in_loop = False
4725
4726    def visit_SingleAssignmentNode(self, node):
4727        """Avoid redundant initialisation of local variables before their
4728        first assignment.
4729        """
4730        self.visitchildren(node)
4731        if node.first:
4732            lhs = node.lhs
4733            lhs.lhs_of_first_assignment = True
4734        return node
4735
4736    def visit_SimpleCallNode(self, node):
4737        """
4738        Replace generic calls to isinstance(x, type) by a more efficient type check.
4739        Replace likely Python method calls by a specialised PyMethodCallNode.
4740        """
4741        self.visitchildren(node)
4742        function = node.function
4743        if function.type.is_cfunction and function.is_name:
4744            if function.name == 'isinstance' and len(node.args) == 2:
4745                type_arg = node.args[1]
4746                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
4747                    cython_scope = self.context.cython_scope
4748                    function.entry = cython_scope.lookup('PyObject_TypeCheck')
4749                    function.type = function.entry.type
4750                    PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
4751                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
4752        elif (node.is_temp and function.type.is_pyobject and self.current_directives.get(
4753                "optimize.unpack_method_calls_in_pyinit"
4754                if not self.in_loop and self.current_env().is_module_scope
4755                else "optimize.unpack_method_calls")):
4756            # optimise simple Python methods calls
4757            if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not (
4758                    node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and len(node.arg_tuple.args) > 1)):
4759                # simple call, now exclude calls to objects that are definitely not methods
4760                may_be_a_method = True
4761                if function.type is Builtin.type_type:
4762                    may_be_a_method = False
4763                elif function.is_attribute:
4764                    if function.entry and function.entry.type.is_cfunction:
4765                        # optimised builtin method
4766                        may_be_a_method = False
4767                elif function.is_name:
4768                    entry = function.entry
4769                    if entry.is_builtin or entry.type.is_cfunction:
4770                        may_be_a_method = False
4771                    elif entry.cf_assignments:
4772                        # local functions/classes are definitely not methods
4773                        non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode)
4774                        may_be_a_method = any(
4775                            assignment.rhs and not isinstance(assignment.rhs, non_method_nodes)
4776                            for assignment in entry.cf_assignments)
4777                if may_be_a_method:
4778                    if (node.self and function.is_attribute and
4779                            isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self):
4780                        # function self object was moved into a CloneNode => undo
4781                        function.obj = function.obj.arg
4782                    node = self.replace(node, ExprNodes.PyMethodCallNode.from_node(
4783                        node, function=function, arg_tuple=node.arg_tuple, type=node.type))
4784        return node
4785
4786    def visit_NumPyMethodCallNode(self, node):
4787        # Exclude from replacement above.
4788        self.visitchildren(node)
4789        return node
4790
4791    def visit_PyTypeTestNode(self, node):
4792        """Remove tests for alternatively allowed None values from
4793        type tests when we know that the argument cannot be None
4794        anyway.
4795        """
4796        self.visitchildren(node)
4797        if not node.notnone:
4798            if not node.arg.may_be_none():
4799                node.notnone = True
4800        return node
4801
4802    def visit_NoneCheckNode(self, node):
4803        """Remove None checks from expressions that definitely do not
4804        carry a None value.
4805        """
4806        self.visitchildren(node)
4807        if not node.arg.may_be_none():
4808            return node.arg
4809        return node
4810
4811    def visit_LoopNode(self, node):
4812        """Remember when we enter a loop as some expensive optimisations might still be worth it there.
4813        """
4814        old_val = self.in_loop
4815        self.in_loop = True
4816        self.visitchildren(node)
4817        self.in_loop = old_val
4818        return node
4819
4820
4821class ConsolidateOverflowCheck(Visitor.CythonTransform):
4822    """
4823    This class facilitates the sharing of overflow checking among all nodes
4824    of a nested arithmetic expression.  For example, given the expression
4825    a*b + c, where a, b, and x are all possibly overflowing ints, the entire
4826    sequence will be evaluated and the overflow bit checked only at the end.
4827    """
4828    overflow_bit_node = None
4829
4830    def visit_Node(self, node):
4831        if self.overflow_bit_node is not None:
4832            saved = self.overflow_bit_node
4833            self.overflow_bit_node = None
4834            self.visitchildren(node)
4835            self.overflow_bit_node = saved
4836        else:
4837            self.visitchildren(node)
4838        return node
4839
4840    def visit_NumBinopNode(self, node):
4841        if node.overflow_check and node.overflow_fold:
4842            top_level_overflow = self.overflow_bit_node is None
4843            if top_level_overflow:
4844                self.overflow_bit_node = node
4845            else:
4846                node.overflow_bit_node = self.overflow_bit_node
4847                node.overflow_check = False
4848            self.visitchildren(node)
4849            if top_level_overflow:
4850                self.overflow_bit_node = None
4851        else:
4852            self.visitchildren(node)
4853        return node
4854