diff options
Diffstat (limited to 'src/mnl.c')
-rw-r--r-- | src/mnl.c | 250 |
1 files changed, 250 insertions, 0 deletions
@@ -21,9 +21,15 @@ #include <mnl.h> #include <errno.h> #include <utils.h> +#include <nftables.h> static int seq; +uint32_t mnl_seqnum_alloc(void) +{ + return seq++; +} + static int mnl_talk(struct mnl_socket *nf_sock, const void *data, unsigned int len, int (*cb)(const struct nlmsghdr *nlh, void *data), void *cb_data) @@ -51,6 +57,250 @@ out: } /* + * Batching + */ + +/* selected batch page is 256 Kbytes long to load ruleset of + * half a million rules without hitting -EMSGSIZE due to large + * iovec. + */ +#define BATCH_PAGE_SIZE getpagesize() * 32 + +static struct mnl_nlmsg_batch *batch; + +static struct mnl_nlmsg_batch *mnl_batch_alloc(void) +{ + static char *buf; + + /* libmnl needs higher buffer to handle batch overflows */ + buf = xmalloc(BATCH_PAGE_SIZE + getpagesize()); + return mnl_nlmsg_batch_start(buf, BATCH_PAGE_SIZE); +} + +void mnl_batch_init(void) +{ + batch = mnl_batch_alloc(); +} + +static LIST_HEAD(batch_page_list); +static int batch_num_pages; + +struct batch_page { + struct list_head head; + struct mnl_nlmsg_batch *batch; +}; + +static void mnl_batch_page_add(void) +{ + struct batch_page *batch_page; + + batch_page = xmalloc(sizeof(struct batch_page)); + batch_page->batch = batch; + list_add_tail(&batch_page->head, &batch_page_list); + batch_num_pages++; + batch = mnl_batch_alloc(); +} + +static void mnl_batch_put(int type) +{ + struct nlmsghdr *nlh; + struct nfgenmsg *nfg; + + nlh = mnl_nlmsg_put_header(mnl_nlmsg_batch_current(batch)); + nlh->nlmsg_type = type; + nlh->nlmsg_flags = NLM_F_REQUEST; + nlh->nlmsg_seq = mnl_seqnum_alloc(); + + nfg = mnl_nlmsg_put_extra_header(nlh, sizeof(*nfg)); + nfg->nfgen_family = AF_INET; + nfg->version = NFNETLINK_V0; + nfg->res_id = NFNL_SUBSYS_NFTABLES; + + if (!mnl_nlmsg_batch_next(batch)) + mnl_batch_page_add(); +} + +void mnl_batch_begin(void) +{ + mnl_batch_put(NFNL_MSG_BATCH_BEGIN); +} + +void mnl_batch_end(void) +{ + mnl_batch_put(NFNL_MSG_BATCH_END); +} + +bool mnl_batch_ready(void) +{ + /* Check if the batch only contains the initial and trailing batch + * messages. In that case, the batch is empty. + */ + return mnl_nlmsg_batch_size(batch) != (NLMSG_HDRLEN+sizeof(struct nfgenmsg)) * 2; +} + +void mnl_batch_reset(void) +{ + mnl_nlmsg_batch_reset(batch); +} + +static void mnl_err_list_node_add(struct list_head *err_list, int error, + int seqnum) +{ + struct mnl_err *err = xmalloc(sizeof(struct mnl_err)); + + err->seqnum = seqnum; + err->err = error; + list_add_tail(&err->head, err_list); +} + +void mnl_err_list_free(struct mnl_err *err) +{ + list_del(&err->head); + xfree(err); +} + +static int nlbuffsiz; + +static void mnl_set_sndbuffer(const struct mnl_socket *nl) +{ + int newbuffsiz; + + if (batch_num_pages * BATCH_PAGE_SIZE <= nlbuffsiz) + return; + + newbuffsiz = batch_num_pages * BATCH_PAGE_SIZE; + + /* Rise sender buffer length to avoid hitting -EMSGSIZE */ + if (setsockopt(mnl_socket_get_fd(nl), SOL_SOCKET, SO_SNDBUFFORCE, + &newbuffsiz, sizeof(socklen_t)) < 0) + return; + + nlbuffsiz = newbuffsiz; +} + +static ssize_t mnl_nft_socket_sendmsg(const struct mnl_socket *nl) +{ + static const struct sockaddr_nl snl = { + .nl_family = AF_NETLINK + }; + struct iovec iov[batch_num_pages]; + struct msghdr msg = { + .msg_name = (struct sockaddr *) &snl, + .msg_namelen = sizeof(snl), + .msg_iov = iov, + .msg_iovlen = batch_num_pages, + }; + struct batch_page *batch_page, *next; + int i = 0; + + mnl_set_sndbuffer(nl); + + list_for_each_entry_safe(batch_page, next, &batch_page_list, head) { + iov[i].iov_base = mnl_nlmsg_batch_head(batch_page->batch); + iov[i].iov_len = mnl_nlmsg_batch_size(batch_page->batch); + i++; +#ifdef DEBUG + if (debug_level & DEBUG_NETLINK) { + mnl_nlmsg_fprintf(stdout, + mnl_nlmsg_batch_head(batch_page->batch), + mnl_nlmsg_batch_size(batch_page->batch), + sizeof(struct nfgenmsg)); + } +#endif + list_del(&batch_page->head); + xfree(batch_page->batch); + xfree(batch_page); + batch_num_pages--; + } + + return sendmsg(mnl_socket_get_fd(nl), &msg, 0); +} + +int mnl_batch_talk(struct mnl_socket *nl, struct list_head *err_list) +{ + int ret, fd = mnl_socket_get_fd(nl), portid = mnl_socket_get_portid(nl); + char rcv_buf[MNL_SOCKET_BUFFER_SIZE]; + fd_set readfds; + struct timeval tv = { + .tv_sec = 0, + .tv_usec = 0 + }; + + if (!mnl_nlmsg_batch_is_empty(batch)) + mnl_batch_page_add(); + + ret = mnl_nft_socket_sendmsg(nl); + if (ret == -1) + goto err; + + FD_ZERO(&readfds); + FD_SET(fd, &readfds); + + /* receive and digest all the acknowledgments from the kernel. */ + ret = select(fd+1, &readfds, NULL, NULL, &tv); + if (ret == -1) + goto err; + + while (ret > 0 && FD_ISSET(fd, &readfds)) { + struct nlmsghdr *nlh = (struct nlmsghdr *)rcv_buf; + + ret = mnl_socket_recvfrom(nl, rcv_buf, sizeof(rcv_buf)); + if (ret == -1) + goto err; + + ret = mnl_cb_run(rcv_buf, ret, 0, portid, NULL, NULL); + /* Continue on error, make sure we get all acknoledgments */ + if (ret == -1) + mnl_err_list_node_add(err_list, errno, nlh->nlmsg_seq); + + ret = select(fd+1, &readfds, NULL, NULL, &tv); + if (ret == -1) + goto err; + + FD_ZERO(&readfds); + FD_SET(fd, &readfds); + } +err: + mnl_nlmsg_batch_reset(batch); + return ret; +} + +int mnl_nft_rule_batch_add(struct nft_rule *nlr, unsigned int flags, + uint32_t seqnum) +{ + struct nlmsghdr *nlh; + + nlh = nft_table_nlmsg_build_hdr(mnl_nlmsg_batch_current(batch), + NFT_MSG_NEWRULE, + nft_rule_attr_get_u32(nlr, NFT_RULE_ATTR_FAMILY), + flags|NLM_F_ACK|NLM_F_CREATE, seqnum); + + nft_rule_nlmsg_build_payload(nlh, nlr); + if (!mnl_nlmsg_batch_next(batch)) + mnl_batch_page_add(); + + return 0; +} + +int mnl_nft_rule_batch_del(struct nft_rule *nlr, unsigned int flags, + uint32_t seqnum) +{ + struct nlmsghdr *nlh; + + nlh = nft_table_nlmsg_build_hdr(mnl_nlmsg_batch_current(batch), + NFT_MSG_DELRULE, + nft_rule_attr_get_u32(nlr, NFT_RULE_ATTR_FAMILY), + NLM_F_ACK, seqnum); + + nft_rule_nlmsg_build_payload(nlh, nlr); + + if (!mnl_nlmsg_batch_next(batch)) + mnl_batch_page_add(); + + return 0; +} + +/* * Rule */ int mnl_nft_rule_add(struct mnl_socket *nf_sock, struct nft_rule *nlr, |