1from __future__ import absolute_import, print_function, division
2from functools import partial
3import numpy as np
4
5import theano
6from theano import config, shared
7
8from theano.gradient import DisconnectedType
9from theano.gof.null_type import NullType
10from theano.compile import function
11
12from theano import tensor as T
13from theano.tensor.shared_randomstreams import RandomStreams
14
15from theano.compile.builders import OpFromGraph
16
17from theano.tests import unittest_tools
18
19test_params = unittest_tools.parameterized.expand(
20    [(OpFromGraph,), (partial(OpFromGraph, inline=True),)])
21
22
23class T_OpFromGraph(unittest_tools.InferShapeTester):
24
25    @test_params
26    def test_straightforward(self, cls_ofg):
27        x, y, z = T.matrices('xyz')
28        e = x + y * z
29        op = cls_ofg([x, y, z], [e])
30        # (1+3*5=array of 16) - (3+1*5=array of 8)
31        f = op(x, y, z) - op(y, z, x)
32
33        fn = function([x, y, z], f)
34        xv = np.ones((2, 2), dtype=config.floatX)
35        yv = np.ones((2, 2), dtype=config.floatX) * 3
36        zv = np.ones((2, 2), dtype=config.floatX) * 5
37        # print function, function.__module__
38        # print fn.maker.fgraph.toposort()
39        fn(xv, yv, zv)
40        assert np.all(8.0 == fn(xv, yv, zv))
41        assert np.all(8.0 == fn(xv, yv, zv))
42
43    @test_params
44    def test_size_changes(self, cls_ofg):
45        x, y, z = T.matrices('xyz')
46        e = T.dot(x, y)
47        op = cls_ofg([x, y], [e])
48        f = op(x, op(y, z))
49        fn = function([x, y, z], f)
50        xv = np.ones((2, 3), dtype=config.floatX)
51        yv = np.ones((3, 4), dtype=config.floatX) * 3
52        zv = np.ones((4, 5), dtype=config.floatX) * 5
53        res = fn(xv, yv, zv)
54        assert res.shape == (2, 5)
55        assert np.all(180.0 == res)
56        res = fn(xv, yv, zv)
57        assert res.shape == (2, 5)
58        assert np.all(180.0 == res)
59
60    @test_params
61    def test_grad(self, cls_ofg):
62        x, y, z = T.matrices('xyz')
63        e = x + y * z
64        op = cls_ofg([x, y, z], [e])
65        f = op(x, y, z)
66        f = f - T.grad(T.sum(f), y)
67        fn = function([x, y, z], f)
68        xv = np.ones((2, 2), dtype=config.floatX)
69        yv = np.ones((2, 2), dtype=config.floatX) * 3
70        zv = np.ones((2, 2), dtype=config.floatX) * 5
71        assert np.all(11.0 == fn(xv, yv, zv))
72
73    @test_params
74    def test_grad_grad(self, cls_ofg):
75        x, y, z = T.matrices('xyz')
76        e = x + y * z
77        op = cls_ofg([x, y, z], [e])
78        f = op(x, y, z)
79        f = f - T.grad(T.sum(f), y)
80        f = f - T.grad(T.sum(f), y)
81        fn = function([x, y, z], f)
82        xv = np.ones((2, 2), dtype=config.floatX)
83        yv = np.ones((2, 2), dtype=config.floatX) * 3
84        zv = np.ones((2, 2), dtype=config.floatX) * 5
85        assert np.allclose(6.0, fn(xv, yv, zv))
86
87    @test_params
88    def test_shared(self, cls_ofg):
89        x, y, z = T.matrices('xyz')
90        s = shared(np.random.rand(2, 2).astype(config.floatX))
91        e = x + y * z + s
92        op = cls_ofg([x, y, z], [e])
93        # (1+3*5=array of 16) - (3+1*5=array of 8)
94        f = op(x, y, z) - op(y, z, x)
95
96        fn = function([x, y, z], f)
97        xv = np.ones((2, 2), dtype=config.floatX)
98        yv = np.ones((2, 2), dtype=config.floatX) * 3
99        zv = np.ones((2, 2), dtype=config.floatX) * 5
100        # print function, function.__module__
101        # print fn.maker.fgraph.toposort()
102        assert np.allclose(8.0, fn(xv, yv, zv))
103        assert np.allclose(8.0, fn(xv, yv, zv))
104
105    @test_params
106    def test_shared_grad(self, cls_ofg):
107        x, y, z = T.matrices('xyz')
108        s = shared(np.random.rand(2, 2).astype(config.floatX))
109        e = x + y * z + s
110        op = cls_ofg([x, y, z], [e])
111        f = op(x, y, z)
112        f = f - T.grad(T.sum(f), y)
113        fn = function([x, y, z], f)
114        xv = np.ones((2, 2), dtype=config.floatX)
115        yv = np.ones((2, 2), dtype=config.floatX) * 3
116        zv = np.ones((2, 2), dtype=config.floatX) * 5
117        assert np.allclose(11.0 + s.get_value(), fn(xv, yv, zv))
118
119        # grad again the shared variable
120        f = op(x, y, z)
121        f = f - T.grad(T.sum(f), s)
122        fn = function([x, y, z], f)
123        assert np.allclose(15.0 + s.get_value(),
124                           fn(xv, yv, zv))
125
126    @test_params
127    def test_grad_override(self, cls_ofg):
128        x, y = T.vectors('xy')
129
130        def go(inps, gs):
131            x, y = inps
132            g, = gs
133            return [g * y * 2, g * x * 1.5]
134
135        dedz = T.vector('dedz')
136        op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))
137
138        op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
139        op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
140
141        # single override case (function or OfG instance)
142        xx, yy = T.vector('xx'), T.vector('yy')
143        for op in [op_mul, op_mul2]:
144            zz = T.sum(op(xx, yy))
145            dx, dy = T.grad(zz, [xx, yy])
146            fn = function([xx, yy], [dx, dy])
147            xv = np.random.rand(16).astype(config.floatX)
148            yv = np.random.rand(16).astype(config.floatX)
149            dxv, dyv = fn(xv, yv)
150            assert np.allclose(yv * 2, dxv)
151            assert np.allclose(xv * 1.5, dyv)
152
153        # list override case
154        def go1(inps, gs):
155            x, w, b = inps
156            g = gs[0]
157            return g * w * 2
158
159        def go2(inps, gs):
160            x, w, b = inps
161            g = gs[0]
162            return g * x * 1.5
163
164        w, b = T.vectors('wb')
165        # we make the 3rd gradient default (no override)
166        op_linear = cls_ofg(
167            [x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
168        xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
169        zz = T.sum(op_linear(xx, ww, bb))
170        dx, dw, db = T.grad(zz, [xx, ww, bb])
171        fn = function([xx, ww, bb], [dx, dw, db])
172        xv = np.random.rand(16).astype(config.floatX)
173        wv = np.random.rand(16).astype(config.floatX)
174        bv = np.random.rand(16).astype(config.floatX)
175        dxv, dwv, dbv = fn(xv, wv, bv)
176        assert np.allclose(wv * 2, dxv)
177        assert np.allclose(xv * 1.5, dwv)
178        assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
179
180        # NullType and DisconnectedType
181        op_linear2 = cls_ofg(
182            [x, w, b], [x * w + b],
183            grad_overrides=[go1, NullType()(), DisconnectedType()()])
184        zz2 = T.sum(op_linear2(xx, ww, bb))
185        dx2, dw2, db2 = T.grad(
186            zz2, [xx, ww, bb],
187            return_disconnected='Disconnected',
188            disconnected_inputs='ignore',
189            null_gradients='return')
190        assert isinstance(dx2.type, T.TensorType)
191        assert dx2.ndim == 1
192        assert isinstance(dw2.type, NullType)
193        assert isinstance(db2.type, DisconnectedType)
194
195    @test_params
196    def test_lop_override(self, cls_ofg):
197        x = T.vector()
198        y = 1. / (1. + T.exp(-x))
199
200        def lop_ov(inps, outs, grads):
201            y_, = outs
202            dedy_, = grads
203            return [2. * y_ * (1. - y_) * dedy_]
204
205        y_, dedy = T.vector(), T.vector()
206        op_lop_ov = cls_ofg([x, y_, dedy], [2. * y_ * (1. - y_) * dedy])
207
208        xx = T.vector()
209        yy1 = T.sum(T.nnet.sigmoid(xx))
210        gyy1 = 2. * T.grad(yy1, xx)
211
212        for ov in [lop_ov, op_lop_ov]:
213            op = cls_ofg([x], [y], lop_overrides=ov)
214            yy2 = T.sum(op(xx))
215            gyy2 = T.grad(yy2, xx)
216            fn = function([xx], [gyy1, gyy2])
217
218            xval = np.random.rand(32).astype(config.floatX)
219            y1val, y2val = fn(xval)
220            assert np.allclose(y1val, y2val)
221
222    @test_params
223    def test_rop(self, cls_ofg):
224        a = T.vector()
225        M = T.matrix()
226        b = T.dot(a, M)
227        op_matmul = cls_ofg([a, M], [b])
228        x = T.vector()
229        W = T.matrix()
230        y = op_matmul(x, W)
231        du = T.vector()
232        dv = T.Rop(y, x, du)
233        fn = function([x, W, du], dv)
234        xval = np.random.rand(16).astype(config.floatX)
235        Wval = np.random.rand(16, 16).astype(config.floatX)
236        duval = np.random.rand(16).astype(config.floatX)
237        dvval = np.dot(duval, Wval)
238        dvval2 = fn(xval, Wval, duval)
239        assert np.allclose(dvval2, dvval)
240
241    @test_params
242    def test_rop_override(self, cls_ofg):
243        x, y = T.vectors('xy')
244
245        def ro(inps, epts):
246            x, y = inps
247            u, v = epts
248            return [u * y * 2. + x * v * 1.5]
249
250        u, v = T.vectors('uv')
251        op_mul_rop = cls_ofg([x, y, u, v], ro([x, y], [u, v]))
252        op_mul = cls_ofg([x, y], [x * y], rop_overrides=ro)
253        op_mul2 = cls_ofg([x, y], [x * y], rop_overrides=op_mul_rop)
254
255        # single override case
256        xx, yy = T.vector('xx'), T.vector('yy')
257        du, dv = T.vector('du'), T.vector('dv')
258        for op in [op_mul, op_mul2]:
259            zz = op_mul(xx, yy)
260            dw = T.Rop(zz, [xx, yy], [du, dv])
261            fn = function([xx, yy, du, dv], dw)
262            vals = np.random.rand(4, 32).astype(config.floatX)
263            dwval = fn(*vals)
264            assert np.allclose(
265                dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.)
266
267        # TODO list override case
268
269    @test_params
270    def test_connection_pattern_override(self, cls_ofg):
271        x, y = T.vectors('xy')
272
273        def f1(x, y):
274            del x
275            # but we know how to backpropagate for x for some reasons
276            # and we don't care about the gradient wrt y.
277            return y + T.round(y)
278
279        def f1_back(inputs, output_gradients):
280            return [
281                output_gradients[0],
282                theano.gradient.disconnected_type()]
283
284        op = cls_ofg(
285            inputs=[x, y],
286            outputs=[f1(x, y)],
287            grad_overrides=f1_back,
288            connection_pattern=[[True], [False]],  # This is new
289            on_unused_input='ignore')  # This is new
290
291        c = op(x, y)
292
293        g1 = theano.grad(c.sum(), x)
294
295        out = g1.eval({
296            x: np.ones((5,), dtype=np.float32),
297            y: np.ones((5,), dtype=np.float32)})
298        assert np.allclose(out, [1.] * 5)
299
300    @test_params
301    def test_nested(self, cls_ofg):
302        x, y = T.vectors('xy')
303        u, v = x + y, x - y
304        op_ft = cls_ofg([x, y], [u, v])
305        op_ift = cls_ofg([x, y], [u / 2, v / 2])
306
307        xx, yy = T.vector('xx'), T.vector('yy')
308        xx2, yy2 = op_ift(*op_ft(xx, yy))
309        fn = function([xx, yy], [xx2, yy2])
310
311        xv = np.random.rand(16).astype(config.floatX)
312        yv = np.random.rand(16).astype(config.floatX)
313        xv2, yv2 = fn(xv, yv)
314        assert np.allclose(xv, xv2)
315        assert np.allclose(yv, yv2)
316
317    @test_params
318    def test_connection_pattern(self, cls_ofg):
319        # Basic case
320        x, y, z = T.matrices('xyz')
321        out1 = x * y
322        out2 = y * z
323
324        op1 = cls_ofg([x, y, z], [out1, out2])
325        results = op1.connection_pattern(None)
326        expect_result = [[True, False],
327                         [True, True],
328                         [False, True]]
329        assert results == expect_result
330
331        # Graph with ops that don't have a 'full' connection pattern
332        # and with ops that have multiple outputs
333        m, n, p, q = T.matrices('mnpq')
334        o1, o2 = op1(m, n, p)
335        out1, out2 = op1(o1, q, o2)
336        op2 = cls_ofg([m, n, p, q], [out1, out2])
337
338        results = op2.connection_pattern(None)
339        expect_result = [[True, False],
340                         [True, True],
341                         [False, True],
342                         [True, True]]
343        assert results == expect_result
344
345        # Inner graph where some computation doesn't rely on explicit inputs
346        srng = RandomStreams(seed=234)
347        rv_u = srng.uniform((2, 2))
348        x, y = T.matrices('xy')
349        out1 = x + rv_u
350        out2 = y + 3
351        out3 = 3 + rv_u
352        op3 = cls_ofg([x, y], [out1, out2, out3])
353
354        results = op3.connection_pattern(None)
355        expect_result = [[True, False, False],
356                         [False, True, False],
357                         [True, False, True]]
358        assert results == expect_result
359
360    def test_infer_shape(self):
361        # test infer shape does not need to against inline case
362        # since the Op is remove during optimization phase
363        x = T.matrix('x')
364        y = T.matrix('y')
365        o1 = x + y
366        o2 = x * y
367        op_graph = OpFromGraph([x, y], [o1, o2])
368
369        q = T.matrix('q')
370        p = T.matrix('p')
371        self._compile_and_check([q, p],
372                                op_graph(q, p),
373                                [np.ones([3, 4], dtype=config.floatX),
374                                 np.ones([3, 4], dtype=config.floatX)],
375                                OpFromGraph)
376
377    @theano.change_flags(compute_test_value='raise')
378    def test_compute_test_value(self):
379        x = T.scalar('x')
380        x.tag.test_value = np.array(1., dtype=config.floatX)
381        op = OpFromGraph([x], [x ** 3])
382        y = T.scalar('y')
383        y.tag.test_value = np.array(1., dtype=config.floatX)
384        f = op(y)
385        grad_f = T.grad(f, y)
386        assert grad_f.tag.test_value is not None
387