1import cython
2import numpy as np
3
4cimport numpy as cnp
5from numpy cimport (
6    float32_t,
7    float64_t,
8    int8_t,
9    int16_t,
10    int32_t,
11    int64_t,
12    ndarray,
13    uint8_t,
14)
15
16cnp.import_array()
17
18
19# -----------------------------------------------------------------------------
20# Preamble stuff
21
22cdef float64_t NaN = <float64_t>np.NaN
23cdef float64_t INF = <float64_t>np.inf
24
25# -----------------------------------------------------------------------------
26
27
28cdef class SparseIndex:
29    """
30    Abstract superclass for sparse index types.
31    """
32
33    def __init__(self):
34        raise NotImplementedError
35
36
37cdef class IntIndex(SparseIndex):
38    """
39    Object for holding exact integer sparse indexing information
40
41    Parameters
42    ----------
43    length : integer
44    indices : array-like
45        Contains integers corresponding to the indices.
46    check_integrity : bool, default=True
47        Check integrity of the input.
48    """
49
50    cdef readonly:
51        Py_ssize_t length, npoints
52        ndarray indices
53
54    def __init__(self, Py_ssize_t length, indices, bint check_integrity=True):
55        self.length = length
56        self.indices = np.ascontiguousarray(indices, dtype=np.int32)
57        self.npoints = len(self.indices)
58
59        if check_integrity:
60            self.check_integrity()
61
62    def __reduce__(self):
63        args = (self.length, self.indices)
64        return IntIndex, args
65
66    def __repr__(self) -> str:
67        output = 'IntIndex\n'
68        output += f'Indices: {repr(self.indices)}\n'
69        return output
70
71    @property
72    def nbytes(self) -> int:
73        return self.indices.nbytes
74
75    def check_integrity(self):
76        """
77        Checks the following:
78
79        - Indices are strictly ascending
80        - Number of indices is at most self.length
81        - Indices are at least 0 and at most the total length less one
82
83        A ValueError is raised if any of these conditions is violated.
84        """
85
86        if self.npoints > self.length:
87            raise ValueError(
88                f"Too many indices. Expected {self.length} but found {self.npoints}"
89            )
90
91        # Indices are vacuously ordered and non-negative
92        # if the sequence of indices is empty.
93        if self.npoints == 0:
94            return
95
96        if self.indices.min() < 0:
97            raise ValueError("No index can be less than zero")
98
99        if self.indices.max() >= self.length:
100            raise ValueError("All indices must be less than the length")
101
102        monotonic = np.all(self.indices[:-1] < self.indices[1:])
103        if not monotonic:
104            raise ValueError("Indices must be strictly increasing")
105
106    def equals(self, other: object) -> bool:
107        if not isinstance(other, IntIndex):
108            return False
109
110        if self is other:
111            return True
112
113        same_length = self.length == other.length
114        same_indices = np.array_equal(self.indices, other.indices)
115        return same_length and same_indices
116
117    @property
118    def ngaps(self) -> int:
119        return self.length - self.npoints
120
121    def to_int_index(self):
122        return self
123
124    def to_block_index(self):
125        locs, lens = get_blocks(self.indices)
126        return BlockIndex(self.length, locs, lens)
127
128    cpdef IntIndex intersect(self, SparseIndex y_):
129        cdef:
130            Py_ssize_t out_length, xi, yi = 0, result_indexer = 0
131            int32_t xind
132            ndarray[int32_t, ndim=1] xindices, yindices, new_indices
133            IntIndex y
134
135        # if is one already, returns self
136        y = y_.to_int_index()
137
138        if self.length != y.length:
139            raise Exception('Indices must reference same underlying length')
140
141        xindices = self.indices
142        yindices = y.indices
143        new_indices = np.empty(min(
144            len(xindices), len(yindices)), dtype=np.int32)
145
146        for xi in range(self.npoints):
147            xind = xindices[xi]
148
149            while yi < y.npoints and yindices[yi] < xind:
150                yi += 1
151
152            if yi >= y.npoints:
153                break
154
155            # TODO: would a two-pass algorithm be faster?
156            if yindices[yi] == xind:
157                new_indices[result_indexer] = xind
158                result_indexer += 1
159
160        new_indices = new_indices[:result_indexer]
161        return IntIndex(self.length, new_indices)
162
163    cpdef IntIndex make_union(self, SparseIndex y_):
164
165        cdef:
166            ndarray[int32_t, ndim=1] new_indices
167            IntIndex y
168
169        # if is one already, returns self
170        y = y_.to_int_index()
171
172        if self.length != y.length:
173            raise ValueError('Indices must reference same underlying length')
174
175        new_indices = np.union1d(self.indices, y.indices)
176        return IntIndex(self.length, new_indices)
177
178    @cython.wraparound(False)
179    cpdef int32_t lookup(self, Py_ssize_t index):
180        """
181        Return the internal location if value exists on given index.
182        Return -1 otherwise.
183        """
184        cdef:
185            int32_t res
186            ndarray[int32_t, ndim=1] inds
187
188        inds = self.indices
189        if self.npoints == 0:
190            return -1
191        elif index < 0 or self.length <= index:
192            return -1
193
194        res = inds.searchsorted(index)
195        if res == self.npoints:
196            return -1
197        elif inds[res] == index:
198            return res
199        else:
200            return -1
201
202    @cython.wraparound(False)
203    cpdef ndarray[int32_t] lookup_array(self, ndarray[int32_t, ndim=1] indexer):
204        """
205        Vectorized lookup, returns ndarray[int32_t]
206        """
207        cdef:
208            Py_ssize_t n, i, ind_val
209            ndarray[int32_t, ndim=1] inds
210            ndarray[uint8_t, ndim=1, cast=True] mask
211            ndarray[int32_t, ndim=1] masked
212            ndarray[int32_t, ndim=1] res
213            ndarray[int32_t, ndim=1] results
214
215        n = len(indexer)
216        results = np.empty(n, dtype=np.int32)
217        results[:] = -1
218
219        if self.npoints == 0:
220            return results
221
222        inds = self.indices
223        mask = (inds[0] <= indexer) & (indexer <= inds[len(inds) - 1])
224
225        masked = indexer[mask]
226        res = inds.searchsorted(masked).astype(np.int32)
227
228        res[inds[res] != masked] = -1
229        results[mask] = res
230        return results
231
232    cpdef ndarray reindex(self, ndarray[float64_t, ndim=1] values,
233                          float64_t fill_value, SparseIndex other_):
234        cdef:
235            Py_ssize_t i = 0, j = 0
236            IntIndex other
237            ndarray[float64_t, ndim=1] result
238            ndarray[int32_t, ndim=1] sinds, oinds
239
240        other = other_.to_int_index()
241
242        oinds = other.indices
243        sinds = self.indices
244
245        result = np.empty(other.npoints, dtype=np.float64)
246        result[:] = fill_value
247
248        for i in range(other.npoints):
249            while oinds[i] > sinds[j] and j < self.npoints:
250                j += 1
251
252            if j == self.npoints:
253                break
254
255            if oinds[i] < sinds[j]:
256                continue
257            elif oinds[i] == sinds[j]:
258                result[i] = values[j]
259                j += 1
260
261        return result
262
263    cpdef put(self, ndarray[float64_t, ndim=1] values,
264              ndarray[int32_t, ndim=1] indices, object to_put):
265        pass
266
267    cpdef take(self, ndarray[float64_t, ndim=1] values,
268               ndarray[int32_t, ndim=1] indices):
269        pass
270
271
272cpdef get_blocks(ndarray[int32_t, ndim=1] indices):
273    cdef:
274        Py_ssize_t init_len, i, npoints, result_indexer = 0
275        int32_t block, length = 1, cur, prev
276        ndarray[int32_t, ndim=1] locs, lens
277
278    npoints = len(indices)
279
280    # just handle the special empty case separately
281    if npoints == 0:
282        return np.array([], dtype=np.int32), np.array([], dtype=np.int32)
283
284    # block size can't be longer than npoints
285    locs = np.empty(npoints, dtype=np.int32)
286    lens = np.empty(npoints, dtype=np.int32)
287
288    # TODO: two-pass algorithm faster?
289    prev = block = indices[0]
290    for i in range(1, npoints):
291        cur = indices[i]
292        if cur - prev > 1:
293            # new block
294            locs[result_indexer] = block
295            lens[result_indexer] = length
296            block = cur
297            length = 1
298            result_indexer += 1
299        else:
300            # same block, increment length
301            length += 1
302
303        prev = cur
304
305    locs[result_indexer] = block
306    lens[result_indexer] = length
307    result_indexer += 1
308    locs = locs[:result_indexer]
309    lens = lens[:result_indexer]
310    return locs, lens
311
312
313# -----------------------------------------------------------------------------
314# BlockIndex
315
316cdef class BlockIndex(SparseIndex):
317    """
318    Object for holding block-based sparse indexing information
319
320    Parameters
321    ----------
322    """
323    cdef readonly:
324        int32_t nblocks, npoints, length
325        ndarray blocs, blengths
326
327    cdef:
328        object __weakref__  # need to be picklable
329        int32_t *locbuf
330        int32_t *lenbuf
331
332    def __init__(self, length, blocs, blengths):
333
334        self.blocs = np.ascontiguousarray(blocs, dtype=np.int32)
335        self.blengths = np.ascontiguousarray(blengths, dtype=np.int32)
336
337        # in case we need
338        self.locbuf = <int32_t*>self.blocs.data
339        self.lenbuf = <int32_t*>self.blengths.data
340
341        self.length = length
342        self.nblocks = np.int32(len(self.blocs))
343        self.npoints = self.blengths.sum()
344
345        # self.block_start = blocs
346        # self.block_end = blocs + blengths
347
348        self.check_integrity()
349
350    def __reduce__(self):
351        args = (self.length, self.blocs, self.blengths)
352        return BlockIndex, args
353
354    def __repr__(self) -> str:
355        output = 'BlockIndex\n'
356        output += f'Block locations: {repr(self.blocs)}\n'
357        output += f'Block lengths: {repr(self.blengths)}'
358
359        return output
360
361    @property
362    def nbytes(self) -> int:
363        return self.blocs.nbytes + self.blengths.nbytes
364
365    @property
366    def ngaps(self) -> int:
367        return self.length - self.npoints
368
369    cpdef check_integrity(self):
370        """
371        Check:
372        - Locations are in ascending order
373        - No overlapping blocks
374        - Blocks to not start after end of index, nor extend beyond end
375        """
376        cdef:
377            Py_ssize_t i
378            ndarray[int32_t, ndim=1] blocs, blengths
379
380        blocs = self.blocs
381        blengths = self.blengths
382
383        if len(blocs) != len(blengths):
384            raise ValueError('block bound arrays must be same length')
385
386        for i in range(self.nblocks):
387            if i > 0:
388                if blocs[i] <= blocs[i - 1]:
389                    raise ValueError('Locations not in ascending order')
390
391            if i < self.nblocks - 1:
392                if blocs[i] + blengths[i] > blocs[i + 1]:
393                    raise ValueError(f'Block {i} overlaps')
394            else:
395                if blocs[i] + blengths[i] > self.length:
396                    raise ValueError(f'Block {i} extends beyond end')
397
398            # no zero-length blocks
399            if blengths[i] == 0:
400                raise ValueError(f'Zero-length block {i}')
401
402    def equals(self, other: object) -> bool:
403        if not isinstance(other, BlockIndex):
404            return False
405
406        if self is other:
407            return True
408
409        same_length = self.length == other.length
410        same_blocks = (np.array_equal(self.blocs, other.blocs) and
411                       np.array_equal(self.blengths, other.blengths))
412        return same_length and same_blocks
413
414    def to_block_index(self):
415        return self
416
417    def to_int_index(self):
418        cdef:
419            int32_t i = 0, j, b
420            int32_t offset
421            ndarray[int32_t, ndim=1] indices
422
423        indices = np.empty(self.npoints, dtype=np.int32)
424
425        for b in range(self.nblocks):
426            offset = self.locbuf[b]
427
428            for j in range(self.lenbuf[b]):
429                indices[i] = offset + j
430                i += 1
431
432        return IntIndex(self.length, indices)
433
434    cpdef BlockIndex intersect(self, SparseIndex other):
435        """
436        Intersect two BlockIndex objects
437
438        Returns
439        -------
440        BlockIndex
441        """
442        cdef:
443            BlockIndex y
444            ndarray[int32_t, ndim=1] xloc, xlen, yloc, ylen, out_bloc, out_blen
445            Py_ssize_t xi = 0, yi = 0, max_len, result_indexer = 0
446            int32_t cur_loc, cur_length, diff
447
448        y = other.to_block_index()
449
450        if self.length != y.length:
451            raise Exception('Indices must reference same underlying length')
452
453        xloc = self.blocs
454        xlen = self.blengths
455        yloc = y.blocs
456        ylen = y.blengths
457
458        # block may be split, but can't exceed original len / 2 + 1
459        max_len = min(self.length, y.length) // 2 + 1
460        out_bloc = np.empty(max_len, dtype=np.int32)
461        out_blen = np.empty(max_len, dtype=np.int32)
462
463        while True:
464            # we are done (or possibly never began)
465            if xi >= self.nblocks or yi >= y.nblocks:
466                break
467
468            # completely symmetric...would like to avoid code dup but oh well
469            if xloc[xi] >= yloc[yi]:
470                cur_loc = xloc[xi]
471                diff = xloc[xi] - yloc[yi]
472
473                if ylen[yi] <= diff:
474                    # have to skip this block
475                    yi += 1
476                    continue
477
478                if ylen[yi] - diff < xlen[xi]:
479                    # take end of y block, move onward
480                    cur_length = ylen[yi] - diff
481                    yi += 1
482                else:
483                    # take end of x block
484                    cur_length = xlen[xi]
485                    xi += 1
486
487            else:  # xloc[xi] < yloc[yi]
488                cur_loc = yloc[yi]
489                diff = yloc[yi] - xloc[xi]
490
491                if xlen[xi] <= diff:
492                    # have to skip this block
493                    xi += 1
494                    continue
495
496                if xlen[xi] - diff < ylen[yi]:
497                    # take end of x block, move onward
498                    cur_length = xlen[xi] - diff
499                    xi += 1
500                else:
501                    # take end of y block
502                    cur_length = ylen[yi]
503                    yi += 1
504
505            out_bloc[result_indexer] = cur_loc
506            out_blen[result_indexer] = cur_length
507            result_indexer += 1
508
509        out_bloc = out_bloc[:result_indexer]
510        out_blen = out_blen[:result_indexer]
511
512        return BlockIndex(self.length, out_bloc, out_blen)
513
514    cpdef BlockIndex make_union(self, SparseIndex y):
515        """
516        Combine together two BlockIndex objects, accepting indices if contained
517        in one or the other
518
519        Parameters
520        ----------
521        other : SparseIndex
522
523        Notes
524        -----
525        union is a protected keyword in Cython, hence make_union
526
527        Returns
528        -------
529        BlockIndex
530        """
531        return BlockUnion(self, y.to_block_index()).result
532
533    cpdef Py_ssize_t lookup(self, Py_ssize_t index):
534        """
535        Return the internal location if value exists on given index.
536        Return -1 otherwise.
537        """
538        cdef:
539            Py_ssize_t i, cum_len
540            ndarray[int32_t, ndim=1] locs, lens
541
542        locs = self.blocs
543        lens = self.blengths
544
545        if self.nblocks == 0:
546            return -1
547        elif index < locs[0]:
548            return -1
549
550        cum_len = 0
551        for i in range(self.nblocks):
552            if index >= locs[i] and index < locs[i] + lens[i]:
553                return cum_len + index - locs[i]
554            cum_len += lens[i]
555
556        return -1
557
558    @cython.wraparound(False)
559    cpdef ndarray[int32_t] lookup_array(self, ndarray[int32_t, ndim=1] indexer):
560        """
561        Vectorized lookup, returns ndarray[int32_t]
562        """
563        cdef:
564            Py_ssize_t n, i, j, ind_val
565            ndarray[int32_t, ndim=1] locs, lens
566            ndarray[int32_t, ndim=1] results
567
568        locs = self.blocs
569        lens = self.blengths
570
571        n = len(indexer)
572        results = np.empty(n, dtype=np.int32)
573        results[:] = -1
574
575        if self.npoints == 0:
576            return results
577
578        for i in range(n):
579            ind_val = indexer[i]
580            if not (ind_val < 0 or self.length <= ind_val):
581                cum_len = 0
582                for j in range(self.nblocks):
583                    if ind_val >= locs[j] and ind_val < locs[j] + lens[j]:
584                        results[i] = cum_len + ind_val - locs[j]
585                    cum_len += lens[j]
586        return results
587
588    cpdef ndarray reindex(self, ndarray[float64_t, ndim=1] values,
589                          float64_t fill_value, SparseIndex other_):
590        cdef:
591            Py_ssize_t i = 0, j = 0, ocur, ocurlen
592            BlockIndex other
593            ndarray[float64_t, ndim=1] result
594            ndarray[int32_t, ndim=1] slocs, slens, olocs, olens
595
596        other = other_.to_block_index()
597
598        olocs = other.blocs
599        olens = other.blengths
600        slocs = self.blocs
601        slens = self.blengths
602
603        result = np.empty(other.npoints, dtype=np.float64)
604
605        for i in range(other.nblocks):
606            ocur = olocs[i]
607            ocurlen = olens[i]
608
609            while slocs[j] + slens[j] < ocur:
610                j += 1
611
612    cpdef put(self, ndarray[float64_t, ndim=1] values,
613              ndarray[int32_t, ndim=1] indices, object to_put):
614        pass
615
616    cpdef take(self, ndarray[float64_t, ndim=1] values,
617               ndarray[int32_t, ndim=1] indices):
618        pass
619
620
621cdef class BlockMerge:
622    """
623    Object-oriented approach makes sharing state between recursive functions a
624    lot easier and reduces code duplication
625    """
626    cdef:
627        BlockIndex x, y, result
628        ndarray xstart, xlen, xend, ystart, ylen, yend
629        int32_t xi, yi  # block indices
630
631    def __init__(self, BlockIndex x, BlockIndex y):
632        self.x = x
633        self.y = y
634
635        if x.length != y.length:
636            raise Exception('Indices must reference same underlying length')
637
638        self.xstart = self.x.blocs
639        self.ystart = self.y.blocs
640
641        self.xend = self.x.blocs + self.x.blengths
642        self.yend = self.y.blocs + self.y.blengths
643
644        # self.xlen = self.x.blengths
645        # self.ylen = self.y.blengths
646
647        self.xi = 0
648        self.yi = 0
649
650        self.result = self._make_merged_blocks()
651
652    cdef _make_merged_blocks(self):
653        raise NotImplementedError
654
655    cdef _set_current_indices(self, int32_t xi, int32_t yi, bint mode):
656        if mode == 0:
657            self.xi = xi
658            self.yi = yi
659        else:
660            self.xi = yi
661            self.yi = xi
662
663
664cdef class BlockUnion(BlockMerge):
665    """
666    Object-oriented approach makes sharing state between recursive functions a
667    lot easier and reduces code duplication
668    """
669
670    cdef _make_merged_blocks(self):
671        cdef:
672            ndarray[int32_t, ndim=1] xstart, xend, ystart
673            ndarray[int32_t, ndim=1] yend, out_bloc, out_blen
674            int32_t nstart, nend, diff
675            Py_ssize_t max_len, result_indexer = 0
676
677        xstart = self.xstart
678        xend = self.xend
679        ystart = self.ystart
680        yend = self.yend
681
682        max_len = min(self.x.length, self.y.length) // 2 + 1
683        out_bloc = np.empty(max_len, dtype=np.int32)
684        out_blen = np.empty(max_len, dtype=np.int32)
685
686        while True:
687            # we are done (or possibly never began)
688            if self.xi >= self.x.nblocks and self.yi >= self.y.nblocks:
689                break
690            elif self.yi >= self.y.nblocks:
691                # through with y, just pass through x blocks
692                nstart = xstart[self.xi]
693                nend = xend[self.xi]
694                self.xi += 1
695            elif self.xi >= self.x.nblocks:
696                # through with x, just pass through y blocks
697                nstart = ystart[self.yi]
698                nend = yend[self.yi]
699                self.yi += 1
700            else:
701                # find end of new block
702                if xstart[self.xi] < ystart[self.yi]:
703                    nstart = xstart[self.xi]
704                    nend = self._find_next_block_end(0)
705                else:
706                    nstart = ystart[self.yi]
707                    nend = self._find_next_block_end(1)
708
709            out_bloc[result_indexer] = nstart
710            out_blen[result_indexer] = nend - nstart
711            result_indexer += 1
712
713        out_bloc = out_bloc[:result_indexer]
714        out_blen = out_blen[:result_indexer]
715
716        return BlockIndex(self.x.length, out_bloc, out_blen)
717
718    cdef int32_t _find_next_block_end(self, bint mode) except -1:
719        """
720        Wow, this got complicated in a hurry
721
722        mode 0: block started in index x
723        mode 1: block started in index y
724        """
725        cdef:
726            ndarray[int32_t, ndim=1] xstart, xend, ystart, yend
727            int32_t xi, yi, xnblocks, ynblocks, nend
728
729        if mode != 0 and mode != 1:
730            raise Exception('Mode must be 0 or 1')
731
732        # so symmetric code will work
733        if mode == 0:
734            xstart = self.xstart
735            xend = self.xend
736            xi = self.xi
737
738            ystart = self.ystart
739            yend = self.yend
740            yi = self.yi
741            ynblocks = self.y.nblocks
742        else:
743            xstart = self.ystart
744            xend = self.yend
745            xi = self.yi
746
747            ystart = self.xstart
748            yend = self.xend
749            yi = self.xi
750            ynblocks = self.x.nblocks
751
752        nend = xend[xi]
753
754        # done with y?
755        if yi == ynblocks:
756            self._set_current_indices(xi + 1, yi, mode)
757            return nend
758        elif nend < ystart[yi]:
759            # block ends before y block
760            self._set_current_indices(xi + 1, yi, mode)
761            return nend
762        else:
763            while yi < ynblocks and nend > yend[yi]:
764                yi += 1
765
766            self._set_current_indices(xi + 1, yi, mode)
767
768            if yi == ynblocks:
769                return nend
770
771            if nend < ystart[yi]:
772                # we're done, return the block end
773                return nend
774            else:
775                # merge blocks, continue searching
776                # this also catches the case where blocks
777                return self._find_next_block_end(1 - mode)
778
779
780# -----------------------------------------------------------------------------
781# Sparse arithmetic
782
783include "sparse_op_helper.pxi"
784
785
786# -----------------------------------------------------------------------------
787# SparseArray mask create operations
788
789def make_mask_object_ndarray(ndarray[object, ndim=1] arr, object fill_value):
790    cdef:
791        object value
792        Py_ssize_t i
793        Py_ssize_t new_length = len(arr)
794        ndarray[int8_t, ndim=1] mask
795
796    mask = np.ones(new_length, dtype=np.int8)
797
798    for i in range(new_length):
799        value = arr[i]
800        if value == fill_value and type(value) == type(fill_value):
801            mask[i] = 0
802
803    return mask.view(dtype=bool)
804