@@ -135,7 +135,7 @@ static int __bpf_set_link_xdp_fd_replace(int ifindex, int fd, int old_fd,
__u32 flags)
{
int sock, seq = 0, ret;
- struct nlattr *nla, *nla_xdp;
+ struct nlattr *nla;
struct {
struct nlmsghdr nh;
struct ifinfomsg ifinfo;
@@ -157,36 +157,31 @@ static int __bpf_set_link_xdp_fd_replace(int ifindex, int fd, int old_fd,
req.ifinfo.ifi_index = ifindex;
/* started nested attribute for XDP */
- nla = (struct nlattr *)(((char *)&req)
- + NLMSG_ALIGN(req.nh.nlmsg_len));
- nla->nla_type = NLA_F_NESTED | IFLA_XDP;
- nla->nla_len = NLA_HDRLEN;
+ nla = begin_nlattr_nested(&req.nh, sizeof(req), IFLA_XDP);
+ if (!nla) {
+ ret = -EMSGSIZE;
+ goto cleanup;
+ }
/* add XDP fd */
- nla_xdp = (struct nlattr *)((char *)nla + nla->nla_len);
- nla_xdp->nla_type = IFLA_XDP_FD;
- nla_xdp->nla_len = NLA_HDRLEN + sizeof(int);
- memcpy((char *)nla_xdp + NLA_HDRLEN, &fd, sizeof(fd));
- nla->nla_len += nla_xdp->nla_len;
+ ret = add_nlattr(&req.nh, sizeof(req), IFLA_XDP_FD, &fd, sizeof(fd));
+ if (ret < 0)
+ goto cleanup;
/* if user passed in any flags, add those too */
if (flags) {
- nla_xdp = (struct nlattr *)((char *)nla + nla->nla_len);
- nla_xdp->nla_type = IFLA_XDP_FLAGS;
- nla_xdp->nla_len = NLA_HDRLEN + sizeof(flags);
- memcpy((char *)nla_xdp + NLA_HDRLEN, &flags, sizeof(flags));
- nla->nla_len += nla_xdp->nla_len;
+ ret = add_nlattr(&req.nh, sizeof(req), IFLA_XDP_FLAGS, &flags, sizeof(flags));
+ if (ret < 0)
+ goto cleanup;
}
if (flags & XDP_FLAGS_REPLACE) {
- nla_xdp = (struct nlattr *)((char *)nla + nla->nla_len);
- nla_xdp->nla_type = IFLA_XDP_EXPECTED_FD;
- nla_xdp->nla_len = NLA_HDRLEN + sizeof(old_fd);
- memcpy((char *)nla_xdp + NLA_HDRLEN, &old_fd, sizeof(old_fd));
- nla->nla_len += nla_xdp->nla_len;
+ ret = add_nlattr(&req.nh, sizeof(req), IFLA_XDP_EXPECTED_FD, &flags, sizeof(flags));
+ if (ret < 0)
+ goto cleanup;
}
- req.nh.nlmsg_len += NLA_ALIGN(nla->nla_len);
+ end_nlattr_nested(&req.nh, nla);
if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
ret = -errno;
@@ -10,7 +10,10 @@
#define __LIBBPF_NLATTR_H
#include <stdint.h>
+#include <string.h>
+#include <errno.h>
#include <linux/netlink.h>
+
/* avoid multiple definition of netlink features */
#define __LINUX_NETLINK_H
@@ -103,4 +106,44 @@ int libbpf_nla_parse_nested(struct nlattr *tb[], int maxtype,
int libbpf_nla_dump_errormsg(struct nlmsghdr *nlh);
+
+/* Helpers for preparing/consuming attributes */
+
+#define NLA_DATA(nla) ((struct nlattr *)((char *)(nla) + NLA_HDRLEN))
+
+static inline int add_nlattr(struct nlmsghdr *nh, size_t maxsz, int type,
+ const void *data, int len)
+{
+ struct nlattr *nla;
+
+ if (NLMSG_ALIGN(nh->nlmsg_len) + NLA_ALIGN(NLA_HDRLEN + len) > maxsz)
+ return -EMSGSIZE;
+ if ((!data && len) || (data && !len))
+ return -EINVAL;
+
+ nla = (struct nlattr *)((char *)nh + NLMSG_ALIGN(nh->nlmsg_len));
+ nla->nla_type = type;
+ nla->nla_len = NLA_HDRLEN + len;
+ if (data)
+ memcpy((char *)nla + NLA_HDRLEN, data, len);
+ nh->nlmsg_len = NLMSG_ALIGN(nh->nlmsg_len) + NLA_ALIGN(nla->nla_len);
+ return 0;
+}
+
+static inline struct nlattr *begin_nlattr_nested(struct nlmsghdr *nh, size_t maxsz,
+ int type)
+{
+ struct nlattr *tail;
+
+ tail = (struct nlattr *)((char *)nh + NLMSG_ALIGN(nh->nlmsg_len));
+ if (add_nlattr(nh, maxsz, type | NLA_F_NESTED, NULL, 0))
+ return NULL;
+ return tail;
+}
+
+static inline void end_nlattr_nested(struct nlmsghdr *nh, struct nlattr *tail)
+{
+ tail->nla_len = ((char *)nh + NLMSG_ALIGN(nh->nlmsg_len)) - (char *)(tail);
+}
+
#endif /* __LIBBPF_NLATTR_H */