TCP校验和的原理和实现

来源:互联网 发布:网络零售经纪模式 编辑:程序博客网 时间:2024/04/30 21:40

概述


TCP校验和是一个端到端的校验和,由发送端计算,然后由接收端验证。其目的是为了发现TCP首部和数据在发送端到接收端之间发生的任何改动。如果接收方检测到校验和有差错,则TCP段会被直接丢弃。

TCP校验和覆盖TCP首部和TCP数据,而IP首部中的校验和只覆盖IP的首部,不覆盖IP数据报中的任何数据。

TCP的校验和是必需的,而UDP的校验和是可选的。

TCP和UDP计算校验和时,都要加上一个12字节的伪首部。

 

Author : zhangskd @ csdn blog

 

伪首部

伪首部共有12字节,包含如下信息:源IP地址、目的IP地址、保留字节(置0)、传输层协议号(TCP是6)、TCP报文长度(报头+数据)。

伪首部是为了增加TCP校验和的检错能力:如检查TCP报文是否收错了(目的IP地址)、传输层协议是否选对了(传输层协议号)等。

 

定义

 

(1) RFC 793的TCP校验和定义

The checksum field is the 16 bit one's complement of the one's complement sum of all 16-bit words in the header and text.

If a segment contains an odd number of header and text octets to be checksummed, the last octet is padded on the right

with zeros to form a 16-bit word for checksum purposes. The pad is not transmitted as part of the segment. While computing

