1import warnings
2
3import numpy
4
5import chainer
6from chainer.dataset.convert import concat_examples
7from chainer.functions.evaluation import accuracy
8from chainer.functions.loss import softmax_cross_entropy
9from chainer import cuda, Variable  # NOQA
10from chainer import reporter
11from chainer_chemistry.models.prediction.base import BaseForwardModel
12
13
14def _argmax(*args):
15    x = args[0]
16    return chainer.functions.argmax(x, axis=1)
17
18
19class Classifier(BaseForwardModel):
20
21    """A simple classifier model.
22
23    This is an example of chain that wraps another chain. It computes the
24    loss and accuracy based on a given input/label pair.
25
26    Args:
27        predictor (~chainer.Link): Predictor network.
28        lossfun (function): Loss function.
29        accfun (function): DEPRECATED. Please use `metrics_fun` instead.
30        metrics_fun (function or dict or None): Function that computes metrics.
31        label_key (int or str): Key to specify label variable from arguments.
32            When it is ``int``, a variable in positional arguments is used.
33            And when it is ``str``, a variable in keyword arguments is used.
34        device (int or chainer._backend.Device):
35             GPU device id of this Regressor to be used.
36             -1 indicates to use in CPU.
37
38    Attributes:
39        predictor (~chainer.Link): Predictor network.
40        lossfun (function): Loss function.
41        accfun (function): DEPRECATED. Please use `metrics_fun` instead.
42        y (~chainer.Variable): Prediction for the last minibatch.
43        loss (~chainer.Variable): Loss value for the last minibatch.
44        metrics (dict): Metrics computed in last minibatch
45        compute_metrics (bool): If ``True``, compute metrics on the forward
46            computation. The default value is ``True``.
47
48    .. note::
49        The differences between original `Classifier` class in chainer and
50        chainer chemistry are as follows.
51        1. `predict` and `predict_proba` methods are supported.
52        2. `device` can be managed internally by the `Classifier`
53        3. `accfun` is deprecated, `metrics_fun` is used instead.
54        4. `metrics_fun` can be `dict` which specifies the metrics name as key
55           and function as value.
56
57    .. note::
58        This link uses :func:`chainer.softmax_cross_entropy` with
59        default arguments as a loss function (specified by ``lossfun``),
60        if users do not explicitly change it. In particular, the loss function
61        does not support double backpropagation.
62        If you need second or higher order differentiation, you need to turn
63        it on with ``enable_double_backprop=True``:
64
65          >>> import chainer.functions as F
66          >>> import chainer.links as L
67          >>>
68          >>> def lossfun(x, t):
69          ...     return F.softmax_cross_entropy(
70          ...         x, t, enable_double_backprop=True)
71          >>>
72          >>> predictor = L.Linear(10)
73          >>> model = L.Classifier(predictor, lossfun=lossfun)
74
75    """
76
77    compute_metrics = True
78
79    def __init__(self, predictor,
80                 lossfun=softmax_cross_entropy.softmax_cross_entropy,
81                 accfun=None, metrics_fun=accuracy.accuracy,
82                 label_key=-1, device=-1):
83        if not (isinstance(label_key, (int, str))):
84            raise TypeError('label_key must be int or str, but is %s' %
85                            type(label_key))
86        if accfun is not None:
87            warnings.warn(
88                'accfun is deprecated, please use metrics_fun instead')
89            warnings.warn('overriding metrics by accfun...')
90            # override metrics by accfun
91            metrics_fun = accfun
92
93        super(Classifier, self).__init__()
94        self.lossfun = lossfun
95        if metrics_fun is None:
96            self.compute_metrics = False
97            self.metrics_fun = {}
98        elif callable(metrics_fun):
99            self.metrics_fun = {'accuracy': metrics_fun}
100        elif isinstance(metrics_fun, dict):
101            self.metrics_fun = metrics_fun
102        else:
103            raise TypeError('Unexpected type metrics_fun must be None or '
104                            'Callable or dict. actual {}'.format(type(accfun)))
105        self.y = None
106        self.loss = None
107        self.metrics = None
108        self.label_key = label_key
109
110        with self.init_scope():
111            self.predictor = predictor
112
113        # `initialize` must be called after `init_scope`.
114        self.initialize(device)
115
116    def _convert_to_scalar(self, value):
117        """Converts an input value to a scalar if its type is a Variable,
118
119        numpy or cupy array, otherwise it returns the value as it is.
120        """
121        if isinstance(value, Variable):
122            value = value.array
123        if numpy.isscalar(value):
124            return value
125        if type(value) is not numpy.array:
126            value = cuda.to_cpu(value)
127        return numpy.asscalar(value)
128
129    def __call__(self, *args, **kwargs):
130        """Computes the loss value for an input and label pair.
131
132        It also computes accuracy and stores it to the attribute.
133
134        Args:
135            args (list of ~chainer.Variable): Input minibatch.
136            kwargs (dict of ~chainer.Variable): Input minibatch.
137
138        When ``label_key`` is ``int``, the correpoding element in ``args``
139        is treated as ground truth labels. And when it is ``str``, the
140        element in ``kwargs`` is used.
141        The all elements of ``args`` and ``kwargs`` except the ground trush
142        labels are features.
143        It feeds features to the predictor and compare the result
144        with ground truth labels.
145
146        Returns:
147            ~chainer.Variable: Loss value.
148
149        """
150
151        # --- Separate `args` and `t` ---
152        if isinstance(self.label_key, int):
153            if not (-len(args) <= self.label_key < len(args)):
154                msg = 'Label key %d is out of bounds' % self.label_key
155                raise ValueError(msg)
156            t = args[self.label_key]
157            if self.label_key == -1:
158                args = args[:-1]
159            else:
160                args = args[:self.label_key] + args[self.label_key + 1:]
161        elif isinstance(self.label_key, str):
162            if self.label_key not in kwargs:
163                msg = 'Label key "%s" is not found' % self.label_key
164                raise ValueError(msg)
165            t = kwargs[self.label_key]
166            del kwargs[self.label_key]
167        else:
168            raise TypeError('Label key type {} not supported'
169                            .format(type(self.label_key)))
170
171        self.y = None
172        self.loss = None
173        self.metrics = None
174        self.y = self.predictor(*args, **kwargs)
175        self.loss = self.lossfun(self.y, t)
176        reporter.report(
177            {'loss': self._convert_to_scalar(self.loss)}, self)
178        if self.compute_metrics:
179            # Note: self.accuracy is `dict`, which is different from original
180            # chainer implementation
181            self.metrics = {key: self._convert_to_scalar(value(self.y, t))
182                            for key, value in self.metrics_fun.items()}
183            reporter.report(self.metrics, self)
184        return self.loss
185
186    def predict_proba(
187            self, data, batchsize=16, converter=concat_examples,
188            retain_inputs=False, preprocess_fn=None,
189            postprocess_fn=chainer.functions.softmax):
190        """Calculate probability of each category.
191
192        Args:
193            data: "train_x array" or "chainer dataset"
194            fn (Callable): Main function to forward. Its input argument is
195                either Variable, cupy.ndarray or numpy.ndarray, and returns
196                Variable.
197            batchsize (int): batch size
198            converter (Callable): convert from `data` to `inputs`
199            preprocess_fn (Callable): Its input is numpy.ndarray or
200                cupy.ndarray, it can return either Variable, cupy.ndarray or
201                numpy.ndarray
202            postprocess_fn (Callable): Its input argument is Variable,
203                but this method may return either Variable, cupy.ndarray or
204                numpy.ndarray.
205            retain_inputs (bool): If True, this instance keeps inputs in
206                `self.inputs` or not.
207
208        Returns (tuple or numpy.ndarray): Typically, it is 2-dimensional float
209            array with shape (batchsize, number of category) which represents
210            each examples probability to be each category.
211
212        """
213        with chainer.no_backprop_mode(), chainer.using_config('train', False):
214            proba = self._forward(
215                data, fn=self.predictor, batchsize=batchsize,
216                converter=converter, retain_inputs=retain_inputs,
217                preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn)
218        return proba
219
220    def predict(
221            self, data, batchsize=16, converter=concat_examples,
222            retain_inputs=False, preprocess_fn=None, postprocess_fn=_argmax):
223        """Predict label of each category by taking .
224
225        Args:
226            data: input data
227            batchsize (int): batch size
228            converter (Callable): convert from `data` to `inputs`
229            preprocess_fn (Callable): Its input is numpy.ndarray or
230                cupy.ndarray, it can return either Variable, cupy.ndarray or
231                numpy.ndarray
232            postprocess_fn (Callable): Its input argument is Variable,
233                but this method may return either Variable, cupy.ndarray or
234                numpy.ndarray.
235            retain_inputs (bool): If True, this instance keeps inputs in
236                `self.inputs` or not.
237
238        Returns (tuple or numpy.ndarray): Typically, it is 1-dimensional int
239            array with shape (batchsize, ) which represents each examples
240            category prediction.
241
242        """
243        with chainer.no_backprop_mode(), chainer.using_config('train', False):
244            predict_labels = self._forward(
245                data, fn=self.predictor, batchsize=batchsize,
246                converter=converter, retain_inputs=retain_inputs,
247                preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn)
248        return predict_labels
249
250    # --- For backward compatibility ---
251    @property
252    def compute_accuracy(self):
253        warnings.warn('compute_accuracy is deprecated,'
254                      'please use compute_metrics instead')
255        return self.compute_metrics
256
257    @compute_accuracy.setter
258    def compute_accuracy(self, value):
259        warnings.warn('compute_accuracy is deprecated,'
260                      'please use compute_metrics instead')
261        self.compute_metrics = value
262
263    @property
264    def accuracy(self):
265        warnings.warn('accuracy is deprecated,'
266                      'please use metrics instead')
267        return self.metrics
268
269    @accuracy.setter
270    def accuracy(self, value):
271        warnings.warn('accuracy is deprecated,'
272                      'please use metrics instead')
273        self.metrics = value
274
275    @property
276    def accfun(self):
277        warnings.warn('accfun is deprecated,'
278                      'please use metrics_fun instead')
279        return self.metrics_fun
280
281    @accfun.setter
282    def accfun(self, value):
283        warnings.warn('accfun is deprecated,'
284                      'please use metrics_fun instead')
285        self.metrics_fun = value
286