@@ -1,8 +1,68 @@
// SPDX-License-Identifier: GPL-2.0-only
/* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
*/
+#if IS_ENABLED(CONFIG_KVM)
+#include <linux/kvm_host.h>
+#endif
+
#include "iommufd_private.h"
+#if IS_ENABLED(CONFIG_KVM)
+static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm)
+{
+ void (*pfn)(struct kvm *kvm);
+ bool (*fn)(struct kvm *kvm);
+ bool ret;
+
+ if (!kvm)
+ return;
+
+ pfn = symbol_get(kvm_put_kvm);
+ if (WARN_ON(!pfn))
+ return;
+
+ fn = symbol_get(kvm_get_kvm_safe);
+ if (WARN_ON(!fn)) {
+ symbol_put(kvm_put_kvm);
+ return;
+ }
+
+ ret = fn(kvm);
+ symbol_put(kvm_get_kvm_safe);
+ if (!ret) {
+ symbol_put(kvm_put_kvm);
+ return;
+ }
+
+ viommu->put_kvm = pfn;
+ viommu->kvm = kvm;
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+ if (!viommu->kvm)
+ return;
+
+ if (WARN_ON(!viommu->put_kvm))
+ goto clear;
+
+ viommu->put_kvm(viommu->kvm);
+ viommu->put_kvm = NULL;
+ symbol_put(kvm_put_kvm);
+
+clear:
+ viommu->kvm = NULL;
+}
+#else
+static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm)
+{
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+}
+#endif
+
void iommufd_viommu_destroy(struct iommufd_object *obj)
{
struct iommufd_viommu *viommu =
@@ -10,6 +70,7 @@ void iommufd_viommu_destroy(struct iommufd_object *obj)
if (viommu->ops && viommu->ops->destroy)
viommu->ops->destroy(viommu);
+ viommu_put_kvm(viommu);
refcount_dec(&viommu->hwpt->common.obj.users);
xa_destroy(&viommu->vdevs);
}
@@ -68,6 +129,7 @@ int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
* on its own.
*/
viommu->iommu_dev = __iommu_get_iommu_dev(idev->dev);
+ viommu_get_kvm_safe(viommu, idev->kvm);
cmd->out_viommu_id = viommu->obj.id;
rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
@@ -104,6 +104,9 @@ struct iommufd_viommu {
struct rw_semaphore veventqs_rwsem;
unsigned int type;
+
+ struct kvm *kvm;
+ void (*put_kvm)(struct kvm *kvm);
};
/**