1import numpy as np
2import scipy.sparse as sp
3
4from Orange.projection import _som
5
6
7class SOM:
8    def __init__(self, dim_x, dim_y,
9                 hexagonal=False, pca_init=True, random_seed=None):
10        self.dim_x = dim_x
11        self.dim_y = dim_y
12        self.weights = self.ssum_weights = None
13        self.hexagonal = hexagonal
14        self.pca_init = pca_init
15        self.random_seed = random_seed
16
17    def init_weights_random(self, x):
18        random = (np.random if self.random_seed is None
19                  else np.random.RandomState(self.random_seed))
20        self.weights = random.rand(self.dim_y, self.dim_x, x.shape[1])
21        norms = np.sum(self.weights ** 2, axis=2)
22        norms[norms == 0] = 1
23        self.weights /= norms[:, :, None]
24        self.ssum_weights = np.ones((self.dim_y, self.dim_x))
25
26    def init_weights_pca(self, x):
27        pc_length, pc = np.linalg.eig(np.cov(x.T))
28        c0, c1, *_ = np.argsort(pc_length)
29        pc0, pc1 = np.real(pc[c0]), np.real(pc[c1])
30        self.weights = np.empty((self.dim_y, self.dim_x, x.shape[1]))
31        for i, c1 in enumerate(np.linspace(-1, 1, self.dim_y)):
32            for j, c2 in enumerate(np.linspace(-1, 1, self.dim_x)):
33                self.weights[i, j] = c1 * pc0 + c2 * pc1
34        norms = np.sum(self.weights ** 2, axis=2)
35        norms[norms == 0] = 1
36        self.weights /= norms[:, :, None]
37        self.ssum_weights = np.ones((self.dim_y, self.dim_x))
38
39    def fit(self, x, n_iterations, learning_rate=0.5, sigma=1.0, callback=None):
40        if sp.issparse(x):
41            f = _som.update_sparse_hex if self.hexagonal else _som.update_sparse
42
43            def update(decay):
44                f(self.weights, self.ssum_weights, x,
45                  sigma / decay, learning_rate / decay)
46        else:
47            f = _som.update_hex if self.hexagonal else _som.update
48
49            def update(decay):
50                f(self.weights, x,
51                  sigma / decay, learning_rate / decay)
52
53        if self.pca_init and not sp.issparse(x) and x.shape[1] > 1:
54            self.init_weights_pca(x)
55        else:
56            self.init_weights_random(x)
57
58        for iteration in range(n_iterations):
59            update(1 + iteration / (n_iterations / 2))
60            if callback is not None and not callback(iteration / n_iterations):
61                break
62
63    def winners(self, x):
64        return self.winner_from_weights(
65            x, self.weights, self.ssum_weights, self.hexagonal)
66
67    @staticmethod
68    def winner_from_weights(x, weights, ssum_weights, hexagonal):
69        if sp.issparse(x):
70            return _som.get_winners_sparse(
71                weights, ssum_weights, x, int(hexagonal))
72        else:
73            return _som.get_winners(weights, x, int(hexagonal))
74