1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Expression Intrinsics and math functions in TVM.""" 18# pylint: disable=redefined-builtin 19from __future__ import absolute_import as _abs 20 21from ._ffi.function import register_func as _register_func 22from . import make as _make 23from .api import convert, const 24from .expr import Call as _Call 25from .schedule import Buffer as _Buffer 26 27def _pack_buffer(buf): 28 """Build intrinsics that packs the buffer. 29 """ 30 assert buf.shape 31 shape = _make.Call("handle", "tvm_stack_make_shape", buf.shape, 32 _Call.Intrinsic, None, 0) 33 strides = _make.Call("handle", "tvm_stack_make_shape", buf.strides, 34 _Call.Intrinsic, None, 0) if buf.strides else 0 35 pack_args = [buf.data, 36 shape, 37 strides, 38 len(buf.shape), 39 const(0, dtype=buf.dtype), 40 buf.elem_offset] 41 return _make.Call("handle", "tvm_stack_make_array", 42 pack_args, _Call.Intrinsic, None, 0) 43 44def call_packed(*args): 45 """Build expression by call an external packed function. 46 47 The argument to packed function can be Expr or Buffer. 48 The argument is the corresponding POD type when Expr is presented. 49 50 When the argument is Buffer, the corresponding PackedFunc 51 will recieve an TVMArrayHandle whose content is valid during the callback period. 52 If the PackedFunc is a python callback, then the corresponding argument is NDArray. 53 54 Parameters 55 ---------- 56 args : list of Expr or Buffer. 57 Positional arguments. 58 59 Returns 60 ------- 61 call : Expr 62 The call expression. 63 64 See Also 65 -------- 66 tvm.extern : Create tensor with extern function call. 67 """ 68 call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args] 69 return _make.Call( 70 "int32", "tvm_call_packed", call_args, _Call.Intrinsic, None, 0) 71 72 73def call_pure_intrin(dtype, func_name, *args): 74 """Build expression by calling a pure intrinsic function. 75 76 Intrinsics can be overloaded with multiple data types via 77 the intrinsic translation rule. 78 79 Parameters 80 ---------- 81 dtype : str 82 The data type of the result. 83 84 func_name: str 85 The intrinsic function name. 86 87 args : list 88 Positional arguments. 89 90 Returns 91 ------- 92 call : Expr 93 The call expression. 94 """ 95 args = convert(args) 96 return _make.Call( 97 dtype, func_name, convert(args), _Call.PureIntrinsic, None, 0) 98 99 100def call_intrin(dtype, func_name, *args): 101 """Build expression by calling an intrinsic function. 102 103 Intrinsics can be overloaded with multiple data types via 104 the intrinsic translation rule. 105 106 Parameters 107 ---------- 108 dtype : str 109 The data type of the result. 110 111 func_name: str 112 The intrinsic function name. 113 114 args : list 115 Positional arguments. 116 117 Returns 118 ------- 119 call : Expr 120 The call expression. 121 """ 122 args = convert(args) 123 return _make.Call( 124 dtype, func_name, convert(args), _Call.Intrinsic, None, 0) 125 126 127def call_pure_extern(dtype, func_name, *args): 128 """Build expression by calling a pure extern function. 129 130 Parameters 131 ---------- 132 dtype : str 133 The data type of the result. 134 135 func_name: str 136 The extern function name. 137 138 args : list 139 Positional arguments. 140 141 Returns 142 ------- 143 call : Expr 144 The call expression. 145 """ 146 return _make.Call( 147 dtype, func_name, convert(args), _Call.PureExtern, None, 0) 148 149 150def call_extern(dtype, func_name, *args): 151 """Build expression by calling a extern function. 152 153 Parameters 154 ---------- 155 dtype : str 156 The data type of the result. 157 158 func_name: str 159 The extern function name. 160 161 args : list 162 Positional arguments. 163 164 Returns 165 ------- 166 call : Expr 167 The call expression. 168 """ 169 return _make.Call( 170 dtype, func_name, convert(args), _Call.Extern, None, 0) 171 172 173def call_llvm_intrin(dtype, name, *args): 174 """Build expression by calling an llvm intrinsic function 175 176 Parameters 177 ---------- 178 dtype : str 179 The data type of the result. 180 181 name : str 182 The name of the llvm intrinsic function. 183 184 args : list 185 Poistional arguments. 186 187 Returns 188 ------- 189 call : Expr 190 The call expression. 191 """ 192 import tvm 193 llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) 194 assert llvm_id != 0, "%s is not an LLVM intrinsic" % name 195 return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) 196 197 198def exp(x): 199 """Take exponetial of input x. 200 201 Parameters 202 ---------- 203 x : Expr 204 Input argument. 205 206 Returns 207 ------- 208 y : Expr 209 The result. 210 """ 211 return call_pure_intrin(x.dtype, "exp", x) 212 213 214def erf(x): 215 """Take gauss error function of the input x. 216 217 Parameters 218 ---------- 219 x : Expr 220 Input argument. 221 222 Returns 223 ------- 224 y : Expr 225 The result. 226 """ 227 return call_pure_intrin(x.dtype, "erf", x) 228 229 230def tanh(x): 231 """Take hyperbolic tanh of input x. 232 233 Parameters 234 ---------- 235 x : Expr 236 Input argument. 237 238 Returns 239 ------- 240 y : Expr 241 The result. 242 """ 243 return call_pure_intrin(x.dtype, "tanh", x) 244 245 246def sigmoid(x): 247 """Quick function to get sigmoid 248 249 Parameters 250 ---------- 251 x : Expr 252 Input argument. 253 254 Returns 255 ------- 256 y : Expr 257 The result. 258 """ 259 return call_pure_intrin(x.dtype, "sigmoid", x) 260 261 262def log(x): 263 """Take log of input x. 264 265 Parameters 266 ---------- 267 x : Expr 268 Input argument. 269 270 Returns 271 ------- 272 y : Expr 273 The result. 274 """ 275 return call_pure_intrin(x.dtype, "log", x) 276 277def cos(x): 278 """Take cos of input x. 279 280 Parameters 281 ---------- 282 x : Expr 283 Input argument. 284 285 Returns 286 ------- 287 y : Expr 288 The result. 289 """ 290 return call_pure_intrin(x.dtype, "cos", x) 291 292def sin(x): 293 """Take sin of input x. 294 295 Parameters 296 ---------- 297 x : Expr 298 Input argument. 299 300 Returns 301 ------- 302 y : Expr 303 The result. 304 """ 305 return call_pure_intrin(x.dtype, "sin", x) 306 307def atan(x): 308 """Take atan of input x. 309 310 Parameters 311 ---------- 312 x : Expr 313 Input argument. 314 315 Returns 316 ------- 317 y : Expr 318 The result. 319 """ 320 return call_pure_intrin(x.dtype, "atan", x) 321 322def sqrt(x): 323 """Take square root of input x. 324 325 Parameters 326 ---------- 327 x : Expr 328 Input argument. 329 330 Returns 331 ------- 332 y : Expr 333 The result. 334 """ 335 return call_pure_intrin(x.dtype, "sqrt", x) 336 337 338def rsqrt(x): 339 """Take reciprocal of square root of input x. 340 341 Parameters 342 ---------- 343 x : Expr 344 Input argument. 345 346 Returns 347 ------- 348 y : Expr 349 The result. 350 """ 351 return call_pure_intrin(x.dtype, "rsqrt", x) 352 353 354def floor(x): 355 """Take floor of float input x. 356 357 Parameters 358 ---------- 359 x : Expr 360 Input argument. 361 362 Returns 363 ------- 364 y : Expr 365 The result. 366 """ 367 return _make.floor(x) 368 369 370def ceil(x): 371 """Take ceil of float input x. 372 373 Parameters 374 ---------- 375 x : Expr 376 Input argument. 377 378 Returns 379 ------- 380 y : Expr 381 The result. 382 """ 383 return _make.ceil(x) 384 385 386def trunc(x): 387 """Get truncated value of the input. 388 389 The truncated value of the scalar x is the 390 nearest integer i which is closer to zero than x is. 391 392 Parameters 393 ---------- 394 x : Expr 395 Input argument. 396 397 Returns 398 ------- 399 y : Expr 400 The result. 401 """ 402 return _make.trunc(x) 403 404 405def abs(x): 406 """Get absolute value of the input element-wise. 407 408 Parameters 409 ---------- 410 x : Expr 411 Input argument. 412 413 Returns 414 ------- 415 y : Expr 416 The result. 417 """ 418 return _make.abs(x) 419 420 421def round(x): 422 """Round elements of the array to the nearest integer. 423 424 Parameters 425 ---------- 426 x : Expr 427 Input argument. 428 429 Returns 430 ------- 431 y : Expr 432 The result. 433 """ 434 return _make.round(x) 435 436 437def nearbyint(x): 438 """Round elements of the array to the nearest integer. 439 This intrinsic uses llvm.nearbyint instead of llvm.round 440 which is faster but will results different from tvm.round. 441 Notably nearbyint rounds according to the rounding mode, 442 whereas tvm.round (llvm.round) ignores that. 443 For differences between the two see: 444 https://en.cppreference.com/w/cpp/numeric/math/round 445 https://en.cppreference.com/w/cpp/numeric/math/nearbyint 446 447 Parameters 448 ---------- 449 x : Expr 450 Input argument. 451 452 Returns 453 ------- 454 y : Expr 455 The result. 456 """ 457 return _make.nearbyint(x) 458 459 460def isnan(x): 461 """Check if input value is Nan. 462 463 Parameters 464 ---------- 465 x : Expr 466 Input argument. 467 468 Returns 469 ------- 470 y : Expr 471 The result. 472 """ 473 return _make.isnan(x) 474 475 476def power(x, y): 477 """x power y 478 479 Parameters 480 ---------- 481 x : Expr 482 Input argument. 483 484 y : Expr 485 The exponent 486 487 Returns 488 ------- 489 z : Expr 490 The result. 491 """ 492 return _make._OpPow(convert(x), convert(y)) 493 494 495def popcount(x): 496 """Count the number of set bits in input x. 497 498 Parameters 499 ---------- 500 x : Expr 501 Input argument. 502 503 Returns 504 ------- 505 y : Expr 506 The result. 507 """ 508 return call_pure_intrin(x.dtype, "popcount", x) 509 510def fmod(x, y): 511 """Return the remainder of x divided by y with the same sign as x. 512 513 Parameters 514 ---------- 515 x : Expr 516 Input argument. 517 y : Expr 518 Input argument. 519 520 Returns 521 ------- 522 z : Expr 523 The result. 524 """ 525 return call_pure_intrin(x.dtype, "fmod", x, y) 526 527 528def if_then_else(cond, t, f): 529 """Conditional selection expression. 530 531 Parameters 532 ---------- 533 cond : Expr 534 The condition 535 536 t : Expr 537 The result expression if cond is true. 538 539 f : Expr 540 The result expression if cond is false. 541 542 Returns 543 ------- 544 result : Node 545 The result of conditional expression. 546 547 Note 548 ---- 549 Unlike Select, if_then_else will not execute 550 the branch that does not satisfy the condition. 551 You can use it to guard against out of bound access. 552 Unlike Select, if_then_else cannot be vectorized 553 if some lanes in the vector have different conditions. 554 """ 555 return _make._OpIfThenElse(convert(cond), convert(t), convert(f)) 556 557 558# Intrinsic rule related code 559def register_intrin_rule(target, intrin, f=None, override=False): 560 """Register an intrinsic function generation rule. 561 562 Intrinsic generation rules are callback functions for 563 code generator to get device specific calls. 564 This function simply translates to. 565 566 :code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)` 567 568 TVM may already pre-register intrinsic rules in the backend. 569 However, user can use this function to change the intrinsic translation 570 behavior or add new intrinsic rules during runtime. 571 572 Parameters 573 ---------- 574 target : str 575 The name of codegen target. 576 577 intrin : str 578 The name of the instrinsic. 579 580 f : function, optional 581 The function to be registered. 582 583 override: boolean optional 584 Whether override existing entry. 585 586 Returns 587 ------- 588 fregister : function 589 Register function if f is not specified. 590 591 Examples 592 -------- 593 The following code registers exp expansion rule for opencl. 594 595 .. code-block:: python 596 597 register_intrin_rule("opencl", "exp", my_exp_rule, override=True) 598 """ 599 return _register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override) 600 601 602def _rule_float_suffix(op): 603 """Intrinsic rule: Add float suffix if it is float32. 604 605 This is an example intrinsic generation rule. 606 607 Parameters 608 ---------- 609 op : Expr 610 The call expression of original intrinsic. 611 612 Returns 613 ------- 614 ret : Expr 615 The translated intrinsic rule. 616 Return same op if no translation is possible. 617 618 See Also 619 -------- 620 register_intrin_rule : The registeration function for intrin rule. 621 """ 622 if op.dtype == "float32": 623 return call_pure_extern(op.dtype, "%sf" % op.name, *op.args) 624 if op.dtype == "float64": 625 return call_pure_extern(op.dtype, op.name, *op.args) 626 return op 627 628 629def _rule_float_direct(op): 630 """Intrinsic rule: Directly call pure extern function for floats. 631 632 This is an example intrinsic generation rule. 633 634 Parameters 635 ---------- 636 op : Expr 637 The call expression of original intrinsic. 638 639 Returns 640 ------- 641 ret : Expr 642 The translated intrinsic rule. 643 Return same op if no translation is possible. 644 645 See Also 646 -------- 647 register_intrin_rule : The registeration function for intrin rule. 648 """ 649 if str(op.dtype).startswith("float"): 650 return call_pure_extern(op.dtype, op.name, *op.args) 651 return None 652 653@_register_func("tvm.default_trace_action") 654def _tvm_default_trace_action(*args): 655 print(list(args)) 656 657def trace(args, trace_action="tvm.default_trace_action"): 658 """Trace tensor data at the runtime. 659 660 The trace function allows to trace specific tensor at the 661 runtime. The tracing value should come as last argument. 662 The trace action should be specified, by default 663 tvm.default_trace_action is used. 664 665 Parameters 666 ---------- 667 args : list of Expr or Buffers. 668 Positional arguments. 669 670 trace_action : str. 671 The name of the trace action. 672 673 Returns 674 ------- 675 call : Expr 676 The call expression. 677 678 See Also 679 -------- 680 tvm.call_packed : Creates packed function. 681 """ 682 if not isinstance(args, list): 683 raise Exception("tvm.trace consumes the args as list type") 684 call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args] 685 call_args.insert(0, trace_action) 686 return _make.Call( 687 args[-1].dtype, "tvm_call_trace_packed", call_args, _Call.Intrinsic, None, 0) 688 689# opencl pattern for exp 690register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) 691# default pattern for exp 692register_intrin_rule("default", "exp", _rule_float_suffix, override=True) 693