1# cython: language_level=3
2# distutils: language = c++
3# Copyright (c) 2021, Manfred Moitzi
4# License: MIT License
5# Cython implementation of the B-spline basis function.
6
7from typing import List, Iterable, Sequence, Tuple
8import cython
9from cpython.mem cimport PyMem_Malloc, PyMem_Free
10from .vector cimport Vec3, isclose, v3_mul, v3_sub, v3_from_cpp_vec3
11from ._cpp_vec3 cimport CppVec3
12
13__all__ = ['Basis', 'Evaluator']
14
15# factorial from 0 to 18
16cdef double[19] FACTORIAL = [
17    1., 1., 2., 6., 24., 120., 720., 5040., 40320., 362880., 3628800.,
18    39916800., 479001600., 6227020800., 87178291200., 1307674368000.,
19    20922789888000., 355687428096000., 6402373705728000.
20]
21
22NULL_LIST = [0.0]
23ONE_LIST = [1.0]
24
25cdef Vec3 NULLVEC = Vec3()
26DEF ABS_TOL = 1e-12
27DEF REL_TOL = 1e-9
28
29# AutoCAD limits the degree to 11 or order = 12
30DEF MAX_ORDER = 12
31
32@cython.cdivision(True)
33cdef double binomial_coefficient(int k, int i):
34    cdef double k_fact = FACTORIAL[k]
35    cdef double i_fact = FACTORIAL[i]
36    cdef double k_i_fact
37    if i > k:
38        return 0.0
39    k_i_fact = FACTORIAL[k - i]
40    return k_fact / (k_i_fact * i_fact)
41
42@cython.boundscheck(False)
43cdef int bisect_right(double*a, double x, int lo, int hi):
44    cdef int mid
45    while lo < hi:
46        mid = (lo + hi) // 2
47        if x < a[mid]:
48            hi = mid
49        else:
50            lo = mid + 1
51    return lo
52
53cdef reset_double_array(double *a, int count, double value=0.0):
54    cdef int i
55    for i in range(count):
56        a[i] = value
57
58cdef class Basis:
59    """ Immutable Basis function class. """
60    # public:
61    cdef readonly int order
62    cdef readonly int count
63    cdef readonly double max_t
64    cdef tuple weights_  # public attribute for Cython Evaluator
65    # private:
66    cdef double*_knots
67    cdef int knot_count
68
69    def __cinit__(self, knots: Iterable[float], int order, int count,
70                  weights: Sequence[float] = None):
71        if order < 2 or order >= MAX_ORDER:
72            raise ValueError('invalid order')
73        self.order = order
74        if count < 2:
75            raise ValueError('invalid count')
76        self.count = count
77        self.knot_count = self.order + self.count
78        self.weights_ = tuple(float(x) for x in weights) if weights else tuple()
79
80        cdef Py_ssize_t i = len(self.weights_)
81        if i != 0 and i != self.count:
82            raise ValueError('invalid weight count')
83
84        knots = [float(x) for x in knots]
85        if len(knots) != self.knot_count:
86            raise ValueError('invalid knot count')
87
88        self._knots = <double *> PyMem_Malloc(self.knot_count * sizeof(double))
89        for i in range(self.knot_count):
90            self._knots[i] = knots[i]
91        self.max_t = self._knots[self.knot_count - 1]
92
93    def __dealloc__(self):
94        PyMem_Free(self._knots)
95
96    @property
97    def degree(self) -> int:
98        return self.order - 1
99
100    @property
101    def knots(self) -> Tuple[float, ...]:
102        return tuple(x for x in self._knots[:self.knot_count])
103
104    @property
105    def weights(self) -> Tuple[float, ...]:
106        return self.weights_
107
108    @property
109    def is_rational(self) -> bool:
110        """ Returns ``True`` if curve is a rational B-spline. (has weights) """
111        return bool(self.weights_)
112
113    cpdef list basis_vector(self, double t):
114        """ Returns the expanded basis vector. """
115
116        cdef int span = self.find_span(t)
117        cdef int p = self.order - 1
118        cdef int front = span - p
119        cdef int back = self.count - span - 1
120        cdef list result
121        if front > 0:
122            result = NULL_LIST * front
123            result.extend(self.basis_funcs(span, t))
124        else:
125            result = self.basis_funcs(span, t)
126        if back > 0:
127            result.extend(NULL_LIST * back)
128        return result
129
130    cpdef int find_span(self, double u):
131        """ Determine the knot span index. """
132        # Linear search is more reliable than binary search of the Algorithm A2.1
133        # from The NURBS Book by Piegl & Tiller.
134        cdef double*knots = self._knots
135        cdef int count = self.count  # text book: n+1
136        cdef int p = self.order - 1
137        cdef int span
138        if u >= knots[count]:  # special case
139            return count - 1
140
141        # common clamped spline:
142        if knots[p] == 0.0:  # use binary search
143            # This is fast and works most of the time,
144            # but Test 621 : test_weired_closed_spline()
145            # goes into an infinity loop, because of
146            # a weird knot configuration.
147            return bisect_right(knots, u, p, count) - 1
148        else:  # use linear search
149            span = 0
150            while knots[span] <= u and span < count:
151                span += 1
152            return span - 1
153
154    cpdef list basis_funcs(self, int span, double u):
155        # Source: The NURBS Book: Algorithm A2.2
156        cdef int order = self.order
157        cdef double*knots = self._knots
158        cdef double[MAX_ORDER] N, left, right
159        cdef list result
160        reset_double_array(N, order)
161        reset_double_array(left, order)
162        reset_double_array(right, order)
163
164        cdef int j, r, i1
165        cdef double temp, saved, temp_r, temp_l
166        N[0] = 1.0
167        for j in range(1, order):
168            i1 = span + 1 - j
169            if i1 < 0:
170                i1 = 0
171            left[j] = u - knots[i1]
172            right[j] = knots[span + j] - u
173            saved = 0.0
174            for r in range(j):
175                temp_r = right[r + 1]
176                temp_l = left[j - r]
177                temp = N[r] / (temp_r + temp_l)
178                N[r] = saved + temp_r * temp
179                saved = temp_l * temp
180            N[j] = saved
181        result = [x for x in N[:order]]
182        if self.is_rational:
183            return self.span_weighting(result, span)
184        else:
185            return result
186
187    cpdef list span_weighting(self, nbasis: List[float], int span):
188        cdef list products = [
189            nb * w for nb, w in zip(
190                nbasis,
191                self.weights_[span - self.order + 1: span + 1]
192            )
193        ]
194        s = sum(products)
195        if s != 0:
196            return [p / s for p in products]
197        else:
198            return NULL_LIST * len(nbasis)
199
200    cpdef list basis_funcs_derivatives(self, int span, double u, int n = 1):
201        # Source: The NURBS Book: Algorithm A2.3
202        cdef int order = self.order
203        cdef int p = order - 1
204        if n > p:
205            n = p
206        cdef double*knots = self._knots
207        cdef double[MAX_ORDER] left, right
208        reset_double_array(left, order, 1.0)
209        reset_double_array(right, order, 1.0)
210
211        cdef list ndu = [ONE_LIST * order for _ in range(order)]
212        cdef int j, r, i1
213        cdef double temp, saved, tmp_r, tmp_l
214        for j in range(1, order):
215            i1 = span + 1 - j
216            if i1 < 0:
217                i1 = 0
218            left[j] = u - knots[i1]
219            right[j] = knots[span + j] - u
220            saved = 0.0
221            for r in range(j):
222                # lower triangle
223                tmp_r = right[r + 1]
224                tmp_l = left[j - r]
225                ndu[j][r] = tmp_r + tmp_l
226                temp = ndu[r][j - 1] / ndu[j][r]
227                # upper triangle
228                ndu[r][j] = saved + (tmp_r * temp)
229                saved = tmp_l * temp
230            ndu[j][j] = saved
231
232        # load the basis_vector functions
233        cdef list derivatives = [NULL_LIST * order for _ in range(order)]
234        for j in range(order):
235            derivatives[0][j] = ndu[j][p]
236
237        # loop over function index
238        cdef list a = [ONE_LIST * order, ONE_LIST * order]
239        cdef int s1, s2, k, rk, pk, j1, j2, t
240        cdef double d
241        for r in range(order):
242            s1 = 0
243            s2 = 1
244            # alternate rows in array a
245            a[0][0] = 1.0
246
247            # loop to compute kth derivative
248            for k in range(1, n + 1):
249                d = 0.0
250                rk = r - k
251                pk = p - k
252                if r >= k:
253                    a[s2][0] = a[s1][0] / ndu[pk + 1][rk]
254                    d = a[s2][0] * ndu[rk][pk]
255                if rk >= -1:
256                    j1 = 1
257                else:
258                    j1 = -rk
259                if (r - 1) <= pk:
260                    j2 = k - 1
261                else:
262                    j2 = p - r
263                for j in range(j1, j2 + 1):
264                    a[s2][j] = (a[s1][j] - a[s1][j - 1]) / ndu[pk + 1][rk + j]
265                    d += (a[s2][j] * ndu[rk + j][pk])
266                if r <= pk:
267                    a[s2][k] = -a[s1][k - 1] / ndu[pk + 1][r]
268                    d += (a[s2][k] * ndu[r][pk])
269                derivatives[k][r] = d
270
271                # Switch rows
272                t = s1
273                s1 = s2
274                s2 = t
275
276        # Multiply through by the the correct factors
277        cdef double rr = p
278        for k in range(1, n + 1):
279            for j in range(order):
280                derivatives[k][j] *= rr
281            rr *= (p - k)
282        return derivatives[:n + 1]
283
284cdef class Evaluator:
285    """ B-spline curve point and curve derivative evaluator. """
286    cdef Basis _basis
287    cdef tuple _control_points
288
289    def __cinit__(self, basis: Basis, control_points: Sequence[Vec3]):
290        self._basis = basis
291        self._control_points = Vec3.tuple(control_points)
292
293    cdef CppVec3 control_point(self, int index):
294        cdef Vec3 v3 = <Vec3> self._control_points[index]
295        return v3.to_cpp_vec3()
296
297    cpdef Vec3 point(self, double u):
298        # Source: The NURBS Book: Algorithm A3.1
299        cdef Basis basis = self._basis
300        cdef int p = basis.order - 1
301        if isclose(u, basis.max_t, REL_TOL, ABS_TOL):
302            u = basis.max_t
303
304        cdef int span = basis.find_span(u)
305        cdef list N = basis.basis_funcs(span, u)
306        cdef int i
307        cdef CppVec3 sum_ = CppVec3(), cpoint
308        for i in range(p + 1):
309            cpoint = self.control_point(span - p + i)
310            sum_ = sum_ + (cpoint * <double> N[i])
311        return v3_from_cpp_vec3(sum_)
312
313    def points(self, t: Iterable[float]) -> Iterable[Vec3]:
314        cdef double u
315        for u in t:
316            yield self.point(u)
317
318    cpdef list derivative(self, double u, int n = 1):
319        """ Return point and derivatives up to n <= degree for parameter u. """
320        # Source: The NURBS Book: Algorithm A3.2
321        cdef Vec3 vec3sum
322        cdef CppVec3 cppsum
323        cdef list CK = [], CKw = [], wders = []
324        cdef tuple weights
325        cdef Basis basis = self._basis
326        if isclose(u, basis.max_t, REL_TOL, ABS_TOL):
327            u = basis.max_t
328
329        cdef int p = basis.degree
330        cdef int span = basis.find_span(u)
331        cdef list basis_funcs_ders = basis.basis_funcs_derivatives(span, u, n)
332        cdef int k, j, i
333        cdef double wder, bas_func_weight, bas_func
334        if basis.is_rational:
335            # Homogeneous point representation required:
336            # (x*w, y*w, z*w, w)
337            weights = basis.weights_
338            for k in range(n + 1):
339                cppsum = CppVec3()
340                wder = 0.0
341                for j in range(p + 1):
342                    i = span - p + j
343                    bas_func_weight = basis_funcs_ders[k][j] * weights[i]
344                    # control_point * weight * bas_func_der = (x*w, y*w, z*w) * bas_func_der
345                    cppsum = cppsum + (self.control_point(i) * bas_func_weight)
346                    wder += bas_func_weight
347                CKw.append(v3_from_cpp_vec3(cppsum))
348                wders.append(wder)
349
350            # Source: The NURBS Book: Algorithm A4.2
351            for k in range(n + 1):
352                vec3sum = CKw[k]
353                for j in range(1, k + 1):
354                    bas_func_weight = binomial_coefficient(k, j) * wders[j]
355                    vec3sum = v3_sub(vec3sum,
356                                     v3_mul(CK[k - j], bas_func_weight))
357                CK.append(vec3sum / wders[0])
358        else:
359            for k in range(n + 1):
360                cppsum = CppVec3()
361                for j in range(p + 1):
362                    bas_func = basis_funcs_ders[k][j]
363                    cppsum = cppsum + (
364                            self.control_point(span - p + j) * bas_func
365                    )
366                CK.append(v3_from_cpp_vec3(cppsum))
367        return CK
368
369    def derivatives(
370            self, t: Iterable[float], int n = 1) -> Iterable[List[Vec3]]:
371        cdef double u
372        for u in t:
373            yield self.derivative(u, n)
374