1package pki
2
3import (
4	"context"
5	"fmt"
6	"reflect"
7	"testing"
8
9	"strings"
10
11	"github.com/hashicorp/vault/sdk/framework"
12	"github.com/hashicorp/vault/sdk/logical"
13)
14
15func TestPki_FetchCertBySerial(t *testing.T) {
16	storage := &logical.InmemStorage{}
17
18	cases := map[string]struct {
19		Req    *logical.Request
20		Prefix string
21		Serial string
22	}{
23		"valid cert": {
24			&logical.Request{
25				Storage: storage,
26			},
27			"certs/",
28			"00:00:00:00:00:00:00:00",
29		},
30		"revoked cert": {
31			&logical.Request{
32				Storage: storage,
33			},
34			"revoked/",
35			"11:11:11:11:11:11:11:11",
36		},
37	}
38
39	// Test for colon-based paths in storage
40	for name, tc := range cases {
41		storageKey := fmt.Sprintf("%s%s", tc.Prefix, tc.Serial)
42		err := storage.Put(context.Background(), &logical.StorageEntry{
43			Key:   storageKey,
44			Value: []byte("some data"),
45		})
46		if err != nil {
47			t.Fatalf("error writing to storage on %s colon-based storage path: %s", name, err)
48		}
49
50		certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial)
51		if err != nil {
52			t.Fatalf("error on %s for colon-based storage path: %s", name, err)
53		}
54
55		// Check for non-nil on valid/revoked certs
56		if certEntry == nil {
57			t.Fatalf("nil on %s for colon-based storage path", name)
58		}
59
60		// Ensure that cert serials are converted/updated after fetch
61		expectedKey := tc.Prefix + normalizeSerial(tc.Serial)
62		se, err := storage.Get(context.Background(), expectedKey)
63		if err != nil {
64			t.Fatalf("error on %s for colon-based storage path:%s", name, err)
65		}
66		if strings.Compare(expectedKey, se.Key) != 0 {
67			t.Fatalf("expected: %s, got: %s", expectedKey, certEntry.Key)
68		}
69	}
70
71	// Reset storage
72	storage = &logical.InmemStorage{}
73
74	// Test for hyphen-base paths in storage
75	for name, tc := range cases {
76		storageKey := tc.Prefix + normalizeSerial(tc.Serial)
77		err := storage.Put(context.Background(), &logical.StorageEntry{
78			Key:   storageKey,
79			Value: []byte("some data"),
80		})
81		if err != nil {
82			t.Fatalf("error writing to storage on %s hyphen-based storage path: %s", name, err)
83		}
84
85		certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial)
86		if err != nil || certEntry == nil {
87			t.Fatalf("error on %s for hyphen-based storage path: err: %v, entry: %v", name, err, certEntry)
88		}
89	}
90
91	noConvCases := map[string]struct {
92		Req    *logical.Request
93		Prefix string
94		Serial string
95	}{
96		"ca": {
97			&logical.Request{
98				Storage: storage,
99			},
100			"",
101			"ca",
102		},
103		"crl": {
104			&logical.Request{
105				Storage: storage,
106			},
107			"",
108			"crl",
109		},
110	}
111
112	// Test for ca and crl case
113	for name, tc := range noConvCases {
114		err := storage.Put(context.Background(), &logical.StorageEntry{
115			Key:   tc.Serial,
116			Value: []byte("some data"),
117		})
118		if err != nil {
119			t.Fatalf("error writing to storage on %s: %s", name, err)
120		}
121
122		certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial)
123		if err != nil || certEntry == nil {
124			t.Fatalf("error on %s: err: %v, entry: %v", name, err, certEntry)
125		}
126	}
127}
128
129// Demonstrate that multiple OUs in the name are handled in an
130// order-preserving way.
131func TestPki_MultipleOUs(t *testing.T) {
132	var b backend
133	fields := addCACommonFields(map[string]*framework.FieldSchema{})
134
135	apiData := &framework.FieldData{
136		Schema: fields,
137		Raw: map[string]interface{}{
138			"cn":  "example.com",
139			"ttl": 3600,
140		},
141	}
142	input := &inputBundle{
143		apiData: apiData,
144		role: &roleEntry{
145			MaxTTL: 3600,
146			OU:     []string{"Z", "E", "V"},
147		},
148	}
149	cb, err := generateCreationBundle(&b, input, nil, nil)
150	if err != nil {
151		t.Fatalf("Error: %v", err)
152	}
153
154	expected := []string{"Z", "E", "V"}
155	actual := cb.Params.Subject.OrganizationalUnit
156
157	if !reflect.DeepEqual(expected, actual) {
158		t.Fatalf("Expected %v, got %v", expected, actual)
159	}
160}
161
162func TestPki_PermitFQDNs(t *testing.T) {
163	var b backend
164	fields := addCACommonFields(map[string]*framework.FieldSchema{})
165
166	apiData := &framework.FieldData{
167		Schema: fields,
168		Raw: map[string]interface{}{
169			"common_name": "example.com.",
170			"ttl":         3600,
171		},
172	}
173	input := &inputBundle{
174		apiData: apiData,
175		role: &roleEntry{
176			AllowAnyName:     true,
177			MaxTTL:           3600,
178			EnforceHostnames: true,
179		},
180	}
181	cb, err := generateCreationBundle(&b, input, nil, nil)
182	if err != nil {
183		t.Fatalf("Error: %v", err)
184	}
185
186	expected := []string{"example.com."}
187	actual := cb.Params.DNSNames
188
189	if !reflect.DeepEqual(expected, actual) {
190		t.Fatalf("Expected %v, got %v", expected, actual)
191	}
192}
193