1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Defines a top-level glue class that operates the Transport and Flasher classes.""" 19 20import logging 21import time 22 23from .._ffi import get_global_func 24from ..contrib import graph_runtime 25from ..rpc import RPCSession 26from .transport import TransportLogger 27 28try: 29 from .base import _rpc_connect 30except ImportError: 31 raise ImportError("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake") 32 33 34class Session: 35 """MicroTVM Device Session 36 37 Parameters 38 ---------- 39 config : dict 40 configuration for this session (as generated by 41 `tvm.micro.device.host.default_config()`, for example) 42 43 Example 44 -------- 45 .. code-block:: python 46 47 c_mod = ... # some module generated with "c" as the target 48 dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666) 49 with tvm.micro.Session(dev_config) as sess: 50 micro_mod = sess.create_micro_mod(c_mod) 51 """ 52 53 def __init__( 54 self, binary=None, flasher=None, transport_context_manager=None, session_name="micro-rpc" 55 ): 56 """Configure a new session. 57 58 Parameters 59 ---------- 60 binary : MicroBinary 61 If given, `flasher` must also be given. During session initialization, this binary will 62 be flashed to the device before the transport is created. 63 flasher : Flasher 64 If given, `binary` must also be given. Used to flash `binary` during session 65 initialization. 66 transport_context_manager : ContextManager[transport.Transport] 67 If given, `flasher` and `binary` should not be given. On entry, this context manager 68 should establish a tarnsport between this TVM instance and the device. 69 session_name : str 70 Name of the session, used for debugging. 71 """ 72 self.binary = binary 73 self.flasher = flasher 74 self.transport_context_manager = transport_context_manager 75 self.session_name = session_name 76 77 self._rpc = None 78 self._graph_runtime = None 79 80 def get_system_lib(self): 81 return self._rpc.get_function("runtime.SystemLib")() 82 83 def __enter__(self): 84 """Initialize this session and establish an RPC session with the on-device RPC server. 85 86 Returns 87 ------- 88 Session : 89 Returns self. 90 """ 91 if self.flasher is not None: 92 self.transport_context_manager = self.flasher.flash(self.binary) 93 time.sleep(3.0) 94 95 self.transport = TransportLogger( 96 self.session_name, self.transport_context_manager, level=logging.INFO 97 ).__enter__() 98 self._rpc = RPCSession( 99 _rpc_connect(self.session_name, self.transport.write, self.transport.read) 100 ) 101 self.context = self._rpc.cpu(0) 102 return self 103 104 def __exit__(self, exc_type, exc_value, exc_traceback): 105 """Tear down this session and associated RPC session resources.""" 106 self.transport.__exit__(exc_type, exc_value, exc_traceback) 107 108 109def create_local_graph_runtime(graph_json_str, mod, ctx): 110 """Create a local graph runtime driving execution on the remote CPU context given. 111 112 Parameters 113 ---------- 114 graph_json_str : str 115 A string containing the graph representation. 116 117 mod : tvm.runtime.Module 118 The remote module containing functions in graph_json_str. 119 120 ctx : tvm.Context 121 The remote CPU execution context. 122 123 Returns 124 ------- 125 tvm.contrib.GraphRuntime : 126 A local graph runtime instance that executes on the remote device. 127 """ 128 device_type_id = [ctx.device_type, ctx.device_id] 129 fcreate = get_global_func("tvm.graph_runtime.create") 130 return graph_runtime.GraphModule(fcreate(graph_json_str, mod, *device_type_id)) 131