1# Copyright 2009 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A base user-interface workflow, to be inherited by UI modules."""
16
17import tempfile
18
19import benchmark
20import better_webbrowser
21import config
22import data_sources
23import geoip
24import nameserver_list
25import reporter
26import site_connector
27import util
28
29
30__author__ = 'tstromberg@google.com (Thomas Stromberg)'
31
32
33class BaseUI(object):
34  """Common methods for all UI implementations."""
35
36  def __init__(self):
37    self.SetupDataStructures()
38
39  def SetupDataStructures(self):
40    """Instead of requiring users to inherit __init__(), this sets up structures."""
41    self.reporter = None
42    self.nameservers = None
43    self.bmark = None
44    self.report_path = None
45    self.csv_path = None
46    self.geodata = None
47    self.country = None
48    self.sources = {}
49    self.url = None
50    self.share_state = None
51    self.test_records = []
52
53  def UpdateStatus(self, msg, **kwargs):
54    """Update the little status message on the bottom of the window."""
55    if hasattr(self, 'status_callback') and self.status_callback:
56      self.status_callback(msg, **kwargs)
57    else:
58      print msg
59
60  def DebugMsg(self, message):
61    self.UpdateStatus(message, debug=True)
62
63  def LoadDataSources(self):
64    self.data_src = data_sources.DataSources(status_callback=self.UpdateStatus)
65
66  def PrepareTestRecords(self):
67    """Figure out what data source a user wants, and create test_records."""
68    if self.options.input_source:
69      src_type = self.options.input_source
70    else:
71      src_type = self.data_src.GetBestSourceDetails()[0]
72      self.options.input_source = src_type
73
74    self.test_records = self.data_src.GetTestsFromSource(
75        src_type,
76        self.options.query_count,
77        select_mode=self.options.select_mode
78    )
79
80  def PrepareNameServers(self):
81    """Setup self.nameservers to have a list of healthy fast servers."""
82    self.nameservers = nameserver_list.NameServers(
83        self.supplied_ns,
84        global_servers=self.global_ns,
85        regional_servers=self.regional_ns,
86        include_internal=self.include_internal,
87        num_servers=self.options.num_servers,
88        timeout=self.options.timeout,
89        ping_timeout=self.options.ping_timeout,
90        health_timeout=self.options.health_timeout,
91        ipv6_only=self.options.ipv6_only,
92        status_callback=self.UpdateStatus
93    )
94    if self.options.invalidate_cache:
95      self.nameservers.InvalidateSecondaryCache()
96
97    self.nameservers.cache_dir = tempfile.gettempdir()
98
99    # Don't waste time checking the health of the only nameserver in the list.
100    if len(self.nameservers) > 1:
101      self.nameservers.thread_count = int(self.options.health_thread_count)
102      self.nameservers.cache_dir = tempfile.gettempdir()
103
104    self.UpdateStatus('Checking latest sanity reference')
105    (primary_checks, secondary_checks, censor_tests) = config.GetLatestSanityChecks()
106    if not self.options.enable_censorship_checks:
107      censor_tests = []
108    else:
109      self.UpdateStatus('Censorship checks enabled: %s found.' % len(censor_tests))
110
111    self.nameservers.CheckHealth(primary_checks, secondary_checks, censor_tests=censor_tests)
112
113  def PrepareBenchmark(self):
114    """Setup the benchmark object with the appropriate dataset."""
115    if len(self.nameservers) == 1:
116      thread_count = 1
117    else:
118      thread_count = self.options.benchmark_thread_count
119
120    self.bmark = benchmark.Benchmark(self.nameservers,
121                                     query_count=self.options.query_count,
122                                     run_count=self.options.run_count,
123                                     thread_count=thread_count,
124                                     status_callback=self.UpdateStatus)
125
126  def RunBenchmark(self):
127    """Run the benchmark."""
128    results = self.bmark.Run(self.test_records)
129    index = []
130    if self.options.upload_results in (1, True):
131      connector = site_connector.SiteConnector(self.options, status_callback=self.UpdateStatus)
132      index_hosts = connector.GetIndexHosts()
133      if index_hosts:
134        index = self.bmark.RunIndex(index_hosts)
135      else:
136        index = []
137      self.DiscoverLocation()
138      if len(self.nameservers) > 1:
139        self.nameservers.RunPortBehaviorThreads()
140
141    self.reporter = reporter.ReportGenerator(self.options, self.nameservers,
142                                             results, index=index, geodata=self.geodata)
143
144  def DiscoverLocation(self):
145    if not getattr(self, 'geodata', None):
146      self.geodata = geoip.GetGeoData()
147      self.country = self.geodata.get('country_name', None)
148    return self.geodata
149
150  def RunAndOpenReports(self):
151    """Run the benchmark and open up the report on completion."""
152    self.RunBenchmark()
153    best = self.reporter.BestOverallNameServer()
154    self.CreateReports()
155    if self.options.template == 'html':
156      self.DisplayHtmlReport()
157    if self.url:
158      self.UpdateStatus('Complete! Your results: %s' % self.url)
159    else:
160      self.UpdateStatus('Complete! %s [%s] is the best.' % (best.name, best.ip))
161
162  def CreateReports(self):
163    """Create CSV & HTML reports for the latest run."""
164
165    if self.options.output_file:
166      self.report_path = self.options.output_file
167    else:
168      self.report_path = util.GenerateOutputFilename(self.options.template)
169
170    if self.options.csv_file:
171      self.csv_path = self.options_csv_file
172    else:
173      self.csv_path = util.GenerateOutputFilename('csv')
174
175    if self.options.upload_results in (1, True):
176      # This is for debugging and transparency only.
177      self.json_path = util.GenerateOutputFilename('js')
178      self.UpdateStatus('Saving anonymized JSON to %s' % self.json_path)
179      json_data = self.reporter.CreateJsonData()
180      f = open(self.json_path, 'w')
181      f.write(json_data)
182      f.close()
183
184      self.UpdateStatus('Uploading results to %s' % self.options.site_url)
185      connector = site_connector.SiteConnector(self.options, status_callback=self.UpdateStatus)
186      self.url, self.share_state = connector.UploadJsonResults(
187          json_data,
188          hide_results=self.options.hide_results
189      )
190
191      if self.url:
192        self.UpdateStatus('Your sharing URL: %s (%s)' % (self.url, self.share_state))
193
194    self.UpdateStatus('Saving report to %s' % self.report_path)
195    f = open(self.report_path, 'w')
196    self.reporter.CreateReport(format=self.options.template,
197                               output_fp=f,
198                               csv_path=self.csv_path,
199                               sharing_url=self.url,
200                               sharing_state=self.share_state)
201    f.close()
202
203    self.UpdateStatus('Saving detailed results to %s' % self.csv_path)
204    self.reporter.SaveResultsToCsv(self.csv_path)
205
206  def DisplayHtmlReport(self):
207    self.UpdateStatus('Opening %s' % self.report_path)
208    better_webbrowser.output = self.DebugMsg
209    better_webbrowser.open(self.report_path)
210
211