1# By Jake Vanderplas (2013) <jakevdp@cs.washington.edu>
2# written for the scikit-learn project
3# License: BSD
4
5import numpy as np
6cimport numpy as np
7np.import_array()  # required in order to use C-API
8
9
10# First, define a function to get an ndarray from a memory buffer
11cdef extern from "arrayobject.h":
12    object PyArray_SimpleNewFromData(int nd, np.npy_intp* dims,
13                                     int typenum, void* data)
14
15
16cdef inline np.ndarray _buffer_to_ndarray(const DTYPE_t* x, np.npy_intp n):
17    # Wrap a memory buffer with an ndarray. Warning: this is not robust.
18    # In particular, if x is deallocated before the returned array goes
19    # out of scope, this could cause memory errors.  Since there is not
20    # a possibility of this for our use-case, this should be safe.
21
22    # Note: this Segfaults unless np.import_array() is called above
23    return PyArray_SimpleNewFromData(1, &n, DTYPECODE, <void*>x)
24
25
26# some handy constants
27from libc.math cimport fabs, sqrt, exp, pow, cos, sin, asin
28cdef DTYPE_t INF = np.inf
29
30from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t, DTYPECODE
31from ..utils._typedefs import DTYPE, ITYPE
32from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
33
34######################################################################
35# newObj function
36#  this is a helper function for pickling
37def newObj(obj):
38    return obj.__new__(obj)
39
40
41######################################################################
42# metric mappings
43#  These map from metric id strings to class names
44METRIC_MAPPING = {'euclidean': EuclideanDistance,
45                  'l2': EuclideanDistance,
46                  'minkowski': MinkowskiDistance,
47                  'p': MinkowskiDistance,
48                  'manhattan': ManhattanDistance,
49                  'cityblock': ManhattanDistance,
50                  'l1': ManhattanDistance,
51                  'chebyshev': ChebyshevDistance,
52                  'infinity': ChebyshevDistance,
53                  'seuclidean': SEuclideanDistance,
54                  'mahalanobis': MahalanobisDistance,
55                  'wminkowski': WMinkowskiDistance,
56                  'hamming': HammingDistance,
57                  'canberra': CanberraDistance,
58                  'braycurtis': BrayCurtisDistance,
59                  'matching': MatchingDistance,
60                  'jaccard': JaccardDistance,
61                  'dice': DiceDistance,
62                  'kulsinski': KulsinskiDistance,
63                  'rogerstanimoto': RogersTanimotoDistance,
64                  'russellrao': RussellRaoDistance,
65                  'sokalmichener': SokalMichenerDistance,
66                  'sokalsneath': SokalSneathDistance,
67                  'haversine': HaversineDistance,
68                  'pyfunc': PyFuncDistance}
69
70
71def get_valid_metric_ids(L):
72    """Given an iterable of metric class names or class identifiers,
73    return a list of metric IDs which map to those classes.
74
75    Example:
76    >>> L = get_valid_metric_ids([EuclideanDistance, 'ManhattanDistance'])
77    >>> sorted(L)
78    ['cityblock', 'euclidean', 'l1', 'l2', 'manhattan']
79    """
80    return [key for (key, val) in METRIC_MAPPING.items()
81            if (val.__name__ in L) or (val in L)]
82
83
84######################################################################
85# Distance Metric Classes
86cdef class DistanceMetric:
87    """DistanceMetric class
88
89    This class provides a uniform interface to fast distance metric
90    functions.  The various metrics can be accessed via the :meth:`get_metric`
91    class method and the metric string identifier (see below).
92
93    Examples
94    --------
95    >>> from sklearn.metrics import DistanceMetric
96    >>> dist = DistanceMetric.get_metric('euclidean')
97    >>> X = [[0, 1, 2],
98             [3, 4, 5]]
99    >>> dist.pairwise(X)
100    array([[ 0.        ,  5.19615242],
101           [ 5.19615242,  0.        ]])
102
103    Available Metrics
104
105    The following lists the string metric identifiers and the associated
106    distance metric classes:
107
108    **Metrics intended for real-valued vector spaces:**
109
110    ==============  ====================  ========  ===============================
111    identifier      class name            args      distance function
112    --------------  --------------------  --------  -------------------------------
113    "euclidean"     EuclideanDistance     -         ``sqrt(sum((x - y)^2))``
114    "manhattan"     ManhattanDistance     -         ``sum(|x - y|)``
115    "chebyshev"     ChebyshevDistance     -         ``max(|x - y|)``
116    "minkowski"     MinkowskiDistance     p, w      ``sum(w * |x - y|^p)^(1/p)``
117    "wminkowski"    WMinkowskiDistance    p, w      ``sum(|w * (x - y)|^p)^(1/p)``
118    "seuclidean"    SEuclideanDistance    V         ``sqrt(sum((x - y)^2 / V))``
119    "mahalanobis"   MahalanobisDistance   V or VI   ``sqrt((x - y)' V^-1 (x - y))``
120    ==============  ====================  ========  ===============================
121
122    Note that "minkowski" with a non-None `w` parameter actually calls
123    `WMinkowskiDistance` with `w=w ** (1/p)` in order to be consistent with the
124    parametrization of scipy 1.8 and later.
125
126    **Metrics intended for two-dimensional vector spaces:**  Note that the haversine
127    distance metric requires data in the form of [latitude, longitude] and both
128    inputs and outputs are in units of radians.
129
130    ============  ==================  ===============================================================
131    identifier    class name          distance function
132    ------------  ------------------  ---------------------------------------------------------------
133    "haversine"   HaversineDistance   ``2 arcsin(sqrt(sin^2(0.5*dx) + cos(x1)cos(x2)sin^2(0.5*dy)))``
134    ============  ==================  ===============================================================
135
136
137    **Metrics intended for integer-valued vector spaces:**  Though intended
138    for integer-valued vectors, these are also valid metrics in the case of
139    real-valued vectors.
140
141    =============  ====================  ========================================
142    identifier     class name            distance function
143    -------------  --------------------  ----------------------------------------
144    "hamming"      HammingDistance       ``N_unequal(x, y) / N_tot``
145    "canberra"     CanberraDistance      ``sum(|x - y| / (|x| + |y|))``
146    "braycurtis"   BrayCurtisDistance    ``sum(|x - y|) / (sum(|x|) + sum(|y|))``
147    =============  ====================  ========================================
148
149    **Metrics intended for boolean-valued vector spaces:**  Any nonzero entry
150    is evaluated to "True".  In the listings below, the following
151    abbreviations are used:
152
153     - N  : number of dimensions
154     - NTT : number of dims in which both values are True
155     - NTF : number of dims in which the first value is True, second is False
156     - NFT : number of dims in which the first value is False, second is True
157     - NFF : number of dims in which both values are False
158     - NNEQ : number of non-equal dimensions, NNEQ = NTF + NFT
159     - NNZ : number of nonzero dimensions, NNZ = NTF + NFT + NTT
160
161    =================  =======================  ===============================
162    identifier         class name               distance function
163    -----------------  -----------------------  -------------------------------
164    "jaccard"          JaccardDistance          NNEQ / NNZ
165    "matching"         MatchingDistance         NNEQ / N
166    "dice"             DiceDistance             NNEQ / (NTT + NNZ)
167    "kulsinski"        KulsinskiDistance        (NNEQ + N - NTT) / (NNEQ + N)
168    "rogerstanimoto"   RogersTanimotoDistance   2 * NNEQ / (N + NNEQ)
169    "russellrao"       RussellRaoDistance       (N - NTT) / N
170    "sokalmichener"    SokalMichenerDistance    2 * NNEQ / (N + NNEQ)
171    "sokalsneath"      SokalSneathDistance      NNEQ / (NNEQ + 0.5 * NTT)
172    =================  =======================  ===============================
173
174    **User-defined distance:**
175
176    ===========    ===============    =======
177    identifier     class name         args
178    -----------    ---------------    -------
179    "pyfunc"       PyFuncDistance     func
180    ===========    ===============    =======
181
182    Here ``func`` is a function which takes two one-dimensional numpy
183    arrays, and returns a distance.  Note that in order to be used within
184    the BallTree, the distance must be a true metric:
185    i.e. it must satisfy the following properties
186
187    1) Non-negativity: d(x, y) >= 0
188    2) Identity: d(x, y) = 0 if and only if x == y
189    3) Symmetry: d(x, y) = d(y, x)
190    4) Triangle Inequality: d(x, y) + d(y, z) >= d(x, z)
191
192    Because of the Python object overhead involved in calling the python
193    function, this will be fairly slow, but it will have the same
194    scaling as other distances.
195    """
196    def __cinit__(self):
197        self.p = 2
198        self.vec = np.zeros(1, dtype=DTYPE, order='c')
199        self.mat = np.zeros((1, 1), dtype=DTYPE, order='c')
200        self.size = 1
201
202    def __reduce__(self):
203        """
204        reduce method used for pickling
205        """
206        return (newObj, (self.__class__,), self.__getstate__())
207
208    def __getstate__(self):
209        """
210        get state for pickling
211        """
212        if self.__class__.__name__ == "PyFuncDistance":
213            return (float(self.p), np.asarray(self.vec), np.asarray(self.mat), self.func, self.kwargs)
214        return (float(self.p), np.asarray(self.vec), np.asarray(self.mat))
215
216    def __setstate__(self, state):
217        """
218        set state for pickling
219        """
220        self.p = state[0]
221        self.vec = ReadonlyArrayWrapper(state[1])
222        self.mat = ReadonlyArrayWrapper(state[2])
223        if self.__class__.__name__ == "PyFuncDistance":
224            self.func = state[3]
225            self.kwargs = state[4]
226        self.size = self.vec.shape[0]
227
228    @classmethod
229    def get_metric(cls, metric, **kwargs):
230        """Get the given distance metric from the string identifier.
231
232        See the docstring of DistanceMetric for a list of available metrics.
233
234        Parameters
235        ----------
236        metric : str or class name
237            The distance metric to use
238        **kwargs
239            additional arguments will be passed to the requested metric
240        """
241        if isinstance(metric, DistanceMetric):
242            return metric
243
244        if callable(metric):
245            return PyFuncDistance(metric, **kwargs)
246
247        # Map the metric string ID to the metric class
248        if isinstance(metric, type) and issubclass(metric, DistanceMetric):
249            pass
250        else:
251            try:
252                metric = METRIC_MAPPING[metric]
253            except:
254                raise ValueError("Unrecognized metric '%s'" % metric)
255
256        # In Minkowski special cases, return more efficient methods
257        if metric is MinkowskiDistance:
258            p = kwargs.pop('p', 2)
259            w = kwargs.pop('w', None)
260            if w is not None:
261                # Be consistent with scipy 1.8 conventions: in scipy 1.8,
262                # 'wminkowski' was removed in favor of passing a
263                # weight vector directly to 'minkowski', however
264                # the new weights apply to the absolute differences raised to
265                # the p power instead of the absolute difference as in
266                # previous versions of scipy.
267                # WMinkowskiDistance in sklearn implements the weighting
268                # scheme of the old 'wminkowski' in scipy < 1.8, hence the
269                # following adaptation:
270                return WMinkowskiDistance(p, w ** (1/p), **kwargs)
271            if p == 1:
272                return ManhattanDistance(**kwargs)
273            elif p == 2:
274                return EuclideanDistance(**kwargs)
275            elif np.isinf(p):
276                return ChebyshevDistance(**kwargs)
277            else:
278                return MinkowskiDistance(p, **kwargs)
279        else:
280            return metric(**kwargs)
281
282    def __init__(self):
283        if self.__class__ is DistanceMetric:
284            raise NotImplementedError("DistanceMetric is an abstract class")
285
286    def _validate_data(self, X):
287        """Validate the input data.
288
289        This should be overridden in a base class if a specific input format
290        is required.
291        """
292        return
293
294    cdef DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
295                      ITYPE_t size) nogil except -1:
296        """Compute the distance between vectors x1 and x2
297
298        This should be overridden in a base class.
299        """
300        return -999
301
302    cdef DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
303                       ITYPE_t size) nogil except -1:
304        """Compute the rank-preserving surrogate distance between vectors x1 and x2.
305
306        This can optionally be overridden in a base class.
307
308        The rank-preserving surrogate distance is any measure that yields the same
309        rank as the distance, but is more efficient to compute. For example, for the
310        Euclidean metric, the surrogate distance is the squared-euclidean distance.
311        """
312        return self.dist(x1, x2, size)
313
314    cdef int pdist(self, const DTYPE_t[:, ::1] X, DTYPE_t[:, ::1] D) except -1:
315        """compute the pairwise distances between points in X"""
316        cdef ITYPE_t i1, i2
317        for i1 in range(X.shape[0]):
318            for i2 in range(i1, X.shape[0]):
319                D[i1, i2] = self.dist(&X[i1, 0], &X[i2, 0], X.shape[1])
320                D[i2, i1] = D[i1, i2]
321        return 0
322
323    cdef int cdist(self, const DTYPE_t[:, ::1] X, const DTYPE_t[:, ::1] Y,
324                   DTYPE_t[:, ::1] D) except -1:
325        """compute the cross-pairwise distances between arrays X and Y"""
326        cdef ITYPE_t i1, i2
327        if X.shape[1] != Y.shape[1]:
328            raise ValueError('X and Y must have the same second dimension')
329        for i1 in range(X.shape[0]):
330            for i2 in range(Y.shape[0]):
331                D[i1, i2] = self.dist(&X[i1, 0], &Y[i2, 0], X.shape[1])
332        return 0
333
334    cdef DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
335        """Convert the rank-preserving surrogate distance to the distance"""
336        return rdist
337
338    cdef DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1:
339        """Convert the distance to the rank-preserving surrogate distance"""
340        return dist
341
342    def rdist_to_dist(self, rdist):
343        """Convert the rank-preserving surrogate distance to the distance.
344
345        The surrogate distance is any measure that yields the same rank as the
346        distance, but is more efficient to compute. For example, for the
347        Euclidean metric, the surrogate distance is the squared-euclidean distance.
348
349        Parameters
350        ----------
351        rdist : double
352            Surrogate distance.
353
354        Returns
355        -------
356        double
357            True distance.
358        """
359        return rdist
360
361    def dist_to_rdist(self, dist):
362        """Convert the true distance to the rank-preserving surrogate distance.
363
364        The surrogate distance is any measure that yields the same rank as the
365        distance, but is more efficient to compute. For example, for the
366        Euclidean metric, the surrogate distance is the squared-euclidean distance.
367
368        Parameters
369        ----------
370        dist : double
371            True distance.
372
373        Returns
374        -------
375        double
376            Surrogate distance.
377        """
378        return dist
379
380    def pairwise(self, X, Y=None):
381        """Compute the pairwise distances between X and Y
382
383        This is a convenience routine for the sake of testing.  For many
384        metrics, the utilities in scipy.spatial.distance.cdist and
385        scipy.spatial.distance.pdist will be faster.
386
387        Parameters
388        ----------
389        X : array-like
390            Array of shape (Nx, D), representing Nx points in D dimensions.
391        Y : array-like (optional)
392            Array of shape (Ny, D), representing Ny points in D dimensions.
393            If not specified, then Y=X.
394
395        Returns
396        -------
397        dist : ndarray
398            The shape (Nx, Ny) array of pairwise distances between points in
399            X and Y.
400        """
401        cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] Xarr
402        cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] Yarr
403        cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] Darr
404
405        Xarr = np.asarray(X, dtype=DTYPE, order='C')
406        self._validate_data(Xarr)
407        if Y is None:
408            Darr = np.zeros((Xarr.shape[0], Xarr.shape[0]),
409                         dtype=DTYPE, order='C')
410            self.pdist(Xarr, Darr)
411        else:
412            Yarr = np.asarray(Y, dtype=DTYPE, order='C')
413            self._validate_data(Yarr)
414            Darr = np.zeros((Xarr.shape[0], Yarr.shape[0]),
415                         dtype=DTYPE, order='C')
416            self.cdist(Xarr, Yarr, Darr)
417        return Darr
418
419
420#------------------------------------------------------------
421# Euclidean Distance
422#  d = sqrt(sum(x_i^2 - y_i^2))
423cdef class EuclideanDistance(DistanceMetric):
424    r"""Euclidean Distance metric
425
426    .. math::
427       D(x, y) = \sqrt{ \sum_i (x_i - y_i) ^ 2 }
428    """
429    def __init__(self):
430        self.p = 2
431
432    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
433                             ITYPE_t size) nogil except -1:
434        return euclidean_dist(x1, x2, size)
435
436    cdef inline DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
437                              ITYPE_t size) nogil except -1:
438        return euclidean_rdist(x1, x2, size)
439
440    cdef inline DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
441        return sqrt(rdist)
442
443    cdef inline DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1:
444        return dist * dist
445
446    def rdist_to_dist(self, rdist):
447        return np.sqrt(rdist)
448
449    def dist_to_rdist(self, dist):
450        return dist ** 2
451
452
453#------------------------------------------------------------
454# SEuclidean Distance
455#  d = sqrt(sum((x_i - y_i2)^2 / v_i))
456cdef class SEuclideanDistance(DistanceMetric):
457    r"""Standardized Euclidean Distance metric
458
459    .. math::
460       D(x, y) = \sqrt{ \sum_i \frac{ (x_i - y_i) ^ 2}{V_i} }
461    """
462    def __init__(self, V):
463        self.vec = ReadonlyArrayWrapper(np.asarray(V, dtype=DTYPE))
464        self.size = self.vec.shape[0]
465        self.p = 2
466
467    def _validate_data(self, X):
468        if X.shape[1] != self.size:
469            raise ValueError('SEuclidean dist: size of V does not match')
470
471    cdef inline DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
472                              ITYPE_t size) nogil except -1:
473        cdef DTYPE_t tmp, d=0
474        cdef np.intp_t j
475        for j in range(size):
476            tmp = x1[j] - x2[j]
477            d += tmp * tmp / self.vec[j]
478        return d
479
480    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
481                             ITYPE_t size) nogil except -1:
482        return sqrt(self.rdist(x1, x2, size))
483
484    cdef inline DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
485        return sqrt(rdist)
486
487    cdef inline DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1:
488        return dist * dist
489
490    def rdist_to_dist(self, rdist):
491        return np.sqrt(rdist)
492
493    def dist_to_rdist(self, dist):
494        return dist ** 2
495
496
497#------------------------------------------------------------
498# Manhattan Distance
499#  d = sum(abs(x_i - y_i))
500cdef class ManhattanDistance(DistanceMetric):
501    r"""Manhattan/City-block Distance metric
502
503    .. math::
504       D(x, y) = \sum_i |x_i - y_i|
505    """
506    def __init__(self):
507        self.p = 1
508
509    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
510                             ITYPE_t size) nogil except -1:
511        cdef DTYPE_t d = 0
512        cdef np.intp_t j
513        for j in range(size):
514            d += fabs(x1[j] - x2[j])
515        return d
516
517
518#------------------------------------------------------------
519# Chebyshev Distance
520#  d = max_i(abs(x_i - y_i))
521cdef class ChebyshevDistance(DistanceMetric):
522    """Chebyshev/Infinity Distance
523
524    .. math::
525       D(x, y) = max_i (|x_i - y_i|)
526
527    Examples
528    --------
529    >>> from sklearn.metrics.dist_metrics import DistanceMetric
530    >>> dist = DistanceMetric.get_metric('chebyshev')
531    >>> X = [[0, 1, 2],
532    ...      [3, 4, 5]]
533    >>> Y = [[-1, 0, 1],
534    ...      [3, 4, 5]]
535    >>> dist.pairwise(X, Y)
536    array([[1.732..., 5.196...],
537           [6.928..., 0....   ]])
538    """
539    def __init__(self):
540        self.p = INF
541
542    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
543                             ITYPE_t size) nogil except -1:
544        cdef DTYPE_t d = 0
545        cdef np.intp_t j
546        for j in range(size):
547            d = fmax(d, fabs(x1[j] - x2[j]))
548        return d
549
550
551#------------------------------------------------------------
552# Minkowski Distance
553cdef class MinkowskiDistance(DistanceMetric):
554    r"""Minkowski Distance
555
556    .. math::
557       D(x, y) = [\sum_i |x_i - y_i|^p] ^ (1/p)
558
559    Minkowski Distance requires p >= 1 and finite. For p = infinity,
560    use ChebyshevDistance.
561    Note that for p=1, ManhattanDistance is more efficient, and for
562    p=2, EuclideanDistance is more efficient.
563    """
564    def __init__(self, p):
565        if p < 1:
566            raise ValueError("p must be greater than 1")
567        elif np.isinf(p):
568            raise ValueError("MinkowskiDistance requires finite p. "
569                             "For p=inf, use ChebyshevDistance.")
570        self.p = p
571
572    cdef inline DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
573                              ITYPE_t size) nogil except -1:
574        cdef DTYPE_t d=0
575        cdef np.intp_t j
576        for j in range(size):
577            d += pow(fabs(x1[j] - x2[j]), self.p)
578        return d
579
580    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
581                             ITYPE_t size) nogil except -1:
582        return pow(self.rdist(x1, x2, size), 1. / self.p)
583
584    cdef inline DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
585        return pow(rdist, 1. / self.p)
586
587    cdef inline DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1:
588        return pow(dist, self.p)
589
590    def rdist_to_dist(self, rdist):
591        return rdist ** (1. / self.p)
592
593    def dist_to_rdist(self, dist):
594        return dist ** self.p
595
596
597#------------------------------------------------------------
598# W-Minkowski Distance
599cdef class WMinkowskiDistance(DistanceMetric):
600    r"""Weighted Minkowski Distance
601
602    .. math::
603       D(x, y) = [\sum_i |w_i * (x_i - y_i)|^p] ^ (1/p)
604
605    Weighted Minkowski Distance requires p >= 1 and finite.
606
607    Parameters
608    ----------
609    p : int
610        The order of the norm of the difference :math:`{||u-v||}_p`.
611    w : (N,) array-like
612        The weight vector.
613
614    """
615    def __init__(self, p, w):
616        if p < 1:
617            raise ValueError("p must be greater than 1")
618        elif np.isinf(p):
619            raise ValueError("WMinkowskiDistance requires finite p. "
620                             "For p=inf, use ChebyshevDistance.")
621        self.p = p
622        self.vec = ReadonlyArrayWrapper(np.asarray(w, dtype=DTYPE))
623        self.size = self.vec.shape[0]
624
625    def _validate_data(self, X):
626        if X.shape[1] != self.size:
627            raise ValueError('WMinkowskiDistance dist: '
628                             'size of w does not match')
629
630    cdef inline DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
631                              ITYPE_t size) nogil except -1:
632        cdef DTYPE_t d=0
633        cdef np.intp_t j
634        for j in range(size):
635            d += pow(self.vec[j] * fabs(x1[j] - x2[j]), self.p)
636        return d
637
638    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
639                             ITYPE_t size) nogil except -1:
640        return pow(self.rdist(x1, x2, size), 1. / self.p)
641
642    cdef inline DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
643        return pow(rdist, 1. / self.p)
644
645    cdef inline DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1:
646        return pow(dist, self.p)
647
648    def rdist_to_dist(self, rdist):
649        return rdist ** (1. / self.p)
650
651    def dist_to_rdist(self, dist):
652        return dist ** self.p
653
654
655#------------------------------------------------------------
656# Mahalanobis Distance
657#  d = sqrt( (x - y)^T V^-1 (x - y) )
658cdef class MahalanobisDistance(DistanceMetric):
659    """Mahalanobis Distance
660
661    .. math::
662       D(x, y) = \sqrt{ (x - y)^T V^{-1} (x - y) }
663
664    Parameters
665    ----------
666    V : array-like
667        Symmetric positive-definite covariance matrix.
668        The inverse of this matrix will be explicitly computed.
669    VI : array-like
670        optionally specify the inverse directly.  If VI is passed,
671        then V is not referenced.
672    """
673    def __init__(self, V=None, VI=None):
674        if VI is None:
675            if V is None:
676                raise ValueError("Must provide either V or VI "
677                                 "for Mahalanobis distance")
678            VI = np.linalg.inv(V)
679        if VI.ndim != 2 or VI.shape[0] != VI.shape[1]:
680            raise ValueError("V/VI must be square")
681
682        self.mat = ReadonlyArrayWrapper(np.asarray(VI, dtype=float, order='C'))
683
684        self.size = self.mat.shape[0]
685
686        # we need vec as a work buffer
687        self.vec = np.zeros(self.size, dtype=DTYPE)
688
689    def _validate_data(self, X):
690        if X.shape[1] != self.size:
691            raise ValueError('Mahalanobis dist: size of V does not match')
692
693    cdef inline DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
694                              ITYPE_t size) nogil except -1:
695        cdef DTYPE_t tmp, d = 0
696        cdef np.intp_t i, j
697
698        # compute (x1 - x2).T * VI * (x1 - x2)
699        for i in range(size):
700            self.vec[i] = x1[i] - x2[i]
701
702        for i in range(size):
703            tmp = 0
704            for j in range(size):
705                tmp += self.mat[i, j] * self.vec[j]
706            d += tmp * self.vec[i]
707        return d
708
709    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
710                             ITYPE_t size) nogil except -1:
711        return sqrt(self.rdist(x1, x2, size))
712
713    cdef inline DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
714        return sqrt(rdist)
715
716    cdef inline DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1:
717        return dist * dist
718
719    def rdist_to_dist(self, rdist):
720        return np.sqrt(rdist)
721
722    def dist_to_rdist(self, dist):
723        return dist ** 2
724
725
726#------------------------------------------------------------
727# Hamming Distance
728#  d = N_unequal(x, y) / N_tot
729cdef class HammingDistance(DistanceMetric):
730    r"""Hamming Distance
731
732    Hamming distance is meant for discrete-valued vectors, though it is
733    a valid metric for real-valued vectors.
734
735    .. math::
736       D(x, y) = \frac{1}{N} \sum_i \delta_{x_i, y_i}
737    """
738    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
739                             ITYPE_t size) nogil except -1:
740        cdef int n_unequal = 0
741        cdef np.intp_t j
742        for j in range(size):
743            if x1[j] != x2[j]:
744                n_unequal += 1
745        return float(n_unequal) / size
746
747
748#------------------------------------------------------------
749# Canberra Distance
750#  D(x, y) = sum[ abs(x_i - y_i) / (abs(x_i) + abs(y_i)) ]
751cdef class CanberraDistance(DistanceMetric):
752    r"""Canberra Distance
753
754    Canberra distance is meant for discrete-valued vectors, though it is
755    a valid metric for real-valued vectors.
756
757    .. math::
758       D(x, y) = \sum_i \frac{|x_i - y_i|}{|x_i| + |y_i|}
759    """
760    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
761                             ITYPE_t size) nogil except -1:
762        cdef DTYPE_t denom, d = 0
763        cdef np.intp_t j
764        for j in range(size):
765            denom = fabs(x1[j]) + fabs(x2[j])
766            if denom > 0:
767                d += fabs(x1[j] - x2[j]) / denom
768        return d
769
770
771#------------------------------------------------------------
772# Bray-Curtis Distance
773#  D(x, y) = sum[abs(x_i - y_i)] / sum[abs(x_i) + abs(y_i)]
774cdef class BrayCurtisDistance(DistanceMetric):
775    r"""Bray-Curtis Distance
776
777    Bray-Curtis distance is meant for discrete-valued vectors, though it is
778    a valid metric for real-valued vectors.
779
780    .. math::
781       D(x, y) = \frac{\sum_i |x_i - y_i|}{\sum_i(|x_i| + |y_i|)}
782    """
783    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
784                             ITYPE_t size) nogil except -1:
785        cdef DTYPE_t num = 0, denom = 0
786        cdef np.intp_t j
787        for j in range(size):
788            num += fabs(x1[j] - x2[j])
789            denom += fabs(x1[j]) + fabs(x2[j])
790        if denom > 0:
791            return num / denom
792        else:
793            return 0.0
794
795
796#------------------------------------------------------------
797# Jaccard Distance (boolean)
798#  D(x, y) = N_unequal(x, y) / N_nonzero(x, y)
799cdef class JaccardDistance(DistanceMetric):
800    r"""Jaccard Distance
801
802    Jaccard Distance is a dissimilarity measure for boolean-valued
803    vectors. All nonzero entries will be treated as True, zero entries will
804    be treated as False.
805
806    .. math::
807       D(x, y) = \frac{N_{TF} + N_{FT}}{N_{TT} + N_{TF} + N_{FT}}
808    """
809    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
810                             ITYPE_t size) nogil except -1:
811        cdef int tf1, tf2, n_eq = 0, nnz = 0
812        cdef np.intp_t j
813        for j in range(size):
814            tf1 = x1[j] != 0
815            tf2 = x2[j] != 0
816            nnz += (tf1 or tf2)
817            n_eq += (tf1 and tf2)
818        # Based on https://github.com/scipy/scipy/pull/7373
819        # When comparing two all-zero vectors, scipy>=1.2.0 jaccard metric
820        # was changed to return 0, instead of nan.
821        if nnz == 0:
822            return 0
823        return (nnz - n_eq) * 1.0 / nnz
824
825
826#------------------------------------------------------------
827# Matching Distance (boolean)
828#  D(x, y) = n_neq / n
829cdef class MatchingDistance(DistanceMetric):
830    r"""Matching Distance
831
832    Matching Distance is a dissimilarity measure for boolean-valued
833    vectors. All nonzero entries will be treated as True, zero entries will
834    be treated as False.
835
836    .. math::
837       D(x, y) = \frac{N_{TF} + N_{FT}}{N}
838    """
839    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
840                             ITYPE_t size) nogil except -1:
841        cdef int tf1, tf2, n_neq = 0
842        cdef np.intp_t j
843        for j in range(size):
844            tf1 = x1[j] != 0
845            tf2 = x2[j] != 0
846            n_neq += (tf1 != tf2)
847        return n_neq * 1. / size
848
849
850#------------------------------------------------------------
851# Dice Distance (boolean)
852#  D(x, y) = n_neq / (2 * ntt + n_neq)
853cdef class DiceDistance(DistanceMetric):
854    r"""Dice Distance
855
856    Dice Distance is a dissimilarity measure for boolean-valued
857    vectors. All nonzero entries will be treated as True, zero entries will
858    be treated as False.
859
860    .. math::
861       D(x, y) = \frac{N_{TF} + N_{FT}}{2 * N_{TT} + N_{TF} + N_{FT}}
862    """
863    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
864                             ITYPE_t size) nogil except -1:
865        cdef int tf1, tf2, n_neq = 0, ntt = 0
866        cdef np.intp_t j
867        for j in range(size):
868            tf1 = x1[j] != 0
869            tf2 = x2[j] != 0
870            ntt += (tf1 and tf2)
871            n_neq += (tf1 != tf2)
872        return n_neq / (2.0 * ntt + n_neq)
873
874
875#------------------------------------------------------------
876# Kulsinski Distance (boolean)
877#  D(x, y) = (ntf + nft - ntt + n) / (n_neq + n)
878cdef class KulsinskiDistance(DistanceMetric):
879    r"""Kulsinski Distance
880
881    Kulsinski Distance is a dissimilarity measure for boolean-valued
882    vectors. All nonzero entries will be treated as True, zero entries will
883    be treated as False.
884
885    .. math::
886       D(x, y) = 1 - \frac{N_{TT}}{N + N_{TF} + N_{FT}}
887    """
888    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
889                             ITYPE_t size) nogil except -1:
890        cdef int tf1, tf2, ntt = 0, n_neq = 0
891        cdef np.intp_t j
892        for j in range(size):
893            tf1 = x1[j] != 0
894            tf2 = x2[j] != 0
895            n_neq += (tf1 != tf2)
896            ntt += (tf1 and tf2)
897        return (n_neq - ntt + size) * 1.0 / (n_neq + size)
898
899
900#------------------------------------------------------------
901# Rogers-Tanimoto Distance (boolean)
902#  D(x, y) = 2 * n_neq / (n + n_neq)
903cdef class RogersTanimotoDistance(DistanceMetric):
904    r"""Rogers-Tanimoto Distance
905
906    Rogers-Tanimoto Distance is a dissimilarity measure for boolean-valued
907    vectors. All nonzero entries will be treated as True, zero entries will
908    be treated as False.
909
910    .. math::
911       D(x, y) = \frac{2 (N_{TF} + N_{FT})}{N + N_{TF} + N_{FT}}
912    """
913    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
914                             ITYPE_t size) nogil except -1:
915        cdef int tf1, tf2, n_neq = 0
916        cdef np.intp_t j
917        for j in range(size):
918            tf1 = x1[j] != 0
919            tf2 = x2[j] != 0
920            n_neq += (tf1 != tf2)
921        return (2.0 * n_neq) / (size + n_neq)
922
923
924#------------------------------------------------------------
925# Russell-Rao Distance (boolean)
926#  D(x, y) = (n - ntt) / n
927cdef class RussellRaoDistance(DistanceMetric):
928    r"""Russell-Rao Distance
929
930    Russell-Rao Distance is a dissimilarity measure for boolean-valued
931    vectors. All nonzero entries will be treated as True, zero entries will
932    be treated as False.
933
934    .. math::
935       D(x, y) = \frac{N - N_{TT}}{N}
936    """
937    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
938                             ITYPE_t size) nogil except -1:
939        cdef int tf1, tf2, ntt = 0
940        cdef np.intp_t j
941        for j in range(size):
942            tf1 = x1[j] != 0
943            tf2 = x2[j] != 0
944            ntt += (tf1 and tf2)
945        return (size - ntt) * 1. / size
946
947
948#------------------------------------------------------------
949# Sokal-Michener Distance (boolean)
950#  D(x, y) = 2 * n_neq / (n + n_neq)
951cdef class SokalMichenerDistance(DistanceMetric):
952    r"""Sokal-Michener Distance
953
954    Sokal-Michener Distance is a dissimilarity measure for boolean-valued
955    vectors. All nonzero entries will be treated as True, zero entries will
956    be treated as False.
957
958    .. math::
959       D(x, y) = \frac{2 (N_{TF} + N_{FT})}{N + N_{TF} + N_{FT}}
960    """
961    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
962                             ITYPE_t size) nogil except -1:
963        cdef int tf1, tf2, n_neq = 0
964        cdef np.intp_t j
965        for j in range(size):
966            tf1 = x1[j] != 0
967            tf2 = x2[j] != 0
968            n_neq += (tf1 != tf2)
969        return (2.0 * n_neq) / (size + n_neq)
970
971
972#------------------------------------------------------------
973# Sokal-Sneath Distance (boolean)
974#  D(x, y) = n_neq / (0.5 * n_tt + n_neq)
975cdef class SokalSneathDistance(DistanceMetric):
976    r"""Sokal-Sneath Distance
977
978    Sokal-Sneath Distance is a dissimilarity measure for boolean-valued
979    vectors. All nonzero entries will be treated as True, zero entries will
980    be treated as False.
981
982    .. math::
983       D(x, y) = \frac{N_{TF} + N_{FT}}{N_{TT} / 2 + N_{TF} + N_{FT}}
984    """
985    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
986                             ITYPE_t size) nogil except -1:
987        cdef int tf1, tf2, ntt = 0, n_neq = 0
988        cdef np.intp_t j
989        for j in range(size):
990            tf1 = x1[j] != 0
991            tf2 = x2[j] != 0
992            n_neq += (tf1 != tf2)
993            ntt += (tf1 and tf2)
994        return n_neq / (0.5 * ntt + n_neq)
995
996
997#------------------------------------------------------------
998# Haversine Distance (2 dimensional)
999#  D(x, y) = 2 arcsin{sqrt[sin^2 ((x1 - y1) / 2)
1000#                          + cos(x1) cos(y1) sin^2 ((x2 - y2) / 2)]}
1001cdef class HaversineDistance(DistanceMetric):
1002    """Haversine (Spherical) Distance
1003
1004    The Haversine distance is the angular distance between two points on
1005    the surface of a sphere.  The first distance of each point is assumed
1006    to be the latitude, the second is the longitude, given in radians.
1007    The dimension of the points must be 2:
1008
1009    .. math::
1010       D(x, y) = 2\\arcsin[\\sqrt{\\sin^2((x1 - y1) / 2)
1011                                + \\cos(x1)\\cos(y1)\\sin^2((x2 - y2) / 2)}]
1012    """
1013
1014    def _validate_data(self, X):
1015        if X.shape[1] != 2:
1016            raise ValueError("Haversine distance only valid "
1017                             "in 2 dimensions")
1018
1019    cdef inline DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
1020                              ITYPE_t size) nogil except -1:
1021        cdef DTYPE_t sin_0 = sin(0.5 * (x1[0] - x2[0]))
1022        cdef DTYPE_t sin_1 = sin(0.5 * (x1[1] - x2[1]))
1023        return (sin_0 * sin_0 + cos(x1[0]) * cos(x2[0]) * sin_1 * sin_1)
1024
1025    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
1026                             ITYPE_t size) nogil except -1:
1027        return 2 * asin(sqrt(self.rdist(x1, x2, size)))
1028
1029    cdef inline DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
1030        return 2 * asin(sqrt(rdist))
1031
1032    cdef inline DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1:
1033        cdef DTYPE_t tmp = sin(0.5 * dist)
1034        return tmp * tmp
1035
1036    def rdist_to_dist(self, rdist):
1037        return 2 * np.arcsin(np.sqrt(rdist))
1038
1039    def dist_to_rdist(self, dist):
1040        tmp = np.sin(0.5 * dist)
1041        return tmp * tmp
1042
1043
1044#------------------------------------------------------------
1045# Yule Distance (boolean)
1046#  D(x, y) = 2 * ntf * nft / (ntt * nff + ntf * nft)
1047# [This is not a true metric, so we will leave it out.]
1048#
1049#cdef class YuleDistance(DistanceMetric):
1050#    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
1051#                             ITYPE_t size):
1052#        cdef int tf1, tf2, ntf = 0, nft = 0, ntt = 0, nff = 0
1053#        cdef np.intp_t j
1054#        for j in range(size):
1055#            tf1 = x1[j] != 0
1056#            tf2 = x2[j] != 0
1057#            ntt += tf1 and tf2
1058#            ntf += tf1 and (tf2 == 0)
1059#            nft += (tf1 == 0) and tf2
1060#        nff = size - ntt - ntf - nft
1061#        return (2.0 * ntf * nft) / (ntt * nff + ntf * nft)
1062
1063
1064#------------------------------------------------------------
1065# Cosine Distance
1066#  D(x, y) = dot(x, y) / (|x| * |y|)
1067# [This is not a true metric, so we will leave it out.]
1068#
1069#cdef class CosineDistance(DistanceMetric):
1070#    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
1071#                             ITYPE_t size):
1072#        cdef DTYPE_t d = 0, norm1 = 0, norm2 = 0
1073#        cdef np.intp_t j
1074#        for j in range(size):
1075#            d += x1[j] * x2[j]
1076#            norm1 += x1[j] * x1[j]
1077#            norm2 += x2[j] * x2[j]
1078#        return 1.0 - d / sqrt(norm1 * norm2)
1079
1080
1081#------------------------------------------------------------
1082# Correlation Distance
1083#  D(x, y) = dot((x - mx), (y - my)) / (|x - mx| * |y - my|)
1084# [This is not a true metric, so we will leave it out.]
1085#
1086#cdef class CorrelationDistance(DistanceMetric):
1087#    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
1088#                             ITYPE_t size):
1089#        cdef DTYPE_t mu1 = 0, mu2 = 0, x1nrm = 0, x2nrm = 0, x1Tx2 = 0
1090#        cdef DTYPE_t tmp1, tmp2
1091#
1092#        cdef np.intp_t i
1093#        for i in range(size):
1094#            mu1 += x1[i]
1095#            mu2 += x2[i]
1096#        mu1 /= size
1097#        mu2 /= size
1098#
1099#        for i in range(size):
1100#            tmp1 = x1[i] - mu1
1101#            tmp2 = x2[i] - mu2
1102#            x1nrm += tmp1 * tmp1
1103#            x2nrm += tmp2 * tmp2
1104#            x1Tx2 += tmp1 * tmp2
1105#
1106#        return (1. - x1Tx2) / sqrt(x1nrm * x2nrm)
1107
1108
1109#------------------------------------------------------------
1110# User-defined distance
1111#
1112cdef class PyFuncDistance(DistanceMetric):
1113    """PyFunc Distance
1114
1115    A user-defined distance
1116
1117    Parameters
1118    ----------
1119    func : function
1120        func should take two numpy arrays as input, and return a distance.
1121    """
1122    def __init__(self, func, **kwargs):
1123        self.func = func
1124        self.kwargs = kwargs
1125
1126    # in cython < 0.26, GIL was required to be acquired during definition of
1127    # the function and inside the body of the function. This behaviour is not
1128    # allowed in cython >= 0.26 since it is a redundant GIL acquisition. The
1129    # only way to be back compatible is to inherit `dist` from the base class
1130    # without GIL and called an inline `_dist` which acquire GIL.
1131    cdef inline DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
1132                             ITYPE_t size) nogil except -1:
1133        return self._dist(x1, x2, size)
1134
1135    cdef inline DTYPE_t _dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
1136                              ITYPE_t size) except -1 with gil:
1137        cdef np.ndarray x1arr
1138        cdef np.ndarray x2arr
1139        x1arr = _buffer_to_ndarray(x1, size)
1140        x2arr = _buffer_to_ndarray(x2, size)
1141        d = self.func(x1arr, x2arr, **self.kwargs)
1142        try:
1143            # Cython generates code here that results in a TypeError
1144            # if d is the wrong type.
1145            return d
1146        except TypeError:
1147            raise TypeError("Custom distance function must accept two "
1148                            "vectors and return a float.")
1149
1150
1151cdef inline double fmax(double a, double b) nogil:
1152    return max(a, b)
1153