1import pytest
2
3from stix2.datastore import CompositeDataSource, make_id
4from stix2.datastore.filters import Filter
5from stix2.datastore.memory import MemorySink, MemorySource, MemoryStore
6from stix2.utils import parse_into_datetime
7from stix2.v21.common import TLP_GREEN
8
9
10def test_add_remove_composite_datasource():
11    cds = CompositeDataSource()
12    ds1 = MemorySource()
13    ds2 = MemorySource()
14    ds3 = MemorySink()
15
16    with pytest.raises(TypeError) as excinfo:
17        cds.add_data_sources([ds1, ds2, ds1, ds3])
18    assert str(excinfo.value) == (
19        "DataSource (to be added) is not of type "
20        "stix2.DataSource. DataSource type is '<class 'stix2.datastore.memory.MemorySink'>'"
21    )
22
23    cds.add_data_sources([ds1, ds2, ds1])
24
25    assert len(cds.get_all_data_sources()) == 2
26
27    cds.remove_data_sources([ds1.id, ds2.id])
28
29    assert len(cds.get_all_data_sources()) == 0
30
31
32def test_composite_datasource_operations(stix_objs1, stix_objs2):
33    BUNDLE1 = dict(
34        id="bundle--%s" % make_id(),
35        objects=stix_objs1,
36        type="bundle",
37    )
38    cds1 = CompositeDataSource()
39    ds1_1 = MemorySource(stix_data=BUNDLE1)
40    ds1_2 = MemorySource(stix_data=stix_objs2)
41
42    cds2 = CompositeDataSource()
43    ds2_1 = MemorySource(stix_data=BUNDLE1)
44    ds2_2 = MemorySource(stix_data=stix_objs2)
45
46    cds1.add_data_sources([ds1_1, ds1_2])
47    cds2.add_data_sources([ds2_1, ds2_2])
48
49    indicators = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
50
51    # In STIX_OBJS2 changed the 'modified' property to a later time...
52    assert len(indicators) == 3
53
54    cds1.add_data_sources([cds2])
55
56    indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
57
58    assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
59    assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z")
60    assert indicator["type"] == "indicator"
61
62    sco = cds1.get("url--cc1deced-d99b-4d72-9268-8182420cb2fd")
63    assert sco["id"] == "url--cc1deced-d99b-4d72-9268-8182420cb2fd"
64
65    scos = cds1.all_versions("url--cc1deced-d99b-4d72-9268-8182420cb2fd")
66    assert len(scos) == 1
67    assert scos[0]["id"] == "url--cc1deced-d99b-4d72-9268-8182420cb2fd"
68
69    scos = cds1.query([Filter("value", "=", "http://example.com/")])
70    assert len(scos) == 1
71    assert scos[0]["id"] == "url--cc1deced-d99b-4d72-9268-8182420cb2fd"
72
73    query1 = [
74        Filter("type", "=", "indicator"),
75    ]
76
77    query2 = [
78        Filter("valid_from", "=", "2017-01-27T13:49:53.935382Z"),
79    ]
80
81    cds1.filters.add(query2)
82
83    results = cds1.query(query1)
84
85    # STIX_OBJS2 has indicator with later time, one with different id, one with
86    # original time in STIX_OBJS1
87    assert len(results) == 4
88
89    indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
90
91    assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
92    assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z")
93    assert indicator["type"] == "indicator"
94
95    results = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
96    assert len(results) == 3
97
98    # Since we have filters already associated with our CompositeSource providing
99    # nothing returns the same as cds1.query(query1) (the associated query is query2)
100    results = cds1.query([])
101    assert len(results) == 4
102
103
104def test_source_markings():
105    msrc = MemorySource(TLP_GREEN)
106
107    assert msrc.get(TLP_GREEN.id) == TLP_GREEN
108    assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
109    assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
110
111
112def test_sink_markings():
113    # just make sure there is no crash
114    msink = MemorySink(TLP_GREEN)
115    msink.add(TLP_GREEN)
116
117
118def test_store_markings():
119    mstore = MemoryStore(TLP_GREEN)
120
121    assert mstore.get(TLP_GREEN.id) == TLP_GREEN
122    assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
123    assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
124
125
126def test_source_mixed(indicator):
127    msrc = MemorySource([TLP_GREEN, indicator])
128
129    assert msrc.get(TLP_GREEN.id) == TLP_GREEN
130    assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
131    assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
132
133    assert msrc.get(indicator.id) == indicator
134    assert msrc.all_versions(indicator.id) == [indicator]
135    assert msrc.query(Filter("id", "=", indicator.id)) == [indicator]
136
137    all_objs = msrc.query()
138    assert TLP_GREEN in all_objs
139    assert indicator in all_objs
140    assert len(all_objs) == 2
141
142
143def test_sink_mixed(indicator):
144    # just make sure there is no crash
145    msink = MemorySink([TLP_GREEN, indicator])
146    msink.add([TLP_GREEN, indicator])
147
148
149def test_store_mixed(indicator):
150    mstore = MemoryStore([TLP_GREEN, indicator])
151
152    assert mstore.get(TLP_GREEN.id) == TLP_GREEN
153    assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
154    assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
155
156    assert mstore.get(indicator.id) == indicator
157    assert mstore.all_versions(indicator.id) == [indicator]
158    assert mstore.query(Filter("id", "=", indicator.id)) == [indicator]
159
160    all_objs = mstore.query()
161    assert TLP_GREEN in all_objs
162    assert indicator in all_objs
163    assert len(all_objs) == 2
164