summaryrefslogtreecommitdiffstats
path: root/tests/regression/nft-test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/regression/nft-test.py')
-rwxr-xr-xtests/regression/nft-test.py100
1 files changed, 95 insertions, 5 deletions
diff --git a/tests/regression/nft-test.py b/tests/regression/nft-test.py
index 153f5e8b..26fc2ec4 100755
--- a/tests/regression/nft-test.py
+++ b/tests/regression/nft-test.py
@@ -423,9 +423,32 @@ def output_clean(pre_output, chain):
return ""
return rule
+def payload_check(payload_buffer, file, cmd):
+
+ file.seek(0, 0)
+
+ ret = False
+ i = 0
+
+ for lineno, want_line in enumerate(payload_buffer):
+ line = file.readline()
+
+ if want_line == line:
+ i += 1
+ continue
+
+ if want_line.find('[') < 0 and line.find('[') < 0:
+ continue
+ if want_line.find(']') < 0 and line.find(']') < 0:
+ continue
+
+ print_differences_warning(file.name, lineno, want_line.strip(), line.strip(), cmd);
+ return 0
+
+ return i > 0
def rule_add(rule, table_list, chain_list, filename, lineno,
- force_all_family_option):
+ force_all_family_option, filename_path):
'''
Adds a rule
'''
@@ -437,7 +460,23 @@ def rule_add(rule, table_list, chain_list, filename, lineno,
print_error(reason, filename, lineno)
return [-1, warning, error, unit_tests]
+ payload_expected = []
+
for table in table_list:
+ try:
+ payload_log = open("%s.payload.%s" % (filename_path, table[0]))
+ except (IOError):
+ payload_log = open("%s.payload" % filename_path)
+
+ if rule[1].strip() == "ok":
+ try:
+ payload_expected.index(rule[0])
+ except (ValueError):
+ payload_expected = payload_find_expected(payload_log, rule[0])
+
+ if payload_expected == []:
+ print_error("did not find payload information for rule '%s'" % rule[0], payload_log.name, 1)
+
for chain in chain_list:
if len(rule) == 1:
reason = "Skipping malformed test. (" + \
@@ -450,7 +489,10 @@ def rule_add(rule, table_list, chain_list, filename, lineno,
table_info = " " + table[0] + " " + table[1] + " "
cmd = "nft add rule" + table_info + chain + " " + rule[0]
- ret = execute_cmd(cmd, filename, lineno)
+ payload_log = os.tmpfile();
+
+ cmd = "nft add rule --debug=netlink" + table_info + chain + " " + rule[0]
+ ret = execute_cmd(cmd, filename, lineno, payload_log)
state = rule[1].rstrip()
if (ret == 0 and state == "fail") or (ret != 0 and state == "ok"):
@@ -470,6 +512,20 @@ def rule_add(rule, table_list, chain_list, filename, lineno,
continue
if ret == 0:
+ # Check for matching payload
+ if state == "ok" and not payload_check(payload_expected, payload_log, cmd):
+ error += 1
+ gotf = open("%s.payload.got" % filename_path, 'a')
+ payload_log.seek(0, 0)
+ gotf.write("# %s\n" % rule[0])
+ while True:
+ line = payload_log.readline()
+ if line == "":
+ break
+ gotf.write(line)
+ gotf.close()
+ print_warning("Wrote payload for rule %s" % rule[0], gotf.name, 1)
+
# Check output of nft
process = subprocess.Popen(['nft', '-nnn', 'list', 'table'] + table,
shell=False, stdout=subprocess.PIPE,
@@ -536,7 +592,7 @@ def signal_handler(signal, frame):
signal_received = 1
-def execute_cmd(cmd, filename, lineno):
+def execute_cmd(cmd, filename, lineno, stdout_log = False):
'''
Executes a command, checks for segfaults and returns the command exit
code.
@@ -549,8 +605,12 @@ def execute_cmd(cmd, filename, lineno):
print >> log_file, "command: %s" % cmd
if debug_option:
print cmd
+
+ if not stdout_log:
+ stdout_log = log_file
+
ret = subprocess.call(cmd, shell=True, universal_newlines=True,
- stderr=subprocess.STDOUT, stdout=log_file,
+ stderr=log_file, stdout=stdout_log,
preexec_fn=preexec)
log_file.flush()
@@ -619,6 +679,36 @@ def set_element_process(element_line, filename, lineno):
return set_add_elements(set_element, set_name, all_set, rule_state,
table_list, filename, lineno)
+def payload_find_expected(payload_log, rule):
+ '''
+ Find the netlink payload that should be generated by given rule in payload_log
+
+ :param payload_log: open file handle of the payload data
+ :param rule: nft rule we are going to add
+ '''
+ found = 0
+ pos = 0
+ payload_buffer = []
+
+ while True:
+ line = payload_log.readline()
+ if not line:
+ break
+
+ if line[0] == "#": # rule start
+ rule_line = line.strip()[2:]
+
+ if rule_line == rule.strip():
+ found = 1
+ continue
+
+ if found == 1:
+ payload_buffer.append(line)
+ if line.isspace():
+ return payload_buffer
+
+ payload_log.seek(0, 0)
+ return payload_buffer
def run_test_file(filename, force_all_family_option, specific_file):
'''
@@ -699,7 +789,7 @@ def run_test_file(filename, force_all_family_option, specific_file):
continue
result = rule_add(rule, table_list, chain_list, filename, lineno,
- force_all_family_option)
+ force_all_family_option, filename_path)
tests += 1
ret = result[0]
warning = result[1]