1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Expression Intrinsics and math functions in TVM."""
18# pylint: disable=redefined-builtin
19from __future__ import absolute_import as _abs
20
21from ._ffi.function import register_func as _register_func
22from . import make as _make
23from .api import convert, const
24from .expr import Call as _Call
25from .schedule import Buffer as _Buffer
26
27def _pack_buffer(buf):
28    """Build intrinsics that packs the buffer.
29    """
30    assert buf.shape
31    shape = _make.Call("handle", "tvm_stack_make_shape", buf.shape,
32                       _Call.Intrinsic, None, 0)
33    strides = _make.Call("handle", "tvm_stack_make_shape", buf.strides,
34                         _Call.Intrinsic, None, 0) if buf.strides else 0
35    pack_args = [buf.data,
36                 shape,
37                 strides,
38                 len(buf.shape),
39                 const(0, dtype=buf.dtype),
40                 buf.elem_offset]
41    return _make.Call("handle", "tvm_stack_make_array",
42                      pack_args, _Call.Intrinsic, None, 0)
43
44def call_packed(*args):
45    """Build expression by call an external packed function.
46
47    The argument to packed function can be Expr or Buffer.
48    The argument is the corresponding POD type when Expr is presented.
49
50    When the argument is Buffer, the corresponding PackedFunc
51    will recieve an TVMArrayHandle whose content is valid during the callback period.
52    If the PackedFunc is a python callback, then the corresponding argument is NDArray.
53
54    Parameters
55    ----------
56    args : list of Expr or Buffer.
57        Positional arguments.
58
59    Returns
60    -------
61    call : Expr
62        The call expression.
63
64    See Also
65    --------
66    tvm.extern : Create tensor with extern function call.
67    """
68    call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args]
69    return _make.Call(
70        "int32", "tvm_call_packed", call_args, _Call.Intrinsic, None, 0)
71
72
73def call_pure_intrin(dtype, func_name, *args):
74    """Build expression by calling a pure intrinsic function.
75
76    Intrinsics can be overloaded with multiple data types via
77    the intrinsic translation rule.
78
79    Parameters
80    ----------
81    dtype : str
82        The data type of the result.
83
84    func_name: str
85        The intrinsic function name.
86
87    args : list
88        Positional arguments.
89
90    Returns
91    -------
92    call : Expr
93        The call expression.
94    """
95    args = convert(args)
96    return _make.Call(
97        dtype, func_name, convert(args), _Call.PureIntrinsic, None, 0)
98
99
100def call_intrin(dtype, func_name, *args):
101    """Build expression by calling an intrinsic function.
102
103    Intrinsics can be overloaded with multiple data types via
104    the intrinsic translation rule.
105
106    Parameters
107    ----------
108    dtype : str
109        The data type of the result.
110
111    func_name: str
112        The intrinsic function name.
113
114    args : list
115        Positional arguments.
116
117    Returns
118    -------
119    call : Expr
120        The call expression.
121    """
122    args = convert(args)
123    return _make.Call(
124        dtype, func_name, convert(args), _Call.Intrinsic, None, 0)
125
126
127def call_pure_extern(dtype, func_name, *args):
128    """Build expression by calling a pure extern function.
129
130    Parameters
131    ----------
132    dtype : str
133        The data type of the result.
134
135    func_name: str
136        The extern function name.
137
138    args : list
139        Positional arguments.
140
141    Returns
142    -------
143    call : Expr
144        The call expression.
145    """
146    return _make.Call(
147        dtype, func_name, convert(args), _Call.PureExtern, None, 0)
148
149
150def call_extern(dtype, func_name, *args):
151    """Build expression by calling a extern function.
152
153    Parameters
154    ----------
155    dtype : str
156        The data type of the result.
157
158    func_name: str
159        The extern function name.
160
161    args : list
162        Positional arguments.
163
164    Returns
165    -------
166    call : Expr
167        The call expression.
168    """
169    return _make.Call(
170        dtype, func_name, convert(args), _Call.Extern, None, 0)
171
172
173def call_llvm_intrin(dtype, name, *args):
174    """Build expression by calling an llvm intrinsic function
175
176    Parameters
177    ----------
178    dtype : str
179       The data type of the result.
180
181    name : str
182       The name of the llvm intrinsic function.
183
184    args : list
185       Poistional arguments.
186
187    Returns
188    -------
189    call : Expr
190        The call expression.
191    """
192    import tvm
193    llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
194    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
195    return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
196
197
198def exp(x):
199    """Take exponetial of input x.
200
201    Parameters
202    ----------
203    x : Expr
204        Input argument.
205
206    Returns
207    -------
208    y : Expr
209        The result.
210    """
211    return call_pure_intrin(x.dtype, "exp", x)
212
213
214def erf(x):
215    """Take gauss error function of the input x.
216
217    Parameters
218    ----------
219    x : Expr
220        Input argument.
221
222    Returns
223    -------
224    y : Expr
225        The result.
226    """
227    return call_pure_intrin(x.dtype, "erf", x)
228
229
230def tanh(x):
231    """Take hyperbolic tanh of input x.
232
233    Parameters
234    ----------
235    x : Expr
236        Input argument.
237
238    Returns
239    -------
240    y : Expr
241        The result.
242    """
243    return call_pure_intrin(x.dtype, "tanh", x)
244
245
246def sigmoid(x):
247    """Quick function to get sigmoid
248
249    Parameters
250    ----------
251    x : Expr
252        Input argument.
253
254    Returns
255    -------
256    y : Expr
257        The result.
258    """
259    return call_pure_intrin(x.dtype, "sigmoid", x)
260
261
262def log(x):
263    """Take log of input x.
264
265    Parameters
266    ----------
267    x : Expr
268        Input argument.
269
270    Returns
271    -------
272    y : Expr
273        The result.
274    """
275    return call_pure_intrin(x.dtype, "log", x)
276
277def cos(x):
278    """Take cos of input x.
279
280    Parameters
281    ----------
282    x : Expr
283        Input argument.
284
285    Returns
286    -------
287    y : Expr
288        The result.
289    """
290    return call_pure_intrin(x.dtype, "cos", x)
291
292def sin(x):
293    """Take sin of input x.
294
295    Parameters
296    ----------
297    x : Expr
298        Input argument.
299
300    Returns
301    -------
302    y : Expr
303        The result.
304    """
305    return call_pure_intrin(x.dtype, "sin", x)
306
307def atan(x):
308    """Take atan of input x.
309
310    Parameters
311    ----------
312    x : Expr
313        Input argument.
314
315    Returns
316    -------
317    y : Expr
318        The result.
319    """
320    return call_pure_intrin(x.dtype, "atan", x)
321
322def sqrt(x):
323    """Take square root of input x.
324
325    Parameters
326    ----------
327    x : Expr
328        Input argument.
329
330    Returns
331    -------
332    y : Expr
333        The result.
334    """
335    return call_pure_intrin(x.dtype, "sqrt", x)
336
337
338def rsqrt(x):
339    """Take reciprocal of square root of input x.
340
341    Parameters
342    ----------
343    x : Expr
344        Input argument.
345
346    Returns
347    -------
348    y : Expr
349        The result.
350    """
351    return call_pure_intrin(x.dtype, "rsqrt", x)
352
353
354def floor(x):
355    """Take floor of float input x.
356
357    Parameters
358    ----------
359    x : Expr
360        Input argument.
361
362    Returns
363    -------
364    y : Expr
365        The result.
366    """
367    return _make.floor(x)
368
369
370def ceil(x):
371    """Take ceil of float input x.
372
373    Parameters
374    ----------
375    x : Expr
376        Input argument.
377
378    Returns
379    -------
380    y : Expr
381        The result.
382    """
383    return _make.ceil(x)
384
385
386def trunc(x):
387    """Get truncated value of the input.
388
389    The truncated value of the scalar x is the
390    nearest integer i which is closer to zero than x is.
391
392    Parameters
393    ----------
394    x : Expr
395        Input argument.
396
397    Returns
398    -------
399    y : Expr
400        The result.
401    """
402    return _make.trunc(x)
403
404
405def abs(x):
406    """Get absolute value of the input element-wise.
407
408    Parameters
409    ----------
410    x : Expr
411        Input argument.
412
413    Returns
414    -------
415    y : Expr
416        The result.
417    """
418    return _make.abs(x)
419
420
421def round(x):
422    """Round elements of the array to the nearest integer.
423
424    Parameters
425    ----------
426    x : Expr
427        Input argument.
428
429    Returns
430    -------
431    y : Expr
432        The result.
433    """
434    return _make.round(x)
435
436
437def nearbyint(x):
438    """Round elements of the array to the nearest integer.
439    This intrinsic uses llvm.nearbyint instead of llvm.round
440    which is faster but will results different from tvm.round.
441    Notably nearbyint rounds according to the rounding mode,
442    whereas tvm.round (llvm.round) ignores that.
443    For differences between the two see:
444    https://en.cppreference.com/w/cpp/numeric/math/round
445    https://en.cppreference.com/w/cpp/numeric/math/nearbyint
446
447    Parameters
448    ----------
449    x : Expr
450        Input argument.
451
452    Returns
453    -------
454    y : Expr
455        The result.
456    """
457    return _make.nearbyint(x)
458
459
460def isnan(x):
461    """Check if input value is Nan.
462
463    Parameters
464    ----------
465    x : Expr
466        Input argument.
467
468    Returns
469    -------
470    y : Expr
471        The result.
472    """
473    return _make.isnan(x)
474
475
476def power(x, y):
477    """x power y
478
479    Parameters
480    ----------
481    x : Expr
482        Input argument.
483
484    y : Expr
485        The exponent
486
487    Returns
488    -------
489    z : Expr
490        The result.
491    """
492    return _make._OpPow(convert(x), convert(y))
493
494
495def popcount(x):
496    """Count the number of set bits in input x.
497
498    Parameters
499    ----------
500    x : Expr
501        Input argument.
502
503    Returns
504    -------
505    y : Expr
506        The result.
507    """
508    return call_pure_intrin(x.dtype, "popcount", x)
509
510def fmod(x, y):
511    """Return the remainder of x divided by y with the same sign as x.
512
513    Parameters
514    ----------
515    x : Expr
516        Input argument.
517    y : Expr
518        Input argument.
519
520    Returns
521    -------
522    z : Expr
523        The result.
524    """
525    return call_pure_intrin(x.dtype, "fmod", x, y)
526
527
528def if_then_else(cond, t, f):
529    """Conditional selection expression.
530
531    Parameters
532    ----------
533    cond : Expr
534        The condition
535
536    t : Expr
537        The result expression if cond is true.
538
539    f : Expr
540        The result expression if cond is false.
541
542    Returns
543    -------
544    result : Node
545        The result of conditional expression.
546
547    Note
548    ----
549    Unlike Select, if_then_else will not execute
550    the branch that does not satisfy the condition.
551    You can use it to guard against out of bound access.
552    Unlike Select, if_then_else cannot be vectorized
553    if some lanes in the vector have different conditions.
554    """
555    return _make._OpIfThenElse(convert(cond), convert(t), convert(f))
556
557
558# Intrinsic rule related code
559def register_intrin_rule(target, intrin, f=None, override=False):
560    """Register an intrinsic function generation rule.
561
562    Intrinsic generation rules are callback functions for
563    code generator to get device specific calls.
564    This function simply translates to.
565
566    :code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)`
567
568    TVM may already pre-register intrinsic rules in the backend.
569    However, user can use this function to change the intrinsic translation
570    behavior or add new intrinsic rules during runtime.
571
572    Parameters
573    ----------
574    target : str
575        The name of codegen target.
576
577    intrin : str
578        The name of the instrinsic.
579
580    f : function, optional
581        The function to be registered.
582
583    override: boolean optional
584        Whether override existing entry.
585
586    Returns
587    -------
588    fregister : function
589        Register function if f is not specified.
590
591    Examples
592    --------
593    The following code registers exp expansion rule for opencl.
594
595    .. code-block:: python
596
597        register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
598    """
599    return _register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
600
601
602def _rule_float_suffix(op):
603    """Intrinsic rule: Add float suffix if it is float32.
604
605    This is an example intrinsic generation rule.
606
607    Parameters
608    ----------
609    op : Expr
610        The call expression of original intrinsic.
611
612    Returns
613    -------
614    ret : Expr
615        The translated intrinsic rule.
616        Return same op if no translation is possible.
617
618    See Also
619    --------
620    register_intrin_rule : The registeration function for intrin rule.
621    """
622    if op.dtype == "float32":
623        return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
624    if op.dtype == "float64":
625        return call_pure_extern(op.dtype, op.name, *op.args)
626    return op
627
628
629def _rule_float_direct(op):
630    """Intrinsic rule: Directly call pure extern function for floats.
631
632    This is an example intrinsic generation rule.
633
634    Parameters
635    ----------
636    op : Expr
637        The call expression of original intrinsic.
638
639    Returns
640    -------
641    ret : Expr
642        The translated intrinsic rule.
643        Return same op if no translation is possible.
644
645    See Also
646    --------
647    register_intrin_rule : The registeration function for intrin rule.
648    """
649    if str(op.dtype).startswith("float"):
650        return call_pure_extern(op.dtype, op.name, *op.args)
651    return None
652
653@_register_func("tvm.default_trace_action")
654def _tvm_default_trace_action(*args):
655    print(list(args))
656
657def trace(args, trace_action="tvm.default_trace_action"):
658    """Trace tensor data at the runtime.
659
660    The trace function allows to trace specific tensor at the
661    runtime. The tracing value should come as last argument.
662    The trace action should be specified, by default
663    tvm.default_trace_action is used.
664
665    Parameters
666    ----------
667    args : list of Expr or Buffers.
668        Positional arguments.
669
670    trace_action : str.
671        The name of the trace action.
672
673    Returns
674    -------
675    call : Expr
676        The call expression.
677
678    See Also
679    --------
680    tvm.call_packed : Creates packed function.
681    """
682    if not isinstance(args, list):
683        raise Exception("tvm.trace consumes the args as list type")
684    call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args]
685    call_args.insert(0, trace_action)
686    return _make.Call(
687        args[-1].dtype, "tvm_call_trace_packed", call_args, _Call.Intrinsic, None, 0)
688
689# opencl pattern for exp
690register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
691# default pattern for exp
692register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
693