diff mbox series

[RFC,16/30] iommufd/viommu: track the kvm pointer & its refcount in viommu core

Message ID 20250529053513.1592088-17-yilun.xu@linux.intel.com
State New
Headers show
Series Host side (KVM/VFIO/IOMMUFD) support for TDISP using TSM | expand

Commit Message

Xu Yilun May 29, 2025, 5:34 a.m. UTC
Track the kvm pointer and its refcount in viommu core. The kvm pointer
will be used later to support TSM Bind feature, which tells the secure
firmware the connection between a vPCI device and a CoCo VM.

There is existing need to reference kvm pointer in viommu [1], but in
that series kvm pointer is used & tracked in platform iommu drivers.
While in Confidential Computing (CC) case, viommu should manage a
generic routine for TSM Bind, i.e. call pci_tsm_bind(pdev, kvm, tdi_id)
So it is better the viommu core keeps and tracks the kvm pointer.

[1] https://lore.kernel.org/all/20250319173202.78988-5-shameerali.kolothum.thodi@huawei.com/

Signed-off-by: Lu Baolu <baolu.lu@linux.intel.com>
Signed-off-by: Xu Yilun <yilun.xu@linux.intel.com>
---
 drivers/iommu/iommufd/viommu.c | 62 ++++++++++++++++++++++++++++++++++
 include/linux/iommufd.h        |  3 ++
 2 files changed, 65 insertions(+)
diff mbox series

Patch

diff --git a/drivers/iommu/iommufd/viommu.c b/drivers/iommu/iommufd/viommu.c
index 488905989b7c..2fcef3f8d1a5 100644
--- a/drivers/iommu/iommufd/viommu.c
+++ b/drivers/iommu/iommufd/viommu.c
@@ -1,8 +1,68 @@ 
 // SPDX-License-Identifier: GPL-2.0-only
 /* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
  */
+#if IS_ENABLED(CONFIG_KVM)
+#include <linux/kvm_host.h>
+#endif
+
 #include "iommufd_private.h"
 
+#if IS_ENABLED(CONFIG_KVM)
+static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm)
+{
+	void (*pfn)(struct kvm *kvm);
+	bool (*fn)(struct kvm *kvm);
+	bool ret;
+
+	if (!kvm)
+		return;
+
+	pfn = symbol_get(kvm_put_kvm);
+	if (WARN_ON(!pfn))
+		return;
+
+	fn = symbol_get(kvm_get_kvm_safe);
+	if (WARN_ON(!fn)) {
+		symbol_put(kvm_put_kvm);
+		return;
+	}
+
+	ret = fn(kvm);
+	symbol_put(kvm_get_kvm_safe);
+	if (!ret) {
+		symbol_put(kvm_put_kvm);
+		return;
+	}
+
+	viommu->put_kvm = pfn;
+	viommu->kvm = kvm;
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+	if (!viommu->kvm)
+		return;
+
+	if (WARN_ON(!viommu->put_kvm))
+		goto clear;
+
+	viommu->put_kvm(viommu->kvm);
+	viommu->put_kvm = NULL;
+	symbol_put(kvm_put_kvm);
+
+clear:
+	viommu->kvm = NULL;
+}
+#else
+static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm)
+{
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+}
+#endif
+
 void iommufd_viommu_destroy(struct iommufd_object *obj)
 {
 	struct iommufd_viommu *viommu =
@@ -10,6 +70,7 @@  void iommufd_viommu_destroy(struct iommufd_object *obj)
 
 	if (viommu->ops && viommu->ops->destroy)
 		viommu->ops->destroy(viommu);
+	viommu_put_kvm(viommu);
 	refcount_dec(&viommu->hwpt->common.obj.users);
 	xa_destroy(&viommu->vdevs);
 }
@@ -68,6 +129,7 @@  int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
 	 * on its own.
 	 */
 	viommu->iommu_dev = __iommu_get_iommu_dev(idev->dev);
+	viommu_get_kvm_safe(viommu, idev->kvm);
 
 	cmd->out_viommu_id = viommu->obj.id;
 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
diff --git a/include/linux/iommufd.h b/include/linux/iommufd.h
index 2b2d6095309c..2712421802b9 100644
--- a/include/linux/iommufd.h
+++ b/include/linux/iommufd.h
@@ -104,6 +104,9 @@  struct iommufd_viommu {
 	struct rw_semaphore veventqs_rwsem;
 
 	unsigned int type;
+
+	struct kvm *kvm;
+	void (*put_kvm)(struct kvm *kvm);
 };
 
 /**