@@ -66,60 +66,6 @@ static void root_remove_peer_lists(struc
}
}
-static void walk_remove_by_peer(struct allowedips_node __rcu **top,
- struct wg_peer *peer, struct mutex *lock)
-{
-#define REF(p) rcu_access_pointer(p)
-#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
-#define PUSH(p) ({ \
- WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
- stack[len++] = p; \
- })
-
- struct allowedips_node __rcu **stack[128], **nptr;
- struct allowedips_node *node, *prev;
- unsigned int len;
-
- if (unlikely(!peer || !REF(*top)))
- return;
-
- for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
- nptr = stack[len - 1];
- node = DEREF(nptr);
- if (!node) {
- --len;
- continue;
- }
- if (!prev || REF(prev->bit[0]) == node ||
- REF(prev->bit[1]) == node) {
- if (REF(node->bit[0]))
- PUSH(&node->bit[0]);
- else if (REF(node->bit[1]))
- PUSH(&node->bit[1]);
- } else if (REF(node->bit[0]) == prev) {
- if (REF(node->bit[1]))
- PUSH(&node->bit[1]);
- } else {
- if (rcu_dereference_protected(node->peer,
- lockdep_is_held(lock)) == peer) {
- RCU_INIT_POINTER(node->peer, NULL);
- list_del_init(&node->peer_list);
- if (!node->bit[0] || !node->bit[1]) {
- rcu_assign_pointer(*nptr, DEREF(
- &node->bit[!REF(node->bit[0])]));
- kfree_rcu(node, rcu);
- node = DEREF(nptr);
- }
- }
- --len;
- }
- }
-
-#undef REF
-#undef DEREF
-#undef PUSH
-}
-
static unsigned int fls128(u64 a, u64 b)
{
return a ? fls64(a) + 64U : fls64(b);
@@ -224,6 +170,7 @@ static int add(struct allowedips_node __
RCU_INIT_POINTER(node->peer, peer);
list_add_tail(&node->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(node, key, cidr, bits);
+ rcu_assign_pointer(node->parent_bit, trie);
rcu_assign_pointer(*trie, node);
return 0;
}
@@ -243,9 +190,9 @@ static int add(struct allowedips_node __
if (!node) {
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
} else {
- down = rcu_dereference_protected(CHOOSE_NODE(node, key),
- lockdep_is_held(lock));
+ down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock));
if (!down) {
+ rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, key));
rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
return 0;
}
@@ -254,29 +201,37 @@ static int add(struct allowedips_node __
parent = node;
if (newnode->cidr == cidr) {
+ rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(newnode, down->bits));
rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
- if (!parent)
+ if (!parent) {
+ rcu_assign_pointer(newnode->parent_bit, trie);
rcu_assign_pointer(*trie, newnode);
- else
- rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
- newnode);
- } else {
- node = kzalloc(sizeof(*node), GFP_KERNEL);
- if (unlikely(!node)) {
- list_del(&newnode->peer_list);
- kfree(newnode);
- return -ENOMEM;
+ } else {
+ rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(parent, newnode->bits));
+ rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode);
}
- INIT_LIST_HEAD(&node->peer_list);
- copy_and_assign_cidr(node, newnode->bits, cidr, bits);
+ return 0;
+ }
+
+ node = kzalloc(sizeof(*node), GFP_KERNEL);
+ if (unlikely(!node)) {
+ list_del(&newnode->peer_list);
+ kfree(newnode);
+ return -ENOMEM;
+ }
+ INIT_LIST_HEAD(&node->peer_list);
+ copy_and_assign_cidr(node, newnode->bits, cidr, bits);
- rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
- rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
- if (!parent)
- rcu_assign_pointer(*trie, node);
- else
- rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
- node);
+ rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(node, down->bits));
+ rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
+ rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, newnode->bits));
+ rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
+ if (!parent) {
+ rcu_assign_pointer(node->parent_bit, trie);
+ rcu_assign_pointer(*trie, node);
+ } else {
+ rcu_assign_pointer(node->parent_bit, &CHOOSE_NODE(parent, node->bits));
+ rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node);
}
return 0;
}
@@ -335,9 +290,30 @@ int wg_allowedips_insert_v6(struct allow
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock)
{
+ struct allowedips_node *node, *child, *tmp;
+
+ if (list_empty(&peer->allowedips_list))
+ return;
++table->seq;
- walk_remove_by_peer(&table->root4, peer, lock);
- walk_remove_by_peer(&table->root6, peer, lock);
+ list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
+ list_del_init(&node->peer_list);
+ RCU_INIT_POINTER(node->peer, NULL);
+ if (node->bit[0] && node->bit[1])
+ continue;
+ child = rcu_dereference_protected(
+ node->bit[!rcu_access_pointer(node->bit[0])],
+ lockdep_is_held(lock));
+ if (child)
+ child->parent_bit = node->parent_bit;
+ *rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child;
+ kfree_rcu(node, rcu);
+
+ /* TODO: Note that we currently don't walk up and down in order to
+ * free any potential filler nodes. This means that this function
+ * doesn't free up as much as it could, which could be revisited
+ * at some point.
+ */
+ }
}
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
@@ -15,14 +15,11 @@ struct wg_peer;
struct allowedips_node {
struct wg_peer __rcu *peer;
struct allowedips_node __rcu *bit[2];
- /* While it may seem scandalous that we waste space for v4,
- * we're alloc'ing to the nearest power of 2 anyway, so this
- * doesn't actually make a difference.
- */
- u8 bits[16] __aligned(__alignof(u64));
u8 cidr, bit_at_a, bit_at_b, bitlen;
+ u8 bits[16] __aligned(__alignof(u64));
- /* Keep rarely used list at bottom to be beyond cache line. */
+ /* Keep rarely used members at bottom to be beyond cache line. */
+ struct allowedips_node *__rcu *parent_bit; /* XXX: this puts us at 68->128 bytes instead of 60->64 bytes!! */
union {
struct list_head peer_list;
struct rcu_head rcu;