1package jws
2
3import (
4	"bytes"
5	"encoding/json"
6)
7
8// Flat serializes the JWS to its "flattened" form per
9// https://tools.ietf.org/html/rfc7515#section-7.2.2
10func (j *jws) Flat(key interface{}) ([]byte, error) {
11	if len(j.sb) < 1 {
12		return nil, ErrNotEnoughMethods
13	}
14	if err := j.sign(key); err != nil {
15		return nil, err
16	}
17	return json.Marshal(struct {
18		Payload rawBase64 `json:"payload"`
19		sigHead
20	}{
21		Payload: j.plcache,
22		sigHead: j.sb[0],
23	})
24}
25
26// General serializes the JWS into its "general" form per
27// https://tools.ietf.org/html/rfc7515#section-7.2.1
28//
29// If only one key is passed it's used for all the provided
30// crypto.SigningMethods. Otherwise, len(keys) must equal the number
31// of crypto.SigningMethods added.
32func (j *jws) General(keys ...interface{}) ([]byte, error) {
33	if err := j.sign(keys...); err != nil {
34		return nil, err
35	}
36	return json.Marshal(struct {
37		Payload    rawBase64 `json:"payload"`
38		Signatures []sigHead `json:"signatures"`
39	}{
40		Payload:    j.plcache,
41		Signatures: j.sb,
42	})
43}
44
45// Compact serializes the JWS into its "compact" form per
46// https://tools.ietf.org/html/rfc7515#section-7.1
47func (j *jws) Compact(key interface{}) ([]byte, error) {
48	if len(j.sb) < 1 {
49		return nil, ErrNotEnoughMethods
50	}
51
52	if err := j.sign(key); err != nil {
53		return nil, err
54	}
55
56	sig, err := j.sb[0].Signature.Base64()
57	if err != nil {
58		return nil, err
59	}
60	return format(
61		j.sb[0].Protected,
62		j.plcache,
63		sig,
64	), nil
65}
66
67// sign signs each index of j's sb member.
68func (j *jws) sign(keys ...interface{}) error {
69	if err := j.cache(); err != nil {
70		return err
71	}
72
73	if len(keys) < 1 ||
74		len(keys) > 1 && len(keys) != len(j.sb) {
75		return ErrNotEnoughKeys
76	}
77
78	if len(keys) == 1 {
79		k := keys[0]
80		keys = make([]interface{}, len(j.sb))
81		for i := range keys {
82			keys[i] = k
83		}
84	}
85
86	for i := range j.sb {
87		if err := j.sb[i].cache(); err != nil {
88			return err
89		}
90
91		raw := format(j.sb[i].Protected, j.plcache)
92		sig, err := j.sb[i].method.Sign(raw, keys[i])
93		if err != nil {
94			return err
95		}
96		j.sb[i].Signature = sig
97	}
98
99	return nil
100}
101
102// cache marshals the payload, but only if it's changed since the last cache.
103func (j *jws) cache() (err error) {
104	if !j.clean {
105		j.plcache, err = j.payload.Base64()
106		j.clean = err == nil
107	}
108	return err
109}
110
111// cache marshals the protected and unprotected headers, but only if
112// they've changed since their last cache.
113func (s *sigHead) cache() (err error) {
114	if !s.clean {
115		s.Protected, err = s.protected.Base64()
116		if err != nil {
117			return err
118		}
119		s.Unprotected, err = s.unprotected.Base64()
120		if err != nil {
121			return err
122		}
123	}
124	s.clean = true
125	return nil
126}
127
128// format formats a slice of bytes in the order given, joining
129// them with a period.
130func format(a ...[]byte) []byte {
131	return bytes.Join(a, []byte{'.'})
132}
133