1import copy
2import random
3
4import numpy as np
5from numpy.testing import assert_array_equal
6
7import tiledb
8from tiledb import TileDBError, core
9from tiledb.tests.common import DiskTestCase, rand_ascii
10
11
12class CoreCCTest(DiskTestCase):
13    def test_pyquery_basic(self):
14        ctx = tiledb.Ctx()
15        uri = self.path("test_pyquery_basic")
16        with tiledb.from_numpy(uri, np.random.rand(4)) as A:
17            pass
18
19        with tiledb.open(uri) as a:
20            with tiledb.scope_ctx({"py.init_buffer_bytes": "abcd"}) as testctx:
21                with self.assertRaises(ValueError):
22                    core.PyQuery(testctx, a, ("",), (), 0, False)
23
24            q = core.PyQuery(ctx, a, ("",), (), 0, False)
25
26            try:
27                q._test_err("bad foo happened")
28            except Exception as exc:
29                assert isinstance(exc, tiledb.TileDBError)
30                assert exc.message == "bad foo happened"
31
32            q.set_ranges([[(0, 3)]])
33
34            with self.assertRaises(TileDBError):
35                q.set_ranges([[(0, 3.0)]])
36
37            q.set_ranges([[(0, np.int32(3))]])
38
39            with self.assertRaises(TileDBError):
40                q.set_ranges([[(3, "a")]])
41
42            with self.assertRaisesRegex(
43                TileDBError,
44                "Failed to cast dim range '\\(1.2344, 5.6789\\)' to dim type UINT64.*$",
45            ):
46                q.set_ranges([[(1.2344, 5.6789)]])
47
48            with self.assertRaisesRegex(
49                TileDBError,
50                "Failed to cast dim range '\\('aa', 'bbbb'\\)' to dim type UINT64.*$",
51            ):
52                q.set_ranges([[("aa", "bbbb")]])
53
54        with tiledb.open(uri) as a:
55            q2 = core.PyQuery(ctx, a, ("",), (), 0, False)
56
57            q2.set_ranges([[(0, 3)]])
58            q2.submit()
59            res = q2.results()[""][0]
60            res.dtype = np.double
61            assert_array_equal(res, a[:])
62
63    def test_pyquery_init(self):
64        uri = self.path("test_pyquery_init")
65        intmax = np.iinfo(np.int64).max
66        config_dict = {
67            "sm.tile_cache_size": "100",
68            "py.init_buffer_bytes": str(intmax),
69            "py.alloc_max_bytes": str(intmax),
70        }
71        with tiledb.scope_ctx(config_dict) as ctx:
72            with tiledb.from_numpy(uri, np.random.rand(4)) as A:
73                pass
74
75            with tiledb.open(uri) as a:
76                q = core.PyQuery(ctx, a, ("",), (), 0, False)
77                self.assertEqual(q._test_init_buffer_bytes, intmax)
78                self.assertEqual(q._test_alloc_max_bytes, intmax)
79
80                with self.assertRaisesRegex(
81                    ValueError,
82                    "Invalid parameter: 'py.alloc_max_bytes' must be >= 1 MB ",
83                ), tiledb.scope_ctx({"py.alloc_max_bytes": 10}) as ctx2:
84                    q = core.PyQuery(ctx2, a, ("",), (), 0, False)
85
86    def test_import_buffer(self):
87        uri = self.path("test_import_buffer")
88
89        def_tile = 1
90        if tiledb.libtiledb.version() < (2, 2):
91            def_tile = 2
92
93        dom = tiledb.Domain(
94            tiledb.Dim(domain=(0, 3), tile=def_tile, dtype=np.int64),
95            tiledb.Dim(domain=(0, 3), tile=def_tile, dtype=np.int64),
96        )
97        attrs = [
98            tiledb.Attr(name="", dtype=np.float64),
99            tiledb.Attr(name="foo", dtype=np.int32),
100            tiledb.Attr(name="str", dtype=str),
101        ]
102        schema = tiledb.ArraySchema(domain=dom, attrs=attrs, sparse=False)
103        tiledb.DenseArray.create(uri, schema)
104
105        data_orig = {
106            "": 2.5 * np.identity(4, dtype=np.float64),
107            "foo": 8 * np.identity(4, dtype=np.int32),
108            "str": np.array(
109                [rand_ascii(random.randint(0, 5)) for _ in range(16)]
110            ).reshape(4, 4),
111        }
112
113        with tiledb.open(uri, "w") as A:
114            A[:] = data_orig
115
116        with tiledb.open(uri) as B:
117            assert_array_equal(B[:][""], data_orig[""]),
118            assert_array_equal(B[:]["foo"], data_orig["foo"])
119
120        data_mod = {
121            "": 5 * np.identity(4, dtype=np.float64),
122            "foo": 32 * np.identity(4, dtype=np.int32),
123            "str": np.array(
124                [rand_ascii(random.randint(1, 7)) for _ in range(16)], dtype="U0"
125            ).reshape(4, 4),
126        }
127
128        str_offsets = np.array(
129            [0] + [len(x) for x in data_mod["str"].flatten()[:-1]], dtype=np.uint64
130        )
131        str_offsets = np.cumsum(str_offsets)
132
133        str_raw = np.array(
134            [ord(c) for c in "".join([x for x in data_mod["str"].flatten()])],
135            dtype=np.uint8,
136        )
137
138        data_mod_bfr = {
139            "": (data_mod[""].flatten().view(np.uint8), np.array([], dtype=np.uint64)),
140            "foo": (
141                data_mod["foo"].flatten().view(np.uint8),
142                np.array([], dtype=np.uint64),
143            ),
144            "str": (str_raw.flatten().view(np.uint8), str_offsets),
145        }
146
147        with tiledb.open(uri) as C:
148            res = C.multi_index[0:3, 0:3]
149            assert_array_equal(res[""], data_orig[""])
150            assert_array_equal(res["foo"], data_orig["foo"])
151            assert_array_equal(res["str"], data_orig["str"])
152
153            C._set_buffers(copy.deepcopy(data_mod_bfr))
154            res = C.multi_index[0:3, 0:3]
155            assert_array_equal(res[""], data_mod[""])
156            assert_array_equal(res["foo"], data_mod["foo"])
157            assert_array_equal(res["str"], data_mod["str"])
158
159        with tiledb.open(uri) as D:
160            D._set_buffers(copy.deepcopy(data_mod_bfr))
161            res = D[:, :]
162            assert_array_equal(res[""], data_mod[""])
163            assert_array_equal(res["foo"], data_mod["foo"])
164            assert_array_equal(res["str"], data_mod["str"])
165
166        with tiledb.DenseArray(uri, mode="r") as E, tiledb.scope_ctx() as ctx:
167            # Ensure that query only returns specified attributes
168            q = core.PyQuery(ctx, E, ("foo",), (), 0, False)
169            q.set_ranges([[(0, 1)]])
170            q.submit()
171            r = q.results()
172            self.assertTrue("foo" in r)
173            self.assertTrue("str" not in r)
174            del q
175