Message ID | 20210323131332.2461409-1-arseny.krasnov@kaspersky.com |
---|---|
State | New |
Headers | show |
Series | virtio/vsock: introduce SOCK_SEQPACKET support | expand |
On Tue, Mar 23, 2021 at 04:13:29PM +0300, Arseny Krasnov wrote: >This adds rest of logic for SEQPACKET: >1) SEQPACKET specific functions which send SEQ_BEGIN/SEQ_END. > Note that both functions may sleep to wait enough space for > SEQPACKET header. >2) SEQ_BEGIN/SEQ_END in TAP packet capture. >3) Send SHUTDOWN on socket close for SEQPACKET type. >4) Set SEQPACKET packet type during send. >5) Set MSG_EOR in flags for SEQPACKET during send. >6) 'seqpacket_allow' flag to virtio transport. > >Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com> >--- > v6 -> v7: > In 'virtio_transport_seqpacket_enqueue()', 'next_tx_msg_id' is updated > in both cases when message send successfully or error occured. > > include/linux/virtio_vsock.h | 7 ++ > net/vmw_vsock/virtio_transport_common.c | 88 ++++++++++++++++++++++++- > 2 files changed, 93 insertions(+), 2 deletions(-) > >diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h >index 0e3aa395c07c..ab5f56fd7251 100644 >--- a/include/linux/virtio_vsock.h >+++ b/include/linux/virtio_vsock.h >@@ -22,6 +22,7 @@ struct virtio_vsock_seq_state { > u32 user_read_seq_len; > u32 user_read_copied; > u32 curr_rx_msg_id; >+ u32 next_tx_msg_id; > }; > > /* Per-socket state (accessed via vsk->trans) */ >@@ -76,6 +77,8 @@ struct virtio_transport { > > /* Takes ownership of the packet */ > int (*send_pkt)(struct virtio_vsock_pkt *pkt); >+ >+ bool seqpacket_allow; > }; > > ssize_t >@@ -89,6 +92,10 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk, > size_t len, int flags); > > int >+virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, >+ struct msghdr *msg, >+ size_t len); >+int > virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, > struct msghdr *msg, > int flags, >diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c >index bfe0d7026bf8..01a56c7da8bd 100644 >--- a/net/vmw_vsock/virtio_transport_common.c >+++ b/net/vmw_vsock/virtio_transport_common.c >@@ -139,6 +139,8 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque) > break; > case VIRTIO_VSOCK_OP_CREDIT_UPDATE: > case VIRTIO_VSOCK_OP_CREDIT_REQUEST: >+ case VIRTIO_VSOCK_OP_SEQ_BEGIN: >+ case VIRTIO_VSOCK_OP_SEQ_END: > hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); > break; > default: >@@ -187,7 +189,12 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, > struct virtio_vsock_pkt *pkt; > u32 pkt_len = info->pkt_len; > >- info->type = VIRTIO_VSOCK_TYPE_STREAM; >+ info->type = virtio_transport_get_type(sk_vsock(vsk)); >+ >+ if (info->type == VIRTIO_VSOCK_TYPE_SEQPACKET && >+ info->msg && >+ info->msg->msg_flags & MSG_EOR) >+ info->flags |= VIRTIO_VSOCK_RW_EOR; > > t_ops = virtio_transport_get_ops(vsk); > if (unlikely(!t_ops)) >@@ -401,6 +408,43 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, > return err; > } > >+static int virtio_transport_seqpacket_send_ctrl(struct vsock_sock *vsk, >+ int type, >+ size_t len, >+ int flags) >+{ >+ struct virtio_vsock_sock *vvs = vsk->trans; >+ struct virtio_vsock_pkt_info info = { >+ .op = type, >+ .vsk = vsk, >+ .pkt_len = sizeof(struct virtio_vsock_seq_hdr) >+ }; >+ >+ struct virtio_vsock_seq_hdr seq_hdr = { >+ .msg_id = cpu_to_le32(vvs->seq_state.next_tx_msg_id), >+ .msg_len = cpu_to_le32(len) >+ }; >+ >+ struct kvec seq_hdr_kiov = { >+ .iov_base = (void *)&seq_hdr, >+ .iov_len = sizeof(struct virtio_vsock_seq_hdr) >+ }; >+ >+ struct msghdr msg = {0}; >+ >+ //XXX: do we need 'vsock_transport_send_notify_data' pointer? >+ if (vsock_wait_space(sk_vsock(vsk), >+ sizeof(struct virtio_vsock_seq_hdr), >+ flags, NULL)) >+ return -1; >+ >+ iov_iter_kvec(&msg.msg_iter, WRITE, &seq_hdr_kiov, 1, sizeof(seq_hdr)); >+ >+ info.msg = &msg; >+ >+ return virtio_transport_send_pkt_info(vsk, &info); >+} >+ > static inline void virtio_transport_remove_pkt(struct virtio_vsock_pkt *pkt) > { > list_del(&pkt->list); >@@ -595,6 +639,46 @@ virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, > } > EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue); > >+int >+virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, >+ struct msghdr *msg, >+ size_t len) >+{ >+ int written = -1; >+ >+ if (msg->msg_iter.iov_offset == 0) { >+ /* Send SEQBEGIN. */ >+ if (virtio_transport_seqpacket_send_ctrl(vsk, >+ VIRTIO_VSOCK_OP_SEQ_BEGIN, >+ len, >+ msg->msg_flags) < 0) >+ goto out; >+ } >+ >+ written = virtio_transport_stream_enqueue(vsk, msg, len); >+ >+ if (written < 0) >+ goto out; >+ >+ if (msg->msg_iter.count == 0) { >+ /* Send SEQEND. */ >+ virtio_transport_seqpacket_send_ctrl(vsk, >+ VIRTIO_VSOCK_OP_SEQ_END, >+ 0, >+ msg->msg_flags); What happen if this fail? In the previous version we returned -1, now we return the bytes transmitted, is that right? The rest LGTM. >+ } >+out: >+ /* Update next id on error or message transmission done. */ >+ if (written < 0 || msg->msg_iter.count == 0) { >+ struct virtio_vsock_sock *vvs = vsk->trans; >+ >+ vvs->seq_state.next_tx_msg_id++; >+ } >+ >+ return written; >+} >+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue); >+ > int > virtio_transport_dgram_dequeue(struct vsock_sock *vsk, > struct msghdr *msg, >@@ -1014,7 +1098,7 @@ void virtio_transport_release(struct vsock_sock *vsk) > struct sock *sk = &vsk->sk; > bool remove_sock = true; > >- if (sk->sk_type == SOCK_STREAM) >+ if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) > remove_sock = virtio_transport_close(vsk); > > list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { >-- >2.25.1 >
On 25.03.2021 13:18, Stefano Garzarella wrote: > On Tue, Mar 23, 2021 at 04:13:29PM +0300, Arseny Krasnov wrote: >> This adds rest of logic for SEQPACKET: >> 1) SEQPACKET specific functions which send SEQ_BEGIN/SEQ_END. >> Note that both functions may sleep to wait enough space for >> SEQPACKET header. >> 2) SEQ_BEGIN/SEQ_END in TAP packet capture. >> 3) Send SHUTDOWN on socket close for SEQPACKET type. >> 4) Set SEQPACKET packet type during send. >> 5) Set MSG_EOR in flags for SEQPACKET during send. >> 6) 'seqpacket_allow' flag to virtio transport. >> >> Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com> >> --- >> v6 -> v7: >> In 'virtio_transport_seqpacket_enqueue()', 'next_tx_msg_id' is updated >> in both cases when message send successfully or error occured. >> >> include/linux/virtio_vsock.h | 7 ++ >> net/vmw_vsock/virtio_transport_common.c | 88 ++++++++++++++++++++++++- >> 2 files changed, 93 insertions(+), 2 deletions(-) >> >> diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h >> index 0e3aa395c07c..ab5f56fd7251 100644 >> --- a/include/linux/virtio_vsock.h >> +++ b/include/linux/virtio_vsock.h >> @@ -22,6 +22,7 @@ struct virtio_vsock_seq_state { >> u32 user_read_seq_len; >> u32 user_read_copied; >> u32 curr_rx_msg_id; >> + u32 next_tx_msg_id; >> }; >> >> /* Per-socket state (accessed via vsk->trans) */ >> @@ -76,6 +77,8 @@ struct virtio_transport { >> >> /* Takes ownership of the packet */ >> int (*send_pkt)(struct virtio_vsock_pkt *pkt); >> + >> + bool seqpacket_allow; >> }; >> >> ssize_t >> @@ -89,6 +92,10 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk, >> size_t len, int flags); >> >> int >> +virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, >> + struct msghdr *msg, >> + size_t len); >> +int >> virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, >> struct msghdr *msg, >> int flags, >> diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c >> index bfe0d7026bf8..01a56c7da8bd 100644 >> --- a/net/vmw_vsock/virtio_transport_common.c >> +++ b/net/vmw_vsock/virtio_transport_common.c >> @@ -139,6 +139,8 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque) >> break; >> case VIRTIO_VSOCK_OP_CREDIT_UPDATE: >> case VIRTIO_VSOCK_OP_CREDIT_REQUEST: >> + case VIRTIO_VSOCK_OP_SEQ_BEGIN: >> + case VIRTIO_VSOCK_OP_SEQ_END: >> hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); >> break; >> default: >> @@ -187,7 +189,12 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, >> struct virtio_vsock_pkt *pkt; >> u32 pkt_len = info->pkt_len; >> >> - info->type = VIRTIO_VSOCK_TYPE_STREAM; >> + info->type = virtio_transport_get_type(sk_vsock(vsk)); >> + >> + if (info->type == VIRTIO_VSOCK_TYPE_SEQPACKET && >> + info->msg && >> + info->msg->msg_flags & MSG_EOR) >> + info->flags |= VIRTIO_VSOCK_RW_EOR; >> >> t_ops = virtio_transport_get_ops(vsk); >> if (unlikely(!t_ops)) >> @@ -401,6 +408,43 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, >> return err; >> } >> >> +static int virtio_transport_seqpacket_send_ctrl(struct vsock_sock *vsk, >> + int type, >> + size_t len, >> + int flags) >> +{ >> + struct virtio_vsock_sock *vvs = vsk->trans; >> + struct virtio_vsock_pkt_info info = { >> + .op = type, >> + .vsk = vsk, >> + .pkt_len = sizeof(struct virtio_vsock_seq_hdr) >> + }; >> + >> + struct virtio_vsock_seq_hdr seq_hdr = { >> + .msg_id = cpu_to_le32(vvs->seq_state.next_tx_msg_id), >> + .msg_len = cpu_to_le32(len) >> + }; >> + >> + struct kvec seq_hdr_kiov = { >> + .iov_base = (void *)&seq_hdr, >> + .iov_len = sizeof(struct virtio_vsock_seq_hdr) >> + }; >> + >> + struct msghdr msg = {0}; >> + >> + //XXX: do we need 'vsock_transport_send_notify_data' pointer? >> + if (vsock_wait_space(sk_vsock(vsk), >> + sizeof(struct virtio_vsock_seq_hdr), >> + flags, NULL)) >> + return -1; >> + >> + iov_iter_kvec(&msg.msg_iter, WRITE, &seq_hdr_kiov, 1, sizeof(seq_hdr)); >> + >> + info.msg = &msg; >> + >> + return virtio_transport_send_pkt_info(vsk, &info); >> +} >> + >> static inline void virtio_transport_remove_pkt(struct virtio_vsock_pkt *pkt) >> { >> list_del(&pkt->list); >> @@ -595,6 +639,46 @@ virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, >> } >> EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue); >> >> +int >> +virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, >> + struct msghdr *msg, >> + size_t len) >> +{ >> + int written = -1; >> + >> + if (msg->msg_iter.iov_offset == 0) { >> + /* Send SEQBEGIN. */ >> + if (virtio_transport_seqpacket_send_ctrl(vsk, >> + VIRTIO_VSOCK_OP_SEQ_BEGIN, >> + len, >> + msg->msg_flags) < 0) >> + goto out; >> + } >> + >> + written = virtio_transport_stream_enqueue(vsk, msg, len); >> + >> + if (written < 0) >> + goto out; >> + >> + if (msg->msg_iter.count == 0) { >> + /* Send SEQEND. */ >> + virtio_transport_seqpacket_send_ctrl(vsk, >> + VIRTIO_VSOCK_OP_SEQ_END, >> + 0, >> + msg->msg_flags); > What happen if this fail? > > In the previous version we returned -1, now we return the bytes > transmitted, is that right? Ack, i'll fix it > > The rest LGTM. > >> + } >> +out: >> + /* Update next id on error or message transmission done. */ >> + if (written < 0 || msg->msg_iter.count == 0) { >> + struct virtio_vsock_sock *vvs = vsk->trans; >> + >> + vvs->seq_state.next_tx_msg_id++; >> + } >> + >> + return written; >> +} >> +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue); >> + >> int >> virtio_transport_dgram_dequeue(struct vsock_sock *vsk, >> struct msghdr *msg, >> @@ -1014,7 +1098,7 @@ void virtio_transport_release(struct vsock_sock *vsk) >> struct sock *sk = &vsk->sk; >> bool remove_sock = true; >> >> - if (sk->sk_type == SOCK_STREAM) >> + if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) >> remove_sock = virtio_transport_close(vsk); >> >> list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { >> -- >> 2.25.1 >> >
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h index 0e3aa395c07c..ab5f56fd7251 100644 --- a/include/linux/virtio_vsock.h +++ b/include/linux/virtio_vsock.h @@ -22,6 +22,7 @@ struct virtio_vsock_seq_state { u32 user_read_seq_len; u32 user_read_copied; u32 curr_rx_msg_id; + u32 next_tx_msg_id; }; /* Per-socket state (accessed via vsk->trans) */ @@ -76,6 +77,8 @@ struct virtio_transport { /* Takes ownership of the packet */ int (*send_pkt)(struct virtio_vsock_pkt *pkt); + + bool seqpacket_allow; }; ssize_t @@ -89,6 +92,10 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk, size_t len, int flags); int +virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len); +int virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, struct msghdr *msg, int flags, diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index bfe0d7026bf8..01a56c7da8bd 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -139,6 +139,8 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque) break; case VIRTIO_VSOCK_OP_CREDIT_UPDATE: case VIRTIO_VSOCK_OP_CREDIT_REQUEST: + case VIRTIO_VSOCK_OP_SEQ_BEGIN: + case VIRTIO_VSOCK_OP_SEQ_END: hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); break; default: @@ -187,7 +189,12 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, struct virtio_vsock_pkt *pkt; u32 pkt_len = info->pkt_len; - info->type = VIRTIO_VSOCK_TYPE_STREAM; + info->type = virtio_transport_get_type(sk_vsock(vsk)); + + if (info->type == VIRTIO_VSOCK_TYPE_SEQPACKET && + info->msg && + info->msg->msg_flags & MSG_EOR) + info->flags |= VIRTIO_VSOCK_RW_EOR; t_ops = virtio_transport_get_ops(vsk); if (unlikely(!t_ops)) @@ -401,6 +408,43 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, return err; } +static int virtio_transport_seqpacket_send_ctrl(struct vsock_sock *vsk, + int type, + size_t len, + int flags) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = type, + .vsk = vsk, + .pkt_len = sizeof(struct virtio_vsock_seq_hdr) + }; + + struct virtio_vsock_seq_hdr seq_hdr = { + .msg_id = cpu_to_le32(vvs->seq_state.next_tx_msg_id), + .msg_len = cpu_to_le32(len) + }; + + struct kvec seq_hdr_kiov = { + .iov_base = (void *)&seq_hdr, + .iov_len = sizeof(struct virtio_vsock_seq_hdr) + }; + + struct msghdr msg = {0}; + + //XXX: do we need 'vsock_transport_send_notify_data' pointer? + if (vsock_wait_space(sk_vsock(vsk), + sizeof(struct virtio_vsock_seq_hdr), + flags, NULL)) + return -1; + + iov_iter_kvec(&msg.msg_iter, WRITE, &seq_hdr_kiov, 1, sizeof(seq_hdr)); + + info.msg = &msg; + + return virtio_transport_send_pkt_info(vsk, &info); +} + static inline void virtio_transport_remove_pkt(struct virtio_vsock_pkt *pkt) { list_del(&pkt->list); @@ -595,6 +639,46 @@ virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, } EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue); +int +virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + int written = -1; + + if (msg->msg_iter.iov_offset == 0) { + /* Send SEQBEGIN. */ + if (virtio_transport_seqpacket_send_ctrl(vsk, + VIRTIO_VSOCK_OP_SEQ_BEGIN, + len, + msg->msg_flags) < 0) + goto out; + } + + written = virtio_transport_stream_enqueue(vsk, msg, len); + + if (written < 0) + goto out; + + if (msg->msg_iter.count == 0) { + /* Send SEQEND. */ + virtio_transport_seqpacket_send_ctrl(vsk, + VIRTIO_VSOCK_OP_SEQ_END, + 0, + msg->msg_flags); + } +out: + /* Update next id on error or message transmission done. */ + if (written < 0 || msg->msg_iter.count == 0) { + struct virtio_vsock_sock *vvs = vsk->trans; + + vvs->seq_state.next_tx_msg_id++; + } + + return written; +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue); + int virtio_transport_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg, @@ -1014,7 +1098,7 @@ void virtio_transport_release(struct vsock_sock *vsk) struct sock *sk = &vsk->sk; bool remove_sock = true; - if (sk->sk_type == SOCK_STREAM) + if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) remove_sock = virtio_transport_close(vsk); list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
This adds rest of logic for SEQPACKET: 1) SEQPACKET specific functions which send SEQ_BEGIN/SEQ_END. Note that both functions may sleep to wait enough space for SEQPACKET header. 2) SEQ_BEGIN/SEQ_END in TAP packet capture. 3) Send SHUTDOWN on socket close for SEQPACKET type. 4) Set SEQPACKET packet type during send. 5) Set MSG_EOR in flags for SEQPACKET during send. 6) 'seqpacket_allow' flag to virtio transport. Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com> --- v6 -> v7: In 'virtio_transport_seqpacket_enqueue()', 'next_tx_msg_id' is updated in both cases when message send successfully or error occured. include/linux/virtio_vsock.h | 7 ++ net/vmw_vsock/virtio_transport_common.c | 88 ++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 2 deletions(-)