diff --git a/xdp-server.c b/xdp-server.c index 1e9cf3e22..c5358ed6c 100644 --- a/xdp-server.c +++ b/xdp-server.c @@ -767,24 +767,43 @@ process_packet(struct xdp_server *xdp, uint8_t *pkt, if (!(udp = parse_ipv6(ipv6))) return -3; + uint16_t udp6_old_check = udp->check; + uint16_t udp6_check = calc_csum_udp6(udp, ipv6); + + if (udp6_check != udp6_old_check) + return -4; + dnslen -= (uint32_t) (sizeof(*eth) + sizeof(*ipv6) + sizeof(*udp)); data_before_dnshdr_len += sizeof(*ipv6); if (!dest_ip_allowed6(xdp, ipv6)) - return -4; + return -5; break; } case ETH_P_IP: { ipv4 = (struct iphdr *)(eth + 1); + __sum16 ipv4_old_check = ipv4->check; + ipv4->check = 0; + ipv4->check = ip_fast_csum(ipv4, ipv4->ihl); + + if (ipv4->check != ipv4_old_check) + return -6; + if (!(udp = parse_ipv4(ipv4))) - return -5; + return -7; + + uint16_t udp4_old_check = udp->check; + uint16_t udp4_check = calc_csum_udp4(udp, ipv4); + + if (udp4_check != udp4_old_check) + return -8; dnslen -= (uint32_t) (sizeof(*eth) + sizeof(*ipv4) + sizeof(*udp)); data_before_dnshdr_len += sizeof(*ipv4); if (!dest_ip_allowed4(xdp, ipv4)) - return -6; + return -9; break; } diff --git a/xdp-util.h b/xdp-util.h index 6994e7249..2d833a3bb 100644 --- a/xdp-util.h +++ b/xdp-util.h @@ -124,4 +124,89 @@ static inline uint16_t calc_csum_udp4(struct udphdr *udp, struct iphdr *ipv4) { return (uint16_t) sum; } +/* + * This function code has been taken from + * Linux kernel lib/checksum.c + */ +static inline unsigned short from32to16(unsigned int x) +{ + /* add up 16-bit and 16-bit for 16+c bit */ + x = (x & 0xffff) + (x >> 16); + /* add up carry.. */ + x = (x & 0xffff) + (x >> 16); + return x; +} + +/* + * This function code has been taken from + * Linux kernel lib/checksum.c + */ +static unsigned int do_csum(const unsigned char *buff, int len) +{ + unsigned int result = 0; + int odd; + + if (len <= 0) + goto out; + odd = 1 & (unsigned long)buff; + if (odd) { +#ifdef __LITTLE_ENDIAN + result += (*buff << 8); +#else + result = *buff; +#endif + len--; + buff++; + } + if (len >= 2) { + if (2 & (unsigned long)buff) { + result += *(unsigned short *)buff; + len -= 2; + buff += 2; + } + if (len >= 4) { + const unsigned char *end = buff + + ((unsigned int)len & ~3); + unsigned int carry = 0; + + do { + unsigned int w = *(unsigned int *)buff; + + buff += 4; + result += carry; + result += w; + carry = (w > result); + } while (buff < end); + result += carry; + result = (result & 0xffff) + (result >> 16); + } + if (len & 2) { + result += *(unsigned short *)buff; + buff += 2; + } + } + if (len & 1) +#ifdef __LITTLE_ENDIAN + result += *buff; +#else + result += (*buff << 8); +#endif + result = from32to16(result); + if (odd) + result = ((result >> 8) & 0xff) | ((result & 0xff) << 8); +out: + return result; +} + +/* + * This is a version of ip_compute_csum() optimized for IP headers, + * which always checksum on 4 octet boundaries. + * This function code has been taken from + * Linux kernel lib/checksum.c + */ +static inline __sum16 ip_fast_csum(const void *iph, unsigned int ihl) +{ + return (__sum16)~do_csum(iph, ihl * 4); +} + #endif /* XDP_UTIL_H */