/* * Copyright (c) 2021 Pablo Neira Ayuso * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 (or any * later) as published by the Free Software Foundation. */ /* Funded through the NGI0 PET Fund established by NLnet (https://nlnet.nl) * with support from the European Commission's Next Generation Internet * programme. */ #define _GNU_SOURCE #include #include #include #include #include #include #include #include #include #include #define MAX_STMTS 32 struct optimize_ctx { struct stmt *stmt[MAX_STMTS]; uint32_t num_stmts; struct stmt ***stmt_matrix; struct rule **rule; uint32_t num_rules; }; static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b) { if (expr_a->etype != expr_b->etype) return false; switch (expr_a->etype) { case EXPR_PAYLOAD: if (expr_a->payload.base != expr_b->payload.base) return false; if (expr_a->payload.offset != expr_b->payload.offset) return false; if (expr_a->payload.desc != expr_b->payload.desc) return false; if (expr_a->payload.inner_desc != expr_b->payload.inner_desc) return false; if (expr_a->payload.tmpl != expr_b->payload.tmpl) return false; break; case EXPR_EXTHDR: if (expr_a->exthdr.desc != expr_b->exthdr.desc) return false; if (expr_a->exthdr.tmpl != expr_b->exthdr.tmpl) return false; break; case EXPR_META: if (expr_a->meta.key != expr_b->meta.key) return false; if (expr_a->meta.base != expr_b->meta.base) return false; break; case EXPR_CT: if (expr_a->ct.key != expr_b->ct.key) return false; if (expr_a->ct.base != expr_b->ct.base) return false; if (expr_a->ct.direction != expr_b->ct.direction) return false; if (expr_a->ct.nfproto != expr_b->ct.nfproto) return false; break; case EXPR_RT: if (expr_a->rt.key != expr_b->rt.key) return false; break; case EXPR_SOCKET: if (expr_a->socket.key != expr_b->socket.key) return false; if (expr_a->socket.level != expr_b->socket.level) return false; break; case EXPR_OSF: if (expr_a->osf.ttl != expr_b->osf.ttl) return false; if (expr_a->osf.flags != expr_b->osf.flags) return false; break; case EXPR_XFRM: if (expr_a->xfrm.key != expr_b->xfrm.key) return false; if (expr_a->xfrm.direction != expr_b->xfrm.direction) return false; break; case EXPR_FIB: if (expr_a->fib.flags != expr_b->fib.flags) return false; if (expr_a->fib.result != expr_b->fib.result) return false; break; case EXPR_NUMGEN: if (expr_a->numgen.type != expr_b->numgen.type) return false; if (expr_a->numgen.mod != expr_b->numgen.mod) return false; if (expr_a->numgen.offset != expr_b->numgen.offset) return false; break; case EXPR_HASH: if (expr_a->hash.mod != expr_b->hash.mod) return false; if (expr_a->hash.seed_set != expr_b->hash.seed_set) return false; if (expr_a->hash.seed != expr_b->hash.seed) return false; if (expr_a->hash.offset != expr_b->hash.offset) return false; if (expr_a->hash.type != expr_b->hash.type) return false; break; case EXPR_BINOP: return __expr_cmp(expr_a->left, expr_b->left); default: return false; } return true; } static bool stmt_expr_supported(const struct expr *expr) { switch (expr->right->etype) { case EXPR_SYMBOL: case EXPR_RANGE: case EXPR_PREFIX: case EXPR_SET: case EXPR_LIST: case EXPR_VALUE: return true; default: break; } return false; } static bool expr_symbol_set(const struct expr *expr) { return expr->right->etype == EXPR_SYMBOL && expr->right->symtype == SYMBOL_SET; } static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b, bool fully_compare) { struct expr *expr_a, *expr_b; if (stmt_a->ops->type != stmt_b->ops->type) return false; switch (stmt_a->ops->type) { case STMT_EXPRESSION: expr_a = stmt_a->expr; expr_b = stmt_b->expr; if (expr_a->op != expr_b->op) return false; if (expr_a->op != OP_IMPLICIT && expr_a->op != OP_EQ) return false; if (fully_compare) { if (!stmt_expr_supported(expr_a) || !stmt_expr_supported(expr_b)) return false; if (expr_symbol_set(expr_a) || expr_symbol_set(expr_b)) return false; } return __expr_cmp(expr_a->left, expr_b->left); case STMT_COUNTER: case STMT_NOTRACK: break; case STMT_VERDICT: if (!fully_compare) break; expr_a = stmt_a->expr; expr_b = stmt_b->expr; if (expr_a->etype != expr_b->etype) return false; if (expr_a->etype == EXPR_MAP && expr_b->etype == EXPR_MAP) return __expr_cmp(expr_a->map, expr_b->map); break; case STMT_LOG: if (stmt_a->log.snaplen != stmt_b->log.snaplen || stmt_a->log.group != stmt_b->log.group || stmt_a->log.qthreshold != stmt_b->log.qthreshold || stmt_a->log.level != stmt_b->log.level || stmt_a->log.logflags != stmt_b->log.logflags || stmt_a->log.flags != stmt_b->log.flags) return false; if (!!stmt_a->log.prefix ^ !!stmt_b->log.prefix) return false; if (!stmt_a->log.prefix) return true; if (stmt_a->log.prefix->etype != EXPR_VALUE || stmt_b->log.prefix->etype != EXPR_VALUE || mpz_cmp(stmt_a->log.prefix->value, stmt_b->log.prefix->value)) return false; break; case STMT_REJECT: if (stmt_a->reject.family != stmt_b->reject.family || stmt_a->reject.type != stmt_b->reject.type || stmt_a->reject.icmp_code != stmt_b->reject.icmp_code) return false; if (!!stmt_a->reject.expr ^ !!stmt_b->reject.expr) return false; if (!stmt_a->reject.expr) return true; if (__expr_cmp(stmt_a->reject.expr, stmt_b->reject.expr)) return false; break; case STMT_NAT: if (stmt_a->nat.type != stmt_b->nat.type || stmt_a->nat.flags != stmt_b->nat.flags || stmt_a->nat.family != stmt_b->nat.family || stmt_a->nat.type_flags != stmt_b->nat.type_flags) return false; switch (stmt_a->nat.type) { case NFT_NAT_SNAT: case NFT_NAT_DNAT: if ((stmt_a->nat.addr && stmt_a->nat.addr->etype != EXPR_SYMBOL && stmt_a->nat.addr->etype != EXPR_RANGE) || (stmt_b->nat.addr && stmt_b->nat.addr->etype != EXPR_SYMBOL && stmt_b->nat.addr->etype != EXPR_RANGE) || (stmt_a->nat.proto && stmt_a->nat.proto->etype != EXPR_SYMBOL && stmt_a->nat.proto->etype != EXPR_RANGE) || (stmt_b->nat.proto && stmt_b->nat.proto->etype != EXPR_SYMBOL && stmt_b->nat.proto->etype != EXPR_RANGE)) return false; break; case NFT_NAT_MASQ: break; case NFT_NAT_REDIR: if ((stmt_a->nat.proto && stmt_a->nat.proto->etype != EXPR_SYMBOL && stmt_a->nat.proto->etype != EXPR_RANGE) || (stmt_b->nat.proto && stmt_b->nat.proto->etype != EXPR_SYMBOL && stmt_b->nat.proto->etype != EXPR_RANGE)) return false; /* it should be possible to infer implicit redirections * such as: * * tcp dport 1234 redirect * tcp dport 3456 redirect to :7890 * merge: * redirect to tcp dport map { 1234 : 1234, 3456 : 7890 } * * currently not implemented. */ if (fully_compare && stmt_a->nat.type == NFT_NAT_REDIR && stmt_b->nat.type == NFT_NAT_REDIR && (!!stmt_a->nat.proto ^ !!stmt_b->nat.proto)) return false; break; default: assert(0); } return true; default: /* ... Merging anything else is yet unsupported. */ return false; } return true; } static bool expr_verdict_eq(const struct expr *expr_a, const struct expr *expr_b) { if (expr_a->verdict != expr_b->verdict) return false; if (expr_a->chain && expr_b->chain) { if (expr_a->chain->etype != expr_b->chain->etype) return false; if (expr_a->chain->etype == EXPR_VALUE && strcmp(expr_a->chain->identifier, expr_b->chain->identifier)) return false; } else if (expr_a->chain || expr_b->chain) { return false; } return true; } static bool stmt_verdict_eq(const struct stmt *stmt_a, const struct stmt *stmt_b) { struct expr *expr_a, *expr_b; assert (stmt_a->ops->type == STMT_VERDICT); expr_a = stmt_a->expr; expr_b = stmt_b->expr; if (expr_a->etype == EXPR_VERDICT && expr_b->etype == EXPR_VERDICT) return expr_verdict_eq(expr_a, expr_b); if (expr_a->etype == EXPR_MAP && expr_b->etype == EXPR_MAP) return __expr_cmp(expr_a->map, expr_b->map); return false; } static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt) { bool unsupported_exists = false; uint32_t i; for (i = 0; i < ctx->num_stmts; i++) { if (ctx->stmt[i]->ops->type == STMT_INVALID) unsupported_exists = true; if (__stmt_type_eq(stmt, ctx->stmt[i], false)) return true; } switch (stmt->ops->type) { case STMT_EXPRESSION: case STMT_VERDICT: case STMT_COUNTER: case STMT_NOTRACK: case STMT_LOG: case STMT_NAT: case STMT_REJECT: break; default: /* add unsupported statement only once to statement matrix. */ if (unsupported_exists) return true; break; } return false; } static struct stmt_ops unsupported_stmt_ops = { .type = STMT_INVALID, .name = "unsupported", }; static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule) { struct stmt *stmt, *clone; list_for_each_entry(stmt, &rule->stmts, list) { if (stmt_type_find(ctx, stmt)) continue; /* No refcounter available in statement objects, clone it to * to store in the array of selectors. */ clone = stmt_alloc(&internal_location, stmt->ops); switch (stmt->ops->type) { case STMT_EXPRESSION: if (stmt->expr->op != OP_IMPLICIT && stmt->expr->op != OP_EQ) { clone->ops = &unsupported_stmt_ops; break; } if (stmt->expr->left->etype == EXPR_CONCAT) { clone->ops = &unsupported_stmt_ops; break; } case STMT_VERDICT: clone->expr = expr_get(stmt->expr); break; case STMT_COUNTER: case STMT_NOTRACK: break; case STMT_LOG: memcpy(&clone->log, &stmt->log, sizeof(clone->log)); if (stmt->log.prefix) clone->log.prefix = expr_get(stmt->log.prefix); break; case STMT_NAT: if ((stmt->nat.addr && stmt->nat.addr->etype == EXPR_MAP) || (stmt->nat.proto && stmt->nat.proto->etype == EXPR_MAP)) { clone->ops = &unsupported_stmt_ops; break; } clone->nat.type = stmt->nat.type; clone->nat.family = stmt->nat.family; if (stmt->nat.addr) clone->nat.addr = expr_clone(stmt->nat.addr); if (stmt->nat.proto) clone->nat.proto = expr_clone(stmt->nat.proto); clone->nat.flags = stmt->nat.flags; clone->nat.type_flags = stmt->nat.type_flags; break; case STMT_REJECT: if (stmt->reject.expr) clone->reject.expr = expr_get(stmt->reject.expr); clone->reject.type = stmt->reject.type; clone->reject.icmp_code = stmt->reject.icmp_code; clone->reject.family = stmt->reject.family; break; default: clone->ops = &unsupported_stmt_ops; break; } ctx->stmt[ctx->num_stmts++] = clone; if (ctx->num_stmts >= MAX_STMTS) return -1; } return 0; } static int unsupported_in_stmt_matrix(const struct optimize_ctx *ctx) { uint32_t i; for (i = 0; i < ctx->num_stmts; i++) { if (ctx->stmt[i]->ops->type == STMT_INVALID) return i; } /* this should not happen. */ return -1; } static int cmd_stmt_find_in_stmt_matrix(struct optimize_ctx *ctx, struct stmt *stmt) { uint32_t i; for (i = 0; i < ctx->num_stmts; i++) { if (__stmt_type_eq(stmt, ctx->stmt[i], false)) return i; } return -1; } static struct stmt unsupported_stmt = { .ops = &unsupported_stmt_ops, }; static void rule_build_stmt_matrix_stmts(struct optimize_ctx *ctx, struct rule *rule, uint32_t *i) { struct stmt *stmt; int k; list_for_each_entry(stmt, &rule->stmts, list) { k = cmd_stmt_find_in_stmt_matrix(ctx, stmt); if (k < 0) { k = unsupported_in_stmt_matrix(ctx); assert(k >= 0); ctx->stmt_matrix[*i][k] = &unsupported_stmt; continue; } ctx->stmt_matrix[*i][k] = stmt; } ctx->rule[(*i)++] = rule; } static int stmt_verdict_find(const struct optimize_ctx *ctx) { uint32_t i; for (i = 0; i < ctx->num_stmts; i++) { if (ctx->stmt[i]->ops->type != STMT_VERDICT) continue; return i; } return -1; } struct merge { /* interval of rules to be merged */ uint32_t rule_from; uint32_t num_rules; /* statements to be merged (index relative to statement matrix) */ uint32_t stmt[MAX_STMTS]; uint32_t num_stmts; }; static void merge_expr_stmts(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge, struct stmt *stmt_a) { struct expr *expr_a, *expr_b, *set, *elem; struct stmt *stmt_b; uint32_t i; set = set_expr_alloc(&internal_location, NULL); set->set_flags |= NFT_SET_ANONYMOUS; expr_a = stmt_a->expr->right; elem = set_elem_expr_alloc(&internal_location, expr_get(expr_a)); compound_expr_add(set, elem); for (i = from + 1; i <= to; i++) { stmt_b = ctx->stmt_matrix[i][merge->stmt[0]]; expr_b = stmt_b->expr->right; elem = set_elem_expr_alloc(&internal_location, expr_get(expr_b)); compound_expr_add(set, elem); } expr_free(stmt_a->expr->right); stmt_a->expr->right = set; } static void merge_vmap(const struct optimize_ctx *ctx, struct stmt *stmt_a, const struct stmt *stmt_b) { struct expr *mappings, *mapping, *expr; mappings = stmt_b->expr->mappings; list_for_each_entry(expr, &mappings->expressions, list) { mapping = expr_clone(expr); compound_expr_add(stmt_a->expr->mappings, mapping); } } static void merge_verdict_stmts(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge, struct stmt *stmt_a) { struct stmt *stmt_b; uint32_t i; for (i = from + 1; i <= to; i++) { stmt_b = ctx->stmt_matrix[i][merge->stmt[0]]; switch (stmt_b->ops->type) { case STMT_VERDICT: switch (stmt_b->expr->etype) { case EXPR_MAP: merge_vmap(ctx, stmt_a, stmt_b); break; default: assert(0); } break; default: assert(0); break; } } } static void merge_stmts(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) { struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]]; switch (stmt_a->ops->type) { case STMT_EXPRESSION: merge_expr_stmts(ctx, from, to, merge, stmt_a); break; case STMT_VERDICT: merge_verdict_stmts(ctx, from, to, merge, stmt_a); break; default: assert(0); } } static void __merge_concat(const struct optimize_ctx *ctx, uint32_t i, const struct merge *merge, struct list_head *concat_list) { struct expr *concat, *next, *expr, *concat_clone, *clone; LIST_HEAD(pending_list); struct stmt *stmt_a; uint32_t k; concat = concat_expr_alloc(&internal_location); list_add(&concat->list, concat_list); for (k = 0; k < merge->num_stmts; k++) { list_for_each_entry_safe(concat, next, concat_list, list) { stmt_a = ctx->stmt_matrix[i][merge->stmt[k]]; switch (stmt_a->expr->right->etype) { case EXPR_SET: list_for_each_entry(expr, &stmt_a->expr->right->expressions, list) { concat_clone = expr_clone(concat); clone = expr_clone(expr->key); compound_expr_add(concat_clone, clone); list_add_tail(&concat_clone->list, &pending_list); } list_del(&concat->list); expr_free(concat); break; case EXPR_SYMBOL: case EXPR_VALUE: case EXPR_PREFIX: case EXPR_RANGE: clone = expr_clone(stmt_a->expr->right); compound_expr_add(concat, clone); break; default: assert(0); break; } } list_splice_init(&pending_list, concat_list); } } static void __merge_concat_stmts(const struct optimize_ctx *ctx, uint32_t i, const struct merge *merge, struct expr *set) { struct expr *concat, *next, *elem; LIST_HEAD(concat_list); __merge_concat(ctx, i, merge, &concat_list); list_for_each_entry_safe(concat, next, &concat_list, list) { list_del(&concat->list); elem = set_elem_expr_alloc(&internal_location, concat); compound_expr_add(set, elem); } } static void merge_concat_stmts(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) { struct stmt *stmt, *stmt_a; struct expr *concat, *set; uint32_t i, k; stmt = ctx->stmt_matrix[from][merge->stmt[0]]; /* build concatenation of selectors, eg. ifname . ip daddr . tcp dport */ concat = concat_expr_alloc(&internal_location); for (k = 0; k < merge->num_stmts; k++) { stmt_a = ctx->stmt_matrix[from][merge->stmt[k]]; compound_expr_add(concat, expr_get(stmt_a->expr->left)); } expr_free(stmt->expr->left); stmt->expr->left = concat; /* build set data contenation, eg. { eth0 . 1.1.1.1 . 22 } */ set = set_expr_alloc(&internal_location, NULL); set->set_flags |= NFT_SET_ANONYMOUS; for (i = from; i <= to; i++) __merge_concat_stmts(ctx, i, merge, set); expr_free(stmt->expr->right); stmt->expr->right = set; for (k = 1; k < merge->num_stmts; k++) { stmt_a = ctx->stmt_matrix[from][merge->stmt[k]]; list_del(&stmt_a->list); stmt_free(stmt_a); } } static void build_verdict_map(struct expr *expr, struct stmt *verdict, struct expr *set, struct stmt *counter) { struct expr *item, *elem, *mapping; switch (expr->etype) { case EXPR_LIST: list_for_each_entry(item, &expr->expressions, list) { elem = set_elem_expr_alloc(&internal_location, expr_get(item)); if (counter) list_add_tail(&counter->list, &elem->stmt_list); mapping = mapping_expr_alloc(&internal_location, elem, expr_get(verdict->expr)); compound_expr_add(set, mapping); } break; case EXPR_SET: list_for_each_entry(item, &expr->expressions, list) { elem = set_elem_expr_alloc(&internal_location, expr_get(item->key)); if (counter) list_add_tail(&counter->list, &elem->stmt_list); mapping = mapping_expr_alloc(&internal_location, elem, expr_get(verdict->expr)); compound_expr_add(set, mapping); } break; case EXPR_PREFIX: case EXPR_RANGE: case EXPR_VALUE: case EXPR_SYMBOL: case EXPR_CONCAT: elem = set_elem_expr_alloc(&internal_location, expr_get(expr)); if (counter) list_add_tail(&counter->list, &elem->stmt_list); mapping = mapping_expr_alloc(&internal_location, elem, expr_get(verdict->expr)); compound_expr_add(set, mapping); break; default: assert(0); break; } } static void remove_counter(const struct optimize_ctx *ctx, uint32_t from) { struct stmt *stmt; uint32_t i; /* remove counter statement */ for (i = 0; i < ctx->num_stmts; i++) { stmt = ctx->stmt_matrix[from][i]; if (!stmt) continue; if (stmt->ops->type == STMT_COUNTER) { list_del(&stmt->list); stmt_free(stmt); } } } static struct stmt *zap_counter(const struct optimize_ctx *ctx, uint32_t from) { struct stmt *stmt; uint32_t i; /* remove counter statement */ for (i = 0; i < ctx->num_stmts; i++) { stmt = ctx->stmt_matrix[from][i]; if (!stmt) continue; if (stmt->ops->type == STMT_COUNTER) { list_del(&stmt->list); return stmt; } } return NULL; } static void merge_stmts_vmap(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) { struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]]; struct stmt *stmt_b, *verdict_a, *verdict_b, *stmt; struct expr *expr_a, *expr_b, *expr, *left, *set; struct stmt *counter; uint32_t i; int k; k = stmt_verdict_find(ctx); assert(k >= 0); set = set_expr_alloc(&internal_location, NULL); set->set_flags |= NFT_SET_ANONYMOUS; expr_a = stmt_a->expr->right; verdict_a = ctx->stmt_matrix[from][k]; counter = zap_counter(ctx, from); build_verdict_map(expr_a, verdict_a, set, counter); for (i = from + 1; i <= to; i++) { stmt_b = ctx->stmt_matrix[i][merge->stmt[0]]; expr_b = stmt_b->expr->right; verdict_b = ctx->stmt_matrix[i][k]; counter = zap_counter(ctx, i); build_verdict_map(expr_b, verdict_b, set, counter); } left = expr_get(stmt_a->expr->left); expr = map_expr_alloc(&internal_location, left, set); stmt = verdict_stmt_alloc(&internal_location, expr); list_add(&stmt->list, &stmt_a->list); list_del(&stmt_a->list); stmt_free(stmt_a); list_del(&verdict_a->list); stmt_free(verdict_a); } static void __merge_concat_stmts_vmap(const struct optimize_ctx *ctx, uint32_t i, const struct merge *merge, struct expr *set, struct stmt *verdict) { struct expr *concat, *next, *elem, *mapping; LIST_HEAD(concat_list); struct stmt *counter; counter = zap_counter(ctx, i); __merge_concat(ctx, i, merge, &concat_list); list_for_each_entry_safe(concat, next, &concat_list, list) { list_del(&concat->list); elem = set_elem_expr_alloc(&internal_location, concat); if (counter) list_add_tail(&counter->list, &elem->stmt_list); mapping = mapping_expr_alloc(&internal_location, elem, expr_get(verdict->expr)); compound_expr_add(set, mapping); } } static void merge_concat_stmts_vmap(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) { struct stmt *orig_stmt = ctx->stmt_matrix[from][merge->stmt[0]]; struct stmt *stmt, *stmt_a, *verdict; struct expr *concat_a, *expr, *set; uint32_t i; int k; k = stmt_verdict_find(ctx); assert(k >= 0); /* build concatenation of selectors, eg. ifname . ip daddr . tcp dport */ concat_a = concat_expr_alloc(&internal_location); for (i = 0; i < merge->num_stmts; i++) { stmt_a = ctx->stmt_matrix[from][merge->stmt[i]]; compound_expr_add(concat_a, expr_get(stmt_a->expr->left)); } /* build set data contenation, eg. { eth0 . 1.1.1.1 . 22 : accept } */ set = set_expr_alloc(&internal_location, NULL); set->set_flags |= NFT_SET_ANONYMOUS; for (i = from; i <= to; i++) { verdict = ctx->stmt_matrix[i][k]; __merge_concat_stmts_vmap(ctx, i, merge, set, verdict); } expr = map_expr_alloc(&internal_location, concat_a, set); stmt = verdict_stmt_alloc(&internal_location, expr); list_add(&stmt->list, &orig_stmt->list); list_del(&orig_stmt->list); stmt_free(orig_stmt); for (i = 1; i < merge->num_stmts; i++) { stmt_a = ctx->stmt_matrix[from][merge->stmt[i]]; list_del(&stmt_a->list); stmt_free(stmt_a); } verdict = ctx->stmt_matrix[from][k]; list_del(&verdict->list); stmt_free(verdict); } static bool stmt_verdict_cmp(const struct optimize_ctx *ctx, uint32_t from, uint32_t to) { struct stmt *stmt_a, *stmt_b; uint32_t i; int k; k = stmt_verdict_find(ctx); if (k < 0) return true; for (i = from; i + 1 <= to; i++) { stmt_a = ctx->stmt_matrix[i][k]; stmt_b = ctx->stmt_matrix[i + 1][k]; if (!stmt_a && !stmt_b) continue; if (!stmt_a || !stmt_b) return false; if (!stmt_verdict_eq(stmt_a, stmt_b)) return false; } return true; } static int stmt_nat_type(const struct optimize_ctx *ctx, int from, enum nft_nat_etypes *nat_type) { uint32_t j; for (j = 0; j < ctx->num_stmts; j++) { if (!ctx->stmt_matrix[from][j]) continue; if (ctx->stmt_matrix[from][j]->ops->type == STMT_NAT) { *nat_type = ctx->stmt_matrix[from][j]->nat.type; return 0; } } return -1; } static int stmt_nat_find(const struct optimize_ctx *ctx, int from) { enum nft_nat_etypes nat_type; uint32_t i; if (stmt_nat_type(ctx, from, &nat_type) < 0) return -1; for (i = 0; i < ctx->num_stmts; i++) { if (ctx->stmt[i]->ops->type != STMT_NAT || ctx->stmt[i]->nat.type != nat_type) continue; return i; } return -1; } static struct expr *stmt_nat_expr(struct stmt *nat_stmt) { struct expr *nat_expr; assert(nat_stmt->ops->type == STMT_NAT); if (nat_stmt->nat.proto) { if (nat_stmt->nat.addr) { nat_expr = concat_expr_alloc(&internal_location); compound_expr_add(nat_expr, expr_get(nat_stmt->nat.addr)); compound_expr_add(nat_expr, expr_get(nat_stmt->nat.proto)); } else { nat_expr = expr_get(nat_stmt->nat.proto); } expr_free(nat_stmt->nat.proto); nat_stmt->nat.proto = NULL; } else { nat_expr = expr_get(nat_stmt->nat.addr); } assert(nat_expr); return nat_expr; } static void merge_nat(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) { struct expr *expr, *set, *elem, *nat_expr, *mapping, *left; int k, family = NFPROTO_UNSPEC; struct stmt *stmt, *nat_stmt; uint32_t i; k = stmt_nat_find(ctx, from); assert(k >= 0); set = set_expr_alloc(&internal_location, NULL); set->set_flags |= NFT_SET_ANONYMOUS; for (i = from; i <= to; i++) { stmt = ctx->stmt_matrix[i][merge->stmt[0]]; expr = stmt->expr->right; nat_stmt = ctx->stmt_matrix[i][k]; nat_expr = stmt_nat_expr(nat_stmt); elem = set_elem_expr_alloc(&internal_location, expr_get(expr)); mapping = mapping_expr_alloc(&internal_location, elem, nat_expr); compound_expr_add(set, mapping); } stmt = ctx->stmt_matrix[from][merge->stmt[0]]; left = expr_get(stmt->expr->left); if (left->etype == EXPR_PAYLOAD) { if (left->payload.desc == &proto_ip) family = NFPROTO_IPV4; else if (left->payload.desc == &proto_ip6) family = NFPROTO_IPV6; } expr = map_expr_alloc(&internal_location, left, set); nat_stmt = ctx->stmt_matrix[from][k]; if (nat_stmt->nat.family == NFPROTO_UNSPEC) nat_stmt->nat.family = family; expr_free(nat_stmt->nat.addr); if (nat_stmt->nat.type == NFT_NAT_REDIR) nat_stmt->nat.proto = expr; else nat_stmt->nat.addr = expr; remove_counter(ctx, from); list_del(&stmt->list); stmt_free(stmt); } static void merge_concat_nat(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge) { struct expr *expr, *set, *elem, *nat_expr, *mapping, *left, *concat; int k, family = NFPROTO_UNSPEC; struct stmt *stmt, *nat_stmt; uint32_t i, j; k = stmt_nat_find(ctx, from); assert(k >= 0); set = set_expr_alloc(&internal_location, NULL); set->set_flags |= NFT_SET_ANONYMOUS; for (i = from; i <= to; i++) { concat = concat_expr_alloc(&internal_location); for (j = 0; j < merge->num_stmts; j++) { stmt = ctx->stmt_matrix[i][merge->stmt[j]]; expr = stmt->expr->right; compound_expr_add(concat, expr_get(expr)); } nat_stmt = ctx->stmt_matrix[i][k]; nat_expr = stmt_nat_expr(nat_stmt); elem = set_elem_expr_alloc(&internal_location, concat); mapping = mapping_expr_alloc(&internal_location, elem, nat_expr); compound_expr_add(set, mapping); } concat = concat_expr_alloc(&internal_location); for (j = 0; j < merge->num_stmts; j++) { stmt = ctx->stmt_matrix[from][merge->stmt[j]]; left = stmt->expr->left; if (left->etype == EXPR_PAYLOAD) { if (left->payload.desc == &proto_ip) family = NFPROTO_IPV4; else if (left->payload.desc == &proto_ip6) family = NFPROTO_IPV6; } compound_expr_add(concat, expr_get(left)); } expr = map_expr_alloc(&internal_location, concat, set); nat_stmt = ctx->stmt_matrix[from][k]; if (nat_stmt->nat.family == NFPROTO_UNSPEC) nat_stmt->nat.family = family; expr_free(nat_stmt->nat.addr); nat_stmt->nat.addr = expr; remove_counter(ctx, from); for (j = 0; j < merge->num_stmts; j++) { stmt = ctx->stmt_matrix[from][merge->stmt[j]]; list_del(&stmt->list); stmt_free(stmt); } } static void rule_optimize_print(struct output_ctx *octx, const struct rule *rule) { const struct location *loc = &rule->location; const struct input_descriptor *indesc = loc->indesc; const char *line = ""; char buf[1024]; switch (indesc->type) { case INDESC_BUFFER: case INDESC_CLI: line = indesc->data; *strchrnul(line, '\n') = '\0'; break; case INDESC_STDIN: line = indesc->data; line += loc->line_offset; *strchrnul(line, '\n') = '\0'; break; case INDESC_FILE: line = line_location(indesc, loc, buf, sizeof(buf)); break; case INDESC_INTERNAL: case INDESC_NETLINK: break; default: BUG("invalid input descriptor type %u\n", indesc->type); } print_location(octx->error_fp, indesc, loc); fprintf(octx->error_fp, "%s\n", line); } enum { MERGE_BY_VERDICT, MERGE_BY_NAT_MAP, MERGE_BY_NAT, }; static uint32_t merge_stmt_type(const struct optimize_ctx *ctx, uint32_t from, uint32_t to) { const struct stmt *stmt; uint32_t i, j; for (i = from; i <= to; i++) { for (j = 0; j < ctx->num_stmts; j++) { stmt = ctx->stmt_matrix[i][j]; if (!stmt) continue; if (stmt->ops->type == STMT_NAT) { if ((stmt->nat.type == NFT_NAT_REDIR && !stmt->nat.proto) || stmt->nat.type == NFT_NAT_MASQ) return MERGE_BY_NAT; return MERGE_BY_NAT_MAP; } } } /* merge by verdict, even if no verdict is specified. */ return MERGE_BY_VERDICT; } static void merge_rules(const struct optimize_ctx *ctx, uint32_t from, uint32_t to, const struct merge *merge, struct output_ctx *octx) { uint32_t merge_type; bool same_verdict; uint32_t i; merge_type = merge_stmt_type(ctx, from, to); switch (merge_type) { case MERGE_BY_VERDICT: same_verdict = stmt_verdict_cmp(ctx, from, to); if (merge->num_stmts > 1) { if (same_verdict) merge_concat_stmts(ctx, from, to, merge); else merge_concat_stmts_vmap(ctx, from, to, merge); } else { if (same_verdict) merge_stmts(ctx, from, to, merge); else merge_stmts_vmap(ctx, from, to, merge); } break; case MERGE_BY_NAT_MAP: if (merge->num_stmts > 1) merge_concat_nat(ctx, from, to, merge); else merge_nat(ctx, from, to, merge); break; case MERGE_BY_NAT: if (merge->num_stmts > 1) merge_concat_stmts(ctx, from, to, merge); else merge_stmts(ctx, from, to, merge); break; default: assert(0); } if (ctx->rule[from]->comment) { xfree(ctx->rule[from]->comment); ctx->rule[from]->comment = NULL; } octx->flags |= NFT_CTX_OUTPUT_STATELESS; fprintf(octx->error_fp, "Merging:\n"); rule_optimize_print(octx, ctx->rule[from]); for (i = from + 1; i <= to; i++) { rule_optimize_print(octx, ctx->rule[i]); list_del(&ctx->rule[i]->list); rule_free(ctx->rule[i]); } fprintf(octx->error_fp, "into:\n\t"); rule_print(ctx->rule[from], octx); fprintf(octx->error_fp, "\n"); octx->flags &= ~NFT_CTX_OUTPUT_STATELESS; } static bool stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b) { if (!stmt_a && !stmt_b) return true; else if (!stmt_a) return false; else if (!stmt_b) return false; return __stmt_type_eq(stmt_a, stmt_b, true); } static bool stmt_is_mergeable(const struct stmt *stmt) { if (!stmt) return false; switch (stmt->ops->type) { case STMT_VERDICT: if (stmt->expr->etype == EXPR_MAP) return true; break; case STMT_EXPRESSION: case STMT_NAT: return true; default: break; } return false; } static bool rules_eq(const struct optimize_ctx *ctx, int i, int j) { uint32_t k, mergeable = 0; for (k = 0; k < ctx->num_stmts; k++) { if (stmt_is_mergeable(ctx->stmt_matrix[i][k])) mergeable++; if (!stmt_type_eq(ctx->stmt_matrix[i][k], ctx->stmt_matrix[j][k])) return false; } if (mergeable == 0) return false; return true; } static int chain_optimize(struct nft_ctx *nft, struct list_head *rules) { struct optimize_ctx *ctx; uint32_t num_merges = 0; struct merge *merge; uint32_t i, j, m, k; struct rule *rule; int ret; ctx = xzalloc(sizeof(*ctx)); /* Step 1: collect statements in rules */ list_for_each_entry(rule, rules, list) { ret = rule_collect_stmts(ctx, rule); if (ret < 0) goto err; ctx->num_rules++; } ctx->rule = xzalloc(sizeof(*ctx->rule) * ctx->num_rules); ctx->stmt_matrix = xzalloc(sizeof(*ctx->stmt_matrix) * ctx->num_rules); for (i = 0; i < ctx->num_rules; i++) ctx->stmt_matrix[i] = xzalloc_array(MAX_STMTS, sizeof(**ctx->stmt_matrix)); merge = xzalloc(sizeof(*merge) * ctx->num_rules); /* Step 2: Build matrix of statements */ i = 0; list_for_each_entry(rule, rules, list) rule_build_stmt_matrix_stmts(ctx, rule, &i); /* Step 3: Look for common selectors for possible rule mergers */ for (i = 0; i < ctx->num_rules; i++) { for (j = i + 1; j < ctx->num_rules; j++) { if (!rules_eq(ctx, i, j)) { if (merge[num_merges].num_rules > 0) num_merges++; i = j - 1; break; } if (merge[num_merges].num_rules > 0) { merge[num_merges].num_rules++; } else { merge[num_merges].rule_from = i; merge[num_merges].num_rules = 2; } } if (j == ctx->num_rules && merge[num_merges].num_rules > 0) { num_merges++; break; } } /* Step 4: Infer how to merge the candidate rules */ for (k = 0; k < num_merges; k++) { i = merge[k].rule_from; for (m = 0; m < ctx->num_stmts; m++) { if (!ctx->stmt_matrix[i][m]) continue; switch (ctx->stmt_matrix[i][m]->ops->type) { case STMT_EXPRESSION: merge[k].stmt[merge[k].num_stmts++] = m; break; case STMT_VERDICT: if (ctx->stmt_matrix[i][m]->expr->etype == EXPR_MAP) merge[k].stmt[merge[k].num_stmts++] = m; break; default: break; } } j = merge[k].num_rules - 1; merge_rules(ctx, i, i + j, &merge[k], &nft->output); } ret = 0; for (i = 0; i < ctx->num_rules; i++) xfree(ctx->stmt_matrix[i]); xfree(ctx->stmt_matrix); xfree(merge); err: for (i = 0; i < ctx->num_stmts; i++) stmt_free(ctx->stmt[i]); xfree(ctx->rule); xfree(ctx); return ret; } static int cmd_optimize(struct nft_ctx *nft, struct cmd *cmd) { struct table *table; struct chain *chain; int ret = 0; switch (cmd->obj) { case CMD_OBJ_TABLE: table = cmd->table; if (!table) break; list_for_each_entry(chain, &table->chains, list) { if (chain->flags & CHAIN_F_HW_OFFLOAD) continue; chain_optimize(nft, &chain->rules); } break; default: break; } return ret; } int nft_optimize(struct nft_ctx *nft, struct list_head *cmds) { struct cmd *cmd; int ret = 0; list_for_each_entry(cmd, cmds, list) { switch (cmd->op) { case CMD_ADD: ret = cmd_optimize(nft, cmd); break; default: break; } } return ret; }