1from __future__ import absolute_import, print_function, division
2import numpy as np
3from theano.gof.type import Type
4from theano.gof.graph import Variable, Apply, Constant
5from theano.gof.op import Op
6from theano.gof.opt import *
7from theano.gof.fg import FunctionGraph as Env
8from theano.gof.toolbox import *
9import theano.tensor.basic as T
10
11
12def as_variable(x):
13    if not isinstance(x, Variable):
14        raise TypeError("not a Variable", x)
15    return x
16
17
18class MyType(Type):
19
20    def filter(self, data):
21        return data
22
23    def __eq__(self, other):
24        return isinstance(other, MyType)
25
26
27class MyOp(Op):
28
29    def __init__(self, name, dmap=None, x=None):
30        if dmap is None:
31            dmap = {}
32        self.name = name
33        self.destroy_map = dmap
34        self.x = x
35
36    def make_node(self, *inputs):
37        inputs = list(map(as_variable, inputs))
38        for input in inputs:
39            if not isinstance(input.type, MyType):
40                raise Exception("Error 1")
41        outputs = [MyType()()]
42        return Apply(self, inputs, outputs)
43
44    def __str__(self):
45        return self.name
46
47    def __repr__(self):
48        return self.name
49
50    def __eq__(self, other):
51        return (self is other or isinstance(other, MyOp) and self.x is not None
52                and self.x == other.x)
53
54    def __hash__(self):
55        if self.x is not None:
56            return self.x
57        else: return id(self)
58op1 = MyOp('Op1')
59
60
61def test_merge_with_weird_eq():
62    # numpy arrays don't compare equal like other python objects
63
64    # SCALAR CASE
65    x = T.constant(np.asarray(1), name='x')
66    y = T.constant(np.asarray(1), name='y')
67    g = Env([x, y], [x+y])
68    MergeOptimizer().optimize(g)
69
70    assert len(g.apply_nodes) == 1
71    node = list(g.apply_nodes)[0]
72    assert len(node.inputs) == 2
73    assert node.inputs[0] is node.inputs[1]
74
75    # NONSCALAR CASE
76    # This was created to test TensorConstantSignature
77    x = T.constant(np.ones(5), name='x')
78    y = T.constant(np.ones(5), name='y')
79    g = Env([x, y], [x+y])
80    MergeOptimizer().optimize(g)
81
82    assert len(g.apply_nodes) == 1
83    node = list(g.apply_nodes)[0]
84    assert len(node.inputs) == 2
85    assert node.inputs[0] is node.inputs[1]
86