1// Copyright (c) 2017-2019 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"context"
7	"errors"
8	"net/http"
9	"net/url"
10	"testing"
11	"time"
12)
13
14func TestUnitPostBackURL(t *testing.T) {
15	c := `<html><form id="1" action="https&#x3a;&#x2f;&#x2f;abc.com&#x2f;"></form></html>`
16	pbURL, err := postBackURL([]byte(c))
17	if err != nil {
18		t.Fatalf("failed to get URL. err: %v, %v", err, c)
19	}
20	if pbURL.String() != "https://abc.com/" {
21		t.Errorf("failed to get URL. got: %v, %v", pbURL, c)
22	}
23	c = `<html></html>`
24	_, err = postBackURL([]byte(c))
25	if err == nil {
26		t.Fatalf("should have failed")
27	}
28	c = `<html><form id="1"/></html>`
29	_, err = postBackURL([]byte(c))
30	if err == nil {
31		t.Fatalf("should have failed")
32	}
33	c = `<html><form id="1" action="https&#x3a;&#x2f;&#x2f;abc.com&#x2f;/></html>`
34	_, err = postBackURL([]byte(c))
35	if err == nil {
36		t.Fatalf("should have failed")
37	}
38}
39
40func getTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
41	return &http.Response{
42		StatusCode: http.StatusOK,
43		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
44	}, errors.New("failed to run post method")
45}
46
47func getTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
48	return &http.Response{
49		StatusCode: http.StatusBadGateway,
50		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
51	}, nil
52}
53
54func getTestHTMLSuccess(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
55	return &http.Response{
56		StatusCode: http.StatusOK,
57		Body:       &fakeResponseBody{body: []byte("<htm></html>")},
58	}, nil
59}
60
61func TestUnitPostAuthSAML(t *testing.T) {
62	sr := &snowflakeRestful{
63		FuncPost:      postTestError,
64		TokenAccessor: getSimpleTokenAccessor(),
65	}
66	var err error
67	_, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{}, 0)
68	if err == nil {
69		t.Fatal("should have failed.")
70	}
71	sr.FuncPost = postTestAppBadGatewayError
72	_, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{}, 0)
73	if err == nil {
74		t.Fatal("should have failed.")
75	}
76	sr.FuncPost = postTestSuccessButInvalidJSON
77	_, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{0x12, 0x34}, 0)
78	if err == nil {
79		t.Fatalf("should have failed to post")
80	}
81}
82
83func TestUnitPostAuthOKTA(t *testing.T) {
84	sr := &snowflakeRestful{
85		FuncPost:      postTestError,
86		TokenAccessor: getSimpleTokenAccessor(),
87	}
88	var err error
89	_, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{}, "hahah", 0)
90	if err == nil {
91		t.Fatal("should have failed.")
92	}
93	sr.FuncPost = postTestAppBadGatewayError
94	_, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{}, "hahah", 0)
95	if err == nil {
96		t.Fatal("should have failed.")
97	}
98	sr.FuncPost = postTestSuccessButInvalidJSON
99	_, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{0x12, 0x34}, "haha", 0)
100	if err == nil {
101		t.Fatal("should have failed to run post request after the renewal")
102	}
103}
104
105func TestUnitGetSSO(t *testing.T) {
106	sr := &snowflakeRestful{
107		FuncGet:       getTestError,
108		TokenAccessor: getSimpleTokenAccessor(),
109	}
110	var err error
111	_, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
112	if err == nil {
113		t.Fatal("should have failed.")
114	}
115	sr.FuncGet = getTestAppBadGatewayError
116	_, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
117	if err == nil {
118		t.Fatal("should have failed.")
119	}
120	sr.FuncGet = getTestHTMLSuccess
121	_, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
122	if err != nil {
123		t.Fatalf("failed to get HTML content. err: %v", err)
124	}
125}
126
127func postAuthSAMLError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
128	return &authResponse{}, errors.New("failed to get SAML response")
129}
130
131func postAuthSAMLAuthFail(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
132	return &authResponse{
133		Success: false,
134		Message: "SAML auth failed",
135	}, nil
136}
137
138func postAuthSAMLAuthSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
139	return &authResponse{
140		Success: true,
141		Message: "",
142		Data: authResponseMain{
143			TokenURL: "https://1abc.com/token",
144			SSOURL:   "https://2abc.com/sso",
145		},
146	}, nil
147}
148
149func postAuthSAMLAuthSuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
150	return &authResponse{
151		Success: true,
152		Message: "",
153		Data: authResponseMain{
154			TokenURL: "https://abc.com/token",
155			SSOURL:   "https://abc.com/sso",
156		},
157	}, nil
158}
159
160func postAuthOKTAError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ string, _ time.Duration) (*authOKTAResponse, error) {
161	return &authOKTAResponse{}, errors.New("failed to get SAML response")
162}
163
164func postAuthOKTASuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ string, _ time.Duration) (*authOKTAResponse, error) {
165	return &authOKTAResponse{}, nil
166}
167
168func getSSOError(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) {
169	return []byte{}, errors.New("failed to get SSO html")
170}
171
172func getSSOSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) {
173	return []byte(`<html><form id="1"/></html>`), nil
174}
175
176func getSSOSuccess(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) {
177	return []byte(`<html><form id="1" action="https&#x3a;&#x2f;&#x2f;abc.com&#x2f;"></form></html>`), nil
178}
179
180func TestUnitAuthenticateBySAML(t *testing.T) {
181	authenticator := &url.URL{
182		Scheme: "https",
183		Host:   "abc.com",
184	}
185	application := "testapp"
186	account := "testaccount"
187	user := "u"
188	password := "p"
189	sr := &snowflakeRestful{
190		Protocol:         "https",
191		Host:             "abc.com",
192		Port:             443,
193		FuncPostAuthSAML: postAuthSAMLError,
194		TokenAccessor:    getSimpleTokenAccessor(),
195	}
196	var err error
197	_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
198	if err == nil {
199		t.Fatal("should have failed.")
200	}
201	sr.FuncPostAuthSAML = postAuthSAMLAuthFail
202	_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
203	if err == nil {
204		t.Fatal("should have failed.")
205	}
206	sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL
207	_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
208	if err == nil {
209		t.Fatal("should have failed.")
210	}
211	driverErr, ok := err.(*SnowflakeError)
212	if !ok {
213		t.Fatalf("should be snowflake error. err: %v", err)
214	}
215	if driverErr.Number != ErrCodeIdpConnectionError {
216		t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number)
217	}
218	sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess
219	sr.FuncPostAuthOKTA = postAuthOKTAError
220	_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
221	if err == nil {
222		t.Fatal("should have failed.")
223	}
224	sr.FuncPostAuthOKTA = postAuthOKTASuccess
225	sr.FuncGetSSO = getSSOError
226	_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
227	if err == nil {
228		t.Fatal("should have failed.")
229	}
230	sr.FuncGetSSO = getSSOSuccessButInvalidURL
231	_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
232	if err == nil {
233		t.Fatal("should have failed.")
234	}
235	sr.FuncGetSSO = getSSOSuccess
236	_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
237	if err != nil {
238		t.Fatalf("failed. err: %v", err)
239	}
240}
241