1 #include <sys/types.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <stdint.h>
5 #include <assert.h>
6
7 #include "inc.h"
8 #include "tcpcrypt_ctl.h"
9 #include "tcpcrypt.h"
10 #include "tcpcryptd.h"
11 #include "checksum.h"
12 #include "config.h"
13
14 typedef __signed__ char __s8;
15 typedef unsigned char __u8;
16
17 typedef __signed__ short __s16;
18 typedef unsigned short __u16;
19
20 typedef __signed__ int __s32;
21 typedef unsigned int __u32;
22
23 typedef __u16 __sum16;
24 typedef __u32 __wsum;
25
26 typedef __u32 u32;
27 typedef u32 __be32;
28
29 # define __force
30
31 extern unsigned int csum_partial(const unsigned char * buff, int len,
32 unsigned int sum);
33
34 #ifdef NO_ASM
35 static int _use_linux = 0;
36
csum_partial(const unsigned char * buff,int len,unsigned int sum)37 unsigned int csum_partial(const unsigned char * buff, int len, unsigned int sum)
38 {
39 abort();
40 }
41 #else
42 static int _use_linux = 1;
43 #endif /* ! NO_ASM */
44
45 struct tcp_ph {
46 struct in_addr ph_src;
47 struct in_addr ph_dst;
48 uint8_t ph_zero;
49 uint8_t ph_proto;
50 uint16_t ph_len;
51 };
52
53
in_cksum(struct tcp_ph * ph,unsigned short * ptr,int nbytes,int s)54 static unsigned short in_cksum(struct tcp_ph *ph, unsigned short *ptr,
55 int nbytes, int s)
56 {
57 register long sum;
58 u_short oddbyte;
59 register u_short answer;
60
61 sum = s;
62
63 if (ph) {
64 unsigned short *p = (unsigned short*) ph;
65 int i;
66
67 for (i = 0; i < sizeof(*ph) >> 1; i++)
68 sum += *p++;
69 }
70
71 while (nbytes > 1)
72 {
73 sum += *ptr++;
74 nbytes -= 2;
75 }
76
77 if (nbytes == 1)
78 {
79 oddbyte = 0;
80 *((u_char *) & oddbyte) = *(u_char *) ptr;
81 sum += oddbyte;
82 }
83
84 sum = (sum >> 16) + (sum & 0xffff);
85 sum += (sum >> 16);
86 answer = ~sum;
87 return (answer);
88 }
89
checksum_ip_generic(struct ip * ip)90 static void checksum_ip_generic(struct ip *ip)
91 {
92 ip->ip_sum = 0;
93 ip->ip_sum = in_cksum(NULL, (unsigned short*) ip, sizeof(*ip), 0);
94 }
95
checksum_tcp_generic(struct ip * ip,struct tcphdr * tcp,int sum)96 static void checksum_tcp_generic(struct ip *ip, struct tcphdr *tcp, int sum)
97 {
98 struct tcp_ph ph;
99 int len;
100
101 len = ntohs(ip->ip_len) - (ip->ip_hl << 2);
102
103 ph.ph_src = ip->ip_src;
104 ph.ph_dst = ip->ip_dst;
105 ph.ph_zero = 0;
106 ph.ph_proto = ip->ip_p;
107 ph.ph_len = htons(len);
108
109 if (sum != 0)
110 len = tcp->th_off << 2;
111
112 tcp->th_sum = 0;
113 tcp->th_sum = in_cksum(&ph, (unsigned short*) tcp, len, sum);
114 }
115
csum_fold(__wsum sum)116 static inline __sum16 csum_fold(__wsum sum)
117 {
118 asm("addl %1, %0 ;\n"
119 "adcl $0xffff, %0 ;\n"
120 : "=r" (sum)
121 : "r" ((__force u32)sum << 16),
122 "0" ((__force u32)sum & 0xffff0000));
123 return (__force __sum16)(~(__force u32)sum >> 16);
124 }
125
csum_tcpudp_nofold(__be32 saddr,__be32 daddr,unsigned short len,unsigned short proto,__wsum sum)126 static inline __wsum csum_tcpudp_nofold(__be32 saddr, __be32 daddr,
127 unsigned short len,
128 unsigned short proto,
129 __wsum sum)
130 {
131 asm("addl %1, %0 ;\n"
132 "adcl %2, %0 ;\n"
133 "adcl %3, %0 ;\n"
134 "adcl $0, %0 ;\n"
135 : "=r" (sum)
136 : "g" (daddr), "g"(saddr),
137 "g" ((len + proto) << 8), "0" (sum));
138 return sum;
139 }
140
141 /*
142 * computes the checksum of the TCP/UDP pseudo-header
143 * returns a 16-bit checksum, already complemented
144 */
csum_tcpudp_magic(__be32 saddr,__be32 daddr,unsigned short len,unsigned short proto,__wsum sum)145 static inline __sum16 csum_tcpudp_magic(__be32 saddr, __be32 daddr,
146 unsigned short len,
147 unsigned short proto,
148 __wsum sum)
149 {
150 return csum_fold(csum_tcpudp_nofold(saddr, daddr, len, proto, sum));
151 }
152
checksum_tcp_linux(struct tc * tc,struct ip * ip,struct tcphdr * tcp)153 static void checksum_tcp_linux(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
154 {
155 int len = ntohs(ip->ip_len) - (ip->ip_hl << 2);
156 int p;
157 int sum = tc->tc_csum;
158
159 tcp->th_sum = 0;
160
161 if (sum) {
162 sum = (sum >> 16) + (sum & 0xffff);
163 sum += (sum >> 16);
164 sum &= 0xffff;
165
166 p = csum_partial((unsigned char*) tcp, tcp->th_off << 2, sum);
167 } else
168 p = csum_partial((unsigned char*) tcp, len, 0);
169
170 tcp->th_sum = csum_tcpudp_magic(ip->ip_src.s_addr,
171 ip->ip_dst.s_addr,
172 len,
173 IPPROTO_TCP,
174 p);
175 }
176
checksum_tcp(struct tc * tc,struct ip * ip,struct tcphdr * tcp)177 void checksum_tcp(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
178 {
179 if (tc && _use_linux)
180 checksum_tcp_linux(tc, ip, tcp);
181 else
182 checksum_tcp_generic(ip, tcp, 0);
183 }
184
ip_compute_csum(const void * buff,int len)185 static inline __sum16 ip_compute_csum(const void *buff, int len)
186 {
187 return csum_fold(csum_partial(buff, len, 0));
188 }
189
checksum_ip_linux(struct ip * ip)190 static void checksum_ip_linux(struct ip *ip)
191 {
192 ip->ip_sum = 0;
193 ip->ip_sum = ip_compute_csum(ip, ip->ip_hl << 2);
194 }
195
checksum_ip(struct ip * ip)196 void checksum_ip(struct ip *ip)
197 {
198 if (_use_linux)
199 checksum_ip_linux(ip);
200 else
201 checksum_ip_generic(ip);
202 }
203
checksum(void * data,int len)204 uint16_t checksum(void *data, int len)
205 {
206 if (_use_linux)
207 return ip_compute_csum(data, len);
208 else
209 return in_cksum(NULL, data, len, 0);
210 }
211