1import copy 2 3import numpy 4import six 5 6import chainer 7from chainer.backends import cuda 8from chainer import device_resident 9from chainer import function 10from chainer.initializers import uniform 11from chainer import link 12from chainer import utils 13from chainer.utils import type_check 14from chainer import variable 15 16 17class TreeParser(object): 18 19 def __init__(self, dtype): 20 self.next_id = 0 21 self.dtype = dtype 22 23 def size(self): 24 return self.next_id 25 26 def get_paths(self): 27 return self.paths 28 29 def get_codes(self): 30 return self.codes 31 32 def parse(self, tree): 33 self.next_id = 0 34 self.path = [] 35 self.code = [] 36 self.paths = {} 37 self.codes = {} 38 self._parse(tree) 39 40 assert(len(self.path) == 0) 41 assert(len(self.code) == 0) 42 assert(len(self.paths) == len(self.codes)) 43 44 def _parse(self, node): 45 if isinstance(node, tuple): 46 # internal node 47 if len(node) != 2: 48 raise ValueError( 49 'All internal nodes must have two child nodes') 50 left, right = node 51 self.path.append(self.next_id) 52 self.next_id += 1 53 self.code.append(1.0) 54 self._parse(left) 55 56 self.code[-1] = -1.0 57 self._parse(right) 58 59 self.path.pop() 60 self.code.pop() 61 62 else: 63 # leaf node 64 self.paths[node] = numpy.array(self.path, dtype=numpy.int32) 65 self.codes[node] = numpy.array(self.code, dtype=self.dtype) 66 67 68class BinaryHierarchicalSoftmaxFunction( 69 device_resident.DeviceResident, function.Function): 70 71 """Hierarchical softmax function based on a binary tree. 72 73 This function object should be allocated beforehand, and be copied on every 74 forward computation, since the initializer parses the given tree. See the 75 implementation of :class:`BinaryHierarchicalSoftmax` for details. 76 77 Args: 78 tree: A binary tree made with tuples like ``((1, 2), 3)``. 79 80 .. seealso:: 81 See :class:`BinaryHierarchicalSoftmax` for details. 82 83 """ 84 85 def __init__(self, tree, dtype): 86 device_resident.DeviceResident.__init__(self) 87 88 parser = TreeParser(dtype) 89 parser.parse(tree) 90 paths = parser.get_paths() 91 codes = parser.get_codes() 92 n_vocab = max(paths.keys()) + 1 93 94 self.paths = numpy.concatenate( 95 [paths[i] for i in range(n_vocab) if i in paths]) 96 self.codes = numpy.concatenate( 97 [codes[i] for i in range(n_vocab) if i in codes]) 98 begins = numpy.empty((n_vocab + 1,), dtype=numpy.int32) 99 begins[0] = 0 100 for i in range(0, n_vocab): 101 length = len(paths[i]) if i in paths else 0 102 begins[i + 1] = begins[i] + length 103 self.begins = begins 104 105 self.parser_size = parser.size() 106 107 def check_type_forward(self, in_types): 108 type_check.expect(in_types.size() == 3) 109 x_type, t_type, w_type = in_types 110 111 type_check.expect( 112 x_type.dtype.kind == 'f', 113 x_type.ndim == 2, 114 t_type.dtype == numpy.int32, 115 t_type.ndim == 1, 116 x_type.shape[0] == t_type.shape[0], 117 w_type.dtype == x_type.dtype, 118 w_type.ndim == 2, 119 w_type.shape[0] == self.parser_size, 120 w_type.shape[1] == x_type.shape[1], 121 ) 122 123 def device_resident_accept(self, visitor): 124 super(BinaryHierarchicalSoftmaxFunction, self).device_resident_accept( 125 visitor) 126 self.paths = visitor.visit_array(self.paths) 127 self.codes = visitor.visit_array(self.codes) 128 self.begins = visitor.visit_array(self.begins) 129 130 def forward_cpu(self, inputs): 131 x, t, W = inputs 132 133 loss = x.dtype.type(0.0) 134 for ix, it in six.moves.zip(x, t): 135 loss += self._forward_cpu_one(ix, it, W) 136 return numpy.array(loss), 137 138 def _forward_cpu_one(self, x, t, W): 139 begin = self.begins[t] 140 end = self.begins[t + 1] 141 142 w = W[self.paths[begin:end]] 143 wxy = w.dot(x) * self.codes[begin:end] 144 loss = numpy.logaddexp(0.0, -wxy) # == log(1 + exp(-wxy)) 145 return numpy.sum(loss) 146 147 def backward_cpu(self, inputs, grad_outputs): 148 x, t, W = inputs 149 gloss, = grad_outputs 150 gx = numpy.empty_like(x) 151 gW = numpy.zeros_like(W) 152 for i, (ix, it) in enumerate(six.moves.zip(x, t)): 153 gx[i] = self._backward_cpu_one(ix, it, W, gloss, gW) 154 return gx, None, gW 155 156 def _backward_cpu_one(self, x, t, W, gloss, gW): 157 begin = self.begins[t] 158 end = self.begins[t + 1] 159 160 path = self.paths[begin:end] 161 w = W[path] 162 wxy = w.dot(x) * self.codes[begin:end] 163 g = -gloss * self.codes[begin:end] / (1.0 + numpy.exp(wxy)) 164 gx = g.dot(w) 165 gw = g.reshape((g.shape[0], 1)).dot(x.reshape(1, x.shape[0])) 166 gW[path] += gw 167 return gx 168 169 def forward_gpu(self, inputs): 170 x, t, W = inputs 171 max_length = cuda.reduce( 172 'T t, raw T begins', 'T out', 'begins[t + 1] - begins[t]', 173 'max(a, b)', 'out = a', '0', 174 'binary_hierarchical_softmax_max_length')(t, self.begins) 175 max_length = cuda.to_cpu(max_length)[()] 176 177 length = max_length * x.shape[0] 178 ls = cuda.cupy.empty((length,), dtype=x.dtype) 179 n_in = x.shape[1] 180 wxy = cuda.cupy.empty_like(ls) 181 cuda.elementwise( 182 '''raw T x, raw T w, raw int32 ts, raw int32 paths, 183 raw T codes, raw int32 begins, int32 c, int32 max_length''', 184 'T ls, T wxy', 185 ''' 186 int ind = i / max_length; 187 int offset = i - ind * max_length; 188 int t = ts[ind]; 189 190 int begin = begins[t]; 191 int length = begins[t + 1] - begins[t]; 192 193 if (offset < length) { 194 int p = begin + offset; 195 int node = paths[p]; 196 197 T wx = 0; 198 for (int j = 0; j < c; ++j) { 199 int w_ind[] = {node, j}; 200 int x_ind[] = {ind, j}; 201 wx += w[w_ind] * x[x_ind]; 202 } 203 wxy = wx * codes[p]; 204 ls = log(1 + exp(-wxy)); 205 } else { 206 ls = 0; 207 } 208 ''', 209 'binary_hierarchical_softmax_forward' 210 )(x, W, t, self.paths, self.codes, self.begins, n_in, max_length, ls, 211 wxy) 212 self.max_length = max_length 213 self.wxy = wxy 214 return ls.sum(), 215 216 def backward_gpu(self, inputs, grad_outputs): 217 utils.nondeterministic('atomicAdd') 218 x, t, W = inputs 219 gloss, = grad_outputs 220 221 n_in = x.shape[1] 222 gx = cuda.cupy.zeros_like(x) 223 gW = cuda.cupy.zeros_like(W) 224 cuda.elementwise( 225 '''T wxy, raw T x, raw T w, raw int32 ts, raw int32 paths, 226 raw T codes, raw int32 begins, raw T gloss, 227 int32 c, int32 max_length''', 228 'raw T gx, raw T gw', 229 ''' 230 int ind = i / max_length; 231 int offset = i - ind * max_length; 232 int t = ts[ind]; 233 234 int begin = begins[t]; 235 int length = begins[t + 1] - begins[t]; 236 237 if (offset < length) { 238 int p = begin + offset; 239 int node = paths[p]; 240 T code = codes[p]; 241 242 T g = -gloss[0] * code / (1.0 + exp(wxy)); 243 for (int j = 0; j < c; ++j) { 244 int w_ind[] = {node, j}; 245 int x_ind[] = {ind, j}; 246 atomicAdd(&gx[x_ind], g * w[w_ind]); 247 atomicAdd(&gw[w_ind], g * x[x_ind]); 248 } 249 } 250 ''', 251 'binary_hierarchical_softmax_bwd' 252 )(self.wxy, x, W, t, self.paths, self.codes, self.begins, gloss, n_in, 253 self.max_length, gx, gW) 254 return gx, None, gW 255 256 257class BinaryHierarchicalSoftmax(link.Link): 258 259 """Hierarchical softmax layer over binary tree. 260 261 In natural language applications, vocabulary size is too large to use 262 softmax loss. 263 Instead, the hierarchical softmax uses product of sigmoid functions. 264 It costs only :math:`O(\\log(n))` time where :math:`n` is the vocabulary 265 size in average. 266 267 At first a user needs to prepare a binary tree whose each leaf is 268 corresponding to a word in a vocabulary. 269 When a word :math:`x` is given, exactly one path from the root of the tree 270 to the leaf of the word exists. 271 Let :math:`\\mbox{path}(x) = ((e_1, b_1), \\dots, (e_m, b_m))` be the path 272 of :math:`x`, where :math:`e_i` is an index of :math:`i`-th internal node, 273 and :math:`b_i \\in \\{-1, 1\\}` indicates direction to move at 274 :math:`i`-th internal node (-1 is left, and 1 is right). 275 Then, the probability of :math:`x` is given as below: 276 277 .. math:: 278 279 P(x) &= \\prod_{(e_i, b_i) \\in \\mbox{path}(x)}P(b_i | e_i) \\\\ 280 &= \\prod_{(e_i, b_i) \\in \\mbox{path}(x)}\\sigma(b_i x^\\top 281 w_{e_i}), 282 283 where :math:`\\sigma(\\cdot)` is a sigmoid function, and :math:`w` is a 284 weight matrix. 285 286 This function costs :math:`O(\\log(n))` time as an average length of paths 287 is :math:`O(\\log(n))`, and :math:`O(n)` memory as the number of internal 288 nodes equals :math:`n - 1`. 289 290 Args: 291 in_size (int): Dimension of input vectors. 292 tree: A binary tree made with tuples like `((1, 2), 3)`. 293 dtype (numpy.dtype): Type to use in computing. 294 295 Attributes: 296 W (~chainer.Variable): Weight parameter matrix. 297 298 See: Hierarchical Probabilistic Neural Network Language Model [Morin+, 299 AISTAT2005]. 300 301 """ 302 303 def __init__(self, in_size, tree, dtype=None): 304 # This function object is copied on every forward computation. 305 super(BinaryHierarchicalSoftmax, self).__init__() 306 dtype = chainer.get_dtype(dtype) 307 self._func = BinaryHierarchicalSoftmaxFunction(tree, dtype) 308 309 with self.init_scope(): 310 self.W = variable.Parameter(uniform.Uniform(1), 311 (self._func.parser_size, in_size)) 312 313 def device_resident_accept(self, visitor): 314 super(BinaryHierarchicalSoftmax, self).device_resident_accept(visitor) 315 self._func.device_resident_accept(visitor) 316 317 @staticmethod 318 def create_huffman_tree(word_counts): 319 """Makes a Huffman tree from a dictionary containing word counts. 320 321 This method creates a binary Huffman tree, that is required for 322 :class:`BinaryHierarchicalSoftmax`. 323 For example, ``{0: 8, 1: 5, 2: 6, 3: 4}`` is converted to 324 ``((3, 1), (2, 0))``. 325 326 Args: 327 word_counts (dict of int key and int or float values): 328 Dictionary representing counts of words. 329 330 Returns: 331 Binary Huffman tree with tuples and keys of ``word_coutns``. 332 333 """ 334 if not word_counts: 335 raise ValueError('Empty vocabulary') 336 337 q = six.moves.queue.PriorityQueue() 338 # Add unique id to each entry so that we can compare two entries with 339 # same counts. 340 # Note that iteritems randomly order the entries. 341 for uid, (w, c) in enumerate(six.iteritems(word_counts)): 342 q.put((c, uid, w)) 343 344 while q.qsize() >= 2: 345 (count1, id1, word1) = q.get() 346 (count2, id2, word2) = q.get() 347 count = count1 + count2 348 tree = (word1, word2) 349 q.put((count, min(id1, id2), tree)) 350 351 return q.get()[2] 352 353 def forward(self, x, t): 354 """Computes the loss value for given input and ground truth labels. 355 356 Args: 357 x (~chainer.Variable): Input to the classifier at each node. 358 t (~chainer.Variable): Batch of ground truth labels. 359 360 Returns: 361 ~chainer.Variable: Loss value. 362 363 """ 364 f = copy.copy(self._func) # creates a copy of the function node 365 return f(x, t, self.W) 366