1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import Base.push!
19
20"""
21    KVStore(kv_type = :local)
22
23For single machine training, there are two commonly used types:
24
25- `local`: Copies all gradients to CPU memory and updates weights there.
26
27- `device`: Aggregates gradients and updates weights on GPU(s).
28  With this setting, the `KVStore` also attempts to use GPU peer-to-peer
29  communication, potentially accelerating the communication.
30
31For distributed training, `KVStore` also supports a number of types:
32
33- `dist_sync`: Behaves similarly to `local` but with one major difference.
34  With `dist_sync`, batch-size now means the batch size used on each machine.
35  So if there are `n` machines and we use batch size ``b``,
36  then `dist_sync` behaves like `local` with batch size `n * b`.
37
38- `dist_device_sync`: Identical to `dist_sync` with the difference similar
39  to `device` vs `local`.
40
41- `dist_async`: Performs asynchronous updates.
42  The weights are updated whenever gradients are received from any machine.
43  No two updates happen on the same weight at the same time.
44  However, the order is not guaranteed.
45"""
46mutable struct KVStore
47  handle    :: MX_KVStoreHandle
48  updater_c :: Ptr{Cvoid}
49  updater   :: Function
50
51  KVStore(hdr::MX_KVStoreHandle) = new(hdr, Ptr{Cvoid}(0))
52end
53
54function KVStore(kv_type::Symbol = :local)
55  @assert kv_type ∈ (:local, :device, :dist_sync, :dist_device_sync, :dist_async)
56  ref_hdr = Ref{MX_handle}(0)
57  @mxcall(:MXKVStoreCreate, (char_p, Ref{MX_handle}), dump_mx_param(kv_type), ref_hdr)
58  KVStore(MX_KVStoreHandle(ref_hdr[]))
59end
60
61Base.unsafe_convert(::Type{MX_handle}, obj::KVStore) =
62  Base.unsafe_convert(MX_handle, obj.handle)
63Base.convert(t::Type{MX_handle}, obj::KVStore) = Base.unsafe_convert(t, obj)
64Base.cconvert(t::Type{MX_handle}, obj::KVStore) = Base.unsafe_convert(t, obj)
65
66Base.show(io::IO, kv::KVStore) =
67    print(io, "mx.KVStore @ $(get_type(kv))")
68
69function _flatten_kvlist(keys::Vector{Int}, vals::Vector{<:Vector{<:NDArray}})
70  @assert length(keys) == length(vals)
71  keys_flt = Int[]
72  vals_flt = NDArray[]
73  for (k,v) in zip(keys, vals)
74    append!(keys_flt, Base.ones(Int, length(v))*k)
75    append!(vals_flt, v)
76  end
77  return (keys_flt, vals_flt)
78end
79
80"""
81    init!(kv::KVStore, key::Int, val::NDArray)
82    init!(kv::KVStore, keys, vals)
83
84Initializes a single or a sequence of key-value pairs into the store.
85
86For each key, one must `init!` it before calling `push!` or `pull!`.
87When multiple workers invoke `init!` for the same key, only
88the value supplied by worker with rank `0` is used. This function returns
89after data has been initialized successfully.
90
91```jldoctest
92julia> kv = KVStore(:local)
93mx.KVStore @ local
94
95julia> init!(kv, 42, mx.rand(2, 3))
96```
97"""
98init!(kv::KVStore, key::Int, val::NDArray) = init!(kv, [key], [val])
99init!(kv::KVStore, key::Int, vals::Vector{<:NDArray}) =
100  init!(kv, Base.ones(Int, length(vals)) * key, vals)
101init!(kv::KVStore, keys::Vector{Int}, vals::Vector{<:Vector{<:NDArray}}) =
102  init!(kv, _flatten_kvlist(keys, vals)...)
103
104function init!(kv::KVStore, keys::Vector{Int}, vals::VecOfNDArray)
105  @assert length(keys) == length(vals)
106  keys = Cint[keys...]
107  vals = MX_handle[vals...]
108  @mxcall(:MXKVStoreInit, (MX_handle, MX_uint, Ptr{Cint}, Ptr{MX_handle}),
109          kv, length(keys), keys, vals)
110end
111
112"""
113    push!(kv::KVStore, key,  val;  priority = 0)
114    push!(kv::KVStore, key,  vals; priority = 0)
115    push!(kv::KVStore, keys, vals; priority = 0)
116
117Pushes a single or a sequence of key-value pairs into the store.
118
119This function returns immediately after adding an operator to the engine.
120The actual operation is executed asynchronously. If there are consecutive
121pushes to the same key, there is no guarantee on the serialization of pushes.
122The execution of a push does not guarantee that all previous pushes are
123finished. There is no synchronization between workers by default.
124One can use ``barrier()`` to sync all workers.
125
126`push!` and `pull!` single `NDArray`:
127```jldoctest
128julia> kv = KVStore(:local)
129mx.KVStore @ local
130
131julia> x = NDArray(undef, 2, 3);
132
133julia> init!(kv, 3, x)
134
135julia> push!(kv, 3, mx.ones(2, 3) * 8)
136
137julia> pull!(kv, 3, x)
138
139julia> x
1402×3 mx.NDArray{Float32,2} @ CPU0:
141 8.0  8.0  8.0
142 8.0  8.0  8.0
143```
144
145Aggregate values and `push!`:
146```jldoctest
147julia> vals = [mx.ones((2, 3), gpu(0)) * 3, mx.ones((2, 3), gpu(1)) * 4];
148
149julia> push!(kv, 3, vals)
150
151julia> pull!(kv, 3, x)
152
153julia> x
1542×3 mx.NDArray{Float32,2} @ CPU0:
155 7.0  7.0  7.0
156 7.0  7.0  7.0
157```
158
159`push!` a list of key to single device:
160
161```jldoctest
162julia> keys = [4, 5];
163
164julia> init!(kv, keys, [NDArray(undef, 2, 3), NDArray(undef, 2, 3)])
165
166julia> push!(kv, keys, [x, x])
167
168julia> y, z = NDArray(undef, 2, 3), NDArray(undef, 2, 3);
169
170julia> pull!(kv, keys, [y, z])
171```
172"""
173push!(kv::KVStore, key::Int, val::NDArray; priority::Int = 0) =
174  push!(kv, [key], [val]; priority = priority)
175push!(kv::KVStore, key::Int, vals::Vector{<:NDArray}; priority::Int = 0) =
176  push!(kv, Base.ones(Int, length(vals)) * key, vals; priority = priority)
177push!(kv:: KVStore, keys::Vector{Int}, vals::Vector{<:Vector{<:NDArray}};
178      priority::Int = 0) =
179  push!(kv, _flatten_kvlist(keys, vals)...; priority = priority)
180
181function push!(kv::KVStore, keys::Vector{Int}, vals::Vector{<:NDArray}; priority::Int = 0)
182  @assert length(keys) == length(vals)
183  keys = Cint[keys...]
184  vals = MX_handle[vals...]
185  @mxcall(:MXKVStorePush, (MX_handle, MX_uint, Ptr{Cint}, Ptr{MX_handle}, Cint),
186          kv, length(keys), keys, vals, priority)
187end
188
189""" Pulls a single value or a sequence of values from the store.
190
191This function returns immediately after adding an operator to the engine.
192Subsequent attempts to read from the `out` variable will be blocked until the
193pull operation completes.
194
195`pull` is executed asynchronously after all previous `pull` calls and only
196the last `push` call for the same input key(s) are finished.
197
198The returned values are guaranteed to be the latest values in the store.
199
200See [`pull!`](@ref) for more examples.
201"""
202pull!(kv::KVStore, key::Int, out::NDArray; priority::Int = 0) =
203  pull!(kv, [key], [out], priority = priority)
204pull!(kv::KVStore, key::Int, outs::Vector{<:NDArray}; priority::Int = 0) =
205  pull!(kv, Base.ones(Int, length(outs))*key, outs; priority = priority)
206pull!(kv::KVStore, keys::Vector{Int}, outs::Vector{<:Vector{<:NDArray}};
207      priority::Int = 0) =
208  pull!(kv, _flatten_kvlist(keys, outs)...; priority = priority)
209
210function pull!(kv::KVStore, keys::Vector{Int}, outs::Vector{<:NDArray}; priority::Int = 0)
211  @assert length(keys) == length(outs)
212  keys = Cint[keys...]
213  outs = MX_handle[outs...]
214  @mxcall(:MXKVStorePull, (MX_handle, MX_uint, Ptr{Cint}, Ptr{MX_handle}, Cint),
215          kv, length(keys), keys, outs, priority)
216end
217
218
219function get_type(kv::KVStore)
220  type_ref = Ref{char_p}(0)
221  @mxcall(:MXKVStoreGetType, (MX_handle, Ref{char_p}), kv, type_ref)
222  return Symbol(unsafe_string(type_ref[]))
223end
224
225function get_num_workers(kv::KVStore)
226  ref_size = Ref{Cint}(0)
227  @mxcall(:MXKVStoreGetGroupSize, (MX_handle, Ref{Cint}), kv, ref_size)
228  return Int(ref_size[])
229end
230
231function get_rank(kv::KVStore)
232  ref_rank = Ref{Cint}(0)
233  @mxcall(:MXKVStoreGetRank, (MX_handle, Ref{Cint}), kv, ref_rank)
234  return Int(ref_rank[])
235end
236
237"""
238    barrier(kv::KVStore)
239
240Invokes global barrier among all worker nodes.
241
242For example, assume there are `n` machines. We would like machine `0` to first
243`init` the values and then have all the workers `pull` the initialized value.
244Before pulling, we can place invoke `barrier(kv)` to guarantee that the
245initialization is finished.
246"""
247barrier(kv::KVStore) = @mxcall(:MXKVStoreBarrier, (MX_handle,), kv)
248
249
250# TODO: Currently Julia does not support closure in c-callbacks, so we are making use of the
251# extra handle parameter of the API to pass the updater object around. Fix this when someday
252# full closure cfunction is supported in Julia.
253function _kvstore_update_wrapper(key::Cint, nd_recv::MX_handle, nd_local::MX_handle,
254                                 updater::Ptr{Cvoid})
255  updater_func = unsafe_pointer_to_objref(updater)
256  updater_func(Int(key), NDArray(MX_NDArrayHandle(nd_recv)),
257               NDArray(MX_NDArrayHandle(nd_local)))
258  nothing
259end
260
261"""
262    setupdater!(kv, updater)
263
264Sets a `push!` updater into the store.
265
266This function only changes the local store.
267When running on multiple machines one must use `set_optimizer`.
268
269```jldoctest
270julia> update(key, val, orig) = mx.@inplace orig += val .* .2
271update (generic function with 1 method)
272
273julia> kv = KVStore(:local)
274mx.KVStore @ local
275
276julia> mx.setupdater!(kv, update)
277
278julia> init!(kv, 42, mx.ones(2, 3))
279
280julia> push!(kv, 42, mx.ones(2, 3))
281
282julia> x = NDArray(undef, 2, 3);
283
284julia> pull!(kv, 42, x)
285
286julia> x
2872×3 mx.NDArray{Float32,2} @ CPU0:
288 1.2  1.2  1.2
289 1.2  1.2  1.2
290```
291"""
292function setupdater!(kv::KVStore, updater)
293  kv.updater = updater # keep a reference to the julia object so that updater_c is kept valid
294  kv.updater_c = @cfunction(_kvstore_update_wrapper, Cvoid,
295                            (Cint,MX_handle,MX_handle,Ptr{Cvoid}))
296  @mxcall(:MXKVStoreSetUpdater, (MX_handle, Ptr{Cvoid}, Any),
297          kv, kv.updater_c, updater)
298end
299
300"""
301    setoptimizer!(kv::KVStore, opt)
302
303Registers an optimizer with the kvstore.
304
305When using a single machine, this function updates the local optimizer.
306If using multiple machines and this operation is invoked from a worker node,
307it will serialized the optimizer with pickle and send it to all servers.
308The function returns after all servers have been updated.
309
310```jldoctest
311julia> kv = KVStore()
312mx.KVStore @ local
313
314julia> W = mx.zeros(2, 3)  # 2×3 weight matrix
3152×3 mx.NDArray{Float32,2} @ CPU0:
316 0.0  0.0  0.0
317 0.0  0.0  0.0
318
319julia> init!(kv, 42, W)
320
321julia> setoptimizer!(kv, SGD(η = .2))  # SGD with .2 as learning rate
322
323julia> ∇W = mx.ones(2, 3)  # assume it's the gradient
3242×3 mx.NDArray{Float32,2} @ CPU0:
325 1.0  1.0  1.0
326 1.0  1.0  1.0
327
328julia> push!(kv, 42, ∇W)
329
330julia> pull!(kv, 42, W)  # fetch weight and write back to `W`
331
332julia> W
3332×3 mx.NDArray{Float32,2} @ CPU0:
334 -0.2  -0.2  -0.2
335 -0.2  -0.2  -0.2
336```
337"""
338function setoptimizer!(kv::KVStore, opt::AbstractOptimizer)
339  if occursin(r"dist", string(get_type(kv))) && _isworker()
340    # TODO
341    error("not implemented")
342  else
343    setupdater!(kv, getupdater(opt))
344  end
345end
346
347function _isworker()::Bool
348  ref = Ref{Cint}(0)
349  @mxcall(:MXKVStoreIsWorkerNode, (Ref{Cint},), ref)
350  ref_is_worker[]
351end
352
353# TODO: sparse support?
354