diff options
Diffstat (limited to 'src/evaluate.c')
-rw-r--r-- | src/evaluate.c | 2607 |
1 files changed, 2036 insertions, 571 deletions
diff --git a/src/evaluate.c b/src/evaluate.c index b64ed3c0..593a0140 100644 --- a/src/evaluate.c +++ b/src/evaluate.c @@ -8,11 +8,10 @@ * Development of this code funded by Astaro AG (http://www.astaro.com/) */ +#include <nft.h> + #include <stddef.h> -#include <stdlib.h> #include <stdio.h> -#include <stdint.h> -#include <string.h> #include <arpa/inet.h> #include <linux/netfilter.h> #include <linux/netfilter_arp.h> @@ -29,6 +28,7 @@ #include <expression.h> #include <statement.h> +#include <intervals.h> #include <netlink.h> #include <time.h> #include <rule.h> @@ -38,6 +38,13 @@ #include <utils.h> #include <xt.h> +struct proto_ctx *eval_proto_ctx(struct eval_ctx *ctx) +{ + uint8_t idx = ctx->inner_desc ? 1 : 0; + + return &ctx->_pctx[idx]; +} + static int expr_evaluate(struct eval_ctx *ctx, struct expr **expr); static const char * const byteorder_names[] = { @@ -67,6 +74,33 @@ static int __fmtstring(3, 4) set_error(struct eval_ctx *ctx, return -1; } +static const char *stmt_name(const struct stmt *stmt) +{ + switch (stmt->ops->type) { + case STMT_NAT: + switch (stmt->nat.type) { + case NFT_NAT_SNAT: + return "snat"; + case NFT_NAT_DNAT: + return "dnat"; + case NFT_NAT_REDIR: + return "redirect"; + case NFT_NAT_MASQ: + return "masquerade"; + } + break; + default: + break; + } + + return stmt->ops->name; +} + +static int stmt_error_range(struct eval_ctx *ctx, const struct stmt *stmt, const struct expr *e) +{ + return expr_error(ctx->msgs, e, "%s: range argument not supported", stmt_name(stmt)); +} + static void key_fix_dtype_byteorder(struct expr *key) { const struct datatype *dtype = key->dtype; @@ -74,7 +108,7 @@ static void key_fix_dtype_byteorder(struct expr *key) if (dtype->byteorder == key->byteorder) return; - datatype_set(key, set_datatype_alloc(dtype, key->byteorder)); + __datatype_set(key, set_datatype_alloc(dtype, key->byteorder)); } static int set_evaluate(struct eval_ctx *ctx, struct set *set); @@ -82,7 +116,8 @@ static struct expr *implicit_set_declaration(struct eval_ctx *ctx, const char *name, struct expr *key, struct expr *data, - struct expr *expr) + struct expr *expr, + uint32_t flags) { struct cmd *cmd; struct set *set; @@ -92,17 +127,25 @@ static struct expr *implicit_set_declaration(struct eval_ctx *ctx, key_fix_dtype_byteorder(key); set = set_alloc(&expr->location); - set->flags = NFT_SET_ANONYMOUS | expr->set_flags; + set->flags = expr->set_flags | flags; set->handle.set.name = xstrdup(name); set->key = key; set->data = data; set->init = expr; set->automerge = set->flags & NFT_SET_INTERVAL; + handle_merge(&set->handle, &ctx->cmd->handle); + + if (set_evaluate(ctx, set) < 0) { + if (set->flags & NFT_SET_MAP) + set->init = NULL; + set_free(set); + return NULL; + } + if (ctx->table != NULL) list_add_tail(&set->list, &ctx->table->sets); else { - handle_merge(&set->handle, &ctx->cmd->handle); memset(&h, 0, sizeof(h)); handle_merge(&h, &set->handle); h.set.location = expr->location; @@ -111,8 +154,6 @@ static struct expr *implicit_set_declaration(struct eval_ctx *ctx, list_add_tail(&cmd->list, &ctx->cmd->list); } - set_evaluate(ctx, set); - return set_ref_expr_alloc(&expr->location, set); } @@ -138,6 +179,7 @@ static enum ops byteorder_conversion_op(struct expr *expr, static int byteorder_conversion(struct eval_ctx *ctx, struct expr **expr, enum byteorder byteorder) { + enum datatypes basetype; enum ops op; assert(!expr_is_constant(*expr) || expr_is_singleton(*expr)); @@ -146,16 +188,54 @@ static int byteorder_conversion(struct eval_ctx *ctx, struct expr **expr, return 0; /* Conversion for EXPR_CONCAT is handled for single composing ranges */ - if ((*expr)->etype == EXPR_CONCAT) + if ((*expr)->etype == EXPR_CONCAT) { + struct expr *i, *next, *unary; + + list_for_each_entry_safe(i, next, &(*expr)->expressions, list) { + if (i->byteorder == BYTEORDER_BIG_ENDIAN) + continue; + + basetype = expr_basetype(i)->type; + if (basetype == TYPE_STRING) + continue; + + assert(basetype == TYPE_INTEGER); + + switch (i->etype) { + case EXPR_VALUE: + if (i->byteorder == BYTEORDER_HOST_ENDIAN) + mpz_switch_byteorder(i->value, div_round_up(i->len, BITS_PER_BYTE)); + break; + default: + if (div_round_up(i->len, BITS_PER_BYTE) >= 2) { + op = byteorder_conversion_op(i, byteorder); + unary = unary_expr_alloc(&i->location, op, i); + if (expr_evaluate(ctx, &unary) < 0) + return -1; + + list_replace(&i->list, &unary->list); + } + } + } + return 0; + } - if (expr_basetype(*expr)->type != TYPE_INTEGER) + basetype = expr_basetype(*expr)->type; + switch (basetype) { + case TYPE_INTEGER: + break; + case TYPE_STRING: + return 0; + default: return expr_error(ctx->msgs, *expr, - "Byteorder mismatch: expected %s, got %s", + "Byteorder mismatch: %s expected %s, %s got %s", byteorder_names[byteorder], + expr_name(*expr), byteorder_names[(*expr)->byteorder]); + } - if (expr_is_constant(*expr) || (*expr)->len / BITS_PER_BYTE < 2) + if (expr_is_constant(*expr) || div_round_up((*expr)->len, BITS_PER_BYTE) < 2) (*expr)->byteorder = byteorder; else { op = byteorder_conversion_op(*expr, byteorder); @@ -166,20 +246,6 @@ static int byteorder_conversion(struct eval_ctx *ctx, struct expr **expr, return 0; } -static struct table *table_lookup_global(struct eval_ctx *ctx) -{ - struct table *table; - - if (ctx->table != NULL) - return ctx->table; - - table = table_lookup(&ctx->cmd->handle, &ctx->nft->cache); - if (table == NULL) - return NULL; - - return table; -} - static int table_not_found(struct eval_ctx *ctx) { struct table *table; @@ -190,7 +256,7 @@ static int table_not_found(struct eval_ctx *ctx) "%s", strerror(ENOENT)); return cmd_error(ctx, &ctx->cmd->handle.table.location, - "%s; did you mean table ‘%s’ in family %s?", + "%s; did you mean table '%s' in family %s?", strerror(ENOENT), table->handle.table.name, family2str(table->handle.family)); } @@ -206,7 +272,7 @@ static int chain_not_found(struct eval_ctx *ctx) "%s", strerror(ENOENT)); return cmd_error(ctx, &ctx->cmd->handle.chain.location, - "%s; did you mean chain ‘%s’ in table %s ‘%s’?", + "%s; did you mean chain '%s' in table %s '%s'?", strerror(ENOENT), chain->handle.chain.name, family2str(chain->handle.family), table->handle.table.name); @@ -223,7 +289,7 @@ static int set_not_found(struct eval_ctx *ctx, const struct location *loc, return cmd_error(ctx, loc, "%s", strerror(ENOENT)); return cmd_error(ctx, loc, - "%s; did you mean %s ‘%s’ in table %s ‘%s’?", + "%s; did you mean %s '%s' in table %s '%s'?", strerror(ENOENT), set_is_map(set->flags) ? "map" : "set", set->handle.set.name, @@ -238,11 +304,11 @@ static int flowtable_not_found(struct eval_ctx *ctx, const struct location *loc, struct flowtable *ft; ft = flowtable_lookup_fuzzy(ft_name, &ctx->nft->cache, &table); - if (ft == NULL) + if (!ft) return cmd_error(ctx, loc, "%s", strerror(ENOENT)); return cmd_error(ctx, loc, - "%s; did you mean flowtable ‘%s’ in table %s ‘%s’?", + "%s; did you mean flowtable '%s' in table %s '%s'?", strerror(ENOENT), ft->handle.flowtable.name, family2str(ft->handle.family), table->handle.table.name); @@ -253,7 +319,10 @@ static int flowtable_not_found(struct eval_ctx *ctx, const struct location *loc, */ static int expr_evaluate_symbol(struct eval_ctx *ctx, struct expr **expr) { - struct parse_ctx parse_ctx = { .tbl = &ctx->nft->output.tbl, }; + struct parse_ctx parse_ctx = { + .tbl = &ctx->nft->output.tbl, + .input = &ctx->nft->input, + }; struct error_record *erec; struct table *table; struct set *set; @@ -269,12 +338,14 @@ static int expr_evaluate_symbol(struct eval_ctx *ctx, struct expr **expr) } break; case SYMBOL_SET: - table = table_lookup_global(ctx); + table = table_cache_find(&ctx->nft->cache.table_cache, + ctx->cmd->handle.table.name, + ctx->cmd->handle.family); if (table == NULL) return table_not_found(ctx); - set = set_lookup(table, (*expr)->identifier); - if (set == NULL) + set = set_cache_find(table, (*expr)->identifier); + if (set == NULL || !set->key) return set_not_found(ctx, &(*expr)->location, (*expr)->identifier); @@ -324,8 +395,11 @@ static int expr_evaluate_string(struct eval_ctx *ctx, struct expr **exprp) return 0; } - if (datalen >= 1 && - data[datalen - 1] == '\\') { + if (datalen == 0) + return expr_error(ctx->msgs, expr, + "All-wildcard strings are not supported"); + + if (data[datalen - 1] == '\\') { char unescaped_str[data_len]; memset(unescaped_str, 0, sizeof(unescaped_str)); @@ -338,15 +412,18 @@ static int expr_evaluate_string(struct eval_ctx *ctx, struct expr **exprp) *exprp = value; return 0; } + + data[datalen] = 0; value = constant_expr_alloc(&expr->location, ctx->ectx.dtype, BYTEORDER_HOST_ENDIAN, - datalen * BITS_PER_BYTE, data); + expr->len, data); prefix = prefix_expr_alloc(&expr->location, value, datalen * BITS_PER_BYTE); datatype_set(prefix, ctx->ectx.dtype); prefix->flags |= EXPR_F_CONSTANT; prefix->byteorder = BYTEORDER_HOST_ENDIAN; + prefix->len = expr->len; expr_free(expr); *exprp = prefix; @@ -357,6 +434,7 @@ static int expr_evaluate_integer(struct eval_ctx *ctx, struct expr **exprp) { struct expr *expr = *exprp; char *valstr, *rangestr; + uint32_t masklen; mpz_t mask; if (ctx->ectx.maxval > 0 && @@ -365,24 +443,29 @@ static int expr_evaluate_integer(struct eval_ctx *ctx, struct expr **exprp) expr_error(ctx->msgs, expr, "Value %s exceeds valid range 0-%u", valstr, ctx->ectx.maxval); - free(valstr); + nft_gmp_free(valstr); return -1; } - mpz_init_bitmask(mask, ctx->ectx.len); + if (ctx->stmt_len > ctx->ectx.len) + masklen = ctx->stmt_len; + else + masklen = ctx->ectx.len; + + mpz_init_bitmask(mask, masklen); if (mpz_cmp(expr->value, mask) > 0) { valstr = mpz_get_str(NULL, 10, expr->value); rangestr = mpz_get_str(NULL, 10, mask); expr_error(ctx->msgs, expr, "Value %s exceeds valid range 0-%s", valstr, rangestr); - free(valstr); - free(rangestr); + nft_gmp_free(valstr); + nft_gmp_free(rangestr); mpz_clear(mask); return -1; } expr->byteorder = ctx->ectx.byteorder; - expr->len = ctx->ectx.len; + expr->len = masklen; mpz_clear(mask); return 0; } @@ -414,20 +497,34 @@ static int expr_evaluate_primary(struct eval_ctx *ctx, struct expr **expr) return 0; } +int stmt_dependency_evaluate(struct eval_ctx *ctx, struct stmt *stmt) +{ + uint32_t stmt_len = ctx->stmt_len; + + if (stmt_evaluate(ctx, stmt) < 0) + return stmt_error(ctx, stmt, "dependency statement is invalid"); + + ctx->stmt_len = stmt_len; + + return 0; +} + static int -conflict_resolution_gen_dependency(struct eval_ctx *ctx, int protocol, - const struct expr *expr, - struct stmt **res) +ll_conflict_resolution_gen_dependency(struct eval_ctx *ctx, int protocol, + const struct expr *expr, + struct stmt **res) { enum proto_bases base = expr->payload.base; const struct proto_hdr_template *tmpl; const struct proto_desc *desc = NULL; struct expr *dep, *left, *right; + struct proto_ctx *pctx; struct stmt *stmt; assert(expr->payload.base == PROTO_BASE_LL_HDR); - desc = ctx->pctx.protocol[base].desc; + pctx = eval_proto_ctx(ctx); + desc = pctx->protocol[base].desc; tmpl = &desc->templates[desc->protocol_key]; left = payload_expr_alloc(&expr->location, desc, desc->protocol_key); @@ -437,10 +534,13 @@ conflict_resolution_gen_dependency(struct eval_ctx *ctx, int protocol, dep = relational_expr_alloc(&expr->location, OP_EQ, left, right); stmt = expr_stmt_alloc(&dep->location, dep); - if (stmt_evaluate(ctx, stmt) < 0) + if (stmt_dependency_evaluate(ctx, stmt) < 0) return expr_error(ctx->msgs, expr, "dependency statement is invalid"); + if (ctx->inner_desc) + left->payload.inner_desc = ctx->inner_desc; + *res = stmt; return 0; } @@ -461,10 +561,11 @@ static uint8_t expr_offset_shift(const struct expr *expr, unsigned int offset, return shift; } -static void expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp) +static int expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp) { struct expr *expr = *exprp, *and, *mask, *rshift, *off; unsigned masklen, len = expr->len, extra_len = 0; + enum byteorder byteorder; uint8_t shift; mpz_t bitmask; @@ -474,7 +575,7 @@ static void expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp) &extra_len); break; case EXPR_EXTHDR: - shift = expr_offset_shift(expr, expr->exthdr.tmpl->offset, + shift = expr_offset_shift(expr, expr->exthdr.offset, &extra_len); break; default: @@ -482,7 +583,10 @@ static void expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp) } masklen = len + shift; - assert(masklen <= NFT_REG_SIZE * BITS_PER_BYTE); + + if (masklen > NFT_REG_SIZE * BITS_PER_BYTE) + return expr_error(ctx->msgs, expr, "mask length %u exceeds allowed maximum of %u\n", + masklen, NFT_REG_SIZE * BITS_PER_BYTE); mpz_init2(bitmask, masklen); mpz_bitmask(bitmask, len); @@ -499,6 +603,16 @@ static void expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp) and->len = masklen; if (shift) { + if ((ctx->ectx.key || ctx->stmt_len > 0) && + div_round_up(masklen, BITS_PER_BYTE) > 1) { + int op = byteorder_conversion_op(expr, BYTEORDER_HOST_ENDIAN); + and = unary_expr_alloc(&expr->location, op, and); + and->len = masklen; + byteorder = BYTEORDER_HOST_ENDIAN; + } else { + byteorder = expr->byteorder; + } + off = constant_expr_alloc(&expr->location, expr_basetype(expr), BYTEORDER_HOST_ENDIAN, @@ -506,7 +620,7 @@ static void expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp) rshift = binop_expr_alloc(&expr->location, OP_RSHIFT, and, off); rshift->dtype = expr->dtype; - rshift->byteorder = expr->byteorder; + rshift->byteorder = byteorder; rshift->len = masklen; *exprp = rshift; @@ -515,10 +629,13 @@ static void expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp) if (extra_len) expr->len += extra_len; + + return 0; } static int __expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) { + const struct expr *key = ctx->ectx.key; struct expr *expr = *exprp; if (expr->exthdr.flags & NFT_EXTHDR_F_PRESENT) @@ -527,18 +644,22 @@ static int __expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) if (expr_evaluate_primary(ctx, exprp) < 0) return -1; - if (expr->exthdr.tmpl->offset % BITS_PER_BYTE != 0 || - expr->len % BITS_PER_BYTE != 0) - expr_evaluate_bits(ctx, exprp); + ctx->ectx.key = key; + + if (expr->exthdr.offset % BITS_PER_BYTE != 0 || + expr->len % BITS_PER_BYTE != 0) { + int err = expr_evaluate_bits(ctx, exprp); + + if (err) + return err; + } switch (expr->exthdr.op) { case NFT_EXTHDR_OP_TCPOPT: { static const unsigned int max_tcpoptlen = (15 * 4 - 20) * BITS_PER_BYTE; - unsigned int totlen = 0; + unsigned int totlen; - totlen += expr->exthdr.tmpl->offset; - totlen += expr->exthdr.tmpl->len; - totlen += expr->exthdr.offset; + totlen = expr->exthdr.tmpl->len + expr->exthdr.offset; if (totlen > max_tcpoptlen) return expr_error(ctx->msgs, expr, @@ -548,11 +669,9 @@ static int __expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) } case NFT_EXTHDR_OP_IPV4: { static const unsigned int max_ipoptlen = 40 * BITS_PER_BYTE; - unsigned int totlen = 0; + unsigned int totlen; - totlen += expr->exthdr.tmpl->offset; - totlen += expr->exthdr.tmpl->len; - totlen += expr->exthdr.offset; + totlen = expr->exthdr.offset + expr->exthdr.tmpl->len; if (totlen > max_ipoptlen) return expr_error(ctx->msgs, expr, @@ -577,13 +696,14 @@ static int expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) const struct proto_desc *base, *dependency = NULL; enum proto_bases pb = PROTO_BASE_NETWORK_HDR; struct expr *expr = *exprp; + struct proto_ctx *pctx; struct stmt *nstmt; switch (expr->exthdr.op) { case NFT_EXTHDR_OP_TCPOPT: - dependency = &proto_tcp; - pb = PROTO_BASE_TRANSPORT_HDR; - break; + case NFT_EXTHDR_OP_SCTP: + case NFT_EXTHDR_OP_DCCP: + return __expr_evaluate_exthdr(ctx, exprp); case NFT_EXTHDR_OP_IPV4: dependency = &proto_ip; break; @@ -595,7 +715,8 @@ static int expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) assert(dependency); - base = ctx->pctx.protocol[pb].desc; + pctx = eval_proto_ctx(ctx); + base = pctx->protocol[pb].desc; if (base == dependency) return __expr_evaluate_exthdr(ctx, exprp); @@ -613,7 +734,7 @@ static int expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) /* dependency supersede. * - * 'inet' is a 'phony' l2 dependency used by NFPROTO_INET to fulfill network + * 'inet' is a 'phony' l2 dependency used by NFPROTO_INET to fulfil network * header dependency, i.e. ensure that 'ip saddr 1.2.3.4' only sees ip headers. * * If a match expression that depends on a particular L2 header, e.g. ethernet, @@ -625,7 +746,7 @@ static int expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) * Example: inet filter in ip saddr 1.2.3.4 ether saddr a:b:c:d:e:f * * ip saddr adds meta dependency on ipv4 packets - * ether saddr adds another dependeny on ethernet frames. + * ether saddr adds another dependency on ethernet frames. */ static int meta_iiftype_gen_dependency(struct eval_ctx *ctx, struct expr *payload, struct stmt **res) @@ -639,9 +760,11 @@ static int meta_iiftype_gen_dependency(struct eval_ctx *ctx, "for this family"); nstmt = meta_stmt_meta_iiftype(&payload->location, type); - if (stmt_evaluate(ctx, nstmt) < 0) - return expr_error(ctx->msgs, payload, - "dependency statement is invalid"); + if (stmt_dependency_evaluate(ctx, nstmt) < 0) + return -1; + + if (ctx->inner_desc) + nstmt->expr->left->meta.inner_desc = ctx->inner_desc; *res = nstmt; return 0; @@ -652,35 +775,92 @@ static bool proto_is_dummy(const struct proto_desc *desc) return desc == &proto_inet || desc == &proto_netdev; } -static int resolve_protocol_conflict(struct eval_ctx *ctx, - const struct proto_desc *desc, - struct expr *payload) +static int stmt_dep_conflict(struct eval_ctx *ctx, const struct stmt *nstmt) +{ + struct stmt *stmt; + + list_for_each_entry(stmt, &ctx->rule->stmts, list) { + if (stmt == nstmt) + break; + + if (stmt->ops->type != STMT_EXPRESSION || + stmt->expr->etype != EXPR_RELATIONAL || + stmt->expr->right->etype != EXPR_VALUE || + stmt->expr->left->etype != EXPR_PAYLOAD || + stmt->expr->left->etype != nstmt->expr->left->etype || + stmt->expr->left->len != nstmt->expr->left->len) + continue; + + if (stmt->expr->left->payload.desc != nstmt->expr->left->payload.desc || + stmt->expr->left->payload.inner_desc != nstmt->expr->left->payload.inner_desc || + stmt->expr->left->payload.base != nstmt->expr->left->payload.base || + stmt->expr->left->payload.offset != nstmt->expr->left->payload.offset) + continue; + + return stmt_binary_error(ctx, stmt, nstmt, + "conflicting statements"); + } + + return 0; +} + +static int rule_stmt_dep_add(struct eval_ctx *ctx, + struct stmt *nstmt, struct stmt *stmt) +{ + rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt); + + if (stmt_dep_conflict(ctx, nstmt) < 0) + return -1; + + return 0; +} + +static int resolve_ll_protocol_conflict(struct eval_ctx *ctx, + const struct proto_desc *desc, + struct expr *payload) { enum proto_bases base = payload->payload.base; struct stmt *nstmt = NULL; + struct proto_ctx *pctx; + unsigned int i; int link, err; - if (payload->payload.base == PROTO_BASE_LL_HDR && - proto_is_dummy(desc)) { - err = meta_iiftype_gen_dependency(ctx, payload, &nstmt); - if (err < 0) - return err; + assert(base == PROTO_BASE_LL_HDR); - rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt); - } + pctx = eval_proto_ctx(ctx); - assert(base <= PROTO_BASE_MAX); - /* This payload and the existing context don't match, conflict. */ - if (ctx->pctx.protocol[base + 1].desc != NULL) - return 1; + if (proto_is_dummy(desc)) { + if (ctx->inner_desc) { + proto_ctx_update(pctx, PROTO_BASE_LL_HDR, &payload->location, &proto_eth); + } else { + err = meta_iiftype_gen_dependency(ctx, payload, &nstmt); + if (err < 0) + return err; + + desc = payload->payload.desc; + if (rule_stmt_dep_add(ctx, nstmt, ctx->stmt) < 0) + return -1; + } + } else { + unsigned int i; + + /* payload desc stored in the L2 header stack? No conflict. */ + for (i = 0; i < pctx->stacked_ll_count; i++) { + if (pctx->stacked_ll[i] == payload->payload.desc) + return 0; + } + } link = proto_find_num(desc, payload->payload.desc); if (link < 0 || - conflict_resolution_gen_dependency(ctx, link, payload, &nstmt) < 0) + ll_conflict_resolution_gen_dependency(ctx, link, payload, &nstmt) < 0) return 1; - payload->payload.offset += ctx->pctx.protocol[base].offset; - rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt); + for (i = 0; i < pctx->stacked_ll_count; i++) + payload->payload.offset += pctx->stacked_ll[i]->length; + + if (rule_stmt_dep_add(ctx, nstmt, ctx->stmt) < 0) + return -1; return 0; } @@ -695,45 +875,106 @@ static int __expr_evaluate_payload(struct eval_ctx *ctx, struct expr *expr) struct expr *payload = expr; enum proto_bases base = payload->payload.base; const struct proto_desc *desc; + struct proto_ctx *pctx; struct stmt *nstmt; int err; if (expr->etype == EXPR_PAYLOAD && expr->payload.is_raw) return 0; - desc = ctx->pctx.protocol[base].desc; + pctx = eval_proto_ctx(ctx); + desc = pctx->protocol[base].desc; if (desc == NULL) { if (payload_gen_dependency(ctx, payload, &nstmt) < 0) return -1; - rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt); - } else { - /* No conflict: Same payload protocol as context, adjust offset - * if needed. - */ - if (desc == payload->payload.desc) { - payload->payload.offset += - ctx->pctx.protocol[base].offset; - return 0; - } - /* If we already have context and this payload is on the same - * base, try to resolve the protocol conflict. - */ - if (payload->payload.base == desc->base) { - err = resolve_protocol_conflict(ctx, desc, payload); - if (err <= 0) - return err; + if (rule_stmt_dep_add(ctx, nstmt, ctx->stmt) < 0) + return -1; - desc = ctx->pctx.protocol[base].desc; - if (desc == payload->payload.desc) - return 0; + desc = pctx->protocol[base].desc; + + if (desc == expr->payload.desc) + goto check_icmp; + + if (base == PROTO_BASE_LL_HDR) { + int link; + + link = proto_find_num(desc, payload->payload.desc); + if (link < 0 || + ll_conflict_resolution_gen_dependency(ctx, link, payload, &nstmt) < 0) + return expr_error(ctx->msgs, payload, + "conflicting protocols specified: %s vs. %s", + desc->name, + payload->payload.desc->name); + + assert(pctx->stacked_ll_count); + payload->payload.offset += pctx->stacked_ll[0]->length; + + if (rule_stmt_dep_add(ctx, nstmt, ctx->stmt) < 0) + return -1; + + return 1; } + goto check_icmp; + } + + if (payload->payload.base == desc->base && + proto_ctx_is_ambiguous(pctx, base)) { + desc = proto_ctx_find_conflict(pctx, base, payload->payload.desc); + assert(desc); + return expr_error(ctx->msgs, payload, "conflicting protocols specified: %s vs. %s", - ctx->pctx.protocol[base].desc->name, + desc->name, payload->payload.desc->name); } - return 0; + + /* No conflict: Same payload protocol as context, adjust offset + * if needed. + */ + if (desc == payload->payload.desc) { + const struct proto_hdr_template *tmpl; + + if (desc->base == PROTO_BASE_LL_HDR) { + unsigned int i; + + for (i = 0; i < pctx->stacked_ll_count; i++) + payload->payload.offset += pctx->stacked_ll[i]->length; + } +check_icmp: + if (desc != &proto_icmp && desc != &proto_icmp6) + return 0; + + tmpl = expr->payload.tmpl; + + if (!tmpl || !tmpl->icmp_dep) + return 0; + + if (payload_gen_icmp_dependency(ctx, expr, &nstmt) < 0) + return -1; + + if (nstmt && rule_stmt_dep_add(ctx, nstmt, ctx->stmt) < 0) + return -1; + + return 0; + } + /* If we already have context and this payload is on the same + * base, try to resolve the protocol conflict. + */ + if (base == PROTO_BASE_LL_HDR) { + err = resolve_ll_protocol_conflict(ctx, desc, payload); + if (err <= 0) + return err; + + desc = pctx->protocol[base].desc; + if (desc == payload->payload.desc) + return 0; + } + return expr_error(ctx->msgs, payload, + "conflicting %s protocols specified: %s vs. %s", + proto_base_names[base], + pctx->protocol[base].desc->name, + payload->payload.desc->name); } static bool payload_needs_adjustment(const struct expr *expr) @@ -744,6 +985,7 @@ static bool payload_needs_adjustment(const struct expr *expr) static int expr_evaluate_payload(struct eval_ctx *ctx, struct expr **exprp) { + const struct expr *key = ctx->ectx.key; struct expr *expr = *exprp; if (expr->payload.evaluated) @@ -755,14 +997,82 @@ static int expr_evaluate_payload(struct eval_ctx *ctx, struct expr **exprp) if (expr_evaluate_primary(ctx, exprp) < 0) return -1; - if (payload_needs_adjustment(expr)) - expr_evaluate_bits(ctx, exprp); + ctx->ectx.key = key; + + if (payload_needs_adjustment(expr)) { + int err = expr_evaluate_bits(ctx, exprp); + + if (err) + return err; + } expr->payload.evaluated = true; return 0; } +static int expr_evaluate_inner(struct eval_ctx *ctx, struct expr **exprp) +{ + struct proto_ctx *pctx = eval_proto_ctx(ctx); + const struct proto_desc *desc = NULL; + struct expr *expr = *exprp; + int ret; + + assert(expr->etype == EXPR_PAYLOAD); + + pctx = eval_proto_ctx(ctx); + desc = pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc; + + if (desc == NULL && + expr->payload.inner_desc->base < PROTO_BASE_INNER_HDR) { + struct stmt *nstmt; + + if (payload_gen_inner_dependency(ctx, expr, &nstmt) < 0) + return -1; + + if (rule_stmt_dep_add(ctx, nstmt, ctx->stmt) < 0) + return -1; + + proto_ctx_update(pctx, PROTO_BASE_TRANSPORT_HDR, &expr->location, expr->payload.inner_desc); + } + + if (expr->payload.inner_desc->base == PROTO_BASE_INNER_HDR) { + desc = pctx->protocol[expr->payload.inner_desc->base - 1].desc; + if (!desc) { + return expr_error(ctx->msgs, expr, + "no transport protocol specified"); + } + + if (proto_find_num(desc, expr->payload.inner_desc) < 0) { + return expr_error(ctx->msgs, expr, + "unexpected transport protocol %s", + desc->name); + } + + proto_ctx_update(pctx, expr->payload.inner_desc->base, &expr->location, + expr->payload.inner_desc); + } + + if (expr->payload.base != PROTO_BASE_INNER_HDR) + ctx->inner_desc = expr->payload.inner_desc; + + ret = expr_evaluate_payload(ctx, exprp); + + return ret; +} + +static int expr_evaluate_payload_inner(struct eval_ctx *ctx, struct expr **exprp) +{ + int ret; + + if ((*exprp)->payload.inner_desc) + ret = expr_evaluate_inner(ctx, exprp); + else + ret = expr_evaluate_payload(ctx, exprp); + + return ret; +} + /* * RT expression: validate protocol dependencies. */ @@ -770,20 +1080,22 @@ static int expr_evaluate_rt(struct eval_ctx *ctx, struct expr **expr) { static const char emsg[] = "cannot determine ip protocol version, use \"ip nexthop\" or \"ip6 nexthop\" instead"; struct expr *rt = *expr; + struct proto_ctx *pctx; - rt_expr_update_type(&ctx->pctx, rt); + pctx = eval_proto_ctx(ctx); + rt_expr_update_type(pctx, rt); switch (rt->rt.key) { case NFT_RT_NEXTHOP4: if (rt->dtype != &ipaddr_type) return expr_error(ctx->msgs, rt, "%s", emsg); - if (ctx->pctx.family == NFPROTO_IPV6) + if (pctx->family == NFPROTO_IPV6) return expr_error(ctx->msgs, rt, "%s nexthop will not match", "ip"); break; case NFT_RT_NEXTHOP6: if (rt->dtype != &ip6addr_type) return expr_error(ctx->msgs, rt, "%s", emsg); - if (ctx->pctx.family == NFPROTO_IPV4) + if (pctx->family == NFPROTO_IPV4) return expr_error(ctx->msgs, rt, "%s nexthop will not match", "ip6"); break; default: @@ -798,8 +1110,10 @@ static int ct_gen_nh_dependency(struct eval_ctx *ctx, struct expr *ct) const struct proto_desc *base, *base_now; struct expr *left, *right, *dep; struct stmt *nstmt = NULL; + struct proto_ctx *pctx; - base_now = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + pctx = eval_proto_ctx(ctx); + base_now = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; switch (ct->ct.nfproto) { case NFPROTO_IPV4: @@ -809,10 +1123,10 @@ static int ct_gen_nh_dependency(struct eval_ctx *ctx, struct expr *ct) base = &proto_ip6; break; default: - base = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + base = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; if (base == &proto_ip) ct->ct.nfproto = NFPROTO_IPV4; - else if (base == &proto_ip) + else if (base == &proto_ip6) ct->ct.nfproto = NFPROTO_IPV6; if (base) @@ -831,8 +1145,8 @@ static int ct_gen_nh_dependency(struct eval_ctx *ctx, struct expr *ct) return expr_error(ctx->msgs, ct, "conflicting dependencies: %s vs. %s\n", base->name, - ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc->name); - switch (ctx->pctx.family) { + pctx->protocol[PROTO_BASE_NETWORK_HDR].desc->name); + switch (pctx->family) { case NFPROTO_IPV4: case NFPROTO_IPV6: return 0; @@ -845,10 +1159,12 @@ static int ct_gen_nh_dependency(struct eval_ctx *ctx, struct expr *ct) constant_data_ptr(ct->ct.nfproto, left->len)); dep = relational_expr_alloc(&ct->location, OP_EQ, left, right); - relational_expr_pctx_update(&ctx->pctx, dep); + relational_expr_pctx_update(pctx, dep); nstmt = expr_stmt_alloc(&dep->location, dep); - rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt); + + if (rule_stmt_dep_add(ctx, nstmt, ctx->stmt) < 0) + return -1; return 0; } @@ -861,8 +1177,10 @@ static int expr_evaluate_ct(struct eval_ctx *ctx, struct expr **expr) { const struct proto_desc *base, *error; struct expr *ct = *expr; + struct proto_ctx *pctx; - base = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + pctx = eval_proto_ctx(ctx); + base = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; switch (ct->ct.key) { case NFT_CT_SRC: @@ -887,13 +1205,13 @@ static int expr_evaluate_ct(struct eval_ctx *ctx, struct expr **expr) break; } - ct_expr_update_type(&ctx->pctx, ct); + ct_expr_update_type(pctx, ct); return expr_evaluate_primary(ctx, expr); err_conflict: return stmt_binary_error(ctx, ct, - &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR], + &pctx->protocol[PROTO_BASE_NETWORK_HDR], "conflicting protocols specified: %s vs. %s", base->name, error->name); } @@ -940,7 +1258,6 @@ static int expr_evaluate_prefix(struct eval_ctx *ctx, struct expr **expr) mpz_prefixmask(mask->value, base->len, prefix->prefix_len); break; case TYPE_STRING: - mpz_init2(mask->value, base->len); mpz_bitmask(mask->value, prefix->prefix_len); break; } @@ -951,7 +1268,7 @@ static int expr_evaluate_prefix(struct eval_ctx *ctx, struct expr **expr) base = prefix->prefix; assert(expr_is_constant(base)); - prefix->dtype = base->dtype; + prefix->dtype = datatype_get(base->dtype); prefix->byteorder = base->byteorder; prefix->len = base->len; prefix->flags |= EXPR_F_CONSTANT; @@ -1002,9 +1319,8 @@ static int expr_evaluate_range(struct eval_ctx *ctx, struct expr **expr) left = range->left; right = range->right; - if (mpz_cmp(left->value, right->value) >= 0) - return expr_error(ctx->msgs, range, - "Range has zero or negative size"); + if (mpz_cmp(left->value, right->value) > 0) + return expr_error(ctx->msgs, range, "Range negative size"); datatype_set(range, left->dtype); range->flags |= EXPR_F_CONSTANT; @@ -1017,12 +1333,10 @@ static int expr_evaluate_range(struct eval_ctx *ctx, struct expr **expr) */ static int expr_evaluate_unary(struct eval_ctx *ctx, struct expr **expr) { - struct expr *unary = *expr, *arg; + struct expr *unary = *expr, *arg = unary->arg; enum byteorder byteorder; - if (expr_evaluate(ctx, &unary->arg) < 0) - return -1; - arg = unary->arg; + /* unary expression arguments has already been evaluated. */ assert(!expr_is_constant(arg)); assert(expr_basetype(arg)->type == TYPE_INTEGER); @@ -1041,7 +1355,7 @@ static int expr_evaluate_unary(struct eval_ctx *ctx, struct expr **expr) BUG("invalid unary operation %u\n", unary->op); } - unary->dtype = arg->dtype; + unary->dtype = datatype_clone(arg->dtype); unary->byteorder = byteorder; unary->len = arg->len; return 0; @@ -1108,14 +1422,24 @@ static int constant_binop_simplify(struct eval_ctx *ctx, struct expr **expr) static int expr_evaluate_shift(struct eval_ctx *ctx, struct expr **expr) { struct expr *op = *expr, *left = op->left, *right = op->right; + unsigned int shift, max_shift_len; - if (mpz_get_uint32(right->value) >= left->len) + /* mpz_get_uint32 has assert() for huge values */ + if (mpz_cmp_ui(right->value, UINT_MAX) > 0) return expr_binary_error(ctx->msgs, right, left, - "%s shift of %u bits is undefined " - "for type of %u bits width", + "shifts exceeding %u bits are not supported", UINT_MAX); + + shift = mpz_get_uint32(right->value); + if (ctx->stmt_len > left->len) + max_shift_len = ctx->stmt_len; + else + max_shift_len = left->len; + + if (shift >= max_shift_len) + return expr_binary_error(ctx->msgs, right, left, + "%s shift of %u bits is undefined for type of %u bits width", op->op == OP_LSHIFT ? "Left" : "Right", - mpz_get_uint32(right->value), - left->len); + shift, max_shift_len); /* Both sides need to be in host byte order */ if (byteorder_conversion(ctx, &op->left, BYTEORDER_HOST_ENDIAN) < 0) @@ -1124,9 +1448,9 @@ static int expr_evaluate_shift(struct eval_ctx *ctx, struct expr **expr) if (byteorder_conversion(ctx, &op->right, BYTEORDER_HOST_ENDIAN) < 0) return -1; - op->dtype = &integer_type; + datatype_set(op, &integer_type); op->byteorder = BYTEORDER_HOST_ENDIAN; - op->len = left->len; + op->len = max_shift_len; if (expr_is_constant(left)) return constant_binop_simplify(ctx, expr); @@ -1136,13 +1460,32 @@ static int expr_evaluate_shift(struct eval_ctx *ctx, struct expr **expr) static int expr_evaluate_bitwise(struct eval_ctx *ctx, struct expr **expr) { struct expr *op = *expr, *left = op->left; + const struct datatype *dtype; + enum byteorder byteorder; + unsigned int max_len; + + if (ctx->stmt_len > left->len) { + max_len = ctx->stmt_len; + byteorder = BYTEORDER_HOST_ENDIAN; + dtype = &integer_type; + + /* Both sides need to be in host byte order */ + if (byteorder_conversion(ctx, &op->left, BYTEORDER_HOST_ENDIAN) < 0) + return -1; - if (byteorder_conversion(ctx, &op->right, left->byteorder) < 0) + left = op->left; + } else { + max_len = left->len; + byteorder = left->byteorder; + dtype = left->dtype; + } + + if (byteorder_conversion(ctx, &op->right, byteorder) < 0) return -1; - op->dtype = left->dtype; - op->byteorder = left->byteorder; - op->len = left->len; + datatype_set(op, dtype); + op->byteorder = byteorder; + op->len = max_len; if (expr_is_constant(left)) return constant_binop_simplify(ctx, expr); @@ -1159,14 +1502,27 @@ static int expr_evaluate_binop(struct eval_ctx *ctx, struct expr **expr) { struct expr *op = *expr, *left, *right; const char *sym = expr_op_symbols[op->op]; + unsigned int max_shift_len = ctx->ectx.len; + int ret = -1; + + if (ctx->recursion >= USHRT_MAX) + return expr_binary_error(ctx->msgs, op, NULL, + "Binary operation limit %u reached ", + ctx->recursion); + ctx->recursion++; if (expr_evaluate(ctx, &op->left) < 0) return -1; left = op->left; - if (op->op == OP_LSHIFT || op->op == OP_RSHIFT) + if (op->op == OP_LSHIFT || op->op == OP_RSHIFT) { + if (left->len > max_shift_len) + max_shift_len = left->len; + __expr_set_context(&ctx->ectx, &integer_type, - left->byteorder, ctx->ectx.len, 0); + left->byteorder, max_shift_len, 0); + } + if (expr_evaluate(ctx, &op->right) < 0) return -1; right = op->right; @@ -1199,20 +1555,51 @@ static int expr_evaluate_binop(struct eval_ctx *ctx, struct expr **expr) "for %s expressions", sym, expr_name(right)); - /* The grammar guarantees this */ - assert(expr_basetype(left) == expr_basetype(right)); + if (!datatype_equal(expr_basetype(left), expr_basetype(right))) + return expr_binary_error(ctx->msgs, left, op, + "Binary operation (%s) with different base types " + "(%s vs %s) is not supported", + sym, expr_basetype(left)->name, expr_basetype(right)->name); switch (op->op) { case OP_LSHIFT: case OP_RSHIFT: - return expr_evaluate_shift(ctx, expr); + ret = expr_evaluate_shift(ctx, expr); + break; case OP_AND: case OP_XOR: case OP_OR: - return expr_evaluate_bitwise(ctx, expr); + ret = expr_evaluate_bitwise(ctx, expr); + break; default: BUG("invalid binary operation %u\n", op->op); } + + + if (ctx->recursion == 0) + BUG("recursion counter underflow"); + + /* can't check earlier: evaluate functions might do constant-merging + expr_free. + * + * So once we've evaluate everything check for remaining length of the + * binop chain. + */ + if (--ctx->recursion == 0) { + unsigned int to_linearize = 0; + + op = *expr; + while (op && op->etype == EXPR_BINOP && op->left != NULL) { + to_linearize++; + op = op->left; + + if (to_linearize >= NFT_MAX_EXPR_RECURSION) + return expr_binary_error(ctx->msgs, op, NULL, + "Binary operation limit %u reached ", + NFT_MAX_EXPR_RECURSION); + } + } + + return ret; } static int list_member_evaluate(struct eval_ctx *ctx, struct expr **expr) @@ -1227,17 +1614,39 @@ static int list_member_evaluate(struct eval_ctx *ctx, struct expr **expr) return err; } -static int expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr, - bool eval) +static int expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr) { const struct datatype *dtype = ctx->ectx.dtype, *tmp; uint32_t type = dtype ? dtype->type : 0, ntype = 0; int off = dtype ? dtype->subtypes : 0; unsigned int flags = EXPR_F_CONSTANT | EXPR_F_SINGLETON; - struct expr *i, *next; + const struct list_head *expressions = NULL; + struct expr *i, *next, *key = NULL; + const struct expr *key_ctx = NULL; + bool runaway = false; + uint32_t size = 0; + + if (ctx->ectx.key && ctx->ectx.key->etype == EXPR_CONCAT) { + key_ctx = ctx->ectx.key; + assert(!list_empty(&ctx->ectx.key->expressions)); + key = list_first_entry(&ctx->ectx.key->expressions, struct expr, list); + expressions = &ctx->ectx.key->expressions; + } list_for_each_entry_safe(i, next, &(*expr)->expressions, list) { - unsigned dsize_bytes; + enum byteorder bo = BYTEORDER_INVALID; + unsigned dsize_bytes, dsize = 0; + + if (runaway) { + return expr_binary_error(ctx->msgs, *expr, key_ctx, + "too many concatenation components"); + } + + if (i->etype == EXPR_CT && + (i->ct.key == NFT_CT_SRC || + i->ct.key == NFT_CT_DST)) + return expr_error(ctx->msgs, i, + "specify either ip or ip6 for address matching"); if (expr_is_constant(*expr) && dtype && off == 0) return expr_binary_error(ctx->msgs, i, *expr, @@ -1245,32 +1654,77 @@ static int expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr, "expecting %s", dtype->desc); - if (dtype == NULL) + if (key) { + tmp = key->dtype; + dsize = key->len; + bo = key->byteorder; + off--; + } else if (dtype == NULL || off == 0) { tmp = datatype_lookup(TYPE_INVALID); - else + } else { tmp = concat_subtype_lookup(type, --off); - expr_set_context(&ctx->ectx, tmp, tmp->size); + dsize = tmp->size; + bo = tmp->byteorder; + } + + __expr_set_context(&ctx->ectx, tmp, bo, dsize, 0); + ctx->ectx.key = i; - if (eval && list_member_evaluate(ctx, &i) < 0) + if (list_member_evaluate(ctx, &i) < 0) return -1; + + if (i->etype == EXPR_SET) + return expr_error(ctx->msgs, i, + "cannot use %s in concatenation", + expr_name(i)); + + if (!i->dtype) + return expr_error(ctx->msgs, i, + "cannot use %s in concatenation, lacks datatype", + expr_name(i)); + flags &= i->flags; + if (!key && i->dtype->type == TYPE_INTEGER) { + struct datatype *clone; + + clone = datatype_clone(i->dtype); + clone->size = i->len; + clone->byteorder = i->byteorder; + __datatype_set(i, clone); + } + if (dtype == NULL && i->dtype->size == 0) return expr_binary_error(ctx->msgs, i, *expr, "can not use variable sized " "data types (%s) in concat " "expressions", i->dtype->name); + if (dsize == 0) /* reload after evaluation or clone above */ + dsize = i->dtype->size; ntype = concat_subtype_add(ntype, i->dtype->type); - dsize_bytes = div_round_up(i->dtype->size, BITS_PER_BYTE); + dsize_bytes = div_round_up(dsize, BITS_PER_BYTE); (*expr)->field_len[(*expr)->field_count++] = dsize_bytes; + size += netlink_padded_len(dsize); + if (key && expressions) { + if (list_is_last(&key->list, expressions)) + runaway = true; + else + key = list_next_entry(key, list); + } + + ctx->inner_desc = NULL; + + if (size > NFT_MAX_EXPR_LEN_BITS) + return expr_error(ctx->msgs, i, "Concatenation of size %u exceeds maximum size of %u", + size, NFT_MAX_EXPR_LEN_BITS); } (*expr)->flags |= flags; - datatype_set(*expr, concat_type_alloc(ntype)); - (*expr)->len = (*expr)->dtype->size; + __datatype_set(*expr, concat_type_alloc(ntype)); + (*expr)->len = size; if (off > 0) return expr_error(ctx->msgs, *expr, @@ -1279,6 +1733,10 @@ static int expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr, dtype->desc, (*expr)->dtype->desc); expr_set_context(&ctx->ectx, (*expr)->dtype, (*expr)->len); + if (!key_ctx) + ctx->ectx.key = *expr; + else + ctx->ectx.key = key_ctx; return 0; } @@ -1290,16 +1748,22 @@ static int expr_evaluate_list(struct eval_ctx *ctx, struct expr **expr) mpz_init_set_ui(val, 0); list_for_each_entry_safe(i, next, &list->expressions, list) { - if (list_member_evaluate(ctx, &i) < 0) + if (list_member_evaluate(ctx, &i) < 0) { + mpz_clear(val); return -1; - if (i->etype != EXPR_VALUE) + } + if (i->etype != EXPR_VALUE) { + mpz_clear(val); return expr_error(ctx->msgs, i, "List member must be a constant " "value"); - if (i->dtype->basetype->type != TYPE_BITMASK) + } + if (datatype_basetype(i->dtype)->type != TYPE_BITMASK) { + mpz_clear(val); return expr_error(ctx->msgs, i, "Basetype of type %s is not bitmask", i->dtype->desc); + } mpz_ior(val, val, i->value); } @@ -1313,27 +1777,68 @@ static int expr_evaluate_list(struct eval_ctx *ctx, struct expr **expr) return 0; } -static int expr_evaluate_set_elem(struct eval_ctx *ctx, struct expr **expr) +static int __expr_evaluate_set_elem(struct eval_ctx *ctx, struct expr *elem) { + int num_elem_exprs = 0, num_set_exprs = 0; struct set *set = ctx->set; - struct expr *elem = *expr; + struct stmt *stmt; + + list_for_each_entry(stmt, &elem->stmt_list, list) + num_elem_exprs++; + list_for_each_entry(stmt, &set->stmt_list, list) + num_set_exprs++; + + if (num_elem_exprs > 0) { + struct stmt *set_stmt, *elem_stmt; + + if (num_set_exprs > 0 && num_elem_exprs != num_set_exprs) { + return expr_error(ctx->msgs, elem, + "number of statements mismatch, set expects %d " + "but element has %d", num_set_exprs, + num_elem_exprs); + } else if (num_set_exprs == 0) { + if (!(set->flags & NFT_SET_ANONYMOUS) && + !(set->flags & NFT_SET_EVAL)) { + elem_stmt = list_first_entry(&elem->stmt_list, struct stmt, list); + return stmt_error(ctx, elem_stmt, + "missing statement in %s declaration", + set_is_map(set->flags) ? "map" : "set"); + } + return 0; + } - if (elem->stmt) { - if (set->stmt && set->stmt->ops != elem->stmt->ops) { - return stmt_error(ctx, elem->stmt, - "statement mismatch, element expects %s, " - "but %s has type %s", - elem->stmt->ops->name, - set_is_map(set->flags) ? "map" : "set", - set->stmt->ops->name); - } else if (!set->stmt && !(set->flags & NFT_SET_EVAL)) { - return stmt_error(ctx, elem->stmt, - "missing %s statement in %s definition", - elem->stmt->ops->name, - set_is_map(set->flags) ? "map" : "set"); + set_stmt = list_first_entry(&set->stmt_list, struct stmt, list); + + list_for_each_entry(elem_stmt, &elem->stmt_list, list) { + if (set_stmt->ops != elem_stmt->ops) { + return stmt_error(ctx, elem_stmt, + "statement mismatch, element expects %s, " + "but %s has type %s", + elem_stmt->ops->name, + set_is_map(set->flags) ? "map" : "set", + set_stmt->ops->name); + } + set_stmt = list_next_entry(set_stmt, list); } } + return 0; +} + +static int expr_evaluate_set_elem(struct eval_ctx *ctx, struct expr **expr) +{ + struct expr *elem = *expr; + const struct expr *key; + + if (ctx->set) { + if (__expr_evaluate_set_elem(ctx, elem) < 0) + return -1; + + key = ctx->set->key; + __expr_set_context(&ctx->ectx, key->dtype, key->byteorder, key->len, 0); + ctx->ectx.key = key; + } + if (expr_evaluate(ctx, &elem->key) < 0) return -1; @@ -1342,9 +1847,19 @@ static int expr_evaluate_set_elem(struct eval_ctx *ctx, struct expr **expr) switch (elem->key->etype) { case EXPR_PREFIX: case EXPR_RANGE: - return expr_error(ctx->msgs, elem, - "You must add 'flags interval' to your %s declaration if you want to add %s elements", - set_is_map(ctx->set->flags) ? "map" : "set", expr_name(elem->key)); + key = elem->key; + goto err_missing_flag; + case EXPR_CONCAT: + list_for_each_entry(key, &elem->key->expressions, list) { + switch (key->etype) { + case EXPR_PREFIX: + case EXPR_RANGE: + goto err_missing_flag; + default: + break; + } + } + break; default: break; } @@ -1353,30 +1868,119 @@ static int expr_evaluate_set_elem(struct eval_ctx *ctx, struct expr **expr) datatype_set(elem, elem->key->dtype); elem->len = elem->key->len; elem->flags = elem->key->flags; + + return 0; + +err_missing_flag: + return expr_error(ctx->msgs, key, + "You must add 'flags interval' to your %s declaration if you want to add %s elements", + set_is_map(ctx->set->flags) ? "map" : "set", expr_name(key)); +} + +static int expr_evaluate_set_elem_catchall(struct eval_ctx *ctx, struct expr **expr) +{ + struct expr *elem = *expr; + + if (ctx->set) + elem->len = ctx->set->key->len; + return 0; } +static const struct expr *expr_set_elem(const struct expr *expr) +{ + if (expr->etype == EXPR_MAPPING) + return expr->left; + + return expr; +} + +static int interval_set_eval(struct eval_ctx *ctx, struct set *set, + struct expr *init) +{ + int ret; + + if (!init) + return 0; + + ret = 0; + switch (ctx->cmd->op) { + case CMD_CREATE: + case CMD_ADD: + case CMD_REPLACE: + case CMD_INSERT: + if (set->automerge) { + ret = set_automerge(ctx->msgs, ctx->cmd, set, init, + ctx->nft->debug_mask); + } else { + ret = set_overlap(ctx->msgs, set, init); + } + break; + case CMD_DELETE: + case CMD_DESTROY: + ret = set_delete(ctx->msgs, ctx->cmd, set, init, + ctx->nft->debug_mask); + break; + case CMD_GET: + break; + default: + BUG("unhandled op %d\n", ctx->cmd->op); + break; + } + + return ret; +} + +static void expr_evaluate_set_ref(struct eval_ctx *ctx, struct expr *expr) +{ + struct set *set = expr->set; + + expr_set_context(&ctx->ectx, set->key->dtype, set->key->len); + ctx->ectx.key = set->key; +} + static int expr_evaluate_set(struct eval_ctx *ctx, struct expr **expr) { struct expr *set = *expr, *i, *next; + const struct expr *elem; list_for_each_entry_safe(i, next, &set->expressions, list) { if (list_member_evaluate(ctx, &i) < 0) return -1; - if (i->etype == EXPR_SET_ELEM && - i->key->etype == EXPR_SET_REF) + if (i->etype == EXPR_MAPPING && + i->left->etype == EXPR_SET_ELEM && + i->left->key->etype == EXPR_SET) { + struct expr *new, *j; + + list_for_each_entry(j, &i->left->key->expressions, list) { + new = mapping_expr_alloc(&i->location, + expr_get(j), + expr_get(i->right)); + list_add_tail(&new->list, &set->expressions); + set->size++; + } + list_del(&i->list); + expr_free(i); + continue; + } + + elem = expr_set_elem(i); + + if (elem->etype == EXPR_SET_ELEM && + elem->key->etype == EXPR_SET_REF) return expr_error(ctx->msgs, i, "Set reference cannot be part of another set"); - if (i->etype == EXPR_SET_ELEM && - i->key->etype == EXPR_SET) { - struct expr *new = expr_clone(i->key); + if (elem->etype == EXPR_SET_ELEM && + elem->key->etype == EXPR_SET) { + struct expr *new = expr_get(elem->key); - set->set_flags |= i->key->set_flags; + set->set_flags |= elem->key->set_flags; list_replace(&i->list, &new->list); expr_free(i); i = new; + elem = expr_set_elem(i); } if (!expr_is_constant(i)) @@ -1392,30 +1996,114 @@ static int expr_evaluate_set(struct eval_ctx *ctx, struct expr **expr) expr_free(i); } else if (!expr_is_singleton(i)) { set->set_flags |= NFT_SET_INTERVAL; - if (i->key->etype == EXPR_CONCAT) + if (elem->key->etype == EXPR_CONCAT) set->set_flags |= NFT_SET_CONCAT; } } - if (ctx->set && (ctx->set->flags & NFT_SET_CONCAT)) - set->set_flags |= NFT_SET_CONCAT; + if (ctx->set) { + if (ctx->set->flags & NFT_SET_CONCAT) + set->set_flags |= NFT_SET_CONCAT; + } set->set_flags |= NFT_SET_CONSTANT; datatype_set(set, ctx->ectx.dtype); set->len = ctx->ectx.len; set->flags |= EXPR_F_CONSTANT; + return 0; } static int binop_transfer(struct eval_ctx *ctx, struct expr **expr); + +static void map_set_concat_info(struct expr *map) +{ + map->mappings->set->flags |= map->mappings->set->init->set_flags; + + if (map->mappings->set->flags & NFT_SET_INTERVAL && + map->map->etype == EXPR_CONCAT) { + memcpy(&map->mappings->set->desc.field_len, &map->map->field_len, + sizeof(map->mappings->set->desc.field_len)); + map->mappings->set->desc.field_count = map->map->field_count; + map->mappings->flags |= NFT_SET_CONCAT; + } +} + +static void __mapping_expr_expand(struct expr *i) +{ + struct expr *j, *range, *next; + + assert(i->etype == EXPR_MAPPING); + switch (i->right->etype) { + case EXPR_VALUE: + range = range_expr_alloc(&i->location, expr_get(i->right), expr_get(i->right)); + expr_free(i->right); + i->right = range; + break; + case EXPR_CONCAT: + list_for_each_entry_safe(j, next, &i->right->expressions, list) { + if (j->etype != EXPR_VALUE) + continue; + + range = range_expr_alloc(&j->location, expr_get(j), expr_get(j)); + list_replace(&j->list, &range->list); + expr_free(j); + } + i->right->flags &= ~EXPR_F_SINGLETON; + break; + default: + break; + } +} + +static int mapping_expr_expand(struct eval_ctx *ctx) +{ + struct expr *i; + + if (!set_is_anonymous(ctx->set->flags)) + return 0; + + list_for_each_entry(i, &ctx->set->init->expressions, list) { + if (i->etype != EXPR_MAPPING) + return expr_error(ctx->msgs, i, + "expected mapping, not %s", expr_name(i)); + __mapping_expr_expand(i); + } + + return 0; +} + +static bool datatype_compatible(const struct datatype *a, const struct datatype *b) +{ + return (a->type == TYPE_MARK && + datatype_equal(datatype_basetype(a), datatype_basetype(b))) || + datatype_equal(a, b); +} + static int expr_evaluate_map(struct eval_ctx *ctx, struct expr **expr) { - struct expr_ctx ectx = ctx->ectx; struct expr *map = *expr, *mappings; - const struct datatype *dtype; + struct expr_ctx ectx = ctx->ectx; struct expr *key, *data; + if (map->map->etype == EXPR_CT && + (map->map->ct.key == NFT_CT_SRC || + map->map->ct.key == NFT_CT_DST)) + return expr_error(ctx->msgs, map->map, + "specify either ip or ip6 for address matching"); + else if (map->map->etype == EXPR_CONCAT) { + struct expr *i; + + list_for_each_entry(i, &map->map->expressions, list) { + if (i->etype == EXPR_CT && + (i->ct.key == NFT_CT_SRC || + i->ct.key == NFT_CT_DST)) + return expr_error(ctx->msgs, i, + "specify either ip or ip6 for address matching"); + } + } + expr_set_context(&ctx->ectx, NULL, 0); if (expr_evaluate(ctx, &map->map) < 0) return -1; @@ -1423,23 +2111,45 @@ static int expr_evaluate_map(struct eval_ctx *ctx, struct expr **expr) return expr_error(ctx->msgs, map->map, "Map expression can not be constant"); + ctx->stmt_len = 0; mappings = map->mappings; mappings->set_flags |= NFT_SET_MAP; switch (map->mappings->etype) { + case EXPR_VARIABLE: case EXPR_SET: - key = constant_expr_alloc(&map->location, - ctx->ectx.dtype, - ctx->ectx.byteorder, - ctx->ectx.len, NULL); + if (ctx->ectx.key && ctx->ectx.key->etype == EXPR_CONCAT) { + key = expr_clone(ctx->ectx.key); + } else { + key = constant_expr_alloc(&map->location, + ctx->ectx.dtype, + ctx->ectx.byteorder, + ctx->ectx.len, NULL); + } - dtype = set_datatype_alloc(ectx.dtype, ectx.byteorder); - data = constant_expr_alloc(&netlink_location, dtype, - dtype->byteorder, ectx.len, NULL); + if (!ectx.dtype) { + expr_free(key); + return expr_error(ctx->msgs, map, + "Implicit map expression without known datatype"); + } + + if (ectx.dtype->type == TYPE_VERDICT) { + data = verdict_expr_alloc(&netlink_location, 0, NULL); + } else { + const struct datatype *dtype; + + dtype = set_datatype_alloc(ectx.dtype, ectx.byteorder); + data = constant_expr_alloc(&netlink_location, dtype, + dtype->byteorder, ectx.len, NULL); + datatype_free(dtype); + } mappings = implicit_set_declaration(ctx, "__map%d", key, data, - mappings); + mappings, + NFT_SET_ANONYMOUS); + if (!mappings) + return -1; if (ectx.len && mappings->set->data->len != ectx.len) BUG("%d vs %d\n", mappings->set->data->len, ectx.len); @@ -1449,17 +2159,33 @@ static int expr_evaluate_map(struct eval_ctx *ctx, struct expr **expr) ctx->set = mappings->set; if (expr_evaluate(ctx, &map->mappings->set->init) < 0) return -1; + + if (map->mappings->set->init->etype != EXPR_SET) { + return expr_error(ctx->msgs, map->mappings->set->init, + "Expression is not a map"); + } + + if (set_is_interval(map->mappings->set->init->set_flags) && + !(map->mappings->set->init->set_flags & NFT_SET_CONCAT) && + interval_set_eval(ctx, ctx->set, map->mappings->set->init) < 0) + return -1; + expr_set_context(&ctx->ectx, ctx->set->key->dtype, ctx->set->key->len); if (binop_transfer(ctx, expr) < 0) return -1; - if (ctx->set->data->flags & EXPR_F_INTERVAL) + if (ctx->set->data->flags & EXPR_F_INTERVAL) { ctx->set->data->len *= 2; + if (mapping_expr_expand(ctx)) + return -1; + } + ctx->set->key->len = ctx->ectx.len; ctx->set = NULL; map = *expr; - map->mappings->set->flags |= map->mappings->set->init->set_flags; + + map_set_concat_info(map); break; case EXPR_SYMBOL: if (expr_evaluate(ctx, &map->mappings) < 0) @@ -1469,12 +2195,18 @@ static int expr_evaluate_map(struct eval_ctx *ctx, struct expr **expr) return expr_error(ctx->msgs, map->mappings, "Expression is not a map"); break; + case EXPR_SET_REF: + /* symbol has been already evaluated to set reference */ + if (!set_is_map(mappings->set->flags)) + return expr_error(ctx->msgs, map->mappings, + "Expression is not a map"); + break; default: - BUG("invalid mapping expression %s\n", - expr_name(map->mappings)); + return expr_binary_error(ctx->msgs, map->mappings, map->map, + "invalid mapping expression %s", expr_name(map->mappings)); } - if (!datatype_equal(map->map->dtype, map->mappings->set->key->dtype)) + if (!datatype_compatible(map->mappings->set->key->dtype, map->map->dtype)) return expr_binary_error(ctx->msgs, map->mappings, map->map, "datatype mismatch, map expects %s, " "mapping expression has type %s", @@ -1492,6 +2224,26 @@ static int expr_evaluate_map(struct eval_ctx *ctx, struct expr **expr) return 0; } +static bool data_mapping_has_interval(struct expr *data) +{ + struct expr *i; + + if (data->etype == EXPR_RANGE || + data->etype == EXPR_PREFIX) + return true; + + if (data->etype != EXPR_CONCAT) + return false; + + list_for_each_entry(i, &data->expressions, list) { + if (i->etype == EXPR_RANGE || + i->etype == EXPR_PREFIX) + return true; + } + + return false; +} + static int expr_evaluate_mapping(struct eval_ctx *ctx, struct expr **expr) { struct expr *mapping = *expr; @@ -1512,17 +2264,14 @@ static int expr_evaluate_mapping(struct eval_ctx *ctx, struct expr **expr) "Key must be a constant"); mapping->flags |= mapping->left->flags & EXPR_F_SINGLETON; - if (set->data) { - if (!set_is_anonymous(set->flags) && - set->data->flags & EXPR_F_INTERVAL) - datalen = set->data->len / 2; - else - datalen = set->data->len; - - expr_set_context(&ctx->ectx, set->data->dtype, datalen); - } else { - assert((set->flags & NFT_SET_MAP) == 0); - } + assert(set->data != NULL); + if (!set_is_anonymous(set->flags) && + set->data->flags & EXPR_F_INTERVAL) + datalen = set->data->len / 2; + else + datalen = set->data->len; + __expr_set_context(&ctx->ectx, set->data->dtype, + set->data->byteorder, datalen, 0); if (expr_evaluate(ctx, &mapping->right) < 0) return -1; @@ -1531,15 +2280,23 @@ static int expr_evaluate_mapping(struct eval_ctx *ctx, struct expr **expr) "Value must be a constant"); if (set_is_anonymous(set->flags) && - (mapping->right->etype == EXPR_RANGE || - mapping->right->etype == EXPR_PREFIX)) + data_mapping_has_interval(mapping->right)) set->data->flags |= EXPR_F_INTERVAL; + if (!set_is_anonymous(set->flags) && + set->data->flags & EXPR_F_INTERVAL) + __mapping_expr_expand(mapping); + if (!(set->data->flags & EXPR_F_INTERVAL) && !expr_is_singleton(mapping->right)) return expr_error(ctx->msgs, mapping->right, "Value must be a singleton"); + if (set_is_objmap(set->flags) && mapping->right->etype != EXPR_VALUE) + return expr_error(ctx->msgs, mapping->right, + "Object mapping data should be a value, not %s", + expr_name(mapping->right)); + mapping->flags |= EXPR_F_CONSTANT; return 0; } @@ -1565,17 +2322,20 @@ static void expr_dtype_integer_compatible(struct eval_ctx *ctx, static int expr_evaluate_numgen(struct eval_ctx *ctx, struct expr **exprp) { struct expr *expr = *exprp; + unsigned int maxval; expr_dtype_integer_compatible(ctx, expr); + maxval = expr->numgen.mod + expr->numgen.offset - 1; __expr_set_context(&ctx->ectx, expr->dtype, expr->byteorder, expr->len, - expr->numgen.mod - 1); + maxval); return 0; } static int expr_evaluate_hash(struct eval_ctx *ctx, struct expr **exprp) { struct expr *expr = *exprp; + unsigned int maxval; expr_dtype_integer_compatible(ctx, expr); @@ -1588,8 +2348,9 @@ static int expr_evaluate_hash(struct eval_ctx *ctx, struct expr **exprp) * expression to be hashed. Since this input is transformed to a 4 bytes * integer, restore context to the datatype that results from hashing. */ + maxval = expr->hash.mod + expr->hash.offset - 1; __expr_set_context(&ctx->ectx, expr->dtype, expr->byteorder, expr->len, - expr->hash.mod - 1); + maxval); return 0; } @@ -1658,8 +2419,6 @@ static int binop_transfer_one(struct eval_ctx *ctx, return 0; } - expr_get(*right); - switch (left->op) { case OP_LSHIFT: (*right) = binop_expr_alloc(&(*right)->location, OP_RSHIFT, @@ -1772,7 +2531,7 @@ static int binop_transfer(struct eval_ctx *ctx, struct expr **expr) return 0; } -static bool lhs_is_meta_hour(const struct expr *meta) +bool lhs_is_meta_hour(const struct expr *meta) { if (meta->etype != EXPR_META) return false; @@ -1781,7 +2540,7 @@ static bool lhs_is_meta_hour(const struct expr *meta) meta->meta.key == NFT_META_TIME_DAY; } -static void swap_values(struct expr *range) +void range_expr_swap_values(struct expr *range) { struct expr *left_tmp; @@ -1798,16 +2557,54 @@ static bool range_needs_swap(const struct expr *range) return mpz_cmp(left->value, right->value) > 0; } +static void optimize_singleton_set(struct expr *rel, struct expr **expr) +{ + struct expr *set = rel->right, *i; + + i = list_first_entry(&set->expressions, struct expr, list); + if (i->etype == EXPR_SET_ELEM && + list_empty(&i->stmt_list)) { + + switch (i->key->etype) { + case EXPR_PREFIX: + case EXPR_RANGE: + case EXPR_VALUE: + rel->right = *expr = i->key; + i->key = NULL; + expr_free(set); + break; + default: + break; + } + } + + if (rel->op == OP_IMPLICIT && + rel->right->dtype->basetype && + rel->right->dtype->basetype->type == TYPE_BITMASK && + rel->right->dtype->type != TYPE_CT_STATE) { + rel->op = OP_EQ; + } +} + static int expr_evaluate_relational(struct eval_ctx *ctx, struct expr **expr) { struct expr *rel = *expr, *left, *right; + struct proto_ctx *pctx; struct expr *range; int ret; + right = rel->right; + if (right->etype == EXPR_SYMBOL && + right->symtype == SYMBOL_SET && + expr_evaluate(ctx, &rel->right) < 0) + return -1; + if (expr_evaluate(ctx, &rel->left) < 0) return -1; left = rel->left; + pctx = eval_proto_ctx(ctx); + if (rel->right->etype == EXPR_RANGE && lhs_is_meta_hour(rel->left)) { ret = __expr_evaluate_range(ctx, &rel->right); if (ret) @@ -1825,10 +2622,10 @@ static int expr_evaluate_relational(struct eval_ctx *ctx, struct expr **expr) "Inverting range values for cross-day hour matching\n\n"); if (rel->op == OP_EQ || rel->op == OP_IMPLICIT) { - swap_values(range); + range_expr_swap_values(range); rel->op = OP_NEQ; } else if (rel->op == OP_NEQ) { - swap_values(range); + range_expr_swap_values(range); rel->op = OP_EQ; } } @@ -1868,6 +2665,19 @@ static int expr_evaluate_relational(struct eval_ctx *ctx, struct expr **expr) return expr_binary_error(ctx->msgs, right, left, "Cannot be used with right hand side constant value"); + if (left->etype != EXPR_CONCAT) { + switch (rel->op) { + case OP_EQ: + case OP_IMPLICIT: + case OP_NEQ: + if (right->etype == EXPR_SET && right->size == 1) + optimize_singleton_set(rel, &right); + break; + default: + break; + } + } + switch (rel->op) { case OP_EQ: case OP_IMPLICIT: @@ -1875,11 +2685,22 @@ static int expr_evaluate_relational(struct eval_ctx *ctx, struct expr **expr) * Update protocol context for payload and meta iiftype * equality expressions. */ - if (expr_is_singleton(right)) - relational_expr_pctx_update(&ctx->pctx, rel); + relational_expr_pctx_update(pctx, rel); /* fall through */ case OP_NEQ: + case OP_NEG: + if (rel->op == OP_NEG) { + if (left->etype == EXPR_BINOP) + return expr_binary_error(ctx->msgs, left, right, + "cannot combine negation with binary expression"); + if (right->etype != EXPR_VALUE || + right->dtype->basetype == NULL || + right->dtype->basetype->type != TYPE_BITMASK) + return expr_binary_error(ctx->msgs, left, right, + "negation can only be used with singleton bitmask values. Did you mean \"!=\"?"); + } + switch (right->etype) { case EXPR_RANGE: if (byteorder_conversion(ctx, &rel->left, BYTEORDER_BIG_ENDIAN) < 0) @@ -1904,7 +2725,11 @@ static int expr_evaluate_relational(struct eval_ctx *ctx, struct expr **expr) right = rel->right = implicit_set_declaration(ctx, "__set%d", expr_get(left), NULL, - right); + right, + NFT_SET_ANONYMOUS); + if (!right) + return -1; + /* fall through */ case EXPR_SET_REF: if (rel->left->etype == EXPR_CT && @@ -1976,11 +2801,12 @@ static int expr_evaluate_fib(struct eval_ctx *ctx, struct expr **exprp) static int expr_evaluate_meta(struct eval_ctx *ctx, struct expr **exprp) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); struct expr *meta = *exprp; switch (meta->meta.key) { case NFT_META_NFPROTO: - if (ctx->pctx.family != NFPROTO_INET && + if (pctx->family != NFPROTO_INET && meta->flags & EXPR_F_PROTOCOL) return expr_error(ctx->msgs, meta, "meta nfproto is only useful in the inet family"); @@ -1998,9 +2824,11 @@ static int expr_evaluate_meta(struct eval_ctx *ctx, struct expr **exprp) static int expr_evaluate_socket(struct eval_ctx *ctx, struct expr **expr) { + enum nft_socket_keys key = (*expr)->socket.key; int maxval = 0; - if((*expr)->socket.key == NFT_SOCKET_TRANSPARENT) + if (key == NFT_SOCKET_TRANSPARENT || + key == NFT_SOCKET_WILDCARD) maxval = 1; __expr_set_context(&ctx->ectx, (*expr)->dtype, (*expr)->byteorder, (*expr)->len, maxval); @@ -2021,7 +2849,16 @@ static int expr_evaluate_osf(struct eval_ctx *ctx, struct expr **expr) static int expr_evaluate_variable(struct eval_ctx *ctx, struct expr **exprp) { - struct expr *new = expr_clone((*exprp)->sym->expr); + struct symbol *sym = (*exprp)->sym; + struct expr *new; + + /* If variable is reused from different locations in the ruleset, then + * clone expression. + */ + if (sym->refcnt > 2) + new = expr_clone(sym->expr); + else + new = expr_get(sym->expr); if (expr_evaluate(ctx, &new) < 0) { expr_free(new); @@ -2036,9 +2873,10 @@ static int expr_evaluate_variable(struct eval_ctx *ctx, struct expr **exprp) static int expr_evaluate_xfrm(struct eval_ctx *ctx, struct expr **exprp) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); struct expr *expr = *exprp; - switch (ctx->pctx.family) { + switch (pctx->family) { case NFPROTO_IPV4: case NFPROTO_IPV6: case NFPROTO_INET: @@ -2051,6 +2889,54 @@ static int expr_evaluate_xfrm(struct eval_ctx *ctx, struct expr **exprp) return expr_evaluate_primary(ctx, exprp); } +static int expr_evaluate_flagcmp(struct eval_ctx *ctx, struct expr **exprp) +{ + struct expr *expr = *exprp, *binop, *rel; + + if (expr->op != OP_EQ && + expr->op != OP_NEQ) + return expr_error(ctx->msgs, expr, "either == or != is allowed"); + + binop = binop_expr_alloc(&expr->location, OP_AND, + expr_get(expr->flagcmp.expr), + expr_get(expr->flagcmp.mask)); + rel = relational_expr_alloc(&expr->location, expr->op, binop, + expr_get(expr->flagcmp.value)); + expr_free(expr); + *exprp = rel; + + return expr_evaluate(ctx, exprp); +} + +static int verdict_validate_chainlen(struct eval_ctx *ctx, + struct expr *chain) +{ + if (chain->len > NFT_CHAIN_MAXNAMELEN * BITS_PER_BYTE) + return expr_error(ctx->msgs, chain, + "chain name too long (%u, max %u)", + chain->len / BITS_PER_BYTE, + NFT_CHAIN_MAXNAMELEN); + + return 0; +} + +static int expr_evaluate_verdict(struct eval_ctx *ctx, struct expr **exprp) +{ + struct expr *expr = *exprp; + + switch (expr->verdict) { + case NFT_GOTO: + case NFT_JUMP: + if (expr->chain->etype == EXPR_VALUE && + verdict_validate_chainlen(ctx, expr->chain)) + return -1; + + break; + } + + return expr_evaluate_primary(ctx, exprp); +} + static int expr_evaluate(struct eval_ctx *ctx, struct expr **expr) { if (ctx->nft->debug_mask & NFT_DEBUG_EVALUATION) { @@ -2069,13 +2955,14 @@ static int expr_evaluate(struct eval_ctx *ctx, struct expr **expr) case EXPR_VARIABLE: return expr_evaluate_variable(ctx, expr); case EXPR_SET_REF: + expr_evaluate_set_ref(ctx, *expr); return 0; case EXPR_VALUE: return expr_evaluate_value(ctx, expr); case EXPR_EXTHDR: return expr_evaluate_exthdr(ctx, expr); case EXPR_VERDICT: - return expr_evaluate_primary(ctx, expr); + return expr_evaluate_verdict(ctx, expr); case EXPR_META: return expr_evaluate_meta(ctx, expr); case EXPR_SOCKET: @@ -2085,7 +2972,7 @@ static int expr_evaluate(struct eval_ctx *ctx, struct expr **expr) case EXPR_FIB: return expr_evaluate_fib(ctx, expr); case EXPR_PAYLOAD: - return expr_evaluate_payload(ctx, expr); + return expr_evaluate_payload_inner(ctx, expr); case EXPR_RT: return expr_evaluate_rt(ctx, expr); case EXPR_CT: @@ -2099,7 +2986,7 @@ static int expr_evaluate(struct eval_ctx *ctx, struct expr **expr) case EXPR_BINOP: return expr_evaluate_binop(ctx, expr); case EXPR_CONCAT: - return expr_evaluate_concat(ctx, expr, true); + return expr_evaluate_concat(ctx, expr); case EXPR_LIST: return expr_evaluate_list(ctx, expr); case EXPR_SET: @@ -2118,6 +3005,10 @@ static int expr_evaluate(struct eval_ctx *ctx, struct expr **expr) return expr_evaluate_hash(ctx, expr); case EXPR_XFRM: return expr_evaluate_xfrm(ctx, expr); + case EXPR_SET_ELEM_CATCHALL: + return expr_evaluate_set_elem_catchall(ctx, expr); + case EXPR_FLAGCMP: + return expr_evaluate_flagcmp(ctx, expr); default: BUG("unknown expression type %s\n", expr_name(*expr)); } @@ -2188,13 +3079,18 @@ static int __stmt_evaluate_arg(struct eval_ctx *ctx, struct stmt *stmt, "expression has type %s with length %d", dtype->desc, (*expr)->dtype->desc, (*expr)->len); - else if ((*expr)->dtype->type != TYPE_INTEGER && - !datatype_equal((*expr)->dtype, dtype)) + + if (!datatype_compatible(dtype, (*expr)->dtype)) return stmt_binary_error(ctx, *expr, stmt, /* verdict vs invalid? */ "datatype mismatch: expected %s, " "expression has type %s", dtype->desc, (*expr)->dtype->desc); + if (dtype->type == TYPE_MARK && + datatype_equal(datatype_basetype(dtype), datatype_basetype((*expr)->dtype)) && + !expr_is_constant(*expr)) + return byteorder_conversion(ctx, expr, byteorder); + /* we are setting a value, we can't use a set */ switch ((*expr)->etype) { case EXPR_SET: @@ -2209,6 +3105,10 @@ static int __stmt_evaluate_arg(struct eval_ctx *ctx, struct stmt *stmt, return byteorder_conversion(ctx, expr, byteorder); case EXPR_PREFIX: return stmt_prefix_conversion(ctx, expr, byteorder); + case EXPR_NUMGEN: + if (dtype->type == TYPE_IPADDR) + return byteorder_conversion(ctx, expr, byteorder); + break; default: break; } @@ -2227,6 +3127,25 @@ static int stmt_evaluate_arg(struct eval_ctx *ctx, struct stmt *stmt, return __stmt_evaluate_arg(ctx, stmt, dtype, len, byteorder, expr); } +/* like stmt_evaluate_arg, but keep existing context created + * by previous expr_evaluate(). + * + * This is needed for add/update statements: + * ctx->ectx.key has the set key, which may be needed for 'typeof' + * sets: the 'add/update' expression might contain integer data types. + * + * Without the key we cannot derive the element size. + */ +static int stmt_evaluate_key(struct eval_ctx *ctx, struct stmt *stmt, + const struct datatype *dtype, unsigned int len, + enum byteorder byteorder, struct expr **expr) +{ + if (expr_evaluate(ctx, expr) < 0) + return -1; + + return __stmt_evaluate_arg(ctx, stmt, dtype, len, byteorder, expr); +} + static int stmt_evaluate_verdict(struct eval_ctx *ctx, struct stmt *stmt) { if (stmt_evaluate_arg(ctx, stmt, &verdict_type, 0, 0, &stmt->expr) < 0) @@ -2237,12 +3156,16 @@ static int stmt_evaluate_verdict(struct eval_ctx *ctx, struct stmt *stmt) if (stmt->expr->verdict != NFT_CONTINUE) stmt->flags |= STMT_F_TERMINAL; if (stmt->expr->chain != NULL) { - if (expr_evaluate(ctx, &stmt->expr->chain) < 0) + if (stmt_evaluate_arg(ctx, stmt, &string_type, 0, 0, + &stmt->expr->chain) < 0) return -1; if (stmt->expr->chain->etype != EXPR_VALUE) { return expr_error(ctx->msgs, stmt->expr->chain, "not a value expression"); } + + if (verdict_validate_chainlen(ctx, stmt->expr->chain)) + return -1; } break; case EXPR_MAP: @@ -2257,6 +3180,9 @@ static bool stmt_evaluate_payload_need_csum(const struct expr *payload) { const struct proto_desc *desc; + if (payload->payload.base == PROTO_BASE_INNER_HDR) + return true; + desc = payload->payload.desc; return desc && desc->checksum_key; @@ -2265,14 +3191,22 @@ static bool stmt_evaluate_payload_need_csum(const struct expr *payload) static int stmt_evaluate_exthdr(struct eval_ctx *ctx, struct stmt *stmt) { struct expr *exthdr; + int ret; if (__expr_evaluate_exthdr(ctx, &stmt->exthdr.expr) < 0) return -1; exthdr = stmt->exthdr.expr; - return stmt_evaluate_arg(ctx, stmt, exthdr->dtype, exthdr->len, - BYTEORDER_BIG_ENDIAN, - &stmt->exthdr.val); + ret = stmt_evaluate_arg(ctx, stmt, exthdr->dtype, exthdr->len, + BYTEORDER_BIG_ENDIAN, + &stmt->exthdr.val); + if (ret < 0) + return ret; + + if (stmt->exthdr.val->etype == EXPR_RANGE) + return stmt_error_range(ctx, stmt, stmt->exthdr.val); + + return 0; } static int stmt_evaluate_payload(struct eval_ctx *ctx, struct stmt *stmt) @@ -2285,6 +3219,11 @@ static int stmt_evaluate_payload(struct eval_ctx *ctx, struct stmt *stmt) mpz_t bitmask, ff; bool need_csum; + if (stmt->payload.expr->payload.inner_desc) { + return expr_error(ctx->msgs, stmt->payload.expr, + "payload statement for this expression is not supported"); + } + if (__expr_evaluate_payload(ctx, stmt->payload.expr) < 0) return -1; @@ -2298,6 +3237,9 @@ static int stmt_evaluate_payload(struct eval_ctx *ctx, struct stmt *stmt) payload->byteorder) < 0) return -1; + if (stmt->payload.val->etype == EXPR_RANGE) + return stmt_error_range(ctx, stmt, stmt->payload.val); + need_csum = stmt_evaluate_payload_need_csum(payload); if (!payload_needs_adjustment(payload)) { @@ -2317,6 +3259,11 @@ static int stmt_evaluate_payload(struct eval_ctx *ctx, struct stmt *stmt) payload_byte_size = div_round_up(payload->len + extra_len, BITS_PER_BYTE); + if (payload_byte_size > sizeof(data)) + return expr_error(ctx->msgs, stmt->payload.expr, + "uneven load cannot span more than %u bytes, got %u", + sizeof(data), payload_byte_size); + if (need_csum && payload_byte_size & 1) { payload_byte_size++; @@ -2355,9 +3302,9 @@ static int stmt_evaluate_payload(struct eval_ctx *ctx, struct stmt *stmt) mpz_clear(ff); assert(sizeof(data) * BITS_PER_BYTE >= masklen); - mpz_export_data(data, bitmask, BYTEORDER_HOST_ENDIAN, sizeof(data)); + mpz_export_data(data, bitmask, payload->byteorder, payload_byte_size); mask = constant_expr_alloc(&payload->location, expr_basetype(payload), - BYTEORDER_HOST_ENDIAN, masklen, data); + payload->byteorder, masklen, data); mpz_clear(bitmask); payload_bytes = payload_expr_alloc(&payload->location, NULL, 0); @@ -2365,6 +3312,7 @@ static int stmt_evaluate_payload(struct eval_ctx *ctx, struct stmt *stmt) payload_byte_offset * BITS_PER_BYTE, payload_byte_size * BITS_PER_BYTE); + payload_bytes->payload.is_raw = 1; payload_bytes->payload.desc = payload->payload.desc; payload_bytes->byteorder = payload->byteorder; @@ -2391,6 +3339,24 @@ static int stmt_evaluate_payload(struct eval_ctx *ctx, struct stmt *stmt) static int stmt_evaluate_meter(struct eval_ctx *ctx, struct stmt *stmt) { struct expr *key, *set, *setref; + struct set *existing_set; + struct table *table; + + table = table_cache_find(&ctx->nft->cache.table_cache, + ctx->cmd->handle.table.name, + ctx->cmd->handle.family); + if (table == NULL) + return table_not_found(ctx); + + existing_set = set_cache_find(table, stmt->meter.name); + if (existing_set) + return cmd_error(ctx, &stmt->location, + "%s; meter '%s' overlaps an existing %s '%s' in family %s", + strerror(EEXIST), + stmt->meter.name, + set_is_map(existing_set->flags) ? "map" : "set", + existing_set->handle.set.name, + family2str(existing_set->handle.family)); expr_set_context(&ctx->ectx, NULL, 0); if (expr_evaluate(ctx, &stmt->meter.key) < 0) @@ -2410,7 +3376,9 @@ static int stmt_evaluate_meter(struct eval_ctx *ctx, struct stmt *stmt) set->set_flags |= NFT_SET_TIMEOUT; setref = implicit_set_declaration(ctx, stmt->meter.name, - expr_get(key), NULL, set); + expr_get(key), NULL, set, 0); + if (!setref) + return -1; setref->set->desc.size = stmt->meter.size; stmt->meter.set = setref; @@ -2426,26 +3394,45 @@ static int stmt_evaluate_meter(struct eval_ctx *ctx, struct stmt *stmt) static int stmt_evaluate_meta(struct eval_ctx *ctx, struct stmt *stmt) { - return stmt_evaluate_arg(ctx, stmt, - stmt->meta.tmpl->dtype, - stmt->meta.tmpl->len, - stmt->meta.tmpl->byteorder, - &stmt->meta.expr); + int ret; + + ctx->stmt_len = stmt->meta.tmpl->len; + + ret = stmt_evaluate_arg(ctx, stmt, + stmt->meta.tmpl->dtype, + stmt->meta.tmpl->len, + stmt->meta.tmpl->byteorder, + &stmt->meta.expr); + if (ret < 0) + return ret; + + if (stmt->meta.expr->etype == EXPR_RANGE) + return stmt_error_range(ctx, stmt, stmt->meta.expr); + + return ret; } static int stmt_evaluate_ct(struct eval_ctx *ctx, struct stmt *stmt) { - if (stmt_evaluate_arg(ctx, stmt, - stmt->ct.tmpl->dtype, - stmt->ct.tmpl->len, - stmt->ct.tmpl->byteorder, - &stmt->ct.expr) < 0) + int ret; + + ctx->stmt_len = stmt->ct.tmpl->len; + + ret = stmt_evaluate_arg(ctx, stmt, + stmt->ct.tmpl->dtype, + stmt->ct.tmpl->len, + stmt->ct.tmpl->byteorder, + &stmt->ct.expr); + if (ret < 0) return -1; if (stmt->ct.key == NFT_CT_SECMARK && expr_is_constant(stmt->ct.expr)) return stmt_error(ctx, stmt, "ct secmark must not be set to constant value"); + if (stmt->ct.expr->etype == EXPR_RANGE) + return stmt_error_range(ctx, stmt, stmt->ct.expr); + return 0; } @@ -2453,9 +3440,10 @@ static int reject_payload_gen_dependency_tcp(struct eval_ctx *ctx, struct stmt *stmt, struct expr **payload) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *desc; - desc = ctx->pctx.protocol[PROTO_BASE_TRANSPORT_HDR].desc; + desc = pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc; if (desc != NULL) return 0; *payload = payload_expr_alloc(&stmt->location, &proto_tcp, @@ -2467,9 +3455,10 @@ static int reject_payload_gen_dependency_family(struct eval_ctx *ctx, struct stmt *stmt, struct expr **payload) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *base; - base = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + base = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; if (base != NULL) return 0; @@ -2528,7 +3517,7 @@ static int stmt_reject_gen_dependency(struct eval_ctx *ctx, struct stmt *stmt, */ list_add(&nstmt->list, &ctx->rule->stmts); out: - xfree(payload); + free(payload); return ret; } @@ -2536,6 +3525,7 @@ static int stmt_evaluate_reject_inet_family(struct eval_ctx *ctx, struct stmt *stmt, const struct proto_desc *desc) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *base; int protocol; @@ -2545,23 +3535,26 @@ static int stmt_evaluate_reject_inet_family(struct eval_ctx *ctx, case NFT_REJECT_ICMPX_UNREACH: break; case NFT_REJECT_ICMP_UNREACH: - base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc; + base = pctx->protocol[PROTO_BASE_LL_HDR].desc; protocol = proto_find_num(base, desc); switch (protocol) { case NFPROTO_IPV4: + case __constant_htons(ETH_P_IP): if (stmt->reject.family == NFPROTO_IPV4) break; return stmt_binary_error(ctx, stmt->reject.expr, - &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR], + &pctx->protocol[PROTO_BASE_NETWORK_HDR], "conflicting protocols specified: ip vs ip6"); case NFPROTO_IPV6: + case __constant_htons(ETH_P_IPV6): if (stmt->reject.family == NFPROTO_IPV6) break; return stmt_binary_error(ctx, stmt->reject.expr, - &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR], + &pctx->protocol[PROTO_BASE_NETWORK_HDR], "conflicting protocols specified: ip vs ip6"); default: - BUG("unsupported family"); + return stmt_error(ctx, stmt, + "cannot infer ICMP reject variant to use: explicit value required.\n"); } break; } @@ -2572,9 +3565,10 @@ static int stmt_evaluate_reject_inet_family(struct eval_ctx *ctx, static int stmt_evaluate_reject_inet(struct eval_ctx *ctx, struct stmt *stmt, struct expr *expr) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *desc; - desc = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + desc = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; if (desc != NULL && stmt_evaluate_reject_inet_family(ctx, stmt, desc) < 0) return -1; @@ -2589,13 +3583,14 @@ static int stmt_evaluate_reject_bridge_family(struct eval_ctx *ctx, struct stmt *stmt, const struct proto_desc *desc) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *base; int protocol; switch (stmt->reject.type) { case NFT_REJECT_ICMPX_UNREACH: case NFT_REJECT_TCP_RST: - base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc; + base = pctx->protocol[PROTO_BASE_LL_HDR].desc; protocol = proto_find_num(base, desc); switch (protocol) { case __constant_htons(ETH_P_IP): @@ -2603,29 +3598,29 @@ static int stmt_evaluate_reject_bridge_family(struct eval_ctx *ctx, break; default: return stmt_binary_error(ctx, stmt, - &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR], + &pctx->protocol[PROTO_BASE_NETWORK_HDR], "cannot reject this network family"); } break; case NFT_REJECT_ICMP_UNREACH: - base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc; + base = pctx->protocol[PROTO_BASE_LL_HDR].desc; protocol = proto_find_num(base, desc); switch (protocol) { case __constant_htons(ETH_P_IP): if (NFPROTO_IPV4 == stmt->reject.family) break; return stmt_binary_error(ctx, stmt->reject.expr, - &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR], + &pctx->protocol[PROTO_BASE_NETWORK_HDR], "conflicting protocols specified: ip vs ip6"); case __constant_htons(ETH_P_IPV6): if (NFPROTO_IPV6 == stmt->reject.family) break; return stmt_binary_error(ctx, stmt->reject.expr, - &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR], + &pctx->protocol[PROTO_BASE_NETWORK_HDR], "conflicting protocols specified: ip vs ip6"); default: return stmt_binary_error(ctx, stmt, - &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR], + &pctx->protocol[PROTO_BASE_NETWORK_HDR], "cannot reject this network family"); } break; @@ -2637,15 +3632,15 @@ static int stmt_evaluate_reject_bridge_family(struct eval_ctx *ctx, static int stmt_evaluate_reject_bridge(struct eval_ctx *ctx, struct stmt *stmt, struct expr *expr) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *desc; - desc = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc; - if (desc != &proto_eth && desc != &proto_vlan) - return stmt_binary_error(ctx, - &ctx->pctx.protocol[PROTO_BASE_LL_HDR], - stmt, "unsupported link layer protocol"); + desc = pctx->protocol[PROTO_BASE_LL_HDR].desc; + if (desc != &proto_eth && desc != &proto_vlan && desc != &proto_netdev) + return __stmt_binary_error(ctx, &stmt->location, NULL, + "cannot reject from this link layer protocol"); - desc = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + desc = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; if (desc != NULL && stmt_evaluate_reject_bridge_family(ctx, stmt, desc) < 0) return -1; @@ -2659,7 +3654,9 @@ static int stmt_evaluate_reject_bridge(struct eval_ctx *ctx, struct stmt *stmt, static int stmt_evaluate_reject_family(struct eval_ctx *ctx, struct stmt *stmt, struct expr *expr) { - switch (ctx->pctx.family) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); + + switch (pctx->family) { case NFPROTO_ARP: return stmt_error(ctx, stmt, "cannot use reject with arp"); case NFPROTO_IPV4: @@ -2673,13 +3670,14 @@ static int stmt_evaluate_reject_family(struct eval_ctx *ctx, struct stmt *stmt, return stmt_binary_error(ctx, stmt->reject.expr, stmt, "abstracted ICMP unreachable not supported"); case NFT_REJECT_ICMP_UNREACH: - if (stmt->reject.family == ctx->pctx.family) + if (stmt->reject.family == pctx->family) break; return stmt_binary_error(ctx, stmt->reject.expr, stmt, "conflicting protocols specified: ip vs ip6"); } break; case NFPROTO_BRIDGE: + case NFPROTO_NETDEV: if (stmt_evaluate_reject_bridge(ctx, stmt, expr) < 0) return -1; break; @@ -2696,49 +3694,53 @@ static int stmt_evaluate_reject_family(struct eval_ctx *ctx, struct stmt *stmt, static int stmt_evaluate_reject_default(struct eval_ctx *ctx, struct stmt *stmt) { - int protocol; + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *desc, *base; + int protocol; - switch (ctx->pctx.family) { + switch (pctx->family) { case NFPROTO_IPV4: case NFPROTO_IPV6: stmt->reject.type = NFT_REJECT_ICMP_UNREACH; - stmt->reject.family = ctx->pctx.family; - if (ctx->pctx.family == NFPROTO_IPV4) + stmt->reject.family = pctx->family; + if (pctx->family == NFPROTO_IPV4) stmt->reject.icmp_code = ICMP_PORT_UNREACH; else stmt->reject.icmp_code = ICMP6_DST_UNREACH_NOPORT; break; case NFPROTO_INET: - desc = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + desc = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; if (desc == NULL) { stmt->reject.type = NFT_REJECT_ICMPX_UNREACH; stmt->reject.icmp_code = NFT_REJECT_ICMPX_PORT_UNREACH; break; } stmt->reject.type = NFT_REJECT_ICMP_UNREACH; - base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc; + base = pctx->protocol[PROTO_BASE_LL_HDR].desc; protocol = proto_find_num(base, desc); switch (protocol) { case NFPROTO_IPV4: + case __constant_htons(ETH_P_IP): stmt->reject.family = NFPROTO_IPV4; stmt->reject.icmp_code = ICMP_PORT_UNREACH; break; case NFPROTO_IPV6: + case __constant_htons(ETH_P_IPV6): stmt->reject.family = NFPROTO_IPV6; stmt->reject.icmp_code = ICMP6_DST_UNREACH_NOPORT; break; } break; case NFPROTO_BRIDGE: - desc = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + case NFPROTO_NETDEV: + desc = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; if (desc == NULL) { stmt->reject.type = NFT_REJECT_ICMPX_UNREACH; stmt->reject.icmp_code = NFT_REJECT_ICMPX_PORT_UNREACH; break; } stmt->reject.type = NFT_REJECT_ICMP_UNREACH; - base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc; + base = pctx->protocol[PROTO_BASE_LL_HDR].desc; protocol = proto_find_num(base, desc); switch (protocol) { case __constant_htons(ETH_P_IP): @@ -2757,7 +3759,10 @@ static int stmt_evaluate_reject_default(struct eval_ctx *ctx, static int stmt_evaluate_reject_icmp(struct eval_ctx *ctx, struct stmt *stmt) { - struct parse_ctx parse_ctx = { .tbl = &ctx->nft->output.tbl, }; + struct parse_ctx parse_ctx = { + .tbl = &ctx->nft->output.tbl, + .input = &ctx->nft->input, + }; struct error_record *erec; struct expr *code; @@ -2766,6 +3771,13 @@ static int stmt_evaluate_reject_icmp(struct eval_ctx *ctx, struct stmt *stmt) erec_queue(erec, ctx->msgs); return -1; } + + if (mpz_cmp_ui(code->value, UINT8_MAX) > 0) { + expr_free(code); + return expr_error(ctx->msgs, stmt->reject.expr, + "reject code must be integer in range 0-255"); + } + stmt->reject.icmp_code = mpz_get_uint8(code->value); expr_free(code); @@ -2774,9 +3786,9 @@ static int stmt_evaluate_reject_icmp(struct eval_ctx *ctx, struct stmt *stmt) static int stmt_evaluate_reset(struct eval_ctx *ctx, struct stmt *stmt) { - int protonum; + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *desc, *base; - struct proto_ctx *pctx = &ctx->pctx; + int protonum; desc = pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc; if (desc == NULL) @@ -2793,7 +3805,7 @@ static int stmt_evaluate_reset(struct eval_ctx *ctx, struct stmt *stmt) default: if (stmt->reject.type == NFT_REJECT_TCP_RST) { return stmt_binary_error(ctx, stmt, - &ctx->pctx.protocol[PROTO_BASE_TRANSPORT_HDR], + &pctx->protocol[PROTO_BASE_TRANSPORT_HDR], "you cannot use tcp reset with this protocol"); } break; @@ -2821,22 +3833,24 @@ static int stmt_evaluate_reject(struct eval_ctx *ctx, struct stmt *stmt) static int nat_evaluate_family(struct eval_ctx *ctx, struct stmt *stmt) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *nproto; - switch (ctx->pctx.family) { + switch (pctx->family) { case NFPROTO_IPV4: case NFPROTO_IPV6: if (stmt->nat.family == NFPROTO_UNSPEC) - stmt->nat.family = ctx->pctx.family; + stmt->nat.family = pctx->family; return 0; case NFPROTO_INET: - if (!stmt->nat.addr) + if (!stmt->nat.addr) { + stmt->nat.family = NFPROTO_INET; return 0; - + } if (stmt->nat.family != NFPROTO_UNSPEC) return 0; - nproto = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + nproto = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; if (nproto == &proto_ip) stmt->nat.family = NFPROTO_IPV4; @@ -2865,7 +3879,7 @@ static const struct datatype *get_addr_dtype(uint8_t family) static int evaluate_addr(struct eval_ctx *ctx, struct stmt *stmt, struct expr **expr) { - struct proto_ctx *pctx = &ctx->pctx; + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct datatype *dtype; dtype = get_addr_dtype(pctx->family); @@ -2875,51 +3889,133 @@ static int evaluate_addr(struct eval_ctx *ctx, struct stmt *stmt, expr); } +static bool nat_evaluate_addr_has_th_expr(const struct expr *map) +{ + const struct expr *i, *concat; + + if (!map || map->etype != EXPR_MAP) + return false; + + concat = map->map; + if (concat ->etype != EXPR_CONCAT) + return false; + + list_for_each_entry(i, &concat->expressions, list) { + enum proto_bases base; + + if (i->etype == EXPR_PAYLOAD && + i->payload.base == PROTO_BASE_TRANSPORT_HDR && + i->payload.desc != &proto_th) + return true; + + if ((i->flags & EXPR_F_PROTOCOL) == 0) + continue; + + switch (i->etype) { + case EXPR_META: + base = i->meta.base; + break; + case EXPR_PAYLOAD: + base = i->payload.base; + break; + default: + return false; + } + + if (base == PROTO_BASE_NETWORK_HDR) + return true; + } + + return false; +} + static int nat_evaluate_transport(struct eval_ctx *ctx, struct stmt *stmt, struct expr **expr) { - struct proto_ctx *pctx = &ctx->pctx; + struct proto_ctx *pctx = eval_proto_ctx(ctx); + int err; - if (pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL) + err = stmt_evaluate_arg(ctx, stmt, + &inet_service_type, 2 * BITS_PER_BYTE, + BYTEORDER_BIG_ENDIAN, expr); + if (err < 0) + return err; + + if (pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL && + !nat_evaluate_addr_has_th_expr(stmt->nat.addr)) return stmt_binary_error(ctx, *expr, stmt, "transport protocol mapping is only " "valid after transport protocol match"); - return stmt_evaluate_arg(ctx, stmt, - &inet_service_type, 2 * BITS_PER_BYTE, - BYTEORDER_BIG_ENDIAN, expr); + return 0; } static int stmt_evaluate_l3proto(struct eval_ctx *ctx, struct stmt *stmt, uint8_t family) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct proto_desc *nproto; - nproto = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; + nproto = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc; if ((nproto == &proto_ip && family != NFPROTO_IPV4) || (nproto == &proto_ip6 && family != NFPROTO_IPV6)) return stmt_binary_error(ctx, stmt, - &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR], + &pctx->protocol[PROTO_BASE_NETWORK_HDR], "conflicting protocols specified: %s vs. %s. You must specify ip or ip6 family in %s statement", - ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc->name, + pctx->protocol[PROTO_BASE_NETWORK_HDR].desc->name, family2str(family), stmt->ops->name); return 0; } +static void expr_family_infer(struct proto_ctx *pctx, const struct expr *expr, + uint8_t *family) +{ + struct expr *i; + + if (expr->etype == EXPR_MAP) { + switch (expr->map->etype) { + case EXPR_CONCAT: + list_for_each_entry(i, &expr->map->expressions, list) { + if (i->etype == EXPR_PAYLOAD) { + if (i->payload.desc == &proto_ip) + *family = NFPROTO_IPV4; + else if (i->payload.desc == &proto_ip6) + *family = NFPROTO_IPV6; + } + } + break; + case EXPR_PAYLOAD: + if (expr->map->payload.desc == &proto_ip) + *family = NFPROTO_IPV4; + else if (expr->map->payload.desc == &proto_ip6) + *family = NFPROTO_IPV6; + break; + default: + break; + } + } +} + static int stmt_evaluate_addr(struct eval_ctx *ctx, struct stmt *stmt, - uint8_t family, - struct expr **addr) + uint8_t *family, struct expr **addr) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct datatype *dtype; int err; - if (ctx->pctx.family == NFPROTO_INET) { - dtype = get_addr_dtype(family); - if (dtype->size == 0) + if (pctx->family == NFPROTO_INET) { + if (*family == NFPROTO_INET || + *family == NFPROTO_UNSPEC) + expr_family_infer(pctx, *addr, family); + + dtype = get_addr_dtype(*family); + if (dtype->size == 0) { return stmt_error(ctx, stmt, - "ip or ip6 must be specified with address for inet tables."); + "specify `%s ip' or '%s ip6' in %s table to disambiguate", + stmt_name(stmt), stmt_name(stmt), family2str(pctx->family)); + } err = stmt_evaluate_arg(ctx, stmt, dtype, dtype->size, BYTEORDER_BIG_ENDIAN, addr); @@ -2932,9 +4028,15 @@ static int stmt_evaluate_addr(struct eval_ctx *ctx, struct stmt *stmt, static int stmt_evaluate_nat_map(struct eval_ctx *ctx, struct stmt *stmt) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); struct expr *one, *two, *data, *tmp; - const struct datatype *dtype; - int addr_type, err; + const struct datatype *dtype = NULL; + const struct datatype *dtype2; + int addr_type; + int err; + + if (stmt->nat.family == NFPROTO_INET) + expr_family_infer(pctx, stmt->nat.addr, &stmt->nat.family); switch (stmt->nat.family) { case NFPROTO_IPV4: @@ -2944,53 +4046,117 @@ static int stmt_evaluate_nat_map(struct eval_ctx *ctx, struct stmt *stmt) addr_type = TYPE_IP6ADDR; break; default: - return -1; + return stmt_error(ctx, stmt, + "specify `%s ip' or '%s ip6' in %s table to disambiguate", + stmt_name(stmt), stmt_name(stmt), family2str(pctx->family)); } dtype = concat_type_alloc((addr_type << TYPE_BITS) | TYPE_INET_SERVICE); expr_set_context(&ctx->ectx, dtype, dtype->size); - if (expr_evaluate(ctx, &stmt->nat.addr)) - return -1; + if (expr_evaluate(ctx, &stmt->nat.addr)) { + err = -1; + goto out; + } - if (stmt->nat.addr->etype != EXPR_MAP) - return 0; + if (pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL && + !nat_evaluate_addr_has_th_expr(stmt->nat.addr)) { + err = stmt_binary_error(ctx, stmt->nat.addr, stmt, + "transport protocol mapping is only " + "valid after transport protocol match"); + goto out; + } + + if (stmt->nat.addr->etype != EXPR_MAP) { + err = 0; + goto out; + } data = stmt->nat.addr->mappings->set->data; + if (data->flags & EXPR_F_INTERVAL) + stmt->nat.type_flags |= STMT_NAT_F_INTERVAL; + datatype_set(data, dtype); - if (expr_ops(data)->type != EXPR_CONCAT) - return __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size, + if (expr_ops(data)->type != EXPR_CONCAT) { + err = __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size, BYTEORDER_BIG_ENDIAN, &stmt->nat.addr); + goto out; + } one = list_first_entry(&data->expressions, struct expr, list); two = list_entry(one->list.next, struct expr, list); - if (one == two || !list_is_last(&two->list, &data->expressions)) - return __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size, + if (one == two || !list_is_last(&two->list, &data->expressions)) { + err = __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size, BYTEORDER_BIG_ENDIAN, &stmt->nat.addr); + goto out; + } - dtype = get_addr_dtype(stmt->nat.family); + dtype2 = get_addr_dtype(stmt->nat.family); tmp = one; - err = __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size, + err = __stmt_evaluate_arg(ctx, stmt, dtype2, dtype2->size, BYTEORDER_BIG_ENDIAN, &tmp); if (err < 0) - return err; + goto out; if (tmp != one) BUG("Internal error: Unexpected alteration of l3 expression"); tmp = two; err = nat_evaluate_transport(ctx, stmt, &tmp); if (err < 0) - return err; + goto out; if (tmp != two) BUG("Internal error: Unexpected alteration of l4 expression"); +out: + datatype_free(dtype); return err; } +static bool nat_concat_map(struct eval_ctx *ctx, struct stmt *stmt) +{ + struct expr *i; + + if (stmt->nat.addr->etype != EXPR_MAP) + return false; + + switch (stmt->nat.addr->mappings->etype) { + case EXPR_SET: + list_for_each_entry(i, &stmt->nat.addr->mappings->expressions, list) { + if (i->etype == EXPR_MAPPING && + i->right->etype == EXPR_CONCAT) { + stmt->nat.type_flags |= STMT_NAT_F_CONCAT; + return true; + } + } + break; + case EXPR_SYMBOL: + /* expr_evaluate_map() see EXPR_SET_REF after this evaluation. */ + if (expr_evaluate(ctx, &stmt->nat.addr->mappings)) + return false; + + if (!set_is_datamap(stmt->nat.addr->mappings->set->flags)) { + expr_error(ctx->msgs, stmt->nat.addr->mappings, + "Expression is not a map"); + return false; + } + + if (stmt->nat.addr->mappings->set->data->etype == EXPR_CONCAT || + stmt->nat.addr->mappings->set->data->dtype->subtypes) { + stmt->nat.type_flags |= STMT_NAT_F_CONCAT; + return true; + } + break; + default: + break; + } + + return false; +} + static int stmt_evaluate_nat(struct eval_ctx *ctx, struct stmt *stmt) { int err; @@ -3004,7 +4170,9 @@ static int stmt_evaluate_nat(struct eval_ctx *ctx, struct stmt *stmt) if (err < 0) return err; - if (stmt->nat.type_flags & STMT_NAT_F_CONCAT) { + if (nat_concat_map(ctx, stmt) || + stmt->nat.type_flags & STMT_NAT_F_CONCAT) { + err = stmt_evaluate_nat_map(ctx, stmt); if (err < 0) return err; @@ -3013,32 +4181,12 @@ static int stmt_evaluate_nat(struct eval_ctx *ctx, struct stmt *stmt) return 0; } - err = stmt_evaluate_addr(ctx, stmt, stmt->nat.family, + err = stmt_evaluate_addr(ctx, stmt, &stmt->nat.family, &stmt->nat.addr); if (err < 0) return err; } - if (stmt->nat.type_flags & STMT_NAT_F_INTERVAL) { - switch (stmt->nat.addr->etype) { - case EXPR_MAP: - if (!(stmt->nat.addr->mappings->set->data->flags & EXPR_F_INTERVAL)) - return expr_error(ctx->msgs, stmt->nat.addr, - "map is not defined as interval"); - break; - case EXPR_RANGE: - case EXPR_PREFIX: - break; - default: - return expr_error(ctx->msgs, stmt->nat.addr, - "neither prefix, range nor map expression"); - } - - stmt->flags |= STMT_F_TERMINAL; - - return 0; - } - if (stmt->nat.proto != NULL) { err = nat_evaluate_transport(ctx, stmt, &stmt->nat.proto); if (err < 0) @@ -3053,13 +4201,14 @@ static int stmt_evaluate_nat(struct eval_ctx *ctx, struct stmt *stmt) static int stmt_evaluate_tproxy(struct eval_ctx *ctx, struct stmt *stmt) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); int err; - switch (ctx->pctx.family) { + switch (pctx->family) { case NFPROTO_IPV4: case NFPROTO_IPV6: /* fallthrough */ if (stmt->tproxy.family == NFPROTO_UNSPEC) - stmt->tproxy.family = ctx->pctx.family; + stmt->tproxy.family = pctx->family; break; case NFPROTO_INET: break; @@ -3068,7 +4217,7 @@ static int stmt_evaluate_tproxy(struct eval_ctx *ctx, struct stmt *stmt) "tproxy is only supported for IPv4/IPv6/INET"); } - if (ctx->pctx.protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL) + if (pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL) return stmt_error(ctx, stmt, "Transparent proxy support requires" " transport protocol match"); @@ -3080,22 +4229,22 @@ static int stmt_evaluate_tproxy(struct eval_ctx *ctx, struct stmt *stmt) return err; if (stmt->tproxy.addr != NULL) { - if (stmt->tproxy.addr->etype == EXPR_RANGE) - return stmt_error(ctx, stmt, "Address ranges are not supported for tproxy."); - - err = stmt_evaluate_addr(ctx, stmt, stmt->tproxy.family, + err = stmt_evaluate_addr(ctx, stmt, &stmt->tproxy.family, &stmt->tproxy.addr); - if (err < 0) return err; + + if (stmt->tproxy.addr->etype == EXPR_RANGE) + return stmt_error(ctx, stmt, "Address ranges are not supported for tproxy."); } if (stmt->tproxy.port != NULL) { - if (stmt->tproxy.port->etype == EXPR_RANGE) - return stmt_error(ctx, stmt, "Port ranges are not supported for tproxy."); err = nat_evaluate_transport(ctx, stmt, &stmt->tproxy.port); if (err < 0) return err; + + if (stmt->tproxy.port->etype == EXPR_RANGE) + return stmt_error(ctx, stmt, "Port ranges are not supported for tproxy."); } return 0; @@ -3132,7 +4281,7 @@ static int stmt_evaluate_chain(struct eval_ctx *ctx, struct stmt *stmt) memset(&h, 0, sizeof(h)); handle_merge(&h, &chain->handle); h.family = ctx->rule->handle.family; - xfree(h.table.name); + free_const(h.table.name); h.table.name = xstrdup(ctx->rule->handle.table.name); h.chain.location = stmt->location; h.chain_id = chain->handle.chain_id; @@ -3147,13 +4296,14 @@ static int stmt_evaluate_chain(struct eval_ctx *ctx, struct stmt *stmt) struct eval_ctx rule_ctx = { .nft = ctx->nft, .msgs = ctx->msgs, + .cmd = ctx->cmd, }; struct handle h2 = {}; handle_merge(&rule->handle, &ctx->rule->handle); - xfree(rule->handle.table.name); + free_const(rule->handle.table.name); rule->handle.table.name = xstrdup(ctx->rule->handle.table.name); - xfree(rule->handle.chain.name); + free_const(rule->handle.chain.name); rule->handle.chain.name = NULL; rule->handle.chain_id = chain->handle.chain_id; if (rule_evaluate(&rule_ctx, rule, CMD_INVALID) < 0) @@ -3170,11 +4320,17 @@ static int stmt_evaluate_chain(struct eval_ctx *ctx, struct stmt *stmt) return 0; } +static int stmt_evaluate_optstrip(struct eval_ctx *ctx, struct stmt *stmt) +{ + return expr_evaluate(ctx, &stmt->optstrip.expr); +} + static int stmt_evaluate_dup(struct eval_ctx *ctx, struct stmt *stmt) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); int err; - switch (ctx->pctx.family) { + switch (pctx->family) { case NFPROTO_IPV4: case NFPROTO_IPV6: if (stmt->dup.to == NULL) @@ -3191,6 +4347,9 @@ static int stmt_evaluate_dup(struct eval_ctx *ctx, struct stmt *stmt) &stmt->dup.dev); if (err < 0) return err; + + if (stmt->dup.dev->etype == EXPR_RANGE) + return stmt_error_range(ctx, stmt, stmt->dup.dev); } break; case NFPROTO_NETDEV: @@ -3209,15 +4368,20 @@ static int stmt_evaluate_dup(struct eval_ctx *ctx, struct stmt *stmt) default: return stmt_error(ctx, stmt, "unsupported family"); } + + if (stmt->dup.to->etype == EXPR_RANGE) + return stmt_error_range(ctx, stmt, stmt->dup.to); + return 0; } static int stmt_evaluate_fwd(struct eval_ctx *ctx, struct stmt *stmt) { + struct proto_ctx *pctx = eval_proto_ctx(ctx); const struct datatype *dtype; int err, len; - switch (ctx->pctx.family) { + switch (pctx->family) { case NFPROTO_NETDEV: if (stmt->fwd.dev == NULL) return stmt_error(ctx, stmt, @@ -3229,6 +4393,9 @@ static int stmt_evaluate_fwd(struct eval_ctx *ctx, struct stmt *stmt) if (err < 0) return err; + if (stmt->fwd.dev->etype == EXPR_RANGE) + return stmt_error_range(ctx, stmt, stmt->fwd.dev); + if (stmt->fwd.addr != NULL) { switch (stmt->fwd.family) { case NFPROTO_IPV4: @@ -3247,6 +4414,9 @@ static int stmt_evaluate_fwd(struct eval_ctx *ctx, struct stmt *stmt) &stmt->fwd.addr); if (err < 0) return err; + + if (stmt->fwd.addr->etype == EXPR_RANGE) + return stmt_error_range(ctx, stmt, stmt->fwd.addr); } break; default: @@ -3263,14 +4433,16 @@ static int stmt_evaluate_queue(struct eval_ctx *ctx, struct stmt *stmt) BYTEORDER_HOST_ENDIAN, &stmt->queue.queue) < 0) return -1; - if (!expr_is_constant(stmt->queue.queue)) - return expr_error(ctx->msgs, stmt->queue.queue, - "queue number is not constant"); - if (stmt->queue.queue->etype != EXPR_RANGE && - (stmt->queue.flags & NFT_QUEUE_FLAG_CPU_FANOUT)) + + if ((stmt->queue.flags & NFT_QUEUE_FLAG_CPU_FANOUT) && + stmt->queue.queue->etype != EXPR_RANGE) return expr_error(ctx->msgs, stmt->queue.queue, "fanout requires a range to be " "specified"); + + if (ctx->ectx.maxval > USHRT_MAX) + return expr_error(ctx->msgs, stmt->queue.queue, + "queue expression max value exceeds %u", USHRT_MAX); } stmt->flags |= STMT_F_TERMINAL; return 0; @@ -3278,39 +4450,12 @@ static int stmt_evaluate_queue(struct eval_ctx *ctx, struct stmt *stmt) static int stmt_evaluate_log_prefix(struct eval_ctx *ctx, struct stmt *stmt) { - char prefix[NF_LOG_PREFIXLEN] = {}, tmp[NF_LOG_PREFIXLEN] = {}; - int len = sizeof(prefix), offset = 0, ret; - struct expr *expr; - size_t size = 0; + unsigned int len = strlen(stmt->log.prefix); - if (stmt->log.prefix->etype != EXPR_LIST) - return 0; - - list_for_each_entry(expr, &stmt->log.prefix->expressions, list) { - switch (expr->etype) { - case EXPR_VALUE: - expr_to_string(expr, tmp); - ret = snprintf(prefix + offset, len, "%s", tmp); - break; - case EXPR_VARIABLE: - ret = snprintf(prefix + offset, len, "%s", - expr->sym->expr->identifier); - break; - default: - BUG("unknown expresion type %s\n", expr_name(expr)); - break; - } - SNPRINTF_BUFFER_SIZE(ret, size, len, offset); - } - - if (len == NF_LOG_PREFIXLEN) + if (len >= NF_LOG_PREFIXLEN) return stmt_error(ctx, stmt, "log prefix is too long"); - - expr = constant_expr_alloc(&stmt->log.prefix->location, &string_type, - BYTEORDER_HOST_ENDIAN, - strlen(prefix) * BITS_PER_BYTE, prefix); - expr_free(stmt->log.prefix); - stmt->log.prefix = expr; + else if (len == 0) + return stmt_error(ctx, stmt, "log prefix must have a minimum length of 1 character"); return 0; } @@ -3341,6 +4486,9 @@ static int stmt_evaluate_log(struct eval_ctx *ctx, struct stmt *stmt) static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt) { + struct set *this_set; + struct stmt *this; + expr_set_context(&ctx->ectx, NULL, 0); if (expr_evaluate(ctx, &stmt->set.set) < 0) return -1; @@ -3348,7 +4496,7 @@ static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt) return expr_error(ctx->msgs, stmt->set.set, "Expression does not refer to a set"); - if (stmt_evaluate_arg(ctx, stmt, + if (stmt_evaluate_key(ctx, stmt, stmt->set.set->set->key->dtype, stmt->set.set->set->key->len, stmt->set.set->set->key->byteorder, @@ -3360,19 +4508,30 @@ static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt) if (stmt->set.key->comment != NULL) return expr_error(ctx->msgs, stmt->set.key, "Key expression comments are not supported"); - if (stmt->set.stmt) { - if (stmt_evaluate(ctx, stmt->set.stmt) < 0) + list_for_each_entry(this, &stmt->set.stmt_list, list) { + if (stmt_evaluate(ctx, this) < 0) return -1; - if (!(stmt->set.stmt->flags & STMT_F_STATEFUL)) - return stmt_binary_error(ctx, stmt->set.stmt, stmt, - "meter statement must be stateful"); + if (!(this->flags & STMT_F_STATEFUL)) + return stmt_error(ctx, this, + "statement must be stateful"); } + this_set = stmt->set.set->set; + + /* Make sure EVAL flag is set on set definition so that kernel + * picks a set that allows updates from the packet path. + * + * Alternatively we could error out in case 'flags dynamic' was + * not given, but we can repair this here. + */ + this_set->flags |= NFT_SET_EVAL; return 0; } static int stmt_evaluate_map(struct eval_ctx *ctx, struct stmt *stmt) { + struct stmt *this; + expr_set_context(&ctx->ectx, NULL, 0); if (expr_evaluate(ctx, &stmt->map.set) < 0) return -1; @@ -3380,7 +4539,11 @@ static int stmt_evaluate_map(struct eval_ctx *ctx, struct stmt *stmt) return expr_error(ctx->msgs, stmt->map.set, "Expression does not refer to a set"); - if (stmt_evaluate_arg(ctx, stmt, + if (!set_is_map(stmt->map.set->set->flags)) + return expr_error(ctx->msgs, stmt->map.set, + "%s is not a map", stmt->map.set->set->handle.set.name); + + if (stmt_evaluate_key(ctx, stmt, stmt->map.set->set->key->dtype, stmt->map.set->set->key->len, stmt->map.set->set->key->byteorder, @@ -3405,13 +4568,16 @@ static int stmt_evaluate_map(struct eval_ctx *ctx, struct stmt *stmt) if (stmt->map.data->comment != NULL) return expr_error(ctx->msgs, stmt->map.data, "Data expression comments are not supported"); + if (stmt->map.data->timeout > 0) + return expr_error(ctx->msgs, stmt->map.data, + "Data expression timeouts are not supported"); - if (stmt->map.stmt) { - if (stmt_evaluate(ctx, stmt->map.stmt) < 0) + list_for_each_entry(this, &stmt->map.stmt_list, list) { + if (stmt_evaluate(ctx, this) < 0) return -1; - if (!(stmt->map.stmt->flags & STMT_F_STATEFUL)) - return stmt_binary_error(ctx, stmt->map.stmt, stmt, - "meter statement must be stateful"); + if (!(this->flags & STMT_F_STATEFUL)) + return stmt_error(ctx, this, + "statement must be stateful"); } return 0; @@ -3434,6 +4600,7 @@ static int stmt_evaluate_objref_map(struct eval_ctx *ctx, struct stmt *stmt) mappings->set_flags |= NFT_SET_OBJECT; switch (map->mappings->etype) { + case EXPR_VARIABLE: case EXPR_SET: key = constant_expr_alloc(&stmt->location, ctx->ectx.dtype, @@ -3441,7 +4608,10 @@ static int stmt_evaluate_objref_map(struct eval_ctx *ctx, struct stmt *stmt) ctx->ectx.len, NULL); mappings = implicit_set_declaration(ctx, "__objmap%d", - key, NULL, mappings); + key, NULL, mappings, + NFT_SET_ANONYMOUS); + if (!mappings) + return -1; mappings->set->objtype = stmt->objref.type; map->mappings = mappings; @@ -3449,10 +4619,20 @@ static int stmt_evaluate_objref_map(struct eval_ctx *ctx, struct stmt *stmt) ctx->set = mappings->set; if (expr_evaluate(ctx, &map->mappings->set->init) < 0) return -1; + + if (map->mappings->set->init->etype != EXPR_SET) { + return expr_error(ctx->msgs, map->mappings->set->init, + "Expression is not a map"); + } + + if (set_is_interval(map->mappings->set->init->set_flags) && + !(map->mappings->set->init->set_flags & NFT_SET_CONCAT) && + interval_set_eval(ctx, ctx->set, map->mappings->set->init) < 0) + return -1; + ctx->set = NULL; - map->mappings->set->flags |= - map->mappings->set->init->set_flags; + map_set_concat_info(map); /* fall through */ case EXPR_SYMBOL: if (expr_evaluate(ctx, &map->mappings) < 0) @@ -3465,11 +4645,12 @@ static int stmt_evaluate_objref_map(struct eval_ctx *ctx, struct stmt *stmt) "Expression is not a map with objects"); break; default: - BUG("invalid mapping expression %s\n", - expr_name(map->mappings)); + return expr_binary_error(ctx->msgs, map->mappings, map->map, + "invalid mapping expression %s", + expr_name(map->mappings)); } - if (!datatype_equal(map->map->dtype, map->mappings->set->key->dtype)) + if (!datatype_compatible(map->mappings->set->key->dtype, map->map->dtype)) return expr_binary_error(ctx->msgs, map->mappings, map->map, "datatype mismatch, map expects %s, " "mapping expression has type %s", @@ -3517,9 +4698,12 @@ int stmt_evaluate(struct eval_ctx *ctx, struct stmt *stmt) erec_destroy(erec); } + ctx->stmt_len = 0; + switch (stmt->ops->type) { case STMT_CONNLIMIT: case STMT_COUNTER: + case STMT_LAST: case STMT_LIMIT: case STMT_QUOTA: case STMT_NOTRACK: @@ -3563,6 +4747,8 @@ int stmt_evaluate(struct eval_ctx *ctx, struct stmt *stmt) return stmt_evaluate_synproxy(ctx, stmt); case STMT_CHAIN: return stmt_evaluate_chain(ctx, stmt); + case STMT_OPTSTRIP: + return stmt_evaluate_optstrip(ctx, stmt); default: BUG("unknown statement type %s\n", stmt->ops->name); } @@ -3573,22 +4759,37 @@ static int setelem_evaluate(struct eval_ctx *ctx, struct cmd *cmd) struct table *table; struct set *set; - table = table_lookup_global(ctx); + table = table_cache_find(&ctx->nft->cache.table_cache, + ctx->cmd->handle.table.name, + ctx->cmd->handle.family); if (table == NULL) return table_not_found(ctx); - set = set_lookup(table, ctx->cmd->handle.set.name); + set = set_cache_find(table, ctx->cmd->handle.set.name); if (set == NULL) return set_not_found(ctx, &ctx->cmd->handle.set.location, ctx->cmd->handle.set.name); + if (set->key == NULL) + return -1; + + set->existing_set = set; ctx->set = set; expr_set_context(&ctx->ectx, set->key->dtype, set->key->len); if (expr_evaluate(ctx, &cmd->expr) < 0) return -1; - ctx->set = NULL; cmd->elem.set = set_get(set); + if (set_is_interval(ctx->set->flags)) { + if (!(set->flags & NFT_SET_CONCAT) && + interval_set_eval(ctx, ctx->set, cmd->expr) < 0) + return -1; + + assert(cmd->expr->etype == EXPR_SET); + cmd->expr->set_flags |= NFT_SET_INTERVAL; + } + + ctx->set = NULL; return 0; } @@ -3607,27 +4808,137 @@ static int set_key_data_error(struct eval_ctx *ctx, const struct set *set, dtype->name, name, hint); } +static int set_expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr) +{ + unsigned int flags = EXPR_F_CONSTANT | EXPR_F_SINGLETON; + uint32_t ntype = 0, size = 0; + struct expr *i, *next; + + list_for_each_entry_safe(i, next, &(*expr)->expressions, list) { + unsigned dsize_bytes; + + if (i->etype == EXPR_CT && + (i->ct.key == NFT_CT_SRC || + i->ct.key == NFT_CT_DST)) + return expr_error(ctx->msgs, i, + "specify either ip or ip6 for address matching"); + + if (i->etype == EXPR_PAYLOAD && + i->dtype->type == TYPE_INTEGER) { + struct datatype *dtype; + + dtype = datatype_clone(i->dtype); + dtype->size = i->len; + dtype->byteorder = i->byteorder; + __datatype_set(i, dtype); + } + + if (i->dtype->size == 0 && i->len == 0) + return expr_binary_error(ctx->msgs, i, *expr, + "can not use variable sized " + "data types (%s) in concat " + "expressions", + i->dtype->name); + + flags &= i->flags; + + ntype = concat_subtype_add(ntype, i->dtype->type); + + dsize_bytes = div_round_up(i->len, BITS_PER_BYTE); + + if (i->dtype->size) + assert(dsize_bytes == div_round_up(i->dtype->size, BITS_PER_BYTE)); + + (*expr)->field_len[(*expr)->field_count++] = dsize_bytes; + size += netlink_padded_len(i->len); + + if (size > NFT_MAX_EXPR_LEN_BITS) + return expr_error(ctx->msgs, i, "Concatenation of size %u exceeds maximum size of %u", + size, NFT_MAX_EXPR_LEN_BITS); + } + + (*expr)->flags |= flags; + __datatype_set(*expr, concat_type_alloc(ntype)); + (*expr)->len = size; + + expr_set_context(&ctx->ectx, (*expr)->dtype, (*expr)->len); + ctx->ectx.key = *expr; + + return 0; +} + +static int elems_evaluate(struct eval_ctx *ctx, struct set *set) +{ + ctx->set = set; + if (set->init != NULL) { + if (set->key == NULL) + return set_error(ctx, set, "set definition does not specify key"); + + __expr_set_context(&ctx->ectx, set->key->dtype, + set->key->byteorder, set->key->len, 0); + if (expr_evaluate(ctx, &set->init) < 0) { + set->errors = true; + return -1; + } + if (set->init->etype != EXPR_SET) + return expr_error(ctx->msgs, set->init, "Set %s: Unexpected initial type %s, missing { }?", + set->handle.set.name, expr_name(set->init)); + } + + if (set_is_interval(ctx->set->flags) && + !(ctx->set->flags & NFT_SET_CONCAT) && + interval_set_eval(ctx, ctx->set, set->init) < 0) + return -1; + + ctx->set = NULL; + + return 0; +} + static int set_evaluate(struct eval_ctx *ctx, struct set *set) { + struct set *existing_set = NULL; + unsigned int num_stmts = 0; struct table *table; + struct stmt *stmt; const char *type; - table = table_lookup_global(ctx); - if (table == NULL) - return table_not_found(ctx); + type = set_is_map(set->flags) ? "map" : "set"; + + if (set->key == NULL) + return set_error(ctx, set, "%s definition does not specify key", + type); + + if (!set_is_anonymous(set->flags)) { + table = table_cache_find(&ctx->nft->cache.table_cache, + set->handle.table.name, + set->handle.family); + if (table == NULL) + return table_not_found(ctx); + + existing_set = set_cache_find(table, set->handle.set.name); + if (!existing_set) + set_cache_add(set_get(set), table); + + if (existing_set && existing_set->flags & NFT_SET_EVAL) { + uint32_t existing_flags = existing_set->flags & ~NFT_SET_EVAL; + uint32_t new_flags = set->flags & ~NFT_SET_EVAL; + + if (existing_flags == new_flags) + set->flags |= NFT_SET_EVAL; + } + } if (!(set->flags & NFT_SET_INTERVAL) && set->automerge) return set_error(ctx, set, "auto-merge only works with interval sets"); - type = set_is_map(set->flags) ? "map" : "set"; - if (set->key == NULL) return set_error(ctx, set, "%s definition does not specify key", type); if (set->key->len == 0) { if (set->key->etype == EXPR_CONCAT && - expr_evaluate_concat(ctx, &set->key, false) < 0) + set_expr_evaluate_concat(ctx, &set->key) < 0) return -1; if (set->key->len == 0) @@ -3642,18 +4953,33 @@ static int set_evaluate(struct eval_ctx *ctx, struct set *set) set->flags |= NFT_SET_CONCAT; } + if (set_is_anonymous(set->flags) && set->key->etype == EXPR_CONCAT) { + struct expr *i; + + list_for_each_entry(i, &set->init->expressions, list) { + if ((i->etype == EXPR_SET_ELEM && + i->key->etype != EXPR_CONCAT && + i->key->etype != EXPR_SET_ELEM_CATCHALL) || + (i->etype == EXPR_MAPPING && + i->left->etype == EXPR_SET_ELEM && + i->left->key->etype != EXPR_CONCAT && + i->left->key->etype != EXPR_SET_ELEM_CATCHALL)) + return expr_error(ctx->msgs, i, "expression is not a concatenation"); + } + } + if (set_is_datamap(set->flags)) { if (set->data == NULL) return set_error(ctx, set, "map definition does not " "specify mapping data type"); - if (set->data->flags & EXPR_F_INTERVAL) - set->data->len *= 2; - if (set->data->etype == EXPR_CONCAT && - expr_evaluate_concat(ctx, &set->data, false) < 0) + set_expr_evaluate_concat(ctx, &set->data) < 0) return -1; + if (set->data->flags & EXPR_F_INTERVAL) + set->data->len *= 2; + if (set->data->len == 0 && set->data->dtype->type != TYPE_VERDICT) return set_key_data_error(ctx, set, set->data->dtype, type); @@ -3670,20 +4996,22 @@ static int set_evaluate(struct eval_ctx *ctx, struct set *set) if (set->timeout) set->flags |= NFT_SET_TIMEOUT; - if (set_is_anonymous(set->flags)) - return 0; + list_for_each_entry(stmt, &set->stmt_list, list) + num_stmts++; - ctx->set = set; - if (set->init != NULL) { - __expr_set_context(&ctx->ectx, set->key->dtype, - set->key->byteorder, set->key->len, 0); - if (expr_evaluate(ctx, &set->init) < 0) + if (num_stmts > 1) + set->flags |= NFT_SET_EXPR; + + if (set_is_anonymous(set->flags)) { + if (set_is_interval(set->init->set_flags) && + !(set->init->set_flags & NFT_SET_CONCAT) && + interval_set_eval(ctx, set, set->init) < 0) return -1; + + return 0; } - ctx->set = NULL; - if (set_lookup(table, set->handle.set.name) == NULL) - set_add_hash(set_get(set), table); + set->existing_set = existing_set; return 0; } @@ -3698,8 +5026,8 @@ static bool evaluate_priority(struct eval_ctx *ctx, struct prio_spec *prio, int prio_snd; char op; - ctx->ectx.dtype = &priority_type; - ctx->ectx.len = NFT_NAME_MAXLEN * BITS_PER_BYTE; + expr_set_context(&ctx->ectx, &priority_type, NFT_NAME_MAXLEN * BITS_PER_BYTE); + if (expr_evaluate(ctx, &prio->expr) < 0) return false; if (prio->expr->etype != EXPR_VALUE) { @@ -3714,7 +5042,7 @@ static bool evaluate_priority(struct eval_ctx *ctx, struct prio_spec *prio, NFT_NAME_MAXLEN); loc = prio->expr->location; - if (sscanf(prio_str, "%s %c %d", prio_fst, &op, &prio_snd) < 3) { + if (sscanf(prio_str, "%255s %c %d", prio_fst, &op, &prio_snd) < 3) { priority = std_prio_lookup(prio_str, family, hook); if (priority == NF_IP_PRI_LAST) return false; @@ -3789,7 +5117,7 @@ static bool evaluate_device_expr(struct eval_ctx *ctx, struct expr **dev_expr) case EXPR_VALUE: break; default: - BUG("invalid expresion type %s\n", expr_name(expr)); + BUG("invalid expression type %s\n", expr_name(expr)); break; } @@ -3806,10 +5134,19 @@ static int flowtable_evaluate(struct eval_ctx *ctx, struct flowtable *ft) { struct table *table; - table = table_lookup_global(ctx); + table = table_cache_find(&ctx->nft->cache.table_cache, + ctx->cmd->handle.table.name, + ctx->cmd->handle.family); if (table == NULL) return table_not_found(ctx); + if (!ft_cache_find(table, ft->handle.flowtable.name)) { + if (!ft->hook.name && !ft->dev_expr) + return chain_error(ctx, ft, "missing hook and priority in flowtable declaration"); + + ft_cache_add(flowtable_get(ft), table); + } + if (ft->hook.name) { ft->hook.num = str2hooknum(NFPROTO_NETDEV, ft->hook.name); if (ft->hook.num == NF_INET_NUMHOOKS) @@ -3851,11 +5188,13 @@ static int rule_cache_update(struct eval_ctx *ctx, enum cmd_ops op) struct table *table; struct chain *chain; - table = table_lookup(&rule->handle, &ctx->nft->cache); + table = table_cache_find(&ctx->nft->cache.table_cache, + rule->handle.table.name, + rule->handle.family); if (!table) return table_not_found(ctx); - chain = chain_lookup(table, &rule->handle); + chain = chain_cache_find(table, rule->handle.chain.name); if (!chain) return chain_not_found(ctx); @@ -3916,7 +5255,9 @@ static int rule_evaluate(struct eval_ctx *ctx, struct rule *rule, struct stmt *stmt, *tstmt = NULL; struct error_record *erec; - proto_ctx_init(&ctx->pctx, rule->handle.family, ctx->nft->debug_mask); + proto_ctx_init(&ctx->_pctx[0], rule->handle.family, ctx->nft->debug_mask, false); + /* use NFPROTO_BRIDGE to set up proto_eth as base protocol. */ + proto_ctx_init(&ctx->_pctx[1], NFPROTO_BRIDGE, ctx->nft->debug_mask, true); memset(&ctx->ectx, 0, sizeof(ctx->ectx)); ctx->rule = rule; @@ -3931,6 +5272,8 @@ static int rule_evaluate(struct eval_ctx *ctx, struct rule *rule, return -1; if (stmt->flags & STMT_F_TERMINAL) tstmt = stmt; + + ctx->inner_desc = NULL; } erec = rule_postprocess(rule); @@ -3939,7 +5282,7 @@ static int rule_evaluate(struct eval_ctx *ctx, struct rule *rule, return -1; } - if (cache_needs_update(&ctx->nft->cache)) + if (nft_cache_needs_update(&ctx->nft->cache)) return rule_cache_update(ctx, op); return 0; @@ -3951,10 +5294,13 @@ static uint32_t str2hooknum(uint32_t family, const char *hook) return NF_INET_NUMHOOKS; switch (family) { + case NFPROTO_INET: + if (!strcmp(hook, "ingress")) + return NF_INET_INGRESS; + /* fall through */ case NFPROTO_IPV4: case NFPROTO_BRIDGE: case NFPROTO_IPV6: - case NFPROTO_INET: /* These families have overlapping values for each hook */ if (!strcmp(hook, "prerouting")) return NF_INET_PRE_ROUTING; @@ -3978,6 +5324,8 @@ static uint32_t str2hooknum(uint32_t family, const char *hook) case NFPROTO_NETDEV: if (!strcmp(hook, "ingress")) return NF_NETDEV_INGRESS; + else if (!strcmp(hook, "egress")) + return NF_NETDEV_EGRESS; break; default: break; @@ -3989,22 +5337,23 @@ static uint32_t str2hooknum(uint32_t family, const char *hook) static int chain_evaluate(struct eval_ctx *ctx, struct chain *chain) { struct table *table; - struct rule *rule; - table = table_lookup_global(ctx); + table = table_cache_find(&ctx->nft->cache.table_cache, + ctx->cmd->handle.table.name, + ctx->cmd->handle.family); if (table == NULL) return table_not_found(ctx); if (chain == NULL) { - if (chain_lookup(table, &ctx->cmd->handle) == NULL) { - chain = chain_alloc(NULL); + if (!chain_cache_find(table, ctx->cmd->handle.chain.name)) { + chain = chain_alloc(); handle_merge(&chain->handle, &ctx->cmd->handle); - chain_add_hash(chain, table); + chain_cache_add(chain, table); } return 0; } else if (!(chain->flags & CHAIN_F_BINDING)) { - if (chain_lookup(table, &chain->handle) == NULL) - chain_add_hash(chain_get(chain), table); + if (!chain_cache_find(table, chain->handle.chain.name)) + chain_cache_add(chain_get(chain), table); } if (chain->flags & CHAIN_F_BASECHAIN) { @@ -4027,13 +5376,17 @@ static int chain_evaluate(struct eval_ctx *ctx, struct chain *chain) return chain_error(ctx, chain, "invalid policy expression %s", expr_name(chain->policy)); } + } - if (chain->handle.family == NFPROTO_NETDEV) { - if (!chain->dev_expr) - return __stmt_binary_error(ctx, &chain->loc, NULL, - "Missing `device' in this chain definition"); + if (chain->dev_expr) { + if (!(chain->flags & CHAIN_F_BASECHAIN)) + chain->flags |= CHAIN_F_BASECHAIN; - if (!evaluate_device_expr(ctx, &chain->dev_expr)) + if (chain->handle.family == NFPROTO_NETDEV || + (chain->handle.family == NFPROTO_INET && + chain->hook.num == NF_INET_INGRESS)) { + if (chain->dev_expr && + !evaluate_device_expr(ctx, &chain->dev_expr)) return -1; } else if (chain->dev_expr) { return __stmt_binary_error(ctx, &chain->dev_expr->location, NULL, @@ -4041,11 +5394,6 @@ static int chain_evaluate(struct eval_ctx *ctx, struct chain *chain) } } - list_for_each_entry(rule, &chain->rules, list) { - handle_merge(&rule->handle, &chain->handle); - if (rule_evaluate(ctx, rule, CMD_INVALID) < 0) - return -1; - } return 0; } @@ -4079,8 +5427,8 @@ static int ct_timeout_evaluate(struct eval_ctx *ctx, struct obj *obj) ct->timeout[ts->timeout_index] = ts->timeout_value; list_del(&ts->head); - xfree(ts->timeout_str); - xfree(ts); + free_const(ts->timeout_str); + free(ts); } return 0; @@ -4088,6 +5436,17 @@ static int ct_timeout_evaluate(struct eval_ctx *ctx, struct obj *obj) static int obj_evaluate(struct eval_ctx *ctx, struct obj *obj) { + struct table *table; + + table = table_cache_find(&ctx->nft->cache.table_cache, + ctx->cmd->handle.table.name, + ctx->cmd->handle.family); + if (!table) + return table_not_found(ctx); + + if (!obj_cache_find(table, obj->handle.obj.name, obj->type)) + obj_cache_add(obj_get(obj), table); + switch (obj->type) { case NFT_OBJECT_CT_TIMEOUT: return ct_timeout_evaluate(ctx, obj); @@ -4102,49 +5461,18 @@ static int obj_evaluate(struct eval_ctx *ctx, struct obj *obj) static int table_evaluate(struct eval_ctx *ctx, struct table *table) { - struct flowtable *ft; - struct chain *chain; - struct set *set; - struct obj *obj; - - if (table_lookup(&ctx->cmd->handle, &ctx->nft->cache) == NULL) { - if (table == NULL) { + if (!table_cache_find(&ctx->nft->cache.table_cache, + ctx->cmd->handle.table.name, + ctx->cmd->handle.family)) { + if (!table) { table = table_alloc(); handle_merge(&table->handle, &ctx->cmd->handle); - table_add_hash(table, &ctx->nft->cache); + table_cache_add(table, &ctx->nft->cache); } else { - table_add_hash(table_get(table), &ctx->nft->cache); + table_cache_add(table_get(table), &ctx->nft->cache); } } - if (ctx->cmd->table == NULL) - return 0; - - ctx->table = table; - list_for_each_entry(set, &table->sets, list) { - expr_set_context(&ctx->ectx, NULL, 0); - handle_merge(&set->handle, &table->handle); - if (set_evaluate(ctx, set) < 0) - return -1; - } - list_for_each_entry(chain, &table->chains, list) { - handle_merge(&chain->handle, &table->handle); - ctx->cmd->handle.chain.location = chain->location; - if (chain_evaluate(ctx, chain) < 0) - return -1; - } - list_for_each_entry(ft, &table->flowtables, list) { - handle_merge(&ft->handle, &table->handle); - if (flowtable_evaluate(ctx, ft) < 0) - return -1; - } - list_for_each_entry(obj, &table->objs, list) { - handle_merge(&obj->handle, &table->handle); - if (obj_evaluate(ctx, obj) < 0) - return -1; - } - - ctx->table = NULL; return 0; } @@ -4156,6 +5484,8 @@ static int cmd_evaluate_add(struct eval_ctx *ctx, struct cmd *cmd) case CMD_OBJ_SET: handle_merge(&cmd->set->handle, &cmd->handle); return set_evaluate(ctx, cmd->set); + case CMD_OBJ_SETELEMS: + return elems_evaluate(ctx, cmd->set); case CMD_OBJ_RULE: handle_merge(&cmd->rule->handle, &cmd->handle); return rule_evaluate(ctx, cmd->rule, cmd->op); @@ -4174,6 +5504,7 @@ static int cmd_evaluate_add(struct eval_ctx *ctx, struct cmd *cmd) case CMD_OBJ_SECMARK: case CMD_OBJ_CT_EXPECT: case CMD_OBJ_SYNPROXY: + handle_merge(&cmd->object->handle, &cmd->handle); return obj_evaluate(ctx, cmd->object); default: BUG("invalid command object type %u\n", cmd->obj); @@ -4184,35 +5515,149 @@ static void table_del_cache(struct eval_ctx *ctx, struct cmd *cmd) { struct table *table; - table = table_lookup(&cmd->handle, &ctx->nft->cache); + if (!cmd->handle.table.name) + return; + + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); if (!table) return; - list_del(&table->list); + table_cache_del(table); table_free(table); } +static void chain_del_cache(struct eval_ctx *ctx, struct cmd *cmd) +{ + struct table *table; + struct chain *chain; + + if (!cmd->handle.chain.name) + return; + + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) + return; + + chain = chain_cache_find(table, cmd->handle.chain.name); + if (!chain) + return; + + chain_cache_del(chain); + chain_free(chain); +} + +static void set_del_cache(struct eval_ctx *ctx, struct cmd *cmd) +{ + struct table *table; + struct set *set; + + if (!cmd->handle.set.name) + return; + + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) + return; + + set = set_cache_find(table, cmd->handle.set.name); + if (!set) + return; + + set_cache_del(set); + set_free(set); +} + +static void ft_del_cache(struct eval_ctx *ctx, struct cmd *cmd) +{ + struct flowtable *ft; + struct table *table; + + if (!cmd->handle.flowtable.name) + return; + + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) + return; + + ft = ft_cache_find(table, cmd->handle.flowtable.name); + if (!ft) + return; + + ft_cache_del(ft); + flowtable_free(ft); +} + +static void obj_del_cache(struct eval_ctx *ctx, struct cmd *cmd, int type) +{ + struct table *table; + struct obj *obj; + + if (!cmd->handle.obj.name) + return; + + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) + return; + + obj = obj_cache_find(table, cmd->handle.obj.name, type); + if (!obj) + return; + + obj_cache_del(obj); + obj_free(obj); +} + static int cmd_evaluate_delete(struct eval_ctx *ctx, struct cmd *cmd) { switch (cmd->obj) { case CMD_OBJ_ELEMENTS: return setelem_evaluate(ctx, cmd); case CMD_OBJ_SET: + set_del_cache(ctx, cmd); + return 0; case CMD_OBJ_RULE: + return 0; case CMD_OBJ_CHAIN: + chain_del_cache(ctx, cmd); return 0; case CMD_OBJ_TABLE: table_del_cache(ctx, cmd); return 0; case CMD_OBJ_FLOWTABLE: + ft_del_cache(ctx, cmd); + return 0; case CMD_OBJ_COUNTER: + obj_del_cache(ctx, cmd, NFT_OBJECT_COUNTER); + return 0; case CMD_OBJ_QUOTA: + obj_del_cache(ctx, cmd, NFT_OBJECT_QUOTA); + return 0; case CMD_OBJ_CT_HELPER: + obj_del_cache(ctx, cmd, NFT_OBJECT_CT_HELPER); + return 0; case CMD_OBJ_CT_TIMEOUT: + obj_del_cache(ctx, cmd, NFT_OBJECT_CT_TIMEOUT); + return 0; case CMD_OBJ_LIMIT: + obj_del_cache(ctx, cmd, NFT_OBJECT_LIMIT); + return 0; case CMD_OBJ_SECMARK: + obj_del_cache(ctx, cmd, NFT_OBJECT_SECMARK); + return 0; case CMD_OBJ_CT_EXPECT: + obj_del_cache(ctx, cmd, NFT_OBJECT_CT_EXPECT); + return 0; case CMD_OBJ_SYNPROXY: + obj_del_cache(ctx, cmd, NFT_OBJECT_SYNPROXY); return 0; default: BUG("invalid command object type %u\n", cmd->obj); @@ -4240,7 +5685,7 @@ static int obj_not_found(struct eval_ctx *ctx, const struct location *loc, return cmd_error(ctx, loc, "%s", strerror(ENOENT)); return cmd_error(ctx, loc, - "%s; did you mean obj ‘%s’ in table %s ‘%s’?", + "%s; did you mean obj '%s' in table %s '%s'?", strerror(ENOENT), obj->handle.obj.name, family2str(obj->handle.family), table->handle.table.name); @@ -4254,11 +5699,13 @@ static int cmd_evaluate_list_obj(struct eval_ctx *ctx, const struct cmd *cmd, if (obj_type == NFT_OBJECT_UNSPEC) obj_type = NFT_OBJECT_COUNTER; - table = table_lookup(&cmd->handle, &ctx->nft->cache); + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); if (table == NULL) return table_not_found(ctx); - if (obj_lookup(table, cmd->handle.obj.name, obj_type) == NULL) + if (!obj_cache_find(table, cmd->handle.obj.name, obj_type)) return obj_not_found(ctx, &cmd->handle.obj.location, cmd->handle.obj.name); @@ -4276,69 +5723,54 @@ static int cmd_evaluate_list(struct eval_ctx *ctx, struct cmd *cmd) if (cmd->handle.table.name == NULL) return 0; - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) return table_not_found(ctx); return 0; case CMD_OBJ_SET: - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) - return table_not_found(ctx); - - set = set_lookup(table, cmd->handle.set.name); - if (set == NULL) - return set_not_found(ctx, &ctx->cmd->handle.set.location, - ctx->cmd->handle.set.name); - else if (!set_is_literal(set->flags)) - return cmd_error(ctx, &ctx->cmd->handle.set.location, - "%s", strerror(ENOENT)); - - return 0; - case CMD_OBJ_METER: - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) - return table_not_found(ctx); - - set = set_lookup(table, cmd->handle.set.name); - if (set == NULL) - return set_not_found(ctx, &ctx->cmd->handle.set.location, - ctx->cmd->handle.set.name); - else if (!set_is_meter(set->flags)) - return cmd_error(ctx, &ctx->cmd->handle.set.location, - "%s", strerror(ENOENT)); - - return 0; case CMD_OBJ_MAP: - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) + case CMD_OBJ_METER: + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) return table_not_found(ctx); - set = set_lookup(table, cmd->handle.set.name); + set = set_cache_find(table, cmd->handle.set.name); if (set == NULL) return set_not_found(ctx, &ctx->cmd->handle.set.location, ctx->cmd->handle.set.name); - else if (!map_is_literal(set->flags)) + if ((cmd->obj == CMD_OBJ_SET && !set_is_literal(set->flags)) || + (cmd->obj == CMD_OBJ_MAP && !map_is_literal(set->flags)) || + (cmd->obj == CMD_OBJ_METER && !set_is_meter_compat(set->flags))) return cmd_error(ctx, &ctx->cmd->handle.set.location, "%s", strerror(ENOENT)); + cmd->set = set_get(set); return 0; case CMD_OBJ_CHAIN: - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) return table_not_found(ctx); - if (chain_lookup(table, &cmd->handle) == NULL) + if (!chain_cache_find(table, cmd->handle.chain.name)) return chain_not_found(ctx); return 0; case CMD_OBJ_FLOWTABLE: - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) return table_not_found(ctx); - ft = flowtable_lookup(table, cmd->handle.flowtable.name); - if (ft == NULL) + ft = ft_cache_find(table, cmd->handle.flowtable.name); + if (!ft) return flowtable_not_found(ctx, &ctx->cmd->handle.flowtable.location, ctx->cmd->handle.flowtable.name); @@ -4367,9 +5799,13 @@ static int cmd_evaluate_list(struct eval_ctx *ctx, struct cmd *cmd) case CMD_OBJ_FLOWTABLES: case CMD_OBJ_SECMARKS: case CMD_OBJ_SYNPROXYS: + case CMD_OBJ_CT_TIMEOUTS: + case CMD_OBJ_CT_EXPECTATIONS: if (cmd->handle.table.name == NULL) return 0; - if (table_lookup(&cmd->handle, &ctx->nft->cache) == NULL) + if (!table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family)) return table_not_found(ctx); return 0; @@ -4378,6 +5814,16 @@ static int cmd_evaluate_list(struct eval_ctx *ctx, struct cmd *cmd) case CMD_OBJ_METERS: case CMD_OBJ_MAPS: return 0; + case CMD_OBJ_HOOKS: + if (cmd->handle.chain.name) { + int hooknum = str2hooknum(cmd->handle.family, cmd->handle.chain.name); + + if (hooknum == NF_INET_NUMHOOKS) + return chain_not_found(ctx); + + cmd->handle.chain_id = hooknum; + } + return 0; default: BUG("invalid command object type %u\n", cmd->obj); } @@ -4390,12 +5836,23 @@ static int cmd_evaluate_reset(struct eval_ctx *ctx, struct cmd *cmd) case CMD_OBJ_QUOTA: case CMD_OBJ_COUNTERS: case CMD_OBJ_QUOTAS: + case CMD_OBJ_RULES: + case CMD_OBJ_RULE: if (cmd->handle.table.name == NULL) return 0; - if (table_lookup(&cmd->handle, &ctx->nft->cache) == NULL) + if (!table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family)) return table_not_found(ctx); return 0; + case CMD_OBJ_ELEMENTS: + return setelem_evaluate(ctx, cmd); + case CMD_OBJ_TABLE: + case CMD_OBJ_CHAIN: + case CMD_OBJ_SET: + case CMD_OBJ_MAP: + return cmd_evaluate_list(ctx, cmd); default: BUG("invalid command object type %u\n", cmd->obj); } @@ -4411,6 +5868,7 @@ static void __flush_set_cache(struct set *set) static int cmd_evaluate_flush(struct eval_ctx *ctx, struct cmd *cmd) { + struct cache *table_cache = &ctx->nft->cache.table_cache; struct table *table; struct set *set; @@ -4425,11 +5883,12 @@ static int cmd_evaluate_flush(struct eval_ctx *ctx, struct cmd *cmd) /* Chains don't hold sets */ break; case CMD_OBJ_SET: - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) + table = table_cache_find(table_cache, cmd->handle.table.name, + cmd->handle.family); + if (!table) return table_not_found(ctx); - set = set_lookup(table, cmd->handle.set.name); + set = set_cache_find(table, cmd->handle.set.name); if (set == NULL) return set_not_found(ctx, &ctx->cmd->handle.set.location, ctx->cmd->handle.set.name); @@ -4441,11 +5900,12 @@ static int cmd_evaluate_flush(struct eval_ctx *ctx, struct cmd *cmd) return 0; case CMD_OBJ_MAP: - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) + table = table_cache_find(table_cache, cmd->handle.table.name, + cmd->handle.family); + if (!table) return table_not_found(ctx); - set = set_lookup(table, cmd->handle.set.name); + set = set_cache_find(table, cmd->handle.set.name); if (set == NULL) return set_not_found(ctx, &ctx->cmd->handle.set.location, ctx->cmd->handle.set.name); @@ -4457,15 +5917,16 @@ static int cmd_evaluate_flush(struct eval_ctx *ctx, struct cmd *cmd) return 0; case CMD_OBJ_METER: - table = table_lookup(&cmd->handle, &ctx->nft->cache); - if (table == NULL) + table = table_cache_find(table_cache, cmd->handle.table.name, + cmd->handle.family); + if (!table) return table_not_found(ctx); - set = set_lookup(table, cmd->handle.set.name); + set = set_cache_find(table, cmd->handle.set.name); if (set == NULL) return set_not_found(ctx, &ctx->cmd->handle.set.location, ctx->cmd->handle.set.name); - else if (!set_is_meter(set->flags)) + else if (!set_is_meter_compat(set->flags)) return cmd_error(ctx, &ctx->cmd->handle.set.location, "%s", strerror(ENOENT)); @@ -4484,11 +5945,13 @@ static int cmd_evaluate_rename(struct eval_ctx *ctx, struct cmd *cmd) switch (cmd->obj) { case CMD_OBJ_CHAIN: - table = table_lookup(&ctx->cmd->handle, &ctx->nft->cache); - if (table == NULL) + table = table_cache_find(&ctx->nft->cache.table_cache, + cmd->handle.table.name, + cmd->handle.family); + if (!table) return table_not_found(ctx); - if (chain_lookup(table, &ctx->cmd->handle) == NULL) + if (!chain_cache_find(table, ctx->cmd->handle.chain.name)) return chain_not_found(ctx); break; @@ -4626,11 +6089,12 @@ static const char * const cmd_op_name[] = { [CMD_EXPORT] = "export", [CMD_MONITOR] = "monitor", [CMD_DESCRIBE] = "describe", + [CMD_DESTROY] = "destroy", }; static const char *cmd_op_to_name(enum cmd_ops op) { - if (op > CMD_DESCRIBE) + if (op >= array_size(cmd_op_name)) return "unknown"; return cmd_op_name[op]; @@ -4658,6 +6122,7 @@ int cmd_evaluate(struct eval_ctx *ctx, struct cmd *cmd) case CMD_INSERT: return cmd_evaluate_add(ctx, cmd); case CMD_DELETE: + case CMD_DESTROY: return cmd_evaluate_delete(ctx, cmd); case CMD_GET: return cmd_evaluate_get(ctx, cmd); |