1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"bytes"
21	"compress/flate"
22	"encoding/base64"
23	"encoding/binary"
24	"io"
25	"math/big"
26	"regexp"
27
28	"gopkg.in/square/go-jose.v2/json"
29)
30
31var stripWhitespaceRegex = regexp.MustCompile("\\s")
32
33// Helper function to serialize known-good objects.
34// Precondition: value is not a nil pointer.
35func mustSerializeJSON(value interface{}) []byte {
36	out, err := json.Marshal(value)
37	if err != nil {
38		panic(err)
39	}
40	// We never want to serialize the top-level value "null," since it's not a
41	// valid JOSE message. But if a caller passes in a nil pointer to this method,
42	// MarshalJSON will happily serialize it as the top-level value "null". If
43	// that value is then embedded in another operation, for instance by being
44	// base64-encoded and fed as input to a signing algorithm
45	// (https://github.com/square/go-jose/issues/22), the result will be
46	// incorrect. Because this method is intended for known-good objects, and a nil
47	// pointer is not a known-good object, we are free to panic in this case.
48	// Note: It's not possible to directly check whether the data pointed at by an
49	// interface is a nil pointer, so we do this hacky workaround.
50	// https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I
51	if string(out) == "null" {
52		panic("Tried to serialize a nil pointer.")
53	}
54	return out
55}
56
57// Strip all newlines and whitespace
58func stripWhitespace(data string) string {
59	return stripWhitespaceRegex.ReplaceAllString(data, "")
60}
61
62// Perform compression based on algorithm
63func compress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
64	switch algorithm {
65	case DEFLATE:
66		return deflate(input)
67	default:
68		return nil, ErrUnsupportedAlgorithm
69	}
70}
71
72// Perform decompression based on algorithm
73func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
74	switch algorithm {
75	case DEFLATE:
76		return inflate(input)
77	default:
78		return nil, ErrUnsupportedAlgorithm
79	}
80}
81
82// Compress with DEFLATE
83func deflate(input []byte) ([]byte, error) {
84	output := new(bytes.Buffer)
85
86	// Writing to byte buffer, err is always nil
87	writer, _ := flate.NewWriter(output, 1)
88	_, _ = io.Copy(writer, bytes.NewBuffer(input))
89
90	err := writer.Close()
91	return output.Bytes(), err
92}
93
94// Decompress with DEFLATE
95func inflate(input []byte) ([]byte, error) {
96	output := new(bytes.Buffer)
97	reader := flate.NewReader(bytes.NewBuffer(input))
98
99	_, err := io.Copy(output, reader)
100	if err != nil {
101		return nil, err
102	}
103
104	err = reader.Close()
105	return output.Bytes(), err
106}
107
108// byteBuffer represents a slice of bytes that can be serialized to url-safe base64.
109type byteBuffer struct {
110	data []byte
111}
112
113func newBuffer(data []byte) *byteBuffer {
114	if data == nil {
115		return nil
116	}
117	return &byteBuffer{
118		data: data,
119	}
120}
121
122func newFixedSizeBuffer(data []byte, length int) *byteBuffer {
123	if len(data) > length {
124		panic("square/go-jose: invalid call to newFixedSizeBuffer (len(data) > length)")
125	}
126	pad := make([]byte, length-len(data))
127	return newBuffer(append(pad, data...))
128}
129
130func newBufferFromInt(num uint64) *byteBuffer {
131	data := make([]byte, 8)
132	binary.BigEndian.PutUint64(data, num)
133	return newBuffer(bytes.TrimLeft(data, "\x00"))
134}
135
136func (b *byteBuffer) MarshalJSON() ([]byte, error) {
137	return json.Marshal(b.base64())
138}
139
140func (b *byteBuffer) UnmarshalJSON(data []byte) error {
141	var encoded string
142	err := json.Unmarshal(data, &encoded)
143	if err != nil {
144		return err
145	}
146
147	if encoded == "" {
148		return nil
149	}
150
151	decoded, err := base64.RawURLEncoding.DecodeString(encoded)
152	if err != nil {
153		return err
154	}
155
156	*b = *newBuffer(decoded)
157
158	return nil
159}
160
161func (b *byteBuffer) base64() string {
162	return base64.RawURLEncoding.EncodeToString(b.data)
163}
164
165func (b *byteBuffer) bytes() []byte {
166	// Handling nil here allows us to transparently handle nil slices when serializing.
167	if b == nil {
168		return nil
169	}
170	return b.data
171}
172
173func (b byteBuffer) bigInt() *big.Int {
174	return new(big.Int).SetBytes(b.data)
175}
176
177func (b byteBuffer) toInt() int {
178	return int(b.bigInt().Int64())
179}
180