diff mbox series

[net-next,02/11] mptcp: pm: reduce entries iterations on connect

Message ID 20240902-net-next-mptcp-mib-mpjtx-misc-v1-2-d3e0f3773b90@kernel.org
State Accepted
Commit b83fbca1b4c9c45628aa55d582c14825b0e71c2b
Headers show
Series mptcp: MIB counters for MPJ TX + misc improvements | expand

Commit Message

Matthieu Baerts (NGI0) Sept. 2, 2024, 10:45 a.m. UTC
__mptcp_subflow_connect() is currently called from the path-managers,
which have all the required information to create subflows. No need to
call the PM again to re-iterate over the list of entries with RCU lock
to get more info.

Instead, it is possible to pass a mptcp_pm_addr_entry structure, instead
of a mptcp_addr_info one. The former contains the ifindex and the flags
that are required when creating the new subflow.

This is a partial revert of commit ee285257a9c1 ("mptcp: drop flags and
ifindex arguments").

While at it, the local ID can also be set if it is known and 0, to avoid
having to set it in the 'rebuild_header' hook, which will cause a new
iteration of the endpoint entries.

Reviewed-by: Mat Martineau <martineau@kernel.org>
Signed-off-by: Matthieu Baerts (NGI0) <matttbe@kernel.org>
---
 net/mptcp/pm.c           | 11 --------
 net/mptcp/pm_netlink.c   | 66 ++++++++++++++++++------------------------------
 net/mptcp/pm_userspace.c | 40 ++++++++++-------------------
 net/mptcp/protocol.h     | 16 +++++-------
 net/mptcp/subflow.c      | 29 +++++++++++++--------
 5 files changed, 62 insertions(+), 100 deletions(-)
diff mbox series

Patch

diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c
index 37f6dbcd8434..620264c75dc2 100644
--- a/net/mptcp/pm.c
+++ b/net/mptcp/pm.c
@@ -430,17 +430,6 @@  bool mptcp_pm_is_backup(struct mptcp_sock *msk, struct sock_common *skc)
 	return mptcp_pm_nl_is_backup(msk, &skc_local);
 }
 
-int mptcp_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id,
-					 u8 *flags, int *ifindex)
-{
-	*flags = 0;
-	*ifindex = 0;
-
-	if (mptcp_pm_is_userspace(msk))
-		return mptcp_userspace_pm_get_flags_and_ifindex_by_id(msk, id, flags, ifindex);
-	return mptcp_pm_nl_get_flags_and_ifindex_by_id(msk, id, flags, ifindex);
-}
-
 int mptcp_pm_get_addr(struct sk_buff *skb, struct genl_info *info)
 {
 	if (info->attrs[MPTCP_PM_ATTR_TOKEN])
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c
index 275959581586..62a42f7ee7cb 100644
--- a/net/mptcp/pm_netlink.c
+++ b/net/mptcp/pm_netlink.c
@@ -149,7 +149,7 @@  static bool lookup_subflow_by_daddr(const struct list_head *list,
 static bool
 select_local_address(const struct pm_nl_pernet *pernet,
 		     const struct mptcp_sock *msk,
-		     struct mptcp_pm_addr_entry *new_entry)
+		     struct mptcp_pm_local *new_local)
 {
 	struct mptcp_pm_addr_entry *entry;
 	bool found = false;
@@ -164,7 +164,9 @@  select_local_address(const struct pm_nl_pernet *pernet,
 		if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap))
 			continue;
 
-		*new_entry = *entry;
+		new_local->addr = entry->addr;
+		new_local->flags = entry->flags;
+		new_local->ifindex = entry->ifindex;
 		found = true;
 		break;
 	}
@@ -175,7 +177,7 @@  select_local_address(const struct pm_nl_pernet *pernet,
 
 static bool
 select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk,
-		      struct mptcp_pm_addr_entry *new_entry)
+		     struct mptcp_pm_local *new_local)
 {
 	struct mptcp_pm_addr_entry *entry;
 	bool found = false;
@@ -193,7 +195,9 @@  select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk,
 		if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
 			continue;
 
-		*new_entry = *entry;
+		new_local->addr = entry->addr;
+		new_local->flags = entry->flags;
+		new_local->ifindex = entry->ifindex;
 		found = true;
 		break;
 	}
