1/*-
2 * Copyright 2016 Zbigniew Mandziejewicz
3 * Copyright 2016 Square, Inc.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package jwt
19
20import (
21	"bytes"
22	"reflect"
23
24	"gopkg.in/square/go-jose.v2/json"
25
26	"gopkg.in/square/go-jose.v2"
27)
28
29// Builder is a utility for making JSON Web Tokens. Calls can be chained, and
30// errors are accumulated until the final call to CompactSerialize/FullSerialize.
31type Builder interface {
32	// Claims encodes claims into JWE/JWS form. Multiple calls will merge claims
33	// into single JSON object. If you are passing private claims, make sure to set
34	// struct field tags to specify the name for the JSON key to be used when
35	// serializing.
36	Claims(i interface{}) Builder
37	// Token builds a JSONWebToken from provided data.
38	Token() (*JSONWebToken, error)
39	// FullSerialize serializes a token using the full serialization format.
40	FullSerialize() (string, error)
41	// CompactSerialize serializes a token using the compact serialization format.
42	CompactSerialize() (string, error)
43}
44
45// NestedBuilder is a utility for making Signed-Then-Encrypted JSON Web Tokens.
46// Calls can be chained, and errors are accumulated until final call to
47// CompactSerialize/FullSerialize.
48type NestedBuilder interface {
49	// Claims encodes claims into JWE/JWS form. Multiple calls will merge claims
50	// into single JSON object. If you are passing private claims, make sure to set
51	// struct field tags to specify the name for the JSON key to be used when
52	// serializing.
53	Claims(i interface{}) NestedBuilder
54	// Token builds a NestedJSONWebToken from provided data.
55	Token() (*NestedJSONWebToken, error)
56	// FullSerialize serializes a token using the full serialization format.
57	FullSerialize() (string, error)
58	// CompactSerialize serializes a token using the compact serialization format.
59	CompactSerialize() (string, error)
60}
61
62type builder struct {
63	payload map[string]interface{}
64	err     error
65}
66
67type signedBuilder struct {
68	builder
69	sig jose.Signer
70}
71
72type encryptedBuilder struct {
73	builder
74	enc jose.Encrypter
75}
76
77type nestedBuilder struct {
78	builder
79	sig jose.Signer
80	enc jose.Encrypter
81}
82
83// Signed creates builder for signed tokens.
84func Signed(sig jose.Signer) Builder {
85	return &signedBuilder{
86		sig: sig,
87	}
88}
89
90// Encrypted creates builder for encrypted tokens.
91func Encrypted(enc jose.Encrypter) Builder {
92	return &encryptedBuilder{
93		enc: enc,
94	}
95}
96
97// SignedAndEncrypted creates builder for signed-then-encrypted tokens.
98// ErrInvalidContentType will be returned if encrypter doesn't have JWT content type.
99func SignedAndEncrypted(sig jose.Signer, enc jose.Encrypter) NestedBuilder {
100	if contentType, _ := enc.Options().ExtraHeaders[jose.HeaderContentType].(jose.ContentType); contentType != "JWT" {
101		return &nestedBuilder{
102			builder: builder{
103				err: ErrInvalidContentType,
104			},
105		}
106	}
107	return &nestedBuilder{
108		sig: sig,
109		enc: enc,
110	}
111}
112
113func (b builder) claims(i interface{}) builder {
114	if b.err != nil {
115		return b
116	}
117
118	m, ok := i.(map[string]interface{})
119	switch {
120	case ok:
121		return b.merge(m)
122	case reflect.Indirect(reflect.ValueOf(i)).Kind() == reflect.Struct:
123		m, err := normalize(i)
124		if err != nil {
125			return builder{
126				err: err,
127			}
128		}
129		return b.merge(m)
130	default:
131		return builder{
132			err: ErrInvalidClaims,
133		}
134	}
135}
136
137func normalize(i interface{}) (map[string]interface{}, error) {
138	m := make(map[string]interface{})
139
140	raw, err := json.Marshal(i)
141	if err != nil {
142		return nil, err
143	}
144
145	d := json.NewDecoder(bytes.NewReader(raw))
146	d.UseNumber()
147
148	if err := d.Decode(&m); err != nil {
149		return nil, err
150	}
151
152	return m, nil
153}
154
155func (b *builder) merge(m map[string]interface{}) builder {
156	p := make(map[string]interface{})
157	for k, v := range b.payload {
158		p[k] = v
159	}
160	for k, v := range m {
161		p[k] = v
162	}
163
164	return builder{
165		payload: p,
166	}
167}
168
169func (b *builder) token(p func(interface{}) ([]byte, error), h []jose.Header) (*JSONWebToken, error) {
170	return &JSONWebToken{
171		payload: p,
172		Headers: h,
173	}, nil
174}
175
176func (b *signedBuilder) Claims(i interface{}) Builder {
177	return &signedBuilder{
178		builder: b.builder.claims(i),
179		sig:     b.sig,
180	}
181}
182
183func (b *signedBuilder) Token() (*JSONWebToken, error) {
184	sig, err := b.sign()
185	if err != nil {
186		return nil, err
187	}
188
189	h := make([]jose.Header, len(sig.Signatures))
190	for i, v := range sig.Signatures {
191		h[i] = v.Header
192	}
193
194	return b.builder.token(sig.Verify, h)
195}
196
197func (b *signedBuilder) CompactSerialize() (string, error) {
198	sig, err := b.sign()
199	if err != nil {
200		return "", err
201	}
202
203	return sig.CompactSerialize()
204}
205
206func (b *signedBuilder) FullSerialize() (string, error) {
207	sig, err := b.sign()
208	if err != nil {
209		return "", err
210	}
211
212	return sig.FullSerialize(), nil
213}
214
215func (b *signedBuilder) sign() (*jose.JSONWebSignature, error) {
216	if b.err != nil {
217		return nil, b.err
218	}
219
220	p, err := json.Marshal(b.payload)
221	if err != nil {
222		return nil, err
223	}
224
225	return b.sig.Sign(p)
226}
227
228func (b *encryptedBuilder) Claims(i interface{}) Builder {
229	return &encryptedBuilder{
230		builder: b.builder.claims(i),
231		enc:     b.enc,
232	}
233}
234
235func (b *encryptedBuilder) CompactSerialize() (string, error) {
236	enc, err := b.encrypt()
237	if err != nil {
238		return "", err
239	}
240
241	return enc.CompactSerialize()
242}
243
244func (b *encryptedBuilder) FullSerialize() (string, error) {
245	enc, err := b.encrypt()
246	if err != nil {
247		return "", err
248	}
249
250	return enc.FullSerialize(), nil
251}
252
253func (b *encryptedBuilder) Token() (*JSONWebToken, error) {
254	enc, err := b.encrypt()
255	if err != nil {
256		return nil, err
257	}
258
259	return b.builder.token(enc.Decrypt, []jose.Header{enc.Header})
260}
261
262func (b *encryptedBuilder) encrypt() (*jose.JSONWebEncryption, error) {
263	if b.err != nil {
264		return nil, b.err
265	}
266
267	p, err := json.Marshal(b.payload)
268	if err != nil {
269		return nil, err
270	}
271
272	return b.enc.Encrypt(p)
273}
274
275func (b *nestedBuilder) Claims(i interface{}) NestedBuilder {
276	return &nestedBuilder{
277		builder: b.builder.claims(i),
278		sig:     b.sig,
279		enc:     b.enc,
280	}
281}
282
283func (b *nestedBuilder) Token() (*NestedJSONWebToken, error) {
284	enc, err := b.signAndEncrypt()
285	if err != nil {
286		return nil, err
287	}
288
289	return &NestedJSONWebToken{
290		enc:     enc,
291		Headers: []jose.Header{enc.Header},
292	}, nil
293}
294
295func (b *nestedBuilder) CompactSerialize() (string, error) {
296	enc, err := b.signAndEncrypt()
297	if err != nil {
298		return "", err
299	}
300
301	return enc.CompactSerialize()
302}
303
304func (b *nestedBuilder) FullSerialize() (string, error) {
305	enc, err := b.signAndEncrypt()
306	if err != nil {
307		return "", err
308	}
309
310	return enc.FullSerialize(), nil
311}
312
313func (b *nestedBuilder) signAndEncrypt() (*jose.JSONWebEncryption, error) {
314	if b.err != nil {
315		return nil, b.err
316	}
317
318	p, err := json.Marshal(b.payload)
319	if err != nil {
320		return nil, err
321	}
322
323	sig, err := b.sig.Sign(p)
324	if err != nil {
325		return nil, err
326	}
327
328	p2, err := sig.CompactSerialize()
329	if err != nil {
330		return nil, err
331	}
332
333	return b.enc.Encrypt([]byte(p2))
334}
335