diff options
Diffstat (limited to 'src/netlink_delinearize.c')
-rw-r--r-- | src/netlink_delinearize.c | 1150 |
1 files changed, 857 insertions, 293 deletions
diff --git a/src/netlink_delinearize.c b/src/netlink_delinearize.c index d82d9f51..da9f7a91 100644 --- a/src/netlink_delinearize.c +++ b/src/netlink_delinearize.c @@ -9,9 +9,8 @@ * Development of this code funded by Astaro AG (http://www.astaro.com/) */ -#include <stdlib.h> -#include <stdbool.h> -#include <string.h> +#include <nft.h> + #include <limits.h> #include <linux/netfilter/nf_tables.h> #include <arpa/inet.h> @@ -30,6 +29,16 @@ #include <cache.h> #include <xt.h> +struct dl_proto_ctx *dl_proto_ctx(struct rule_pp_ctx *ctx) +{ + return ctx->dl; +} + +static struct dl_proto_ctx *dl_proto_ctx_outer(struct rule_pp_ctx *ctx) +{ + return &ctx->_dl[0]; +} + static int netlink_parse_expr(const struct nftnl_expr *nle, struct netlink_parse_ctx *ctx); @@ -72,8 +81,7 @@ static void netlink_set_register(struct netlink_parse_ctx *ctx, return; } - if (ctx->registers[reg] != NULL) - expr_free(ctx->registers[reg]); + expr_free(ctx->registers[reg]); ctx->registers[reg] = expr; } @@ -100,7 +108,7 @@ static void netlink_release_registers(struct netlink_parse_ctx *ctx) { int i; - for (i = 0; i < MAX_REGS; i++) + for (i = 0; i <= MAX_REGS; i++) expr_free(ctx->registers[i]); } @@ -134,6 +142,50 @@ err: return NULL; } +static struct expr *netlink_parse_concat_key(struct netlink_parse_ctx *ctx, + const struct location *loc, + unsigned int reg, + const struct expr *key) +{ + uint32_t type = key->dtype->type; + unsigned int n, len = key->len; + struct expr *concat, *expr; + unsigned int consumed; + + concat = concat_expr_alloc(loc); + n = div_round_up(fls(type), TYPE_BITS); + + while (len > 0) { + const struct datatype *i; + + expr = netlink_get_register(ctx, loc, reg); + if (expr == NULL) { + netlink_error(ctx, loc, + "Concat expression size mismatch"); + goto err; + } + + if (n > 0 && concat_subtype_id(type, --n)) { + i = concat_subtype_lookup(type, n); + + expr_set_type(expr, i, i->byteorder); + } + + compound_expr_add(concat, expr); + + consumed = netlink_padded_len(expr->len); + assert(consumed > 0); + len -= consumed; + reg += netlink_register_space(expr->len); + } + + return concat; + +err: + expr_free(concat); + return NULL; +} + static struct expr *netlink_parse_concat_data(struct netlink_parse_ctx *ctx, const struct location *loc, unsigned int reg, @@ -156,6 +208,7 @@ static struct expr *netlink_parse_concat_data(struct netlink_parse_ctx *ctx, len -= netlink_padded_len(expr->len); reg += netlink_register_space(expr->len); + expr_free(expr); } return concat; @@ -174,6 +227,13 @@ static void netlink_parse_chain_verdict(struct netlink_parse_ctx *ctx, expr_chain_export(expr->chain, chain_name); chain = chain_binding_lookup(ctx->table, chain_name); + + /* Special case: 'nft list chain x y' needs to pull in implicit chains */ + if (!chain && !strncmp(chain_name, "__chain", strlen("__chain"))) { + nft_chain_cache_update(ctx->nlctx, ctx->table, chain_name); + chain = chain_binding_lookup(ctx->table, chain_name); + } + if (chain) { ctx->stmt = chain_stmt_alloc(loc, chain, verdict); expr_free(expr); @@ -427,7 +487,7 @@ static struct expr *netlink_parse_bitwise_bool(struct netlink_parse_ctx *ctx, mpz_ior(m, m, o); } - if (left->len > 0 && mpz_scan0(m, 0) == left->len) { + if (left->len > 0 && mpz_scan0(m, 0) >= left->len) { /* mask encompasses the entire value */ expr_free(mask); } else { @@ -475,7 +535,7 @@ static struct expr *netlink_parse_bitwise_shift(struct netlink_parse_ctx *ctx, right->byteorder = BYTEORDER_HOST_ENDIAN; expr = binop_expr_alloc(loc, op, left, right); - expr->len = left->len; + expr->len = nftnl_expr_get_u32(nle, NFTNL_EXPR_BITWISE_LEN) * BITS_PER_BYTE; return expr; } @@ -561,6 +621,10 @@ static void netlink_parse_payload_expr(struct netlink_parse_ctx *ctx, struct expr *expr; base = nftnl_expr_get_u32(nle, NFTNL_EXPR_PAYLOAD_BASE) + 1; + + if (base == NFT_PAYLOAD_TUN_HEADER + 1) + base = NFT_PAYLOAD_INNER_HEADER + 1; + offset = nftnl_expr_get_u32(nle, NFTNL_EXPR_PAYLOAD_OFFSET) * BITS_PER_BYTE; len = nftnl_expr_get_u32(nle, NFTNL_EXPR_PAYLOAD_LEN) * BITS_PER_BYTE; @@ -568,9 +632,80 @@ static void netlink_parse_payload_expr(struct netlink_parse_ctx *ctx, payload_init_raw(expr, base, offset, len); dreg = netlink_parse_register(nle, NFTNL_EXPR_PAYLOAD_DREG); + + if (ctx->inner) + ctx->inner_reg = dreg; + netlink_set_register(ctx, dreg, expr); } +static void netlink_parse_inner(struct netlink_parse_ctx *ctx, + const struct location *loc, + const struct nftnl_expr *nle) +{ + const struct proto_desc *inner_desc; + const struct nftnl_expr *inner_nle; + uint32_t hdrsize, flags, type; + struct expr *expr; + + hdrsize = nftnl_expr_get_u32(nle, NFTNL_EXPR_INNER_HDRSIZE); + type = nftnl_expr_get_u32(nle, NFTNL_EXPR_INNER_TYPE); + flags = nftnl_expr_get_u32(nle, NFTNL_EXPR_INNER_FLAGS); + + inner_nle = nftnl_expr_get(nle, NFTNL_EXPR_INNER_EXPR, NULL); + if (!inner_nle) { + netlink_error(ctx, loc, "Could not parse inner expression"); + return; + } + + inner_desc = proto_find_inner(type, hdrsize, flags); + + ctx->inner = true; + if (netlink_parse_expr(inner_nle, ctx) < 0) { + ctx->inner = false; + return; + } + ctx->inner = false; + + expr = netlink_get_register(ctx, loc, ctx->inner_reg); + assert(expr); + + if (expr->etype == EXPR_PAYLOAD && + expr->payload.base == PROTO_BASE_INNER_HDR) { + const struct proto_hdr_template *tmpl; + unsigned int i; + + for (i = 1; i < array_size(inner_desc->templates); i++) { + tmpl = &inner_desc->templates[i]; + + if (tmpl->len == 0) + return; + + if (tmpl->offset != expr->payload.offset || + tmpl->len != expr->len) + continue; + + expr->payload.desc = inner_desc; + expr->payload.tmpl = tmpl; + break; + } + } + + switch (expr->etype) { + case EXPR_PAYLOAD: + expr->payload.inner_desc = inner_desc; + break; + case EXPR_META: + expr->meta.inner_desc = inner_desc; + break; + default: + assert(0); + break; + } + + netlink_set_register(ctx, ctx->inner_reg, expr); +} + static void netlink_parse_payload_stmt(struct netlink_parse_ctx *ctx, const struct location *loc, const struct nftnl_expr *nle) @@ -637,7 +772,7 @@ static void netlink_parse_exthdr(struct netlink_parse_ctx *ctx, sreg = netlink_parse_register(nle, NFTNL_EXPR_EXTHDR_SREG); val = netlink_get_register(ctx, loc, sreg); if (val == NULL) { - xfree(expr); + expr_free(expr); return netlink_error(ctx, loc, "exthdr statement has no expression"); } @@ -646,6 +781,10 @@ static void netlink_parse_exthdr(struct netlink_parse_ctx *ctx, stmt = exthdr_stmt_alloc(loc, expr, val); rule_stmt_append(ctx->rule, stmt); + } else { + struct stmt *stmt = optstrip_stmt_alloc(loc, expr); + + rule_stmt_append(ctx->rule, stmt); } } @@ -679,7 +818,7 @@ static void netlink_parse_hash(struct netlink_parse_ctx *ctx, len = nftnl_expr_get_u32(nle, NFTNL_EXPR_HASH_LEN) * BITS_PER_BYTE; if (hexpr->len < len) { - xfree(hexpr); + expr_free(hexpr); hexpr = netlink_parse_concat_expr(ctx, loc, sreg, len); if (hexpr == NULL) goto out_err; @@ -691,7 +830,7 @@ static void netlink_parse_hash(struct netlink_parse_ctx *ctx, netlink_set_register(ctx, dreg, expr); return; out_err: - xfree(expr); + expr_free(expr); } static void netlink_parse_fib(struct netlink_parse_ctx *ctx, @@ -723,6 +862,9 @@ static void netlink_parse_meta_expr(struct netlink_parse_ctx *ctx, expr = meta_expr_alloc(loc, key); dreg = netlink_parse_register(nle, NFTNL_EXPR_META_DREG); + if (ctx->inner) + ctx->inner_reg = dreg; + netlink_set_register(ctx, dreg, expr); } @@ -731,11 +873,12 @@ static void netlink_parse_socket(struct netlink_parse_ctx *ctx, const struct nftnl_expr *nle) { enum nft_registers dreg; - uint32_t key; + uint32_t key, level; struct expr * expr; key = nftnl_expr_get_u32(nle, NFTNL_EXPR_SOCKET_KEY); - expr = socket_expr_alloc(loc, key); + level = nftnl_expr_get_u32(nle, NFTNL_EXPR_SOCKET_LEVEL); + expr = socket_expr_alloc(loc, key, level); dreg = netlink_parse_register(nle, NFTNL_EXPR_SOCKET_DREG); netlink_set_register(ctx, dreg, expr); @@ -775,7 +918,9 @@ static void netlink_parse_meta_stmt(struct netlink_parse_ctx *ctx, key = nftnl_expr_get_u32(nle, NFTNL_EXPR_META_KEY); stmt = meta_stmt_alloc(loc, key, expr); - expr_set_type(expr, stmt->meta.tmpl->dtype, stmt->meta.tmpl->byteorder); + + if (stmt->meta.tmpl) + expr_set_type(expr, stmt->meta.tmpl->dtype, stmt->meta.tmpl->byteorder); ctx->stmt = stmt; } @@ -922,6 +1067,19 @@ static void netlink_parse_counter(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; } +static void netlink_parse_last(struct netlink_parse_ctx *ctx, + const struct location *loc, + const struct nftnl_expr *nle) +{ + struct stmt *stmt; + + stmt = last_stmt_alloc(loc); + stmt->last.used = nftnl_expr_get_u64(nle, NFTNL_EXPR_LAST_MSECS); + stmt->last.set = nftnl_expr_get_u32(nle, NFTNL_EXPR_LAST_SET); + + ctx->stmt = stmt; +} + static void netlink_parse_log(struct netlink_parse_ctx *ctx, const struct location *loc, const struct nftnl_expr *nle) @@ -1014,7 +1172,8 @@ static void netlink_parse_reject(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; } -static bool is_nat_addr_map(const struct expr *addr, uint8_t family) +static bool is_nat_addr_map(const struct expr *addr, uint8_t family, + struct stmt *stmt) { const struct expr *mappings, *data; const struct set *set; @@ -1033,14 +1192,31 @@ static bool is_nat_addr_map(const struct expr *addr, uint8_t family) if (!(data->flags & EXPR_F_INTERVAL)) return false; + stmt->nat.family = family; + /* if we're dealing with an address:address map, * the length will be bit_sizeof(addr) + 32 (one register). */ switch (family) { case NFPROTO_IPV4: - return data->len == 32 + 32; + if (data->len == 32 + 32) { + stmt->nat.type_flags |= STMT_NAT_F_INTERVAL; + return true; + } else if (data->len == 32 + 32 + 32 + 32) { + stmt->nat.type_flags |= STMT_NAT_F_INTERVAL | + STMT_NAT_F_CONCAT; + return true; + } + break; case NFPROTO_IPV6: - return data->len == 128 + 128; + if (data->len == 128 + 128) { + stmt->nat.type_flags |= STMT_NAT_F_INTERVAL; + return true; + } else if (data->len == 128 + 32 + 128 + 32) { + stmt->nat.type_flags |= STMT_NAT_F_INTERVAL | + STMT_NAT_F_CONCAT; + return true; + } } return false; @@ -1116,9 +1292,8 @@ static void netlink_parse_nat(struct netlink_parse_ctx *ctx, stmt->nat.addr = addr; } - if (is_nat_addr_map(addr, family)) { + if (is_nat_addr_map(addr, family, stmt)) { stmt->nat.family = family; - stmt->nat.type_flags |= STMT_NAT_F_INTERVAL; ctx->stmt = stmt; return; } @@ -1182,7 +1357,7 @@ static void netlink_parse_nat(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; return; out_err: - xfree(stmt); + stmt_free(stmt); } static void netlink_parse_synproxy(struct netlink_parse_ctx *ctx, @@ -1246,7 +1421,7 @@ static void netlink_parse_tproxy(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; return; err: - xfree(stmt); + stmt_free(stmt); } static void netlink_parse_masq(struct netlink_parse_ctx *ctx, @@ -1293,7 +1468,7 @@ static void netlink_parse_masq(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; return; out_err: - xfree(stmt); + stmt_free(stmt); } static void netlink_parse_redir(struct netlink_parse_ctx *ctx, @@ -1344,7 +1519,7 @@ static void netlink_parse_redir(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; return; out_err: - xfree(stmt); + stmt_free(stmt); } static void netlink_parse_dup(struct netlink_parse_ctx *ctx, @@ -1397,7 +1572,7 @@ static void netlink_parse_dup(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; return; out_err: - xfree(stmt); + stmt_free(stmt); } static void netlink_parse_fwd(struct netlink_parse_ctx *ctx, @@ -1459,34 +1634,43 @@ static void netlink_parse_fwd(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; return; out_err: - xfree(stmt); + stmt_free(stmt); } static void netlink_parse_queue(struct netlink_parse_ctx *ctx, const struct location *loc, const struct nftnl_expr *nle) { - struct expr *expr, *high; - struct stmt *stmt; - uint16_t num, total; + struct expr *expr; + uint16_t flags; + + if (nftnl_expr_is_set(nle, NFTNL_EXPR_QUEUE_SREG_QNUM)) { + enum nft_registers reg = netlink_parse_register(nle, NFTNL_EXPR_QUEUE_SREG_QNUM); + + expr = netlink_get_register(ctx, loc, reg); + if (!expr) { + netlink_error(ctx, loc, "queue statement has no sreg expression"); + return; + } + } else { + uint16_t total = nftnl_expr_get_u16(nle, NFTNL_EXPR_QUEUE_TOTAL); + uint16_t num = nftnl_expr_get_u16(nle, NFTNL_EXPR_QUEUE_NUM); - num = nftnl_expr_get_u16(nle, NFTNL_EXPR_QUEUE_NUM); - total = nftnl_expr_get_u16(nle, NFTNL_EXPR_QUEUE_TOTAL); + expr = constant_expr_alloc(loc, &integer_type, + BYTEORDER_HOST_ENDIAN, 16, &num); - expr = constant_expr_alloc(loc, &integer_type, - BYTEORDER_HOST_ENDIAN, 16, &num); - if (total > 1) { - total += num - 1; - high = constant_expr_alloc(loc, &integer_type, + if (total > 1) { + struct expr *high; + + total += num - 1; + high = constant_expr_alloc(loc, &integer_type, BYTEORDER_HOST_ENDIAN, 16, &total); - expr = range_expr_alloc(loc, expr, high); + expr = range_expr_alloc(loc, expr, high); + } } - stmt = queue_stmt_alloc(loc); - stmt->queue.queue = expr; - stmt->queue.flags = nftnl_expr_get_u16(nle, NFTNL_EXPR_QUEUE_FLAGS); - - ctx->stmt = stmt; + flags = nftnl_expr_get_u16(nle, NFTNL_EXPR_QUEUE_FLAGS); + ctx->stmt = queue_stmt_alloc(loc, expr, flags); } struct dynset_parse_ctx { @@ -1545,9 +1729,11 @@ static void netlink_parse_dynset(struct netlink_parse_ctx *ctx, if (expr->len < set->key->len) { expr_free(expr); - expr = netlink_parse_concat_expr(ctx, loc, sreg, set->key->len); + expr = netlink_parse_concat_key(ctx, loc, sreg, set->key); if (expr == NULL) return; + } else if (expr->dtype == &invalid_type) { + expr_set_type(expr, datatype_get(set->key->dtype), set->key->byteorder); } expr = set_elem_expr_alloc(&expr->location, expr); @@ -1577,6 +1763,14 @@ static void netlink_parse_dynset(struct netlink_parse_ctx *ctx, if (nftnl_expr_is_set(nle, NFTNL_EXPR_DYNSET_SREG_DATA)) { sreg_data = netlink_parse_register(nle, NFTNL_EXPR_DYNSET_SREG_DATA); expr_data = netlink_get_register(ctx, loc, sreg_data); + + if (expr_data && expr_data->len < set->data->len) { + expr_free(expr_data); + expr_data = netlink_parse_concat_expr(ctx, loc, sreg_data, set->data->len); + if (expr_data == NULL) + netlink_error(ctx, loc, + "Could not parse dynset map data expressions"); + } } if (expr_data != NULL) { @@ -1590,7 +1784,7 @@ static void netlink_parse_dynset(struct netlink_parse_ctx *ctx, &stmt->map.stmt_list); } else { if (!list_empty(&dynset_parse_ctx.stmt_list) && - set->flags & NFT_SET_ANONYMOUS) { + set_is_anonymous(set->flags)) { stmt = meter_stmt_alloc(loc); stmt->meter.set = set_ref_expr_alloc(loc, set); stmt->meter.key = expr; @@ -1613,7 +1807,7 @@ out_err: list_for_each_entry_safe(dstmt, next, &dynset_parse_ctx.stmt_list, list) stmt_free(dstmt); - xfree(expr); + expr_free(expr); } static void netlink_parse_objref(struct netlink_parse_ctx *ctx, @@ -1689,6 +1883,7 @@ static const struct expr_handler netlink_parsers[] = { { .name = "bitwise", .parse = netlink_parse_bitwise }, { .name = "byteorder", .parse = netlink_parse_byteorder }, { .name = "payload", .parse = netlink_parse_payload }, + { .name = "inner", .parse = netlink_parse_inner }, { .name = "exthdr", .parse = netlink_parse_exthdr }, { .name = "meta", .parse = netlink_parse_meta }, { .name = "socket", .parse = netlink_parse_socket }, @@ -1697,6 +1892,7 @@ static const struct expr_handler netlink_parsers[] = { { .name = "ct", .parse = netlink_parse_ct }, { .name = "connlimit", .parse = netlink_parse_connlimit }, { .name = "counter", .parse = netlink_parse_counter }, + { .name = "last", .parse = netlink_parse_last }, { .name = "log", .parse = netlink_parse_log }, { .name = "limit", .parse = netlink_parse_limit }, { .name = "range", .parse = netlink_parse_range }, @@ -1723,47 +1919,26 @@ static const struct expr_handler netlink_parsers[] = { { .name = "synproxy", .parse = netlink_parse_synproxy }, }; -static const struct expr_handler **expr_handle_ht; - -#define NFT_EXPR_HSIZE 4096 - -void expr_handler_init(void) -{ - unsigned int i; - uint32_t hash; - - expr_handle_ht = calloc(NFT_EXPR_HSIZE, sizeof(expr_handle_ht)); - if (!expr_handle_ht) - memory_allocation_error(); - - for (i = 0; i < array_size(netlink_parsers); i++) { - hash = djb_hash(netlink_parsers[i].name) % NFT_EXPR_HSIZE; - assert(expr_handle_ht[hash] == NULL); - expr_handle_ht[hash] = &netlink_parsers[i]; - } -} - -void expr_handler_exit(void) -{ - xfree(expr_handle_ht); -} - static int netlink_parse_expr(const struct nftnl_expr *nle, struct netlink_parse_ctx *ctx) { const char *type = nftnl_expr_get_str(nle, NFTNL_EXPR_NAME); struct location loc; - uint32_t hash; + unsigned int i; memset(&loc, 0, sizeof(loc)); loc.indesc = &indesc_netlink; loc.nle = nle; - hash = djb_hash(type) % NFT_EXPR_HSIZE; - if (expr_handle_ht[hash]) - expr_handle_ht[hash]->parse(ctx, &loc, nle); - else - netlink_error(ctx, &loc, "unknown expression type '%s'", type); + for (i = 0; i < array_size(netlink_parsers); i++) { + if (strcmp(type, netlink_parsers[i].name)) + continue; + + netlink_parsers[i].parse(ctx, &loc, nle); + + return 0; + } + netlink_error(ctx, &loc, "unknown expression type '%s'", type); return 0; } @@ -1806,6 +1981,37 @@ struct stmt *netlink_parse_set_expr(const struct set *set, return pctx->stmt; } +static bool meta_outer_may_dependency_kill(struct rule_pp_ctx *ctx, + const struct expr *expr) +{ + struct dl_proto_ctx *dl_outer = dl_proto_ctx_outer(ctx); + struct stmt *stmt = dl_outer->pdctx.pdeps[expr->payload.inner_desc->base]; + struct expr *dep; + uint8_t l4proto; + + if (!stmt) + return false; + + dep = stmt->expr; + + if (dep->left->meta.key != NFT_META_L4PROTO) + return false; + + l4proto = mpz_get_uint8(dep->right->value); + + switch (l4proto) { + case IPPROTO_GRE: + if (expr->payload.inner_desc == &proto_gre || + expr->payload.inner_desc == &proto_gretap) + return true; + break; + default: + break; + } + + return false; +} + static void expr_postprocess(struct rule_pp_ctx *ctx, struct expr **exprp); static void payload_match_expand(struct rule_pp_ctx *ctx, @@ -1814,15 +2020,12 @@ static void payload_match_expand(struct rule_pp_ctx *ctx, { struct expr *left = payload, *right = expr->right, *tmp; struct list_head list = LIST_HEAD_INIT(list); - struct stmt *nstmt; - struct expr *nexpr = NULL; + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); enum proto_bases base = left->payload.base; - bool stacked; - - if (ctx->pdctx.icmp_type) - ctx->pctx.th_dep.icmp.type = ctx->pdctx.icmp_type; + struct expr *nexpr = NULL; + struct stmt *nstmt; - payload_expr_expand(&list, left, &ctx->pctx); + payload_expr_expand(&list, left, &dl->pctx); list_for_each_entry(left, &list, list) { tmp = constant_expr_splice(right, left->len); @@ -1837,7 +2040,7 @@ static void payload_match_expand(struct rule_pp_ctx *ctx, nexpr = relational_expr_alloc(&expr->location, expr->op, left, tmp); if (expr->op == OP_EQ) - relational_expr_pctx_update(&ctx->pctx, nexpr); + relational_expr_pctx_update(&dl->pctx, nexpr); nstmt = expr_stmt_alloc(&ctx->stmt->location, nexpr); list_add_tail(&nstmt->list, &ctx->stmt->list); @@ -1846,36 +2049,95 @@ static void payload_match_expand(struct rule_pp_ctx *ctx, assert(left->payload.base); assert(base == left->payload.base); - stacked = payload_is_stacked(ctx->pctx.protocol[base].desc, nexpr); + if (expr->left->payload.inner_desc) { + if (expr->left->payload.inner_desc == expr->left->payload.desc) { + nexpr->left->payload.desc = expr->left->payload.desc; + nexpr->left->payload.tmpl = expr->left->payload.tmpl; + } + nexpr->left->payload.inner_desc = expr->left->payload.inner_desc; + + if (meta_outer_may_dependency_kill(ctx, expr->left)) { + struct dl_proto_ctx *dl_outer = dl_proto_ctx_outer(ctx); + + payload_dependency_release(&dl_outer->pdctx, expr->left->payload.inner_desc->base); + } + } + + if (payload_is_stacked(dl->pctx.protocol[base].desc, nexpr)) + base--; /* Remember the first payload protocol expression to * kill it later on if made redundant by a higher layer * payload expression. */ - if (ctx->pdctx.pbase == PROTO_BASE_INVALID && - expr->op == OP_EQ && - left->flags & EXPR_F_PROTOCOL) { - payload_dependency_store(&ctx->pdctx, nstmt, base - stacked); - } else { - payload_dependency_kill(&ctx->pdctx, nexpr->left, - ctx->pctx.family); - if (expr->op == OP_EQ && left->flags & EXPR_F_PROTOCOL) - payload_dependency_store(&ctx->pdctx, nstmt, base - stacked); - } + payload_dependency_kill(&dl->pdctx, nexpr->left, + dl->pctx.family); + if (expr->op == OP_EQ && left->flags & EXPR_F_PROTOCOL) + payload_dependency_store(&dl->pdctx, nstmt, base); } list_del(&ctx->stmt->list); stmt_free(ctx->stmt); ctx->stmt = NULL; } +static void payload_icmp_check(struct rule_pp_ctx *rctx, struct expr *expr, const struct expr *value) +{ + struct dl_proto_ctx *dl = dl_proto_ctx(rctx); + const struct proto_hdr_template *tmpl; + const struct proto_desc *desc; + uint8_t icmp_type; + unsigned int i; + + assert(expr->etype == EXPR_PAYLOAD); + assert(value->etype == EXPR_VALUE); + + if (expr->payload.base != PROTO_BASE_TRANSPORT_HDR) + return; + + /* icmp(v6) type is 8 bit, if value is smaller or larger, this is not + * a protocol dependency. + */ + if (expr->len != 8 || value->len != 8 || dl->pctx.th_dep.icmp.type) + return; + + desc = dl->pctx.protocol[expr->payload.base].desc; + if (desc == NULL) + return; + + /* not icmp? ignore. */ + if (desc != &proto_icmp && desc != &proto_icmp6) + return; + + assert(desc->base == expr->payload.base); + + icmp_type = mpz_get_uint8(value->value); + + for (i = 1; i < array_size(desc->templates); i++) { + tmpl = &desc->templates[i]; + + if (tmpl->len == 0) + return; + + if (tmpl->offset != expr->payload.offset || + tmpl->len != expr->len) + continue; + + /* Matches but doesn't load a protocol key -> ignore. */ + if (desc->protocol_key != i) + return; + + expr->payload.desc = desc; + expr->payload.tmpl = tmpl; + dl->pctx.th_dep.icmp.type = icmp_type; + return; + } +} + static void payload_match_postprocess(struct rule_pp_ctx *ctx, struct expr *expr, struct expr *payload) { - enum proto_bases base = payload->payload.base; - - assert(payload->payload.offset >= ctx->pctx.protocol[base].offset); - payload->payload.offset -= ctx->pctx.protocol[base].offset; + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); switch (expr->op) { case OP_EQ: @@ -1883,19 +2145,82 @@ static void payload_match_postprocess(struct rule_pp_ctx *ctx, if (expr->right->etype == EXPR_VALUE) { payload_match_expand(ctx, expr, payload); break; + } else if (expr->right->etype == EXPR_SET_REF) { + struct set *set = expr->right->set; + + if (set_is_anonymous(set->flags) && + set->init && + !list_empty(&set->init->expressions)) { + struct expr *elem; + + elem = list_first_entry(&set->init->expressions, struct expr, list); + + if (elem->etype == EXPR_SET_ELEM && + elem->key->etype == EXPR_VALUE) + payload_icmp_check(ctx, payload, elem->key); + } } /* Fall through */ default: - payload_expr_complete(payload, &ctx->pctx); + payload_expr_complete(payload, &dl->pctx); expr_set_type(expr->right, payload->dtype, payload->byteorder); - payload_dependency_kill(&ctx->pdctx, payload, ctx->pctx.family); + payload_dependency_kill(&dl->pdctx, payload, dl->pctx.family); + break; + } +} + +static uint8_t ether_type_to_nfproto(uint16_t l3proto) +{ + switch(l3proto) { + case ETH_P_IP: + return NFPROTO_IPV4; + case ETH_P_IPV6: + return NFPROTO_IPV6; + default: break; } + + return NFPROTO_UNSPEC; +} + +static bool __meta_dependency_may_kill(const struct expr *dep, uint8_t *nfproto) +{ + uint16_t l3proto; + + switch (dep->left->etype) { + case EXPR_META: + switch (dep->left->meta.key) { + case NFT_META_NFPROTO: + *nfproto = mpz_get_uint8(dep->right->value); + break; + case NFT_META_PROTOCOL: + l3proto = mpz_get_uint16(dep->right->value); + *nfproto = ether_type_to_nfproto(l3proto); + break; + default: + return true; + } + break; + case EXPR_PAYLOAD: + if (dep->left->payload.base != PROTO_BASE_LL_HDR) + return true; + + if (dep->left->dtype != ðertype_type) + return true; + + l3proto = mpz_get_uint16(dep->right->value); + *nfproto = ether_type_to_nfproto(l3proto); + break; + default: + return true; + } + + return false; } /* We have seen a protocol key expression that restricts matching at the network - * base, leave it in place since this is meaninful in bridge, inet and netdev + * base, leave it in place since this is meaningful in bridge, inet and netdev * families. Exceptions are ICMP and ICMPv6 where this code assumes that can * only happen with IPv4 and IPv6. */ @@ -1903,11 +2228,13 @@ static bool meta_may_dependency_kill(struct payload_dep_ctx *ctx, unsigned int family, const struct expr *expr) { - struct expr *dep = ctx->pdep->expr; - uint16_t l3proto; - uint8_t l4proto; + uint8_t l4proto, nfproto = NFPROTO_UNSPEC; + struct expr *dep = payload_dependency_get(ctx, PROTO_BASE_NETWORK_HDR); - if (ctx->pbase != PROTO_BASE_NETWORK_HDR) + if (!dep) + return true; + + if (__meta_dependency_may_kill(dep, &nfproto)) return true; switch (family) { @@ -1915,6 +2242,9 @@ static bool meta_may_dependency_kill(struct payload_dep_ctx *ctx, case NFPROTO_NETDEV: case NFPROTO_BRIDGE: break; + case NFPROTO_IPV4: + case NFPROTO_IPV6: + return family == nfproto; default: return true; } @@ -1922,32 +2252,6 @@ static bool meta_may_dependency_kill(struct payload_dep_ctx *ctx, if (expr->left->meta.key != NFT_META_L4PROTO) return true; - l3proto = mpz_get_uint16(dep->right->value); - - switch (dep->left->etype) { - case EXPR_META: - if (dep->left->meta.key != NFT_META_NFPROTO) - return true; - break; - case EXPR_PAYLOAD: - if (dep->left->payload.base != PROTO_BASE_LL_HDR) - return true; - - switch(l3proto) { - case ETH_P_IP: - l3proto = NFPROTO_IPV4; - break; - case ETH_P_IPV6: - l3proto = NFPROTO_IPV6; - break; - default: - break; - } - break; - default: - break; - } - l4proto = mpz_get_uint8(expr->right->value); switch (l4proto) { @@ -1958,8 +2262,8 @@ static bool meta_may_dependency_kill(struct payload_dep_ctx *ctx, return false; } - if ((l3proto == NFPROTO_IPV4 && l4proto == IPPROTO_ICMPV6) || - (l3proto == NFPROTO_IPV6 && l4proto == IPPROTO_ICMP)) + if ((nfproto == NFPROTO_IPV4 && l4proto == IPPROTO_ICMPV6) || + (nfproto == NFPROTO_IPV6 && l4proto == IPPROTO_ICMP)) return false; return true; @@ -1969,6 +2273,7 @@ static void ct_meta_common_postprocess(struct rule_pp_ctx *ctx, const struct expr *expr, enum proto_bases base) { + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); const struct expr *left = expr->left; struct expr *right = expr->right; @@ -1982,18 +2287,16 @@ static void ct_meta_common_postprocess(struct rule_pp_ctx *ctx, expr->right->etype == EXPR_SET_REF) break; - relational_expr_pctx_update(&ctx->pctx, expr); + relational_expr_pctx_update(&dl->pctx, expr); + + if (base < PROTO_BASE_TRANSPORT_HDR) { + if (payload_dependency_exists(&dl->pdctx, base) && + meta_may_dependency_kill(&dl->pdctx, + dl->pctx.family, expr)) + payload_dependency_release(&dl->pdctx, base); - if (ctx->pdctx.pbase == PROTO_BASE_INVALID && - left->flags & EXPR_F_PROTOCOL) { - payload_dependency_store(&ctx->pdctx, ctx->stmt, base); - } else if (ctx->pdctx.pbase < PROTO_BASE_TRANSPORT_HDR) { - if (payload_dependency_exists(&ctx->pdctx, base) && - meta_may_dependency_kill(&ctx->pdctx, - ctx->pctx.family, expr)) - payload_dependency_release(&ctx->pdctx); if (left->flags & EXPR_F_PROTOCOL) - payload_dependency_store(&ctx->pdctx, ctx->stmt, base); + payload_dependency_store(&dl->pdctx, ctx->stmt, base); } break; default: @@ -2078,8 +2381,8 @@ static void binop_adjust_one(const struct expr *binop, struct expr *value, } } -static void __binop_adjust(const struct expr *binop, struct expr *right, - unsigned int shift) +static void binop_adjust(const struct expr *binop, struct expr *right, + unsigned int shift) { struct expr *i; @@ -2088,6 +2391,9 @@ static void __binop_adjust(const struct expr *binop, struct expr *right, binop_adjust_one(binop, right, shift); break; case EXPR_SET_REF: + if (!set_is_anonymous(right->set->flags)) + break; + list_for_each_entry(i, &right->set->init->expressions, list) { switch (i->key->etype) { case EXPR_VALUE: @@ -2098,7 +2404,7 @@ static void __binop_adjust(const struct expr *binop, struct expr *right, binop_adjust_one(binop, i->key->right, shift); break; case EXPR_SET_ELEM: - __binop_adjust(binop, i->key->key, shift); + binop_adjust(binop, i->key->key, shift); break; default: BUG("unknown expression type %s\n", expr_name(i->key)); @@ -2115,22 +2421,24 @@ static void __binop_adjust(const struct expr *binop, struct expr *right, } } -static void binop_adjust(struct expr *expr, unsigned int shift) +static void __binop_postprocess(struct rule_pp_ctx *ctx, + struct expr *expr, + struct expr *left, + struct expr *mask, + struct expr **expr_binop) { - __binop_adjust(expr->left, expr->right, shift); -} - -static void binop_postprocess(struct rule_pp_ctx *ctx, struct expr *expr) -{ - struct expr *binop = expr->left; - struct expr *left = binop->left; - struct expr *mask = binop->right; + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); + struct expr *binop = *expr_binop; unsigned int shift; + assert(binop->etype == EXPR_BINOP); + if ((left->etype == EXPR_PAYLOAD && - payload_expr_trim(left, mask, &ctx->pctx, &shift)) || + payload_expr_trim(left, mask, &dl->pctx, &shift)) || (left->etype == EXPR_EXTHDR && exthdr_find_template(left, mask, &shift))) { + struct expr *right = NULL; + /* mask is implicit, binop needs to be removed. * * Fix all values of the expression according to the mask @@ -2140,45 +2448,84 @@ static void binop_postprocess(struct rule_pp_ctx *ctx, struct expr *expr) * Finally, convert the expression to 1) by replacing * the binop with the binop payload/exthdr expression. */ - binop_adjust(expr, shift); + switch (expr->etype) { + case EXPR_BINOP: + case EXPR_RELATIONAL: + right = expr->right; + binop_adjust(binop, right, shift); + break; + case EXPR_MAP: + right = expr->mappings; + binop_adjust(binop, right, shift); + break; + default: + break; + } - assert(expr->left->etype == EXPR_BINOP); assert(binop->left == left); - expr->left = expr_get(left); - expr_free(binop); + *expr_binop = expr_get(left); + if (left->etype == EXPR_PAYLOAD) payload_match_postprocess(ctx, expr, left); - else if (left->etype == EXPR_EXTHDR) - expr_set_type(expr->right, left->dtype, left->byteorder); + else if (left->etype == EXPR_EXTHDR && right) + expr_set_type(right, left->dtype, left->byteorder); + + expr_free(binop); } } +static void binop_postprocess(struct rule_pp_ctx *ctx, struct expr *expr, + struct expr **expr_binop) +{ + struct expr *binop = *expr_binop; + struct expr *left = binop->left; + struct expr *mask = binop->right; + + __binop_postprocess(ctx, expr, left, mask, expr_binop); +} + static void map_binop_postprocess(struct rule_pp_ctx *ctx, struct expr *expr) { - struct expr *binop = expr->left; + struct expr *binop = expr->map; if (binop->op != OP_AND) return; if (binop->left->etype == EXPR_PAYLOAD && binop->right->etype == EXPR_VALUE) - binop_postprocess(ctx, expr); + binop_postprocess(ctx, expr, &expr->map); +} + +static bool is_shift_by_zero(const struct expr *binop) +{ + struct expr *rhs; + + if (binop->op != OP_RSHIFT && binop->op != OP_LSHIFT) + return false; + + rhs = binop->right; + if (rhs->etype != EXPR_VALUE || rhs->len > 64) + return false; + + return mpz_get_uint64(rhs->value) == 0; } -static void relational_binop_postprocess(struct rule_pp_ctx *ctx, struct expr *expr) +static void relational_binop_postprocess(struct rule_pp_ctx *ctx, + struct expr **exprp) { - struct expr *binop = expr->left, *value = expr->right; + struct expr *expr = *exprp, *binop = expr->left, *right = expr->right; if (binop->op == OP_AND && (expr->op == OP_NEQ || expr->op == OP_EQ) && - value->dtype->basetype && - value->dtype->basetype->type == TYPE_BITMASK && - !mpz_cmp_ui(value->value, 0)) { + right->dtype->basetype && + right->dtype->basetype->type == TYPE_BITMASK && + right->etype == EXPR_VALUE && + !mpz_cmp_ui(right->value, 0)) { /* Flag comparison: data & flags != 0 * * Split the flags into a list of flag values and convert the * op to OP_EQ. */ - expr_free(value); + expr_free(right); expr->left = expr_get(binop->left); expr->right = binop_tree_to_list(NULL, binop->right); @@ -2198,9 +2545,9 @@ static void relational_binop_postprocess(struct rule_pp_ctx *ctx, struct expr *e expr_mask_is_prefix(binop->right)) { expr->left = expr_get(binop->left); expr->right = prefix_expr_alloc(&expr->location, - expr_get(value), + expr_get(right), expr_mask_to_prefix(binop->right)); - expr_free(value); + expr_free(right); expr_free(binop); } else if (binop->op == OP_AND && binop->right->etype == EXPR_VALUE) { @@ -2227,8 +2574,58 @@ static void relational_binop_postprocess(struct rule_pp_ctx *ctx, struct expr *e * payload_expr_trim will figure out if the mask is needed to match * templates. */ - binop_postprocess(ctx, expr); + binop_postprocess(ctx, expr, &expr->left); + } else if (binop->op == OP_RSHIFT && binop->left->op == OP_AND && + binop->right->etype == EXPR_VALUE && binop->left->right->etype == EXPR_VALUE) { + /* Handle 'ip version @s4' and similar, i.e. set lookups where the lhs needs + * fixups to mask out unwanted bits AND a shift. + */ + + binop_postprocess(ctx, binop, &binop->left); + if (is_shift_by_zero(binop)) { + struct expr *lhs = binop->left; + + expr_get(lhs); + expr_free(binop); + expr->left = lhs; + } + } +} + +static bool payload_binop_postprocess(struct rule_pp_ctx *ctx, + struct expr **exprp) +{ + struct expr *expr = *exprp; + + if (expr->op != OP_RSHIFT) + return false; + + if (expr->left->etype == EXPR_UNARY) { + /* + * If the payload value was originally in a different byte-order + * from the payload expression, there will be a byte-order + * conversion to remove. + */ + struct expr *left = expr_get(expr->left->arg); + expr_free(expr->left); + expr->left = left; } + + if (expr->left->etype != EXPR_BINOP || expr->left->op != OP_AND) + return false; + + if (expr->left->left->etype != EXPR_PAYLOAD) + return false; + + expr_set_type(expr->right, &integer_type, + BYTEORDER_HOST_ENDIAN); + expr_postprocess(ctx, &expr->right); + + binop_postprocess(ctx, expr, &expr->left); + *exprp = expr_get(expr->left); + expr_free(expr); + + return true; } static struct expr *string_wildcard_expr_alloc(struct location *loc, @@ -2248,56 +2645,33 @@ static struct expr *string_wildcard_expr_alloc(struct location *loc, expr->len + BITS_PER_BYTE, data); } -static void escaped_string_wildcard_expr_alloc(struct expr **exprp, - unsigned int len) -{ - struct expr *expr = *exprp, *tmp; - char data[len + 3]; - int pos; - - mpz_export_data(data, expr->value, BYTEORDER_HOST_ENDIAN, len); - pos = div_round_up(len, BITS_PER_BYTE); - data[pos - 1] = '\\'; - data[pos] = '*'; - - tmp = constant_expr_alloc(&expr->location, expr->dtype, - BYTEORDER_HOST_ENDIAN, - expr->len + BITS_PER_BYTE, data); - expr_free(expr); - *exprp = tmp; -} - /* This calculates the string length and checks if it is nul-terminated, this * function is quite a hack :) */ static bool __expr_postprocess_string(struct expr **exprp) { struct expr *expr = *exprp; - unsigned int len = expr->len; - bool nulterminated = false; - mpz_t tmp; - - mpz_init(tmp); - while (len >= BITS_PER_BYTE) { - mpz_bitmask(tmp, BITS_PER_BYTE); - mpz_lshift_ui(tmp, len - BITS_PER_BYTE); - mpz_and(tmp, tmp, expr->value); - if (mpz_cmp_ui(tmp, 0)) - break; - else - nulterminated = true; - len -= BITS_PER_BYTE; - } + unsigned int len = div_round_up(expr->len, BITS_PER_BYTE); + char data[len + 1]; - mpz_rshift_ui(tmp, len - BITS_PER_BYTE); + mpz_export_data(data, expr->value, BYTEORDER_HOST_ENDIAN, len); - if (nulterminated && - mpz_cmp_ui(tmp, '*') == 0) - escaped_string_wildcard_expr_alloc(exprp, len); + if (data[len - 1] != '\0') + return false; - mpz_clear(tmp); + len = strlen(data); + if (len && data[len - 1] == '*') { + data[len - 1] = '\\'; + data[len] = '*'; + data[len + 1] = '\0'; + expr = constant_expr_alloc(&expr->location, expr->dtype, + BYTEORDER_HOST_ENDIAN, + (len + 2) * BITS_PER_BYTE, data); + expr_free(*exprp); + *exprp = expr; + } - return nulterminated; + return true; } static struct expr *expr_postprocess_string(struct expr *expr) @@ -2311,14 +2685,62 @@ static struct expr *expr_postprocess_string(struct expr *expr) mask = constant_expr_alloc(&expr->location, &integer_type, BYTEORDER_HOST_ENDIAN, expr->len + BITS_PER_BYTE, NULL); + mpz_clear(mask->value); mpz_init_bitmask(mask->value, expr->len); out = string_wildcard_expr_alloc(&expr->location, mask, expr); + expr_free(expr); expr_free(mask); return out; } +static void expr_postprocess_value(struct rule_pp_ctx *ctx, struct expr **exprp) +{ + bool interval = (ctx->set && ctx->set->flags & NFT_SET_INTERVAL); + struct expr *expr = *exprp; + + // FIXME + if (expr->byteorder == BYTEORDER_HOST_ENDIAN && !interval) + mpz_switch_byteorder(expr->value, expr->len / BITS_PER_BYTE); + + if (expr_basetype(expr)->type == TYPE_STRING) + *exprp = expr_postprocess_string(expr); + + expr = *exprp; + if (expr->dtype->basetype != NULL && + expr->dtype->basetype->type == TYPE_BITMASK) + *exprp = bitmask_expr_to_binops(expr); +} + +static void expr_postprocess_concat(struct rule_pp_ctx *ctx, struct expr **exprp) +{ + struct expr *i, *n, *expr = *exprp; + unsigned int type = expr->dtype->type, ntype = 0; + int off = expr->dtype->subtypes; + const struct datatype *dtype; + LIST_HEAD(tmp); + + assert(expr->etype == EXPR_CONCAT); + + ctx->flags |= RULE_PP_IN_CONCATENATION; + list_for_each_entry_safe(i, n, &expr->expressions, list) { + if (type) { + dtype = concat_subtype_lookup(type, --off); + expr_set_type(i, dtype, dtype->byteorder); + } + list_del(&i->list); + expr_postprocess(ctx, &i); + list_add_tail(&i->list, &tmp); + + ntype = concat_subtype_add(ntype, i->dtype->type); + } + ctx->flags &= ~RULE_PP_IN_CONCATENATION; + list_splice(&tmp, &expr->expressions); + __datatype_set(expr, concat_type_alloc(ntype)); +} + static void expr_postprocess(struct rule_pp_ctx *ctx, struct expr **exprp) { + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); struct expr *expr = *exprp, *i; switch (expr->etype) { @@ -2342,28 +2764,17 @@ static void expr_postprocess(struct rule_pp_ctx *ctx, struct expr **exprp) list_for_each_entry(i, &expr->expressions, list) expr_postprocess(ctx, &i); break; - case EXPR_CONCAT: { - unsigned int type = expr->dtype->type, ntype = 0; - int off = expr->dtype->subtypes; - const struct datatype *dtype; - - list_for_each_entry(i, &expr->expressions, list) { - if (type) { - dtype = concat_subtype_lookup(type, --off); - expr_set_type(i, dtype, dtype->byteorder); - } - expr_postprocess(ctx, &i); - - ntype = concat_subtype_add(ntype, i->dtype->type); - } - datatype_set(expr, concat_type_alloc(ntype)); + case EXPR_CONCAT: + expr_postprocess_concat(ctx, exprp); break; - } case EXPR_UNARY: expr_postprocess(ctx, &expr->arg); expr_set_type(expr, expr->arg->dtype, !expr->arg->byteorder); break; case EXPR_BINOP: + if (payload_binop_postprocess(ctx, exprp)) + break; + expr_postprocess(ctx, &expr->left); switch (expr->op) { case OP_LSHIFT: @@ -2371,20 +2782,89 @@ static void expr_postprocess(struct rule_pp_ctx *ctx, struct expr **exprp) expr_set_type(expr->right, &integer_type, BYTEORDER_HOST_ENDIAN); break; + case OP_AND: + if (expr->right->len > expr->left->len) { + expr_set_type(expr->right, expr->left->dtype, + BYTEORDER_HOST_ENDIAN); + } else { + expr_set_type(expr->right, expr->left->dtype, + expr->left->byteorder); + } + + /* Do not process OP_AND in ordinary rule context. + * + * Removal needs to be performed as part of the relational + * operation because the RHS constant might need to be adjusted + * (shifted). + * + * This is different in set element context or concatenations: + * There is no relational operation (eq, neq and so on), thus + * it needs to be processed right away. + */ + if ((ctx->flags & RULE_PP_REMOVE_OP_AND) && + expr->left->etype == EXPR_PAYLOAD && + expr->right->etype == EXPR_VALUE) { + __binop_postprocess(ctx, expr, expr->left, expr->right, exprp); + return; + } + break; default: - expr_set_type(expr->right, expr->left->dtype, - expr->left->byteorder); + if (expr->right->len > expr->left->len) { + expr_set_type(expr->right, expr->left->dtype, + BYTEORDER_HOST_ENDIAN); + } else { + expr_set_type(expr->right, expr->left->dtype, + expr->left->byteorder); + } } expr_postprocess(ctx, &expr->right); - expr_set_type(expr, expr->left->dtype, - expr->left->byteorder); + switch (expr->op) { + case OP_LSHIFT: + case OP_RSHIFT: + expr_set_type(expr, &xinteger_type, + BYTEORDER_HOST_ENDIAN); + break; + default: + expr_set_type(expr, expr->left->dtype, + expr->left->byteorder); + } + break; case EXPR_RELATIONAL: switch (expr->left->etype) { case EXPR_PAYLOAD: payload_match_postprocess(ctx, expr, expr->left); return; + case EXPR_CONCAT: + if (expr->right->etype == EXPR_SET_REF) { + assert(expr->left->dtype == &invalid_type); + assert(expr->right->dtype != &invalid_type); + + datatype_set(expr->left, expr->right->dtype); + } + ctx->set = expr->right->set; + expr_postprocess(ctx, &expr->left); + ctx->set = NULL; + break; + case EXPR_UNARY: + if (lhs_is_meta_hour(expr->left->arg) && + expr->right->etype == EXPR_RANGE) { + struct expr *range = expr->right; + + /* Cross-day range needs to be reversed. + * Kernel handles time in UTC. Therefore, + * 03:00-14:00 AEDT (Sidney, Australia) time + * is a cross-day range. + */ + if (mpz_cmp(range->left->value, + range->right->value) <= 0 && + expr->op == OP_NEQ) { + range_expr_swap_values(range); + expr->op = OP_IMPLICIT; + } + } + /* fallthrough */ default: expr_postprocess(ctx, &expr->left); break; @@ -2401,41 +2881,40 @@ static void expr_postprocess(struct rule_pp_ctx *ctx, struct expr **exprp) meta_match_postprocess(ctx, expr); break; case EXPR_BINOP: - relational_binop_postprocess(ctx, expr); + relational_binop_postprocess(ctx, exprp); break; default: break; } break; case EXPR_PAYLOAD: - payload_expr_complete(expr, &ctx->pctx); - payload_dependency_kill(&ctx->pdctx, expr, ctx->pctx.family); + payload_expr_complete(expr, &dl->pctx); + if (expr->payload.inner_desc) { + if (meta_outer_may_dependency_kill(ctx, expr)) { + struct dl_proto_ctx *dl_outer = dl_proto_ctx_outer(ctx); + + payload_dependency_release(&dl_outer->pdctx, expr->payload.inner_desc->base); + } + } + payload_dependency_kill(&dl->pdctx, expr, dl->pctx.family); break; case EXPR_VALUE: - // FIXME - if (expr->byteorder == BYTEORDER_HOST_ENDIAN) - mpz_switch_byteorder(expr->value, expr->len / BITS_PER_BYTE); - - if (expr_basetype(expr)->type == TYPE_STRING) - *exprp = expr_postprocess_string(expr); - - expr = *exprp; - if (expr->dtype->basetype != NULL && - expr->dtype->basetype->type == TYPE_BITMASK) - *exprp = bitmask_expr_to_binops(expr); - + expr_postprocess_value(ctx, exprp); break; case EXPR_RANGE: expr_postprocess(ctx, &expr->left); expr_postprocess(ctx, &expr->right); + break; case EXPR_PREFIX: expr_postprocess(ctx, &expr->prefix); break; case EXPR_SET_ELEM: + ctx->flags |= RULE_PP_IN_SET_ELEM; expr_postprocess(ctx, &expr->key); + ctx->flags &= ~RULE_PP_IN_SET_ELEM; break; case EXPR_EXTHDR: - exthdr_dependency_kill(&ctx->pdctx, expr, ctx->pctx.family); + exthdr_dependency_kill(&dl->pdctx, expr, dl->pctx.family); break; case EXPR_SET_REF: case EXPR_META: @@ -2452,7 +2931,7 @@ static void expr_postprocess(struct rule_pp_ctx *ctx, struct expr **exprp) expr_postprocess(ctx, &expr->hash.expr); break; case EXPR_CT: - ct_expr_update_type(&ctx->pctx, expr); + ct_expr_update_type(&dl->pctx, expr); break; default: BUG("unknown expression type %s\n", expr_name(expr)); @@ -2461,32 +2940,35 @@ static void expr_postprocess(struct rule_pp_ctx *ctx, struct expr **exprp) static void stmt_reject_postprocess(struct rule_pp_ctx *rctx) { + struct dl_proto_ctx *dl = dl_proto_ctx(rctx); const struct proto_desc *desc, *base; struct stmt *stmt = rctx->stmt; int protocol; - switch (rctx->pctx.family) { + switch (dl->pctx.family) { case NFPROTO_IPV4: - stmt->reject.family = rctx->pctx.family; - datatype_set(stmt->reject.expr, &icmp_code_type); + stmt->reject.family = dl->pctx.family; + datatype_set(stmt->reject.expr, &reject_icmp_code_type); if (stmt->reject.type == NFT_REJECT_TCP_RST && - payload_dependency_exists(&rctx->pdctx, + payload_dependency_exists(&dl->pdctx, PROTO_BASE_TRANSPORT_HDR)) - payload_dependency_release(&rctx->pdctx); + payload_dependency_release(&dl->pdctx, + PROTO_BASE_TRANSPORT_HDR); break; case NFPROTO_IPV6: - stmt->reject.family = rctx->pctx.family; - datatype_set(stmt->reject.expr, &icmpv6_code_type); + stmt->reject.family = dl->pctx.family; + datatype_set(stmt->reject.expr, &reject_icmpv6_code_type); if (stmt->reject.type == NFT_REJECT_TCP_RST && - payload_dependency_exists(&rctx->pdctx, + payload_dependency_exists(&dl->pdctx, PROTO_BASE_TRANSPORT_HDR)) - payload_dependency_release(&rctx->pdctx); + payload_dependency_release(&dl->pdctx, + PROTO_BASE_TRANSPORT_HDR); break; case NFPROTO_INET: case NFPROTO_BRIDGE: case NFPROTO_NETDEV: if (stmt->reject.type == NFT_REJECT_ICMPX_UNREACH) { - datatype_set(stmt->reject.expr, &icmpx_code_type); + datatype_set(stmt->reject.expr, &reject_icmpx_code_type); break; } @@ -2495,26 +2977,27 @@ static void stmt_reject_postprocess(struct rule_pp_ctx *rctx) */ stmt->reject.verbose_print = 1; - base = rctx->pctx.protocol[PROTO_BASE_LL_HDR].desc; - desc = rctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + base = dl->pctx.protocol[PROTO_BASE_LL_HDR].desc; + desc = dl->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; protocol = proto_find_num(base, desc); switch (protocol) { case NFPROTO_IPV4: /* INET */ case __constant_htons(ETH_P_IP): /* BRIDGE, NETDEV */ stmt->reject.family = NFPROTO_IPV4; - datatype_set(stmt->reject.expr, &icmp_code_type); + datatype_set(stmt->reject.expr, &reject_icmp_code_type); break; case NFPROTO_IPV6: /* INET */ case __constant_htons(ETH_P_IPV6): /* BRIDGE, NETDEV */ stmt->reject.family = NFPROTO_IPV6; - datatype_set(stmt->reject.expr, &icmpv6_code_type); + datatype_set(stmt->reject.expr, &reject_icmpv6_code_type); break; default: break; } - if (payload_dependency_exists(&rctx->pdctx, PROTO_BASE_NETWORK_HDR)) - payload_dependency_release(&rctx->pdctx); + if (payload_dependency_exists(&dl->pdctx, PROTO_BASE_NETWORK_HDR)) + payload_dependency_release(&dl->pdctx, + PROTO_BASE_NETWORK_HDR); break; default: break; @@ -2557,23 +3040,24 @@ static bool expr_may_merge_range(struct expr *expr, struct expr *prev, static void expr_postprocess_range(struct rule_pp_ctx *ctx, enum ops op) { + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); struct stmt *nstmt, *stmt = ctx->stmt; struct expr *nexpr, *rel; - nexpr = range_expr_alloc(&ctx->pdctx.prev->location, - expr_clone(ctx->pdctx.prev->expr->right), + nexpr = range_expr_alloc(&dl->pdctx.prev->location, + expr_clone(dl->pdctx.prev->expr->right), expr_clone(stmt->expr->right)); expr_set_type(nexpr, stmt->expr->right->dtype, stmt->expr->right->byteorder); - rel = relational_expr_alloc(&ctx->pdctx.prev->location, op, + rel = relational_expr_alloc(&dl->pdctx.prev->location, op, expr_clone(stmt->expr->left), nexpr); nstmt = expr_stmt_alloc(&stmt->location, rel); list_add_tail(&nstmt->list, &stmt->list); - list_del(&ctx->pdctx.prev->list); - stmt_free(ctx->pdctx.prev); + list_del(&dl->pdctx.prev->list); + stmt_free(dl->pdctx.prev); list_del(&stmt->list); stmt_free(stmt); @@ -2582,26 +3066,28 @@ static void expr_postprocess_range(struct rule_pp_ctx *ctx, enum ops op) static void stmt_expr_postprocess(struct rule_pp_ctx *ctx) { + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); enum ops op; expr_postprocess(ctx, &ctx->stmt->expr); - if (ctx->pdctx.prev && ctx->stmt && - ctx->stmt->ops->type == ctx->pdctx.prev->ops->type && - expr_may_merge_range(ctx->stmt->expr, ctx->pdctx.prev->expr, &op)) + if (dl->pdctx.prev && ctx->stmt && + ctx->stmt->ops->type == dl->pdctx.prev->ops->type && + expr_may_merge_range(ctx->stmt->expr, dl->pdctx.prev->expr, &op)) expr_postprocess_range(ctx, op); } static void stmt_payload_binop_pp(struct rule_pp_ctx *ctx, struct expr *binop) { + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); struct expr *payload = binop->left; struct expr *mask = binop->right; unsigned int shift; assert(payload->etype == EXPR_PAYLOAD); - if (payload_expr_trim(payload, mask, &ctx->pctx, &shift)) { - __binop_adjust(binop, mask, shift); - payload_expr_complete(payload, &ctx->pctx); + if (payload_expr_trim(payload, mask, &dl->pctx, &shift)) { + binop_adjust(binop, mask, shift); + payload_expr_complete(payload, &dl->pctx); expr_set_type(mask, payload->dtype, payload->byteorder); } @@ -2695,7 +3181,7 @@ static void stmt_payload_binop_postprocess(struct rule_pp_ctx *ctx) mpz_set(mask->value, bitmask); mpz_clear(bitmask); - binop_postprocess(ctx, expr); + binop_postprocess(ctx, expr, &expr->left); if (!payload_is_known(payload)) { mpz_set(mask->value, tmp); mpz_clear(tmp); @@ -2756,20 +3242,34 @@ static void stmt_payload_binop_postprocess(struct rule_pp_ctx *ctx) static void stmt_payload_postprocess(struct rule_pp_ctx *ctx) { + struct dl_proto_ctx *dl = dl_proto_ctx(ctx); struct stmt *stmt = ctx->stmt; + payload_expr_complete(stmt->payload.expr, &dl->pctx); + if (!payload_is_known(stmt->payload.expr)) + stmt_payload_binop_postprocess(ctx); + expr_postprocess(ctx, &stmt->payload.expr); expr_set_type(stmt->payload.val, stmt->payload.expr->dtype, stmt->payload.expr->byteorder); - if (!payload_is_known(stmt->payload.expr)) - stmt_payload_binop_postprocess(ctx); - expr_postprocess(ctx, &stmt->payload.val); } +static void stmt_queue_postprocess(struct rule_pp_ctx *ctx) +{ + struct stmt *stmt = ctx->stmt; + struct expr *e = stmt->queue.queue; + + if (e == NULL || e->etype == EXPR_VALUE || + e->etype == EXPR_RANGE) + return; + + expr_postprocess(ctx, &stmt->queue.queue); +} + /* * We can only remove payload dependencies if they occur without * a statement with side effects in between. @@ -2791,18 +3291,75 @@ rule_maybe_reset_payload_deps(struct payload_dep_ctx *pdctx, enum stmt_types t) payload_dependency_reset(pdctx); } +static bool has_inner_desc(const struct expr *expr) +{ + struct expr *i; + + switch (expr->etype) { + case EXPR_BINOP: + return has_inner_desc(expr->left); + case EXPR_CONCAT: + list_for_each_entry(i, &expr->expressions, list) { + if (has_inner_desc(i)) + return true; + } + break; + case EXPR_META: + return expr->meta.inner_desc; + case EXPR_PAYLOAD: + return expr->payload.inner_desc; + case EXPR_SET_ELEM: + return has_inner_desc(expr->key); + default: + break; + } + + return false; +} + +static struct dl_proto_ctx *rule_update_dl_proto_ctx(struct rule_pp_ctx *rctx) +{ + const struct stmt *stmt = rctx->stmt; + bool inner = false; + + switch (stmt->ops->type) { + case STMT_EXPRESSION: + if (has_inner_desc(stmt->expr->left)) + inner = true; + break; + case STMT_SET: + if (has_inner_desc(stmt->set.key)) + inner = true; + break; + default: + break; + } + + if (inner) + rctx->dl = &rctx->_dl[1]; + else + rctx->dl = &rctx->_dl[0]; + + return rctx->dl; +} + static void rule_parse_postprocess(struct netlink_parse_ctx *ctx, struct rule *rule) { - struct rule_pp_ctx rctx; struct stmt *stmt, *next; + struct dl_proto_ctx *dl; + struct rule_pp_ctx rctx; + struct expr *expr; memset(&rctx, 0, sizeof(rctx)); - proto_ctx_init(&rctx.pctx, rule->handle.family, ctx->debug_mask); + proto_ctx_init(&rctx._dl[0].pctx, rule->handle.family, ctx->debug_mask, false); + /* use NFPROTO_BRIDGE to set up proto_eth as base protocol. */ + proto_ctx_init(&rctx._dl[1].pctx, NFPROTO_BRIDGE, ctx->debug_mask, true); list_for_each_entry_safe(stmt, next, &rule->stmts, list) { enum stmt_types type = stmt->ops->type; rctx.stmt = stmt; + dl = rule_update_dl_proto_ctx(&rctx); switch (type) { case STMT_EXPRESSION: @@ -2823,24 +3380,24 @@ static void rule_parse_postprocess(struct netlink_parse_ctx *ctx, struct rule *r expr_postprocess(&rctx, &stmt->ct.expr); if (stmt->ct.expr->etype == EXPR_BINOP && - stmt->ct.key == NFT_CT_EVENTMASK) - stmt->ct.expr = binop_tree_to_list(NULL, - stmt->ct.expr); + stmt->ct.key == NFT_CT_EVENTMASK) { + expr = binop_tree_to_list(NULL, stmt->ct.expr); + expr_free(stmt->ct.expr); + stmt->ct.expr = expr; + } } break; case STMT_NAT: if (stmt->nat.addr != NULL) expr_postprocess(&rctx, &stmt->nat.addr); - if (stmt->nat.proto != NULL) { - payload_dependency_reset(&rctx.pdctx); + if (stmt->nat.proto != NULL) expr_postprocess(&rctx, &stmt->nat.proto); - } break; case STMT_TPROXY: if (stmt->tproxy.addr) expr_postprocess(&rctx, &stmt->tproxy.addr); if (stmt->tproxy.port) { - payload_dependency_reset(&rctx.pdctx); + payload_dependency_reset(&dl->pdctx); expr_postprocess(&rctx, &stmt->tproxy.port); } break; @@ -2871,13 +3428,16 @@ static void rule_parse_postprocess(struct netlink_parse_ctx *ctx, struct rule *r case STMT_OBJREF: expr_postprocess(&rctx, &stmt->objref.expr); break; + case STMT_QUEUE: + stmt_queue_postprocess(&rctx); + break; default: break; } - rctx.pdctx.prev = rctx.stmt; + dl->pdctx.prev = rctx.stmt; - rule_maybe_reset_payload_deps(&rctx.pdctx, type); + rule_maybe_reset_payload_deps(&dl->pdctx, type); } } @@ -2929,6 +3489,7 @@ struct rule *netlink_delinearize_rule(struct netlink_ctx *ctx, memset(&_ctx, 0, sizeof(_ctx)); _ctx.msgs = ctx->msgs; _ctx.debug_mask = ctx->nft->debug_mask; + _ctx.nlctx = ctx; memset(&h, 0, sizeof(h)); h.family = nftnl_rule_get_u32(nlr, NFTNL_RULE_FAMILY); @@ -2942,7 +3503,10 @@ struct rule *netlink_delinearize_rule(struct netlink_ctx *ctx, pctx->rule = rule_alloc(&netlink_location, &h); pctx->table = table_cache_find(&ctx->nft->cache.table_cache, h.table.name, h.family); - assert(pctx->table != NULL); + if (!pctx->table) { + errno = ENOENT; + return NULL; + } pctx->rule->comment = nftnl_rule_get_comment(nlr); |