diff mbox series

[net-next,af_unix,v1,1/1] af_unix: Add OOB support

Message ID 20210801075707.176201-2-Rao.Shoaib@oracle.com
State New
Headers show
Series af_unix: Add OOB support | expand

Commit Message

Rao Shoaib Aug. 1, 2021, 7:57 a.m. UTC
From: Rao Shoaib <rao.shoaib@oracle.com>

This patch adds OOB support for AF_UNIX sockets.
The semantics is same as TCP.

The last byte of a message with the OOB flag is
treated as the OOB byte. The byte is separated into
a skb and a pointer to the skb is stored in unix_sock.
The pointer is used to enforce OOB semantics.

Signed-off-by: Rao Shoaib <rao.shoaib@oracle.com>
---
 include/net/af_unix.h                         |   3 +
 net/unix/Kconfig                              |   5 +
 net/unix/af_unix.c                            | 153 +++++-
 tools/testing/selftests/Makefile              |   1 +
 tools/testing/selftests/net/af_unix/Makefile  |   5 +
 .../selftests/net/af_unix/test_unix_oob.c     | 437 ++++++++++++++++++
 6 files changed, 602 insertions(+), 2 deletions(-)
 create mode 100644 tools/testing/selftests/net/af_unix/Makefile
 create mode 100644 tools/testing/selftests/net/af_unix/test_unix_oob.c
diff mbox series

Patch

diff --git a/include/net/af_unix.h b/include/net/af_unix.h
index f42fdddecd41..17965c8d6189 100644
--- a/include/net/af_unix.h
+++ b/include/net/af_unix.h
@@ -70,6 +70,9 @@  struct unix_sock {
 	struct socket_wq	peer_wq;
 	wait_queue_entry_t	peer_wake;
 	struct scm_stat		scm_stat;
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+	struct sk_buff		*oob_skb;
+#endif
 };
 
 static inline struct unix_sock *unix_sk(const struct sock *sk)
diff --git a/net/unix/Kconfig b/net/unix/Kconfig
index b6c4282899ec..b7f811216820 100644
--- a/net/unix/Kconfig
+++ b/net/unix/Kconfig
@@ -25,6 +25,11 @@  config UNIX_SCM
 	depends on UNIX
 	default y
 
+config	AF_UNIX_OOB
+	bool
+	depends on UNIX
+	default y
+
 config UNIX_DIAG
 	tristate "UNIX: socket monitoring interface"
 	depends on UNIX
diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c
index 23c92ad15c61..9435a17487bb 100644
--- a/net/unix/af_unix.c
+++ b/net/unix/af_unix.c
@@ -502,6 +502,12 @@  static void unix_sock_destructor(struct sock *sk)
 
 	skb_queue_purge(&sk->sk_receive_queue);
 
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+	if (u->oob_skb) {
+		kfree_skb(u->oob_skb);
+		u->oob_skb = NULL;
+	}
+#endif
 	WARN_ON(refcount_read(&sk->sk_wmem_alloc));
 	WARN_ON(!sk_unhashed(sk));
 	WARN_ON(sk->sk_socket);
