diff options
-rw-r--r-- | doc/nft.txt | 5 | ||||
-rw-r--r-- | include/nftables.h | 3 | ||||
-rw-r--r-- | include/nftables/libnftables.h | 7 | ||||
-rw-r--r-- | src/Makefile.am | 1 | ||||
-rw-r--r-- | src/libnftables.c | 71 | ||||
-rw-r--r-- | src/libnftables.map | 5 | ||||
-rw-r--r-- | src/main.c | 9 | ||||
-rw-r--r-- | src/optimize.c | 478 | ||||
-rw-r--r-- | tests/shell/testcases/optimizations/dumps/merge_stmts.nft | 5 | ||||
-rwxr-xr-x | tests/shell/testcases/optimizations/merge_stmts | 13 |
10 files changed, 586 insertions, 11 deletions
diff --git a/doc/nft.txt b/doc/nft.txt index e4ed9824..7240deaa 100644 --- a/doc/nft.txt +++ b/doc/nft.txt @@ -62,6 +62,11 @@ understanding of their meaning. You can get information about options by running *--check*:: Check commands validity without actually applying the changes. +*-o*:: +*--optimize*:: + Optimize your ruleset. You can combine this option with '-c' to inspect + the proposed optimizations. + .Ruleset list output formatting that modify the output of the list ruleset command: *-a*:: diff --git a/include/nftables.h b/include/nftables.h index d6d9b9cc..d49eb579 100644 --- a/include/nftables.h +++ b/include/nftables.h @@ -123,6 +123,7 @@ struct nft_ctx { bool check; struct nft_cache cache; uint32_t flags; + uint32_t optimize_flags; struct parser_state *state; void *scanner; struct scope *top_scope; @@ -224,6 +225,8 @@ int nft_print(struct output_ctx *octx, const char *fmt, ...) int nft_gmp_print(struct output_ctx *octx, const char *fmt, ...) __attribute__((format(printf, 2, 0))); +int nft_optimize(struct nft_ctx *nft, struct list_head *cmds); + #define __NFT_OUTPUT_NOTSUPP UINT_MAX #endif /* NFTABLES_NFTABLES_H */ diff --git a/include/nftables/libnftables.h b/include/nftables/libnftables.h index 957b5fbe..85e08c9b 100644 --- a/include/nftables/libnftables.h +++ b/include/nftables/libnftables.h @@ -41,6 +41,13 @@ void nft_ctx_free(struct nft_ctx *ctx); bool nft_ctx_get_dry_run(struct nft_ctx *ctx); void nft_ctx_set_dry_run(struct nft_ctx *ctx, bool dry); +enum nft_optimize_flags { + NFT_OPTIMIZE_ENABLED = 0x1, +}; + +uint32_t nft_ctx_get_optimize(struct nft_ctx *ctx); +void nft_ctx_set_optimize(struct nft_ctx *ctx, uint32_t flags); + enum { NFT_CTX_OUTPUT_REVERSEDNS = (1 << 0), NFT_CTX_OUTPUT_SERVICE = (1 << 1), diff --git a/src/Makefile.am b/src/Makefile.am index 6ab07523..4cfba0af 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -68,6 +68,7 @@ libnftables_la_SOURCES = \ mnl.c \ iface.c \ mergesort.c \ + optimize.c \ osf.c \ nfnl_osf.c \ tcpopt.c \ diff --git a/src/libnftables.c b/src/libnftables.c index e76f32ef..bd71ae9e 100644 --- a/src/libnftables.c +++ b/src/libnftables.c @@ -395,6 +395,18 @@ void nft_ctx_set_dry_run(struct nft_ctx *ctx, bool dry) ctx->check = dry; } +EXPORT_SYMBOL(nft_ctx_get_optimize); +uint32_t nft_ctx_get_optimize(struct nft_ctx *ctx) +{ + return ctx->optimize_flags; +} + +EXPORT_SYMBOL(nft_ctx_set_optimize); +void nft_ctx_set_optimize(struct nft_ctx *ctx, uint32_t flags) +{ + ctx->optimize_flags = flags; +} + EXPORT_SYMBOL(nft_ctx_output_get_flags); unsigned int nft_ctx_output_get_flags(struct nft_ctx *ctx) { @@ -626,8 +638,7 @@ retry: return rc; } -EXPORT_SYMBOL(nft_run_cmd_from_filename); -int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename) +static int __nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename) { struct cmd *cmd, *next; int rc, parser_rc; @@ -638,13 +649,6 @@ int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename) if (rc < 0) goto err; - if (!strcmp(filename, "-")) - filename = "/dev/stdin"; - - if (!strcmp(filename, "/dev/stdin") && - !nft_output_json(&nft->output)) - nft->stdin_buf = stdin_to_buffer(); - rc = -EINVAL; if (nft_output_json(&nft->output)) rc = nft_parse_json_filename(nft, filename, &msgs, &cmds); @@ -653,6 +657,9 @@ int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename) parser_rc = rc; + if (nft->optimize_flags) + nft_optimize(nft, &cmds); + rc = nft_evaluate(nft, &msgs, &cmds); if (rc < 0) goto err; @@ -694,7 +701,51 @@ err: if (rc) nft_cache_release(&nft->cache); + return rc; +} + +static int nft_run_optimized_file(struct nft_ctx *nft, const char *filename) +{ + uint32_t optimize_flags; + bool check; + int ret; + + check = nft->check; + nft->check = true; + optimize_flags = nft->optimize_flags; + nft->optimize_flags = 0; + + /* First check the original ruleset loads fine as is. */ + ret = __nft_run_cmd_from_filename(nft, filename); + if (ret < 0) + return ret; + + nft->check = check; + nft->optimize_flags = optimize_flags; + + return __nft_run_cmd_from_filename(nft, filename); +} + +EXPORT_SYMBOL(nft_run_cmd_from_filename); +int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename) +{ + int ret; + + if (!strcmp(filename, "-")) + filename = "/dev/stdin"; + + if (!strcmp(filename, "/dev/stdin") && + !nft_output_json(&nft->output)) + nft->stdin_buf = stdin_to_buffer(); + + if (nft->optimize_flags) { + ret = nft_run_optimized_file(nft, filename); + xfree(nft->stdin_buf); + return ret; + } + + ret = __nft_run_cmd_from_filename(nft, filename); xfree(nft->stdin_buf); - return rc; + return ret; } diff --git a/src/libnftables.map b/src/libnftables.map index d3a795ce..a511dd78 100644 --- a/src/libnftables.map +++ b/src/libnftables.map @@ -28,3 +28,8 @@ LIBNFTABLES_2 { nft_ctx_add_var; nft_ctx_clear_vars; } LIBNFTABLES_1; + +LIBNFTABLES_3 { + nft_set_optimize; + nft_get_optimize; +} LIBNFTABLES_2; @@ -36,7 +36,8 @@ enum opt_indices { IDX_INTERACTIVE, IDX_INCLUDEPATH, IDX_CHECK, -#define IDX_RULESET_INPUT_END IDX_CHECK + IDX_OPTIMIZE, +#define IDX_RULESET_INPUT_END IDX_OPTIMIZE /* Ruleset list formatting */ IDX_HANDLE, #define IDX_RULESET_LIST_START IDX_HANDLE @@ -80,6 +81,7 @@ enum opt_vals { OPT_NUMERIC_PROTO = 'p', OPT_NUMERIC_TIME = 'T', OPT_TERSE = 't', + OPT_OPTIMIZE = 'o', OPT_INVALID = '?', }; @@ -136,6 +138,8 @@ static const struct nft_opt nft_options[] = { "Format output in JSON"), [IDX_DEBUG] = NFT_OPT("debug", OPT_DEBUG, "<level [,level...]>", "Specify debugging level (scanner, parser, eval, netlink, mnl, proto-ctx, segtree, all)"), + [IDX_OPTIMIZE] = NFT_OPT("optimize", OPT_OPTIMIZE, NULL, + "Optimize ruleset"), }; #define NR_NFT_OPTIONS (sizeof(nft_options) / sizeof(nft_options[0])) @@ -484,6 +488,9 @@ int main(int argc, char * const *argv) case OPT_TERSE: output_flags |= NFT_CTX_OUTPUT_TERSE; break; + case OPT_OPTIMIZE: + nft_ctx_set_optimize(nft, 0x1); + break; case OPT_INVALID: exit(EXIT_FAILURE); } diff --git a/src/optimize.c b/src/optimize.c new file mode 100644 index 00000000..bae36d76 --- /dev/null +++ b/src/optimize.c @@ -0,0 +1,478 @@ +/* + * Copyright (c) 2021 Pablo Neira Ayuso <pablo@netfilter.org> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 (or any + * later) as published by the Free Software Foundation. + */ + +/* Funded through the NGI0 PET Fund established by NLnet (https://nlnet.nl) + * with support from the European Commission's Next Generation Internet + * programme. + */ + +#define _GNU_SOURCE +#include <string.h> +#include <errno.h> +#include <inttypes.h> +#include <nftables.h> +#include <parser.h> +#include <expression.h> +#include <statement.h> +#include <utils.h> +#include <erec.h> + +#define MAX_STMTS 32 + +struct optimize_ctx { + struct stmt *stmt[MAX_STMTS]; + uint32_t num_stmts; + + struct stmt ***stmt_matrix; + struct rule **rule; + uint32_t num_rules; +}; + +static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b) +{ + struct expr *expr_a, *expr_b; + + if (stmt_a->ops->type != stmt_b->ops->type) + return false; + + switch (stmt_a->ops->type) { + case STMT_EXPRESSION: + expr_a = stmt_a->expr; + expr_b = stmt_b->expr; + + if (expr_a->left->etype != expr_b->left->etype) + return false; + + switch (expr_a->left->etype) { + case EXPR_PAYLOAD: + if (expr_a->left->payload.desc != expr_b->left->payload.desc) + return false; + if (expr_a->left->payload.tmpl != expr_b->left->payload.tmpl) + return false; + break; + case EXPR_EXTHDR: + if (expr_a->left->exthdr.desc != expr_b->left->exthdr.desc) + return false; + if (expr_a->left->exthdr.tmpl != expr_b->left->exthdr.tmpl) + return false; + break; + case EXPR_META: + if (expr_a->left->meta.key != expr_b->left->meta.key) + return false; + if (expr_a->left->meta.base != expr_b->left->meta.base) + return false; + break; + case EXPR_CT: + if (expr_a->left->ct.key != expr_b->left->ct.key) + return false; + if (expr_a->left->ct.base != expr_b->left->ct.base) + return false; + if (expr_a->left->ct.direction != expr_b->left->ct.direction) + return false; + if (expr_a->left->ct.nfproto != expr_b->left->ct.nfproto) + return false; + break; + case EXPR_RT: + if (expr_a->left->rt.key != expr_b->left->rt.key) + return false; + break; + case EXPR_SOCKET: + if (expr_a->left->socket.key != expr_b->left->socket.key) + return false; + if (expr_a->left->socket.level != expr_b->left->socket.level) + return false; + break; + default: + return false; + } + break; + case STMT_COUNTER: + case STMT_NOTRACK: + break; + case STMT_VERDICT: + expr_a = stmt_a->expr; + expr_b = stmt_b->expr; + if (expr_a->verdict != expr_b->verdict) + return false; + if (expr_a->chain && expr_b->chain) { + if (expr_a->chain->etype != expr_b->chain->etype) + return false; + if (expr_a->chain->etype == EXPR_VALUE && + strcmp(expr_a->chain->identifier, expr_b->chain->identifier)) + return false; + } else if (expr_a->chain || expr_b->chain) { + return false; + } + break; + case STMT_LIMIT: + if (stmt_a->limit.rate != stmt_b->limit.rate || + stmt_a->limit.unit != stmt_b->limit.unit || + stmt_a->limit.burst != stmt_b->limit.burst || + stmt_a->limit.type != stmt_b->limit.type || + stmt_a->limit.flags != stmt_b->limit.flags) + return false; + break; + case STMT_LOG: + if (stmt_a->log.snaplen != stmt_b->log.snaplen || + stmt_a->log.group != stmt_b->log.group || + stmt_a->log.qthreshold != stmt_b->log.qthreshold || + stmt_a->log.level != stmt_b->log.level || + stmt_a->log.logflags != stmt_b->log.logflags || + stmt_a->log.flags != stmt_b->log.flags || + stmt_a->log.prefix->etype != EXPR_VALUE || + stmt_b->log.prefix->etype != EXPR_VALUE || + mpz_cmp(stmt_a->log.prefix->value, stmt_b->log.prefix->value)) + return false; + break; + case STMT_REJECT: + if (stmt_a->reject.expr || stmt_b->reject.expr) + return false; + + if (stmt_a->reject.family != stmt_b->reject.family || + stmt_a->reject.type != stmt_b->reject.type || + stmt_a->reject.icmp_code != stmt_b->reject.icmp_code) + return false; + break; + default: + /* ... Merging anything else is yet unsupported. */ + return false; + } + + return true; +} + +static bool stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b) +{ + if (!stmt_a && !stmt_b) + return true; + else if (!stmt_a) + return false; + else if (!stmt_b) + return false; + + return __stmt_type_eq(stmt_a, stmt_b); +} + +static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt) +{ + uint32_t i; + + for (i = 0; i < ctx->num_stmts; i++) { + if (__stmt_type_eq(stmt, ctx->stmt[i])) + return true; + } + + return false; +} + +static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule) +{ + struct stmt *stmt, *clone; + + list_for_each_entry(stmt, &rule->stmts, list) { + if (stmt_type_find(ctx, stmt)) + continue; + + /* No refcounter available in statement objects, clone it to + * to store in the array of selectors. + */ + clone = stmt_alloc(&internal_location, stmt->ops); + switch (stmt->ops->type) { + case STMT_EXPRESSION: + case STMT_VERDICT: + clone->expr = expr_get(stmt->expr); + break; + case STMT_COUNTER: + case STMT_NOTRACK: + break; + case STMT_LIMIT: + memcpy(&clone->limit, &stmt->limit, sizeof(clone->limit)); + break; + case STMT_LOG: + memcpy(&clone->log, &stmt->log, sizeof(clone->log)); + clone->log.prefix = expr_get(stmt->log.prefix); + break; + default: + break; + } + + ctx->stmt[ctx->num_stmts++] = clone; + if (ctx->num_stmts >= MAX_STMTS) + return -1; + } + + return 0; +} + +static int cmd_stmt_find_in_stmt_matrix(struct optimize_ctx *ctx, struct stmt *stmt) +{ + uint32_t i; + + for (i = 0; i < ctx->num_stmts; i++) { + if (__stmt_type_eq(stmt, ctx->stmt[i])) + return i; + } + /* should not ever happen. */ + return 0; +} + +static void rule_build_stmt_matrix_stmts(struct optimize_ctx *ctx, + struct rule *rule, uint32_t *i) +{ + struct stmt *stmt; + int k; + + list_for_each_entry(stmt, &rule->stmts, list) { + k = cmd_stmt_find_in_stmt_matrix(ctx, stmt); + ctx->stmt_matrix[*i][k] = stmt; + } + ctx->rule[(*i)++] = rule; +} + +struct merge { + /* interval of rules to be merged */ + uint32_t rule_from; + uint32_t num_rules; + /* statements to be merged (index relative to statement matrix) */ + uint32_t stmt[MAX_STMTS]; + uint32_t num_stmts; +}; + +static void merge_stmts(const struct optimize_ctx *ctx, + uint32_t from, uint32_t to, const struct merge *merge) +{ + struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]]; + struct expr *expr_a, *expr_b, *set, *elem; + struct stmt *stmt_b; + uint32_t i; + + assert (stmt_a->ops->type == STMT_EXPRESSION); + + set = set_expr_alloc(&internal_location, NULL); + set->set_flags |= NFT_SET_ANONYMOUS; + + expr_a = stmt_a->expr->right; + elem = set_elem_expr_alloc(&internal_location, expr_get(expr_a)); + compound_expr_add(set, elem); + + for (i = from + 1; i <= to; i++) { + stmt_b = ctx->stmt_matrix[i][merge->stmt[0]]; + expr_b = stmt_b->expr->right; + elem = set_elem_expr_alloc(&internal_location, expr_get(expr_b)); + compound_expr_add(set, elem); + } + + expr_free(stmt_a->expr->right); + stmt_a->expr->right = set; +} + +static void rule_optimize_print(struct output_ctx *octx, + const struct rule *rule) +{ + const struct location *loc = &rule->location; + const struct input_descriptor *indesc = loc->indesc; + const char *line; + char buf[1024]; + + switch (indesc->type) { + case INDESC_BUFFER: + case INDESC_CLI: + line = indesc->data; + *strchrnul(line, '\n') = '\0'; + break; + case INDESC_STDIN: + line = indesc->data; + line += loc->line_offset; + *strchrnul(line, '\n') = '\0'; + break; + case INDESC_FILE: + line = line_location(indesc, loc, buf, sizeof(buf)); + break; + case INDESC_INTERNAL: + case INDESC_NETLINK: + break; + default: + BUG("invalid input descriptor type %u\n", indesc->type); + } + + print_location(octx->error_fp, indesc, loc); + fprintf(octx->error_fp, "%s\n", line); +} + +static void merge_rules(const struct optimize_ctx *ctx, + uint32_t from, uint32_t to, + const struct merge *merge, + struct output_ctx *octx) +{ + uint32_t i; + + if (merge->num_stmts > 1) { + return; + } else { + merge_stmts(ctx, from, to, merge); + } + + fprintf(octx->error_fp, "Merging:\n"); + rule_optimize_print(octx, ctx->rule[from]); + + for (i = from + 1; i <= to; i++) { + rule_optimize_print(octx, ctx->rule[i]); + list_del(&ctx->rule[i]->list); + rule_free(ctx->rule[i]); + } + + fprintf(octx->error_fp, "into:\n\t"); + rule_print(ctx->rule[from], octx); + fprintf(octx->error_fp, "\n"); +} + +static bool rules_eq(const struct optimize_ctx *ctx, int i, int j) +{ + uint32_t k; + + for (k = 0; k < ctx->num_stmts; k++) { + if (!stmt_type_eq(ctx->stmt_matrix[i][k], ctx->stmt_matrix[j][k])) + return false; + } + + return true; +} + +static int chain_optimize(struct nft_ctx *nft, struct list_head *rules) +{ + struct optimize_ctx *ctx; + uint32_t num_merges = 0; + struct merge *merge; + uint32_t i, j, m, k; + struct rule *rule; + int ret; + + ctx = xzalloc(sizeof(*ctx)); + + /* Step 1: collect statements in rules */ + list_for_each_entry(rule, rules, list) { + ret = rule_collect_stmts(ctx, rule); + if (ret < 0) + goto err; + + ctx->num_rules++; + } + + ctx->rule = xzalloc(sizeof(ctx->rule) * ctx->num_rules); + ctx->stmt_matrix = xzalloc(sizeof(struct stmt *) * ctx->num_rules); + for (i = 0; i < ctx->num_rules; i++) + ctx->stmt_matrix[i] = xzalloc(sizeof(struct stmt *) * MAX_STMTS); + + merge = xzalloc(sizeof(*merge) * ctx->num_rules); + + /* Step 2: Build matrix of statements */ + i = 0; + list_for_each_entry(rule, rules, list) + rule_build_stmt_matrix_stmts(ctx, rule, &i); + + /* Step 3: Look for common selectors for possible rule mergers */ + for (i = 0; i < ctx->num_rules; i++) { + for (j = i + 1; j < ctx->num_rules; j++) { + if (!rules_eq(ctx, i, j)) { + if (merge[num_merges].num_rules > 0) + num_merges++; + + i = j - 1; + break; + } + if (merge[num_merges].num_rules > 0) { + merge[num_merges].num_rules++; + } else { + merge[num_merges].rule_from = i; + merge[num_merges].num_rules = 2; + } + } + if (j == ctx->num_rules && merge[num_merges].num_rules > 0) { + num_merges++; + break; + } + } + + /* Step 4: Infer how to merge the candidate rules */ + for (k = 0; k < num_merges; k++) { + i = merge[k].rule_from; + + for (m = 0; m < ctx->num_stmts; m++) { + if (!ctx->stmt_matrix[i][m]) + continue; + switch (ctx->stmt_matrix[i][m]->ops->type) { + case STMT_EXPRESSION: + merge[k].stmt[merge[k].num_stmts++] = m; + break; + default: + break; + } + } + + j = merge[k].num_rules - 1; + merge_rules(ctx, i, i + j, &merge[k], &nft->output); + } + ret = 0; + for (i = 0; i < ctx->num_rules; i++) + xfree(ctx->stmt_matrix[i]); + + xfree(ctx->stmt_matrix); + xfree(merge); +err: + for (i = 0; i < ctx->num_stmts; i++) + stmt_free(ctx->stmt[i]); + + xfree(ctx->rule); + xfree(ctx); + + return ret; +} + +static int cmd_optimize(struct nft_ctx *nft, struct cmd *cmd) +{ + struct table *table; + struct chain *chain; + int ret = 0; + + switch (cmd->obj) { + case CMD_OBJ_TABLE: + table = cmd->table; + if (!table) + break; + + list_for_each_entry(chain, &table->chains, list) { + if (chain->flags & CHAIN_F_HW_OFFLOAD) + continue; + + chain_optimize(nft, &chain->rules); + } + break; + default: + break; + } + + return ret; +} + +int nft_optimize(struct nft_ctx *nft, struct list_head *cmds) +{ + struct cmd *cmd; + int ret; + + list_for_each_entry(cmd, cmds, list) { + switch (cmd->op) { + case CMD_ADD: + ret = cmd_optimize(nft, cmd); + break; + default: + break; + } + } + + return ret; +} diff --git a/tests/shell/testcases/optimizations/dumps/merge_stmts.nft b/tests/shell/testcases/optimizations/dumps/merge_stmts.nft new file mode 100644 index 00000000..b56ea3ed --- /dev/null +++ b/tests/shell/testcases/optimizations/dumps/merge_stmts.nft @@ -0,0 +1,5 @@ +table ip x { + chain y { + ip daddr { 192.168.0.1, 192.168.0.2, 192.168.0.3 } counter packets 0 bytes 0 accept + } +} diff --git a/tests/shell/testcases/optimizations/merge_stmts b/tests/shell/testcases/optimizations/merge_stmts new file mode 100755 index 00000000..0c35636e --- /dev/null +++ b/tests/shell/testcases/optimizations/merge_stmts @@ -0,0 +1,13 @@ +#!/bin/bash + +set -e + +RULESET="table ip x { + chain y { + ip daddr 192.168.0.1 counter accept + ip daddr 192.168.0.2 counter accept + ip daddr 192.168.0.3 counter accept + } +}" + +$NFT -o -f - <<< $RULESET |