1#!/usr/local/bin/python3.8
2#
3# Copyright 2017 Ettus Research, a National Instruments Company
4#
5# SPDX-License-Identifier: GPL-3.0-or-later
6#
7"""
8RPC shell to debug USRP MPM capable devices
9"""
10
11from __future__ import print_function
12import re
13import cmd
14import time
15import argparse
16import multiprocessing
17from importlib import import_module
18
19try:
20    from usrp_mpm.mpmtypes import MPM_RPC_PORT
21except ImportError:
22    MPM_RPC_PORT = None
23
24DEFAULT_MPM_RPC_PORT = 49601
25if MPM_RPC_PORT is None:
26    MPM_RPC_PORT = DEFAULT_MPM_RPC_PORT
27if MPM_RPC_PORT != DEFAULT_MPM_RPC_PORT:
28    print("Warning: Default encoded MPM RPC port does not match that in MPM.")
29
30
31def parse_args():
32    """
33    Parse command line args.
34    """
35    parser = argparse.ArgumentParser(
36        description="MPM Shell",
37    )
38    parser.add_argument(
39        'host',
40        help="Specify host to connect to.", default=None,
41    )
42    parser.add_argument(
43        '-p', '--port', type=int,
44        help="Specify port to connect to.", default=MPM_RPC_PORT,
45    )
46    parser.add_argument(
47        '-c', '--claim',
48        action='store_true',
49        help="Claim device after connecting."
50    )
51    parser.add_argument(
52        '-j', '--hijack', type=str,
53        help="Hijack running session (excludes --claim)."
54    )
55    parser.add_argument(
56        '-s', '--script', type=str,
57        help="Run shell in scripting mode. Specified script contains "
58             "MPM shell commands, one per line."
59    )
60    return parser.parse_args()
61
62
63def split_args(args, *default_args):
64    " Returns an array of args, space-separated "
65    args = args.split()
66    return [
67        arg_val if arg_idx < len(args) else default_args[arg_idx]
68        for arg_idx, arg_val in enumerate(args)
69    ]
70
71
72class MPMClaimer(object):
73    """
74    Holds a claim.
75    """
76    def __init__(self, host, port):
77        self.token = None
78        self.hijacked = False
79        self._cmd_q = multiprocessing.Queue()
80        self._token_q = multiprocessing.Queue()
81        self._claim_loop = multiprocessing.Process(
82            target=self.claim_loop,
83            name="Claimer Loop",
84            args=(host, port, self._cmd_q, self._token_q)
85        )
86        self._claim_loop.start()
87
88    def claim_loop(self, host, port, cmd_q, token_q):
89        """
90        Run a claim loop
91        """
92        from mprpc import RPCClient
93        from mprpc.exceptions import RPCError
94        command = None
95        token = None
96        exit_loop = False
97        client = RPCClient(host, port, pack_params={'use_bin_type': True})
98        try:
99            while not exit_loop:
100                if token and not command:
101                    client.call('reclaim', token)
102                elif command == 'claim':
103                    if not token:
104                        token = client.call('claim', 'MPM Shell')
105                    else:
106                        print("Already have claim")
107                    token_q.put(token)
108                elif command == 'unclaim':
109                    if token:
110                        client.call('unclaim', token)
111                    token = None
112                    token_q.put(None)
113                elif command == 'exit':
114                    if token:
115                        client.call('unclaim', token)
116                    token = None
117                    token_q.put(None)
118                    exit_loop = True
119                time.sleep(1)
120                command = None
121                if not cmd_q.empty():
122                    command = cmd_q.get(False)
123        except RPCError as ex:
124            print("Unexpected RPC error in claimer loop!")
125            print(str(ex))
126
127    def exit(self):
128        """
129        Unclaim device and exit claim loop.
130        """
131        self.unclaim()
132        self._cmd_q.put('exit')
133        self._claim_loop.join()
134
135    def unclaim(self):
136        """
137        Unclaim device.
138        """
139        if not self.hijacked:
140            self._cmd_q.put('unclaim')
141        else:
142            self.hijacked = False
143        self.token = None
144
145    def claim(self):
146        """
147        Claim device.
148        """
149        self._cmd_q.put('claim')
150        self.token = self._token_q.get(True, 5.0)
151
152    def get_token(self):
153        """
154        Get current token (if any)
155        """
156        if not self._token_q.empty():
157            self.token = self._token_q.get(False)
158        return self.token
159
160    def hijack(self, token):
161        """
162        Take over existing session by providing session token.
163        """
164        if self.token:
165            print("Already have token")
166            return
167        else:
168            self.token = token
169        self.hijacked = True
170
171class MPMShell(cmd.Cmd):
172    """
173    RPC Shell class. See cmd module.
174    """
175    def __init__(self, host, port, claim, hijack, script):
176        cmd.Cmd.__init__(self)
177        self.prompt = "> "
178        self.client = None
179        self.remote_methods = []
180        self._host = host
181        self._port = port
182        self._device_info = None
183        self._claimer = MPMClaimer(self._host, self._port)
184        if host is not None:
185            self.connect(host, port)
186            if claim:
187                self.claim()
188            elif hijack:
189                self.hijack(hijack)
190        self.update_prompt()
191        self._script = script
192        if self._script:
193            self.parse_script()
194
195    def _add_command(self, command, docs, requires_token=False):
196        """
197        Add a command to the current session
198        """
199        cmd_name = 'do_' + command
200        if not hasattr(self, cmd_name):
201            new_command = lambda args: self.rpc_template(
202                str(command), requires_token, args
203            )
204            new_command.__doc__ = docs
205            setattr(self, cmd_name, new_command)
206            self.remote_methods.append(command)
207
208    def _print_response(self, response):
209        print(re.sub("^", "< ", response, flags=re.MULTILINE))
210
211    def rpc_template(self, command, requires_token, args=None):
212        """
213        Template function to create new RPC shell commands
214        """
215        from mprpc.exceptions import RPCError
216        if requires_token and \
217                (self._claimer is None or self._claimer.get_token() is None):
218            self._print_response("Cannot execute `{}' -- "
219                                 "no claim available!".format(command))
220            return False
221        try:
222            if args or requires_token:
223                expanded_args = self.expand_args(args)
224                if requires_token:
225                    expanded_args.insert(0, self._claimer.get_token())
226                response = self.client.call(command, *expanded_args)
227            else:
228                response = self.client.call(command)
229        except RPCError as ex:
230            self._print_response("RPC Command failed!\nError: {}".format(ex))
231            return False
232        except Exception as ex:
233            self._print_response("Unexpected exception!\nError: {}".format(ex))
234            return True
235        if isinstance(response, bool):
236            if response:
237                self._print_response("Command succeeded.")
238            else:
239                self._print_response("Command failed!")
240        else:
241            self._print_response(str(response))
242
243        return False
244
245    def get_names(self):
246        " We need this for tab completion. "
247        return dir(self)
248
249    ###########################################################################
250    # Cmd module specific
251    ###########################################################################
252    def default(self, line):
253        self._print_response("*** Unknown syntax: %s" % line)
254
255    def preloop(self):
256        """
257        In script mode add Execution start marker to ease parsing script output
258        :return: None
259        """
260        if self._script:
261            print("Execute %s" % self._script)
262
263    def precmd(self, line):
264        """
265        Add command prepended by "> " in scripting mode to ease parsing script
266        output.
267        """
268        if self.cmdqueue:
269            print("> %s" % line)
270        return line
271
272    def postcmd(self, stop, line):
273        """
274        Is run after every command executes. Does:
275        - Update prompt
276        """
277        self.update_prompt()
278        return stop
279
280    ###########################################################################
281    # Internal methods
282    ###########################################################################
283    def connect(self, host, port):
284        """
285        Launch a connection.
286        """
287        from mprpc import RPCClient
288        print("Attempting to connect to {host}:{port}...".format(
289            host=host, port=port
290        ))
291        try:
292            self.client = RPCClient(host, port, pack_params={'use_bin_type': True})
293            print("Connection successful.")
294        except Exception as ex:
295            print("Connection refused")
296            print("Error: {}".format(ex))
297            return False
298        self._host = host
299        self._port = port
300        print("Getting methods...")
301        methods = self.client.call('list_methods')
302        for method in methods:
303            self._add_command(*method)
304        print("Added {} methods.".format(len(methods)))
305        print("Quering device info...")
306        self._device_info = self.client.call('get_device_info')
307        return True
308
309    def disconnect(self):
310        """
311        Clean up after a connection was closed.
312        """
313        from mprpc.exceptions import RPCError
314        self._device_info = None
315        if self._claimer is not None:
316            self._claimer.exit()
317        if self.client:
318            try:
319                self.client.close()
320            except RPCError as ex:
321                print("Error while closing the connection")
322                print("Error: {}".format(ex))
323        for method in self.remote_methods:
324            delattr(self, "do_" + method)
325        self.remote_methods = []
326        self.client = None
327        self._host = None
328        self._port = None
329
330    def claim(self):
331        " Initialize claim "
332        print("Claiming device...")
333        self._claimer.claim()
334        return True
335
336    def hijack(self, token):
337        " Hijack running session "
338        if self._claimer.hijacked:
339            print("Claimer already active. Can't hijack.")
340            return False
341        print("Hijacking device...")
342        self._claimer.hijack(token)
343        return True
344
345    def unclaim(self):
346        """
347        unclaim
348        """
349        self._claimer.unclaim()
350
351    def update_prompt(self):
352        """
353        Update prompt
354        """
355        if self._device_info is None:
356            self.prompt = '> '
357        else:
358            token = self._claimer.get_token()
359            if token is None:
360                claim_status = ''
361            elif self._claimer.hijacked:
362                claim_status = ' [H]'
363            else:
364                claim_status = ' [C]'
365            self.prompt = '{dev_id}{claim_status}> '.format(
366                dev_id=self._device_info.get(
367                    'name', self._device_info.get('serial', '?')
368                ),
369                claim_status=claim_status,
370            )
371
372    def parse_script(self):
373        """
374        Adding script command from file pointed to by self._script.
375
376        The commands are read from file one per line and added to cmdqueue of
377        parent class. This way they will be executed instead of input from
378        stdin. An EOF command is appended to the list to ensure the shell exits
379        after script execution.
380        :return: None
381        """
382        try:
383            with open(self._script, "r") as script:
384                for command in script:
385                    self.cmdqueue.append(command.strip())
386        except OSError as ex:
387            print("Failed to read script. (%s)" % ex)
388        self.cmdqueue.append("EOF") # terminate shell after script execution
389
390    def expand_args(self, args):
391        """
392        Takes a string and returns a list
393        """
394        if self._claimer is not None and self._claimer.get_token() is not None:
395            args = args.replace('$T', str(self._claimer.get_token()))
396        eval_preamble = '='
397        args = args.strip()
398        if args.startswith(eval_preamble):
399            parsed_args = eval(args.lstrip(eval_preamble))
400            if not isinstance(parsed_args, list):
401                parsed_args = [parsed_args]
402        else:
403            parsed_args = []
404            for arg in args.split():
405                try:
406                    parsed_args.append(int(arg, 0))
407                    continue
408                except ValueError:
409                    pass
410                try:
411                    parsed_args.append(float(arg))
412                    continue
413                except ValueError:
414                    pass
415                parsed_args.append(arg)
416        return parsed_args
417
418    ###########################################################################
419    # Predefined commands
420    ###########################################################################
421    def do_connect(self, args):
422        """
423        Connect to a remote MPM server. See connect()
424        """
425        host, port = split_args(args, 'localhost', MPM_RPC_PORT)
426        port = int(port)
427        self.connect(host, port)
428
429    def do_claim(self, _):
430        """
431        Spawn a claim loop
432        """
433        self.claim()
434
435    def do_hijack(self, token):
436        """
437        Hijack a running session
438        """
439        self.hijack(token)
440
441    def do_unclaim(self, _):
442        """
443        unclaim
444        """
445        self.unclaim()
446
447    def do_disconnect(self, _):
448        """
449        disconnect from the RPC server
450        """
451        self.disconnect()
452
453    def do_import(self, args):
454        """import a python module into the global namespace"""
455        globals()[args] = import_module(args)
456
457    # pylint: disable=invalid-name
458    def do_EOF(self, _):
459        """
460        When catching EOF, exit the program.
461        """
462        print("Exiting...")
463        self.disconnect()
464        return True # orderly shutdown
465
466def main():
467    " Go, go, go! "
468    args = parse_args()
469    my_shell = MPMShell(args.host, args.port, args.claim,
470                        args.hijack, args.script)
471
472    try:
473        my_shell.cmdloop()
474    except KeyboardInterrupt:
475        my_shell.disconnect()
476    except Exception as ex: # pylint: disable=broad-except
477        print("Uncaught exception: " + str(ex))
478        my_shell.disconnect()
479        return False
480    return True
481
482if __name__ == "__main__":
483    exit(not main())
484