1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Blocks for sampled losses."""
19__all__ = ['ISDense', 'NCEDense', 'SparseISDense', 'SparseNCEDense']
20
21from mxnet import nd
22from mxnet.gluon import Block, HybridBlock
23
24class _SampledDenseHelper(HybridBlock):
25    """A helper Block for calculating sampled pred.
26
27    Parameters
28    ----------
29    num_classes: int
30        Number of possible classes.
31    num_sampled: int
32        Number of classes randomly sampled for each batch.
33    in_unit: int
34        Dimensionality of the input space.
35    remove_accidental_hits: bool
36        Whether to remove "accidental hits" when a sampled candidate is equal to
37        one of the true classes.
38    sparse_label: bool
39        Whether to output label as an integer array instead of probability distribution.
40    """
41    def __init__(self, num_classes, num_sampled, in_unit, remove_accidental_hits,
42                 sparse_label, prefix=None, params=None):
43        super(_SampledDenseHelper, self).__init__(prefix=prefix, params=params)
44        self._num_classes = num_classes
45        self._num_sampled = num_sampled
46        self._in_unit = in_unit
47        self._remove_accidental_hits = remove_accidental_hits
48        self._sparse_label = sparse_label
49
50    # pylint: disable=arguments-differ
51    def hybrid_forward(self, F, x, sampled_values, label, w_all, b_all):
52        """Forward computation."""
53        sampled_candidates, expected_count_sampled, expected_count_true = sampled_values
54        # (num_sampled, in_unit)
55        w_sampled = w_all.slice(begin=(0, 0), end=(self._num_sampled, None))
56        w_true = w_all.slice(begin=(self._num_sampled, 0), end=(None, None))
57        b_sampled = b_all.slice(begin=(0,), end=(self._num_sampled,))
58        b_true = b_all.slice(begin=(self._num_sampled,), end=(None,))
59        # true pred
60        # (batch_size, 1)
61        x = x.reshape((-1, self._in_unit))
62        pred_true = (w_true * x).sum(axis=1) + b_true
63        # samples pred
64        # (batch_size, num_sampled)
65        b_sampled = F.reshape(b_sampled, (-1,))
66        pred_sampled = F.FullyConnected(x, weight=w_sampled, bias=b_sampled,
67                                        num_hidden=self._num_sampled)
68
69        # remove accidental hits
70        if self._remove_accidental_hits:
71            label_vec = F.reshape(label, (-1, 1)).astype('int32')
72            sample_vec = F.reshape(sampled_candidates, (1, -1)).astype('int32')
73            mask = F.broadcast_equal(label_vec, sample_vec).astype('float32') * -1e37
74            pred_sampled = pred_sampled + mask
75
76        # subtract log(q)
77        expected_count_sampled = expected_count_sampled.astype('float32')
78        expected_count_sampled = expected_count_sampled.reshape(shape=(1, self._num_sampled))
79        expected_count_true = expected_count_true.astype('float32').reshape((-1,))
80        pred_true = pred_true - F.log(expected_count_true)
81        pred_true = pred_true.reshape((-1, 1))
82        pred_sampled = F.broadcast_sub(pred_sampled, F.log(expected_count_sampled))
83
84        # pred and new_labels
85        # (batch_size, 1+num_sampled)
86        pred = F.concat(pred_true, pred_sampled, dim=1)
87        if self._sparse_label:
88            new_label = F.zeros_like(label)
89        else:
90            label_vec = F.reshape(label, (-1, 1))
91            new_label_true = F.ones_like(label_vec)
92            new_label_sampled = F.zeros_like(pred_sampled)
93            new_label = F.Concat(new_label_true, new_label_sampled, dim=1)
94        return pred, new_label
95
96    def __repr__(self):
97        s = '{name}({mapping})'
98        mapping = '{0} -> {1}, with {2} samples'.format(self._in_unit, self._num_classes,
99                                                        self._num_sampled)
100        return s.format(name=self.__class__.__name__,
101                        mapping=mapping,
102                        **self.__dict__)
103
104class _SampledDense(HybridBlock):
105    """Block that computes sampled output training pred and labels suitable for
106    sampled softmax loss or noise contrastive estimation loss.
107
108    Please use `loss.SoftmaxCrossEntropyLoss` for sampled softmax loss, and
109    `loss.SigmoidBinaryCrossEntropyLoss` for nce loss.
110
111    Parameters
112    ----------
113    num_classes: int
114        Number of possible classes.
115    num_sampled: int
116        Number of classes randomly sampled for each batch.
117    in_unit: int
118        Dimensionality of the input space.
119    remove_accidental_hits: bool
120        Whether to remove "accidental hits" when a sampled candidate is equal to
121        one of the true classes.
122    dtype : str or np.dtype, default 'float32'
123        Data type of output embeddings.
124    weight_initializer : str or `Initializer`, optional
125        Initializer for the `kernel` weights matrix.
126    bias_initializer: str or `Initializer`, optional
127        Initializer for the bias vector.
128    sparse_grad: bool, default True.
129        Whether to use sparse gradient.
130
131    Inputs:
132        - **x**: A tensor of shape `(batch_size, in_unit)`. The forward activation of
133          the input network.
134        - **sampled_values** : A list of three tensors for
135          `sampled_classes` with shape `(num_samples,)`,
136          `expected_count_sampled` with shape `(num_samples,)`, and
137          `expected_count_true` with shape `(sequence_length, batch_size)`.
138        - **label**: A tensor of shape `(batch_size,1)`.
139          The target classes.
140
141    Outputs:
142        - **out**: A tensor of shape `(batch_size, 1+num_sampled)`.
143          The output probability for the true class and sampled classes
144        - **new_targets**: A tensor.
145          The new target classes. The shape is `(batch_size, 1)` if `sparse_label` is `True`,
146          `(batch_size, 1+num_sampled)` otherwise.
147
148    """
149    def __init__(self, num_classes, num_sampled, in_unit, remove_accidental_hits,
150                 sparse_label, dtype='float32', weight_initializer=None,
151                 bias_initializer='zeros', sparse_grad=True, prefix=None, params=None):
152        super(_SampledDense, self).__init__(prefix=prefix, params=params)
153        with self.name_scope():
154            grad_stype = 'row_sparse' if sparse_grad else 'default'
155            self.weight = self.params.get('weight', shape=(num_classes, in_unit),
156                                          init=weight_initializer,
157                                          dtype=dtype, grad_stype=grad_stype)
158            self.bias = self.params.get('bias', shape=(num_classes,), init=bias_initializer,
159                                        dtype=dtype)
160        self._dense = _SampledDenseHelper(num_classes, num_sampled, in_unit,
161                                          remove_accidental_hits, sparse_label)
162        self._num_classes = num_classes
163        self._num_sampled = num_sampled
164        self._in_unit = in_unit
165        self._remove_accidental_hits = remove_accidental_hits
166        self._sparse_grad = sparse_grad
167
168    # pylint: disable=arguments-differ
169    def hybrid_forward(self, F, x, sampled_values, label, weight, bias):
170        """Forward computation."""
171        sampled_candidates, _, _ = sampled_values
172        # (batch_size,)
173        label = F.reshape(label, shape=(-1,))
174        # (num_sampled+batch_size,)
175        ids = F.concat(sampled_candidates.astype('int32'), label.astype('int32'), dim=0)
176        # lookup weights and biases
177        # (num_sampled+batch_size, dim)
178        w_all = F.Embedding(data=ids, weight=weight,
179                            input_dim=self._num_classes, output_dim=self._in_unit,
180                            sparse_grad=self._sparse_grad)
181        # (num_sampled+batch_size, 1)
182        b_all = F.take(bias, indices=ids)
183        return self._dense(x, sampled_values, label, w_all, b_all)
184
185    def __repr__(self):
186        s = '{name}({mapping})'
187        mapping = '{0} -> {1}, with {2} samples'.format(self._in_unit, self._num_classes,
188                                                        self._num_sampled)
189        return s.format(name=self.__class__.__name__,
190                        mapping=mapping,
191                        **self.__dict__)
192
193class NCEDense(_SampledDense):
194    """Noise contrastive estimated Dense block, which computes sampled pred
195    output and labels for noise contrastive estimation loss during training.
196
197    Reference:
198
199    Exploring the Limits of Language Modeling
200    Jozefowicz, Rafal and Vinyals, Oriol and Schuster, Mike and Shazeer, Noam and Wu, Yonghui
201    https://arxiv.org/pdf/1602.02410
202
203    Please use `loss.SigmoidBinaryCrossEntropyLoss` for noise contrastive estimation loss
204    during training.
205
206    .. note::
207
208        If `sparse_grad` is set to True, the gradient w.r.t input and output
209        embeddings will be sparse. Only a subset of optimizers support
210        sparse gradients, including SGD, AdaGrad and Adam.
211        By default `lazy_update` is turned on for these optimizers,
212        which may perform differently from standard updates.
213        For more details, please check the Optimization API at:
214        https://mxnet.incubator.apache.org/api/python/optimization/optimization.html
215
216    Example::
217
218        # network with sampling for training
219        encoder = Encoder(..)
220        decoder = NCEDense(..)
221        train_net.add(encoder)
222        train_net.add(decoder)
223        loss_train = SigmoidBinaryCrossEntropyLoss()
224
225        # training
226        for x, y, sampled_values in train_batches:
227            pred, new_targets = train_net(x, sampled_values, y)
228            l = loss_train(pred, new_targets)
229
230        # network for testing
231        test_net.add(encoder)
232        test_net.add(Dense(..., params=decoder.params))
233        loss_test = SoftmaxCrossEntropyLoss()
234
235        # testing
236        for x, y in test_batches:
237            pred = test_net(x)
238            l = loss_test(pred, y)
239
240    Parameters
241    ----------
242    num_classes: int
243        Number of possible classes.
244    num_sampled: int
245        Number of classes randomly sampled for each batch.
246    in_unit: int
247        Dimensionality of the input space.
248    remove_accidental_hits: bool, default False
249        Whether to remove "accidental hits" when a sampled candidate is equal to
250        one of the true classes.
251    dtype : str or np.dtype, default 'float32'
252        Data type of output embeddings.
253    weight_initializer : str or `Initializer`, optional
254        Initializer for the `kernel` weights matrix.
255    bias_initializer: str or `Initializer`, optional
256        Initializer for the bias vector.
257    sparse_grad: bool, default True.
258        Whether to use sparse gradient.
259
260    Inputs:
261        - **x**: A tensor of shape `(batch_size, in_unit)`. The forward activation of
262          the input network.
263        - **sampled_values** : A list of three tensors for
264          `sampled_classes` with shape `(num_samples,)`,
265          `expected_count_sampled` with shape `(num_samples,)`, and
266          `expected_count_true` with shape `(sequence_length, batch_size)`.
267        - **label**: A tensor of shape `(batch_size,1)`.
268          The target classes.
269
270    Outputs:
271        - **out**: A tensor of shape `(batch_size, 1+num_sampled)`.
272          The output probability for the true class and sampled classes
273        - **new_targets**: A tensor of shape `(batch_size, 1+num_sampled)`.
274          The new target classes.
275    """
276    def __init__(self, num_classes, num_sampled, in_unit, remove_accidental_hits=False,
277                 dtype='float32', weight_initializer=None, bias_initializer='zeros',
278                 sparse_grad=True, prefix=None, params=None):
279        super(NCEDense, self).__init__(num_classes, num_sampled, in_unit, remove_accidental_hits,
280                                       False, dtype=dtype, weight_initializer=weight_initializer,
281                                       bias_initializer=bias_initializer, sparse_grad=sparse_grad,
282                                       prefix=prefix, params=params)
283
284class ISDense(_SampledDense):
285    """Importance sampled Dense block, which computes sampled pred output and labels
286    for importance sampled softmax loss during training.
287
288    Reference:
289
290    Exploring the Limits of Language Modeling
291    Jozefowicz, Rafal and Vinyals, Oriol and Schuster, Mike and Shazeer, Noam and Wu, Yonghui
292    https://arxiv.org/pdf/1602.02410
293
294    Please use `loss.SoftmaxCrossEntropyLoss` for sampled softmax loss.
295
296    .. note::
297
298        If `sparse_grad` is set to True, the gradient w.r.t input and output
299        embeddings will be sparse. Only a subset of optimizers support
300        sparse gradients, including SGD, AdaGrad and Adam.
301        By default `lazy_update` is turned on for these optimizers,
302        which may perform differently from standard updates.
303        For more details, please check the Optimization API at
304        https://mxnet.incubator.apache.org/api/python/optimization/optimization.html
305
306    Example::
307
308        # network with importance sampling for training
309        encoder = Encoder(..)
310        decoder = ISDense(..)
311        train_net.add(encoder)
312        train_net.add(decoder)
313        loss = SoftmaxCrossEntropyLoss()
314
315        # training
316        for x, y, sampled_values in train_batches:
317            pred, new_targets = train_net(x, sampled_values, y)
318            l = loss(pred, new_targets)
319
320        # network for testing
321        test_net.add(encoder)
322        test_net.add(Dense(..., params=decoder.params))
323
324        # testing
325        for x, y in test_batches:
326            pred = test_net(x)
327            l = loss(pred, y)
328
329    Parameters
330    ----------
331    num_classes: int
332        Number of possible classes.
333    num_sampled: int
334        Number of classes randomly sampled for each batch.
335    in_unit: int
336        Dimensionality of the input space.
337    remove_accidental_hits: bool, default True
338        Whether to remove "accidental hits" when a sampled candidate is equal to
339        one of the true classes.
340    dtype : str or np.dtype, default 'float32'
341        Data type of output embeddings.
342    weight_initializer : str or `Initializer`, optional
343        Initializer for the `kernel` weights matrix.
344    bias_initializer: str or `Initializer`, optional
345        Initializer for the bias vector.
346    sparse_grad: bool, default True.
347        Whether to use sparse gradient.
348
349    Inputs:
350        - **x**: A tensor of shape `(batch_size, in_unit)`. The forward activation of
351          the input network.
352        - **sampled_values** : A list of three tensors for
353          `sampled_classes` with shape `(num_samples,)`,
354          `expected_count_sampled` with shape `(num_samples,)`, and
355          `expected_count_true` with shape `(sequence_length, batch_size)`.
356        - **label**: A tensor of shape `(batch_size,1)`.
357          The target classes.
358
359    Outputs:
360        - **out**: A tensor of shape `(batch_size, 1+num_sampled)`.
361          The output probability for the true class and sampled classes
362        - **new_targets**: A tensor of shape `(batch_size,)`.
363          The new target classes.
364    """
365    def __init__(self, num_classes, num_sampled, in_unit, remove_accidental_hits=True,
366                 dtype='float32', weight_initializer=None, bias_initializer='zeros',
367                 sparse_grad=True, prefix=None, params=None):
368        super(ISDense, self).__init__(num_classes, num_sampled, in_unit, remove_accidental_hits,
369                                      True, dtype=dtype, weight_initializer=weight_initializer,
370                                      bias_initializer=bias_initializer, sparse_grad=sparse_grad,
371                                      prefix=prefix, params=params)
372
373class _SparseSampledDense(Block):
374    """Block that computes sampled output training pred and labels suitable for
375    sampled softmax loss or noise contrastive estimation loss.
376
377    Please use `loss.SoftmaxCrossEntropyLoss` for sampled softmax loss, and
378    `loss.SigmoidBinaryCrossEntropyLoss` for nce loss.
379
380    The block is designed for distributed training with extremely large
381    number of classes to reduce communication overhead and memory consumption.
382    Both weight and gradient w.r.t. weight are `RowSparseNDArray`.
383
384    Different from SampledDense block, the parameters have to be saved before they
385    are used for testing.
386
387    Example::
388
389        # network with sampled_softmax_loss for training
390        encoder = Encoder(..)
391        train_net.add(encoder)
392        train_net.add(SampledDense(.., prefix='decoder')))
393        loss = SoftmaxCrossEntropyLoss()
394
395        # training
396        for x, y, sampled_values in train_batches:
397            pred, new_targets = train_net(x, sampled_values, y)
398            l = loss(pred, new_targets)
399
400        # save params
401        train_net.save_parameters('net.params')
402
403        # network for testing
404        test_net.add(encoder)
405        test_net.add(Dense(..., prefix='decoder'))
406
407        # load params
408        test_net.load_parameters('net.params')
409
410        # testing
411        for x, y in test_batches:
412            pred = test_net(x)
413            l = loss(pred, y)
414
415    Parameters
416    ----------
417    num_classes: int
418        Number of possible classes.
419    num_sampled: int
420        Number of classes randomly sampled for each batch.
421    in_unit: int
422        Dimensionality of the input space.
423    remove_accidental_hits: bool
424        Whether to remove "accidental hits" when a sampled candidate is equal to
425        one of the true classes.
426    sparse_label: bool
427        Whether to output label as an integer array instead of probability distribution.
428    dtype : str or np.dtype, default 'float32'
429        Data type of output embeddings.
430    weight_initializer : str or `Initializer`, optional
431        Initializer for the `kernel` weights matrix.
432    bias_initializer: str or `Initializer`, optional
433        Initializer for the bias vector.
434
435    Inputs:
436        - **x**: A tensor of shape `(batch_size, in_unit)`. The forward activation of
437          the input network.
438        - **sampled_values** : A list of three tensors for
439          `sampled_classes` with shape `(num_samples,)`,
440          `expected_count_sampled` with shape `(num_samples,)`, and
441          `expected_count_true` with shape `(sequence_length, batch_size)`.
442        - **label**: A tensor of shape `(batch_size,1)`.
443          The target classes.
444
445    Outputs:
446        - **out**: A tensor of shape `(batch_size, 1+num_sampled)`.
447          The output probability for the true class and sampled classes
448        - **new_targets**: A tensor.
449          The new target classes. The shape is `(batch_size, 1)` if `sparse_label` is `True`,
450          `(batch_size, 1+num_sampled)` otherwise.
451
452    """
453    def __init__(self, num_classes, num_sampled, in_unit, remove_accidental_hits,
454                 sparse_label, dtype='float32', weight_initializer=None,
455                 bias_initializer='zeros', prefix=None, params=None):
456        super(_SparseSampledDense, self).__init__(prefix=prefix, params=params)
457        with self.name_scope():
458            self.weight = self.params.get('weight', shape=(num_classes, in_unit),
459                                          init=weight_initializer, dtype=dtype,
460                                          grad_stype='row_sparse', stype='row_sparse')
461            self.bias = self.params.get('bias', shape=(num_classes,), init=bias_initializer,
462                                        dtype=dtype)
463            self._dense = _SampledDenseHelper(num_classes, num_sampled, in_unit,
464                                              remove_accidental_hits, sparse_label)
465        self._num_classes = num_classes
466        self._num_sampled = num_sampled
467        self._in_unit = in_unit
468        self._remove_accidental_hits = remove_accidental_hits
469        self._kwargs = {'input_dim': self._num_classes, 'output_dim': self._in_unit,
470                        'sparse_grad': True}
471
472    def forward(self, x, sampled_values, label): # pylint: disable=arguments-differ
473        """Forward computation."""
474        sampled_candidates, _, _ = sampled_values
475        # (batch_size,)
476        label = label.reshape(shape=(-1,))
477        # (num_sampled+batch_size,)
478        ids = nd.concat(sampled_candidates.astype('int32'), label.astype('int32'), dim=0)
479        # lookup weights and biases
480        weight = self.weight.row_sparse_data(ids)
481        bias = self.bias.data(ids.context)
482        # (num_sampled+batch_size, dim)
483        w_all = nd.Embedding(data=ids, weight=weight, **self._kwargs)
484        # (num_sampled+batch_size,)
485        b_all = nd.take(bias, indices=ids)
486        out, new_targets = self._dense(x, sampled_values, label, w_all, b_all)
487        return out, new_targets
488
489    def __repr__(self):
490        s = '{name}({mapping})'
491        mapping = '{0} -> {1}, num_sampled = {2}, remove_accidental_hits = {3}'
492        mapping = mapping.format(self._in_unit, self._num_classes, self._num_sampled,
493                                 str(self._remove_accidental_hits))
494        return s.format(name=self.__class__.__name__,
495                        mapping=mapping, **self.__dict__)
496
497class SparseISDense(_SparseSampledDense):
498    """Importance sampled Dense block with sparse weights, which computes sampled pred output
499    and labels for importance sampled softmax loss during training.
500
501    Reference:
502
503    Exploring the Limits of Language Modeling
504    Jozefowicz, Rafal and Vinyals, Oriol and Schuster, Mike and Shazeer, Noam and Wu, Yonghui
505    https://arxiv.org/pdf/1602.02410
506
507    Please use `loss.SoftmaxCrossEntropyLoss` for sampled softmax loss.
508
509    The block is designed for distributed training with extremely large
510    number of classes to reduce communication overhead and memory consumption.
511    Both weight and gradient w.r.t. weight are `RowSparseNDArray`.
512
513    .. note::
514
515        Different from `ISDense` block, the weight parameter is stored in
516        row_sparse format, which helps reduce memory consumption and
517        communication overhead during multi-GPU training. However,
518        sparse parameters cannot be shared with other blocks, nor could we hybridize
519        a block containing sparse parameters. Therefore, the parameters have
520        to be saved before they are used for testing.
521
522    Example::
523
524        # network with importance sampled softmax for training
525        encoder = Encoder(..)
526        train_net.add(encoder)
527        train_net.add(SparseISDense(.., prefix='decoder')))
528        loss = SoftmaxCrossEntropyLoss()
529
530        # training
531        for x, y, sampled_values in train_batches:
532            pred, new_targets = train_net(x, sampled_values, y)
533            l = loss(pred, new_targets)
534
535        # save params
536        train_net.save_parameters('net.params')
537
538        # network for testing
539        test_net.add(encoder)
540        test_net.add(Dense(..., prefix='decoder'))
541
542        # load params
543        test_net.load_parameters('net.params')
544
545        # testing
546        for x, y in test_batches:
547            pred = test_net(x)
548            l = loss(pred, y)
549
550    Parameters
551    ----------
552    num_classes: int
553        Number of possible classes.
554    num_sampled: int
555        Number of classes randomly sampled for each batch.
556    in_unit: int
557        Dimensionality of the input space.
558    remove_accidental_hits: bool, default True
559        Whether to remove "accidental hits" when a sampled candidate is equal to
560        one of the true classes.
561    dtype : str or np.dtype, default 'float32'
562        Data type of output embeddings.
563    weight_initializer : str or `Initializer`, optional
564        Initializer for the `kernel` weights matrix.
565    bias_initializer: str or `Initializer`, optional
566        Initializer for the bias vector.
567
568    Inputs:
569        - **x**: A tensor of shape `(batch_size, in_unit)`. The forward activation of
570          the input network.
571        - **sampled_values** : A list of three tensors for
572          `sampled_classes` with shape `(num_samples,)`,
573          `expected_count_sampled` with shape `(num_samples,)`, and
574          `expected_count_true` with shape `(sequence_length, batch_size)`.
575        - **label**: A tensor of shape `(batch_size,1)`.
576          The target classes.
577
578    Outputs:
579        - **out**: A tensor of shape `(batch_size, 1+num_sampled)`.
580          The output probability for the true class and sampled classes
581        - **new_targets**: A tensor of shape `(batch_size,)`.
582          The new target classes.
583
584    """
585    def __init__(self, num_classes, num_sampled, in_unit, remove_accidental_hits=True,
586                 dtype='float32', weight_initializer=None, bias_initializer='zeros',
587                 prefix=None, params=None):
588        super(SparseISDense, self).__init__(num_classes, num_sampled, in_unit,
589                                            remove_accidental_hits, True, dtype,
590                                            weight_initializer, bias_initializer,
591                                            prefix=prefix, params=params)
592
593class SparseNCEDense(_SparseSampledDense):
594    """Noise contrastive estimated Dense block with sparse weights, which computes sampled
595    pred output and labels for noise contrastive estimation loss during training.
596
597    Reference:
598
599    Exploring the Limits of Language Modeling
600    Jozefowicz, Rafal and Vinyals, Oriol and Schuster, Mike and Shazeer, Noam and Wu, Yonghui
601    https://arxiv.org/pdf/1602.02410
602
603    Please use `loss.SigmoidBinaryCrossEntropyLoss` for noise contrastive estimation loss
604    during training.
605
606    The block is designed for distributed training with extremely large
607    number of classes to reduce communication overhead and memory consumption.
608    Both weight and gradient w.r.t. weight are `RowSparseNDArray`.
609
610    .. note::
611
612        Different from `NCEDense` block, the weight parameter is stored
613        in row_sparse format, which helps reduce memory consumption and
614        communication overhead during multi-GPU training. However,
615        sparse parameters cannot be shared with other blocks, nor could we
616        hybridize a block containing sparse parameters. Therefore, the
617        parameters have to be saved before they are used for testing.
618
619    Example::
620
621        # network with importance sampled softmax for training
622        encoder = Encoder(..)
623        train_net.add(encoder)
624        train_net.add(SparseNCEDense(.., prefix='decoder')))
625        train_loss = SigmoidBinaryCrossEntropyLoss()
626
627        # training
628        for x, y, sampled_values in train_batches:
629            pred, new_targets = train_net(x, sampled_values, y)
630            l = train_loss(pred, new_targets)
631
632        # save params
633        train_net.save_parameters('net.params')
634
635        # network for testing
636        test_net.add(encoder)
637        test_net.add(Dense(..., prefix='decoder'))
638
639        # load params
640        test_net.load_parameters('net.params')
641        test_loss = SoftmaxCrossEntropyLoss()
642
643        # testing
644        for x, y in test_batches:
645            pred = test_net(x)
646            l = test_loss(pred, y)
647
648    Parameters
649    ----------
650    num_classes: int
651        Number of possible classes.
652    num_sampled: int
653        Number of classes randomly sampled for each batch.
654    in_unit: int
655        Dimensionality of the input space.
656    remove_accidental_hits: bool, default True
657        Whether to remove "accidental hits" when a sampled candidate is equal to
658        one of the true classes.
659    dtype : str or np.dtype, default 'float32'
660        Data type of output embeddings.
661    weight_initializer : str or `Initializer`, optional
662        Initializer for the `kernel` weights matrix.
663    bias_initializer: str or `Initializer`, optional
664        Initializer for the bias vector.
665
666    Inputs:
667        - **x**: A tensor of shape `(batch_size, in_unit)`. The forward activation of
668          the input network.
669        - **sampled_values** : A list of three tensors for
670          `sampled_classes` with shape `(num_samples,)`,
671          `expected_count_sampled` with shape `(num_samples,)`, and
672          `expected_count_true` with shape `(sequence_length, batch_size)`.
673        - **label**: A tensor of shape `(batch_size, 1+num_samples)`.
674          The target classes.
675
676    Outputs:
677        - **out**: A tensor of shape `(batch_size, 1+num_sampled)`.
678          The output probability for the true class and sampled classes
679        - **new_targets**: A tensor of shape `(batch_size, 1+num_sampled)`.
680          The new target classes.
681
682    """
683    def __init__(self, num_classes, num_sampled, in_unit, remove_accidental_hits=True,
684                 dtype='float32', weight_initializer=None, bias_initializer='zeros',
685                 prefix=None, params=None):
686        super(SparseNCEDense, self).__init__(num_classes, num_sampled, in_unit,
687                                             remove_accidental_hits, False,
688                                             dtype, weight_initializer, bias_initializer,
689                                             prefix=prefix, params=params)
690