1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"fmt"
21	"strings"
22)
23
24// rawJsonWebEncryption represents a raw JWE JSON object. Used for parsing/serializing.
25type rawJsonWebEncryption struct {
26	Protected    *byteBuffer        `json:"protected,omitempty"`
27	Unprotected  *rawHeader         `json:"unprotected,omitempty"`
28	Header       *rawHeader         `json:"header,omitempty"`
29	Recipients   []rawRecipientInfo `json:"recipients,omitempty"`
30	Aad          *byteBuffer        `json:"aad,omitempty"`
31	EncryptedKey *byteBuffer        `json:"encrypted_key,omitempty"`
32	Iv           *byteBuffer        `json:"iv,omitempty"`
33	Ciphertext   *byteBuffer        `json:"ciphertext,omitempty"`
34	Tag          *byteBuffer        `json:"tag,omitempty"`
35}
36
37// rawRecipientInfo represents a raw JWE Per-Recipient header JSON object. Used for parsing/serializing.
38type rawRecipientInfo struct {
39	Header       *rawHeader `json:"header,omitempty"`
40	EncryptedKey string     `json:"encrypted_key,omitempty"`
41}
42
43// JsonWebEncryption represents an encrypted JWE object after parsing.
44type JsonWebEncryption struct {
45	Header                   JoseHeader
46	protected, unprotected   *rawHeader
47	recipients               []recipientInfo
48	aad, iv, ciphertext, tag []byte
49	original                 *rawJsonWebEncryption
50}
51
52// recipientInfo represents a raw JWE Per-Recipient header JSON object after parsing.
53type recipientInfo struct {
54	header       *rawHeader
55	encryptedKey []byte
56}
57
58// GetAuthData retrieves the (optional) authenticated data attached to the object.
59func (obj JsonWebEncryption) GetAuthData() []byte {
60	if obj.aad != nil {
61		out := make([]byte, len(obj.aad))
62		copy(out, obj.aad)
63		return out
64	}
65
66	return nil
67}
68
69// Get the merged header values
70func (obj JsonWebEncryption) mergedHeaders(recipient *recipientInfo) rawHeader {
71	out := rawHeader{}
72	out.merge(obj.protected)
73	out.merge(obj.unprotected)
74
75	if recipient != nil {
76		out.merge(recipient.header)
77	}
78
79	return out
80}
81
82// Get the additional authenticated data from a JWE object.
83func (obj JsonWebEncryption) computeAuthData() []byte {
84	var protected string
85
86	if obj.original != nil {
87		protected = obj.original.Protected.base64()
88	} else {
89		protected = base64URLEncode(mustSerializeJSON((obj.protected)))
90	}
91
92	output := []byte(protected)
93	if obj.aad != nil {
94		output = append(output, '.')
95		output = append(output, []byte(base64URLEncode(obj.aad))...)
96	}
97
98	return output
99}
100
101// ParseEncrypted parses an encrypted message in compact or full serialization format.
102func ParseEncrypted(input string) (*JsonWebEncryption, error) {
103	input = stripWhitespace(input)
104	if strings.HasPrefix(input, "{") {
105		return parseEncryptedFull(input)
106	}
107
108	return parseEncryptedCompact(input)
109}
110
111// parseEncryptedFull parses a message in compact format.
112func parseEncryptedFull(input string) (*JsonWebEncryption, error) {
113	var parsed rawJsonWebEncryption
114	err := UnmarshalJSON([]byte(input), &parsed)
115	if err != nil {
116		return nil, err
117	}
118
119	return parsed.sanitized()
120}
121
122// sanitized produces a cleaned-up JWE object from the raw JSON.
123func (parsed *rawJsonWebEncryption) sanitized() (*JsonWebEncryption, error) {
124	obj := &JsonWebEncryption{
125		original:    parsed,
126		unprotected: parsed.Unprotected,
127	}
128
129	// Check that there is not a nonce in the unprotected headers
130	if (parsed.Unprotected != nil && parsed.Unprotected.Nonce != "") ||
131		(parsed.Header != nil && parsed.Header.Nonce != "") {
132		return nil, ErrUnprotectedNonce
133	}
134
135	if parsed.Protected != nil && len(parsed.Protected.bytes()) > 0 {
136		err := UnmarshalJSON(parsed.Protected.bytes(), &obj.protected)
137		if err != nil {
138			return nil, fmt.Errorf("square/go-jose: invalid protected header: %s, %s", err, parsed.Protected.base64())
139		}
140	}
141
142	// Note: this must be called _after_ we parse the protected header,
143	// otherwise fields from the protected header will not get picked up.
144	obj.Header = obj.mergedHeaders(nil).sanitized()
145
146	if len(parsed.Recipients) == 0 {
147		obj.recipients = []recipientInfo{
148			recipientInfo{
149				header:       parsed.Header,
150				encryptedKey: parsed.EncryptedKey.bytes(),
151			},
152		}
153	} else {
154		obj.recipients = make([]recipientInfo, len(parsed.Recipients))
155		for r := range parsed.Recipients {
156			encryptedKey, err := base64URLDecode(parsed.Recipients[r].EncryptedKey)
157			if err != nil {
158				return nil, err
159			}
160
161			// Check that there is not a nonce in the unprotected header
162			if parsed.Recipients[r].Header != nil && parsed.Recipients[r].Header.Nonce != "" {
163				return nil, ErrUnprotectedNonce
164			}
165
166			obj.recipients[r].header = parsed.Recipients[r].Header
167			obj.recipients[r].encryptedKey = encryptedKey
168		}
169	}
170
171	for _, recipient := range obj.recipients {
172		headers := obj.mergedHeaders(&recipient)
173		if headers.Alg == "" || headers.Enc == "" {
174			return nil, fmt.Errorf("square/go-jose: message is missing alg/enc headers")
175		}
176	}
177
178	obj.iv = parsed.Iv.bytes()
179	obj.ciphertext = parsed.Ciphertext.bytes()
180	obj.tag = parsed.Tag.bytes()
181	obj.aad = parsed.Aad.bytes()
182
183	return obj, nil
184}
185
186// parseEncryptedCompact parses a message in compact format.
187func parseEncryptedCompact(input string) (*JsonWebEncryption, error) {
188	parts := strings.Split(input, ".")
189	if len(parts) != 5 {
190		return nil, fmt.Errorf("square/go-jose: compact JWE format must have five parts")
191	}
192
193	rawProtected, err := base64URLDecode(parts[0])
194	if err != nil {
195		return nil, err
196	}
197
198	encryptedKey, err := base64URLDecode(parts[1])
199	if err != nil {
200		return nil, err
201	}
202
203	iv, err := base64URLDecode(parts[2])
204	if err != nil {
205		return nil, err
206	}
207
208	ciphertext, err := base64URLDecode(parts[3])
209	if err != nil {
210		return nil, err
211	}
212
213	tag, err := base64URLDecode(parts[4])
214	if err != nil {
215		return nil, err
216	}
217
218	raw := &rawJsonWebEncryption{
219		Protected:    newBuffer(rawProtected),
220		EncryptedKey: newBuffer(encryptedKey),
221		Iv:           newBuffer(iv),
222		Ciphertext:   newBuffer(ciphertext),
223		Tag:          newBuffer(tag),
224	}
225
226	return raw.sanitized()
227}
228
229// CompactSerialize serializes an object using the compact serialization format.
230func (obj JsonWebEncryption) CompactSerialize() (string, error) {
231	if len(obj.recipients) != 1 || obj.unprotected != nil ||
232		obj.protected == nil || obj.recipients[0].header != nil {
233		return "", ErrNotSupported
234	}
235
236	serializedProtected := mustSerializeJSON(obj.protected)
237
238	return fmt.Sprintf(
239		"%s.%s.%s.%s.%s",
240		base64URLEncode(serializedProtected),
241		base64URLEncode(obj.recipients[0].encryptedKey),
242		base64URLEncode(obj.iv),
243		base64URLEncode(obj.ciphertext),
244		base64URLEncode(obj.tag)), nil
245}
246
247// FullSerialize serializes an object using the full JSON serialization format.
248func (obj JsonWebEncryption) FullSerialize() string {
249	raw := rawJsonWebEncryption{
250		Unprotected:  obj.unprotected,
251		Iv:           newBuffer(obj.iv),
252		Ciphertext:   newBuffer(obj.ciphertext),
253		EncryptedKey: newBuffer(obj.recipients[0].encryptedKey),
254		Tag:          newBuffer(obj.tag),
255		Aad:          newBuffer(obj.aad),
256		Recipients:   []rawRecipientInfo{},
257	}
258
259	if len(obj.recipients) > 1 {
260		for _, recipient := range obj.recipients {
261			info := rawRecipientInfo{
262				Header:       recipient.header,
263				EncryptedKey: base64URLEncode(recipient.encryptedKey),
264			}
265			raw.Recipients = append(raw.Recipients, info)
266		}
267	} else {
268		// Use flattened serialization
269		raw.Header = obj.recipients[0].header
270		raw.EncryptedKey = newBuffer(obj.recipients[0].encryptedKey)
271	}
272
273	if obj.protected != nil {
274		raw.Protected = newBuffer(mustSerializeJSON(obj.protected))
275	}
276
277	return string(mustSerializeJSON(raw))
278}
279