1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Functions defined in TVM.""" 18# pylint: disable=invalid-name,unused-import,redefined-builtin 19from __future__ import absolute_import as _abs 20 21from numbers import Integral as _Integral 22 23from ._ffi.base import string_types 24from ._ffi.object import register_object, Object 25from ._ffi.node import register_node, NodeBase 26from ._ffi.node import convert_to_node as _convert_to_node 27from ._ffi.node_generic import _scalar_type_inference 28from ._ffi.function import Function 29from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs 30from ._ffi.function import convert_to_tvm_func as _convert_tvm_func 31from ._ffi.runtime_ctypes import TVMType 32from . import _api_internal 33from . import make as _make 34from . import expr as _expr 35from . import tensor as _tensor 36from . import schedule as _schedule 37from . import container as _container 38from . import tag as _tag 39 40int8 = "int8" 41int32 = "int32" 42float32 = "float32" 43handle = "handle" 44 45 46def min_value(dtype): 47 """minimum value of dtype 48 49 Parameters 50 ---------- 51 dtype : str 52 The data type. 53 54 Returns 55 ------- 56 value : tvm.Expr 57 The minimum value of dtype. 58 """ 59 return _api_internal._min_value(dtype) 60 61 62def max_value(dtype): 63 """maximum value of dtype 64 65 Parameters 66 ---------- 67 dtype : str 68 The data type. 69 70 Returns 71 ------- 72 value : tvm.Expr 73 The maximum value of dtype. 74 """ 75 return _api_internal._max_value(dtype) 76 77 78def const(value, dtype=None): 79 """construct a constant 80 81 Parameters 82 ---------- 83 value : number 84 The content of the constant number. 85 86 dtype : str or None, optional 87 The data type. 88 89 Returns 90 ------- 91 const_val: tvm.Expr 92 The result expression. 93 """ 94 if dtype is None: 95 dtype = _scalar_type_inference(value) 96 return _api_internal._const(value, dtype) 97 98 99def get_env_func(name): 100 """Get an EnvFunc by a global name. 101 102 Parameters 103 ---------- 104 name: str 105 The name of the global function. 106 107 Returns 108 ------- 109 env_func : EnvFunc 110 The result env function. 111 112 Note 113 ---- 114 EnvFunc is a Node wrapper around 115 global function that can be serialized via its name. 116 This can be used to serialize function field in the language. 117 """ 118 return _api_internal._EnvFuncGet(name) 119 120 121def convert(value): 122 """Convert value to TVM node or function. 123 124 Parameters 125 ---------- 126 value : python value 127 128 Returns 129 ------- 130 tvm_val : Node or Function 131 Converted value in TVM 132 """ 133 if isinstance(value, (Function, NodeBase)): 134 return value 135 136 if callable(value): 137 return _convert_tvm_func(value) 138 139 return _convert_to_node(value) 140 141 142def load_json(json_str): 143 """Load tvm object from json_str. 144 145 Parameters 146 ---------- 147 json_str : str 148 The json string 149 150 Returns 151 ------- 152 node : Node 153 The loaded tvm node. 154 """ 155 return _api_internal._load_json(json_str) 156 157 158def save_json(node): 159 """Save tvm object as json string. 160 161 Parameters 162 ---------- 163 node : Node 164 A TVM Node object to be saved. 165 166 Returns 167 ------- 168 json_str : str 169 Saved json string. 170 """ 171 return _api_internal._save_json(node) 172 173 174def var(name="tindex", dtype=int32): 175 """Create a new variable with specified name and dtype 176 177 Parameters 178 ---------- 179 name : str 180 The name 181 182 dtype : int 183 The data type 184 185 Returns 186 ------- 187 var : Var 188 The result symbolic variable. 189 """ 190 return _api_internal._Var(name, dtype) 191 192 193def any(*args): 194 """Create a new experssion of the union of all conditions in the arguments 195 196 Parameters 197 ---------- 198 args : list 199 List of symbolic boolean expressions 200 201 Returns 202 ------- 203 expr: Expr 204 Expression 205 """ 206 if not args: 207 raise ValueError("Any must take at least 1 argument") 208 if len(args) == 1: 209 return args[0] 210 ret = _make._OpOr(args[0], args[1]) 211 for i in range(2, len(args)): 212 ret = _make._OpOr(ret, args[i]) 213 return ret 214 215 216def all(*args): 217 """Create a new experssion of the intersection of all conditions in the 218 arguments 219 220 Parameters 221 ---------- 222 args : list 223 List of symbolic boolean expressions 224 225 Returns 226 ------- 227 expr: Expr 228 Expression 229 """ 230 if not args: 231 raise ValueError("Any must take at least 1 argument") 232 if len(args) == 1: 233 return args[0] 234 ret = _make._OpAnd(args[0], args[1]) 235 for i in range(2, len(args)): 236 ret = _make._OpAnd(ret, args[i]) 237 return ret 238 239 240def placeholder(shape, dtype=None, name="placeholder"): 241 """Construct an empty tensor object. 242 243 Parameters 244 ---------- 245 shape: Tuple of Expr 246 The shape of the tensor 247 248 dtype: str, optional 249 The data type of the tensor 250 251 name: str, optional 252 The name hint of the tensor 253 254 Returns 255 ------- 256 tensor: Tensor 257 The created tensor 258 """ 259 shape = (shape,) if isinstance(shape, _expr.Expr) else shape 260 dtype = float32 if dtype is None else dtype 261 return _api_internal._Placeholder( 262 shape, dtype, name) 263 264 265def compute(shape, fcompute, name="compute", tag="", attrs=None): 266 """Construct a new tensor by computing over the shape domain. 267 268 The compute rule is result[axis] = fcompute(axis) 269 270 Parameters 271 ---------- 272 shape: Tuple of Expr 273 The shape of the tensor 274 275 fcompute: lambda function of indices-> value 276 Specifies the input source expression 277 278 name: str, optional 279 The name hint of the tensor 280 281 tag: str, optional 282 Additional tag information about the compute. 283 284 attrs: dict, optional 285 The additional auxiliary attributes about the compute. 286 287 Returns 288 ------- 289 tensor: Tensor 290 The created tensor 291 """ 292 if _tag.TagScope.get_current() is not None: 293 if tag != "": 294 raise ValueError("nested tag is not allowed for now") 295 tag = _tag.TagScope.get_current().tag 296 shape = (shape,) if isinstance(shape, _expr.Expr) else shape 297 # for python3 298 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) 299 ndim = len(shape) 300 code = fcompute.__code__ 301 302 out_ndim = ndim 303 if code.co_argcount == 0: 304 arg_names = ["i%d" % i for i in range(ndim)] 305 else: 306 arg_names = code.co_varnames[:code.co_argcount] 307 out_ndim = code.co_argcount 308 309 if out_ndim != len(arg_names): 310 raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) 311 312 dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])] 313 body = fcompute(*[v.var for v in dim_var]) 314 315 if isinstance(body, _tensor.TensorIntrinCall): 316 for i, s in enumerate(shape[out_ndim:]): 317 var_name = "ax" + str(i) 318 dim_var.append(_IterVar((0, s), var_name, 4)) 319 op_node = _api_internal._TensorComputeOp(name, 320 tag, 321 dim_var, 322 body.reduce_axis, 323 out_ndim, 324 body.intrin, 325 body.tensors, 326 body.regions, 327 body.scalar_inputs) 328 else: 329 if not isinstance(body, (list, tuple)): 330 body = [body] 331 body = convert(body) 332 op_node = _api_internal._ComputeOp( 333 name, tag, attrs, dim_var, body) 334 335 num = op_node.num_outputs 336 outputs = tuple(op_node.output(i) for i in range(num)) 337 return outputs[0] if num == 1 else outputs 338 339 340def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None): 341 """Construct new tensors by scanning over axis. 342 343 Parameters 344 ---------- 345 init: Tensor or list of Tensor 346 The initial condition of first init.shape[0] timestamps 347 348 update: Tensor or list of Tensor 349 The update rule of the scan given by symbolic tensor. 350 351 state_placeholder: Tensor or list of Tensor 352 The placeholder variables used by update. 353 354 inputs: Tensor or list of Tensor, optional 355 The list of inputs to the scan. This is not required, but can 356 be useful for the compiler to detect scan body faster. 357 358 name: str, optional 359 The name hint of the tensor 360 361 tag: str, optional 362 Additonal tag information about the compute. 363 364 attrs: dict, optional 365 The additional auxiliary attributes about the compute. 366 367 Returns 368 ------- 369 tensor: Tensor or list of Tensors 370 The created tensor or tuple of tensors it it contains multiple outputs. 371 372 Example 373 ------- 374 .. code-block:: python 375 376 # The following code is equivalent to numpy.cumsum 377 m = tvm.var("m") 378 n = tvm.var("n") 379 X = tvm.placeholder((m, n), name="X") 380 s_state = tvm.placeholder((m, n)) 381 s_init = tvm.compute((1, n), lambda _, i: X[0, i]) 382 s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) 383 res = tvm.scan(s_init, s_update, s_state, X) 384 """ 385 if _tag.TagScope.get_current() is not None: 386 if tag != "": 387 raise ValueError("nested tag is not allowed for now") 388 tag = _tag.TagScope.get_current().tag 389 if isinstance(init, _tensor.Tensor): 390 init = [init] 391 if isinstance(update, _tensor.Tensor): 392 update = [update] 393 if isinstance(state_placeholder, _tensor.Tensor): 394 state_placeholder = [state_placeholder] 395 if isinstance(inputs, _tensor.Tensor): 396 inputs = [inputs] 397 if inputs is None: 398 inputs = [] 399 if len(init) != len(update) or len(init) != len(state_placeholder): 400 raise ValueError("init, update, state_placeholder must have same length") 401 axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3) 402 op = _api_internal._ScanOp(name, tag, attrs, 403 axis, init, update, 404 state_placeholder, inputs) 405 res = [op.output(i) for i in range(len(update))] 406 return res[0] if len(res) == 1 else res 407 408 409def extern(shape, 410 inputs, 411 fcompute, 412 name="extern", 413 dtype=None, 414 in_buffers=None, 415 out_buffers=None, 416 tag="", 417 attrs=None): 418 """Compute several tensor via extern function. 419 420 Parameters 421 ---------- 422 shape: tuple or list of tuples. 423 The shape of the outputs. 424 425 inputs: list of Tensor 426 The inputs 427 428 fcompute: lambda function of inputs, outputs-> stmt 429 Specifies the IR statement to do the computation. 430 See the following note for function signature of fcompute 431 432 .. note:: 433 **Parameters** 434 435 - **ins** (list of :any:`Buffer`) - Placeholder for each inputs 436 - **outs** (list of :any:`Buffer`) - Placeholder for each outputs 437 438 **Returns** 439 440 - **stmt** (:any:`Stmt`) - The statement that carries out array computation. 441 442 name: str, optional 443 The name hint of the tensor 444 445 dtype: str or list of str, optional 446 The data types of outputs, 447 by default dtype will be same as inputs. 448 449 in_buffers: Buffer or list of Buffer, optional 450 Input buffers. 451 452 out_buffers: Buffer or list of Buffers, optional 453 Output buffers. 454 455 456 tag: str, optional 457 Additonal tag information about the compute. 458 459 attrs: dict, optional 460 The additional auxiliary attributes about the compute. 461 462 Returns 463 ------- 464 tensor: Tensor or list of Tensors 465 The created tensor or tuple of tensors it it contains multiple outputs. 466 467 Example 468 ------- 469 In the code below, C is generated by calling external PackedFunc 470 `tvm.contrib.cblas.matmul` 471 472 .. code-block:: python 473 474 A = tvm.placeholder((n, l), name='A') 475 B = tvm.placeholder((l, m), name='B') 476 C = tvm.extern((n, m), [A, B], 477 lambda ins, outs: tvm.call_packed( 478 "tvm.contrib.cblas.matmul", 479 ins[0], ins[1], outs[0], 0, 0), name="C") 480 """ 481 if _tag.TagScope.get_current() is not None: 482 if tag != "": 483 raise ValueError("nested tag is not allowed for now") 484 tag = _tag.TagScope.get_current().tag 485 shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape 486 if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)): 487 shape = [shape] 488 if in_buffers is not None: 489 in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers 490 if len(inputs) != len(in_buffers): 491 raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d." 492 % (len(inputs), len(in_buffers))) 493 if out_buffers is not None: 494 out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers 495 if len(shape) != len(out_buffers): 496 raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d." 497 % (len(shape), len(out_buffers))) 498 input_placeholders = in_buffers or [] 499 output_placeholders = out_buffers or [] 500 types = set() 501 for t in inputs: 502 if not isinstance(t, _tensor.Tensor): 503 raise ValueError("expect inputs to be tensor") 504 if in_buffers is None: 505 input_placeholders.append( 506 decl_buffer(t.shape, t.dtype, t.op.name)) 507 types.add(t.dtype) 508 509 if dtype is None: 510 if len(types) != 1: 511 raise ValueError("Cannot infer output type, please provide dtype argument") 512 infered_type = types.pop() 513 dtype = [infered_type for _ in shape] 514 if isinstance(dtype, str): 515 dtype = [dtype] 516 517 if out_buffers is None: 518 for shp, dt in zip(shape, dtype): 519 output_placeholders.append(decl_buffer(shp, dt, name)) 520 body = fcompute(input_placeholders, output_placeholders) 521 if isinstance(body, _expr.Expr): 522 body = _make.Evaluate(body) 523 524 op = _api_internal._ExternOp(name, tag, attrs, 525 inputs, input_placeholders, 526 output_placeholders, body) 527 res = [op.output(i) for i in range(len(output_placeholders))] 528 return res[0] if len(res) == 1 else res 529 530 531def decl_buffer(shape, 532 dtype=None, 533 name="buffer", 534 data=None, 535 strides=None, 536 elem_offset=None, 537 scope="", 538 data_alignment=-1, 539 offset_factor=0, 540 buffer_type=""): 541 """Declare a new symbolic buffer. 542 543 Normally buffer is created automatically during lower and build. 544 This is only needed if user want to specify their own buffer layout. 545 546 See the note below for detailed discussion on usage of buffer. 547 548 Parameters 549 ---------- 550 shape : tuple of Expr 551 The shape of the buffer. 552 553 dtype : str, optional 554 The data type of the buffer. 555 556 name : str, optional 557 The name of the buffer. 558 559 data : Var, optional 560 The data pointer in the buffer. 561 562 strides: array of Expr 563 The stride of the buffer. 564 565 elem_offset: Expr, optional 566 The beginning offset of the array to data. 567 In terms of number of elements of dtype. 568 569 scope: str, optional 570 The storage scope of the buffer, if not global. 571 If scope equals empty string, it means it is global memory. 572 573 data_alignment: int, optional 574 The alignment of data pointer in bytes. 575 If -1 is passed, the alignment will be set to TVM's internal default. 576 577 offset_factor: int, optional 578 The factor of elem_offset field, when set, 579 elem_offset is required to be multiple of offset_factor. 580 If 0 is pssed, the alignment will be set to 1. 581 if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. 582 583 buffer_type: str, optional, {"", "auto_broadcast"} 584 auto_broadcast buffer allows one to implement broadcast computation 585 without considering whether dimension size equals to one. 586 TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1. 587 588 Returns 589 ------- 590 buffer : Buffer 591 The created buffer 592 593 Example 594 ------- 595 Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation, 596 597 .. code-block:: python 598 599 m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") 600 n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") 601 o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") 602 A = tvm.placeholder((m0, m1, m2), name='A') 603 B = tvm.placeholder((n0, n1, n2), name='B') 604 C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') 605 Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") 606 Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") 607 s = tvm.create_schedule(C.op) 608 fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) 609 ctx = tvm.cpu(0) 610 a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) 611 b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx) 612 c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) 613 fadd(a, b, c) 614 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) 615 616 Note 617 ---- 618 Buffer data structure reflects the DLTensor structure in dlpack. 619 While DLTensor data structure is very general, it is usually helpful 620 to create function that only handles specific case of data structure 621 and make compiled function benefit from it. 622 623 If user pass strides and elem_offset is passed as None 624 when constructing the function, then the function will be specialized 625 for the DLTensor that is compact and aligned. 626 If user pass a fully generic symbolic array to the strides, 627 then the resulting function becomes fully generic. 628 """ 629 shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape 630 dtype = float32 if dtype is None else dtype 631 strides = () if strides is None else strides 632 if offset_factor != 0 and elem_offset is None: 633 shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" 634 elem_offset = var('%s_elem_offset' % name, shape_dtype) 635 if data is None: 636 data = var(name, "handle") 637 return _api_internal._Buffer( 638 data, dtype, shape, strides, elem_offset, name, scope, 639 data_alignment, offset_factor, buffer_type) 640 641def layout(layout_str): 642 """Create a layout node from a string. 643 644 Parameters 645 ---------- 646 layout_str : str 647 A layout representation is composed of upper cases, lower cases and numbers, 648 where upper case indicates a primal axis and 649 the corresponding lower case with factor size indicates the subordinate axis. 650 For example, NCHW16c can describe a 5-D tensor of 651 [batch_size, channel, height, width, channel_block]. 652 Here subordinate axis channel_block=16 is the factor size of 653 the primal axis C (channel). 654 655 Returns 656 ------- 657 layout : Layout 658 The created layout 659 """ 660 return _api_internal._Layout(layout_str) 661 662def bijective_layout(src_layout, dst_layout): 663 """Create a bijective layout mapping. 664 665 Parameters 666 ---------- 667 src_layout : str or Layout 668 source layout. 669 670 dst_layout : str or Layout 671 destination layout. 672 673 Returns 674 ------- 675 bijective_layout : BijectiveLayout 676 The created bijective layout 677 """ 678 if isinstance(src_layout, str): 679 src_layout = layout(src_layout) 680 if isinstance(dst_layout, str): 681 dst_layout = layout(dst_layout) 682 return _api_internal._BijectiveLayout(src_layout, dst_layout) 683 684def _IterVar(dom, name, iter_type, thread_tag=''): 685 """Internal function to create IterVar 686 687 Parameters 688 ---------- 689 dom : Range 690 The domain of iteration. 691 692 name : str 693 The name of iteration variable. 694 695 iter_type : int 696 The type of iteration. 697 698 thread_tag : str 699 The thread tag of the iteration variable. 700 701 Returns 702 ------- 703 iter_var : IterVar 704 The result itervar 705 """ 706 if dom is not None: 707 if isinstance(dom, (list, tuple)): 708 if len(dom) != 2: 709 raise TypeError("need to be list of ranges") 710 dom = Range(dom[0], dom[1]) 711 712 if not isinstance(dom, _container.Range): 713 raise TypeError("dom need to be Range") 714 name = name if name else 'iter' 715 v = var(name) 716 return _api_internal._IterVar(dom, v, iter_type, thread_tag) 717 718 719def thread_axis(dom=None, tag='', name=''): 720 """Create a new IterVar to represent thread index. 721 722 Parameters 723 ---------- 724 dom : Range or str 725 The domain of iteration 726 When str is passed, dom is set to None and str is used as tag 727 728 tag : str, optional 729 The thread tag 730 731 name : str, optional 732 The name of the var. 733 734 Returns 735 ------- 736 axis : IterVar 737 The thread itervar. 738 """ 739 if isinstance(dom, string_types): 740 tag, dom = dom, None 741 if not tag: 742 raise ValueError("tag must be given as Positional or keyword argument") 743 name = name if name else tag 744 return _IterVar(dom, name, 1, tag) 745 746 747def reduce_axis(dom, name="rv"): 748 """Create a new IterVar for reduction. 749 750 Parameters 751 ---------- 752 dom : Range 753 The domain of iteration. 754 755 name : str 756 The name of the variable. 757 758 Returns 759 ------- 760 axis : IterVar 761 An iteration variable representing the value. 762 """ 763 return _IterVar(dom, name, 2) 764 765 766def comm_reducer(fcombine, fidentity, name="reduce"): 767 """Create a commutative reducer for reduction. 768 769 Parameters 770 ---------- 771 fcombine : function(Expr -> Expr -> Expr) 772 A binary function which takes two Expr as input to return a Expr. 773 774 fidentity : function(str -> Expr) 775 A function which takes a type string as input to return a const Expr. 776 777 Returns 778 ------- 779 reducer : function 780 A function which creates a reduce expression over axis. 781 There are two ways to use it: 782 783 1. accept (expr, axis, where) to produce an Reduce Expr on 784 specified axis; 785 2. simply use it with multiple Exprs. 786 787 Example 788 ------- 789 .. code-block:: python 790 791 n = tvm.var('n') 792 m = tvm.var('m') 793 mysum = tvm.comm_reducer(lambda x, y: x+y, 794 lambda t: tvm.const(0, dtype=t), name="mysum") 795 A = tvm.placeholder((n, m), name='A') 796 k = tvm.reduce_axis((0, m), name='k') 797 B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B') 798 """ 799 def _reduce_directly(*args): 800 num = len(args) 801 # process `where` is None 802 if num == 3 and args[2] is None: 803 num = 2 804 res = args[0] 805 for i in range(num-1): 806 res = fcombine(res, args[i+1]) 807 return res 808 809 def _make_reduce(expr, axis, where=None): 810 code = fcombine.__code__ 811 assert fcombine.__code__.co_argcount == 2 812 expr = convert(expr) 813 if isinstance(expr, _container.Array): 814 size = len(expr) 815 larr = [] 816 rarr = [] 817 dtypes = [] 818 for i in range(size): 819 dtype = expr[i].dtype 820 dtypes.append(dtype) 821 lname = code.co_varnames[0] + '_' + str(i) 822 larr.append(var(lname, dtype)) 823 rname = code.co_varnames[1] + '_' + str(i) 824 rarr.append(var(rname, dtype)) 825 lhs = convert(larr) 826 rhs = convert(rarr) 827 result = fcombine(lhs, rhs) 828 id_elem = fidentity(*dtypes) 829 else: 830 assert isinstance(expr, _expr.Expr) 831 size = 1 832 dtype = expr.dtype 833 lvar = var(code.co_varnames[0], dtype) 834 rvar = var(code.co_varnames[1], dtype) 835 result = [fcombine(lvar, rvar)] 836 id_elem = [fidentity(dtype)] 837 lhs = convert([lvar]) 838 rhs = convert([rvar]) 839 expr = convert([expr]) 840 result = convert(result) 841 id_elem = convert(id_elem) 842 combiner = _make.CommReducer(lhs, rhs, result, id_elem) 843 axis = convert(axis if isinstance(axis, (list, tuple)) else [axis]) 844 if where is None: 845 where = convert(True) 846 outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i) 847 for i in range(size)) 848 return outputs[0] if size == 1 else outputs 849 850 # pylint: disable=keyword-arg-before-vararg 851 def reducer(expr, axis, where=None, *args): 852 if isinstance(axis, (_schedule.IterVar, list, tuple)): 853 assert not args 854 return _make_reduce(expr, axis, where) 855 if where is None: 856 assert not args 857 return _reduce_directly(expr, axis) 858 return _reduce_directly(expr, axis, where, *args) 859 860 doc_str = """Create a {0} expression over axis. 861 862 Parameters 863 ---------- 864 expr : Expr 865 The source expression. 866 axis : IterVar 867 The reduction IterVar axis 868 where : optional, Expr 869 Filtering predicate of the reduction. 870 Returns 871 ------- 872 value : Expr 873 The result value. 874 875 Example 876 ------- 877 .. code-block:: python 878 879 m = tvm.var("m") 880 n = tvm.var("n") 881 A = tvm.placeholder((m, n), name="A") 882 k = tvm.reduce_axis((0, n), name="k") 883 884 # there are two way to use this {0} reducer: 885 # mode 1, accept (expr, axis, where) to produce an Reduce Expr 886 B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B") 887 888 # mode 2, simply use it with multiple Exprs: 889 {0}_res = tvm.{0}(m, n) 890 """ 891 reducer.__doc__ = doc_str.format(name) 892 return reducer 893 894def div(a, b): 895 """Compute a / b as in C/C++ semantics. 896 897 Parameters 898 ---------- 899 a : Expr 900 The left hand operand, known to be non-negative. 901 902 b : Expr 903 The right hand operand, known to be non-negative. 904 905 Returns 906 ------- 907 res : Expr 908 The result expression. 909 Note 910 ---- 911 When operands are integers, returns truncdiv(a, b). 912 """ 913 return _make._OpDiv(a, b) 914 915 916def indexdiv(a, b): 917 """Compute floor(a / b) where a and b are non-negative. 918 919 Parameters 920 ---------- 921 a : Expr 922 The left hand operand, known to be non-negative. 923 924 b : Expr 925 The right hand operand, known to be non-negative. 926 927 Returns 928 ------- 929 res : Expr 930 The result expression. 931 932 Note 933 ---- 934 Use this function to split non-negative indices. 935 This function may take advantage of operands' 936 non-negativeness. 937 """ 938 return _make._OpIndexDiv(a, b) 939 940 941def indexmod(a, b): 942 """Compute the remainder of indexdiv. a and b are non-negative. 943 944 Parameters 945 ---------- 946 a : Expr 947 The left hand operand, known to be non-negative. 948 949 b : Expr 950 The right hand operand, known to be non-negative. 951 952 Returns 953 ------- 954 res : Expr 955 The result expression. 956 957 Note 958 ---- 959 Use this function to split non-negative indices. 960 This function may take advantage of operands' 961 non-negativeness. 962 """ 963 return _make._OpIndexMod(a, b) 964 965 966def truncdiv(a, b): 967 """Compute the truncdiv of two expressions. 968 969 Parameters 970 ---------- 971 a : Expr 972 The left hand operand 973 974 b : Expr 975 The right hand operand 976 977 Returns 978 ------- 979 res : Expr 980 The result expression. 981 982 Note 983 ---- 984 This is the default integer division behavior in C. 985 """ 986 return _make._OpTruncDiv(a, b) 987 988 989def truncmod(a, b): 990 """Compute the truncmod of two expressions. 991 992 Parameters 993 ---------- 994 a : Expr 995 The left hand operand 996 997 b : Expr 998 The right hand operand 999 1000 Returns 1001 ------- 1002 res : Expr 1003 The result expression. 1004 1005 Note 1006 ---- 1007 This is the default integer division behavior in C. 1008 """ 1009 return _make._OpTruncMod(a, b) 1010 1011 1012def floordiv(a, b): 1013 """Compute the floordiv of two expressions. 1014 1015 Parameters 1016 ---------- 1017 a : Expr 1018 The left hand operand 1019 1020 b : Expr 1021 The right hand operand 1022 1023 Returns 1024 ------- 1025 res : Expr 1026 The result expression. 1027 """ 1028 return _make._OpFloorDiv(a, b) 1029 1030 1031def floormod(a, b): 1032 """Compute the floormod of two expressions. 1033 1034 Parameters 1035 ---------- 1036 a : Expr 1037 The left hand operand 1038 1039 b : Expr 1040 The right hand operand 1041 1042 Returns 1043 ------- 1044 res : Expr 1045 The result expression. 1046 """ 1047 return _make._OpFloorMod(a, b) 1048 1049 1050_init_api("tvm.api") 1051 1052#pylint: disable=unnecessary-lambda 1053sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum") 1054min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min') 1055max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max') 1056