1# -*- coding: utf-8 -*-
2#
3# Copyright (c) 2020 Martin Owens <doctormo@gmail.com>
4#                    Sergei Izmailov <sergei.a.izmailov@gmail.com>
5#                    Thomas Holder <thomas.holder@schrodinger.com>
6#
7# This program is free software; you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation; either version 2 of the License, or
10# (at your option) any later version.
11#
12# This program is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with this program; if not, write to the Free Software
19# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
20#
21# pylint: disable=arguments-differ
22"""
23Provide extra utility to each svg element type specific to its type.
24
25This is useful for having a common interface for each element which can
26give path, transform, and property access easily.
27"""
28
29from collections import defaultdict
30from copy import deepcopy
31from lxml import etree
32
33from ..paths import Path
34from ..styles import Style, AttrFallbackStyle, Classes
35from ..transforms import Transform, BoundingBox
36from ..utils import FragmentError
37from ..units import convert_unit, render_unit
38from ._utils import ChildToProperty, NSS, addNS, removeNS, splitNS
39
40from typing import overload, DefaultDict, Type, Any, List, Tuple, Union, Optional  # pylint: disable=unused-import
41
42class NodeBasedLookup(etree.PythonElementClassLookup):
43    """
44    We choose what kind of Elements we should return for each element, providing useful
45    SVG based API to our extensions system.
46    """
47    # (ns,tag) -> list(cls) ; ascending priority
48    lookup_table = defaultdict(list) # type: DefaultDict[str, List[Any]]
49
50    @classmethod
51    def register_class(cls, klass):
52        """Register the given class using it's attached tag name"""
53        cls.lookup_table[splitNS(klass.tag_name)].append(klass)
54
55    @classmethod
56    def find_class(cls, xpath):
57        """Find the class for this type of element defined by an xpath"""
58        if isinstance(xpath, type):
59            return xpath
60        for cls in cls.lookup_table[splitNS(xpath.split('/')[-1])]:
61            # TODO: We could create a apply the xpath attrs to the test element
62            # to narrow the search, but this does everything we need right now.
63            test_element = cls()
64            if cls._is_class_element(test_element):
65                return cls
66        raise KeyError(f"Could not find svg tag for '{xpath}'")
67
68    def lookup(self, doc, element): # pylint: disable=unused-argument
69        """Lookup called by lxml when assigning elements their object class"""
70        try:
71            for cls in reversed(self.lookup_table[splitNS(element.tag)]):
72                if cls._is_class_element(element): # pylint: disable=protected-access
73                    return cls
74        except TypeError:
75            # Handle non-element proxies case
76            # The documentation implies that it's not possible
77            # Didn't found a reliable way to check whether proxy corresponds to element or not
78            # Look like lxml issue to me.
79            # The troubling element is "<!--Comment-->"
80            return None
81        return BaseElement
82
83
84SVG_PARSER = etree.XMLParser(huge_tree=True, strip_cdata=False)
85SVG_PARSER.set_element_class_lookup(NodeBasedLookup())
86
87def load_svg(stream):
88    """Load SVG file using the SVG_PARSER"""
89    if (isinstance(stream, str) and stream.lstrip().startswith('<'))\
90      or (isinstance(stream, bytes) and stream.lstrip().startswith(b'<')):
91        return etree.ElementTree(etree.fromstring(stream, parser=SVG_PARSER))
92    return etree.parse(stream, parser=SVG_PARSER)
93
94class BaseElement(etree.ElementBase):
95    """Provide automatic namespaces to all calls"""
96    def __init_subclass__(cls):
97        if cls.tag_name:
98            NodeBasedLookup.register_class(cls)
99
100    @classmethod
101    def _is_class_element(cls, el):  # type: (etree.Element) -> bool
102        """Hook to do more restrictive check in addition to (ns,tag) match"""
103        return True
104
105    tag_name = ''
106
107    @property
108    def TAG(self): # pylint: disable=invalid-name
109        """Return the tag_name without NS"""
110        if not self.tag_name:
111            return removeNS(super().tag)[-1]
112        return removeNS(self.tag_name)[-1]
113
114    @classmethod
115    def new(cls, *children, **attrs):
116        """Create a new element, converting attrs values to strings."""
117        obj = cls(*children)
118        obj.update(**attrs)
119        return obj
120
121    NAMESPACE = property(lambda self: splitNS(self.tag_name)[0])
122    PARSER = SVG_PARSER
123    WRAPPED_ATTRS = (
124        # (prop_name, [optional: attr_name], cls)
125        ('transform', Transform),
126        ('style', Style),
127        ('classes', 'class', Classes),
128    ) # type: Tuple[Tuple[Any, ...], ...]
129
130    # We do this because python2 and python3 have different ways
131    # of combining two dictionaries that are incompatible.
132    # This allows us to update these with inheritance.
133    @property
134    def wrapped_attrs(self):
135        """Map attributes to property name and wrapper class"""
136        return dict([(row[-2], (row[0], row[-1])) for row in self.WRAPPED_ATTRS])
137
138    @property
139    def wrapped_props(self):
140        """Map properties to attribute name and wrapper class"""
141        return dict([(row[0], (row[-2], row[-1])) for row in self.WRAPPED_ATTRS])
142
143    typename = property(lambda self: type(self).__name__)
144    xml_path = property(lambda self: self.getroottree().getpath(self))
145    desc = ChildToProperty("svg:desc", prepend=True)
146    title = ChildToProperty("svg:title", prepend=True)
147
148    def __getattr__(self, name):
149        """Get the attribute, but load it if it is not available yet"""
150        if name in self.wrapped_props:
151            (attr, cls) = self.wrapped_props[name]
152            # The reason we do this here and not in _init is because lxml
153            # is inconsistant about when elements are initialised.
154            # So we make this a lazy property.
155            def _set_attr(new_item):
156                if new_item:
157                    self.set(attr, str(new_item))
158                else:
159                    self.attrib.pop(attr, None) # pylint: disable=no-member
160
161            # pylint: disable=no-member
162            value = cls(self.attrib.get(attr, None), callback=_set_attr)
163            setattr(self, name, value)
164            return value
165        raise AttributeError(f"Can't find attribute {self.typename}.{name}")
166
167    def __setattr__(self, name, value):
168        """Set the attribute, update it if needed"""
169        if name in self.wrapped_props:
170            (attr, cls) = self.wrapped_props[name]
171            # Don't call self.set or self.get (infinate loop)
172            if value:
173                if not isinstance(value, cls):
174                    value = cls(value)
175                self.attrib[attr] = str(value)
176            else:
177                self.attrib.pop(attr, None) # pylint: disable=no-member
178        else:
179            super().__setattr__(name, value)
180
181    def get(self, attr, default=None):
182        """Get element attribute named, with addNS support."""
183        if attr in self.wrapped_attrs:
184            (prop, _) = self.wrapped_attrs[attr]
185            value = getattr(self, prop, None)
186            # We check the boolean nature of the value, because empty
187            # transformations and style attributes are equiv to not-existing
188            ret = str(value) if value else (default or None)
189            return ret
190        return super().get(addNS(attr), default)
191
192    def set(self, attr, value):
193        """Set element attribute named, with addNS support"""
194        if attr in self.wrapped_attrs:
195            # Always keep the local wrapped class up to date.
196            (prop, cls) = self.wrapped_attrs[attr]
197            setattr(self, prop, cls(value))
198            value = getattr(self, prop)
199            if not value:
200                return
201        if value is None:
202            self.attrib.pop(addNS(attr), None) # pylint: disable=no-member
203        else:
204            value = str(value)
205            super().set(addNS(attr), value)
206
207    def update(self, **kwargs):
208        """
209        Update element attributes using keyword arguments
210
211        Note: double underscore is used as namespace separator,
212        i.e. "namespace__attr" argument name will be treated as "namespace:attr"
213
214        :param kwargs: dict with name=value pairs
215        :return: self
216        """
217        for name, value in kwargs.items():
218            self.set(name, value)
219        return self
220
221    def pop(self, attr, default=None):
222        """Delete/remove the element attribute named, with addNS support."""
223        if attr in self.wrapped_attrs:
224            # Always keep the local wrapped class up to date.
225            (prop, cls) = self.wrapped_attrs[attr]
226            value = getattr(self, prop)
227            setattr(self, prop, cls(None))
228            return value
229        return self.attrib.pop(addNS(attr), default) # pylint: disable=no-member
230
231    def add(self, *children):
232        """
233        Like append, but will do multiple children and will return
234        children or only child
235        """
236        for child in children:
237            self.append(child)
238        return children if len(children) != 1 else children[0]
239
240    def tostring(self):
241        """Return this element as it would appear in an svg document"""
242        # This kind of hack is pure maddness, but etree provides very little
243        # in the way of fragment printing, prefering to always output valid xml
244        from ..base import SvgOutputMixin
245        svg = SvgOutputMixin.get_template(width=0, height=0).getroot()
246        svg.append(self.copy())
247        return svg.tostring().split(b'>\n    ', 1)[-1][:-6]
248
249    def set_random_id(self, prefix=None, size=4, backlinks=False):
250        """Sets the id attribute if it is not already set."""
251        prefix = str(self) if prefix is None else prefix
252        self.set_id(self.root.get_unique_id(prefix, size=size), backlinks=backlinks)
253
254    def set_random_ids(self, prefix=None, levels=-1, backlinks=False):
255        """Same as set_random_id, but will apply also to children"""
256        self.set_random_id(prefix=prefix, backlinks=backlinks)
257        if levels != 0:
258            for child in self:
259                if hasattr(child, 'set_random_ids'):
260                    child.set_random_ids(prefix=prefix, levels=levels-1, backlinks=backlinks)
261
262    eid = property(lambda self: self.get_id())
263    def get_id(self, as_url=0):
264        """Get the id for the element, will set a new unique id if not set.
265
266        as_url - If set to 1, returns #{id} as a string
267                 If set to 2, returns url(#{id}) as a string
268        """
269        if 'id' not in self.attrib:
270            self.set_random_id(self.TAG)
271        eid = self.get('id')
272        if as_url > 0:
273            eid = '#' + eid
274        if as_url > 1:
275            eid = f'url({eid})'
276        return eid
277
278    def set_id(self, new_id, backlinks=False):
279        """Set the id and update backlinks to xlink and style urls if needed"""
280        old_id = self.get('id', None)
281        self.set('id', new_id)
282        if backlinks and old_id:
283            for elem in self.root.getElementsByHref(old_id):
284                elem.href = self
285            for elem in self.root.getElementsByStyleUrl(old_id):
286                elem.style.update_urls(old_id, new_id)
287
288    @property
289    def root(self):
290        """Get the root document element from any element descendent"""
291        if self.getparent() is not None:
292            return self.getparent().root
293        from ._svg import SvgDocumentElement
294        if not isinstance(self, SvgDocumentElement):
295            raise FragmentError("Element fragment does not have a document root!")
296        return self
297
298    def get_or_create(self, xpath, nodeclass=None, prepend=False):
299        """Get or create the given xpath, pre/append new node if not found."""
300        node = self.findone(xpath)
301        if node is None:
302            if nodeclass is None:
303                nodeclass = NodeBasedLookup.find_class(xpath)
304            node = nodeclass()
305            if prepend:
306                self.insert(0, node)
307            else:
308                self.append(node)
309        return node
310
311    def descendants(self):
312        """Walks the element tree and yields all elements, parent first"""
313        from ._selected import ElementList
314        return ElementList(self.root, self._descendants())
315
316    def _descendants(self):
317        yield self
318        for child in self:
319            if hasattr(child, '_descendants'):
320                yield from child._descendants() # pylint: disable=protected-access
321
322    def ancestors(self, elem=None, stop_at=()):
323        """
324        Walk the parents and yield all the ancestor elements, parent first
325
326        If elem is provided, it will stop at the last common ancestor.
327        If stop_at is provided, it will stop at the first parent that is in this list.
328        """
329        from ._selected import ElementList
330        return ElementList(self.root, self._ancestors(elem=elem, stop_at=stop_at))
331
332    def _ancestors(self, elem, stop_at):
333        if isinstance(elem, BaseElement):
334            stop_at = list(elem.ancestors())
335        parent = self.getparent()
336        if parent is not None:
337            yield parent
338            if parent not in stop_at:
339                yield from parent._ancestors(elem=elem, stop_at=stop_at) # pylint: disable=protected-access
340
341    def backlinks(self, *types):
342        """Get elements which link back to this element, like ancestors but via xlinks"""
343        if not types or isinstance(self, types):
344            yield self
345        my_id = self.get('id')
346        if my_id is not None:
347            elems = list(self.root.getElementsByHref(my_id)) \
348                  + list(self.root.getElementsByStyleUrl(my_id))
349            for elem in elems:
350                if hasattr(elem, 'backlinks'):
351                    for child in elem.backlinks(*types):
352                        yield child
353
354    def xpath(self, pattern, namespaces=NSS):  # pylint: disable=dangerous-default-value
355        """Wrap xpath call and add svg namespaces"""
356        return super().xpath(pattern, namespaces=namespaces)
357
358    def findall(self, pattern, namespaces=NSS):  # pylint: disable=dangerous-default-value
359        """Wrap findall call and add svg namespaces"""
360        return super().findall(pattern, namespaces=namespaces)
361
362    def findone(self, xpath):
363        """Gets a single element from the given xpath or returns None"""
364        el_list = self.xpath(xpath)
365        return el_list[0] if el_list else None
366
367    def delete(self):
368        """Delete this node from it's parent node"""
369        if self.getparent() is not None:
370            self.getparent().remove(self)
371
372    def remove_all(self, *types):
373        """Remove all children or child types"""
374        types = tuple(NodeBasedLookup.find_class(t) for t in types)
375        for child in self:
376            if not types or isinstance(child, types):
377                self.remove(child)
378
379    def replace_with(self, elem):
380        """Replace this element with the given element"""
381        self.addnext(elem)
382        if not elem.get('id') and self.get('id'):
383            elem.set('id', self.get('id'))
384        if not elem.label and self.label:
385            elem.label = self.label
386        self.delete()
387        return elem
388
389    def copy(self):
390        """Make a copy of the element and return it"""
391        elem = deepcopy(self)
392        elem.set('id', None)
393        return elem
394
395    def duplicate(self):
396        """Like copy(), but the copy stays in the tree and sets a random id"""
397        elem = self.copy()
398        self.addnext(elem)
399        elem.set_random_id()
400        return elem
401
402    def __str__(self):
403        # We would do more here, but lxml is VERY unpleseant when it comes to
404        # namespaces, basically over printing details and providing no
405        # supression mechanisms to turn off xml's over engineering.
406        return str(self.tag).split('}')[-1]
407
408    @property
409    def href(self):
410        """Returns the referred-to element if available"""
411        ref = self.get('xlink:href')
412        if not ref:
413            return None
414        return self.root.getElementById(ref.strip('#'))
415
416    @href.setter
417    def href(self, elem):
418        """Set the href object"""
419        if isinstance(elem, BaseElement):
420            elem = elem.get_id()
421        self.set('xlink:href', '#' + elem)
422
423    def fallback_style(self, move=False):
424        """Get styles falling back to element attributes"""
425        return AttrFallbackStyle(self, move=move)
426
427    @property
428    def label(self):
429        """Returns the inkscape label"""
430        return self.get('inkscape:label', None)
431
432    label = label.setter(lambda self, value: self.set('inkscape:label', str(value))) # type: ignore
433
434    def is_sensitive(self):
435        """Return true if this element is sensitive in inkscape"""
436        return self.get('sodipodi:insensitive', None) != 'true'
437
438    def set_sensitive(self, sensitive=True):
439        """Set the sensitivity of the element/layer"""
440        self.set('sodipodi:insensitive', str((not sensitive)).lower())
441
442    @property
443    def unit(self):
444        """Return the unit being used by the owning document, cached"""
445        try:
446            return self.root.unit
447        except FragmentError:
448            return 'px' # Don't cache.
449
450    def uutounit(self, value, to_unit='px'):
451        """Convert the unit the given unit type"""
452        return convert_unit(value, to_unit, default=self.unit)
453
454    def unittouu(self, value):
455        """Convert a unit value into the document's units"""
456        return convert_unit(value, self.unit)
457
458    def add_unit(self, value):
459        """Add document unit when no unit is specified in the string """
460        return render_unit(value, self.unit)
461
462
463
464class ShapeElement(BaseElement):
465    """Elements which have a visible representation on the canvas"""
466    @property
467    def path(self):
468        """Gets the outline or path of the element, this may be a simple bounding box"""
469        return Path(self.get_path())
470
471    @path.setter
472    def path(self, path):
473        self.set_path(path)
474
475    @property
476    def clip(self):
477        """Gets the clip path element (if any)"""
478        ref = self.get('clip-path')
479        if not ref:
480            return None
481        return self.root.getElementById(ref)
482
483    @clip.setter
484    def clip(self, elem):
485        self.set('clip-path', elem.get_id(as_url=2))
486
487    def get_path(self):
488        """Generate a path for this object which can inform the bounding box"""
489        raise NotImplementedError(f"Path should be provided by svg elem {self.typename}.")
490
491    def set_path(self, path):
492        """Set the path for this object (if possible)"""
493        raise AttributeError(
494            f"Path can not be set on this element: {self.typename} <- {path}.")
495
496    def to_path_element(self):
497        """Replace this element with a path element"""
498        from ._polygons import PathElement
499        elem = PathElement()
500        elem.path = self.path
501        elem.style = self.effective_style()
502        elem.transform = self.transform
503        return elem
504
505    def composed_transform(self, other=None):
506        """Calculate every transform down to the other element
507          if none specified the transform is to the root document element"""
508        parent = self.getparent()
509        if parent is not None and isinstance(parent, ShapeElement):
510            return parent.composed_transform() * self.transform
511        return self.transform
512
513    def composed_style(self):
514        """Calculate the final styles applied to this element"""
515        parent = self.getparent()
516        if parent is not None and isinstance(parent, ShapeElement):
517            return parent.composed_style() + self.style
518        return self.style
519
520    def cascaded_style(self):
521        """Add all cascaded styles, do not write to this Style object"""
522        ret = Style()
523        for style in self.root.stylesheets.lookup(self.get('id')):
524            ret += style
525        return ret + self.style
526
527    def effective_style(self):
528        """Without parent styles, what is the effective style is"""
529        return self.style
530
531    def bounding_box(self, transform=None):
532        # type: (Optional[Transform]) -> Optional[BoundingBox]
533        """BoundingBox of the shape (adjusted for its clip path if applicable)"""
534        shape_box = self.shape_box(transform)
535        clip = self.clip
536        if clip is None or shape_box is None:
537            return shape_box
538        return shape_box & clip.bounding_box(Transform(transform) * self.transform)
539
540    def shape_box(self, transform=None):
541        # type: (Optional[Transform]) -> Optional[BoundingBox]
542        """BoundingBox of the unclipped shape"""
543        path = self.path.to_absolute()
544        if transform is True:
545            path = path.transform(self.composed_transform())
546        else:
547            path = path.transform(self.transform)
548            if transform:  # apply extra transformation
549                path = path.transform(transform)
550        return path.bounding_box()
551
552    def is_visible(self):
553        """Returns false if the css says this object is invisible"""
554        if self.style.get('display', '') == 'none':
555            return False
556        if not float(self.style.get('opacity', 1.0)):
557            return False
558        return True
559