1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Building blocks and utility for models."""
18__all__ = ['WeightDropParameter']
19
20from mxnet import nd, gluon
21
22
23class WeightDropParameter(gluon.Parameter):
24    """A Container holding parameters (weights) of Blocks and performs dropout.
25
26    Parameters
27    ----------
28    parameter : Parameter
29        The parameter which drops out.
30    rate : float, default 0.0
31        Fraction of the input units to drop. Must be a number between 0 and 1.
32        Dropout is not applied if dropout_rate is 0.
33    mode : str, default 'training'
34        Whether to only turn on dropout during training or to also turn on for inference.
35        Options are 'training' and 'always'.
36    axes : tuple of int, default ()
37        Axes on which dropout mask is shared.
38    """
39    def __init__(self, parameter, rate=0.0, mode='training', axes=()):
40        p = parameter
41        self._deferred_init = p._deferred_init
42        super(WeightDropParameter, self).__init__(
43            name=p.name, grad_req=p.grad_req, shape=p._shape, dtype=p.dtype,
44            lr_mult=p.lr_mult, wd_mult=p.wd_mult, init=p.init,
45            allow_deferred_init=p._allow_deferred_init,
46            differentiable=p._differentiable)
47        self._rate = rate
48        self._mode = mode
49        self._axes = axes
50        self._var = p._var
51        self._data = p._data
52        self._grad = p._grad
53        self._ctx_list = p._ctx_list
54        self._ctx_map = p._ctx_map
55        self._trainer = p._trainer
56
57    def data(self, ctx=None):
58        """Returns a copy of this parameter on one context. Must have been
59        initialized on this context before.
60
61        Parameters
62        ----------
63        ctx : Context
64            Desired context.
65        Returns
66        -------
67        NDArray on ctx
68        """
69        d = self._check_and_get(self._data, ctx)
70        if self._rate:
71            d = nd.Dropout(d, self._rate, self._mode, self._axes)
72        return d
73
74    def __repr__(self):
75        s = 'WeightDropParameter {name} (shape={shape}, dtype={dtype}, rate={rate}, mode={mode})'
76        return s.format(name=self.name, shape=self.shape, dtype=self.dtype,
77                        rate=self._rate, mode=self._mode)
78