@@ -1825,6 +1831,46 @@  static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
  */
 #define UNIX_SKB_FRAGS_SZ (PAGE_SIZE << get_order(32768))
 
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+static int queue_oob(struct socket *sock, struct msghdr *msg, struct sock *other)
+{
+	struct unix_sock *ousk = unix_sk(other);
+	struct sk_buff *skb;
+	int err = 0;
+
+	skb = sock_alloc_send_skb(sock->sk, 1, msg->msg_flags & MSG_DONTWAIT, &err);
+
+	if (!skb)
+		return err;
+
+	skb_put(skb, 1);
+	skb->len = 1;
+	err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, 1);
+
+	if (err) {
+		kfree_skb(skb);
+		return err;
+	}
+
+	unix_state_lock(other);
+	maybe_add_creds(skb, sock, other);
+	skb_get(skb);
+
+	if (ousk->oob_skb)
+		kfree_skb(ousk->oob_skb);
+
+	ousk->oob_skb = skb;
+
+	scm_stat_add(other, skb);
+	skb_queue_tail(&other->sk_receive_queue, skb);
+	sk_send_sigurg(other);
+	unix_state_unlock(other);
+	other->sk_data_ready(other);
+
+	return err;
+}
+#endif
+
 static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 			       size_t len)
 {
@@ -1843,8 +1889,14 @@  static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 		return err;
 
 	err = -EOPNOTSUPP;
-	if (msg->msg_flags&MSG_OOB)
-		goto out_err;
+	if (msg->msg_flags & MSG_OOB) {
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+		if (len)
+			len--;
+		else
+#endif
+			goto out_err;
+	}
 
 	if (msg->msg_namelen) {
 		err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
@@ -1909,6 +1961,15 @@  static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 		sent += size;
 	}
 
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+	if (msg->msg_flags & MSG_OOB) {
+		err = queue_oob(sock, msg, other);
+		if (err)
+			goto out_err;
+		sent++;
+	}
+#endif
+
 	scm_destroy(&scm);
 
 	return sent;
@@ -2247,6 +2308,59 @@  struct unix_stream_read_state {
 	unsigned int splice_flags;
 };
 
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+static int unix_stream_recv_urg(struct unix_stream_read_state *state)
+{
+	struct socket *sock = state->socket;
+	struct sock *sk = sock->sk;
+	struct unix_sock *u = unix_sk(sk);
+	int chunk = 1;
+
+	if (sock_flag(sk, SOCK_URGINLINE) || !u->oob_skb)
+		return -EINVAL;
+
+	chunk = state->recv_actor(u->oob_skb, 0, chunk, state);
+	if (chunk < 0)
+		return -EFAULT;
+
+	if (!(state->flags & MSG_PEEK)) {
+		UNIXCB(u->oob_skb).consumed += 1;
+		kfree_skb(u->oob_skb);
+		u->oob_skb = NULL;
+	}
+	state->msg->msg_flags |= MSG_OOB;
+	return 1;
+}
+
+static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk,
+				  int flags, int copied)
+{
+	struct unix_sock *u = unix_sk(sk);
+
+	if (!unix_skb_len(skb) && !(flags & MSG_PEEK)) {
+		skb_unlink(skb, &sk->sk_receive_queue);
+		consume_skb(skb);
+		skb = NULL;
+	} else {
+		if (skb == u->oob_skb) {
+			if (copied) {
+				skb = NULL;
+			} else if (sock_flag(sk, SOCK_URGINLINE)) {
+				if (!(flags & MSG_PEEK)) {
+					u->oob_skb = NULL;
+					consume_skb(skb);
+				}
+			} else if (!(flags & MSG_PEEK)) {
+				skb_unlink(skb, &sk->sk_receive_queue);
+				consume_skb(skb);
+				skb = skb_peek(&sk->sk_receive_queue);
+			}
+		}
+	}
+	return skb;
+}
+#endif
+
 static int unix_stream_read_generic(struct unix_stream_read_state *state,
 				    bool freezable)
 {
@@ -2272,6 +2386,15 @@  static int unix_stream_read_generic(struct unix_stream_read_state *state,
 
 	if (unlikely(flags & MSG_OOB)) {
 		err = -EOPNOTSUPP;
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+		mutex_lock(&u->iolock);
+		unix_state_lock(sk);
+
+		err = unix_stream_recv_urg(state);
+
+		unix_state_unlock(sk);
+		mutex_unlock(&u->iolock);
+#endif
 		goto out;
 	}
 
@@ -2300,6 +2423,18 @@  static int unix_stream_read_generic(struct unix_stream_read_state *state,
 		}
 		last = skb = skb_peek(&sk->sk_receive_queue);
 		last_len = last ? last->len : 0;
+
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+		if (skb) {
+			skb = manage_oob(skb, sk, flags, copied);
+			if (!skb) {
+				unix_state_unlock(sk);
+				if (copied)
+					break;
+				goto redo;
+			}
+		}
+#endif
 again:
 		if (skb == NULL) {
 			if (copied >= target)
@@ -2635,6 +2770,20 @@  static int unix_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
 	case SIOCUNIXFILE:
 		err = unix_open_file(sk);
 		break;
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+	case SIOCATMARK:
+		{
+			struct sk_buff *skb;
+			struct unix_sock *u = unix_sk(sk);
+			int answ = 0;
+
+			skb = skb_peek(&sk->sk_receive_queue);
+			if (skb && skb == u->oob_skb)
+				answ = 1;
+			err = put_user(answ, (int __user *)arg);
+		}
+		break;
+#endif
 	default:
 		err = -ENOIOCTLCMD;
 		break;
diff --git a/tools/testing/selftests/Makefile b/tools/testing/selftests/Makefile
index fb010a35d61a..da9e8b699e42 100644
--- a/tools/testing/selftests/Makefile
+++ b/tools/testing/selftests/Makefile
@@ -38,6 +38,7 @@  TARGETS += mount_setattr
 TARGETS += mqueue
 TARGETS += nci
 TARGETS += net
+TARGETS += net/af_unix
 TARGETS += net/forwarding
 TARGETS += net/mptcp
 TARGETS += netfilter
diff --git a/tools/testing/selftests/net/af_unix/Makefile b/tools/testing/selftests/net/af_unix/Makefile
new file mode 100644
index 000000000000..cfc7f4f97fd1
--- /dev/null
+++ b/tools/testing/selftests/net/af_unix/Makefile
@@ -0,0 +1,5 @@ 
+##TEST_GEN_FILES := test_unix_oob
+TEST_PROGS := test_unix_oob
+include ../../lib.mk
+
+all: $(TEST_PROGS)
diff --git a/tools/testing/selftests/net/af_unix/test_unix_oob.c b/tools/testing/selftests/net/af_unix/test_unix_oob.c
new file mode 100644
index 000000000000..0f3e3763f4f8
--- /dev/null
+++ b/tools/testing/selftests/net/af_unix/test_unix_oob.c
@@ -0,0 +1,437 @@ 
+// SPDX-License-Identifier: GPL-2.0-or-later
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/socket.h>
+#include <arpa/inet.h>
+#include <unistd.h>
+#include <string.h>
+#include <fcntl.h>
+#include <sys/ioctl.h>
+#include <errno.h>
+#include <netinet/tcp.h>
+#include <sys/un.h>
+#include <sys/signal.h>
+#include <sys/poll.h>
+
+static int pipefd[2];
+static int signal_recvd;
+static pid_t producer_id;
+static char sock_name[32];
+
+static void sig_hand(int sn, siginfo_t *si, void *p)
+{
+	signal_recvd = sn;
+}
+
+static int set_sig_handler(int signal)
+{
+	struct sigaction sa;
+
+	sa.sa_sigaction = sig_hand;
+	sigemptyset(&sa.sa_mask);
+	sa.sa_flags = SA_SIGINFO | SA_RESTART;
+
+	return sigaction(signal, &sa, NULL);
+}
+
+static void set_filemode(int fd, int set)
+{
+	int flags = fcntl(fd, F_GETFL, 0);
+
+	if (set)
+		flags &= ~O_NONBLOCK;
+	else
+		flags |= O_NONBLOCK;
+	fcntl(fd, F_SETFL, flags);
+}
+
+static void signal_producer(int fd)
+{
+	char cmd;
+
+	cmd = 'S';
+	write(fd, &cmd, sizeof(cmd));
+}
+
+static void wait_for_signal(int fd)
+{
+	char buf[5];
+
+	read(fd, buf, 5);
+}
+
+static void die(int status)
+{
+	fflush(NULL);
+	unlink(sock_name);
+	kill(producer_id, SIGTERM);
+	exit(status);
+}
+
+int is_sioctatmark(int fd)
+{
+	int ans = -1;
+
+	if (ioctl(fd, SIOCATMARK, &ans, sizeof(ans)) < 0) {
+#ifdef DEBUG
+		perror("SIOCATMARK Failed");
+#endif
+	}
+	return ans;
+}
+
+void read_oob(int fd, char *c)
+{
+
+	*c = ' ';
+	if (recv(fd, c, sizeof(*c), MSG_OOB) < 0) {
+#ifdef DEBUG
+		perror("Reading MSG_OOB Failed");
+#endif
+	}
+}
+
+int read_data(int pfd, char *buf, int size)
+{
+	int len = 0;
+
+	memset(buf, size, '0');
+	len = read(pfd, buf, size);
+#ifdef DEBUG
+	if (len < 0)
+		perror("read failed");
+#endif
+	return len;
+}
+
+static void wait_for_data(int pfd, int event)
+{
+	struct pollfd pfds[1];
+
+	pfds[0].fd = pfd;
+	pfds[0].events = event;
+	poll(pfds, 1, -1);
+}
+
+void producer(struct sockaddr_un *consumer_addr)
+{
+	int cfd;
+	char buf[64];
+	int i;
+
+	memset(buf, 'x', sizeof(buf));
+	cfd = socket(AF_UNIX, SOCK_STREAM, 0);
+
+	wait_for_signal(pipefd[0]);
+	if (connect(cfd, (struct sockaddr *)consumer_addr,
+		     sizeof(struct sockaddr)) != 0) {
+		perror("Connect failed");
+		kill(0, SIGTERM);
+		exit(1);
+	}
+
+	for (i = 0; i < 2; i++) {
+		/* Test 1: Test for SIGURG and OOB */
+		wait_for_signal(pipefd[0]);
+		memset(buf, 'x', sizeof(buf));
+		buf[63] = '@';
+		send(cfd, buf, sizeof(buf), MSG_OOB);
+
+		wait_for_signal(pipefd[0]);
+
+		/* Test 2: Test for OOB being overwitten */
+		memset(buf, 'x', sizeof(buf));
+		buf[63] = '%';
+		send(cfd, buf, sizeof(buf), MSG_OOB);
+
+		memset(buf, 'x', sizeof(buf));
+		buf[63] = '#';
+		send(cfd, buf, sizeof(buf), MSG_OOB);
+
+		wait_for_signal(pipefd[0]);
+
+		/* Test 3: Test for SIOCATMARK */
+		memset(buf, 'x', sizeof(buf));
+		buf[63] = '@';
+		send(cfd, buf, sizeof(buf), MSG_OOB);
+
+		memset(buf, 'x', sizeof(buf));
+		buf[63] = '%';
+		send(cfd, buf, sizeof(buf), MSG_OOB);
+
+		memset(buf, 'x', sizeof(buf));
+		send(cfd, buf, sizeof(buf), 0);
+
+		wait_for_signal(pipefd[0]);
+
+		/* Test 4: Test for 1byte OOB msg */
+		memset(buf, 'x', sizeof(buf));
+		buf[0] = '@';
+		send(cfd, buf, 1, MSG_OOB);
+	}
+}
+
+int
+main(int argc, char **argv)
+{
+	int lfd, pfd;
+	struct sockaddr_un consumer_addr, paddr;
+	socklen_t len = sizeof(consumer_addr);
+	char buf[1024];
+	int on = 0;
+	char oob;
+	int flags;
+	int atmark;
+	char *tmp_file;
+
+	lfd = socket(AF_UNIX, SOCK_STREAM, 0);
+	memset(&consumer_addr, 0, sizeof(consumer_addr));
+	consumer_addr.sun_family = AF_UNIX;
+	sprintf(sock_name, "unix_oob_%d", getpid());
+	unlink(sock_name);
+	strcpy(consumer_addr.sun_path, sock_name);
+
+	if ((bind(lfd, (struct sockaddr *)&consumer_addr,
+		  sizeof(consumer_addr))) != 0) {
+		perror("socket bind failed");
+		exit(1);
+	}
+
+	pipe(pipefd);
+
+	listen(lfd, 1);
+
+	producer_id = fork();
+	if (producer_id == 0) {
+		producer(&consumer_addr);
+		exit(0);
+	}
+
+	set_sig_handler(SIGURG);
+	signal_producer(pipefd[1]);
+
+	pfd = accept(lfd, (struct sockaddr *) &paddr, &len);
+	fcntl(pfd, F_SETOWN, getpid());
+
+	signal_recvd = 0;
+	signal_producer(pipefd[1]);
+
+	/* Test 1:
+	 * veriyf that SIGURG is
+	 * delivered and 63 bytes are
+	 * read and oob is '@'
+	 */
+	wait_for_data(pfd, POLLIN | POLLPRI);
+	read_oob(pfd, &oob);
+	len = read_data(pfd, buf, 1024);
+	if (!signal_recvd || len != 63 || oob != '@') {
+		fprintf(stderr, "Test 1 failed sigurg %d len %d %c\n",
+			 signal_recvd, len, oob);
+			die(1);
+	}
+
+	signal_recvd = 0;
+	signal_producer(pipefd[1]);
+
+	/* Test 2:
+	 * Verify that the first OOB is over written by
+	 * the 2nd one and the first OOB is returned as
+	 * part of the read, and sigurg is received.
+	 */
+	wait_for_data(pfd, POLLIN | POLLPRI);
+	len = 0;
+	while (len < 70)
+		len = recv(pfd, buf, 1024, MSG_PEEK);
+	len = read_data(pfd, buf, 1024);
+	read_oob(pfd, &oob);
+	if (!signal_recvd || len != 127 || oob != '#') {
+		fprintf(stderr, "Test 2 failed, sigurg %d len %d OOB %c\n",
+		signal_recvd, len, oob);
+		die(1);
+	}
+
+	signal_recvd = 0;
+	signal_producer(pipefd[1]);
+
+	/* Test 3:
+	 * verify that 2nd oob over writes
+	 * the first one and read breaks at
+	 * oob boundary returning 127 bytes
+	 * and sigurg is received and atmark
+	 * is set.
+	 * oob is '%' and second read returns
+	 * 64 bytes.
+	 */
+	len = 0;
+	wait_for_data(pfd, POLLIN | POLLPRI);
+	while (len < 150)
+		len = recv(pfd, buf, 1024, MSG_PEEK);
+	len = read_data(pfd, buf, 1024);
+	atmark = is_sioctatmark(pfd);
+	read_oob(pfd, &oob);
+
+	if (!signal_recvd || len != 127 || oob != '%' || atmark != 1) {
+		fprintf(stderr, "Test 3 failed, sigurg %d len %d OOB %c ",
+		"atmark %d\n", signal_recvd, len, oob, atmark);
+		die(1);
+	}
+
+	signal_recvd = 0;
+
+	len = read_data(pfd, buf, 1024);
+	if (len != 64) {
+		fprintf(stderr, "Test 3.1 failed, sigurg %d len %d OOB %c\n",
+			signal_recvd, len, oob);
+		die(1);
+	}
+
+	signal_recvd = 0;
+	signal_producer(pipefd[1]);
+
+	/* Test 4:
+	 * verify that a single byte
+	 * oob message is delivered.
+	 * set non blocking mode and
+	 * check proper error is
+	 * returned and sigurg is
+	 * received and correct
+	 * oob is read.
+	 */
+
+	set_filemode(pfd, 0);
+
+	wait_for_data(pfd, POLLIN | POLLPRI);
+	len = read_data(pfd, buf, 1024);
+	if ((len == -1) && (errno == 11))
+		len = 0;
+
+	read_oob(pfd, &oob);
+
+	if (!signal_recvd || len != 0 || oob != '@') {
+		fprintf(stderr, "Test 4 failed, sigurg %d len %d OOB %c\n",
+			 signal_recvd, len, oob);
+		die(1);
+	}
+
+	set_filemode(pfd, 1);
+
+	/* Inline Testing */
+
+	on = 1;
+	if (setsockopt(pfd, SOL_SOCKET, SO_OOBINLINE, &on, sizeof(on))) {
+		perror("SO_OOBINLINE");
+		die(1);
+	}
+
+	signal_recvd = 0;
+	signal_producer(pipefd[1]);
+
+	/* Test 1 -- Inline:
+	 * Check that SIGURG is
+	 * delivered and 63 bytes are
+	 * read and oob is '@'
+	 */
+
+	wait_for_data(pfd, POLLIN | POLLPRI);
+	len = read_data(pfd, buf, 1024);
+
+	if (!signal_recvd || len != 63) {
+		fprintf(stderr, "Test 1 Inline failed, sigurg %d len %d\n",
+			signal_recvd, len);
+		die(1);
+	}
+
+	len = read_data(pfd, buf, 1024);
+
+	if (len != 1) {
+		fprintf(stderr,
+			 "Test 1.1 Inline failed, sigurg %d len %d oob %c\n",
+			 signal_recvd, len, oob);
+		die(1);
+	}
+
+	signal_recvd = 0;
+	signal_producer(pipefd[1]);
+
+	/* Test 2 -- Inline:
+	 * Verify that the first OOB is over written by
+	 * the 2nd one and read breaks correctly on
+	 * 2nd OOB boundary with the first OOB returned as
+	 * part of the read, and sigurg is delivered and
+	 * siocatmark returns true.
+	 * next read returns one byte, the oob byte
+	 * and siocatmark returns false.
+	 */
+	len = 0;
+	wait_for_data(pfd, POLLIN | POLLPRI);
+	while (len < 70)
+		len = recv(pfd, buf, 1024, MSG_PEEK);
+	len = read_data(pfd, buf, 1024);
+	atmark = is_sioctatmark(pfd);
+	if (len != 127 || atmark != 1 || !signal_recvd) {
+		fprintf(stderr, "Test 2 Inline failed, len %d atmark %d\n",
+			 len, atmark);
+		die(1);
+	}
+
+	len = read_data(pfd, buf, 1024);
+	atmark = is_sioctatmark(pfd);
+	if (len != 1 || buf[0] != '#' || atmark == 1) {
+		fprintf(stderr, "Test 2.1 Inline failed, len %d data %c atmark %d\n",
+			len, buf[0], atmark);
+		die(1);
+	}
+
+	signal_recvd = 0;
+	signal_producer(pipefd[1]);
+
+	/* Test 3 -- Inline:
+	 * verify that 2nd oob over writes
+	 * the first one and read breaks at
+	 * oob boundary returning 127 bytes
+	 * and sigurg is received and siocatmark
+	 * is true after the read.
+	 * subsequent read returns 65 bytes
+	 * because of oob which should be '%'.
+	 */
+	len = 0;
+	wait_for_data(pfd, POLLIN | POLLPRI);
+	while (len < 126)
+		len = recv(pfd, buf, 1024, MSG_PEEK);
+	len = read_data(pfd, buf, 1024);
+	atmark = is_sioctatmark(pfd);
+	if (!signal_recvd || len != 127 || !atmark) {
+		fprintf(stderr,
+			 "Test 3 Inline failed, sigurg %d len %d data %c\n",
+			 signal_recvd, len, buf[0]);
+		die(1);
+	}
+
+	len = read_data(pfd, buf, 1024);
+	atmark = is_sioctatmark(pfd);
+	if (len != 65 || buf[0] != '%' || atmark != 0) {
+		fprintf(stderr,
+			 "Test 3.1 Inline failed, len %d oob %c atmark %d\n",
+			 len, buf[0], atmark);
+		die(1);
+	}
+
+	signal_recvd = 0;
+	signal_producer(pipefd[1]);
+
+	/* Test 4 -- Inline:
+	 * verify that a single
+	 * byte oob message is delivered
+	 * and read returns one byte, the oob
+	 * byte and sigurg is received
+	 */
+	wait_for_data(pfd, POLLIN | POLLPRI);
+	len = read_data(pfd, buf, 1024);
+	if (!signal_recvd || len != 1 || buf[0] != '@') {
+		fprintf(stderr,
+			"Test 4 Inline failed, signal %d len %d data %c\n",
+		signal_recvd, len, buf[0]);
+		die(1);
+	}
+	die(0);
+}