1from __future__ import absolute_import, print_function, division 2import numpy as np 3from theano.gof.type import Type 4from theano.gof.graph import Variable, Apply, Constant 5from theano.gof.op import Op 6from theano.gof.opt import * 7from theano.gof.fg import FunctionGraph as Env 8from theano.gof.toolbox import * 9import theano.tensor.basic as T 10 11 12def as_variable(x): 13 if not isinstance(x, Variable): 14 raise TypeError("not a Variable", x) 15 return x 16 17 18class MyType(Type): 19 20 def filter(self, data): 21 return data 22 23 def __eq__(self, other): 24 return isinstance(other, MyType) 25 26 27class MyOp(Op): 28 29 def __init__(self, name, dmap=None, x=None): 30 if dmap is None: 31 dmap = {} 32 self.name = name 33 self.destroy_map = dmap 34 self.x = x 35 36 def make_node(self, *inputs): 37 inputs = list(map(as_variable, inputs)) 38 for input in inputs: 39 if not isinstance(input.type, MyType): 40 raise Exception("Error 1") 41 outputs = [MyType()()] 42 return Apply(self, inputs, outputs) 43 44 def __str__(self): 45 return self.name 46 47 def __repr__(self): 48 return self.name 49 50 def __eq__(self, other): 51 return (self is other or isinstance(other, MyOp) and self.x is not None 52 and self.x == other.x) 53 54 def __hash__(self): 55 if self.x is not None: 56 return self.x 57 else: return id(self) 58op1 = MyOp('Op1') 59 60 61def test_merge_with_weird_eq(): 62 # numpy arrays don't compare equal like other python objects 63 64 # SCALAR CASE 65 x = T.constant(np.asarray(1), name='x') 66 y = T.constant(np.asarray(1), name='y') 67 g = Env([x, y], [x+y]) 68 MergeOptimizer().optimize(g) 69 70 assert len(g.apply_nodes) == 1 71 node = list(g.apply_nodes)[0] 72 assert len(node.inputs) == 2 73 assert node.inputs[0] is node.inputs[1] 74 75 # NONSCALAR CASE 76 # This was created to test TensorConstantSignature 77 x = T.constant(np.ones(5), name='x') 78 y = T.constant(np.ones(5), name='y') 79 g = Env([x, y], [x+y]) 80 MergeOptimizer().optimize(g) 81 82 assert len(g.apply_nodes) == 1 83 node = list(g.apply_nodes)[0] 84 assert len(node.inputs) == 2 85 assert node.inputs[0] is node.inputs[1] 86