From 3713072d3bcb5dc1cfbb7c5fa3e24b8a73fd4104 Mon Sep 17 00:00:00 2001 From: Jozsef Kadlecsik Date: Mon, 26 Nov 2018 10:54:36 +0100 Subject: Implement sorting for hash types in the ipset tool Support listing/saving with sorted entries for the hash types. (bitmap and list types are automatically sorted.) --- lib/Makefile.am | 1 + lib/list_sort.c | 141 ++++++++++++++++++++++++++++++++++++++++++++++++++++ lib/session.c | 152 +++++++++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 277 insertions(+), 17 deletions(-) create mode 100644 lib/list_sort.c (limited to 'lib') diff --git a/lib/Makefile.am b/lib/Makefile.am index 281f693..3a82417 100644 --- a/lib/Makefile.am +++ b/lib/Makefile.am @@ -32,6 +32,7 @@ libipset_la_SOURCES = \ errcode.c \ icmp.c \ icmpv6.c \ + list_sort.c \ mnl.c \ parse.c \ print.c \ diff --git a/lib/list_sort.c b/lib/list_sort.c new file mode 100644 index 0000000..56c43f5 --- /dev/null +++ b/lib/list_sort.c @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copied from the Linux kernel lib/list_sort.c file */ +#include +#include /* memset */ +#include + +#define MAX_LIST_LENGTH_BITS 20 + +/* + * Returns a list organized in an intermediate format suited + * to chaining of merge() calls: null-terminated, no reserved or + * sentinel head node, "prev" links not maintained. + */ +static struct list_head *merge(void *priv, + int (*cmp)(void *priv, struct list_head *a, + struct list_head *b), + struct list_head *a, struct list_head *b) +{ + struct list_head head, *tail = &head; + + while (a && b) { + /* if equal, take 'a' -- important for sort stability */ + if ((*cmp)(priv, a, b) <= 0) { + tail->next = a; + a = a->next; + } else { + tail->next = b; + b = b->next; + } + tail = tail->next; + } + tail->next = a?:b; + return head.next; +} + +/* + * Combine final list merge with restoration of standard doubly-linked + * list structure. This approach duplicates code from merge(), but + * runs faster than the tidier alternatives of either a separate final + * prev-link restoration pass, or maintaining the prev links + * throughout. + */ +static void merge_and_restore_back_links(void *priv, + int (*cmp)(void *priv, struct list_head *a, + struct list_head *b), + struct list_head *head, + struct list_head *a, struct list_head *b) +{ + struct list_head *tail = head; + int count = 0; + + while (a && b) { + /* if equal, take 'a' -- important for sort stability */ + if ((*cmp)(priv, a, b) <= 0) { + tail->next = a; + a->prev = tail; + a = a->next; + } else { + tail->next = b; + b->prev = tail; + b = b->next; + } + tail = tail->next; + } + tail->next = a ? : b; + + do { + /* + * In worst cases this loop may run many iterations. + * Continue callbacks to the client even though no + * element comparison is needed, so the client's cmp() + * routine can invoke cond_resched() periodically. + */ + if (unlikely(!(++count))) + (*cmp)(priv, tail->next, tail->next); + + tail->next->prev = tail; + tail = tail->next; + } while (tail->next); + + tail->next = head; + head->prev = tail; +} + +/** + * list_sort - sort a list + * @priv: private data, opaque to list_sort(), passed to @cmp + * @head: the list to sort + * @cmp: the elements comparison function + * + * This function implements "merge sort", which has O(nlog(n)) + * complexity. + * + * The comparison function @cmp must return a negative value if @a + * should sort before @b, and a positive value if @a should sort after + * @b. If @a and @b are equivalent, and their original relative + * ordering is to be preserved, @cmp must return 0. + */ +void list_sort(void *priv, struct list_head *head, + int (*cmp)(void *priv, struct list_head *a, + struct list_head *b)) +{ + struct list_head *part[MAX_LIST_LENGTH_BITS+1]; /* sorted partial lists + -- last slot is a sentinel */ + int lev; /* index into part[] */ + int max_lev = 0; + struct list_head *list; + + if (list_empty(head)) + return; + + memset(part, 0, sizeof(part)); + + head->prev->next = NULL; + list = head->next; + + while (list) { + struct list_head *cur = list; + list = list->next; + cur->next = NULL; + + for (lev = 0; part[lev]; lev++) { + cur = merge(priv, cmp, part[lev], cur); + part[lev] = NULL; + } + if (lev > max_lev) { + if (unlikely(lev >= MAX_LIST_LENGTH_BITS)) { + // printk_once(KERN_DEBUG "list too long for efficiency\n"); + lev--; + } + max_lev = lev; + } + part[lev] = cur; + } + + for (lev = 0; lev < max_lev; lev++) + if (part[lev]) + list = merge(priv, cmp, part[lev], list); + + merge_and_restore_back_links(priv, cmp, head, part[max_lev], list); +} diff --git a/lib/session.c b/lib/session.c index 768cc05..9e1d7dd 100644 --- a/lib/session.c +++ b/lib/session.c @@ -27,10 +27,18 @@ #include /* default backend */ #include /* STREQ */ #include /* IPSET_ENV_* */ +#include /* list_sort */ #include /* prototypes */ #define IPSET_NEST_MAX 4 +/* When we want to sort the entries */ +struct ipset_sorted { + struct list_head list; + size_t offset; /* Offset in outbuf */ +}; + + /* The session structure */ struct ipset_session { const struct ipset_transport *transport;/* Transport protocol */ @@ -47,10 +55,15 @@ struct ipset_session { uint8_t protocol; /* The protocol used */ bool version_checked; /* Version checked */ /* Output buffer */ - char outbuf[IPSET_OUTBUFLEN]; /* Output buffer */ + char *outbuf; /* Output buffer */ + size_t outbuflen; /* Output buffer size */ + size_t pos; /* Printing position in outbuf */ + struct list_head sorted; /* Sorted entries */ + struct list_head pool; /* Pool to reuse */ enum ipset_output_mode mode; /* Output mode */ ipset_print_outfn print_outfn; /* Output function to file */ void *p; /* Private data for print_outfn */ + bool sort; /* Print sorted hash:* types */ /* Session IO */ bool normal_io, full_io; /* Default/normal/full IO */ FILE *istream, *ostream; /* Session input/output stream */ @@ -744,6 +757,7 @@ call_outfn(struct ipset_session *session) "%s", session->outbuf); session->outbuf[0] = '\0'; + session->pos = 0; return ret < 0 ? ret : 0; } @@ -751,20 +765,39 @@ call_outfn(struct ipset_session *session) /* Handle printing failures */ static jmp_buf printf_failure; +static void +realloc_outbuf(struct ipset_session *session) +{ + char *buf = realloc(session->outbuf, + session->outbuflen + IPSET_OUTBUFLEN); + if (!buf) { + ipset_err(session, + "Could not allocate memory to print sorted!"); + longjmp(printf_failure, 1); + } + session->outbuf = buf; + session->outbuflen += IPSET_OUTBUFLEN; +} + static int handle_snprintf_error(struct ipset_session *session, - int len, int ret, int loop) + int ret, int loop) { - if (ret < 0 || ret >= IPSET_OUTBUFLEN - len) { + if (ret < 0 || ret + session->pos >= session->outbuflen) { + if (session->sort && !loop) { + realloc_outbuf(session); + return 1; + } /* Buffer was too small, push it out and retry */ - D("print buffer and try again: len: %u, ret: %d", len, ret); + D("print buffer and try again: outbuflen: %lu, pos %lu, ret: %d", + session->outbuflen, session->pos, ret); if (loop) { ipset_err(session, "Internal error at printing, loop detected!"); longjmp(printf_failure, 1); } - session->outbuf[len] = '\0'; + session->outbuf[session->pos] = '\0'; if (call_outfn(session)) { ipset_err(session, "Internal error, could not print output buffer!"); @@ -772,6 +805,7 @@ handle_snprintf_error(struct ipset_session *session, } return 1; } + session->pos += ret; return 0; } @@ -779,17 +813,17 @@ static int __attribute__((format(printf, 2, 3))) safe_snprintf(struct ipset_session *session, const char *fmt, ...) { va_list args; - int len, ret, loop = 0; + int ret, loop = 0; do { - len = strlen(session->outbuf); - D("len: %u, retry %u", len, loop); + D("outbuflen: %lu, pos: %lu, retry %u", + session->outbuflen, session->pos, loop); va_start(args, fmt); - ret = vsnprintf(session->outbuf + len, - IPSET_OUTBUFLEN - len, + ret = vsnprintf(session->outbuf + session->pos, + session->outbuflen - session->pos, fmt, args); va_end(args); - loop = handle_snprintf_error(session, len, ret, loop); + loop = handle_snprintf_error(session, ret, loop); } while (loop); return ret; @@ -799,14 +833,15 @@ static int safe_dprintf(struct ipset_session *session, ipset_printfn fn, enum ipset_opt opt) { - int len, ret, loop = 0; + int ret, loop = 0; do { - len = strlen(session->outbuf); - D("len: %u, retry %u", len, loop); - ret = fn(session->outbuf + len, IPSET_OUTBUFLEN - len, + D("outbuflen: %lu, len: %lu, retry %u", + session->outbuflen, session->pos, loop); + ret = fn(session->outbuf + session->pos, + session->outbuflen - session->pos, session->data, opt, session->envopts); - loop = handle_snprintf_error(session, len, ret, loop); + loop = handle_snprintf_error(session, ret, loop); } while (loop); return ret; @@ -818,6 +853,7 @@ list_adt(struct ipset_session *session, struct nlattr *nla[]) const struct ipset_data *data = session->data; const struct ipset_type *type; const struct ipset_arg *arg; + size_t offset = 0; int i, found = 0; D("enter"); @@ -839,6 +875,13 @@ list_adt(struct ipset_session *session, struct nlattr *nla[]) if (!found) return MNL_CB_OK; + if (session->sort) { + if (session->outbuflen <= session->pos + 1) + realloc_outbuf(session); + session->pos++; /* \0 */ + offset = session->pos; + } + switch (session->mode) { case IPSET_LIST_SAVE: safe_snprintf(session, "add %s ", ipset_data_setname(data)); @@ -891,6 +934,24 @@ list_adt(struct ipset_session *session, struct nlattr *nla[]) else safe_snprintf(session, "\n"); + if (session->sort) { + struct ipset_sorted *sorted; + + if (!list_empty(&session->pool)) { + sorted = list_first_entry(&session->pool, + struct ipset_sorted, list); + list_del(&sorted->list); + } else { + sorted = calloc(1, sizeof(struct ipset_sorted)); + if (!sorted) { + ipset_err(session, + "Could not allocate memory to print sorted!"); + longjmp(printf_failure, 1); + } + } + sorted->offset = offset; + list_add_tail(&sorted->list, &session->sorted); + } return MNL_CB_OK; } @@ -1017,14 +1078,50 @@ list_create(struct ipset_session *session, struct nlattr *nla[]) } session->printed_set++; + session->sort = strncmp(type->name, "hash:", 5) == 0 && + ipset_envopt_test(session, IPSET_ENV_SORTED); + return MNL_CB_OK; } +static int +bystrcmp(void *priv, struct list_head *a, struct list_head *b) +{ + struct ipset_session *session = priv; + struct ipset_sorted *x = list_entry(a, struct ipset_sorted, list); + struct ipset_sorted *y = list_entry(b, struct ipset_sorted, list); + + return strcmp(session->outbuf + x->offset, + session->outbuf + y->offset); +} + static int print_set_done(struct ipset_session *session, bool callback_done) { D("called for %s", session->saved_setname[0] == '\0' ? "NONE" : session->saved_setname); + if (session->sort) { + struct ipset_sorted *pos; + int ret; + + /* Print set header */ + ret = call_outfn(session); + + if (ret) + return MNL_CB_ERROR; + + list_sort(session, &session->sorted, bystrcmp); + + list_for_each_entry(pos, &session->sorted, list) { + ret = session->print_outfn(session, session->p, + "%s", + session->outbuf + pos->offset); + if (ret < 0) + return MNL_CB_ERROR; + } + list_splice(&session->sorted, &session->pool); + INIT_LIST_HEAD(&session->sorted); + } switch (session->mode) { case IPSET_LIST_XML: if (session->envopts & IPSET_ENV_LIST_SETNAME) @@ -1138,6 +1235,8 @@ callback_list(struct ipset_session *session, struct nlattr *nla[], if (list_adt(session, adt) != MNL_CB_OK) return MNL_CB_ERROR; } + if (session->sort) + return MNL_CB_OK; } return call_outfn(session) ? MNL_CB_ERROR : MNL_CB_OK; } @@ -2121,11 +2220,17 @@ ipset_session_init(ipset_print_outfn print_outfn, void *p) session = calloc(1, sizeof(struct ipset_session) + bufsize); if (session == NULL) return NULL; + session->outbuf = calloc(1, IPSET_OUTBUFLEN); + if (session->outbuf == NULL) + goto free_session; + session->outbuflen = IPSET_OUTBUFLEN; session->bufsize = bufsize; session->buffer = session + 1; session->istream = stdin; session->ostream = stdout; session->protocol = IPSET_PROTOCOL; + INIT_LIST_HEAD(&session->sorted); + INIT_LIST_HEAD(&session->pool); /* The single transport method yet */ session->transport = &ipset_mnl_transport; @@ -2136,11 +2241,13 @@ ipset_session_init(ipset_print_outfn print_outfn, void *p) /* Initialize data structures */ session->data = ipset_data_init(); if (session->data == NULL) - goto free_session; + goto free_outbuf; ipset_cache_init(); return session; +free_outbuf: + free(session->outbuf); free_session: free(session); return NULL; @@ -2348,6 +2455,7 @@ ipset_session_io_close(struct ipset_session *session, int ipset_session_fini(struct ipset_session *session) { + struct ipset_sorted *pos, *n; assert(session); if (session->handle) @@ -2360,6 +2468,16 @@ ipset_session_fini(struct ipset_session *session) fclose(session->ostream); ipset_cache_fini(); + + list_for_each_entry_safe(pos, n, &session->sorted, list) { + list_del(&pos->list); + free(pos); + } + list_for_each_entry_safe(pos, n, &session->pool, list) { + list_del(&pos->list); + free(pos); + } + free(session->outbuf); free(session); return 0; } -- cgit v1.2.3