1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"fmt"
21	"strings"
22
23	"gopkg.in/square/go-jose.v1/json"
24)
25
26// rawJsonWebEncryption represents a raw JWE JSON object. Used for parsing/serializing.
27type rawJsonWebEncryption struct {
28	Protected    *byteBuffer        `json:"protected,omitempty"`
29	Unprotected  *rawHeader         `json:"unprotected,omitempty"`
30	Header       *rawHeader         `json:"header,omitempty"`
31	Recipients   []rawRecipientInfo `json:"recipients,omitempty"`
32	Aad          *byteBuffer        `json:"aad,omitempty"`
33	EncryptedKey *byteBuffer        `json:"encrypted_key,omitempty"`
34	Iv           *byteBuffer        `json:"iv,omitempty"`
35	Ciphertext   *byteBuffer        `json:"ciphertext,omitempty"`
36	Tag          *byteBuffer        `json:"tag,omitempty"`
37}
38
39// rawRecipientInfo represents a raw JWE Per-Recipient header JSON object. Used for parsing/serializing.
40type rawRecipientInfo struct {
41	Header       *rawHeader `json:"header,omitempty"`
42	EncryptedKey string     `json:"encrypted_key,omitempty"`
43}
44
45// JsonWebEncryption represents an encrypted JWE object after parsing.
46type JsonWebEncryption struct {
47	Header                   JoseHeader
48	protected, unprotected   *rawHeader
49	recipients               []recipientInfo
50	aad, iv, ciphertext, tag []byte
51	original                 *rawJsonWebEncryption
52}
53
54// recipientInfo represents a raw JWE Per-Recipient header JSON object after parsing.
55type recipientInfo struct {
56	header       *rawHeader
57	encryptedKey []byte
58}
59
60// GetAuthData retrieves the (optional) authenticated data attached to the object.
61func (obj JsonWebEncryption) GetAuthData() []byte {
62	if obj.aad != nil {
63		out := make([]byte, len(obj.aad))
64		copy(out, obj.aad)
65		return out
66	}
67
68	return nil
69}
70
71// Get the merged header values
72func (obj JsonWebEncryption) mergedHeaders(recipient *recipientInfo) rawHeader {
73	out := rawHeader{}
74	out.merge(obj.protected)
75	out.merge(obj.unprotected)
76
77	if recipient != nil {
78		out.merge(recipient.header)
79	}
80
81	return out
82}
83
84// Get the additional authenticated data from a JWE object.
85func (obj JsonWebEncryption) computeAuthData() []byte {
86	var protected string
87
88	if obj.original != nil {
89		protected = obj.original.Protected.base64()
90	} else {
91		protected = base64URLEncode(mustSerializeJSON((obj.protected)))
92	}
93
94	output := []byte(protected)
95	if obj.aad != nil {
96		output = append(output, '.')
97		output = append(output, []byte(base64URLEncode(obj.aad))...)
98	}
99
100	return output
101}
102
103// ParseEncrypted parses an encrypted message in compact or full serialization format.
104func ParseEncrypted(input string) (*JsonWebEncryption, error) {
105	input = stripWhitespace(input)
106	if strings.HasPrefix(input, "{") {
107		return parseEncryptedFull(input)
108	}
109
110	return parseEncryptedCompact(input)
111}
112
113// parseEncryptedFull parses a message in compact format.
114func parseEncryptedFull(input string) (*JsonWebEncryption, error) {
115	var parsed rawJsonWebEncryption
116	err := json.Unmarshal([]byte(input), &parsed)
117	if err != nil {
118		return nil, err
119	}
120
121	return parsed.sanitized()
122}
123
124// sanitized produces a cleaned-up JWE object from the raw JSON.
125func (parsed *rawJsonWebEncryption) sanitized() (*JsonWebEncryption, error) {
126	obj := &JsonWebEncryption{
127		original:    parsed,
128		unprotected: parsed.Unprotected,
129	}
130
131	// Check that there is not a nonce in the unprotected headers
132	if (parsed.Unprotected != nil && parsed.Unprotected.Nonce != "") ||
133		(parsed.Header != nil && parsed.Header.Nonce != "") {
134		return nil, ErrUnprotectedNonce
135	}
136
137	if parsed.Protected != nil && len(parsed.Protected.bytes()) > 0 {
138		err := json.Unmarshal(parsed.Protected.bytes(), &obj.protected)
139		if err != nil {
140			return nil, fmt.Errorf("square/go-jose: invalid protected header: %s, %s", err, parsed.Protected.base64())
141		}
142	}
143
144	// Note: this must be called _after_ we parse the protected header,
145	// otherwise fields from the protected header will not get picked up.
146	obj.Header = obj.mergedHeaders(nil).sanitized()
147
148	if len(parsed.Recipients) == 0 {
149		obj.recipients = []recipientInfo{
150			recipientInfo{
151				header:       parsed.Header,
152				encryptedKey: parsed.EncryptedKey.bytes(),
153			},
154		}
155	} else {
156		obj.recipients = make([]recipientInfo, len(parsed.Recipients))
157		for r := range parsed.Recipients {
158			encryptedKey, err := base64URLDecode(parsed.Recipients[r].EncryptedKey)
159			if err != nil {
160				return nil, err
161			}
162
163			// Check that there is not a nonce in the unprotected header
164			if parsed.Recipients[r].Header != nil && parsed.Recipients[r].Header.Nonce != "" {
165				return nil, ErrUnprotectedNonce
166			}
167
168			obj.recipients[r].header = parsed.Recipients[r].Header
169			obj.recipients[r].encryptedKey = encryptedKey
170		}
171	}
172
173	for _, recipient := range obj.recipients {
174		headers := obj.mergedHeaders(&recipient)
175		if headers.Alg == "" || headers.Enc == "" {
176			return nil, fmt.Errorf("square/go-jose: message is missing alg/enc headers")
177		}
178	}
179
180	obj.iv = parsed.Iv.bytes()
181	obj.ciphertext = parsed.Ciphertext.bytes()
182	obj.tag = parsed.Tag.bytes()
183	obj.aad = parsed.Aad.bytes()
184
185	return obj, nil
186}
187
188// parseEncryptedCompact parses a message in compact format.
189func parseEncryptedCompact(input string) (*JsonWebEncryption, error) {
190	parts := strings.Split(input, ".")
191	if len(parts) != 5 {
192		return nil, fmt.Errorf("square/go-jose: compact JWE format must have five parts")
193	}
194
195	rawProtected, err := base64URLDecode(parts[0])
196	if err != nil {
197		return nil, err
198	}
199
200	encryptedKey, err := base64URLDecode(parts[1])
201	if err != nil {
202		return nil, err
203	}
204
205	iv, err := base64URLDecode(parts[2])
206	if err != nil {
207		return nil, err
208	}
209
210	ciphertext, err := base64URLDecode(parts[3])
211	if err != nil {
212		return nil, err
213	}
214
215	tag, err := base64URLDecode(parts[4])
216	if err != nil {
217		return nil, err
218	}
219
220	raw := &rawJsonWebEncryption{
221		Protected:    newBuffer(rawProtected),
222		EncryptedKey: newBuffer(encryptedKey),
223		Iv:           newBuffer(iv),
224		Ciphertext:   newBuffer(ciphertext),
225		Tag:          newBuffer(tag),
226	}
227
228	return raw.sanitized()
229}
230
231// CompactSerialize serializes an object using the compact serialization format.
232func (obj JsonWebEncryption) CompactSerialize() (string, error) {
233	if len(obj.recipients) != 1 || obj.unprotected != nil ||
234		obj.protected == nil || obj.recipients[0].header != nil {
235		return "", ErrNotSupported
236	}
237
238	serializedProtected := mustSerializeJSON(obj.protected)
239
240	return fmt.Sprintf(
241		"%s.%s.%s.%s.%s",
242		base64URLEncode(serializedProtected),
243		base64URLEncode(obj.recipients[0].encryptedKey),
244		base64URLEncode(obj.iv),
245		base64URLEncode(obj.ciphertext),
246		base64URLEncode(obj.tag)), nil
247}
248
249// FullSerialize serializes an object using the full JSON serialization format.
250func (obj JsonWebEncryption) FullSerialize() string {
251	raw := rawJsonWebEncryption{
252		Unprotected:  obj.unprotected,
253		Iv:           newBuffer(obj.iv),
254		Ciphertext:   newBuffer(obj.ciphertext),
255		EncryptedKey: newBuffer(obj.recipients[0].encryptedKey),
256		Tag:          newBuffer(obj.tag),
257		Aad:          newBuffer(obj.aad),
258		Recipients:   []rawRecipientInfo{},
259	}
260
261	if len(obj.recipients) > 1 {
262		for _, recipient := range obj.recipients {
263			info := rawRecipientInfo{
264				Header:       recipient.header,
265				EncryptedKey: base64URLEncode(recipient.encryptedKey),
266			}
267			raw.Recipients = append(raw.Recipients, info)
268		}
269	} else {
270		// Use flattened serialization
271		raw.Header = obj.recipients[0].header
272		raw.EncryptedKey = newBuffer(obj.recipients[0].encryptedKey)
273	}
274
275	if obj.protected != nil {
276		raw.Protected = newBuffer(mustSerializeJSON(obj.protected))
277	}
278
279	return string(mustSerializeJSON(raw))
280}
281