Commit 2a3d15f2 authored by Jason Gunthorpe's avatar Jason Gunthorpe Committed by Alex Williamson

vfio/mdev: Add missing typesafety around mdev_device

The mdev API should accept and pass a 'struct mdev_device *' in all
places, not pass a 'struct device *' and cast it internally with
to_mdev_device(). Particularly in its struct mdev_driver functions, the
whole point of a bus's struct device_driver wrapper is to provide type
safety compared to the default struct device_driver.

Further, the driver core standard is for bus drivers to expose their
device structure in their public headers that can be used with
container_of() inlines and '&foo->dev' to go between the class levels, and
'&foo->dev' to be used with dev_err/etc driver core helper functions. Move
'struct mdev_device' to mdev.h

Once done this allows moving some one instruction exported functions to
static inlines, which in turns allows removing one of the two grotesque
symbol_get()'s related to mdev in the core code.
Reviewed-by: default avatarKevin Tian <kevin.tian@intel.com>
Reviewed-by: default avatarCornelia Huck <cohuck@redhat.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@nvidia.com>
Reviewed-by: default avatarChristoph Hellwig <hch@lst.de>
Message-Id: <3-v2-d36939638fc6+d54-vfio2_jgg@nvidia.com>
Signed-off-by: default avatarAlex Williamson <alex.williamson@redhat.com>
parent b5a1f892
...@@ -105,8 +105,8 @@ structure to represent a mediated device's driver:: ...@@ -105,8 +105,8 @@ structure to represent a mediated device's driver::
*/ */
struct mdev_driver { struct mdev_driver {
const char *name; const char *name;
int (*probe) (struct device *dev); int (*probe) (struct mdev_device *dev);
void (*remove) (struct device *dev); void (*remove) (struct mdev_device *dev);
struct device_driver driver; struct device_driver driver;
}; };
......
...@@ -33,36 +33,6 @@ struct device *mdev_parent_dev(struct mdev_device *mdev) ...@@ -33,36 +33,6 @@ struct device *mdev_parent_dev(struct mdev_device *mdev)
} }
EXPORT_SYMBOL(mdev_parent_dev); EXPORT_SYMBOL(mdev_parent_dev);
void *mdev_get_drvdata(struct mdev_device *mdev)
{
return mdev->driver_data;
}
EXPORT_SYMBOL(mdev_get_drvdata);
void mdev_set_drvdata(struct mdev_device *mdev, void *data)
{
mdev->driver_data = data;
}
EXPORT_SYMBOL(mdev_set_drvdata);
struct device *mdev_dev(struct mdev_device *mdev)
{
return &mdev->dev;
}
EXPORT_SYMBOL(mdev_dev);
struct mdev_device *mdev_from_dev(struct device *dev)
{
return dev_is_mdev(dev) ? to_mdev_device(dev) : NULL;
}
EXPORT_SYMBOL(mdev_from_dev);
const guid_t *mdev_uuid(struct mdev_device *mdev)
{
return &mdev->uuid;
}
EXPORT_SYMBOL(mdev_uuid);
/* Should be called holding parent_list_lock */ /* Should be called holding parent_list_lock */
static struct mdev_parent *__find_parent_device(struct device *dev) static struct mdev_parent *__find_parent_device(struct device *dev)
{ {
...@@ -107,7 +77,7 @@ static void mdev_device_remove_common(struct mdev_device *mdev) ...@@ -107,7 +77,7 @@ static void mdev_device_remove_common(struct mdev_device *mdev)
int ret; int ret;
type = to_mdev_type(mdev->type_kobj); type = to_mdev_type(mdev->type_kobj);
mdev_remove_sysfs_files(&mdev->dev, type); mdev_remove_sysfs_files(mdev, type);
device_del(&mdev->dev); device_del(&mdev->dev);
parent = mdev->parent; parent = mdev->parent;
lockdep_assert_held(&parent->unreg_sem); lockdep_assert_held(&parent->unreg_sem);
...@@ -122,12 +92,10 @@ static void mdev_device_remove_common(struct mdev_device *mdev) ...@@ -122,12 +92,10 @@ static void mdev_device_remove_common(struct mdev_device *mdev)
static int mdev_device_remove_cb(struct device *dev, void *data) static int mdev_device_remove_cb(struct device *dev, void *data)
{ {
if (dev_is_mdev(dev)) { struct mdev_device *mdev = mdev_from_dev(dev);
struct mdev_device *mdev;
mdev = to_mdev_device(dev); if (mdev)
mdev_device_remove_common(mdev); mdev_device_remove_common(mdev);
}
return 0; return 0;
} }
...@@ -332,7 +300,7 @@ int mdev_device_create(struct kobject *kobj, ...@@ -332,7 +300,7 @@ int mdev_device_create(struct kobject *kobj,
if (ret) if (ret)
goto add_fail; goto add_fail;
ret = mdev_create_sysfs_files(&mdev->dev, type); ret = mdev_create_sysfs_files(mdev, type);
if (ret) if (ret)
goto sysfs_fail; goto sysfs_fail;
...@@ -354,13 +322,11 @@ int mdev_device_create(struct kobject *kobj, ...@@ -354,13 +322,11 @@ int mdev_device_create(struct kobject *kobj,
return ret; return ret;
} }
int mdev_device_remove(struct device *dev) int mdev_device_remove(struct mdev_device *mdev)
{ {
struct mdev_device *mdev, *tmp; struct mdev_device *tmp;
struct mdev_parent *parent; struct mdev_parent *parent;
mdev = to_mdev_device(dev);
mutex_lock(&mdev_list_lock); mutex_lock(&mdev_list_lock);
list_for_each_entry(tmp, &mdev_list, next) { list_for_each_entry(tmp, &mdev_list, next) {
if (tmp == mdev) if (tmp == mdev)
...@@ -390,24 +356,6 @@ int mdev_device_remove(struct device *dev) ...@@ -390,24 +356,6 @@ int mdev_device_remove(struct device *dev)
return 0; return 0;
} }
int mdev_set_iommu_device(struct device *dev, struct device *iommu_device)
{
struct mdev_device *mdev = to_mdev_device(dev);
mdev->iommu_device = iommu_device;
return 0;
}
EXPORT_SYMBOL(mdev_set_iommu_device);
struct device *mdev_get_iommu_device(struct device *dev)
{
struct mdev_device *mdev = to_mdev_device(dev);
return mdev->iommu_device;
}
EXPORT_SYMBOL(mdev_get_iommu_device);
static int __init mdev_init(void) static int __init mdev_init(void)
{ {
return mdev_bus_register(); return mdev_bus_register();
......
...@@ -48,7 +48,7 @@ static int mdev_probe(struct device *dev) ...@@ -48,7 +48,7 @@ static int mdev_probe(struct device *dev)
return ret; return ret;
if (drv && drv->probe) { if (drv && drv->probe) {
ret = drv->probe(dev); ret = drv->probe(mdev);
if (ret) if (ret)
mdev_detach_iommu(mdev); mdev_detach_iommu(mdev);
} }
...@@ -62,7 +62,7 @@ static int mdev_remove(struct device *dev) ...@@ -62,7 +62,7 @@ static int mdev_remove(struct device *dev)
struct mdev_device *mdev = to_mdev_device(dev); struct mdev_device *mdev = to_mdev_device(dev);
if (drv && drv->remove) if (drv && drv->remove)
drv->remove(dev); drv->remove(mdev);
mdev_detach_iommu(mdev); mdev_detach_iommu(mdev);
......
...@@ -24,23 +24,6 @@ struct mdev_parent { ...@@ -24,23 +24,6 @@ struct mdev_parent {
struct rw_semaphore unreg_sem; struct rw_semaphore unreg_sem;
}; };
struct mdev_device {
struct device dev;
struct mdev_parent *parent;
guid_t uuid;
void *driver_data;
struct list_head next;
struct kobject *type_kobj;
struct device *iommu_device;
bool active;
};
static inline struct mdev_device *to_mdev_device(struct device *dev)
{
return container_of(dev, struct mdev_device, dev);
}
#define dev_is_mdev(d) ((d)->bus == &mdev_bus_type)
struct mdev_type { struct mdev_type {
struct kobject kobj; struct kobject kobj;
struct kobject *devices_kobj; struct kobject *devices_kobj;
...@@ -57,11 +40,11 @@ struct mdev_type { ...@@ -57,11 +40,11 @@ struct mdev_type {
int parent_create_sysfs_files(struct mdev_parent *parent); int parent_create_sysfs_files(struct mdev_parent *parent);
void parent_remove_sysfs_files(struct mdev_parent *parent); void parent_remove_sysfs_files(struct mdev_parent *parent);
int mdev_create_sysfs_files(struct device *dev, struct mdev_type *type); int mdev_create_sysfs_files(struct mdev_device *mdev, struct mdev_type *type);
void mdev_remove_sysfs_files(struct device *dev, struct mdev_type *type); void mdev_remove_sysfs_files(struct mdev_device *mdev, struct mdev_type *type);
int mdev_device_create(struct kobject *kobj, int mdev_device_create(struct kobject *kobj,
struct device *dev, const guid_t *uuid); struct device *dev, const guid_t *uuid);
int mdev_device_remove(struct device *dev); int mdev_device_remove(struct mdev_device *dev);
#endif /* MDEV_PRIVATE_H */ #endif /* MDEV_PRIVATE_H */
...@@ -225,6 +225,7 @@ int parent_create_sysfs_files(struct mdev_parent *parent) ...@@ -225,6 +225,7 @@ int parent_create_sysfs_files(struct mdev_parent *parent)
static ssize_t remove_store(struct device *dev, struct device_attribute *attr, static ssize_t remove_store(struct device *dev, struct device_attribute *attr,
const char *buf, size_t count) const char *buf, size_t count)
{ {
struct mdev_device *mdev = to_mdev_device(dev);
unsigned long val; unsigned long val;
if (kstrtoul(buf, 0, &val) < 0) if (kstrtoul(buf, 0, &val) < 0)
...@@ -233,7 +234,7 @@ static ssize_t remove_store(struct device *dev, struct device_attribute *attr, ...@@ -233,7 +234,7 @@ static ssize_t remove_store(struct device *dev, struct device_attribute *attr,
if (val && device_remove_file_self(dev, attr)) { if (val && device_remove_file_self(dev, attr)) {
int ret; int ret;
ret = mdev_device_remove(dev); ret = mdev_device_remove(mdev);
if (ret) if (ret)
return ret; return ret;
} }
...@@ -248,34 +249,37 @@ static const struct attribute *mdev_device_attrs[] = { ...@@ -248,34 +249,37 @@ static const struct attribute *mdev_device_attrs[] = {
NULL, NULL,
}; };
int mdev_create_sysfs_files(struct device *dev, struct mdev_type *type) int mdev_create_sysfs_files(struct mdev_device *mdev, struct mdev_type *type)
{ {
struct kobject *kobj = &mdev->dev.kobj;
int ret; int ret;
ret = sysfs_create_link(type->devices_kobj, &dev->kobj, dev_name(dev)); ret = sysfs_create_link(type->devices_kobj, kobj, dev_name(&mdev->dev));
if (ret) if (ret)
return ret; return ret;
ret = sysfs_create_link(&dev->kobj, &type->kobj, "mdev_type"); ret = sysfs_create_link(kobj, &type->kobj, "mdev_type");
if (ret) if (ret)
goto type_link_failed; goto type_link_failed;
ret = sysfs_create_files(&dev->kobj, mdev_device_attrs); ret = sysfs_create_files(kobj, mdev_device_attrs);
if (ret) if (ret)
goto create_files_failed; goto create_files_failed;
return ret; return ret;
create_files_failed: create_files_failed:
sysfs_remove_link(&dev->kobj, "mdev_type"); sysfs_remove_link(kobj, "mdev_type");
type_link_failed: type_link_failed:
sysfs_remove_link(type->devices_kobj, dev_name(dev)); sysfs_remove_link(type->devices_kobj, dev_name(&mdev->dev));
return ret; return ret;
} }
void mdev_remove_sysfs_files(struct device *dev, struct mdev_type *type) void mdev_remove_sysfs_files(struct mdev_device *mdev, struct mdev_type *type)
{ {
sysfs_remove_files(&dev->kobj, mdev_device_attrs); struct kobject *kobj = &mdev->dev.kobj;
sysfs_remove_link(&dev->kobj, "mdev_type");
sysfs_remove_link(type->devices_kobj, dev_name(dev)); sysfs_remove_files(kobj, mdev_device_attrs);
sysfs_remove_link(kobj, "mdev_type");
sysfs_remove_link(type->devices_kobj, dev_name(&mdev->dev));
} }
...@@ -124,9 +124,8 @@ static const struct vfio_device_ops vfio_mdev_dev_ops = { ...@@ -124,9 +124,8 @@ static const struct vfio_device_ops vfio_mdev_dev_ops = {
.request = vfio_mdev_request, .request = vfio_mdev_request,
}; };
static int vfio_mdev_probe(struct device *dev) static int vfio_mdev_probe(struct mdev_device *mdev)
{ {
struct mdev_device *mdev = to_mdev_device(dev);
struct vfio_device *vdev; struct vfio_device *vdev;
int ret; int ret;
...@@ -144,9 +143,9 @@ static int vfio_mdev_probe(struct device *dev) ...@@ -144,9 +143,9 @@ static int vfio_mdev_probe(struct device *dev)
return 0; return 0;
} }
static void vfio_mdev_remove(struct device *dev) static void vfio_mdev_remove(struct mdev_device *mdev)
{ {
struct vfio_device *vdev = dev_get_drvdata(dev); struct vfio_device *vdev = dev_get_drvdata(&mdev->dev);
vfio_unregister_group_dev(vdev); vfio_unregister_group_dev(vdev);
kfree(vdev); kfree(vdev);
......
...@@ -1933,28 +1933,13 @@ static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions, ...@@ -1933,28 +1933,13 @@ static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
return ret; return ret;
} }
static struct device *vfio_mdev_get_iommu_device(struct device *dev)
{
struct device *(*fn)(struct device *dev);
struct device *iommu_device;
fn = symbol_get(mdev_get_iommu_device);
if (fn) {
iommu_device = fn(dev);
symbol_put(mdev_get_iommu_device);
return iommu_device;
}
return NULL;
}
static int vfio_mdev_attach_domain(struct device *dev, void *data) static int vfio_mdev_attach_domain(struct device *dev, void *data)
{ {
struct mdev_device *mdev = to_mdev_device(dev);
struct iommu_domain *domain = data; struct iommu_domain *domain = data;
struct device *iommu_device; struct device *iommu_device;
iommu_device = vfio_mdev_get_iommu_device(dev); iommu_device = mdev_get_iommu_device(mdev);
if (iommu_device) { if (iommu_device) {
if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX)) if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
return iommu_aux_attach_device(domain, iommu_device); return iommu_aux_attach_device(domain, iommu_device);
...@@ -1967,10 +1952,11 @@ static int vfio_mdev_attach_domain(struct device *dev, void *data) ...@@ -1967,10 +1952,11 @@ static int vfio_mdev_attach_domain(struct device *dev, void *data)
static int vfio_mdev_detach_domain(struct device *dev, void *data) static int vfio_mdev_detach_domain(struct device *dev, void *data)
{ {
struct mdev_device *mdev = to_mdev_device(dev);
struct iommu_domain *domain = data; struct iommu_domain *domain = data;
struct device *iommu_device; struct device *iommu_device;
iommu_device = vfio_mdev_get_iommu_device(dev); iommu_device = mdev_get_iommu_device(mdev);
if (iommu_device) { if (iommu_device) {
if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX)) if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
iommu_aux_detach_device(domain, iommu_device); iommu_aux_detach_device(domain, iommu_device);
...@@ -2018,9 +2004,10 @@ static bool vfio_bus_is_mdev(struct bus_type *bus) ...@@ -2018,9 +2004,10 @@ static bool vfio_bus_is_mdev(struct bus_type *bus)
static int vfio_mdev_iommu_device(struct device *dev, void *data) static int vfio_mdev_iommu_device(struct device *dev, void *data)
{ {
struct mdev_device *mdev = to_mdev_device(dev);
struct device **old = data, *new; struct device **old = data, *new;
new = vfio_mdev_get_iommu_device(dev); new = mdev_get_iommu_device(mdev);
if (!new || (*old && *old != new)) if (!new || (*old && *old != new))
return -EINVAL; return -EINVAL;
......
...@@ -10,7 +10,21 @@ ...@@ -10,7 +10,21 @@
#ifndef MDEV_H #ifndef MDEV_H
#define MDEV_H #define MDEV_H
struct mdev_device; struct mdev_device {
struct device dev;
struct mdev_parent *parent;
guid_t uuid;
void *driver_data;
struct list_head next;
struct kobject *type_kobj;
struct device *iommu_device;
bool active;
};
static inline struct mdev_device *to_mdev_device(struct device *dev)
{
return container_of(dev, struct mdev_device, dev);
}
/* /*
* Called by the parent device driver to set the device which represents * Called by the parent device driver to set the device which represents
...@@ -19,12 +33,17 @@ struct mdev_device; ...@@ -19,12 +33,17 @@ struct mdev_device;
* *
* @dev: the mediated device that iommu will isolate. * @dev: the mediated device that iommu will isolate.
* @iommu_device: a pci device which represents the iommu for @dev. * @iommu_device: a pci device which represents the iommu for @dev.
*
* Return 0 for success, otherwise negative error value.
*/ */
int mdev_set_iommu_device(struct device *dev, struct device *iommu_device); static inline void mdev_set_iommu_device(struct mdev_device *mdev,
struct device *iommu_device)
{
mdev->iommu_device = iommu_device;
}
struct device *mdev_get_iommu_device(struct device *dev); static inline struct device *mdev_get_iommu_device(struct mdev_device *mdev)
{
return mdev->iommu_device;
}
/** /**
* struct mdev_parent_ops - Structure to be registered for each parent device to * struct mdev_parent_ops - Structure to be registered for each parent device to
...@@ -126,16 +145,25 @@ struct mdev_type_attribute mdev_type_attr_##_name = \ ...@@ -126,16 +145,25 @@ struct mdev_type_attribute mdev_type_attr_##_name = \
**/ **/
struct mdev_driver { struct mdev_driver {
const char *name; const char *name;
int (*probe)(struct device *dev); int (*probe)(struct mdev_device *dev);
void (*remove)(struct device *dev); void (*remove)(struct mdev_device *dev);
struct device_driver driver; struct device_driver driver;
}; };
#define to_mdev_driver(drv) container_of(drv, struct mdev_driver, driver) #define to_mdev_driver(drv) container_of(drv, struct mdev_driver, driver)
void *mdev_get_drvdata(struct mdev_device *mdev); static inline void *mdev_get_drvdata(struct mdev_device *mdev)
void mdev_set_drvdata(struct mdev_device *mdev, void *data); {
const guid_t *mdev_uuid(struct mdev_device *mdev); return mdev->driver_data;
}
static inline void mdev_set_drvdata(struct mdev_device *mdev, void *data)
{
mdev->driver_data = data;
}
static inline const guid_t *mdev_uuid(struct mdev_device *mdev)
{
return &mdev->uuid;
}
extern struct bus_type mdev_bus_type; extern struct bus_type mdev_bus_type;
...@@ -146,7 +174,13 @@ int mdev_register_driver(struct mdev_driver *drv, struct module *owner); ...@@ -146,7 +174,13 @@ int mdev_register_driver(struct mdev_driver *drv, struct module *owner);
void mdev_unregister_driver(struct mdev_driver *drv); void mdev_unregister_driver(struct mdev_driver *drv);
struct device *mdev_parent_dev(struct mdev_device *mdev); struct device *mdev_parent_dev(struct mdev_device *mdev);
struct device *mdev_dev(struct mdev_device *mdev); static inline struct device *mdev_dev(struct mdev_device *mdev)
struct mdev_device *mdev_from_dev(struct device *dev); {
return &mdev->dev;
}
static inline struct mdev_device *mdev_from_dev(struct device *dev)
{
return dev->bus == &mdev_bus_type ? to_mdev_device(dev) : NULL;
}
#endif /* MDEV_H */ #endif /* MDEV_H */
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment