1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved.
2// +build go1.10
3
4package gosnowflake
5
6// This file contains variables or functions of test cases that we want to run for go version >= 1.10
7
8// For compile concern, should any newly added variables or functions here must also be added with same
9// name or signature but with default or empty content in the priv_key_test.go(See addParseDSNTest)
10
11import (
12	"bytes"
13	"crypto/rand"
14	"crypto/rsa"
15	"crypto/x509"
16	"database/sql"
17	"encoding/base64"
18	"encoding/pem"
19	"fmt"
20	"io/ioutil"
21	"os"
22	"testing"
23)
24
25// helper function to generate PKCS8 encoded base64 string of a private key
26func generatePKCS8StringSupress(key *rsa.PrivateKey) string {
27	// Error would only be thrown when the private key type is not supported
28	// We would be safe as long as we are using rsa.PrivateKey
29	tmpBytes, _ := x509.MarshalPKCS8PrivateKey(key)
30	privKeyPKCS8 := base64.URLEncoding.EncodeToString(tmpBytes)
31	return privKeyPKCS8
32}
33
34// helper function to generate PKCS1 encoded base64 string of a private key
35func generatePKCS1String(key *rsa.PrivateKey) string {
36	tmpBytes := x509.MarshalPKCS1PrivateKey(key)
37	privKeyPKCS1 := base64.URLEncoding.EncodeToString(tmpBytes)
38	return privKeyPKCS1
39}
40
41// helper function to set up private key for testing
42func setupPrivateKey() {
43	env := func(key, defaultValue string) string {
44		if value := os.Getenv(key); value != "" {
45			return value
46		}
47		return defaultValue
48	}
49	privKeyPath := env("SNOWFLAKE_TEST_PRIVATE_KEY", "")
50	if privKeyPath == "" {
51		customPrivateKey = false
52		testPrivKey, _ = rsa.GenerateKey(rand.Reader, 2048)
53	} else {
54		// path to the DER file
55		customPrivateKey = true
56		data, _ := ioutil.ReadFile(privKeyPath)
57		block, _ := pem.Decode(data)
58		if block == nil || block.Type != "PRIVATE KEY" {
59			panic(fmt.Sprintf("%v is not a public key in PEM format.", privKeyPath))
60		}
61		privKey, _ := x509.ParsePKCS8PrivateKey(block.Bytes)
62		testPrivKey = privKey.(*rsa.PrivateKey)
63	}
64}
65
66// Helper function to add encoded private key to dsn
67func appendPrivateKeyString(dsn *string, key *rsa.PrivateKey) string {
68	var b bytes.Buffer
69	b.WriteString(*dsn)
70	b.WriteString(fmt.Sprintf("&authenticator=%v", AuthTypeJwt.String()))
71	b.WriteString(fmt.Sprintf("&privateKey=%s", generatePKCS8StringSupress(key)))
72	return b.String()
73}
74
75// Integration test for the JWT authentication function
76func TestJWTAuthentication(t *testing.T) {
77	// For private key generated on the fly, we want to load the public key to the server first
78	if !customPrivateKey {
79		db := openDB(t)
80		// Load server's public key to database
81		pubKeyByte, err := x509.MarshalPKIXPublicKey(testPrivKey.Public())
82		if err != nil {
83			t.Fatalf("error marshaling public key: %s", err.Error())
84		}
85		if _, err = db.Exec("USE ROLE ACCOUNTADMIN"); err != nil {
86			t.Fatalf("error changin role: %s", err.Error())
87		}
88		encodedKey := base64.StdEncoding.EncodeToString(pubKeyByte)
89		if _, err = db.Exec(fmt.Sprintf("ALTER USER %v set rsa_public_key='%v'", user, encodedKey)); err != nil {
90			t.Fatalf("error setting server's public key: %s", err.Error())
91		}
92		db.Close()
93	}
94
95	// Test that a valid private key can pass
96	jwtDSN := appendPrivateKeyString(&dsn, testPrivKey)
97	db, err := sql.Open("snowflake", jwtDSN)
98	if err != nil {
99		t.Fatalf("error creating a connection object: %s", err.Error())
100	}
101	if _, err = db.Exec("SELECT 1"); err != nil {
102		t.Fatalf("error executing: %s", err.Error())
103	}
104	db.Close()
105
106	// Test that an invalid private key cannot pass
107	invalidPrivateKey, _ := rsa.GenerateKey(rand.Reader, 2048)
108	jwtDSN = appendPrivateKeyString(&dsn, invalidPrivateKey)
109	db, _ = sql.Open("snowflake", jwtDSN)
110	if _, err = db.Exec("SELECT 1"); err == nil {
111		t.Fatalf("An invalid jwt token can pass")
112	}
113
114	db.Close()
115}
116