1# Copyright 2018 MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import warnings
16
17try:
18    import snappy
19    _HAVE_SNAPPY = True
20except ImportError:
21    # python-snappy isn't available.
22    _HAVE_SNAPPY = False
23
24try:
25    import zlib
26    _HAVE_ZLIB = True
27except ImportError:
28    # Python built without zlib support.
29    _HAVE_ZLIB = False
30
31try:
32    from zstandard import ZstdCompressor, ZstdDecompressor
33    _HAVE_ZSTD = True
34except ImportError:
35    _HAVE_ZSTD = False
36
37from pymongo.hello_compat import HelloCompat
38from pymongo.monitoring import _SENSITIVE_COMMANDS
39
40_SUPPORTED_COMPRESSORS = set(["snappy", "zlib", "zstd"])
41_NO_COMPRESSION = set([HelloCompat.CMD, HelloCompat.LEGACY_CMD])
42_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
43
44
45def validate_compressors(dummy, value):
46    try:
47        # `value` is string.
48        compressors = value.split(",")
49    except AttributeError:
50        # `value` is an iterable.
51        compressors = list(value)
52
53    for compressor in compressors[:]:
54        if compressor not in _SUPPORTED_COMPRESSORS:
55            compressors.remove(compressor)
56            warnings.warn("Unsupported compressor: %s" % (compressor,))
57        elif compressor == "snappy" and not _HAVE_SNAPPY:
58            compressors.remove(compressor)
59            warnings.warn(
60                "Wire protocol compression with snappy is not available. "
61                "You must install the python-snappy module for snappy support.")
62        elif compressor == "zlib" and not _HAVE_ZLIB:
63            compressors.remove(compressor)
64            warnings.warn(
65                "Wire protocol compression with zlib is not available. "
66                "The zlib module is not available.")
67        elif compressor == "zstd" and not _HAVE_ZSTD:
68            compressors.remove(compressor)
69            warnings.warn(
70                "Wire protocol compression with zstandard is not available. "
71                "You must install the zstandard module for zstandard support.")
72    return compressors
73
74
75def validate_zlib_compression_level(option, value):
76    try:
77        level = int(value)
78    except:
79        raise TypeError("%s must be an integer, not %r." % (option, value))
80    if level < -1 or level > 9:
81        raise ValueError(
82            "%s must be between -1 and 9, not %d." % (option, level))
83    return level
84
85
86class CompressionSettings(object):
87    def __init__(self, compressors, zlib_compression_level):
88        self.compressors = compressors
89        self.zlib_compression_level = zlib_compression_level
90
91    def get_compression_context(self, compressors):
92        if compressors:
93            chosen = compressors[0]
94            if chosen == "snappy":
95                return SnappyContext()
96            elif chosen == "zlib":
97                return ZlibContext(self.zlib_compression_level)
98            elif chosen == "zstd":
99                return ZstdContext()
100
101
102def _zlib_no_compress(data):
103    """Compress data with zlib level 0."""
104    cobj = zlib.compressobj(0)
105    return b"".join([cobj.compress(data), cobj.flush()])
106
107
108class SnappyContext(object):
109    compressor_id = 1
110
111    @staticmethod
112    def compress(data):
113        return snappy.compress(data)
114
115
116class ZlibContext(object):
117    compressor_id = 2
118
119    def __init__(self, level):
120        # Jython zlib.compress doesn't support -1
121        if level == -1:
122            self.compress = zlib.compress
123        # Jython zlib.compress also doesn't support 0
124        elif level == 0:
125            self.compress = _zlib_no_compress
126        else:
127            self.compress = lambda data: zlib.compress(data, level)
128
129
130class ZstdContext(object):
131    compressor_id = 3
132
133    @staticmethod
134    def compress(data):
135        # ZstdCompressor is not thread safe.
136        # TODO: Use a pool?
137        return ZstdCompressor().compress(data)
138
139
140def decompress(data, compressor_id):
141    if compressor_id == SnappyContext.compressor_id:
142        # python-snappy doesn't support the buffer interface.
143        # https://github.com/andrix/python-snappy/issues/65
144        # This only matters when data is a memoryview since
145        # id(bytes(data)) == id(data) when data is a bytes.
146        # NOTE: bytes(memoryview) returns the memoryview repr
147        # in Python 2.7. The right thing to do in 2.7 is call
148        # memoryview.tobytes(), but we currently only use
149        # memoryview in Python 3.x.
150        return snappy.uncompress(bytes(data))
151    elif compressor_id == ZlibContext.compressor_id:
152        return zlib.decompress(data)
153    elif compressor_id == ZstdContext.compressor_id:
154        # ZstdDecompressor is not thread safe.
155        # TODO: Use a pool?
156        return ZstdDecompressor().decompress(data)
157    else:
158        raise ValueError("Unknown compressorId %d" % (compressor_id,))
159