1 //===- llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h ----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H 10 #define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H 11 12 #include "llvm/ADT/StringRef.h" 13 #include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h" 14 #include "llvm/Support/Endian.h" 15 #include "llvm/Support/Error.h" 16 #include <cstdint> 17 #include <mutex> 18 #include <string> 19 #include <type_traits> 20 21 namespace llvm { 22 namespace orc { 23 namespace rpc { 24 25 /// Interface for byte-streams to be used with RPC. 26 class RawByteChannel { 27 public: 28 virtual ~RawByteChannel() = default; 29 30 /// Read Size bytes from the stream into *Dst. 31 virtual Error readBytes(char *Dst, unsigned Size) = 0; 32 33 /// Read size bytes from *Src and append them to the stream. 34 virtual Error appendBytes(const char *Src, unsigned Size) = 0; 35 36 /// Flush the stream if possible. 37 virtual Error send() = 0; 38 39 /// Notify the channel that we're starting a message send. 40 /// Locks the channel for writing. 41 template <typename FunctionIdT, typename SequenceIdT> startSendMessage(const FunctionIdT & FnId,const SequenceIdT & SeqNo)42 Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { 43 writeLock.lock(); 44 if (auto Err = serializeSeq(*this, FnId, SeqNo)) { 45 writeLock.unlock(); 46 return Err; 47 } 48 return Error::success(); 49 } 50 51 /// Notify the channel that we're ending a message send. 52 /// Unlocks the channel for writing. endSendMessage()53 Error endSendMessage() { 54 writeLock.unlock(); 55 return Error::success(); 56 } 57 58 /// Notify the channel that we're starting a message receive. 59 /// Locks the channel for reading. 60 template <typename FunctionIdT, typename SequenceNumberT> startReceiveMessage(FunctionIdT & FnId,SequenceNumberT & SeqNo)61 Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { 62 readLock.lock(); 63 if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { 64 readLock.unlock(); 65 return Err; 66 } 67 return Error::success(); 68 } 69 70 /// Notify the channel that we're ending a message receive. 71 /// Unlocks the channel for reading. endReceiveMessage()72 Error endReceiveMessage() { 73 readLock.unlock(); 74 return Error::success(); 75 } 76 77 /// Get the lock for stream reading. getReadLock()78 std::mutex &getReadLock() { return readLock; } 79 80 /// Get the lock for stream writing. getWriteLock()81 std::mutex &getWriteLock() { return writeLock; } 82 83 private: 84 std::mutex readLock, writeLock; 85 }; 86 87 template <typename ChannelT, typename T> 88 class SerializationTraits< 89 ChannelT, T, T, 90 std::enable_if_t< 91 std::is_base_of<RawByteChannel, ChannelT>::value && 92 (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || 93 std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || 94 std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || 95 std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || 96 std::is_same<T, char>::value)>> { 97 public: serialize(ChannelT & C,T V)98 static Error serialize(ChannelT &C, T V) { 99 support::endian::byte_swap<T, support::big>(V); 100 return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); 101 }; 102 deserialize(ChannelT & C,T & V)103 static Error deserialize(ChannelT &C, T &V) { 104 if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) 105 return Err; 106 support::endian::byte_swap<T, support::big>(V); 107 return Error::success(); 108 }; 109 }; 110 111 template <typename ChannelT> 112 class SerializationTraits< 113 ChannelT, bool, bool, 114 std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { 115 public: serialize(ChannelT & C,bool V)116 static Error serialize(ChannelT &C, bool V) { 117 uint8_t Tmp = V ? 1 : 0; 118 if (auto Err = 119 C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) 120 return Err; 121 return Error::success(); 122 } 123 deserialize(ChannelT & C,bool & V)124 static Error deserialize(ChannelT &C, bool &V) { 125 uint8_t Tmp = 0; 126 if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) 127 return Err; 128 V = Tmp != 0; 129 return Error::success(); 130 } 131 }; 132 133 template <typename ChannelT> 134 class SerializationTraits< 135 ChannelT, std::string, StringRef, 136 std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { 137 public: 138 /// RPC channel serialization for std::strings. serialize(RawByteChannel & C,StringRef S)139 static Error serialize(RawByteChannel &C, StringRef S) { 140 if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) 141 return Err; 142 return C.appendBytes((const char *)S.data(), S.size()); 143 } 144 }; 145 146 template <typename ChannelT, typename T> 147 class SerializationTraits< 148 ChannelT, std::string, T, 149 std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value && 150 (std::is_same<T, const char *>::value || 151 std::is_same<T, char *>::value)>> { 152 public: serialize(RawByteChannel & C,const char * S)153 static Error serialize(RawByteChannel &C, const char *S) { 154 return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, 155 S); 156 } 157 }; 158 159 template <typename ChannelT> 160 class SerializationTraits< 161 ChannelT, std::string, std::string, 162 std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { 163 public: 164 /// RPC channel serialization for std::strings. serialize(RawByteChannel & C,const std::string & S)165 static Error serialize(RawByteChannel &C, const std::string &S) { 166 return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, 167 S); 168 } 169 170 /// RPC channel deserialization for std::strings. deserialize(RawByteChannel & C,std::string & S)171 static Error deserialize(RawByteChannel &C, std::string &S) { 172 uint64_t Count = 0; 173 if (auto Err = deserializeSeq(C, Count)) 174 return Err; 175 S.resize(Count); 176 return C.readBytes(&S[0], Count); 177 } 178 }; 179 180 } // end namespace rpc 181 } // end namespace orc 182 } // end namespace llvm 183 184 #endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H 185