1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #include "td/utils/port/StdStreams.h"
8
9 #include "td/utils/logging.h"
10 #include "td/utils/misc.h"
11 #include "td/utils/port/detail/Iocp.h"
12 #include "td/utils/port/detail/NativeFd.h"
13 #include "td/utils/port/detail/PollableFd.h"
14 #include "td/utils/port/PollFlags.h"
15 #include "td/utils/port/thread.h"
16 #include "td/utils/ScopeGuard.h"
17 #include "td/utils/Slice.h"
18 #include "td/utils/SliceBuilder.h"
19
20 #include <atomic>
21
22 namespace td {
23
24 #if TD_PORT_POSIX
25 template <int id>
get_file_fd()26 static FileFd &get_file_fd() {
27 static FileFd result = FileFd::from_native_fd(NativeFd(id, true));
28 static auto guard = ScopeExit() + [&] {
29 result.move_as_native_fd().release();
30 };
31 return result;
32 }
33
Stdin()34 FileFd &Stdin() {
35 return get_file_fd<0>();
36 }
Stdout()37 FileFd &Stdout() {
38 return get_file_fd<1>();
39 }
Stderr()40 FileFd &Stderr() {
41 return get_file_fd<2>();
42 }
43 #elif TD_PORT_WINDOWS
44 template <DWORD id>
45 static FileFd &get_file_fd() {
46 #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM)
47 static auto handle = GetStdHandle(id);
48 LOG_IF(FATAL, handle == INVALID_HANDLE_VALUE) << "Failed to GetStdHandle " << id;
49 static FileFd result = FileFd::from_native_fd(NativeFd(handle, true));
50 static auto guard = ScopeExit() + [&] {
51 result.move_as_native_fd().release();
52 };
53 #else
54 static FileFd result;
55 #endif
56 return result;
57 }
58
59 FileFd &Stdin() {
60 return get_file_fd<STD_INPUT_HANDLE>();
61 }
62 FileFd &Stdout() {
63 return get_file_fd<STD_OUTPUT_HANDLE>();
64 }
65 FileFd &Stderr() {
66 return get_file_fd<STD_ERROR_HANDLE>();
67 }
68 #endif
69
70 #if TD_PORT_WINDOWS
71 namespace detail {
72 class BufferedStdinImpl final : private Iocp::Callback {
73 public:
74 #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM)
BufferedStdinImpl()75 BufferedStdinImpl() : info_(NativeFd(GetStdHandle(STD_INPUT_HANDLE), true)) {
76 iocp_ref_ = Iocp::get()->get_ref();
77 read_thread_ = thread([this] { this->read_loop(); });
78 }
79 #else
80 BufferedStdinImpl() {
81 close();
82 }
83 #endif
84 BufferedStdinImpl(const BufferedStdinImpl &) = delete;
85 BufferedStdinImpl &operator=(const BufferedStdinImpl &) = delete;
86 BufferedStdinImpl(BufferedStdinImpl &&) = delete;
87 BufferedStdinImpl &operator=(BufferedStdinImpl &&) = delete;
~BufferedStdinImpl()88 ~BufferedStdinImpl() {
89 info_.move_as_native_fd().release();
90 }
close()91 void close() {
92 close_flag_ = true;
93 }
94
input_buffer()95 ChainBufferReader &input_buffer() {
96 return reader_;
97 }
98
get_poll_info()99 PollableFdInfo &get_poll_info() {
100 return info_;
101 }
get_poll_info() const102 const PollableFdInfo &get_poll_info() const {
103 return info_;
104 }
105
flush_read(size_t max_read=std::numeric_limits<size_t>::max ())106 Result<size_t> flush_read(size_t max_read = std::numeric_limits<size_t>::max()) TD_WARN_UNUSED_RESULT {
107 info_.sync_with_poll();
108 info_.clear_flags(PollFlags::Read());
109 reader_.sync_with_writer();
110 return reader_.size();
111 }
112
113 private:
114 PollableFdInfo info_;
115 ChainBufferWriter writer_;
116 ChainBufferReader reader_ = writer_.extract_reader();
117 thread read_thread_;
118 std::atomic<bool> close_flag_{false};
119 IocpRef iocp_ref_;
120 std::atomic<int> refcnt_{1};
121
read_loop()122 void read_loop() {
123 while (!close_flag_) {
124 auto slice = writer_.prepare_append();
125 auto r_size = read(slice);
126 if (r_size.is_error()) {
127 LOG(ERROR) << "Stop read stdin loop: " << r_size.error();
128 break;
129 }
130 writer_.confirm_append(r_size.ok());
131 inc_refcnt();
132 if (!iocp_ref_.post(0, this, nullptr)) {
133 dec_refcnt();
134 }
135 }
136 if (!iocp_ref_.post(0, this, nullptr)) {
137 read_thread_.detach();
138 dec_refcnt();
139 }
140 }
on_iocp(Result<size_t> r_size,WSAOVERLAPPED * overlapped)141 void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) final {
142 info_.add_flags_from_poll(PollFlags::Read());
143 dec_refcnt();
144 }
145
dec_refcnt()146 bool dec_refcnt() {
147 if (--refcnt_ == 0) {
148 delete this;
149 return true;
150 }
151 return false;
152 }
inc_refcnt()153 void inc_refcnt() {
154 CHECK(refcnt_ != 0);
155 refcnt_++;
156 }
157
read(MutableSlice slice)158 Result<size_t> read(MutableSlice slice) {
159 auto native_fd = info_.native_fd().fd();
160 DWORD bytes_read = 0;
161 auto res = ReadFile(native_fd, slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_read, nullptr);
162 if (res) {
163 return static_cast<size_t>(bytes_read);
164 }
165 return OS_ERROR(PSLICE() << "Read from " << info_.native_fd() << " has failed");
166 }
167 };
operator ()(BufferedStdinImpl * impl)168 void BufferedStdinImplDeleter::operator()(BufferedStdinImpl *impl) {
169 // LOG(ERROR) << "Close";
170 impl->close();
171 }
172 } // namespace detail
173 #elif TD_PORT_POSIX
174 namespace detail {
175 class BufferedStdinImpl {
176 public:
BufferedStdinImpl()177 BufferedStdinImpl() {
178 file_fd_ = FileFd::from_native_fd(NativeFd(Stdin().get_native_fd().fd()));
179 file_fd_.get_native_fd().set_is_blocking(false);
180 }
181 BufferedStdinImpl(const BufferedStdinImpl &) = delete;
182 BufferedStdinImpl &operator=(const BufferedStdinImpl &) = delete;
183 BufferedStdinImpl(BufferedStdinImpl &&) = delete;
184 BufferedStdinImpl &operator=(BufferedStdinImpl &&) = delete;
~BufferedStdinImpl()185 ~BufferedStdinImpl() {
186 file_fd_.get_native_fd().set_is_blocking(true);
187 file_fd_.move_as_native_fd().release();
188 }
189
input_buffer()190 ChainBufferReader &input_buffer() {
191 return reader_;
192 }
193
get_poll_info()194 PollableFdInfo &get_poll_info() {
195 return file_fd_.get_poll_info();
196 }
get_poll_info() const197 const PollableFdInfo &get_poll_info() const {
198 return file_fd_.get_poll_info();
199 }
200
flush_read(size_t max_read=std::numeric_limits<size_t>::max ())201 Result<size_t> flush_read(size_t max_read = std::numeric_limits<size_t>::max()) TD_WARN_UNUSED_RESULT {
202 size_t result = 0;
203 ::td::sync_with_poll(*this);
204 while (::td::can_read_local(*this) && max_read) {
205 MutableSlice slice = writer_.prepare_append();
206 slice.truncate(max_read);
207 TRY_RESULT(x, file_fd_.read(slice));
208 slice.truncate(x);
209 writer_.confirm_append(x);
210 result += x;
211 max_read -= x;
212 }
213 if (result) {
214 reader_.sync_with_writer();
215 }
216 return result;
217 }
218
219 private:
220 FileFd file_fd_;
221 ChainBufferWriter writer_;
222 ChainBufferReader reader_ = writer_.extract_reader();
223 };
operator ()(BufferedStdinImpl * impl)224 void BufferedStdinImplDeleter::operator()(BufferedStdinImpl *impl) {
225 delete impl;
226 }
227 } // namespace detail
228 #endif
229
BufferedStdin()230 BufferedStdin::BufferedStdin() : impl_(make_unique<detail::BufferedStdinImpl>().release()) {
231 }
232 BufferedStdin::BufferedStdin(BufferedStdin &&) noexcept = default;
233 BufferedStdin &BufferedStdin::operator=(BufferedStdin &&) noexcept = default;
234 BufferedStdin::~BufferedStdin() = default;
235
input_buffer()236 ChainBufferReader &BufferedStdin::input_buffer() {
237 return impl_->input_buffer();
238 }
get_poll_info()239 PollableFdInfo &BufferedStdin::get_poll_info() {
240 return impl_->get_poll_info();
241 }
get_poll_info() const242 const PollableFdInfo &BufferedStdin::get_poll_info() const {
243 return impl_->get_poll_info();
244 }
flush_read(size_t max_read)245 Result<size_t> BufferedStdin::flush_read(size_t max_read) {
246 return impl_->flush_read(max_read);
247 }
248
249 } // namespace td
250