import sys import collections import fcntl import os import select import socket import struct import subprocess import unittest def parse_payload(proto, data): if proto == socket.IPPROTO_IPIP: return Ipv4Header.parse(data) elif proto == socket.IPPROTO_IPV6: return Ipv6Header.parse(data) elif proto == socket.IPPROTO_TCP: return TcpHeader.parse(data) elif proto == socket.IPPROTO_UDP: return UdpHeader.parse(data) else: return None def internet_csum(data): def carry_around_add(a, b): c = a + b return (c & 0xffff) + (c >> 16) def checksum(msg): s = 0 for i in range(0, len(msg), 2): if i + 1 < len(msg): w = ord(msg[i]) + (ord(msg[i+1]) << 8) else: w = ord(msg[i]) s = carry_around_add(s, w) return ~s & 0xffff return socket.htons(checksum(data)) def send_raw_packet(packet): s = socket.socket(packet.__class__.address_family, socket.SOCK_RAW, socket.IPPROTO_RAW); s.sendto(packet.serialize(), 0, (packet.dst, 0)) s.close() def parse_packet(s): if s == None: return None if (ord(s[0]) >> 4) == 4: return Ipv4Header.parse(s) else: return Ipv6Header.parse(s) class Tun: def __init__(self): TUNSETIFF = 0x400454ca IFF_TUN = 0x0001 IFF_NO_PI = 0x1000 self.tun = open('/dev/net/tun', 'r') ifr = struct.pack('16sH', 'tun%d', IFF_TUN | IFF_NO_PI) ifr = fcntl.ioctl(self.tun, TUNSETIFF, ifr) self.name = struct.unpack('16sH', ifr)[0].split('\0')[0] subprocess.call(['ip', 'link', 'set', 'dev', self.name, 'up']) def add_route(self, rt): subprocess.call(['ip', 'route', 'add', 'dev', self.name, rt]) def add_routes(self, routes): for rt in routes: self.add_route(rt) def read(self, tmo=1): r, w, e = select.select([self.tun], [], [], tmo) if len(r): return os.read(self.tun.fileno(), 1500) else: return None def close(self): self.tun.close() class Ipv4Header: fmt = '!BBHHHBBH4s4s' encaps_proto = socket.IPPROTO_IPIP address_family = socket.AF_INET def __init__(self, src, dst, proto=0, version=4, ihl=5, tos=0, tot_len=0, ident=0, frag_off=0, ttl=42, csum=0, payload=None): self.src, self.dst, self.proto, self.version = src, dst, proto, version self.ihl, self.tos, self.tot_len, self.ident = ihl, tos, tot_len, ident self.frag_off, self.ttl, self.csum = frag_off, ttl, csum self.payload = payload @staticmethod def parse(data): fields = struct.unpack(Ipv4Header.fmt, data[:20]) ver_ihl, tos, tot_len, ident, frag_off = fields[:5] ttl, proto, csum, src_raw, dst_raw = fields[5:] version = ver_ihl >> 4 ihl = ver_ihl & 0xf src = socket.inet_ntop(socket.AF_INET, src_raw) dst = socket.inet_ntop(socket.AF_INET, dst_raw) payload = parse_payload(proto, data[ihl * 4:tot_len]) return Ipv4Header(src=src, dst=dst, proto=proto, version=version, ihl=ihl, tos=tos, tot_len=tot_len, ident=ident, frag_off=frag_off, ttl=ttl, csum=csum, payload=payload) def __pseudo_header_fn(self, length): src_raw = socket.inet_pton(socket.AF_INET, self.src) dst_raw = socket.inet_pton(socket.AF_INET, self.dst) proto = self.payload.__class__.encaps_proto return struct.pack('!4s4sBBH', src_raw, dst_raw, 0, proto, length) def serialize(self, ph_fn=None): ver_ihl = (self.version << 4) | self.ihl ps = self.payload.serialize(ph_fn=self.__pseudo_header_fn) if self.proto == 0: self.proto = self.payload.__class__.encaps_proto if self.tot_len == 0: self.tot_len = 20 + len(ps) if self.ident == 0: self.ident = 0xf00f src_raw = socket.inet_pton(socket.AF_INET, self.src) dst_raw = socket.inet_pton(socket.AF_INET, self.dst) if self.csum == 0: fields = [ver_ihl, self.tos, self.tot_len, self.ident, self.frag_off, self.ttl, self.proto, 0, src_raw, dst_raw] self.csum = internet_csum(struct.pack(self.__class__.fmt, *fields)) fields = [ver_ihl, self.tos, self.tot_len, self.ident, self.frag_off, self.ttl, self.proto, self.csum, src_raw, dst_raw] return struct.pack(self.__class__.fmt, *fields) + ps def __repr__(self): return ('Ipv4Header(src="{s.src}", dst="{s.dst}", proto="{s.proto}", ' + 'version={s.version}, ihl={s.ihl}, tos={s.tos}, ' + 'tot_len={s.tot_len}, ident={s.ident}, ' + 'frag_off={s.frag_off}, ttl={s.ttl}, csum=0x{s.csum:x}, ' + 'payload={s.payload})').format(s=self) class Ipv6Header: fmt = '!IHBB16s16s' encaps_proto = socket.IPPROTO_IPV6 address_family = socket.AF_INET6 def __init__(self, src, dst, proto=0, version=6, tclass=0, flow_label=0, payload_len=0, hop_limit=42, payload=None): self.src, self.dst, self.proto, self.version = src, dst, proto, version self.tclass, self.flow_label = tclass, flow_label self.payload_len, self.hop_limit = payload_len, hop_limit self.payload = payload @staticmethod def parse(data): fields = struct.unpack(Ipv6Header.fmt, data[:40]) fw, payload_len, proto, hop_limit, src_raw, dst_raw = fields version = (fw >> 28) & 0xf tclass = (fw >> 20) & 0xff flow_label = fw & 0xfffff src = socket.inet_ntop(socket.AF_INET6, src_raw) dst = socket.inet_ntop(socket.AF_INET6, dst_raw) payload = parse_payload(proto, data[40:payload_len + 40]) return Ipv6Header(src=src, dst=dst, proto=proto, version=version, tclass=tclass, flow_label=flow_label, payload_len=payload_len, hop_limit=hop_limit, payload=payload) def __pseudo_header_fn(self, length): src_raw = socket.inet_pton(socket.AF_INET6, self.src) dst_raw = socket.inet_pton(socket.AF_INET6, self.dst) proto = self.payload.__class__.encaps_proto return struct.pack('!16s16sIBBBB', src_raw, dst_raw, length, 0, 0, 0, proto) def serialize(self, ph_fn=None): fw = (self.version << 28) | (self.tclass << 20) | self.flow_label ps = self.payload.serialize(ph_fn=self.__pseudo_header_fn) if self.proto == 0: self.proto = self.payload.__class__.encaps_proto if self.payload_len == 0: self.payload_len = len(ps) src_raw = socket.inet_pton(socket.AF_INET6, self.src) dst_raw = socket.inet_pton(socket.AF_INET6, self.dst) fields = [fw, self.payload_len, self.proto, self.hop_limit, src_raw, dst_raw] return struct.pack(self.__class__.fmt, *fields) + ps def __repr__(self): return ('Ipv6Header(src="{s.src}", dst="{s.dst}", proto="{s.proto}", ' + 'version={s.version}, tclass={s.tclass}, ' + 'flow_label={s.flow_label}, payload_len={s.payload_len}, ' + 'hop_limit={s.hop_limit}, payload={s.payload})').format(s=self) class TcpHeader: fmt = '!HHIIHHHH' encaps_proto = socket.IPPROTO_TCP def __init__(self, src, dst, seq=0, ack_num=0, data_off=5, ns=0, cwr=0, ece=0, urg=0, ack=0, psh=0, rst=0, syn=0, fin=0, win_size=100, csum=0, urg_ptr=0, payload=''): self.src, self.dst, self.seq, self.ack_num = src, dst, seq, ack_num self.data_off, self.ns, self.cwr, self.ece = data_off, ns, cwr, ece self.urg, self.ack, self.psh, self.rst = urg, ack, psh, rst self.syn, self.fin, self.win_size, self.csum = syn, fin, win_size, csum self.urg_ptr, self.payload = urg_ptr, payload @staticmethod def parse(data): fields = struct.unpack(TcpHeader.fmt, data[:20]) src, dst, seq, ack_num, flags, win_size, csum, urg_ptr = fields data_off = (flags >> 12) & 0xf fbits = [((flags >> i) & 0x1) for i in xrange(9)][::-1] ns, cwr, ece, urg, ack, psh, rst, syn, fin = fbits payload = data[(data_off * 4):] return TcpHeader(src=src, dst=dst, seq=seq, ack_num=ack_num, ns=ns, cwr=cwr, ece=ece, urg=urg, ack=ack, psh=psh, rst=rst, syn=syn, fin=fin, win_size=win_size, csum=csum, urg_ptr=urg_ptr, payload=payload) def serialize(self, ph_fn): if self.data_off == 0: self.data_off = 5 bits = [self.ns, self.cwr, self.ece, self.urg, self.ack, self.psh, self.rst, self.syn, self.fin] flags = reduce(lambda flags, bit: (flags << 1) | bit, bits, 0) flags = flags | (self.data_off << 12) if self.csum == 0: fields = [self.src, self.dst, self.seq, self.ack_num, flags, self.win_size, self.csum, self.urg_ptr] s = struct.pack(self.__class__.fmt, *fields) + self.payload self.csum = internet_csum(ph_fn(20 + len(self.payload)) + s) fields = [self.src, self.dst, self.seq, self.ack_num, flags, self.win_size, self.csum, self.urg_ptr] return struct.pack(self.__class__.fmt, *fields) + self.payload def __repr__(self): return ('TcpHeader(src={s.src}, dst={s.dst}, seq={s.seq}, ' + 'ack_num={s.ack_num}, ns={s.ns}, cwr={s.cwr}, ece={s.ece}, ' + 'urg={s.urg}, ack={s.ack}, psh={s.psh}, rst={s.rst}, ' + 'syn={s.syn}, fin={s.fin}, win_size={s.win_size}, ' + 'csum=0x{s.csum:x}, urg_ptr={s.urg_ptr})').format(s=self) class UdpHeader: fmt = '!HHHH' encaps_proto = socket.IPPROTO_UDP def __init__(self, src, dst, csum=0, payload=''): self.src, self.dst, self.csum, self.payload = src, dst, csum, payload @staticmethod def parse(data): fields = struct.unpack(UdpHeader.fmt, data[:8]) src, dst, csum, length = fields payload = data[8:length] return UdpHeader(src=src, dst=dst, csum=csum, payload=payload) def serialize(self, ph_fn): length = 8 + len(self.payload) if self.csum == 0: fields = [self.src, self.dst, self.csum, length] s = struct.pack(self.__class__.fmt, *fields) + self.payload self.csum = internet_csum(ph_fn(length) + s) fields = [self.src, self.dst, self.csum, 8 + len(self.payload)] return struct.pack(UdpHeader.fmt, *fields) + self.payload def __repr__(self): return ('UdpHeader(src={s.src}, dst={s.dst}, csum=0x{s.csum:x}, ' + 'payload={s.payload})').format(s=self) RESET_IPVS = """\ ipvsadm --clear && lsmod | grep '^ip_vs_' | cut -d' ' -f1 | xargs -r -n1 rmmod && rmmod ip_vs; """ def add_service(vip, port, real_servers, scheduler='rr', service_type='tcp'): if vip.find(':') >= 0 and vip[0] != '[': vip = '[' + vip + ']' def c(*args): assert subprocess.call(args) == 0 service = '%s:%d' % (vip, port) sf = '--tcp-service' if service_type == 'tcp' else '--udp-service' ef = '--ipip' c('ipvsadm', '-A', sf, service, '--scheduler', scheduler) for rs in real_servers[::-1]: c('ipvsadm', '-a', sf, service, '--real-server', rs, ef) class TestIpvsRoundRobin(unittest.TestCase): def setUp(self): subprocess.call(RESET_IPVS, shell=True) self.tun = Tun() def tearDown(self): self.tun.close() self.tun = None def do_balance_test(self, vip, port, real_servers, src_ip, iph_fn): self.tun.add_routes(real_servers) # For some reason this is necessary to not break ipv6? add_service('192.168.255.38', 999, []) add_service(vip, port, real_servers) buckets = collections.defaultdict(lambda: 0) base_port = 15000 for i in xrange(10000): src = base_port + i p = iph_fn(src_ip, vip, payload=TcpHeader(src, port, syn=1)) send_raw_packet(p) buckets[parse_packet(self.tun.read()).dst] += 1 avg = float(sum(buckets.values())) / len(buckets) for k, v in buckets.iteritems(): self.assertTrue(abs(v - avg) <= 1.0) def test_balance_v4(self): real_servers = ['1.2.3.%d' % i for i in xrange(5, 32)] self.do_balance_test( '1.2.3.4', 15213, real_servers, '9.9.9.9', Ipv4Header) def test_balance_v6(self): real_servers = ['face::%x' % i for i in xrange(5, 32)] self.do_balance_test( 'face::4', 15213, real_servers, 'b00c::1', Ipv6Header) def do_stickiness_test(self, vip, port, real_servers, src_ip, iph_fn): self.tun.add_routes(real_servers) # For some reason this is necessary to not break ipv6? add_service('192.168.255.38', 999, []) add_service(vip, port, real_servers) buckets = {} base_port = 15000 for i in xrange(10000): src = base_port + (i % 43) p = iph_fn(src_ip, vip, payload=TcpHeader(src, port, syn=1)) send_raw_packet(p) dst_ip = parse_packet(self.tun.read()).dst if src not in buckets: buckets[src] = dst_ip else: self.assertEqual(buckets[src], dst_ip) def test_stickiness_v4(self): real_servers = ['1.2.3.%d' % i for i in xrange(5, 32)] self.do_stickiness_test( '1.2.3.4', 15213, real_servers, '9.9.9.9', Ipv4Header) def test_stickiness_v6(self): real_servers = ['face::%x' % i for i in xrange(5, 32)] self.do_stickiness_test( 'face::4', 15213, real_servers, 'b00c::1', Ipv6Header) if __name__ == '__main__': unittest.main()