@@ -524,11 +528,11 @@  __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info)
 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
 {
 	struct sock *sk = (struct sock *)msk;
-	struct mptcp_pm_addr_entry local;
 	unsigned int add_addr_signal_max;
 	bool signal_and_subflow = false;
 	unsigned int local_addr_max;
 	struct pm_nl_pernet *pernet;
+	struct mptcp_pm_local local;
 	unsigned int subflows_max;
 
 	pernet = pm_nl_get_pernet(sock_net(sk));
@@ -629,7 +633,7 @@  static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
 
 		spin_unlock_bh(&msk->pm.lock);
 		for (i = 0; i < nr; i++)
-			__mptcp_subflow_connect(sk, &local.addr, &addrs[i]);
+			__mptcp_subflow_connect(sk, &local, &addrs[i]);
 		spin_lock_bh(&msk->pm.lock);
 	}
 	mptcp_pm_nl_check_work_pending(msk);
@@ -650,7 +654,7 @@  static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
  */
 static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
 					     struct mptcp_addr_info *remote,
-					     struct mptcp_addr_info *addrs)
+					     struct mptcp_pm_local *locals)
 {
 	struct sock *sk = (struct sock *)msk;
 	struct mptcp_pm_addr_entry *entry;
@@ -673,13 +677,15 @@  static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
 			continue;
 
 		if (msk->pm.subflows < subflows_max) {
-			msk->pm.subflows++;
-			addrs[i] = entry->addr;
+			locals[i].addr = entry->addr;
+			locals[i].flags = entry->flags;
+			locals[i].ifindex = entry->ifindex;
 
 			/* Special case for ID0: set the correct ID */
-			if (mptcp_addresses_equal(&entry->addr, &mpc_addr, entry->addr.port))
-				addrs[i].id = 0;
+			if (mptcp_addresses_equal(&locals[i].addr, &mpc_addr, locals[i].addr.port))
+				locals[i].addr.id = 0;
 
+			msk->pm.subflows++;
 			i++;
 		}
 	}
@@ -689,21 +695,19 @@  static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
 	 * 'IPADDRANY' local address
 	 */
 	if (!i) {
-		struct mptcp_addr_info local;
-
-		memset(&local, 0, sizeof(local));
-		local.family =
+		memset(&locals[i], 0, sizeof(locals[i]));
+		locals[i].addr.family =
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
 			       remote->family == AF_INET6 &&
 			       ipv6_addr_v4mapped(&remote->addr6) ? AF_INET :
 #endif
 			       remote->family;
 
-		if (!mptcp_pm_addr_families_match(sk, &local, remote))
+		if (!mptcp_pm_addr_families_match(sk, &locals[i].addr, remote))
 			return 0;
 
 		msk->pm.subflows++;
-		addrs[i++] = local;
+		i++;
 	}
 
 	return i;
@@ -711,7 +715,7 @@  static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
 
 static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
 {
-	struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
+	struct mptcp_pm_local locals[MPTCP_PM_ADDR_MAX];
 	struct sock *sk = (struct sock *)msk;
 	unsigned int add_addr_accept_max;
 	struct mptcp_addr_info remote;
@@ -740,13 +744,13 @@  static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
 	/* connect to the specified remote address, using whatever
 	 * local address the routing configuration will pick.
 	 */
-	nr = fill_local_addresses_vec(msk, &remote, addrs);
+	nr = fill_local_addresses_vec(msk, &remote, locals);
 	if (nr == 0)
 		return;
 
 	spin_unlock_bh(&msk->pm.lock);
 	for (i = 0; i < nr; i++)
-		if (__mptcp_subflow_connect(sk, &addrs[i], &remote) == 0)
+		if (__mptcp_subflow_connect(sk, &locals[i], &remote) == 0)
 			sf_created = true;
 	spin_lock_bh(&msk->pm.lock);
 
@@ -1433,28 +1437,6 @@  int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info)
 	return ret;
 }
 
-int mptcp_pm_nl_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id,
-					    u8 *flags, int *ifindex)
-{
-	struct mptcp_pm_addr_entry *entry;
-	struct sock *sk = (struct sock *)msk;
-	struct net *net = sock_net(sk);
-
-	/* No entries with ID 0 */
-	if (id == 0)
-		return 0;
-
-	rcu_read_lock();
-	entry = __lookup_addr_by_id(pm_nl_get_pernet(net), id);
-	if (entry) {
-		*flags = entry->flags;
-		*ifindex = entry->ifindex;
-	}
-	rcu_read_unlock();
-
-	return 0;
-}
-
 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
 				      const struct mptcp_addr_info *addr)
 {
diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c
index 8eaa9fbe3e34..2cceded3a83a 100644
--- a/net/mptcp/pm_userspace.c
+++ b/net/mptcp/pm_userspace.c
@@ -119,23 +119,6 @@  mptcp_userspace_pm_lookup_addr_by_id(struct mptcp_sock *msk, unsigned int id)
 	return NULL;
 }
 
