@@ -70,7 +70,7 @@ dispatch_hid_bpf_device_event(struct hid_device *hdev, enum hid_report_type type
memset(ctx_kern.data, 0, hdev->bpf.allocated_data);
memcpy(ctx_kern.data, data, *size);
- ret = hid_bpf_prog_run(hdev, HID_BPF_PROG_TYPE_DEVICE_EVENT, &ctx_kern);
+ ret = hid_bpf_prog_run(hdev, HID_BPF_PROG_TYPE_DEVICE_EVENT, &ctx_kern, false);
if (ret < 0)
return ERR_PTR(ret);
@@ -122,7 +122,7 @@ u8 *call_hid_bpf_rdesc_fixup(struct hid_device *hdev, u8 *rdesc, unsigned int *s
memcpy(ctx_kern.data, rdesc, min_t(unsigned int, *size, HID_MAX_DESCRIPTOR_SIZE));
- ret = hid_bpf_prog_run(hdev, HID_BPF_PROG_TYPE_RDESC_FIXUP, &ctx_kern);
+ ret = hid_bpf_prog_run(hdev, HID_BPF_PROG_TYPE_RDESC_FIXUP, &ctx_kern, false);
if (ret < 0)
goto ignore_bpf;
@@ -205,7 +205,7 @@ int hid_bpf_reconnect(struct hid_device *hdev)
static int do_hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
hid_bpf_cb_t prog_fn, struct bpf_prog *prog,
- __u32 flags)
+ __u32 flags, bool sleepable)
{
int fd, err;
@@ -213,7 +213,7 @@ static int do_hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_typ
if (err)
return err;
- fd = __hid_bpf_attach_prog(hdev, prog_type, prog_fn, prog, flags);
+ fd = __hid_bpf_attach_prog(hdev, prog_type, prog_fn, prog, flags, sleepable);
if (fd < 0)
return fd;
@@ -228,6 +228,56 @@ static int do_hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_typ
return fd;
}
+static int
+hid_bpf_attach_prog(unsigned int hid_id, enum hid_bpf_prog_type prog_type,
+ hid_bpf_cb_t prog_fn, __u32 flags, void *prog__aux,
+ bool sleepable)
+{
+ struct bpf_prog_aux *aux = (struct bpf_prog_aux *)prog__aux;
+ struct bpf_prog *prog = aux->prog;
+ struct hid_device *hdev;
+ struct device *dev;
+ int err, fd;
+
+ if (!hid_bpf_ops)
+ return -EINVAL;
+
+ if ((flags & ~HID_BPF_FLAG_MASK))
+ return -EINVAL;
+
+ if (prog_type < 0 || prog_type >= HID_BPF_PROG_TYPE_MAX)
+ return -EINVAL;
+
+ dev = bus_find_device(hid_bpf_ops->bus_type, NULL, &hid_id, device_match_id);
+ if (!dev)
+ return -EINVAL;
+
+ hdev = to_hid_device(dev);
+
+ /*
+ * take a ref on the prog itself, it will be released
+ * on errors or when it'll be detached
+ */
+ prog = bpf_prog_inc_not_zero(prog);
+ if (IS_ERR(prog)) {
+ err = PTR_ERR(prog);
+ goto out_dev_put;
+ }
+
+ fd = do_hid_bpf_attach_prog(hdev, prog_type, prog_fn, prog, flags, sleepable);
+ if (fd < 0) {
+ err = fd;
+ goto out_prog_put;
+ }
+
+ return fd;
+
+ out_prog_put:
+ bpf_prog_put(prog);
+ out_dev_put:
+ put_device(dev);
+ return err;
+}
/* Disables missing prototype warnings */
__bpf_kfunc_start_defs();
@@ -272,50 +322,22 @@ __bpf_kfunc int
hid_bpf_attach_prog_impl(unsigned int hid_id, enum hid_bpf_prog_type prog_type,
hid_bpf_cb_t prog_fn__async, __u32 flags, void *prog__aux)
{
- struct bpf_prog_aux *aux = (struct bpf_prog_aux *)prog__aux;
- struct bpf_prog *prog = aux->prog;
- struct hid_device *hdev;
- struct device *dev;
- int err, fd;
-
- if (!hid_bpf_ops)
- return -EINVAL;
-
- if ((flags & ~HID_BPF_FLAG_MASK))
- return -EINVAL;
-
- if (prog_type < 0 || prog_type >= HID_BPF_PROG_TYPE_MAX)
- return -EINVAL;
+ return hid_bpf_attach_prog(hid_id, prog_type, prog_fn__async, flags, prog__aux, false);
+}
- dev = bus_find_device(hid_bpf_ops->bus_type, NULL, &hid_id, device_match_id);
- if (!dev)
+__bpf_kfunc int
+hid_bpf_attach_sleepable_prog_impl(unsigned int hid_id, enum hid_bpf_prog_type prog_type,
+ hid_bpf_cb_t prog_fn__s_async, __u32 flags, void *prog__aux)
+{
+ switch (prog_type) {
+ case HID_BPF_PROG_TYPE_RAW_REQUEST:
+ /* OK */
+ break;
+ default:
return -EINVAL;
-
- hdev = to_hid_device(dev);
-
- /*
- * take a ref on the prog itself, it will be released
- * on errors or when it'll be detached
- */
- prog = bpf_prog_inc_not_zero(prog);
- if (IS_ERR(prog)) {
- err = PTR_ERR(prog);
- goto out_dev_put;
- }
-
- fd = do_hid_bpf_attach_prog(hdev, prog_type, prog_fn__async, prog, flags);
- if (fd < 0) {
- err = fd;
- goto out_prog_put;
}
- return fd;
-
- out_prog_put:
- bpf_prog_put(prog);
- out_dev_put:
- put_device(dev);
- return err;
+ return hid_bpf_attach_prog(hid_id, prog_type, prog_fn__s_async, flags, prog__aux, true);
}
/**
@@ -538,6 +560,7 @@ __bpf_kfunc_end_defs();
BTF_KFUNCS_START(hid_bpf_kfunc_ids)
BTF_ID_FLAGS(func, hid_bpf_get_data, KF_RET_NULL)
BTF_ID_FLAGS(func, hid_bpf_attach_prog_impl, KF_SLEEPABLE)
+BTF_ID_FLAGS(func, hid_bpf_attach_sleepable_prog_impl, KF_SLEEPABLE)
BTF_ID_FLAGS(func, hid_bpf_allocate_context, KF_ACQUIRE | KF_RET_NULL | KF_SLEEPABLE)
BTF_ID_FLAGS(func, hid_bpf_release_context, KF_RELEASE | KF_SLEEPABLE)
BTF_ID_FLAGS(func, hid_bpf_hw_request, KF_SLEEPABLE)
@@ -14,10 +14,10 @@ typedef int (*hid_bpf_cb_t)(struct hid_bpf_ctx *hid_ctx);
int __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
int (prog_fn__async)(struct hid_bpf_ctx *hid_ctx),
- struct bpf_prog *prog, __u32 flags);
+ struct bpf_prog *prog, __u32 flags, bool sleepable);
void __hid_bpf_destroy_device(struct hid_device *hdev);
int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
- struct hid_bpf_ctx_kern *ctx_kern);
+ struct hid_bpf_ctx_kern *ctx_kern, bool is_sleepable);
int hid_bpf_reconnect(struct hid_device *hdev);
struct bpf_prog;
@@ -39,6 +39,7 @@ struct hid_bpf_prog_entry {
struct hid_bpf_prog_cb {
struct bpf_prog *prog;
void *fn;
+ bool sleepable;
};
struct hid_bpf_jmp_table {
@@ -99,14 +100,20 @@ static int hid_bpf_program_count(struct hid_device *hdev,
}
int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
- struct hid_bpf_ctx_kern *ctx_kern)
+ struct hid_bpf_ctx_kern *ctx_kern, bool is_sleepable)
{
struct hid_bpf_prog_list *prog_list;
bpf_callback_t prog_fn;
int i, idx, err = 0;
- rcu_read_lock();
- prog_list = rcu_dereference(hdev->bpf.progs[type]);
+ if (is_sleepable) {
+ prog_list = READ_ONCE(hdev->bpf.progs[type]);
+ rcu_read_lock_trace();
+ might_fault();
+ } else {
+ rcu_read_lock();
+ prog_list = rcu_dereference(hdev->bpf.progs[type]);
+ }
if (!prog_list)
goto out_unlock;
@@ -117,6 +124,10 @@ int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
if (!test_bit(idx, jmp_table.enabled))
continue;
+ /* prevent a sleepable program to be run in a non sleepable context */
+ if (!is_sleepable && jmp_table.prog_cbs[idx].sleepable)
+ continue;
+
ctx_kern->ctx.index = idx;
prog_fn = jmp_table.prog_cbs[idx].fn;
migrate_disable();
@@ -129,7 +140,10 @@ int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
}
out_unlock:
- rcu_read_unlock();
+ if (is_sleepable)
+ rcu_read_unlock_trace();
+ else
+ rcu_read_unlock();
return err;
}
@@ -279,7 +293,7 @@ static void hid_bpf_release_progs(struct work_struct *work)
* Insert the given BPF program represented by its function call in the jmp table.
* Returns the index in the jump table or a negative error.
*/
-static int hid_bpf_insert_prog(struct bpf_prog *prog, hid_bpf_cb_t prog_fn)
+static int hid_bpf_insert_prog(struct bpf_prog *prog, hid_bpf_cb_t prog_fn, bool sleepable)
{
int i, index = -1, err = -EINVAL;
@@ -289,6 +303,7 @@ static int hid_bpf_insert_prog(struct bpf_prog *prog, hid_bpf_cb_t prog_fn)
/* mark the index as used */
jmp_table.prog_cbs[i].fn = prog_fn;
jmp_table.prog_cbs[i].prog = prog;
+ jmp_table.prog_cbs[i].sleepable = sleepable;
index = i;
__set_bit(i, jmp_table.enabled);
}
@@ -340,7 +355,8 @@ static const struct bpf_link_ops hid_bpf_link_lops = {
/* called from syscall */
noinline int
__hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
- hid_bpf_cb_t prog_fn, struct bpf_prog *prog, __u32 flags)
+ hid_bpf_cb_t prog_fn, struct bpf_prog *prog, __u32 flags,
+ bool sleepable)
{
struct bpf_link_primer link_primer;
struct hid_bpf_link *link;
@@ -370,7 +386,7 @@ __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
goto err_unlock;
}
- prog_table_idx = hid_bpf_insert_prog(prog, prog_fn);
+ prog_table_idx = hid_bpf_insert_prog(prog, prog_fn, sleepable);
/* if the jmp table is full, abort */
if (prog_table_idx < 0) {
err = prog_table_idx;
Signed-off-by: Benjamin Tissoires <bentiss@kernel.org> --- drivers/hid/bpf/hid_bpf_dispatch.c | 111 ++++++++++++++++++++++-------------- drivers/hid/bpf/hid_bpf_dispatch.h | 4 +- drivers/hid/bpf/hid_bpf_jmp_table.c | 30 +++++++--- 3 files changed, 92 insertions(+), 53 deletions(-)