1package pki 2 3import ( 4 "context" 5 "fmt" 6 "reflect" 7 "testing" 8 9 "strings" 10 11 "github.com/hashicorp/vault/sdk/framework" 12 "github.com/hashicorp/vault/sdk/logical" 13) 14 15func TestPki_FetchCertBySerial(t *testing.T) { 16 storage := &logical.InmemStorage{} 17 18 cases := map[string]struct { 19 Req *logical.Request 20 Prefix string 21 Serial string 22 }{ 23 "valid cert": { 24 &logical.Request{ 25 Storage: storage, 26 }, 27 "certs/", 28 "00:00:00:00:00:00:00:00", 29 }, 30 "revoked cert": { 31 &logical.Request{ 32 Storage: storage, 33 }, 34 "revoked/", 35 "11:11:11:11:11:11:11:11", 36 }, 37 } 38 39 // Test for colon-based paths in storage 40 for name, tc := range cases { 41 storageKey := fmt.Sprintf("%s%s", tc.Prefix, tc.Serial) 42 err := storage.Put(context.Background(), &logical.StorageEntry{ 43 Key: storageKey, 44 Value: []byte("some data"), 45 }) 46 if err != nil { 47 t.Fatalf("error writing to storage on %s colon-based storage path: %s", name, err) 48 } 49 50 certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial) 51 if err != nil { 52 t.Fatalf("error on %s for colon-based storage path: %s", name, err) 53 } 54 55 // Check for non-nil on valid/revoked certs 56 if certEntry == nil { 57 t.Fatalf("nil on %s for colon-based storage path", name) 58 } 59 60 // Ensure that cert serials are converted/updated after fetch 61 expectedKey := tc.Prefix + normalizeSerial(tc.Serial) 62 se, err := storage.Get(context.Background(), expectedKey) 63 if err != nil { 64 t.Fatalf("error on %s for colon-based storage path:%s", name, err) 65 } 66 if strings.Compare(expectedKey, se.Key) != 0 { 67 t.Fatalf("expected: %s, got: %s", expectedKey, certEntry.Key) 68 } 69 } 70 71 // Reset storage 72 storage = &logical.InmemStorage{} 73 74 // Test for hyphen-base paths in storage 75 for name, tc := range cases { 76 storageKey := tc.Prefix + normalizeSerial(tc.Serial) 77 err := storage.Put(context.Background(), &logical.StorageEntry{ 78 Key: storageKey, 79 Value: []byte("some data"), 80 }) 81 if err != nil { 82 t.Fatalf("error writing to storage on %s hyphen-based storage path: %s", name, err) 83 } 84 85 certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial) 86 if err != nil || certEntry == nil { 87 t.Fatalf("error on %s for hyphen-based storage path: err: %v, entry: %v", name, err, certEntry) 88 } 89 } 90 91 noConvCases := map[string]struct { 92 Req *logical.Request 93 Prefix string 94 Serial string 95 }{ 96 "ca": { 97 &logical.Request{ 98 Storage: storage, 99 }, 100 "", 101 "ca", 102 }, 103 "crl": { 104 &logical.Request{ 105 Storage: storage, 106 }, 107 "", 108 "crl", 109 }, 110 } 111 112 // Test for ca and crl case 113 for name, tc := range noConvCases { 114 err := storage.Put(context.Background(), &logical.StorageEntry{ 115 Key: tc.Serial, 116 Value: []byte("some data"), 117 }) 118 if err != nil { 119 t.Fatalf("error writing to storage on %s: %s", name, err) 120 } 121 122 certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial) 123 if err != nil || certEntry == nil { 124 t.Fatalf("error on %s: err: %v, entry: %v", name, err, certEntry) 125 } 126 } 127} 128 129// Demonstrate that multiple OUs in the name are handled in an 130// order-preserving way. 131func TestPki_MultipleOUs(t *testing.T) { 132 var b backend 133 fields := addCACommonFields(map[string]*framework.FieldSchema{}) 134 135 apiData := &framework.FieldData{ 136 Schema: fields, 137 Raw: map[string]interface{}{ 138 "cn": "example.com", 139 "ttl": 3600, 140 }, 141 } 142 input := &inputBundle{ 143 apiData: apiData, 144 role: &roleEntry{ 145 MaxTTL: 3600, 146 OU: []string{"Z", "E", "V"}, 147 }, 148 } 149 cb, err := generateCreationBundle(&b, input, nil, nil) 150 if err != nil { 151 t.Fatalf("Error: %v", err) 152 } 153 154 expected := []string{"Z", "E", "V"} 155 actual := cb.Params.Subject.OrganizationalUnit 156 157 if !reflect.DeepEqual(expected, actual) { 158 t.Fatalf("Expected %v, got %v", expected, actual) 159 } 160} 161 162func TestPki_PermitFQDNs(t *testing.T) { 163 var b backend 164 fields := addCACommonFields(map[string]*framework.FieldSchema{}) 165 166 apiData := &framework.FieldData{ 167 Schema: fields, 168 Raw: map[string]interface{}{ 169 "common_name": "example.com.", 170 "ttl": 3600, 171 }, 172 } 173 input := &inputBundle{ 174 apiData: apiData, 175 role: &roleEntry{ 176 AllowAnyName: true, 177 MaxTTL: 3600, 178 EnforceHostnames: true, 179 }, 180 } 181 cb, err := generateCreationBundle(&b, input, nil, nil) 182 if err != nil { 183 t.Fatalf("Error: %v", err) 184 } 185 186 expected := []string{"example.com."} 187 actual := cb.Params.DNSNames 188 189 if !reflect.DeepEqual(expected, actual) { 190 t.Fatalf("Expected %v, got %v", expected, actual) 191 } 192} 193