summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xiptables-test.py84
1 files changed, 51 insertions, 33 deletions
diff --git a/iptables-test.py b/iptables-test.py
index 6504b231..b5a70e44 100755
--- a/iptables-test.py
+++ b/iptables-test.py
@@ -408,10 +408,13 @@ def main():
show_missing()
return
- global EXECUTABLE
- EXECUTABLE = "xtables-legacy-multi"
+ variants = []
+ if args.legacy:
+ variants.append("legacy")
if args.nftables:
- EXECUTABLE = "xtables-nft-multi"
+ variants.append("nft")
+ if len(variants) == 0:
+ variants = [ "legacy", "nft" ]
if os.getuid() != 0:
print("You need to be root to run this, sorry", file=sys.stderr)
@@ -426,36 +429,51 @@ def main():
os.putenv("PATH", "%s/iptables:%s" % (os.path.abspath(os.path.curdir),
os.getenv("PATH")))
- test_files = 0
- tests = 0
- passed = 0
-
- # setup global var log file
- global log_file
- try:
- log_file = open(LOGFILE, 'w')
- except IOError:
- print("Couldn't open log file %s" % LOGFILE, file=sys.stderr)
- return
-
- if args.filename:
- file_list = args.filename
- else:
- file_list = [os.path.join(EXTENSIONS_PATH, i)
- for i in os.listdir(EXTENSIONS_PATH)
- if i.endswith('.t')]
- file_list.sort()
-
- for filename in file_list:
- file_tests, file_passed = run_test_file(filename, args.netns)
- if file_tests:
- tests += file_tests
- passed += file_passed
- test_files += 1
-
- print("%d test files, %d unit tests, %d passed" % (test_files, tests, passed))
- return passed - tests
-
+ total_test_files = 0
+ total_passed = 0
+ total_tests = 0
+ for variant in variants:
+ global EXECUTABLE
+ EXECUTABLE = "xtables-" + variant + "-multi"
+
+ test_files = 0
+ tests = 0
+ passed = 0
+
+ # setup global var log file
+ global log_file
+ try:
+ log_file = open(LOGFILE, 'w')
+ except IOError:
+ print("Couldn't open log file %s" % LOGFILE, file=sys.stderr)
+ return
+
+ if args.filename:
+ file_list = args.filename
+ else:
+ file_list = [os.path.join(EXTENSIONS_PATH, i)
+ for i in os.listdir(EXTENSIONS_PATH)
+ if i.endswith('.t')]
+ file_list.sort()
+
+ for filename in file_list:
+ file_tests, file_passed = run_test_file(filename, args.netns)
+ if file_tests:
+ tests += file_tests
+ passed += file_passed
+ test_files += 1
+
+ print("%s: %d test files, %d unit tests, %d passed"
+ % (variant, test_files, tests, passed))
+
+ total_passed += passed
+ total_tests += tests
+ total_test_files = max(total_test_files, test_files)
+
+ if len(variants) > 1:
+ print("total: %d test files, %d unit tests, %d passed"
+ % (total_test_files, total_tests, total_passed))
+ return total_passed - total_tests
if __name__ == '__main__':
sys.exit(main())