1# Copyright (c) 2017, The MITRE Corporation. All rights reserved. 2# See LICENSE.txt for complete terms. 3 4import contextlib 5import functools 6import itertools 7import json 8import warnings 9 10import cybox.utils 11from mixbox.binding_utils import ExternalEncoding 12from mixbox.entities import NamespaceCollector 13from mixbox.vendor.six import iteritems, text_type 14 15from stix.utils import silence_warnings 16 17 18@contextlib.contextmanager 19def ctx_assert_warnings(self): 20 """Context manager for verifying that a block of code has raised a 21 warning. 22 23 """ 24 with warnings.catch_warnings(record=True) as w: 25 # Raise all warnings 26 warnings.simplefilter('always') 27 28 # Return to caller 29 yield 30 31 # Assert that a warning was raised. 32 self.assertTrue(len(w) > 0) 33 34 35def assert_warnings(func): 36 """Test function decorator which asserts that a warning has been raised 37 during the execution of the test. 38 39 """ 40 @functools.wraps(func) 41 def inner(*args, **kwargs): 42 self = args[0] 43 with ctx_assert_warnings(self): 44 return func(*args, **kwargs) 45 46 return inner 47 48 49def round_trip_dict(cls, dict_): 50 obj = cls.object_from_dict(dict_) 51 dict2 = cls.dict_from_object(obj) 52 53 api_obj = cls.from_dict(dict_) 54 dict2 = cls.to_dict(api_obj) 55 return dict2 56 57 58def round_trip(o, output=False, list_=False): 59 """ Performs all eight conversions to verify import/export functionality. 60 61 1. cybox.Entity -> dict/list 62 2. dict/list -> JSON string 63 3. JSON string -> dict/list 64 4. dict/list -> cybox.Entity 65 5. cybox.Entity -> Bindings Object 66 6. Bindings Object -> XML String 67 7. XML String -> Bindings Object 68 8. Bindings object -> cybox.Entity 69 70 It returns the final object, so tests which call this function can check to 71 ensure it was not modified during any of the transforms. 72 """ 73 74 klass = o.__class__ 75 if output: 76 print("Class: ", klass) 77 print("-" * 40) 78 79 # 1. cybox.Entity -> dict/list 80 if list_: 81 d = o.to_list() 82 else: 83 d = o.to_dict() 84 85 # 2. dict/list -> JSON string 86 json_string = json.dumps(d) 87 88 if output: 89 print(json_string) 90 print("-" * 40) 91 92 # Before parsing the JSON, make sure the cache is clear 93 cybox.utils.cache_clear() 94 95 # 3. JSON string -> dict/list 96 d2 = json.loads(json_string) 97 98 # 4. dict/list -> cybox.Entity 99 if list_: 100 o2 = klass.from_list(d2) 101 else: 102 o2 = klass.from_dict(d2) 103 104 # 5. Entity -> Bindings Object 105 ns_info = NamespaceCollector() 106 xobj = o2.to_obj(ns_info=ns_info) 107 108 try: 109 # 6. Bindings Object -> XML String 110 xml_string = o2.to_xml(encoding=ExternalEncoding) 111 112 if not isinstance(xml_string, text_type): 113 xml_string = xml_string.decode(ExternalEncoding) 114 115 except KeyError as ex: 116 print(str(ex)) 117 ns_info.finalize() 118 print(ns_info.binding_namespaces) 119 raise ex 120 121 if output: 122 print(xml_string) 123 print("-" * 40) 124 125 # Before parsing the XML, make sure the cache is clear 126 cybox.utils.cache_clear() 127 128 # 7. XML String -> Bindings Object 129 xobj2 = klass._binding.parseString(xml_string) 130 131 # 8. Bindings object -> cybox.Entity 132 o3 = klass.from_obj(xobj2) 133 134 return o3 135 136 137class EntityTestCase(object): 138 """A base class for testing STIX Entities""" 139 140 def setUp(self): 141 self.assertNotEqual(self.klass, None) 142 self.assertNotEqual(self._full_dict, None) 143 144 @silence_warnings 145 def test_round_trip_full_dict(self): 146 # Don't run this test on the base class 147 if type(self) is EntityTestCase: 148 return 149 150 dict2 = round_trip_dict(self.klass, self._full_dict) 151 self.maxDiff = None 152 self.assertEqual(self._full_dict, dict2) 153 154 def _combine(self, d): 155 items = itertools.chain( 156 iteritems(self._full_dict), 157 iteritems(d) 158 ) 159 160 return dict(items) 161 162 @silence_warnings 163 def test_round_trip_full(self): 164 # Don't run this test on the base class 165 if type(self) is EntityTestCase: 166 return 167 168 ent = self.klass.from_dict(self._full_dict) 169 ent2 = round_trip(ent, output=True) 170 171 @silence_warnings 172 def _test_round_trip_dict(self, input): 173 dict2 = round_trip_dict(self.klass, input) 174 self.maxDiff = None 175 self.assertEqual(input, dict2) 176 177 @silence_warnings 178 def _test_partial_dict(self, partial): 179 d = self._combine(partial) 180 self._test_round_trip_dict(d) 181 182 183class TypedListTestCase(object): 184 185 @silence_warnings 186 def test_round_trip_rt(self): 187 if type(self) is TypedListTestCase: 188 return 189 190 obj = self.klass.from_dict(self._full_dict) 191 dict2 = obj.to_dict() 192 self.assertEqual(self._full_dict, dict2) 193