1import warnings 2 3import numpy 4 5import chainer 6from chainer.dataset.convert import concat_examples 7from chainer.functions.evaluation import accuracy 8from chainer.functions.loss import softmax_cross_entropy 9from chainer import cuda, Variable # NOQA 10from chainer import reporter 11from chainer_chemistry.models.prediction.base import BaseForwardModel 12 13 14def _argmax(*args): 15 x = args[0] 16 return chainer.functions.argmax(x, axis=1) 17 18 19class Classifier(BaseForwardModel): 20 21 """A simple classifier model. 22 23 This is an example of chain that wraps another chain. It computes the 24 loss and accuracy based on a given input/label pair. 25 26 Args: 27 predictor (~chainer.Link): Predictor network. 28 lossfun (function): Loss function. 29 accfun (function): DEPRECATED. Please use `metrics_fun` instead. 30 metrics_fun (function or dict or None): Function that computes metrics. 31 label_key (int or str): Key to specify label variable from arguments. 32 When it is ``int``, a variable in positional arguments is used. 33 And when it is ``str``, a variable in keyword arguments is used. 34 device (int or chainer._backend.Device): 35 GPU device id of this Regressor to be used. 36 -1 indicates to use in CPU. 37 38 Attributes: 39 predictor (~chainer.Link): Predictor network. 40 lossfun (function): Loss function. 41 accfun (function): DEPRECATED. Please use `metrics_fun` instead. 42 y (~chainer.Variable): Prediction for the last minibatch. 43 loss (~chainer.Variable): Loss value for the last minibatch. 44 metrics (dict): Metrics computed in last minibatch 45 compute_metrics (bool): If ``True``, compute metrics on the forward 46 computation. The default value is ``True``. 47 48 .. note:: 49 The differences between original `Classifier` class in chainer and 50 chainer chemistry are as follows. 51 1. `predict` and `predict_proba` methods are supported. 52 2. `device` can be managed internally by the `Classifier` 53 3. `accfun` is deprecated, `metrics_fun` is used instead. 54 4. `metrics_fun` can be `dict` which specifies the metrics name as key 55 and function as value. 56 57 .. note:: 58 This link uses :func:`chainer.softmax_cross_entropy` with 59 default arguments as a loss function (specified by ``lossfun``), 60 if users do not explicitly change it. In particular, the loss function 61 does not support double backpropagation. 62 If you need second or higher order differentiation, you need to turn 63 it on with ``enable_double_backprop=True``: 64 65 >>> import chainer.functions as F 66 >>> import chainer.links as L 67 >>> 68 >>> def lossfun(x, t): 69 ... return F.softmax_cross_entropy( 70 ... x, t, enable_double_backprop=True) 71 >>> 72 >>> predictor = L.Linear(10) 73 >>> model = L.Classifier(predictor, lossfun=lossfun) 74 75 """ 76 77 compute_metrics = True 78 79 def __init__(self, predictor, 80 lossfun=softmax_cross_entropy.softmax_cross_entropy, 81 accfun=None, metrics_fun=accuracy.accuracy, 82 label_key=-1, device=-1): 83 if not (isinstance(label_key, (int, str))): 84 raise TypeError('label_key must be int or str, but is %s' % 85 type(label_key)) 86 if accfun is not None: 87 warnings.warn( 88 'accfun is deprecated, please use metrics_fun instead') 89 warnings.warn('overriding metrics by accfun...') 90 # override metrics by accfun 91 metrics_fun = accfun 92 93 super(Classifier, self).__init__() 94 self.lossfun = lossfun 95 if metrics_fun is None: 96 self.compute_metrics = False 97 self.metrics_fun = {} 98 elif callable(metrics_fun): 99 self.metrics_fun = {'accuracy': metrics_fun} 100 elif isinstance(metrics_fun, dict): 101 self.metrics_fun = metrics_fun 102 else: 103 raise TypeError('Unexpected type metrics_fun must be None or ' 104 'Callable or dict. actual {}'.format(type(accfun))) 105 self.y = None 106 self.loss = None 107 self.metrics = None 108 self.label_key = label_key 109 110 with self.init_scope(): 111 self.predictor = predictor 112 113 # `initialize` must be called after `init_scope`. 114 self.initialize(device) 115 116 def _convert_to_scalar(self, value): 117 """Converts an input value to a scalar if its type is a Variable, 118 119 numpy or cupy array, otherwise it returns the value as it is. 120 """ 121 if isinstance(value, Variable): 122 value = value.array 123 if numpy.isscalar(value): 124 return value 125 if type(value) is not numpy.array: 126 value = cuda.to_cpu(value) 127 return numpy.asscalar(value) 128 129 def __call__(self, *args, **kwargs): 130 """Computes the loss value for an input and label pair. 131 132 It also computes accuracy and stores it to the attribute. 133 134 Args: 135 args (list of ~chainer.Variable): Input minibatch. 136 kwargs (dict of ~chainer.Variable): Input minibatch. 137 138 When ``label_key`` is ``int``, the correpoding element in ``args`` 139 is treated as ground truth labels. And when it is ``str``, the 140 element in ``kwargs`` is used. 141 The all elements of ``args`` and ``kwargs`` except the ground trush 142 labels are features. 143 It feeds features to the predictor and compare the result 144 with ground truth labels. 145 146 Returns: 147 ~chainer.Variable: Loss value. 148 149 """ 150 151 # --- Separate `args` and `t` --- 152 if isinstance(self.label_key, int): 153 if not (-len(args) <= self.label_key < len(args)): 154 msg = 'Label key %d is out of bounds' % self.label_key 155 raise ValueError(msg) 156 t = args[self.label_key] 157 if self.label_key == -1: 158 args = args[:-1] 159 else: 160 args = args[:self.label_key] + args[self.label_key + 1:] 161 elif isinstance(self.label_key, str): 162 if self.label_key not in kwargs: 163 msg = 'Label key "%s" is not found' % self.label_key 164 raise ValueError(msg) 165 t = kwargs[self.label_key] 166 del kwargs[self.label_key] 167 else: 168 raise TypeError('Label key type {} not supported' 169 .format(type(self.label_key))) 170 171 self.y = None 172 self.loss = None 173 self.metrics = None 174 self.y = self.predictor(*args, **kwargs) 175 self.loss = self.lossfun(self.y, t) 176 reporter.report( 177 {'loss': self._convert_to_scalar(self.loss)}, self) 178 if self.compute_metrics: 179 # Note: self.accuracy is `dict`, which is different from original 180 # chainer implementation 181 self.metrics = {key: self._convert_to_scalar(value(self.y, t)) 182 for key, value in self.metrics_fun.items()} 183 reporter.report(self.metrics, self) 184 return self.loss 185 186 def predict_proba( 187 self, data, batchsize=16, converter=concat_examples, 188 retain_inputs=False, preprocess_fn=None, 189 postprocess_fn=chainer.functions.softmax): 190 """Calculate probability of each category. 191 192 Args: 193 data: "train_x array" or "chainer dataset" 194 fn (Callable): Main function to forward. Its input argument is 195 either Variable, cupy.ndarray or numpy.ndarray, and returns 196 Variable. 197 batchsize (int): batch size 198 converter (Callable): convert from `data` to `inputs` 199 preprocess_fn (Callable): Its input is numpy.ndarray or 200 cupy.ndarray, it can return either Variable, cupy.ndarray or 201 numpy.ndarray 202 postprocess_fn (Callable): Its input argument is Variable, 203 but this method may return either Variable, cupy.ndarray or 204 numpy.ndarray. 205 retain_inputs (bool): If True, this instance keeps inputs in 206 `self.inputs` or not. 207 208 Returns (tuple or numpy.ndarray): Typically, it is 2-dimensional float 209 array with shape (batchsize, number of category) which represents 210 each examples probability to be each category. 211 212 """ 213 with chainer.no_backprop_mode(), chainer.using_config('train', False): 214 proba = self._forward( 215 data, fn=self.predictor, batchsize=batchsize, 216 converter=converter, retain_inputs=retain_inputs, 217 preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn) 218 return proba 219 220 def predict( 221 self, data, batchsize=16, converter=concat_examples, 222 retain_inputs=False, preprocess_fn=None, postprocess_fn=_argmax): 223 """Predict label of each category by taking . 224 225 Args: 226 data: input data 227 batchsize (int): batch size 228 converter (Callable): convert from `data` to `inputs` 229 preprocess_fn (Callable): Its input is numpy.ndarray or 230 cupy.ndarray, it can return either Variable, cupy.ndarray or 231 numpy.ndarray 232 postprocess_fn (Callable): Its input argument is Variable, 233 but this method may return either Variable, cupy.ndarray or 234 numpy.ndarray. 235 retain_inputs (bool): If True, this instance keeps inputs in 236 `self.inputs` or not. 237 238 Returns (tuple or numpy.ndarray): Typically, it is 1-dimensional int 239 array with shape (batchsize, ) which represents each examples 240 category prediction. 241 242 """ 243 with chainer.no_backprop_mode(), chainer.using_config('train', False): 244 predict_labels = self._forward( 245 data, fn=self.predictor, batchsize=batchsize, 246 converter=converter, retain_inputs=retain_inputs, 247 preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn) 248 return predict_labels 249 250 # --- For backward compatibility --- 251 @property 252 def compute_accuracy(self): 253 warnings.warn('compute_accuracy is deprecated,' 254 'please use compute_metrics instead') 255 return self.compute_metrics 256 257 @compute_accuracy.setter 258 def compute_accuracy(self, value): 259 warnings.warn('compute_accuracy is deprecated,' 260 'please use compute_metrics instead') 261 self.compute_metrics = value 262 263 @property 264 def accuracy(self): 265 warnings.warn('accuracy is deprecated,' 266 'please use metrics instead') 267 return self.metrics 268 269 @accuracy.setter 270 def accuracy(self, value): 271 warnings.warn('accuracy is deprecated,' 272 'please use metrics instead') 273 self.metrics = value 274 275 @property 276 def accfun(self): 277 warnings.warn('accfun is deprecated,' 278 'please use metrics_fun instead') 279 return self.metrics_fun 280 281 @accfun.setter 282 def accfun(self, value): 283 warnings.warn('accfun is deprecated,' 284 'please use metrics_fun instead') 285 self.metrics_fun = value 286