1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable=wildcard-import, unused-argument, too-many-ancestors 20"""Gluon Batch Processor for Estimators""" 21 22from ...utils import split_and_load 23from .... import autograd 24 25__all__ = ['BatchProcessor'] 26 27class BatchProcessor(object): 28 """BatchProcessor Class for plug and play fit_batch & evaluate_batch 29 30 During training or validation, data are divided into minibatches for processing. This 31 class aims at providing hooks of training or validating on a minibatch of data. Users 32 may provide customized fit_batch() and evaluate_batch() methods by inheriting from 33 this class and overriding class methods. 34 35 :py:class:`BatchProcessor` can be used to replace fit_batch() and evaluate_batch() 36 in the base estimator class 37 """ 38 39 def __init__(self): 40 pass 41 42 def _get_data_and_label(self, batch, ctx, batch_axis=0): 43 data = batch[0] 44 label = batch[1] 45 data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis) 46 label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis) 47 return data, label 48 49 def evaluate_batch(self, estimator, 50 val_batch, 51 batch_axis=0): 52 """Evaluate the estimator model on a batch of validation data. 53 54 Parameters 55 ---------- 56 estimator : Estimator 57 Reference to the estimator 58 val_batch : tuple 59 Data and label of a batch from the validation data loader. 60 batch_axis : int, default 0 61 Batch axis to split the validation data into devices. 62 """ 63 data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis) 64 pred = [estimator.val_net(x) for x in data] 65 loss = [estimator.val_loss(y_hat, y) for y_hat, y in zip(pred, label)] 66 67 return data, label, pred, loss 68 69 def fit_batch(self, estimator, 70 train_batch, 71 batch_axis=0): 72 """Trains the estimator model on a batch of training data. 73 74 Parameters 75 ---------- 76 estimator : Estimator 77 Reference to the estimator 78 train_batch : tuple 79 Data and label of a batch from the training data loader. 80 batch_axis : int, default 0 81 Batch axis to split the training data into devices. 82 83 Returns 84 ------- 85 data: List of NDArray 86 Sharded data from the batch. Data is sharded with 87 `gluon.split_and_load`. 88 label: List of NDArray 89 Sharded label from the batch. Labels are sharded with 90 `gluon.split_and_load`. 91 pred: List of NDArray 92 Prediction on each of the sharded inputs. 93 loss: List of NDArray 94 Loss on each of the sharded inputs. 95 """ 96 data, label = self._get_data_and_label(train_batch, estimator.context, batch_axis) 97 98 with autograd.record(): 99 pred = [estimator.net(x) for x in data] 100 loss = [estimator.loss(y_hat, y) for y_hat, y in zip(pred, label)] 101 102 for l in loss: 103 l.backward() 104 105 return data, label, pred, loss 106