1/* 2 * Copyright © 2019-2021 A Bunch Tell LLC. 3 * 4 * This file is part of WriteFreely. 5 * 6 * WriteFreely is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Affero General Public License, included 8 * in the LICENSE file in this source code package. 9 */ 10 11package writefreely 12 13import ( 14 "context" 15 "fmt" 16 "github.com/gorilla/sessions" 17 "github.com/stretchr/testify/assert" 18 "github.com/writeas/impart" 19 "github.com/writeas/web-core/id" 20 "github.com/writefreely/writefreely/config" 21 "net/http" 22 "net/http/httptest" 23 "net/url" 24 "strings" 25 "testing" 26) 27 28type MockOAuthDatastoreProvider struct { 29 DoDB func() OAuthDatastore 30 DoConfig func() *config.Config 31 DoSessionStore func() sessions.Store 32} 33 34type MockOAuthDatastore struct { 35 DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error) 36 DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error) 37 DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) 38 DoCreateUser func(*config.Config, *User, string) error 39 DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error 40 DoGetUserByID func(int64) (*User, error) 41} 42 43var _ OAuthDatastore = &MockOAuthDatastore{} 44 45type StringReadCloser struct { 46 *strings.Reader 47} 48 49func (src *StringReadCloser) Close() error { 50 return nil 51} 52 53type MockHTTPClient struct { 54 DoDo func(req *http.Request) (*http.Response, error) 55} 56 57func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { 58 if m.DoDo != nil { 59 return m.DoDo(req) 60 } 61 return &http.Response{}, nil 62} 63 64func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store { 65 if m.DoSessionStore != nil { 66 return m.DoSessionStore() 67 } 68 return sessions.NewCookieStore([]byte("secret-key")) 69} 70 71func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore { 72 if m.DoDB != nil { 73 return m.DoDB() 74 } 75 return &MockOAuthDatastore{} 76} 77 78func (m *MockOAuthDatastoreProvider) Config() *config.Config { 79 if m.DoConfig != nil { 80 return m.DoConfig() 81 } 82 cfg := config.New() 83 cfg.UseSQLite(true) 84 cfg.WriteAsOauth = config.WriteAsOauthCfg{ 85 ClientID: "development", 86 ClientSecret: "development", 87 AuthLocation: "https://write.as/oauth/login", 88 TokenLocation: "https://write.as/oauth/token", 89 InspectLocation: "https://write.as/oauth/inspect", 90 } 91 cfg.SlackOauth = config.SlackOauthCfg{ 92 ClientID: "development", 93 ClientSecret: "development", 94 TeamID: "development", 95 } 96 return cfg 97} 98 99func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) { 100 if m.DoValidateOAuthState != nil { 101 return m.DoValidateOAuthState(ctx, state) 102 } 103 return "", "", 0, "", nil 104} 105 106func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { 107 if m.DoGetIDForRemoteUser != nil { 108 return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID) 109 } 110 return -1, nil 111} 112 113func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username, description string) error { 114 if m.DoCreateUser != nil { 115 return m.DoCreateUser(cfg, u, username) 116 } 117 u.ID = 1 118 return nil 119} 120 121func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { 122 if m.DoRecordRemoteUserID != nil { 123 return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken) 124 } 125 return nil 126} 127 128func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) { 129 if m.DoGetUserByID != nil { 130 return m.DoGetUserByID(userID) 131 } 132 user := &User{} 133 return user, nil 134} 135 136func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) { 137 if m.DoGenerateOAuthState != nil { 138 return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode) 139 } 140 return id.Generate62RandomString(14), nil 141} 142 143func TestViewOauthInit(t *testing.T) { 144 145 t.Run("success", func(t *testing.T) { 146 app := &MockOAuthDatastoreProvider{} 147 h := oauthHandler{ 148 Config: app.Config(), 149 DB: app.DB(), 150 Store: app.SessionStore(), 151 EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, 152 oauthClient: writeAsOauthClient{ 153 ClientID: app.Config().WriteAsOauth.ClientID, 154 ClientSecret: app.Config().WriteAsOauth.ClientSecret, 155 ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, 156 InspectLocation: app.Config().WriteAsOauth.InspectLocation, 157 AuthLocation: app.Config().WriteAsOauth.AuthLocation, 158 CallbackLocation: "http://localhost/oauth/callback", 159 HttpClient: nil, 160 }, 161 } 162 req, err := http.NewRequest("GET", "/oauth/client", nil) 163 assert.NoError(t, err) 164 rr := httptest.NewRecorder() 165 err = h.viewOauthInit(nil, rr, req) 166 assert.NotNil(t, err) 167 httpErr, ok := err.(impart.HTTPError) 168 assert.True(t, ok) 169 assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status) 170 assert.NotEmpty(t, httpErr.Message) 171 locURI, err := url.Parse(httpErr.Message) 172 assert.NoError(t, err) 173 assert.Equal(t, "/oauth/login", locURI.Path) 174 assert.Equal(t, "development", locURI.Query().Get("client_id")) 175 assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri")) 176 assert.Equal(t, "code", locURI.Query().Get("response_type")) 177 assert.NotEmpty(t, locURI.Query().Get("state")) 178 }) 179 180 t.Run("state failure", func(t *testing.T) { 181 app := &MockOAuthDatastoreProvider{ 182 DoDB: func() OAuthDatastore { 183 return &MockOAuthDatastore{ 184 DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) { 185 return "", fmt.Errorf("pretend unable to write state error") 186 }, 187 } 188 }, 189 } 190 h := oauthHandler{ 191 Config: app.Config(), 192 DB: app.DB(), 193 Store: app.SessionStore(), 194 EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, 195 oauthClient: writeAsOauthClient{ 196 ClientID: app.Config().WriteAsOauth.ClientID, 197 ClientSecret: app.Config().WriteAsOauth.ClientSecret, 198 ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, 199 InspectLocation: app.Config().WriteAsOauth.InspectLocation, 200 AuthLocation: app.Config().WriteAsOauth.AuthLocation, 201 CallbackLocation: "http://localhost/oauth/callback", 202 HttpClient: nil, 203 }, 204 } 205 req, err := http.NewRequest("GET", "/oauth/client", nil) 206 assert.NoError(t, err) 207 rr := httptest.NewRecorder() 208 err = h.viewOauthInit(nil, rr, req) 209 httpErr, ok := err.(impart.HTTPError) 210 assert.True(t, ok) 211 assert.NotEmpty(t, httpErr.Message) 212 assert.Equal(t, http.StatusInternalServerError, httpErr.Status) 213 assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message) 214 }) 215} 216 217func TestViewOauthCallback(t *testing.T) { 218 t.Run("success", func(t *testing.T) { 219 app := &MockOAuthDatastoreProvider{} 220 h := oauthHandler{ 221 Config: app.Config(), 222 DB: app.DB(), 223 Store: app.SessionStore(), 224 EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, 225 oauthClient: writeAsOauthClient{ 226 ClientID: app.Config().WriteAsOauth.ClientID, 227 ClientSecret: app.Config().WriteAsOauth.ClientSecret, 228 ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, 229 InspectLocation: app.Config().WriteAsOauth.InspectLocation, 230 AuthLocation: app.Config().WriteAsOauth.AuthLocation, 231 CallbackLocation: "http://localhost/oauth/callback", 232 HttpClient: &MockHTTPClient{ 233 DoDo: func(req *http.Request) (*http.Response, error) { 234 switch req.URL.String() { 235 case "https://write.as/oauth/token": 236 return &http.Response{ 237 StatusCode: 200, 238 Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, 239 }, nil 240 case "https://write.as/oauth/inspect": 241 return &http.Response{ 242 StatusCode: 200, 243 Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, 244 }, nil 245 } 246 247 return &http.Response{ 248 StatusCode: http.StatusNotFound, 249 }, nil 250 }, 251 }, 252 }, 253 } 254 req, err := http.NewRequest("GET", "/oauth/callback", nil) 255 assert.NoError(t, err) 256 rr := httptest.NewRecorder() 257 err = h.viewOauthCallback(&App{cfg: app.Config(), sessionStore: app.SessionStore()}, rr, req) 258 assert.NoError(t, err) 259 assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) 260 }) 261} 262