1from __future__ import absolute_import, print_function, division
2from theano import scalar as scal
3from . import elemwise
4from theano import printing
5from theano.printing import pprint
6
7
8def _scal_inplace(symbol):
9    """Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
10    symbolname = symbol.__name__
11    inplace = symbolname.endswith('_inplace')
12
13    if inplace:
14        scalar_op = getattr(scal, symbolname[:-len('_inplace')])
15        inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
16        rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname)
17    else:
18        scalar_op = getattr(scal, symbolname)
19        rval = elemwise.Elemwise(scalar_op, name=symbolname)
20
21    if getattr(symbol, '__doc__', False):
22        rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__
23
24    # for the meaning of this see the ./epydoc script
25    # it makes epydoc display rval as if it were a function, not an object
26    rval.__epydoc_asRoutine = symbol
27    rval.__module__ = 'theano.tensor.inplace'
28
29    pprint.assign(rval, printing.FunctionPrinter(symbolname.replace('_inplace', '=')))
30    return rval
31
32
33@_scal_inplace
34def lt_inplace(a, b):
35    """a < b (inplace on a)"""
36
37
38@_scal_inplace
39def gt_inplace(a, b):
40    """a > b (inplace on a)"""
41
42
43@_scal_inplace
44def le_inplace(a, b):
45    """a <= b (inplace on a)"""
46
47
48@_scal_inplace
49def ge_inplace(a, b):
50    """a >= b (inplace on a)"""
51
52
53@_scal_inplace
54def eq_inplace(a, b):
55    """a == b (inplace on a)"""
56
57
58@_scal_inplace
59def neq_inplace(a, b):
60    """a != b (inplace on a)"""
61
62
63@_scal_inplace
64def and__inplace(a, b):
65    """bitwise a & b (inplace on a)"""
66
67
68@_scal_inplace
69def or__inplace(a, b):
70    """bitwise a | b (inplace on a)"""
71
72
73@_scal_inplace
74def xor_inplace(a, b):
75    """bitwise a ^ b (inplace on a)"""
76
77
78@_scal_inplace
79def invert_inplace(a):
80    """bitwise ~a (inplace on a)"""
81
82
83@_scal_inplace
84def abs__inplace(a):
85    """|`a`| (inplace on `a`)"""
86
87
88@_scal_inplace
89def exp_inplace(a):
90    """e^`a` (inplace on `a`)"""
91
92
93@_scal_inplace
94def exp2_inplace(a):
95    """2^`a` (inplace on `a`)"""
96
97
98@_scal_inplace
99def expm1_inplace(a):
100    """e^`a` - 1 (inplace on `a`)"""
101
102
103@_scal_inplace
104def neg_inplace(a):
105    """-a (inplace on a)"""
106
107
108@_scal_inplace
109def inv_inplace(a):
110    """1.0/a (inplace on a)"""
111
112
113@_scal_inplace
114def log_inplace(a):
115    """base e logarithm of a (inplace on a)"""
116
117
118@_scal_inplace
119def log1p_inplace(a):
120    """log(1+a)"""
121
122
123@_scal_inplace
124def log2_inplace(a):
125    """base 2 logarithm of a (inplace on a)"""
126
127
128@_scal_inplace
129def log10_inplace(a):
130    """base 10 logarithm of a (inplace on a)"""
131
132
133@_scal_inplace
134def sgn_inplace(a):
135    """sign of `a` (inplace on `a`)"""
136
137
138@_scal_inplace
139def ceil_inplace(a):
140    """ceil of `a` (inplace on `a`)"""
141
142
143@_scal_inplace
144def floor_inplace(a):
145    """floor of `a` (inplace on `a`)"""
146
147
148@_scal_inplace
149def trunc_inplace(a):
150    """trunc of `a` (inplace on `a`)"""
151
152
153@_scal_inplace
154def round_half_to_even_inplace(a):
155    """round_half_to_even_inplace(a) (inplace on `a`)"""
156
157
158@_scal_inplace
159def round_half_away_from_zero_inplace(a):
160    """round_half_away_from_zero_inplace(a) (inplace on `a`)"""
161
162
163@_scal_inplace
164def sqr_inplace(a):
165    """square of `a` (inplace on `a`)"""
166
167
168@_scal_inplace
169def sqrt_inplace(a):
170    """square root of `a` (inplace on `a`)"""
171
172
173@_scal_inplace
174def deg2rad_inplace(a):
175    """convert degree `a` to radian(inplace on `a`)"""
176
177
178@_scal_inplace
179def rad2deg_inplace(a):
180    """convert radian `a` to degree(inplace on `a`)"""
181
182
183@_scal_inplace
184def cos_inplace(a):
185    """cosine of `a` (inplace on `a`)"""
186
187
188@_scal_inplace
189def arccos_inplace(a):
190    """arccosine of `a` (inplace on `a`)"""
191
192
193@_scal_inplace
194def sin_inplace(a):
195    """sine of `a` (inplace on `a`)"""
196
197
198@_scal_inplace
199def arcsin_inplace(a):
200    """arcsine of `a` (inplace on `a`)"""
201
202
203@_scal_inplace
204def tan_inplace(a):
205    """tangent of `a` (inplace on `a`)"""
206
207
208@_scal_inplace
209def arctan_inplace(a):
210    """arctangent of `a` (inplace on `a`)"""
211
212
213@_scal_inplace
214def arctan2_inplace(a, b):
215    """arctangent of `a` / `b` (inplace on `a`)"""
216
217
218@_scal_inplace
219def cosh_inplace(a):
220    """hyperbolic cosine of `a` (inplace on `a`)"""
221
222
223@_scal_inplace
224def arccosh_inplace(a):
225    """hyperbolic arc cosine of `a` (inplace on `a`)"""
226
227
228@_scal_inplace
229def sinh_inplace(a):
230    """hyperbolic sine of `a` (inplace on `a`)"""
231
232
233@_scal_inplace
234def arcsinh_inplace(a):
235    """hyperbolic arc sine of `a` (inplace on `a`)"""
236
237
238@_scal_inplace
239def tanh_inplace(a):
240    """hyperbolic tangent of `a` (inplace on `a`)"""
241
242
243@_scal_inplace
244def arctanh_inplace(a):
245    """hyperbolic arc tangent of `a` (inplace on `a`)"""
246
247
248@_scal_inplace
249def erf_inplace(a):
250    """error function"""
251
252
253@_scal_inplace
254def erfc_inplace(a):
255    """complementary error function"""
256
257
258@_scal_inplace
259def erfcx_inplace(a):
260    """scaled complementary error function"""
261
262
263@_scal_inplace
264def gamma_inplace(a):
265    """gamma function"""
266
267
268@_scal_inplace
269def gammaln_inplace(a):
270    """log gamma function"""
271
272
273@_scal_inplace
274def psi_inplace(a):
275    """derivative of log gamma function"""
276
277
278@_scal_inplace
279def tri_gamma_inplace(a):
280    """second derivative of the log gamma function"""
281
282
283@_scal_inplace
284def chi2sf_inplace(x, k):
285    """chi squared survival function"""
286
287
288@_scal_inplace
289def j0_inplace(x):
290    """Bessel function of the first kind of order 0."""
291
292
293@_scal_inplace
294def j1_inplace(x):
295    """Bessel function of the first kind of order 1."""
296
297
298@_scal_inplace
299def jv_inplace(v, x):
300    """Bessel function of the first kind of order v (real)."""
301
302
303@_scal_inplace
304def i0_inplace(x):
305    """Modified Bessel function of the first kind of order 0."""
306
307
308@_scal_inplace
309def i1_inplace(x):
310    """Modified Bessel function of the first kind of order 1."""
311
312
313@_scal_inplace
314def iv_inplace(v, x):
315    """Modified Bessel function of the first kind of order v (real)."""
316
317
318@_scal_inplace
319def second_inplace(a):
320    """Fill `a` with `b`"""
321
322fill_inplace = second_inplace
323pprint.assign(fill_inplace, printing.FunctionPrinter('fill='))
324
325
326@_scal_inplace
327def maximum_inplace(a, b):
328    """elementwise addition (inplace on `a`)"""
329
330
331@_scal_inplace
332def minimum_inplace(a, b):
333    """elementwise addition (inplace on `a`)"""
334
335
336@_scal_inplace
337def add_inplace(a, b):
338    """elementwise addition (inplace on `a`)"""
339
340
341@_scal_inplace
342def sub_inplace(a, b):
343    """elementwise subtraction (inplace on `a`)"""
344
345
346@_scal_inplace
347def mul_inplace(a, b):
348    """elementwise multiplication (inplace on `a`)"""
349
350
351@_scal_inplace
352def true_div_inplace(a, b):
353    """elementwise division (inplace on `a`)"""
354
355
356@_scal_inplace
357def int_div_inplace(a, b):
358    """elementwise division (inplace on `a`)"""
359
360
361@_scal_inplace
362def mod_inplace(a, b):
363    """elementwise modulo (inplace on `a`)"""
364
365
366@_scal_inplace
367def pow_inplace(a, b):
368    """elementwise power (inplace on `a`)"""
369
370
371@_scal_inplace
372def conj_inplace(a):
373    """elementwise conjugate (inplace on `a`)"""
374
375pprint.assign(add_inplace, printing.OperatorPrinter('+=', -2, 'either'))
376pprint.assign(mul_inplace, printing.OperatorPrinter('*=', -1, 'either'))
377pprint.assign(sub_inplace, printing.OperatorPrinter('-=', -2, 'left'))
378pprint.assign(neg_inplace, printing.OperatorPrinter('-=', 0, 'either'))
379pprint.assign(true_div_inplace, printing.OperatorPrinter('/=', -1, 'left'))
380pprint.assign(int_div_inplace, printing.OperatorPrinter('//=', -1, 'left'))
381pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right'))
382
383
384def transpose_inplace(x, **kwargs):
385    "Perform a transpose on a tensor without copying the underlying storage"
386    dims = list(range(x.ndim - 1, -1, -1))
387    return elemwise.DimShuffle(x.broadcastable, dims, inplace=True)(x)
388