1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import Base.push! 19 20""" 21 KVStore(kv_type = :local) 22 23For single machine training, there are two commonly used types: 24 25- `local`: Copies all gradients to CPU memory and updates weights there. 26 27- `device`: Aggregates gradients and updates weights on GPU(s). 28 With this setting, the `KVStore` also attempts to use GPU peer-to-peer 29 communication, potentially accelerating the communication. 30 31For distributed training, `KVStore` also supports a number of types: 32 33- `dist_sync`: Behaves similarly to `local` but with one major difference. 34 With `dist_sync`, batch-size now means the batch size used on each machine. 35 So if there are `n` machines and we use batch size ``b``, 36 then `dist_sync` behaves like `local` with batch size `n * b`. 37 38- `dist_device_sync`: Identical to `dist_sync` with the difference similar 39 to `device` vs `local`. 40 41- `dist_async`: Performs asynchronous updates. 42 The weights are updated whenever gradients are received from any machine. 43 No two updates happen on the same weight at the same time. 44 However, the order is not guaranteed. 45""" 46mutable struct KVStore 47 handle :: MX_KVStoreHandle 48 updater_c :: Ptr{Cvoid} 49 updater :: Function 50 51 KVStore(hdr::MX_KVStoreHandle) = new(hdr, Ptr{Cvoid}(0)) 52end 53 54function KVStore(kv_type::Symbol = :local) 55 @assert kv_type ∈ (:local, :device, :dist_sync, :dist_device_sync, :dist_async) 56 ref_hdr = Ref{MX_handle}(0) 57 @mxcall(:MXKVStoreCreate, (char_p, Ref{MX_handle}), dump_mx_param(kv_type), ref_hdr) 58 KVStore(MX_KVStoreHandle(ref_hdr[])) 59end 60 61Base.unsafe_convert(::Type{MX_handle}, obj::KVStore) = 62 Base.unsafe_convert(MX_handle, obj.handle) 63Base.convert(t::Type{MX_handle}, obj::KVStore) = Base.unsafe_convert(t, obj) 64Base.cconvert(t::Type{MX_handle}, obj::KVStore) = Base.unsafe_convert(t, obj) 65 66Base.show(io::IO, kv::KVStore) = 67 print(io, "mx.KVStore @ $(get_type(kv))") 68 69function _flatten_kvlist(keys::Vector{Int}, vals::Vector{<:Vector{<:NDArray}}) 70 @assert length(keys) == length(vals) 71 keys_flt = Int[] 72 vals_flt = NDArray[] 73 for (k,v) in zip(keys, vals) 74 append!(keys_flt, Base.ones(Int, length(v))*k) 75 append!(vals_flt, v) 76 end 77 return (keys_flt, vals_flt) 78end 79 80""" 81 init!(kv::KVStore, key::Int, val::NDArray) 82 init!(kv::KVStore, keys, vals) 83 84Initializes a single or a sequence of key-value pairs into the store. 85 86For each key, one must `init!` it before calling `push!` or `pull!`. 87When multiple workers invoke `init!` for the same key, only 88the value supplied by worker with rank `0` is used. This function returns 89after data has been initialized successfully. 90 91```jldoctest 92julia> kv = KVStore(:local) 93mx.KVStore @ local 94 95julia> init!(kv, 42, mx.rand(2, 3)) 96``` 97""" 98init!(kv::KVStore, key::Int, val::NDArray) = init!(kv, [key], [val]) 99init!(kv::KVStore, key::Int, vals::Vector{<:NDArray}) = 100 init!(kv, Base.ones(Int, length(vals)) * key, vals) 101init!(kv::KVStore, keys::Vector{Int}, vals::Vector{<:Vector{<:NDArray}}) = 102 init!(kv, _flatten_kvlist(keys, vals)...) 103 104function init!(kv::KVStore, keys::Vector{Int}, vals::VecOfNDArray) 105 @assert length(keys) == length(vals) 106 keys = Cint[keys...] 107 vals = MX_handle[vals...] 108 @mxcall(:MXKVStoreInit, (MX_handle, MX_uint, Ptr{Cint}, Ptr{MX_handle}), 109 kv, length(keys), keys, vals) 110end 111 112""" 113 push!(kv::KVStore, key, val; priority = 0) 114 push!(kv::KVStore, key, vals; priority = 0) 115 push!(kv::KVStore, keys, vals; priority = 0) 116 117Pushes a single or a sequence of key-value pairs into the store. 118 119This function returns immediately after adding an operator to the engine. 120The actual operation is executed asynchronously. If there are consecutive 121pushes to the same key, there is no guarantee on the serialization of pushes. 122The execution of a push does not guarantee that all previous pushes are 123finished. There is no synchronization between workers by default. 124One can use ``barrier()`` to sync all workers. 125 126`push!` and `pull!` single `NDArray`: 127```jldoctest 128julia> kv = KVStore(:local) 129mx.KVStore @ local 130 131julia> x = NDArray(undef, 2, 3); 132 133julia> init!(kv, 3, x) 134 135julia> push!(kv, 3, mx.ones(2, 3) * 8) 136 137julia> pull!(kv, 3, x) 138 139julia> x 1402×3 mx.NDArray{Float32,2} @ CPU0: 141 8.0 8.0 8.0 142 8.0 8.0 8.0 143``` 144 145Aggregate values and `push!`: 146```jldoctest 147julia> vals = [mx.ones((2, 3), gpu(0)) * 3, mx.ones((2, 3), gpu(1)) * 4]; 148 149julia> push!(kv, 3, vals) 150 151julia> pull!(kv, 3, x) 152 153julia> x 1542×3 mx.NDArray{Float32,2} @ CPU0: 155 7.0 7.0 7.0 156 7.0 7.0 7.0 157``` 158 159`push!` a list of key to single device: 160 161```jldoctest 162julia> keys = [4, 5]; 163 164julia> init!(kv, keys, [NDArray(undef, 2, 3), NDArray(undef, 2, 3)]) 165 166julia> push!(kv, keys, [x, x]) 167 168julia> y, z = NDArray(undef, 2, 3), NDArray(undef, 2, 3); 169 170julia> pull!(kv, keys, [y, z]) 171``` 172""" 173push!(kv::KVStore, key::Int, val::NDArray; priority::Int = 0) = 174 push!(kv, [key], [val]; priority = priority) 175push!(kv::KVStore, key::Int, vals::Vector{<:NDArray}; priority::Int = 0) = 176 push!(kv, Base.ones(Int, length(vals)) * key, vals; priority = priority) 177push!(kv:: KVStore, keys::Vector{Int}, vals::Vector{<:Vector{<:NDArray}}; 178 priority::Int = 0) = 179 push!(kv, _flatten_kvlist(keys, vals)...; priority = priority) 180 181function push!(kv::KVStore, keys::Vector{Int}, vals::Vector{<:NDArray}; priority::Int = 0) 182 @assert length(keys) == length(vals) 183 keys = Cint[keys...] 184 vals = MX_handle[vals...] 185 @mxcall(:MXKVStorePush, (MX_handle, MX_uint, Ptr{Cint}, Ptr{MX_handle}, Cint), 186 kv, length(keys), keys, vals, priority) 187end 188 189""" Pulls a single value or a sequence of values from the store. 190 191This function returns immediately after adding an operator to the engine. 192Subsequent attempts to read from the `out` variable will be blocked until the 193pull operation completes. 194 195`pull` is executed asynchronously after all previous `pull` calls and only 196the last `push` call for the same input key(s) are finished. 197 198The returned values are guaranteed to be the latest values in the store. 199 200See [`pull!`](@ref) for more examples. 201""" 202pull!(kv::KVStore, key::Int, out::NDArray; priority::Int = 0) = 203 pull!(kv, [key], [out], priority = priority) 204pull!(kv::KVStore, key::Int, outs::Vector{<:NDArray}; priority::Int = 0) = 205 pull!(kv, Base.ones(Int, length(outs))*key, outs; priority = priority) 206pull!(kv::KVStore, keys::Vector{Int}, outs::Vector{<:Vector{<:NDArray}}; 207 priority::Int = 0) = 208 pull!(kv, _flatten_kvlist(keys, outs)...; priority = priority) 209 210function pull!(kv::KVStore, keys::Vector{Int}, outs::Vector{<:NDArray}; priority::Int = 0) 211 @assert length(keys) == length(outs) 212 keys = Cint[keys...] 213 outs = MX_handle[outs...] 214 @mxcall(:MXKVStorePull, (MX_handle, MX_uint, Ptr{Cint}, Ptr{MX_handle}, Cint), 215 kv, length(keys), keys, outs, priority) 216end 217 218 219function get_type(kv::KVStore) 220 type_ref = Ref{char_p}(0) 221 @mxcall(:MXKVStoreGetType, (MX_handle, Ref{char_p}), kv, type_ref) 222 return Symbol(unsafe_string(type_ref[])) 223end 224 225function get_num_workers(kv::KVStore) 226 ref_size = Ref{Cint}(0) 227 @mxcall(:MXKVStoreGetGroupSize, (MX_handle, Ref{Cint}), kv, ref_size) 228 return Int(ref_size[]) 229end 230 231function get_rank(kv::KVStore) 232 ref_rank = Ref{Cint}(0) 233 @mxcall(:MXKVStoreGetRank, (MX_handle, Ref{Cint}), kv, ref_rank) 234 return Int(ref_rank[]) 235end 236 237""" 238 barrier(kv::KVStore) 239 240Invokes global barrier among all worker nodes. 241 242For example, assume there are `n` machines. We would like machine `0` to first 243`init` the values and then have all the workers `pull` the initialized value. 244Before pulling, we can place invoke `barrier(kv)` to guarantee that the 245initialization is finished. 246""" 247barrier(kv::KVStore) = @mxcall(:MXKVStoreBarrier, (MX_handle,), kv) 248 249 250# TODO: Currently Julia does not support closure in c-callbacks, so we are making use of the 251# extra handle parameter of the API to pass the updater object around. Fix this when someday 252# full closure cfunction is supported in Julia. 253function _kvstore_update_wrapper(key::Cint, nd_recv::MX_handle, nd_local::MX_handle, 254 updater::Ptr{Cvoid}) 255 updater_func = unsafe_pointer_to_objref(updater) 256 updater_func(Int(key), NDArray(MX_NDArrayHandle(nd_recv)), 257 NDArray(MX_NDArrayHandle(nd_local))) 258 nothing 259end 260 261""" 262 setupdater!(kv, updater) 263 264Sets a `push!` updater into the store. 265 266This function only changes the local store. 267When running on multiple machines one must use `set_optimizer`. 268 269```jldoctest 270julia> update(key, val, orig) = mx.@inplace orig += val .* .2 271update (generic function with 1 method) 272 273julia> kv = KVStore(:local) 274mx.KVStore @ local 275 276julia> mx.setupdater!(kv, update) 277 278julia> init!(kv, 42, mx.ones(2, 3)) 279 280julia> push!(kv, 42, mx.ones(2, 3)) 281 282julia> x = NDArray(undef, 2, 3); 283 284julia> pull!(kv, 42, x) 285 286julia> x 2872×3 mx.NDArray{Float32,2} @ CPU0: 288 1.2 1.2 1.2 289 1.2 1.2 1.2 290``` 291""" 292function setupdater!(kv::KVStore, updater) 293 kv.updater = updater # keep a reference to the julia object so that updater_c is kept valid 294 kv.updater_c = @cfunction(_kvstore_update_wrapper, Cvoid, 295 (Cint,MX_handle,MX_handle,Ptr{Cvoid})) 296 @mxcall(:MXKVStoreSetUpdater, (MX_handle, Ptr{Cvoid}, Any), 297 kv, kv.updater_c, updater) 298end 299 300""" 301 setoptimizer!(kv::KVStore, opt) 302 303Registers an optimizer with the kvstore. 304 305When using a single machine, this function updates the local optimizer. 306If using multiple machines and this operation is invoked from a worker node, 307it will serialized the optimizer with pickle and send it to all servers. 308The function returns after all servers have been updated. 309 310```jldoctest 311julia> kv = KVStore() 312mx.KVStore @ local 313 314julia> W = mx.zeros(2, 3) # 2×3 weight matrix 3152×3 mx.NDArray{Float32,2} @ CPU0: 316 0.0 0.0 0.0 317 0.0 0.0 0.0 318 319julia> init!(kv, 42, W) 320 321julia> setoptimizer!(kv, SGD(η = .2)) # SGD with .2 as learning rate 322 323julia> ∇W = mx.ones(2, 3) # assume it's the gradient 3242×3 mx.NDArray{Float32,2} @ CPU0: 325 1.0 1.0 1.0 326 1.0 1.0 1.0 327 328julia> push!(kv, 42, ∇W) 329 330julia> pull!(kv, 42, W) # fetch weight and write back to `W` 331 332julia> W 3332×3 mx.NDArray{Float32,2} @ CPU0: 334 -0.2 -0.2 -0.2 335 -0.2 -0.2 -0.2 336``` 337""" 338function setoptimizer!(kv::KVStore, opt::AbstractOptimizer) 339 if occursin(r"dist", string(get_type(kv))) && _isworker() 340 # TODO 341 error("not implemented") 342 else 343 setupdater!(kv, getupdater(opt)) 344 end 345end 346 347function _isworker()::Bool 348 ref = Ref{Cint}(0) 349 @mxcall(:MXKVStoreIsWorkerNode, (Ref{Cint},), ref) 350 ref_is_worker[] 351end 352 353# TODO: sparse support? 354