summaryrefslogtreecommitdiffstats
path: root/tests/py/nft-test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/py/nft-test.py')
-rwxr-xr-xtests/py/nft-test.py161
1 files changed, 143 insertions, 18 deletions
diff --git a/tests/py/nft-test.py b/tests/py/nft-test.py
index 01ee6c98..1bc89558 100755
--- a/tests/py/nft-test.py
+++ b/tests/py/nft-test.py
@@ -28,7 +28,7 @@ os.environ['TZ'] = 'UTC-2'
from nftables import Nftables
-TESTS_DIRECTORY = ["any", "arp", "bridge", "inet", "ip", "ip6"]
+TESTS_DIRECTORY = ["any", "arp", "bridge", "inet", "ip", "ip6", "netdev"]
LOGFILE = "/tmp/nftables-test.log"
log_file = None
table_list = []
@@ -39,7 +39,7 @@ signal_received = 0
class Colors:
- if sys.stdout.isatty():
+ if sys.stdout.isatty() and sys.stderr.isatty():
HEADER = '\033[95m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
@@ -86,11 +86,12 @@ class Table:
class Set:
"""Class that represents a set"""
- def __init__(self, family, table, name, type, timeout, flags):
+ def __init__(self, family, table, name, type, data, timeout, flags):
self.family = family
self.table = table
self.name = name
self.type = type
+ self.data = data
self.timeout = timeout
self.flags = flags
@@ -366,7 +367,11 @@ def set_add(s, test_result, filename, lineno):
if flags != "":
flags = "flags %s; " % flags
- cmd = "add set %s %s { type %s;%s %s}" % (table, s.name, s.type, s.timeout, flags)
+ if s.data == "":
+ cmd = "add set %s %s { %s;%s %s}" % (table, s.name, s.type, s.timeout, flags)
+ else:
+ cmd = "add map %s %s { %s : %s;%s %s}" % (table, s.name, s.type, s.data, s.timeout, flags)
+
ret = execute_cmd(cmd, filename, lineno)
if (ret == 0 and test_result == "fail") or \
@@ -384,6 +389,44 @@ def set_add(s, test_result, filename, lineno):
return 0
+def map_add(s, test_result, filename, lineno):
+ '''
+ Adds a map
+ '''
+ if not table_list:
+ reason = "Missing table to add rule"
+ print_error(reason, filename, lineno)
+ return -1
+
+ for table in table_list:
+ s.table = table.name
+ s.family = table.family
+ if _map_exist(s, filename, lineno):
+ reason = "Map %s already exists in %s" % (s.name, table)
+ print_error(reason, filename, lineno)
+ return -1
+
+ flags = s.flags
+ if flags != "":
+ flags = "flags %s; " % flags
+
+ cmd = "add map %s %s { %s : %s;%s %s}" % (table, s.name, s.type, s.data, s.timeout, flags)
+
+ ret = execute_cmd(cmd, filename, lineno)
+
+ if (ret == 0 and test_result == "fail") or \
+ (ret != 0 and test_result == "ok"):
+ reason = "%s: I cannot add the set %s" % (cmd, s.name)
+ print_error(reason, filename, lineno)
+ return -1
+
+ if not _map_exist(s, filename, lineno):
+ reason = "I have just added the set %s to " \
+ "the table %s but it does not exist" % (s.name, table)
+ print_error(reason, filename, lineno)
+ return -1
+
+
def set_add_elements(set_element, set_name, state, filename, lineno):
'''
Adds elements to the set.
@@ -407,7 +450,11 @@ def set_add_elements(set_element, set_name, state, filename, lineno):
ret = execute_cmd(cmd, filename, lineno)
if (state == "fail" and ret == 0) or (state == "ok" and ret != 0):
- test_state = "This rule should have failed."
+ if state == "fail":
+ test_state = "This rule should have failed."
+ else:
+ test_state = "This rule should not have failed."
+
reason = cmd + ": " + test_state
print_error(reason, filename, lineno)
return -1
@@ -486,6 +533,16 @@ def _set_exist(s, filename, lineno):
return True if (ret == 0) else False
+def _map_exist(s, filename, lineno):
+ '''
+ Check if the map exists.
+ '''
+ cmd = "list map %s %s %s" % (s.family, s.table, s.name)
+ ret = execute_cmd(cmd, filename, lineno)
+
+ return True if (ret == 0) else False
+
+
def set_check_element(rule1, rule2):
'''
Check if element exists in anonymous sets.
@@ -712,8 +769,10 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
if rule[1].strip() == "ok":
payload_expected = None
+ payload_path = None
try:
payload_log = open("%s.payload" % filename_path)
+ payload_path = payload_log.name
payload_expected = payload_find_expected(payload_log, rule[0])
except:
payload_log = None
@@ -750,12 +809,15 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
reason = "Invalid JSON syntax in expected output: %s" % json_expected
print_error(reason)
return [-1, warning, error, unit_tests]
+ if json_expected == json_input:
+ print_warning("Recorded JSON output matches input for: %s" % rule[0])
for table in table_list:
if rule[1].strip() == "ok":
table_payload_expected = None
try:
payload_log = open("%s.payload.%s" % (filename_path, table.family))
+ payload_path = payload_log.name
table_payload_expected = payload_find_expected(payload_log, rule[0])
except:
if not payload_log:
@@ -802,17 +864,26 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
if state == "ok" and not payload_check(table_payload_expected,
payload_log, cmd):
error += 1
- gotf = open("%s.payload.got" % filename_path, 'a')
+
+ try:
+ gotf = open("%s.got" % payload_path)
+ gotf_payload_expected = payload_find_expected(gotf, rule[0])
+ gotf.close()
+ except:
+ gotf_payload_expected = None
payload_log.seek(0, 0)
- gotf.write("# %s\n" % rule[0])
- while True:
- line = payload_log.readline()
- if line == "":
- break
- gotf.write(line)
- gotf.close()
- print_warning("Wrote payload for rule %s" % rule[0],
- gotf.name, 1)
+ if not payload_check(gotf_payload_expected, payload_log, cmd):
+ gotf = open("%s.got" % payload_path, 'a')
+ payload_log.seek(0, 0)
+ gotf.write("# %s\n" % rule[0])
+ while True:
+ line = payload_log.readline()
+ if line == "":
+ break
+ gotf.write(line)
+ gotf.close()
+ print_warning("Wrote payload for rule %s" % rule[0],
+ gotf.name, 1)
# Check for matching ruleset listing
numeric_proto_old = nftables.set_numeric_proto_output(True)
@@ -1022,6 +1093,8 @@ def execute_cmd(cmd, filename, lineno, stdout_log=False, debug=False):
if debug_option:
print(cmd)
+ log_file.flush()
+
if debug:
debug_old = nftables.get_debug()
nftables.set_debug(debug)
@@ -1073,14 +1146,28 @@ def set_process(set_line, filename, lineno):
tokens = set_line[0].split(" ")
set_name = tokens[0]
- set_type = tokens[2]
+ parse_typeof = tokens[1] == "typeof"
+ set_type = tokens[1] + " " + tokens[2]
+ set_data = ""
set_flags = ""
i = 3
+ if parse_typeof and tokens[i] == "id":
+ set_type += " " + tokens[i]
+ i += 1;
+
while len(tokens) > i and tokens[i] == ".":
set_type += " . " + tokens[i+1]
i += 2
+ while len(tokens) > i and tokens[i] == ":":
+ set_data = tokens[i+1]
+ i += 2
+
+ if parse_typeof and tokens[i] == "mark":
+ set_data += " " + tokens[i]
+ i += 1;
+
if len(tokens) == i+2 and tokens[i] == "timeout":
timeout = "timeout " + tokens[i+1] + ";"
i += 2
@@ -1090,9 +1177,13 @@ def set_process(set_line, filename, lineno):
elif len(tokens) != i:
print_error(set_name + " bad flag: " + tokens[i], filename, lineno)
- s = Set("", "", set_name, set_type, timeout, set_flags)
+ s = Set("", "", set_name, set_type, set_data, timeout, set_flags)
+
+ if set_data == "":
+ ret = set_add(s, test_result, filename, lineno)
+ else:
+ ret = map_add(s, test_result, filename, lineno)
- ret = set_add(s, test_result, filename, lineno)
if ret == 0:
all_set[set_name] = set()
@@ -1340,6 +1431,33 @@ def run_test_file(filename, force_all_family_option, specific_file):
return [tests, passed, total_warning, total_error, total_unit_run]
+def spawn_netns():
+ # prefer unshare module
+ try:
+ import unshare
+ unshare.unshare(unshare.CLONE_NEWNET)
+ return True
+ except:
+ pass
+
+ # sledgehammer style:
+ # - call ourselves prefixed by 'unshare -n' if found
+ # - pass extra --no-netns parameter to avoid another recursion
+ try:
+ import shutil
+
+ unshare = shutil.which("unshare")
+ if unshare is None:
+ return False
+
+ sys.argv.append("--no-netns")
+ if debug_option:
+ print("calling: ", [unshare, "-n", sys.executable] + sys.argv)
+ os.execv(unshare, [unshare, "-n", sys.executable] + sys.argv)
+ except:
+ pass
+
+ return False
def main():
parser = argparse.ArgumentParser(description='Run nft tests')
@@ -1367,6 +1485,10 @@ def main():
parser.add_argument('-l', '--library', default=None,
help='path to libntables.so.1, overrides --host')
+ parser.add_argument('-N', '--no-netns', action='store_true',
+ dest='no_netns',
+ help='Do not run in own network namespace')
+
parser.add_argument('-s', '--schema', action='store_true',
dest='enable_schema',
help='verify json input/output against schema')
@@ -1391,6 +1513,9 @@ def main():
print("You need to be root to run this, sorry")
return
+ if not args.no_netns and not spawn_netns():
+ print_warning("cannot run in own namespace, connectivity might break")
+
# Change working directory to repository root
os.chdir(TESTS_PATH + "/../..")