1_Int = int
2_Float = float
3from _mnncengine._expr import *
4import _mnncengine._expr as _F
5
6_numpy_supported = False
7try:
8    import numpy as np
9    _numpy_supported = True
10except Exception:
11    print ("Numpy not found. Using MNN without numpy.")
12
13def _to_var(x, to_float=True):
14    if _numpy_supported:
15        if isinstance(x, np.ndarray): # convert numpy ndarray to MNN var
16            if to_float:
17                if x.dtype != np.float32:
18                    x = x.astype(np.float32)
19                return _F.const(x, x.shape)
20            if not to_float:
21                if x.dtype != np.int32:
22                    x = x.astype(np.int32)
23                return _F.const(x, x.shape, dtype=_F.int)
24        elif isinstance(x, (list, tuple)) and x: # convert list and tuple to MNN Var
25            x = np.array(x)
26            if to_float:
27                if x.dtype != np.float32:
28                    x = x.astype(np.float32)
29                return _F.const(x, x.shape)
30            if not to_float:
31                if x.dtype != np.int32:
32                    x = x.astype(np.int32)
33                return _F.const(x, x.shape, dtype=_F.int)
34    else: # No numpy support
35        if isinstance(x, _Int):
36            return _F.const(x, [], dtype=_F.int)
37        elif isinstance(x, _Float):
38            return _F.const(x, [], dtype=_F.float)
39    return x
40def scalar(value):
41    if type(value) == type(1):
42        res = _F.const([value], [], _F.NCHW, _F.int)
43        return res
44    elif type(value) == type(1.):
45        res = _F.const([value], [], _F.NCHW, _F.float)
46        return res
47    else:
48        raise NotImplementedError("not supported data type for creating scalar variable")
49def sign(x):
50    x = _to_var(x)
51    if not isinstance(x, Var):
52        raise RuntimeError("parameter x is not valid")
53    return _F.sign(x)
54def floor(x):
55    x = _to_var(x)
56    if not isinstance(x, Var):
57        raise RuntimeError("parameter x is not valid")
58    return _F.floor(x)
59def ceil(x):
60    x = _to_var(x)
61    if not isinstance(x, Var):
62        raise RuntimeError("parameter x is not valid")
63    return _F.ceil(x)
64def square(x):
65    x = _to_var(x)
66    if not isinstance(x, Var):
67        raise RuntimeError("parameter x is not valid")
68    return _F.square(x)
69def sqrt(x):
70    x = _to_var(x)
71    if not isinstance(x, Var):
72        raise RuntimeError("parameter x is not valid")
73    return _F.sqrt(x)
74def rsqrt(x):
75    x = _to_var(x)
76    if not isinstance(x, Var):
77        raise RuntimeError("parameter x is not valid")
78    return _F.rsqrt(x)
79def exp(x):
80    x = _to_var(x)
81    if not isinstance(x, Var):
82        raise RuntimeError("parameter x is not valid")
83    return _F.exp(x)
84def log(x):
85    x = _to_var(x)
86    if not isinstance(x, Var):
87        raise RuntimeError("parameter x is not valid")
88    return _F.log(x)
89def sin(x):
90    x = _to_var(x)
91    if not isinstance(x, Var):
92        raise RuntimeError("parameter x is not valid")
93    return _F.sin(x)
94def cos(x):
95    x = _to_var(x)
96    if not isinstance(x, Var):
97        raise RuntimeError("parameter x is not valid")
98    return _F.cos(x)
99def tan(x):
100    x = _to_var(x)
101    if not isinstance(x, Var):
102        raise RuntimeError("parameter x is not valid")
103    return _F.tan(x)
104def asin(x):
105    x = _to_var(x)
106    if not isinstance(x, Var):
107        raise RuntimeError("parameter x is not valid")
108    return _F.asin(x)
109def acos(x):
110    x = _to_var(x)
111    if not isinstance(x, Var):
112        raise RuntimeError("parameter x is not valid")
113    return _F.acos(x)
114def atan(x):
115    x = _to_var(x)
116    if not isinstance(x, Var):
117        raise RuntimeError("parameter x is not valid")
118    return _F.atan(x)
119def log1p(x):
120    x = _to_var(x)
121    if not isinstance(x, Var):
122        raise RuntimeError("parameter x is not valid")
123    return _F.log1p(x)
124def tanh(x):
125    x = _to_var(x)
126    if not isinstance(x, Var):
127        raise RuntimeError("parameter x is not valid")
128    return _F.tanh(x)
129def sigmoid(x):
130    x = _to_var(x)
131    if not isinstance(x, Var):
132        raise RuntimeError("parameter x is not valid")
133    return _F.sigmoid(x)
134def minimum(x, y):
135    x = _to_var(x)
136    y = _to_var(y)
137    if not isinstance(x, Var):
138        raise RuntimeError("parameter x is not valid")
139    if not isinstance(y, Var):
140        raise RuntimeError("parameter y is not valid")
141    return _F.minimum(x, y)
142def maximum(x, y):
143    x = _to_var(x)
144    y = _to_var(y)
145    if not isinstance(x, Var):
146        raise RuntimeError("parameter x is not valid")
147    if not isinstance(y, Var):
148        raise RuntimeError("parameter y is not valid")
149    return _F.maximum(x, y)
150def bias_add(value, bias):
151    """
152    Adds bias to value.
153
154    This is (mostly) a special case of add where bias is restricted to 1-D.
155    Broadcasting is supported, so value may have any number of dimensions.
156    Unlike add, the type of bias is allowed to differ from value in the case where both types are quantized.
157
158    Example usage:
159    >>> MNN.expr.bias_add(np.eye(3,3), np.ones(3))
160    array([[2., 1., 1.],
161       [1., 2., 1.],
162       [1., 1., 2.]], dtype=float32)
163
164    Args:
165    value: A variable with type dtype.float or dtype.int.
166    bias: A 1-D variable with size matching the channel dimension of value.
167          Must be the same type as value unless value is a quantized type, in which case a different quantized type may be used.
168
169    Returns:
170    A variable with the same type as value.
171  """
172    value = _to_var(value)
173    bias = _to_var(bias)
174    if not isinstance(value, Var):
175        raise RuntimeError("parameter value is not valid")
176    if not isinstance(bias, Var):
177        raise RuntimeError("parameter bias is not valid")
178    if len(bias.shape) != 1:
179        raise RuntimeError("parameter bias must be 1-D in bias_add")
180    if value.shape[-1] != bias.shape[-1]:
181        raise RuntimeError("parameter bias's dim must match parameter value's dim in bias_add")
182    return _F.bias_add(value, bias)
183def unravel_index(indices, dims):
184    indices = _to_var(indices, to_float=False)
185    dims = _to_var(dims, to_float=False)
186    if not isinstance(indices, Var):
187        raise RuntimeError("parameter indices is not valid")
188    if not isinstance(dims, Var):
189        raise RuntimeError("parameter dims is not valid")
190    return _F.unravel_index(indices, dims)
191def one_hot(indices, depth, on_value=1., off_value=0., axis=-1):
192    indices = _to_var(indices, to_float=False)
193    if not isinstance(indices, Var):
194        raise RuntimeError("parameter indices is not valid")
195    return _F.one_hot(indices, depth, on_value, off_value, axis)
196def broadcast_to(input, shape):
197    shape = _to_var(shape, to_float=False)
198    if not isinstance(input, Var):
199        raise RuntimeError("parameter input is not valid")
200    if not isinstance(shape, Var):
201        raise RuntimeError("parameter shape is not valid")
202    return _F.broadcast_to(input, shape)
203def zeros_like(input):
204    input = _to_var(input)
205    if not isinstance(input, Var):
206        raise RuntimeError("parameter input is not valid")
207    return _F.zeros_like(input)
208def range(start, limit, delta):
209    start = _to_var(start)
210    limit = _to_var(limit)
211    delta = _to_var(delta)
212    if not isinstance(start, Var):
213        raise RuntimeError("parameter start is not valid")
214    if not isinstance(limit, Var):
215        raise RuntimeError("parameter limit is not valid")
216    if not isinstance(delta, Var):
217        raise RuntimeError("parameter delta is not valid")
218    if limit.dtype != start.dtype or delta.dtype != start.dtype:
219        raise RuntimeError("parameter start/limit/delta must use same data type, either all int or all float")
220    return _F.range(start, limit, delta)
221def rank(input):
222    input = _to_var(input)
223    if not isinstance(input, Var):
224        raise RuntimeError("parameter input is not valid")
225    return _F.rank(input)
226def space_to_batch_nd(input, block_shape, paddings):
227    input = _to_var(input)
228    block_shape = _to_var(block_shape, to_float=False)
229    paddings = _to_var(paddings, to_float=False)
230    if not isinstance(input, Var):
231        raise RuntimeError("parameter input is not valid")
232    if not isinstance(block_shape, Var):
233        raise RuntimeError("parameter block_shape is not valid")
234    if not isinstance(paddings, Var):
235        raise RuntimeError("parameter paddings is not valid")
236    if len(input.shape) != 4 or input.data_format != _F.NC4HW4:
237        raise RuntimeError("parameter input must be 4-D w/ NC4HW4 format")
238    if block_shape.dtype != _F.int or paddings.dtype != _F.int:
239        raise RuntimeError("parameter block_shape/paddings must be int type")
240    if len(block_shape.shape) != 1:
241        raise RuntimeError("parameter block_shape must be 1-D w/ shape [M]")
242    if len(paddings.shape) != 2 or paddings.shape[-1] != 2:
243        raise RuntimeError("parameter paddings must be 2-D w/ shape [M, 2]")
244    return _F.space_to_batch_nd(input, block_shape, paddings)
245def batch_to_space_nd(input, block_shape, crops):
246    input = _to_var(input)
247    block_shape = _to_var(block_shape, to_float=False)
248    crops = _to_var(crops, to_float=False)
249    if not isinstance(input, Var):
250        raise RuntimeError("parameter input is not valid")
251    if not isinstance(block_shape, Var):
252        raise RuntimeError("parameter block_shape is not valid")
253    if not isinstance(crops, Var):
254        raise RuntimeError("parameter crops is not valid")
255    if len(input.shape) != 4 or input.data_format != _F.NC4HW4:
256        raise RuntimeError("parameter input must be 4-D w/ NC4HW4 format")
257    if block_shape.dtype != _F.int or crops.dtype != _F.int:
258        raise RuntimeError("parameter block_shape/crops must be int type")
259    if len(block_shape.shape) != 1:
260        raise RuntimeError("parameter block_shape must be 1-D w/ shape [M]")
261    if len(crops.shape) != 2 or crops.shape[-1] != 2 or crops.shape[0] != block_shape.shape[0]:
262        raise RuntimeError("parameter crops must be 2-D w/ shape [M, 2]")
263    return _F.batch_to_space_nd(input, block_shape, crops)
264def setdiff1d(x, y):
265    x = _to_var(x)
266    y = _to_var(y)
267    if not isinstance(x, Var):
268        raise RuntimeError("parameter x is not valid")
269    if not isinstance(y, Var):
270        raise RuntimeError("parameter y is not valid")
271    if len(x.shape) != 1 or len(y.shape) != 1:
272        raise RuntimeError("parameter x/y must be 1-D")
273    return _F.setdiff1d(x, y)
274def moments(x, axes=[2, 3], shift=None, keep_dims=True):
275    x = _to_var(x)
276    if not isinstance(x, Var):
277        raise RuntimeError("parameter x is not valid")
278    if len(x.shape) != 4 or x.data_format != _F.NC4HW4:
279        raise RuntimeError("parameter x must be 4-D w/ NC4HW4 format")
280    if axes != [2, 3] and axes != (2, 3):
281        raise RuntimeError("parameter axes must be [2, 3] in current implementation")
282    shift = _F.const([0.], [1]) #though it's not used, it's preserved
283    return _F.moments(x, axes, shift, True)
284def matrix_band_part(input, num_lower, num_upper):
285    input = _to_var(input)
286    num_lower = _to_var(num_lower)
287    num_upper = _to_var(num_upper)
288    if not isinstance(input, Var):
289        raise RuntimeError("parameter input is not valid")
290    if not isinstance(num_lower, Var):
291        raise RuntimeError("parameter num_lower is not valid")
292    if not isinstance(num_upper, Var):
293        raise RuntimeError("parameter num_upper is not valid")
294    if len(num_lower.shape) != 0 or num_lower.dtype != _F.int:
295        raise RuntimeError("parameter num_lower must be 0-D int")
296    if len(num_upper.shape) != 0 or num_upper.dtype != _F.int:
297        raise RuntimeError("parameter num_upper must be 0-D int")
298    return _F.matrix_band_part(input, num_lower, num_upper)
299def gather_nd(params, indices):
300    params = _to_var(params)
301    indices = _to_var(indices, to_float=False)
302    if not isinstance(params, Var):
303        raise RuntimeError("parameter params is not valid")
304    if not isinstance(indices, Var):
305        raise RuntimeError("parameter indices is not valid")
306    if indices.dtype != _F.int:
307        raise RuntimeError("parameter indices must be int type")
308    return _F.gather_nd(params, indices)
309def gather(params, indices):
310    params = _to_var(params)
311    indices = _to_var(indices, to_float=False)
312    if not isinstance(params, Var):
313        raise RuntimeError("parameter params is not valid")
314    if not isinstance(indices, Var):
315        raise RuntimeError("parameter indices is not valid")
316    if indices.dtype != _F.int:
317        raise RuntimeError("parameter indices must be int type")
318    return _F.gather(params, indices)
319def fill(dims, value):
320    dims = _to_var(dims, to_float=False)
321    value = _to_var(value)
322    if not isinstance(dims, Var):
323        raise RuntimeError("parameter dims is not valid")
324    if not isinstance(value, Var):
325        raise RuntimeError("parameter value is not valid")
326    if dims.dtype != _F.int:
327        raise RuntimeError("parameter dims must be int type")
328    if len(value.shape) != 0:
329        raise RuntimeError("parameter value must be 0-D")
330    return _F.fill(dims, value)
331def tile(input, multiples):
332    input = _to_var(input)
333    multiples = _to_var(multiples, to_float=False)
334    if not isinstance(input, Var):
335        raise RuntimeError("parameter input is not valid")
336    if not isinstance(multiples, Var):
337        raise RuntimeError("parameter multiples is not valid")
338    if multiples.dtype != _F.int or len(multiples.shape) != 1:
339        raise RuntimeError("parameter multiples must be 1-D int type")
340    if len(input.shape) != multiples.shape[-1]:
341        raise RuntimeError("parameter multiples's length must match w/ number of dimensions of input")
342    return _F.tile(input, multiples)
343def shape(input):
344    input = _to_var(input)
345    if not isinstance(input, Var):
346        raise RuntimeError("parameter input is not valid")
347    return _F.shape(input)
348def softplus(features):
349    features = _to_var(features)
350    if not isinstance(features, Var):
351        raise RuntimeError("parameter features is not valid")
352    return _F.softplus(features)
353def softsign(features):
354    features = _to_var(features)
355    if not isinstance(features, Var):
356        raise RuntimeError("parameter features is not valid")
357    return _F.softsign(features)
358def stack(values, axis=0):
359    if not isinstance(values, (list, tuple)):
360        raise RuntimeError("parameter values must be a list/tuple of MNN Var")
361    if len(values) < 1:
362        raise RuntimeError("parameter values must have at least one item")
363    for value in values:
364        if not isinstance(value, Var):
365            raise RuntimeError("all items in parameter values must be MNN Var type")
366        if value.shape != values[0].shape or value.dtype != values[0].dtype:
367            raise RuntimeError("all items in parameter values must have same shape and dtype")
368    return _F.stack(values, axis)
369def slice(input, starts, sizes):
370    input = _to_var(input)
371    starts = _to_var(starts, to_float=False)
372    sizes = _to_var(sizes, to_float=False)
373    if not isinstance(input, Var):
374        raise RuntimeError("parameter input is not valid")
375    if not isinstance(starts, Var):
376        raise RuntimeError("parameter starts is not valid")
377    if not isinstance(sizes, Var):
378        raise RuntimeError("parameter sizes is not valid")
379    if starts.dtype != _F.int or sizes.dtype != _F.int:
380        raise RuntimeError("parameter starts/sizes must be int type")
381    return _F.slice(input, starts, sizes)
382def transpose(x, perm):
383    return _F.transpose(x, perm)
384def pad(x, paddings, mode=CONSTANT):
385    x = _to_var(x)
386    paddings = _to_var(paddings, to_float=False)
387    if not isinstance(x, Var):
388        raise RuntimeError("parameter x is not valid")
389    if not isinstance(paddings, Var):
390        raise RuntimeError("parameter paddings is not valid")
391    if paddings.dtype != _F.int:
392        raise RuntimeError("parameter perm must be int type")
393    if len(paddings.shape) != 2 or paddings.shape[-1] != 2 or paddings.shape[0] != len(x.shape):
394        raise RuntimeError("parameter paddings must be 2-D w/ shape[n, 2], and n is the number of dimensions of parameter x")
395    return _F.pad(x, paddings, mode)
396def resize(images, x_scale, y_scale):
397    images = _to_var(images)
398    if not isinstance(images, Var):
399        raise RuntimeError("parameter images is not valid")
400    if len(images.shape) != 4 or images.data_format != _F.NC4HW4:
401        raise RuntimeError("parameter images must be 4-D NC4HW4 format")
402    return _F.resize(images, x_scale, y_scale)
403def crop(images, size, axis, offset):
404    images = _to_var(images)
405    size = _to_var(size)
406    if not isinstance(images, Var):
407        raise RuntimeError("parameter images is not valid")
408    if not isinstance(size, Var):
409        raise RuntimeError("parameter size is not valid")
410    if len(images.shape) != 4 or images.data_format != _F.NC4HW4:
411        raise RuntimeError("parameter images must be 4-D NC4HW4 format")
412    if len(size.shape) != 4:
413        raise RuntimeError("parameter size must be 4-D")
414    if axis != 2 and axis != 3:
415        raise RuntimeError("parameter axis must be 2 or 3, if 2 you may change both h/w, if 3 only w")
416    if axis == 2:
417        if len(offset) != 1 and len(offset) !=2:
418            raise RuntimeError("parameter offset must be at most 2 if you want to change h/w")
419    if axis == 3:
420        if len(offset) != 1:
421            raise RuntimeError("parameter offset must be at most 1 if you want to change w only")
422    return _F.crop(images, size, axis, offset)
423def crop_and_resize(image, boxes, box_ind, crop_size, method=BILINEAR, extrapolation_value=0.):
424    image = _to_var(image)
425    boxes = _to_var(boxes, to_float=False)
426    box_ind = _to_var(box_ind, to_float=False)
427    crop_size = _to_var(crop_size, to_float=False)
428    if not isinstance(image, Var):
429        raise RuntimeError("parameter image is not valid")
430    if not isinstance(boxes, Var):
431        raise RuntimeError("parameter boxes is not valid")
432    if not isinstance(box_ind, Var):
433        raise RuntimeError("parameter box_ind is not valid")
434    if not isinstance(crop_size, Var):
435        raise RuntimeError("parameter crop_size is not valid")
436    if len(image.shape) != 4:
437        raise RuntimeError("parameter image must be 4-D format")
438    if boxes.dtype != _F.int or box_ind.dtype != _F.int or crop_size.dtype != _F.int:
439        raise RuntimeError("parameter boxes/box_ind/crop_size must be int type")
440    if len(boxes.shape) != 2 or boxes.shape[-1] !=4:
441        raise RuntimeError("parameter boxes must be 2-D w/ shape [num_boxes, 4]")
442    if len(box_ind.shape) != 1 or box_ind.shape[-1] != boxes.shape[0]:
443        raise RuntimeError("parameter boxes must be 1-D w/ shape [num_boxes]")
444    if len(crop_size.shape) != 1 or crop_size.shape[0] != 2:
445        raise RuntimeError("parameter boxes must be 1-D w/ shape [2]")
446    return _F.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value)
447def reverse_sequence(x, y, batch_dim, seq_dim):
448    x = _to_var(x)
449    y = _to_var(y, to_float=False)
450    if not isinstance(x, Var):
451        raise RuntimeError("parameter x is not valid")
452    if not isinstance(y, Var):
453        raise RuntimeError("parameter y is not valid")
454    if y.dtype != _F.int or len(y.shape) != 1:
455        raise RuntimeError("parameter y must be 1-D int type")
456    if batch_dim < 0 or batch_dim >= len(x.shape):
457        raise RuntimeError("parameter batch_dim must be in range of the number of dimensions of parameter x")
458    if seq_dim < 0 or seq_dim >= len(x.shape):
459        raise RuntimeError("parameter seq_dim must be in range of the number of dimensions of parameter x")
460    if y.shape[0] != x.shape[batch_dim]:
461        raise RuntimeError("parameter y must be shape [x.shape[batch.dim]]")
462    return _F.reverse_sequence(x, y, batch_dim, seq_dim)
463def reshape(x, shape, original_format=NCHW):
464    x = _to_var(x)
465    if not isinstance(x, Var):
466        raise RuntimeError("parameter x is not valid")
467    if not isinstance(shape, (list, tuple)):
468        raise RuntimeError("parameter shape is not valid")
469    new_length = 1
470    skip = False
471    for value in shape:
472        if value < 0:
473            skip = True
474        new_length *= value
475
476    if new_length != x.size and not skip:
477        raise RuntimeError("parameter shape is not valid")
478    return _F.reshape(x, shape, original_format)
479