1package okta
2
3import (
4	"bytes"
5	"encoding/json"
6	"fmt"
7	"io"
8	"io/ioutil"
9	"net/http"
10	"net/url"
11	"regexp"
12	"strconv"
13	"sync"
14	"time"
15
16	"github.com/google/go-querystring/query"
17
18	"reflect"
19)
20
21const (
22	libraryVersion            = "1"
23	userAgent                 = "oktasdk-go/" + libraryVersion
24	productionDomain          = "okta.com"
25	previewDomain             = "oktapreview.com"
26	urlFormat                 = "https://%s.%s/api/v1/"
27	headerRateLimit           = "X-Rate-Limit-Limit"
28	headerRateRemaining       = "X-Rate-Limit-Remaining"
29	headerRateReset           = "X-Rate-Limit-Reset"
30	headerOKTARequestID       = "X-Okta-Request-Id"
31	headerAuthorization       = "Authorization"
32	headerAuthorizationFormat = "SSWS %v"
33	mediaTypeJSON             = "application/json"
34	defaultLimit              = 50
35	// FilterEqualOperator Filter Operatorid for "equal"
36	FilterEqualOperator = "eq"
37	// FilterStartsWithOperator - filter operator for "starts with"
38	FilterStartsWithOperator = "sw"
39	// FilterGreaterThanOperator - filter operator for "greater than"
40	FilterGreaterThanOperator = "gt"
41	// FilterLessThanOperator - filter operator for "less than"
42	FilterLessThanOperator = "lt"
43
44	// If the API returns a "X-Rate-Limit-Remaining" header less than this the SDK will either pause
45	//  Or throw  RateLimitError depending on the client.PauseOnRateLimit value
46	defaultRateRemainingFloor = 100
47)
48
49// A Client manages communication with the API.
50type Client struct {
51	clientMu sync.Mutex   // clientMu protects the client during calls that modify the CheckRedirect func.
52	client   *http.Client // HTTP client used to communicate with the API.
53
54	// Base URL for API requests.
55	//  This will be built automatically based on inputs to NewClient
56	//  If needed you can override this if needed (your URL is not *.okta.com or *.oktapreview.com)
57	BaseURL *url.URL
58
59	// User agent used when communicating with the GitHub API.
60	UserAgent string
61
62	apiKey                   string
63	authorizationHeaderValue string
64	PauseOnRateLimit         bool
65
66	// RateRemainingFloor - If the API returns a "X-Rate-Limit-Remaining" header less than this the SDK will either pause
67	//  Or throw  RateLimitError depending on the client.PauseOnRateLimit value. It defaults to 30
68	// One client doing too much work can lock out all API Access for every other client
69	// We are trying to be a "good API User Citizen"
70	RateRemainingFloor int
71
72	rateMu         sync.Mutex
73	mostRecentRate Rate
74
75	Limit int
76	// mostRecent rateLimitCategory
77
78	common service // Reuse a single struct instead of allocating one for each service on the heap.
79
80	// Services used for talking to different parts of the  API.
81	// Service for Working with Users
82	Users *UsersService
83
84	// Service for Working with Groups
85	Groups *GroupsService
86
87	// Service for Working with Apps
88	Apps *AppsService
89}
90
91type service struct {
92	client *Client
93}
94
95// NewClient returns a new OKTA API client.  If a nil httpClient is
96// provided, http.DefaultClient will be used.
97func NewClient(httpClient *http.Client, orgName string, apiToken string, isProduction bool) *Client {
98	var baseDomain string
99	if isProduction {
100		baseDomain = productionDomain
101	} else {
102		baseDomain = previewDomain
103	}
104	client, _ := NewClientWithDomain(httpClient, orgName, baseDomain, apiToken)
105	return client
106}
107
108// NewClientWithDomain creates a client based on the organziation name and
109// base domain for requests (okta.com, okta-emea.com, oktapreview.com, etc).
110func NewClientWithDomain(httpClient *http.Client, orgName string, domain string, apiToken string) (*Client, error) {
111	baseURL, err := url.Parse(fmt.Sprintf(urlFormat, orgName, domain))
112	if err != nil {
113		return nil, err
114	}
115	return NewClientWithBaseURL(httpClient, baseURL, apiToken), nil
116}
117
118// NewClientWithBaseURL creates a client based on the full base URL and api
119// token
120func NewClientWithBaseURL(httpClient *http.Client, baseURL *url.URL, apiToken string) *Client {
121	if httpClient == nil {
122		httpClient = http.DefaultClient
123	}
124
125	c := &Client{
126		client:    httpClient,
127		BaseURL:   baseURL,
128		UserAgent: userAgent,
129	}
130	c.PauseOnRateLimit = true // If rate limit found it will block until that time. If false then Error will be returned
131	c.authorizationHeaderValue = fmt.Sprintf(headerAuthorizationFormat, apiToken)
132	c.apiKey = apiToken
133	c.Limit = defaultLimit
134	c.RateRemainingFloor = defaultRateRemainingFloor
135	c.common.client = c
136
137	c.Users = (*UsersService)(&c.common)
138	c.Groups = (*GroupsService)(&c.common)
139	c.Apps = (*AppsService)(&c.common)
140	return c
141}
142
143// Rate represents the rate limit for the current client.
144type Rate struct {
145	// The number of requests per minute the client is currently limited to.
146	RatePerMinuteLimit int
147
148	// The number of remaining requests the client can make this minute
149	Remaining int
150
151	// The time at which the current rate limit will reset.
152	ResetTime time.Time
153}
154
155// Response is a OKTA API response.  This wraps the standard http.Response
156// returned from OKTA and provides convenient access to things like
157// pagination links.
158type Response struct {
159	*http.Response
160
161	// These fields provide the page values for paginating through a set of
162	// results.
163
164	NextURL *url.URL
165	// PrevURL       *url.URL
166	SelfURL       *url.URL
167	OKTARequestID string
168	Rate
169}
170
171// newResponse creates a new Response for the provided http.Response.
172func newResponse(r *http.Response) *Response {
173	response := &Response{Response: r}
174
175	response.OKTARequestID = r.Header.Get(headerOKTARequestID)
176
177	response.populatePaginationURLS()
178	response.Rate = parseRate(r)
179	return response
180}
181
182// populatePageValues parses the HTTP Link response headers and populates the
183// various pagination link values in the Response.
184
185// OKTA LINK Header takes this form:
186// 		Link: <https://yoursubdomain.okta.com/api/v1/users?after=00ubfjQEMYBLRUWIEDKK>; rel="next",
187// 			<https://yoursubdomain.okta.com/api/v1/users?after=00ub4tTFYKXCCZJSGFKM>; rel="self"
188
189func (r *Response) populatePaginationURLS() {
190
191	for k, v := range r.Header {
192
193		if k == "Link" {
194			nextRegex := regexp.MustCompile(`<(.*?)>; rel="next"`)
195			// prevRegex := regexp.MustCompile(`<(.*?)>; rel="prev"`)
196			selfRegex := regexp.MustCompile(`<(.*?)>; rel="self"`)
197
198			for _, linkValue := range v {
199				nextLinkMatch := nextRegex.FindStringSubmatch(linkValue)
200				if len(nextLinkMatch) != 0 {
201					r.NextURL, _ = url.Parse(nextLinkMatch[1])
202				}
203				selfLinkMatch := selfRegex.FindStringSubmatch(linkValue)
204				if len(selfLinkMatch) != 0 {
205					r.SelfURL, _ = url.Parse(selfLinkMatch[1])
206				}
207				// prevLinkMatch := prevRegex.FindStringSubmatch(linkValue)
208				// if len(prevLinkMatch) != 0 {
209				// 	r.PrevURL, _ = url.Parse(prevLinkMatch[1])
210				// }
211			}
212		}
213	}
214
215}
216
217// parseRate parses the rate related headers.
218func parseRate(r *http.Response) Rate {
219	var rate Rate
220
221	if limit := r.Header.Get(headerRateLimit); limit != "" {
222		rate.RatePerMinuteLimit, _ = strconv.Atoi(limit)
223	}
224	if remaining := r.Header.Get(headerRateRemaining); remaining != "" {
225		rate.Remaining, _ = strconv.Atoi(remaining)
226	}
227	if reset := r.Header.Get(headerRateReset); reset != "" {
228		if v, _ := strconv.ParseInt(reset, 10, 64); v != 0 {
229			rate.ResetTime = time.Unix(v, 0)
230		}
231	}
232	return rate
233}
234
235// Do sends an API request and returns the API response.  The API response is
236// JSON decoded and stored in the value pointed to by v, or returned as an
237// error if an API error has occurred.  If v implements the io.Writer
238// interface, the raw response body will be written to v, without attempting to
239// first decode it.  If rate limit is exceeded and reset time is in the future,
240// Do returns rate immediately without making a network API call.
241func (c *Client) Do(req *http.Request, v interface{}) (*Response, error) {
242
243	// If we've hit rate limit, don't make further requests before Reset time.
244	if err := c.checkRateLimitBeforeDo(req); err != nil {
245		return nil, err
246	}
247
248	resp, err := c.client.Do(req)
249	if err != nil {
250		return nil, err
251	}
252
253	defer func() {
254		// Drain up to 512 bytes and close the body to let the Transport reuse the connection
255		io.CopyN(ioutil.Discard, resp.Body, 512)
256		resp.Body.Close()
257	}()
258
259	response := newResponse(resp)
260
261	c.rateMu.Lock()
262	c.mostRecentRate.RatePerMinuteLimit = response.Rate.RatePerMinuteLimit
263	c.mostRecentRate.Remaining = response.Rate.Remaining
264	c.mostRecentRate.ResetTime = response.Rate.ResetTime
265	c.rateMu.Unlock()
266
267	err = CheckResponse(resp)
268	if err != nil {
269		// even though there was an error, we still return the response
270		// in case the caller wants to inspect it further
271		// fmt.Printf("Error after sdk.Do return\n")
272
273		return response, err
274	}
275
276	if v != nil {
277		if w, ok := v.(io.Writer); ok {
278			io.Copy(w, resp.Body)
279		} else {
280			err = json.NewDecoder(resp.Body).Decode(v)
281			if err == io.EOF {
282				err = nil // ignore EOF errors caused by empty response body
283			}
284		}
285	}
286
287	return response, err
288}
289
290// checkRateLimitBeforeDo does not make any network calls, but uses existing knowledge from
291// current client state in order to quickly check if *RateLimitError can be immediately returned
292// from Client.Do, and if so, returns it so that Client.Do can skip making a network API call unnecessarily.
293// Otherwise it returns nil, and Client.Do should proceed normally.
294// http://developer.okta.com/docs/api/getting_started/design_principles.html#rate-limiting
295func (c *Client) checkRateLimitBeforeDo(req *http.Request) error {
296
297	c.rateMu.Lock()
298	mostRecentRate := c.mostRecentRate
299	c.rateMu.Unlock()
300	// fmt.Printf("checkRateLimitBeforeDo: \t Remaining = %d, \t ResetTime = %s\n", mostRecentRate.Remaining, mostRecentRate.ResetTime.String())
301	if !mostRecentRate.ResetTime.IsZero() && mostRecentRate.Remaining < c.RateRemainingFloor && time.Now().Before(mostRecentRate.ResetTime) {
302
303		if c.PauseOnRateLimit {
304			// If rate limit is hitting threshold then pause until the rate limit resets
305			//   This behavior is controlled by the client PauseOnRateLimit value
306			// fmt.Printf("checkRateLimitBeforeDo: \t ***pause**** \t Time Now = %s \tPause After = %s\n", time.Now().String(), mostRecentRate.ResetTime.Sub(time.Now().Add(2*time.Second)).String())
307			<-time.After(mostRecentRate.ResetTime.Sub(time.Now().Add(2 * time.Second)))
308		} else {
309			// fmt.Printf("checkRateLimitBeforeDo: \t ***error****\n")
310
311			return &RateLimitError{
312				Rate: mostRecentRate,
313			}
314		}
315
316	}
317
318	return nil
319}
320
321// CheckResponse checks the API response for errors, and returns them if
322// present.  A response is considered an error if it has a status code outside
323// the 200 range.  API error responses are expected to have either no response
324// body, or a JSON response body that maps to ErrorResponse.  Any other
325// response body will be silently ignored.
326//
327// The error type will be *RateLimitError for rate limit exceeded errors,
328// and *TwoFactorAuthError for two-factor authentication errors.
329// TODO - check un-authorized
330func CheckResponse(r *http.Response) error {
331	if c := r.StatusCode; 200 <= c && c <= 299 {
332		return nil
333	}
334
335	errorResp := &errorResponse{Response: r}
336	data, err := ioutil.ReadAll(r.Body)
337	if err == nil && data != nil {
338		json.Unmarshal(data, &errorResp.ErrorDetail)
339	}
340	switch {
341	case r.StatusCode == http.StatusTooManyRequests:
342
343		return &RateLimitError{
344			Rate:        parseRate(r),
345			Response:    r,
346			ErrorDetail: errorResp.ErrorDetail}
347
348	default:
349		return errorResp
350	}
351
352}
353
354type apiError struct {
355	ErrorCode    string `json:"errorCode"`
356	ErrorSummary string `json:"errorSummary"`
357	ErrorLink    string `json:"errorLink"`
358	ErrorID      string `json:"errorId"`
359	ErrorCauses  []struct {
360		ErrorSummary string `json:"errorSummary"`
361	} `json:"errorCauses"`
362}
363
364type errorResponse struct {
365	Response    *http.Response //
366	ErrorDetail apiError
367}
368
369func (r *errorResponse) Error() string {
370	return fmt.Sprintf("HTTP Method: %v - URL: %v: - HTTP Status Code: %d, OKTA Error Code: %v, OKTA Error Summary: %v, OKTA Error Causes: %v",
371		r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.ErrorDetail.ErrorCode, r.ErrorDetail.ErrorSummary, r.ErrorDetail.ErrorCauses)
372}
373
374// RateLimitError occurs when OKTA returns 429 "Too Many Requests" response with a rate limit
375// remaining value of 0, and error message starts with "API rate limit exceeded for ".
376type RateLimitError struct {
377	Rate        Rate // Rate specifies last known rate limit for the client
378	ErrorDetail apiError
379	Response    *http.Response //
380}
381
382func (r *RateLimitError) Error() string {
383
384	return fmt.Sprintf("rate reset in %v", r.Rate.ResetTime.Sub(time.Now()))
385
386}
387
388// Code stolen from Github api libary
389// Stringify attempts to create a reasonable string representation of types in
390// the library.  It does things like resolve pointers to their values
391// and omits struct fields with nil values.
392func stringify(message interface{}) string {
393	var buf bytes.Buffer
394	v := reflect.ValueOf(message)
395	stringifyValue(&buf, v)
396	return buf.String()
397}
398
399// stringifyValue was heavily inspired by the goprotobuf library.
400
401func stringifyValue(w io.Writer, val reflect.Value) {
402	if val.Kind() == reflect.Ptr && val.IsNil() {
403		w.Write([]byte("<nil>"))
404		return
405	}
406
407	v := reflect.Indirect(val)
408
409	switch v.Kind() {
410	case reflect.String:
411		fmt.Fprintf(w, `"%s"`, v)
412	case reflect.Slice:
413		w.Write([]byte{'['})
414		for i := 0; i < v.Len(); i++ {
415			if i > 0 {
416				w.Write([]byte{' '})
417			}
418
419			stringifyValue(w, v.Index(i))
420		}
421
422		w.Write([]byte{']'})
423		return
424	case reflect.Struct:
425		if v.Type().Name() != "" {
426			w.Write([]byte(v.Type().String()))
427		}
428		w.Write([]byte{'{'})
429
430		var sep bool
431		for i := 0; i < v.NumField(); i++ {
432			fv := v.Field(i)
433			if fv.Kind() == reflect.Ptr && fv.IsNil() {
434				continue
435			}
436			if fv.Kind() == reflect.Slice && fv.IsNil() {
437				continue
438			}
439
440			if sep {
441				w.Write([]byte(", "))
442			} else {
443				sep = true
444			}
445
446			w.Write([]byte(v.Type().Field(i).Name))
447			w.Write([]byte{':'})
448			stringifyValue(w, fv)
449		}
450
451		w.Write([]byte{'}'})
452	default:
453		if v.CanInterface() {
454			fmt.Fprint(w, v.Interface())
455		}
456	}
457}
458
459// NewRequest creates an API request. A relative URL can be provided in urlStr,
460// in which case it is resolved relative to the BaseURL of the Client.
461// Relative URLs should always be specified without a preceding slash.  If
462// specified, the value pointed to by body is JSON encoded and included as the
463// request body.
464func (c *Client) NewRequest(method, urlStr string, body interface{}) (*http.Request, error) {
465	rel, err := url.Parse(urlStr)
466	if err != nil {
467		return nil, err
468	}
469
470	u := c.BaseURL.ResolveReference(rel)
471
472	var buf io.ReadWriter
473	if body != nil {
474		buf = new(bytes.Buffer)
475		err := json.NewEncoder(buf).Encode(body)
476		if err != nil {
477			return nil, err
478		}
479	}
480
481	req, err := http.NewRequest(method, u.String(), buf)
482	if err != nil {
483		return nil, err
484	}
485	if c.apiKey != "" {
486		req.Header.Set(headerAuthorization, fmt.Sprintf(headerAuthorizationFormat, c.apiKey))
487	}
488	if body != nil {
489		req.Header.Set("Content-Type", mediaTypeJSON)
490	}
491
492	if c.UserAgent != "" {
493		req.Header.Set("User-Agent", c.UserAgent)
494	}
495	return req, nil
496}
497
498// addOptions adds the parameters in opt as URL query parameters to s.  opt
499// must be a struct whose fields may contain "url" tags.
500func addOptions(s string, opt interface{}) (string, error) {
501	v := reflect.ValueOf(opt)
502	if v.Kind() == reflect.Ptr && v.IsNil() {
503		return s, nil
504	}
505
506	u, err := url.Parse(s)
507	if err != nil {
508		return s, err
509	}
510
511	qs, err := query.Values(opt)
512	if err != nil {
513		return s, err
514	}
515
516	u.RawQuery = qs.Encode()
517	return u.String(), nil
518}
519
520type dateFilter struct {
521	Value    time.Time
522	Operator string
523}
524