1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"bytes"
21	"compress/flate"
22	"encoding/base64"
23	"encoding/binary"
24	"io"
25	"math/big"
26	"strings"
27	"unicode"
28
29	"gopkg.in/square/go-jose.v2/json"
30)
31
32// Helper function to serialize known-good objects.
33// Precondition: value is not a nil pointer.
34func mustSerializeJSON(value interface{}) []byte {
35	out, err := json.Marshal(value)
36	if err != nil {
37		panic(err)
38	}
39	// We never want to serialize the top-level value "null," since it's not a
40	// valid JOSE message. But if a caller passes in a nil pointer to this method,
41	// MarshalJSON will happily serialize it as the top-level value "null". If
42	// that value is then embedded in another operation, for instance by being
43	// base64-encoded and fed as input to a signing algorithm
44	// (https://github.com/square/go-jose/issues/22), the result will be
45	// incorrect. Because this method is intended for known-good objects, and a nil
46	// pointer is not a known-good object, we are free to panic in this case.
47	// Note: It's not possible to directly check whether the data pointed at by an
48	// interface is a nil pointer, so we do this hacky workaround.
49	// https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I
50	if string(out) == "null" {
51		panic("Tried to serialize a nil pointer.")
52	}
53	return out
54}
55
56// Strip all newlines and whitespace
57func stripWhitespace(data string) string {
58	buf := strings.Builder{}
59	buf.Grow(len(data))
60	for _, r := range data {
61		if !unicode.IsSpace(r) {
62			buf.WriteRune(r)
63		}
64	}
65	return buf.String()
66}
67
68// Perform compression based on algorithm
69func compress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
70	switch algorithm {
71	case DEFLATE:
72		return deflate(input)
73	default:
74		return nil, ErrUnsupportedAlgorithm
75	}
76}
77
78// Perform decompression based on algorithm
79func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
80	switch algorithm {
81	case DEFLATE:
82		return inflate(input)
83	default:
84		return nil, ErrUnsupportedAlgorithm
85	}
86}
87
88// Compress with DEFLATE
89func deflate(input []byte) ([]byte, error) {
90	output := new(bytes.Buffer)
91
92	// Writing to byte buffer, err is always nil
93	writer, _ := flate.NewWriter(output, 1)
94	_, _ = io.Copy(writer, bytes.NewBuffer(input))
95
96	err := writer.Close()
97	return output.Bytes(), err
98}
99
100// Decompress with DEFLATE
101func inflate(input []byte) ([]byte, error) {
102	output := new(bytes.Buffer)
103	reader := flate.NewReader(bytes.NewBuffer(input))
104
105	_, err := io.Copy(output, reader)
106	if err != nil {
107		return nil, err
108	}
109
110	err = reader.Close()
111	return output.Bytes(), err
112}
113
114// byteBuffer represents a slice of bytes that can be serialized to url-safe base64.
115type byteBuffer struct {
116	data []byte
117}
118
119func newBuffer(data []byte) *byteBuffer {
120	if data == nil {
121		return nil
122	}
123	return &byteBuffer{
124		data: data,
125	}
126}
127
128func newFixedSizeBuffer(data []byte, length int) *byteBuffer {
129	if len(data) > length {
130		panic("square/go-jose: invalid call to newFixedSizeBuffer (len(data) > length)")
131	}
132	pad := make([]byte, length-len(data))
133	return newBuffer(append(pad, data...))
134}
135
136func newBufferFromInt(num uint64) *byteBuffer {
137	data := make([]byte, 8)
138	binary.BigEndian.PutUint64(data, num)
139	return newBuffer(bytes.TrimLeft(data, "\x00"))
140}
141
142func (b *byteBuffer) MarshalJSON() ([]byte, error) {
143	return json.Marshal(b.base64())
144}
145
146func (b *byteBuffer) UnmarshalJSON(data []byte) error {
147	var encoded string
148	err := json.Unmarshal(data, &encoded)
149	if err != nil {
150		return err
151	}
152
153	if encoded == "" {
154		return nil
155	}
156
157	decoded, err := base64.RawURLEncoding.DecodeString(encoded)
158	if err != nil {
159		return err
160	}
161
162	*b = *newBuffer(decoded)
163
164	return nil
165}
166
167func (b *byteBuffer) base64() string {
168	return base64.RawURLEncoding.EncodeToString(b.data)
169}
170
171func (b *byteBuffer) bytes() []byte {
172	// Handling nil here allows us to transparently handle nil slices when serializing.
173	if b == nil {
174		return nil
175	}
176	return b.data
177}
178
179func (b byteBuffer) bigInt() *big.Int {
180	return new(big.Int).SetBytes(b.data)
181}
182
183func (b byteBuffer) toInt() int {
184	return int(b.bigInt().Int64())
185}
186