1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: skip-file
19from __future__ import print_function
20
21import os
22import logging
23import numpy as np
24from sklearn.cluster import KMeans
25from scipy.spatial.distance import cdist
26import mxnet as mx
27import data
28import model
29from autoencoder import AutoEncoderModel
30from solver import Solver, Monitor
31
32
33def cluster_acc(Y_pred, Y):
34    from sklearn.utils.linear_assignment_ import linear_assignment
35    assert Y_pred.size == Y.size
36    D = max(Y_pred.max(), Y.max())+1
37    w = np.zeros((D, D), dtype=np.int64)
38    for i in range(Y_pred.size):
39        w[Y_pred[i], int(Y[i])] += 1
40    ind = linear_assignment(w.max() - w)
41    return sum([w[i, j] for i, j in ind])*1.0/Y_pred.size, w
42
43
44class DECModel(model.MXModel):
45    class DECLoss(mx.operator.NumpyOp):
46        def __init__(self, num_centers, alpha):
47            super(DECModel.DECLoss, self).__init__(need_top_grad=False)
48            self.num_centers = num_centers
49            self.alpha = alpha
50
51        def forward(self, in_data, out_data):
52            z = in_data[0]
53            mu = in_data[1]
54            q = out_data[0]
55            self.mask = 1.0/(1.0+cdist(z, mu)**2/self.alpha)
56            q[:] = self.mask**((self.alpha+1.0)/2.0)
57            q[:] = (q.T/q.sum(axis=1)).T
58
59        def backward(self, out_grad, in_data, out_data, in_grad):
60            q = out_data[0]
61            z = in_data[0]
62            mu = in_data[1]
63            p = in_data[2]
64            dz = in_grad[0]
65            dmu = in_grad[1]
66            self.mask *= (self.alpha+1.0)/self.alpha*(p-q)
67            dz[:] = (z.T*self.mask.sum(axis=1)).T - self.mask.dot(mu)
68            dmu[:] = (mu.T*self.mask.sum(axis=0)).T - self.mask.T.dot(z)
69
70        def infer_shape(self, in_shape):
71            assert len(in_shape) == 3
72            assert len(in_shape[0]) == 2
73            input_shape = in_shape[0]
74            label_shape = (input_shape[0], self.num_centers)
75            mu_shape = (self.num_centers, input_shape[1])
76            out_shape = (input_shape[0], self.num_centers)
77            return [input_shape, mu_shape, label_shape], [out_shape]
78
79        def list_arguments(self):
80            return ['data', 'mu', 'label']
81
82    def setup(self, X, num_centers, alpha, save_to='dec_model'):
83        sep = X.shape[0]*9//10
84        X_train = X[:sep]
85        X_val = X[sep:]
86        ae_model = AutoEncoderModel(self.xpu, [X.shape[1], 500, 500, 2000, 10], pt_dropout=0.2)
87        if not os.path.exists(save_to+'_pt.arg'):
88            ae_model.layerwise_pretrain(X_train, 256, 50000, 'sgd', l_rate=0.1, decay=0.0,
89                                        lr_scheduler=mx.lr_scheduler.FactorScheduler(20000, 0.1))
90            ae_model.finetune(X_train, 256, 100000, 'sgd', l_rate=0.1, decay=0.0,
91                              lr_scheduler=mx.lr_scheduler.FactorScheduler(20000, 0.1))
92            ae_model.save(save_to+'_pt.arg')
93            logging.log(logging.INFO, "Autoencoder Training error: %f"%ae_model.eval(X_train))
94            logging.log(logging.INFO, "Autoencoder Validation error: %f"%ae_model.eval(X_val))
95        else:
96            ae_model.load(save_to+'_pt.arg')
97        self.ae_model = ae_model
98
99        self.dec_op = DECModel.DECLoss(num_centers, alpha)
100        label = mx.sym.Variable('label')
101        self.feature = self.ae_model.encoder
102        self.loss = self.dec_op(data=self.ae_model.encoder, label=label, name='dec')
103        self.args.update({k: v for k, v in self.ae_model.args.items() if k in self.ae_model.encoder.list_arguments()})
104        self.args['dec_mu'] = mx.nd.empty((num_centers, self.ae_model.dims[-1]), ctx=self.xpu)
105        self.args_grad.update({k: mx.nd.empty(v.shape, ctx=self.xpu) for k, v in self.args.items()})
106        self.args_mult.update({k: k.endswith('bias') and 2.0 or 1.0 for k in self.args})
107        self.num_centers = num_centers
108
109    def cluster(self, X, y=None, update_interval=None):
110        N = X.shape[0]
111        if not update_interval:
112            update_interval = N
113        batch_size = 256
114        test_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False,
115                                      last_batch_handle='pad')
116        args = {k: mx.nd.array(v.asnumpy(), ctx=self.xpu) for k, v in self.args.items()}
117        z = list(model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values())[0]
118        kmeans = KMeans(self.num_centers, n_init=20)
119        kmeans.fit(z)
120        args['dec_mu'][:] = kmeans.cluster_centers_
121        solver = Solver('sgd', momentum=0.9, wd=0.0, learning_rate=0.01)
122
123        def ce(label, pred):
124            return np.sum(label*np.log(label/(pred+0.000001)))/label.shape[0]
125        solver.set_metric(mx.metric.CustomMetric(ce))
126
127        label_buff = np.zeros((X.shape[0], self.num_centers))
128        train_iter = mx.io.NDArrayIter({'data': X}, {'label': label_buff}, batch_size=batch_size,
129                                       shuffle=False, last_batch_handle='roll_over')
130        self.y_pred = np.zeros((X.shape[0]))
131
132        def refresh(i):
133            if i%update_interval == 0:
134                z = list(model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values())[0]
135                p = np.zeros((z.shape[0], self.num_centers))
136                self.dec_op.forward([z, args['dec_mu'].asnumpy()], [p])
137                y_pred = p.argmax(axis=1)
138                print(np.std(np.bincount(y_pred)), np.bincount(y_pred))
139                print(np.std(np.bincount(y.astype(np.int))), np.bincount(y.astype(np.int)))
140                if y is not None:
141                    print(cluster_acc(y_pred, y)[0])
142                weight = 1.0/p.sum(axis=0)
143                weight *= self.num_centers/weight.sum()
144                p = (p**2)*weight
145                train_iter.data_list[1][:] = (p.T/p.sum(axis=1)).T
146                print(np.sum(y_pred != self.y_pred), 0.001*y_pred.shape[0])
147                if np.sum(y_pred != self.y_pred) < 0.001*y_pred.shape[0]:
148                    self.y_pred = y_pred
149                    return True
150                self.y_pred = y_pred
151        solver.set_iter_start_callback(refresh)
152        solver.set_monitor(Monitor(50))
153
154        solver.solve(self.xpu, self.loss, args, self.args_grad, None,
155                     train_iter, 0, 1000000000, {}, False)
156        self.end_args = args
157        if y is not None:
158            return cluster_acc(self.y_pred, y)[0]
159        else:
160            return -1
161
162
163def mnist_exp(xpu):
164    X, Y = data.get_mnist()
165    if not os.path.isdir('data'):
166        os.makedirs('data')
167    dec_model = DECModel(xpu, X, 10, 1.0, 'data/mnist')
168    acc = []
169    for i in [10*(2**j) for j in range(9)]:
170        acc.append(dec_model.cluster(X, Y, i))
171        logging.log(logging.INFO, 'Clustering Acc: %f at update interval: %d'%(acc[-1], i))
172    logging.info(str(acc))
173    logging.info('Best Clustering ACC: %f at update_interval: %d'%(np.max(acc), 10*(2**np.argmax(acc))))
174
175
176if __name__ == '__main__':
177    logging.basicConfig(level=logging.INFO)
178    mnist_exp(mx.gpu(0))
179