1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Licensed to the Apache Software Foundation (ASF) under one 5# or more contributor license agreements. See the NOTICE file 6# distributed with this work for additional information 7# regarding copyright ownership. The ASF licenses this file 8# to you under the Apache License, Version 2.0 (the 9# "License"); you may not use this file except in compliance 10# with the License. You may obtain a copy of the License at 11# 12# http://www.apache.org/licenses/LICENSE-2.0 13# 14# Unless required by applicable law or agreed to in writing, 15# software distributed under the License is distributed on an 16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 17# KIND, either express or implied. See the License for the 18# specific language governing permissions and limitations 19# under the License. 20 21import sys 22sys.path.insert(0, "../../python/") 23import mxnet as mx 24import numpy as np 25import numpy.random as rnd 26import copy 27 28from mxnet.test_utils import assert_almost_equal 29 30def check_diff_to_scalar(A, x, rank=None): 31 """ assert A == x""" 32 assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x) 33 34def compute_expected_2bit_quantization(arr, curr_residual, threshold): 35 from struct import pack,unpack 36 def bits2int(bits): 37 bits = [int(x) for x in bits[::-1]] 38 x = 0 39 for i in range(len(bits)): 40 x += bits[i]*2**i 41 return x 42 43 def as_float32(s): 44 return unpack("f",pack("I", bits2int(s)))[0] 45 46 # str_quant stores the quantized representation as a sequence of bits 47 str_quant = '' 48 new_residual = [] 49 decompr = [] 50 51 arr_npy = arr.asnumpy() 52 for i, a in np.ndenumerate(arr_npy): 53 a += curr_residual[i] 54 if a >= threshold: 55 str_quant += '11' 56 new_residual.append(a - threshold) 57 decompr.append(threshold) 58 elif a <= (-1*threshold): 59 str_quant += '10' 60 new_residual.append(a + threshold) 61 decompr.append(-1*threshold) 62 else: 63 str_quant += '00' 64 new_residual.append(a) 65 decompr.append(0) 66 # append extra bits when size of array not a factor of 16 67 if len(str_quant)%16 != 0: 68 str_quant += '0'*(16 - len(str_quant)%16) 69 70 compr = [] 71 # converts the string generated into integers 32chars at a time 72 i = 0 73 while i<len(str_quant): 74 cur_float = str_quant[i+24:i+32] + str_quant[i+16:i+24] + str_quant[i+8:i+16] + str_quant[i:i+8] 75 compr.append(as_float32(cur_float)) 76 i+=32 77 return np.array(compr), np.array(new_residual).reshape(arr.shape), np.array(decompr).reshape(arr.shape) 78 79## individual key interface 80def test_kvstore(kv_type, stype): 81 print(kv_type) 82 kv = mx.kv.create(kv_type) 83 kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr)) 84 for k, s in zip(keys, shapes): 85 kv.init(k, mx.nd.zeros(s)) 86 87 res = [np.zeros(s) for s in shapes] 88 for i in range(nrepeat): 89 for j in range(len(keys)): 90 kv.push(keys[j], [mx.nd.array( 91 data[i][j][g], mx.gpu(g)).tostype(stype) for g in range(nworker)]) 92 93 res = [a + b * lr for a, b in zip(res, [sum(d) for d in data[i]])] 94 for j in range(len(keys)): 95 out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] 96 kv.pull(keys[j], out=out) 97 err = [np.sum(np.abs(o.asnumpy() - res[j])) for o in out] 98 err = sum(err) / np.sum(np.abs(res[j])) 99 assert(err < 1e-6), (err, shapes[j]) 100 101def test_compress_kvstore(kv_type, compression='2bit', threshold=0.5): 102 print(kv_type + ' with ' + compression + ' compression') 103 rate = 2 104 kv = mx.kv.create(kv_type) 105 kv.set_gradient_compression({'type':compression, 'threshold':threshold}) 106 kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate)) 107 for k, s in zip(keys, shapes): 108 kv.init(k, mx.nd.zeros(s)) 109 # init one key with 1s so we can check if it was compressed during init 110 kv.init(gc_init_test_key, mx.nd.ones(shapes[0])) 111 # use different keys for random tests so that 112 # we can track residual from start 113 random_keys = [13, 15, 17] 114 for k, s in zip(random_keys, shapes): 115 kv.init(k, mx.nd.zeros(s)) 116 117 def pull_init_test(kv): 118 # checks that compression is not applied to init of key 119 out = [mx.nd.zeros(shapes[0], mx.gpu(g)) for g in range(nworker)] 120 kv.pull(gc_init_test_key, out=out) 121 exp = np.ones_like(out[0].asnumpy()) 122 for o in out: 123 assert_almost_equal(o.asnumpy(), exp) 124 125 def pull_before_push(kv): 126 for i in range(nrepeat): 127 for j in range(len(keys)): 128 out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)] 129 kv.pull(keys[j], out=out) 130 exp = np.zeros_like(out[0].asnumpy()) 131 for o in out: 132 assert_almost_equal(o.asnumpy(), exp) 133 134 def push_zeros(kv): 135 for i in range(nrepeat): 136 for j in range(len(keys)): 137 kv.push(keys[j], [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]) 138 out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)] 139 kv.pull(keys[j], out=out) 140 exp = np.zeros_like(out[0].asnumpy()) 141 for o in out: 142 assert_almost_equal(o.asnumpy(), exp) 143 144 def verify_residual(kv, threshold, rate): 145 for j in range(len(keys)): 146 kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*0.4 for g in range(nworker)]) 147 out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] 148 kv.pull(keys[j],out=out) 149 for o in out: 150 check_diff_to_scalar(o, 0) 151 152 kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(threshold-0.3) for g in range(nworker)]) 153 out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] 154 kv.pull(keys[j],out=out) 155 curval = threshold * rate * nworker 156 for o in out: 157 check_diff_to_scalar(o, curval) 158 159 kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(0.2) for g in range(nworker)]) 160 out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] 161 kv.pull(keys[j],out=out) 162 for o in out: 163 check_diff_to_scalar(o, curval) 164 165 kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(threshold-0.3) for g in range(nworker)]) 166 out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] 167 kv.pull(keys[j],out=out) 168 curval += threshold*rate*nworker 169 for o in out: 170 check_diff_to_scalar(o, curval) 171 # residual would be 0 now 172 return curval 173 174 def check_neg(kv, neg, rate, curval): 175 for r in range(nrepeat): 176 curval = curval + rate*nworker*neg 177 for j in range(len(keys)): 178 kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*neg for g in range(nworker)]) 179 out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)] 180 kv.pull(keys[j], out=out) 181 for o in out: 182 check_diff_to_scalar(o, curval) 183 # residual would be 0 again 184 185 def check_compr_random(kv, threshold): 186 for k, s in zip(random_keys, shapes): 187 curr_residual = [np.zeros(s) for g in range(nworker)] 188 orig_val = [mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] 189 kv.pull(k, out=orig_val) 190 grads = [mx.nd.random_uniform(-0.6, 0.6, shape=s, ctx=mx.gpu(g)) for g in range(nworker)] 191 grads_cpy = copy.deepcopy(grads) 192 kv.push(k, grads) 193 val = [mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] 194 kv.pull(k, out=val) 195 diffs = [val[g] - orig_val[g] for g in range(nworker)] 196 # compute expected by using simulation of operator 197 # on cpu 198 sum_dequantized_vals = np.zeros(s) 199 for g in range(nworker): 200 compr, curr_residual[g], decompr = compute_expected_2bit_quantization( 201 grads_cpy[g], curr_residual[g], threshold) 202 sum_dequantized_vals += (decompr * rate) 203 204 for g in range(nworker): 205 assert_almost_equal(diffs[g].asnumpy(), sum_dequantized_vals) 206 207 pull_init_test(kv) 208 pull_before_push(kv) 209 push_zeros(kv) 210 curval = verify_residual(kv, threshold, rate) 211 check_neg(kv, -1*threshold, rate, curval) 212 check_compr_random(kv, threshold) 213 214## group keys interface 215def test_group_kvstore(kv_type, stype): 216 print(kv_type) 217 kv = mx.kv.create(kv_type) 218 kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr)) 219 kv.init(keys, [mx.nd.zeros(s) for s in shapes]) 220 res = [np.zeros(s) for s in shapes] 221 out = [[mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] for s in shapes] 222 for i in range(nrepeat): 223 kv.push(keys, [[ 224 mx.nd.array(data[i][j][g], mx.gpu(g)).tostype(stype) for g in range(nworker)] 225 for j in range(len(keys))]) 226 227 kv.pull(keys, out=out) 228 res = [a + b * lr for a, b in zip(res, [sum(d) for d in data[i]])] 229 for a, b in zip(res, out): 230 err = [np.sum(np.abs(o.asnumpy() - a)) for o in b] 231 err = sum(err) / np.sum(np.abs(a)) 232 assert(err < 1e-6), (err, a.shape) 233 234if __name__ == "__main__": 235 keys = [3, 5, 7] 236 # let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND 237 shapes = [(4, 4), (100, 100), (2000, 2000)] 238 stypes = ['default', 'row_sparse'] 239 240 gc_init_test_key = 9 241 242 lr = .1 243 nworker = 4 244 nrepeat = 10 245 246 # generate data 247 data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)] 248 249 for stype in stypes: 250 test_kvstore('local_update_cpu', stype) 251 test_kvstore('local_allreduce_cpu', stype) 252 test_kvstore('local_allreduce_device', stype) 253 254 ## compression for local kvstore happens only when reduce is on device 255 test_compress_kvstore('local_allreduce_device') 256 for stype in stypes: 257 test_group_kvstore('local_update_cpu', stype) 258 test_group_kvstore('local_allreduce_cpu', stype) 259 test_group_kvstore('local_allreduce_device', stype) 260