diff mbox series

[v1,05/16] iommufd/viommu: Add IOMMU_VIOMMU_SET/UNSET_VDEV_ID ioctl

Message ID e35a24d4337b985aabbcfe7857cac2186d4f61e9.1723061378.git.nicolinc@nvidia.com
State New
Headers show
Series iommufd: Add VIOMMU infrastructure (Part-1) | expand

Commit Message

Nicolin Chen Aug. 7, 2024, 8:10 p.m. UTC
Introduce a pair of new ioctls to set/unset a per-viommu virtual device id
that should be linked to a physical device id via a struct device pointer.

Continue the support IOMMU_VIOMMU_TYPE_DEFAULT for a core-managed viommu.
Provide a lookup function for drivers to load device pointer by a virtual
device id.

Signed-off-by: Nicolin Chen <nicolinc@nvidia.com>
---
 drivers/iommu/iommufd/device.c          |   9 ++
 drivers/iommu/iommufd/iommufd_private.h |  20 ++++
 drivers/iommu/iommufd/main.c            |   6 ++
 drivers/iommu/iommufd/viommu.c          | 118 ++++++++++++++++++++++++
 include/uapi/linux/iommufd.h            |  40 ++++++++
 5 files changed, 193 insertions(+)

Comments

Nicolin Chen Aug. 14, 2024, 5:09 p.m. UTC | #1
On Wed, Aug 07, 2024 at 01:10:46PM -0700, Nicolin Chen wrote:
> @@ -135,7 +135,14 @@ void iommufd_device_destroy(struct iommufd_object *obj)
>  {
>         struct iommufd_device *idev =
>                 container_of(obj, struct iommufd_device, obj);
> +       struct iommufd_vdev_id *vdev_id, *curr;
> 
> +       list_for_each_entry(vdev_id, &idev->vdev_id_list, idev_item) {
> +               curr = xa_cmpxchg(&vdev_id->viommu->vdev_ids, vdev_id->vdev_id,
> +                                 vdev_id, NULL, GFP_KERNEL);
> +               WARN_ON(curr != vdev_id);
> +               kfree(vdev_id);
> +       }

Kevin already pointed out previously during the RFC review that
we probably should do one vdev_id per idev. And Jason expressed
okay to either way. I didn't plan to change this part until this
week for the VIRQ series.

My rethinking is that an idev is attached to one (and only one)
nested HWPT. The nested HWPT is associated to one (and only one)
VIOMMU object. So, it's unlikely we can a second vdev_id, i.e.
idev->vdev_id is enough.

This helps us to build a device-based virq report function:
+void iommufd_device_report_virq(struct device *dev, unsigned int data_type,
+                               void *data_ptr, size_t data_len);

I built a link from device to viommu reusing Baolu's work:
struct device -> struct iommu_group -> struct iommu_attach_handle
-> struct iommufd_attach_handle -> struct iommufd_device (idev)
-> struct iommufd_vdev_id (idev->vdev_id)

The vdev_id struct holds viommu and virtual ID, so allowing us
to add another two helpers:
+struct iommufd_viommu *iommufd_device_get_viommu(struct device *dev);
+u64 iommufd_device_get_virtual_id(struct device *dev);

A driver that reports event/irq per device can use these helpers
to report virq via the core-managed VIOMMU object. (If a driver
has some non-per-device type of IRQs, it would have to allocate
a driver-managed VIOMMU object instead.)

I have both a revised VIOMMU series and a new VIRQ series ready.
Will send in the following days after some testing/polishing.

Thanks
Nicolin
Jason Gunthorpe Aug. 14, 2024, 10:02 p.m. UTC | #2
On Wed, Aug 14, 2024 at 10:09:22AM -0700, Nicolin Chen wrote:

> This helps us to build a device-based virq report function:
> +void iommufd_device_report_virq(struct device *dev, unsigned int data_type,
> +                               void *data_ptr, size_t data_len);
> 
> I built a link from device to viommu reusing Baolu's work:
> struct device -> struct iommu_group -> struct iommu_attach_handle
> -> struct iommufd_attach_handle -> struct iommufd_device (idev)
> -> struct iommufd_vdev_id (idev->vdev_id)

That makes sense, the vdev id would be 1:1 with the struct device, and
the iommufd_device is also supposed to be 1:1 with the struct device.

Jason
Jason Gunthorpe Aug. 15, 2024, 7:08 p.m. UTC | #3
On Wed, Aug 07, 2024 at 01:10:46PM -0700, Nicolin Chen wrote:

> +int iommufd_viommu_set_vdev_id(struct iommufd_ucmd *ucmd)
> +{
> +	struct iommu_viommu_set_vdev_id *cmd = ucmd->cmd;
> +	struct iommufd_hwpt_nested *hwpt_nested;
> +	struct iommufd_vdev_id *vdev_id, *curr;
> +	struct iommufd_hw_pagetable *hwpt;
> +	struct iommufd_viommu *viommu;
> +	struct iommufd_device *idev;
> +	int rc = 0;
> +
> +	if (cmd->vdev_id > ULONG_MAX)
> +		return -EINVAL;
> +
> +	idev = iommufd_get_device(ucmd, cmd->dev_id);
> +	if (IS_ERR(idev))
> +		return PTR_ERR(idev);
> +	hwpt = idev->igroup->hwpt;
> +
> +	if (hwpt == NULL || hwpt->obj.type != IOMMUFD_OBJ_HWPT_NESTED) {
> +		rc = -EINVAL;
> +		goto out_put_idev;
> +	}
> +	hwpt_nested = container_of(hwpt, struct iommufd_hwpt_nested, common);

This doesn't seem like a necessary check, the attached hwpt can change
after this is established, so this can't be an invariant we enforce.

If you want to do 1:1 then somehow directly check if the idev is
already linked to a viommu.

> +static struct device *
> +iommufd_viommu_find_device(struct iommufd_viommu *viommu, u64 id)
> +{
> +	struct iommufd_vdev_id *vdev_id;
> +
> +	xa_lock(&viommu->vdev_ids);
> +	vdev_id = xa_load(&viommu->vdev_ids, (unsigned long)id);
> +	xa_unlock(&viommu->vdev_ids);

This lock doesn't do anything

> +	if (!vdev_id || vdev_id->vdev_id != id)
> +		return NULL;

And this is unlocked

> +	return vdev_id->dev;
> +}

This isn't good.. We can't return the struct device pointer here as
there is no locking for it anymore. We can't even know it is still
probed to VFIO anymore.

It has to work by having the iommu driver directly access the xarray
and the entirely under the spinlock the iommu driver can translate the
vSID to the pSID and the let go and push the invalidation to HW. No
races.


> +int iommufd_viommu_unset_vdev_id(struct iommufd_ucmd *ucmd)
> +{
> +	struct iommu_viommu_unset_vdev_id *cmd = ucmd->cmd;
> +	struct iommufd_vdev_id *vdev_id;
> +	struct iommufd_viommu *viommu;
> +	struct iommufd_device *idev;
> +	int rc = 0;
> +
> +	idev = iommufd_get_device(ucmd, cmd->dev_id);
> +	if (IS_ERR(idev))
> +		return PTR_ERR(idev);
> +
> +	viommu = iommufd_get_viommu(ucmd, cmd->viommu_id);
> +	if (IS_ERR(viommu)) {
> +		rc = PTR_ERR(viommu);
> +		goto out_put_idev;
> +	}
> +
> +	if (idev->dev != iommufd_viommu_find_device(viommu, cmd->vdev_id)) {

Swap the order around != to be more kernely

> +		rc = -EINVAL;
> +		goto out_put_viommu;
> +	}
> +
> +	vdev_id = xa_erase(&viommu->vdev_ids, cmd->vdev_id);

And this whole thing needs to be done under the xa_lock too.

xa_lock(&viommu->vdev_ids);
vdev_id = xa_load(&viommu->vdev_ids, cmd->vdev_id);
if (!vdev_id || vdev_id->vdev_id != cmd->vdev_id (????) || vdev_id->dev != idev->dev)
    err
__xa_erase(&viommu->vdev_ids, cmd->vdev_id);
xa_unlock((&viommu->vdev_ids);

Jason
Nicolin Chen Aug. 15, 2024, 7:53 p.m. UTC | #4
On Thu, Aug 15, 2024 at 12:46:29PM -0700, Nicolin Chen wrote:
> > > +static struct device *
> > > +iommufd_viommu_find_device(struct iommufd_viommu *viommu, u64 id)
> > > +{
> > > +	struct iommufd_vdev_id *vdev_id;
> > > +
> > > +	xa_lock(&viommu->vdev_ids);
> > > +	vdev_id = xa_load(&viommu->vdev_ids, (unsigned long)id);
> > > +	xa_unlock(&viommu->vdev_ids);
> > 
> > This lock doesn't do anything
> > 
> > > +	if (!vdev_id || vdev_id->vdev_id != id)
> > > +		return NULL;
> > 
> > And this is unlocked
> > 
> > > +	return vdev_id->dev;
> > > +}
> > 
> > This isn't good.. We can't return the struct device pointer here as
> > there is no locking for it anymore. We can't even know it is still
> > probed to VFIO anymore.
> > 
> > It has to work by having the iommu driver directly access the xarray
> > and the entirely under the spinlock the iommu driver can translate the
> > vSID to the pSID and the let go and push the invalidation to HW. No
> > races.
> 
> Maybe the iommufd_viommu_invalidate ioctl handler should hold that
> xa_lock around the viommu->ops->cache_invalidate, and then add lock
> assert in iommufd_viommu_find_device?

xa_lock/spinlock might be too heavy. We can have a mutex to wrap
around viommu ioctl handlers..
Jason Gunthorpe Aug. 15, 2024, 11:41 p.m. UTC | #5
On Thu, Aug 15, 2024 at 12:46:24PM -0700, Nicolin Chen wrote:
> On Thu, Aug 15, 2024 at 04:08:48PM -0300, Jason Gunthorpe wrote:
> > On Wed, Aug 07, 2024 at 01:10:46PM -0700, Nicolin Chen wrote:
> > 
> > > +int iommufd_viommu_set_vdev_id(struct iommufd_ucmd *ucmd)
> > > +{
> > > +	struct iommu_viommu_set_vdev_id *cmd = ucmd->cmd;
> > > +	struct iommufd_hwpt_nested *hwpt_nested;
> > > +	struct iommufd_vdev_id *vdev_id, *curr;
> > > +	struct iommufd_hw_pagetable *hwpt;
> > > +	struct iommufd_viommu *viommu;
> > > +	struct iommufd_device *idev;
> > > +	int rc = 0;
> > > +
> > > +	if (cmd->vdev_id > ULONG_MAX)
> > > +		return -EINVAL;
> > > +
> > > +	idev = iommufd_get_device(ucmd, cmd->dev_id);
> > > +	if (IS_ERR(idev))
> > > +		return PTR_ERR(idev);
> > > +	hwpt = idev->igroup->hwpt;
> > > +
> > > +	if (hwpt == NULL || hwpt->obj.type != IOMMUFD_OBJ_HWPT_NESTED) {
> > > +		rc = -EINVAL;
> > > +		goto out_put_idev;
> > > +	}
> > > +	hwpt_nested = container_of(hwpt, struct iommufd_hwpt_nested, common);
> > 
> > This doesn't seem like a necessary check, the attached hwpt can change
> > after this is established, so this can't be an invariant we enforce.
> > 
> > If you want to do 1:1 then somehow directly check if the idev is
> > already linked to a viommu.
> 
> But idev can't link to a viommu without a proxy hwpt_nested?

Why not? The idev becomes linked to the viommu when the dev id is set

Unless we are also going to enforce the idev is always attached to a
nested then I don't think we need to check it here.

Things will definately not entirely work as expected if the vdev is
directly attached to the s2 or a blocking, but it won't harm anything.

> the stage-2 only configuration should have an identity hwpt_nested
> right?

Yes, that is the right way to use the API

> > It has to work by having the iommu driver directly access the xarray
> > and the entirely under the spinlock the iommu driver can translate the
> > vSID to the pSID and the let go and push the invalidation to HW. No
> > races.
> 
> Maybe the iommufd_viommu_invalidate ioctl handler should hold that
> xa_lock around the viommu->ops->cache_invalidate, and then add lock
> assert in iommufd_viommu_find_device?

That doesn't seem like a great idea, you can't do copy_from_user under
a spinlock.

> > xa_lock(&viommu->vdev_ids);
> > vdev_id = xa_load(&viommu->vdev_ids, cmd->vdev_id);
> > if (!vdev_id || vdev_id->vdev_id != cmd->vdev_id (????) || vdev_id->dev != idev->dev)
> >     err
> > __xa_erase(&viommu->vdev_ids, cmd->vdev_id);
> > xa_unlock((&viommu->vdev_ids);
> 
> I've changed to xa_cmpxchg() in my local tree. Would it be simpler?

No, that is still not right, you can't take the vdev_id outside the
lock at all. Even for cmpxchng because the vdev_id could have been
freed and reallocated by another thread.

You must combine the validation of the vdev_id with the erase under a
single critical region.

Jason
Jason Gunthorpe Aug. 19, 2024, 5:33 p.m. UTC | #6
On Thu, Aug 15, 2024 at 05:21:57PM -0700, Nicolin Chen wrote:

> > Why not? The idev becomes linked to the viommu when the dev id is set
> 
> > Unless we are also going to enforce the idev is always attached to a
> > nested then I don't think we need to check it here.
> > 
> > Things will definately not entirely work as expected if the vdev is
> > directly attached to the s2 or a blocking, but it won't harm anything.
> 
> My view is that, the moment there is a VIOMMU object, that must
> be a nested IOMMU case, so there must be a nested hwpt. Blocking
> domain would be a hwpt_nested too (vSTE=Abort) as we previously
> concluded.

I'm not sure other vendors can do that vSTE=Abort/Bypass thing though
yet..

> Then, in a nested case, it feels odd that an idev is attached to
> an S2 hwpt..
>
> That being said, I think we can still do that with validations:
>  If idev->hwpt is nested, compare input viommu v.s idev->hwpt->viommu.
>  If idev->hwpt is paging, compare input viommu->hwpt v.s idev->hwpt.

But again, if you don't contiguously validate those invariants in all
the other attach paths it is sort of pointless to check them since the
userspace can still violate things.

> This complicates things overall especially with the VIRQ that has
> involved interrupt context polling vdev_id, where semaphore/mutex
> won't fit very well. Perhaps it would need a driver-level bottom
> half routine to call those helpers with locks. I am glad that you
> noticed the problem early.

I think you have to show the xarray to the driver and the driver can
use the spinlock to access it safely. Keeping it hidden in the core
code is causing all these locking problems.

Jason
Nicolin Chen Aug. 19, 2024, 6:10 p.m. UTC | #7
On Mon, Aug 19, 2024 at 02:33:32PM -0300, Jason Gunthorpe wrote:
> On Thu, Aug 15, 2024 at 05:21:57PM -0700, Nicolin Chen wrote:
> 
> > > Why not? The idev becomes linked to the viommu when the dev id is set
> > 
> > > Unless we are also going to enforce the idev is always attached to a
> > > nested then I don't think we need to check it here.
> > > 
> > > Things will definately not entirely work as expected if the vdev is
> > > directly attached to the s2 or a blocking, but it won't harm anything.
> > 
> > My view is that, the moment there is a VIOMMU object, that must
> > be a nested IOMMU case, so there must be a nested hwpt. Blocking
> > domain would be a hwpt_nested too (vSTE=Abort) as we previously
> > concluded.
> 
> I'm not sure other vendors can do that vSTE=Abort/Bypass thing though
> yet..
> 
> > Then, in a nested case, it feels odd that an idev is attached to
> > an S2 hwpt..
> >
> > That being said, I think we can still do that with validations:
> >  If idev->hwpt is nested, compare input viommu v.s idev->hwpt->viommu.
> >  If idev->hwpt is paging, compare input viommu->hwpt v.s idev->hwpt.
> 
> But again, if you don't contiguously validate those invariants in all
> the other attach paths it is sort of pointless to check them since the
> userspace can still violate things.

Hmm, would that be unsafe? I start to wonder if we should allow an
attach to viommu and put validations on that?

> > This complicates things overall especially with the VIRQ that has
> > involved interrupt context polling vdev_id, where semaphore/mutex
> > won't fit very well. Perhaps it would need a driver-level bottom
> > half routine to call those helpers with locks. I am glad that you
> > noticed the problem early.
> 
> I think you have to show the xarray to the driver and the driver can
> use the spinlock to access it safely. Keeping it hidden in the core
> code is causing all these locking problems.

Yea, I just figured that out... You have been right. I was able to
get rid of the locking problem with invalidation API. But then irq
became a headache as drivers would only know the dev pointer, so
everything that the dev could convert to would be unsafe as it can
not grab the idev/viommu locks until it converts.

Thanks
Nicolin
Jason Gunthorpe Aug. 19, 2024, 6:26 p.m. UTC | #8
On Mon, Aug 19, 2024 at 11:10:03AM -0700, Nicolin Chen wrote:
> On Mon, Aug 19, 2024 at 02:33:32PM -0300, Jason Gunthorpe wrote:
> > On Thu, Aug 15, 2024 at 05:21:57PM -0700, Nicolin Chen wrote:
> > 
> > > > Why not? The idev becomes linked to the viommu when the dev id is set
> > > 
> > > > Unless we are also going to enforce the idev is always attached to a
> > > > nested then I don't think we need to check it here.
> > > > 
> > > > Things will definately not entirely work as expected if the vdev is
> > > > directly attached to the s2 or a blocking, but it won't harm anything.
> > > 
> > > My view is that, the moment there is a VIOMMU object, that must
> > > be a nested IOMMU case, so there must be a nested hwpt. Blocking
> > > domain would be a hwpt_nested too (vSTE=Abort) as we previously
> > > concluded.
> > 
> > I'm not sure other vendors can do that vSTE=Abort/Bypass thing though
> > yet..
> > 
> > > Then, in a nested case, it feels odd that an idev is attached to
> > > an S2 hwpt..
> > >
> > > That being said, I think we can still do that with validations:
> > >  If idev->hwpt is nested, compare input viommu v.s idev->hwpt->viommu.
> > >  If idev->hwpt is paging, compare input viommu->hwpt v.s idev->hwpt.
> > 
> > But again, if you don't contiguously validate those invariants in all
> > the other attach paths it is sort of pointless to check them since the
> > userspace can still violate things.
> 
> Hmm, would that be unsafe? I start to wonder if we should allow an
> attach to viommu and put validations on that?

I don't think it is unsafe to mismatch things, if a device is
disconnected from it's VIOMMU then the HW should isolate it the same
as anything else

It doesn't matter if the VIOMMU has a devid mapping for the device
when it is not currently part of the viommu configuration.

IOW it is not the devid ioctl that causes the device to join the
VIOMMU, it is the attach of the nest.

Jason
diff mbox series

Patch

diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index 5fd3dd420290..ed29bc606f5e 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -135,7 +135,14 @@  void iommufd_device_destroy(struct iommufd_object *obj)
 {
 	struct iommufd_device *idev =
 		container_of(obj, struct iommufd_device, obj);
+	struct iommufd_vdev_id *vdev_id, *curr;
 
+	list_for_each_entry(vdev_id, &idev->vdev_id_list, idev_item) {
+		curr = xa_cmpxchg(&vdev_id->viommu->vdev_ids, vdev_id->vdev_id,
+				  vdev_id, NULL, GFP_KERNEL);
+		WARN_ON(curr != vdev_id);
+		kfree(vdev_id);
+	}
 	iommu_device_release_dma_owner(idev->dev);
 	iommufd_put_group(idev->igroup);
 	if (!iommufd_selftest_is_mock_dev(idev->dev))
@@ -217,6 +224,8 @@  struct iommufd_device *iommufd_device_bind(struct iommufd_ctx *ictx,
 	idev->igroup = igroup;
 	mutex_init(&idev->iopf_lock);
 
+	INIT_LIST_HEAD(&idev->vdev_id_list);
+
 	/*
 	 * If the caller fails after this success it must call
 	 * iommufd_unbind_device() which is safe since we hold this refcount.
diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index 443575fd3dd4..10c63972b9ab 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -417,6 +417,7 @@  struct iommufd_device {
 	struct iommufd_ctx *ictx;
 	struct iommufd_group *igroup;
 	struct list_head group_item;
+	struct list_head vdev_id_list;
 	/* always the physical device */
 	struct device *dev;
 	bool enforce_cache_coherency;
@@ -533,12 +534,31 @@  struct iommufd_viommu {
 	struct iommufd_ctx *ictx;
 	struct iommu_device *iommu_dev;
 	struct iommufd_hwpt_paging *hwpt;
+	struct xarray vdev_ids;
 
 	unsigned int type;
 };
 
+struct iommufd_vdev_id {
+	struct iommufd_viommu *viommu;
+	struct device *dev;
+	u64 vdev_id;
+
+	struct list_head idev_item;
+};
+
+static inline struct iommufd_viommu *
+iommufd_get_viommu(struct iommufd_ucmd *ucmd, u32 id)
+{
+	return container_of(iommufd_get_object(ucmd->ictx, id,
+					       IOMMUFD_OBJ_VIOMMU),
+			    struct iommufd_viommu, obj);
+}
+
 int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd);
 void iommufd_viommu_destroy(struct iommufd_object *obj);
+int iommufd_viommu_set_vdev_id(struct iommufd_ucmd *ucmd);
+int iommufd_viommu_unset_vdev_id(struct iommufd_ucmd *ucmd);
 
 #ifdef CONFIG_IOMMUFD_TEST
 int iommufd_test(struct iommufd_ucmd *ucmd);
diff --git a/drivers/iommu/iommufd/main.c b/drivers/iommu/iommufd/main.c
index 288ee51b6829..199ad90fa36b 100644
--- a/drivers/iommu/iommufd/main.c
+++ b/drivers/iommu/iommufd/main.c
@@ -334,6 +334,8 @@  union ucmd_buffer {
 	struct iommu_option option;
 	struct iommu_vfio_ioas vfio_ioas;
 	struct iommu_viommu_alloc viommu;
+	struct iommu_viommu_set_vdev_id set_vdev_id;
+	struct iommu_viommu_unset_vdev_id unset_vdev_id;
 #ifdef CONFIG_IOMMUFD_TEST
 	struct iommu_test_cmd test;
 #endif
@@ -387,6 +389,10 @@  static const struct iommufd_ioctl_op iommufd_ioctl_ops[] = {
 		 __reserved),
 	IOCTL_OP(IOMMU_VIOMMU_ALLOC, iommufd_viommu_alloc_ioctl,
 		 struct iommu_viommu_alloc, out_viommu_id),
+	IOCTL_OP(IOMMU_VIOMMU_SET_VDEV_ID, iommufd_viommu_set_vdev_id,
+		 struct iommu_viommu_set_vdev_id, vdev_id),
+	IOCTL_OP(IOMMU_VIOMMU_UNSET_VDEV_ID, iommufd_viommu_unset_vdev_id,
+		 struct iommu_viommu_unset_vdev_id, vdev_id),
 #ifdef CONFIG_IOMMUFD_TEST
 	IOCTL_OP(IOMMU_TEST_CMD, iommufd_test, struct iommu_test_cmd, last),
 #endif
diff --git a/drivers/iommu/iommufd/viommu.c b/drivers/iommu/iommufd/viommu.c
index 35ad6a77c9c1..05a688a471db 100644
--- a/drivers/iommu/iommufd/viommu.c
+++ b/drivers/iommu/iommufd/viommu.c
@@ -10,7 +10,14 @@  void iommufd_viommu_destroy(struct iommufd_object *obj)
 {
 	struct iommufd_viommu *viommu =
 		container_of(obj, struct iommufd_viommu, obj);
+	struct iommufd_vdev_id *vdev_id;
+	unsigned long index;
 
+	xa_for_each(&viommu->vdev_ids, index, vdev_id) {
+		list_del(&vdev_id->idev_item);
+		kfree(vdev_id);
+	}
+	xa_destroy(&viommu->vdev_ids);
 	refcount_dec(&viommu->hwpt->common.obj.users);
 }
 
@@ -73,3 +80,114 @@  int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
 	iommufd_put_object(ucmd->ictx, &idev->obj);
 	return rc;
 }
+
+int iommufd_viommu_set_vdev_id(struct iommufd_ucmd *ucmd)
+{
+	struct iommu_viommu_set_vdev_id *cmd = ucmd->cmd;
+	struct iommufd_hwpt_nested *hwpt_nested;
+	struct iommufd_vdev_id *vdev_id, *curr;
+	struct iommufd_hw_pagetable *hwpt;
+	struct iommufd_viommu *viommu;
+	struct iommufd_device *idev;
+	int rc = 0;
+
+	if (cmd->vdev_id > ULONG_MAX)
+		return -EINVAL;
+
+	idev = iommufd_get_device(ucmd, cmd->dev_id);
+	if (IS_ERR(idev))
+		return PTR_ERR(idev);
+	hwpt = idev->igroup->hwpt;
+
+	if (hwpt == NULL || hwpt->obj.type != IOMMUFD_OBJ_HWPT_NESTED) {
+		rc = -EINVAL;
+		goto out_put_idev;
+	}
+	hwpt_nested = container_of(hwpt, struct iommufd_hwpt_nested, common);
+
+	viommu = iommufd_get_viommu(ucmd, cmd->viommu_id);
+	if (IS_ERR(viommu)) {
+		rc = PTR_ERR(viommu);
+		goto out_put_idev;
+	}
+
+	if (hwpt_nested->viommu != viommu) {
+		rc = -EINVAL;
+		goto out_put_viommu;
+	}
+
+	vdev_id = kzalloc(sizeof(*vdev_id), GFP_KERNEL);
+	if (IS_ERR(vdev_id)) {
+		rc = PTR_ERR(vdev_id);
+		goto out_put_viommu;
+	}
+
+	vdev_id->viommu = viommu;
+	vdev_id->dev = idev->dev;
+	vdev_id->vdev_id = cmd->vdev_id;
+
+	curr = xa_cmpxchg(&viommu->vdev_ids, cmd->vdev_id,
+			  NULL, vdev_id, GFP_KERNEL);
+	if (curr) {
+		rc = xa_err(curr) ? : -EBUSY;
+		goto out_free_vdev_id;
+	}
+
+	list_add_tail(&vdev_id->idev_item, &idev->vdev_id_list);
+	goto out_put_viommu;
+
+out_free_vdev_id:
+	kfree(vdev_id);
+out_put_viommu:
+	iommufd_put_object(ucmd->ictx, &viommu->obj);
+out_put_idev:
+	iommufd_put_object(ucmd->ictx, &idev->obj);
+	return rc;
+}
+
+static struct device *
+iommufd_viommu_find_device(struct iommufd_viommu *viommu, u64 id)
+{
+	struct iommufd_vdev_id *vdev_id;
+
+	xa_lock(&viommu->vdev_ids);
+	vdev_id = xa_load(&viommu->vdev_ids, (unsigned long)id);
+	xa_unlock(&viommu->vdev_ids);
+	if (!vdev_id || vdev_id->vdev_id != id)
+		return NULL;
+	return vdev_id->dev;
+}
+
+int iommufd_viommu_unset_vdev_id(struct iommufd_ucmd *ucmd)
+{
+	struct iommu_viommu_unset_vdev_id *cmd = ucmd->cmd;
+	struct iommufd_vdev_id *vdev_id;
+	struct iommufd_viommu *viommu;
+	struct iommufd_device *idev;
+	int rc = 0;
+
+	idev = iommufd_get_device(ucmd, cmd->dev_id);
+	if (IS_ERR(idev))
+		return PTR_ERR(idev);
+
+	viommu = iommufd_get_viommu(ucmd, cmd->viommu_id);
+	if (IS_ERR(viommu)) {
+		rc = PTR_ERR(viommu);
+		goto out_put_idev;
+	}
+
+	if (idev->dev != iommufd_viommu_find_device(viommu, cmd->vdev_id)) {
+		rc = -EINVAL;
+		goto out_put_viommu;
+	}
+
+	vdev_id = xa_erase(&viommu->vdev_ids, cmd->vdev_id);
+	list_del(&vdev_id->idev_item);
+	kfree(vdev_id);
+
+out_put_viommu:
+	iommufd_put_object(ucmd->ictx, &viommu->obj);
+out_put_idev:
+	iommufd_put_object(ucmd->ictx, &idev->obj);
+	return rc;
+}
diff --git a/include/uapi/linux/iommufd.h b/include/uapi/linux/iommufd.h
index 0e384331a9c8..d5e72682ba57 100644
--- a/include/uapi/linux/iommufd.h
+++ b/include/uapi/linux/iommufd.h
@@ -52,6 +52,8 @@  enum {
 	IOMMUFD_CMD_HWPT_INVALIDATE = 0x8d,
 	IOMMUFD_CMD_FAULT_QUEUE_ALLOC = 0x8e,
 	IOMMUFD_CMD_VIOMMU_ALLOC = 0x8f,
+	IOMMUFD_CMD_VIOMMU_SET_VDEV_ID = 0x90,
+	IOMMUFD_CMD_VIOMMU_UNSET_VDEV_ID = 0x91,
 };
 
 /**
@@ -906,4 +908,42 @@  struct iommu_viommu_alloc {
 	__u32 out_viommu_id;
 };
 #define IOMMU_VIOMMU_ALLOC _IO(IOMMUFD_TYPE, IOMMUFD_CMD_VIOMMU_ALLOC)
+
+/**
+ * struct iommu_viommu_set_vdev_id - ioctl(IOMMU_VIOMMU_SET_VDEV_ID)
+ * @size: sizeof(struct iommu_viommu_set_vdev_id)
+ * @viommu_id: viommu ID to associate with the device to store its virtual ID
+ * @dev_id: device ID to set its virtual ID
+ * @__reserved: Must be 0
+ * @vdev_id: Virtual device ID
+ *
+ * Set a viommu-specific virtual ID of a device
+ */
+struct iommu_viommu_set_vdev_id {
+	__u32 size;
+	__u32 viommu_id;
+	__u32 dev_id;
+	__u32 __reserved;
+	__aligned_u64 vdev_id;
+};
+#define IOMMU_VIOMMU_SET_VDEV_ID _IO(IOMMUFD_TYPE, IOMMUFD_CMD_VIOMMU_SET_VDEV_ID)
+
+/**
+ * struct iommu_viommu_unset_vdev_id - ioctl(IOMMU_VIOMMU_UNSET_VDEV_ID)
+ * @size: sizeof(struct iommu_viommu_unset_vdev_id)
+ * @viommu_id: viommu ID associated with the device to delete its virtual ID
+ * @dev_id: device ID to unset its virtual ID
+ * @__reserved: Must be 0
+ * @vdev_id: Virtual device ID (for verification)
+ *
+ * Unset a viommu-specific virtual ID of a device
+ */
+struct iommu_viommu_unset_vdev_id {
+	__u32 size;
+	__u32 viommu_id;
+	__u32 dev_id;
+	__u32 __reserved;
+	__aligned_u64 vdev_id;
+};
+#define IOMMU_VIOMMU_UNSET_VDEV_ID _IO(IOMMUFD_TYPE, IOMMUFD_CMD_VIOMMU_UNSET_VDEV_ID)
 #endif