1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: skip-file 19from __future__ import print_function 20 21import os 22import logging 23import numpy as np 24from sklearn.cluster import KMeans 25from scipy.spatial.distance import cdist 26import mxnet as mx 27import data 28import model 29from autoencoder import AutoEncoderModel 30from solver import Solver, Monitor 31 32 33def cluster_acc(Y_pred, Y): 34 from sklearn.utils.linear_assignment_ import linear_assignment 35 assert Y_pred.size == Y.size 36 D = max(Y_pred.max(), Y.max())+1 37 w = np.zeros((D, D), dtype=np.int64) 38 for i in range(Y_pred.size): 39 w[Y_pred[i], int(Y[i])] += 1 40 ind = linear_assignment(w.max() - w) 41 return sum([w[i, j] for i, j in ind])*1.0/Y_pred.size, w 42 43 44class DECModel(model.MXModel): 45 class DECLoss(mx.operator.NumpyOp): 46 def __init__(self, num_centers, alpha): 47 super(DECModel.DECLoss, self).__init__(need_top_grad=False) 48 self.num_centers = num_centers 49 self.alpha = alpha 50 51 def forward(self, in_data, out_data): 52 z = in_data[0] 53 mu = in_data[1] 54 q = out_data[0] 55 self.mask = 1.0/(1.0+cdist(z, mu)**2/self.alpha) 56 q[:] = self.mask**((self.alpha+1.0)/2.0) 57 q[:] = (q.T/q.sum(axis=1)).T 58 59 def backward(self, out_grad, in_data, out_data, in_grad): 60 q = out_data[0] 61 z = in_data[0] 62 mu = in_data[1] 63 p = in_data[2] 64 dz = in_grad[0] 65 dmu = in_grad[1] 66 self.mask *= (self.alpha+1.0)/self.alpha*(p-q) 67 dz[:] = (z.T*self.mask.sum(axis=1)).T - self.mask.dot(mu) 68 dmu[:] = (mu.T*self.mask.sum(axis=0)).T - self.mask.T.dot(z) 69 70 def infer_shape(self, in_shape): 71 assert len(in_shape) == 3 72 assert len(in_shape[0]) == 2 73 input_shape = in_shape[0] 74 label_shape = (input_shape[0], self.num_centers) 75 mu_shape = (self.num_centers, input_shape[1]) 76 out_shape = (input_shape[0], self.num_centers) 77 return [input_shape, mu_shape, label_shape], [out_shape] 78 79 def list_arguments(self): 80 return ['data', 'mu', 'label'] 81 82 def setup(self, X, num_centers, alpha, save_to='dec_model'): 83 sep = X.shape[0]*9//10 84 X_train = X[:sep] 85 X_val = X[sep:] 86 ae_model = AutoEncoderModel(self.xpu, [X.shape[1], 500, 500, 2000, 10], pt_dropout=0.2) 87 if not os.path.exists(save_to+'_pt.arg'): 88 ae_model.layerwise_pretrain(X_train, 256, 50000, 'sgd', l_rate=0.1, decay=0.0, 89 lr_scheduler=mx.lr_scheduler.FactorScheduler(20000, 0.1)) 90 ae_model.finetune(X_train, 256, 100000, 'sgd', l_rate=0.1, decay=0.0, 91 lr_scheduler=mx.lr_scheduler.FactorScheduler(20000, 0.1)) 92 ae_model.save(save_to+'_pt.arg') 93 logging.log(logging.INFO, "Autoencoder Training error: %f"%ae_model.eval(X_train)) 94 logging.log(logging.INFO, "Autoencoder Validation error: %f"%ae_model.eval(X_val)) 95 else: 96 ae_model.load(save_to+'_pt.arg') 97 self.ae_model = ae_model 98 99 self.dec_op = DECModel.DECLoss(num_centers, alpha) 100 label = mx.sym.Variable('label') 101 self.feature = self.ae_model.encoder 102 self.loss = self.dec_op(data=self.ae_model.encoder, label=label, name='dec') 103 self.args.update({k: v for k, v in self.ae_model.args.items() if k in self.ae_model.encoder.list_arguments()}) 104 self.args['dec_mu'] = mx.nd.empty((num_centers, self.ae_model.dims[-1]), ctx=self.xpu) 105 self.args_grad.update({k: mx.nd.empty(v.shape, ctx=self.xpu) for k, v in self.args.items()}) 106 self.args_mult.update({k: k.endswith('bias') and 2.0 or 1.0 for k in self.args}) 107 self.num_centers = num_centers 108 109 def cluster(self, X, y=None, update_interval=None): 110 N = X.shape[0] 111 if not update_interval: 112 update_interval = N 113 batch_size = 256 114 test_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False, 115 last_batch_handle='pad') 116 args = {k: mx.nd.array(v.asnumpy(), ctx=self.xpu) for k, v in self.args.items()} 117 z = list(model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values())[0] 118 kmeans = KMeans(self.num_centers, n_init=20) 119 kmeans.fit(z) 120 args['dec_mu'][:] = kmeans.cluster_centers_ 121 solver = Solver('sgd', momentum=0.9, wd=0.0, learning_rate=0.01) 122 123 def ce(label, pred): 124 return np.sum(label*np.log(label/(pred+0.000001)))/label.shape[0] 125 solver.set_metric(mx.metric.CustomMetric(ce)) 126 127 label_buff = np.zeros((X.shape[0], self.num_centers)) 128 train_iter = mx.io.NDArrayIter({'data': X}, {'label': label_buff}, batch_size=batch_size, 129 shuffle=False, last_batch_handle='roll_over') 130 self.y_pred = np.zeros((X.shape[0])) 131 132 def refresh(i): 133 if i%update_interval == 0: 134 z = list(model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values())[0] 135 p = np.zeros((z.shape[0], self.num_centers)) 136 self.dec_op.forward([z, args['dec_mu'].asnumpy()], [p]) 137 y_pred = p.argmax(axis=1) 138 print(np.std(np.bincount(y_pred)), np.bincount(y_pred)) 139 print(np.std(np.bincount(y.astype(np.int))), np.bincount(y.astype(np.int))) 140 if y is not None: 141 print(cluster_acc(y_pred, y)[0]) 142 weight = 1.0/p.sum(axis=0) 143 weight *= self.num_centers/weight.sum() 144 p = (p**2)*weight 145 train_iter.data_list[1][:] = (p.T/p.sum(axis=1)).T 146 print(np.sum(y_pred != self.y_pred), 0.001*y_pred.shape[0]) 147 if np.sum(y_pred != self.y_pred) < 0.001*y_pred.shape[0]: 148 self.y_pred = y_pred 149 return True 150 self.y_pred = y_pred 151 solver.set_iter_start_callback(refresh) 152 solver.set_monitor(Monitor(50)) 153 154 solver.solve(self.xpu, self.loss, args, self.args_grad, None, 155 train_iter, 0, 1000000000, {}, False) 156 self.end_args = args 157 if y is not None: 158 return cluster_acc(self.y_pred, y)[0] 159 else: 160 return -1 161 162 163def mnist_exp(xpu): 164 X, Y = data.get_mnist() 165 if not os.path.isdir('data'): 166 os.makedirs('data') 167 dec_model = DECModel(xpu, X, 10, 1.0, 'data/mnist') 168 acc = [] 169 for i in [10*(2**j) for j in range(9)]: 170 acc.append(dec_model.cluster(X, Y, i)) 171 logging.log(logging.INFO, 'Clustering Acc: %f at update interval: %d'%(acc[-1], i)) 172 logging.info(str(acc)) 173 logging.info('Best Clustering ACC: %f at update_interval: %d'%(np.max(acc), 10*(2**np.argmax(acc)))) 174 175 176if __name__ == '__main__': 177 logging.basicConfig(level=logging.INFO) 178 mnist_exp(mx.gpu(0)) 179