summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorPablo Neira Ayuso <pablo@netfilter.org>2022-01-02 21:46:21 +0100
committerPablo Neira Ayuso <pablo@netfilter.org>2022-01-15 18:11:22 +0100
commitfb298877ece2739ffb08b1967c10829969859e2c (patch)
tree2a403f39cda489c44dcf9c1ef5a114f78299a621 /src
parent19a36424bc2949f2ae96731ada9714f8bce950d8 (diff)
src: add ruleset optimization infrastructure
This patch adds a new -o/--optimize option to enable ruleset optimization. You can combine this option with the dry run mode (--check) to review the proposed ruleset updates without actually loading the ruleset, e.g. # nft -c -o -f ruleset.test Merging: ruleset.nft:16:3-37: ip daddr 192.168.0.1 counter accept ruleset.nft:17:3-37: ip daddr 192.168.0.2 counter accept ruleset.nft:18:3-37: ip daddr 192.168.0.3 counter accept into: ip daddr { 192.168.0.1, 192.168.0.2, 192.168.0.3 } counter packets 0 bytes 0 accept This infrastructure collects the common statements that are used in rules, then it builds a matrix of rules vs. statements. Then, it looks for common statements in consecutive rules which allows to merge rules. This ruleset optimization always performs an implicit dry run to validate that the original ruleset is correct. Then, on a second pass, it performs the ruleset optimization and add the rules into the kernel (unless --check has been specified by the user). From libnftables perspective, there is a new API to enable this feature: uint32_t nft_ctx_get_optimize(struct nft_ctx *ctx); void nft_ctx_set_optimize(struct nft_ctx *ctx, uint32_t flags); This patch adds support for the first optimization: Collapse a linear list of rules matching on a single selector into a set as exposed in the example above. Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
Diffstat (limited to 'src')
-rw-r--r--src/Makefile.am1
-rw-r--r--src/libnftables.c71
-rw-r--r--src/libnftables.map5
-rw-r--r--src/main.c9
-rw-r--r--src/optimize.c478
5 files changed, 553 insertions, 11 deletions
diff --git a/src/Makefile.am b/src/Makefile.am
index 6ab07523..4cfba0af 100644
--- a/src/Makefile.am
+++ b/src/Makefile.am
@@ -68,6 +68,7 @@ libnftables_la_SOURCES = \
mnl.c \
iface.c \
mergesort.c \
+ optimize.c \
osf.c \
nfnl_osf.c \
tcpopt.c \
diff --git a/src/libnftables.c b/src/libnftables.c
index e76f32ef..bd71ae9e 100644
--- a/src/libnftables.c
+++ b/src/libnftables.c
@@ -395,6 +395,18 @@ void nft_ctx_set_dry_run(struct nft_ctx *ctx, bool dry)
ctx->check = dry;
}
+EXPORT_SYMBOL(nft_ctx_get_optimize);
+uint32_t nft_ctx_get_optimize(struct nft_ctx *ctx)
+{
+ return ctx->optimize_flags;
+}
+
+EXPORT_SYMBOL(nft_ctx_set_optimize);
+void nft_ctx_set_optimize(struct nft_ctx *ctx, uint32_t flags)
+{
+ ctx->optimize_flags = flags;
+}
+
EXPORT_SYMBOL(nft_ctx_output_get_flags);
unsigned int nft_ctx_output_get_flags(struct nft_ctx *ctx)
{
@@ -626,8 +638,7 @@ retry:
return rc;
}
-EXPORT_SYMBOL(nft_run_cmd_from_filename);
-int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
+static int __nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
{
struct cmd *cmd, *next;
int rc, parser_rc;
@@ -638,13 +649,6 @@ int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
if (rc < 0)
goto err;
- if (!strcmp(filename, "-"))
- filename = "/dev/stdin";
-
- if (!strcmp(filename, "/dev/stdin") &&
- !nft_output_json(&nft->output))
- nft->stdin_buf = stdin_to_buffer();
-
rc = -EINVAL;
if (nft_output_json(&nft->output))
rc = nft_parse_json_filename(nft, filename, &msgs, &cmds);
@@ -653,6 +657,9 @@ int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
parser_rc = rc;
+ if (nft->optimize_flags)
+ nft_optimize(nft, &cmds);
+
rc = nft_evaluate(nft, &msgs, &cmds);
if (rc < 0)
goto err;
@@ -694,7 +701,51 @@ err:
if (rc)
nft_cache_release(&nft->cache);
+ return rc;
+}
+
+static int nft_run_optimized_file(struct nft_ctx *nft, const char *filename)
+{
+ uint32_t optimize_flags;
+ bool check;
+ int ret;
+
+ check = nft->check;
+ nft->check = true;
+ optimize_flags = nft->optimize_flags;
+ nft->optimize_flags = 0;
+
+ /* First check the original ruleset loads fine as is. */
+ ret = __nft_run_cmd_from_filename(nft, filename);
+ if (ret < 0)
+ return ret;
+
+ nft->check = check;
+ nft->optimize_flags = optimize_flags;
+
+ return __nft_run_cmd_from_filename(nft, filename);
+}
+
+EXPORT_SYMBOL(nft_run_cmd_from_filename);
+int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
+{
+ int ret;
+
+ if (!strcmp(filename, "-"))
+ filename = "/dev/stdin";
+
+ if (!strcmp(filename, "/dev/stdin") &&
+ !nft_output_json(&nft->output))
+ nft->stdin_buf = stdin_to_buffer();
+
+ if (nft->optimize_flags) {
+ ret = nft_run_optimized_file(nft, filename);
+ xfree(nft->stdin_buf);
+ return ret;
+ }
+
+ ret = __nft_run_cmd_from_filename(nft, filename);
xfree(nft->stdin_buf);
- return rc;
+ return ret;
}
diff --git a/src/libnftables.map b/src/libnftables.map
index d3a795ce..a511dd78 100644
--- a/src/libnftables.map
+++ b/src/libnftables.map
@@ -28,3 +28,8 @@ LIBNFTABLES_2 {
nft_ctx_add_var;
nft_ctx_clear_vars;
} LIBNFTABLES_1;
+
+LIBNFTABLES_3 {
+ nft_set_optimize;
+ nft_get_optimize;
+} LIBNFTABLES_2;
diff --git a/src/main.c b/src/main.c
index 5847fc4a..9bd25db8 100644
--- a/src/main.c
+++ b/src/main.c
@@ -36,7 +36,8 @@ enum opt_indices {
IDX_INTERACTIVE,
IDX_INCLUDEPATH,
IDX_CHECK,
-#define IDX_RULESET_INPUT_END IDX_CHECK
+ IDX_OPTIMIZE,
+#define IDX_RULESET_INPUT_END IDX_OPTIMIZE
/* Ruleset list formatting */
IDX_HANDLE,
#define IDX_RULESET_LIST_START IDX_HANDLE
@@ -80,6 +81,7 @@ enum opt_vals {
OPT_NUMERIC_PROTO = 'p',
OPT_NUMERIC_TIME = 'T',
OPT_TERSE = 't',
+ OPT_OPTIMIZE = 'o',
OPT_INVALID = '?',
};
@@ -136,6 +138,8 @@ static const struct nft_opt nft_options[] = {
"Format output in JSON"),
[IDX_DEBUG] = NFT_OPT("debug", OPT_DEBUG, "<level [,level...]>",
"Specify debugging level (scanner, parser, eval, netlink, mnl, proto-ctx, segtree, all)"),
+ [IDX_OPTIMIZE] = NFT_OPT("optimize", OPT_OPTIMIZE, NULL,
+ "Optimize ruleset"),
};
#define NR_NFT_OPTIONS (sizeof(nft_options) / sizeof(nft_options[0]))
@@ -484,6 +488,9 @@ int main(int argc, char * const *argv)
case OPT_TERSE:
output_flags |= NFT_CTX_OUTPUT_TERSE;
break;
+ case OPT_OPTIMIZE:
+ nft_ctx_set_optimize(nft, 0x1);
+ break;
case OPT_INVALID:
exit(EXIT_FAILURE);
}
diff --git a/src/optimize.c b/src/optimize.c
new file mode 100644
index 00000000..bae36d76
--- /dev/null
+++ b/src/optimize.c
@@ -0,0 +1,478 @@
+/*
+ * Copyright (c) 2021 Pablo Neira Ayuso <pablo@netfilter.org>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2 (or any
+ * later) as published by the Free Software Foundation.
+ */
+
+/* Funded through the NGI0 PET Fund established by NLnet (https://nlnet.nl)
+ * with support from the European Commission's Next Generation Internet
+ * programme.
+ */
+
+#define _GNU_SOURCE
+#include <string.h>
+#include <errno.h>
+#include <inttypes.h>
+#include <nftables.h>
+#include <parser.h>
+#include <expression.h>
+#include <statement.h>
+#include <utils.h>
+#include <erec.h>
+
+#define MAX_STMTS 32
+
+struct optimize_ctx {
+ struct stmt *stmt[MAX_STMTS];
+ uint32_t num_stmts;
+
+ struct stmt ***stmt_matrix;
+ struct rule **rule;
+ uint32_t num_rules;
+};
+
+static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b)
+{
+ struct expr *expr_a, *expr_b;
+
+ if (stmt_a->ops->type != stmt_b->ops->type)
+ return false;
+
+ switch (stmt_a->ops->type) {
+ case STMT_EXPRESSION:
+ expr_a = stmt_a->expr;
+ expr_b = stmt_b->expr;
+
+ if (expr_a->left->etype != expr_b->left->etype)
+ return false;
+
+ switch (expr_a->left->etype) {
+ case EXPR_PAYLOAD:
+ if (expr_a->left->payload.desc != expr_b->left->payload.desc)
+ return false;
+ if (expr_a->left->payload.tmpl != expr_b->left->payload.tmpl)
+ return false;
+ break;
+ case EXPR_EXTHDR:
+ if (expr_a->left->exthdr.desc != expr_b->left->exthdr.desc)
+ return false;
+ if (expr_a->left->exthdr.tmpl != expr_b->left->exthdr.tmpl)
+ return false;
+ break;
+ case EXPR_META:
+ if (expr_a->left->meta.key != expr_b->left->meta.key)
+ return false;
+ if (expr_a->left->meta.base != expr_b->left->meta.base)
+ return false;
+ break;
+ case EXPR_CT:
+ if (expr_a->left->ct.key != expr_b->left->ct.key)
+ return false;
+ if (expr_a->left->ct.base != expr_b->left->ct.base)
+ return false;
+ if (expr_a->left->ct.direction != expr_b->left->ct.direction)
+ return false;
+ if (expr_a->left->ct.nfproto != expr_b->left->ct.nfproto)
+ return false;
+ break;
+ case EXPR_RT:
+ if (expr_a->left->rt.key != expr_b->left->rt.key)
+ return false;
+ break;
+ case EXPR_SOCKET:
+ if (expr_a->left->socket.key != expr_b->left->socket.key)
+ return false;
+ if (expr_a->left->socket.level != expr_b->left->socket.level)
+ return false;
+ break;
+ default:
+ return false;
+ }
+ break;
+ case STMT_COUNTER:
+ case STMT_NOTRACK:
+ break;
+ case STMT_VERDICT:
+ expr_a = stmt_a->expr;
+ expr_b = stmt_b->expr;
+ if (expr_a->verdict != expr_b->verdict)
+ return false;
+ if (expr_a->chain && expr_b->chain) {
+ if (expr_a->chain->etype != expr_b->chain->etype)
+ return false;
+ if (expr_a->chain->etype == EXPR_VALUE &&
+ strcmp(expr_a->chain->identifier, expr_b->chain->identifier))
+ return false;
+ } else if (expr_a->chain || expr_b->chain) {
+ return false;
+ }
+ break;
+ case STMT_LIMIT:
+ if (stmt_a->limit.rate != stmt_b->limit.rate ||
+ stmt_a->limit.unit != stmt_b->limit.unit ||
+ stmt_a->limit.burst != stmt_b->limit.burst ||
+ stmt_a->limit.type != stmt_b->limit.type ||
+ stmt_a->limit.flags != stmt_b->limit.flags)
+ return false;
+ break;
+ case STMT_LOG:
+ if (stmt_a->log.snaplen != stmt_b->log.snaplen ||
+ stmt_a->log.group != stmt_b->log.group ||
+ stmt_a->log.qthreshold != stmt_b->log.qthreshold ||
+ stmt_a->log.level != stmt_b->log.level ||
+ stmt_a->log.logflags != stmt_b->log.logflags ||
+ stmt_a->log.flags != stmt_b->log.flags ||
+ stmt_a->log.prefix->etype != EXPR_VALUE ||
+ stmt_b->log.prefix->etype != EXPR_VALUE ||
+ mpz_cmp(stmt_a->log.prefix->value, stmt_b->log.prefix->value))
+ return false;
+ break;
+ case STMT_REJECT:
+ if (stmt_a->reject.expr || stmt_b->reject.expr)
+ return false;
+
+ if (stmt_a->reject.family != stmt_b->reject.family ||
+ stmt_a->reject.type != stmt_b->reject.type ||
+ stmt_a->reject.icmp_code != stmt_b->reject.icmp_code)
+ return false;
+ break;
+ default:
+ /* ... Merging anything else is yet unsupported. */
+ return false;
+ }
+
+ return true;
+}
+
+static bool stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b)
+{
+ if (!stmt_a && !stmt_b)
+ return true;
+ else if (!stmt_a)
+ return false;
+ else if (!stmt_b)
+ return false;
+
+ return __stmt_type_eq(stmt_a, stmt_b);
+}
+
+static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt)
+{
+ uint32_t i;
+
+ for (i = 0; i < ctx->num_stmts; i++) {
+ if (__stmt_type_eq(stmt, ctx->stmt[i]))
+ return true;
+ }
+
+ return false;
+}
+
+static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
+{
+ struct stmt *stmt, *clone;
+
+ list_for_each_entry(stmt, &rule->stmts, list) {
+ if (stmt_type_find(ctx, stmt))
+ continue;
+
+ /* No refcounter available in statement objects, clone it to
+ * to store in the array of selectors.
+ */
+ clone = stmt_alloc(&internal_location, stmt->ops);
+ switch (stmt->ops->type) {
+ case STMT_EXPRESSION:
+ case STMT_VERDICT:
+ clone->expr = expr_get(stmt->expr);
+ break;
+ case STMT_COUNTER:
+ case STMT_NOTRACK:
+ break;
+ case STMT_LIMIT:
+ memcpy(&clone->limit, &stmt->limit, sizeof(clone->limit));
+ break;
+ case STMT_LOG:
+ memcpy(&clone->log, &stmt->log, sizeof(clone->log));
+ clone->log.prefix = expr_get(stmt->log.prefix);
+ break;
+ default:
+ break;
+ }
+
+ ctx->stmt[ctx->num_stmts++] = clone;
+ if (ctx->num_stmts >= MAX_STMTS)
+ return -1;
+ }
+
+ return 0;
+}
+
+static int cmd_stmt_find_in_stmt_matrix(struct optimize_ctx *ctx, struct stmt *stmt)
+{
+ uint32_t i;
+
+ for (i = 0; i < ctx->num_stmts; i++) {
+ if (__stmt_type_eq(stmt, ctx->stmt[i]))
+ return i;
+ }
+ /* should not ever happen. */
+ return 0;
+}
+
+static void rule_build_stmt_matrix_stmts(struct optimize_ctx *ctx,
+ struct rule *rule, uint32_t *i)
+{
+ struct stmt *stmt;
+ int k;
+
+ list_for_each_entry(stmt, &rule->stmts, list) {
+ k = cmd_stmt_find_in_stmt_matrix(ctx, stmt);
+ ctx->stmt_matrix[*i][k] = stmt;
+ }
+ ctx->rule[(*i)++] = rule;
+}
+
+struct merge {
+ /* interval of rules to be merged */
+ uint32_t rule_from;
+ uint32_t num_rules;
+ /* statements to be merged (index relative to statement matrix) */
+ uint32_t stmt[MAX_STMTS];
+ uint32_t num_stmts;
+};
+
+static void merge_stmts(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to, const struct merge *merge)
+{
+ struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]];
+ struct expr *expr_a, *expr_b, *set, *elem;
+ struct stmt *stmt_b;
+ uint32_t i;
+
+ assert (stmt_a->ops->type == STMT_EXPRESSION);
+
+ set = set_expr_alloc(&internal_location, NULL);
+ set->set_flags |= NFT_SET_ANONYMOUS;
+
+ expr_a = stmt_a->expr->right;
+ elem = set_elem_expr_alloc(&internal_location, expr_get(expr_a));
+ compound_expr_add(set, elem);
+
+ for (i = from + 1; i <= to; i++) {
+ stmt_b = ctx->stmt_matrix[i][merge->stmt[0]];
+ expr_b = stmt_b->expr->right;
+ elem = set_elem_expr_alloc(&internal_location, expr_get(expr_b));
+ compound_expr_add(set, elem);
+ }
+
+ expr_free(stmt_a->expr->right);
+ stmt_a->expr->right = set;
+}
+
+static void rule_optimize_print(struct output_ctx *octx,
+ const struct rule *rule)
+{
+ const struct location *loc = &rule->location;
+ const struct input_descriptor *indesc = loc->indesc;
+ const char *line;
+ char buf[1024];
+
+ switch (indesc->type) {
+ case INDESC_BUFFER:
+ case INDESC_CLI:
+ line = indesc->data;
+ *strchrnul(line, '\n') = '\0';
+ break;
+ case INDESC_STDIN:
+ line = indesc->data;
+ line += loc->line_offset;
+ *strchrnul(line, '\n') = '\0';
+ break;
+ case INDESC_FILE:
+ line = line_location(indesc, loc, buf, sizeof(buf));
+ break;
+ case INDESC_INTERNAL:
+ case INDESC_NETLINK:
+ break;
+ default:
+ BUG("invalid input descriptor type %u\n", indesc->type);
+ }
+
+ print_location(octx->error_fp, indesc, loc);
+ fprintf(octx->error_fp, "%s\n", line);
+}
+
+static void merge_rules(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge,
+ struct output_ctx *octx)
+{
+ uint32_t i;
+
+ if (merge->num_stmts > 1) {
+ return;
+ } else {
+ merge_stmts(ctx, from, to, merge);
+ }
+
+ fprintf(octx->error_fp, "Merging:\n");
+ rule_optimize_print(octx, ctx->rule[from]);
+
+ for (i = from + 1; i <= to; i++) {
+ rule_optimize_print(octx, ctx->rule[i]);
+ list_del(&ctx->rule[i]->list);
+ rule_free(ctx->rule[i]);
+ }
+
+ fprintf(octx->error_fp, "into:\n\t");
+ rule_print(ctx->rule[from], octx);
+ fprintf(octx->error_fp, "\n");
+}
+
+static bool rules_eq(const struct optimize_ctx *ctx, int i, int j)
+{
+ uint32_t k;
+
+ for (k = 0; k < ctx->num_stmts; k++) {
+ if (!stmt_type_eq(ctx->stmt_matrix[i][k], ctx->stmt_matrix[j][k]))
+ return false;
+ }
+
+ return true;
+}
+
+static int chain_optimize(struct nft_ctx *nft, struct list_head *rules)
+{
+ struct optimize_ctx *ctx;
+ uint32_t num_merges = 0;
+ struct merge *merge;
+ uint32_t i, j, m, k;
+ struct rule *rule;
+ int ret;
+
+ ctx = xzalloc(sizeof(*ctx));
+
+ /* Step 1: collect statements in rules */
+ list_for_each_entry(rule, rules, list) {
+ ret = rule_collect_stmts(ctx, rule);
+ if (ret < 0)
+ goto err;
+
+ ctx->num_rules++;
+ }
+
+ ctx->rule = xzalloc(sizeof(ctx->rule) * ctx->num_rules);
+ ctx->stmt_matrix = xzalloc(sizeof(struct stmt *) * ctx->num_rules);
+ for (i = 0; i < ctx->num_rules; i++)
+ ctx->stmt_matrix[i] = xzalloc(sizeof(struct stmt *) * MAX_STMTS);
+
+ merge = xzalloc(sizeof(*merge) * ctx->num_rules);
+
+ /* Step 2: Build matrix of statements */
+ i = 0;
+ list_for_each_entry(rule, rules, list)
+ rule_build_stmt_matrix_stmts(ctx, rule, &i);
+
+ /* Step 3: Look for common selectors for possible rule mergers */
+ for (i = 0; i < ctx->num_rules; i++) {
+ for (j = i + 1; j < ctx->num_rules; j++) {
+ if (!rules_eq(ctx, i, j)) {
+ if (merge[num_merges].num_rules > 0)
+ num_merges++;
+
+ i = j - 1;
+ break;
+ }
+ if (merge[num_merges].num_rules > 0) {
+ merge[num_merges].num_rules++;
+ } else {
+ merge[num_merges].rule_from = i;
+ merge[num_merges].num_rules = 2;
+ }
+ }
+ if (j == ctx->num_rules && merge[num_merges].num_rules > 0) {
+ num_merges++;
+ break;
+ }
+ }
+
+ /* Step 4: Infer how to merge the candidate rules */
+ for (k = 0; k < num_merges; k++) {
+ i = merge[k].rule_from;
+
+ for (m = 0; m < ctx->num_stmts; m++) {
+ if (!ctx->stmt_matrix[i][m])
+ continue;
+ switch (ctx->stmt_matrix[i][m]->ops->type) {
+ case STMT_EXPRESSION:
+ merge[k].stmt[merge[k].num_stmts++] = m;
+ break;
+ default:
+ break;
+ }
+ }
+
+ j = merge[k].num_rules - 1;
+ merge_rules(ctx, i, i + j, &merge[k], &nft->output);
+ }
+ ret = 0;
+ for (i = 0; i < ctx->num_rules; i++)
+ xfree(ctx->stmt_matrix[i]);
+
+ xfree(ctx->stmt_matrix);
+ xfree(merge);
+err:
+ for (i = 0; i < ctx->num_stmts; i++)
+ stmt_free(ctx->stmt[i]);
+
+ xfree(ctx->rule);
+ xfree(ctx);
+
+ return ret;
+}
+
+static int cmd_optimize(struct nft_ctx *nft, struct cmd *cmd)
+{
+ struct table *table;
+ struct chain *chain;
+ int ret = 0;
+
+ switch (cmd->obj) {
+ case CMD_OBJ_TABLE:
+ table = cmd->table;
+ if (!table)
+ break;
+
+ list_for_each_entry(chain, &table->chains, list) {
+ if (chain->flags & CHAIN_F_HW_OFFLOAD)
+ continue;
+
+ chain_optimize(nft, &chain->rules);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return ret;
+}
+
+int nft_optimize(struct nft_ctx *nft, struct list_head *cmds)
+{
+ struct cmd *cmd;
+ int ret;
+
+ list_for_each_entry(cmd, cmds, list) {
+ switch (cmd->op) {
+ case CMD_ADD:
+ ret = cmd_optimize(nft, cmd);
+ break;
+ default:
+ break;
+ }
+ }
+
+ return ret;
+}