1from __future__ import (division, absolute_import, print_function)
2import os
3import sys
4import theano.tensor as T
5from theano import config
6from theano import gof
7from theano.gof import local_optimizer
8from theano.gof.cmodule import GCC_compiler
9from theano.tensor.opt import register_canonicalize
10from theano.tensor.extra_ops import cpu_contiguous
11from theano.gradient import grad_undefined
12
13
14def _ctc_find_lib():
15    """
16    Find the directory that contains libwarpctc.so
17    """
18    if config.ctc.root != '':
19        for lib_dir in ["build", "lib", "lib64"]:
20            lib_path = os.path.join(config.ctc.root, lib_dir)
21            if os.path.isdir(lib_path) and os.path.exists(lib_path):
22                lib_found = os.path.exists(os.path.join(lib_path, "libwarpctc.so"))
23                if lib_found:
24                    return lib_path
25    return None
26
27
28def _ctc_check_compile(ctc_lib_path):
29    preambule = """
30#include <string.h>
31#include "ctc.h"
32"""
33
34    body = """
35ctcOptions options;
36memset(&options, 0, sizeof(ctcOptions));
37options.loc = CTC_CPU;
38options.num_threads = 1;
39"""
40
41    params = ['-I%s' % (os.path.dirname(__file__))]
42    if ctc_lib_path is not None:
43        params.extend(["-I%s" % (os.path.join(config.ctc.root, "include"))])
44        params.extend(["-L%s" % (ctc_lib_path)])
45    params.extend(["-l", "warpctc"])
46    compiler_res = GCC_compiler.try_flags(
47        params, preambule=preambule, body=body,
48        try_run=False, output=True)
49
50    avail, out, err = compiler_res if isinstance(compiler_res, tuple) else (compiler_res, None, None)
51    if not avail:
52        return False, ("cannot compile with warp-ctc. "
53                       "We got this error:\n" + str(err))
54    return True, None
55
56
57def ctc_present():
58    if ctc_present.avail is not None:
59        return ctc_present.avail
60    ctc_lib_path = _ctc_find_lib()
61    ctc_present.path = ctc_lib_path
62    ctc_present.avail, ctc_present.msg = _ctc_check_compile(ctc_present.path)
63    return ctc_present.avail
64
65
66ctc_present.avail = None
67ctc_present.msg = None
68ctc_present.path = None
69
70
71def ctc_available():
72    if os.name == 'nt':
73        ctc_available.msg = 'Windows platforms are currently not supported ',
74        'by underlying CTC library (warp-ctc).'
75        return False
76    elif not ctc_present():
77        ctc_available.msg = ctc_present.msg
78        return False
79
80    ctc_available.path = ctc_present.path
81    return True
82
83
84ctc_available.msg = None
85ctc_available.path = None
86
87
88class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
89    """
90    CTC loss function wrapper.
91
92    Notes
93    -----
94    Using the wrapper requires that Baidu's warp-ctc library is installed.
95    If the warp-ctc library is not on your compiler's default library path,
96    you must set the configuration variable ``config.ctc.root`` appropriately.
97
98    Parameters
99    ----------
100    compute_grad
101        If set to True, enables the computation of gradients of the CTC loss function.
102    """
103    __props__ = ('compute_grad',)
104
105    _cop_num_inputs = 3
106    _cop_num_outputs = 2
107
108    func_file = os.path.join('c_code', 'ctc_wrapper.c')
109    func_name = "APPLY_SPECIFIC(ctc_cost_cpu)"
110
111    def __init__(self, compute_grad=True, openmp=None):
112        if not ctc_available():
113            raise RuntimeError('Baidu CTC is not available and '
114                               'ConnectionistTemporalClassification Op '
115                               'can not be constructed.')
116
117        gof.COp.__init__(self, self.func_file, self.func_name)
118        gof.OpenMPOp.__init__(self, openmp=openmp)
119
120        self.compute_grad = compute_grad
121        # Return only the cost. Gradient will be returned by grad()
122        self.default_output = 0
123
124    def c_lib_dirs(self):
125        lib_dirs = []
126        if ctc_available.path is not None:
127            lib_dirs += [ctc_available.path]
128        return lib_dirs
129
130    def c_compile_args(self):
131        if ctc_available.path is not None:
132            if sys.platform != 'darwin' and ' ' in ctc_available.path:
133                return ['-Wl,-rpath,"' + ctc_available.path + '"']
134            else:
135                return ['-Wl,-rpath,' + ctc_available.path]
136        return []
137
138    def c_libraries(self):
139        return ["warpctc"]
140
141    def c_header_dirs(self):
142        header_dirs = []
143        if config.ctc.root != '':
144            # We assume here that the header is available at the include directory
145            # of the CTC root directory.
146            header_dirs += [os.path.join(config.ctc.root, "include")]
147        return header_dirs
148
149    def c_headers(self):
150        return ["ctc.h"] + gof.OpenMPOp.c_headers(self)
151
152    def make_node(self, activations, labels, input_lengths):
153        t_activations = T.as_tensor_variable(activations)
154        # Ensure activations array is C-contiguous
155        t_activations = cpu_contiguous(t_activations)
156
157        t_labels = T.as_tensor_variable(labels)
158        t_input_lengths = T.as_tensor_variable(input_lengths)
159
160        if t_activations.type.dtype != 'float32':
161            raise TypeError('activations must use the float32 type!')
162
163        if t_activations.ndim != 3:
164            raise ValueError('activations must have 3 dimensions.')
165
166        if t_labels.type.dtype != 'int32':
167            raise TypeError('labels must use the int32 type!')
168
169        if t_labels.ndim != 2:
170            raise ValueError('labels must have 2 dimensions.')
171
172        if t_input_lengths.type.dtype != 'int32':
173            raise TypeError('input_lengths must use the int32 type!')
174
175        if t_input_lengths.ndim != 1:
176            raise ValueError('input_lengths must have 1 dimension.')
177
178        costs = T.fvector(name="ctc_cost")
179        outputs = [costs]
180        if self.compute_grad:
181            gradients = T.ftensor3(name="ctc_grad")
182            outputs += [gradients]
183
184        return gof.Apply(self, inputs=[t_activations, t_labels, t_input_lengths],
185                         outputs=outputs)
186
187    def L_op(self, inputs, outputs, output_grads):
188        assert self.compute_grad and len(outputs) == 2
189        gradients = outputs[1]
190        assert gradients is not None
191
192        grad_op = output_grads[0]
193        total_grad = T.basic.batched_dot(grad_op, gradients.dimshuffle(1, 0, 2)).dimshuffle(1, 0, 2)
194        return [total_grad,
195                grad_undefined(self, 1, inputs[1]),
196                grad_undefined(self, 2, inputs[2])]
197
198
199def ctc(activations, labels, input_lengths):
200    """
201    Compute CTC loss function.
202
203    Notes
204    -----
205    Using the loss function requires that the Baidu's warp-ctc library be installed.
206    If the warp-ctc library is not on the compiler's default library path, the
207    configuration variable ``config.ctc.root`` must be properly set.
208
209    Parameters
210    ----------
211    activations
212        Three-dimensional tensor, which has a shape of (t, m, p), where
213        t is the time index, m is the minibatch index, and p is the index
214        over the probabilities of each symbol in the alphabet. The memory
215        layout is assumed to be in C-order, which consists in the slowest
216        to the fastest changing dimension, from left to right. In this case,
217        p is the fastest changing dimension.
218    labels
219        A 2-D tensor of all the labels for the minibatch. In each row, there
220        is a sequence of target labels. Negative values are assumed to be padding,
221        and thus are ignored. Blank symbol is assumed to have index 0 in the
222        alphabet.
223    input_lengths
224        A 1-D tensor with the number of time steps for each sequence in
225        the minibatch.
226
227    Returns
228    -------
229    1-D array
230        Cost of each example in the minibatch.
231    """
232    return ConnectionistTemporalClassification()(activations, labels, input_lengths)
233
234
235# Disable gradient computation if not needed
236@register_canonicalize('fast_compile')
237@local_optimizer([ConnectionistTemporalClassification])
238def local_ctc_no_grad(node):
239    if isinstance(node.op, ConnectionistTemporalClassification):
240        if len(node.outputs) > 1:
241            if len(node.outputs[1].clients) == 0:   # gradient is not used
242                return [ConnectionistTemporalClassification(compute_grad=False)(*node.inputs), None]
243    return False
244