1import numpy as np
3from string import Template
6from ..arc import arc_center
7from ..entities import Line, Arc, Bezier
9from ...constants import log, tol
10from ...constants import res_path as res
12from ... import util
13from ... import grouping
14from ... import resources
15from ... import exceptions
17from ... transformations import transform_points, planar_matrix
20    # pip install svg.path
21    from svg.path import parse_path
22except BaseException as E:
23    # will re-raise the import exception when
24    # someone tries to call `parse_path`
25    parse_path = exceptions.closure(E)
28    from lxml import etree
29except BaseException as E:
30    # will re-raise the import exception when
31    # someone actually tries to use the module
32    etree = exceptions.ExceptionModule(E)
35def svg_to_path(file_obj, file_type=None):
36    """
37    Load an SVG file into a Path2D object.
39    Parameters
40    -----------
41    file_obj : open file object
42      Contains SVG data
43    file_type: None
44      Not used
46    Returns
47    -----------
48    loaded : dict
49      With kwargs for Path2D constructor
50    """
52    def element_transform(e, max_depth=100):
53        """
54        Find a transformation matrix for an XML element.
55        """
56        matrices = []
57        current = e
58        for i in range(max_depth):
59            if 'transform' in current.attrib:
60                matrices.extend(transform_to_matrices(
61                    current.attrib['transform']))
62            current = current.getparent()
63            if current is None:
64                break
66        if len(matrices) == 0:
67            return np.eye(3)
68        elif len(matrices) == 1:
69            return matrices[0]
70        else:
71            return util.multi_dot(matrices[::-1])
73    # first parse the XML
74    xml = etree.fromstring(file_obj.read())
76    # store paths and transforms as
77    # (path string, 3x3 matrix)
78    paths = []
80    # store every path element
81    for element in xml.iter('{*}path'):
82        paths.append((element.attrib['d'],
83                      element_transform(element)))
85    return _svg_path_convert(paths)
88def transform_to_matrices(transform):
89    """
90    Convert an SVG transform string to an array of matrices.
92    i.e. "rotate(-10 50 100)
93          translate(-36 45.5)
94          skewX(40)
95          scale(1 0.5)"
97    Parameters
98    -----------
99    transform : str
100      Contains transformation information in SVG form
102    Returns
103    -----------
104    matrices : (n, 3, 3) float
105      Multiple transformation matrices from input transform string
106    """
107    # split the transform string in to components of:
108    # (operation, args) i.e. (translate, '-1.0, 2.0')
109    components = [
110        [j.strip() for j in i.strip().split('(') if len(j) > 0]
111        for i in transform.lower().split(')') if len(i) > 0]
112    # store each matrix without dotting
113    matrices = []
114    for line in components:
115        if len(line) == 0:
116            continue
117        elif len(line) != 2:
118            raise ValueError('should always have two components!')
119        key, args = line
120        # convert string args to array of floats
121        # support either comma or space delimiter
122        values = np.array([float(i) for i in
123                           args.replace(',', ' ').split()])
124        if key == 'translate':
125            # convert translation to a (3, 3) homogeneous matrix
126            matrices.append(np.eye(3))
127            matrices[-1][:2, 2] = values
128        elif key == 'matrix':
129            # [a b c d e f] ->
130            # [[a c e],
131            #  [b d f],
132            #  [0 0 1]]
133            matrices.append(np.vstack((
134                values.reshape((3, 2)).T, [0, 0, 1])))
135        elif key == 'rotate':
136            # SVG rotations are in degrees
137            angle = np.degrees(values[0])
138            # if there are three values rotate around point
139            if len(values) == 3:
140                point = values[1:]
141            else:
142                point = None
143            matrices.append(planar_matrix(theta=angle,
144                                          point=point))
145        elif key == 'scale':
146            # supports (x_scale, y_scale) or (scale)
147            mat = np.eye(3)
148            mat[:2, :2] *= values
149            matrices.append(mat)
150        else:
151            log.warning('unknown SVG transform: {}'.format(key))
153    return matrices
156def _svg_path_convert(paths):
157    """
158    Convert an SVG path string into a Path2D object
160    Parameters
161    -------------
162    paths: list of tuples
163      Containing (path string, (3, 3) matrix)
165    Returns
166    -------------
167    drawing : dict
168      Kwargs for Path2D constructor
169    """
170    def complex_to_float(values):
171        return np.array([[i.real, i.imag] for i in values])
173    def load_multi(multi):
174        # load a previously parsed multiline
175        return Line(np.arange(len(multi.points)) + count), multi.points
177    def load_arc(svg_arc):
178        # load an SVG arc into a trimesh arc
179        points = complex_to_float([svg_arc.start,
180                                   svg_arc.point(.5),
181                                   svg_arc.end])
182        return Arc(np.arange(3) + count), points
184    def load_quadratic(svg_quadratic):
185        # load a quadratic bezier spline
186        points = complex_to_float([svg_quadratic.start,
187                                   svg_quadratic.control,
188                                   svg_quadratic.end])
189        return Bezier(np.arange(3) + count), points
191    def load_cubic(svg_cubic):
192        # load a cubic bezier spline
193        points = complex_to_float([svg_cubic.start,
194                                   svg_cubic.control1,
195                                   svg_cubic.control2,
196                                   svg_cubic.end])
197        return Bezier(np.arange(4) + count), points
199    # store loaded values here
200    entities = []
201    vertices = []
202    # how many vertices have we loaded
203    count = 0
204    # load functions for each entity
205    loaders = {'Arc': load_arc,
206               'MultiLine': load_multi,
207               'CubicBezier': load_cubic,
208               'QuadraticBezier': load_quadratic}
210    class MultiLine(object):
211        # An object to hold one or multiple Line entities.
212        def __init__(self, lines):
213            if tol.strict:
214                # in unit tests make sure we only have lines
215                assert all(type(L).__name__ == 'Line'
216                           for L in lines)
217            # get the starting point of every line
218            points = [L.start for L in lines]
219            # append the endpoint
220            points.append(lines[-1].end)
221            # convert to (n, 2) float points
222            self.points = np.array([[i.real, i.imag]
223                                    for i in points],
224                                   dtype=np.float64)
226    for path_string, matrix in paths:
227        # get parsed entities from svg.path
228        raw = np.array(list(parse_path(path_string)))
229        # check to see if each entity is a Line
230        is_line = np.array([type(i).__name__ == 'Line'
231                            for i in raw])
232        # find groups of consecutive lines so we can combine them
233        blocks = grouping.blocks(
234            is_line, min_len=1, only_nonzero=False)
235        if tol.strict:
236            # in unit tests make sure we didn't lose any entities
237            assert np.allclose(np.hstack(blocks),
238                               np.arange(len(raw)))
240        # Combine consecutive lines into a single MultiLine
241        parsed = []
242        for b in blocks:
243            if type(raw[b[0]]).__name__ == 'Line':
244                # if entity consists of lines add a multiline
245                parsed.append(MultiLine(raw[b]))
246            else:
247                # otherwise just add the entities
248                parsed.extend(raw[b])
249        # loop through parsed entity objects
250        for svg_entity in parsed:
251            # keyed by entity class name
252            type_name = type(svg_entity).__name__
253            if type_name in loaders:
254                # get new entities and vertices
255                e, v = loaders[type_name](svg_entity)
256                # append them to the result
257                entities.append(e)
258                # create a sequence of vertex arrays
259                vertices.append(transform_points(v, matrix))
260                count += len(vertices[-1])
262    # store results as kwargs and stack vertices
263    loaded = {'entities': np.array(entities),
264              'vertices': np.vstack(vertices)}
265    return loaded
268def export_svg(drawing,
269               return_path=False,
270               layers=None,
271               **kwargs):
272    """
273    Export a Path2D object into an SVG file.
275    Parameters
276    -----------
277    drawing : Path2D
278     Source geometry
279    return_path : bool
280      If True return only path string not wrapped in XML
281    layers : None, or [str]
282      Only export specified layers
284    Returns
285    -----------
286    as_svg : str
287      XML formatted SVG, or path string
288    """
289    if not util.is_instance_named(drawing, 'Path2D'):
290        raise ValueError('drawing must be Path2D object!')
292    # copy the points and make sure they're not a TrackedArray
293    points = drawing.vertices.view(np.ndarray).copy()
295    # fetch the export template for SVG files
296    template_svg = Template(resources.get('svg.template.xml'))
298    def circle_to_svgpath(center, radius, reverse):
299        radius_str = format(radius, res.export)
300        path_str = ' M ' + format(center[0] - radius, res.export) + ','
301        path_str += format(center[1], res.export)
302        path_str += ' a ' + radius_str + ',' + radius_str
303        path_str += ',0,1,' + str(int(reverse)) + ','
304        path_str += format(2 * radius, res.export) + ',0'
305        path_str += ' a ' + radius_str + ',' + radius_str
306        path_str += ',0,1,' + str(int(reverse)) + ','
307        path_str += format(-2 * radius, res.export) + ',0 Z'
308        return path_str
310    def svg_arc(arc, reverse):
311        """
312        arc string: (rx ry x-axis-rotation large-arc-flag sweep-flag x y)+
313        large-arc-flag: greater than 180 degrees
314        sweep flag: direction (cw/ccw)
315        """
316        arc_idx = arc.points[::((reverse * -2) + 1)]
317        vertices = points[arc_idx]
318        vertex_start, vertex_mid, vertex_end = vertices
319        center_info = arc_center(vertices)
320        C, R, angle = (center_info['center'],
321                       center_info['radius'],
322                       center_info['span'])
323        if arc.closed:
324            return circle_to_svgpath(C, R, reverse)
326        large_flag = str(int(angle > np.pi))
327        sweep_flag = str(int(np.cross(vertex_mid - vertex_start,
328                                      vertex_end - vertex_start) > 0.0))
330        arc_str = move_to(arc_idx[0])
331        arc_str += 'A {},{} 0 {}, {} {},{}'.format(R,
332                                                   R,
333                                                   large_flag,
334                                                   sweep_flag,
335                                                   vertex_end[0],
336                                                   vertex_end[1])
337        return arc_str
339    def move_to(vertex_id):
340        x_ex = format(points[vertex_id][0], res.export)
341        y_ex = format(points[vertex_id][1], res.export)
342        move_str = ' M ' + x_ex + ',' + y_ex
343        return move_str
345    def svg_discrete(entity, reverse):
346        """
347        Use an entities discrete representation to export a
348        curve as a polyline
349        """
350        discrete = entity.discrete(points)
351        # if entity contains no geometry return
352        if len(discrete) == 0:
353            return ''
354        # are we reversing the entity
355        if reverse:
356            discrete = discrete[::-1]
357        # the format string for the SVG path
358        template = ' M {},{} ' + (' L {},{}' * (len(discrete) - 1))
359        # apply the data from the discrete curve
360        result = template.format(*discrete.reshape(-1))
361        return result
363    def convert_entity(entity, reverse=False):
364        if layers is not None and entity.layer not in layers:
365            return ''
366        # the class name of the entity
367        etype = entity.__class__.__name__
368        if etype == 'Arc':
369            # export the exact version of the entity
370            return svg_arc(entity, reverse=False)
371        else:
372            # just export the polyline version of the entity
373            return svg_discrete(entity, reverse=False)
375    # convert each entity to an SVG entity
376    converted = [convert_entity(e) for e in drawing.entities]
378    # append list of converted into a string
379    path_str = ''.join(converted).strip()
381    # return path string without XML wrapping
382    if return_path:
383        return path_str
385    # format as XML
386    if 'stroke_width' in kwargs:
387        stroke_width = float(kwargs['stroke_width'])
388    else:
389        stroke_width = drawing.extents.max() / 800.0
390    subs = {'PATH_STRING': path_str,
391            'MIN_X': points[:, 0].min(),
392            'MIN_Y': points[:, 1].min(),
393            'WIDTH': drawing.extents[0],
394            'HEIGHT': drawing.extents[1],
395            'STROKE': stroke_width}
396    result = template_svg.substitute(subs)
397    return result