1 // license:BSD-3-Clause
2 // copyright-holders:Aaron Giles, Vas Crabb
3 //============================================================
4 //
5 //  winsocket.c - Windows socket (inet) access functions
6 //
7 //============================================================
8 
9 
10 #include "winfile.h"
11 
12 // MAME headers
13 #include "osdcore.h"
14 
15 #include <cassert>
16 #include <cstdio>
17 
18 // standard windows headers
19 #include <windows.h>
20 #include <winioctl.h>
21 #include <tchar.h>
22 #include <cstdlib>
23 #include <cctype>
24 
25 
26 namespace {
27 char const *const winfile_socket_identifier  = "socket.";
28 
29 
30 class win_osd_socket : public osd_file
31 {
32 public:
33 	win_osd_socket(win_osd_socket const &) = delete;
34 	win_osd_socket(win_osd_socket &&) = delete;
35 	win_osd_socket& operator=(win_osd_socket const &) = delete;
36 	win_osd_socket& operator=(win_osd_socket &&) = delete;
37 
win_osd_socket(SOCKET s,bool l)38 	win_osd_socket(SOCKET s, bool l)
39 		: m_socket(s)
40 		, m_listening(l)
41 	{
42 		assert(INVALID_SOCKET != m_socket);
43 	}
44 
~win_osd_socket()45 	virtual ~win_osd_socket() override
46 	{
47 		closesocket(m_socket);
48 	}
49 
read(void * buffer,std::uint64_t offset,std::uint32_t length,std::uint32_t & actual)50 	virtual error read(void *buffer, std::uint64_t offset, std::uint32_t length, std::uint32_t &actual) override
51 	{
52 		fd_set readfds;
53 		FD_ZERO(&readfds);
54 		FD_SET(m_socket, &readfds);
55 
56 		struct timeval timeout;
57 		timeout.tv_sec = timeout.tv_usec = 0;
58 
59 		if (select(m_socket + 1, &readfds, nullptr, nullptr, &timeout) < 0)
60 		{
61 			char line[80];
62 			std::sprintf(line, "win_read_socket : %s : %d ", __FILE__,  __LINE__);
63 			std::perror(line);
64 			return error::FAILURE;
65 		}
66 		else if (FD_ISSET(m_socket, &readfds))
67 		{
68 			if (!m_listening)
69 			{
70 				// connected socket
71 				int const result = recv(m_socket, (char*)buffer, length, 0);
72 				if (result < 0)
73 				{
74 					return wsa_error_to_file_error(WSAGetLastError());
75 				}
76 				else
77 				{
78 					actual = result;
79 					return error::NONE;
80 				}
81 			}
82 			else
83 			{
84 				// listening socket
85 				SOCKET const accepted = accept(m_socket, nullptr, nullptr);
86 				if (INVALID_SOCKET == accepted)
87 				{
88 					return wsa_error_to_file_error(WSAGetLastError());
89 				}
90 				else
91 				{
92 					closesocket(m_socket);
93 					m_socket = accepted;
94 					m_listening = false;
95 					actual = 0;
96 
97 					return error::NONE;
98 				}
99 			}
100 		}
101 		else
102 		{
103 			return error::FAILURE;
104 		}
105 	}
106 
write(void const * buffer,std::uint64_t offset,std::uint32_t length,std::uint32_t & actual)107 	virtual error write(void const *buffer, std::uint64_t offset, std::uint32_t length, std::uint32_t &actual) override
108 	{
109 		auto const result = send(m_socket, reinterpret_cast<const char *>(buffer), length, 0);
110 		if (result < 0)
111 			return wsa_error_to_file_error(WSAGetLastError());
112 
113 		actual = result;
114 		return error::NONE;
115 	}
116 
truncate(std::uint64_t offset)117 	virtual error truncate(std::uint64_t offset) override
118 	{
119 		// doesn't make sense for a socket
120 		return error::INVALID_ACCESS;
121 	}
122 
flush()123 	virtual error flush() override
124 	{
125 		// no buffers to flush
126 		return error::NONE;
127 	}
128 
wsa_error_to_file_error(int err)129 	static error wsa_error_to_file_error(int err)
130 	{
131 		switch (err)
132 		{
133 		case 0:                 return error::NONE;
134 		case WSAEACCES:         return error::ACCESS_DENIED;
135 		case WSAEADDRINUSE:     return error::ALREADY_OPEN;
136 		case WSAEADDRNOTAVAIL:  return error::NOT_FOUND;
137 		case WSAECONNREFUSED:   return error::NOT_FOUND;
138 		case WSAEHOSTUNREACH:   return error::NOT_FOUND;
139 		case WSAENETUNREACH:    return error::NOT_FOUND;
140 		default:                return error::FAILURE;
141 		}
142 	}
143 
144 private:
145 	SOCKET  m_socket;
146 	bool    m_listening;
147 };
148 
149 } // anonymous namespace
150 
151 
win_init_sockets()152 bool win_init_sockets()
153 {
154 	WSADATA wsaData;
155 	WORD const version = MAKEWORD(2, 0);
156 	int const error = WSAStartup(version, &wsaData);
157 
158 	// check for error
159 	if (error)
160 	{
161 		// error occurred
162 		return false;
163 	}
164 
165 	// check for correct version
166 	if (LOBYTE(wsaData.wVersion) != 2 || HIBYTE(wsaData.wVersion ) != 0)
167 	{
168 		// incorrect WinSock version
169 		WSACleanup();
170 		return false;
171 	}
172 
173 	// WinSock has been initialized
174 	return true;
175 }
176 
177 
win_cleanup_sockets()178 void win_cleanup_sockets()
179 {
180 	WSACleanup();
181 }
182 
183 
win_check_socket_path(std::string const & path)184 bool win_check_socket_path(std::string const &path)
185 {
186 	if (strncmp(path.c_str(), winfile_socket_identifier, strlen(winfile_socket_identifier)) == 0 &&
187 		strchr(path.c_str(), ':') != nullptr) return true;
188 	return false;
189 }
190 
191 
win_open_socket(std::string const & path,std::uint32_t openflags,osd_file::ptr & file,std::uint64_t & filesize)192 osd_file::error win_open_socket(std::string const &path, std::uint32_t openflags, osd_file::ptr &file, std::uint64_t &filesize)
193 {
194 	char hostname[256];
195 	int port;
196 	std::sscanf(&path[strlen(winfile_socket_identifier)], "%255[^:]:%d", hostname, &port);
197 
198 	struct hostent const *const localhost = gethostbyname(hostname);
199 	if (!localhost)
200 		return osd_file::error::NOT_FOUND;
201 
202 	struct sockaddr_in sai;
203 	memset(&sai, 0, sizeof(sai));
204 	sai.sin_family = AF_INET;
205 	sai.sin_port = htons(port);
206 	sai.sin_addr = *reinterpret_cast<struct in_addr *>(localhost->h_addr);
207 
208 	SOCKET sock = socket(AF_INET, SOCK_STREAM, 0);
209 	if (INVALID_SOCKET == sock)
210 		return win_osd_socket::wsa_error_to_file_error(WSAGetLastError());
211 
212 	int const flag = 1;
213 	if (setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&flag), sizeof(flag)) == SOCKET_ERROR)
214 	{
215 		int const err = WSAGetLastError();
216 		closesocket(sock);
217 		return win_osd_socket::wsa_error_to_file_error(err);
218 	}
219 
220 	// listening socket support
221 	if (openflags & OPEN_FLAG_CREATE)
222 	{
223 		//printf("Listening for client at '%s' on port '%d'\n", hostname, port);
224 		// bind socket...
225 		if (bind(sock, reinterpret_cast<struct sockaddr const *>(&sai), sizeof(struct sockaddr)) == SOCKET_ERROR)
226 		{
227 			int const err = WSAGetLastError();
228 			closesocket(sock);
229 			return win_osd_socket::wsa_error_to_file_error(err);
230 		}
231 
232 		// start to listen...
233 		if (listen(sock, 1) == SOCKET_ERROR)
234 		{
235 			int const err = WSAGetLastError();
236 			closesocket(sock);
237 			return win_osd_socket::wsa_error_to_file_error(err);
238 		}
239 
240 		// mark socket as "listening"
241 		try
242 		{
243 			file = std::make_unique<win_osd_socket>(sock, true);
244 			filesize = 0;
245 			return osd_file::error::NONE;
246 		}
247 		catch (...)
248 		{
249 			closesocket(sock);
250 			return osd_file::error::OUT_OF_MEMORY;
251 		}
252 	}
253 	else
254 	{
255 		//printf("Connecting to server '%s' on port '%d'\n", hostname, port);
256 		if (connect(sock, reinterpret_cast<struct sockaddr const *>(&sai), sizeof(struct sockaddr)) == SOCKET_ERROR)
257 		{
258 			closesocket(sock);
259 			return osd_file::error::ACCESS_DENIED; // have to return this value or bitb won't try to bind on connect failure
260 		}
261 		try
262 		{
263 			file = std::make_unique<win_osd_socket>(sock, false);
264 			filesize = 0;
265 			return osd_file::error::NONE;
266 		}
267 		catch (...)
268 		{
269 			closesocket(sock);
270 			return osd_file::error::OUT_OF_MEMORY;
271 		}
272 	}
273 }
274