1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"sync"
9	"time"
10)
11
12const cacheMaxAge = 5 * time.Second
13
14func parseLiteralIP(addr string) string {
15	var ip IP
16	var zone string
17	ip = parseIPv4(addr)
18	if ip == nil {
19		ip, zone = parseIPv6(addr, true)
20	}
21	if ip == nil {
22		return ""
23	}
24	if zone == "" {
25		return ip.String()
26	}
27	return ip.String() + "%" + zone
28}
29
30// hosts contains known host entries.
31var hosts struct {
32	sync.Mutex
33
34	// Key for the list of literal IP addresses must be a host
35	// name. It would be part of DNS labels, a FQDN or an absolute
36	// FQDN.
37	// For now the key is converted to lower case for convenience.
38	byName map[string][]string
39
40	// Key for the list of host names must be a literal IP address
41	// including IPv6 address with zone identifier.
42	// We don't support old-classful IP address notation.
43	byAddr map[string][]string
44
45	expire time.Time
46	path   string
47	mtime  time.Time
48	size   int64
49}
50
51func readHosts() {
52	now := time.Now()
53	hp := testHookHostsPath
54
55	if now.Before(hosts.expire) && hosts.path == hp && len(hosts.byName) > 0 {
56		return
57	}
58	mtime, size, err := stat(hp)
59	if err == nil && hosts.path == hp && hosts.mtime.Equal(mtime) && hosts.size == size {
60		hosts.expire = now.Add(cacheMaxAge)
61		return
62	}
63
64	hs := make(map[string][]string)
65	is := make(map[string][]string)
66	var file *file
67	if file, _ = open(hp); file == nil {
68		return
69	}
70	for line, ok := file.readLine(); ok; line, ok = file.readLine() {
71		if i := byteIndex(line, '#'); i >= 0 {
72			// Discard comments.
73			line = line[0:i]
74		}
75		f := getFields(line)
76		if len(f) < 2 {
77			continue
78		}
79		addr := parseLiteralIP(f[0])
80		if addr == "" {
81			continue
82		}
83		for i := 1; i < len(f); i++ {
84			name := absDomainName([]byte(f[i]))
85			h := []byte(f[i])
86			lowerASCIIBytes(h)
87			key := absDomainName(h)
88			hs[key] = append(hs[key], addr)
89			is[addr] = append(is[addr], name)
90		}
91	}
92	// Update the data cache.
93	hosts.expire = now.Add(cacheMaxAge)
94	hosts.path = hp
95	hosts.byName = hs
96	hosts.byAddr = is
97	hosts.mtime = mtime
98	hosts.size = size
99	file.close()
100}
101
102// lookupStaticHost looks up the addresses for the given host from /etc/hosts.
103func lookupStaticHost(host string) []string {
104	hosts.Lock()
105	defer hosts.Unlock()
106	readHosts()
107	if len(hosts.byName) != 0 {
108		// TODO(jbd,bradfitz): avoid this alloc if host is already all lowercase?
109		// or linear scan the byName map if it's small enough?
110		lowerHost := []byte(host)
111		lowerASCIIBytes(lowerHost)
112		if ips, ok := hosts.byName[absDomainName(lowerHost)]; ok {
113			ipsCp := make([]string, len(ips))
114			copy(ipsCp, ips)
115			return ipsCp
116		}
117	}
118	return nil
119}
120
121// lookupStaticAddr looks up the hosts for the given address from /etc/hosts.
122func lookupStaticAddr(addr string) []string {
123	hosts.Lock()
124	defer hosts.Unlock()
125	readHosts()
126	addr = parseLiteralIP(addr)
127	if addr == "" {
128		return nil
129	}
130	if len(hosts.byAddr) != 0 {
131		if hosts, ok := hosts.byAddr[addr]; ok {
132			hostsCp := make([]string, len(hosts))
133			copy(hostsCp, hosts)
134			return hostsCp
135		}
136	}
137	return nil
138}
139