1// Copyright The OpenTelemetry Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package oidcauthextension
16
17import (
18	"bytes"
19	"crypto"
20	"crypto/rand"
21	"crypto/rsa"
22	"crypto/sha1" // #nosec
23	"crypto/sha256"
24	"crypto/x509"
25	"crypto/x509/pkix"
26	"encoding/base64"
27	"encoding/binary"
28	"encoding/json"
29	"fmt"
30	"math/big"
31	"net/http"
32	"net/http/httptest"
33	"time"
34)
35
36// oidcServer is an overly simplified OIDC mock server, good enough to sign the tokens required by the test
37// and pass the verification done by the underlying libraries
38type oidcServer struct {
39	*httptest.Server
40	x509Cert   []byte
41	privateKey *rsa.PrivateKey
42}
43
44func newOIDCServer() (*oidcServer, error) {
45	jwks := map[string]interface{}{}
46
47	mux := http.NewServeMux()
48	server := httptest.NewUnstartedServer(mux)
49
50	mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, req *http.Request) {
51		w.Header().Set("Content-Type", "application/json; charset=utf-8")
52		err := json.NewEncoder(w).Encode(map[string]interface{}{
53			"issuer":   server.URL,
54			"jwks_uri": fmt.Sprintf("%s/.well-known/jwks.json", server.URL),
55		})
56		if err != nil {
57			w.WriteHeader(http.StatusInternalServerError)
58			return
59		}
60	})
61	mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, req *http.Request) {
62		w.Header().Set("Content-Type", "application/json; charset=utf-8")
63		if err := json.NewEncoder(w).Encode(jwks); err != nil {
64			w.WriteHeader(http.StatusInternalServerError)
65			return
66		}
67	})
68
69	privateKey, err := createPrivateKey()
70	if err != nil {
71		return nil, err
72	}
73
74	x509Cert, err := createCertificate(privateKey)
75	if err != nil {
76		return nil, err
77	}
78
79	eBytes := make([]byte, 8)
80	binary.BigEndian.PutUint64(eBytes, uint64(privateKey.E))
81	eBytes = bytes.TrimLeft(eBytes, "\x00")
82
83	// #nosec
84	sum := sha1.Sum(x509Cert)
85	jwks["keys"] = []map[string]interface{}{{
86		"alg": "RS256",
87		"kty": "RSA",
88		"use": "sig",
89		"x5c": []string{base64.StdEncoding.EncodeToString(x509Cert)},
90		"n":   base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()),
91		"e":   base64.RawURLEncoding.EncodeToString(eBytes),
92		"kid": base64.RawURLEncoding.EncodeToString(sum[:]),
93		"x5t": base64.RawURLEncoding.EncodeToString(sum[:]),
94	}}
95
96	return &oidcServer{server, x509Cert, privateKey}, nil
97}
98
99func (s *oidcServer) token(jsonPayload []byte) (string, error) {
100	jsonHeader, _ := json.Marshal(map[string]interface{}{
101		"alg": "RS256",
102		"typ": "JWT",
103	})
104
105	header := base64.RawURLEncoding.EncodeToString(jsonHeader)
106	payload := base64.RawURLEncoding.EncodeToString(jsonPayload)
107	digest := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", header, payload)))
108
109	signature, err := rsa.SignPKCS1v15(rand.Reader, s.privateKey, crypto.SHA256, digest[:])
110	if err != nil {
111		return "", err
112	}
113
114	encodedSignature := base64.RawURLEncoding.EncodeToString(signature)
115	token := fmt.Sprintf("%s.%s.%s", header, payload, encodedSignature)
116	return token, nil
117}
118
119func createCertificate(privateKey *rsa.PrivateKey) ([]byte, error) {
120	cert := x509.Certificate{
121		SerialNumber: big.NewInt(1),
122		Subject: pkix.Name{
123			Organization: []string{"Ecorp, Inc"},
124		},
125		NotBefore: time.Now(),
126		NotAfter:  time.Now().Add(5 * time.Minute),
127	}
128
129	x509Cert, err := x509.CreateCertificate(rand.Reader, &cert, &cert, &privateKey.PublicKey, privateKey)
130	if err != nil {
131		return nil, err
132	}
133
134	return x509Cert, nil
135}
136
137func createPrivateKey() (*rsa.PrivateKey, error) {
138	priv, err := rsa.GenerateKey(rand.Reader, 2048)
139	if err != nil {
140		return nil, err
141	}
142	return priv, nil
143}
144