diff options
Diffstat (limited to 'iptables-test.py')
-rwxr-xr-x | iptables-test.py | 464 |
1 files changed, 363 insertions, 101 deletions
diff --git a/iptables-test.py b/iptables-test.py index ca5efb1b..cefe4233 100755 --- a/iptables-test.py +++ b/iptables-test.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # # (C) 2012-2013 by Pablo Neira Ayuso <pablo@netfilter.org> # @@ -15,6 +15,7 @@ import sys import os import subprocess import argparse +from difflib import unified_diff IPTABLES = "iptables" IP6TABLES = "ip6tables" @@ -32,30 +33,34 @@ EXTENSIONS_PATH = "extensions" LOGFILE="/tmp/iptables-test.log" log_file = None +STDOUT_IS_TTY = sys.stdout.isatty() +STDERR_IS_TTY = sys.stderr.isatty() -class Colors: - HEADER = '\033[95m' - BLUE = '\033[94m' - GREEN = '\033[92m' - YELLOW = '\033[93m' - RED = '\033[91m' - ENDC = '\033[0m' +def maybe_colored(color, text, isatty): + terminal_sequences = { + 'green': '\033[92m', + 'red': '\033[91m', + } + + return ( + terminal_sequences[color] + text + '\033[0m' if isatty else text + ) def print_error(reason, filename=None, lineno=None): ''' Prints an error with nice colors, indicating file and line number. ''' - print(filename + ": " + Colors.RED + "ERROR" + - Colors.ENDC + ": line %d (%s)" % (lineno, reason)) + print(filename + ": " + maybe_colored('red', "ERROR", STDERR_IS_TTY) + + ": line %d (%s)" % (lineno, reason), file=sys.stderr) -def delete_rule(iptables, rule, filename, lineno): +def delete_rule(iptables, rule, filename, lineno, netns = None): ''' Removes an iptables rule ''' cmd = iptables + " -D " + rule - ret = execute_cmd(cmd, filename, lineno) + ret = execute_cmd(cmd, filename, lineno, netns) if ret == 1: reason = "cannot delete: " + iptables + " -I " + rule print_error(reason, filename, lineno) @@ -69,26 +74,24 @@ def run_test(iptables, rule, rule_save, res, filename, lineno, netns): Executes an unit test. Returns the output of delete_rule(). Parameters: - :param iptables: string with the iptables command to execute + :param iptables: string with the iptables command to execute :param rule: string with iptables arguments for the rule to test - :param rule_save: string to find the rule in the output of iptables -save + :param rule_save: string to find the rule in the output of iptables-save :param res: expected result of the rule. Valid values: "OK", "FAIL" :param filename: name of the file tested (used for print_error purposes) :param lineno: line number being tested (used for print_error purposes) + :param netns: network namespace to call commands in (or None) ''' ret = 0 cmd = iptables + " -A " + rule - if netns: - cmd = "ip netns exec ____iptables-container-test " + EXECUTEABLE + " " + cmd - - ret = execute_cmd(cmd, filename, lineno) + ret = execute_cmd(cmd, filename, lineno, netns) # # report failed test # if ret: - if res == "OK": + if res != "FAIL": reason = "cannot load: " + cmd print_error(reason, filename, lineno) return -1 @@ -99,32 +102,32 @@ def run_test(iptables, rule, rule_save, res, filename, lineno, netns): if res == "FAIL": reason = "should fail: " + cmd print_error(reason, filename, lineno) - delete_rule(iptables, rule, filename, lineno) + delete_rule(iptables, rule, filename, lineno, netns) return -1 matching = 0 - splitted = iptables.split(" ") - if len(splitted) == 2: - if splitted[1] == '-4': + tokens = iptables.split(" ") + if len(tokens) == 2: + if tokens[1] == '-4': command = IPTABLES_SAVE - elif splitted[1] == '-6': + elif tokens[1] == '-6': command = IP6TABLES_SAVE - elif len(splitted) == 1: - if splitted[0] == IPTABLES: + elif len(tokens) == 1: + if tokens[0] == IPTABLES: command = IPTABLES_SAVE - elif splitted[0] == IP6TABLES: + elif tokens[0] == IP6TABLES: command = IP6TABLES_SAVE - elif splitted[0] == ARPTABLES: + elif tokens[0] == ARPTABLES: command = ARPTABLES_SAVE - elif splitted[0] == EBTABLES: + elif tokens[0] == EBTABLES: command = EBTABLES_SAVE - command = EXECUTEABLE + " " + command + command = EXECUTABLE + " " + command if netns: - command = "ip netns exec ____iptables-container-test " + command + command = "ip netns exec " + netns + " " + command - args = splitted[1:] + args = tokens[1:] proc = subprocess.Popen(command, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -134,18 +137,29 @@ def run_test(iptables, rule, rule_save, res, filename, lineno, netns): # check for segfaults # if proc.returncode == -11: - reason = "iptables-save segfaults: " + cmd + reason = command + " segfaults!" print_error(reason, filename, lineno) - delete_rule(iptables, rule, filename, lineno) + delete_rule(iptables, rule, filename, lineno, netns) return -1 # find the rule - matching = out.find(rule_save.encode('utf-8')) + matching = out.find("\n-A {}\n".format(rule_save).encode('utf-8')) + if matching < 0: - reason = "cannot find: " + iptables + " -I " + rule - print_error(reason, filename, lineno) - delete_rule(iptables, rule, filename, lineno) - return -1 + if res == "OK": + reason = "cannot find: " + iptables + " -I " + rule + print_error(reason, filename, lineno) + delete_rule(iptables, rule, filename, lineno, netns) + return -1 + else: + # do not report this error + return 0 + else: + if res != "OK": + reason = "should not match: " + cmd + print_error(reason, filename, lineno) + delete_rule(iptables, rule, filename, lineno, netns) + return -1 # Test "ip netns del NETNS" path with rules in place if netns: @@ -153,7 +167,7 @@ def run_test(iptables, rule, rule_save, res, filename, lineno, netns): return delete_rule(iptables, rule, filename, lineno) -def execute_cmd(cmd, filename, lineno): +def execute_cmd(cmd, filename, lineno = 0, netns = None): ''' Executes a command, checking for segfaults and returning the command exit code. @@ -161,10 +175,14 @@ def execute_cmd(cmd, filename, lineno): :param cmd: string with the command to be executed :param filename: name of the file tested (used for print_error purposes) :param lineno: line number being tested (used for print_error purposes) + :param netns: network namespace to run command in ''' global log_file if cmd.startswith('iptables ') or cmd.startswith('ip6tables ') or cmd.startswith('ebtables ') or cmd.startswith('arptables '): - cmd = EXECUTEABLE + " " + cmd + cmd = EXECUTABLE + " " + cmd + + if netns: + cmd = "ip netns exec " + netns + " " + cmd print("command: {}".format(cmd), file=log_file) ret = subprocess.call(cmd, shell=True, universal_newlines=True, @@ -172,17 +190,201 @@ def execute_cmd(cmd, filename, lineno): log_file.flush() # generic check for segfaults - if ret == -11: + if ret == -11: reason = "command segfaults: " + cmd print_error(reason, filename, lineno) return ret +def variant_res(res, variant, alt_res=None): + ''' + Adjust expected result with given variant + + If expected result is scoped to a variant, the other one yields a different + result. Therefore map @res to itself if given variant is current, use the + alternate result, @alt_res, if specified, invert @res otherwise. + + :param res: expected result from test spec ("OK", "FAIL" or "NOMATCH") + :param variant: variant @res is scoped to by test spec ("NFT" or "LEGACY") + :param alt_res: optional expected result for the alternate variant. + ''' + variant_executable = { + "NFT": "xtables-nft-multi", + "LEGACY": "xtables-legacy-multi" + } + res_inverse = { + "OK": "FAIL", + "FAIL": "OK", + "NOMATCH": "OK" + } + + if variant_executable[variant] == EXECUTABLE: + return res + if alt_res is not None: + return alt_res + return res_inverse[res] + +def fast_run_possible(filename): + ''' + Keep things simple, run only for simple test files: + - no external commands + - no multiple tables + - no variant-specific results + ''' + table = None + rulecount = 0 + for line in open(filename): + if line[0] in ["#", ":"] or len(line.strip()) == 0: + continue + if line[0] == "*": + if table or rulecount > 0: + return False + table = line.rstrip()[1:] + if line[0] in ["@", "%"]: + return False + if len(line.split(";")) > 3: + return False + rulecount += 1 + + return True + +def run_test_file_fast(iptables, filename, netns): + ''' + Run a test file, but fast + + :param filename: name of the file with the test rules + :param netns: network namespace to perform test run in + ''' + + f = open(filename) + + rules = {} + table = "filter" + chain_array = [] + tests = 0 + + for lineno, line in enumerate(f): + if line[0] == "#" or len(line.strip()) == 0: + continue + + if line[0] == "*": + table = line.rstrip()[1:] + continue + + if line[0] == ":": + chain_array = line.rstrip()[1:].split(",") + continue + + if len(chain_array) == 0: + return -1 + + tests += 1 + + for chain in chain_array: + item = line.split(";") + rule = chain + " " + item[0] + + if item[1] == "=": + rule_save = chain + " " + item[0] + else: + rule_save = chain + " " + item[1] + + if iptables == EBTABLES and rule_save.find('-j') < 0: + rule_save += " -j CONTINUE" + + res = item[2].rstrip() + if res != "OK": + rule = chain + " -t " + table + " " + item[0] + ret = run_test(iptables, rule, rule_save, + res, filename, lineno + 1, netns) + + if ret < 0: + return -1 + continue + + if not chain in rules.keys(): + rules[chain] = [] + rules[chain].append((rule, rule_save)) + + restore_data = ["*" + table] + out_expect = [] + for chain in ["PREROUTING", "INPUT", "FORWARD", "OUTPUT", "POSTROUTING"]: + if not chain in rules.keys(): + continue + for rule in rules[chain]: + restore_data.append("-A " + rule[0]) + out_expect.append("-A " + rule[1]) + restore_data.append("COMMIT") + + out_expect = "\n".join(out_expect) + + # load all rules via iptables_restore + + command = EXECUTABLE + " " + iptables + "-restore" + if netns: + command = "ip netns exec " + netns + " " + command + + for line in restore_data: + print(iptables + "-restore: " + line, file=log_file) + + proc = subprocess.Popen(command, shell = True, text = True, + stdin = subprocess.PIPE, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE) + restore_data = "\n".join(restore_data) + "\n" + out, err = proc.communicate(input = restore_data) + + if proc.returncode == -11: + reason = iptables + "-restore segfaults!" + print_error(reason, filename, lineno) + msg = [iptables + "-restore segfault from:"] + msg.extend(["input: " + l for l in restore_data.split("\n")]) + print("\n".join(msg), file=log_file) + return -1 + + if proc.returncode != 0: + print("%s-restore returned %d: %s" % (iptables, proc.returncode, err), + file=log_file) + return -1 + + # find all rules in iptables_save output + + command = EXECUTABLE + " " + iptables + "-save" + if netns: + command = "ip netns exec " + netns + " " + command + + proc = subprocess.Popen(command, shell = True, + stdin = subprocess.PIPE, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE) + out, err = proc.communicate() + + if proc.returncode == -11: + reason = iptables + "-save segfaults!" + print_error(reason, filename, lineno) + return -1 + + cmd = iptables + " -F -t " + table + execute_cmd(cmd, filename, 0, netns) + + out = out.decode('utf-8').rstrip() + if out.find(out_expect) < 0: + print("dumps differ!", file=log_file) + out_clean = [ l for l in out.split("\n") + if not l[0] in ['*', ':', '#']] + diff = unified_diff(out_expect.split("\n"), out_clean, + fromfile="expect", tofile="got", lineterm='') + print("\n".join(diff), file=log_file) + return -1 + + return tests + def run_test_file(filename, netns): ''' Runs a test file :param filename: name of the file with the test rules + :param netns: network namespace to perform test run in ''' # # if this is not a test file, skip. @@ -198,27 +400,36 @@ def run_test_file(filename, netns): iptables = IPTABLES elif "libarpt_" in filename: # only supported with nf_tables backend - if EXECUTEABLE != "xtables-nft-multi": + if EXECUTABLE != "xtables-nft-multi": return 0, 0 iptables = ARPTABLES elif "libebt_" in filename: # only supported with nf_tables backend - if EXECUTEABLE != "xtables-nft-multi": + if EXECUTABLE != "xtables-nft-multi": return 0, 0 iptables = EBTABLES else: # default to iptables if not known prefix iptables = IPTABLES + fast_failed = False + if fast_run_possible(filename): + tests = run_test_file_fast(iptables, filename, netns) + if tests > 0: + print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY)) + return tests, tests + fast_failed = True + f = open(filename) tests = 0 passed = 0 table = "" + chain_array = [] total_test_passed = True if netns: - execute_cmd("ip netns add ____iptables-container-test", filename, 0) + execute_cmd("ip netns add " + netns, filename) for lineno, line in enumerate(f): if line[0] == "#" or len(line.strip()) == 0: @@ -228,20 +439,11 @@ def run_test_file(filename, netns): chain_array = line.rstrip()[1:].split(",") continue - # external non-iptables invocation, executed as is. - if line[0] == "@": + # external command invocation, executed as is. + # detects iptables commands to prefix with EXECUTABLE automatically + if line[0] in ["@", "%"]: external_cmd = line.rstrip()[1:] - if netns: - external_cmd = "ip netns exec ____iptables-container-test " + external_cmd - execute_cmd(external_cmd, filename, lineno) - continue - - # external iptables invocation, executed as is. - if line[0] == "%": - external_cmd = line.rstrip()[1:] - if netns: - external_cmd = "ip netns exec ____iptables-container-test " + EXECUTEABLE + " " + external_cmd - execute_cmd(external_cmd, filename, lineno) + execute_cmd(external_cmd, filename, lineno, netns) continue if line[0] == "*": @@ -249,8 +451,10 @@ def run_test_file(filename, netns): continue if len(chain_array) == 0: - print("broken test, missing chain, leaving") - sys.exit() + print_error("broken test, missing chain", + filename = filename, lineno = lineno) + total_test_passed = False + break test_passed = True tests += 1 @@ -267,7 +471,18 @@ def run_test_file(filename, netns): else: rule_save = chain + " " + item[1] + if iptables == EBTABLES and rule_save.find('-j') < 0: + rule_save += " -j CONTINUE" + res = item[2].rstrip() + if len(item) > 3: + variant = item[3].rstrip() + if len(item) > 4: + alt_res = item[4].rstrip() + else: + alt_res = None + res = variant_res(res, variant, alt_res) + ret = run_test(iptables, rule, rule_save, res, filename, lineno + 1, netns) @@ -280,9 +495,12 @@ def run_test_file(filename, netns): passed += 1 if netns: - execute_cmd("ip netns del ____iptables-container-test", filename, 0) + execute_cmd("ip netns del " + netns, filename) if total_test_passed: - print(filename + ": " + Colors.GREEN + "OK" + Colors.ENDC) + suffix = "" + if fast_failed: + suffix = maybe_colored('red', " but fast mode failed!", STDOUT_IS_TTY) + print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY) + suffix) f.close() return tests, passed @@ -304,6 +522,31 @@ def show_missing(): print('\n'.join(missing)) +def spawn_netns(): + # prefer unshare module + try: + import unshare + unshare.unshare(unshare.CLONE_NEWNET) + return True + except: + pass + + # sledgehammer style: + # - call ourselves prefixed by 'unshare -n' if found + # - pass extra --no-netns parameter to avoid another recursion + try: + import shutil + + unshare = shutil.which("unshare") + if unshare is None: + return False + + sys.argv.append("--no-netns") + os.execv(unshare, [unshare, "-n", sys.executable] + sys.argv) + except: + pass + + return False # # main @@ -321,8 +564,11 @@ def main(): help='Check for missing tests') parser.add_argument('-n', '--nftables', action='store_true', help='Test iptables-over-nftables') - parser.add_argument('-N', '--netns', action='store_true', + parser.add_argument('-N', '--netns', action='store_const', + const='____iptables-container-test', help='Test netnamespace path') + parser.add_argument('--no-netns', action='store_true', + help='Do not run testsuite in own network namespace') args = parser.parse_args() # @@ -332,56 +578,72 @@ def main(): show_missing() return - global EXECUTEABLE - EXECUTEABLE = "xtables-legacy-multi" + variants = [] + if args.legacy: + variants.append("legacy") if args.nftables: - EXECUTEABLE = "xtables-nft-multi" + variants.append("nft") + if len(variants) == 0: + variants = [ "legacy", "nft" ] if os.getuid() != 0: - print("You need to be root to run this, sorry") - return + print("You need to be root to run this, sorry", file=sys.stderr) + return 77 + + if not args.netns and not args.no_netns and not spawn_netns(): + print("Cannot run in own namespace, connectivity might break", + file=sys.stderr) if not args.host: os.putenv("XTABLES_LIBDIR", os.path.abspath(EXTENSIONS_PATH)) os.putenv("PATH", "%s/iptables:%s" % (os.path.abspath(os.path.curdir), os.getenv("PATH"))) - test_files = 0 - tests = 0 - passed = 0 - - # setup global var log file - global log_file - try: - log_file = open(LOGFILE, 'w') - except IOError: - print("Couldn't open log file %s" % LOGFILE) - return + total_test_files = 0 + total_passed = 0 + total_tests = 0 + for variant in variants: + global EXECUTABLE + EXECUTABLE = "xtables-" + variant + "-multi" - if args.filename: - file_list = args.filename - else: - file_list = [os.path.join(EXTENSIONS_PATH, i) - for i in os.listdir(EXTENSIONS_PATH) - if i.endswith('.t')] - file_list.sort() + test_files = 0 + tests = 0 + passed = 0 - if not args.netns: + # setup global var log file + global log_file try: - import unshare - unshare.unshare(unshare.CLONE_NEWNET) - except: - print("Cannot run in own namespace, connectivity might break") - - for filename in file_list: - file_tests, file_passed = run_test_file(filename, args.netns) - if file_tests: - tests += file_tests - passed += file_passed - test_files += 1 - - print("%d test files, %d unit tests, %d passed" % (test_files, tests, passed)) + log_file = open(LOGFILE, 'w') + except IOError: + print("Couldn't open log file %s" % LOGFILE, file=sys.stderr) + return + if args.filename: + file_list = args.filename + else: + file_list = [os.path.join(EXTENSIONS_PATH, i) + for i in os.listdir(EXTENSIONS_PATH) + if i.endswith('.t')] + file_list.sort() + + for filename in file_list: + file_tests, file_passed = run_test_file(filename, args.netns) + if file_tests: + tests += file_tests + passed += file_passed + test_files += 1 + + print("%s: %d test files, %d unit tests, %d passed" + % (variant, test_files, tests, passed)) + + total_passed += passed + total_tests += tests + total_test_files = max(total_test_files, test_files) + + if len(variants) > 1: + print("total: %d test files, %d unit tests, %d passed" + % (total_test_files, total_tests, total_passed)) + return total_passed - total_tests if __name__ == '__main__': - main() + sys.exit(main()) |