1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved. 2// +build go1.10 3 4package gosnowflake 5 6// This file contains variables or functions of test cases that we want to run for go version >= 1.10 7 8// For compile concern, should any newly added variables or functions here must also be added with same 9// name or signature but with default or empty content in the priv_key_test.go(See addParseDSNTest) 10 11import ( 12 "bytes" 13 "crypto/rand" 14 "crypto/rsa" 15 "crypto/x509" 16 "database/sql" 17 "encoding/base64" 18 "encoding/pem" 19 "fmt" 20 "io/ioutil" 21 "os" 22 "testing" 23) 24 25// helper function to generate PKCS8 encoded base64 string of a private key 26func generatePKCS8StringSupress(key *rsa.PrivateKey) string { 27 // Error would only be thrown when the private key type is not supported 28 // We would be safe as long as we are using rsa.PrivateKey 29 tmpBytes, _ := x509.MarshalPKCS8PrivateKey(key) 30 privKeyPKCS8 := base64.URLEncoding.EncodeToString(tmpBytes) 31 return privKeyPKCS8 32} 33 34// helper function to generate PKCS1 encoded base64 string of a private key 35func generatePKCS1String(key *rsa.PrivateKey) string { 36 tmpBytes := x509.MarshalPKCS1PrivateKey(key) 37 privKeyPKCS1 := base64.URLEncoding.EncodeToString(tmpBytes) 38 return privKeyPKCS1 39} 40 41// helper function to set up private key for testing 42func setupPrivateKey() { 43 env := func(key, defaultValue string) string { 44 if value := os.Getenv(key); value != "" { 45 return value 46 } 47 return defaultValue 48 } 49 privKeyPath := env("SNOWFLAKE_TEST_PRIVATE_KEY", "") 50 if privKeyPath == "" { 51 customPrivateKey = false 52 testPrivKey, _ = rsa.GenerateKey(rand.Reader, 2048) 53 } else { 54 // path to the DER file 55 customPrivateKey = true 56 data, _ := ioutil.ReadFile(privKeyPath) 57 block, _ := pem.Decode(data) 58 if block == nil || block.Type != "PRIVATE KEY" { 59 panic(fmt.Sprintf("%v is not a public key in PEM format.", privKeyPath)) 60 } 61 privKey, _ := x509.ParsePKCS8PrivateKey(block.Bytes) 62 testPrivKey = privKey.(*rsa.PrivateKey) 63 } 64} 65 66// Helper function to add encoded private key to dsn 67func appendPrivateKeyString(dsn *string, key *rsa.PrivateKey) string { 68 var b bytes.Buffer 69 b.WriteString(*dsn) 70 b.WriteString(fmt.Sprintf("&authenticator=%v", AuthTypeJwt.String())) 71 b.WriteString(fmt.Sprintf("&privateKey=%s", generatePKCS8StringSupress(key))) 72 return b.String() 73} 74 75// Integration test for the JWT authentication function 76func TestJWTAuthentication(t *testing.T) { 77 // For private key generated on the fly, we want to load the public key to the server first 78 if !customPrivateKey { 79 db := openDB(t) 80 // Load server's public key to database 81 pubKeyByte, err := x509.MarshalPKIXPublicKey(testPrivKey.Public()) 82 if err != nil { 83 t.Fatalf("error marshaling public key: %s", err.Error()) 84 } 85 if _, err = db.Exec("USE ROLE ACCOUNTADMIN"); err != nil { 86 t.Fatalf("error changin role: %s", err.Error()) 87 } 88 encodedKey := base64.StdEncoding.EncodeToString(pubKeyByte) 89 if _, err = db.Exec(fmt.Sprintf("ALTER USER %v set rsa_public_key='%v'", user, encodedKey)); err != nil { 90 t.Fatalf("error setting server's public key: %s", err.Error()) 91 } 92 db.Close() 93 } 94 95 // Test that a valid private key can pass 96 jwtDSN := appendPrivateKeyString(&dsn, testPrivKey) 97 db, err := sql.Open("snowflake", jwtDSN) 98 if err != nil { 99 t.Fatalf("error creating a connection object: %s", err.Error()) 100 } 101 if _, err = db.Exec("SELECT 1"); err != nil { 102 t.Fatalf("error executing: %s", err.Error()) 103 } 104 db.Close() 105 106 // Test that an invalid private key cannot pass 107 invalidPrivateKey, _ := rsa.GenerateKey(rand.Reader, 2048) 108 jwtDSN = appendPrivateKeyString(&dsn, invalidPrivateKey) 109 db, _ = sql.Open("snowflake", jwtDSN) 110 if _, err = db.Exec("SELECT 1"); err == nil { 111 t.Fatalf("An invalid jwt token can pass") 112 } 113 114 db.Close() 115} 116