1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=unused-import
18"""The computation schedule api of TVM."""
19import tvm._ffi
20from tvm._ffi.base import string_types
21
22from tvm.runtime import Object, convert
23from tvm.ir import container as _container
24from tvm.tir import IterVar, Buffer
25
26from . import tensor as _tensor
27from . import _ffi_api
28
29
30@tvm._ffi.register_object
31class Split(Object):
32    """Split operation on axis."""
33
34
35@tvm._ffi.register_object
36class Fuse(Object):
37    """Fuse operation on axis."""
38
39
40@tvm._ffi.register_object
41class Singleton(Object):
42    """Singleton axis."""
43
44
45def create_schedule(ops):
46    """Create a schedule for list of ops
47
48    Parameters
49    ----------
50    ops : list of Operations
51        The source expression.
52
53    Returns
54    -------
55    sch : schedule.Schedule
56        The created schedule.
57    """
58    if not isinstance(ops, (list, _container.Array)):
59        ops = [ops]
60    return _ffi_api.CreateSchedule(ops)
61
62
63@tvm._ffi.register_object
64class Schedule(Object):
65    """Schedule for all the stages."""
66
67    def __getitem__(self, k):
68        if isinstance(k, _tensor.Tensor):
69            k = k.op
70        if not isinstance(k, _tensor.Operation):
71            raise ValueError("Expect schedule key to be Tensor or Operation")
72        if k not in self.stage_map:
73            raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
74        return self.stage_map[k]
75
76    def normalize(self):
77        """Build a normalized schedule from the current schedule.
78
79        Insert necessary rebase to make certain iter var to start from 0.
80        This is needed before bound inference and followup step.
81
82        Returns
83        -------
84        sch : Schedule
85            The normalized schedule.
86        """
87        return _ffi_api.ScheduleNormalize(self)
88
89    def create_group(self, outputs, inputs, include_inputs=False):
90        """Create stage group by giving output and input boundary.
91
92        The operators between outputs and inputs are placed as member of group.
93        outputs are include in the group, while inputs are not included.
94
95        Parameters
96        ----------
97        outputs : list of Tensors
98            The outputs of the group.
99
100        inputs : list of Tensors
101            The inputs of the group.
102
103        include_inputs : boolean, optional
104            Whether include input operations in the group if they are used by outputs.
105
106        Returns
107        -------
108        group : Stage
109            A virtual stage represents the group, user can use compute_at to move
110            the attachment point of the group.
111        """
112        if isinstance(outputs, _tensor.Tensor):
113            outputs = [outputs]
114        if isinstance(inputs, _tensor.Tensor):
115            inputs = [inputs]
116        return _ffi_api.ScheduleCreateGroup(self, outputs, inputs, include_inputs)
117
118    def cache_read(self, tensor, scope, readers):
119        """Create a cache read of original tensor for readers.
120
121        This will mutate the body of the readers.
122        A new cache stage will be created for the tensor.
123        Call this before doing any split/fuse schedule.
124
125        Parameters
126        ----------
127        tensor : Tensor
128            The tensor to be cached.
129        scope : str
130            The scope of cached
131        readers : list of Tensor or Operation
132            The readers to read the cache.
133
134        Returns
135        -------
136        cache : Tensor
137            The created cache tensor.
138        """
139        if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
140            readers = [readers]
141        readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
142        return _ffi_api.ScheduleCacheRead(self, tensor, scope, readers)
143
144    def cache_write(self, tensor, scope):
145        """Create a cache write of original tensor, before storing into tensor.
146
147        This will mutate the body of the tensor.
148        A new cache stage will created before feed into the tensor.
149
150        This function can be used to support data layout transformation.
151        If there is a split/fuse/reorder on the data parallel axis of tensor
152        before cache_write is called. The intermediate cache stores
153        the data in the layout as the iteration order of leave axis.
154        The data will be transformed back to the original layout in the original tensor.
155        User can further call compute_inline to inline the original layout and keep
156        the data stored in the transformed layout.
157
158        Parameters
159        ----------
160        tensor : Tensor, list or tuple
161            The tensors to be feed to. All the tensors must be produced by one computeOp
162        scope : str
163            The scope of cached
164
165        Returns
166        -------
167        cache : Tensor
168            The created cache tensor.
169        """
170        return _ffi_api.ScheduleCacheWrite(self, tensor, scope)
171
172    def rfactor(self, tensor, axis, factor_axis=0):
173        """Factor a reduction axis in tensor's schedule to be an explicit axis.
174
175        This will create a new stage that generated the new tensor with axis
176        as the first dimension. The tensor's body will be rewritten as a reduction
177        over the factored tensor.
178
179        Parameters
180        ----------
181        tensor : Tensor
182            The tensor to be factored.
183        axis : IterVar
184            The reduction axis in the schedule to be factored.
185        factor_axis : int
186            The position where the new axis is placed.
187
188        Returns
189        -------
190        tfactor : Tensor or Array of Tensor
191            The created factored tensor.
192        """
193        factored = _ffi_api.ScheduleRFactor(self, tensor, axis, factor_axis)
194        return factored[0] if len(factored) == 1 else factored
195
196
197@tvm._ffi.register_object
198class Stage(Object):
199    """A Stage represents schedule for one operation."""
200
201    def split(self, parent, factor=None, nparts=None):
202        """Split the stage either by factor providing outer scope, or both
203
204        Parameters
205        ----------
206        parent : IterVar
207             The parent iter var.
208
209        factor : Expr, optional
210             The splitting factor
211
212        nparts : Expr, optional
213             The number of outer parts.
214
215        Returns
216        -------
217        outer : IterVar
218            The outer variable of iteration.
219
220        inner : IterVar
221            The inner variable of iteration.
222        """
223        if nparts is not None:
224            if factor is not None:
225                raise ValueError("Do not need to provide both outer and nparts")
226            outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts)
227        else:
228            if factor is None:
229                raise ValueError("Either nparts or factor need to be provided")
230            outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor)
231        return outer, inner
232
233    def fuse(self, *args):
234        """Fuse multiple consecutive iteration variables into a single iteration variable.
235
236        fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1])
237        The order is from outer to inner.
238
239        Parameters
240        ----------
241        args : list of IterVars
242            Itervars that proceeds each other
243
244        Returns
245        -------
246        fused : IterVar
247            The fused variable of iteration.
248        """
249        fused = _ffi_api.StageFuse(self, args)
250        return fused
251
252    def set_scope(self, scope):
253        """Set the thread scope of this stage
254
255        Parameters
256        ----------
257        scope : str
258            The thread scope of this stage
259        """
260        return _ffi_api.StageSetScope(self, scope)
261
262    def bind(self, ivar, thread_ivar):
263        """Bind ivar to thread index thread_ivar
264
265        Parameters
266        ----------
267        ivar : IterVar
268            The iteration to be binded to thread.
269
270        thread_ivar : IterVar
271            The thread to be binded.
272        """
273        _ffi_api.StageBind(self, ivar, thread_ivar)
274
275    def env_threads(self, threads):
276        """Mark threads to be launched at the outer scope of composed op.
277
278        Parameters
279        ----------
280        threads : list of threads
281            The threads to be launched.
282        """
283        if isinstance(threads, IterVar):
284            threads = [threads]
285        _ffi_api.StageEnvThreads(self, threads)
286
287    def set_store_predicate(self, predicate):
288        """Set predicate under which store to the array can be performed.
289
290        Use this when there are duplicated threads doing the same store and we only
291        need one of them to do the store.
292
293        Parameters
294        ----------
295        predicate : Expr
296            The guard condition fo store.
297        """
298        _ffi_api.StageSetStorePredicate(self, predicate)
299
300    def compute_at(self, parent, scope):
301        """Attach the stage at parent's scope
302
303        Parameters
304        ----------
305        parent : Stage
306            The parent stage
307
308        scope : IterVar
309            The loop scope t be attached to.
310        """
311        _ffi_api.StageComputeAt(self, parent, scope)
312
313    def compute_inline(self):
314        """Mark stage as inline
315
316        Parameters
317        ----------
318        parent : Stage
319            The parent stage
320        """
321        _ffi_api.StageComputeInline(self)
322
323    def compute_root(self):
324        """Attach the stage at parent, and mark it as root
325
326        Parameters
327        ----------
328        parent : Stage
329            The parent stage
330        """
331        _ffi_api.StageComputeRoot(self)
332
333    def reorder(self, *args):
334        """reorder the arguments in the specified order.
335
336        Parameters
337        ----------
338        args : list of IterVar
339            The order to be ordered
340        """
341        _ffi_api.StageReorder(self, args)
342
343    def tile(self, x_parent, y_parent, x_factor, y_factor):
344        """Perform tiling on two dimensions
345
346        The final loop order from outmost to inner most are
347        [x_outer, y_outer, x_inner, y_inner]
348
349        Parameters
350        ----------
351        x_parent : IterVar
352            The original x dimension
353        y_parent : IterVar
354            The original y dimension
355        x_factor : Expr
356            The stride factor on x axis
357        y_factor : Expr
358            The stride factor on y axis
359
360        Returns
361        -------
362        x_outer : IterVar
363            Outer axis of x dimension
364        y_outer : IterVar
365            Outer axis of y dimension
366        x_inner : IterVar
367            Inner axis of x dimension
368        p_y_inner : IterVar
369            Inner axis of y dimension
370        """
371        x_outer, y_outer, x_inner, y_inner = _ffi_api.StageTile(
372            self, x_parent, y_parent, x_factor, y_factor
373        )
374        return x_outer, y_outer, x_inner, y_inner
375
376    def vectorize(self, var):
377        """Vectorize the iteration.
378
379        Parameters
380        ----------
381        var : IterVar
382            The iteration to be vectorize
383        """
384        _ffi_api.StageVectorize(self, var)
385
386    def tensorize(self, var, tensor_intrin):
387        """Tensorize the computation enclosed by var with tensor_intrin
388
389        Parameters
390        ----------
391        var : IterVar
392            The iteration boundary of tensorization.
393
394        tensor_intrin : TensorIntrin
395            The tensor intrinsic used for computation.
396        """
397        _ffi_api.StageTensorize(self, var, tensor_intrin)
398
399    def unroll(self, var):
400        """Unroll the iteration.
401
402        Parameters
403        ----------
404        var : IterVar
405            The iteration to be unrolled.
406        """
407        _ffi_api.StageUnroll(self, var)
408
409    def parallel(self, var):
410        """Parallelize the iteration.
411
412        Parameters
413        ----------
414        var : IterVar
415            The iteration to be parallelized.
416        """
417        _ffi_api.StageParallel(self, var)
418
419    def pragma(self, var, pragma_type, pragma_value=None):
420        """Annotate the iteration with pragma
421
422        This will translate to a pragma_scope surrounding
423        the corresponding loop generated.
424        Useful to support experimental features and extensions.
425
426        Parameters
427        ----------
428        var : IterVar
429            The iteration to be anotated
430
431        pragma_type : str
432             The pragma string to be annotated
433
434        pragma_value : Expr, optional
435             The pragma value to pass along the pragma
436
437        Note
438        ----
439        Most pragmas are advanced/experimental features
440        and may subject to change. List of supported pragmas:
441
442        - **debug_skip_region**
443
444          Force skip the region marked by the axis and turn it into no-op.
445          This is useful for debug purposes.
446
447        - **parallel_launch_point**
448
449          Specify to launch parallel threads outside the
450          specified iteration loop. By default the threads
451          launch at the point of parallel construct.
452          This pragma moves the launching point to even outer scope.
453          The threads are launched once and reused across multiple
454          parallel constructs as BSP style program.
455
456        - **parallel_barrier_when_finish**
457
458          Insert a synchronization barrier between working threads
459          after the specified loop iteration finishes.
460
461        - **parallel_stride_pattern**
462
463          Hint parallel loop to execute in strided pattern.
464          :code:`for (int i = task_id; i < end; i += num_task)`
465
466        """
467        if isinstance(pragma_value, string_types):
468            pragma_value = convert(pragma_value)
469        _ffi_api.StagePragma(self, var, pragma_type, pragma_value)
470
471    def prefetch(self, tensor, var, offset):
472        """Prefetch the specified variable
473
474        Parameters
475        ----------
476        tensor : Tensor
477            The tensor to be prefetched
478        var : IterVar
479            The loop point at which the prefetching is applied
480        offset : Expr
481            The number of iterations to be prefetched before actual execution
482        """
483        _ffi_api.StagePrefetch(self, tensor, var, offset)
484
485    def storage_align(self, axis, factor, offset):
486        """Set alignment requirement for specific axis
487
488        This ensures that stride[axis] == k * factor + offset for some k.
489        This is useful to set memory layout to for more friendly memory
490        access pattern. For example, we can set alignment to be
491        factor=2, offset=1 to avoid bank conflict for thread access on
492        higher dimension in GPU shared memory.
493
494        Parameters
495        ----------
496        axis : IterVar
497            The axis dimension to be aligned.
498        factor : int
499            The factor in alignment specification.
500        offset : int
501            The offset in the alignment specification.
502        """
503        _ffi_api.StageStorageAlign(self, axis, factor, offset)
504
505    def double_buffer(self):
506        """Compute the current stage via double buffering.
507
508        This can only be applied to intermediate stage.
509        This will double the storage cost of the current stage.
510        Can be useful to hide load latency.
511        """
512        _ffi_api.StageDoubleBuffer(self)
513
514
515@tvm._ffi.register_object
516class SpecializedCondition(Object):
517    """Specialized condition to enable op specialization."""
518
519    def __init__(self, conditions):
520        """Create a specialized condition.
521
522        .. note::
523            Conditions are represented in conjunctive joint form (CNF).
524            Each condition should be a simple expression, e.g., n > 16,
525            m % 8 == 0, etc., where n, m are tvm.Var that represents a
526            dimension in the tensor shape.
527
528        Parameters
529        ----------
530        conditions : List of tvm.Expr
531            List of conditions in conjunctive joint form (CNF).
532        """
533        if not isinstance(conditions, (list, _container.Array)):
534            conditions = [conditions]
535        self.__init_handle_by_constructor__(_ffi_api.CreateSpecializedCondition, conditions)
536
537    @staticmethod
538    def current():
539        """Returns the current specialized condition"""
540        return _ffi_api.GetCurrentSpecialization()
541
542    def __enter__(self):
543        _ffi_api.EnterSpecializationScope(self)
544        return self
545
546    def __exit__(self, ptype, value, trace):
547        _ffi_api.ExitSpecializationScope(self)
548
549
550tvm._ffi._init_api("schedule", __name__)
551