diff options
Diffstat (limited to 'src/optimize.c')
-rw-r--r-- | src/optimize.c | 560 |
1 files changed, 451 insertions, 109 deletions
diff --git a/src/optimize.c b/src/optimize.c index 419a37f2..89ba0d9d 100644 --- a/src/optimize.c +++ b/src/optimize.c @@ -11,8 +11,8 @@ * programme. */ -#define _GNU_SOURCE -#include <string.h> +#include <nft.h> + #include <errno.h> #include <inttypes.h> #include <nftables.h> @@ -21,6 +21,7 @@ #include <statement.h> #include <utils.h> #include <erec.h> +#include <linux/netfilter.h> #define MAX_STMTS 32 @@ -37,6 +38,8 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b) { if (expr_a->etype != expr_b->etype) return false; + if (expr_a->len != expr_b->len) + return false; switch (expr_a->etype) { case EXPR_PAYLOAD: @@ -46,6 +49,8 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b) return false; if (expr_a->payload.desc != expr_b->payload.desc) return false; + if (expr_a->payload.inner_desc != expr_b->payload.inner_desc) + return false; if (expr_a->payload.tmpl != expr_b->payload.tmpl) return false; break; @@ -60,6 +65,8 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b) return false; if (expr_a->meta.base != expr_b->meta.base) return false; + if (expr_a->meta.inner_desc != expr_b->meta.inner_desc) + return false; break; case EXPR_CT: if (expr_a->ct.key != expr_b->ct.key) @@ -120,7 +127,19 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b) return false; break; case EXPR_BINOP: - return __expr_cmp(expr_a->left, expr_b->left); + if (!__expr_cmp(expr_a->left, expr_b->left)) + return false; + + return __expr_cmp(expr_a->right, expr_b->right); + case EXPR_SYMBOL: + if (expr_a->symtype != expr_b->symtype) + return false; + if (expr_a->symtype != SYMBOL_VALUE) + return false; + + return !strcmp(expr_a->identifier, expr_b->identifier); + case EXPR_VALUE: + return !mpz_cmp(expr_a->value, expr_b->value); default: return false; } @@ -128,16 +147,41 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b) return true; } +static bool is_bitmask(const struct expr *expr) +{ + switch (expr->etype) { + case EXPR_BINOP: + if (expr->op == OP_OR && + !is_bitmask(expr->left)) + return false; + + return is_bitmask(expr->right); + case EXPR_VALUE: + case EXPR_SYMBOL: + return true; + default: + break; + } + + return false; +} + static bool stmt_expr_supported(const struct expr *expr) { switch (expr->right->etype) { case EXPR_SYMBOL: + case EXPR_RANGE_SYMBOL: case EXPR_RANGE: + case EXPR_RANGE_VALUE: case EXPR_PREFIX: case EXPR_SET: case EXPR_LIST: case EXPR_VALUE: return true; + case EXPR_BINOP: + if (is_bitmask(expr->right)) + return true; + break; default: break; } @@ -156,10 +200,10 @@ static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b, { struct expr *expr_a, *expr_b; - if (stmt_a->ops->type != stmt_b->ops->type) + if (stmt_a->type != stmt_b->type) return false; - switch (stmt_a->ops->type) { + switch (stmt_a->type) { case STMT_EXPRESSION: expr_a = stmt_a->expr; expr_b = stmt_b->expr; @@ -212,9 +256,7 @@ static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b, if (!stmt_a->log.prefix) return true; - if (stmt_a->log.prefix->etype != EXPR_VALUE || - stmt_b->log.prefix->etype != EXPR_VALUE || - mpz_cmp(stmt_a->log.prefix->value, stmt_b->log.prefix->value)) + if (strcmp(stmt_a->log.prefix, stmt_b->log.prefix)) return false; break; case STMT_REJECT: @@ -229,28 +271,65 @@ static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b, if (!stmt_a->reject.expr) return true; - if (__expr_cmp(stmt_a->reject.expr, stmt_b->reject.expr)) + if (!__expr_cmp(stmt_a->reject.expr, stmt_b->reject.expr)) return false; break; case STMT_NAT: if (stmt_a->nat.type != stmt_b->nat.type || stmt_a->nat.flags != stmt_b->nat.flags || stmt_a->nat.family != stmt_b->nat.family || - stmt_a->nat.type_flags != stmt_b->nat.type_flags || - (stmt_a->nat.addr && - stmt_a->nat.addr->etype != EXPR_SYMBOL && - stmt_a->nat.addr->etype != EXPR_RANGE) || - (stmt_b->nat.addr && - stmt_b->nat.addr->etype != EXPR_SYMBOL && - stmt_b->nat.addr->etype != EXPR_RANGE) || - (stmt_a->nat.proto && - stmt_a->nat.proto->etype != EXPR_SYMBOL && - stmt_a->nat.proto->etype != EXPR_RANGE) || - (stmt_b->nat.proto && - stmt_b->nat.proto->etype != EXPR_SYMBOL && - stmt_b->nat.proto->etype != EXPR_RANGE)) + stmt_a->nat.type_flags != stmt_b->nat.type_flags) return false; + switch (stmt_a->nat.type) { + case NFT_NAT_SNAT: + case NFT_NAT_DNAT: + if ((stmt_a->nat.addr && + stmt_a->nat.addr->etype != EXPR_SYMBOL && + stmt_a->nat.addr->etype != EXPR_RANGE) || + (stmt_b->nat.addr && + stmt_b->nat.addr->etype != EXPR_SYMBOL && + stmt_b->nat.addr->etype != EXPR_RANGE) || + (stmt_a->nat.proto && + stmt_a->nat.proto->etype != EXPR_SYMBOL && + stmt_a->nat.proto->etype != EXPR_RANGE) || + (stmt_b->nat.proto && + stmt_b->nat.proto->etype != EXPR_SYMBOL && + stmt_b->nat.proto->etype != EXPR_RANGE)) + return false; + break; + case NFT_NAT_MASQ: + break; + case NFT_NAT_REDIR: + if ((stmt_a->nat.proto && + stmt_a->nat.proto->etype != EXPR_SYMBOL && + stmt_a->nat.proto->etype != EXPR_RANGE) || + (stmt_b->nat.proto && + stmt_b->nat.proto->etype != EXPR_SYMBOL && + stmt_b->nat.proto->etype != EXPR_RANGE)) + return false; + + /* it should be possible to infer implicit redirections + * such as: + * + * tcp dport 1234 redirect + * tcp dport 3456 redirect to :7890 + * merge: + * redirect to tcp dport map { 1234 : 1234, 3456 : 7890 } + * + * currently not implemented. + */ + if (fully_compare && + stmt_a->nat.type == NFT_NAT_REDIR && + stmt_b->nat.type == NFT_NAT_REDIR && + (!!stmt_a->nat.proto ^ !!stmt_b->nat.proto)) + return false; + + break; + default: + assert(0); + } + return true; default: /* ... Merging anything else is yet unsupported. */ @@ -281,7 +360,7 @@ static bool stmt_verdict_eq(const struct stmt *stmt_a, const struct stmt *stmt_b { struct expr *expr_a, *expr_b; - assert (stmt_a->ops->type == STMT_VERDICT); + assert (stmt_a->type == STMT_VERDICT); expr_a = stmt_a->expr; expr_b = stmt_b->expr; @@ -302,14 +381,14 @@ static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt) uint32_t i; for (i = 0; i < ctx->num_stmts; i++) { - if (ctx->stmt[i]->ops->type == STMT_INVALID) + if (ctx->stmt[i]->type == STMT_INVALID) unsupported_exists = true; if (__stmt_type_eq(stmt, ctx->stmt[i], false)) return true; } - switch (stmt->ops->type) { + switch (stmt->type) { case STMT_EXPRESSION: case STMT_VERDICT: case STMT_COUNTER: @@ -328,13 +407,9 @@ static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt) return false; } -static struct stmt_ops unsupported_stmt_ops = { - .type = STMT_INVALID, - .name = "unsupported", -}; - static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule) { + const struct stmt_ops *ops; struct stmt *stmt, *clone; list_for_each_entry(stmt, &rule->stmts, list) { @@ -344,18 +419,20 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule) /* No refcounter available in statement objects, clone it to * to store in the array of selectors. */ - clone = stmt_alloc(&internal_location, stmt->ops); - switch (stmt->ops->type) { + ops = stmt_ops(stmt); + clone = stmt_alloc(&internal_location, ops); + switch (stmt->type) { case STMT_EXPRESSION: if (stmt->expr->op != OP_IMPLICIT && stmt->expr->op != OP_EQ) { - clone->ops = &unsupported_stmt_ops; + clone->type = STMT_INVALID; break; } if (stmt->expr->left->etype == EXPR_CONCAT) { - clone->ops = &unsupported_stmt_ops; + clone->type = STMT_INVALID; break; } + /* fall-through */ case STMT_VERDICT: clone->expr = expr_get(stmt->expr); break; @@ -365,9 +442,18 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule) case STMT_LOG: memcpy(&clone->log, &stmt->log, sizeof(clone->log)); if (stmt->log.prefix) - clone->log.prefix = expr_get(stmt->log.prefix); + clone->log.prefix = xstrdup(stmt->log.prefix); break; case STMT_NAT: + if ((stmt->nat.addr && + (stmt->nat.addr->etype == EXPR_MAP || + stmt->nat.addr->etype == EXPR_VARIABLE)) || + (stmt->nat.proto && + (stmt->nat.proto->etype == EXPR_MAP || + stmt->nat.proto->etype == EXPR_VARIABLE))) { + clone->type = STMT_INVALID; + break; + } clone->nat.type = stmt->nat.type; clone->nat.family = stmt->nat.family; if (stmt->nat.addr) @@ -385,7 +471,7 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule) clone->reject.family = stmt->reject.family; break; default: - clone->ops = &unsupported_stmt_ops; + clone->type = STMT_INVALID; break; } @@ -402,7 +488,7 @@ static int unsupported_in_stmt_matrix(const struct optimize_ctx *ctx) uint32_t i; for (i = 0; i < ctx->num_stmts; i++) { - if (ctx->stmt[i]->ops->type == STMT_INVALID) + if (ctx->stmt[i]->type == STMT_INVALID) return i; } /* this should not happen. */ @@ -422,7 +508,7 @@ static int cmd_stmt_find_in_stmt_matrix(struct optimize_ctx *ctx, struct stmt *s } static struct stmt unsupported_stmt = { - .ops = &unsupported_stmt_ops, + .type = STMT_INVALID, }; static void rule_build_stmt_matrix_stmts(struct optimize_ctx *ctx, @@ -449,7 +535,7 @@ static int stmt_verdict_find(const struct optimize_ctx *ctx) uint32_t i; for (i = 0; i < ctx->num_stmts; i++) { - if (ctx->stmt[i]->ops->type != STMT_VERDICT) + if (ctx->stmt[i]->type != STMT_VERDICT) continue; return i; @@ -465,6 +551,8 @@ struct merge { /* statements to be merged (index relative to statement matrix) */ uint32_t stmt[MAX_STMTS]; uint32_t num_stmts; + /* merge has been invalidated */ + bool skip; }; static void merge_expr_stmts(const struct optimize_ctx *ctx, @@ -516,7 +604,7 @@ static void merge_verdict_stmts(const struct optimize_ctx *ctx, for (i = from + 1; i <= to; i++) { stmt_b = ctx->stmt_matrix[i][merge->stmt[0]]; - switch (stmt_b->ops->type) { + switch (stmt_b->type) { case STMT_VERDICT: switch (stmt_b->expr->etype) { case EXPR_MAP: @@ -538,7 +626,7 @@ static void merge_stmts(const struct optimize_ctx *ctx, { struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]]; - switch (stmt_a->ops->type) { + switch (stmt_a->type) { case STMT_EXPRESSION: merge_expr_stmts(ctx, from, to, merge, stmt_a); break; @@ -550,12 +638,80 @@ static void merge_stmts(const struct optimize_ctx *ctx, } } +static void __merge_concat(const struct optimize_ctx *ctx, uint32_t i, + const struct merge *merge, struct list_head *concat_list) +{ + struct expr *concat, *next, *expr, *concat_clone, *clone; + LIST_HEAD(pending_list); + struct stmt *stmt_a; + uint32_t k; + + concat = concat_expr_alloc(&internal_location); + list_add(&concat->list, concat_list); + + for (k = 0; k < merge->num_stmts; k++) { + list_for_each_entry_safe(concat, next, concat_list, list) { + stmt_a = ctx->stmt_matrix[i][merge->stmt[k]]; + switch (stmt_a->expr->right->etype) { + case EXPR_SET: + list_for_each_entry(expr, &stmt_a->expr->right->expressions, list) { + concat_clone = expr_clone(concat); + clone = expr_clone(expr->key); + compound_expr_add(concat_clone, clone); + list_add_tail(&concat_clone->list, &pending_list); + } + list_del(&concat->list); + expr_free(concat); + break; + case EXPR_SYMBOL: + case EXPR_VALUE: + case EXPR_PREFIX: + case EXPR_RANGE_SYMBOL: + case EXPR_RANGE: + case EXPR_RANGE_VALUE: + clone = expr_clone(stmt_a->expr->right); + compound_expr_add(concat, clone); + break; + case EXPR_LIST: + list_for_each_entry(expr, &stmt_a->expr->right->expressions, list) { + concat_clone = expr_clone(concat); + clone = expr_clone(expr); + compound_expr_add(concat_clone, clone); + list_add_tail(&concat_clone->list, &pending_list); + } + list_del(&concat->list); + expr_free(concat); + break; + default: + assert(0); + break; + } + } + list_splice_init(&pending_list, concat_list); + } +} + +static void __merge_concat_stmts(const struct optimize_ctx *ctx, uint32_t i, + const struct merge *merge, struct expr *set) +{ + struct expr *concat, *next, *elem; + LIST_HEAD(concat_list); + + __merge_concat(ctx, i, merge, &concat_list); + + list_for_each_entry_safe(concat, next, &concat_list, list) { + list_del(&concat->list); + elem = set_elem_expr_alloc(&internal_location, concat); + compound_expr_add(set, elem); + } +} + static void merge_concat_stmts(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) { - struct expr *concat, *elem, *set; struct stmt *stmt, *stmt_a; + struct expr *concat, *set; uint32_t i, k; stmt = ctx->stmt_matrix[from][merge->stmt[0]]; @@ -573,15 +729,9 @@ static void merge_concat_stmts(const struct optimize_ctx *ctx, set = set_expr_alloc(&internal_location, NULL); set->set_flags |= NFT_SET_ANONYMOUS; - for (i = from; i <= to; i++) { - concat = concat_expr_alloc(&internal_location); - for (k = 0; k < merge->num_stmts; k++) { - stmt_a = ctx->stmt_matrix[i][merge->stmt[k]]; - compound_expr_add(concat, expr_get(stmt_a->expr->right)); - } - elem = set_elem_expr_alloc(&internal_location, concat); - compound_expr_add(set, elem); - } + for (i = from; i <= to; i++) + __merge_concat_stmts(ctx, i, merge, set); + expr_free(stmt->expr->right); stmt->expr->right = set; @@ -592,33 +742,52 @@ static void merge_concat_stmts(const struct optimize_ctx *ctx, } } -static void build_verdict_map(struct expr *expr, struct stmt *verdict, struct expr *set) +static void build_verdict_map(struct expr *expr, struct stmt *verdict, + struct expr *set, struct stmt *counter) { struct expr *item, *elem, *mapping; + struct stmt *counter_elem; switch (expr->etype) { case EXPR_LIST: list_for_each_entry(item, &expr->expressions, list) { elem = set_elem_expr_alloc(&internal_location, expr_get(item)); + if (counter) { + counter_elem = counter_stmt_alloc(&counter->location); + list_add_tail(&counter_elem->list, &elem->stmt_list); + } + mapping = mapping_expr_alloc(&internal_location, elem, expr_get(verdict->expr)); compound_expr_add(set, mapping); } + stmt_free(counter); break; case EXPR_SET: list_for_each_entry(item, &expr->expressions, list) { elem = set_elem_expr_alloc(&internal_location, expr_get(item->key)); + if (counter) { + counter_elem = counter_stmt_alloc(&counter->location); + list_add_tail(&counter_elem->list, &elem->stmt_list); + } + mapping = mapping_expr_alloc(&internal_location, elem, expr_get(verdict->expr)); compound_expr_add(set, mapping); } + stmt_free(counter); break; case EXPR_PREFIX: + case EXPR_RANGE_SYMBOL: case EXPR_RANGE: + case EXPR_RANGE_VALUE: case EXPR_VALUE: case EXPR_SYMBOL: case EXPR_CONCAT: elem = set_elem_expr_alloc(&internal_location, expr_get(expr)); + if (counter) + list_add_tail(&counter->list, &elem->stmt_list); + mapping = mapping_expr_alloc(&internal_location, elem, expr_get(verdict->expr)); compound_expr_add(set, mapping); @@ -640,13 +809,33 @@ static void remove_counter(const struct optimize_ctx *ctx, uint32_t from) if (!stmt) continue; - if (stmt->ops->type == STMT_COUNTER) { + if (stmt->type == STMT_COUNTER) { list_del(&stmt->list); stmt_free(stmt); } } } +static struct stmt *zap_counter(const struct optimize_ctx *ctx, uint32_t from) +{ + struct stmt *stmt; + uint32_t i; + + /* remove counter statement */ + for (i = 0; i < ctx->num_stmts; i++) { + stmt = ctx->stmt_matrix[from][i]; + if (!stmt) + continue; + + if (stmt->type == STMT_COUNTER) { + list_del(&stmt->list); + return stmt; + } + } + + return NULL; +} + static void merge_stmts_vmap(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) @@ -654,31 +843,33 @@ static void merge_stmts_vmap(const struct optimize_ctx *ctx, struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]]; struct stmt *stmt_b, *verdict_a, *verdict_b, *stmt; struct expr *expr_a, *expr_b, *expr, *left, *set; + struct stmt *counter; uint32_t i; int k; k = stmt_verdict_find(ctx); assert(k >= 0); - verdict_a = ctx->stmt_matrix[from][k]; set = set_expr_alloc(&internal_location, NULL); set->set_flags |= NFT_SET_ANONYMOUS; expr_a = stmt_a->expr->right; - build_verdict_map(expr_a, verdict_a, set); + verdict_a = ctx->stmt_matrix[from][k]; + counter = zap_counter(ctx, from); + build_verdict_map(expr_a, verdict_a, set, counter); + for (i = from + 1; i <= to; i++) { stmt_b = ctx->stmt_matrix[i][merge->stmt[0]]; expr_b = stmt_b->expr->right; verdict_b = ctx->stmt_matrix[i][k]; - - build_verdict_map(expr_b, verdict_b, set); + counter = zap_counter(ctx, i); + build_verdict_map(expr_b, verdict_b, set, counter); } left = expr_get(stmt_a->expr->left); expr = map_expr_alloc(&internal_location, left, set); stmt = verdict_stmt_alloc(&internal_location, expr); - remove_counter(ctx, from); list_add(&stmt->list, &stmt_a->list); list_del(&stmt_a->list); stmt_free(stmt_a); @@ -686,14 +877,40 @@ static void merge_stmts_vmap(const struct optimize_ctx *ctx, stmt_free(verdict_a); } +static void __merge_concat_stmts_vmap(const struct optimize_ctx *ctx, + uint32_t i, const struct merge *merge, + struct expr *set, struct stmt *verdict) +{ + struct expr *concat, *next, *elem, *mapping; + struct stmt *counter, *counter_elem; + LIST_HEAD(concat_list); + + counter = zap_counter(ctx, i); + __merge_concat(ctx, i, merge, &concat_list); + + list_for_each_entry_safe(concat, next, &concat_list, list) { + list_del(&concat->list); + elem = set_elem_expr_alloc(&internal_location, concat); + if (counter) { + counter_elem = counter_stmt_alloc(&counter->location); + list_add_tail(&counter_elem->list, &elem->stmt_list); + } + + mapping = mapping_expr_alloc(&internal_location, elem, + expr_get(verdict->expr)); + compound_expr_add(set, mapping); + } + stmt_free(counter); +} + static void merge_concat_stmts_vmap(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) { struct stmt *orig_stmt = ctx->stmt_matrix[from][merge->stmt[0]]; - struct expr *concat_a, *concat_b, *expr, *set; - struct stmt *stmt, *stmt_a, *stmt_b, *verdict; - uint32_t i, j; + struct stmt *stmt, *stmt_a, *verdict; + struct expr *concat_a, *expr, *set; + uint32_t i; int k; k = stmt_verdict_find(ctx); @@ -711,21 +928,13 @@ static void merge_concat_stmts_vmap(const struct optimize_ctx *ctx, set->set_flags |= NFT_SET_ANONYMOUS; for (i = from; i <= to; i++) { - concat_b = concat_expr_alloc(&internal_location); - for (j = 0; j < merge->num_stmts; j++) { - stmt_b = ctx->stmt_matrix[i][merge->stmt[j]]; - expr = stmt_b->expr->right; - compound_expr_add(concat_b, expr_get(expr)); - } verdict = ctx->stmt_matrix[i][k]; - build_verdict_map(concat_b, verdict, set); - expr_free(concat_b); + __merge_concat_stmts_vmap(ctx, i, merge, set, verdict); } expr = map_expr_alloc(&internal_location, concat_a, set); stmt = verdict_stmt_alloc(&internal_location, expr); - remove_counter(ctx, from); list_add(&stmt->list, &orig_stmt->list); list_del(&orig_stmt->list); stmt_free(orig_stmt); @@ -766,12 +975,35 @@ static bool stmt_verdict_cmp(const struct optimize_ctx *ctx, return true; } -static int stmt_nat_find(const struct optimize_ctx *ctx) +static int stmt_nat_type(const struct optimize_ctx *ctx, int from, + enum nft_nat_etypes *nat_type) +{ + uint32_t j; + + for (j = 0; j < ctx->num_stmts; j++) { + if (!ctx->stmt_matrix[from][j]) + continue; + + if (ctx->stmt_matrix[from][j]->type == STMT_NAT) { + *nat_type = ctx->stmt_matrix[from][j]->nat.type; + return 0; + } + } + + return -1; +} + +static int stmt_nat_find(const struct optimize_ctx *ctx, int from) { + enum nft_nat_etypes nat_type; uint32_t i; + if (stmt_nat_type(ctx, from, &nat_type) < 0) + return -1; + for (i = 0; i < ctx->num_stmts; i++) { - if (ctx->stmt[i]->ops->type != STMT_NAT) + if (ctx->stmt[i]->type != STMT_NAT || + ctx->stmt[i]->nat.type != nat_type) continue; return i; @@ -784,16 +1016,24 @@ static struct expr *stmt_nat_expr(struct stmt *nat_stmt) { struct expr *nat_expr; + assert(nat_stmt->type == STMT_NAT); + if (nat_stmt->nat.proto) { - nat_expr = concat_expr_alloc(&internal_location); - compound_expr_add(nat_expr, expr_get(nat_stmt->nat.addr)); - compound_expr_add(nat_expr, expr_get(nat_stmt->nat.proto)); + if (nat_stmt->nat.addr) { + nat_expr = concat_expr_alloc(&internal_location); + compound_expr_add(nat_expr, expr_get(nat_stmt->nat.addr)); + compound_expr_add(nat_expr, expr_get(nat_stmt->nat.proto)); + } else { + nat_expr = expr_get(nat_stmt->nat.proto); + } expr_free(nat_stmt->nat.proto); nat_stmt->nat.proto = NULL; } else { nat_expr = expr_get(nat_stmt->nat.addr); } + assert(nat_expr); + return nat_expr; } @@ -802,11 +1042,11 @@ static void merge_nat(const struct optimize_ctx *ctx, const struct merge *merge) { struct expr *expr, *set, *elem, *nat_expr, *mapping, *left; + int k, family = NFPROTO_UNSPEC; struct stmt *stmt, *nat_stmt; uint32_t i; - int k; - k = stmt_nat_find(ctx); + k = stmt_nat_find(ctx, from); assert(k >= 0); set = set_expr_alloc(&internal_location, NULL); @@ -826,11 +1066,23 @@ static void merge_nat(const struct optimize_ctx *ctx, stmt = ctx->stmt_matrix[from][merge->stmt[0]]; left = expr_get(stmt->expr->left); + if (left->etype == EXPR_PAYLOAD) { + if (left->payload.desc == &proto_ip) + family = NFPROTO_IPV4; + else if (left->payload.desc == &proto_ip6) + family = NFPROTO_IPV6; + } expr = map_expr_alloc(&internal_location, left, set); nat_stmt = ctx->stmt_matrix[from][k]; + if (nat_stmt->nat.family == NFPROTO_UNSPEC) + nat_stmt->nat.family = family; + expr_free(nat_stmt->nat.addr); - nat_stmt->nat.addr = expr; + if (nat_stmt->nat.type == NFT_NAT_REDIR) + nat_stmt->nat.proto = expr; + else + nat_stmt->nat.addr = expr; remove_counter(ctx, from); list_del(&stmt->list); @@ -842,11 +1094,11 @@ static void merge_concat_nat(const struct optimize_ctx *ctx, const struct merge *merge) { struct expr *expr, *set, *elem, *nat_expr, *mapping, *left, *concat; + int k, family = NFPROTO_UNSPEC; struct stmt *stmt, *nat_stmt; uint32_t i, j; - int k; - k = stmt_nat_find(ctx); + k = stmt_nat_find(ctx, from); assert(k >= 0); set = set_expr_alloc(&internal_location, NULL); @@ -873,11 +1125,20 @@ static void merge_concat_nat(const struct optimize_ctx *ctx, for (j = 0; j < merge->num_stmts; j++) { stmt = ctx->stmt_matrix[from][merge->stmt[j]]; left = stmt->expr->left; + if (left->etype == EXPR_PAYLOAD) { + if (left->payload.desc == &proto_ip) + family = NFPROTO_IPV4; + else if (left->payload.desc == &proto_ip6) + family = NFPROTO_IPV6; + } compound_expr_add(concat, expr_get(left)); } expr = map_expr_alloc(&internal_location, concat, set); nat_stmt = ctx->stmt_matrix[from][k]; + if (nat_stmt->nat.family == NFPROTO_UNSPEC) + nat_stmt->nat.family = family; + expr_free(nat_stmt->nat.addr); nat_stmt->nat.addr = expr; @@ -922,22 +1183,36 @@ static void rule_optimize_print(struct output_ctx *octx, fprintf(octx->error_fp, "%s\n", line); } -static enum stmt_types merge_stmt_type(const struct optimize_ctx *ctx) +enum { + MERGE_BY_VERDICT, + MERGE_BY_NAT_MAP, + MERGE_BY_NAT, +}; + +static uint32_t merge_stmt_type(const struct optimize_ctx *ctx, + uint32_t from, uint32_t to) { - uint32_t i; + const struct stmt *stmt; + uint32_t i, j; - for (i = 0; i < ctx->num_stmts; i++) { - switch (ctx->stmt[i]->ops->type) { - case STMT_VERDICT: - case STMT_NAT: - return ctx->stmt[i]->ops->type; - default: - continue; + for (i = from; i <= to; i++) { + for (j = 0; j < ctx->num_stmts; j++) { + stmt = ctx->stmt_matrix[i][j]; + if (!stmt) + continue; + if (stmt->type == STMT_NAT) { + if ((stmt->nat.type == NFT_NAT_REDIR && + !stmt->nat.proto) || + stmt->nat.type == NFT_NAT_MASQ) + return MERGE_BY_NAT; + + return MERGE_BY_NAT_MAP; + } } } - /* actually no verdict, this assumes rules have the same verdict. */ - return STMT_VERDICT; + /* merge by verdict, even if no verdict is specified. */ + return MERGE_BY_VERDICT; } static void merge_rules(const struct optimize_ctx *ctx, @@ -945,14 +1220,14 @@ static void merge_rules(const struct optimize_ctx *ctx, const struct merge *merge, struct output_ctx *octx) { - enum stmt_types stmt_type; + uint32_t merge_type; bool same_verdict; uint32_t i; - stmt_type = merge_stmt_type(ctx); + merge_type = merge_stmt_type(ctx, from, to); - switch (stmt_type) { - case STMT_VERDICT: + switch (merge_type) { + case MERGE_BY_VERDICT: same_verdict = stmt_verdict_cmp(ctx, from, to); if (merge->num_stmts > 1) { if (same_verdict) @@ -966,18 +1241,24 @@ static void merge_rules(const struct optimize_ctx *ctx, merge_stmts_vmap(ctx, from, to, merge); } break; - case STMT_NAT: + case MERGE_BY_NAT_MAP: if (merge->num_stmts > 1) merge_concat_nat(ctx, from, to, merge); else merge_nat(ctx, from, to, merge); break; + case MERGE_BY_NAT: + if (merge->num_stmts > 1) + merge_concat_stmts(ctx, from, to, merge); + else + merge_stmts(ctx, from, to, merge); + break; default: assert(0); } if (ctx->rule[from]->comment) { - xfree(ctx->rule[from]->comment); + free_const(ctx->rule[from]->comment); ctx->rule[from]->comment = NULL; } @@ -1011,15 +1292,41 @@ static bool stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b) return __stmt_type_eq(stmt_a, stmt_b, true); } +static bool stmt_is_mergeable(const struct stmt *stmt) +{ + if (!stmt) + return false; + + switch (stmt->type) { + case STMT_VERDICT: + if (stmt->expr->etype == EXPR_MAP) + return true; + break; + case STMT_EXPRESSION: + case STMT_NAT: + return true; + default: + break; + } + + return false; +} + static bool rules_eq(const struct optimize_ctx *ctx, int i, int j) { - uint32_t k; + uint32_t k, mergeable = 0; for (k = 0; k < ctx->num_stmts; k++) { + if (stmt_is_mergeable(ctx->stmt_matrix[i][k])) + mergeable++; + if (!stmt_type_eq(ctx->stmt_matrix[i][k], ctx->stmt_matrix[j][k])) return false; } + if (mergeable == 0) + return false; + return true; } @@ -1043,10 +1350,11 @@ static int chain_optimize(struct nft_ctx *nft, struct list_head *rules) ctx->num_rules++; } - ctx->rule = xzalloc(sizeof(ctx->rule) * ctx->num_rules); - ctx->stmt_matrix = xzalloc(sizeof(struct stmt *) * ctx->num_rules); + ctx->rule = xzalloc(sizeof(*ctx->rule) * ctx->num_rules); + ctx->stmt_matrix = xzalloc(sizeof(*ctx->stmt_matrix) * ctx->num_rules); for (i = 0; i < ctx->num_rules; i++) - ctx->stmt_matrix[i] = xzalloc(sizeof(struct stmt *) * MAX_STMTS); + ctx->stmt_matrix[i] = xzalloc_array(MAX_STMTS, + sizeof(**ctx->stmt_matrix)); merge = xzalloc(sizeof(*merge) * ctx->num_rules); @@ -1078,14 +1386,48 @@ static int chain_optimize(struct nft_ctx *nft, struct list_head *rules) } } - /* Step 4: Infer how to merge the candidate rules */ + /* Step 4: Invalidate merge in case of duplicated keys in set/map. */ for (k = 0; k < num_merges; k++) { + uint32_t r1, r2; + + i = merge[k].rule_from; + + for (r1 = i; r1 < i + merge[k].num_rules; r1++) { + for (r2 = r1 + 1; r2 < i + merge[k].num_rules; r2++) { + bool match_same_value = true, match_seen = false; + + for (m = 0; m < ctx->num_stmts; m++) { + if (!ctx->stmt_matrix[r1][m]) + continue; + + switch (ctx->stmt_matrix[r1][m]->type) { + case STMT_EXPRESSION: + match_seen = true; + if (!__expr_cmp(ctx->stmt_matrix[r1][m]->expr->right, + ctx->stmt_matrix[r2][m]->expr->right)) + match_same_value = false; + break; + default: + break; + } + } + if (match_seen && match_same_value) + merge[k].skip = true; + } + } + } + + /* Step 5: Infer how to merge the candidate rules */ + for (k = 0; k < num_merges; k++) { + if (merge[k].skip) + continue; + i = merge[k].rule_from; for (m = 0; m < ctx->num_stmts; m++) { if (!ctx->stmt_matrix[i][m]) continue; - switch (ctx->stmt_matrix[i][m]->ops->type) { + switch (ctx->stmt_matrix[i][m]->type) { case STMT_EXPRESSION: merge[k].stmt[merge[k].num_stmts++] = m; break; @@ -1103,16 +1445,16 @@ static int chain_optimize(struct nft_ctx *nft, struct list_head *rules) } ret = 0; for (i = 0; i < ctx->num_rules; i++) - xfree(ctx->stmt_matrix[i]); + free(ctx->stmt_matrix[i]); - xfree(ctx->stmt_matrix); - xfree(merge); + free(ctx->stmt_matrix); + free(merge); err: for (i = 0; i < ctx->num_stmts; i++) stmt_free(ctx->stmt[i]); - xfree(ctx->rule); - xfree(ctx); + free(ctx->rule); + free(ctx); return ret; } @@ -1146,7 +1488,7 @@ static int cmd_optimize(struct nft_ctx *nft, struct cmd *cmd) int nft_optimize(struct nft_ctx *nft, struct list_head *cmds) { struct cmd *cmd; - int ret; + int ret = 0; list_for_each_entry(cmd, cmds, list) { switch (cmd->op) { |