summaryrefslogtreecommitdiffstats
path: root/iptables-test.py
diff options
context:
space:
mode:
authorPablo Neira Ayuso <pablo@netfilter.org>2012-08-21 19:43:09 +0200
committerPablo Neira Ayuso <pablo@netfilter.org>2013-10-07 16:35:25 +0200
commitc8b7aaabbe1fc6da1c97d1e3de8cfae67ad483a1 (patch)
tree73e4f5768d674a464b9f5dacd55e3cae7d4c7c6d /iptables-test.py
parentb90e798596465f0ea2daff75d95f4f978f0a8377 (diff)
add iptables unit test infrastructure
This patch adds a python script to verify unit test cases. Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
Diffstat (limited to 'iptables-test.py')
-rwxr-xr-xiptables-test.py311
1 files changed, 311 insertions, 0 deletions
diff --git a/iptables-test.py b/iptables-test.py
new file mode 100755
index 00000000..9e137f8c
--- /dev/null
+++ b/iptables-test.py
@@ -0,0 +1,311 @@
+#!/usr/bin/python
+#
+# (C) 2012-2013 by Pablo Neira Ayuso <pablo@netfilter.org>
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This software has been sponsored by Sophos Astaro <http://www.sophos.com>
+#
+
+import sys
+import os
+import subprocess
+import argparse
+
+IPTABLES = "iptables"
+IP6TABLES = "ip6tables"
+#IPTABLES = "xtables -4"
+#IP6TABLES = "xtables -6"
+
+IPTABLES_SAVE = "iptables-save"
+IP6TABLES_SAVE = "ip6tables-save"
+#IPTABLES_SAVE = ['xtables-save','-4']
+#IP6TABLES_SAVE = ['xtables-save','-6']
+
+EXTENSIONS_PATH = "extensions"
+LOGFILE="/tmp/iptables-test.log"
+log_file = None
+
+
+class Colors:
+ HEADER = '\033[95m'
+ BLUE = '\033[94m'
+ GREEN = '\033[92m'
+ YELLOW = '\033[93m'
+ RED = '\033[91m'
+ ENDC = '\033[0m'
+
+
+def print_error(reason, filename=None, lineno=None):
+ '''
+ Prints an error with nice colors, indicating file and line number.
+ '''
+ print (filename + ": " + Colors.RED + "ERROR" +
+ Colors.ENDC + ": line %d (%s)" % (lineno, reason))
+
+
+def delete_rule(iptables, rule, filename, lineno):
+ '''
+ Removes an iptables rule
+ '''
+ cmd = iptables + " -D " + rule
+ ret = execute_cmd(cmd, filename, lineno)
+ if ret == 1:
+ reason = "cannot delete: " + iptables + " -I " + rule
+ print_error(reason, filename, lineno)
+ return -1
+
+ return 0
+
+
+def run_test(iptables, rule, rule_save, res, filename, lineno):
+ '''
+ Executes an unit test. Returns the output of delete_rule().
+
+ Parameters:
+ :param iptables: string with the iptables command to execute
+ :param rule: string with iptables arguments for the rule to test
+ :param rule_save: string to find the rule in the output of iptables -save
+ :param res: expected result of the rule. Valid values: "OK", "FAIL"
+ :param filename: name of the file tested (used for print_error purposes)
+ :param lineno: line number being tested (used for print_error purposes)
+ '''
+ ret = 0
+
+ cmd = iptables + " -A " + rule
+ ret = execute_cmd(cmd, filename, lineno)
+
+ #
+ # report failed test
+ #
+ if ret:
+ if res == "OK":
+ reason = "cannot load: " + cmd
+ print_error(reason, filename, lineno)
+ return -1
+ else:
+ # do not report this error
+ return 0
+ else:
+ if res == "FAIL":
+ reason = "should fail: " + cmd
+ print_error(reason, filename, lineno)
+ delete_rule(iptables, rule, filename, lineno)
+ return -1
+
+ matching = 0
+ splitted = iptables.split(" ")
+ if len(splitted) == 2:
+ if splitted[1] == '-4':
+ command = IPTABLES_SAVE
+ elif splitted[1] == '-6':
+ command = IP6TABLES_SAVE
+ elif len(splitted) == 1:
+ if splitted[0] == IPTABLES:
+ command = IPTABLES_SAVE
+ elif splitted[0] == IP6TABLES:
+ command = IP6TABLES_SAVE
+ args = splitted[1:]
+ proc = subprocess.Popen(command, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, err = proc.communicate()
+
+ #
+ # check for segfaults
+ #
+ if proc.returncode == -11:
+ reason = "iptables-save segfaults: " + cmd
+ print_error(reason, filename, lineno)
+ delete_rule(iptables, rule, filename, lineno)
+ return -1
+
+ # find the rule
+ matching = out.find(rule_save)
+ if matching < 0:
+ reason = "cannot find: " + iptables + " -I " + rule
+ print_error(reason, filename, lineno)
+ delete_rule(iptables, rule, filename, lineno)
+ return -1
+
+ return delete_rule(iptables, rule, filename, lineno)
+
+
+def execute_cmd(cmd, filename, lineno):
+ '''
+ Executes a command, checking for segfaults and returning the command exit
+ code.
+
+ :param cmd: string with the command to be executed
+ :param filename: name of the file tested (used for print_error purposes)
+ :param lineno: line number being tested (used for print_error purposes)
+ '''
+ global log_file
+ print >> log_file, "command: %s" % cmd
+ ret = subprocess.call(cmd, shell=True, universal_newlines=True,
+ stderr=subprocess.STDOUT, stdout=log_file)
+ log_file.flush()
+
+ # generic check for segfaults
+ if ret == -11:
+ reason = "command segfaults: " + cmd
+ print_error(reason, filename, lineno)
+ return ret
+
+
+def run_test_file(filename):
+ '''
+ Runs a test file
+
+ :param filename: name of the file with the test rules
+ '''
+ #
+ # if this is not a test file, skip.
+ #
+ if not filename.endswith(".t"):
+ return 0, 0
+
+ if "libipt_" in filename:
+ iptables = IPTABLES
+ elif "libip6t_" in filename:
+ iptables = IP6TABLES
+ elif "libxt_" in filename:
+ iptables = IPTABLES
+ else:
+ # default to iptables if not known prefix
+ iptables = IPTABLES
+
+ f = open(filename)
+
+ tests = 0
+ passed = 0
+ table = ""
+ total_test_passed = True
+
+ for lineno, line in enumerate(f):
+ if line[0] == "#":
+ continue
+
+ if line[0] == ":":
+ chain_array = line.rstrip()[1:].split(",")
+ continue
+
+ # external non-iptables invocation, executed as is.
+ if line[0] == "@":
+ external_cmd = line.rstrip()[1:]
+ execute_cmd(external_cmd, filename, lineno)
+ continue
+
+ if line[0] == "*":
+ table = line.rstrip()[1:]
+ continue
+
+ if len(chain_array) == 0:
+ print "broken test, missing chain, leaving"
+ sys.exit()
+
+ test_passed = True
+ tests += 1
+
+ for chain in chain_array:
+ item = line.split(";")
+ if table == "":
+ rule = chain + " " + item[0]
+ else:
+ rule = chain + " -t " + table + " " + item[0]
+
+ if item[1] == "=":
+ rule_save = chain + " " + item[0]
+ else:
+ rule_save = chain + " " + item[1]
+
+ res = item[2].rstrip()
+
+ ret = run_test(iptables, rule, rule_save,
+ res, filename, lineno + 1)
+ if ret < 0:
+ test_passed = False
+ total_test_passed = False
+ break
+
+ if test_passed:
+ passed += 1
+
+ if total_test_passed:
+ print filename + ": " + Colors.GREEN + "OK" + Colors.ENDC
+
+ f.close()
+ return tests, passed
+
+
+def show_missing():
+ '''
+ Show the list of missing test files
+ '''
+ file_list = os.listdir(EXTENSIONS_PATH)
+ testfiles = [i for i in file_list if i.endswith('.t')]
+ libfiles = [i for i in file_list
+ if i.startswith('lib') and i.endswith('.c')]
+
+ def test_name(x):
+ return x[0:-2] + '.t'
+ missing = [test_name(i) for i in libfiles
+ if not test_name(i) in testfiles]
+
+ print '\n'.join(missing)
+
+
+#
+# main
+#
+def main():
+ parser = argparse.ArgumentParser(description='Run iptables tests')
+ parser.add_argument('filename', nargs='?',
+ metavar='path/to/file.t',
+ help='Run only this test')
+ parser.add_argument('-m', '--missing', action='store_true',
+ help='Check for missing tests')
+ args = parser.parse_args()
+
+ #
+ # show list of missing test files
+ #
+ if args.missing:
+ show_missing()
+ return
+
+ if os.getuid() != 0:
+ print "You need to be root to run this, sorry"
+ return
+
+ test_files = 0
+ tests = 0
+ passed = 0
+
+ # setup global var log file
+ global log_file
+ try:
+ log_file = open(LOGFILE, 'w')
+ except IOError:
+ print "Couldn't open log file %s" % LOGFILE
+ return
+
+ file_list = [os.path.join(EXTENSIONS_PATH, i)
+ for i in os.listdir(EXTENSIONS_PATH)]
+ if args.filename:
+ file_list = [args.filename]
+ for filename in file_list:
+ file_tests, file_passed = run_test_file(filename)
+ if file_tests:
+ tests += file_tests
+ passed += file_passed
+ test_files += 1
+
+ print ("%d test files, %d unit tests, %d passed" %
+ (test_files, tests, passed))
+
+
+if __name__ == '__main__':
+ main()