1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/shape_inference.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/stream_executor/device_memory.h"
25 #include "tensorflow/stream_executor/event.h"
26 #include "tensorflow/stream_executor/stream.h"
27
28 #ifdef GOOGLE_CUDA
29 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
30 #include "tensorflow/core/platform/stream_executor.h"
31 #endif
32
33 #include "arrow/adapters/tensorflow/convert.h"
34 #include "arrow/api.h"
35 #include "arrow/io/api.h"
36 #include "arrow/util/logging.h"
37
38 // These headers do not include Python.h
39 #include "arrow/python/deserialize.h"
40 #include "arrow/python/serialize.h"
41
42 #include "plasma/client.h"
43
44 namespace tf = tensorflow;
45
46 using ArrowStatus = arrow::Status;
47 using CPUDevice = Eigen::ThreadPoolDevice;
48 using GPUDevice = Eigen::GpuDevice;
49
50 using Event = perftools::gputools::Event;
51 using Stream = perftools::gputools::Stream;
52
53 // NOTE(zongheng): for some reason using unique_ptr or shared_ptr results in
54 // CUDA_ERROR_DEINITIALIZED on program exit. I suspect this is because the
55 // static object's dtor gets called *after* TensorFlow's own CUDA cleanup.
56 // Instead, we use a raw pointer here and manually clean up in the Ops' dtors.
57 static Stream* d2h_stream = nullptr;
58 static tf::mutex d2h_stream_mu;
59
60 // TODO(zongheng): CPU kernels' std::memcpy might be able to be sped up by
61 // parallelization.
62
get_byte_width(const arrow::DataType & dtype)63 int64_t get_byte_width(const arrow::DataType& dtype) {
64 return arrow::internal::checked_cast<const arrow::FixedWidthType&>(dtype)
65 .bit_width() / CHAR_BIT;
66 }
67
68 // Put: tf.Tensor -> plasma.
69 template <typename Device>
70 class TensorToPlasmaOp : public tf::AsyncOpKernel {
71 public:
TensorToPlasmaOp(tf::OpKernelConstruction * context)72 explicit TensorToPlasmaOp(tf::OpKernelConstruction* context) : tf::AsyncOpKernel(context) {
73 OP_REQUIRES_OK(context, context->GetAttr("plasma_store_socket_name",
74 &plasma_store_socket_name_));
75 tf::mutex_lock lock(mu_);
76 if (!connected_) {
77 VLOG(1) << "Connecting to Plasma...";
78 ARROW_CHECK_OK(client_.Connect(plasma_store_socket_name_));
79 VLOG(1) << "Connected!";
80 connected_ = true;
81 }
82 }
83
~TensorToPlasmaOp()84 ~TensorToPlasmaOp() override {
85 {
86 tf::mutex_lock lock(mu_);
87 ARROW_CHECK_OK(client_.Disconnect());
88 connected_ = false;
89 }
90 {
91 tf::mutex_lock lock(d2h_stream_mu);
92 if (d2h_stream != nullptr) {
93 delete d2h_stream;
94 }
95 }
96 }
97
ComputeAsync(tf::OpKernelContext * context,DoneCallback done)98 void ComputeAsync(tf::OpKernelContext* context, DoneCallback done) override {
99 const int num_inputs = context->num_inputs();
100 OP_REQUIRES_ASYNC(
101 context, num_inputs >= 2,
102 tf::errors::InvalidArgument("Input should have at least 1 tensor and 1 object_id"),
103 done);
104 const int num_tensors = num_inputs - 1;
105
106 // Check that all tensors have the same dtype
107 tf::DataType tf_dtype = context->input(0).dtype();
108 for (int i = 1; i < num_inputs - 1; i++) {
109 if (tf_dtype != context->input(i).dtype()) {
110 ARROW_CHECK_OK(arrow::Status(arrow::StatusCode::TypeError,
111 "All input tensors must have the same data type"));
112 }
113 }
114
115 std::shared_ptr<arrow::DataType> arrow_dtype;
116 ARROW_CHECK_OK(arrow::adapters::tensorflow::GetArrowType(tf_dtype, &arrow_dtype));
117 int64_t byte_width = get_byte_width(*arrow_dtype);
118
119 std::vector<size_t> offsets;
120 offsets.reserve(num_tensors + 1);
121 offsets.push_back(0);
122 int64_t total_bytes = 0;
123 for (int i = 0; i < num_tensors; ++i) {
124 const size_t s = context->input(i).TotalBytes();
125 CHECK_EQ(s, context->input(i).NumElements() * byte_width);
126 CHECK_GT(s, 0);
127 total_bytes += s;
128 offsets.push_back(total_bytes);
129 }
130
131 const tf::Tensor& plasma_object_id = context->input(num_inputs - 1);
132 CHECK_EQ(plasma_object_id.NumElements(), 1);
133 const std::string& plasma_object_id_str = plasma_object_id.flat<std::string>()(0);
134 VLOG(1) << "plasma_object_id_str: '" << plasma_object_id_str << "'";
135 const plasma::ObjectID object_id =
136 plasma::ObjectID::from_binary(plasma_object_id_str);
137
138 std::vector<int64_t> shape = {total_bytes / byte_width};
139
140 arrow::io::MockOutputStream mock;
141 ARROW_CHECK_OK(arrow::py::WriteNdarrayHeader(arrow_dtype, shape, 0, &mock));
142 int64_t header_size = mock.GetExtentBytesWritten();
143
144 std::shared_ptr<Buffer> data_buffer;
145 {
146 tf::mutex_lock lock(mu_);
147 ARROW_CHECK_OK(client_.Create(object_id, header_size + total_bytes,
148 /*metadata=*/nullptr, 0, &data_buffer));
149 }
150
151 int64_t offset;
152 arrow::io::FixedSizeBufferWriter buf(data_buffer);
153 ARROW_CHECK_OK(arrow::py::WriteNdarrayHeader(arrow_dtype, shape, total_bytes, &buf));
154 ARROW_CHECK_OK(buf.Tell(&offset));
155
156 uint8_t* data = reinterpret_cast<uint8_t*>(data_buffer->mutable_data() + offset);
157
158 auto wrapped_callback = [this, context, done, data_buffer, data, object_id]() {
159 {
160 tf::mutex_lock lock(mu_);
161 ARROW_CHECK_OK(client_.Seal(object_id));
162 ARROW_CHECK_OK(client_.Release(object_id));
163 #ifdef GOOGLE_CUDA
164 auto orig_stream = context->op_device_context()->stream();
165 auto stream_executor = orig_stream->parent();
166 CHECK(stream_executor->HostMemoryUnregister(static_cast<void*>(data)));
167 #endif
168 }
169 context->SetStatus(tensorflow::Status::OK());
170 done();
171 };
172
173 if (std::is_same<Device, CPUDevice>::value) {
174 for (int i = 0; i < num_tensors; ++i) {
175 const auto& input_tensor = context->input(i);
176 std::memcpy(static_cast<void*>(data + offsets[i]),
177 input_tensor.tensor_data().data(),
178 static_cast<tf::uint64>(offsets[i + 1] - offsets[i]));
179 }
180 wrapped_callback();
181 } else {
182 #ifdef GOOGLE_CUDA
183 auto orig_stream = context->op_device_context()->stream();
184 OP_REQUIRES_ASYNC(context, orig_stream != nullptr,
185 tf::errors::Internal("No GPU stream available."), done);
186 auto stream_executor = orig_stream->parent();
187
188 // NOTE(zongheng): this is critical of getting good performance out of D2H
189 // async memcpy. Under the hood it performs cuMemHostRegister(), see:
190 // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223
191 CHECK(stream_executor->HostMemoryRegister(static_cast<void*>(data),
192 static_cast<tf::uint64>(total_bytes)));
193
194 {
195 tf::mutex_lock l(d2h_stream_mu);
196 if (d2h_stream == nullptr) {
197 d2h_stream = new Stream(stream_executor);
198 CHECK(d2h_stream->Init().ok());
199 }
200 }
201
202 // Needed to make sure the input buffers have been computed.
203 // NOTE(ekl): this is unnecessary when the op is behind a NCCL allreduce already
204 CHECK(d2h_stream->ThenWaitFor(orig_stream).ok());
205
206 for (int i = 0; i < num_tensors; ++i) {
207 const auto& input_tensor = context->input(i);
208 auto input_buffer = const_cast<char*>(input_tensor.tensor_data().data());
209 perftools::gputools::DeviceMemoryBase wrapped_src(
210 static_cast<void*>(input_buffer));
211 const bool success =
212 d2h_stream
213 ->ThenMemcpy(static_cast<void*>(data + offsets[i]), wrapped_src,
214 static_cast<tf::uint64>(offsets[i + 1] - offsets[i]))
215 .ok();
216 OP_REQUIRES_ASYNC(context, success,
217 tf::errors::Internal("D2H memcpy failed to be enqueued."), done);
218 }
219 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
220 d2h_stream, std::move(wrapped_callback));
221 #endif
222 }
223 }
224
225 private:
226 std::string plasma_store_socket_name_;
227
228 tf::mutex mu_;
229 bool connected_ = false;
230 plasma::PlasmaClient client_ GUARDED_BY(mu_);
231 };
232
233 static Stream* h2d_stream = nullptr;
234 static tf::mutex h2d_stream_mu;
235
236 // Get: plasma -> tf.Tensor.
237 template <typename Device>
238 class PlasmaToTensorOp : public tf::AsyncOpKernel {
239 public:
PlasmaToTensorOp(tf::OpKernelConstruction * context)240 explicit PlasmaToTensorOp(tf::OpKernelConstruction* context) : tf::AsyncOpKernel(context) {
241 OP_REQUIRES_OK(context, context->GetAttr("plasma_store_socket_name",
242 &plasma_store_socket_name_));
243 tf::mutex_lock lock(mu_);
244 if (!connected_) {
245 VLOG(1) << "Connecting to Plasma...";
246 ARROW_CHECK_OK(client_.Connect(plasma_store_socket_name_));
247 VLOG(1) << "Connected!";
248 connected_ = true;
249 }
250 }
251
~PlasmaToTensorOp()252 ~PlasmaToTensorOp() override {
253 {
254 tf::mutex_lock lock(mu_);
255 ARROW_CHECK_OK(client_.Disconnect());
256 connected_ = false;
257 }
258 {
259 tf::mutex_lock lock(h2d_stream_mu);
260 if (h2d_stream != nullptr) {
261 delete h2d_stream;
262 }
263 }
264 }
265
ComputeAsync(tf::OpKernelContext * context,DoneCallback done)266 void ComputeAsync(tf::OpKernelContext* context, DoneCallback done) override {
267 const tf::Tensor& plasma_object_id = context->input(0);
268 CHECK_EQ(plasma_object_id.NumElements(), 1);
269 const std::string& plasma_object_id_str = plasma_object_id.flat<std::string>()(0);
270
271 VLOG(1) << "plasma_object_id_str: '" << plasma_object_id_str << "'";
272 const plasma::ObjectID object_id =
273 plasma::ObjectID::from_binary(plasma_object_id_str);
274
275 plasma::ObjectBuffer object_buffer;
276 {
277 tf::mutex_lock lock(mu_);
278 // NOTE(zongheng): this is a blocking call. We might want to (1) make
279 // Plasma asynchronous, (2) launch a thread / event here ourselves, or
280 // something like that...
281 ARROW_CHECK_OK(client_.Get(&object_id, /*num_objects=*/1,
282 /*timeout_ms=*/-1, &object_buffer));
283 }
284
285 std::shared_ptr<arrow::Tensor> ndarray;
286 ARROW_CHECK_OK(arrow::py::NdarrayFromBuffer(object_buffer.data, &ndarray));
287
288 int64_t byte_width = get_byte_width(*ndarray->type());
289 const int64_t size_in_bytes = ndarray->data()->size();
290
291 tf::TensorShape shape({static_cast<int64_t>(size_in_bytes / byte_width)});
292
293 const float* plasma_data = reinterpret_cast<const float*>(ndarray->raw_data());
294
295 tf::Tensor* output_tensor = nullptr;
296 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shape, &output_tensor),
297 done);
298
299 auto wrapped_callback = [this, context, done, plasma_data, object_id]() {
300 {
301 tf::mutex_lock lock(mu_);
302 ARROW_CHECK_OK(client_.Release(object_id));
303 #ifdef GOOGLE_CUDA
304 auto orig_stream = context->op_device_context()->stream();
305 auto stream_executor = orig_stream->parent();
306 CHECK(stream_executor->HostMemoryUnregister(
307 const_cast<void*>(static_cast<const void*>(plasma_data))));
308 #endif
309 }
310 done();
311 };
312
313 if (std::is_same<Device, CPUDevice>::value) {
314 std::memcpy(
315 reinterpret_cast<void*>(const_cast<char*>(output_tensor->tensor_data().data())),
316 plasma_data, size_in_bytes);
317 wrapped_callback();
318 } else {
319 #ifdef GOOGLE_CUDA
320 auto orig_stream = context->op_device_context()->stream();
321 OP_REQUIRES_ASYNC(context, orig_stream != nullptr,
322 tf::errors::Internal("No GPU stream available."), done);
323 auto stream_executor = orig_stream->parent();
324
325 {
326 tf::mutex_lock l(h2d_stream_mu);
327 if (h2d_stream == nullptr) {
328 h2d_stream = new Stream(stream_executor);
329 CHECK(h2d_stream->Init().ok());
330 }
331 }
332
333 // Important. See note in T2P op.
334 CHECK(stream_executor->HostMemoryRegister(
335 const_cast<void*>(static_cast<const void*>(plasma_data)),
336 static_cast<tf::uint64>(size_in_bytes)));
337
338 perftools::gputools::DeviceMemoryBase wrapped_dst(
339 reinterpret_cast<void*>(const_cast<char*>(output_tensor->tensor_data().data())));
340 const bool success =
341 h2d_stream
342 ->ThenMemcpy(&wrapped_dst, static_cast<const void*>(plasma_data),
343 static_cast<tf::uint64>(size_in_bytes))
344 .ok();
345 OP_REQUIRES_ASYNC(context, success,
346 tf::errors::Internal("H2D memcpy failed to be enqueued."), done);
347
348 // Without this sync the main compute stream might proceed to use the
349 // Tensor buffer, but its contents might still be in-flight from our
350 // h2d_stream.
351 CHECK(orig_stream->ThenWaitFor(h2d_stream).ok());
352
353 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
354 h2d_stream, std::move(wrapped_callback));
355 #endif
356 }
357 }
358
359 private:
360 std::string plasma_store_socket_name_;
361
362 tf::mutex mu_;
363 bool connected_ = false;
364 plasma::PlasmaClient client_ GUARDED_BY(mu_);
365 };
366
367 REGISTER_OP("TensorToPlasma")
368 .Input("input_tensor: dtypes")
369 .Input("plasma_object_id: string")
370 .Attr("dtypes: list(type)")
371 .Attr("plasma_store_socket_name: string");
372
373 REGISTER_KERNEL_BUILDER(Name("TensorToPlasma").Device(tf::DEVICE_CPU),
374 TensorToPlasmaOp<CPUDevice>);
375 #ifdef GOOGLE_CUDA
376 REGISTER_KERNEL_BUILDER(Name("TensorToPlasma").Device(tf::DEVICE_GPU),
377 TensorToPlasmaOp<GPUDevice>);
378 #endif
379
380 REGISTER_OP("PlasmaToTensor")
381 .Input("plasma_object_id: string")
382 .Output("tensor: dtype")
383 .Attr("dtype: type")
384 .Attr("plasma_store_socket_name: string");
385
386 REGISTER_KERNEL_BUILDER(Name("PlasmaToTensor").Device(tf::DEVICE_CPU),
387 PlasmaToTensorOp<CPUDevice>);
388 #ifdef GOOGLE_CUDA
389 REGISTER_KERNEL_BUILDER(Name("PlasmaToTensor").Device(tf::DEVICE_GPU),
390 PlasmaToTensorOp<GPUDevice>);
391 #endif
392