1# Copyright 2009 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16"""Module for all nameserver health checks."""
17
18__author__ = 'tstromberg@google.com (Thomas Stromberg)'
19
20import random
21import sys
22import time
23import util
24
25from dns import rcode
26
27WILDCARD_DOMAINS = ('live.com.', 'blogspot.com.', 'wordpress.com.')
28LIKELY_HIJACKS = ['www.google.com.', 'windowsupdate.microsoft.com.', 'www.paypal.com.']
29
30# How many checks to consider when calculating ns check_duration
31SHARED_CACHE_TIMEOUT_MULTIPLIER = 1.25
32ROOT_SERVER_TIMEOUT_MULTIPLIER = 0.5
33CENSORSHIP_TIMEOUT = 30
34MAX_STORE_ATTEMPTS = 4
35TOTAL_WILDCARDS_TO_STORE = 2
36MAX_PORT_BEHAVIOR_TRIES = 2
37
38FATAL_RCODES = ['REFUSED', 'NOTAUTH']
39
40class NameServerHealthChecks(object):
41  """Health checks for a nameserver."""
42
43  def TestAnswers(self, record_type, record, expected, critical=False, timeout=None):
44    """Test to see that an answer returns correct IP's.
45
46    Args:
47      record_type: text record type for NS query (A, CNAME, etc)
48      record: string to query for
49      expected: tuple of strings expected in all answers
50      critical: If this query fails, should it count against the server.
51      timeout: timeout for query in seconds (int)
52
53    Returns:
54      (is_broken, error_msg, duration)
55    """
56    is_broken = False
57    unmatched_answers = []
58    if not timeout:
59      timeout = self.health_timeout
60    (response, duration, error_msg) = self.TimedRequest(record_type, record, timeout)
61    if response:
62      response_code = rcode.to_text(response.rcode())
63      if response_code in FATAL_RCODES:
64        error_msg = 'Responded with: %s' % response_code
65        if critical:
66          is_broken = True
67      elif not response.answer:
68        # Avoid preferring broken DNS servers that respond quickly
69        duration = util.SecondsToMilliseconds(self.health_timeout)
70        error_msg = 'No answer (%s): %s' % (response_code, record)
71        is_broken = True
72      else:
73        found_usable_record = False
74        for answer in response.answer:
75          if found_usable_record:
76            break
77
78          # Process the first sane rdata object available in the answers
79          for rdata in answer:
80            # CNAME
81            if rdata.rdtype == 5:
82              reply = str(rdata.target)
83            # A Record
84            elif rdata.rdtype == 1:
85              reply = str(rdata.address)
86            else:
87              continue
88
89            found_usable_record = True
90            found_match = False
91            for string in expected:
92              if reply.startswith(string) or reply.endswith(string):
93                found_match = True
94                break
95            if not found_match:
96              unmatched_answers.append(reply)
97
98        if unmatched_answers:
99          hijack_text = ', '.join(unmatched_answers).rstrip('.')
100          if record in LIKELY_HIJACKS:
101            error_msg = '%s is hijacked: %s' % (record.rstrip('.'), hijack_text)
102          else:
103            error_msg = '%s appears incorrect: %s' % (record.rstrip('.'), hijack_text)
104    else:
105      if not error_msg:
106        error_msg = 'No response'
107      is_broken = True
108
109    return (is_broken, error_msg, duration)
110
111  def TestBindVersion(self):
112    """Test for BIND version. This acts as a pretty decent ping."""
113    (unused_response, duration, error_msg) = self.RequestVersion()
114    # Sometimes nameservers aren't able to respond to this request in a way that
115    # dnspython likes. Lets just always call this one good and use it for latency.
116    return (False, None, duration)
117
118  def TestNodeId(self):
119    """Get the current node id."""
120    self.RequestNodeId()
121    return (False, False, 0.0)
122
123  def TestNegativeResponse(self, prefix=None):
124    """Test for NXDOMAIN hijaaking."""
125    is_broken = False
126    if prefix:
127      hostname = prefix
128      warning_suffix = ' (%s)' % prefix
129    else:
130      hostname = 'test'
131      warning_suffix = ''
132    poison_test = '%s.nb%s.google.com.' % (hostname, random.random())
133    (response, duration, error_msg) = self.TimedRequest('A', poison_test,
134                                                        timeout=self.health_timeout*2)
135    if not response:
136      if not error_msg:
137        error_msg = 'No response'
138      is_broken = True
139    elif response.answer:
140      error_msg = 'NXDOMAIN Hijacking' + warning_suffix
141
142    return (is_broken, error_msg, duration)
143
144  def TestRootNsResponse(self):
145    """Test a . NS response.
146
147    NOTE: This is a bad way to gauge performance of a nameserver, as the
148    response length varies between nameserver configurations.
149    """
150    is_broken = False
151    error_msg = None
152    (response, duration, error_msg) = self.TimedRequest('NS', '.')
153    if not response:
154      response_code = None
155      is_broken = True
156      if not error_msg:
157        error_msg = 'No response'
158    else:
159      response_code = rcode.to_text(response.rcode())
160      if response_code in FATAL_RCODES:
161        error_msg = response_code
162        is_broken = True
163
164    return (is_broken, error_msg, duration)
165
166  def TestWwwNegativeResponse(self):
167    return self.TestNegativeResponse(prefix='www')
168
169  def TestARootServerResponse(self):
170    return self.TestAnswers('A', 'a.root-servers.net.', '198.41.0.4', critical=True)
171
172  def TestPortBehavior(self, tries=0):
173    """This is designed to be called multiple times to retry bad results."""
174
175    if self.port_behavior:
176      if 'UNKNOWN' not in self.port_behavior:
177        return (False, None, 0)
178
179    tries += 1
180    response = self.TimedRequest('TXT', 'porttest.dns-oarc.net.', timeout=5)[0]
181    if response and response.answer:
182      if len(response.answer) > 1:
183        text = response.answer[1].to_rdataset().to_text()
184        self.port_behavior = text.split('"')[1]
185
186    if (not self.port_behavior or 'UNKNOWN' in self.port_behavior) and tries < MAX_PORT_BEHAVIOR_TRIES:
187      time.sleep(1)
188      return self.TestPortBehavior(tries=tries)
189
190#    print "%s behavior: %s (tries=%s)" % (self, self.port_behavior, tries)
191    return (False, None, 0)
192
193  def StoreWildcardCache(self):
194    """Store a set of wildcard records."""
195    timeout = self.health_timeout * SHARED_CACHE_TIMEOUT_MULTIPLIER
196    attempted = []
197
198    while len(self.cache_checks) != TOTAL_WILDCARDS_TO_STORE:
199      if len(attempted) == MAX_STORE_ATTEMPTS:
200        self.disabled = 'Unable to get uncached results for: %s' % ', '.join(attempted)
201        return False
202      domain = random.choice(WILDCARD_DOMAINS)
203      hostname = 'namebench%s.%s' % (random.randint(1, 2**32), domain)
204      attempted.append(hostname)
205      response = self.TimedRequest('A', hostname, timeout=timeout)[0]
206      if response and response.answer:
207        self.cache_checks.append((hostname, response, self.timer()))
208      else:
209        sys.stdout.write('x')
210
211  def TestSharedCache(self, other_ns):
212    """Is this nameserver sharing a cache with another nameserver?
213
214    Args:
215      other_ns: A nameserver to compare it to.
216
217    Returns:
218      A tuple containing:
219        - Boolean of whether or not this host has a shared cache
220        - The faster NameServer object
221        - The slower NameServer object
222    """
223    timeout = self.health_timeout * SHARED_CACHE_TIMEOUT_MULTIPLIER
224    checked = []
225    shared = False
226
227    if self.disabled or other_ns.disabled:
228      return False
229
230    if not other_ns.cache_checks:
231      print '%s has no cache checks (disabling - how did this happen?)' % other_ns
232      other_ns.disabled = 'Unable to perform cache checks.'
233      return False
234
235    for (ref_hostname, ref_response, ref_timestamp) in other_ns.cache_checks:
236      response = self.TimedRequest('A', ref_hostname, timeout=timeout)[0]
237      # Retry once - this *may* cause false positives however, as the TTL may be updated.
238      if not response or not response.answer:
239        sys.stdout.write('x')
240        response = self.TimedRequest('A', ref_hostname, timeout=timeout)[0]
241
242      if response and response.answer:
243        ref_ttl = ref_response.answer[0].ttl
244        ttl = response.answer[0].ttl
245        delta = abs(ref_ttl - ttl)
246        query_age = self.timer() - ref_timestamp
247        delta_age_delta = abs(query_age - delta)
248
249        if delta > 0 and delta_age_delta < 2:
250          return other_ns
251      else:
252        sys.stdout.write('!')
253      checked.append(ref_hostname)
254
255    if not checked:
256      self.AddFailure('Failed to test %s wildcard caches'  % len(other_ns.cache_checks))
257    return shared
258
259  def CheckCensorship(self, tests):
260    """Check to see if results from a nameserver are being censored."""
261    for (check, expected) in tests:
262      (req_type, req_name) = check.split(' ')
263      expected_values = expected.split(',')
264      result = self.TestAnswers(req_type.upper(), req_name, expected_values,
265                                timeout=CENSORSHIP_TIMEOUT)
266      warning = result[1]
267      if warning:
268        self.AddWarning(warning, penalty=False)
269
270  def CheckHealth(self, sanity_checks=None, fast_check=False, final_check=False, port_check=False):
271    """Qualify a nameserver to see if it is any good."""
272
273    is_fatal = False
274    if fast_check:
275      tests = [(self.TestARootServerResponse, [])]
276      is_fatal = True
277      sanity_checks = []
278    elif final_check:
279      tests = [(self.TestWwwNegativeResponse, []), (self.TestPortBehavior, []), (self.TestNodeId, [])]
280    elif port_check:
281      tests = [(self.TestPortBehavior, []), (self.TestNodeId, [])]
282    else:
283      # Put the bind version here so that we have a great minimum latency measurement.
284      tests = [(self.TestNegativeResponse, []), (self.TestBindVersion, [])]
285
286    if sanity_checks:
287      for (check, expected_value) in sanity_checks:
288        (req_type, req_name) = check.split(' ')
289        expected_values = expected_value.split(',')
290        tests.append((self.TestAnswers, [req_type.upper(), req_name, expected_values]))
291
292    for test in tests:
293      (function, args) = test
294      (is_broken, warning, duration) = function(*args)
295      if args:
296        test_name = args[1]
297      else:
298        test_name = function.__name__
299
300      self.checks.append((test_name, is_broken, warning, duration))
301      if is_broken:
302        self.AddFailure('%s: %s' % (test_name, warning), fatal=is_fatal)
303      if warning:
304        # Special case for NXDOMAIN de-duplication
305        if not ('NXDOMAIN' in warning and 'NXDOMAIN Hijacking' in self.warnings):
306          self.AddWarning(warning)
307      if self.disabled:
308        break
309
310    return self.disabled
311
312