diff options
Diffstat (limited to 'tests/py/nft-test.py')
-rwxr-xr-x | tests/py/nft-test.py | 183 |
1 files changed, 161 insertions, 22 deletions
diff --git a/tests/py/nft-test.py b/tests/py/nft-test.py index 6edca3c6..1bc89558 100755 --- a/tests/py/nft-test.py +++ b/tests/py/nft-test.py @@ -28,7 +28,7 @@ os.environ['TZ'] = 'UTC-2' from nftables import Nftables -TESTS_DIRECTORY = ["any", "arp", "bridge", "inet", "ip", "ip6"] +TESTS_DIRECTORY = ["any", "arp", "bridge", "inet", "ip", "ip6", "netdev"] LOGFILE = "/tmp/nftables-test.log" log_file = None table_list = [] @@ -39,7 +39,7 @@ signal_received = 0 class Colors: - if sys.stdout.isatty(): + if sys.stdout.isatty() and sys.stderr.isatty(): HEADER = '\033[95m' GREEN = '\033[92m' YELLOW = '\033[93m' @@ -86,11 +86,12 @@ class Table: class Set: """Class that represents a set""" - def __init__(self, family, table, name, type, timeout, flags): + def __init__(self, family, table, name, type, data, timeout, flags): self.family = family self.table = table self.name = name self.type = type + self.data = data self.timeout = timeout self.flags = flags @@ -366,7 +367,11 @@ def set_add(s, test_result, filename, lineno): if flags != "": flags = "flags %s; " % flags - cmd = "add set %s %s { type %s;%s %s}" % (table, s.name, s.type, s.timeout, flags) + if s.data == "": + cmd = "add set %s %s { %s;%s %s}" % (table, s.name, s.type, s.timeout, flags) + else: + cmd = "add map %s %s { %s : %s;%s %s}" % (table, s.name, s.type, s.data, s.timeout, flags) + ret = execute_cmd(cmd, filename, lineno) if (ret == 0 and test_result == "fail") or \ @@ -384,6 +389,44 @@ def set_add(s, test_result, filename, lineno): return 0 +def map_add(s, test_result, filename, lineno): + ''' + Adds a map + ''' + if not table_list: + reason = "Missing table to add rule" + print_error(reason, filename, lineno) + return -1 + + for table in table_list: + s.table = table.name + s.family = table.family + if _map_exist(s, filename, lineno): + reason = "Map %s already exists in %s" % (s.name, table) + print_error(reason, filename, lineno) + return -1 + + flags = s.flags + if flags != "": + flags = "flags %s; " % flags + + cmd = "add map %s %s { %s : %s;%s %s}" % (table, s.name, s.type, s.data, s.timeout, flags) + + ret = execute_cmd(cmd, filename, lineno) + + if (ret == 0 and test_result == "fail") or \ + (ret != 0 and test_result == "ok"): + reason = "%s: I cannot add the set %s" % (cmd, s.name) + print_error(reason, filename, lineno) + return -1 + + if not _map_exist(s, filename, lineno): + reason = "I have just added the set %s to " \ + "the table %s but it does not exist" % (s.name, table) + print_error(reason, filename, lineno) + return -1 + + def set_add_elements(set_element, set_name, state, filename, lineno): ''' Adds elements to the set. @@ -407,7 +450,11 @@ def set_add_elements(set_element, set_name, state, filename, lineno): ret = execute_cmd(cmd, filename, lineno) if (state == "fail" and ret == 0) or (state == "ok" and ret != 0): - test_state = "This rule should have failed." + if state == "fail": + test_state = "This rule should have failed." + else: + test_state = "This rule should not have failed." + reason = cmd + ": " + test_state print_error(reason, filename, lineno) return -1 @@ -486,6 +533,16 @@ def _set_exist(s, filename, lineno): return True if (ret == 0) else False +def _map_exist(s, filename, lineno): + ''' + Check if the map exists. + ''' + cmd = "list map %s %s %s" % (s.family, s.table, s.name) + ret = execute_cmd(cmd, filename, lineno) + + return True if (ret == 0) else False + + def set_check_element(rule1, rule2): ''' Check if element exists in anonymous sets. @@ -712,8 +769,10 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path): if rule[1].strip() == "ok": payload_expected = None + payload_path = None try: payload_log = open("%s.payload" % filename_path) + payload_path = payload_log.name payload_expected = payload_find_expected(payload_log, rule[0]) except: payload_log = None @@ -750,12 +809,15 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path): reason = "Invalid JSON syntax in expected output: %s" % json_expected print_error(reason) return [-1, warning, error, unit_tests] + if json_expected == json_input: + print_warning("Recorded JSON output matches input for: %s" % rule[0]) for table in table_list: if rule[1].strip() == "ok": table_payload_expected = None try: payload_log = open("%s.payload.%s" % (filename_path, table.family)) + payload_path = payload_log.name table_payload_expected = payload_find_expected(payload_log, rule[0]) except: if not payload_log: @@ -802,17 +864,26 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path): if state == "ok" and not payload_check(table_payload_expected, payload_log, cmd): error += 1 - gotf = open("%s.payload.got" % filename_path, 'a') + + try: + gotf = open("%s.got" % payload_path) + gotf_payload_expected = payload_find_expected(gotf, rule[0]) + gotf.close() + except: + gotf_payload_expected = None payload_log.seek(0, 0) - gotf.write("# %s\n" % rule[0]) - while True: - line = payload_log.readline() - if line == "": - break - gotf.write(line) - gotf.close() - print_warning("Wrote payload for rule %s" % rule[0], - gotf.name, 1) + if not payload_check(gotf_payload_expected, payload_log, cmd): + gotf = open("%s.got" % payload_path, 'a') + payload_log.seek(0, 0) + gotf.write("# %s\n" % rule[0]) + while True: + line = payload_log.readline() + if line == "": + break + gotf.write(line) + gotf.close() + print_warning("Wrote payload for rule %s" % rule[0], + gotf.name, 1) # Check for matching ruleset listing numeric_proto_old = nftables.set_numeric_proto_output(True) @@ -1022,6 +1093,8 @@ def execute_cmd(cmd, filename, lineno, stdout_log=False, debug=False): if debug_option: print(cmd) + log_file.flush() + if debug: debug_old = nftables.get_debug() nftables.set_debug(debug) @@ -1073,14 +1146,28 @@ def set_process(set_line, filename, lineno): tokens = set_line[0].split(" ") set_name = tokens[0] - set_type = tokens[2] + parse_typeof = tokens[1] == "typeof" + set_type = tokens[1] + " " + tokens[2] + set_data = "" set_flags = "" i = 3 + if parse_typeof and tokens[i] == "id": + set_type += " " + tokens[i] + i += 1; + while len(tokens) > i and tokens[i] == ".": set_type += " . " + tokens[i+1] i += 2 + while len(tokens) > i and tokens[i] == ":": + set_data = tokens[i+1] + i += 2 + + if parse_typeof and tokens[i] == "mark": + set_data += " " + tokens[i] + i += 1; + if len(tokens) == i+2 and tokens[i] == "timeout": timeout = "timeout " + tokens[i+1] + ";" i += 2 @@ -1090,9 +1177,13 @@ def set_process(set_line, filename, lineno): elif len(tokens) != i: print_error(set_name + " bad flag: " + tokens[i], filename, lineno) - s = Set("", "", set_name, set_type, timeout, set_flags) + s = Set("", "", set_name, set_type, set_data, timeout, set_flags) + + if set_data == "": + ret = set_add(s, test_result, filename, lineno) + else: + ret = map_add(s, test_result, filename, lineno) - ret = set_add(s, test_result, filename, lineno) if ret == 0: all_set[set_name] = set() @@ -1340,6 +1431,33 @@ def run_test_file(filename, force_all_family_option, specific_file): return [tests, passed, total_warning, total_error, total_unit_run] +def spawn_netns(): + # prefer unshare module + try: + import unshare + unshare.unshare(unshare.CLONE_NEWNET) + return True + except: + pass + + # sledgehammer style: + # - call ourselves prefixed by 'unshare -n' if found + # - pass extra --no-netns parameter to avoid another recursion + try: + import shutil + + unshare = shutil.which("unshare") + if unshare is None: + return False + + sys.argv.append("--no-netns") + if debug_option: + print("calling: ", [unshare, "-n", sys.executable] + sys.argv) + os.execv(unshare, [unshare, "-n", sys.executable] + sys.argv) + except: + pass + + return False def main(): parser = argparse.ArgumentParser(description='Run nft tests') @@ -1357,10 +1475,20 @@ def main(): dest='force_all_family', help='keep testing all families on error') + parser.add_argument('-H', '--host', action='store_true', + help='run tests against installed libnftables.so.1') + parser.add_argument('-j', '--enable-json', action='store_true', dest='enable_json', help='test JSON functionality as well') + parser.add_argument('-l', '--library', default=None, + help='path to libntables.so.1, overrides --host') + + parser.add_argument('-N', '--no-netns', action='store_true', + dest='no_netns', + help='Do not run in own network namespace') + parser.add_argument('-s', '--schema', action='store_true', dest='enable_schema', help='verify json input/output against schema') @@ -1385,12 +1513,23 @@ def main(): print("You need to be root to run this, sorry") return + if not args.no_netns and not spawn_netns(): + print_warning("cannot run in own namespace, connectivity might break") + # Change working directory to repository root os.chdir(TESTS_PATH + "/../..") - if not os.path.exists('src/.libs/libnftables.so'): - print("The nftables library does not exist. " - "You need to build the project.") + check_lib_path = True + if args.library is None: + if args.host: + args.library = 'libnftables.so.1' + check_lib_path = False + else: + args.library = 'src/.libs/libnftables.so.1' + + if check_lib_path and not os.path.exists(args.library): + print("The nftables library at '%s' does not exist. " + "You need to build the project." % args.library) return if args.enable_schema and not args.enable_json: @@ -1398,7 +1537,7 @@ def main(): return global nftables - nftables = Nftables(sofile = 'src/.libs/libnftables.so') + nftables = Nftables(sofile = args.library) test_files = files_ok = run_total = 0 tests = passed = warnings = errors = 0 |