1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""
19Distributed measurement infrastructure to measure the runtime costs of tensor programs.
20
21These functions are responsible for building the tvm module, uploading it to
22remote devices, recording the running time costs, and checking the correctness of the output.
23
24We separate the measurement into two steps: build and run.
25A builder builds the executable binary files and a runner runs the binary files to
26get the measurement results. The flow of data structures is
27
28  .                `ProgramBuilder`                 `ProgramRunner`
29  `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
30
31We implement these in python to utilize python's multiprocessing and error handling.
32"""
33
34import os
35import time
36import shutil
37import traceback
38import tempfile
39import multiprocessing
40
41import tvm._ffi
42from tvm.runtime import Object, module, ndarray
43from tvm.driver import build_module
44from tvm.ir import transform
45from tvm.rpc.tracker import Tracker
46from tvm.rpc.server import Server
47from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
48from tvm.contrib import tar, ndk
49
50from . import _ffi_api
51from .loop_state import StateObject
52from .utils import (
53    get_const_tuple,
54    NoDaemonPool,
55    call_func_with_timeout,
56    request_remote,
57    check_remote,
58)
59
60# The maximum length of error message
61MAX_ERROR_MSG_LEN = 512
62
63# We use fork and a global variable to copy arguments between processes.
64# This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
65GLOBAL_BUILD_ARGUMENTS = None
66GLOBAL_RUN_ARGUMENTS = None
67
68
69@tvm._ffi.register_object("auto_scheduler.MeasureCallback")
70class MeasureCallback(Object):
71    """ The base class of measurement callback functions. """
72
73
74@tvm._ffi.register_object("auto_scheduler.MeasureInput")
75class MeasureInput(Object):
76    """Store the input of a measurement.
77
78    Parameters
79    ----------
80    task : SearchTask
81        The SearchTask of this measurement.
82    state : Union[State, StateObject]
83        The State to be measured.
84    """
85
86    def __init__(self, task, state):
87        state = state if isinstance(state, StateObject) else state.state_object
88        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state)
89
90
91@tvm._ffi.register_object("auto_scheduler.BuildResult")
92class BuildResult(Object):
93    """Store the result of a build.
94
95    Parameters
96    ----------
97    filename : Optional[str]
98        The filename of built binary file.
99    args : List[Tensor]
100        The arguments.
101    error_no : int
102        The error code.
103    error_msg : Optional[str]
104        The error message if there is any error.
105    time_cost : float
106        The time cost of build.
107    """
108
109    def __init__(self, filename, args, error_no, error_msg, time_cost):
110        filename = filename if filename else ""
111        error_msg = error_msg if error_msg else ""
112
113        self.__init_handle_by_constructor__(
114            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost
115        )
116
117
118@tvm._ffi.register_object("auto_scheduler.MeasureResult")
119class MeasureResult(Object):
120    """Store the results of a measurement.
121
122    Parameters
123    ----------
124    costs : List[float]
125        The time costs of execution.
126    error_no : int
127        The error code.
128    error_msg : Optional[str]
129        The error message if there is any error.
130    all_cost : float
131        The time cost of build and run.
132    timestamp : float
133        The time stamps of this measurement.
134    """
135
136    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
137        error_msg = error_msg if error_msg else ""
138
139        self.__init_handle_by_constructor__(
140            _ffi_api.MeasureResult, costs, error_no, error_msg, all_cost, timestamp
141        )
142
143
144@tvm._ffi.register_object("auto_scheduler.ProgramBuilder")
145class ProgramBuilder(Object):
146    """ The base class of ProgramBuilders. """
147
148    def build(self, measure_inputs, verbose=1):
149        """Build programs and return results.
150
151        Parameters
152        ----------
153        measure_inputs : List[MeasureInput]
154            A List of MeasureInput.
155        verbose: int = 1
156            Verbosity level. 0 for silent, 1 to output information during program building.
157
158        Returns
159        -------
160        res : List[BuildResult]
161        """
162        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
163
164
165@tvm._ffi.register_object("auto_scheduler.ProgramRunner")
166class ProgramRunner(Object):
167    """ The base class of ProgramRunners. """
168
169    def run(self, measure_inputs, build_results, verbose=1):
170        """Run measurement and return results.
171
172        Parameters
173        ----------
174        measure_inputs : List[MeasureInput]
175            A List of MeasureInput.
176        build_results : List[BuildResult]
177            A List of BuildResult to be ran.
178        verbose: int = 1
179            Verbosity level. 0 for silent, 1 to output information during program running.
180
181        Returns
182        -------
183        res : List[MeasureResult]
184        """
185        return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose)
186
187
188@tvm._ffi.register_object("auto_scheduler.LocalBuilder")
189class LocalBuilder(ProgramBuilder):
190    """LocalBuilder use local CPU cores to build programs in parallel.
191
192    Parameters
193    ----------
194    timeout : int = 15
195        The timeout limit (in second) for each build thread.
196        This is used in a wrapper of the multiprocessing.Process.join().
197    n_parallel : int = multiprocessing.cpu_count()
198        Number of threads used to build in parallel.
199    build_func : str = 'default'
200        The name of registered build function.
201    """
202
203    def __init__(self, timeout=15, n_parallel=multiprocessing.cpu_count(), build_func="default"):
204        self.__init_handle_by_constructor__(_ffi_api.LocalBuilder, timeout, n_parallel, build_func)
205
206
207@tvm._ffi.register_object("auto_scheduler.LocalRunner")
208class LocalRunner(ProgramRunner):
209    """LocalRunner that uses local CPU/GPU to measures the time cost of programs.
210
211    Parameters
212    ----------
213    timeout : int = 10
214        The timeout limit (in second) for each run.
215        This is used in a wrapper of the multiprocessing.Process.join().
216    number : int = 3
217        The number of times to run the generated code for taking average.
218        We call these runs as one `repeat` of measurement.
219    repeat : int = 1
220        The number of times to repeat the measurement.
221        In total, the generated code will be run (1 + number x repeat) times,
222        where the first "1" is warm up and will be discarded.
223        The returned result contains `repeat` costs,
224        each of which is an average of `number` costs.
225    min_repeat_ms : int = 100
226        The minimum duration of one `repeat` in milliseconds.
227        By default, one `repeat` contains `number` runs. If this parameter is set,
228        the parameters `number` will be dynamically adjusted to meet the
229        minimum duration requirement of one `repeat`.
230        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
231        will be automatically increased.
232    cooldown_interval : float = 0.0
233        The cool down interval between two measurements.
234    enable_cpu_cache_flush: bool = False
235        Whether to flush cache on CPU between repeated measurements.
236        Flushing cache can make the measured latency of one operator closer to
237        its actual latency during end-to-end inference.
238        To make this option effective, the argument `number` should also be set to 1.
239        This is only has effect on CPU task.
240    """
241
242    def __init__(
243        self,
244        timeout=10,
245        number=3,
246        repeat=1,
247        min_repeat_ms=100,
248        cooldown_interval=0.0,
249        enable_cpu_cache_flush=False,
250    ):
251        self.__init_handle_by_constructor__(
252            _ffi_api.LocalRunner,
253            timeout,
254            number,
255            repeat,
256            min_repeat_ms,
257            cooldown_interval,
258            enable_cpu_cache_flush,
259        )
260
261
262@tvm._ffi.register_object("auto_scheduler.RPCRunner")
263class RPCRunner(ProgramRunner):
264    """RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
265    Or sometime we may need to use RPC even in local running to insulate the thread environment.
266    (e.g. running CUDA programs)
267
268    Parameters
269    ----------
270    key : str
271        The key of the device registered in the RPC tracker.
272    host : str
273        The host address of the RPC Tracker.
274    port : int
275        The port of RPC Tracker.
276    priority : int = 1
277        The priority of this run request, larger is more prior.
278    n_parallel : int = 1
279        The number of tasks run in parallel.
280    timeout : int = 10
281        The timeout limit (in second) for each run.
282        This is used in a wrapper of the multiprocessing.Process.join().
283    number : int = 3
284        The number of times to run the generated code for taking average.
285        We call these runs as one `repeat` of measurement.
286    repeat : int = 1
287        The number of times to repeat the measurement.
288        In total, the generated code will be run (1 + number x repeat) times,
289        where the first "1" is warm up and will be discarded.
290        The returned result contains `repeat` costs,
291        each of which is an average of `number` costs.
292    min_repeat_ms : int = 100
293        The minimum duration of one `repeat` in milliseconds.
294        By default, one `repeat` contains `number` runs. If this parameter is set,
295        the parameters `number` will be dynamically adjusted to meet the
296        minimum duration requirement of one `repeat`.
297        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
298        will be automatically increased.
299    cooldown_interval : float = 0.0
300        The cool down interval between two measurements.
301    enable_cpu_cache_flush: bool = False
302        Whether to flush cache on CPU between repeated measurements.
303        Flushing cache can make the measured latency of one operator closer to
304        its actual latency during end-to-end inference.
305        To make this option effective, the argument `number` should also be set to 1.
306        This is only has effect on CPU task.
307    """
308
309    def __init__(
310        self,
311        key,
312        host,
313        port,
314        priority=1,
315        n_parallel=1,
316        timeout=10,
317        number=3,
318        repeat=1,
319        min_repeat_ms=100,
320        cooldown_interval=0.0,
321        enable_cpu_cache_flush=False,
322    ):
323        self.__init_handle_by_constructor__(
324            _ffi_api.RPCRunner,
325            key,
326            host,
327            port,
328            priority,
329            n_parallel,
330            timeout,
331            number,
332            repeat,
333            min_repeat_ms,
334            cooldown_interval,
335            enable_cpu_cache_flush,
336        )
337
338        if check_remote(key, host, port, priority, timeout):
339            print("Get devices for measurement successfully!")
340        else:
341            raise RuntimeError(
342                "Cannot get remote devices from the tracker. "
343                "Please check the status of tracker by "
344                "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
345                "and make sure you have free devices on the queue status."
346            )
347
348
349class LocalRPCMeasureContext:
350    """A context wrapper for running RPCRunner locally.
351    This will launch a local RPC Tracker and local RPC Server.
352
353    Parameters
354    ----------
355    priority : int = 1
356        The priority of this run request, larger is more prior.
357    n_parallel : int = 1
358        The number of tasks run in parallel.
359    timeout : int = 10
360        The timeout limit (in second) for each run.
361        This is used in a wrapper of the multiprocessing.Process.join().
362    number : int = 3
363        The number of times to run the generated code for taking average.
364        We call these runs as one `repeat` of measurement.
365    repeat : int = 1
366        The number of times to repeat the measurement.
367        In total, the generated code will be run (1 + number x repeat) times,
368        where the first "1" is warm up and will be discarded.
369        The returned result contains `repeat` costs,
370        each of which is an average of `number` costs.
371    min_repeat_ms : int = 0
372        The minimum duration of one `repeat` in milliseconds.
373        By default, one `repeat` contains `number` runs. If this parameter is set,
374        the parameters `number` will be dynamically adjusted to meet the
375        minimum duration requirement of one `repeat`.
376        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
377        will be automatically increased.
378    cooldown_interval : float = 0.0
379        The cool down interval between two measurements.
380    enable_cpu_cache_flush: bool = False
381        Whether to flush cache on CPU between repeated measurements.
382        Flushing cache can make the measured latency of one operator closer to
383        its actual latency during end-to-end inference.
384        To make this option effective, the argument `number` should also be set to 1.
385        This is only has effect on CPU task.
386    """
387
388    def __init__(
389        self,
390        priority=1,
391        n_parallel=1,
392        timeout=10,
393        number=3,
394        repeat=1,
395        min_repeat_ms=0,
396        cooldown_interval=0.0,
397        enable_cpu_cache_flush=False,
398    ):
399        ctx = tvm.context("cuda", 0)
400        if ctx.exist:
401            cuda_arch = "sm_" + "".join(ctx.compute_version.split("."))
402            set_cuda_target_arch(cuda_arch)
403        host = "0.0.0.0"
404        self.tracker = Tracker(host, port=9000, port_end=10000, silent=True)
405        device_key = "$local$device$%d" % self.tracker.port
406        self.server = Server(
407            host,
408            port=self.tracker.port,
409            port_end=10000,
410            key=device_key,
411            use_popen=True,
412            silent=True,
413            tracker_addr=(self.tracker.host, self.tracker.port),
414        )
415        self.runner = RPCRunner(
416            device_key,
417            host,
418            self.tracker.port,
419            priority,
420            n_parallel,
421            timeout,
422            number,
423            repeat,
424            min_repeat_ms,
425            cooldown_interval,
426            enable_cpu_cache_flush,
427        )
428        # Wait for the processes to start
429        time.sleep(0.5)
430
431    def __del__(self):
432        # Close the tracker and server before exit
433        self.tracker.terminate()
434        self.server.terminate()
435
436
437class MeasureErrorNo(object):
438    """ Error type for MeasureResult. """
439
440    NO_ERROR = 0  # No error
441    INSTANTIATION_ERROR = 1  # Errors happen when apply transform steps from init state
442    COMPILE_HOST = 2  # Errors happen when compiling code on host (e.g., tvm.build)
443    COMPILE_DEVICE = 3  # Errors happen when compiling code on device
444    # (e.g. OpenCL JIT on the device)
445    RUNTIME_DEVICE = 4  # Errors happen when run program on device
446    WRONG_ANSWER = 5  # Answer is wrong when compared to a reference output
447    BUILD_TIMEOUT = 6  # Timeout during compilation
448    RUN_TIMEOUT = 7  # Timeout during run
449    UNKNOWN_ERROR = 8  # Unknown error
450
451
452def make_error_msg():
453    """ Get the error message from traceback. """
454    error_msg = str(traceback.format_exc())
455    if len(error_msg) > MAX_ERROR_MSG_LEN:
456        error_msg = (
457            error_msg[: MAX_ERROR_MSG_LEN // 2] + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN // 2 :]
458        )
459    return error_msg
460
461
462def local_build_worker(index):
463    """
464    Build function of LocalBuilder to be ran in the Builder thread pool.
465
466    Parameters
467    ----------
468    index : int
469        The MeasureInput index to be processed by the current Builder thread.
470
471    Returns
472    -------
473    res : BuildResult
474        The build result of this Builder thread.
475    """
476    global GLOBAL_BUILD_ARGUMENTS
477
478    # We use fork and a global variable to copy arguments between processes.
479    # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
480    if not GLOBAL_BUILD_ARGUMENTS:
481        raise ValueError("GLOBAL_BUILD_ARGUMENTS not found")
482    measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS
483    assert isinstance(build_func, str)
484
485    if build_func == "default":
486        build_func = tar.tar
487    elif build_func == "ndk":
488        build_func = ndk.create_shared
489    else:
490        raise ValueError("Invalid build_func" + build_func)
491
492    def timed_func():
493        tic = time.time()
494        inp = measure_inputs[index]
495        task = inp.task
496
497        error_no = MeasureErrorNo.NO_ERROR
498        error_msg = None
499        args = []
500
501        try:
502            sch, args = task.compute_dag.apply_steps_from_state(inp.state, layout_rewrite=True)
503        # pylint: disable=broad-except
504        except Exception:
505            error_no = MeasureErrorNo.INSTANTIATION_ERROR
506            error_msg = make_error_msg()
507
508        if error_no == 0:
509            dirname = tempfile.mkdtemp()
510            filename = os.path.join(dirname, "tmp_func." + build_func.output_format)
511
512            try:
513                # TODO(merrymercy): Port the unroll pass.
514                with transform.PassContext():
515                    func = build_module.build(
516                        sch, args, target=task.target, target_host=task.target_host
517                    )
518                func.export_library(filename, build_func)
519            # pylint: disable=broad-except
520            except Exception:
521                error_no = MeasureErrorNo.COMPILE_HOST
522                error_msg = make_error_msg()
523        else:
524            filename = ""
525
526        if verbose >= 1:
527            if error_no == MeasureErrorNo.NO_ERROR:
528                print(".", end="")
529            else:
530                print(".E", end="")  # Build error
531        return filename, args, error_no, error_msg, time.time() - tic
532
533    res = call_func_with_timeout(timeout, timed_func)
534    if isinstance(res, TimeoutError):
535        if verbose >= 1:
536            print(".T", end="")  # Build timeout
537        res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
538
539    return res
540
541
542@tvm._ffi.register_func("auto_scheduler.local_builder.build")
543def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbose=1):
544    """
545    Build function of LocalBuilder to build the MeasureInputs to runnable modules.
546
547    Parameters
548    ----------
549    inputs : List[MeasureInput]
550        The MeasureInputs to be built.
551    timeout : int
552        The timeout limit (in second) for each build thread.
553        This is used in a wrapper of the multiprocessing.Process.join().
554    n_parallel : int
555        Number of threads used to build in parallel.
556    build_func : str = 'default'
557        The name of build function to process the built module.
558    verbose: int = 1
559        Verbosity level. 0 for silent, 1 to output information during program building.
560
561    Returns
562    -------
563    res : List[BuildResult]
564        The build results of these MeasureInputs.
565    """
566    # We use fork and a global variable to copy arguments between processes.
567    # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
568    global GLOBAL_BUILD_ARGUMENTS
569
570    GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose)
571
572    pool = NoDaemonPool(n_parallel)
573    tuple_res = pool.map(local_build_worker, range(len(inputs)))
574    pool.terminate()
575    pool.join()
576    del pool
577
578    results = []
579    for res in tuple_res:
580        results.append(BuildResult(*res))
581
582    return results
583
584
585@tvm._ffi.register_func("auto_scheduler.local_runner.run")
586def local_run(
587    inputs,
588    build_results,
589    timeout=10,
590    number=3,
591    repeat=1,
592    min_repeat_ms=0,
593    cooldown_interval=0,
594    enable_cpu_cache_flush=False,
595    verbose=1,
596):
597    """
598    Run function of LocalRunner to test the performance of the input BuildResults.
599
600    Parameters
601    ----------
602    inputs : List[MeasureInput]
603        The MeasureInputs to be measured.
604    build_results : List[BuildResult]
605        The BuildResults to be measured.
606    timeout : int = 10
607        The timeout limit (in second) for each run.
608        This is used in a wrapper of the multiprocessing.Process.join().
609    number : int = 3
610        The number of times to run the generated code for taking average.
611        We call these runs as one `repeat` of measurement.
612    repeat : int = 1
613        The number of times to repeat the measurement.
614        In total, the generated code will be run (1 + number x repeat) times,
615        where the first "1" is warm up and will be discarded.
616        The returned result contains `repeat` costs,
617        each of which is an average of `number` costs.
618    min_repeat_ms : int = 0
619        The minimum duration of one `repeat` in milliseconds.
620        By default, one `repeat` contains `number` runs. If this parameter is set,
621        the parameters `number` will be dynamically adjusted to meet the
622        minimum duration requirement of one `repeat`.
623        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
624        will be automatically increased.
625    cooldown_interval : float = 0.0
626        The cool down interval between two measurements.
627    enable_cpu_cache_flush: bool = False
628        Whether to flush cache on CPU between repeated measurements.
629        Flushing cache can make the measured latency of one operator closer to
630        its actual latency during end-to-end inference.
631        To make this option effective, the argument `number` should also be set to 1.
632        This is only has effect on CPU task.
633    verbose: int = 1
634        Verbosity level. 0 for silent, 1 to output information during program measuring.
635
636    Returns
637    -------
638    res : List[MeasureResult]
639        The measure results of these MeasureInputs.
640    """
641    max_float = 1e10  # We use 1e10 instead of sys.float_info.max for better readability in log
642
643    def timed_func(inp, build_res):
644        tic = time.time()
645        error_no = 0
646        error_msg = None
647        try:
648            func = module.load_module(build_res.filename)
649            ctx = ndarray.context(str(inp.task.target), 0)
650            # Limitation:
651            # We can not get PackFunction directly in the remote mode as it is wrapped
652            # under the std::function. We could lift the restriction later once we fold
653            # the PackedFunc as an object. Currently, we pass function name to work
654            # around it.
655            f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
656            time_f = func.time_evaluator(
657                func.entry_name,
658                ctx,
659                number=number,
660                repeat=repeat,
661                min_repeat_ms=min_repeat_ms,
662                f_preproc=f_prepare,
663            )
664        # pylint: disable=broad-except
665        except Exception:
666            costs = (max_float,)
667            error_no = MeasureErrorNo.COMPILE_DEVICE
668            error_msg = make_error_msg()
669
670        if error_no == 0:
671            try:
672                args = [
673                    ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args
674                ]
675                random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
676                assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
677                for arg in args:
678                    random_fill(arg)
679                ctx.sync()
680                costs = time_f(*args).results
681            # pylint: disable=broad-except
682            except Exception:
683                costs = (max_float,)
684                error_no = MeasureErrorNo.RUNTIME_DEVICE
685                error_msg = make_error_msg()
686
687        shutil.rmtree(os.path.dirname(build_res.filename))
688        toc = time.time()
689        time.sleep(cooldown_interval)
690
691        if verbose >= 1:
692            if error_no == MeasureErrorNo.NO_ERROR:
693                print("*", end="")
694            else:
695                print("*E", end="")  # Run error
696        return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc
697
698    measure_results = []
699    assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
700    for inp, build_res in zip(inputs, build_results):
701        if build_res.error_no != 0:
702            res = (
703                (max_float,),
704                build_res.error_no,
705                build_res.error_msg,
706                build_res.time_cost,
707                time.time(),
708            )
709        else:
710            res = call_func_with_timeout(timeout, timed_func, args=(inp, build_res))
711            if isinstance(res, TimeoutError):
712                if verbose >= 1:
713                    print("*T", end="")  # Run timeout
714                res = (
715                    (max_float,),
716                    MeasureErrorNo.RUN_TIMEOUT,
717                    None,
718                    build_res.time_cost + timeout,
719                    time.time(),
720                )
721        measure_results.append(MeasureResult(*res))
722
723    if verbose >= 1:
724        print("")
725
726    return measure_results
727
728
729def rpc_run_worker(index):
730    """Function to be ran in the RPCRunner thread pool.
731
732    Parameters
733    ----------
734    index : int
735        The MeasureInput and BuildResult index to be processed by the current Runner thread.
736
737    Returns
738    -------
739    res : MeasureResult
740        The measure result of this Runner thread.
741    """
742    global GLOBAL_RUN_ARGUMENTS
743    (
744        inputs,
745        build_results,
746        key,
747        host,
748        port,
749        priority,
750        timeout,
751        number,
752        repeat,
753        min_repeat_ms,
754        cooldown_interval,
755        enable_cpu_cache_flush,
756        verbose,
757    ) = GLOBAL_RUN_ARGUMENTS
758
759    max_float = 1e10  # We use 1e10 instead of sys.float_info.max for better readability in log
760    inp = inputs[index]
761    build_res = build_results[index]
762
763    if build_res.error_no != MeasureErrorNo.NO_ERROR:
764        return (
765            (max_float,),
766            build_res.error_no,
767            build_res.error_msg,
768            build_res.time_cost,
769            time.time(),
770        )
771
772    def timed_func():
773        tic = time.time()
774        error_no = 0
775        error_msg = None
776        try:
777            # upload built module
778            remote = request_remote(key, host, port, priority, timeout)
779            remote.upload(build_res.filename)
780            func = remote.load_module(os.path.split(build_res.filename)[1])
781            ctx = remote.context(str(inp.task.target), 0)
782            # Limitation:
783            # We can not get PackFunction directly in the remote mode as it is wrapped
784            # under the std::function. We could lift the restriction later once we fold
785            # the PackedFunc as an object. Currently, we pass function name to work
786            # around it.
787            f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
788            time_f = func.time_evaluator(
789                func.entry_name,
790                ctx,
791                number=number,
792                repeat=repeat,
793                min_repeat_ms=min_repeat_ms,
794                f_preproc=f_prepare,
795            )
796        # pylint: disable=broad-except
797        except Exception:
798            costs = (max_float,)
799            error_no = MeasureErrorNo.COMPILE_DEVICE
800            error_msg = make_error_msg()
801
802        if error_no == 0:
803            try:
804                args = [
805                    ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args
806                ]
807                try:
808                    random_fill = remote.get_function("tvm.contrib.random.random_fill")
809                except AttributeError:
810                    raise AttributeError(
811                        "Please make sure USE_RANDOM is ON in the config.cmake "
812                        "on the remote devices"
813                    )
814                for arg in args:
815                    random_fill(arg)
816                ctx.sync()
817
818                costs = time_f(*args).results
819                # clean up remote files
820                remote.remove(build_res.filename)
821                remote.remove(os.path.splitext(build_res.filename)[0] + ".so")
822                remote.remove("")
823            # pylint: disable=broad-except
824            except Exception:
825                costs = (max_float,)
826                error_no = MeasureErrorNo.RUNTIME_DEVICE
827                error_msg = make_error_msg()
828
829        shutil.rmtree(os.path.dirname(build_res.filename))
830        toc = time.time()
831
832        time.sleep(cooldown_interval)
833        if verbose >= 1:
834            if error_no == MeasureErrorNo.NO_ERROR:
835                print("*", end="")
836            else:
837                print("*E", end="")  # Run error
838
839        return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc
840
841    res = call_func_with_timeout(timeout, timed_func)
842
843    if isinstance(res, TimeoutError):
844        if verbose >= 1:
845            print("*T", end="")  # Run timeout
846        res = (
847            (max_float,),
848            MeasureErrorNo.RUN_TIMEOUT,
849            None,
850            build_res.time_cost + timeout,
851            time.time(),
852        )
853    return res
854
855
856@tvm._ffi.register_func("auto_scheduler.rpc_runner.run")
857def rpc_runner_run(
858    inputs,
859    build_results,
860    key,
861    host,
862    port,
863    priority=1,
864    n_parallel=1,
865    timeout=10,
866    number=3,
867    repeat=1,
868    min_repeat_ms=0,
869    cooldown_interval=0.0,
870    enable_cpu_cache_flush=False,
871    verbose=1,
872):
873    """Run function of RPCRunner to test the performance of the input BuildResults.
874
875    Parameters
876    ----------
877    inputs : List[MeasureInput]
878        The MeasureInputs to be measured.
879    build_results : List[BuildResult]
880        The BuildResults to be measured.
881    key : str
882        The key of the device registered in the RPC tracker.
883    host : str
884        The host address of the RPC Tracker.
885    port : int
886        The port of RPC Tracker.
887    priority : int = 1
888        The priority of this run request, larger is more prior.
889    n_parallel : int = 1
890        The number of tasks run in parallel.
891    timeout : int = 10
892        The timeout limit (in second) for each run.
893        This is used in a wrapper of the multiprocessing.Process.join().
894    number : int = 3
895        The number of times to run the generated code for taking average.
896        We call these runs as one `repeat` of measurement.
897    repeat : int = 1
898        The number of times to repeat the measurement.
899        In total, the generated code will be run (1 + number x repeat) times,
900        where the first "1" is warm up and will be discarded.
901        The returned result contains `repeat` costs,
902        each of which is an average of `number` costs.
903    min_repeat_ms : int = 0
904        The minimum duration of one `repeat` in milliseconds.
905        By default, one `repeat` contains `number` runs. If this parameter is set,
906        the parameters `number` will be dynamically adjusted to meet the
907        minimum duration requirement of one `repeat`.
908        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
909        will be automatically increased.
910    cooldown_interval : float = 0.0
911        The cool down interval between two measurements.
912    enable_cpu_cache_flush: bool = False
913        Whether to flush cache on CPU between repeated measurements.
914        Flushing cache can make the measured latency of one operator closer to
915        its actual latency during end-to-end inference.
916        To make this option effective, the argument `number` should also be set to 1.
917        This is only has effect on CPU task.
918    verbose: int = 1
919        Verbosity level. 0 for silent, 1 to output information during program measuring.
920
921    Returns
922    -------
923    res : List[MeasureResult]
924        The measure results of these MeasureInputs.
925    """
926    global GLOBAL_RUN_ARGUMENTS
927    GLOBAL_RUN_ARGUMENTS = (
928        inputs,
929        build_results,
930        key,
931        host,
932        port,
933        priority,
934        timeout,
935        number,
936        repeat,
937        min_repeat_ms,
938        cooldown_interval,
939        enable_cpu_cache_flush,
940        verbose,
941    )
942
943    assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
944    pool = NoDaemonPool(n_parallel)
945    tuple_res = pool.map(rpc_run_worker, range(len(build_results)))
946    pool.terminate()
947    pool.join()
948    del pool
949
950    results = []
951    for res in tuple_res:
952        results.append(MeasureResult(*res))
953
954    if verbose >= 1:
955        print("")
956
957    return results
958