1import os
2import re
3import sys
4import unittest
5
6# Note that nothing should be imported from AccessControl, and in particular
7# nothing from ZopeGuards.py.  Transformed code may need several wrappers
8# in order to run at all, and most of the production wrappers are defined
9# in ZopeGuards.  But RestrictedPython isn't supposed to depend on
10# AccessControl, so we need to define throwaway wrapper implementations
11# here instead.
12
13from RestrictedPython import compile_restricted, PrintCollector
14from RestrictedPython.Eval import RestrictionCapableEval
15from RestrictedPython.tests import restricted_module, verify
16from RestrictedPython.RCompile import RModule, RFunction
17
18try:
19    __file__
20except NameError:
21    __file__ = os.path.abspath(sys.argv[1])
22_FILEPATH = os.path.abspath( __file__ )
23_HERE = os.path.dirname( _FILEPATH )
24
25def _getindent(line):
26    """Returns the indentation level of the given line."""
27    indent = 0
28    for c in line:
29        if c == ' ': indent = indent + 1
30        elif c == '\t': indent = indent + 8
31        else: break
32    return indent
33
34def find_source(fn, func):
35    """Given a func_code object, this function tries to find and return
36    the python source code of the function.  Originally written by
37    Harm van der Heijden (H.v.d.Heijden@phys.tue.nl)"""
38    f = open(fn,"r")
39    for i in range(func.co_firstlineno):
40        line = f.readline()
41    ind = _getindent(line)
42    msg = ""
43    while line:
44        msg = msg + line
45        line = f.readline()
46        # the following should be <= ind, but then we get
47        # confused by multiline docstrings. Using == works most of
48        # the time... but not always!
49        if _getindent(line) == ind: break
50    f.close()
51    return fn, msg
52
53def get_source(func):
54    """Less silly interface to find_source"""
55    file = func.func_globals['__file__']
56    if file.endswith('.pyc'):
57        file = file[:-1]
58    source = find_source(file, func.func_code)[1]
59    assert source.strip(), "Source should not be empty!"
60    return source
61
62def create_rmodule():
63    global rmodule
64    fn = os.path.join(_HERE, 'restricted_module.py')
65    f = open(fn, 'r')
66    source = f.read()
67    f.close()
68    # Sanity check
69    compile(source, fn, 'exec')
70    # Now compile it for real
71    code = compile_restricted(source, fn, 'exec')
72    rmodule = {'__builtins__':{'__import__':__import__, 'None':None,
73                               '__name__': 'restricted_module'}}
74    builtins = getattr(__builtins__, '__dict__', __builtins__)
75    for name in ('map', 'reduce', 'int', 'pow', 'range', 'filter',
76                 'len', 'chr', 'ord',
77                 ):
78        rmodule[name] = builtins[name]
79    exec code in rmodule
80
81class AccessDenied (Exception): pass
82
83DisallowedObject = []
84
85class RestrictedObject:
86    disallowed = DisallowedObject
87    allowed = 1
88    _ = 2
89    __ = 3
90    _some_attr = 4
91    __some_other_attr__ = 5
92    s = 'Another day, another test...'
93    __writeable_attrs__ = ('writeable',)
94
95    def __getitem__(self, idx):
96        if idx == 'protected':
97            raise AccessDenied
98        elif idx == 0 or idx == 'safe':
99            return 1
100        elif idx == 1:
101            return DisallowedObject
102        else:
103            return self.s[idx]
104
105    def __getslice__(self, lo, hi):
106        return self.s[lo:hi]
107
108    def __len__(self):
109        return len(self.s)
110
111    def __setitem__(self, idx, v):
112        if idx == 'safe':
113            self.safe = v
114        else:
115            raise AccessDenied
116
117    def __setslice__(self, lo, hi, value):
118        raise AccessDenied
119
120    write = DisallowedObject
121
122
123def guarded_getattr(ob, name):
124    v = getattr(ob, name)
125    if v is DisallowedObject:
126        raise AccessDenied
127    return v
128
129SliceType = type(slice(0))
130def guarded_getitem(ob, index):
131    if type(index) is SliceType and index.step is None:
132        start = index.start
133        stop = index.stop
134        if start is None:
135            start = 0
136        if stop is None:
137            v = ob[start:]
138        else:
139            v = ob[start:stop]
140    else:
141        v = ob[index]
142    if v is DisallowedObject:
143        raise AccessDenied
144    return v
145
146def minimal_import(name, _globals, _locals, names):
147    if name != "__future__":
148        raise ValueError, "Only future imports are allowed"
149    import __future__
150    return __future__
151
152
153class TestGuard:
154    '''A guard class'''
155    def __init__(self, _ob, write=None):
156        self.__dict__['_ob'] = _ob
157
158    # Write guard methods
159
160    def __setattr__(self, name, value):
161        _ob = self.__dict__['_ob']
162        writeable = getattr(_ob, '__writeable_attrs__', ())
163        if name not in writeable:
164            raise AccessDenied
165        if name[:5] == 'func_':
166            raise AccessDenied
167        setattr(_ob, name, value)
168
169    def __setitem__(self, index, value):
170        _ob = self.__dict__['_ob']
171        _ob[index] = value
172
173    def __setslice__(self, lo, hi, value):
174        _ob = self.__dict__['_ob']
175        _ob[lo:hi] = value
176
177# A wrapper for _apply_.
178apply_wrapper_called = []
179def apply_wrapper(func, *args, **kws):
180    apply_wrapper_called.append('yes')
181    return func(*args, **kws)
182
183inplacevar_wrapper_called = {}
184def inplacevar_wrapper(op, x, y):
185    inplacevar_wrapper_called[op] = x, y
186    # This is really lame.  But it's just a test. :)
187    globs = {'x': x, 'y': y}
188    exec 'x'+op+'y' in globs
189    return globs['x']
190
191class RestrictionTests(unittest.TestCase):
192    def execFunc(self, name, *args, **kw):
193        func = rmodule[name]
194        verify.verify(func.func_code)
195        func.func_globals.update({'_getattr_': guarded_getattr,
196                                  '_getitem_': guarded_getitem,
197                                  '_write_': TestGuard,
198                                  '_print_': PrintCollector,
199        # I don't want to write something as involved as ZopeGuard's
200        # SafeIter just for these tests.  Using the builtin list() function
201        # worked OK for everything the tests did at the time this was added,
202        # but may fail in the future.  If Python 2.1 is no longer an
203        # interesting platform then, using 2.2's builtin iter() here should
204        # work for everything.
205                                  '_getiter_': list,
206                                  '_apply_': apply_wrapper,
207                                  '_inplacevar_': inplacevar_wrapper,
208                                  })
209        return func(*args, **kw)
210
211    def checkPrint(self):
212        for i in range(2):
213            res = self.execFunc('print%s' % i)
214            self.assertEqual(res, 'Hello, world!')
215
216    def checkPrintToNone(self):
217        try:
218            res = self.execFunc('printToNone')
219        except AttributeError:
220            # Passed.  "None" has no "write" attribute.
221            pass
222        else:
223            self.fail(0, res)
224
225    def checkPrintStuff(self):
226        res = self.execFunc('printStuff')
227        self.assertEqual(res, 'a b c')
228
229    def checkPrintLines(self):
230        res = self.execFunc('printLines')
231        self.assertEqual(res,  '0 1 2\n3 4 5\n6 7 8\n')
232
233    def checkPrimes(self):
234        res = self.execFunc('primes')
235        self.assertEqual(res, '[2, 3, 5, 7, 11, 13, 17, 19]')
236
237    def checkAllowedSimple(self):
238        res = self.execFunc('allowed_simple')
239        self.assertEqual(res, 'abcabcabc')
240
241    def checkAllowedRead(self):
242        self.execFunc('allowed_read', RestrictedObject())
243
244    def checkAllowedWrite(self):
245        self.execFunc('allowed_write', RestrictedObject())
246
247    def checkAllowedArgs(self):
248        self.execFunc('allowed_default_args', RestrictedObject())
249
250    def checkTryMap(self):
251        res = self.execFunc('try_map')
252        self.assertEqual(res, "[2, 3, 4]")
253
254    def checkApply(self):
255        del apply_wrapper_called[:]
256        res = self.execFunc('try_apply')
257        self.assertEqual(apply_wrapper_called, ["yes"])
258        self.assertEqual(res, "321")
259
260    def checkInplace(self):
261        inplacevar_wrapper_called.clear()
262        res = self.execFunc('try_inplace')
263        self.assertEqual(inplacevar_wrapper_called['+='], (1, 3))
264
265    def checkDenied(self):
266        for k in rmodule.keys():
267            if k[:6] == 'denied':
268                try:
269                    self.execFunc(k, RestrictedObject())
270                except AccessDenied:
271                    # Passed the test
272                    pass
273                else:
274                    self.fail('%s() did not trip security' % k)
275
276    def checkSyntaxSecurity(self):
277        self._checkSyntaxSecurity('security_in_syntax.py')
278        if sys.version_info >= (2, 6):
279            self._checkSyntaxSecurity('security_in_syntax26.py')
280        if sys.version_info >= (2, 7):
281            self._checkSyntaxSecurity('security_in_syntax27.py')
282
283    def _checkSyntaxSecurity(self, mod_name):
284        # Ensures that each of the functions in security_in_syntax.py
285        # throws a SyntaxError when using compile_restricted.
286        fn = os.path.join(_HERE, mod_name)
287        f = open(fn, 'r')
288        source = f.read()
289        f.close()
290        # Unrestricted compile.
291        code = compile(source, fn, 'exec')
292        m = {'__builtins__': {'__import__':minimal_import}}
293        exec code in m
294        for k, v in m.items():
295            if hasattr(v, 'func_code'):
296                filename, source = find_source(fn, v.func_code)
297                # Now compile it with restrictions
298                try:
299                    code = compile_restricted(source, filename, 'exec')
300                except SyntaxError:
301                    # Passed the test.
302                    pass
303                else:
304                    self.fail('%s should not have compiled' % k)
305
306    def checkOrderOfOperations(self):
307        res = self.execFunc('order_of_operations')
308        self.assertEqual(res, 0)
309
310    def checkRot13(self):
311        res = self.execFunc('rot13', 'Zope is k00l')
312        self.assertEqual(res, 'Mbcr vf x00y')
313
314    def checkNestedScopes1(self):
315        res = self.execFunc('nested_scopes_1')
316        self.assertEqual(res, 2)
317
318    def checkUnrestrictedEval(self):
319        expr = RestrictionCapableEval("{'a':[m.pop()]}['a'] + [m[0]]")
320        v = [12, 34]
321        expect = v[:]
322        expect.reverse()
323        res = expr.eval({'m':v})
324        self.assertEqual(res, expect)
325        v = [12, 34]
326        res = expr(m=v)
327        self.assertEqual(res, expect)
328
329    def checkStackSize(self):
330        for k, rfunc in rmodule.items():
331            if not k.startswith('_') and hasattr(rfunc, 'func_code'):
332                rss = rfunc.func_code.co_stacksize
333                ss = getattr(restricted_module, k).func_code.co_stacksize
334                self.failUnless(
335                    rss >= ss, 'The stack size estimate for %s() '
336                    'should have been at least %d, but was only %d'
337                    % (k, ss, rss))
338
339
340    def checkBeforeAndAfter(self):
341        from RestrictedPython.RCompile import RModule
342        from RestrictedPython.tests import before_and_after
343        from compiler import parse
344
345        defre = re.compile(r'def ([_A-Za-z0-9]+)_(after|before)\(')
346
347        beforel = [name for name in before_and_after.__dict__
348                   if name.endswith("_before")]
349
350        for name in beforel:
351            before = getattr(before_and_after, name)
352            before_src = get_source(before)
353            before_src = re.sub(defre, r'def \1(', before_src)
354            rm = RModule(before_src, '')
355            tree_before = rm._get_tree()
356
357            after = getattr(before_and_after, name[:-6]+'after')
358            after_src = get_source(after)
359            after_src = re.sub(defre, r'def \1(', after_src)
360            tree_after = parse(after_src)
361
362            self.assertEqual(str(tree_before), str(tree_after))
363
364            rm.compile()
365            verify.verify(rm.getCode())
366
367    def _checkBeforeAndAfter(self, mod):
368            from RestrictedPython.RCompile import RModule
369            from compiler import parse
370
371            defre = re.compile(r'def ([_A-Za-z0-9]+)_(after|before)\(')
372
373            beforel = [name for name in mod.__dict__
374                       if name.endswith("_before")]
375
376            for name in beforel:
377                before = getattr(mod, name)
378                before_src = get_source(before)
379                before_src = re.sub(defre, r'def \1(', before_src)
380                rm = RModule(before_src, '')
381                tree_before = rm._get_tree()
382
383                after = getattr(mod, name[:-6]+'after')
384                after_src = get_source(after)
385                after_src = re.sub(defre, r'def \1(', after_src)
386                tree_after = parse(after_src)
387
388                self.assertEqual(str(tree_before), str(tree_after))
389
390                rm.compile()
391                verify.verify(rm.getCode())
392
393    if sys.version_info[:2] >= (2, 4):
394        def checkBeforeAndAfter24(self):
395            from RestrictedPython.tests import before_and_after24
396            self._checkBeforeAndAfter(before_and_after24)
397
398    if sys.version_info[:2] >= (2, 5):
399        def checkBeforeAndAfter25(self):
400            from RestrictedPython.tests import before_and_after25
401            self._checkBeforeAndAfter(before_and_after25)
402
403    if sys.version_info[:2] >= (2, 6):
404        def checkBeforeAndAfter26(self):
405            from RestrictedPython.tests import before_and_after26
406            self._checkBeforeAndAfter(before_and_after26)
407
408    if sys.version_info[:2] >= (2, 7):
409        def checkBeforeAndAfter27(self):
410            from RestrictedPython.tests import before_and_after27
411            self._checkBeforeAndAfter(before_and_after27)
412
413    def _compile_file(self, name):
414        path = os.path.join(_HERE, name)
415        f = open(path, "r")
416        source = f.read()
417        f.close()
418
419        co = compile_restricted(source, path, "exec")
420        verify.verify(co)
421        return co
422
423    def checkUnpackSequence(self):
424        co = self._compile_file("unpack.py")
425        calls = []
426        def getiter(seq):
427            calls.append(seq)
428            return list(seq)
429        globals = {"_getiter_": getiter, '_inplacevar_': inplacevar_wrapper}
430        exec co in globals, {}
431        # The comparison here depends on the exact code that is
432        # contained in unpack.py.
433        # The test doing implicit unpacking in an "except:" clause is
434        # a pain, because there are two levels of unpacking, and the top
435        # level is unpacking the specific TypeError instance constructed
436        # by the test.  We have to worm around that one.
437        ineffable =  "a TypeError instance"
438        expected = [[1, 2],
439                    (1, 2),
440                    "12",
441                    [1],
442                    [1, [2, 3], 4],
443                    [2, 3],
444                    (1, (2, 3), 4),
445                    (2, 3),
446                    [1, 2, 3],
447                    2,
448                    ('a', 'b'),
449                    ((1, 2), (3, 4)), (1, 2),
450                    ((1, 2), (3, 4)), (3, 4),
451                    ineffable, [42, 666],
452                    [[0, 1], [2, 3], [4, 5]], [0, 1], [2, 3], [4, 5],
453                    ([[[1, 2]]], [[[3, 4]]]), [[[1, 2]]], [[1, 2]], [1, 2],
454                                              [[[3, 4]]], [[3, 4]], [3, 4],
455                    ]
456        i = expected.index(ineffable)
457        self.assert_(isinstance(calls[i], TypeError))
458        expected[i] = calls[i]
459        self.assertEqual(calls, expected)
460
461    def checkUnpackSequenceExpression(self):
462        co = compile_restricted("[x for x, y in [(1, 2)]]", "<string>", "eval")
463        verify.verify(co)
464        calls = []
465        def getiter(s):
466            calls.append(s)
467            return list(s)
468        globals = {"_getiter_": getiter}
469        exec co in globals, {}
470        self.assertEqual(calls, [[(1,2)], (1, 2)])
471
472    def checkUnpackSequenceSingle(self):
473        co = compile_restricted("x, y = 1, 2", "<string>", "single")
474        verify.verify(co)
475        calls = []
476        def getiter(s):
477            calls.append(s)
478            return list(s)
479        globals = {"_getiter_": getiter}
480        exec co in globals, {}
481        self.assertEqual(calls, [(1, 2)])
482
483    def checkClass(self):
484        getattr_calls = []
485        setattr_calls = []
486
487        def test_getattr(obj, attr):
488            getattr_calls.append(attr)
489            return getattr(obj, attr)
490
491        def test_setattr(obj):
492            setattr_calls.append(obj.__class__.__name__)
493            return obj
494
495        co = self._compile_file("class.py")
496        globals = {"_getattr_": test_getattr,
497                   "_write_": test_setattr,
498                   }
499        exec co in globals, {}
500        # Note that the getattr calls don't correspond to the method call
501        # order, because the x.set method is fetched before its arguments
502        # are evaluated.
503        self.assertEqual(getattr_calls,
504                         ["set", "set", "get", "state", "get", "state"])
505        self.assertEqual(setattr_calls, ["MyClass", "MyClass"])
506
507    def checkLambda(self):
508        co = self._compile_file("lambda.py")
509        exec co in {}, {}
510
511    def checkEmpty(self):
512        rf = RFunction("", "", "issue945", "empty.py", {})
513        rf.parse()
514        rf2 = RFunction("", "# still empty\n\n# by", "issue945", "empty.py", {})
515        rf2.parse()
516
517    def checkSyntaxError(self):
518        err = ("def f(x, y):\n"
519               "    if x, y < 2 + 1:\n"
520               "        return x + y\n"
521               "    else:\n"
522               "        return x - y\n")
523        self.assertRaises(SyntaxError,
524                          compile_restricted, err, "<string>", "exec")
525
526    # these two tests check that source code with Windows line
527    # endings still works.
528
529    def checkLineEndingsRFunction(self):
530        from RestrictedPython.RCompile import RFunction
531        gen = RFunction(
532            p='',
533            body='# testing\r\nprint "testing"\r\nreturn printed\n',
534            name='test',
535            filename='<test>',
536            globals=(),
537            )
538        gen.mode = 'exec'
539        # if the source has any line ending other than \n by the time
540        # parse() is called, then you'll get a syntax error.
541        gen.parse()
542
543    def checkLineEndingsRestrictedCompileMode(self):
544        from RestrictedPython.RCompile import RestrictedCompileMode
545        gen = RestrictedCompileMode(
546            '# testing\r\nprint "testing"\r\nreturn printed\n',
547            '<testing>'
548            )
549        gen.mode='exec'
550        # if the source has any line ending other than \n by the time
551        # parse() is called, then you'll get a syntax error.
552        gen.parse()
553
554    def checkCollector2295(self):
555        from RestrictedPython.RCompile import RestrictedCompileMode
556        gen = RestrictedCompileMode(
557            'if False:\n  pass\n# Me Grok, Say Hi',
558            '<testing>'
559            )
560        gen.mode='exec'
561        # if the source has any line ending other than \n by the time
562        # parse() is called, then you'll get a syntax error.
563        gen.parse()
564
565
566create_rmodule()
567
568def test_suite():
569    return unittest.makeSuite(RestrictionTests, 'check')
570
571if __name__=='__main__':
572    unittest.main(defaultTest="test_suite")
573