1# cython: boundscheck=False
2# cython: wraparound=False
3# cython: cdivision=True
4# cython: initializedcheck=False
5# cython: warn.undeclared=True
6# cython: language_level=3
7cimport numpy as np
8import numpy as np
9from .quad_tree cimport QuadTree
10from ._tsne cimport (
11    estimate_negative_gradient_bh,
12    estimate_negative_gradient_fft_1d,
13    estimate_negative_gradient_fft_2d,
14)
15# This returns a tuple, and can"t be called from C
16from ._tsne import estimate_positive_gradient_nn
17
18
19cdef double EPSILON = np.finfo(np.float64).eps
20
21cdef extern from "math.h":
22    double log(double x) nogil
23
24
25cdef sqeuclidean(double[:] x, double[:] y):
26    cdef:
27        Py_ssize_t n_dims = x.shape[0]
28        double result = 0
29        Py_ssize_t i
30
31    for i in range(n_dims):
32        result += (x[i] - y[i]) ** 2
33
34    return result
35
36
37cpdef double kl_divergence_exact(double[:, ::1] P, double[:, ::1] embedding):
38    """Compute the exact KL divergence."""
39    cdef:
40        Py_ssize_t n_samples = embedding.shape[0]
41        Py_ssize_t i, j
42
43        double sum_P = 0, sum_Q = 0, p_ij, q_ij
44        double kl_divergence = 0
45
46    for i in range(n_samples):
47        for j in range(n_samples):
48            if i != j:
49                p_ij = P[i, j]
50                q_ij = 1 / (1 + sqeuclidean(embedding[i], embedding[j]))
51                sum_Q += q_ij
52                sum_P += p_ij
53                if p_ij > 0:
54                    kl_divergence += p_ij * log(p_ij / (q_ij + EPSILON))
55
56    kl_divergence += sum_P * log(sum_Q + EPSILON)
57
58    return kl_divergence
59
60
61cpdef double kl_divergence_approx_bh(
62    int[:] indices,
63    int[:] indptr,
64    double[:] P_data,
65    double[:, ::1] embedding,
66    double theta=0.5,
67    double dof=1,
68):
69    """Compute the KL divergence using the Barnes-Hut approximation."""
70    cdef:
71        Py_ssize_t n_samples = embedding.shape[0]
72        Py_ssize_t i, j
73
74        QuadTree tree = QuadTree(embedding)
75        # We don"t actually care about the gradient, so don"t waste time
76        # initializing memory
77        double[:, ::1] gradient = np.empty_like(embedding, dtype=float)
78
79        double sum_P = 0, sum_Q = 0
80        double kl_divergence = 0
81
82    sum_Q = estimate_negative_gradient_bh(tree, embedding, gradient, theta, dof)
83    sum_P, kl_divergence = estimate_positive_gradient_nn(
84        indices,
85        indptr,
86        P_data,
87        embedding,
88        embedding,
89        gradient,
90        dof=dof,
91        should_eval_error=True,
92    )
93
94    kl_divergence += sum_P * log(sum_Q + EPSILON)
95
96    return kl_divergence
97
98
99
100cpdef double kl_divergence_approx_fft(
101    int[:] indices,
102    int[:] indptr,
103    double[:] P_data,
104    double[:, ::1] embedding,
105    double dof=1,
106    Py_ssize_t n_interpolation_points=3,
107    Py_ssize_t min_num_intervals=10,
108    double ints_in_interval=1,
109):
110    """Compute the KL divergence using the interpolation based approximation."""
111    cdef:
112        Py_ssize_t n_samples = embedding.shape[0]
113        Py_ssize_t n_dims = embedding.shape[1]
114        Py_ssize_t i, j
115
116        # We don"t actually care about the gradient, so don"t waste time
117        # initializing memory
118        double[:, ::1] gradient = np.empty_like(embedding, dtype=float)
119
120        double sum_P = 0, sum_Q = 0
121        double kl_divergence = 0
122
123
124    if n_dims == 1:
125        sum_Q = estimate_negative_gradient_fft_1d(
126            embedding.ravel(),
127            gradient.ravel(),
128            n_interpolation_points,
129            min_num_intervals,
130            ints_in_interval,
131            dof,
132        )
133    elif n_dims == 2:
134        sum_Q = estimate_negative_gradient_fft_2d(
135            embedding,
136            gradient,
137            n_interpolation_points,
138            min_num_intervals,
139            ints_in_interval,
140            dof,
141        )
142    else:
143        return -1
144
145    sum_P, kl_divergence = estimate_positive_gradient_nn(
146        indices,
147        indptr,
148        P_data,
149        embedding,
150        embedding,
151        gradient,
152        dof=dof,
153        should_eval_error=True,
154    )
155
156    kl_divergence += sum_P * log(sum_Q + EPSILON)
157
158    return kl_divergence
159