1import numpy as np 2import scipy.sparse as sp 3 4from Orange.projection import _som 5 6 7class SOM: 8 def __init__(self, dim_x, dim_y, 9 hexagonal=False, pca_init=True, random_seed=None): 10 self.dim_x = dim_x 11 self.dim_y = dim_y 12 self.weights = self.ssum_weights = None 13 self.hexagonal = hexagonal 14 self.pca_init = pca_init 15 self.random_seed = random_seed 16 17 def init_weights_random(self, x): 18 random = (np.random if self.random_seed is None 19 else np.random.RandomState(self.random_seed)) 20 self.weights = random.rand(self.dim_y, self.dim_x, x.shape[1]) 21 norms = np.sum(self.weights ** 2, axis=2) 22 norms[norms == 0] = 1 23 self.weights /= norms[:, :, None] 24 self.ssum_weights = np.ones((self.dim_y, self.dim_x)) 25 26 def init_weights_pca(self, x): 27 pc_length, pc = np.linalg.eig(np.cov(x.T)) 28 c0, c1, *_ = np.argsort(pc_length) 29 pc0, pc1 = np.real(pc[c0]), np.real(pc[c1]) 30 self.weights = np.empty((self.dim_y, self.dim_x, x.shape[1])) 31 for i, c1 in enumerate(np.linspace(-1, 1, self.dim_y)): 32 for j, c2 in enumerate(np.linspace(-1, 1, self.dim_x)): 33 self.weights[i, j] = c1 * pc0 + c2 * pc1 34 norms = np.sum(self.weights ** 2, axis=2) 35 norms[norms == 0] = 1 36 self.weights /= norms[:, :, None] 37 self.ssum_weights = np.ones((self.dim_y, self.dim_x)) 38 39 def fit(self, x, n_iterations, learning_rate=0.5, sigma=1.0, callback=None): 40 if sp.issparse(x): 41 f = _som.update_sparse_hex if self.hexagonal else _som.update_sparse 42 43 def update(decay): 44 f(self.weights, self.ssum_weights, x, 45 sigma / decay, learning_rate / decay) 46 else: 47 f = _som.update_hex if self.hexagonal else _som.update 48 49 def update(decay): 50 f(self.weights, x, 51 sigma / decay, learning_rate / decay) 52 53 if self.pca_init and not sp.issparse(x) and x.shape[1] > 1: 54 self.init_weights_pca(x) 55 else: 56 self.init_weights_random(x) 57 58 for iteration in range(n_iterations): 59 update(1 + iteration / (n_iterations / 2)) 60 if callback is not None and not callback(iteration / n_iterations): 61 break 62 63 def winners(self, x): 64 return self.winner_from_weights( 65 x, self.weights, self.ssum_weights, self.hexagonal) 66 67 @staticmethod 68 def winner_from_weights(x, weights, ssum_weights, hexagonal): 69 if sp.issparse(x): 70 return _som.get_winners_sparse( 71 weights, ssum_weights, x, int(hexagonal)) 72 else: 73 return _som.get_winners(weights, x, int(hexagonal)) 74