1"""
2Copyright 2008-2015 Free Software Foundation, Inc.
3This file is part of GNU Radio
4
5GNU Radio Companion is free software; you can redistribute it and/or
6modify it under the terms of the GNU General Public License
7as published by the Free Software Foundation; either version 2
8of the License, or (at your option) any later version.
9
10GNU Radio Companion is distributed in the hope that it will be useful,
11but WITHOUT ANY WARRANTY; without even the implied warranty of
12MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13GNU General Public License for more details.
14
15You should have received a copy of the GNU General Public License
16along with this program; if not, write to the Free Software
17Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
18"""
19
20from __future__ import absolute_import
21
22import collections
23import itertools
24import copy
25
26import six
27from six.moves import range
28import re
29
30import ast
31
32from ._templates import MakoTemplates
33from ._flags import Flags
34
35from ..base import Element
36from ..utils.descriptors import lazy_property
37
38def _get_elem(iterable, key):
39    items = list(iterable)
40    for item in items:
41        if item.key == key:
42            return item
43    return ValueError('Key "{}" not found in {}.'.format(key, items))
44
45
46class Block(Element):
47
48    is_block = True
49
50    STATE_LABELS = ['disabled', 'enabled', 'bypassed']
51
52    key = ''
53    label = ''
54    category = ''
55    vtype = '' # This is only used for variables when we want C++ output
56    flags = Flags('')
57    documentation = {'': ''}
58
59    value = None
60    asserts = []
61
62    templates = MakoTemplates()
63    parameters_data = []
64    inputs_data = []
65    outputs_data = []
66
67    extra_data = {}
68    loaded_from = '(unknown)'
69
70    def __init__(self, parent):
71        """Make a new block from nested data."""
72        super(Block, self).__init__(parent)
73        param_factory = self.parent_platform.make_param
74        port_factory = self.parent_platform.make_port
75
76        self.params = collections.OrderedDict(
77            (data['id'], param_factory(parent=self, **data))
78            for data in self.parameters_data
79        )
80        if self.key == 'options':
81            self.params['id'].hide = 'part'
82
83        self.sinks = [port_factory(parent=self, **params) for params in self.inputs_data]
84        self.sources = [port_factory(parent=self, **params) for params in self.outputs_data]
85
86        self.active_sources = []  # on rewrite
87        self.active_sinks = []  # on rewrite
88
89        self.states = {'state': True, 'bus_source': False, 'bus_sink': False, 'bus_structure': None}
90        self.block_namespace = {}
91
92        if 'cpp' in self.flags:
93            self.orig_cpp_templates = self.cpp_templates # The original template, in case we have to edit it when transpiling to C++
94
95        self.current_bus_structure = {'source': None, 'sink': None}
96
97    def get_bus_structure(self, direction):
98        if direction == 'source':
99            bus_structure = self.bus_structure_source
100        else:
101            bus_structure = self.bus_structure_sink
102
103        if not bus_structure:
104            return None
105
106        try:
107            clean_bus_structure = self.evaluate(bus_structure)
108            return clean_bus_structure
109        except:
110            return None
111
112
113    # region Rewrite_and_Validation
114    def rewrite(self):
115        """
116        Add and remove ports to adjust for the nports.
117        """
118        Element.rewrite(self)
119
120        def rekey(ports):
121            """Renumber non-message/message ports"""
122            domain_specific_port_index = collections.defaultdict(int)
123            for port in ports:
124                if not port.key.isdigit():
125                    continue
126                domain = port.domain
127                port.key = str(domain_specific_port_index[domain])
128                domain_specific_port_index[domain] += 1
129
130        # Adjust nports
131        for ports in (self.sources, self.sinks):
132            self._rewrite_nports(ports)
133            rekey(ports)
134
135        self.update_bus_logic()
136        # disconnect hidden ports
137        self.parent_flowgraph.disconnect(*[p for p in self.ports() if p.hidden])
138
139        self.active_sources = [p for p in self.sources if not p.hidden]
140        self.active_sinks = [p for p in self.sinks if not p.hidden]
141
142        # namespaces may have changed, update them
143        self.block_namespace.clear()
144        imports = ""
145        try:
146            imports = self.templates.render('imports')
147            exec(imports, self.block_namespace)
148        except ImportError:
149            # We do not have a good way right now to determine if an import is for a
150            # hier block, these imports will fail as they are not in the search path
151            # this is ok behavior, unfortunately we could be hiding other import bugs
152            pass
153        except Exception:
154            self.add_error_message("Failed to evaluate import expression {!r}".format(imports))
155
156    def update_bus_logic(self):
157        ###############################
158        ## Bus Logic
159        ###############################
160
161        for direc in {'source','sink'}:
162            if direc == 'source':
163                ports = self.sources
164                ports_gui = self.filter_bus_port(self.sources)
165                bus_state = self.bus_source
166            else:
167                ports = self.sinks
168                ports_gui = self.filter_bus_port(self.sinks)
169                bus_state = self.bus_sink
170
171            # Remove the bus ports
172            removed_bus_ports = []
173            removed_bus_connections = []
174            if 'bus' in map(lambda a: a.dtype, ports):
175                for port in ports_gui:
176                    for c in self.parent_flowgraph.connections:
177                        if port is c.source_port or port is c.sink_port:
178                            removed_bus_ports.append(port)
179                            removed_bus_connections.append(c)
180                    ports.remove(port)
181
182
183            if (bus_state):
184                struct = self.form_bus_structure(direc)
185                self.current_bus_structure[direc] = struct
186
187                # Hide ports that are not part of the bus structure
188                #TODO: Blocks where it is desired to only have a subset
189                # of ports included in the bus still has some issues
190                for idx, port in enumerate(ports):
191                    if any([idx in bus for bus in self.current_bus_structure[direc]]):
192                        if (port.stored_hidden_state is None):
193                            port.stored_hidden_state = port.hidden
194                            port.hidden = True
195
196                # Add the Bus Ports to the list of ports
197                for i in range(len(struct)):
198                    # self.sinks = [port_factory(parent=self, **params) for params in self.inputs_data]
199                    port = self.parent.parent.make_port(self,direction=direc,id=str(len(ports)),label='bus',dtype='bus',bus_struct=struct[i])
200                    ports.append(port)
201
202                    for (saved_port, connection) in zip(removed_bus_ports, removed_bus_connections):
203                        if port.key == saved_port.key:
204                            self.parent_flowgraph.connections.remove(connection)
205                            if saved_port.is_source:
206                                connection.source_port = port
207                            if saved_port.is_sink:
208                                connection.sink_port = port
209                            self.parent_flowgraph.connections.add(connection)
210
211
212            else:
213                self.current_bus_structure[direc] = None
214
215                # Re-enable the hidden property of the ports
216                for port in ports:
217                    if (port.stored_hidden_state is not None):
218                        port.hidden = port.stored_hidden_state
219                        port.stored_hidden_state = None
220
221
222
223    def _rewrite_nports(self, ports):
224        for port in ports:
225            if hasattr(port, 'master_port'):  # Not a master port and no left-over clones
226                port.dtype = port.master_port.dtype
227                port.vlen = port.master_port.vlen
228                continue
229            nports = port.multiplicity
230            for clone in port.clones[nports-1:]:
231                # Remove excess connections
232                self.parent_flowgraph.disconnect(clone)
233                port.remove_clone(clone)
234                ports.remove(clone)
235            # Add more cloned ports
236            for j in range(1 + len(port.clones), nports):
237                clone = port.add_clone()
238                ports.insert(ports.index(port) + j, clone)
239
240    def validate(self):
241        """
242        Validate this block.
243        Call the base class validate.
244        Evaluate the checks: each check must evaluate to True.
245        """
246        Element.validate(self)
247        self._run_asserts()
248        self._validate_generate_mode_compat()
249        self._validate_output_language_compat()
250        self._validate_var_value()
251
252    def _run_asserts(self):
253        """Evaluate the checks"""
254        for expr in self.asserts:
255            try:
256                if not self.evaluate(expr):
257                    self.add_error_message('Assertion "{}" failed.'.format(expr))
258            except Exception:
259                self.add_error_message('Assertion "{}" did not evaluate.'.format(expr))
260
261    def _validate_generate_mode_compat(self):
262        """check if this is a GUI block and matches the selected generate option"""
263        current_generate_option = self.parent.get_option('generate_options')
264
265        def check_generate_mode(label, flag, valid_options):
266            block_requires_mode = (
267                flag in self.flags or self.label.upper().startswith(label)
268            )
269            if block_requires_mode and current_generate_option not in valid_options:
270                self.add_error_message("Can't generate this block in mode: {} ".format(
271                                       repr(current_generate_option)))
272
273        check_generate_mode('QT GUI', Flags.NEED_QT_GUI, ('qt_gui', 'hb_qt_gui'))
274
275    def _validate_output_language_compat(self):
276        """check if this block supports the selected output language"""
277        current_output_language = self.parent.get_option('output_language')
278
279        if current_output_language == 'cpp':
280            if 'cpp' not in self.flags:
281                self.add_error_message("This block does not support C++ output.")
282
283            if self.key == 'parameter':
284                if not self.params['type'].value:
285                    self.add_error_message("C++ output requires you to choose a parameter type.")
286
287    def _validate_var_value(self):
288        """or variables check the value (only if var_value is used)"""
289        if self.is_variable and self.value != 'value':
290            try:
291                self.parent_flowgraph.evaluate(self.value, local_namespace=self.namespace)
292            except Exception as err:
293                self.add_error_message('Value "{}" cannot be evaluated:\n{}'.format(self.value, err))
294    # endregion
295
296    # region Properties
297
298    def __str__(self):
299        return 'Block - {} - {}({})'.format(self.name, self.label, self.key)
300
301    def __repr__(self):
302        try:
303            name = self.name
304        except Exception:
305            name = self.key
306        return 'block[' + name + ']'
307
308    @property
309    def name(self):
310        return self.params['id'].value
311
312    @lazy_property
313    def is_virtual_or_pad(self):
314        return self.key in ("virtual_source", "virtual_sink", "pad_source", "pad_sink")
315
316    @lazy_property
317    def is_variable(self):
318        return bool(self.value)
319
320    @lazy_property
321    def is_import(self):
322        return self.key == 'import'
323
324    @lazy_property
325    def is_snippet(self):
326        return self.key == 'snippet'
327
328    @property
329    def comment(self):
330        return self.params['comment'].value
331
332    @property
333    def state(self):
334        """Gets the block's current state."""
335        state = self.states['state']
336        return state if state in self.STATE_LABELS else 'enabled'
337
338    @state.setter
339    def state(self, value):
340        """Sets the state for the block."""
341        self.states['state'] = value
342
343    # Enable/Disable Aliases
344    @property
345    def enabled(self):
346        """Get the enabled state of the block"""
347        return self.state != 'disabled'
348
349    @property
350    def bus_sink(self):
351        """Gets the block's current Toggle Bus Sink state."""
352        return self.states['bus_sink']
353
354    @bus_sink.setter
355    def bus_sink(self, value):
356        """Sets the Toggle Bus Sink state for the block."""
357        self.states['bus_sink'] = value
358
359    @property
360    def bus_source(self):
361        """Gets the block's current Toggle Bus Sink state."""
362        return self.states['bus_source']
363
364    @bus_source.setter
365    def bus_source(self, value):
366        """Sets the Toggle Bus Source state for the block."""
367        self.states['bus_source'] = value
368
369    @property
370    def bus_structure_source(self):
371        """Gets the block's current source bus structure."""
372        try:
373            bus_structure = self.params['bus_structure_source'].value or None
374        except:
375            bus_structure = None
376        return bus_structure
377
378    @property
379    def bus_structure_sink(self):
380        """Gets the block's current source bus structure."""
381        try:
382            bus_structure = self.params['bus_structure_sink'].value or None
383        except:
384            bus_structure = None
385        return bus_structure
386
387    # endregion
388
389    ##############################################
390    # Getters (old)
391    ##############################################
392    def get_var_make(self):
393        return self.templates.render('var_make')
394
395    def get_cpp_var_make(self):
396        return self.cpp_templates.render('var_make')
397
398    def get_var_value(self):
399        return self.templates.render('var_value')
400
401    def get_callbacks(self):
402        """
403        Get a list of function callbacks for this block.
404
405        Returns:
406            a list of strings
407        """
408        def make_callback(callback):
409            if 'self.' in callback:
410                return callback
411            return 'self.{}.{}'.format(self.name, callback)
412
413        return [make_callback(c) for c in self.templates.render('callbacks')]
414
415    def get_cpp_callbacks(self):
416        """
417        Get a list of C++ function callbacks for this block.
418
419        Returns:
420            a list of strings
421        """
422        def make_callback(callback):
423            if self.is_variable:
424                return callback
425            if 'this->' in callback:
426                return callback
427            return 'this->{}->{}'.format(self.name, callback)
428
429        return [make_callback(c) for c in self.cpp_templates.render('callbacks')]
430
431    def decide_type(self):
432        """
433        Evaluate the value of the variable block and decide its type.
434
435        Returns:
436            None
437        """
438        value = self.params['value'].value
439        self.cpp_templates = copy.copy(self.orig_cpp_templates)
440
441        def get_type(element):
442            try:
443                evaluated = ast.literal_eval(element)
444
445            except ValueError or SyntaxError:
446                if re.match(r'^(numpy|np|scipy|sp)\.pi$', value):
447                    return 'pi'
448                else:
449                    return 'std::string'
450
451            else:
452                _vtype = type(evaluated)
453                if _vtype in [int, float, bool, list]:
454                    if _vtype == (int or long):
455                        return 'int'
456
457                    if _vtype == float:
458                        return 'double'
459
460                    if _vtype == bool:
461                        return 'bool'
462
463                    if _vtype == list:
464                        try:
465                            first_element_type = type(evaluated[0])
466                            if first_element_type != str:
467                                list_type = get_type(str(evaluated[0]))
468                            else:
469                                list_type = get_type(evaluated[0])
470
471                        except IndexError: # empty list
472                            return 'std::vector<std::string>'
473
474                        else:
475                            return 'std::vector<' + list_type + '>'
476
477                else:
478                    return 'std::string'
479
480        self.vtype = get_type(value)
481        if self.vtype == 'bool':
482            self.cpp_templates['var_make'] = self.cpp_templates['var_make'].replace('${value}', (value[0].lower() + value[1:]))
483
484        elif self.vtype == 'pi':
485            self.vtype = 'double'
486            self.cpp_templates['var_make'] = self.cpp_templates['var_make'].replace('${value}', 'boost::math::constants::pi<double>()')
487            self.cpp_templates['includes'].append('#include <boost/math/constants/constants.hpp>')
488
489        elif 'std::vector' in self.vtype:
490            self.cpp_templates['includes'].append('#include <vector>')
491            self.cpp_templates['var_make'] = self.cpp_templates['var_make'].replace('${value}', '{' + value[1:-1] + '}')
492
493        if 'string' in self.vtype:
494            self.cpp_templates['includes'].append('#include <string>')
495
496    def is_virtual_sink(self):
497        return self.key == 'virtual_sink'
498
499    def is_virtual_source(self):
500        return self.key == 'virtual_source'
501
502    # Block bypassing
503    def get_bypassed(self):
504        """
505        Check if the block is bypassed
506        """
507        return self.state == 'bypassed'
508
509    def set_bypassed(self):
510        """
511        Bypass the block
512
513        Returns:
514            True if block changes state
515        """
516        if self.state != 'bypassed' and self.can_bypass():
517            self.state = 'bypassed'
518            return True
519        return False
520
521    def can_bypass(self):
522        """ Check the number of sinks and sources and see if this block can be bypassed """
523        # Check to make sure this is a single path block
524        # Could possibly support 1 to many blocks
525        if len(self.sources) != 1 or len(self.sinks) != 1:
526            return False
527        if not (self.sources[0].dtype == self.sinks[0].dtype):
528            return False
529        if self.flags.disable_bypass:
530            return False
531        return True
532
533    def ports(self):
534        return itertools.chain(self.sources, self.sinks)
535
536    def active_ports(self):
537        return itertools.chain(self.active_sources, self.active_sinks)
538
539    def children(self):
540        return itertools.chain(six.itervalues(self.params), self.ports())
541
542    ##############################################
543    # Access
544    ##############################################
545
546    def get_sink(self, key):
547        return _get_elem(self.sinks, key)
548
549    def get_source(self, key):
550        return _get_elem(self.sources, key)
551
552    ##############################################
553    # Resolve
554    ##############################################
555    @property
556    def namespace(self):
557        # update block namespace
558        self.block_namespace.update({key:param.get_evaluated() for key, param in six.iteritems(self.params)})
559        return self.block_namespace
560
561    @property
562    def namespace_templates(self):
563        return {key: param.template_arg for key, param in six.iteritems(self.params)}
564
565    def evaluate(self, expr):
566        return self.parent_flowgraph.evaluate(expr, self.namespace)
567
568    ##############################################
569    # Import/Export Methods
570    ##############################################
571    def export_data(self):
572        """
573        Export this block's params to nested data.
574
575        Returns:
576            a nested data odict
577        """
578        data = collections.OrderedDict()
579        if self.key != 'options':
580            data['name'] = self.name
581            data['id'] = self.key
582        data['parameters'] = collections.OrderedDict(sorted(
583            (param_id, param.value) for param_id, param in self.params.items()
584            if (param_id != 'id' or self.key == 'options')
585        ))
586        data['states'] = collections.OrderedDict(sorted(self.states.items()))
587        return data
588
589    def import_data(self, name, states, parameters, **_):
590        """
591        Import this block's params from nested data.
592        Any param keys that do not exist will be ignored.
593        Since params can be dynamically created based another param,
594        call rewrite, and repeat the load until the params stick.
595        """
596        self.params['id'].value = name
597        self.states.update(states)
598
599        def get_hash():
600            return hash(tuple(hash(v) for v in self.params.values()))
601
602        pre_rewrite_hash = -1
603        while pre_rewrite_hash != get_hash():
604            for key, value in six.iteritems(parameters):
605                try:
606                    self.params[key].set_value(value)
607                except KeyError:
608                    continue
609            # Store hash and call rewrite
610            pre_rewrite_hash = get_hash()
611            self.rewrite()
612
613    ##############################################
614    # Controller Modify
615    ##############################################
616    def filter_bus_port(self, ports):
617        buslist = [p for p in ports if p.dtype == 'bus']
618        return buslist or ports
619
620    def type_controller_modify(self, direction):
621        """
622        Change the type controller.
623
624        Args:
625            direction: +1 or -1
626
627        Returns:
628            true for change
629        """
630        changed = False
631        type_param = None
632        for param in filter(lambda p: p.is_enum(), self.get_params()):
633            children = self.get_ports() + self.get_params()
634            # Priority to the type controller
635            if param.get_key() in ' '.join(map(lambda p: p._type, children)): type_param = param
636            # Use param if type param is unset
637            if not type_param:
638                type_param = param
639        if type_param:
640            # Try to increment the enum by direction
641            try:
642                keys = type_param.get_option_keys()
643                old_index = keys.index(type_param.get_value())
644                new_index = (old_index + direction + len(keys)) % len(keys)
645                type_param.set_value(keys[new_index])
646                changed = True
647            except:
648                pass
649        return changed
650
651    def form_bus_structure(self, direc):
652        if direc == 'source':
653            ports = self.sources
654            bus_structure = self.get_bus_structure('source')
655        else:
656            ports = self.sinks
657            bus_structure = self.get_bus_structure('sink')
658
659        struct = [range(len(ports))]
660        # struct = list(range(len(ports)))
661        #TODO for more complicated port structures, this code is needed but not working yet
662        if any([p.multiplicity for p in ports]):
663            structlet = []
664            last = 0
665            # group the ports with > n inputs together on the bus
666            cnt = 0
667            idx = 0
668            for p in ports:
669                if p.domain == 'message':
670                    continue
671                if cnt > 0:
672                    cnt -= 1
673                    continue
674
675                if p.multiplicity > 1:
676                    cnt = p.multiplicity-1
677                    structlet.append([idx+j for j in range(p.multiplicity)])
678                else:
679                    structlet.append([idx])
680
681            struct = structlet
682        if bus_structure:
683            struct = bus_structure
684
685        self.current_bus_structure[direc] = struct
686        return struct
687
688    def bussify(self, direc):
689        if direc == 'source':
690            ports = self.sources
691            ports_gui = self.filter_bus_port(self.sources)
692            self.bus_structure = self.get_bus_structure('source')
693            self.bus_source = not self.bus_source
694        else:
695            ports = self.sinks
696            ports_gui = self.filter_bus_port(self.sinks)
697            self.bus_structure = self.get_bus_structure('sink')
698            self.bus_sink = not self.bus_sink
699
700        # Disconnect all the connections when toggling the bus state
701        for port in ports:
702            l_connections = list(port.connections())
703            for connect in l_connections:
704                self.parent.remove_element(connect)
705
706        self.update_bus_logic()
707