1 /*
2  * This file contains SW Implementation of checksum computation for IP,TCP,UDP
3  *
4  * Copyright (c) 2008-2017 Red Hat, Inc.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met :
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and / or other materials provided with the distribution.
14  * 3. Neither the names of the copyright holders nor the names of their contributors
15  *    may be used to endorse or promote products derived from this software
16  *    without specific prior written permission.
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 #include "ndis56common.h"
30 
31 #ifdef WPP_EVENT_TRACING
32 #include "sw-offload.tmh"
33 #endif
34 #include <sal.h>
35 
36 // till IP header size is 8 bit
37 #define MAX_SUPPORTED_IPV6_HEADERS  (256 - 4)
38 
39 typedef ULONG IPV6_ADDRESS[4];
40 
41 // IPv6 Header RFC 2460 (40 bytes)
42 typedef struct _tagIPv6Header {
43     UCHAR       ip6_ver_tc;            // traffic class(low nibble), version (high nibble)
44     UCHAR       ip6_tc_fl;             // traffic class(high nibble), flow label
45     USHORT      ip6_fl;                // flow label, the rest
46     USHORT      ip6_payload_len;       // length of following headers and payload
47     UCHAR       ip6_next_header;       // next header type
48     UCHAR       ip6_hoplimit;          // hop limit
49     IPV6_ADDRESS ip6_src_address;    //
50     IPV6_ADDRESS ip6_dst_address;    //
51 } IPv6Header;
52 
53 typedef union
54 {
55     IPv6Header v6;
56     IPv4Header v4;
57 } IPHeader;
58 
59 // IPv6 Header RFC 2460 (n*8 bytes)
60 typedef struct _tagIPv6ExtHeader {
61     UCHAR       ip6ext_next_header;     // next header type
62     UCHAR       ip6ext_hdr_len;         // length of this header in 8 bytes unit, not including first 8 bytes
63     USHORT      options;                //
64 } IPv6ExtHeader;
65 
66 // IP Pseudo Header RFC 768
67 typedef struct _tagIPv4PseudoHeader {
68     ULONG       ipph_src;               // Source address
69     ULONG       ipph_dest;              // Destination address
70     UCHAR       ipph_zero;              // 0
71     UCHAR       ipph_protocol;          // TCP/UDP
72     USHORT      ipph_length;            // TCP/UDP length
73 }tIPv4PseudoHeader;
74 
75 // IPv6 Pseudo Header RFC 2460
76 typedef struct _tagIPv6PseudoHeader {
77     IPV6_ADDRESS ipph_src;              // Source address
78     IPV6_ADDRESS ipph_dest;             // Destination address
79     ULONG        ipph_length;               // TCP/UDP length
80     UCHAR        z1;                // 0
81     UCHAR        z2;                // 0
82     UCHAR        z3;                // 0
83     UCHAR        ipph_protocol;             // TCP/UDP
84 }tIPv6PseudoHeader;
85 
86 
87 #define PROTOCOL_TCP                    6
88 #define PROTOCOL_UDP                    17
89 
90 
91 #define IP_HEADER_LENGTH(pHeader)  (((pHeader)->ip_verlen & 0x0F) << 2)
92 #define TCP_HEADER_LENGTH(pHeader) ((pHeader->tcp_flags & 0xF0) >> 2)
93 
94 
95 
CheckSumCalculator(ULONG val,PVOID buffer,ULONG len)96 static __inline USHORT CheckSumCalculator(ULONG val, PVOID buffer, ULONG len)
97 {
98     PUSHORT pus = (PUSHORT)buffer;
99     ULONG count = len >> 1;
100     while (count--) val += *pus++;
101     if (len & 1) val += (USHORT)*(PUCHAR)pus;
102     val = (((val >> 16) | (val << 16)) + val) >> 16;
103     return (USHORT)~val;
104 }
105 
106 
107 /******************************************
108     IP header checksum calculator
109 *******************************************/
CalculateIpChecksum(IPv4Header * pIpHeader)110 static __inline VOID CalculateIpChecksum(IPv4Header *pIpHeader)
111 {
112     pIpHeader->ip_xsum = 0;
113     pIpHeader->ip_xsum = CheckSumCalculator(0, pIpHeader, IP_HEADER_LENGTH(pIpHeader));
114 }
115 
116 static __inline tTcpIpPacketParsingResult
ProcessTCPHeader(tTcpIpPacketParsingResult _res,PVOID pIpHeader,ULONG len,USHORT ipHeaderSize)117 ProcessTCPHeader(tTcpIpPacketParsingResult _res, PVOID pIpHeader, ULONG len, USHORT ipHeaderSize)
118 {
119     ULONG tcpipDataAt;
120     tTcpIpPacketParsingResult res = _res;
121     tcpipDataAt = ipHeaderSize + sizeof(TCPHeader);
122     res.xxpStatus = ppresXxpIncomplete;
123     res.TcpUdp = ppresIsTCP;
124 
125     if (len >= tcpipDataAt)
126     {
127         TCPHeader *pTcpHeader = (TCPHeader *)RtlOffsetToPointer(pIpHeader, ipHeaderSize);
128         res.xxpStatus = ppresXxpKnown;
129         tcpipDataAt = ipHeaderSize + TCP_HEADER_LENGTH(pTcpHeader);
130         res.XxpIpHeaderSize = tcpipDataAt;
131     }
132     else
133     {
134         DPrintf(2, ("tcp: %d < min headers %d", len, tcpipDataAt));
135     }
136     return res;
137 }
138 
139 static __inline tTcpIpPacketParsingResult
ProcessUDPHeader(tTcpIpPacketParsingResult _res,PVOID pIpHeader,ULONG len,USHORT ipHeaderSize)140 ProcessUDPHeader(tTcpIpPacketParsingResult _res, PVOID pIpHeader, ULONG len, USHORT ipHeaderSize)
141 {
142     tTcpIpPacketParsingResult res = _res;
143     ULONG udpDataStart = ipHeaderSize + sizeof(UDPHeader);
144     res.xxpStatus = ppresXxpIncomplete;
145     res.TcpUdp = ppresIsUDP;
146     res.XxpIpHeaderSize = udpDataStart;
147     if (len >= udpDataStart)
148     {
149         UDPHeader *pUdpHeader = (UDPHeader *)RtlOffsetToPointer(pIpHeader, ipHeaderSize);
150         USHORT datagramLength = swap_short(pUdpHeader->udp_length);
151         res.xxpStatus = ppresXxpKnown;
152         // may be full or not, but the datagram length is known
153         DPrintf(2, ("udp: len %d, datagramLength %d", len, datagramLength));
154     }
155     return res;
156 }
157 
158 static __inline tTcpIpPacketParsingResult
QualifyIpPacket(IPHeader * pIpHeader,ULONG len)159 QualifyIpPacket(IPHeader *pIpHeader, ULONG len)
160 {
161     tTcpIpPacketParsingResult res;
162     UCHAR  ver_len = pIpHeader->v4.ip_verlen;
163     UCHAR  ip_version = (ver_len & 0xF0) >> 4;
164     USHORT ipHeaderSize = 0;
165     USHORT fullLength = 0;
166     res.value = 0;
167 
168     if (ip_version == 4)
169     {
170         ipHeaderSize = (ver_len & 0xF) << 2;
171         fullLength = swap_short(pIpHeader->v4.ip_length);
172         DPrintf(3, ("ip_version %d, ipHeaderSize %d, protocol %d, iplen %d",
173             ip_version, ipHeaderSize, pIpHeader->v4.ip_protocol, fullLength));
174         res.ipStatus = (ipHeaderSize >= sizeof(IPv4Header)) ? ppresIPV4 : ppresNotIP;
175         if (len < ipHeaderSize) res.ipCheckSum = ppresIPTooShort;
176         if (fullLength) {}
177         else
178         {
179             DPrintf(2, ("ip v.%d, iplen %d", ip_version, fullLength));
180         }
181     }
182     else if (ip_version == 6)
183     {
184         UCHAR nextHeader = pIpHeader->v6.ip6_next_header;
185         BOOLEAN bParsingDone = FALSE;
186         ipHeaderSize = sizeof(pIpHeader->v6);
187         res.ipStatus = ppresIPV6;
188         res.ipCheckSum = ppresCSOK;
189         fullLength = swap_short(pIpHeader->v6.ip6_payload_len);
190         fullLength += ipHeaderSize;
191         while (nextHeader != 59)
192         {
193             IPv6ExtHeader *pExt;
194             switch (nextHeader)
195             {
196                 case PROTOCOL_TCP:
197                     bParsingDone = TRUE;
198                     res.xxpStatus = ppresXxpKnown;
199                     res.TcpUdp = ppresIsTCP;
200                     res.xxpFull = len >= fullLength ? 1 : 0;
201                     res = ProcessTCPHeader(res, pIpHeader, len, ipHeaderSize);
202                     break;
203                 case PROTOCOL_UDP:
204                     bParsingDone = TRUE;
205                     res.xxpStatus = ppresXxpKnown;
206                     res.TcpUdp = ppresIsUDP;
207                     res.xxpFull = len >= fullLength ? 1 : 0;
208                     res = ProcessUDPHeader(res, pIpHeader, len, ipHeaderSize);
209                     break;
210                     //existing extended headers
211                 case 0:
212                     __fallthrough;
213                 case 60:
214                     __fallthrough;
215                 case 43:
216                     __fallthrough;
217                 case 44:
218                     __fallthrough;
219                 case 51:
220                     __fallthrough;
221                 case 50:
222                     __fallthrough;
223                 case 135:
224                     if (len >= ((ULONG)ipHeaderSize + 8))
225                     {
226                         pExt = (IPv6ExtHeader *)((PUCHAR)pIpHeader + ipHeaderSize);
227                         nextHeader = pExt->ip6ext_next_header;
228                         ipHeaderSize += 8;
229                         ipHeaderSize += pExt->ip6ext_hdr_len * 8;
230                     }
231                     else
232                     {
233                         DPrintf(0, ("[%s] ERROR: Break in the middle of ext. headers(len %d, hdr > %d)", __FUNCTION__, len, ipHeaderSize));
234                         res.ipStatus = ppresNotIP;
235                         bParsingDone = TRUE;
236                     }
237                     break;
238                     //any other protocol
239                 default:
240                     res.xxpStatus = ppresXxpOther;
241                     bParsingDone = TRUE;
242                     break;
243             }
244             if (bParsingDone)
245                 break;
246         }
247         if (ipHeaderSize <= MAX_SUPPORTED_IPV6_HEADERS)
248         {
249             DPrintf(3, ("ip_version %d, ipHeaderSize %d, protocol %d, iplen %d",
250                 ip_version, ipHeaderSize, nextHeader, fullLength));
251             res.ipHeaderSize = ipHeaderSize;
252         }
253         else
254         {
255             DPrintf(0, ("[%s] ERROR: IP chain is too large (%d)", __FUNCTION__, ipHeaderSize));
256             res.ipStatus = ppresNotIP;
257         }
258     }
259 
260     if (res.ipStatus == ppresIPV4)
261     {
262         res.ipHeaderSize = ipHeaderSize;
263         res.xxpFull = len >= fullLength ? 1 : 0;
264         // bit "more fragments" or fragment offset mean the packet is fragmented
265         res.IsFragment = (pIpHeader->v4.ip_offset & ~0xC0) != 0;
266         switch (pIpHeader->v4.ip_protocol)
267         {
268             case PROTOCOL_TCP:
269             {
270                 res = ProcessTCPHeader(res, pIpHeader, len, ipHeaderSize);
271             }
272             break;
273         case PROTOCOL_UDP:
274             {
275                 res = ProcessUDPHeader(res, pIpHeader, len, ipHeaderSize);
276             }
277             break;
278         default:
279             res.xxpStatus = ppresXxpOther;
280             break;
281         }
282     }
283     return res;
284 }
285 
GetXxpHeaderAndPayloadLen(IPHeader * pIpHeader,tTcpIpPacketParsingResult res)286 static __inline USHORT GetXxpHeaderAndPayloadLen(IPHeader *pIpHeader, tTcpIpPacketParsingResult res)
287 {
288     if (res.ipStatus == ppresIPV4)
289     {
290         USHORT headerLength = IP_HEADER_LENGTH(&pIpHeader->v4);
291         USHORT len = swap_short(pIpHeader->v4.ip_length);
292         return len - headerLength;
293     }
294     if (res.ipStatus == ppresIPV6)
295     {
296         USHORT fullLength = swap_short(pIpHeader->v6.ip6_payload_len);
297         return fullLength + sizeof(pIpHeader->v6) - (USHORT)res.ipHeaderSize;
298     }
299     return 0;
300 }
301 
CalculateIpv4PseudoHeaderChecksum(IPv4Header * pIpHeader,USHORT headerAndPayloadLen)302 static __inline USHORT CalculateIpv4PseudoHeaderChecksum(IPv4Header *pIpHeader, USHORT headerAndPayloadLen)
303 {
304     tIPv4PseudoHeader ipph;
305     USHORT checksum;
306     ipph.ipph_src  = pIpHeader->ip_src;
307     ipph.ipph_dest = pIpHeader->ip_dest;
308     ipph.ipph_zero = 0;
309     ipph.ipph_protocol = pIpHeader->ip_protocol;
310     ipph.ipph_length = swap_short(headerAndPayloadLen);
311     checksum = CheckSumCalculator(0, &ipph, sizeof(ipph));
312     return ~checksum;
313 }
314 
315 
CalculateIpv6PseudoHeaderChecksum(IPv6Header * pIpHeader,USHORT headerAndPayloadLen)316 static __inline USHORT CalculateIpv6PseudoHeaderChecksum(IPv6Header *pIpHeader, USHORT headerAndPayloadLen)
317 {
318     tIPv6PseudoHeader ipph;
319     USHORT checksum;
320     ipph.ipph_src[0]  = pIpHeader->ip6_src_address[0];
321     ipph.ipph_src[1]  = pIpHeader->ip6_src_address[1];
322     ipph.ipph_src[2]  = pIpHeader->ip6_src_address[2];
323     ipph.ipph_src[3]  = pIpHeader->ip6_src_address[3];
324     ipph.ipph_dest[0] = pIpHeader->ip6_dst_address[0];
325     ipph.ipph_dest[1] = pIpHeader->ip6_dst_address[1];
326     ipph.ipph_dest[2] = pIpHeader->ip6_dst_address[2];
327     ipph.ipph_dest[3] = pIpHeader->ip6_dst_address[3];
328     ipph.z1 = ipph.z2 = ipph.z3 = 0;
329     ipph.ipph_protocol = pIpHeader->ip6_next_header;
330     ipph.ipph_length = swap_short(headerAndPayloadLen);
331     checksum = CheckSumCalculator(0, &ipph, sizeof(ipph));
332     return ~checksum;
333 }
334 
CalculateIpPseudoHeaderChecksum(IPHeader * pIpHeader,tTcpIpPacketParsingResult res,USHORT headerAndPayloadLen)335 static __inline USHORT CalculateIpPseudoHeaderChecksum(IPHeader *pIpHeader,
336                                                        tTcpIpPacketParsingResult res,
337                                                        USHORT headerAndPayloadLen)
338 {
339     if (res.ipStatus == ppresIPV4)
340         return CalculateIpv4PseudoHeaderChecksum(&pIpHeader->v4, headerAndPayloadLen);
341     if (res.ipStatus == ppresIPV6)
342         return CalculateIpv6PseudoHeaderChecksum(&pIpHeader->v6, headerAndPayloadLen);
343     return 0;
344 }
345 
346 static __inline BOOLEAN
CompareNetCheckSumOnEndSystem(USHORT computedChecksum,USHORT arrivedChecksum)347 CompareNetCheckSumOnEndSystem(USHORT computedChecksum, USHORT arrivedChecksum)
348 {
349     //According to RFC 1624 sec. 3
350     //Checksum verification mechanism should treat 0xFFFF
351     //checksum value from received packet as 0x0000
352     if(arrivedChecksum == 0xFFFF)
353         arrivedChecksum = 0;
354 
355     return computedChecksum == arrivedChecksum;
356 }
357 
358 /******************************************
359   Calculates IP header checksum calculator
360   it can be already calculated
361   the header must be complete!
362 *******************************************/
363 static __inline tTcpIpPacketParsingResult
VerifyIpChecksum(IPv4Header * pIpHeader,tTcpIpPacketParsingResult known,BOOLEAN bFix)364 VerifyIpChecksum(
365     IPv4Header *pIpHeader,
366     tTcpIpPacketParsingResult known,
367     BOOLEAN bFix)
368 {
369     tTcpIpPacketParsingResult res = known;
370     if (res.ipCheckSum != ppresIPTooShort)
371     {
372         USHORT saved = pIpHeader->ip_xsum;
373         CalculateIpChecksum(pIpHeader);
374         res.ipCheckSum = CompareNetCheckSumOnEndSystem(pIpHeader->ip_xsum, saved) ? ppresCSOK : ppresCSBad;
375         if (!bFix)
376             pIpHeader->ip_xsum = saved;
377         else
378             res.fixedIpCS = res.ipCheckSum == ppresCSBad;
379     }
380     return res;
381 }
382 
383 /*********************************************
384 Calculates UDP checksum, assuming the checksum field
385 is initialized with pseudoheader checksum
386 **********************************************/
CalculateUdpChecksumGivenPseudoCS(UDPHeader * pUdpHeader,ULONG udpLength)387 static VOID CalculateUdpChecksumGivenPseudoCS(UDPHeader *pUdpHeader, ULONG udpLength)
388 {
389     pUdpHeader->udp_xsum = CheckSumCalculator(0, pUdpHeader, udpLength);
390 }
391 
392 /*********************************************
393 Calculates TCP checksum, assuming the checksum field
394 is initialized with pseudoheader checksum
395 **********************************************/
CalculateTcpChecksumGivenPseudoCS(TCPHeader * pTcpHeader,ULONG tcpLength)396 static __inline VOID CalculateTcpChecksumGivenPseudoCS(TCPHeader *pTcpHeader, ULONG tcpLength)
397 {
398     pTcpHeader->tcp_xsum = CheckSumCalculator(0, pTcpHeader, tcpLength);
399 }
400 
401 /************************************************
402 Checks (and fix if required) the TCP checksum
403 sets flags in result structure according to verification
404 TcpPseudoOK if valid pseudo CS was found
405 TcpOK if valid TCP checksum was found
406 ************************************************/
407 static __inline tTcpIpPacketParsingResult
VerifyTcpChecksum(IPHeader * pIpHeader,ULONG len,tTcpIpPacketParsingResult known,ULONG whatToFix)408 VerifyTcpChecksum( IPHeader *pIpHeader, ULONG len, tTcpIpPacketParsingResult known, ULONG whatToFix)
409 {
410     USHORT  phcs;
411     tTcpIpPacketParsingResult res = known;
412     TCPHeader *pTcpHeader = (TCPHeader *)RtlOffsetToPointer(pIpHeader, res.ipHeaderSize);
413     USHORT saved = pTcpHeader->tcp_xsum;
414     USHORT xxpHeaderAndPayloadLen = GetXxpHeaderAndPayloadLen(pIpHeader, res);
415     if (len >= res.ipHeaderSize)
416     {
417         phcs = CalculateIpPseudoHeaderChecksum(pIpHeader, res, xxpHeaderAndPayloadLen);
418         res.xxpCheckSum = CompareNetCheckSumOnEndSystem(phcs, saved) ?  ppresPCSOK : ppresCSBad;
419         if (res.xxpCheckSum != ppresPCSOK || whatToFix)
420         {
421             if (whatToFix & pcrFixPHChecksum)
422             {
423                 if (len >= (ULONG)(res.ipHeaderSize + sizeof(*pTcpHeader)))
424                 {
425                     pTcpHeader->tcp_xsum = phcs;
426                     res.fixedXxpCS = res.xxpCheckSum != ppresPCSOK;
427                 }
428                 else
429                     res.xxpStatus = ppresXxpIncomplete;
430             }
431             else if (res.xxpFull)
432             {
433                 //USHORT ipFullLength = swap_short(pIpHeader->v4.ip_length);
434                 pTcpHeader->tcp_xsum = phcs;
435                 CalculateTcpChecksumGivenPseudoCS(pTcpHeader, xxpHeaderAndPayloadLen);
436                 if (CompareNetCheckSumOnEndSystem(pTcpHeader->tcp_xsum, saved))
437                     res.xxpCheckSum = ppresCSOK;
438 
439                 if (!(whatToFix & pcrFixXxpChecksum))
440                     pTcpHeader->tcp_xsum = saved;
441                 else
442                     res.fixedXxpCS =
443                         res.xxpCheckSum == ppresCSBad || res.xxpCheckSum == ppresPCSOK;
444             }
445             else if (whatToFix)
446             {
447                 res.xxpStatus = ppresXxpIncomplete;
448             }
449         }
450         else if (res.xxpFull)
451         {
452             // we have correct PHCS and we do not need to fix anything
453             // there is a very small chance that it is also good TCP CS
454             // in such rare case we give a priority to TCP CS
455             CalculateTcpChecksumGivenPseudoCS(pTcpHeader, xxpHeaderAndPayloadLen);
456             if (CompareNetCheckSumOnEndSystem(pTcpHeader->tcp_xsum, saved))
457                 res.xxpCheckSum = ppresCSOK;
458             pTcpHeader->tcp_xsum = saved;
459         }
460     }
461     else
462         res.ipCheckSum = ppresIPTooShort;
463     return res;
464 }
465 
466 /************************************************
467 Checks (and fix if required) the UDP checksum
468 sets flags in result structure according to verification
469 UdpPseudoOK if valid pseudo CS was found
470 UdpOK if valid UDP checksum was found
471 ************************************************/
472 static __inline tTcpIpPacketParsingResult
VerifyUdpChecksum(IPHeader * pIpHeader,ULONG len,tTcpIpPacketParsingResult known,ULONG whatToFix)473 VerifyUdpChecksum( IPHeader *pIpHeader, ULONG len, tTcpIpPacketParsingResult known, ULONG whatToFix)
474 {
475     USHORT  phcs;
476     tTcpIpPacketParsingResult res = known;
477     UDPHeader *pUdpHeader = (UDPHeader *)RtlOffsetToPointer(pIpHeader, res.ipHeaderSize);
478     USHORT saved = pUdpHeader->udp_xsum;
479     USHORT xxpHeaderAndPayloadLen = GetXxpHeaderAndPayloadLen(pIpHeader, res);
480     if (len >= res.ipHeaderSize)
481     {
482         phcs = CalculateIpPseudoHeaderChecksum(pIpHeader, res, xxpHeaderAndPayloadLen);
483         res.xxpCheckSum = CompareNetCheckSumOnEndSystem(phcs, saved) ?  ppresPCSOK : ppresCSBad;
484         if (whatToFix & pcrFixPHChecksum)
485         {
486             if (len >= (ULONG)(res.ipHeaderSize + sizeof(UDPHeader)))
487             {
488                 pUdpHeader->udp_xsum = phcs;
489                 res.fixedXxpCS = res.xxpCheckSum != ppresPCSOK;
490             }
491             else
492                 res.xxpStatus = ppresXxpIncomplete;
493         }
494         else if (res.xxpCheckSum != ppresPCSOK || (whatToFix & pcrFixXxpChecksum))
495         {
496             if (res.xxpFull)
497             {
498                 pUdpHeader->udp_xsum = phcs;
499                 CalculateUdpChecksumGivenPseudoCS(pUdpHeader, xxpHeaderAndPayloadLen);
500                 if (CompareNetCheckSumOnEndSystem(pUdpHeader->udp_xsum, saved))
501                     res.xxpCheckSum = ppresCSOK;
502 
503                 if (!(whatToFix & pcrFixXxpChecksum))
504                     pUdpHeader->udp_xsum = saved;
505                 else
506                     res.fixedXxpCS =
507                         res.xxpCheckSum == ppresCSBad || res.xxpCheckSum == ppresPCSOK;
508             }
509             else
510                 res.xxpCheckSum = ppresXxpIncomplete;
511         }
512         else if (res.xxpFull)
513         {
514             // we have correct PHCS and we do not need to fix anything
515             // there is a very small chance that it is also good UDP CS
516             // in such rare case we give a priority to UDP CS
517             CalculateUdpChecksumGivenPseudoCS(pUdpHeader, xxpHeaderAndPayloadLen);
518             if (CompareNetCheckSumOnEndSystem(pUdpHeader->udp_xsum, saved))
519                 res.xxpCheckSum = ppresCSOK;
520             pUdpHeader->udp_xsum = saved;
521         }
522     }
523     else
524         res.ipCheckSum = ppresIPTooShort;
525 
526     return res;
527 }
528 
GetPacketCase(tTcpIpPacketParsingResult res)529 static LPCSTR __inline GetPacketCase(tTcpIpPacketParsingResult res)
530 {
531     static const char *const IPCaseName[4] = { "not tested", "Non-IP", "IPv4", "IPv6" };
532     if (res.xxpStatus == ppresXxpKnown) return res.TcpUdp == ppresIsTCP ?
533         (res.ipStatus == ppresIPV4 ? "TCPv4" : "TCPv6") :
534         (res.ipStatus == ppresIPV4 ? "UDPv4" : "UDPv6");
535     if (res.xxpStatus == ppresXxpIncomplete) return res.TcpUdp == ppresIsTCP ? "Incomplete TCP" : "Incomplete UDP";
536     if (res.xxpStatus == ppresXxpOther) return "IP";
537     return  IPCaseName[res.ipStatus];
538 }
539 
GetIPCSCase(tTcpIpPacketParsingResult res)540 static LPCSTR __inline GetIPCSCase(tTcpIpPacketParsingResult res)
541 {
542     static const char *const CSCaseName[4] = { "not tested", "(too short)", "OK", "Bad" };
543     return CSCaseName[res.ipCheckSum];
544 }
545 
GetXxpCSCase(tTcpIpPacketParsingResult res)546 static LPCSTR __inline GetXxpCSCase(tTcpIpPacketParsingResult res)
547 {
548     static const char *const CSCaseName[4] = { "-", "PCS", "CS", "Bad" };
549     return CSCaseName[res.xxpCheckSum];
550 }
551 
PrintOutParsingResult(tTcpIpPacketParsingResult res,int level,LPCSTR procname)552 static __inline VOID PrintOutParsingResult(
553     tTcpIpPacketParsingResult res,
554     int level,
555     LPCSTR procname)
556 {
557     DPrintf(level, ("[%s] %s packet IPCS %s%s, checksum %s%s", procname,
558         GetPacketCase(res),
559         GetIPCSCase(res),
560         res.fixedIpCS ? "(fixed)" : "",
561         GetXxpCSCase(res),
562         res.fixedXxpCS ? "(fixed)" : ""));
563 }
564 
ParaNdis_CheckSumVerify(PVOID buffer,ULONG size,ULONG flags,LPCSTR caller)565 tTcpIpPacketParsingResult ParaNdis_CheckSumVerify(PVOID buffer, ULONG size, ULONG flags, LPCSTR caller)
566 {
567     tTcpIpPacketParsingResult res = QualifyIpPacket(buffer, size);
568     if (res.ipStatus == ppresIPV4)
569     {
570         if (flags & pcrIpChecksum)
571             res = VerifyIpChecksum(buffer, res, (flags & pcrFixIPChecksum) != 0);
572         if(res.xxpStatus == ppresXxpKnown)
573         {
574             if (res.TcpUdp == ppresIsTCP) /* TCP */
575             {
576                 if(flags & pcrTcpV4Checksum)
577                 {
578                     res = VerifyTcpChecksum(buffer, size, res, flags & (pcrFixPHChecksum | pcrFixTcpV4Checksum));
579                 }
580             }
581             else /* UDP */
582             {
583                 if (flags & pcrUdpV4Checksum)
584                 {
585                     res = VerifyUdpChecksum(buffer, size, res, flags & (pcrFixPHChecksum | pcrFixUdpV4Checksum));
586                 }
587             }
588         }
589     }
590     else if (res.ipStatus == ppresIPV6)
591     {
592         if(res.xxpStatus == ppresXxpKnown)
593         {
594             if (res.TcpUdp == ppresIsTCP) /* TCP */
595             {
596                 if(flags & pcrTcpV6Checksum)
597                 {
598                     res = VerifyTcpChecksum(buffer, size, res, flags & (pcrFixPHChecksum | pcrFixTcpV6Checksum));
599                 }
600             }
601             else /* UDP */
602             {
603                 if (flags & pcrUdpV6Checksum)
604                 {
605                     res = VerifyUdpChecksum(buffer, size, res, flags & (pcrFixPHChecksum | pcrFixUdpV6Checksum));
606                 }
607             }
608         }
609     }
610     PrintOutParsingResult(res, 1, caller);
611     return res;
612 }
613 
ParaNdis_ReviewIPPacket(PVOID buffer,ULONG size,LPCSTR caller)614 tTcpIpPacketParsingResult ParaNdis_ReviewIPPacket(PVOID buffer, ULONG size, LPCSTR caller)
615 {
616     tTcpIpPacketParsingResult res = QualifyIpPacket(buffer, size);
617     PrintOutParsingResult(res, 1, caller);
618     return res;
619 }
620