1// Copyright 2012 Google, Inc. All rights reserved.
2// Copyright 2009-2011 Andreas Krennmair. All rights reserved.
3//
4// Use of this source code is governed by a BSD-style license
5// that can be found in the LICENSE file in the root of the source
6// tree.
7
8package layers
9
10import (
11	"errors"
12	"fmt"
13
14	"github.com/google/gopacket"
15)
16
17// Checksum computation for TCP/UDP.
18type tcpipchecksum struct {
19	pseudoheader tcpipPseudoHeader
20}
21
22type tcpipPseudoHeader interface {
23	pseudoheaderChecksum() (uint32, error)
24}
25
26func (ip *IPv4) pseudoheaderChecksum() (csum uint32, err error) {
27	if err := ip.AddressTo4(); err != nil {
28		return 0, err
29	}
30	csum += (uint32(ip.SrcIP[0]) + uint32(ip.SrcIP[2])) << 8
31	csum += uint32(ip.SrcIP[1]) + uint32(ip.SrcIP[3])
32	csum += (uint32(ip.DstIP[0]) + uint32(ip.DstIP[2])) << 8
33	csum += uint32(ip.DstIP[1]) + uint32(ip.DstIP[3])
34	return csum, nil
35}
36
37func (ip *IPv6) pseudoheaderChecksum() (csum uint32, err error) {
38	if err := ip.AddressTo16(); err != nil {
39		return 0, err
40	}
41	for i := 0; i < 16; i += 2 {
42		csum += uint32(ip.SrcIP[i]) << 8
43		csum += uint32(ip.SrcIP[i+1])
44		csum += uint32(ip.DstIP[i]) << 8
45		csum += uint32(ip.DstIP[i+1])
46	}
47	return csum, nil
48}
49
50// Calculate the TCP/IP checksum defined in rfc1071.  The passed-in csum is any
51// initial checksum data that's already been computed.
52func tcpipChecksum(data []byte, csum uint32) uint16 {
53	// to handle odd lengths, we loop to length - 1, incrementing by 2, then
54	// handle the last byte specifically by checking against the original
55	// length.
56	length := len(data) - 1
57	for i := 0; i < length; i += 2 {
58		// For our test packet, doing this manually is about 25% faster
59		// (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16.
60		csum += uint32(data[i]) << 8
61		csum += uint32(data[i+1])
62	}
63	if len(data)%2 == 1 {
64		csum += uint32(data[length]) << 8
65	}
66	for csum > 0xffff {
67		csum = (csum >> 16) + (csum & 0xffff)
68	}
69	return ^uint16(csum)
70}
71
72// computeChecksum computes a TCP or UDP checksum.  headerAndPayload is the
73// serialized TCP or UDP header plus its payload, with the checksum zero'd
74// out. headerProtocol is the IP protocol number of the upper-layer header.
75func (c *tcpipchecksum) computeChecksum(headerAndPayload []byte, headerProtocol IPProtocol) (uint16, error) {
76	if c.pseudoheader == nil {
77		return 0, errors.New("TCP/IP layer 4 checksum cannot be computed without network layer... call SetNetworkLayerForChecksum to set which layer to use")
78	}
79	length := uint32(len(headerAndPayload))
80	csum, err := c.pseudoheader.pseudoheaderChecksum()
81	if err != nil {
82		return 0, err
83	}
84	csum += uint32(headerProtocol)
85	csum += length & 0xffff
86	csum += length >> 16
87	return tcpipChecksum(headerAndPayload, csum), nil
88}
89
90// SetNetworkLayerForChecksum tells this layer which network layer is wrapping it.
91// This is needed for computing the checksum when serializing, since TCP/IP transport
92// layer checksums depends on fields in the IPv4 or IPv6 layer that contains it.
93// The passed in layer must be an *IPv4 or *IPv6.
94func (i *tcpipchecksum) SetNetworkLayerForChecksum(l gopacket.NetworkLayer) error {
95	switch v := l.(type) {
96	case *IPv4:
97		i.pseudoheader = v
98	case *IPv6:
99		i.pseudoheader = v
100	default:
101		return fmt.Errorf("cannot use layer type %v for tcp checksum network layer", l.LayerType())
102	}
103	return nil
104}
105