diff options
Diffstat (limited to 'src/callback.c')
-rw-r--r-- | src/callback.c | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/src/callback.c b/src/callback.c index 1b18b2a..f6637b3 100644 --- a/src/callback.c +++ b/src/callback.c @@ -48,7 +48,8 @@ static mnl_cb_t default_cb_array[NLMSG_MIN_TYPE] = { * mnl_cb_run2 - callback runqueue for netlink messages * @buf: buffer that contains the netlink messages * @numbytes: number of bytes stored in the buffer - * @seq: sequence number that we expect to receive (use zero to skip) + * @seq: sequence number that we expect to receive + * @portid: Netlink PortID that we expect to receive * @cb_data: callback handler for data messages * @data: pointer to data that will be passed to the data callback handler * @cb_ctl_array: array of custom callback handlers from control messages @@ -66,13 +67,18 @@ static mnl_cb_t default_cb_array[NLMSG_MIN_TYPE] = { * This function propagates the callback return value. */ int mnl_cb_run2(const char *buf, int numbytes, unsigned int seq, - mnl_cb_t cb_data, void *data, + unsigned int portid, mnl_cb_t cb_data, void *data, mnl_cb_t *cb_ctl_array, unsigned int cb_ctl_array_len) { int ret = MNL_CB_OK; struct nlmsghdr *nlh = (struct nlmsghdr *)buf; while (mnl_nlmsg_ok(nlh, numbytes)) { + /* check message source */ + if (!mnl_nlmsg_portid_ok(nlh, portid)) { + errno = EINVAL; + return -1; + } /* perform sequence tracking */ if (!mnl_nlmsg_seq_ok(nlh, seq)) { errno = EILSEQ; @@ -107,7 +113,8 @@ out: * mnl_cb_run - callback runqueue for netlink messages (simplified version) * @buf: buffer that contains the netlink messages * @numbytes: number of bytes stored in the buffer - * @seq: sequence number that we expect to receive (use zero to skip) + * @seq: sequence number that we expect to receive + * @portid: Netlink PortID that we expect to receive * @cb_data: callback handler for data messages * @data: pointer to data that will be passed to the data callback handler * @@ -122,7 +129,7 @@ out: * This function propagates the callback return value. */ int mnl_cb_run(const char *buf, int numbytes, unsigned int seq, - mnl_cb_t cb_data, void *data) + unsigned int portid, mnl_cb_t cb_data, void *data) { - return mnl_cb_run2(buf, numbytes, seq, cb_data, data, NULL, 0); + return mnl_cb_run2(buf, numbytes, seq, portid, cb_data, data, NULL, 0); } |