-int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk,
-						   unsigned int id,
-						   u8 *flags, int *ifindex)
-{
-	struct mptcp_pm_addr_entry *match;
-
-	spin_lock_bh(&msk->pm.lock);
-	match = mptcp_userspace_pm_lookup_addr_by_id(msk, id);
-	spin_unlock_bh(&msk->pm.lock);
-	if (match) {
-		*flags = match->flags;
-		*ifindex = match->ifindex;
-	}
-
-	return 0;
-}
-
 int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk,
 				    struct mptcp_addr_info *skc)
 {
@@ -352,8 +335,9 @@  int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info)
 	struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
 	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
 	struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
-	struct mptcp_pm_addr_entry local = { 0 };
+	struct mptcp_pm_addr_entry entry = { 0 };
 	struct mptcp_addr_info addr_r;
+	struct mptcp_pm_local local;
 	struct mptcp_sock *msk;
 	int err = -EINVAL;
 	struct sock *sk;
@@ -379,18 +363,18 @@  int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info)
 		goto create_err;
 	}
 
-	err = mptcp_pm_parse_entry(laddr, info, true, &local);
+	err = mptcp_pm_parse_entry(laddr, info, true, &entry);
 	if (err < 0) {
 		NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
 		goto create_err;
 	}
 
-	if (local.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
+	if (entry.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
 		GENL_SET_ERR_MSG(info, "invalid addr flags");
 		err = -EINVAL;
 		goto create_err;
 	}
-	local.flags |= MPTCP_PM_ADDR_FLAG_SUBFLOW;
+	entry.flags |= MPTCP_PM_ADDR_FLAG_SUBFLOW;
 
 	err = mptcp_pm_parse_addr(raddr, info, &addr_r);
 	if (err < 0) {
@@ -398,27 +382,29 @@  int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info)
 		goto create_err;
 	}
 
-	if (!mptcp_pm_addr_families_match(sk, &local.addr, &addr_r)) {
+	if (!mptcp_pm_addr_families_match(sk, &entry.addr, &addr_r)) {
 		GENL_SET_ERR_MSG(info, "families mismatch");
 		err = -EINVAL;
 		goto create_err;
 	}
 
-	err = mptcp_userspace_pm_append_new_local_addr(msk, &local, false);
+	err = mptcp_userspace_pm_append_new_local_addr(msk, &entry, false);
 	if (err < 0) {
 		GENL_SET_ERR_MSG(info, "did not match address and id");
 		goto create_err;
 	}
 
+	local.addr = entry.addr;
+	local.flags = entry.flags;
+	local.ifindex = entry.ifindex;
+
 	lock_sock(sk);
-
-	err = __mptcp_subflow_connect(sk, &local.addr, &addr_r);
-
+	err = __mptcp_subflow_connect(sk, &local, &addr_r);
 	release_sock(sk);
 
 	spin_lock_bh(&msk->pm.lock);
 	if (err)
-		mptcp_userspace_pm_delete_local_addr(msk, &local);
+		mptcp_userspace_pm_delete_local_addr(msk, &entry);
 	else
 		msk->pm.subflows++;
 	spin_unlock_bh(&msk->pm.lock);
diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h
index 3735b20f2626..bf03bff9ac44 100644
--- a/net/mptcp/protocol.h
+++ b/net/mptcp/protocol.h
@@ -236,6 +236,12 @@  struct mptcp_pm_data {
 	struct mptcp_rm_list rm_list_rx;
 };
 
+struct mptcp_pm_local {
+	struct mptcp_addr_info	addr;
+	u8			flags;
+	int			ifindex;
+};
+
 struct mptcp_pm_addr_entry {
 	struct list_head	list;
 	struct mptcp_addr_info	addr;
@@ -719,7 +725,7 @@  bool mptcp_addresses_equal(const struct mptcp_addr_info *a,
 void mptcp_local_address(const struct sock_common *skc, struct mptcp_addr_info *addr);
 
 /* called with sk socket lock held */
-int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc,
+int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_pm_local *local,
 			    const struct mptcp_addr_info *remote);
 int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
 				struct socket **new_sock);
@@ -1014,14 +1020,6 @@  mptcp_pm_del_add_timer(struct mptcp_sock *msk,
 struct mptcp_pm_add_entry *
 mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk,
 				const struct mptcp_addr_info *addr);
-int mptcp_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk,
-					 unsigned int id,
-					 u8 *flags, int *ifindex);
-int mptcp_pm_nl_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id,
-					    u8 *flags, int *ifindex);
-int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk,
-						   unsigned int id,
-						   u8 *flags, int *ifindex);
 int mptcp_pm_set_flags(struct sk_buff *skb, struct genl_info *info);
 int mptcp_pm_nl_set_flags(struct sk_buff *skb, struct genl_info *info);
 int mptcp_userspace_pm_set_flags(struct sk_buff *skb, struct genl_info *info);
diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
index 064ab3235893..0796122c9467 100644
--- a/net/mptcp/subflow.c
+++ b/net/mptcp/subflow.c
@@ -1565,26 +1565,24 @@  void mptcp_info2sockaddr(const struct mptcp_addr_info *info,
 #endif
 }
 
-int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc,
+int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_pm_local *local,
 			    const struct mptcp_addr_info *remote)
 {
 	struct mptcp_sock *msk = mptcp_sk(sk);
 	struct mptcp_subflow_context *subflow;
+	int local_id = local->addr.id;
 	struct sockaddr_storage addr;
 	int remote_id = remote->id;
-	int local_id = loc->id;
 	int err = -ENOTCONN;
 	struct socket *sf;
 	struct sock *ssk;
 	u32 remote_token;
 	int addrlen;
-	int ifindex;
-	u8 flags;
 
 	if (!mptcp_is_fully_established(sk))
 		goto err_out;
 
-	err = mptcp_subflow_create_socket(sk, loc->family, &sf);
+	err = mptcp_subflow_create_socket(sk, local->addr.family, &sf);
 	if (err)
 		goto err_out;
 
@@ -1594,23 +1592,32 @@  int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc,
 		get_random_bytes(&subflow->local_nonce, sizeof(u32));
 	} while (!subflow->local_nonce);
 
-	if (local_id)
+	/* if 'IPADDRANY', the ID will be set later, after the routing */
+	if (local->addr.family == AF_INET) {
+		if (!local->addr.addr.s_addr)
+			local_id = -1;
+#if IS_ENABLED(CONFIG_MPTCP_IPV6)
+	} else if (sk->sk_family == AF_INET6) {
+		if (ipv6_addr_any(&local->addr.addr6))
+			local_id = -1;
+#endif
+	}
+
+	if (local_id >= 0)
 		subflow_set_local_id(subflow, local_id);
 
-	mptcp_pm_get_flags_and_ifindex_by_id(msk, local_id,
-					     &flags, &ifindex);
 	subflow->remote_key_valid = 1;
 	subflow->remote_key = READ_ONCE(msk->remote_key);
 	subflow->local_key = READ_ONCE(msk->local_key);
 	subflow->token = msk->token;
-	mptcp_info2sockaddr(loc, &addr, ssk->sk_family);
+	mptcp_info2sockaddr(&local->addr, &addr, ssk->sk_family);
 
 	addrlen = sizeof(struct sockaddr_in);
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
 	if (addr.ss_family == AF_INET6)
 		addrlen = sizeof(struct sockaddr_in6);
 #endif
-	ssk->sk_bound_dev_if = ifindex;
+	ssk->sk_bound_dev_if = local->ifindex;
 	err = kernel_bind(sf, (struct sockaddr *)&addr, addrlen);
 	if (err)
 		goto failed;
@@ -1621,7 +1628,7 @@  int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc,
 	subflow->remote_token = remote_token;
 	WRITE_ONCE(subflow->remote_id, remote_id);
 	subflow->request_join = 1;
-	subflow->request_bkup = !!(flags & MPTCP_PM_ADDR_FLAG_BACKUP);
+	subflow->request_bkup = !!(local->flags & MPTCP_PM_ADDR_FLAG_BACKUP);
 	subflow->subflow_id = msk->subflow_id++;
 	mptcp_info2sockaddr(remote, &addr, ssk->sk_family);