1# Copyright (c) 2010-2020 Manfred Moitzi
2# License: MIT License
3from typing import TYPE_CHECKING, Iterable, Tuple, Sequence
4from functools import lru_cache
5import math
6from ezdxf.math import Vec3, NULLVEC, Matrix44
7from .construct2d import linspace
8
9if TYPE_CHECKING:
10    from ezdxf.eztypes import Vertex
11
12__all__ = ['Bezier']
13
14"""
15
16Bezier curves
17=============
18
19https://www.cl.cam.ac.uk/teaching/2000/AGraphHCI/SMEG/node3.html
20
21A Bezier curve is a weighted sum of n+1 control points,  P0, P1, ..., Pn, where
22the weights are the Bernstein polynomials.
23
24The Bezier curve of order n+1 (degree n) has n+1 control points. These are the
25first three orders of Bezier curve definitions.
26
27(75) linear P(t) = (1-t)*P0 + t*P1
28(76) quadratic P(t) = (1-t)^2*P0 + 2*(t-1)*t*P1 + t^2*P2
29(77) cubic P(t) = (1-t)^3*P0 + 3*(1-t)^2*t*P1 + 3*(1-t)*t^2*P2 + t^3*P3
30
31Ways of thinking about Bezier curves
32------------------------------------
33
34There are several useful ways in which you can think about Bezier curves.
35Here are the ones that I use.
36
37Linear interpolation
38~~~~~~~~~~~~~~~~~~~~
39
40Equation (75) is obviously a linear interpolation between two points. Equation
41(76) can be rewritten as a linear interpolation between linear interpolations
42between points.
43
44Weighted average
45~~~~~~~~~~~~~~~~
46
47A Bezier curve can be seen as a weighted average of all of its control points.
48Because all of the weights are positive, and because the weights sum to one, the
49Bezier curve is guaranteed to lie within the convex hull of its control points.
50
51Refinement of the control polygon
52~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
53
54A Bezier curve can be seen as some sort of refinement of the polygon made by
55connecting its control points in order. The Bezier curve starts and ends at the
56two end points and its shape is determined by the relative positions of the n-1
57other control points, although it will generally not pass through these other
58control points. The tangent vectors at the start and end of the curve pass
59through the end point and the immediately adjacent point.
60
61Continuity
62----------
63
64You should note that each Bezier curve is independent of any other Bezier curve.
65If we wish two Bezier curves to join with any type of continuity, then we must
66explicitly position the control points of the second curve so that they bear
67the appropriate relationship with the control points in the first curve.
68
69Any Bezier curve is infinitely differentiable within itself, and is therefore
70continuous to any degree.
71
72"""
73
74
75class Bezier:
76    """ A `Bézier curve`_ is a parametric curve used in computer graphics and
77    related fields. Bézier curves are used to model smooth curves that can be
78    scaled indefinitely. "Paths", as they are commonly referred to in image
79    manipulation programs, are combinations of linked Bézier curves.
80    Paths are not bound by the limits of rasterized images and are intuitive to
81    modify. (Source: Wikipedia)
82
83    This is a generic implementation which works with any count of definition
84    points greater than 2, but it is a simple and slow implementation. For more
85    performance look at the specialized :class:`Bezier4P` class.
86
87    Objects are immutable.
88
89    Args:
90        defpoints: iterable of definition points as :class:`Vec3` compatible objects.
91
92    """
93
94    def __init__(self, defpoints: Iterable['Vertex']):
95        self._defpoints: Sequence[Vec3] = Vec3.tuple(defpoints)
96
97    @property
98    def control_points(self) -> Sequence[Vec3]:
99        """ Control points as tuple of :class:`Vec3` objects. """
100        return self._defpoints
101
102    def approximate(self, segments: int = 20) -> Iterable[Vec3]:
103        """ Approximates curve by vertices as :class:`Vec3` objects, vertices
104        count = segments + 1.
105        """
106        return self.points(self.params(segments))
107
108    def flattening(self, distance: float,
109                   segments: int = 4) -> Iterable[Vec3]:
110        """ Adaptive recursive flattening. The argument `segments` is the
111        minimum count of approximation segments, if the distance from the center
112        of the approximation segment to the curve is bigger than `distance` the
113        segment will be subdivided.
114
115        Args:
116            distance: maximum distance from the center of the curve (Cn)
117                to the center of the linear (C1) curve between two
118                approximation points to determine if a segment should be
119                subdivided.
120            segments: minimum segment count
121
122        .. versionadded:: 0.15
123
124        """
125
126        def subdiv(start_point, end_point, start_t: float, end_t: float):
127            mid_t = (start_t + end_t) * 0.5
128            mid_point = self.point(mid_t)
129            chk_point = start_point.lerp(end_point)
130            # center point point is faster than projecting mid point onto
131            # vector start -> end:
132            if chk_point.distance(mid_point) < distance:
133                yield end_point
134            else:
135                yield from subdiv(start_point, mid_point, start_t, mid_t)
136                yield from subdiv(mid_point, end_point, mid_t, end_t)
137
138        dt = 1.0 / segments
139        t0 = 0.0
140        start_point = self._defpoints[0]
141        yield start_point
142        while t0 < 1.0:
143            t1 = t0 + dt
144            if math.isclose(t1, 1.0):
145                end_point = self._defpoints[-1]
146                t1 = 1.0
147            else:
148                end_point = self.point(t1)
149            yield from subdiv(start_point, end_point, t0, t1)
150            t0 = t1
151            start_point = end_point
152
153    def params(self, segments: int) -> Iterable[float]:
154        """ Yield evenly spaced parameters from 0 to 1 for given segment count. """
155        yield from linspace(0.0, 1.0, segments + 1)
156
157    def point(self, t: float) -> Vec3:
158        """
159        Returns a point for parameter `t` in range [0, 1] as :class:`Vec3` object.
160        """
161        if t < 0.0 or t > 1.0:
162            raise ValueError('Parameter t not in range [0, 1]')
163        if (1.0 - t) < 5e-6:
164            t = 1.0
165        point = NULLVEC
166        pts = self._defpoints
167        n = len(pts)
168
169        for i in range(n):
170            point += bernstein_basis(n - 1, i, t) * pts[i]
171        return point
172
173    def points(self, t: Iterable[float]) -> Iterable[Vec3]:
174        """ Yields multiple points for parameters in vector `t` as :class:`Vec3` objects.
175        Parameters have to be in range [0, 1].
176        """
177        for u in t:
178            yield self.point(u)
179
180    def derivative(self, t: float) -> Tuple[Vec3, Vec3, Vec3]:
181        """
182        Returns (point, 1st derivative, 2nd derivative) tuple for parameter `t` in range [0, 1]
183        as :class:`Vec3` objects.
184        """
185        if t < 0.0 or t > 1.0:
186            raise ValueError('Parameter t not in range [0, 1]')
187
188        if (1.0 - t) < 5e-6:
189            t = 1.0
190        pts = self._defpoints
191        n = len(pts)
192        n0 = n - 1
193        point = NULLVEC
194        d1 = NULLVEC
195        d2 = NULLVEC
196        t2 = t * t
197        n0_1 = n0 - 1
198        if t == 0.0:
199            d1 = n0 * (pts[1] - pts[0])
200            d2 = n0 * n0_1 * (pts[0] - 2. * pts[1] + pts[2])
201        for i in range(n):
202            tmp_bas = bernstein_basis(n0, i, t)
203            point += tmp_bas * pts[i]
204            if 0.0 < t < 1.0:
205                _1_t = 1.0 - t
206                i_n0_t = i - n0 * t
207                d1 += i_n0_t / (t * _1_t) * tmp_bas * pts[i]
208                d2 += (i_n0_t * i_n0_t - n0 * t2 - i * (1. - 2. * t)) / (t2 * _1_t * _1_t) * tmp_bas * pts[i]
209            if t == 1.0:
210                d1 = n0 * (pts[n0] - pts[n0_1])
211                d2 = n0 * n0_1 * (pts[n0] - 2 * pts[n0_1] + pts[n0 - 2])
212        return point, d1, d2
213
214    def derivatives(self, t: Iterable[float]) -> Iterable[Tuple[Vec3, Vec3, Vec3]]:
215        """
216        Returns multiple (point, 1st derivative, 2nd derivative) tuples for parameter vector  `t` as :class:`Vec3` objects.
217        Parameters in range [0, 1]
218        """
219        for u in t:
220            yield self.derivative(u)
221
222    def reverse(self) -> 'Bezier':
223        """ Returns a new Bèzier-curve with reversed control point order. """
224        return Bezier(list(reversed(self.control_points)))
225
226    def transform(self, m: Matrix44) -> 'Bezier':
227        """ General transformation interface, returns a new :class:`Bezier` curve.
228
229        Args:
230             m: 4x4 transformation matrix (:class:`ezdxf.math.Matrix44`)
231
232        .. versionadded:: 0.14
233
234        """
235        defpoints = tuple(m.transform_vertices(self.control_points))
236        return Bezier(defpoints)
237
238
239def bernstein_basis(n: int, i: int, t: float) -> float:
240    # handle the special cases to avoid domain problem with pow
241    if t == 0.0 and i == 0:
242        ti = 1.0
243    else:
244        ti = pow(t, i)
245    if n == i and t == 1.0:
246        tni = 1.0
247    else:
248        tni = pow((1.0 - t), (n - i))
249    Ni = factorial(n) / (factorial(i) * factorial(n - i))
250    return Ni * ti * tni
251
252
253@lru_cache(maxsize=16)
254def factorial(n: int):
255    return math.factorial(n)
256