1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package wiremessage
8
9import (
10	"errors"
11	"fmt"
12)
13
14// Compressed represents the OP_COMPRESSED message of the MongoDB wire protocol.
15type Compressed struct {
16	MsgHeader         Header
17	OriginalOpCode    OpCode
18	UncompressedSize  int32
19	CompressorID      CompressorID
20	CompressedMessage []byte
21}
22
23// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
24func (c Compressed) MarshalWireMessage() ([]byte, error) {
25	b := make([]byte, 0, c.Len())
26	return c.AppendWireMessage(b)
27}
28
29// ValidateWireMessage implements the Validator and WireMessage interfaces.
30func (c Compressed) ValidateWireMessage() error {
31	if int(c.MsgHeader.MessageLength) != c.Len() {
32		return errors.New("incorrect header: message length is not correct")
33	}
34
35	if c.MsgHeader.OpCode != OpCompressed {
36		return errors.New("incorrect header: opcode is not OpCompressed")
37	}
38
39	if c.OriginalOpCode != c.MsgHeader.OpCode {
40		return errors.New("incorrect header: original opcode does not match opcode in message header")
41	}
42	return nil
43}
44
45// AppendWireMessage implements the Appender and WireMessage interfaces.
46//
47// AppendWireMessage will set the MessageLength property of MsgHeader if it is 0. It will also set the OpCode to
48// OpCompressed if the OpCode is 0. If either of these properties are non-zero and not correct, this method will return
49// both the []byte with the wire message appended to it and an invalid header error.
50func (c Compressed) AppendWireMessage(b []byte) ([]byte, error) {
51	err := c.MsgHeader.SetDefaults(c.Len(), OpCompressed)
52
53	b = c.MsgHeader.AppendHeader(b)
54	b = appendInt32(b, int32(c.OriginalOpCode))
55	b = appendInt32(b, c.UncompressedSize)
56	b = append(b, byte(c.CompressorID))
57	b = append(b, c.CompressedMessage...)
58
59	return b, err
60}
61
62// String implements the fmt.Stringer interface.
63func (c Compressed) String() string {
64	return fmt.Sprintf(
65		`OP_COMPRESSED{MsgHeader: %s, Uncompressed Size: %d, CompressorId: %d, Compressed message: %s}`,
66		c.MsgHeader, c.UncompressedSize, c.CompressorID, c.CompressedMessage,
67	)
68}
69
70// Len implements the WireMessage interface.
71func (c Compressed) Len() int {
72	// Header + OpCode + UncompressedSize + CompressorId + CompressedMessage
73	return 16 + 4 + 4 + 1 + len(c.CompressedMessage)
74}
75
76// UnmarshalWireMessage implements the Unmarshaler interface.
77func (c *Compressed) UnmarshalWireMessage(b []byte) error {
78	var err error
79	c.MsgHeader, err = ReadHeader(b, 0)
80	if err != nil {
81		return err
82	}
83
84	if len(b) < int(c.MsgHeader.MessageLength) {
85		return Error{Type: ErrOpCompressed, Message: "[]byte too small"}
86	}
87
88	c.OriginalOpCode = OpCode(readInt32(b, 16)) // skip first 16 for header
89	c.UncompressedSize = readInt32(b, 20)
90	c.CompressorID = CompressorID(b[24])
91
92	// messageLength - Header - OpCode - UncompressedSize - CompressorId
93	msgLen := c.MsgHeader.MessageLength - 16 - 4 - 4 - 1
94	c.CompressedMessage = b[25 : 25+msgLen]
95
96	return nil
97}
98
99// CompressorID is the ID for each type of Compressor.
100type CompressorID uint8
101
102// These constants represent the individual compressor IDs for an OP_COMPRESSED.
103const (
104	CompressorNoOp CompressorID = iota
105	CompressorSnappy
106	CompressorZLib
107)
108
109// DefaultZlibLevel is the default level for zlib compression
110const DefaultZlibLevel = 6
111