1#!/usr/bin/env python
2# Copyright 2009 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Provides data sources to use for benchmarking."""
17
18import ConfigParser
19import glob
20import os
21import os.path
22import random
23import re
24import subprocess
25import sys
26import time
27
28# relative
29import addr_util
30import selectors
31import util
32
33# Pick the most accurate timer for a platform. Stolen from timeit.py:
34if sys.platform[:3] == 'win':
35  DEFAULT_TIMER = time.clock
36else:
37  DEFAULT_TIMER = time.time
38
39GLOBAL_DATA_CACHE = {}
40
41DEFAULT_CONFIG_PATH = 'config/data_sources.cfg'
42MAX_NON_UNIQUE_RECORD_COUNT = 500000
43MAX_FILE_MTIME_AGE_DAYS = 45
44MIN_FILE_SIZE = 10000
45MIN_RECOMMENDED_RECORD_COUNT = 200
46MAX_FQDN_SYNTHESIZE_PERCENT = 4
47
48
49class DataSources(object):
50  """A collection of methods related to available hostname data sources."""
51
52  def __init__(self, config_path=DEFAULT_CONFIG_PATH, status_callback=None):
53    self.source_cache = GLOBAL_DATA_CACHE
54    self.source_config = {}
55    self.status_callback = status_callback
56    self._LoadConfigFromPath(config_path)
57
58  def msg(self, msg, **kwargs):
59    if self.status_callback:
60      self.status_callback(msg, **kwargs)
61    else:
62      print '- %s' % msg
63
64  def _LoadConfigFromPath(self, path):
65    """Load a configuration file describing data sources that may be available."""
66    conf_file = util.FindDataFile(path)
67    config = ConfigParser.ConfigParser()
68    config.read(conf_file)
69    for section in config.sections():
70      if section not in self.source_config:
71        self.source_config[section] = {
72            'name': None,
73            'search_paths': set(),
74            # Store whether or not this data source contains personal data
75            'full_hostnames': True,
76            'synthetic': False,
77            'include_duplicates': False,
78            'max_mtime_days': MAX_FILE_MTIME_AGE_DAYS
79        }
80
81      for (key, value) in config.items(section):
82        if key == 'name':
83          self.source_config[section]['name'] = value
84        elif key == 'full_hostnames' and int(value) == 0:
85          self.source_config[section]['full_hostnames'] = False
86        elif key == 'max_mtime_days':
87          self.source_config[section]['max_mtime_days'] = int(value)
88        elif key == 'include_duplicates':
89          self.source_config[section]['include_duplicates'] = bool(value)
90        elif key == 'synthetic':
91          self.source_config[section]['synthetic'] = bool(value)
92        else:
93          self.source_config[section]['search_paths'].add(value)
94
95  def ListSourceTypes(self):
96    """Get a list of all data sources we know about."""
97    return sorted(self.source_config.keys())
98
99  def ListSourcesWithDetails(self):
100    """Get a list of all data sources found with total counts.
101
102    Returns:
103      List of tuples in form of (short_name, full_name, full_hosts, # of entries)
104    """
105    for source in self.ListSourceTypes():
106      max_mtime = self.source_config[source]['max_mtime_days']
107      self._GetHostsFromSource(source, min_file_size=MIN_FILE_SIZE,
108                               max_mtime_age_days=max_mtime)
109
110    details = []
111    for source in self.source_cache:
112      details.append((source,
113                      self.source_config[source]['name'],
114                      self.source_config[source]['synthetic'],
115                      len(self.source_cache[source])))
116    return sorted(details, key=lambda x: (x[2], x[3] * -1))
117
118  def ListSourceTitles(self):
119    """Return a list of sources in title + count format."""
120    titles = []
121    seen_synthetic = False
122    seen_organic = False
123    for (unused_type, name, is_synthetic, count) in self.ListSourcesWithDetails():
124      if not is_synthetic:
125        seen_organic = True
126
127      if is_synthetic and seen_organic and not seen_synthetic:
128        titles.append('-' * 36)
129        seen_synthetic = True
130      titles.append('%s (%s)' % (name, count))
131    return titles
132
133  def ConvertSourceTitleToType(self, detail):
134    """Convert a detail name to a source type."""
135    for source_type in self.source_config:
136      if detail.startswith(self.source_config[source_type]['name']):
137        return source_type
138
139  def GetBestSourceDetails(self):
140    return self.ListSourcesWithDetails()[0]
141
142  def GetNameForSource(self, source):
143    if source in self.source_config:
144      return self.source_config[source]['name']
145    else:
146      # Most likely a custom file path
147      return source
148
149  def GetCachedRecordCountForSource(self, source):
150    return len(self.source_cache[source])
151
152  def _CreateRecordsFromHostEntries(self, entries, include_duplicates=False):
153    """Create records from hosts, removing duplicate entries and IP's.
154
155    Args:
156      entries: A list of test-data entries.
157      include_duplicates: Whether or not to filter duplicates (optional: False)
158
159    Returns:
160      A tuple of (filtered records, full_host_names (Boolean)
161
162    Raises:
163      ValueError: If no records could be grokked from the input.
164    """
165    last_entry = None
166
167    records = []
168    full_host_count = 0
169    for entry in entries:
170      if entry == last_entry and not include_duplicates:
171        continue
172      else:
173        last_entry = entry
174
175      if ' ' in entry:
176        (record_type, host) = entry.split(' ')
177      else:
178        record_type = 'A'
179        host = entry
180
181      if not addr_util.IP_RE.match(host) and not addr_util.INTERNAL_RE.search(host):
182        if not host.endswith('.'):
183          host += '.'
184        records.append((record_type, host))
185
186        if addr_util.FQDN_RE.match(host):
187          full_host_count += 1
188
189    if not records:
190      raise ValueError('No records could be created from: %s' % entries)
191
192    # Now that we've read everything, are we dealing with domains or full hostnames?
193    full_host_percent = full_host_count / float(len(records)) * 100
194    if full_host_percent < MAX_FQDN_SYNTHESIZE_PERCENT:
195      full_host_names = True
196    else:
197      full_host_names = False
198    return (records, full_host_names)
199
200  def GetTestsFromSource(self, source, count=50, select_mode=None):
201    """Parse records from source, and return tuples to use for testing.
202
203    Args:
204      source: A source name (str) that has been configured.
205      count: Number of tests to generate from the source (int)
206      select_mode: automatic, weighted, random, chunk (str)
207
208    Returns:
209      A list of record tuples in the form of (req_type, hostname)
210
211    Raises:
212      ValueError: If no usable records are found from the data source.
213
214    This is tricky because we support 3 types of input data:
215
216    - List of domains
217    - List of hosts
218    - List of record_type + hosts
219    """
220    records = []
221
222    if source in self.source_config:
223      include_duplicates = self.source_config[source].get('include_duplicates', False)
224    else:
225      include_duplicates = False
226
227    records = self._GetHostsFromSource(source)
228    if not records:
229      raise ValueError('Unable to generate records from %s (nothing found)' % source)
230
231    self.msg('Generating tests from %s (%s records, selecting %s %s)'
232             % (self.GetNameForSource(source), len(records), count, select_mode))
233    (records, are_records_fqdn) = self._CreateRecordsFromHostEntries(records,
234                                                                     include_duplicates=include_duplicates)
235    # First try to resolve whether to use weighted or random.
236    if select_mode in ('weighted', 'automatic', None):
237      # If we are in include_duplicates mode (cachemiss, cachehit, etc.), we have different rules.
238      if include_duplicates:
239        if count > len(records):
240          select_mode = 'random'
241        else:
242          select_mode = 'chunk'
243      elif len(records) != len(set(records)):
244        if select_mode == 'weighted':
245          self.msg('%s data contains duplicates, switching select_mode to random' % source)
246        select_mode = 'random'
247      else:
248        select_mode = 'weighted'
249
250    self.msg('Selecting %s out of %s sanitized records (%s mode).' %
251             (count, len(records), select_mode))
252    if select_mode == 'weighted':
253      records = selectors.WeightedDistribution(records, count)
254    elif select_mode == 'chunk':
255      records = selectors.ChunkSelect(records, count)
256    elif select_mode == 'random':
257      records = selectors.RandomSelect(records, count, include_duplicates=include_duplicates)
258    else:
259      raise ValueError('No such final selection mode: %s' % select_mode)
260
261    # For custom filenames
262    if source not in self.source_config:
263      self.source_config[source] = {'synthetic': True}
264
265    if are_records_fqdn:
266      self.source_config[source]['full_hostnames'] = False
267      self.msg('%s input appears to be predominantly domain names. Synthesizing FQDNs' % source)
268      synthesized = []
269      for (req_type, hostname) in records:
270        if not addr_util.FQDN_RE.match(hostname):
271          hostname = self._GenerateRandomHostname(hostname)
272        synthesized.append((req_type, hostname))
273      return synthesized
274    else:
275      return records
276
277  def _GenerateRandomHostname(self, domain):
278    """Generate a random hostname f or a given domain."""
279    oracle = random.randint(0, 100)
280    if oracle < 70:
281      return 'www.%s' % domain
282    elif oracle < 95:
283      return domain
284    elif oracle < 98:
285      return 'static.%s' % domain
286    else:
287      return 'cache-%s.%s' % (random.randint(0, 10), domain)
288
289  def _GetHostsFromSource(self, source, min_file_size=None, max_mtime_age_days=None):
290    """Get data for a particular source. This needs to be fast.
291
292    Args:
293      source: A configured source type (str)
294      min_file_size: What the minimum allowable file size is for this source (int)
295      max_mtime_age_days: Maximum days old the file can be for this source (int)
296
297    Returns:
298      list of hostnames gathered from data source.
299
300    The results of this function are cached by source type.
301    """
302    if source in self.source_cache:
303      return self.source_cache[source]
304    filename = self._FindBestFileForSource(source, min_file_size=min_file_size,
305                                           max_mtime_age_days=max_mtime_age_days)
306    if not filename:
307      return None
308
309    size_mb = os.path.getsize(filename) / 1024.0 / 1024.0
310    self.msg('Reading %s: %s (%0.1fMB)' % (self.GetNameForSource(source), filename, size_mb))
311    start_clock = DEFAULT_TIMER()
312    if filename.endswith('.pcap') or filename.endswith('.tcp'):
313      hosts = self._ExtractHostsFromPcapFile(filename)
314    else:
315      hosts = self._ExtractHostsFromHistoryFile(filename)
316
317    if not hosts:
318      hosts = self._ReadDataFile(filename)
319    duration = DEFAULT_TIMER() - start_clock
320    if duration > 5:
321      self.msg('%s data took %1.1fs to read!' % (self.GetNameForSource(source), duration))
322    self.source_cache[source] = hosts
323    return hosts
324
325  def _ExtractHostsFromHistoryFile(self, path):
326    """Get a list of sanitized records from a history file containing URLs."""
327    # This regexp is fairly general (no ip filtering), since we need speed more
328    # than precision at this stage.
329    parse_re = re.compile('https*://([\-\w]+\.[\-\w\.]+)')
330    return parse_re.findall(open(path, 'rb').read())
331
332  def _ExtractHostsFromPcapFile(self, path):
333    """Get a list of requests out of a pcap file - requires tcpdump."""
334    self.msg('Extracting requests from pcap file using tcpdump')
335    cmd = 'tcpdump -r %s -n port 53' % path
336    pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE).stdout
337    parse_re = re.compile(' ([A-Z]+)\? ([\-\w\.]+)')
338    requests = []
339    for line in pipe:
340      if '?' not in line:
341        continue
342      match = parse_re.search(line)
343      if match:
344        requests.append(' '.join(match.groups()))
345    return requests
346
347  def _ReadDataFile(self, path):
348    """Read a line-based datafile."""
349    records = []
350    domains_re = re.compile('^\w[\w\.-]+[a-zA-Z]$')
351    requests_re = re.compile('^[A-Z]{1,4} \w[\w\.-]+\.$')
352    for line in open(path):
353      if domains_re.match(line) or requests_re.match(line):
354        records.append(line.rstrip())
355    return records
356
357  def _GetSourceSearchPaths(self, source):
358    """Get a list of possible search paths (globs) for a given source."""
359
360    # This is likely a custom file path
361    if source not in self.source_config:
362      return [source]
363
364    search_paths = []
365    environment_re = re.compile('%(\w+)%')
366
367    # First get through resolving environment variables
368    for path in self.source_config[source]['search_paths']:
369      env_vars = set(environment_re.findall(path))
370      if env_vars:
371        for variable in env_vars:
372          env_var = os.getenv(variable, False)
373          if env_var:
374            path = path.replace('%%%s%%' % variable, env_var)
375          else:
376            path = None
377
378      # If everything is good, replace all '/'  chars with the os path variable.
379      if path:
380        path = path.replace('/', os.sep)
381        search_paths.append(path)
382
383        # This moment of weirdness brought to you by Windows XP(tm). If we find
384        # a Local or Roaming keyword in path, add the other forms to the search
385        # path.
386        if sys.platform[:3] == 'win':
387          keywords = ['Local', 'Roaming']
388          for keyword in keywords:
389            if keyword in path:
390              replacement = keywords[keywords.index(keyword)-1]
391              search_paths.append(path.replace('\\%s' % keyword, '\\%s' % replacement))
392              search_paths.append(path.replace('\\%s' % keyword, ''))
393
394    return search_paths
395
396  def _FindBestFileForSource(self, source, min_file_size=None, max_mtime_age_days=None):
397    """Find the best file (newest over X size) to use for a given source type.
398
399    Args:
400      source: source type
401      min_file_size: What the minimum allowable file size is for this source (int)
402      max_mtime_age_days: Maximum days old the file can be for this source (int)
403
404    Returns:
405      A file path.
406    """
407    found = []
408    for path in self._GetSourceSearchPaths(source):
409      if not os.path.isabs(path):
410        path = util.FindDataFile(path)
411
412      for filename in glob.glob(path):
413        if min_file_size and os.path.getsize(filename) < min_file_size:
414          self.msg('Skipping %s (only %sb)' % (filename, os.path.getsize(filename)))
415        else:
416          try:
417            fp = open(filename, 'rb')
418            fp.close()
419            found.append(filename)
420          except IOError:
421            self.msg('Skipping %s (could not open)' % filename)
422
423    if found:
424      newest = sorted(found, key=os.path.getmtime)[-1]
425      age_days = (time.time() - os.path.getmtime(newest)) / 86400
426      if max_mtime_age_days and age_days > max_mtime_age_days:
427        self.msg('Skipping %s (%2.0fd old)' % (newest, age_days))
428      else:
429        return newest
430    else:
431      return None
432
433if __name__ == '__main__':
434  parser = DataSources()
435  print parser.ListSourceTypes()
436  print parser.ListSourcesWithDetails()
437  best = parser.ListSourcesWithDetails()[0][0]
438  print len(parser.GetRecordsFromSource(best))
439