1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements.  See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership.  The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License.  You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied.  See the License for the
18# specific language governing permissions and limitations
19# under the License.
20
21import sys
22sys.path.insert(0, "../../python/")
23import mxnet as mx
24import numpy as np
25import numpy.random as rnd
26import copy
27
28from mxnet.test_utils import assert_almost_equal
29
30def check_diff_to_scalar(A, x, rank=None):
31    """ assert A == x"""
32    assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x)
33
34def compute_expected_2bit_quantization(arr, curr_residual, threshold):
35    from struct import pack,unpack
36    def bits2int(bits):
37        bits = [int(x) for x in bits[::-1]]
38        x = 0
39        for i in range(len(bits)):
40            x += bits[i]*2**i
41        return x
42
43    def as_float32(s):
44        return unpack("f",pack("I", bits2int(s)))[0]
45
46    # str_quant stores the quantized representation as a sequence of bits
47    str_quant = ''
48    new_residual = []
49    decompr = []
50
51    arr_npy = arr.asnumpy()
52    for i, a in np.ndenumerate(arr_npy):
53        a += curr_residual[i]
54        if a >= threshold:
55            str_quant += '11'
56            new_residual.append(a - threshold)
57            decompr.append(threshold)
58        elif a <= (-1*threshold):
59            str_quant += '10'
60            new_residual.append(a + threshold)
61            decompr.append(-1*threshold)
62        else:
63            str_quant += '00'
64            new_residual.append(a)
65            decompr.append(0)
66    # append extra bits when size of array not a factor of 16
67    if len(str_quant)%16 != 0:
68        str_quant += '0'*(16 - len(str_quant)%16)
69
70    compr = []
71    # converts the string generated into integers 32chars at a time
72    i = 0
73    while i<len(str_quant):
74        cur_float = str_quant[i+24:i+32] + str_quant[i+16:i+24] + str_quant[i+8:i+16] + str_quant[i:i+8]
75        compr.append(as_float32(cur_float))
76        i+=32
77    return np.array(compr), np.array(new_residual).reshape(arr.shape), np.array(decompr).reshape(arr.shape)
78
79## individual key interface
80def test_kvstore(kv_type, stype):
81    print(kv_type)
82    kv = mx.kv.create(kv_type)
83    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr))
84    for k, s in zip(keys, shapes):
85        kv.init(k, mx.nd.zeros(s))
86
87    res = [np.zeros(s) for s in shapes]
88    for i in range(nrepeat):
89        for j in range(len(keys)):
90            kv.push(keys[j], [mx.nd.array(
91                data[i][j][g], mx.gpu(g)).tostype(stype) for g in range(nworker)])
92
93        res = [a + b * lr for a, b in zip(res, [sum(d) for d in data[i]])]
94        for j in range(len(keys)):
95            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
96            kv.pull(keys[j], out=out)
97            err = [np.sum(np.abs(o.asnumpy() - res[j])) for o in out]
98            err = sum(err) / np.sum(np.abs(res[j]))
99            assert(err < 1e-6), (err, shapes[j])
100
101def test_compress_kvstore(kv_type, compression='2bit', threshold=0.5):
102    print(kv_type + ' with ' + compression + ' compression')
103    rate = 2
104    kv = mx.kv.create(kv_type)
105    kv.set_gradient_compression({'type':compression, 'threshold':threshold})
106    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate))
107    for k, s in zip(keys, shapes):
108        kv.init(k, mx.nd.zeros(s))
109    # init one key with 1s so we can check if it was compressed during init
110    kv.init(gc_init_test_key, mx.nd.ones(shapes[0]))
111    # use different keys for random tests so that
112    # we can track residual from start
113    random_keys = [13, 15, 17]
114    for k, s in zip(random_keys, shapes):
115        kv.init(k, mx.nd.zeros(s))
116
117    def pull_init_test(kv):
118        # checks that compression is not applied to init of key
119        out = [mx.nd.zeros(shapes[0], mx.gpu(g)) for g in range(nworker)]
120        kv.pull(gc_init_test_key, out=out)
121        exp = np.ones_like(out[0].asnumpy())
122        for o in out:
123            assert_almost_equal(o.asnumpy(), exp)
124
125    def pull_before_push(kv):
126        for i in range(nrepeat):
127            for j in range(len(keys)):
128                out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
129                kv.pull(keys[j], out=out)
130                exp = np.zeros_like(out[0].asnumpy())
131                for o in out:
132                    assert_almost_equal(o.asnumpy(), exp)
133
134    def push_zeros(kv):
135        for i in range(nrepeat):
136            for j in range(len(keys)):
137                kv.push(keys[j], [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)])
138                out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
139                kv.pull(keys[j], out=out)
140                exp = np.zeros_like(out[0].asnumpy())
141                for o in out:
142                    assert_almost_equal(o.asnumpy(), exp)
143
144    def verify_residual(kv, threshold, rate):
145        for j in range(len(keys)):
146            kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*0.4 for g in range(nworker)])
147            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
148            kv.pull(keys[j],out=out)
149            for o in out:
150                check_diff_to_scalar(o, 0)
151
152            kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(threshold-0.3) for g in range(nworker)])
153            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
154            kv.pull(keys[j],out=out)
155            curval = threshold * rate * nworker
156            for o in out:
157                check_diff_to_scalar(o, curval)
158
159            kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(0.2) for g in range(nworker)])
160            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
161            kv.pull(keys[j],out=out)
162            for o in out:
163                check_diff_to_scalar(o, curval)
164
165            kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(threshold-0.3) for g in range(nworker)])
166            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
167            kv.pull(keys[j],out=out)
168            curval += threshold*rate*nworker
169            for o in out:
170                check_diff_to_scalar(o, curval)
171            # residual would be 0 now
172        return curval
173
174    def check_neg(kv, neg, rate, curval):
175        for r in range(nrepeat):
176            curval = curval + rate*nworker*neg
177            for j in range(len(keys)):
178                kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*neg for g in range(nworker)])
179                out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
180                kv.pull(keys[j], out=out)
181                for o in out:
182                    check_diff_to_scalar(o, curval)
183            # residual would be 0 again
184
185    def check_compr_random(kv, threshold):
186        for k, s in zip(random_keys, shapes):
187            curr_residual = [np.zeros(s) for g in range(nworker)]
188            orig_val = [mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)]
189            kv.pull(k, out=orig_val)
190            grads = [mx.nd.random_uniform(-0.6, 0.6, shape=s, ctx=mx.gpu(g)) for g in range(nworker)]
191            grads_cpy = copy.deepcopy(grads)
192            kv.push(k, grads)
193            val = [mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)]
194            kv.pull(k, out=val)
195            diffs = [val[g] - orig_val[g] for g in range(nworker)]
196            # compute expected by using simulation of operator
197            # on cpu
198            sum_dequantized_vals = np.zeros(s)
199            for g in range(nworker):
200                compr, curr_residual[g], decompr = compute_expected_2bit_quantization(
201                                                    grads_cpy[g], curr_residual[g], threshold)
202                sum_dequantized_vals += (decompr * rate)
203
204            for g in range(nworker):
205                assert_almost_equal(diffs[g].asnumpy(), sum_dequantized_vals)
206
207    pull_init_test(kv)
208    pull_before_push(kv)
209    push_zeros(kv)
210    curval = verify_residual(kv, threshold, rate)
211    check_neg(kv, -1*threshold, rate, curval)
212    check_compr_random(kv, threshold)
213
214## group keys interface
215def test_group_kvstore(kv_type, stype):
216    print(kv_type)
217    kv = mx.kv.create(kv_type)
218    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr))
219    kv.init(keys, [mx.nd.zeros(s) for s in shapes])
220    res = [np.zeros(s) for s in shapes]
221    out = [[mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] for s in shapes]
222    for i in range(nrepeat):
223        kv.push(keys, [[
224            mx.nd.array(data[i][j][g], mx.gpu(g)).tostype(stype) for g in range(nworker)]
225                       for j in range(len(keys))])
226
227        kv.pull(keys, out=out)
228        res = [a + b * lr for a, b in zip(res, [sum(d) for d in data[i]])]
229        for a, b in zip(res, out):
230            err = [np.sum(np.abs(o.asnumpy() - a)) for o in b]
231            err = sum(err) / np.sum(np.abs(a))
232            assert(err < 1e-6), (err, a.shape)
233
234if __name__ == "__main__":
235    keys = [3, 5, 7]
236    # let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND
237    shapes = [(4, 4), (100, 100), (2000, 2000)]
238    stypes = ['default', 'row_sparse']
239
240    gc_init_test_key = 9
241
242    lr = .1
243    nworker = 4
244    nrepeat = 10
245
246    # generate data
247    data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)]
248
249    for stype in stypes:
250        test_kvstore('local_update_cpu', stype)
251        test_kvstore('local_allreduce_cpu', stype)
252        test_kvstore('local_allreduce_device', stype)
253
254    ## compression for local kvstore happens only when reduce is on device
255    test_compress_kvstore('local_allreduce_device')
256    for stype in stypes:
257        test_group_kvstore('local_update_cpu', stype)
258        test_group_kvstore('local_allreduce_cpu', stype)
259        test_group_kvstore('local_allreduce_device', stype)
260