1 /*-
2  ***********************************************************************
3  *
4  * $Id: socket.c,v 1.18 2012/01/07 07:56:14 mavrik Exp $
5  *
6  ***********************************************************************
7  *
8  * Copyright 2001-2012 The WebJob Project, All Rights Reserved.
9  *
10  ***********************************************************************
11  */
12 #include "all-includes.h"
13 
14 /*-
15  ***********************************************************************
16  *
17  * SocketCleanup
18  *
19  ***********************************************************************
20  */
21 void
SocketCleanup(SOCKET_CONTEXT * psSocketCTX)22 SocketCleanup(SOCKET_CONTEXT *psSocketCTX)
23 {
24   if (psSocketCTX != NULL)
25   {
26     if (psSocketCTX->iSocket != -1)
27     {
28 #ifdef WIN32
29       closesocket(psSocketCTX->iSocket);
30       WSACleanup();
31 #endif
32 #ifdef UNIX
33       close(psSocketCTX->iSocket);
34 #endif
35     }
36 #ifdef USE_SSL
37     if (psSocketCTX->iType == SOCKET_TYPE_SSL)
38     {
39       SslSessionCleanup(psSocketCTX->pssl);
40     }
41 #endif
42     free(psSocketCTX);
43   }
44 }
45 
46 
47 /*-
48  ***********************************************************************
49  *
50  * SocketConnect
51  *
52  ***********************************************************************
53  */
54 SOCKET_CONTEXT
SocketConnect(unsigned long ulIP,unsigned short usPort,int iType,void * psslCTX,char * pcError)55 *SocketConnect(unsigned long ulIP, unsigned short usPort, int iType, void *psslCTX, char *pcError)
56 {
57   const char          acRoutine[] = "SocketConnect()";
58 #ifdef USE_SSL
59   char                acLocalError[MESSAGE_SIZE] = "";
60 #endif
61   struct sockaddr_in  sServerAddr;
62   SOCKET_CONTEXT     *psSocketCTX;
63 
64 #ifdef WIN32
65   DWORD               dwStatus;
66   WORD                wVersion;
67   WSADATA             wsaData;
68 #endif
69 
70   /*-
71    *********************************************************************
72    *
73    * Initialize a socket context.
74    *
75    *********************************************************************
76    */
77   psSocketCTX = malloc(sizeof(SOCKET_CONTEXT));
78   if (psSocketCTX == NULL)
79   {
80     snprintf(pcError, MESSAGE_SIZE, "%s: malloc(): %s", acRoutine, strerror(errno));
81     return NULL;
82   }
83   memset(psSocketCTX, 0, sizeof(SOCKET_CONTEXT));
84   psSocketCTX->iSocket = -1;
85 
86   /*-
87    ***********************************************************************
88    *
89    * Create a socket, and open a TCP connection to the server.
90    *
91    ***********************************************************************
92    */
93   sServerAddr.sin_family = AF_INET;
94   sServerAddr.sin_addr.s_addr = ulIP;
95   sServerAddr.sin_port = htons(usPort);
96 
97 #ifdef WIN32
98   wVersion = (WORD)(1) | ((WORD)(1) << 8); /* MAKEWORD(1, 1) */
99   if ((dwStatus = WSAStartup(wVersion, &wsaData)) != 0)
100   {
101     snprintf(pcError, MESSAGE_SIZE, "%s: WSAStartup(): %u", acRoutine, dwStatus);
102     SocketCleanup(psSocketCTX);
103     return NULL;
104   }
105 
106   psSocketCTX->iSocket = socket(PF_INET, SOCK_STREAM, 0);
107   if (psSocketCTX->iSocket == INVALID_SOCKET)
108   {
109     snprintf(pcError, MESSAGE_SIZE, "%s: socket(): %u", acRoutine, WSAGetLastError());
110     SocketCleanup(psSocketCTX);
111     return NULL;
112   }
113 
114   if (connect(psSocketCTX->iSocket, (struct sockaddr *) & sServerAddr, sizeof(sServerAddr)) == SOCKET_ERROR)
115   {
116     snprintf(pcError, MESSAGE_SIZE, "%s: connect(): %u", acRoutine, WSAGetLastError());
117     SocketCleanup(psSocketCTX);
118     return NULL;
119   }
120 #endif
121 
122 #ifdef UNIX
123   psSocketCTX->iSocket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
124   if (psSocketCTX->iSocket == -1)
125   {
126     snprintf(pcError, MESSAGE_SIZE, "%s: socket(): %s", acRoutine, strerror(errno));
127     SocketCleanup(psSocketCTX);
128     return NULL;
129   }
130 
131   if (connect(psSocketCTX->iSocket, (struct sockaddr *) & sServerAddr, sizeof(sServerAddr)) == -1)
132   {
133     snprintf(pcError, MESSAGE_SIZE, "%s: connect(): %s", acRoutine, strerror(errno));
134     SocketCleanup(psSocketCTX);
135     return NULL;
136   }
137 #endif
138 
139 #ifdef USE_SSL
140   /*-
141    ***********************************************************************
142    *
143    * We have a TCP conncetion... begin SSL negotiations.
144    *
145    ***********************************************************************
146    */
147   if (iType == SOCKET_TYPE_SSL)
148   {
149     if (psslCTX != NULL)
150     {
151       psSocketCTX->psslCTX = (SSL_CTX *) psslCTX;
152       psSocketCTX->pssl = SslConnect(psSocketCTX->iSocket, psSocketCTX->psslCTX, acLocalError);
153       if (psSocketCTX->pssl == NULL)
154       {
155         snprintf(pcError, MESSAGE_SIZE, "%s: %s", acRoutine, acLocalError);
156         SocketCleanup(psSocketCTX);
157         return NULL;
158       }
159     }
160     else
161     {
162       snprintf(pcError, MESSAGE_SIZE, "%s: Undefined SSL_CTX.", acRoutine);
163       SocketCleanup(psSocketCTX);
164       return NULL;
165     }
166   }
167 #endif
168 
169   psSocketCTX->iType = iType;
170 
171   return psSocketCTX;
172 }
173 
174 
175 /*-
176  ***********************************************************************
177  *
178  * SocketRead
179  *
180  ***********************************************************************
181  */
182 int
SocketRead(SOCKET_CONTEXT * psSocketCTX,char * pcData,int iToRead,char * pcError)183 SocketRead(SOCKET_CONTEXT *psSocketCTX, char *pcData, int iToRead, char *pcError)
184 {
185   const char          acRoutine[] = "SocketRead()";
186 #ifdef USE_SSL
187   char                acLocalError[MESSAGE_SIZE] = "";
188 #endif
189   int                 iNRead;
190 
191   switch (psSocketCTX->iType)
192   {
193 #ifdef USE_SSL
194   case SOCKET_TYPE_SSL:
195     iNRead = SslRead(psSocketCTX->pssl, pcData, iToRead, acLocalError);
196     if (iNRead == -1)
197     {
198       snprintf(pcError, MESSAGE_SIZE, "%s: %s", acRoutine, acLocalError);
199     }
200     break;
201 #endif
202   default:
203     iNRead = recv(psSocketCTX->iSocket, pcData, iToRead, 0);
204     if (iNRead == -1)
205     {
206       snprintf(pcError, MESSAGE_SIZE, "%s: recv(): %s", acRoutine, strerror(errno));
207     }
208     break;
209   }
210 
211   return iNRead;
212 }
213 
214 
215 /*-
216  ***********************************************************************
217  *
218  * SocketWrite
219  *
220  ***********************************************************************
221  */
222 int
SocketWrite(SOCKET_CONTEXT * psSocketCTX,char * pcData,int iToSend,char * pcError)223 SocketWrite(SOCKET_CONTEXT *psSocketCTX, char *pcData, int iToSend, char *pcError)
224 {
225   const char          acRoutine[] = "SocketWrite()";
226 #ifdef USE_SSL
227   char                acLocalError[MESSAGE_SIZE] = "";
228 #endif
229   int                 iNSent;
230 
231   if (iToSend == 0)
232   {
233     return 0;
234   }
235 
236   switch (psSocketCTX->iType)
237   {
238 #ifdef USE_SSL
239   case SOCKET_TYPE_SSL:
240     iNSent = SslWrite(psSocketCTX->pssl, pcData, iToSend, acLocalError);
241     if (iNSent == -1)
242     {
243       snprintf(pcError, MESSAGE_SIZE, "%s: %s", acRoutine, acLocalError);
244     }
245     break;
246 #endif
247   default:
248     iNSent = send(psSocketCTX->iSocket, pcData, iToSend, 0);
249     if (iNSent == -1)
250     {
251       snprintf(pcError, MESSAGE_SIZE, "%s: send(): %s", acRoutine, strerror(errno));
252     }
253     break;
254   }
255 
256   return iNSent;
257 }
258