1# Copyright 2016 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Classes for dealing with I/O from ML pipelines.
15"""
16
17import csv
18import datetime
19import json
20import logging
21
22import apache_beam as beam
23from six.moves import cStringIO
24import yaml
25
26from google.cloud.ml.util import _decoders
27from google.cloud.ml.util import _file
28
29
30# TODO(user): Use a ProtoCoder once b/29055158 is resolved.
31class ExampleProtoCoder(beam.coders.Coder):
32  """A coder to encode and decode TensorFlow Example objects."""
33
34  def __init__(self):
35    import tensorflow as tf  # pylint: disable=g-import-not-at-top
36    self._tf_train = tf.train
37
38  def encode(self, example_proto):
39    """Encodes Tensorflow example object to a serialized string.
40
41    Args:
42      example_proto: A Tensorflow Example object
43
44    Returns:
45      String.
46    """
47    return example_proto.SerializeToString()
48
49  def decode(self, serialized_str):
50    """Decodes a serialized string into a Tensorflow Example object.
51
52    Args:
53      serialized_str: string
54
55    Returns:
56      Tensorflow Example object.
57    """
58    example = self._tf_train.Example()
59    example.ParseFromString(serialized_str)
60    return example
61
62
63class JsonCoder(beam.coders.Coder):
64  """A coder to encode and decode JSON formatted data."""
65
66  def __init__(self, indent=None):
67    self._indent = indent
68
69  def encode(self, obj):
70    """Encodes a python object into a JSON string.
71
72    Args:
73      obj: python object.
74
75    Returns:
76      JSON string.
77    """
78    # Supplying seperators to avoid unnecessary trailing whitespaces.
79    return json.dumps(obj, indent=self._indent, separators=(',', ': '))
80
81  def decode(self, json_string):
82    """Decodes a JSON string to a python object.
83
84    Args:
85      json_string: A JSON string.
86
87    Returns:
88      A python object.
89    """
90    return json.loads(json_string)
91
92
93class CsvCoder(beam.coders.Coder):
94  """A coder to encode and decode CSV formatted data.
95  """
96
97  class _WriterWrapper(object):
98    """A wrapper for csv.writer / csv.DictWriter to make it picklable."""
99
100    def __init__(self, column_names, delimiter, decode_to_dict):
101      self._state = (column_names, delimiter, decode_to_dict)
102      self._buffer = cStringIO()
103      if decode_to_dict:
104        self._writer = csv.DictWriter(
105            self._buffer,
106            column_names,
107            lineterminator='',
108            delimiter=delimiter)
109      else:
110        self._writer = csv.writer(
111            self._buffer,
112            lineterminator='',
113            delimiter=delimiter)
114
115    def encode_record(self, record):
116      self._writer.writerow(record)
117      value = self._buffer.getvalue()
118      # Reset the buffer.
119      self._buffer.seek(0)
120      self._buffer.truncate(0)
121      return value
122
123    def __getstate__(self):
124      return self._state
125
126    def __setstate__(self, state):
127      self.__init__(*state)
128
129  def __init__(self, column_names, numeric_column_names, delimiter=',',
130               decode_to_dict=True, fail_on_error=True,
131               skip_initial_space=False):
132    """Initializes CsvCoder.
133
134    Args:
135      column_names: Tuple of strings. Order must match the order in the file.
136      numeric_column_names: Tuple of strings. Contains column names that are
137          numeric. Every name in numeric_column_names must also be in
138          column_names.
139      delimiter: A one-character string used to separate fields.
140      decode_to_dict: Boolean indicating whether the docoder should generate a
141          dictionary instead of a raw sequence. True by default.
142      fail_on_error: Whether to fail if a corrupt row is found. Default is True.
143      skip_initial_space: When True, whitespace immediately following the
144          delimiter is ignored when reading.
145    """
146    self._decoder = _decoders.CsvDecoder(
147        column_names, numeric_column_names, delimiter, decode_to_dict,
148        fail_on_error, skip_initial_space)
149    self._encoder = self._WriterWrapper(
150        column_names=column_names,
151        delimiter=delimiter,
152        decode_to_dict=decode_to_dict)
153
154  def decode(self, csv_line):
155    """Decode csv line into a python dict.
156
157    Args:
158      csv_line: String. One csv line from the file.
159
160    Returns:
161      Python dict where the keys are the column names from the file. The dict
162      values are strings or numbers depending if a column name was listed in
163      numeric_column_names. Missing string columns have the value '', while
164      missing numeric columns have the value None. If there is an error in
165      parsing csv_line, a python dict is returned where every value is '' or
166      None.
167
168    Raises:
169      Exception: The number of columns to not match.
170    """
171    return self._decoder.decode(csv_line)
172
173  def encode(self, python_data):
174    """Encode python dict to a csv-formatted string.
175
176    Args:
177      python_data: A python collection, depending on the value of decode_to_dict
178          it will be a python dictionary where the keys are the column names or
179          a sequence.
180
181    Returns:
182      A csv-formatted string. The order of the columns is given by column_names.
183    """
184    return self._encoder.encode_record(python_data)
185
186
187class YamlCoder(beam.coders.Coder):
188  """A coder to encode and decode YAML formatted data."""
189
190  def __init__(self):
191    """Trying to use the efficient libyaml library to encode and decode yaml.
192
193    If libyaml is not available than we fallback to use the native yaml library,
194    use with caution; it is far less efficient, uses excessive memory, and leaks
195    memory.
196    """
197    # TODO(user): Always use libyaml once possible.
198    if yaml.__with_libyaml__:
199      self._safe_dumper = yaml.CSafeDumper
200      self._safe_loader = yaml.CSafeLoader
201    else:
202      logging.warning(
203          'Can\'t find libyaml so it is not used for YamlCoder, the '
204          'implementation used is far slower and has a memory leak.')
205      self._safe_dumper = yaml.SafeDumper
206      self._safe_loader = yaml.SafeLoader
207
208  def encode(self, obj):
209    """Encodes a python object into a YAML string.
210
211    Args:
212      obj: python object.
213
214    Returns:
215      YAML string.
216    """
217    return yaml.dump(
218        obj,
219        default_flow_style=False,
220        encoding='utf-8',
221        Dumper=self._safe_dumper)
222
223  def decode(self, yaml_string):
224    """Decodes a YAML string to a python object.
225
226    Args:
227      yaml_string: A YAML string.
228
229    Returns:
230      A python object.
231    """
232    return yaml.load(yaml_string, Loader=self._safe_loader)
233
234
235class MetadataCoder(beam.coders.Coder):
236  """A coder to encode and decode CloudML metadata."""
237
238  def encode(self, obj):
239    """Encodes a python object into a YAML string.
240
241    Args:
242      obj: python object.
243
244    Returns:
245      JSON string.
246    """
247    return JsonCoder(indent=1).encode(obj)
248
249  def decode(self, metadata_string):
250    """Decodes a metadata string to a python object.
251
252    Args:
253      metadata_string: A metadata string, either in json or yaml format.
254
255    Returns:
256      A python object.
257    """
258    return self._decode_internal(metadata_string)
259
260  @classmethod
261  def load_from(cls, path):
262    """Reads a metadata file.
263
264    Assums it's in json format by default and falls back to yaml format if that
265    fails.
266
267    Args:
268      path: A metadata file path string.
269
270    Returns:
271      A decoded metadata object.
272    """
273    data = _file.load_file(path)
274    return cls._decode_internal(data)
275
276  @staticmethod
277  def _decode_internal(metadata_string):
278    try:
279      return JsonCoder().decode(metadata_string)
280    except ValueError:
281      return YamlCoder().decode(metadata_string)
282
283
284class TrainingJobRequestCoder(beam.coders.Coder):
285  """Custom coder for a TrainingJobRequest object."""
286
287  def encode(self, training_job_request):
288    """Encode a TrainingJobRequest to a JSON string.
289
290    Args:
291      training_job_request: A TrainingJobRequest object.
292
293    Returns:
294      A JSON string
295    """
296    d = {}
297    d.update(training_job_request.__dict__)
298
299    # We need to convert timedelta values for values that are json encodable.
300    for k in ['timeout', 'polling_interval']:
301      if d[k]:
302        d[k] = d[k].total_seconds()
303    return json.dumps(d)
304
305  def decode(self, training_job_request_string):
306    """Decode a JSON string representing a TrainingJobRequest.
307
308    Args:
309      training_job_request_string: A string representing a TrainingJobRequest.
310
311    Returns:
312      TrainingJobRequest object.
313    """
314    r = TrainingJobRequest()
315    d = json.loads(training_job_request_string)
316
317    # We need to parse timedelata values.
318    for k in ['timeout', 'polling_interval']:
319      if d[k]:
320        d[k] = datetime.timedelta(seconds=d[k])
321
322    r.__dict__.update(d)
323    return r
324
325
326class TrainingJobResultCoder(beam.coders.Coder):
327  """Custom coder for TrainingJobResult."""
328
329  def encode(self, training_job_result):
330    """Encode a TrainingJobResult object into a JSON string.
331
332    Args:
333      training_job_result: A TrainingJobResult object.
334
335    Returns:
336      A JSON string
337    """
338    d = {}
339    d.update(training_job_result.__dict__)
340
341    # We need to properly encode the request.
342    if d['training_request'] is not None:
343      coder = TrainingJobRequestCoder()
344      d['training_request'] = coder.encode(d['training_request'])
345    return json.dumps(d)
346
347  def decode(self, training_job_result_string):
348    """Decode a string to a TrainingJobResult object.
349
350    Args:
351      training_job_result_string: A string representing a TrainingJobResult.
352
353    Returns:
354      A TrainingJobResult object.
355    """
356    r = TrainingJobResult()
357    d = json.loads(training_job_result_string)
358
359    # We need to properly encode the request.
360    if d['training_request'] is not None:
361      coder = TrainingJobRequestCoder()
362      d['training_request'] = coder.decode(d['training_request'])
363
364    r.__dict__.update(d)
365    return r
366
367
368class TrainingJobRequest(object):
369  """This class contains the parameters for running a training job.
370  """
371
372  def __init__(self,
373               parent=None,
374               job_name=None,
375               job_args=None,
376               package_uris=None,
377               python_module=None,
378               timeout=None,
379               polling_interval=datetime.timedelta(seconds=30),
380               scale_tier=None,
381               hyperparameters=None,
382               region=None,
383               master_type=None,
384               worker_type=None,
385               ps_type=None,
386               worker_count=None,
387               ps_count=None,
388               endpoint=None,
389               runtime_version=None):
390    """Construct an instance of TrainingSpec.
391
392    Args:
393      parent: The project name. This is named parent because the parent object
394          of jobs is the project.
395      job_name: A job name. This must be unique within the project.
396      job_args: Additional arguments to pass to the job.
397      package_uris: A list of URIs to tarballs with the training program.
398      python_module: The module name of the python file within the tarball.
399      timeout: A datetime.timedelta expressing the amount of time to wait before
400          giving up. The timeout applies to a single invocation of the process
401          method in TrainModelDo. A DoFn can be retried several times before a
402          pipeline fails.
403      polling_interval: A datetime.timedelta to represent the amount of time to
404          wait between requests polling for the files.
405      scale_tier: Google Cloud ML tier to run in.
406      hyperparameters: (Optional) Hyperparameter config to use for the job.
407      region: (Optional) Google Cloud region in which to run.
408      master_type: Master type to use with a CUSTOM scale tier.
409      worker_type: Worker type to use with a CUSTOM scale tier.
410      ps_type: Parameter Server type to use with a CUSTOM scale tier.
411      worker_count: Worker count to use with a CUSTOM scale tier.
412      ps_count: Parameter Server count to use with a CUSTOM scale tier.
413      endpoint: (Optional) The endpoint for the Cloud ML API.
414      runtime_version: (Optional) the Google Cloud ML runtime version to use.
415
416    """
417    self.parent = parent
418    self.job_name = job_name
419    self.job_args = job_args
420    self.python_module = python_module
421    self.package_uris = package_uris
422    self.scale_tier = scale_tier
423    self.hyperparameters = hyperparameters
424    self.region = region
425    self.master_type = master_type
426    self.worker_type = worker_type
427    self.ps_type = ps_type
428    self.worker_count = worker_count
429    self.ps_count = ps_count
430    self.timeout = timeout
431    self.polling_interval = polling_interval
432    self.endpoint = endpoint
433    self.runtime_version = runtime_version
434
435  @property
436  def project(self):
437    return self.parent
438
439  def copy(self):
440    """Return a copy of the object."""
441    r = TrainingJobRequest()
442    r.__dict__.update(self.__dict__)
443
444    return r
445
446  def __eq__(self, o):
447    for f in ['parent', 'job_name', 'job_args', 'package_uris', 'python_module',
448              'timeout', 'polling_interval', 'endpoint', 'hyperparameters',
449              'scale_tier', 'worker_type', 'ps_type', 'master_type', 'region',
450              'ps_count', 'worker_count', 'runtime_version']:
451      if getattr(self, f) != getattr(o, f):
452        return False
453
454    return True
455
456  def __ne__(self, o):
457    return not self == o
458
459  def __repr__(self):
460    fields = []
461    for k, v in self.__dict__.iteritems():
462      fields.append('{0}={1}'.format(k, v))
463    return 'TrainingJobRequest({0})'.format(', '.join(fields))
464
465# Register coder for this class.
466beam.coders.registry.register_coder(TrainingJobRequest, TrainingJobRequestCoder)
467
468
469class TrainingJobResult(object):
470  """Result of training a model."""
471
472  def __init__(self):
473    # A copy of the training request that created the job.
474    self.training_request = None
475
476    # An instance of TrainingJobMetadata as returned by the API.
477    self.training_job_metadata = None
478
479    # At most one of error and training_job_result will be specified.
480    # These fields will only be supplied if the job completed.
481    # training_job_result will be provided if the job completed successfully
482    # and error will be supplied otherwise.
483    self.error = None
484    self.training_job_result = None
485
486  def __eq__(self, o):
487    for f in ['training_request', 'training_job_metadata', 'error',
488              'training_job_result']:
489      if getattr(self, f) != getattr(o, f):
490        return False
491
492    return True
493
494  def __ne__(self, o):
495    return not self == o
496
497  def __repr__(self):
498    fields = []
499    for k, v in self.__dict__.iteritems():
500      fields.append('{0}={1}'.format(k, v))
501    return 'TrainingJobResult({0})'.format(', '.join(fields))
502
503# Register coder for this class.
504beam.coders.registry.register_coder(TrainingJobResult, TrainingJobResultCoder)
505