1"""
2    Unit tests for simpleeval.
3    --------------------------
4
5    Most of this stuff is pretty basic.
6
7"""
8# pylint: disable=too-many-public-methods, missing-docstring
9import sys
10import unittest
11import operator
12import ast
13import simpleeval
14import os
15from simpleeval import (
16    SimpleEval, EvalWithCompoundTypes, FeatureNotAvailable, FunctionNotDefined, NameNotDefined,
17    InvalidExpression, AttributeDoesNotExist, simple_eval
18)
19
20
21class DRYTest(unittest.TestCase):
22    """ Stuff we need to do every test, let's do here instead..
23        Don't Repeat Yourself. """
24
25    def setUp(self):
26        """ initialize a SimpleEval """
27        self.s = SimpleEval()
28
29    def t(self, expr, shouldbe):  # pylint: disable=invalid-name
30        """ test an evaluation of an expression against an expected answer """
31        return self.assertEqual(self.s.eval(expr), shouldbe)
32
33
34class TestBasic(DRYTest):
35    """ Simple expressions. """
36
37    def test_maths_with_ints(self):
38        """ simple maths expressions """
39
40        self.t("21 + 21", 42)
41        self.t("6*7", 42)
42        self.t("20 + 1 + (10*2) + 1", 42)
43        self.t("100/10", 10)
44        self.t("12*12", 144)
45        self.t("2 ** 10", 1024)
46        self.t("100 % 9", 1)
47
48    def test_bools_and_or(self):
49        self.t('True and ""', "")
50        self.t('True and False', False)
51        self.t('True or False', True)
52        self.t('False or False', False)
53        self.t('1 - 1 or 21', 21)
54        self.t('1 - 1 and 11', 0)
55        self.t('110 == 100 + 10 and True', True)
56        self.t('110 != 100 + 10 and True', False)
57        self.t('False or 42', 42)
58
59        self.t('False or None', None)
60        self.t('None or None', None)
61
62        self.s.names = {'out': True, 'position': 3}
63        self.t('(out and position <=6 and -10)'
64               ' or (out and position > 6 and -5)'
65               ' or (not out and 15)', -10)
66
67    def test_not(self):
68        self.t('not False', True)
69        self.t('not True', False)
70        self.t('not 0', True)
71        self.t('not 1', False)
72
73    def test_maths_with_floats(self):
74        self.t("11.02 - 9.1", 1.92)
75        self.t("29.1+39", 68.1)
76
77    def test_comparisons(self):
78        # GT & LT:
79        self.t("1 > 0", True)
80        self.t("100000 < 28", False)
81        self.t("-2 < 11", True)
82        self.t("+2 < 5", True)
83        self.t("0 == 0", True)
84
85        # GtE, LtE
86        self.t("-2 <= -2", True)
87        self.t("2 >= 2", True)
88        self.t("1 >= 12", False)
89        self.t("1.09 <= 1967392", True)
90
91        self.t('1 < 2 < 3 < 4', 1 < 2 < 3 < 4)
92        self.t('1 < 2 > 3 < 4', 1 < 2 > 3 < 4)
93
94        self.t('1<2<1+1', 1 < 2 < 1 + 1)
95        self.t('1 == 1 == 2', 1 == 1 == 2)
96        self.t('1 == 1 < 2', 1 == 1 < 2)
97
98    def test_mixed_comparisons(self):
99        self.t("1 > 0.999999", True)
100        self.t("1 == True", True)  # Note ==, not 'is'.
101        self.t("0 == False", True)  # Note ==, not 'is'.
102        self.t("False == False", True)
103        self.t("False < True", True)
104
105    def test_if_else(self):
106        """ x if y else z """
107
108        # and test if/else expressions:
109        self.t("'a' if 1 == 1 else 'b'", 'a')
110        self.t("'a' if 1 > 2 else 'b'", 'b')
111
112        # and more complex expressions:
113        self.t("'a' if 4 < 1 else 'b' if 1 == 2 else 'c'", 'c')
114
115    def test_default_conversions(self):
116        """ conversion between types """
117
118        self.t('int("20") + int(0.22*100)', 42)
119        self.t('float("42")', 42.0)
120        self.t('"Test Stuff!" + str(11)', "Test Stuff!11")
121
122    def test_slicing(self):
123        self.s.operators[ast.Slice] = (operator.getslice
124                                       if hasattr(operator, "getslice") else operator.getitem)
125        self.t("'hello'[1]", "e")
126        self.t("'hello'[:]", "hello")
127        self.t("'hello'[:3]", "hel")
128        self.t("'hello'[3:]", "lo")
129        self.t("'hello'[::2]", "hlo")
130        self.t("'hello'[::-1]", "olleh")
131        self.t("'hello'[3::]", "lo")
132        self.t("'hello'[:3:]", "hel")
133        self.t("'hello'[1:3]", "el")
134        self.t("'hello'[1:3:]", "el")
135        self.t("'hello'[1::2]", "el")
136        self.t("'hello'[:1:2]", "h")
137        self.t("'hello'[1:3:1]", "el")
138        self.t("'hello'[1:3:2]", "e")
139
140        with self.assertRaises(IndexError):
141            self.t("'hello'[90]", 0)
142
143        self.t('"spam" not in "my breakfast"', True)
144        self.t('"silly" in "ministry of silly walks"', True)
145        self.t('"I" not in "team"', True)
146        self.t('"U" in "RUBBISH"', True)
147
148    def test_is(self):
149        self.t('1 is 1', True)
150        self.t('1 is 2', False)
151        self.t('1 is "a"', False)
152        self.t('1 is None', False)
153        self.t('None is None', True)
154
155        self.t('1 is not 1', False)
156        self.t('1 is not 2', True)
157        self.t('1 is not "a"', True)
158        self.t('1 is not None', True)
159        self.t('None is not None', False)
160
161    def test_fstring(self):
162        if sys.version_info >= (3, 6, 0):
163            self.t('f""', "")
164            self.t('f"stuff"', "stuff")
165            self.t('f"one is {1} and two is {2}"', "one is 1 and two is 2")
166            self.t('f"1+1 is {1+1}"', "1+1 is 2")
167            self.t('f"{\'dramatic\':!<11}"', "dramatic!!!")
168
169    def test_set_not_allowed(self):
170        with self.assertRaises(FeatureNotAvailable):
171            self.t('{22}', False)
172
173
174class TestFunctions(DRYTest):
175    """ Functions for expressions to play with """
176
177    def test_load_file(self):
178        """ add in a function which loads data from an external file. """
179
180        # write to the file:
181
182        with open("testfile.txt", 'w') as f:
183            f.write("42")
184
185        # define the function we'll send to the eval'er
186
187        def load_file(filename):
188            """ load a file and return its contents """
189            with open(filename) as f2:
190                return f2.read()
191
192        # simple load:
193
194        self.s.functions = {"read": load_file}
195        self.t("read('testfile.txt')", "42")
196
197        # and we should have *replaced* the default functions. Let's check:
198
199        with self.assertRaises(simpleeval.FunctionNotDefined):
200            self.t("int(read('testfile.txt'))", 42)
201
202        # OK, so we can load in the default functions as well...
203
204        self.s.functions.update(simpleeval.DEFAULT_FUNCTIONS)
205
206        # now it works:
207
208        self.t("int(read('testfile.txt'))", 42)
209
210        os.remove('testfile.txt')
211
212    def test_randoms(self):
213        """ test the rand() and randint() functions """
214
215        i = self.s.eval('randint(1000)')
216        self.assertEqual(type(i), int)
217        self.assertLessEqual(i, 1000)
218
219        f = self.s.eval('rand()')
220        self.assertEqual(type(f), float)
221
222        self.t("randint(20)<20", True)
223        self.t("rand()<1.0", True)
224
225        # I don't know how to further test these functions.  Ideas?
226
227    def test_methods(self):
228        self.t('"WORD".lower()', 'word')
229        x = simpleeval.DISALLOW_METHODS
230        simpleeval.DISALLOW_METHODS = []
231        self.t('"{}:{}".format(1, 2)', '1:2')
232        simpleeval.DISALLOW_METHODS = x
233
234    def test_function_args_none(self):
235        def foo():
236            return 42
237
238        self.s.functions['foo'] = foo
239        self.t('foo()', 42)
240
241    def test_function_args_required(self):
242        def foo(toret):
243            return toret
244
245        self.s.functions['foo'] = foo
246        with self.assertRaises(TypeError):
247            self.t('foo()', 42)
248
249        self.t('foo(12)', 12)
250        self.t('foo(toret=100)', 100)
251
252    def test_function_args_defaults(self):
253        def foo(toret=9999):
254            return toret
255
256        self.s.functions['foo'] = foo
257        self.t('foo()', 9999)
258
259        self.t('foo(12)', 12)
260        self.t('foo(toret=100)', 100)
261
262    def test_function_args_bothtypes(self):
263        def foo(mult, toret=100):
264            return toret * mult
265
266        self.s.functions['foo'] = foo
267        with self.assertRaises(TypeError):
268            self.t('foo()', 9999)
269
270        self.t('foo(2)', 200)
271
272        with self.assertRaises(TypeError):
273            self.t('foo(toret=100)', 100)
274
275        self.t('foo(4, toret=4)', 16)
276        self.t('foo(mult=2, toret=4)', 8)
277        self.t('foo(2, 10)', 20)
278
279
280class TestOperators(DRYTest):
281    """ Test adding in new operators, removing them, make sure it works. """
282    # TODO
283    pass
284
285class TestNewFeatures(DRYTest):
286    """ Tests which will break when new features are added..."""
287    def test_lambda(self):
288        with self.assertRaises(FeatureNotAvailable):
289            self.t('lambda x:22', None)
290
291    def test_lambda_application(self):
292        with self.assertRaises(FeatureNotAvailable):
293            self.t('(lambda x:22)(44)', None)
294
295
296class TestTryingToBreakOut(DRYTest):
297    """ Test various weird methods to break the security sandbox... """
298
299    def test_import(self):
300        """ usual suspect. import """
301        # cannot import things:
302        with self.assertRaises(AttributeError):
303            self.t("import sys", None)
304
305    def test_long_running(self):
306        """ exponent operations can take a long time. """
307        old_max = simpleeval.MAX_POWER
308
309        self.t("9**9**5", 9 ** 9 ** 5)
310
311        with self.assertRaises(simpleeval.NumberTooHigh):
312            self.t("9**9**8", 0)
313
314        # and does limiting work?
315
316        simpleeval.MAX_POWER = 100
317
318        with self.assertRaises(simpleeval.NumberTooHigh):
319            self.t("101**2", 0)
320
321        # good, so set it back:
322
323        simpleeval.MAX_POWER = old_max
324
325    def test_encode_bignums(self):
326        # thanks gk
327        if hasattr(1, 'from_bytes'):  # python3 only
328            with self.assertRaises(simpleeval.IterableTooLong):
329                self.t('(1).from_bytes(("123123123123123123123123").encode()*999999, "big")', 0)
330
331    def test_string_length(self):
332        with self.assertRaises(simpleeval.IterableTooLong):
333            self.t("50000*'text'", 0)
334
335        with self.assertRaises(simpleeval.IterableTooLong):
336            self.t("'text'*50000", 0)
337
338        with self.assertRaises(simpleeval.IterableTooLong):
339            self.t("('text'*50000)*1000", 0)
340
341        with self.assertRaises(simpleeval.IterableTooLong):
342            self.t("(50000*'text')*1000", 0)
343
344        self.t("'stuff'*20000", 20000 * 'stuff')
345
346        self.t("20000*'stuff'", 20000 * 'stuff')
347
348        with self.assertRaises(simpleeval.IterableTooLong):
349            self.t("('stuff'*20000) + ('stuff'*20000) ", 0)
350
351        with self.assertRaises(simpleeval.IterableTooLong):
352            self.t("'stuff'*100000", 100000 * 'stuff')
353
354        with self.assertRaises(simpleeval.IterableTooLong):
355            self.t("'" + (10000 * "stuff") + "'*100", 0)
356
357        with self.assertRaises(simpleeval.IterableTooLong):
358            self.t("'" + (50000 * "stuff") + "'", 0)
359
360        if sys.version_info >= (3, 6, 0):
361            with self.assertRaises(simpleeval.IterableTooLong):
362                self.t("f'{\"foo\"*50000}'", 0)
363
364    def test_bytes_array_test(self):
365        self.t("'20000000000000000000'.encode() * 5000",
366               '20000000000000000000'.encode() * 5000)
367
368        with self.assertRaises(simpleeval.IterableTooLong):
369            self.t("'123121323123131231223'.encode() * 5000", 20)
370
371    def test_list_length_test(self):
372        self.t("'spam spam spam'.split() * 5000", ['spam', 'spam', 'spam'] * 5000)
373
374        with self.assertRaises(simpleeval.IterableTooLong):
375            self.t("('spam spam spam' * 5000).split() * 5000", None)
376
377    def test_python_stuff(self):
378        """ other various pythony things. """
379        # it only evaluates the first statement:
380        self.t("a = 11; x = 21; x + x", 11)
381
382
383    def test_function_globals_breakout(self):
384        """ by accessing function.__globals__ or func_... """
385        # thanks perkinslr.
386
387        self.s.functions['x'] = lambda y: y + y
388        self.t('x(100)', 200)
389
390        with self.assertRaises(simpleeval.FeatureNotAvailable):
391            self.t('x.__globals__', None)
392
393        class EscapeArtist(object):
394            @staticmethod
395            def trapdoor():
396                return 42
397
398            @staticmethod
399            def _quasi_private():
400                return 84
401
402        self.s.names['houdini'] = EscapeArtist()
403
404        with self.assertRaises(simpleeval.FeatureNotAvailable):
405            self.t('houdini.trapdoor.__globals__', 0)
406
407        with self.assertRaises(simpleeval.FeatureNotAvailable):
408            self.t('houdini.trapdoor.func_globals', 0)
409
410        with self.assertRaises(simpleeval.FeatureNotAvailable):
411            self.t('houdini._quasi_private()', 0)
412
413        # and test for changing '_' to '__':
414
415        dis = simpleeval.DISALLOW_PREFIXES
416        simpleeval.DISALLOW_PREFIXES = ['func_']
417
418        self.t('houdini.trapdoor()', 42)
419        self.t('houdini._quasi_private()', 84)
420
421        # and return things to normal
422
423        simpleeval.DISALLOW_PREFIXES = dis
424
425    def test_mro_breakout(self):
426        class Blah(object):
427            x = 42
428
429        self.s.names['b'] = Blah
430
431        with self.assertRaises(simpleeval.FeatureNotAvailable):
432            self.t('b.mro()', None)
433
434    def test_builtins_private_access(self):
435        # explicit attempt of the exploit from perkinslr
436        with self.assertRaises(simpleeval.FeatureNotAvailable):
437            self.t("True.__class__.__class__.__base__.__subclasses__()[-1]"
438                   ".__init__.func_globals['sys'].exit(1)", 42)
439
440
441    def test_string_format(self):
442        # python has so many ways to break out!
443        with self.assertRaises(simpleeval.FeatureNotAvailable):
444             self.t('"{string.__class__}".format(string="things")', 0)
445
446        with self.assertRaises(simpleeval.FeatureNotAvailable):
447             self.s.names['x'] = {"a": 1}
448             self.t('"{a.__class__}".format_map(x)', 0)
449
450        if sys.version_info >= (3, 6, 0):
451            self.s.names['x'] = 42
452
453            with self.assertRaises(simpleeval.FeatureNotAvailable):
454                self.t('f"{x.__class__}"', 0)
455
456            self.s.names['x'] = lambda y: y
457
458            with self.assertRaises(simpleeval.FeatureNotAvailable):
459                self.t('f"{x.__globals__}"', 0)
460
461            class EscapeArtist(object):
462                @staticmethod
463                def trapdoor():
464                    return 42
465
466                @staticmethod
467                def _quasi_private():
468                    return 84
469
470            self.s.names['houdini'] = EscapeArtist()  # let's just retest this, but in a f-string
471
472            with self.assertRaises(simpleeval.FeatureNotAvailable):
473                self.t('f"{houdini.trapdoor.__globals__}"', 0)
474
475            with self.assertRaises(simpleeval.FeatureNotAvailable):
476                self.t('f"{houdini.trapdoor.func_globals}"', 0)
477
478            with self.assertRaises(simpleeval.FeatureNotAvailable):
479                self.t('f"{houdini._quasi_private()}"', 0)
480
481            # and test for changing '_' to '__':
482
483            dis = simpleeval.DISALLOW_PREFIXES
484            simpleeval.DISALLOW_PREFIXES = ['func_']
485
486            self.t('f"{houdini.trapdoor()}"', "42")
487            self.t('f"{houdini._quasi_private()}"', "84")
488
489            # and return things to normal
490
491            simpleeval.DISALLOW_PREFIXES = dis
492
493
494
495class TestCompoundTypes(DRYTest):
496    """ Test the compound-types edition of the library """
497
498    def setUp(self):
499        self.s = EvalWithCompoundTypes()
500
501    def test_dict(self):
502        self.t('{}', {})
503        self.t('{"foo": "bar"}', {'foo': 'bar'})
504        self.t('{"foo": "bar"}["foo"]', 'bar')
505        self.t('dict()', {})
506        self.t('dict(a=1)', {'a': 1})
507
508    def test_dict_contains(self):
509        self.t('{"a":22}["a"]', 22)
510        with self.assertRaises(KeyError):
511            self.t('{"a":22}["b"]', 22)
512
513        self.t('{"a": 24}.get("b", 11)', 11)
514        self.t('"a" in {"a": 24}', True)
515
516    def test_tuple(self):
517        self.t('()', ())
518        self.t('(1,)', (1,))
519        self.t('(1, 2, 3, 4, 5, 6)', (1, 2, 3, 4, 5, 6))
520        self.t('(1, 2) + (3, 4)', (1, 2, 3, 4))
521        self.t('(1, 2, 3)[1]', 2)
522        self.t('tuple()', ())
523        self.t('tuple("foo")', ('f', 'o', 'o'))
524
525    def test_tuple_contains(self):
526        self.t('("a","b")[1]', 'b')
527        with self.assertRaises(IndexError):
528            self.t('("a","b")[5]', 'b')
529        self.t('"a" in ("b","c","a")', True)
530
531    def test_list(self):
532        self.t('[]', [])
533        self.t('[1]', [1])
534        self.t('[1, 2, 3, 4, 5]', [1, 2, 3, 4, 5])
535        self.t('[1, 2, 3][1]', 2)
536        self.t('list()', [])
537        self.t('list("foo")', ['f', 'o', 'o'])
538
539    def test_list_contains(self):
540        self.t('["a","b"][1]', 'b')
541        with self.assertRaises(IndexError):
542            self.t('("a","b")[5]', 'b')
543
544        self.t('"b" in ["a","b"]', True)
545
546    def test_set(self):
547        self.t('{1}', {1})
548        self.t('{1, 2, 1, 2, 1, 2, 1}', {1, 2})
549        self.t('set()', set())
550        self.t('set("foo")', {'f', 'o'})
551
552        self.t('2 in {1,2,3,4}', True)
553        self.t('22 not in {1,2,3,4}', True)
554
555    def test_not(self):
556        self.t('not []', True)
557        self.t('not [0]', False)
558        self.t('not {}', True)
559        self.t('not {0: 1}', False)
560        self.t('not {0}', False)
561
562    def test_use_func(self):
563        self.s = EvalWithCompoundTypes(functions={"map": map, "str": str})
564        self.t('list(map(str, [-1, 0, 1]))', ['-1', '0', '1'])
565        with self.assertRaises(NameNotDefined):
566            self.s.eval('list(map(bad, [-1, 0, 1]))')
567
568        with self.assertRaises(FunctionNotDefined):
569            self.s.eval('dir(str)')
570        with self.assertRaises(FeatureNotAvailable):
571            self.s.eval('str.__dict__')
572
573        self.s = EvalWithCompoundTypes(functions={"dir": dir, "str": str})
574        self.t('dir(str)', dir(str))
575
576
577class TestComprehensions(DRYTest):
578    """ Test the comprehensions support of the compound-types edition of the class. """
579
580    def setUp(self):
581        self.s = EvalWithCompoundTypes()
582
583    def test_basic(self):
584        self.t('[a + 1 for a in [1,2,3]]', [2,3,4])
585
586    def test_with_self_reference(self):
587        self.t('[a + a for a in [1,2,3]]', [2,4,6])
588
589    def test_with_if(self):
590        self.t('[a for a in [1,2,3,4,5] if a <= 3]', [1,2,3])
591
592    def test_with_multiple_if(self):
593        self.t('[a for a in [1,2,3,4,5] if a <= 3 and a > 1 ]', [2,3])
594
595    def test_attr_access_fails(self):
596        with self.assertRaises(FeatureNotAvailable):
597            self.t('[a.__class__ for a in [1,2,3]]', None)
598
599    def test_unpack(self):
600        self.t('[a+b for a,b in ((1,2),(3,4))]', [3, 7])
601
602    def test_nested_unpack(self):
603        self.t('[a+b+c for a, (b, c) in ((1,(1,1)),(3,(2,2)))]', [3, 7])
604
605    def test_other_places(self):
606        self.s.functions = {'sum': sum}
607        self.t('sum([a+1 for a in [1,2,3,4,5]])', 20)
608        self.t('sum(a+1 for a in [1,2,3,4,5])', 20)
609
610    def test_external_names_work(self):
611        self.s.names = {'x': [22, 102, 12.3]}
612        self.t('[a/2 for a in x]', [11.0, 51.0, 6.15])
613
614        self.s.names = lambda x: ord(x.id)
615        self.t('[a + a for a in [b, c, d]]', [ord(x) * 2 for x in 'bcd'])
616
617    def test_multiple_generators(self):
618        self.s.functions = {'range': range}
619        s = '[j for i in range(100) if i > 10 for j in range(i) if j < 20]'
620        self.t(s, eval(s))
621
622    def test_triple_generators(self):
623        self.s.functions = {'range': range}
624        s = '[(a,b,c) for a in range(4) for b in range(a) for c in range(b)]'
625        self.t(s, eval(s))
626
627    def test_too_long_generator(self):
628        self.s.functions = {'range': range}
629        s = '[j for i in range(1000) if i > 10 for j in range(i) if j < 20]'
630        with self.assertRaises(simpleeval.IterableTooLong):
631            self.s.eval(s)
632
633    def test_too_long_generator_2(self):
634        self.s.functions = {'range': range}
635        s = '[j for i in range(100) if i > 1 for j in range(i+10) if j < 100 for k in range(i*j)]'
636        with self.assertRaises(simpleeval.IterableTooLong):
637            self.s.eval(s)
638
639    def test_nesting_generators_to_cheat(self):
640        self.s.functions = {'range': range}
641        s = '[[[c for c in range(a)] for a in range(b)] for b in range(200)]'
642
643        with self.assertRaises(simpleeval.IterableTooLong):
644            self.s.eval(s)
645
646    def test_no_leaking_names(self):
647        # see issue #52, failing list comprehensions could leak locals
648        with self.assertRaises(simpleeval.NameNotDefined):
649            self.s.eval('[x if x == "2" else y for x in "123"]')
650
651        with self.assertRaises(simpleeval.NameNotDefined):
652            self.s.eval('x')
653
654
655class TestNames(DRYTest):
656    """ 'names', what other languages call variables... """
657
658    def test_none(self):
659        """ what to do when names isn't defined, or is 'none' """
660        with self.assertRaises(NameNotDefined):
661            self.t("a == 2", None)
662
663        self.s.names["s"] = 21
664
665        with self.assertRaises(NameNotDefined):
666            self.t("s += a", 21)
667
668        self.s.names = None
669
670        with self.assertRaises(InvalidExpression):
671            self.t('s', 21)
672
673        self.s.names = {'a': {'b': {'c': 42}}}
674
675        with self.assertRaises(AttributeDoesNotExist):
676            self.t('a.b.d**2', 42)
677
678    def test_dict(self):
679        """ using a normal dict for names lookup """
680
681        self.s.names = {'a': 42}
682        self.t("a + a", 84)
683
684        self.s.names['also'] = 100
685
686        self.t("a + also - a", 100)
687
688        # however, you can't assign to those names:
689
690        self.t("a = 200", 200)
691
692        self.assertEqual(self.s.names['a'], 42)
693
694        # or assign to lists
695
696        self.s.names['b'] = [0]
697
698        self.t("b[0] = 11", 11)
699
700        self.assertEqual(self.s.names['b'], [0])
701
702        # but you can get items from a list:
703
704        self.s.names['b'] = [6, 7]
705
706        self.t("b[0] * b[1]", 42)
707
708        # or from a dict
709
710        self.s.names['c'] = {'i': 11}
711
712        self.t("c['i']", 11)
713        self.t("c.get('i')", 11)
714        self.t("c.get('j', 11)", 11)
715        self.t("c.get('j')", None)
716
717        # you still can't assign though:
718
719        self.t("c['b'] = 99", 99)
720
721        self.assertFalse('b' in self.s.names['c'])
722
723        # and going all 'inception' on it doesn't work either:
724
725        self.s.names['c']['c'] = {'c': 11}
726
727        self.t("c['c']['c'] = 21", 21)
728
729        self.assertEqual(self.s.names['c']['c']['c'], 11)
730
731    def test_dict_attr_access(self):
732        # nested dict
733
734        self.assertEqual(self.s.ATTR_INDEX_FALLBACK, True)
735
736        self.s.names = {'a': {'b': {'c': 42}}}
737
738        self.t("a.b.c*2", 84)
739
740        self.t("a.b.c = 11", 11)
741
742        self.assertEqual(self.s.names['a']['b']['c'], 42)
743
744        # TODO: Wat?
745        self.t("a.d = 11", 11)
746
747        with self.assertRaises(KeyError):
748            self.assertEqual(self.s.names['a']['d'], 11)
749
750    def test_dict_attr_access_disabled(self):
751        # nested dict
752
753        self.s.ATTR_INDEX_FALLBACK = False
754        self.assertEqual(self.s.ATTR_INDEX_FALLBACK, False)
755
756        self.s.names = {'a': {'b': {'c': 42}}}
757
758        with self.assertRaises(simpleeval.AttributeDoesNotExist):
759            self.t("a.b.c * 2", 84)
760
761        self.t("a['b']['c'] * 2", 84)
762
763        self.assertEqual(self.s.names['a']['b']['c'], 42)
764
765
766    def test_object(self):
767        """ using an object for name lookup """
768
769        class TestObject(object):
770            @staticmethod
771            def method_thing():
772                return 42
773
774        o = TestObject()
775        o.a = 23
776        o.b = 42
777        o.c = TestObject()
778        o.c.d = 9001
779
780        self.s.names = {'o': o}
781
782        self.t('o', o)
783        self.t('o.a', 23)
784        self.t('o.b + o.c.d', 9043)
785
786        self.t('o.method_thing()', 42)
787
788        with self.assertRaises(AttributeDoesNotExist):
789            self.t('o.d', None)
790
791    def test_func(self):
792        """ using a function for 'names lookup' """
793
794        def resolver(_):
795            """ all names now equal 1024! """
796            return 1024
797
798        self.s.names = resolver
799
800        self.t("a", 1024)
801        self.t("a + b - c - d", 0)
802
803        # the function can do stuff with the value it's sent:
804
805        def my_name(node):
806            """ all names equal their textual name, twice. """
807            return node.id + node.id
808
809        self.s.names = my_name
810
811        self.t("a", "aa")
812
813    def test_from_doc(self):
814        """ the 'name first letter as value' example from the docs """
815
816        def name_handler(node):
817            """ return the alphabet number of the first letter of
818                the name's textual name """
819            return ord(node.id[0].lower()) - 96
820
821        self.s.names = name_handler
822        self.t('a', 1)
823        self.t('a + b', 3)
824
825
826class TestWhitespace(DRYTest):
827    """ test that incorrect whitespace (preceding/trailing) doesn't matter. """
828
829    def test_no_whitespace(self):
830        self.t('200 + 200', 400)
831
832    def test_trailing(self):
833        self.t('200 + 200       ', 400)
834
835    def test_preciding_whitespace(self):
836        self.t('    200 + 200', 400)
837
838    def test_preceding_tab_whitespace(self):
839        self.t("\t200 + 200", 400)
840
841    def test_preceding_mixed_whitespace(self):
842        self.t("  \t 200 + 200", 400)
843
844    def test_both_ends_whitespace(self):
845        self.t("  \t 200 + 200  ", 400)
846
847
848class TestSimpleEval(unittest.TestCase):
849    """ test the 'simple_eval' wrapper function """
850
851    def test_basic_run(self):
852        self.assertEqual(simple_eval('6*7'), 42)
853
854    def test_default_functions(self):
855        self.assertEqual(simple_eval('rand() < 1.0 and rand() > -0.01'), True)
856        self.assertEqual(simple_eval('randint(200) < 200 and rand() > 0'), True)
857
858
859class TestMethodChaining(unittest.TestCase):
860    def test_chaining_correct(self):
861        """
862            Contributed by Khalid Grandi (xaled).
863        """
864        class A(object):
865            def __init__(self):
866                self.a = "0"
867
868            def add(self, b):
869                self.a += "-add" + str(b)
870                return self
871
872            def sub(self, b):
873                self.a += "-sub" + str(b)
874                return self
875
876            def tostring(self):
877                return str(self.a)
878
879        x = A()
880        self.assertEqual(simple_eval("x.add(1).sub(2).sub(3).tostring()", names={"x": x}), "0-add1-sub2-sub3")
881
882class TestExtendingClass(unittest.TestCase):
883    """
884        It should be pretty easy to extend / inherit from the SimpleEval class,
885        to further lock things down, or unlock stuff, or whatever.
886    """
887
888    def test_methods_forbidden(self):
889        # Example from README
890        class EvalNoMethods(simpleeval.SimpleEval):
891            def _eval_call(self, node):
892                if isinstance(node.func, ast.Attribute):
893                    raise simpleeval.FeatureNotAvailable("No methods please, we're British")
894                return super(EvalNoMethods, self)._eval_call(node)
895
896        e = EvalNoMethods()
897
898        self.assertEqual(e.eval('"stuff happens"'), "stuff happens")
899        self.assertEqual(e.eval('22 + 20'), 42)
900        self.assertEqual(e.eval('int("42")'), 42)
901
902        with self.assertRaises(simpleeval.FeatureNotAvailable):
903            e.eval('"  blah  ".strip()')
904
905
906class TestExceptions(unittest.TestCase):
907    """
908        confirm a few attributes exist properly and haven't been
909        eaten by 2to3 or whatever... (see #41)
910    """
911
912    def test_functionnotdefined(self):
913        try:
914            raise FunctionNotDefined("foo", "foo in bar")
915        except FunctionNotDefined as e:
916            assert hasattr(e, 'func_name')
917            assert getattr(e, 'func_name') == 'foo'
918            assert hasattr(e, 'expression')
919            assert getattr(e, 'expression') == 'foo in bar'
920
921    def test_namenotdefined(self):
922        try:
923            raise NameNotDefined("foo", "foo in bar")
924        except NameNotDefined as e:
925            assert hasattr(e, 'name')
926            assert getattr(e, 'name') == 'foo'
927            assert hasattr(e, 'expression')
928            assert getattr(e, 'expression') == 'foo in bar'
929
930
931    def test_attributedoesnotexist(self):
932        try:
933            raise AttributeDoesNotExist("foo", "foo in bar")
934        except AttributeDoesNotExist as e:
935            assert hasattr(e, 'attr')
936            assert getattr(e, 'attr') == 'foo'
937            assert hasattr(e, 'expression')
938            assert getattr(e, 'expression') == 'foo in bar'
939
940class TestUnusualComparisons(DRYTest):
941    def test_custom_comparison_returner(self):
942        class Blah(object):
943            def __gt__(self, other):
944                return self
945
946        b = Blah()
947        self.s.names = {'b': b}
948        self.t('b > 2', b)
949
950    def test_custom_comparison_doesnt_return_boolable(self):
951        """
952            SqlAlchemy, bless it's cotton socks, returns BinaryExpression objects
953            when asking for comparisons between things.  These BinaryExpressions
954            raise a TypeError if you try and check for Truthyiness.
955        """
956        class BinaryExpression(object):
957            def __init__(self, value):
958                self.value = value
959            def __eq__(self, other):
960                return self.value == getattr(other, 'value', other)
961            def __repr__(self):
962                return '<BinaryExpression:{}>'.format(self.value)
963            def __bool__(self):
964                # This is the only important part, to match SqlAlchemy - the rest
965                # of the methods are just to make testing a bit easier...
966                raise TypeError("Boolean value of this clause is not defined")
967
968        class Blah(object):
969            def __gt__(self, other):
970                return BinaryExpression('GT')
971            def __lt__(self, other):
972                return BinaryExpression('LT')
973
974        b = Blah()
975        self.s.names = {'b': b}
976        # This should not crash:
977        e = eval('b > 2', self.s.names)
978
979        self.t('b > 2', BinaryExpression('GT'))
980        self.t('1 < 5 > b', BinaryExpression('LT'))
981
982class TestGetItemUnhappy(DRYTest):
983    # Again, SqlAlchemy doing unusual things.  Throwing it's own errors, rather than
984    # expected types...
985
986    def test_getitem_not_implemented(self):
987        class Meh(object):
988            def __getitem__(self, key):
989                raise NotImplementedError("booya!")
990            def __getattr__(self, key):
991                return 42
992
993        m = Meh()
994
995        self.assertEqual(m.anything, 42)
996        with self.assertRaises(NotImplementedError):
997            m['nothing']
998
999        self.s.names = {"m": m}
1000        self.t("m.anything", 42)
1001
1002        with self.assertRaises(NotImplementedError):
1003            self.t("m['nothing']", None)
1004
1005        self.s.ATTR_INDEX_FALLBACK = False
1006
1007        self.t("m.anything", 42)
1008
1009        with self.assertRaises(NotImplementedError):
1010            self.t("m['nothing']", None)
1011
1012
1013class TestShortCircuiting(DRYTest):
1014    def test_shortcircuit_if(self):
1015        x = []
1016        def foo(y):
1017            x.append(y)
1018            return y
1019        self.s.functions = {'foo': foo}
1020        self.t('foo(1) if foo(2) else foo(3)', 1)
1021        self.assertListEqual(x, [2, 1])
1022
1023        x = []
1024        self.t('42 if True else foo(99)', 42)
1025        self.assertListEqual(x, [])
1026
1027    def test_shortcircuit_comparison(self):
1028        x = []
1029        def foo(y):
1030            x.append(y)
1031            return y
1032        self.s.functions = {'foo': foo}
1033        self.t('foo(11) < 12', True)
1034        self.assertListEqual(x, [11])
1035        x = []
1036
1037        self.t('1 > 2 < foo(22)', False)
1038        self.assertListEqual(x, [])
1039
1040
1041class TestDisallowedFunctions(DRYTest):
1042    def test_functions_are_disallowed_at_init(self):
1043        DISALLOWED = [type, isinstance, eval, getattr, setattr, help, repr, compile, open]
1044        if simpleeval.PYTHON3:
1045            exec('DISALLOWED.append(exec)') # exec is not a function in Python2...
1046
1047        for f in simpleeval.DISALLOW_FUNCTIONS:
1048            assert f in DISALLOWED
1049
1050        for x in DISALLOWED:
1051            with self.assertRaises(FeatureNotAvailable):
1052                s = SimpleEval(functions ={'foo': x})
1053
1054    def test_functions_are_disallowed_in_expressions(self):
1055        DISALLOWED = [type, isinstance, eval, getattr, setattr, help, repr, compile, open]
1056
1057        if simpleeval.PYTHON3:
1058            exec('DISALLOWED.append(exec)') # exec is not a function in Python2...
1059
1060        for f in simpleeval.DISALLOW_FUNCTIONS:
1061            assert f in DISALLOWED
1062
1063
1064        DF = simpleeval.DEFAULT_FUNCTIONS.copy()
1065
1066        for x in DISALLOWED:
1067            simpleeval.DEFAULT_FUNCTIONS = DF.copy()
1068            with self.assertRaises(FeatureNotAvailable):
1069                s = SimpleEval()
1070                s.functions['foo'] = x
1071                s.eval('foo(42)')
1072
1073        simpleeval.DEFAULT_FUNCTIONS = DF.copy()
1074
1075if __name__ == '__main__':  # pragma: no cover
1076    unittest.main()
1077