1package dsig
2
3import (
4	"crypto"
5	"crypto/rand"
6	"crypto/rsa"
7	_ "crypto/sha1"
8	_ "crypto/sha256"
9	"encoding/base64"
10	"errors"
11	"fmt"
12
13	"github.com/beevik/etree"
14	"github.com/russellhaering/goxmldsig/etreeutils"
15)
16
17type SigningContext struct {
18	Hash          crypto.Hash
19	KeyStore      X509KeyStore
20	IdAttribute   string
21	Prefix        string
22	Canonicalizer Canonicalizer
23}
24
25func NewDefaultSigningContext(ks X509KeyStore) *SigningContext {
26	return &SigningContext{
27		Hash:          crypto.SHA256,
28		KeyStore:      ks,
29		IdAttribute:   DefaultIdAttr,
30		Prefix:        DefaultPrefix,
31		Canonicalizer: MakeC14N11Canonicalizer(),
32	}
33}
34
35func (ctx *SigningContext) SetSignatureMethod(algorithmID string) error {
36	hash, ok := signatureMethodsByIdentifier[algorithmID]
37	if !ok {
38		return fmt.Errorf("Unknown SignatureMethod: %s", algorithmID)
39	}
40
41	ctx.Hash = hash
42
43	return nil
44}
45
46func (ctx *SigningContext) digest(el *etree.Element) ([]byte, error) {
47	canonical, err := ctx.Canonicalizer.Canonicalize(el)
48	if err != nil {
49		return nil, err
50	}
51
52	hash := ctx.Hash.New()
53	_, err = hash.Write(canonical)
54	if err != nil {
55		return nil, err
56	}
57
58	return hash.Sum(nil), nil
59}
60
61func (ctx *SigningContext) constructSignedInfo(el *etree.Element, enveloped bool) (*etree.Element, error) {
62	digestAlgorithmIdentifier, ok := digestAlgorithmIdentifiers[ctx.Hash]
63	if !ok {
64		return nil, errors.New("unsupported hash mechanism")
65	}
66
67	signatureMethodIdentifier, ok := signatureMethodIdentifiers[ctx.Hash]
68	if !ok {
69		return nil, errors.New("unsupported signature method")
70	}
71
72	digest, err := ctx.digest(el)
73	if err != nil {
74		return nil, err
75	}
76
77	signedInfo := &etree.Element{
78		Tag:   SignedInfoTag,
79		Space: ctx.Prefix,
80	}
81
82	// /SignedInfo/CanonicalizationMethod
83	canonicalizationMethod := ctx.createNamespacedElement(signedInfo, CanonicalizationMethodTag)
84	canonicalizationMethod.CreateAttr(AlgorithmAttr, string(ctx.Canonicalizer.Algorithm()))
85
86	// /SignedInfo/SignatureMethod
87	signatureMethod := ctx.createNamespacedElement(signedInfo, SignatureMethodTag)
88	signatureMethod.CreateAttr(AlgorithmAttr, signatureMethodIdentifier)
89
90	// /SignedInfo/Reference
91	reference := ctx.createNamespacedElement(signedInfo, ReferenceTag)
92
93	dataId := el.SelectAttrValue(ctx.IdAttribute, "")
94	if dataId == "" {
95		return nil, errors.New("Missing data ID")
96	}
97
98	reference.CreateAttr(URIAttr, "#"+dataId)
99
100	// /SignedInfo/Reference/Transforms
101	transforms := ctx.createNamespacedElement(reference, TransformsTag)
102	if enveloped {
103		envelopedTransform := ctx.createNamespacedElement(transforms, TransformTag)
104		envelopedTransform.CreateAttr(AlgorithmAttr, EnvelopedSignatureAltorithmId.String())
105	}
106	canonicalizationAlgorithm := ctx.createNamespacedElement(transforms, TransformTag)
107	canonicalizationAlgorithm.CreateAttr(AlgorithmAttr, string(ctx.Canonicalizer.Algorithm()))
108
109	// /SignedInfo/Reference/DigestMethod
110	digestMethod := ctx.createNamespacedElement(reference, DigestMethodTag)
111	digestMethod.CreateAttr(AlgorithmAttr, digestAlgorithmIdentifier)
112
113	// /SignedInfo/Reference/DigestValue
114	digestValue := ctx.createNamespacedElement(reference, DigestValueTag)
115	digestValue.SetText(base64.StdEncoding.EncodeToString(digest))
116
117	return signedInfo, nil
118}
119
120func (ctx *SigningContext) constructSignature(el *etree.Element, enveloped bool) (*etree.Element, error) {
121	signedInfo, err := ctx.constructSignedInfo(el, enveloped)
122	if err != nil {
123		return nil, err
124	}
125
126	sig := &etree.Element{
127		Tag:   SignatureTag,
128		Space: ctx.Prefix,
129	}
130
131	xmlns := "xmlns"
132	if ctx.Prefix != "" {
133		xmlns += ":" + ctx.Prefix
134	}
135
136	sig.CreateAttr(xmlns, Namespace)
137	sig.AddChild(signedInfo)
138
139	// When using xml-c14n11 (ie, non-exclusive canonicalization) the canonical form
140	// of the SignedInfo must declare all namespaces that are in scope at it's final
141	// enveloped location in the document. In order to do that, we're going to construct
142	// a series of cascading NSContexts to capture namespace declarations:
143
144	// First get the context surrounding the element we are signing.
145	rootNSCtx, err := etreeutils.NSBuildParentContext(el)
146	if err != nil {
147		return nil, err
148	}
149
150	// Then capture any declarations on the element itself.
151	elNSCtx, err := rootNSCtx.SubContext(el)
152	if err != nil {
153		return nil, err
154	}
155
156	// Followed by declarations on the Signature (which we just added above)
157	sigNSCtx, err := elNSCtx.SubContext(sig)
158	if err != nil {
159		return nil, err
160	}
161
162	// Finally detatch the SignedInfo in order to capture all of the namespace
163	// declarations in the scope we've constructed.
164	detatchedSignedInfo, err := etreeutils.NSDetatch(sigNSCtx, signedInfo)
165	if err != nil {
166		return nil, err
167	}
168
169	digest, err := ctx.digest(detatchedSignedInfo)
170	if err != nil {
171		return nil, err
172	}
173
174	key, cert, err := ctx.KeyStore.GetKeyPair()
175	if err != nil {
176		return nil, err
177	}
178
179	rawSignature, err := rsa.SignPKCS1v15(rand.Reader, key, ctx.Hash, digest)
180	if err != nil {
181		return nil, err
182	}
183
184	signatureValue := ctx.createNamespacedElement(sig, SignatureValueTag)
185	signatureValue.SetText(base64.StdEncoding.EncodeToString(rawSignature))
186
187	keyInfo := ctx.createNamespacedElement(sig, KeyInfoTag)
188	x509Data := ctx.createNamespacedElement(keyInfo, X509DataTag)
189	x509Certificate := ctx.createNamespacedElement(x509Data, X509CertificateTag)
190	x509Certificate.SetText(base64.StdEncoding.EncodeToString(cert))
191
192	return sig, nil
193}
194
195func (ctx *SigningContext) createNamespacedElement(el *etree.Element, tag string) *etree.Element {
196	child := el.CreateElement(tag)
197	child.Space = ctx.Prefix
198	return child
199}
200
201func (ctx *SigningContext) SignEnveloped(el *etree.Element) (*etree.Element, error) {
202	sig, err := ctx.constructSignature(el, true)
203	if err != nil {
204		return nil, err
205	}
206
207	ret := el.Copy()
208	ret.Child = append(ret.Child, sig)
209
210	return ret, nil
211}
212