diff options
Diffstat (limited to 'src/evaluate.c')
-rw-r--r-- | src/evaluate.c | 105 |
1 files changed, 97 insertions, 8 deletions
diff --git a/src/evaluate.c b/src/evaluate.c index 72a0e435..ab732616 100644 --- a/src/evaluate.c +++ b/src/evaluate.c @@ -362,30 +362,108 @@ conflict_resolution_gen_dependency(struct eval_ctx *ctx, int protocol, return 0; } +static uint8_t expr_offset_shift(const struct expr *expr, unsigned int offset) +{ + unsigned int new_offset, len; + int shift; + + new_offset = offset % BITS_PER_BYTE; + len = round_up(expr->len, BITS_PER_BYTE); + shift = len - (new_offset + expr->len); + assert(shift >= 0); + + return shift; +} + +static void expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp) +{ + struct expr *expr = *exprp, *and, *mask, *lshift, *off; + unsigned masklen; + uint8_t shift; + mpz_t bitmask; + + switch (expr->ops->type) { + case EXPR_PAYLOAD: + shift = expr_offset_shift(expr, expr->payload.offset); + break; + case EXPR_EXTHDR: + shift = expr_offset_shift(expr, expr->exthdr.tmpl->offset); + break; + default: + BUG("Unknown expression %s\n", expr->ops->name); + } + + masklen = expr->len + shift; + assert(masklen <= NFT_REG_SIZE * BITS_PER_BYTE); + + mpz_init2(bitmask, masklen); + mpz_bitmask(bitmask, expr->len); + mpz_lshift_ui(bitmask, shift); + + mask = constant_expr_alloc(&expr->location, expr_basetype(expr), + BYTEORDER_HOST_ENDIAN, masklen, NULL); + mpz_set(mask->value, bitmask); + + and = binop_expr_alloc(&expr->location, OP_AND, expr, mask); + and->dtype = expr->dtype; + and->byteorder = expr->byteorder; + and->len = masklen; + + if (shift) { + off = constant_expr_alloc(&expr->location, + expr_basetype(expr), + BYTEORDER_BIG_ENDIAN, + sizeof(shift), &shift); + + lshift = binop_expr_alloc(&expr->location, OP_RSHIFT, and, off); + lshift->dtype = expr->dtype; + lshift->byteorder = expr->byteorder; + lshift->len = masklen; + + *exprp = lshift; + } else + *exprp = and; +} + +static int __expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) +{ + struct expr *expr = *exprp; + + if (expr_evaluate_primary(ctx, exprp) < 0) + return -1; + + if (expr->exthdr.tmpl->offset % BITS_PER_BYTE != 0 || + expr->len % BITS_PER_BYTE != 0) + expr_evaluate_bits(ctx, exprp); + + return 0; +} + /* * Exthdr expression: check whether dependencies are fulfilled, otherwise * generate the necessary relational expression and prepend it to the current * statement. */ -static int expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **expr) +static int expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp) { const struct proto_desc *base; + struct expr *expr = *exprp; struct stmt *nstmt; base = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc; if (base == &proto_ip6) - return expr_evaluate_primary(ctx, expr); + return __expr_evaluate_exthdr(ctx, exprp); if (base) - return expr_error(ctx->msgs, *expr, + return expr_error(ctx->msgs, expr, "cannot use exthdr with %s", base->name); - if (exthdr_gen_dependency(ctx, *expr, &nstmt) < 0) + if (exthdr_gen_dependency(ctx, expr, &nstmt) < 0) return -1; list_add(&nstmt->list, &ctx->rule->stmts); - return expr_evaluate_primary(ctx, expr); + return __expr_evaluate_exthdr(ctx, exprp); } /* dependency supersede. @@ -509,12 +587,21 @@ static int __expr_evaluate_payload(struct eval_ctx *ctx, struct expr *expr) return 0; } -static int expr_evaluate_payload(struct eval_ctx *ctx, struct expr **expr) +static int expr_evaluate_payload(struct eval_ctx *ctx, struct expr **exprp) { - if (__expr_evaluate_payload(ctx, *expr) < 0) + struct expr *expr = *exprp; + + if (__expr_evaluate_payload(ctx, expr) < 0) return -1; - return expr_evaluate_primary(ctx, expr); + if (expr_evaluate_primary(ctx, exprp) < 0) + return -1; + + if (expr->payload.offset % BITS_PER_BYTE != 0 || + expr->len % BITS_PER_BYTE != 0) + expr_evaluate_bits(ctx, exprp); + + return 0; } /* @@ -1096,6 +1183,8 @@ static int binop_can_transfer(struct eval_ctx *ctx, "Comparison is always false"); return 1; case OP_RSHIFT: + if (ctx->ectx.len < right->len + mpz_get_uint32(left->right->value)) + ctx->ectx.len += mpz_get_uint32(left->right->value); return 1; case OP_XOR: return 1; |