1package disco
2
3import (
4	"encoding/json"
5	"fmt"
6	"log"
7	"net/http"
8	"net/url"
9	"os"
10	"strconv"
11	"strings"
12	"time"
13
14	"github.com/hashicorp/go-version"
15)
16
17const versionServiceID = "versions.v1"
18
19// Host represents a service discovered host.
20type Host struct {
21	discoURL  *url.URL
22	hostname  string
23	services  map[string]interface{}
24	transport http.RoundTripper
25}
26
27// Constraints represents the version constraints of a service.
28type Constraints struct {
29	Service   string   `json:"service"`
30	Product   string   `json:"product"`
31	Minimum   string   `json:"minimum"`
32	Maximum   string   `json:"maximum"`
33	Excluding []string `json:"excluding"`
34}
35
36// ErrServiceNotProvided is returned when the service is not provided.
37type ErrServiceNotProvided struct {
38	hostname string
39	service  string
40}
41
42// Error returns a customized error message.
43func (e *ErrServiceNotProvided) Error() string {
44	if e.hostname == "" {
45		return fmt.Sprintf("host does not provide a %s service", e.service)
46	}
47	return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service)
48}
49
50// ErrVersionNotSupported is returned when the version is not supported.
51type ErrVersionNotSupported struct {
52	hostname string
53	service  string
54	version  string
55}
56
57// Error returns a customized error message.
58func (e *ErrVersionNotSupported) Error() string {
59	if e.hostname == "" {
60		return fmt.Sprintf("host does not support %s version %s", e.service, e.version)
61	}
62	return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version)
63}
64
65// ErrNoVersionConstraints is returned when checkpoint was disabled
66// or the endpoint to query for version constraints was unavailable.
67type ErrNoVersionConstraints struct {
68	disabled bool
69}
70
71// Error returns a customized error message.
72func (e *ErrNoVersionConstraints) Error() string {
73	if e.disabled {
74		return "checkpoint disabled"
75	}
76	return "unable to contact versions service"
77}
78
79// ServiceURL returns the URL associated with the given service identifier,
80// which should be of the form "servicename.vN".
81//
82// A non-nil result is always an absolute URL with a scheme of either HTTPS
83// or HTTP.
84func (h *Host) ServiceURL(id string) (*url.URL, error) {
85	svc, ver, err := parseServiceID(id)
86	if err != nil {
87		return nil, err
88	}
89
90	// No services supported for an empty Host.
91	if h == nil || h.services == nil {
92		return nil, &ErrServiceNotProvided{service: svc}
93	}
94
95	urlStr, ok := h.services[id].(string)
96	if !ok {
97		// See if we have a matching service as that would indicate
98		// the service is supported, but not the requested version.
99		for serviceID := range h.services {
100			if strings.HasPrefix(serviceID, svc+".") {
101				return nil, &ErrVersionNotSupported{
102					hostname: h.hostname,
103					service:  svc,
104					version:  ver.Original(),
105				}
106			}
107		}
108
109		// No discovered services match the requested service.
110		return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
111	}
112
113	u, err := h.parseURL(urlStr)
114	if err != nil {
115		return nil, fmt.Errorf("Failed to parse service URL: %v", err)
116	}
117
118	return u, nil
119}
120
121// ServiceOAuthClient returns the OAuth client configuration associated with the
122// given service identifier, which should be of the form "servicename.vN".
123//
124// This is an alternative to ServiceURL for unusual services that require
125// a full OAuth2 client definition rather than just a URL. Use this only
126// for services whose specification calls for this sort of definition.
127func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) {
128	svc, ver, err := parseServiceID(id)
129	if err != nil {
130		return nil, err
131	}
132
133	// No services supported for an empty Host.
134	if h == nil || h.services == nil {
135		return nil, &ErrServiceNotProvided{service: svc}
136	}
137
138	if _, ok := h.services[id]; !ok {
139		// See if we have a matching service as that would indicate
140		// the service is supported, but not the requested version.
141		for serviceID := range h.services {
142			if strings.HasPrefix(serviceID, svc+".") {
143				return nil, &ErrVersionNotSupported{
144					hostname: h.hostname,
145					service:  svc,
146					version:  ver.Original(),
147				}
148			}
149		}
150
151		// No discovered services match the requested service.
152		return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
153	}
154
155	var raw map[string]interface{}
156	switch v := h.services[id].(type) {
157	case map[string]interface{}:
158		raw = v // Great!
159	case []map[string]interface{}:
160		// An absolutely infuriating legacy HCL ambiguity.
161		raw = v[0]
162	default:
163		// Debug message because raw Go types don't belong in our UI.
164		log.Printf("[DEBUG] The definition for %s has Go type %T", id, h.services[id])
165		return nil, fmt.Errorf("Service %s must be declared with an object value in the service discovery document", id)
166	}
167
168	var grantTypes OAuthGrantTypeSet
169	if rawGTs, ok := raw["grant_types"]; ok {
170		if gts, ok := rawGTs.([]interface{}); ok {
171			var kws []string
172			for _, gtI := range gts {
173				gt, ok := gtI.(string)
174				if !ok {
175					// We'll ignore this so that we can potentially introduce
176					// other types into this array later if we need to.
177					continue
178				}
179				kws = append(kws, gt)
180			}
181			grantTypes = NewOAuthGrantTypeSet(kws...)
182		} else {
183			return nil, fmt.Errorf("Service %s is defined with invalid grant_types property: must be an array of grant type strings", id)
184		}
185	} else {
186		grantTypes = NewOAuthGrantTypeSet("authz_code")
187	}
188
189	ret := &OAuthClient{
190		SupportedGrantTypes: grantTypes,
191	}
192	if clientIDStr, ok := raw["client"].(string); ok {
193		ret.ID = clientIDStr
194	} else {
195		return nil, fmt.Errorf("Service %s definition is missing required property \"client\"", id)
196	}
197	if urlStr, ok := raw["authz"].(string); ok {
198		u, err := h.parseURL(urlStr)
199		if err != nil {
200			return nil, fmt.Errorf("Failed to parse authorization URL: %v", err)
201		}
202		ret.AuthorizationURL = u
203	} else {
204		if grantTypes.RequiresAuthorizationEndpoint() {
205			return nil, fmt.Errorf("Service %s definition is missing required property \"authz\"", id)
206		}
207	}
208	if urlStr, ok := raw["token"].(string); ok {
209		u, err := h.parseURL(urlStr)
210		if err != nil {
211			return nil, fmt.Errorf("Failed to parse token URL: %v", err)
212		}
213		ret.TokenURL = u
214	} else {
215		if grantTypes.RequiresTokenEndpoint() {
216			return nil, fmt.Errorf("Service %s definition is missing required property \"token\"", id)
217		}
218	}
219	if portsRaw, ok := raw["ports"].([]interface{}); ok {
220		if len(portsRaw) != 2 {
221			return nil, fmt.Errorf("Invalid \"ports\" definition for service %s: must be a two-element array", id)
222		}
223		invalidPortsErr := fmt.Errorf("Invalid \"ports\" definition for service %s: both ports must be whole numbers between 1024 and 65535", id)
224		ports := make([]uint16, 2)
225		for i := range ports {
226			switch v := portsRaw[i].(type) {
227			case float64:
228				// JSON unmarshaling always produces float64. HCL 2 might, if
229				// an invalid fractional number were given.
230				if float64(uint16(v)) != v || v < 1024 {
231					return nil, invalidPortsErr
232				}
233				ports[i] = uint16(v)
234			case int:
235				// Legacy HCL produces int. HCL 2 will too, if the given number
236				// is a whole number.
237				if v < 1024 || v > 65535 {
238					return nil, invalidPortsErr
239				}
240				ports[i] = uint16(v)
241			default:
242				// Debug message because raw Go types don't belong in our UI.
243				log.Printf("[DEBUG] Port value %d has Go type %T", i, portsRaw[i])
244				return nil, invalidPortsErr
245			}
246		}
247		if ports[1] < ports[0] {
248			return nil, fmt.Errorf("Invalid \"ports\" definition for service %s: minimum port cannot be greater than maximum port", id)
249		}
250		ret.MinPort = ports[0]
251		ret.MaxPort = ports[1]
252	} else {
253		// Default is to accept any port in the range, for a client that is
254		// able to call back to any localhost port.
255		ret.MinPort = 1024
256		ret.MaxPort = 65535
257	}
258	if scopesRaw, ok := raw["scopes"].([]interface{}); ok {
259		var scopes []string
260		for _, scopeI := range scopesRaw {
261			scope, ok := scopeI.(string)
262			if !ok {
263				return nil, fmt.Errorf("Invalid \"scopes\" for service %s: all scopes must be strings", id)
264			}
265			scopes = append(scopes, scope)
266		}
267		ret.Scopes = scopes
268	}
269
270	return ret, nil
271}
272
273func (h *Host) parseURL(urlStr string) (*url.URL, error) {
274	u, err := url.Parse(urlStr)
275	if err != nil {
276		return nil, err
277	}
278
279	// Make relative URLs absolute using our discovery URL.
280	if !u.IsAbs() {
281		u = h.discoURL.ResolveReference(u)
282	}
283
284	if u.Scheme != "https" && u.Scheme != "http" {
285		return nil, fmt.Errorf("unsupported scheme %s", u.Scheme)
286	}
287	if u.User != nil {
288		return nil, fmt.Errorf("embedded username/password information is not permitted")
289	}
290
291	// Fragment part is irrelevant, since we're not a browser.
292	u.Fragment = ""
293
294	return u, nil
295}
296
297// VersionConstraints returns the contraints for a given service identifier
298// (which should be of the form "servicename.vN") and product.
299//
300// When an exact (service and version) match is found, the constraints for
301// that service are returned.
302//
303// When the requested version is not provided but the service is, we will
304// search for all alternative versions. If mutliple alternative versions
305// are found, the contrains of the latest available version are returned.
306//
307// When a service is not provided at all an error will be returned instead.
308//
309// When checkpoint is disabled or when a 404 is returned after making the
310// HTTP call, an ErrNoVersionConstraints error will be returned.
311func (h *Host) VersionConstraints(id, product string) (*Constraints, error) {
312	svc, _, err := parseServiceID(id)
313	if err != nil {
314		return nil, err
315	}
316
317	// Return early if checkpoint is disabled.
318	if disabled := os.Getenv("CHECKPOINT_DISABLE"); disabled != "" {
319		return nil, &ErrNoVersionConstraints{disabled: true}
320	}
321
322	// No services supported for an empty Host.
323	if h == nil || h.services == nil {
324		return nil, &ErrServiceNotProvided{service: svc}
325	}
326
327	// Try to get the service URL for the version service and
328	// return early if the service isn't provided by the host.
329	u, err := h.ServiceURL(versionServiceID)
330	if err != nil {
331		return nil, err
332	}
333
334	// Check if we have an exact (service and version) match.
335	if _, ok := h.services[id].(string); !ok {
336		// If we don't have an exact match, we search for all matching
337		// services and then use the service ID of the latest version.
338		var services []string
339		for serviceID := range h.services {
340			if strings.HasPrefix(serviceID, svc+".") {
341				services = append(services, serviceID)
342			}
343		}
344
345		if len(services) == 0 {
346			// No discovered services match the requested service.
347			return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
348		}
349
350		// Set id to the latest service ID we found.
351		var latest *version.Version
352		for _, serviceID := range services {
353			if _, ver, err := parseServiceID(serviceID); err == nil {
354				if latest == nil || latest.LessThan(ver) {
355					id = serviceID
356					latest = ver
357				}
358			}
359		}
360	}
361
362	// Set a default timeout of 1 sec for the versions request (in milliseconds)
363	timeout := 1000
364	if v, err := strconv.Atoi(os.Getenv("CHECKPOINT_TIMEOUT")); err == nil {
365		timeout = v
366	}
367
368	client := &http.Client{
369		Transport: h.transport,
370		Timeout:   time.Duration(timeout) * time.Millisecond,
371	}
372
373	// Prepare the service URL by setting the service and product.
374	v := u.Query()
375	v.Set("product", product)
376	u.Path += id
377	u.RawQuery = v.Encode()
378
379	// Create a new request.
380	req, err := http.NewRequest("GET", u.String(), nil)
381	if err != nil {
382		return nil, fmt.Errorf("Failed to create version constraints request: %v", err)
383	}
384	req.Header.Set("Accept", "application/json")
385
386	log.Printf("[DEBUG] Retrieve version constraints for service %s and product %s", id, product)
387
388	resp, err := client.Do(req)
389	if err != nil {
390		return nil, fmt.Errorf("Failed to request version constraints: %v", err)
391	}
392	defer resp.Body.Close()
393
394	if resp.StatusCode == 404 {
395		return nil, &ErrNoVersionConstraints{disabled: false}
396	}
397
398	if resp.StatusCode != 200 {
399		return nil, fmt.Errorf("Failed to request version constraints: %s", resp.Status)
400	}
401
402	// Parse the constraints from the response body.
403	result := &Constraints{}
404	if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
405		return nil, fmt.Errorf("Error parsing version constraints: %v", err)
406	}
407
408	return result, nil
409}
410
411func parseServiceID(id string) (string, *version.Version, error) {
412	parts := strings.SplitN(id, ".", 2)
413	if len(parts) != 2 {
414		return "", nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id)
415	}
416
417	version, err := version.NewVersion(parts[1])
418	if err != nil {
419		return "", nil, fmt.Errorf("Invalid service version: %v", err)
420	}
421
422	return parts[0], version, nil
423}
424