1package samlidp
2
3import (
4	"bytes"
5	"encoding/xml"
6	"errors"
7	"io"
8	"io/ioutil"
9
10	xrv "github.com/mattermost/xml-roundtrip-validator"
11
12	"github.com/crewjam/saml"
13)
14
15func randomBytes(n int) []byte {
16	rv := make([]byte, n)
17	if _, err := saml.RandReader.Read(rv); err != nil {
18		panic(err)
19	}
20	return rv
21}
22
23func getSPMetadata(r io.Reader) (spMetadata *saml.EntityDescriptor, err error) {
24	var data []byte
25	if data, err = ioutil.ReadAll(r); err != nil {
26		return nil, err
27	}
28
29	spMetadata = &saml.EntityDescriptor{}
30	if err := xrv.Validate(bytes.NewBuffer(data)); err != nil {
31		return nil, err
32	}
33
34	if err := xml.Unmarshal(data, &spMetadata); err != nil {
35		if err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
36			entities := &saml.EntitiesDescriptor{}
37			if err := xml.Unmarshal(data, &entities); err != nil {
38				return nil, err
39			}
40
41			for _, e := range entities.EntityDescriptors {
42				if len(e.SPSSODescriptors) > 0 {
43					return &e, nil
44				}
45			}
46
47			// there were no SPSSODescriptors in the response
48			return nil, errors.New("metadata contained no service provider metadata")
49		}
50
51		return nil, err
52	}
53
54	return spMetadata, nil
55}
56