1import copy
2
3import numpy
4import six
5
6import chainer
7from chainer.backends import cuda
8from chainer import device_resident
9from chainer import function
10from chainer.initializers import uniform
11from chainer import link
12from chainer import utils
13from chainer.utils import type_check
14from chainer import variable
15
16
17class TreeParser(object):
18
19    def __init__(self, dtype):
20        self.next_id = 0
21        self.dtype = dtype
22
23    def size(self):
24        return self.next_id
25
26    def get_paths(self):
27        return self.paths
28
29    def get_codes(self):
30        return self.codes
31
32    def parse(self, tree):
33        self.next_id = 0
34        self.path = []
35        self.code = []
36        self.paths = {}
37        self.codes = {}
38        self._parse(tree)
39
40        assert(len(self.path) == 0)
41        assert(len(self.code) == 0)
42        assert(len(self.paths) == len(self.codes))
43
44    def _parse(self, node):
45        if isinstance(node, tuple):
46            # internal node
47            if len(node) != 2:
48                raise ValueError(
49                    'All internal nodes must have two child nodes')
50            left, right = node
51            self.path.append(self.next_id)
52            self.next_id += 1
53            self.code.append(1.0)
54            self._parse(left)
55
56            self.code[-1] = -1.0
57            self._parse(right)
58
59            self.path.pop()
60            self.code.pop()
61
62        else:
63            # leaf node
64            self.paths[node] = numpy.array(self.path, dtype=numpy.int32)
65            self.codes[node] = numpy.array(self.code, dtype=self.dtype)
66
67
68class BinaryHierarchicalSoftmaxFunction(
69        device_resident.DeviceResident, function.Function):
70
71    """Hierarchical softmax function based on a binary tree.
72
73    This function object should be allocated beforehand, and be copied on every
74    forward computation, since the initializer parses the given tree. See the
75    implementation of :class:`BinaryHierarchicalSoftmax` for details.
76
77    Args:
78        tree: A binary tree made with tuples like ``((1, 2), 3)``.
79
80    .. seealso::
81       See :class:`BinaryHierarchicalSoftmax` for details.
82
83    """
84
85    def __init__(self, tree, dtype):
86        device_resident.DeviceResident.__init__(self)
87
88        parser = TreeParser(dtype)
89        parser.parse(tree)
90        paths = parser.get_paths()
91        codes = parser.get_codes()
92        n_vocab = max(paths.keys()) + 1
93
94        self.paths = numpy.concatenate(
95            [paths[i] for i in range(n_vocab) if i in paths])
96        self.codes = numpy.concatenate(
97            [codes[i] for i in range(n_vocab) if i in codes])
98        begins = numpy.empty((n_vocab + 1,), dtype=numpy.int32)
99        begins[0] = 0
100        for i in range(0, n_vocab):
101            length = len(paths[i]) if i in paths else 0
102            begins[i + 1] = begins[i] + length
103        self.begins = begins
104
105        self.parser_size = parser.size()
106
107    def check_type_forward(self, in_types):
108        type_check.expect(in_types.size() == 3)
109        x_type, t_type, w_type = in_types
110
111        type_check.expect(
112            x_type.dtype.kind == 'f',
113            x_type.ndim == 2,
114            t_type.dtype == numpy.int32,
115            t_type.ndim == 1,
116            x_type.shape[0] == t_type.shape[0],
117            w_type.dtype == x_type.dtype,
118            w_type.ndim == 2,
119            w_type.shape[0] == self.parser_size,
120            w_type.shape[1] == x_type.shape[1],
121        )
122
123    def device_resident_accept(self, visitor):
124        super(BinaryHierarchicalSoftmaxFunction, self).device_resident_accept(
125            visitor)
126        self.paths = visitor.visit_array(self.paths)
127        self.codes = visitor.visit_array(self.codes)
128        self.begins = visitor.visit_array(self.begins)
129
130    def forward_cpu(self, inputs):
131        x, t, W = inputs
132
133        loss = x.dtype.type(0.0)
134        for ix, it in six.moves.zip(x, t):
135            loss += self._forward_cpu_one(ix, it, W)
136        return numpy.array(loss),
137
138    def _forward_cpu_one(self, x, t, W):
139        begin = self.begins[t]
140        end = self.begins[t + 1]
141
142        w = W[self.paths[begin:end]]
143        wxy = w.dot(x) * self.codes[begin:end]
144        loss = numpy.logaddexp(0.0, -wxy)  # == log(1 + exp(-wxy))
145        return numpy.sum(loss)
146
147    def backward_cpu(self, inputs, grad_outputs):
148        x, t, W = inputs
149        gloss, = grad_outputs
150        gx = numpy.empty_like(x)
151        gW = numpy.zeros_like(W)
152        for i, (ix, it) in enumerate(six.moves.zip(x, t)):
153            gx[i] = self._backward_cpu_one(ix, it, W, gloss, gW)
154        return gx, None, gW
155
156    def _backward_cpu_one(self, x, t, W, gloss, gW):
157        begin = self.begins[t]
158        end = self.begins[t + 1]
159
160        path = self.paths[begin:end]
161        w = W[path]
162        wxy = w.dot(x) * self.codes[begin:end]
163        g = -gloss * self.codes[begin:end] / (1.0 + numpy.exp(wxy))
164        gx = g.dot(w)
165        gw = g.reshape((g.shape[0], 1)).dot(x.reshape(1, x.shape[0]))
166        gW[path] += gw
167        return gx
168
169    def forward_gpu(self, inputs):
170        x, t, W = inputs
171        max_length = cuda.reduce(
172            'T t, raw T begins', 'T out', 'begins[t + 1] - begins[t]',
173            'max(a, b)', 'out = a', '0',
174            'binary_hierarchical_softmax_max_length')(t, self.begins)
175        max_length = cuda.to_cpu(max_length)[()]
176
177        length = max_length * x.shape[0]
178        ls = cuda.cupy.empty((length,), dtype=x.dtype)
179        n_in = x.shape[1]
180        wxy = cuda.cupy.empty_like(ls)
181        cuda.elementwise(
182            '''raw T x, raw T w, raw int32 ts, raw int32 paths,
183            raw T codes, raw int32 begins, int32 c, int32 max_length''',
184            'T ls, T wxy',
185            '''
186            int ind = i / max_length;
187            int offset = i - ind * max_length;
188            int t = ts[ind];
189
190            int begin = begins[t];
191            int length = begins[t + 1] - begins[t];
192
193            if (offset < length) {
194              int p = begin + offset;
195              int node = paths[p];
196
197              T wx = 0;
198              for (int j = 0; j < c; ++j) {
199                int w_ind[] = {node, j};
200                int x_ind[] = {ind, j};
201                wx += w[w_ind] * x[x_ind];
202              }
203              wxy = wx * codes[p];
204              ls = log(1 + exp(-wxy));
205            } else {
206              ls = 0;
207            }
208            ''',
209            'binary_hierarchical_softmax_forward'
210        )(x, W, t, self.paths, self.codes, self.begins, n_in, max_length, ls,
211          wxy)
212        self.max_length = max_length
213        self.wxy = wxy
214        return ls.sum(),
215
216    def backward_gpu(self, inputs, grad_outputs):
217        utils.nondeterministic('atomicAdd')
218        x, t, W = inputs
219        gloss, = grad_outputs
220
221        n_in = x.shape[1]
222        gx = cuda.cupy.zeros_like(x)
223        gW = cuda.cupy.zeros_like(W)
224        cuda.elementwise(
225            '''T wxy, raw T x, raw T w, raw int32 ts, raw int32 paths,
226            raw T codes, raw int32 begins, raw T gloss,
227            int32 c, int32 max_length''',
228            'raw T gx, raw T gw',
229            '''
230            int ind = i / max_length;
231            int offset = i - ind * max_length;
232            int t = ts[ind];
233
234            int begin = begins[t];
235            int length = begins[t + 1] - begins[t];
236
237            if (offset < length) {
238              int p = begin + offset;
239              int node = paths[p];
240              T code = codes[p];
241
242              T g = -gloss[0] * code / (1.0 + exp(wxy));
243              for (int j = 0; j < c; ++j) {
244                int w_ind[] = {node, j};
245                int x_ind[] = {ind, j};
246                atomicAdd(&gx[x_ind], g * w[w_ind]);
247                atomicAdd(&gw[w_ind], g * x[x_ind]);
248              }
249            }
250            ''',
251            'binary_hierarchical_softmax_bwd'
252        )(self.wxy, x, W, t, self.paths, self.codes, self.begins, gloss, n_in,
253          self.max_length, gx, gW)
254        return gx, None, gW
255
256
257class BinaryHierarchicalSoftmax(link.Link):
258
259    """Hierarchical softmax layer over binary tree.
260
261    In natural language applications, vocabulary size is too large to use
262    softmax loss.
263    Instead, the hierarchical softmax uses product of sigmoid functions.
264    It costs only :math:`O(\\log(n))` time where :math:`n` is the vocabulary
265    size in average.
266
267    At first a user needs to prepare a binary tree whose each leaf is
268    corresponding to a word in a vocabulary.
269    When a word :math:`x` is given, exactly one path from the root of the tree
270    to the leaf of the word exists.
271    Let :math:`\\mbox{path}(x) = ((e_1, b_1), \\dots, (e_m, b_m))` be the path
272    of :math:`x`, where :math:`e_i` is an index of :math:`i`-th internal node,
273    and :math:`b_i \\in \\{-1, 1\\}` indicates direction to move at
274    :math:`i`-th internal node (-1 is left, and 1 is right).
275    Then, the probability of :math:`x` is given as below:
276
277    .. math::
278
279       P(x) &= \\prod_{(e_i, b_i) \\in \\mbox{path}(x)}P(b_i | e_i)  \\\\
280            &= \\prod_{(e_i, b_i) \\in \\mbox{path}(x)}\\sigma(b_i x^\\top
281               w_{e_i}),
282
283    where :math:`\\sigma(\\cdot)` is a sigmoid function, and :math:`w` is a
284    weight matrix.
285
286    This function costs :math:`O(\\log(n))` time as an average length of paths
287    is :math:`O(\\log(n))`, and :math:`O(n)` memory as the number of internal
288    nodes equals :math:`n - 1`.
289
290    Args:
291        in_size (int): Dimension of input vectors.
292        tree: A binary tree made with tuples like `((1, 2), 3)`.
293        dtype (numpy.dtype): Type to use in computing.
294
295    Attributes:
296        W (~chainer.Variable): Weight parameter matrix.
297
298    See: Hierarchical Probabilistic Neural Network Language Model [Morin+,
299    AISTAT2005].
300
301    """
302
303    def __init__(self, in_size, tree, dtype=None):
304        # This function object is copied on every forward computation.
305        super(BinaryHierarchicalSoftmax, self).__init__()
306        dtype = chainer.get_dtype(dtype)
307        self._func = BinaryHierarchicalSoftmaxFunction(tree, dtype)
308
309        with self.init_scope():
310            self.W = variable.Parameter(uniform.Uniform(1),
311                                        (self._func.parser_size, in_size))
312
313    def device_resident_accept(self, visitor):
314        super(BinaryHierarchicalSoftmax, self).device_resident_accept(visitor)
315        self._func.device_resident_accept(visitor)
316
317    @staticmethod
318    def create_huffman_tree(word_counts):
319        """Makes a Huffman tree from a dictionary containing word counts.
320
321        This method creates a binary Huffman tree, that is required for
322        :class:`BinaryHierarchicalSoftmax`.
323        For example, ``{0: 8, 1: 5, 2: 6, 3: 4}`` is converted to
324        ``((3, 1), (2, 0))``.
325
326        Args:
327            word_counts (dict of int key and int or float values):
328                Dictionary representing counts of words.
329
330        Returns:
331            Binary Huffman tree with tuples and keys of ``word_coutns``.
332
333        """
334        if not word_counts:
335            raise ValueError('Empty vocabulary')
336
337        q = six.moves.queue.PriorityQueue()
338        # Add unique id to each entry so that we can compare two entries with
339        # same counts.
340        # Note that iteritems randomly order the entries.
341        for uid, (w, c) in enumerate(six.iteritems(word_counts)):
342            q.put((c, uid, w))
343
344        while q.qsize() >= 2:
345            (count1, id1, word1) = q.get()
346            (count2, id2, word2) = q.get()
347            count = count1 + count2
348            tree = (word1, word2)
349            q.put((count, min(id1, id2), tree))
350
351        return q.get()[2]
352
353    def forward(self, x, t):
354        """Computes the loss value for given input and ground truth labels.
355
356        Args:
357            x (~chainer.Variable): Input to the classifier at each node.
358            t (~chainer.Variable): Batch of ground truth labels.
359
360        Returns:
361            ~chainer.Variable: Loss value.
362
363        """
364        f = copy.copy(self._func)  # creates a copy of the function node
365        return f(x, t, self.W)
366