1# Copyright (c) 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import copy
6import hashlib
7import json
8import string
9import sys
10import urllib2
11
12MANIFEST_VERSION = 2
13
14# Some commonly-used key names.
15ARCHIVES_KEY = 'archives'
16BUNDLES_KEY = 'bundles'
17NAME_KEY = 'name'
18REVISION_KEY = 'revision'
19VERSION_KEY = 'version'
20
21# Valid values for the archive.host_os field
22HOST_OS_LITERALS = frozenset(['mac', 'win', 'linux', 'all'])
23
24# Valid keys for various sdk objects, used for validation.
25VALID_ARCHIVE_KEYS = frozenset(['host_os', 'size', 'checksum', 'url'])
26
27# Valid values for bundle.stability field
28STABILITY_LITERALS = [
29    'obsolete', 'post_stable', 'stable', 'beta', 'dev', 'canary']
30
31# Valid values for bundle-recommended field.
32YES_NO_LITERALS = ['yes', 'no']
33VALID_BUNDLES_KEYS = frozenset([
34    ARCHIVES_KEY, NAME_KEY, VERSION_KEY, REVISION_KEY,
35    'description', 'desc_url', 'stability', 'recommended', 'repath',
36    'sdk_revision'
37    ])
38
39VALID_MANIFEST_KEYS = frozenset(['manifest_version', BUNDLES_KEY])
40
41
42def GetHostOS():
43  '''Returns the host_os value that corresponds to the current host OS'''
44  return {
45      'linux2': 'linux',
46      'darwin': 'mac',
47      'cygwin': 'win',
48      'win32':  'win'
49      }[sys.platform]
50
51
52def DictToJSON(pydict):
53  """Convert a dict to a JSON-formatted string."""
54  pretty_string = json.dumps(pydict, sort_keys=True, indent=2)
55  # json.dumps sometimes returns trailing whitespace and does not put
56  # a newline at the end.  This code fixes these problems.
57  pretty_lines = pretty_string.split('\n')
58  return '\n'.join([line.rstrip() for line in pretty_lines]) + '\n'
59
60
61def DownloadAndComputeHash(from_stream, to_stream=None, progress_func=None):
62  '''Download the archive data from from-stream and generate sha1 and
63  size info.
64
65  Args:
66    from_stream:   An input stream that supports read.
67    to_stream:     [optional] the data is written to to_stream if it is
68                   provided.
69    progress_func: [optional] A function used to report download progress. If
70                   provided, progress_func is called with progress=0 at the
71                   beginning of the download, periodically with progress=1
72                   during the download, and progress=100 at the end.
73
74  Return
75    A tuple (sha1, size) where sha1 is a sha1-hash for the archive data and
76    size is the size of the archive data in bytes.'''
77  # Use a no-op progress function if none is specified.
78  def progress_no_op(progress):
79    pass
80  if not progress_func:
81    progress_func = progress_no_op
82
83  sha1_hash = hashlib.sha1()
84  size = 0
85  progress_func(progress=0)
86  while(1):
87    data = from_stream.read(32768)
88    if not data:
89      break
90    sha1_hash.update(data)
91    size += len(data)
92    if to_stream:
93      to_stream.write(data)
94    progress_func(size)
95
96  progress_func(progress=100)
97  return sha1_hash.hexdigest(), size
98
99
100class Error(Exception):
101  """Generic error/exception for manifest_util module"""
102  pass
103
104
105class Archive(dict):
106  """A placeholder for sdk archive information. We derive Archive from
107     dict so that it is easily serializable. """
108
109  def __init__(self, host_os_name):
110    """ Create a new archive for the given host-os name. """
111    super(Archive, self).__init__()
112    self['host_os'] = host_os_name
113
114  def CopyFrom(self, src):
115    """Update the content of the archive by copying values from the given
116       dictionary.
117
118    Args:
119      src: The dictionary whose values must be copied to the archive."""
120    for key, value in src.items():
121      self[key] = value
122
123  def Validate(self, error_on_unknown_keys=False):
124    """Validate the content of the archive object. Raise an Error if
125       an invalid or missing field is found.
126
127    Args:
128      error_on_unknown_keys: If True, raise an Error when unknown keys are
129      found in the archive.
130    """
131    host_os = self.get('host_os', None)
132    if host_os and host_os not in HOST_OS_LITERALS:
133      raise Error('Invalid host-os name in archive')
134    # Ensure host_os has a valid string. We'll use it for pretty printing.
135    if not host_os:
136      host_os = 'all (default)'
137    if not self.get('url', None):
138      raise Error('Archive "%s" has no URL' % host_os)
139    if not self.get('size', None):
140      raise Error('Archive "%s" has no size' % host_os)
141    checksum = self.get('checksum', None)
142    if not checksum:
143      raise Error('Archive "%s" has no checksum' % host_os)
144    elif not isinstance(checksum, dict):
145      raise Error('Archive "%s" has a checksum, but it is not a dict' % host_os)
146    elif not len(checksum):
147      raise Error('Archive "%s" has an empty checksum dict' % host_os)
148    # Verify that all key names are valid.
149    if error_on_unknown_keys:
150      for key in self:
151        if key not in VALID_ARCHIVE_KEYS:
152          raise Error('Archive "%s" has invalid attribute "%s"' % (
153              host_os, key))
154
155  def UpdateVitals(self, revision):
156    """Update the size and checksum information for this archive
157    based on the content currently at the URL.
158
159    This allows the template manifest to be maintained without
160    the need to size and checksums to be present.
161    """
162    template = string.Template(self['url'])
163    self['url'] = template.substitute({'revision': revision})
164    from_stream = urllib2.urlopen(self['url'])
165    sha1_hash, size = DownloadAndComputeHash(from_stream)
166    self['size'] = size
167    self['checksum'] = { 'sha1': sha1_hash }
168
169  def __getattr__(self, name):
170    """Retrieve values from this dict using attributes.
171
172    This allows for foo.bar instead of foo['bar'].
173
174    Args:
175      name: the name of the key, 'bar' in the example above.
176    Returns:
177      The value associated with that key."""
178    if name not in self:
179      raise AttributeError(name)
180    # special case, self.checksum returns the sha1, not the checksum dict.
181    if name == 'checksum':
182      return self.GetChecksum()
183    return self.__getitem__(name)
184
185  def __setattr__(self, name, value):
186    """Set values in this dict using attributes.
187
188    This allows for foo.bar instead of foo['bar'].
189
190    Args:
191      name: The name of the key, 'bar' in the example above.
192      value: The value to associate with that key."""
193    # special case, self.checksum returns the sha1, not the checksum dict.
194    if name == 'checksum':
195      self.setdefault('checksum', {})['sha1'] = value
196      return
197    return self.__setitem__(name, value)
198
199  def GetChecksum(self, hash_type='sha1'):
200    """Returns a given cryptographic checksum of the archive"""
201    return self['checksum'][hash_type]
202
203
204class Bundle(dict):
205  """A placeholder for sdk bundle information. We derive Bundle from
206     dict so that it is easily serializable."""
207
208  def __init__(self, obj):
209    """ Create a new bundle with the given bundle name."""
210    if isinstance(obj, str) or isinstance(obj, unicode):
211      dict.__init__(self, [(ARCHIVES_KEY, []), (NAME_KEY, obj)])
212    else:
213      dict.__init__(self, obj)
214
215  def MergeWithBundle(self, bundle):
216    """Merge this bundle with |bundle|.
217
218    Merges dict in |bundle| with this one in such a way that keys are not
219    duplicated: the values of the keys in |bundle| take precedence in the
220    resulting dictionary.
221
222    Archives in |bundle| will be appended to archives in self.
223
224    Args:
225      bundle: The other bundle.  Must be a dict.
226    """
227    assert self is not bundle
228
229    for k, v in bundle.iteritems():
230      if k == ARCHIVES_KEY:
231        for archive in v:
232          self.get(k, []).append(archive)
233      else:
234        self[k] = v
235
236  def __str__(self):
237    return self.GetDataAsString()
238
239  def GetDataAsString(self):
240    """Returns the JSON bundle object, pretty-printed"""
241    return DictToJSON(self)
242
243  def LoadDataFromString(self, json_string):
244    """Load a JSON bundle string. Raises an exception if json_string
245       is not well-formed JSON.
246
247    Args:
248      json_string: a JSON-formatted string containing the bundle
249    """
250    self.CopyFrom(json.loads(json_string))
251
252  def CopyFrom(self, source):
253    """Update the content of the bundle by copying values from the given
254       dictionary.
255
256    Args:
257      source: The dictionary whose values must be copied to the bundle."""
258    for key, value in source.items():
259      if key == ARCHIVES_KEY:
260        archives = []
261        for a in value:
262          new_archive = Archive(a['host_os'])
263          new_archive.CopyFrom(a)
264          archives.append(new_archive)
265        self[ARCHIVES_KEY] = archives
266      else:
267        self[key] = value
268
269  def Validate(self, add_missing_info=False, error_on_unknown_keys=False):
270    """Validate the content of the bundle. Raise an Error if an invalid or
271       missing field is found.
272
273    Args:
274      error_on_unknown_keys: If True, raise an Error when unknown keys are
275      found in the bundle.
276    """
277    # Check required fields.
278    if not self.get(NAME_KEY):
279      raise Error('Bundle has no name')
280    if self.get(REVISION_KEY) == None:
281      raise Error('Bundle "%s" is missing a revision number' % self[NAME_KEY])
282    if self.get(VERSION_KEY) == None:
283      raise Error('Bundle "%s" is missing a version number' % self[NAME_KEY])
284    if not self.get('description'):
285      raise Error('Bundle "%s" is missing a description' % self[NAME_KEY])
286    if not self.get('stability'):
287      raise Error('Bundle "%s" is missing stability info' % self[NAME_KEY])
288    if self.get('recommended') == None:
289      raise Error('Bundle "%s" is missing the recommended field' %
290                  self[NAME_KEY])
291    # Check specific values
292    if self['stability'] not in STABILITY_LITERALS:
293      raise Error('Bundle "%s" has invalid stability field: "%s"' %
294                  (self[NAME_KEY], self['stability']))
295    if self['recommended'] not in YES_NO_LITERALS:
296      raise Error(
297          'Bundle "%s" has invalid recommended field: "%s"' %
298          (self[NAME_KEY], self['recommended']))
299    # Verify that all key names are valid.
300    if error_on_unknown_keys:
301      for key in self:
302        if key not in VALID_BUNDLES_KEYS:
303          raise Error('Bundle "%s" has invalid attribute "%s"' %
304                      (self[NAME_KEY], key))
305    # Validate the archives
306    for archive in self[ARCHIVES_KEY]:
307      if add_missing_info and 'size' not in archive:
308        archive.UpdateVitals(self[REVISION_KEY])
309      archive.Validate(error_on_unknown_keys)
310
311  def GetArchive(self, host_os_name):
312    """Retrieve the archive for the given host os.
313
314    Args:
315      host_os_name: name of host os whose archive must be retrieved.
316    Return:
317      An Archive instance or None if it doesn't exist."""
318    for archive in self[ARCHIVES_KEY]:
319      if archive.host_os == host_os_name or archive.host_os == 'all':
320        return archive
321    return None
322
323  def GetHostOSArchive(self):
324    """Retrieve the archive for the current host os."""
325    return self.GetArchive(GetHostOS())
326
327  def GetHostOSArchives(self):
328    """Retrieve all archives for the current host os, or marked all.
329    """
330    return [archive for archive in self.GetArchives()
331        if archive.host_os in (GetHostOS(), 'all')]
332
333  def GetArchives(self):
334    """Returns all the archives in this bundle"""
335    return self[ARCHIVES_KEY]
336
337  def AddArchive(self, archive):
338    """Add an archive to this bundle."""
339    self[ARCHIVES_KEY].append(archive)
340
341  def RemoveAllArchives(self):
342    """Remove all archives from this Bundle."""
343    del self[ARCHIVES_KEY][:]
344
345  def RemoveAllArchivesForHostOS(self, host_os_name):
346    """Remove an archive from this Bundle."""
347    if host_os_name == 'all':
348      del self[ARCHIVES_KEY][:]
349    else:
350      for i, archive in enumerate(self[ARCHIVES_KEY]):
351        if archive.host_os == host_os_name:
352          del self[ARCHIVES_KEY][i]
353
354  def __getattr__(self, name):
355    """Retrieve values from this dict using attributes.
356
357    This allows for foo.bar instead of foo['bar'].
358
359    Args:
360      name: the name of the key, 'bar' in the example above.
361    Returns:
362      The value associated with that key."""
363    if name not in self:
364      raise AttributeError(name)
365    return self.__getitem__(name)
366
367  def __setattr__(self, name, value):
368    """Set values in this dict using attributes.
369
370    This allows for foo.bar instead of foo['bar'].
371
372    Args:
373      name: The name of the key, 'bar' in the example above.
374      value: The value to associate with that key."""
375    self.__setitem__(name, value)
376
377  def __eq__(self, bundle):
378    """Test if two bundles are equal.
379
380    Normally the default comparison for two dicts is fine, but in this case we
381    don't care about the list order of the archives.
382
383    Args:
384      bundle: The other bundle to compare against.
385    Returns:
386      True if the bundles are equal."""
387    if not isinstance(bundle, Bundle):
388      return False
389    if len(self.keys()) != len(bundle.keys()):
390      return False
391    for key in self.keys():
392      if key not in bundle:
393        return False
394      # special comparison for ARCHIVE_KEY because we don't care about the list
395      # ordering.
396      if key == ARCHIVES_KEY:
397        if len(self[key]) != len(bundle[key]):
398          return False
399        for archive in self[key]:
400          if archive != bundle.GetArchive(archive.host_os):
401            return False
402      elif self[key] != bundle[key]:
403        return False
404    return True
405
406  def __ne__(self, bundle):
407    """Test if two bundles are unequal.
408
409    See __eq__ for more info."""
410    return not self.__eq__(bundle)
411
412
413class SDKManifest(object):
414  """This class contains utilities for manipulation an SDK manifest string
415
416  For ease of unit-testing, this class should not contain any file I/O.
417  """
418
419  def __init__(self):
420    """Create a new SDKManifest object with default contents"""
421    self._manifest_data = {
422        "manifest_version": MANIFEST_VERSION,
423        "bundles": [],
424        }
425
426  def Validate(self, add_missing_info=False):
427    """Validate the Manifest file and raises an exception for problems"""
428    # Validate the manifest top level
429    if self._manifest_data["manifest_version"] > MANIFEST_VERSION:
430      raise Error("Manifest version too high: %s" %
431                  self._manifest_data["manifest_version"])
432    # Verify that all key names are valid.
433    for key in self._manifest_data:
434      if key not in VALID_MANIFEST_KEYS:
435        raise Error('Manifest has invalid attribute "%s"' % key)
436    # Validate each bundle
437    for bundle in self._manifest_data[BUNDLES_KEY]:
438      bundle.Validate(add_missing_info)
439
440  def GetBundle(self, name):
441    """Get a bundle from the array of bundles.
442
443    Args:
444      name: the name of the bundle to return.
445    Return:
446      The first bundle with the given name, or None if it is not found."""
447    if not BUNDLES_KEY in self._manifest_data:
448      return None
449    bundles = [bundle for bundle in self._manifest_data[BUNDLES_KEY]
450               if bundle[NAME_KEY] == name]
451    if len(bundles) > 1:
452      sys.stderr.write("WARNING: More than one bundle with name"
453                       "'%s' exists.\n" % name)
454    return bundles[0] if len(bundles) > 0 else None
455
456  def GetBundles(self):
457    """Return all the bundles in the manifest."""
458    return self._manifest_data[BUNDLES_KEY]
459
460  def SetBundle(self, new_bundle):
461    """Add or replace a bundle in the manifest.
462
463    Note: If a bundle in the manifest already exists with this name, it will be
464    overwritten with a copy of this bundle, at the same index as the original.
465
466    Args:
467      bundle: The bundle.
468    """
469    name = new_bundle[NAME_KEY]
470    bundles = self.GetBundles()
471    new_bundle_copy = copy.deepcopy(new_bundle)
472    for i, bundle in enumerate(bundles):
473      if bundle[NAME_KEY] == name:
474        bundles[i] = new_bundle_copy
475        return
476    # Bundle not already in list, append it.
477    bundles.append(new_bundle_copy)
478
479  def RemoveBundle(self, name):
480    """Remove a bundle by name.
481
482    Args:
483      name: the name of the bundle to remove.
484    Return:
485      True if the bundle was removed, False if there is no bundle with that
486      name.
487    """
488    if not BUNDLES_KEY in self._manifest_data:
489      return False
490    bundles = self._manifest_data[BUNDLES_KEY]
491    for i, bundle in enumerate(bundles):
492      if bundle[NAME_KEY] == name:
493        del bundles[i]
494        return True
495    return False
496
497  def BundleNeedsUpdate(self, bundle):
498    """Decides if a bundle needs to be updated.
499
500    A bundle needs to be updated if it is not installed (doesn't exist in this
501    manifest file) or if its revision is later than the revision in this file.
502
503    Args:
504      bundle: The Bundle to test.
505    Returns:
506      True if Bundle needs to be updated.
507    """
508    if NAME_KEY not in bundle:
509      raise KeyError("Bundle must have a 'name' key.")
510    local_bundle = self.GetBundle(bundle[NAME_KEY])
511    return (local_bundle == None) or (
512           (local_bundle[VERSION_KEY], local_bundle[REVISION_KEY]) <
513           (bundle[VERSION_KEY], bundle[REVISION_KEY]))
514
515  def MergeBundle(self, bundle, allow_existing=True):
516    """Merge a Bundle into this manifest.
517
518    The new bundle is added if not present, or merged into the existing bundle.
519
520    Args:
521      bundle: The bundle to merge.
522    """
523    if NAME_KEY not in bundle:
524      raise KeyError("Bundle must have a 'name' key.")
525    local_bundle = self.GetBundle(bundle.name)
526    if not local_bundle:
527      self.SetBundle(bundle)
528    else:
529      if not allow_existing:
530        raise Error('cannot merge manifest bundle \'%s\', it already exists'
531                    % bundle.name)
532      local_bundle.MergeWithBundle(bundle)
533
534  def MergeManifest(self, manifest):
535    '''Merge another manifest into this manifest, disallowing overriding.
536
537    Args
538      manifest: The manifest to merge.
539    '''
540    for bundle in manifest.GetBundles():
541      self.MergeBundle(bundle, allow_existing=False)
542
543  def FilterBundles(self, predicate):
544    """Filter the list of bundles by |predicate|.
545
546    For all bundles in this manifest, if predicate(bundle) is False, the bundle
547    is removed from the manifest.
548
549    Args:
550      predicate: a function that take a bundle and returns whether True to keep
551      it or False to remove it.
552    """
553    self._manifest_data[BUNDLES_KEY] = filter(predicate, self.GetBundles())
554
555  def LoadDataFromString(self, json_string, add_missing_info=False):
556    """Load a JSON manifest string. Raises an exception if json_string
557       is not well-formed JSON.
558
559    Args:
560      json_string: a JSON-formatted string containing the previous manifest
561      all_hosts: True indicates that we should load bundles for all hosts.
562          False (default) says to only load bundles for the current host"""
563    new_manifest = json.loads(json_string)
564    for key, value in new_manifest.items():
565      if key == BUNDLES_KEY:
566        # Remap each bundle in |value| to a Bundle instance
567        bundles = []
568        for b in value:
569          new_bundle = Bundle(b[NAME_KEY])
570          new_bundle.CopyFrom(b)
571          bundles.append(new_bundle)
572        self._manifest_data[key] = bundles
573      else:
574        self._manifest_data[key] = value
575    self.Validate(add_missing_info)
576
577  def __str__(self):
578    return self.GetDataAsString()
579
580  def __eq__(self, other):
581    # Access to protected member _manifest_data of a client class
582    # pylint: disable=W0212
583    if (self._manifest_data['manifest_version'] !=
584        other._manifest_data['manifest_version']):
585      return False
586
587    self_bundle_names = set(b.name for b in self.GetBundles())
588    other_bundle_names = set(b.name for b in other.GetBundles())
589    if self_bundle_names != other_bundle_names:
590      return False
591
592    for bundle_name in self_bundle_names:
593      if self.GetBundle(bundle_name) != other.GetBundle(bundle_name):
594        return False
595
596    return True
597
598  def __ne__(self, other):
599    return not (self == other)
600
601  def GetDataAsString(self):
602    """Returns the current JSON manifest object, pretty-printed"""
603    return DictToJSON(self._manifest_data)
604