1package msgpackzip
2
3import (
4	"bytes"
5	"fmt"
6	"sort"
7)
8
9type compressor struct {
10	input          []byte
11	valueWhitelist ValueWhitelist
12	collectMapKeys bool
13}
14
15func newCompressor(b []byte, wl ValueWhitelist, c bool) *compressor {
16	return &compressor{input: b, valueWhitelist: wl, collectMapKeys: c}
17}
18
19// ValueWhitelist can be used to specify which values can be compressed.
20// Values are either strings or binary []byte arrays.
21type ValueWhitelist struct {
22	strings     map[string]bool
23	binaries    map[string]bool
24	allValuesOk bool
25}
26
27// NewValueWhitelist makes an empty value white list, initialzed with empty
28// lists.
29func NewValueWhitelist() *ValueWhitelist {
30	return &ValueWhitelist{
31		strings:  make(map[string]bool),
32		binaries: make(map[string]bool),
33	}
34}
35
36// AddString adds a string to the value whitelist
37func (v *ValueWhitelist) AddString(s string) {
38	v.strings[s] = true
39}
40
41// AddBinary adds a binary buffer to the value whitelist
42func (v *ValueWhitelist) AddBinary(b []byte) {
43	v.binaries[string(b)] = true
44}
45
46func (v *ValueWhitelist) hasString(s string) bool {
47	if v.allValuesOk {
48		return true
49	}
50	return v.strings[s]
51}
52
53func (v *ValueWhitelist) hasBinary(b []byte) bool {
54	if v.allValuesOk {
55		return true
56	}
57	return v.binaries[string(b)]
58}
59
60// CompressWithWhitelist takes as input a msgpack encoded payload,
61// and also a whitelist of values that it's OK to compress.  It then compresses
62// all map keys and values in the given whitelist, and returns a compression,
63// or an error on error.
64func CompressWithWhitelist(input []byte, wl ValueWhitelist) (output []byte, err error) {
65	return newCompressor(input, wl, true).run()
66}
67
68// Compress the given msgpack encoding, compressing only static map keys, and
69// not compressing any values.
70func Compress(input []byte) (output []byte, err error) {
71	return newCompressor(input, *NewValueWhitelist(), true).run()
72}
73
74// ReportValuesFrequencies takes as input a msgpack encoding, and reports
75// which values are the most fequent in the encoding. It returns a list
76// of Frequency objects, sorted from most frequent to least frequent.
77func ReportValuesFrequencies(input []byte) (ret []Frequency, err error) {
78	wl := NewValueWhitelist()
79	wl.allValuesOk = true
80	return newCompressor(input, *wl, false).collectAndSortFrequencies()
81}
82
83// collectAndSortFrequencies decodes and descends the input buffer, making
84// an list of frequencies of map keys (and values if we have an active whitelist)
85// and returns a sorted list of those values, from most frequent to least frequent.
86func (c *compressor) collectAndSortFrequencies() (ret []Frequency, err error) {
87
88	freqs, err := c.collectFrequencies()
89	if err != nil {
90		return nil, err
91	}
92	freqsSorted, err := c.sortFrequencies(freqs)
93	if err != nil {
94		return nil, err
95	}
96	return freqsSorted, nil
97}
98
99// run the compressor on the given input. Can be used only once
100// per instantiation of compressor object.
101func (c *compressor) run() (output []byte, err error) {
102
103	freqsSorted, err := c.collectAndSortFrequencies()
104	if err != nil {
105		return nil, err
106	}
107	keys, err := c.frequenciesToMap(freqsSorted)
108	if err != nil {
109		return nil, err
110	}
111	output, err = c.output(freqsSorted, keys)
112	return output, err
113}
114
115// BinaryMapKey is a wrapper around a []byte vector of binary data so that it
116// can be stored as an interface{} and differentiated from proper strings.
117type BinaryMapKey string
118
119// collectFrequencies descends the input msgpack encoding and collects
120// the frequencies of map keys and values on the white list. It returns a map
121// of (key or value) to the number of times it shows up in the object.
122// The map is of type `map[interface{}]int`, which the `interface{}` can
123// be an int64, a plain old string, or a binary []byte buffer wrapped in a
124// BinaryMapKey.
125func (c *compressor) collectFrequencies() (ret map[interface{}]int, err error) {
126
127	ret = make(map[interface{}]int)
128	hooks := msgpackDecoderHooks{
129		mapKeyHook: func(d decodeStack) (decodeStack, error) {
130			d.hooks = msgpackDecoderHooks{
131				stringHook: func(l msgpackInt, s string) error {
132					if c.collectMapKeys {
133						ret[s]++
134					}
135					return nil
136				},
137				intHook: func(l msgpackInt) error {
138					i, err := l.toInt64()
139					if err != nil {
140						return err
141					}
142					if c.collectMapKeys {
143						ret[i]++
144					}
145					return nil
146				},
147				fallthroughHook: func(i interface{}, s string) error {
148					return fmt.Errorf("bad map key (type %T)", i)
149				},
150			}
151			return d, nil
152		},
153		stringHook: func(l msgpackInt, s string) error {
154			if c.valueWhitelist.hasString(s) {
155				ret[s]++
156			}
157			return nil
158		},
159		binaryHook: func(l msgpackInt, b []byte) error {
160			s := string(b)
161			if c.valueWhitelist.hasBinary(b) {
162				ret[BinaryMapKey(s)]++
163			}
164			return nil
165		},
166	}
167	err = newMsgpackDecoder(bytes.NewReader(c.input)).run(hooks)
168	if err != nil {
169		return nil, err
170	}
171	return ret, nil
172}
173
174// Frequency is a tuple, with a `Key interface{}` that can be an int64, a string,
175// or a BinaryMapKey (which is a wrapper around a binary buffer). The `Freq` field
176// is a count for how many times the `Key` shows up in the encoded msgpack object.
177type Frequency struct {
178	Key  interface{}
179	Freq int
180}
181
182// sortFrequencies converts a map of (keys -> counts) into an ordered vector of frequencies.
183func (c *compressor) sortFrequencies(freqs map[interface{}]int) (ret []Frequency, err error) {
184
185	ret = make([]Frequency, len(freqs))
186	var i int
187	for k, v := range freqs {
188		ret[i] = Frequency{k, v}
189		i++
190	}
191	sort.SliceStable(ret, func(i, j int) bool { return ret[i].Freq > ret[j].Freq })
192	return ret, nil
193}
194
195// frequenciesToMap converts a sorted vectors of frequencies to a map (key -> uint),
196// where the RHS values are ordered 0 to N. The idea is that the most frequent
197// keys get ths smallest values, which take of the least space when msgpack encoded.
198// This function returns the "keyMap" refered to later.
199func (c *compressor) frequenciesToMap(freqs []Frequency) (keys map[interface{}]uint, err error) {
200	ret := make(map[interface{}]uint, len(freqs))
201	for i, freq := range freqs {
202		ret[freq.Key] = uint(i)
203	}
204	return ret, nil
205}
206
207// output the data, the compressed keymap, and the version byte, which is the whole
208// encodeded compressed output.
209func (c *compressor) output(freqsSorted []Frequency, keys map[interface{}]uint) (output []byte, err error) {
210
211	version := Version(1)
212	data, err := c.outputData(keys)
213	if err != nil {
214		return nil, err
215	}
216	compressedKeymap, err := c.outputCompressedKeymap(freqsSorted)
217	if err != nil {
218		return nil, err
219	}
220	return c.outputFinalProduct(version, data, compressedKeymap)
221}
222
223// outputData, replacing all map Keys with their corresponding uints in the
224// keyMap. If we come across white-listed values, replace them with an
225// "external marker", followed by their position in the keyMap.
226func (c *compressor) outputData(keys map[interface{}]uint) (output []byte, err error) {
227
228	var data outputter
229
230	hooks := data.decoderHooks()
231
232	// mapKeys are rewritten to be uints that appear in the
233	// keyMap passed in.
234	hooks.mapKeyHook = func(d decodeStack) (decodeStack, error) {
235		d.hooks = msgpackDecoderHooks{
236			intHook: func(l msgpackInt) error {
237				i, err := l.toInt64()
238				if err != nil {
239					return err
240				}
241				val, ok := keys[i]
242				if !ok {
243					return fmt.Errorf("unexpected map key: %v", i)
244				}
245				return data.outputRawUint(val)
246			},
247			stringHook: func(l msgpackInt, s string) error {
248				val, ok := keys[s]
249				if !ok {
250					return fmt.Errorf("unexpected map key: %q", s)
251				}
252				return data.outputRawUint(val)
253			},
254			fallthroughHook: func(i interface{}, s string) error {
255				return fmt.Errorf("bad map key (type %T)", i)
256			},
257		}
258		return d, nil
259	}
260
261	// strings are rewritten if they are on the whitelist
262	hooks.stringHook = func(l msgpackInt, s string) error {
263		val, ok := keys[s]
264		if ok {
265			return data.outputExtUint(val)
266		}
267		return data.outputString(l, s)
268	}
269
270	// binary buffers are rewritten if they are on the whitelist
271	hooks.binaryHook = func(l msgpackInt, b []byte) error {
272		val, ok := keys[BinaryMapKey(string(b))]
273		if ok {
274			return data.outputExtUint(val)
275		}
276		return data.outputBinary(l, b)
277	}
278
279	// external data types are output and aren't allowed in inputs
280	hooks.extHook = func(b []byte) error {
281		return fmt.Errorf("cannot handle external data types")
282	}
283
284	err = newMsgpackDecoder(bytes.NewReader(c.input)).run(hooks)
285	if err != nil {
286		return nil, err
287	}
288
289	return data.Bytes(), nil
290}
291
292// outputCompressedKeymap msgpack encodes the keymap and then runs
293// `flate.Compress` on the output (which is gzip without the headers).
294// We're hand-encoding this map using our msgpack encoder. Note that we're
295// not compressing the keymap directly, but rather the frequence array
296// that we derive the keymap from. This is so that we get determinstic
297// output, since ranging of a map in Go is non-deterministic and randomized.
298func (c *compressor) outputCompressedKeymap(freqsSorted []Frequency) (output []byte, err error) {
299
300	var keymap outputter
301
302	// Now write out a msgpack dictionary for the keymaps;
303	// do it but hand, that's simplest for now, rather than pulling
304	// in a new encoder.
305	err = keymap.outputMapPrefix(msgpackIntFromUint(uint(len(freqsSorted))))
306	if err != nil {
307		return nil, err
308	}
309	for i, v := range freqsSorted {
310		// Note that we reverse the map to make decoding easier
311		err = keymap.outputInt(msgpackIntFromUint(uint(i)))
312		if err != nil {
313			return nil, err
314		}
315		err = keymap.outputStringOrUintOrBinary(v.Key)
316		if err != nil {
317			return nil, err
318		}
319	}
320	tmp := keymap.Bytes()
321
322	compressedKeymap, err := flateCompress(tmp)
323	if err != nil {
324		return nil, err
325	}
326
327	return compressedKeymap, nil
328}
329
330type Version int
331
332// outputFinalProduct is the final pass output routine. It outputs the wrapper
333// 3-value array, the version prefix, the encoded data, and the compressed, encoded
334// keyMap.
335func (c *compressor) outputFinalProduct(version Version, data []byte, compressedKeymap []byte) (output []byte, err error) {
336
337	var ret outputter
338
339	// 3 elements in the array, so output '3'
340	err = ret.outputArrayPrefix(msgpackIntFromUint(uint(3)))
341	if err != nil {
342		return nil, err
343	}
344	err = ret.outputInt(msgpackIntFromUint(uint(version)))
345	if err != nil {
346		return nil, err
347	}
348	err = ret.outputBinary(msgpackIntFromUint(uint(len(data))), data)
349	if err != nil {
350		return nil, err
351	}
352	err = ret.outputBinary(msgpackIntFromUint(uint(len(compressedKeymap))), compressedKeymap)
353	if err != nil {
354		return nil, err
355	}
356
357	return ret.Bytes(), nil
358}
359