1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
20"""Gluon Batch Processor for Estimators"""
21
22from ...utils import split_and_load
23from .... import autograd
24
25__all__ = ['BatchProcessor']
26
27class BatchProcessor(object):
28    """BatchProcessor Class for plug and play fit_batch & evaluate_batch
29
30    During training or validation, data are divided into minibatches for processing. This
31    class aims at providing hooks of training or validating on a minibatch of data. Users
32    may provide customized fit_batch() and evaluate_batch() methods by inheriting from
33    this class and overriding class methods.
34
35    :py:class:`BatchProcessor` can be used to replace fit_batch() and evaluate_batch()
36    in the base estimator class
37    """
38
39    def __init__(self):
40        pass
41
42    def _get_data_and_label(self, batch, ctx, batch_axis=0):
43        data = batch[0]
44        label = batch[1]
45        data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
46        label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
47        return data, label
48
49    def evaluate_batch(self, estimator,
50                       val_batch,
51                       batch_axis=0):
52        """Evaluate the estimator model on a batch of validation data.
53
54        Parameters
55        ----------
56        estimator : Estimator
57            Reference to the estimator
58        val_batch : tuple
59            Data and label of a batch from the validation data loader.
60        batch_axis : int, default 0
61            Batch axis to split the validation data into devices.
62        """
63        data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis)
64        pred = [estimator.val_net(x) for x in data]
65        loss = [estimator.val_loss(y_hat, y) for y_hat, y in zip(pred, label)]
66
67        return data, label, pred, loss
68
69    def fit_batch(self, estimator,
70                  train_batch,
71                  batch_axis=0):
72        """Trains the estimator model on a batch of training data.
73
74        Parameters
75        ----------
76        estimator : Estimator
77            Reference to the estimator
78        train_batch : tuple
79            Data and label of a batch from the training data loader.
80        batch_axis : int, default 0
81            Batch axis to split the training data into devices.
82
83        Returns
84        -------
85        data: List of NDArray
86            Sharded data from the batch. Data is sharded with
87            `gluon.split_and_load`.
88        label: List of NDArray
89            Sharded label from the batch. Labels are sharded with
90            `gluon.split_and_load`.
91        pred: List of NDArray
92            Prediction on each of the sharded inputs.
93        loss: List of NDArray
94            Loss on each of the sharded inputs.
95        """
96        data, label = self._get_data_and_label(train_batch, estimator.context, batch_axis)
97
98        with autograd.record():
99            pred = [estimator.net(x) for x in data]
100            loss = [estimator.loss(y_hat, y) for y_hat, y in zip(pred, label)]
101
102        for l in loss:
103            l.backward()
104
105        return data, label, pred, loss
106