1/*
2 * Copyright © 2019-2021 A Bunch Tell LLC.
3 *
4 * This file is part of WriteFreely.
5 *
6 * WriteFreely is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Affero General Public License, included
8 * in the LICENSE file in this source code package.
9 */
10
11package writefreely
12
13import (
14	"context"
15	"fmt"
16	"github.com/gorilla/sessions"
17	"github.com/stretchr/testify/assert"
18	"github.com/writeas/impart"
19	"github.com/writeas/web-core/id"
20	"github.com/writefreely/writefreely/config"
21	"net/http"
22	"net/http/httptest"
23	"net/url"
24	"strings"
25	"testing"
26)
27
28type MockOAuthDatastoreProvider struct {
29	DoDB           func() OAuthDatastore
30	DoConfig       func() *config.Config
31	DoSessionStore func() sessions.Store
32}
33
34type MockOAuthDatastore struct {
35	DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error)
36	DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error)
37	DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
38	DoCreateUser         func(*config.Config, *User, string) error
39	DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
40	DoGetUserByID        func(int64) (*User, error)
41}
42
43var _ OAuthDatastore = &MockOAuthDatastore{}
44
45type StringReadCloser struct {
46	*strings.Reader
47}
48
49func (src *StringReadCloser) Close() error {
50	return nil
51}
52
53type MockHTTPClient struct {
54	DoDo func(req *http.Request) (*http.Response, error)
55}
56
57func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
58	if m.DoDo != nil {
59		return m.DoDo(req)
60	}
61	return &http.Response{}, nil
62}
63
64func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
65	if m.DoSessionStore != nil {
66		return m.DoSessionStore()
67	}
68	return sessions.NewCookieStore([]byte("secret-key"))
69}
70
71func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
72	if m.DoDB != nil {
73		return m.DoDB()
74	}
75	return &MockOAuthDatastore{}
76}
77
78func (m *MockOAuthDatastoreProvider) Config() *config.Config {
79	if m.DoConfig != nil {
80		return m.DoConfig()
81	}
82	cfg := config.New()
83	cfg.UseSQLite(true)
84	cfg.WriteAsOauth = config.WriteAsOauthCfg{
85		ClientID:        "development",
86		ClientSecret:    "development",
87		AuthLocation:    "https://write.as/oauth/login",
88		TokenLocation:   "https://write.as/oauth/token",
89		InspectLocation: "https://write.as/oauth/inspect",
90	}
91	cfg.SlackOauth = config.SlackOauthCfg{
92		ClientID:     "development",
93		ClientSecret: "development",
94		TeamID:       "development",
95	}
96	return cfg
97}
98
99func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
100	if m.DoValidateOAuthState != nil {
101		return m.DoValidateOAuthState(ctx, state)
102	}
103	return "", "", 0, "", nil
104}
105
106func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
107	if m.DoGetIDForRemoteUser != nil {
108		return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID)
109	}
110	return -1, nil
111}
112
113func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username, description string) error {
114	if m.DoCreateUser != nil {
115		return m.DoCreateUser(cfg, u, username)
116	}
117	u.ID = 1
118	return nil
119}
120
121func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
122	if m.DoRecordRemoteUserID != nil {
123		return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken)
124	}
125	return nil
126}
127
128func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
129	if m.DoGetUserByID != nil {
130		return m.DoGetUserByID(userID)
131	}
132	user := &User{}
133	return user, nil
134}
135
136func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) {
137	if m.DoGenerateOAuthState != nil {
138		return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode)
139	}
140	return id.Generate62RandomString(14), nil
141}
142
143func TestViewOauthInit(t *testing.T) {
144
145	t.Run("success", func(t *testing.T) {
146		app := &MockOAuthDatastoreProvider{}
147		h := oauthHandler{
148			Config:   app.Config(),
149			DB:       app.DB(),
150			Store:    app.SessionStore(),
151			EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
152			oauthClient: writeAsOauthClient{
153				ClientID:         app.Config().WriteAsOauth.ClientID,
154				ClientSecret:     app.Config().WriteAsOauth.ClientSecret,
155				ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
156				InspectLocation:  app.Config().WriteAsOauth.InspectLocation,
157				AuthLocation:     app.Config().WriteAsOauth.AuthLocation,
158				CallbackLocation: "http://localhost/oauth/callback",
159				HttpClient:       nil,
160			},
161		}
162		req, err := http.NewRequest("GET", "/oauth/client", nil)
163		assert.NoError(t, err)
164		rr := httptest.NewRecorder()
165		err = h.viewOauthInit(nil, rr, req)
166		assert.NotNil(t, err)
167		httpErr, ok := err.(impart.HTTPError)
168		assert.True(t, ok)
169		assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status)
170		assert.NotEmpty(t, httpErr.Message)
171		locURI, err := url.Parse(httpErr.Message)
172		assert.NoError(t, err)
173		assert.Equal(t, "/oauth/login", locURI.Path)
174		assert.Equal(t, "development", locURI.Query().Get("client_id"))
175		assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
176		assert.Equal(t, "code", locURI.Query().Get("response_type"))
177		assert.NotEmpty(t, locURI.Query().Get("state"))
178	})
179
180	t.Run("state failure", func(t *testing.T) {
181		app := &MockOAuthDatastoreProvider{
182			DoDB: func() OAuthDatastore {
183				return &MockOAuthDatastore{
184					DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) {
185						return "", fmt.Errorf("pretend unable to write state error")
186					},
187				}
188			},
189		}
190		h := oauthHandler{
191			Config:   app.Config(),
192			DB:       app.DB(),
193			Store:    app.SessionStore(),
194			EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
195			oauthClient: writeAsOauthClient{
196				ClientID:         app.Config().WriteAsOauth.ClientID,
197				ClientSecret:     app.Config().WriteAsOauth.ClientSecret,
198				ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
199				InspectLocation:  app.Config().WriteAsOauth.InspectLocation,
200				AuthLocation:     app.Config().WriteAsOauth.AuthLocation,
201				CallbackLocation: "http://localhost/oauth/callback",
202				HttpClient:       nil,
203			},
204		}
205		req, err := http.NewRequest("GET", "/oauth/client", nil)
206		assert.NoError(t, err)
207		rr := httptest.NewRecorder()
208		err = h.viewOauthInit(nil, rr, req)
209		httpErr, ok := err.(impart.HTTPError)
210		assert.True(t, ok)
211		assert.NotEmpty(t, httpErr.Message)
212		assert.Equal(t, http.StatusInternalServerError, httpErr.Status)
213		assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message)
214	})
215}
216
217func TestViewOauthCallback(t *testing.T) {
218	t.Run("success", func(t *testing.T) {
219		app := &MockOAuthDatastoreProvider{}
220		h := oauthHandler{
221			Config:   app.Config(),
222			DB:       app.DB(),
223			Store:    app.SessionStore(),
224			EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
225			oauthClient: writeAsOauthClient{
226				ClientID:         app.Config().WriteAsOauth.ClientID,
227				ClientSecret:     app.Config().WriteAsOauth.ClientSecret,
228				ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
229				InspectLocation:  app.Config().WriteAsOauth.InspectLocation,
230				AuthLocation:     app.Config().WriteAsOauth.AuthLocation,
231				CallbackLocation: "http://localhost/oauth/callback",
232				HttpClient: &MockHTTPClient{
233					DoDo: func(req *http.Request) (*http.Response, error) {
234						switch req.URL.String() {
235						case "https://write.as/oauth/token":
236							return &http.Response{
237								StatusCode: 200,
238								Body:       &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
239							}, nil
240						case "https://write.as/oauth/inspect":
241							return &http.Response{
242								StatusCode: 200,
243								Body:       &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
244							}, nil
245						}
246
247						return &http.Response{
248							StatusCode: http.StatusNotFound,
249						}, nil
250					},
251				},
252			},
253		}
254		req, err := http.NewRequest("GET", "/oauth/callback", nil)
255		assert.NoError(t, err)
256		rr := httptest.NewRecorder()
257		err = h.viewOauthCallback(&App{cfg: app.Config(), sessionStore: app.SessionStore()}, rr, req)
258		assert.NoError(t, err)
259		assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
260	})
261}
262