the checksum, the checksum field itself is replaced with zeros.

 (原码true code 补码complemental code 反码ones-complement code

上述的定义说得很明确:

首先,把伪首部、TCP报头、TCP数据分为16位的字,如果总长度为奇数个字节,则在最后增添一个位都为0的字节。

            把TCP报头中的校验和字段置为0(否则就陷入鸡生蛋还是蛋生鸡的问题)。

其次,用反码相加法累加所有的16位字(进位也要累加)。

最后,对计算结果取反,作为TCP的校验和。

 

(2) RFC 1071的IP校验和定义

1. Adjacent octets to be checksummed are paired to form 16-bit integers, and the 1's complement sum of these

    16-bit integers is formed.

2. To generate a checksum, the checksum field itself is cleared, the 16-bit 1's complement sum is computed over

    the octets concerned, and the 1's complement of this sum is placed in the checksum field.

3. To check a checksum, the 1's complement sum is computed over the same set of octets, including the checksum

    field. If the result is all 1 bits (-0 in 1's complement arithmetic), the check succeeds.

 

可以看到,TCP校验和、IP校验和的计算方法是基本一致的,除了计算的范围不同。

 

实现

 

基于2.6.18、x86_64。

csum_tcpudp_nofold()按4字节累加伪首部到sum中。

[java] view plaincopy
  1. static inline unsigned long csum_tcpudp_nofold (unsigned long saddr, unsigned long daddr,  
  2.                                                 unsigned short len, unsigned short proto,  
  3.                                                 unsigned int sum)  
  4. {  
  5.     asm("addl %1, %0\n"    /* 累加daddr */  
  6.         "adcl %2, %0\n"    /* 累加saddr */  
  7.         "adcl %3, %0\n"    /* 累加len(2字节), proto, 0*/  
  8.         "adcl $0, %0\n"    /*加上进位 */  
  9.         : "=r" (sum)  
  10.         : "g" (daddr), "g" (saddr), "g" ((ntohs(len) << 16) + proto*256), "0" (sum));  
  11.     return sum;  
  12. }   

 

csum_tcpudp_magic()产生最终的校验和。

首先,按4字节累加伪首部到sum中。

其次,累加sum的低16位、sum的高16位,并且对累加的结果取反。

最后,截取sum的高16位,作为校验和。

[java] view plaincopy
  1. static inline unsigned short int csum_tcpudp_magic(unsigned long saddr, unsigned long daddr,  
  2.                                                    unsigned short len, unsigned short proto,  
  3.                                                    unsigned int sum)  
  4. {  
  5.     return csum_fold(csum_tcpudp_nofold(saddr, daddr, len, proto, sum));  
  6. }  
  7.   
  8. static inline unsigned int csum_fold(unsigned int sum)  
  9. {  
  10.     __asm__(  
  11.         "addl %1, %0\n"  
  12.         "adcl 0xffff, %0"  
  13.         : "=r" (sum)  
  14.         : "r" (sum << 16), "0" (sum & 0xffff0000)   
  15.   
  16.         /* 将sum的低16位,作为寄存器1的高16位,寄存器1的低16位补0。 
  17.           * 将sum的高16位,作为寄存器0的高16位,寄存器0的低16位补0。 
  18.           * 这样,addl %1, %0就累加了sum的高16位和低16位。 
  19.           * 
  20.          * 还要考虑进位。如果有进位,adcl 0xfff, %0为:0x1 + 0xffff + %0,寄存器0的高16位加1。 
  21.           * 如果没有进位,adcl 0xffff, %0为:0xffff + %0,对寄存器0的高16位无影响。 
  22.           */  
  23.   
  24.     );  
  25.   
  26.     return (~sum) >> 16/* 对sum取反,返回它的高16位,作为最终的校验和 */  
  27. }  

 

发送校验

 

[java] view plaincopy
  1. #define CHECKSUM_NONE 0 /* 不使用校验和,UDP可选 */  
  2. #define CHECKSUM_HW 1 /* 由硬件计算报头和首部的校验和 */  
  3. #define CHECKSUM_UNNECESSARY 2 /* 表示不需要校验,或者已经成功校验了 */  
  4. #define CHECKSUM_PARTIAL CHECKSUM_HW  
  5. #define CHECKSUM_COMPLETE CHECKSUM_HW  

 

@tcp_transmit_skb()

    icsk->icsk_af_ops->send_check(sk, skb->len, skb); /* 计算校验和 */

 

[java] view plaincopy
  1. void tcp_v4_send_check(struct sock *sk, int len, struct sk_buff *skb)  
  2. {  
  3.     struct inet_sock *inet = inet_sk(sk);  
  4.     struct tcphdr *th = skb->h.th;  
  5.    
  6.     if (skb->ip_summed == CHECKSUM_HW) {  
  7.         /* 只计算伪首部,TCP报头和TCP数据的累加由硬件完成 */  
  8.         th->check = ~tcp_v4_check(th, len, inet->saddr, inet->daddr, 0);  
  9.         skb->csum = offsetof(struct tcphdr, check); /* 校验和值在TCP首部的偏移 */  
  10.   
  11.     } else {  
  12.         /* tcp_v4_check累加伪首部,获取最终的校验和。 
  13.          * csum_partial累加TCP报头。 
  14.          * 那么skb->csum应该是TCP数据部分的累加,这是在从用户空间复制时顺便累加的。 
  15.          */  
  16.         th->check = tcp_v4_check(th, len, inet->saddr, inet->daddr,  
  17.                                  csum_partial((char *)th, th->doff << 2, skb->csum));  
  18.     }  
  19. }  
[java] view plaincopy
  1. unsigned csum_partial(const unsigned char *buff, unsigned len, unsigned sum)  
  2. {  
  3.     return add32_with_carry(do_csum(buff, len), sum);  
  4. }  
  5.   
  6. static inline unsigned add32_with_carry(unsigned a, unsigned b)  
  7. {  
  8.     asm("addl %2, %0\n\t"  
  9.              "adcl $0, %0"  
  10.              : "=r" (a)  
  11.              : "0" (a), "r" (b));  
  12.     return a;  
  13. }   

 

do_csum()用于计算一段内存的校验和,这里用于累加TCP报头。

具体计算时用到一些技巧:

1. 反码累加时,按16位、32位、64位来累加的效果是一样的。

2. 使用内存对齐,减少内存操作的次数。

[java] view plaincopy
  1. static __force_inline unsigned do_csum(const unsigned char *buff, unsigned len)  
  2. {  
  3.     unsigned odd, count;  
  4.     unsigned long result = 0;  
  5.   
  6.     if (unlikely(len == 0))  
  7.         return result;  
  8.   
  9.     /* 使起始地址为XXX0,接下来可按2字节对齐 */  
  10.     odd = 1 & (unsigned long) buff;  
  11.     if (unlikely(odd)) {  
  12.         result = *buff << 8/* 因为机器是小端的 */  
  13.         len--;  
  14.         buff++;  
  15.     }  
  16.     count = len >> 1/* nr of 16-bit words,这里可能余下1字节未算,最后会处理*/  
  17.   
  18.     if (count) {  
  19.         /* 使起始地址为XX00,接下来可按4字节对齐 */  
  20.         if (2 & (unsigned long) buff) {  
  21.             result += *(unsigned short *)buff;  
  22.             count--;  
  23.             len -= 2;  
  24.             buff += 2;  
  25.         }  
  26.         count >>= 1/* nr of 32-bit words,这里可能余下2字节未算,最后会处理 */  
  27.   
  28.         if (count) {  
  29.             unsigned long zero;  
  30.             unsigned count64;  
  31.             /* 使起始地址为X000,接下来可按8字节对齐 */  
  32.             if (4 & (unsigned long)buff) {  
  33.                 result += *(unsigned int *)buff;  
  34.                 count--;  
  35.                 len -= 4;  
  36.                 buff += 4;  
  37.             }  
  38.             count >>= 1/* nr of 64-bit words,这里可能余下4字节未算,最后会处理*/  
  39.   
  40.             /* main loop using 64byte blocks */  
  41.             zero = 0;  
  42.             count64 = count >> 3/* 64字节的块数,这里可能余下56字节未算,最后会处理 */  
  43.             while (count64) { /* 反码累加所有的64字节块 */  
  44.                 asm ("addq 0*8(%[src]), %[res]\n\t"    /* b、w、l、q分别对应8、16、32、64位操作 */  
  45.                           "addq 1*8(%[src]), %[res]\n\t"    /* [src]为指定寄存器的别名,效果应该等同于0、1等 */  
  46.                           "adcq 2*8(%[src]), %[res]\n\t"  
  47.                           "adcq 3*8(%[src]), %[res]\n\t"  
  48.                           "adcq 4*8(%[src]), %[res]\n\t"  
  49.                           "adcq 5*8(%[src]), %[res]\n\t"  
  50.                           "adcq 6*8(%[src]), %[res]\n\t"  
  51.                           "adcq 7*8(%[src]), %[res]\n\t"  
  52.                           "adcq %[zero], %[res]"  
  53.                           : [res] "=r" (result)  
  54.                           : [src] "r" (buff), [zero] "r" (zero), "[res]" (result));  
  55.                 buff += 64;  
  56.                 count64--;  
  57.             }  
  58.   
  59.             /* 从这里开始,反序处理之前可能漏算的字节 */  
  60.   
  61.             /* last upto 7 8byte blocks,前面按8个8字节做计算单位,所以最多可能剩下7个8字节 */  
  62.             count %= 8;  
  63.             while (count) {  
  64.                 asm ("addq %1, %0\n\t"  
  65.                      "adcq %2, %0\n"  
  66.                      : "=r" (result)  
  67.                      : "m" (*(unsigned long *)buff), "r" (zero), "0" (result));  
  68.                 --count;  
  69.                 buff += 8;  
  70.             }  
  71.   
  72.             /* 带进位累加result的高32位和低32位 */  
  73.             result = add32_with_carry(result>>32, result&0xffffffff);  
  74.   
  75.             /* 之前始按8字节对齐,可能有4字节剩下 */  
  76.             if (len & 4) {  
  77.                 result += *(unsigned int *) buff;  
  78.                 buff += 4;  
  79.             }  
  80.         }  
  81.   
  82.        /* 更早前按4字节对齐,可能有2字节剩下 */  
  83.         if (len & 2) {  
  84.             result += *(unsigned short *) buff;  
  85.             buff += 2;  
  86.         }  
  87.     }  
  88.   
  89.     /* 最早之前按2字节对齐,可能有1字节剩下 */  
  90.     if (len & 1)  
  91.         result += *buff;  
  92.   
  93.     /* 再次带进位累加result的高32位和低32位 */  
  94.     result = add32_with_carry(result>>32, result & 0xffffffff);   
  95.   
  96.     /* 这里涉及到一个技巧,用于处理初始地址为奇数的情况 */  
  97.     if (unlikely(odd)) {  
  98.         result = from32to16(result); /* 累加到result的低16位 */  
  99.         /* result为:0 0 a b 
  100.          * 然后交换a和b,result变为:0 0 b a 
  101.          */  
  102.         result = ((result >> 8) & 0xff) | ((result & oxff) << 8);  
  103.     }  
  104.   
  105.     return result; /* 返回result的低32位 */  
  106. }  
[java] view plaincopy
  1. static inline unsigned short from32to16(unsigned a)  
  2. {  
  3.     unsigned short b = a >> 16;  
  4.     asm ("addw %w2, %w0\n\t"  
  5.               "adcw $0, %w0\n"  
  6.               : "=r" (b)  
  7.               : "0" (b), "r" (a));  
  8.     return b;  
  9. }  

 

csum_partial_copy_from_user()用于拷贝用户空间数据到内核空间,同时计算用户数据的校验和,

结果保存到skb->csum中(X86_64)。

[java] view plaincopy
  1. /** 
  2.  * csum_partial_copy_from_user - Copy and checksum from user space. 
  3.  * @src: source address (user space) 
  4.  * @dst: destination address 
  5.  * @len: number of bytes to be copied. 
  6.  * @isum: initial sum that is added into the result (32bit unfolded) 
  7.  * @errp: set to -EFAULT for an bad source address. 
  8.  * 
  9.  * Returns an 32bit unfolded checksum of the buffer. 
  10.  * src and dst are best aligned to 64bits. 
  11.  */  
  12.   
  13. unsigned int csum_partial_copy_from_user(const unsigned char __user *src,  
  14.                                   unsigned char *dst, int len, unsigned int isum, int *errp)  
  15. {  
  16.     might_sleep();  
  17.     *errp = 0;  
  18.   
  19.     if (likely(access_ok(VERIFY_READ, src, len))) {  
  20.   
  21.         /* Why 6, not 7? To handle odd addresses aligned we would need to do considerable 
  22.          * complications to fix the checksum which is defined as an 16bit accumulator. The fix 
  23.          * alignment code is primarily for performance compatibility with 32bit and that will handle 
  24.          * odd addresses slowly too. 
  25.          * 处理X010、X100、X110的起始地址。不处理X001,因为这会使复杂度大增加。 
  26.          */  
  27.         if (unlikely((unsigned long)src & 6)) {  
  28.             while (((unsigned long)src & 6) && len >= 2) {  
  29.                 __u16 val16;  
  30.                 *errp = __get_user(val16, (__u16 __user *)src);  
  31.                 if (*errp)  
  32.                     return isum;  
  33.                 *(__u16 *)dst = val16;  
  34.                 isum = add32_with_carry(isum, val16);  
  35.                 src += 2;  
  36.                 dst += 2;  
  37.                 len -= 2;  
  38.             }  
  39.         }  
  40.   
  41.         /* 计算函数是用纯汇编实现的,应该是因为效率吧 */  
  42.         isum = csum_parial_copy_generic((__force void *)src, dst, len, isum, errp, NULL);  
  43.   
  44.         if (likely(*errp == 0))  
  45.             return isum; /* 成功 */  
  46.     }  
  47.   
  48.     *errp = -EFAULT;  
  49.     memset(dst, 0, len);  
  50.     return isum;  
  51. }  

 

上述的实现比较复杂,来看下最简单的csum_partial_copy_from_user()实现(um)。

[java] view plaincopy
  1. unsigned int csum_partial_copy_from_user(const unsigned char *src,  
  2.                                          unsigned char *dst, int len, int sum,  
  3.                                          int *err_ptr)  
  4. {  
  5.     if (copy_from_user(dst, src, len)) { /* 拷贝用户空间数据到内核空间 */  
  6.         *err_ptr = -EFAULT; /* bad address */  
  7.         return (-1);  
  8.     }  
  9.   
  10.     return csum_partial(dst, len, sum); /* 计算用户数据的校验和,会存到skb->csum中 */  
  11. }  

 

接收校验

 

@tcp_v4_rcv

    /* 检查校验和 */

    if (skb->ip_summed != CHECKSUM_UNNECESSARY && tcp_v4_checksum_init(skb))

        goto bad_packet;   

 

接收校验的第一部分,主要是计算伪首部。

[java] view plaincopy
  1. static int tcp_v4_checksum_init(struct sk_buff *skb)  
  2. {  
  3.     /* 如果TCP报头、TCP数据的反码累加已经由硬件完成 */  
  4.     if (skb->ip_summed == CHECKSUM_HW) {  
  5.   
  6.         /* 现在只需要再累加上伪首部,取反获取最终的校验和。 
  7.          * 校验和为0时,表示TCP数据报正确。 
  8.          */  
  9.         if (! tcp_v4_check(skb->h.th, skb->len, skb->nh.iph->saddr, skb->nh.iph->daddr, skb->csum)) {  
  10.             skb->ip_summed = CHECKSUM_UNNECESSARY;  
  11.             return 0/* 校验成功 */  
  12.   
  13.         } /* 没有else失败退出吗?*/  
  14.     }  
  15.   
  16.     /* 对伪首部进行反码累加,主要用于软件方法 */  
  17.     skb->csum = csum_tcpudp_nofold(skb->nh.iph->saddr, skb->nh.iph->daddr, skb->len, IPPROTO_TCP, 0);  
  18.    
  19.   
  20.     /* 对于长度小于76字节的小包,接着累加TCP报头和报文,完成校验;否则,以后再完成检验。*/  
  21.     if (skb->len <= 76) {  
  22.         return __skb_checksum_complete(skb);  
  23.     }  
  24. }  

 

接收校验的第二部分,计算报头和报文。

tcp_v4_rcv、tcp_v4_do_rcv()

    | --> tcp_checksum_complete()

                | --> __tcp_checksum_complete()

                            | --> __skb_checksum_complete()

 

tcp_rcv_established()

    | --> tcp_checksum_complete_user()

                | --> __tcp_checksum_complete_user()

                            | --> __tcp_checksum_complete()

                                        | --> __skb_checksum_complete()

 

[java] view plaincopy
  1. unsigned int __skb_checksum_complete(struct sk_buff *skb)  
  2. {  
  3.     unsigned int sum;  
  4.   
  5.     sum = (u16) csum_fold(skb_checksum(skb, 0, skb->len, skb->csum));  
  6.   
  7.     if (likely(!sum)) { /* sum为0表示成功了 */  
  8.         /* 硬件检测失败,软件检测成功了,说明硬件检测有误 */  
  9.         if (unlikely(skb->ip_summed == CHECKSUM_HW))  
  10.             netdev_rx_csum_fault(skb->dev);  
  11.         skb->ip_summed = CHECKSUM_UNNECESSARY;  
  12.     }  
  13.     return sum;  
  14. }  

 

计算skb包的校验和时,可以指定相对于skb->data的偏移量offset。

由于skb包可能由分页和分段,所以需要考虑skb->data + offset是位于此skb段的线性区中、

还是此skb的分页中,或者位于其它分段中。这个函数逻辑比较复杂。

[java] view plaincopy
  1. /* Checksum skb data. */  
  2. unsigned int skb_checksum(const struct sk_buff *skb, int offset, int len, unsigned int csum)  
  3. {  
  4.     int start = skb_headlen(skb); /* 线性区域长度 */  
  5.     /* copy > 0,说明offset在线性区域中。 
  6.      * copy < 0,说明offset在此skb的分页数据中,或者在其它分段skb中。 
  7.      */  
  8.     int i, copy = start - offset;  
  9.     int pos = 0/* 表示校验了多少数据 */  
  10.   
  11.     /* Checksum header. */  
  12.     if (copy > 0) { /* 说明offset在本skb的线性区域中 */  
  13.         if (copy > len)  
  14.             copy = len; /* 不能超过指定的校验长度 */  
  15.   
  16.         /* 累加copy长度的线性区校验 */  
  17.         csum = csum_partial(skb->data + offset, copy, csum);  
  18.   
  19.         if ((len -= copy) == 0)  
  20.             return csum;  
  21.   
  22.         offset += copy; /* 接下来从这里继续处理 */  
  23.         pos = copy; /* 已处理数据长 */  
  24.     }  
  25.   
  26.     /* 累加本skb分页数据的校验和 */  
  27.     for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {  
  28.         int end;  
  29.         BUG_TRAP(start <= offset + len);  
  30.       
  31.         end = start + skb_shinfo(skb)->frags[i].size;  
  32.   
  33.         if ((copy = end - offset) > 0) { /* 如果offset位于本页中,或者线性区中 */  
  34.             unsigned int csum2;  
  35.             u8 *vaddr; /* 8位够吗?*/  
  36.             skb_frag_t *frag = &skb_shinfo(skb)->frags[i];  
  37.    
  38.             if (copy > len)  
  39.                 copy = len;  
  40.   
  41.             vaddr = kmap_skb_frag(frag); /* 把物理页映射到内核空间 */  
  42.             csum2 = csum_partial(vaddr + frag->page_offset + offset - start, copy, 0);  
  43.             kunmap_skb_frag(vaddr); /* 解除映射 */  
  44.   
  45.             /* 如果pos为奇数,需要对csum2进行处理。 
  46.              * csum2:a, b, c, d => b, a, d, c 
  47.              */  
  48.             csum = csum_block_add(csum, csum2, pos);  
  49.   
  50.             if (! (len -= copy))  
  51.                 return csum;  
  52.   
  53.             offset += copy;  
  54.             pos += copy;  
  55.         }  
  56.         start = end; /* 接下来从这里处理 */  
  57.     }  
  58.    
  59.     /* 如果此skb是个大包,还有其它分段 */  
  60.     if (skb_shinfo(skb)->frag_list) {  
  61.         struct sk_buff *list = skb_shinfo(skb)->frag_list;  
  62.   
  63.         for (; list; list = list->next) {  
  64.             int end;  
  65.             BUG_TRAP(start <= offset + len);  
  66.    
  67.             end = start + list->len;  
  68.   
  69.             if ((copy = end - offset) > 0) { /* 如果offset位于此skb分段中,或者分页,或者线性区 */  
  70.                 unsigned int csum2;  
  71.                 if (copy > len)  
  72.                     copy = len;  
  73.   
  74.                 csum2 = skb_checksum(list, offset - start, copy, 0); /* 递归调用 */  
  75.                 csum = csum_block_add(csum, csum2, pos);  
  76.                 if ((len -= copy) == 0)  
  77.                     return csum;  
  78.   
  79.                 offset += copy;  
  80.                 pos += copy;  
  81.             }  
  82.             start = end;  
  83.         }  
  84.     }  
  85.   
  86.     BUG_ON(len);  
  87.     return csum;  
  88. }  

http://blog.csdn.net/zhangskd/article/details/11770647


0 0