1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Defines a top-level glue class that operates the Transport and Flasher classes."""
19
20import logging
21import time
22
23from .._ffi import get_global_func
24from ..contrib import graph_runtime
25from ..rpc import RPCSession
26from .transport import TransportLogger
27
28try:
29    from .base import _rpc_connect
30except ImportError:
31    raise ImportError("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake")
32
33
34class Session:
35    """MicroTVM Device Session
36
37    Parameters
38    ----------
39    config : dict
40        configuration for this session (as generated by
41        `tvm.micro.device.host.default_config()`, for example)
42
43    Example
44    --------
45    .. code-block:: python
46
47      c_mod = ...  # some module generated with "c" as the target
48      dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666)
49      with tvm.micro.Session(dev_config) as sess:
50          micro_mod = sess.create_micro_mod(c_mod)
51    """
52
53    def __init__(
54        self, binary=None, flasher=None, transport_context_manager=None, session_name="micro-rpc"
55    ):
56        """Configure a new session.
57
58        Parameters
59        ----------
60        binary : MicroBinary
61            If given, `flasher` must also be given. During session initialization, this binary will
62            be flashed to the device before the transport is created.
63        flasher : Flasher
64            If given, `binary` must also be given. Used to flash `binary` during session
65            initialization.
66        transport_context_manager : ContextManager[transport.Transport]
67            If given, `flasher` and `binary` should not be given. On entry, this context manager
68            should establish a tarnsport between this TVM instance and the device.
69        session_name : str
70            Name of the session, used for debugging.
71        """
72        self.binary = binary
73        self.flasher = flasher
74        self.transport_context_manager = transport_context_manager
75        self.session_name = session_name
76
77        self._rpc = None
78        self._graph_runtime = None
79
80    def get_system_lib(self):
81        return self._rpc.get_function("runtime.SystemLib")()
82
83    def __enter__(self):
84        """Initialize this session and establish an RPC session with the on-device RPC server.
85
86        Returns
87        -------
88        Session :
89            Returns self.
90        """
91        if self.flasher is not None:
92            self.transport_context_manager = self.flasher.flash(self.binary)
93            time.sleep(3.0)
94
95        self.transport = TransportLogger(
96            self.session_name, self.transport_context_manager, level=logging.INFO
97        ).__enter__()
98        self._rpc = RPCSession(
99            _rpc_connect(self.session_name, self.transport.write, self.transport.read)
100        )
101        self.context = self._rpc.cpu(0)
102        return self
103
104    def __exit__(self, exc_type, exc_value, exc_traceback):
105        """Tear down this session and associated RPC session resources."""
106        self.transport.__exit__(exc_type, exc_value, exc_traceback)
107
108
109def create_local_graph_runtime(graph_json_str, mod, ctx):
110    """Create a local graph runtime driving execution on the remote CPU context given.
111
112    Parameters
113    ----------
114    graph_json_str : str
115        A string containing the graph representation.
116
117    mod : tvm.runtime.Module
118        The remote module containing functions in graph_json_str.
119
120    ctx : tvm.Context
121        The remote CPU execution context.
122
123    Returns
124    -------
125    tvm.contrib.GraphRuntime :
126         A local graph runtime instance that executes on the remote device.
127    """
128    device_type_id = [ctx.device_type, ctx.device_id]
129    fcreate = get_global_func("tvm.graph_runtime.create")
130    return graph_runtime.GraphModule(fcreate(graph_json_str, mod, *device_type_id))
131