1 #include <sys/types.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <stdint.h>
5 #include <assert.h>
6 
7 #include "inc.h"
8 #include "tcpcrypt_ctl.h"
9 #include "tcpcrypt.h"
10 #include "tcpcryptd.h"
11 #include "checksum.h"
12 #include "config.h"
13 
14 typedef __signed__ char __s8;
15 typedef unsigned char __u8;
16 
17 typedef __signed__ short __s16;
18 typedef unsigned short __u16;
19 
20 typedef __signed__ int __s32;
21 typedef unsigned int __u32;
22 
23 typedef __u16 __sum16;
24 typedef __u32 __wsum;
25 
26 typedef __u32 u32;
27 typedef u32 __be32;
28 
29 # define __force
30 
31 extern unsigned int csum_partial(const unsigned char * buff, int len,
32 				 unsigned int sum);
33 
34 #ifdef NO_ASM
35 static int _use_linux = 0;
36 
csum_partial(const unsigned char * buff,int len,unsigned int sum)37 unsigned int csum_partial(const unsigned char * buff, int len, unsigned int sum)
38 {
39 	abort();
40 }
41 #else
42 static int _use_linux = 1;
43 #endif /* ! NO_ASM */
44 
45 struct tcp_ph {
46         struct in_addr  ph_src;
47         struct in_addr  ph_dst;
48         uint8_t         ph_zero;
49         uint8_t         ph_proto;
50         uint16_t        ph_len;
51 };
52 
53 
in_cksum(struct tcp_ph * ph,unsigned short * ptr,int nbytes,int s)54 static unsigned short in_cksum(struct tcp_ph *ph, unsigned short *ptr,
55                                int nbytes, int s)
56 {
57   register long sum;
58   u_short oddbyte;
59   register u_short answer;
60 
61   sum = s;
62 
63   if (ph) {
64         unsigned short *p = (unsigned short*) ph;
65         int i;
66 
67         for (i = 0; i < sizeof(*ph) >> 1; i++)
68                 sum += *p++;
69   }
70 
71   while (nbytes > 1)
72     {
73       sum += *ptr++;
74       nbytes -= 2;
75     }
76 
77   if (nbytes == 1)
78     {
79       oddbyte = 0;
80       *((u_char *) & oddbyte) = *(u_char *) ptr;
81       sum += oddbyte;
82     }
83 
84   sum = (sum >> 16) + (sum & 0xffff);
85   sum += (sum >> 16);
86   answer = ~sum;
87   return (answer);
88 }
89 
checksum_ip_generic(struct ip * ip)90 static void checksum_ip_generic(struct ip *ip)
91 {
92         ip->ip_sum = 0;
93         ip->ip_sum = in_cksum(NULL, (unsigned short*) ip, sizeof(*ip), 0);
94 }
95 
checksum_tcp_generic(struct ip * ip,struct tcphdr * tcp,int sum)96 static void checksum_tcp_generic(struct ip *ip, struct tcphdr *tcp, int sum)
97 {
98         struct tcp_ph ph;
99 	int len;
100 
101 	len = ntohs(ip->ip_len) - (ip->ip_hl << 2);
102 
103         ph.ph_src   = ip->ip_src;
104         ph.ph_dst   = ip->ip_dst;
105         ph.ph_zero  = 0;
106         ph.ph_proto = ip->ip_p;
107         ph.ph_len   = htons(len);
108 
109 	if (sum != 0)
110 		len = tcp->th_off << 2;
111 
112         tcp->th_sum = 0;
113         tcp->th_sum = in_cksum(&ph, (unsigned short*) tcp, len, sum);
114 }
115 
csum_fold(__wsum sum)116 static inline __sum16 csum_fold(__wsum sum)
117 {
118         asm("addl %1, %0                ;\n"
119             "adcl $0xffff, %0   ;\n"
120             : "=r" (sum)
121             : "r" ((__force u32)sum << 16),
122               "0" ((__force u32)sum & 0xffff0000));
123         return (__force __sum16)(~(__force u32)sum >> 16);
124 }
125 
csum_tcpudp_nofold(__be32 saddr,__be32 daddr,unsigned short len,unsigned short proto,__wsum sum)126 static inline __wsum csum_tcpudp_nofold(__be32 saddr, __be32 daddr,
127                                         unsigned short len,
128                                         unsigned short proto,
129                                         __wsum sum)
130 {
131         asm("addl %1, %0        ;\n"
132             "adcl %2, %0        ;\n"
133             "adcl %3, %0        ;\n"
134             "adcl $0, %0        ;\n"
135             : "=r" (sum)
136             : "g" (daddr), "g"(saddr),
137               "g" ((len + proto) << 8), "0" (sum));
138         return sum;
139 }
140 
141 /*
142  * computes the checksum of the TCP/UDP pseudo-header
143  * returns a 16-bit checksum, already complemented
144  */
csum_tcpudp_magic(__be32 saddr,__be32 daddr,unsigned short len,unsigned short proto,__wsum sum)145 static inline __sum16 csum_tcpudp_magic(__be32 saddr, __be32 daddr,
146                                         unsigned short len,
147                                         unsigned short proto,
148                                         __wsum sum)
149 {
150         return csum_fold(csum_tcpudp_nofold(saddr, daddr, len, proto, sum));
151 }
152 
checksum_tcp_linux(struct tc * tc,struct ip * ip,struct tcphdr * tcp)153 static void checksum_tcp_linux(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
154 {
155 	int len = ntohs(ip->ip_len) - (ip->ip_hl << 2);
156 	int p;
157 	int sum = tc->tc_csum;
158 
159 	tcp->th_sum = 0;
160 
161 	if (sum) {
162   		sum  = (sum >> 16) + (sum & 0xffff);
163   		sum += (sum >> 16);
164 		sum &= 0xffff;
165 
166 		p = csum_partial((unsigned char*) tcp, tcp->th_off << 2, sum);
167 	} else
168 		p = csum_partial((unsigned char*) tcp, len, 0);
169 
170 	tcp->th_sum = csum_tcpudp_magic(ip->ip_src.s_addr,
171 					ip->ip_dst.s_addr,
172 					len,
173 					IPPROTO_TCP,
174 					p);
175 }
176 
checksum_tcp(struct tc * tc,struct ip * ip,struct tcphdr * tcp)177 void checksum_tcp(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
178 {
179 	if (tc && _use_linux)
180 		checksum_tcp_linux(tc, ip, tcp);
181 	else
182 		checksum_tcp_generic(ip, tcp, 0);
183 }
184 
ip_compute_csum(const void * buff,int len)185 static inline __sum16 ip_compute_csum(const void *buff, int len)
186 {
187     return csum_fold(csum_partial(buff, len, 0));
188 }
189 
checksum_ip_linux(struct ip * ip)190 static void checksum_ip_linux(struct ip *ip)
191 {
192 	ip->ip_sum = 0;
193 	ip->ip_sum = ip_compute_csum(ip, ip->ip_hl << 2);
194 }
195 
checksum_ip(struct ip * ip)196 void checksum_ip(struct ip *ip)
197 {
198 	if (_use_linux)
199 		checksum_ip_linux(ip);
200 	else
201 		checksum_ip_generic(ip);
202 }
203 
checksum(void * data,int len)204 uint16_t checksum(void *data, int len)
205 {
206 	if (_use_linux)
207 		return ip_compute_csum(data, len);
208 	else
209 		return in_cksum(NULL, data, len, 0);
210 }
211