diff mbox series

[RFC,v2,1/6] af_vsock/virtio/vsock: change seqpacket receive logic

Message ID 20210704080942.89177-1-arseny.krasnov@kaspersky.com
State New
Headers show
Series Improve SOCK_SEQPACKET receive logic | expand

Commit Message

Arseny Krasnov July 4, 2021, 8:09 a.m. UTC
1) In af_vsock "loop" now is really loop: it receives
   message fragments one by one, until 'msg_ready'
   value is returned by transport.
2) In virtio transport, dequeue callback is called
   everytime when at least one fragment of message is
   received.

Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com>
---
 include/linux/virtio_vsock.h            |  3 +-
 include/net/af_vsock.h                  |  2 +-
 net/vmw_vsock/af_vsock.c                | 33 +++++++++----
 net/vmw_vsock/virtio_transport_common.c | 62 +++++++++++--------------
 4 files changed, 52 insertions(+), 48 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index 35d7eedb5e8e..e68b4029f038 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -88,7 +88,8 @@  virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
 ssize_t
 virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
 				   struct msghdr *msg,
-				   int flags);
+				   int flags,
+				   bool *msg_ready);
 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
 u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk);
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index ab207677e0a8..c40d341611b0 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -137,7 +137,7 @@  struct vsock_transport {
 
 	/* SEQ_PACKET. */
 	ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
-				     int flags);
+				     int flags, bool *msg_ready);
 	int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg,
 				 size_t len);
 	bool (*seqpacket_allow)(u32 remote_cid);
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 3e02cc3b24f8..b66884def8e8 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1881,7 +1881,7 @@  static int vsock_connectible_wait_data(struct sock *sk,
 	err = 0;
 	transport = vsk->transport;
 
-	while ((data = vsock_connectible_has_data(vsk)) == 0) {
+	while ((data = vsock_stream_has_data(vsk)) == 0) {
 		prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE);
 
 		if (sk->sk_err != 0 ||
@@ -2013,6 +2013,7 @@  static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
 				     size_t len, int flags)
 {
 	const struct vsock_transport *transport;
+	bool msg_ready;
 	struct vsock_sock *vsk;
 	ssize_t record_len;
 	long timeout;
@@ -2023,23 +2024,36 @@  static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
 	transport = vsk->transport;
 
 	timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	msg_ready = false;
+	record_len = 0;
 
-	err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0);
-	if (err <= 0)
-		goto out;
+	while (!msg_ready) {
+		ssize_t fragment_len;
+		int intr_err;
 
-	record_len = transport->seqpacket_dequeue(vsk, msg, flags);
+		intr_err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0);
+		if (intr_err <= 0) {
+			err = intr_err;
+			break;
+		}
 
-	if (record_len < 0) {
-		err = -ENOMEM;
-		goto out;
+		fragment_len = transport->seqpacket_dequeue(vsk, msg, flags, &msg_ready);
+
+		if (fragment_len < 0) {
+			err = -ENOMEM;
+			break;
+		}
+
+		record_len += fragment_len;
 	}
 
 	if (sk->sk_err) {
 		err = -sk->sk_err;
 	} else if (sk->sk_shutdown & RCV_SHUTDOWN) {
 		err = 0;
-	} else {
+	}
+
+	if (msg_ready && !err) {
 		/* User sets MSG_TRUNC, so return real length of
 		 * packet.
 		 */
@@ -2055,7 +2069,6 @@  static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
 			msg->msg_flags |= MSG_TRUNC;
 	}
 
-out:
 	return err;
 }
 
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 169ba8b72a63..053bcea1a03f 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -407,58 +407,48 @@  virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
 
 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
 						 struct msghdr *msg,
-						 int flags)
+						 int flags,
+						 bool *msg_ready)
 {
 	struct virtio_vsock_sock *vvs = vsk->trans;
 	struct virtio_vsock_pkt *pkt;
 	int dequeued_len = 0;
 	size_t user_buf_len = msg_data_left(msg);
-	bool msg_ready = false;
 
+	*msg_ready = false;
 	spin_lock_bh(&vvs->rx_lock);
 
-	if (vvs->msg_count == 0) {
-		spin_unlock_bh(&vvs->rx_lock);
-		return 0;
-	}
+	while (!*msg_ready && !list_empty(&vvs->rx_queue) && dequeued_len >= 0) {
+		size_t pkt_len;
+		size_t bytes_to_copy;
 
-	while (!msg_ready) {
 		pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
+		pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
 
-		if (dequeued_len >= 0) {
-			size_t pkt_len;
-			size_t bytes_to_copy;
+		bytes_to_copy = min(user_buf_len, pkt_len);
 
-			pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
-			bytes_to_copy = min(user_buf_len, pkt_len);
-
-			if (bytes_to_copy) {
-				int err;
-
-				/* sk_lock is held by caller so no one else can dequeue.
-				 * Unlock rx_lock since memcpy_to_msg() may sleep.
-				 */
-				spin_unlock_bh(&vvs->rx_lock);
+		if (bytes_to_copy) {
+			int err;
+			/* sk_lock is held by caller so no one else can dequeue.
+			 * Unlock rx_lock since memcpy_to_msg() may sleep.
+			 */
+			spin_unlock_bh(&vvs->rx_lock);
 
-				err = memcpy_to_msg(msg, pkt->buf, bytes_to_copy);
-				if (err) {
-					/* Copy of message failed. Rest of
-					 * fragments will be freed without copy.
-					 */
-					dequeued_len = err;
-				} else {
-					user_buf_len -= bytes_to_copy;
-				}
+			err = memcpy_to_msg(msg, pkt->buf, bytes_to_copy);
 
-				spin_lock_bh(&vvs->rx_lock);
-			}
+			spin_lock_bh(&vvs->rx_lock);
 
-			if (dequeued_len >= 0)
-				dequeued_len += pkt_len;
+			if (err)
+				dequeued_len = err;
+			else
+				user_buf_len -= bytes_to_copy;
 		}
 
+		if (dequeued_len >= 0)
+			dequeued_len += pkt_len;
+
 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
-			msg_ready = true;
+			*msg_ready = true;
 			vvs->msg_count--;
 		}
 
@@ -489,12 +479,12 @@  EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
 ssize_t
 virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
 				   struct msghdr *msg,
-				   int flags)
+				   int flags, bool *msg_ready)
 {
 	if (flags & MSG_PEEK)
 		return -EOPNOTSUPP;
 
-	return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
+	return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags, msg_ready);
 }
 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);