1// Copyright (c) 2017-2019 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "context" 7 "errors" 8 "net/http" 9 "net/url" 10 "testing" 11 "time" 12) 13 14func TestUnitPostBackURL(t *testing.T) { 15 c := `<html><form id="1" action="https://abc.com/"></form></html>` 16 pbURL, err := postBackURL([]byte(c)) 17 if err != nil { 18 t.Fatalf("failed to get URL. err: %v, %v", err, c) 19 } 20 if pbURL.String() != "https://abc.com/" { 21 t.Errorf("failed to get URL. got: %v, %v", pbURL, c) 22 } 23 c = `<html></html>` 24 _, err = postBackURL([]byte(c)) 25 if err == nil { 26 t.Fatalf("should have failed") 27 } 28 c = `<html><form id="1"/></html>` 29 _, err = postBackURL([]byte(c)) 30 if err == nil { 31 t.Fatalf("should have failed") 32 } 33 c = `<html><form id="1" action="https://abc.com//></html>` 34 _, err = postBackURL([]byte(c)) 35 if err == nil { 36 t.Fatalf("should have failed") 37 } 38} 39 40func getTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { 41 return &http.Response{ 42 StatusCode: http.StatusOK, 43 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 44 }, errors.New("failed to run post method") 45} 46 47func getTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { 48 return &http.Response{ 49 StatusCode: http.StatusBadGateway, 50 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 51 }, nil 52} 53 54func getTestHTMLSuccess(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { 55 return &http.Response{ 56 StatusCode: http.StatusOK, 57 Body: &fakeResponseBody{body: []byte("<htm></html>")}, 58 }, nil 59} 60 61func TestUnitPostAuthSAML(t *testing.T) { 62 sr := &snowflakeRestful{ 63 FuncPost: postTestError, 64 TokenAccessor: getSimpleTokenAccessor(), 65 } 66 var err error 67 _, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{}, 0) 68 if err == nil { 69 t.Fatal("should have failed.") 70 } 71 sr.FuncPost = postTestAppBadGatewayError 72 _, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{}, 0) 73 if err == nil { 74 t.Fatal("should have failed.") 75 } 76 sr.FuncPost = postTestSuccessButInvalidJSON 77 _, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{0x12, 0x34}, 0) 78 if err == nil { 79 t.Fatalf("should have failed to post") 80 } 81} 82 83func TestUnitPostAuthOKTA(t *testing.T) { 84 sr := &snowflakeRestful{ 85 FuncPost: postTestError, 86 TokenAccessor: getSimpleTokenAccessor(), 87 } 88 var err error 89 _, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{}, "hahah", 0) 90 if err == nil { 91 t.Fatal("should have failed.") 92 } 93 sr.FuncPost = postTestAppBadGatewayError 94 _, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{}, "hahah", 0) 95 if err == nil { 96 t.Fatal("should have failed.") 97 } 98 sr.FuncPost = postTestSuccessButInvalidJSON 99 _, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{0x12, 0x34}, "haha", 0) 100 if err == nil { 101 t.Fatal("should have failed to run post request after the renewal") 102 } 103} 104 105func TestUnitGetSSO(t *testing.T) { 106 sr := &snowflakeRestful{ 107 FuncGet: getTestError, 108 TokenAccessor: getSimpleTokenAccessor(), 109 } 110 var err error 111 _, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0) 112 if err == nil { 113 t.Fatal("should have failed.") 114 } 115 sr.FuncGet = getTestAppBadGatewayError 116 _, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0) 117 if err == nil { 118 t.Fatal("should have failed.") 119 } 120 sr.FuncGet = getTestHTMLSuccess 121 _, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0) 122 if err != nil { 123 t.Fatalf("failed to get HTML content. err: %v", err) 124 } 125} 126 127func postAuthSAMLError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { 128 return &authResponse{}, errors.New("failed to get SAML response") 129} 130 131func postAuthSAMLAuthFail(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { 132 return &authResponse{ 133 Success: false, 134 Message: "SAML auth failed", 135 }, nil 136} 137 138func postAuthSAMLAuthSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { 139 return &authResponse{ 140 Success: true, 141 Message: "", 142 Data: authResponseMain{ 143 TokenURL: "https://1abc.com/token", 144 SSOURL: "https://2abc.com/sso", 145 }, 146 }, nil 147} 148 149func postAuthSAMLAuthSuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { 150 return &authResponse{ 151 Success: true, 152 Message: "", 153 Data: authResponseMain{ 154 TokenURL: "https://abc.com/token", 155 SSOURL: "https://abc.com/sso", 156 }, 157 }, nil 158} 159 160func postAuthOKTAError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ string, _ time.Duration) (*authOKTAResponse, error) { 161 return &authOKTAResponse{}, errors.New("failed to get SAML response") 162} 163 164func postAuthOKTASuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ string, _ time.Duration) (*authOKTAResponse, error) { 165 return &authOKTAResponse{}, nil 166} 167 168func getSSOError(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { 169 return []byte{}, errors.New("failed to get SSO html") 170} 171 172func getSSOSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { 173 return []byte(`<html><form id="1"/></html>`), nil 174} 175 176func getSSOSuccess(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { 177 return []byte(`<html><form id="1" action="https://abc.com/"></form></html>`), nil 178} 179 180func TestUnitAuthenticateBySAML(t *testing.T) { 181 authenticator := &url.URL{ 182 Scheme: "https", 183 Host: "abc.com", 184 } 185 application := "testapp" 186 account := "testaccount" 187 user := "u" 188 password := "p" 189 sr := &snowflakeRestful{ 190 Protocol: "https", 191 Host: "abc.com", 192 Port: 443, 193 FuncPostAuthSAML: postAuthSAMLError, 194 TokenAccessor: getSimpleTokenAccessor(), 195 } 196 var err error 197 _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) 198 if err == nil { 199 t.Fatal("should have failed.") 200 } 201 sr.FuncPostAuthSAML = postAuthSAMLAuthFail 202 _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) 203 if err == nil { 204 t.Fatal("should have failed.") 205 } 206 sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL 207 _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) 208 if err == nil { 209 t.Fatal("should have failed.") 210 } 211 driverErr, ok := err.(*SnowflakeError) 212 if !ok { 213 t.Fatalf("should be snowflake error. err: %v", err) 214 } 215 if driverErr.Number != ErrCodeIdpConnectionError { 216 t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) 217 } 218 sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess 219 sr.FuncPostAuthOKTA = postAuthOKTAError 220 _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) 221 if err == nil { 222 t.Fatal("should have failed.") 223 } 224 sr.FuncPostAuthOKTA = postAuthOKTASuccess 225 sr.FuncGetSSO = getSSOError 226 _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) 227 if err == nil { 228 t.Fatal("should have failed.") 229 } 230 sr.FuncGetSSO = getSSOSuccessButInvalidURL 231 _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) 232 if err == nil { 233 t.Fatal("should have failed.") 234 } 235 sr.FuncGetSSO = getSSOSuccess 236 _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) 237 if err != nil { 238 t.Fatalf("failed. err: %v", err) 239 } 240} 241