1import struct
2import sys
3import unittest
4
5import zstandard as zstd
6
7from .common import (
8    generate_samples,
9    make_cffi,
10    random_input_data,
11    TestCase,
12)
13
14if sys.version_info[0] >= 3:
15    int_type = int
16else:
17    int_type = long
18
19
20@make_cffi
21class TestTrainDictionary(TestCase):
22    def test_no_args(self):
23        with self.assertRaises(TypeError):
24            zstd.train_dictionary()
25
26    def test_bad_args(self):
27        with self.assertRaises(TypeError):
28            zstd.train_dictionary(8192, u"foo")
29
30        with self.assertRaises(ValueError):
31            zstd.train_dictionary(8192, [u"foo"])
32
33    def test_no_params(self):
34        d = zstd.train_dictionary(8192, random_input_data())
35        self.assertIsInstance(d.dict_id(), int_type)
36
37        # The dictionary ID may be different across platforms.
38        expected = b"\x37\xa4\x30\xec" + struct.pack("<I", d.dict_id())
39
40        data = d.as_bytes()
41        self.assertEqual(data[0:8], expected)
42
43    def test_basic(self):
44        d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
45        self.assertIsInstance(d.dict_id(), int_type)
46
47        data = d.as_bytes()
48        self.assertEqual(data[0:4], b"\x37\xa4\x30\xec")
49
50        self.assertEqual(d.k, 64)
51        self.assertEqual(d.d, 16)
52
53    def test_set_dict_id(self):
54        d = zstd.train_dictionary(
55            8192, generate_samples(), k=64, d=16, dict_id=42
56        )
57        self.assertEqual(d.dict_id(), 42)
58
59    def test_optimize(self):
60        d = zstd.train_dictionary(
61            8192, generate_samples(), threads=-1, steps=1, d=16
62        )
63
64        # This varies by platform.
65        self.assertIn(d.k, (50, 2000))
66        self.assertEqual(d.d, 16)
67
68
69@make_cffi
70class TestCompressionDict(TestCase):
71    def test_bad_mode(self):
72        with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"):
73            zstd.ZstdCompressionDict(b"foo", dict_type=42)
74
75    def test_bad_precompute_compress(self):
76        d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
77
78        with self.assertRaisesRegex(
79            ValueError, "must specify one of level or "
80        ):
81            d.precompute_compress()
82
83        with self.assertRaisesRegex(
84            ValueError, "must only specify one of level or "
85        ):
86            d.precompute_compress(
87                level=3, compression_params=zstd.CompressionParameters()
88            )
89
90    def test_precompute_compress_rawcontent(self):
91        d = zstd.ZstdCompressionDict(
92            b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT
93        )
94        d.precompute_compress(level=1)
95
96        d = zstd.ZstdCompressionDict(
97            b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT
98        )
99        with self.assertRaisesRegex(
100            zstd.ZstdError, "unable to precompute dictionary"
101        ):
102            d.precompute_compress(level=1)
103