diff mbox series

[v4,01/10] iommu: Introduce a replace API for device pasid

Message ID 20240912131255.13305-2-yi.l.liu@intel.com
State New
Headers show
Series iommufd support pasid attach/replace | expand

Commit Message

Yi Liu Sept. 12, 2024, 1:12 p.m. UTC
Provide a high-level API to allow replacements of one domain with
another for specific pasid of a device. This is similar to
iommu_group_replace_domain() and it is expected to be used only by
IOMMUFD.

Co-developed-by: Lu Baolu <baolu.lu@linux.intel.com>
Signed-off-by: Lu Baolu <baolu.lu@linux.intel.com>
Signed-off-by: Yi Liu <yi.l.liu@intel.com>
---
 drivers/iommu/iommu-priv.h |  4 ++
 drivers/iommu/iommu.c      | 90 ++++++++++++++++++++++++++++++++++++--
 2 files changed, 90 insertions(+), 4 deletions(-)

Comments

Baolu Lu Sept. 13, 2024, 2:44 a.m. UTC | #1
On 9/12/24 9:12 PM, Yi Liu wrote:
> Provide a high-level API to allow replacements of one domain with
> another for specific pasid of a device. This is similar to
> iommu_group_replace_domain() and it is expected to be used only by
> IOMMUFD.
> 
> Co-developed-by: Lu Baolu <baolu.lu@linux.intel.com>
> Signed-off-by: Lu Baolu <baolu.lu@linux.intel.com>
> Signed-off-by: Yi Liu <yi.l.liu@intel.com>
> ---
>   drivers/iommu/iommu-priv.h |  4 ++
>   drivers/iommu/iommu.c      | 90 ++++++++++++++++++++++++++++++++++++--
>   2 files changed, 90 insertions(+), 4 deletions(-)
> 
> diff --git a/drivers/iommu/iommu-priv.h b/drivers/iommu/iommu-priv.h
> index de5b54eaa8bf..90b367de267e 100644
> --- a/drivers/iommu/iommu-priv.h
> +++ b/drivers/iommu/iommu-priv.h
> @@ -27,6 +27,10 @@ static inline const struct iommu_ops *iommu_fwspec_ops(struct iommu_fwspec *fwsp
>   int iommu_group_replace_domain(struct iommu_group *group,
>   			       struct iommu_domain *new_domain);
>   
> +int iommu_replace_device_pasid(struct iommu_domain *domain,
> +			       struct device *dev, ioasid_t pasid,
> +			       struct iommu_attach_handle *handle);
> +
>   int iommu_device_register_bus(struct iommu_device *iommu,
>   			      const struct iommu_ops *ops,
>   			      const struct bus_type *bus,
> diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
> index b6b44b184004..066f659018a5 100644
> --- a/drivers/iommu/iommu.c
> +++ b/drivers/iommu/iommu.c
> @@ -3347,14 +3347,15 @@ static void iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
>   }
>   
>   static int __iommu_set_group_pasid(struct iommu_domain *domain,
> -				   struct iommu_group *group, ioasid_t pasid)
> +				   struct iommu_group *group, ioasid_t pasid,
> +				   struct iommu_domain *old)
>   {
>   	struct group_device *device, *last_gdev;
>   	int ret;
>   
>   	for_each_group_device(group, device) {
>   		ret = domain->ops->set_dev_pasid(domain, device->dev,
> -						 pasid, NULL);
> +						 pasid, old);
>   		if (ret)
>   			goto err_revert;
>   	}
> @@ -3366,7 +3367,20 @@ static int __iommu_set_group_pasid(struct iommu_domain *domain,
>   	for_each_group_device(group, device) {
>   		if (device == last_gdev)
>   			break;
> -		iommu_remove_dev_pasid(device->dev, pasid, domain);
> +		/* If no old domain, undo the succeeded devices/pasid */
> +		if (!old) {
> +			iommu_remove_dev_pasid(device->dev, pasid, domain);
> +			continue;
> +		}
> +
> +		/*
> +		 * Rollback the succeeded devices/pasid to the old domain.
> +		 * And it is a driver bug to fail attaching with a previously
> +		 * good domain.
> +		 */
> +		if (WARN_ON(old->ops->set_dev_pasid(old, device->dev,
> +						    pasid, domain)))
> +			iommu_remove_dev_pasid(device->dev, pasid, domain);

You want to rollback to the 'old' domain, right? So, %s/domain/old/ ?

>   	}
>   	return ret;
>   }
> @@ -3425,7 +3439,7 @@ int iommu_attach_device_pasid(struct iommu_domain *domain,
>   	if (ret)
>   		goto out_unlock;
>   
> -	ret = __iommu_set_group_pasid(domain, group, pasid);
> +	ret = __iommu_set_group_pasid(domain, group, pasid, NULL);
>   	if (ret)
>   		xa_erase(&group->pasid_array, pasid);
>   out_unlock:
> @@ -3434,6 +3448,74 @@ int iommu_attach_device_pasid(struct iommu_domain *domain,
>   }
>   EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
>   
> +/**
> + * iommu_replace_device_pasid - Replace the domain that a pasid is attached to
> + * @domain: the new iommu domain
> + * @dev: the attached device.
> + * @pasid: the pasid of the device.
> + * @handle: the attach handle.
> + *
> + * This API allows the pasid to switch domains. Return 0 on success, or an
> + * error. The pasid will keep the old configuration if replacement failed.
> + * This is supposed to be used by iommufd, and iommufd can guarantee that
> + * both iommu_attach_device_pasid() and iommu_replace_device_pasid() would
> + * pass in a valid @handle.
> + */
> +int iommu_replace_device_pasid(struct iommu_domain *domain,
> +			       struct device *dev, ioasid_t pasid,
> +			       struct iommu_attach_handle *handle)

How about passing the old domain as a parameter?

> +{
> +	/* Caller must be a probed driver on dev */
> +	struct iommu_group *group = dev->iommu_group;
> +	struct iommu_attach_handle *curr;
> +	int ret;
> +
> +	if (!domain->ops->set_dev_pasid)
> +		return -EOPNOTSUPP;
> +
> +	if (!group)
> +		return -ENODEV;
> +
> +	if (!dev_has_iommu(dev) || dev_iommu_ops(dev) != domain->owner ||
> +	    pasid == IOMMU_NO_PASID || !handle)

dev_has_iommu() check is duplicate with above if (!group) check.

By the way, why do you require a non-NULL attach handle? In the current
design, attach handles are only used for domains with iopf capability.

> +		return -EINVAL;
> +
> +	handle->domain = domain;
> +
> +	mutex_lock(&group->mutex);
> +	/*
> +	 * The iommu_attach_handle of the pasid becomes inconsistent with the
> +	 * actual handle per the below operation. The concurrent PRI path will
> +	 * deliver the PRQs per the new handle, this does not have a function
> +	 * impact. The PRI path would eventually become consistent when the
> +	 * replacement is done.
> +	 */
> +	curr = (struct iommu_attach_handle *)xa_store(&group->pasid_array,
> +						      pasid, handle,
> +						      GFP_KERNEL);
> +	if (!curr) {
> +		xa_erase(&group->pasid_array, pasid);
> +		ret = -EINVAL;
> +		goto out_unlock;
> +	}

This seems to be broken as explained above. The attach handle is
currently only for iopf-capable domains.

If I understand it correctly, you just want the previous attached domain
here, right? If so, why not just passing it to this helper from callers?

> +
> +	ret = xa_err(curr);
> +	if (ret)
> +		goto out_unlock;
> +
> +	if (curr->domain == domain)
> +		goto out_unlock;
> +
> +	ret = __iommu_set_group_pasid(domain, group, pasid, curr->domain);
> +	if (ret)
> +		WARN_ON(handle != xa_store(&group->pasid_array, pasid,
> +					   curr, GFP_KERNEL));
> +out_unlock:
> +	mutex_unlock(&group->mutex);
> +	return ret;
> +}
> +EXPORT_SYMBOL_NS_GPL(iommu_replace_device_pasid, IOMMUFD_INTERNAL);
> +
>   /*
>    * iommu_detach_device_pasid() - Detach the domain from pasid of device
>    * @domain: the iommu domain.

Thanks,
baolu
Yi Liu Sept. 13, 2024, 12:04 p.m. UTC | #2
On 2024/9/13 10:44, Baolu Lu wrote:
> On 9/12/24 9:12 PM, Yi Liu wrote:
>> Provide a high-level API to allow replacements of one domain with
>> another for specific pasid of a device. This is similar to
>> iommu_group_replace_domain() and it is expected to be used only by
>> IOMMUFD.
>>
>> Co-developed-by: Lu Baolu <baolu.lu@linux.intel.com>
>> Signed-off-by: Lu Baolu <baolu.lu@linux.intel.com>
>> Signed-off-by: Yi Liu <yi.l.liu@intel.com>
>> ---
>>   drivers/iommu/iommu-priv.h |  4 ++
>>   drivers/iommu/iommu.c      | 90 ++++++++++++++++++++++++++++++++++++--
>>   2 files changed, 90 insertions(+), 4 deletions(-)
>>
>> diff --git a/drivers/iommu/iommu-priv.h b/drivers/iommu/iommu-priv.h
>> index de5b54eaa8bf..90b367de267e 100644
>> --- a/drivers/iommu/iommu-priv.h
>> +++ b/drivers/iommu/iommu-priv.h
>> @@ -27,6 +27,10 @@ static inline const struct iommu_ops 
>> *iommu_fwspec_ops(struct iommu_fwspec *fwsp
>>   int iommu_group_replace_domain(struct iommu_group *group,
>>                      struct iommu_domain *new_domain);
>> +int iommu_replace_device_pasid(struct iommu_domain *domain,
>> +                   struct device *dev, ioasid_t pasid,
>> +                   struct iommu_attach_handle *handle);
>> +
>>   int iommu_device_register_bus(struct iommu_device *iommu,
>>                     const struct iommu_ops *ops,
>>                     const struct bus_type *bus,
>> diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
>> index b6b44b184004..066f659018a5 100644
>> --- a/drivers/iommu/iommu.c
>> +++ b/drivers/iommu/iommu.c
>> @@ -3347,14 +3347,15 @@ static void iommu_remove_dev_pasid(struct device 
>> *dev, ioasid_t pasid,
>>   }
>>   static int __iommu_set_group_pasid(struct iommu_domain *domain,
>> -                   struct iommu_group *group, ioasid_t pasid)
>> +                   struct iommu_group *group, ioasid_t pasid,
>> +                   struct iommu_domain *old)
>>   {
>>       struct group_device *device, *last_gdev;
>>       int ret;
>>       for_each_group_device(group, device) {
>>           ret = domain->ops->set_dev_pasid(domain, device->dev,
>> -                         pasid, NULL);
>> +                         pasid, old);
>>           if (ret)
>>               goto err_revert;
>>       }
>> @@ -3366,7 +3367,20 @@ static int __iommu_set_group_pasid(struct 
>> iommu_domain *domain,
>>       for_each_group_device(group, device) {
>>           if (device == last_gdev)
>>               break;
>> -        iommu_remove_dev_pasid(device->dev, pasid, domain);
>> +        /* If no old domain, undo the succeeded devices/pasid */
>> +        if (!old) {
>> +            iommu_remove_dev_pasid(device->dev, pasid, domain);
>> +            continue;
>> +        }
>> +
>> +        /*
>> +         * Rollback the succeeded devices/pasid to the old domain.
>> +         * And it is a driver bug to fail attaching with a previously
>> +         * good domain.
>> +         */
>> +        if (WARN_ON(old->ops->set_dev_pasid(old, device->dev,
>> +                            pasid, domain)))
>> +            iommu_remove_dev_pasid(device->dev, pasid, domain);
> 
> You want to rollback to the 'old' domain, right? So, %s/domain/old/ ?

this will be invoked if the rollback failed. Since the set_dev_pasid op
would keep the 'old' configure, so at this point, the 'old' domain is 'domain'.

>>       }
>>       return ret;
>>   }
>> @@ -3425,7 +3439,7 @@ int iommu_attach_device_pasid(struct iommu_domain 
>> *domain,
>>       if (ret)
>>           goto out_unlock;
>> -    ret = __iommu_set_group_pasid(domain, group, pasid);
>> +    ret = __iommu_set_group_pasid(domain, group, pasid, NULL);
>>       if (ret)
>>           xa_erase(&group->pasid_array, pasid);
>>   out_unlock:
>> @@ -3434,6 +3448,74 @@ int iommu_attach_device_pasid(struct iommu_domain 
>> *domain,
>>   }
>>   EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
>> +/**
>> + * iommu_replace_device_pasid - Replace the domain that a pasid is 
>> attached to
>> + * @domain: the new iommu domain
>> + * @dev: the attached device.
>> + * @pasid: the pasid of the device.
>> + * @handle: the attach handle.
>> + *
>> + * This API allows the pasid to switch domains. Return 0 on success, or an
>> + * error. The pasid will keep the old configuration if replacement failed.
>> + * This is supposed to be used by iommufd, and iommufd can guarantee that
>> + * both iommu_attach_device_pasid() and iommu_replace_device_pasid() would
>> + * pass in a valid @handle.
>> + */
>> +int iommu_replace_device_pasid(struct iommu_domain *domain,
>> +                   struct device *dev, ioasid_t pasid,
>> +                   struct iommu_attach_handle *handle)
> 
> How about passing the old domain as a parameter?

I suppose it was agreed in the below link.

https://lore.kernel.org/linux-iommu/20240816124707.GZ2032816@nvidia.com/

>> +{
>> +    /* Caller must be a probed driver on dev */
>> +    struct iommu_group *group = dev->iommu_group;
>> +    struct iommu_attach_handle *curr;
>> +    int ret;
>> +
>> +    if (!domain->ops->set_dev_pasid)
>> +        return -EOPNOTSUPP;
>> +
>> +    if (!group)
>> +        return -ENODEV;
>> +
>> +    if (!dev_has_iommu(dev) || dev_iommu_ops(dev) != domain->owner ||
>> +        pasid == IOMMU_NO_PASID || !handle)
> 
> dev_has_iommu() check is duplicate with above if (!group) check.

I was just referring to the iommu_attach_device_pasid(). So both the two
path could drop the dev_has_iommu() check, is it?

> By the way, why do you require a non-NULL attach handle? In the current
> design, attach handles are only used for domains with iopf capability.

yeah, but it looks fine to always pass in an attach handle. The iopf
path would require hwpt->domain->iopf_handler.

>> +        return -EINVAL;
>> +
>> +    handle->domain = domain;
>> +
>> +    mutex_lock(&group->mutex);
>> +    /*
>> +     * The iommu_attach_handle of the pasid becomes inconsistent with the
>> +     * actual handle per the below operation. The concurrent PRI path will
>> +     * deliver the PRQs per the new handle, this does not have a function
>> +     * impact. The PRI path would eventually become consistent when the
>> +     * replacement is done.
>> +     */
>> +    curr = (struct iommu_attach_handle *)xa_store(&group->pasid_array,
>> +                              pasid, handle,
>> +                              GFP_KERNEL);
>> +    if (!curr) {
>> +        xa_erase(&group->pasid_array, pasid);
>> +        ret = -EINVAL;
>> +        goto out_unlock;
>> +    }
> 
> This seems to be broken as explained above. The attach handle is
> currently only for iopf-capable domains.

if attach handle is always passed, then this is not broken. is it?

> If I understand it correctly, you just want the previous attached domain
> here, right? If so, why not just passing it to this helper from callers?

yeah, I'm open about it. :) @Jason, your opinion?

>> +
>> +    ret = xa_err(curr);
>> +    if (ret)
>> +        goto out_unlock;
>> +
>> +    if (curr->domain == domain)
>> +        goto out_unlock;
>> +
>> +    ret = __iommu_set_group_pasid(domain, group, pasid, curr->domain);
>> +    if (ret)
>> +        WARN_ON(handle != xa_store(&group->pasid_array, pasid,
>> +                       curr, GFP_KERNEL));
>> +out_unlock:
>> +    mutex_unlock(&group->mutex);
>> +    return ret;
>> +}
>> +EXPORT_SYMBOL_NS_GPL(iommu_replace_device_pasid, IOMMUFD_INTERNAL);
>> +
>>   /*
>>    * iommu_detach_device_pasid() - Detach the domain from pasid of device
>>    * @domain: the iommu domain.
Tian, Kevin Sept. 30, 2024, 7:38 a.m. UTC | #3
> From: Liu, Yi L <yi.l.liu@intel.com>
> Sent: Thursday, September 12, 2024 9:13 PM
> 
> +/**
> + * iommu_replace_device_pasid - Replace the domain that a pasid is
> attached to
> + * @domain: the new iommu domain
> + * @dev: the attached device.
> + * @pasid: the pasid of the device.
> + * @handle: the attach handle.
> + *
> + * This API allows the pasid to switch domains. Return 0 on success, or an
> + * error. The pasid will keep the old configuration if replacement failed.
> + * This is supposed to be used by iommufd, and iommufd can guarantee
> that
> + * both iommu_attach_device_pasid() and iommu_replace_device_pasid()
> would
> + * pass in a valid @handle.

this function assumes handle is always valid. So above comment
makes it clear that iommufd is the only user and it will always
pass in a valid handle.

but the code in iommu_attach_device_pasid() allows handle to
be NULL. Then that comment is meaningless for it.

Also following patches are built on iommufd always passing in
a valid handle as it's required by pasid operations but there is
no detail explanation why it's mandatory or any alternative 
option exists. More explanation is welcomed.

> + */
> +int iommu_replace_device_pasid(struct iommu_domain *domain,
> +			       struct device *dev, ioasid_t pasid,
> +			       struct iommu_attach_handle *handle)
> +{
> +	/* Caller must be a probed driver on dev */
> +	struct iommu_group *group = dev->iommu_group;
> +	struct iommu_attach_handle *curr;
> +	int ret;
> +
> +	if (!domain->ops->set_dev_pasid)
> +		return -EOPNOTSUPP;
> +
> +	if (!group)
> +		return -ENODEV;
> +
> +	if (!dev_has_iommu(dev) || dev_iommu_ops(dev) != domain-
> >owner ||
> +	    pasid == IOMMU_NO_PASID || !handle)
> +		return -EINVAL;
> +
> +	handle->domain = domain;
> +
> +	mutex_lock(&group->mutex);
> +	/*
> +	 * The iommu_attach_handle of the pasid becomes inconsistent with
> the
> +	 * actual handle per the below operation. The concurrent PRI path
> will
> +	 * deliver the PRQs per the new handle, this does not have a function
> +	 * impact. The PRI path would eventually become consistent when

s/function/functional/

> the
> +	 * replacement is done.
> +	 */
> +	curr = (struct iommu_attach_handle *)xa_store(&group->pasid_array,
> +						      pasid, handle,
> +						      GFP_KERNEL);

Could you elaborate why the PRI path will eventually becomes
consistent with this path?
Yi Liu Oct. 12, 2024, 4:31 a.m. UTC | #4
On 2024/9/30 15:38, Tian, Kevin wrote:
>> From: Liu, Yi L <yi.l.liu@intel.com>
>> Sent: Thursday, September 12, 2024 9:13 PM
>>
>> +/**
>> + * iommu_replace_device_pasid - Replace the domain that a pasid is
>> attached to
>> + * @domain: the new iommu domain
>> + * @dev: the attached device.
>> + * @pasid: the pasid of the device.
>> + * @handle: the attach handle.
>> + *
>> + * This API allows the pasid to switch domains. Return 0 on success, or an
>> + * error. The pasid will keep the old configuration if replacement failed.
>> + * This is supposed to be used by iommufd, and iommufd can guarantee
>> that
>> + * both iommu_attach_device_pasid() and iommu_replace_device_pasid()
>> would
>> + * pass in a valid @handle.
> 
> this function assumes handle is always valid. So above comment
> makes it clear that iommufd is the only user and it will always
> pass in a valid handle.
> 
> but the code in iommu_attach_device_pasid() allows handle to
> be NULL. Then that comment is meaningless for it.

Actually, this is why I added the above comment. iommufd can ensure
it would pass valid handle to both iommu_attach_device_pasid() and
iommu_replace_device_pasid(), and iommu_replace_device_pasid() is
only used by iommufd, so iommu_replace_device_pasid() can assume
all the pasids have a valid handle stored in the pasid_array.

> 
> Also following patches are built on iommufd always passing in
> a valid handle as it's required by pasid operations but there is
> no detail explanation why it's mandatory or any alternative
> option exists. More explanation is welcomed.

There is more detail about it in the below link, but is it necessary
to add them in the comment as well, or is it ok to add more explanation
in commit message?

https://lore.kernel.org/linux-iommu/0bf383b7-ed96-49ca-b1da-d1fff48e161a@intel.com/

>> + */
>> +int iommu_replace_device_pasid(struct iommu_domain *domain,
>> +			       struct device *dev, ioasid_t pasid,
>> +			       struct iommu_attach_handle *handle)
>> +{
>> +	/* Caller must be a probed driver on dev */
>> +	struct iommu_group *group = dev->iommu_group;
>> +	struct iommu_attach_handle *curr;
>> +	int ret;
>> +
>> +	if (!domain->ops->set_dev_pasid)
>> +		return -EOPNOTSUPP;
>> +
>> +	if (!group)
>> +		return -ENODEV;
>> +
>> +	if (!dev_has_iommu(dev) || dev_iommu_ops(dev) != domain-
>>> owner ||
>> +	    pasid == IOMMU_NO_PASID || !handle)
>> +		return -EINVAL;
>> +
>> +	handle->domain = domain;
>> +
>> +	mutex_lock(&group->mutex);
>> +	/*
>> +	 * The iommu_attach_handle of the pasid becomes inconsistent with
>> the
>> +	 * actual handle per the below operation. The concurrent PRI path
>> will
>> +	 * deliver the PRQs per the new handle, this does not have a function
>> +	 * impact. The PRI path would eventually become consistent when
> 
> s/function/functional/

got it.

>> the
>> +	 * replacement is done.
>> +	 */
>> +	curr = (struct iommu_attach_handle *)xa_store(&group->pasid_array,
>> +						      pasid, handle,
>> +						      GFP_KERNEL);
> 
> Could you elaborate why the PRI path will eventually becomes
> consistent with this path?

Because the handle stored in pasid_array would be consistent with the
configuration of pasid. So the PRI would be forwarded to the correct
domain.
diff mbox series

Patch

diff --git a/drivers/iommu/iommu-priv.h b/drivers/iommu/iommu-priv.h
index de5b54eaa8bf..90b367de267e 100644
--- a/drivers/iommu/iommu-priv.h
+++ b/drivers/iommu/iommu-priv.h
@@ -27,6 +27,10 @@  static inline const struct iommu_ops *iommu_fwspec_ops(struct iommu_fwspec *fwsp
 int iommu_group_replace_domain(struct iommu_group *group,
 			       struct iommu_domain *new_domain);
 
+int iommu_replace_device_pasid(struct iommu_domain *domain,
+			       struct device *dev, ioasid_t pasid,
+			       struct iommu_attach_handle *handle);
+
 int iommu_device_register_bus(struct iommu_device *iommu,
 			      const struct iommu_ops *ops,
 			      const struct bus_type *bus,
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index b6b44b184004..066f659018a5 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -3347,14 +3347,15 @@  static void iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
 }
 
 static int __iommu_set_group_pasid(struct iommu_domain *domain,
-				   struct iommu_group *group, ioasid_t pasid)
+				   struct iommu_group *group, ioasid_t pasid,
+				   struct iommu_domain *old)
 {
 	struct group_device *device, *last_gdev;
 	int ret;
 
 	for_each_group_device(group, device) {
 		ret = domain->ops->set_dev_pasid(domain, device->dev,
-						 pasid, NULL);
+						 pasid, old);
 		if (ret)
 			goto err_revert;
 	}
@@ -3366,7 +3367,20 @@  static int __iommu_set_group_pasid(struct iommu_domain *domain,
 	for_each_group_device(group, device) {
 		if (device == last_gdev)
 			break;
-		iommu_remove_dev_pasid(device->dev, pasid, domain);
+		/* If no old domain, undo the succeeded devices/pasid */
+		if (!old) {
+			iommu_remove_dev_pasid(device->dev, pasid, domain);
+			continue;
+		}
+
+		/*
+		 * Rollback the succeeded devices/pasid to the old domain.
+		 * And it is a driver bug to fail attaching with a previously
+		 * good domain.
+		 */
+		if (WARN_ON(old->ops->set_dev_pasid(old, device->dev,
+						    pasid, domain)))
+			iommu_remove_dev_pasid(device->dev, pasid, domain);
 	}
 	return ret;
 }
@@ -3425,7 +3439,7 @@  int iommu_attach_device_pasid(struct iommu_domain *domain,
 	if (ret)
 		goto out_unlock;
 
-	ret = __iommu_set_group_pasid(domain, group, pasid);
+	ret = __iommu_set_group_pasid(domain, group, pasid, NULL);
 	if (ret)
 		xa_erase(&group->pasid_array, pasid);
 out_unlock:
@@ -3434,6 +3448,74 @@  int iommu_attach_device_pasid(struct iommu_domain *domain,
 }
 EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
 
+/**
+ * iommu_replace_device_pasid - Replace the domain that a pasid is attached to
+ * @domain: the new iommu domain
+ * @dev: the attached device.
+ * @pasid: the pasid of the device.
+ * @handle: the attach handle.
+ *
+ * This API allows the pasid to switch domains. Return 0 on success, or an
+ * error. The pasid will keep the old configuration if replacement failed.
+ * This is supposed to be used by iommufd, and iommufd can guarantee that
+ * both iommu_attach_device_pasid() and iommu_replace_device_pasid() would
+ * pass in a valid @handle.
+ */
+int iommu_replace_device_pasid(struct iommu_domain *domain,
+			       struct device *dev, ioasid_t pasid,
+			       struct iommu_attach_handle *handle)
+{
+	/* Caller must be a probed driver on dev */
+	struct iommu_group *group = dev->iommu_group;
+	struct iommu_attach_handle *curr;
+	int ret;
+
+	if (!domain->ops->set_dev_pasid)
+		return -EOPNOTSUPP;
+
+	if (!group)
+		return -ENODEV;
+
+	if (!dev_has_iommu(dev) || dev_iommu_ops(dev) != domain->owner ||
+	    pasid == IOMMU_NO_PASID || !handle)
+		return -EINVAL;
+
+	handle->domain = domain;
+
+	mutex_lock(&group->mutex);
+	/*
+	 * The iommu_attach_handle of the pasid becomes inconsistent with the
+	 * actual handle per the below operation. The concurrent PRI path will
+	 * deliver the PRQs per the new handle, this does not have a function
+	 * impact. The PRI path would eventually become consistent when the
+	 * replacement is done.
+	 */
+	curr = (struct iommu_attach_handle *)xa_store(&group->pasid_array,
+						      pasid, handle,
+						      GFP_KERNEL);
+	if (!curr) {
+		xa_erase(&group->pasid_array, pasid);
+		ret = -EINVAL;
+		goto out_unlock;
+	}
+
+	ret = xa_err(curr);
+	if (ret)
+		goto out_unlock;
+
+	if (curr->domain == domain)
+		goto out_unlock;
+
+	ret = __iommu_set_group_pasid(domain, group, pasid, curr->domain);
+	if (ret)
+		WARN_ON(handle != xa_store(&group->pasid_array, pasid,
+					   curr, GFP_KERNEL));
+out_unlock:
+	mutex_unlock(&group->mutex);
+	return ret;
+}
+EXPORT_SYMBOL_NS_GPL(iommu_replace_device_pasid, IOMMUFD_INTERNAL);
+
 /*
  * iommu_detach_device_pasid() - Detach the domain from pasid of device
  * @domain: the iommu domain.