1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package netutil implements network-related utility functions.
16package netutil
17
18import (
19	"context"
20	"net"
21	"net/url"
22	"reflect"
23	"sort"
24	"time"
25
26	"github.com/coreos/etcd/pkg/types"
27	"github.com/coreos/pkg/capnslog"
28)
29
30var (
31	plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "pkg/netutil")
32
33	// indirection for testing
34	resolveTCPAddr = resolveTCPAddrDefault
35)
36
37const retryInterval = time.Second
38
39// taken from go's ResolveTCP code but uses configurable ctx
40func resolveTCPAddrDefault(ctx context.Context, addr string) (*net.TCPAddr, error) {
41	host, port, serr := net.SplitHostPort(addr)
42	if serr != nil {
43		return nil, serr
44	}
45	portnum, perr := net.DefaultResolver.LookupPort(ctx, "tcp", port)
46	if perr != nil {
47		return nil, perr
48	}
49
50	var ips []net.IPAddr
51	if ip := net.ParseIP(host); ip != nil {
52		ips = []net.IPAddr{{IP: ip}}
53	} else {
54		// Try as a DNS name.
55		ipss, err := net.DefaultResolver.LookupIPAddr(ctx, host)
56		if err != nil {
57			return nil, err
58		}
59		ips = ipss
60	}
61	// randomize?
62	ip := ips[0]
63	return &net.TCPAddr{IP: ip.IP, Port: portnum, Zone: ip.Zone}, nil
64}
65
66// resolveTCPAddrs is a convenience wrapper for net.ResolveTCPAddr.
67// resolveTCPAddrs return a new set of url.URLs, in which all DNS hostnames
68// are resolved.
69func resolveTCPAddrs(ctx context.Context, urls [][]url.URL) ([][]url.URL, error) {
70	newurls := make([][]url.URL, 0)
71	for _, us := range urls {
72		nus := make([]url.URL, len(us))
73		for i, u := range us {
74			nu, err := url.Parse(u.String())
75			if err != nil {
76				return nil, err
77			}
78			nus[i] = *nu
79		}
80		for i, u := range nus {
81			h, err := resolveURL(ctx, u)
82			if err != nil {
83				return nil, err
84			}
85			if h != "" {
86				nus[i].Host = h
87			}
88		}
89		newurls = append(newurls, nus)
90	}
91	return newurls, nil
92}
93
94func resolveURL(ctx context.Context, u url.URL) (string, error) {
95	for ctx.Err() == nil {
96		host, _, err := net.SplitHostPort(u.Host)
97		if err != nil {
98			plog.Errorf("could not parse url %s during tcp resolving", u.Host)
99			return "", err
100		}
101		if host == "localhost" || net.ParseIP(host) != nil {
102			return "", nil
103		}
104		tcpAddr, err := resolveTCPAddr(ctx, u.Host)
105		if err == nil {
106			plog.Infof("resolving %s to %s", u.Host, tcpAddr.String())
107			return tcpAddr.String(), nil
108		}
109		plog.Warningf("failed resolving host %s (%v); retrying in %v", u.Host, err, retryInterval)
110		select {
111		case <-ctx.Done():
112			plog.Errorf("could not resolve host %s", u.Host)
113			return "", err
114		case <-time.After(retryInterval):
115		}
116	}
117	return "", ctx.Err()
118}
119
120// urlsEqual checks equality of url.URLS between two arrays.
121// This check pass even if an URL is in hostname and opposite is in IP address.
122func urlsEqual(ctx context.Context, a []url.URL, b []url.URL) bool {
123	if len(a) != len(b) {
124		return false
125	}
126	urls, err := resolveTCPAddrs(ctx, [][]url.URL{a, b})
127	if err != nil {
128		return false
129	}
130	a, b = urls[0], urls[1]
131	sort.Sort(types.URLs(a))
132	sort.Sort(types.URLs(b))
133	for i := range a {
134		if !reflect.DeepEqual(a[i], b[i]) {
135			return false
136		}
137	}
138
139	return true
140}
141
142func URLStringsEqual(ctx context.Context, a []string, b []string) bool {
143	if len(a) != len(b) {
144		return false
145	}
146	urlsA := make([]url.URL, 0)
147	for _, str := range a {
148		u, err := url.Parse(str)
149		if err != nil {
150			return false
151		}
152		urlsA = append(urlsA, *u)
153	}
154	urlsB := make([]url.URL, 0)
155	for _, str := range b {
156		u, err := url.Parse(str)
157		if err != nil {
158			return false
159		}
160		urlsB = append(urlsB, *u)
161	}
162
163	return urlsEqual(ctx, urlsA, urlsB)
164}
165
166func IsNetworkTimeoutError(err error) bool {
167	nerr, ok := err.(net.Error)
168	return ok && nerr.Timeout()
169}
170