1import cytoolz
2import cytoolz.curried
3from cytoolz.curried import (take, first, second, sorted, merge_with, reduce,
4                           merge, operator as cop)
5from collections import defaultdict
6from importlib import import_module
7from operator import add
8
9
10def test_take():
11    assert list(take(2)([1, 2, 3])) == [1, 2]
12
13
14def test_first():
15    assert first is cytoolz.itertoolz.first
16
17
18def test_merge():
19    assert merge(factory=lambda: defaultdict(int))({1: 1}) == {1: 1}
20    assert merge({1: 1}) == {1: 1}
21    assert merge({1: 1}, factory=lambda: defaultdict(int)) == {1: 1}
22
23
24def test_merge_with():
25    assert merge_with(sum)({1: 1}, {1: 2}) == {1: 3}
26
27
28def test_merge_with_list():
29    assert merge_with(sum, [{'a': 1}, {'a': 2}]) == {'a': 3}
30
31
32def test_sorted():
33    assert sorted(key=second)([(1, 2), (2, 1)]) == [(2, 1), (1, 2)]
34
35
36def test_reduce():
37    assert reduce(add)((1, 2, 3)) == 6
38
39
40def test_module_name():
41    assert cytoolz.curried.__name__ == 'cytoolz.curried'
42
43
44def test_curried_operator():
45    for k, v in vars(cop).items():
46        if not callable(v):
47            continue
48
49        if not isinstance(v, cytoolz.curry):
50            try:
51                # Make sure it is unary
52                v(1)
53            except TypeError:
54                try:
55                    v('x')
56                except TypeError:
57                    pass
58                else:
59                    continue
60                raise AssertionError(
61                    'cytoolz.curried.operator.%s is not curried!' % k,
62                )
63
64    # Make sure this isn't totally empty.
65    assert len(set(vars(cop)) & {'add', 'sub', 'mul'}) == 3
66
67
68def test_curried_namespace():
69    exceptions = import_module('cytoolz.curried.exceptions')
70    namespace = {}
71
72    def should_curry(func):
73        if not callable(func) or isinstance(func, cytoolz.curry):
74            return False
75        nargs = cytoolz.functoolz.num_required_args(func)
76        if nargs is None or nargs > 1:
77            return True
78        return nargs == 1 and cytoolz.functoolz.has_keywords(func)
79
80
81    def curry_namespace(ns):
82        return {
83            name: cytoolz.curry(f) if should_curry(f) else f
84            for name, f in ns.items() if '__' not in name
85        }
86
87    from_cytoolz = curry_namespace(vars(cytoolz))
88    from_exceptions = curry_namespace(vars(exceptions))
89    namespace.update(cytoolz.merge(from_cytoolz, from_exceptions))
90
91    namespace = cytoolz.valfilter(callable, namespace)
92    curried_namespace = cytoolz.valfilter(callable, cytoolz.curried.__dict__)
93
94    if namespace != curried_namespace:
95        missing = set(namespace) - set(curried_namespace)
96        if missing:
97            raise AssertionError('There are missing functions in cytoolz.curried:\n    %s'
98                                 % '    \n'.join(sorted(missing)))
99        extra = set(curried_namespace) - set(namespace)
100        if extra:
101            raise AssertionError('There are extra functions in cytoolz.curried:\n    %s'
102                                 % '    \n'.join(sorted(extra)))
103        unequal = cytoolz.merge_with(list, namespace, curried_namespace)
104        unequal = cytoolz.valfilter(lambda x: x[0] != x[1], unequal)
105        messages = []
106        for name, (orig_func, auto_func) in sorted(unequal.items()):
107            if name in from_exceptions:
108                messages.append('%s should come from cytoolz.curried.exceptions' % name)
109            elif should_curry(getattr(cytoolz, name)):
110                messages.append('%s should be curried from cytoolz' % name)
111            else:
112                messages.append('%s should come from cytoolz and NOT be curried' % name)
113        raise AssertionError('\n'.join(messages))
114