1package dsig
2
3import (
4	"crypto"
5	"crypto/rand"
6	"crypto/rsa"
7	_ "crypto/sha1"
8	_ "crypto/sha256"
9	"encoding/base64"
10	"errors"
11	"fmt"
12
13	"github.com/beevik/etree"
14	"github.com/russellhaering/goxmldsig/etreeutils"
15)
16
17type SigningContext struct {
18	Hash          crypto.Hash
19	KeyStore      X509KeyStore
20	IdAttribute   string
21	Prefix        string
22	Canonicalizer Canonicalizer
23}
24
25func NewDefaultSigningContext(ks X509KeyStore) *SigningContext {
26	return &SigningContext{
27		Hash:          crypto.SHA256,
28		KeyStore:      ks,
29		IdAttribute:   DefaultIdAttr,
30		Prefix:        DefaultPrefix,
31		Canonicalizer: MakeC14N11Canonicalizer(),
32	}
33}
34
35func (ctx *SigningContext) SetSignatureMethod(algorithmID string) error {
36	hash, ok := signatureMethodsByIdentifier[algorithmID]
37	if !ok {
38		return fmt.Errorf("Unknown SignatureMethod: %s", algorithmID)
39	}
40
41	ctx.Hash = hash
42
43	return nil
44}
45
46func (ctx *SigningContext) digest(el *etree.Element) ([]byte, error) {
47	canonical, err := ctx.Canonicalizer.Canonicalize(el)
48	if err != nil {
49		return nil, err
50	}
51
52	hash := ctx.Hash.New()
53	_, err = hash.Write(canonical)
54	if err != nil {
55		return nil, err
56	}
57
58	return hash.Sum(nil), nil
59}
60
61func (ctx *SigningContext) constructSignedInfo(el *etree.Element, enveloped bool) (*etree.Element, error) {
62	digestAlgorithmIdentifier := ctx.GetDigestAlgorithmIdentifier()
63	if digestAlgorithmIdentifier == "" {
64		return nil, errors.New("unsupported hash mechanism")
65	}
66
67	signatureMethodIdentifier := ctx.GetSignatureMethodIdentifier()
68	if signatureMethodIdentifier == "" {
69		return nil, errors.New("unsupported signature method")
70	}
71
72	digest, err := ctx.digest(el)
73	if err != nil {
74		return nil, err
75	}
76
77	signedInfo := &etree.Element{
78		Tag:   SignedInfoTag,
79		Space: ctx.Prefix,
80	}
81
82	// /SignedInfo/CanonicalizationMethod
83	canonicalizationMethod := ctx.createNamespacedElement(signedInfo, CanonicalizationMethodTag)
84	canonicalizationMethod.CreateAttr(AlgorithmAttr, string(ctx.Canonicalizer.Algorithm()))
85
86	// /SignedInfo/SignatureMethod
87	signatureMethod := ctx.createNamespacedElement(signedInfo, SignatureMethodTag)
88	signatureMethod.CreateAttr(AlgorithmAttr, signatureMethodIdentifier)
89
90	// /SignedInfo/Reference
91	reference := ctx.createNamespacedElement(signedInfo, ReferenceTag)
92
93	dataId := el.SelectAttrValue(ctx.IdAttribute, "")
94	if dataId == "" {
95		reference.CreateAttr(URIAttr, "")
96	} else {
97		reference.CreateAttr(URIAttr, "#"+dataId)
98	}
99
100
101	// /SignedInfo/Reference/Transforms
102	transforms := ctx.createNamespacedElement(reference, TransformsTag)
103	if enveloped {
104		envelopedTransform := ctx.createNamespacedElement(transforms, TransformTag)
105		envelopedTransform.CreateAttr(AlgorithmAttr, EnvelopedSignatureAltorithmId.String())
106	}
107	canonicalizationAlgorithm := ctx.createNamespacedElement(transforms, TransformTag)
108	canonicalizationAlgorithm.CreateAttr(AlgorithmAttr, string(ctx.Canonicalizer.Algorithm()))
109
110	// /SignedInfo/Reference/DigestMethod
111	digestMethod := ctx.createNamespacedElement(reference, DigestMethodTag)
112	digestMethod.CreateAttr(AlgorithmAttr, digestAlgorithmIdentifier)
113
114	// /SignedInfo/Reference/DigestValue
115	digestValue := ctx.createNamespacedElement(reference, DigestValueTag)
116	digestValue.SetText(base64.StdEncoding.EncodeToString(digest))
117
118	return signedInfo, nil
119}
120
121func (ctx *SigningContext) ConstructSignature(el *etree.Element, enveloped bool) (*etree.Element, error) {
122	signedInfo, err := ctx.constructSignedInfo(el, enveloped)
123	if err != nil {
124		return nil, err
125	}
126
127	sig := &etree.Element{
128		Tag:   SignatureTag,
129		Space: ctx.Prefix,
130	}
131
132	xmlns := "xmlns"
133	if ctx.Prefix != "" {
134		xmlns += ":" + ctx.Prefix
135	}
136
137	sig.CreateAttr(xmlns, Namespace)
138	sig.AddChild(signedInfo)
139
140	// When using xml-c14n11 (ie, non-exclusive canonicalization) the canonical form
141	// of the SignedInfo must declare all namespaces that are in scope at it's final
142	// enveloped location in the document. In order to do that, we're going to construct
143	// a series of cascading NSContexts to capture namespace declarations:
144
145	// First get the context surrounding the element we are signing.
146	rootNSCtx, err := etreeutils.NSBuildParentContext(el)
147	if err != nil {
148		return nil, err
149	}
150
151	// Then capture any declarations on the element itself.
152	elNSCtx, err := rootNSCtx.SubContext(el)
153	if err != nil {
154		return nil, err
155	}
156
157	// Followed by declarations on the Signature (which we just added above)
158	sigNSCtx, err := elNSCtx.SubContext(sig)
159	if err != nil {
160		return nil, err
161	}
162
163	// Finally detatch the SignedInfo in order to capture all of the namespace
164	// declarations in the scope we've constructed.
165	detatchedSignedInfo, err := etreeutils.NSDetatch(sigNSCtx, signedInfo)
166	if err != nil {
167		return nil, err
168	}
169
170	digest, err := ctx.digest(detatchedSignedInfo)
171	if err != nil {
172		return nil, err
173	}
174
175	key, cert, err := ctx.KeyStore.GetKeyPair()
176	if err != nil {
177		return nil, err
178	}
179
180	certs := [][]byte{cert}
181	if cs, ok := ctx.KeyStore.(X509ChainStore); ok {
182		certs, err = cs.GetChain()
183		if err != nil {
184			return nil, err
185		}
186	}
187
188	rawSignature, err := rsa.SignPKCS1v15(rand.Reader, key, ctx.Hash, digest)
189	if err != nil {
190		return nil, err
191	}
192
193	signatureValue := ctx.createNamespacedElement(sig, SignatureValueTag)
194	signatureValue.SetText(base64.StdEncoding.EncodeToString(rawSignature))
195
196	keyInfo := ctx.createNamespacedElement(sig, KeyInfoTag)
197	x509Data := ctx.createNamespacedElement(keyInfo, X509DataTag)
198	for _, cert := range certs {
199		x509Certificate := ctx.createNamespacedElement(x509Data, X509CertificateTag)
200		x509Certificate.SetText(base64.StdEncoding.EncodeToString(cert))
201	}
202
203	return sig, nil
204}
205
206func (ctx *SigningContext) createNamespacedElement(el *etree.Element, tag string) *etree.Element {
207	child := el.CreateElement(tag)
208	child.Space = ctx.Prefix
209	return child
210}
211
212func (ctx *SigningContext) SignEnveloped(el *etree.Element) (*etree.Element, error) {
213	sig, err := ctx.ConstructSignature(el, true)
214	if err != nil {
215		return nil, err
216	}
217
218	ret := el.Copy()
219	ret.Child = append(ret.Child, sig)
220
221	return ret, nil
222}
223
224func (ctx *SigningContext) GetSignatureMethodIdentifier() string {
225	if ident, ok := signatureMethodIdentifiers[ctx.Hash]; ok {
226		return ident
227	}
228	return ""
229}
230
231func (ctx *SigningContext) GetDigestAlgorithmIdentifier() string {
232	if ident, ok := digestAlgorithmIdentifiers[ctx.Hash]; ok {
233		return ident
234	}
235	return ""
236}
237
238// Useful for signing query string (including DEFLATED AuthnRequest) when
239// using HTTP-Redirect to make a signed request.
240// See 3.4.4.1 DEFLATE Encoding of https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf
241func (ctx *SigningContext) SignString(content string) ([]byte, error) {
242	hash := ctx.Hash.New()
243	if ln, err := hash.Write([]byte(content)); err != nil {
244		return nil, fmt.Errorf("error calculating hash: %v", err)
245	} else if ln < 1 {
246		return nil, fmt.Errorf("zero length hash")
247	}
248	digest := hash.Sum(nil)
249
250	var signature []byte
251	if key, _, err := ctx.KeyStore.GetKeyPair(); err != nil {
252		return nil, fmt.Errorf("unable to fetch key for signing: %v", err)
253	} else if signature, err = rsa.SignPKCS1v15(rand.Reader, key, ctx.Hash, digest); err != nil {
254		return nil, fmt.Errorf("error signing: %v", err)
255	}
256	return signature, nil
257}
258