1// Package mastodon implements the OAuth2 protocol for authenticating users through Mastodon.
2// This package can be used as a reference implementation of an OAuth2 provider for Goth.
3package mastodon
4
5import (
6	"bytes"
7	"encoding/json"
8	"fmt"
9	"io"
10	"io/ioutil"
11	"net/http"
12	"strings"
13
14	"github.com/markbates/goth"
15	"golang.org/x/oauth2"
16)
17
18// Mastodon.social is the flagship instance of mastodon
19var (
20	InstanceURL = "https://mastodon.social/"
21)
22
23// Provider is the implementation of `goth.Provider` for accessing Mastodon.
24type Provider struct {
25	ClientKey    string
26	Secret       string
27	CallbackURL  string
28	HTTPClient   *http.Client
29	config       *oauth2.Config
30	providerName string
31	authURL      string
32	tokenURL     string
33	profileURL   string
34}
35
36// New creates a new Mastodon provider and sets up important connection details.
37// You should always call `mastodon.New` to get a new provider.  Never try to
38// create one manually.
39func New(clientKey, secret, callbackURL string, scopes ...string) *Provider {
40	return NewCustomisedURL(clientKey, secret, callbackURL, InstanceURL, scopes...)
41}
42
43// NewCustomisedURL is similar to New(...) but can be used to set custom URLs to connect to
44func NewCustomisedURL(clientKey, secret, callbackURL, instanceURL string, scopes ...string) *Provider {
45	instanceURL = fmt.Sprintf("%s/", strings.TrimSuffix(instanceURL, "/"))
46	profileURL := fmt.Sprintf("%sapi/v1/accounts/verify_credentials", instanceURL)
47	authURL := fmt.Sprintf("%soauth/authorize", instanceURL)
48	tokenURL := fmt.Sprintf("%soauth/token", instanceURL)
49	p := &Provider{
50		ClientKey:    clientKey,
51		Secret:       secret,
52		CallbackURL:  callbackURL,
53		providerName: "mastodon",
54		profileURL:   profileURL,
55	}
56	p.config = newConfig(p, authURL, tokenURL, scopes)
57	return p
58}
59
60// Name is the name used to retrieve this provider later.
61func (p *Provider) Name() string {
62	return p.providerName
63}
64
65// SetName is to update the name of the provider (needed in case of multiple providers of 1 type)
66func (p *Provider) SetName(name string) {
67	p.providerName = name
68}
69
70func (p *Provider) Client() *http.Client {
71	return goth.HTTPClientWithFallBack(p.HTTPClient)
72}
73
74// Debug is a no-op for the Mastodon package.
75func (p *Provider) Debug(debug bool) {}
76
77// BeginAuth asks Mastodon for an authentication end-point.
78func (p *Provider) BeginAuth(state string) (goth.Session, error) {
79	return &Session{
80		AuthURL: p.config.AuthCodeURL(state),
81	}, nil
82}
83
84// FetchUser will go to Mastodon and access basic information about the user.
85func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
86	sess := session.(*Session)
87	user := goth.User{
88		AccessToken:  sess.AccessToken,
89		Provider:     p.Name(),
90		RefreshToken: sess.RefreshToken,
91		ExpiresAt:    sess.ExpiresAt,
92	}
93
94	if user.AccessToken == "" {
95		// data is not yet retrieved since accessToken is still empty
96		return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName)
97	}
98
99	req, err := http.NewRequest("GET", p.profileURL, nil)
100	if err != nil {
101		return user, err
102	}
103
104	req.Header.Add("Authorization", "Bearer "+sess.AccessToken)
105	response, err := p.Client().Do(req)
106	if err != nil {
107		return user, err
108	}
109	defer response.Body.Close()
110
111	if response.StatusCode != http.StatusOK {
112		return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
113	}
114
115	bits, err := ioutil.ReadAll(response.Body)
116	if err != nil {
117		return user, err
118	}
119
120	err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData)
121	if err != nil {
122		return user, err
123	}
124
125	err = userFromReader(bytes.NewReader(bits), &user)
126
127	return user, err
128}
129
130func newConfig(provider *Provider, authURL, tokenURL string, scopes []string) *oauth2.Config {
131	c := &oauth2.Config{
132		ClientID:     provider.ClientKey,
133		ClientSecret: provider.Secret,
134		RedirectURL:  provider.CallbackURL,
135		Endpoint: oauth2.Endpoint{
136			AuthURL:  authURL,
137			TokenURL: tokenURL,
138		},
139		Scopes: []string{},
140	}
141
142	if len(scopes) > 0 {
143		for _, scope := range scopes {
144			c.Scopes = append(c.Scopes, scope)
145		}
146	}
147	return c
148}
149
150func userFromReader(r io.Reader, user *goth.User) error {
151	u := struct {
152		Name      string `json:"display_name"`
153		NickName  string `json:"username"`
154		ID        string `json:"id"`
155		AvatarURL string `json:"avatar"`
156	}{}
157	err := json.NewDecoder(r).Decode(&u)
158	if err != nil {
159		return err
160	}
161	user.Name = u.Name
162	if len(user.Name) == 0 {
163		user.Name = u.NickName
164	}
165	user.NickName = u.NickName
166	user.UserID = u.ID
167	user.AvatarURL = u.AvatarURL
168	return nil
169}
170
171//RefreshTokenAvailable refresh token is provided by auth provider or not
172func (p *Provider) RefreshTokenAvailable() bool {
173	return true
174}
175
176//RefreshToken get new access token based on the refresh token
177func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
178	token := &oauth2.Token{RefreshToken: refreshToken}
179	ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token)
180	newToken, err := ts.Token()
181	if err != nil {
182		return nil, err
183	}
184	return newToken, err
185}
186