1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #include "td/utils/port/StdStreams.h"
8 
9 #include "td/utils/logging.h"
10 #include "td/utils/misc.h"
11 #include "td/utils/port/detail/Iocp.h"
12 #include "td/utils/port/detail/NativeFd.h"
13 #include "td/utils/port/detail/PollableFd.h"
14 #include "td/utils/port/PollFlags.h"
15 #include "td/utils/port/thread.h"
16 #include "td/utils/ScopeGuard.h"
17 #include "td/utils/Slice.h"
18 #include "td/utils/SliceBuilder.h"
19 
20 #include <atomic>
21 
22 namespace td {
23 
24 #if TD_PORT_POSIX
25 template <int id>
get_file_fd()26 static FileFd &get_file_fd() {
27   static FileFd result = FileFd::from_native_fd(NativeFd(id, true));
28   static auto guard = ScopeExit() + [&] {
29     result.move_as_native_fd().release();
30   };
31   return result;
32 }
33 
Stdin()34 FileFd &Stdin() {
35   return get_file_fd<0>();
36 }
Stdout()37 FileFd &Stdout() {
38   return get_file_fd<1>();
39 }
Stderr()40 FileFd &Stderr() {
41   return get_file_fd<2>();
42 }
43 #elif TD_PORT_WINDOWS
44 template <DWORD id>
45 static FileFd &get_file_fd() {
46 #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM)
47   static auto handle = GetStdHandle(id);
48   LOG_IF(FATAL, handle == INVALID_HANDLE_VALUE) << "Failed to GetStdHandle " << id;
49   static FileFd result = FileFd::from_native_fd(NativeFd(handle, true));
50   static auto guard = ScopeExit() + [&] {
51     result.move_as_native_fd().release();
52   };
53 #else
54   static FileFd result;
55 #endif
56   return result;
57 }
58 
59 FileFd &Stdin() {
60   return get_file_fd<STD_INPUT_HANDLE>();
61 }
62 FileFd &Stdout() {
63   return get_file_fd<STD_OUTPUT_HANDLE>();
64 }
65 FileFd &Stderr() {
66   return get_file_fd<STD_ERROR_HANDLE>();
67 }
68 #endif
69 
70 #if TD_PORT_WINDOWS
71 namespace detail {
72 class BufferedStdinImpl final : private Iocp::Callback {
73  public:
74 #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM)
BufferedStdinImpl()75   BufferedStdinImpl() : info_(NativeFd(GetStdHandle(STD_INPUT_HANDLE), true)) {
76     iocp_ref_ = Iocp::get()->get_ref();
77     read_thread_ = thread([this] { this->read_loop(); });
78   }
79 #else
80   BufferedStdinImpl() {
81     close();
82   }
83 #endif
84   BufferedStdinImpl(const BufferedStdinImpl &) = delete;
85   BufferedStdinImpl &operator=(const BufferedStdinImpl &) = delete;
86   BufferedStdinImpl(BufferedStdinImpl &&) = delete;
87   BufferedStdinImpl &operator=(BufferedStdinImpl &&) = delete;
~BufferedStdinImpl()88   ~BufferedStdinImpl() {
89     info_.move_as_native_fd().release();
90   }
close()91   void close() {
92     close_flag_ = true;
93   }
94 
input_buffer()95   ChainBufferReader &input_buffer() {
96     return reader_;
97   }
98 
get_poll_info()99   PollableFdInfo &get_poll_info() {
100     return info_;
101   }
get_poll_info() const102   const PollableFdInfo &get_poll_info() const {
103     return info_;
104   }
105 
flush_read(size_t max_read=std::numeric_limits<size_t>::max ())106   Result<size_t> flush_read(size_t max_read = std::numeric_limits<size_t>::max()) TD_WARN_UNUSED_RESULT {
107     info_.sync_with_poll();
108     info_.clear_flags(PollFlags::Read());
109     reader_.sync_with_writer();
110     return reader_.size();
111   }
112 
113  private:
114   PollableFdInfo info_;
115   ChainBufferWriter writer_;
116   ChainBufferReader reader_ = writer_.extract_reader();
117   thread read_thread_;
118   std::atomic<bool> close_flag_{false};
119   IocpRef iocp_ref_;
120   std::atomic<int> refcnt_{1};
121 
read_loop()122   void read_loop() {
123     while (!close_flag_) {
124       auto slice = writer_.prepare_append();
125       auto r_size = read(slice);
126       if (r_size.is_error()) {
127         LOG(ERROR) << "Stop read stdin loop: " << r_size.error();
128         break;
129       }
130       writer_.confirm_append(r_size.ok());
131       inc_refcnt();
132       if (!iocp_ref_.post(0, this, nullptr)) {
133         dec_refcnt();
134       }
135     }
136     if (!iocp_ref_.post(0, this, nullptr)) {
137       read_thread_.detach();
138       dec_refcnt();
139     }
140   }
on_iocp(Result<size_t> r_size,WSAOVERLAPPED * overlapped)141   void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) final {
142     info_.add_flags_from_poll(PollFlags::Read());
143     dec_refcnt();
144   }
145 
dec_refcnt()146   bool dec_refcnt() {
147     if (--refcnt_ == 0) {
148       delete this;
149       return true;
150     }
151     return false;
152   }
inc_refcnt()153   void inc_refcnt() {
154     CHECK(refcnt_ != 0);
155     refcnt_++;
156   }
157 
read(MutableSlice slice)158   Result<size_t> read(MutableSlice slice) {
159     auto native_fd = info_.native_fd().fd();
160     DWORD bytes_read = 0;
161     auto res = ReadFile(native_fd, slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_read, nullptr);
162     if (res) {
163       return static_cast<size_t>(bytes_read);
164     }
165     return OS_ERROR(PSLICE() << "Read from " << info_.native_fd() << " has failed");
166   }
167 };
operator ()(BufferedStdinImpl * impl)168 void BufferedStdinImplDeleter::operator()(BufferedStdinImpl *impl) {
169   //  LOG(ERROR) << "Close";
170   impl->close();
171 }
172 }  // namespace detail
173 #elif TD_PORT_POSIX
174 namespace detail {
175 class BufferedStdinImpl {
176  public:
BufferedStdinImpl()177   BufferedStdinImpl() {
178     file_fd_ = FileFd::from_native_fd(NativeFd(Stdin().get_native_fd().fd()));
179     file_fd_.get_native_fd().set_is_blocking(false);
180   }
181   BufferedStdinImpl(const BufferedStdinImpl &) = delete;
182   BufferedStdinImpl &operator=(const BufferedStdinImpl &) = delete;
183   BufferedStdinImpl(BufferedStdinImpl &&) = delete;
184   BufferedStdinImpl &operator=(BufferedStdinImpl &&) = delete;
~BufferedStdinImpl()185   ~BufferedStdinImpl() {
186     file_fd_.get_native_fd().set_is_blocking(true);
187     file_fd_.move_as_native_fd().release();
188   }
189 
input_buffer()190   ChainBufferReader &input_buffer() {
191     return reader_;
192   }
193 
get_poll_info()194   PollableFdInfo &get_poll_info() {
195     return file_fd_.get_poll_info();
196   }
get_poll_info() const197   const PollableFdInfo &get_poll_info() const {
198     return file_fd_.get_poll_info();
199   }
200 
flush_read(size_t max_read=std::numeric_limits<size_t>::max ())201   Result<size_t> flush_read(size_t max_read = std::numeric_limits<size_t>::max()) TD_WARN_UNUSED_RESULT {
202     size_t result = 0;
203     ::td::sync_with_poll(*this);
204     while (::td::can_read_local(*this) && max_read) {
205       MutableSlice slice = writer_.prepare_append();
206       slice.truncate(max_read);
207       TRY_RESULT(x, file_fd_.read(slice));
208       slice.truncate(x);
209       writer_.confirm_append(x);
210       result += x;
211       max_read -= x;
212     }
213     if (result) {
214       reader_.sync_with_writer();
215     }
216     return result;
217   }
218 
219  private:
220   FileFd file_fd_;
221   ChainBufferWriter writer_;
222   ChainBufferReader reader_ = writer_.extract_reader();
223 };
operator ()(BufferedStdinImpl * impl)224 void BufferedStdinImplDeleter::operator()(BufferedStdinImpl *impl) {
225   delete impl;
226 }
227 }  // namespace detail
228 #endif
229 
BufferedStdin()230 BufferedStdin::BufferedStdin() : impl_(make_unique<detail::BufferedStdinImpl>().release()) {
231 }
232 BufferedStdin::BufferedStdin(BufferedStdin &&) noexcept = default;
233 BufferedStdin &BufferedStdin::operator=(BufferedStdin &&) noexcept = default;
234 BufferedStdin::~BufferedStdin() = default;
235 
input_buffer()236 ChainBufferReader &BufferedStdin::input_buffer() {
237   return impl_->input_buffer();
238 }
get_poll_info()239 PollableFdInfo &BufferedStdin::get_poll_info() {
240   return impl_->get_poll_info();
241 }
get_poll_info() const242 const PollableFdInfo &BufferedStdin::get_poll_info() const {
243   return impl_->get_poll_info();
244 }
flush_read(size_t max_read)245 Result<size_t> BufferedStdin::flush_read(size_t max_read) {
246   return impl_->flush_read(max_read);
247 }
248 
249 }  // namespace td
250