summaryrefslogtreecommitdiffstats
path: root/src/optimize.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/optimize.c')
-rw-r--r--src/optimize.c560
1 files changed, 451 insertions, 109 deletions
diff --git a/src/optimize.c b/src/optimize.c
index 419a37f2..89ba0d9d 100644
--- a/src/optimize.c
+++ b/src/optimize.c
@@ -11,8 +11,8 @@
* programme.
*/
-#define _GNU_SOURCE
-#include <string.h>
+#include <nft.h>
+
#include <errno.h>
#include <inttypes.h>
#include <nftables.h>
@@ -21,6 +21,7 @@
#include <statement.h>
#include <utils.h>
#include <erec.h>
+#include <linux/netfilter.h>
#define MAX_STMTS 32
@@ -37,6 +38,8 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b)
{
if (expr_a->etype != expr_b->etype)
return false;
+ if (expr_a->len != expr_b->len)
+ return false;
switch (expr_a->etype) {
case EXPR_PAYLOAD:
@@ -46,6 +49,8 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b)
return false;
if (expr_a->payload.desc != expr_b->payload.desc)
return false;
+ if (expr_a->payload.inner_desc != expr_b->payload.inner_desc)
+ return false;
if (expr_a->payload.tmpl != expr_b->payload.tmpl)
return false;
break;
@@ -60,6 +65,8 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b)
return false;
if (expr_a->meta.base != expr_b->meta.base)
return false;
+ if (expr_a->meta.inner_desc != expr_b->meta.inner_desc)
+ return false;
break;
case EXPR_CT:
if (expr_a->ct.key != expr_b->ct.key)
@@ -120,7 +127,19 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b)
return false;
break;
case EXPR_BINOP:
- return __expr_cmp(expr_a->left, expr_b->left);
+ if (!__expr_cmp(expr_a->left, expr_b->left))
+ return false;
+
+ return __expr_cmp(expr_a->right, expr_b->right);
+ case EXPR_SYMBOL:
+ if (expr_a->symtype != expr_b->symtype)
+ return false;
+ if (expr_a->symtype != SYMBOL_VALUE)
+ return false;
+
+ return !strcmp(expr_a->identifier, expr_b->identifier);
+ case EXPR_VALUE:
+ return !mpz_cmp(expr_a->value, expr_b->value);
default:
return false;
}
@@ -128,16 +147,41 @@ static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b)
return true;
}
+static bool is_bitmask(const struct expr *expr)
+{
+ switch (expr->etype) {
+ case EXPR_BINOP:
+ if (expr->op == OP_OR &&
+ !is_bitmask(expr->left))
+ return false;
+
+ return is_bitmask(expr->right);
+ case EXPR_VALUE:
+ case EXPR_SYMBOL:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
static bool stmt_expr_supported(const struct expr *expr)
{
switch (expr->right->etype) {
case EXPR_SYMBOL:
+ case EXPR_RANGE_SYMBOL:
case EXPR_RANGE:
+ case EXPR_RANGE_VALUE:
case EXPR_PREFIX:
case EXPR_SET:
case EXPR_LIST:
case EXPR_VALUE:
return true;
+ case EXPR_BINOP:
+ if (is_bitmask(expr->right))
+ return true;
+ break;
default:
break;
}
@@ -156,10 +200,10 @@ static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b,
{
struct expr *expr_a, *expr_b;
- if (stmt_a->ops->type != stmt_b->ops->type)
+ if (stmt_a->type != stmt_b->type)
return false;
- switch (stmt_a->ops->type) {
+ switch (stmt_a->type) {
case STMT_EXPRESSION:
expr_a = stmt_a->expr;
expr_b = stmt_b->expr;
@@ -212,9 +256,7 @@ static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b,
if (!stmt_a->log.prefix)
return true;
- if (stmt_a->log.prefix->etype != EXPR_VALUE ||
- stmt_b->log.prefix->etype != EXPR_VALUE ||
- mpz_cmp(stmt_a->log.prefix->value, stmt_b->log.prefix->value))
+ if (strcmp(stmt_a->log.prefix, stmt_b->log.prefix))
return false;
break;
case STMT_REJECT:
@@ -229,28 +271,65 @@ static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b,
if (!stmt_a->reject.expr)
return true;
- if (__expr_cmp(stmt_a->reject.expr, stmt_b->reject.expr))
+ if (!__expr_cmp(stmt_a->reject.expr, stmt_b->reject.expr))
return false;
break;
case STMT_NAT:
if (stmt_a->nat.type != stmt_b->nat.type ||
stmt_a->nat.flags != stmt_b->nat.flags ||
stmt_a->nat.family != stmt_b->nat.family ||
- stmt_a->nat.type_flags != stmt_b->nat.type_flags ||
- (stmt_a->nat.addr &&
- stmt_a->nat.addr->etype != EXPR_SYMBOL &&
- stmt_a->nat.addr->etype != EXPR_RANGE) ||
- (stmt_b->nat.addr &&
- stmt_b->nat.addr->etype != EXPR_SYMBOL &&
- stmt_b->nat.addr->etype != EXPR_RANGE) ||
- (stmt_a->nat.proto &&
- stmt_a->nat.proto->etype != EXPR_SYMBOL &&
- stmt_a->nat.proto->etype != EXPR_RANGE) ||
- (stmt_b->nat.proto &&
- stmt_b->nat.proto->etype != EXPR_SYMBOL &&
- stmt_b->nat.proto->etype != EXPR_RANGE))
+ stmt_a->nat.type_flags != stmt_b->nat.type_flags)
return false;
+ switch (stmt_a->nat.type) {
+ case NFT_NAT_SNAT:
+ case NFT_NAT_DNAT:
+ if ((stmt_a->nat.addr &&
+ stmt_a->nat.addr->etype != EXPR_SYMBOL &&
+ stmt_a->nat.addr->etype != EXPR_RANGE) ||
+ (stmt_b->nat.addr &&
+ stmt_b->nat.addr->etype != EXPR_SYMBOL &&
+ stmt_b->nat.addr->etype != EXPR_RANGE) ||
+ (stmt_a->nat.proto &&
+ stmt_a->nat.proto->etype != EXPR_SYMBOL &&
+ stmt_a->nat.proto->etype != EXPR_RANGE) ||
+ (stmt_b->nat.proto &&
+ stmt_b->nat.proto->etype != EXPR_SYMBOL &&
+ stmt_b->nat.proto->etype != EXPR_RANGE))
+ return false;
+ break;
+ case NFT_NAT_MASQ:
+ break;
+ case NFT_NAT_REDIR:
+ if ((stmt_a->nat.proto &&
+ stmt_a->nat.proto->etype != EXPR_SYMBOL &&
+ stmt_a->nat.proto->etype != EXPR_RANGE) ||
+ (stmt_b->nat.proto &&
+ stmt_b->nat.proto->etype != EXPR_SYMBOL &&
+ stmt_b->nat.proto->etype != EXPR_RANGE))
+ return false;
+
+ /* it should be possible to infer implicit redirections
+ * such as:
+ *
+ * tcp dport 1234 redirect
+ * tcp dport 3456 redirect to :7890
+ * merge:
+ * redirect to tcp dport map { 1234 : 1234, 3456 : 7890 }
+ *
+ * currently not implemented.
+ */
+ if (fully_compare &&
+ stmt_a->nat.type == NFT_NAT_REDIR &&
+ stmt_b->nat.type == NFT_NAT_REDIR &&
+ (!!stmt_a->nat.proto ^ !!stmt_b->nat.proto))
+ return false;
+
+ break;
+ default:
+ assert(0);
+ }
+
return true;
default:
/* ... Merging anything else is yet unsupported. */
@@ -281,7 +360,7 @@ static bool stmt_verdict_eq(const struct stmt *stmt_a, const struct stmt *stmt_b
{
struct expr *expr_a, *expr_b;
- assert (stmt_a->ops->type == STMT_VERDICT);
+ assert (stmt_a->type == STMT_VERDICT);
expr_a = stmt_a->expr;
expr_b = stmt_b->expr;
@@ -302,14 +381,14 @@ static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt)
uint32_t i;
for (i = 0; i < ctx->num_stmts; i++) {
- if (ctx->stmt[i]->ops->type == STMT_INVALID)
+ if (ctx->stmt[i]->type == STMT_INVALID)
unsupported_exists = true;
if (__stmt_type_eq(stmt, ctx->stmt[i], false))
return true;
}
- switch (stmt->ops->type) {
+ switch (stmt->type) {
case STMT_EXPRESSION:
case STMT_VERDICT:
case STMT_COUNTER:
@@ -328,13 +407,9 @@ static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt)
return false;
}
-static struct stmt_ops unsupported_stmt_ops = {
- .type = STMT_INVALID,
- .name = "unsupported",
-};
-
static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
{
+ const struct stmt_ops *ops;
struct stmt *stmt, *clone;
list_for_each_entry(stmt, &rule->stmts, list) {
@@ -344,18 +419,20 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
/* No refcounter available in statement objects, clone it to
* to store in the array of selectors.
*/
- clone = stmt_alloc(&internal_location, stmt->ops);
- switch (stmt->ops->type) {
+ ops = stmt_ops(stmt);
+ clone = stmt_alloc(&internal_location, ops);
+ switch (stmt->type) {
case STMT_EXPRESSION:
if (stmt->expr->op != OP_IMPLICIT &&
stmt->expr->op != OP_EQ) {
- clone->ops = &unsupported_stmt_ops;
+ clone->type = STMT_INVALID;
break;
}
if (stmt->expr->left->etype == EXPR_CONCAT) {
- clone->ops = &unsupported_stmt_ops;
+ clone->type = STMT_INVALID;
break;
}
+ /* fall-through */
case STMT_VERDICT:
clone->expr = expr_get(stmt->expr);
break;
@@ -365,9 +442,18 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
case STMT_LOG:
memcpy(&clone->log, &stmt->log, sizeof(clone->log));
if (stmt->log.prefix)
- clone->log.prefix = expr_get(stmt->log.prefix);
+ clone->log.prefix = xstrdup(stmt->log.prefix);
break;
case STMT_NAT:
+ if ((stmt->nat.addr &&
+ (stmt->nat.addr->etype == EXPR_MAP ||
+ stmt->nat.addr->etype == EXPR_VARIABLE)) ||
+ (stmt->nat.proto &&
+ (stmt->nat.proto->etype == EXPR_MAP ||
+ stmt->nat.proto->etype == EXPR_VARIABLE))) {
+ clone->type = STMT_INVALID;
+ break;
+ }
clone->nat.type = stmt->nat.type;
clone->nat.family = stmt->nat.family;
if (stmt->nat.addr)
@@ -385,7 +471,7 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
clone->reject.family = stmt->reject.family;
break;
default:
- clone->ops = &unsupported_stmt_ops;
+ clone->type = STMT_INVALID;
break;
}
@@ -402,7 +488,7 @@ static int unsupported_in_stmt_matrix(const struct optimize_ctx *ctx)
uint32_t i;
for (i = 0; i < ctx->num_stmts; i++) {
- if (ctx->stmt[i]->ops->type == STMT_INVALID)
+ if (ctx->stmt[i]->type == STMT_INVALID)
return i;
}
/* this should not happen. */
@@ -422,7 +508,7 @@ static int cmd_stmt_find_in_stmt_matrix(struct optimize_ctx *ctx, struct stmt *s
}
static struct stmt unsupported_stmt = {
- .ops = &unsupported_stmt_ops,
+ .type = STMT_INVALID,
};
static void rule_build_stmt_matrix_stmts(struct optimize_ctx *ctx,
@@ -449,7 +535,7 @@ static int stmt_verdict_find(const struct optimize_ctx *ctx)
uint32_t i;
for (i = 0; i < ctx->num_stmts; i++) {
- if (ctx->stmt[i]->ops->type != STMT_VERDICT)
+ if (ctx->stmt[i]->type != STMT_VERDICT)
continue;
return i;
@@ -465,6 +551,8 @@ struct merge {
/* statements to be merged (index relative to statement matrix) */
uint32_t stmt[MAX_STMTS];
uint32_t num_stmts;
+ /* merge has been invalidated */
+ bool skip;
};
static void merge_expr_stmts(const struct optimize_ctx *ctx,
@@ -516,7 +604,7 @@ static void merge_verdict_stmts(const struct optimize_ctx *ctx,
for (i = from + 1; i <= to; i++) {
stmt_b = ctx->stmt_matrix[i][merge->stmt[0]];
- switch (stmt_b->ops->type) {
+ switch (stmt_b->type) {
case STMT_VERDICT:
switch (stmt_b->expr->etype) {
case EXPR_MAP:
@@ -538,7 +626,7 @@ static void merge_stmts(const struct optimize_ctx *ctx,
{
struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]];
- switch (stmt_a->ops->type) {
+ switch (stmt_a->type) {
case STMT_EXPRESSION:
merge_expr_stmts(ctx, from, to, merge, stmt_a);
break;
@@ -550,12 +638,80 @@ static void merge_stmts(const struct optimize_ctx *ctx,
}
}
+static void __merge_concat(const struct optimize_ctx *ctx, uint32_t i,
+ const struct merge *merge, struct list_head *concat_list)
+{
+ struct expr *concat, *next, *expr, *concat_clone, *clone;
+ LIST_HEAD(pending_list);
+ struct stmt *stmt_a;
+ uint32_t k;
+
+ concat = concat_expr_alloc(&internal_location);
+ list_add(&concat->list, concat_list);
+
+ for (k = 0; k < merge->num_stmts; k++) {
+ list_for_each_entry_safe(concat, next, concat_list, list) {
+ stmt_a = ctx->stmt_matrix[i][merge->stmt[k]];
+ switch (stmt_a->expr->right->etype) {
+ case EXPR_SET:
+ list_for_each_entry(expr, &stmt_a->expr->right->expressions, list) {
+ concat_clone = expr_clone(concat);
+ clone = expr_clone(expr->key);
+ compound_expr_add(concat_clone, clone);
+ list_add_tail(&concat_clone->list, &pending_list);
+ }
+ list_del(&concat->list);
+ expr_free(concat);
+ break;
+ case EXPR_SYMBOL:
+ case EXPR_VALUE:
+ case EXPR_PREFIX:
+ case EXPR_RANGE_SYMBOL:
+ case EXPR_RANGE:
+ case EXPR_RANGE_VALUE:
+ clone = expr_clone(stmt_a->expr->right);
+ compound_expr_add(concat, clone);
+ break;
+ case EXPR_LIST:
+ list_for_each_entry(expr, &stmt_a->expr->right->expressions, list) {
+ concat_clone = expr_clone(concat);
+ clone = expr_clone(expr);
+ compound_expr_add(concat_clone, clone);
+ list_add_tail(&concat_clone->list, &pending_list);
+ }
+ list_del(&concat->list);
+ expr_free(concat);
+ break;
+ default:
+ assert(0);
+ break;
+ }
+ }
+ list_splice_init(&pending_list, concat_list);
+ }
+}
+
+static void __merge_concat_stmts(const struct optimize_ctx *ctx, uint32_t i,
+ const struct merge *merge, struct expr *set)
+{
+ struct expr *concat, *next, *elem;
+ LIST_HEAD(concat_list);
+
+ __merge_concat(ctx, i, merge, &concat_list);
+
+ list_for_each_entry_safe(concat, next, &concat_list, list) {
+ list_del(&concat->list);
+ elem = set_elem_expr_alloc(&internal_location, concat);
+ compound_expr_add(set, elem);
+ }
+}
+
static void merge_concat_stmts(const struct optimize_ctx *ctx,
uint32_t from, uint32_t to,
const struct merge *merge)
{
- struct expr *concat, *elem, *set;
struct stmt *stmt, *stmt_a;
+ struct expr *concat, *set;
uint32_t i, k;
stmt = ctx->stmt_matrix[from][merge->stmt[0]];
@@ -573,15 +729,9 @@ static void merge_concat_stmts(const struct optimize_ctx *ctx,
set = set_expr_alloc(&internal_location, NULL);
set->set_flags |= NFT_SET_ANONYMOUS;
- for (i = from; i <= to; i++) {
- concat = concat_expr_alloc(&internal_location);
- for (k = 0; k < merge->num_stmts; k++) {
- stmt_a = ctx->stmt_matrix[i][merge->stmt[k]];
- compound_expr_add(concat, expr_get(stmt_a->expr->right));
- }
- elem = set_elem_expr_alloc(&internal_location, concat);
- compound_expr_add(set, elem);
- }
+ for (i = from; i <= to; i++)
+ __merge_concat_stmts(ctx, i, merge, set);
+
expr_free(stmt->expr->right);
stmt->expr->right = set;
@@ -592,33 +742,52 @@ static void merge_concat_stmts(const struct optimize_ctx *ctx,
}
}
-static void build_verdict_map(struct expr *expr, struct stmt *verdict, struct expr *set)
+static void build_verdict_map(struct expr *expr, struct stmt *verdict,
+ struct expr *set, struct stmt *counter)
{
struct expr *item, *elem, *mapping;
+ struct stmt *counter_elem;
switch (expr->etype) {
case EXPR_LIST:
list_for_each_entry(item, &expr->expressions, list) {
elem = set_elem_expr_alloc(&internal_location, expr_get(item));
+ if (counter) {
+ counter_elem = counter_stmt_alloc(&counter->location);
+ list_add_tail(&counter_elem->list, &elem->stmt_list);
+ }
+
mapping = mapping_expr_alloc(&internal_location, elem,
expr_get(verdict->expr));
compound_expr_add(set, mapping);
}
+ stmt_free(counter);
break;
case EXPR_SET:
list_for_each_entry(item, &expr->expressions, list) {
elem = set_elem_expr_alloc(&internal_location, expr_get(item->key));
+ if (counter) {
+ counter_elem = counter_stmt_alloc(&counter->location);
+ list_add_tail(&counter_elem->list, &elem->stmt_list);
+ }
+
mapping = mapping_expr_alloc(&internal_location, elem,
expr_get(verdict->expr));
compound_expr_add(set, mapping);
}
+ stmt_free(counter);
break;
case EXPR_PREFIX:
+ case EXPR_RANGE_SYMBOL:
case EXPR_RANGE:
+ case EXPR_RANGE_VALUE:
case EXPR_VALUE:
case EXPR_SYMBOL:
case EXPR_CONCAT:
elem = set_elem_expr_alloc(&internal_location, expr_get(expr));
+ if (counter)
+ list_add_tail(&counter->list, &elem->stmt_list);
+
mapping = mapping_expr_alloc(&internal_location, elem,
expr_get(verdict->expr));
compound_expr_add(set, mapping);
@@ -640,13 +809,33 @@ static void remove_counter(const struct optimize_ctx *ctx, uint32_t from)
if (!stmt)
continue;
- if (stmt->ops->type == STMT_COUNTER) {
+ if (stmt->type == STMT_COUNTER) {
list_del(&stmt->list);
stmt_free(stmt);
}
}
}
+static struct stmt *zap_counter(const struct optimize_ctx *ctx, uint32_t from)
+{
+ struct stmt *stmt;
+ uint32_t i;
+
+ /* remove counter statement */
+ for (i = 0; i < ctx->num_stmts; i++) {
+ stmt = ctx->stmt_matrix[from][i];
+ if (!stmt)
+ continue;
+
+ if (stmt->type == STMT_COUNTER) {
+ list_del(&stmt->list);
+ return stmt;
+ }
+ }
+
+ return NULL;
+}
+
static void merge_stmts_vmap(const struct optimize_ctx *ctx,
uint32_t from, uint32_t to,
const struct merge *merge)
@@ -654,31 +843,33 @@ static void merge_stmts_vmap(const struct optimize_ctx *ctx,
struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]];
struct stmt *stmt_b, *verdict_a, *verdict_b, *stmt;
struct expr *expr_a, *expr_b, *expr, *left, *set;
+ struct stmt *counter;
uint32_t i;
int k;
k = stmt_verdict_find(ctx);
assert(k >= 0);
- verdict_a = ctx->stmt_matrix[from][k];
set = set_expr_alloc(&internal_location, NULL);
set->set_flags |= NFT_SET_ANONYMOUS;
expr_a = stmt_a->expr->right;
- build_verdict_map(expr_a, verdict_a, set);
+ verdict_a = ctx->stmt_matrix[from][k];
+ counter = zap_counter(ctx, from);
+ build_verdict_map(expr_a, verdict_a, set, counter);
+
for (i = from + 1; i <= to; i++) {
stmt_b = ctx->stmt_matrix[i][merge->stmt[0]];
expr_b = stmt_b->expr->right;
verdict_b = ctx->stmt_matrix[i][k];
-
- build_verdict_map(expr_b, verdict_b, set);
+ counter = zap_counter(ctx, i);
+ build_verdict_map(expr_b, verdict_b, set, counter);
}
left = expr_get(stmt_a->expr->left);
expr = map_expr_alloc(&internal_location, left, set);
stmt = verdict_stmt_alloc(&internal_location, expr);
- remove_counter(ctx, from);
list_add(&stmt->list, &stmt_a->list);
list_del(&stmt_a->list);
stmt_free(stmt_a);
@@ -686,14 +877,40 @@ static void merge_stmts_vmap(const struct optimize_ctx *ctx,
stmt_free(verdict_a);
}
+static void __merge_concat_stmts_vmap(const struct optimize_ctx *ctx,
+ uint32_t i, const struct merge *merge,
+ struct expr *set, struct stmt *verdict)
+{
+ struct expr *concat, *next, *elem, *mapping;
+ struct stmt *counter, *counter_elem;
+ LIST_HEAD(concat_list);
+
+ counter = zap_counter(ctx, i);
+ __merge_concat(ctx, i, merge, &concat_list);
+
+ list_for_each_entry_safe(concat, next, &concat_list, list) {
+ list_del(&concat->list);
+ elem = set_elem_expr_alloc(&internal_location, concat);
+ if (counter) {
+ counter_elem = counter_stmt_alloc(&counter->location);
+ list_add_tail(&counter_elem->list, &elem->stmt_list);
+ }
+
+ mapping = mapping_expr_alloc(&internal_location, elem,
+ expr_get(verdict->expr));
+ compound_expr_add(set, mapping);
+ }
+ stmt_free(counter);
+}
+
static void merge_concat_stmts_vmap(const struct optimize_ctx *ctx,
uint32_t from, uint32_t to,
const struct merge *merge)
{
struct stmt *orig_stmt = ctx->stmt_matrix[from][merge->stmt[0]];
- struct expr *concat_a, *concat_b, *expr, *set;
- struct stmt *stmt, *stmt_a, *stmt_b, *verdict;
- uint32_t i, j;
+ struct stmt *stmt, *stmt_a, *verdict;
+ struct expr *concat_a, *expr, *set;
+ uint32_t i;
int k;
k = stmt_verdict_find(ctx);
@@ -711,21 +928,13 @@ static void merge_concat_stmts_vmap(const struct optimize_ctx *ctx,
set->set_flags |= NFT_SET_ANONYMOUS;
for (i = from; i <= to; i++) {
- concat_b = concat_expr_alloc(&internal_location);
- for (j = 0; j < merge->num_stmts; j++) {
- stmt_b = ctx->stmt_matrix[i][merge->stmt[j]];
- expr = stmt_b->expr->right;
- compound_expr_add(concat_b, expr_get(expr));
- }
verdict = ctx->stmt_matrix[i][k];
- build_verdict_map(concat_b, verdict, set);
- expr_free(concat_b);
+ __merge_concat_stmts_vmap(ctx, i, merge, set, verdict);
}
expr = map_expr_alloc(&internal_location, concat_a, set);
stmt = verdict_stmt_alloc(&internal_location, expr);
- remove_counter(ctx, from);
list_add(&stmt->list, &orig_stmt->list);
list_del(&orig_stmt->list);
stmt_free(orig_stmt);
@@ -766,12 +975,35 @@ static bool stmt_verdict_cmp(const struct optimize_ctx *ctx,
return true;
}
-static int stmt_nat_find(const struct optimize_ctx *ctx)
+static int stmt_nat_type(const struct optimize_ctx *ctx, int from,
+ enum nft_nat_etypes *nat_type)
+{
+ uint32_t j;
+
+ for (j = 0; j < ctx->num_stmts; j++) {
+ if (!ctx->stmt_matrix[from][j])
+ continue;
+
+ if (ctx->stmt_matrix[from][j]->type == STMT_NAT) {
+ *nat_type = ctx->stmt_matrix[from][j]->nat.type;
+ return 0;
+ }
+ }
+
+ return -1;
+}
+
+static int stmt_nat_find(const struct optimize_ctx *ctx, int from)
{
+ enum nft_nat_etypes nat_type;
uint32_t i;
+ if (stmt_nat_type(ctx, from, &nat_type) < 0)
+ return -1;
+
for (i = 0; i < ctx->num_stmts; i++) {
- if (ctx->stmt[i]->ops->type != STMT_NAT)
+ if (ctx->stmt[i]->type != STMT_NAT ||
+ ctx->stmt[i]->nat.type != nat_type)
continue;
return i;
@@ -784,16 +1016,24 @@ static struct expr *stmt_nat_expr(struct stmt *nat_stmt)
{
struct expr *nat_expr;
+ assert(nat_stmt->type == STMT_NAT);
+
if (nat_stmt->nat.proto) {
- nat_expr = concat_expr_alloc(&internal_location);
- compound_expr_add(nat_expr, expr_get(nat_stmt->nat.addr));
- compound_expr_add(nat_expr, expr_get(nat_stmt->nat.proto));
+ if (nat_stmt->nat.addr) {
+ nat_expr = concat_expr_alloc(&internal_location);
+ compound_expr_add(nat_expr, expr_get(nat_stmt->nat.addr));
+ compound_expr_add(nat_expr, expr_get(nat_stmt->nat.proto));
+ } else {
+ nat_expr = expr_get(nat_stmt->nat.proto);
+ }
expr_free(nat_stmt->nat.proto);
nat_stmt->nat.proto = NULL;
} else {
nat_expr = expr_get(nat_stmt->nat.addr);
}
+ assert(nat_expr);
+
return nat_expr;
}
@@ -802,11 +1042,11 @@ static void merge_nat(const struct optimize_ctx *ctx,
const struct merge *merge)
{
struct expr *expr, *set, *elem, *nat_expr, *mapping, *left;
+ int k, family = NFPROTO_UNSPEC;
struct stmt *stmt, *nat_stmt;
uint32_t i;
- int k;
- k = stmt_nat_find(ctx);
+ k = stmt_nat_find(ctx, from);
assert(k >= 0);
set = set_expr_alloc(&internal_location, NULL);
@@ -826,11 +1066,23 @@ static void merge_nat(const struct optimize_ctx *ctx,
stmt = ctx->stmt_matrix[from][merge->stmt[0]];
left = expr_get(stmt->expr->left);
+ if (left->etype == EXPR_PAYLOAD) {
+ if (left->payload.desc == &proto_ip)
+ family = NFPROTO_IPV4;
+ else if (left->payload.desc == &proto_ip6)
+ family = NFPROTO_IPV6;
+ }
expr = map_expr_alloc(&internal_location, left, set);
nat_stmt = ctx->stmt_matrix[from][k];
+ if (nat_stmt->nat.family == NFPROTO_UNSPEC)
+ nat_stmt->nat.family = family;
+
expr_free(nat_stmt->nat.addr);
- nat_stmt->nat.addr = expr;
+ if (nat_stmt->nat.type == NFT_NAT_REDIR)
+ nat_stmt->nat.proto = expr;
+ else
+ nat_stmt->nat.addr = expr;
remove_counter(ctx, from);
list_del(&stmt->list);
@@ -842,11 +1094,11 @@ static void merge_concat_nat(const struct optimize_ctx *ctx,
const struct merge *merge)
{
struct expr *expr, *set, *elem, *nat_expr, *mapping, *left, *concat;
+ int k, family = NFPROTO_UNSPEC;
struct stmt *stmt, *nat_stmt;
uint32_t i, j;
- int k;
- k = stmt_nat_find(ctx);
+ k = stmt_nat_find(ctx, from);
assert(k >= 0);
set = set_expr_alloc(&internal_location, NULL);
@@ -873,11 +1125,20 @@ static void merge_concat_nat(const struct optimize_ctx *ctx,
for (j = 0; j < merge->num_stmts; j++) {
stmt = ctx->stmt_matrix[from][merge->stmt[j]];
left = stmt->expr->left;
+ if (left->etype == EXPR_PAYLOAD) {
+ if (left->payload.desc == &proto_ip)
+ family = NFPROTO_IPV4;
+ else if (left->payload.desc == &proto_ip6)
+ family = NFPROTO_IPV6;
+ }
compound_expr_add(concat, expr_get(left));
}
expr = map_expr_alloc(&internal_location, concat, set);
nat_stmt = ctx->stmt_matrix[from][k];
+ if (nat_stmt->nat.family == NFPROTO_UNSPEC)
+ nat_stmt->nat.family = family;
+
expr_free(nat_stmt->nat.addr);
nat_stmt->nat.addr = expr;
@@ -922,22 +1183,36 @@ static void rule_optimize_print(struct output_ctx *octx,
fprintf(octx->error_fp, "%s\n", line);
}
-static enum stmt_types merge_stmt_type(const struct optimize_ctx *ctx)
+enum {
+ MERGE_BY_VERDICT,
+ MERGE_BY_NAT_MAP,
+ MERGE_BY_NAT,
+};
+
+static uint32_t merge_stmt_type(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to)
{
- uint32_t i;
+ const struct stmt *stmt;
+ uint32_t i, j;
- for (i = 0; i < ctx->num_stmts; i++) {
- switch (ctx->stmt[i]->ops->type) {
- case STMT_VERDICT:
- case STMT_NAT:
- return ctx->stmt[i]->ops->type;
- default:
- continue;
+ for (i = from; i <= to; i++) {
+ for (j = 0; j < ctx->num_stmts; j++) {
+ stmt = ctx->stmt_matrix[i][j];
+ if (!stmt)
+ continue;
+ if (stmt->type == STMT_NAT) {
+ if ((stmt->nat.type == NFT_NAT_REDIR &&
+ !stmt->nat.proto) ||
+ stmt->nat.type == NFT_NAT_MASQ)
+ return MERGE_BY_NAT;
+
+ return MERGE_BY_NAT_MAP;
+ }
}
}
- /* actually no verdict, this assumes rules have the same verdict. */
- return STMT_VERDICT;
+ /* merge by verdict, even if no verdict is specified. */
+ return MERGE_BY_VERDICT;
}
static void merge_rules(const struct optimize_ctx *ctx,
@@ -945,14 +1220,14 @@ static void merge_rules(const struct optimize_ctx *ctx,
const struct merge *merge,
struct output_ctx *octx)
{
- enum stmt_types stmt_type;
+ uint32_t merge_type;
bool same_verdict;
uint32_t i;
- stmt_type = merge_stmt_type(ctx);
+ merge_type = merge_stmt_type(ctx, from, to);
- switch (stmt_type) {
- case STMT_VERDICT:
+ switch (merge_type) {
+ case MERGE_BY_VERDICT:
same_verdict = stmt_verdict_cmp(ctx, from, to);
if (merge->num_stmts > 1) {
if (same_verdict)
@@ -966,18 +1241,24 @@ static void merge_rules(const struct optimize_ctx *ctx,
merge_stmts_vmap(ctx, from, to, merge);
}
break;
- case STMT_NAT:
+ case MERGE_BY_NAT_MAP:
if (merge->num_stmts > 1)
merge_concat_nat(ctx, from, to, merge);
else
merge_nat(ctx, from, to, merge);
break;
+ case MERGE_BY_NAT:
+ if (merge->num_stmts > 1)
+ merge_concat_stmts(ctx, from, to, merge);
+ else
+ merge_stmts(ctx, from, to, merge);
+ break;
default:
assert(0);
}
if (ctx->rule[from]->comment) {
- xfree(ctx->rule[from]->comment);
+ free_const(ctx->rule[from]->comment);
ctx->rule[from]->comment = NULL;
}
@@ -1011,15 +1292,41 @@ static bool stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b)
return __stmt_type_eq(stmt_a, stmt_b, true);
}
+static bool stmt_is_mergeable(const struct stmt *stmt)
+{
+ if (!stmt)
+ return false;
+
+ switch (stmt->type) {
+ case STMT_VERDICT:
+ if (stmt->expr->etype == EXPR_MAP)
+ return true;
+ break;
+ case STMT_EXPRESSION:
+ case STMT_NAT:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
static bool rules_eq(const struct optimize_ctx *ctx, int i, int j)
{
- uint32_t k;
+ uint32_t k, mergeable = 0;
for (k = 0; k < ctx->num_stmts; k++) {
+ if (stmt_is_mergeable(ctx->stmt_matrix[i][k]))
+ mergeable++;
+
if (!stmt_type_eq(ctx->stmt_matrix[i][k], ctx->stmt_matrix[j][k]))
return false;
}
+ if (mergeable == 0)
+ return false;
+
return true;
}
@@ -1043,10 +1350,11 @@ static int chain_optimize(struct nft_ctx *nft, struct list_head *rules)
ctx->num_rules++;
}
- ctx->rule = xzalloc(sizeof(ctx->rule) * ctx->num_rules);
- ctx->stmt_matrix = xzalloc(sizeof(struct stmt *) * ctx->num_rules);
+ ctx->rule = xzalloc(sizeof(*ctx->rule) * ctx->num_rules);
+ ctx->stmt_matrix = xzalloc(sizeof(*ctx->stmt_matrix) * ctx->num_rules);
for (i = 0; i < ctx->num_rules; i++)
- ctx->stmt_matrix[i] = xzalloc(sizeof(struct stmt *) * MAX_STMTS);
+ ctx->stmt_matrix[i] = xzalloc_array(MAX_STMTS,
+ sizeof(**ctx->stmt_matrix));
merge = xzalloc(sizeof(*merge) * ctx->num_rules);
@@ -1078,14 +1386,48 @@ static int chain_optimize(struct nft_ctx *nft, struct list_head *rules)
}
}
- /* Step 4: Infer how to merge the candidate rules */
+ /* Step 4: Invalidate merge in case of duplicated keys in set/map. */
for (k = 0; k < num_merges; k++) {
+ uint32_t r1, r2;
+
+ i = merge[k].rule_from;
+
+ for (r1 = i; r1 < i + merge[k].num_rules; r1++) {
+ for (r2 = r1 + 1; r2 < i + merge[k].num_rules; r2++) {
+ bool match_same_value = true, match_seen = false;
+
+ for (m = 0; m < ctx->num_stmts; m++) {
+ if (!ctx->stmt_matrix[r1][m])
+ continue;
+
+ switch (ctx->stmt_matrix[r1][m]->type) {
+ case STMT_EXPRESSION:
+ match_seen = true;
+ if (!__expr_cmp(ctx->stmt_matrix[r1][m]->expr->right,
+ ctx->stmt_matrix[r2][m]->expr->right))
+ match_same_value = false;
+ break;
+ default:
+ break;
+ }
+ }
+ if (match_seen && match_same_value)
+ merge[k].skip = true;
+ }
+ }
+ }
+
+ /* Step 5: Infer how to merge the candidate rules */
+ for (k = 0; k < num_merges; k++) {
+ if (merge[k].skip)
+ continue;
+
i = merge[k].rule_from;
for (m = 0; m < ctx->num_stmts; m++) {
if (!ctx->stmt_matrix[i][m])
continue;
- switch (ctx->stmt_matrix[i][m]->ops->type) {
+ switch (ctx->stmt_matrix[i][m]->type) {
case STMT_EXPRESSION:
merge[k].stmt[merge[k].num_stmts++] = m;
break;
@@ -1103,16 +1445,16 @@ static int chain_optimize(struct nft_ctx *nft, struct list_head *rules)
}
ret = 0;
for (i = 0; i < ctx->num_rules; i++)
- xfree(ctx->stmt_matrix[i]);
+ free(ctx->stmt_matrix[i]);
- xfree(ctx->stmt_matrix);
- xfree(merge);
+ free(ctx->stmt_matrix);
+ free(merge);
err:
for (i = 0; i < ctx->num_stmts; i++)
stmt_free(ctx->stmt[i]);
- xfree(ctx->rule);
- xfree(ctx);
+ free(ctx->rule);
+ free(ctx);
return ret;
}
@@ -1146,7 +1488,7 @@ static int cmd_optimize(struct nft_ctx *nft, struct cmd *cmd)
int nft_optimize(struct nft_ctx *nft, struct list_head *cmds)
{
struct cmd *cmd;
- int ret;
+ int ret = 0;
list_for_each_entry(cmd, cmds, list) {
switch (cmd->op) {