1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Building blocks and utility for models.""" 18__all__ = ['WeightDropParameter'] 19 20from mxnet import nd, gluon 21 22 23class WeightDropParameter(gluon.Parameter): 24 """A Container holding parameters (weights) of Blocks and performs dropout. 25 26 Parameters 27 ---------- 28 parameter : Parameter 29 The parameter which drops out. 30 rate : float, default 0.0 31 Fraction of the input units to drop. Must be a number between 0 and 1. 32 Dropout is not applied if dropout_rate is 0. 33 mode : str, default 'training' 34 Whether to only turn on dropout during training or to also turn on for inference. 35 Options are 'training' and 'always'. 36 axes : tuple of int, default () 37 Axes on which dropout mask is shared. 38 """ 39 def __init__(self, parameter, rate=0.0, mode='training', axes=()): 40 p = parameter 41 self._deferred_init = p._deferred_init 42 super(WeightDropParameter, self).__init__( 43 name=p.name, grad_req=p.grad_req, shape=p._shape, dtype=p.dtype, 44 lr_mult=p.lr_mult, wd_mult=p.wd_mult, init=p.init, 45 allow_deferred_init=p._allow_deferred_init, 46 differentiable=p._differentiable) 47 self._rate = rate 48 self._mode = mode 49 self._axes = axes 50 self._var = p._var 51 self._data = p._data 52 self._grad = p._grad 53 self._ctx_list = p._ctx_list 54 self._ctx_map = p._ctx_map 55 self._trainer = p._trainer 56 57 def data(self, ctx=None): 58 """Returns a copy of this parameter on one context. Must have been 59 initialized on this context before. 60 61 Parameters 62 ---------- 63 ctx : Context 64 Desired context. 65 Returns 66 ------- 67 NDArray on ctx 68 """ 69 d = self._check_and_get(self._data, ctx) 70 if self._rate: 71 d = nd.Dropout(d, self._rate, self._mode, self._axes) 72 return d 73 74 def __repr__(self): 75 s = 'WeightDropParameter {name} (shape={shape}, dtype={dtype}, rate={rate}, mode={mode})' 76 return s.format(name=self.name, shape=self.shape, dtype=self.dtype, 77 rate=self._rate, mode=self._mode) 78