1# Copyright (c) 2017, The MITRE Corporation. All rights reserved.
2# See LICENSE.txt for complete terms.
3
4import contextlib
5import functools
6import itertools
7import json
8import warnings
9
10import cybox.utils
11from mixbox.binding_utils import ExternalEncoding
12from mixbox.entities import NamespaceCollector
13from mixbox.vendor.six import iteritems, text_type
14
15from stix.utils import silence_warnings
16
17
18@contextlib.contextmanager
19def ctx_assert_warnings(self):
20    """Context manager for verifying that a block of code has raised a
21    warning.
22
23    """
24    with warnings.catch_warnings(record=True) as w:
25        # Raise all warnings
26        warnings.simplefilter('always')
27
28        # Return to caller
29        yield
30
31        # Assert that a warning was raised.
32        self.assertTrue(len(w) > 0)
33
34
35def assert_warnings(func):
36    """Test function decorator which asserts that a warning has been raised
37    during the execution of the test.
38
39    """
40    @functools.wraps(func)
41    def inner(*args, **kwargs):
42        self = args[0]
43        with ctx_assert_warnings(self):
44            return func(*args, **kwargs)
45
46    return inner
47
48
49def round_trip_dict(cls, dict_):
50    obj = cls.object_from_dict(dict_)
51    dict2 = cls.dict_from_object(obj)
52
53    api_obj = cls.from_dict(dict_)
54    dict2 = cls.to_dict(api_obj)
55    return dict2
56
57
58def round_trip(o, output=False, list_=False):
59    """ Performs all eight conversions to verify import/export functionality.
60
61    1. cybox.Entity -> dict/list
62    2. dict/list -> JSON string
63    3. JSON string -> dict/list
64    4. dict/list -> cybox.Entity
65    5. cybox.Entity -> Bindings Object
66    6. Bindings Object -> XML String
67    7. XML String -> Bindings Object
68    8. Bindings object -> cybox.Entity
69
70    It returns the final object, so tests which call this function can check to
71    ensure it was not modified during any of the transforms.
72    """
73
74    klass = o.__class__
75    if output:
76        print("Class: ", klass)
77        print("-" * 40)
78
79    # 1. cybox.Entity -> dict/list
80    if list_:
81        d = o.to_list()
82    else:
83        d = o.to_dict()
84
85    # 2. dict/list -> JSON string
86    json_string = json.dumps(d)
87
88    if output:
89        print(json_string)
90        print("-" * 40)
91
92    # Before parsing the JSON, make sure the cache is clear
93    cybox.utils.cache_clear()
94
95    # 3. JSON string -> dict/list
96    d2 = json.loads(json_string)
97
98    # 4. dict/list -> cybox.Entity
99    if list_:
100        o2 = klass.from_list(d2)
101    else:
102        o2 = klass.from_dict(d2)
103
104    # 5. Entity -> Bindings Object
105    ns_info = NamespaceCollector()
106    xobj = o2.to_obj(ns_info=ns_info)
107
108    try:
109        # 6. Bindings Object -> XML String
110        xml_string = o2.to_xml(encoding=ExternalEncoding)
111
112        if not isinstance(xml_string, text_type):
113            xml_string = xml_string.decode(ExternalEncoding)
114
115    except KeyError as ex:
116        print(str(ex))
117        ns_info.finalize()
118        print(ns_info.binding_namespaces)
119        raise ex
120
121    if output:
122        print(xml_string)
123        print("-" * 40)
124
125    # Before parsing the XML, make sure the cache is clear
126    cybox.utils.cache_clear()
127
128    # 7. XML String -> Bindings Object
129    xobj2 = klass._binding.parseString(xml_string)
130
131    # 8. Bindings object -> cybox.Entity
132    o3 = klass.from_obj(xobj2)
133
134    return o3
135
136
137class EntityTestCase(object):
138    """A base class for testing STIX Entities"""
139
140    def setUp(self):
141        self.assertNotEqual(self.klass, None)
142        self.assertNotEqual(self._full_dict, None)
143
144    @silence_warnings
145    def test_round_trip_full_dict(self):
146        # Don't run this test on the base class
147        if type(self) is EntityTestCase:
148            return
149
150        dict2 = round_trip_dict(self.klass, self._full_dict)
151        self.maxDiff = None
152        self.assertEqual(self._full_dict, dict2)
153
154    def _combine(self, d):
155        items = itertools.chain(
156            iteritems(self._full_dict),
157            iteritems(d)
158        )
159
160        return dict(items)
161
162    @silence_warnings
163    def test_round_trip_full(self):
164        # Don't run this test on the base class
165        if type(self) is EntityTestCase:
166            return
167
168        ent = self.klass.from_dict(self._full_dict)
169        ent2 = round_trip(ent, output=True)
170
171    @silence_warnings
172    def _test_round_trip_dict(self, input):
173        dict2 = round_trip_dict(self.klass, input)
174        self.maxDiff = None
175        self.assertEqual(input, dict2)
176
177    @silence_warnings
178    def _test_partial_dict(self, partial):
179        d = self._combine(partial)
180        self._test_round_trip_dict(d)
181
182
183class TypedListTestCase(object):
184
185    @silence_warnings
186    def test_round_trip_rt(self):
187        if type(self) is TypedListTestCase:
188            return
189
190        obj = self.klass.from_dict(self._full_dict)
191        dict2 = obj.to_dict()
192        self.assertEqual(self._full_dict, dict2)
193