Commit 8e8d9442 authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'vfio-v5.14-rc1' of git://github.com/awilliam/linux-vfio

Pull VFIO updates from Alex Williamson:

 - Module reference fixes, structure renaming (Max Gurtovoy)

 - Export and use common pci_dev_trylock() (Luis Chamberlain)

 - Enable direct mdev device creation and probing by parent (Christoph
   Hellwig & Jason Gunthorpe)

 - Fix mdpy error path leak (Colin Ian King)

 - Fix mtty list entry leak (Jason Gunthorpe)

 - Enforce mtty device limit (Alex Williamson)

 - Resolve concurrent vfio-pci mmap faults (Alex Williamson)

* tag 'vfio-v5.14-rc1' of git://github.com/awilliam/linux-vfio:
  vfio/pci: Handle concurrent vma faults
  vfio/mtty: Enforce available_instances
  vfio/mtty: Delete mdev_devices_list
  vfio: use the new pci_dev_trylock() helper to simplify try lock
  PCI: Export pci_dev_trylock() and pci_dev_unlock()
  vfio/mdpy: Fix memory leak of object mdev_state->vconfig
  vfio/iommu_type1: rename vfio_group struck to vfio_iommu_group
  vfio/mbochs: Convert to use vfio_register_group_dev()
  vfio/mdpy: Convert to use vfio_register_group_dev()
  vfio/mtty: Convert to use vfio_register_group_dev()
  vfio/mdev: Allow the mdev_parent_ops to specify the device driver to bind
  vfio/mdev: Remove CONFIG_VFIO_MDEV_DEVICE
  driver core: Export device_driver_attach()
  driver core: Don't return EPROBE_DEFER to userspace during sysfs bind
  driver core: Flow the return code from ->probe() through to sysfs bind
  driver core: Better distinguish probe errors in really_probe
  driver core: Pull required checks into driver_probe_device()
  vfio/platform: remove unneeded parent_module attribute
  vfio: centralize module refcount in subsystem layer
parents 58ec9059 6a45ece4
...@@ -93,7 +93,7 @@ interfaces: ...@@ -93,7 +93,7 @@ interfaces:
Registration Interface for a Mediated Bus Driver Registration Interface for a Mediated Bus Driver
------------------------------------------------ ------------------------------------------------
The registration interface for a mediated bus driver provides the following The registration interface for a mediated device driver provides the following
structure to represent a mediated device's driver:: structure to represent a mediated device's driver::
/* /*
...@@ -136,37 +136,26 @@ The structures in the mdev_parent_ops structure are as follows: ...@@ -136,37 +136,26 @@ The structures in the mdev_parent_ops structure are as follows:
* dev_attr_groups: attributes of the parent device * dev_attr_groups: attributes of the parent device
* mdev_attr_groups: attributes of the mediated device * mdev_attr_groups: attributes of the mediated device
* supported_config: attributes to define supported configurations * supported_config: attributes to define supported configurations
* device_driver: device driver to bind for mediated device instances
The functions in the mdev_parent_ops structure are as follows: The mdev_parent_ops also still has various functions pointers. Theses exist
for historical reasons only and shall not be used for new drivers.
* create: allocate basic resources in a driver for a mediated device When a driver wants to add the GUID creation sysfs to an existing device it has
* remove: free resources in a driver when a mediated device is destroyed probe'd to then it should call::
(Note that mdev-core provides no implicit serialization of create/remove
callbacks per mdev parent device, per mdev type, or any other categorization.
Vendor drivers are expected to be fully asynchronous in this respect or
provide their own internal resource protection.)
The callbacks in the mdev_parent_ops structure are as follows:
* open: open callback of mediated device
* close: close callback of mediated device
* ioctl: ioctl callback of mediated device
* read : read emulation callback
* write: write emulation callback
* mmap: mmap emulation callback
A driver should use the mdev_parent_ops structure in the function call to
register itself with the mdev core driver::
extern int mdev_register_device(struct device *dev, extern int mdev_register_device(struct device *dev,
const struct mdev_parent_ops *ops); const struct mdev_parent_ops *ops);
However, the mdev_parent_ops structure is not required in the function call This will provide the 'mdev_supported_types/XX/create' files which can then be
that a driver should use to unregister itself with the mdev core driver:: used to trigger the creation of a mdev_device. The created mdev_device will be
attached to the specified driver.
When the driver needs to remove itself it calls::
extern void mdev_unregister_device(struct device *dev); extern void mdev_unregister_device(struct device *dev);
Which will unbind and destroy all the created mdevs and remove the sysfs files.
Mediated Device Management Interface Through sysfs Mediated Device Management Interface Through sysfs
================================================== ==================================================
......
...@@ -514,7 +514,6 @@ These are the steps: ...@@ -514,7 +514,6 @@ These are the steps:
* S390_AP_IOMMU * S390_AP_IOMMU
* VFIO * VFIO
* VFIO_MDEV * VFIO_MDEV
* VFIO_MDEV_DEVICE
* KVM * KVM
If using make menuconfig select the following to build the vfio_ap module:: If using make menuconfig select the following to build the vfio_ap module::
......
...@@ -767,7 +767,7 @@ config VFIO_CCW ...@@ -767,7 +767,7 @@ config VFIO_CCW
config VFIO_AP config VFIO_AP
def_tristate n def_tristate n
prompt "VFIO support for AP devices" prompt "VFIO support for AP devices"
depends on S390_AP_IOMMU && VFIO_MDEV_DEVICE && KVM depends on S390_AP_IOMMU && VFIO_MDEV && KVM
depends on ZCRYPT depends on ZCRYPT
help help
This driver grants access to Adjunct Processor (AP) devices This driver grants access to Adjunct Processor (AP) devices
......
...@@ -152,7 +152,6 @@ extern int driver_add_groups(struct device_driver *drv, ...@@ -152,7 +152,6 @@ extern int driver_add_groups(struct device_driver *drv,
const struct attribute_group **groups); const struct attribute_group **groups);
extern void driver_remove_groups(struct device_driver *drv, extern void driver_remove_groups(struct device_driver *drv,
const struct attribute_group **groups); const struct attribute_group **groups);
int device_driver_attach(struct device_driver *drv, struct device *dev);
void device_driver_detach(struct device *dev); void device_driver_detach(struct device *dev);
extern char *make_class_name(const char *name, struct kobject *kobj); extern char *make_class_name(const char *name, struct kobject *kobj);
......
...@@ -210,15 +210,11 @@ static ssize_t bind_store(struct device_driver *drv, const char *buf, ...@@ -210,15 +210,11 @@ static ssize_t bind_store(struct device_driver *drv, const char *buf,
int err = -ENODEV; int err = -ENODEV;
dev = bus_find_device_by_name(bus, NULL, buf); dev = bus_find_device_by_name(bus, NULL, buf);
if (dev && dev->driver == NULL && driver_match_device(drv, dev)) { if (dev && driver_match_device(drv, dev)) {
err = device_driver_attach(drv, dev); err = device_driver_attach(drv, dev);
if (!err) {
if (err > 0) {
/* success */ /* success */
err = count; err = count;
} else if (err == 0) {
/* driver didn't accept device */
err = -ENODEV;
} }
} }
put_device(dev); put_device(dev);
......
...@@ -471,6 +471,8 @@ static void driver_sysfs_remove(struct device *dev) ...@@ -471,6 +471,8 @@ static void driver_sysfs_remove(struct device *dev)
* (It is ok to call with no other effort from a driver's probe() method.) * (It is ok to call with no other effort from a driver's probe() method.)
* *
* This function must be called with the device lock held. * This function must be called with the device lock held.
*
* Callers should prefer to use device_driver_attach() instead.
*/ */
int device_bind_driver(struct device *dev) int device_bind_driver(struct device *dev)
{ {
...@@ -491,15 +493,6 @@ EXPORT_SYMBOL_GPL(device_bind_driver); ...@@ -491,15 +493,6 @@ EXPORT_SYMBOL_GPL(device_bind_driver);
static atomic_t probe_count = ATOMIC_INIT(0); static atomic_t probe_count = ATOMIC_INIT(0);
static DECLARE_WAIT_QUEUE_HEAD(probe_waitqueue); static DECLARE_WAIT_QUEUE_HEAD(probe_waitqueue);
static void driver_deferred_probe_add_trigger(struct device *dev,
int local_trigger_count)
{
driver_deferred_probe_add(dev);
/* Did a trigger occur while probing? Need to re-trigger if yes */
if (local_trigger_count != atomic_read(&deferred_trigger_count))
driver_deferred_probe_trigger();
}
static ssize_t state_synced_show(struct device *dev, static ssize_t state_synced_show(struct device *dev,
struct device_attribute *attr, char *buf) struct device_attribute *attr, char *buf)
{ {
...@@ -513,12 +506,43 @@ static ssize_t state_synced_show(struct device *dev, ...@@ -513,12 +506,43 @@ static ssize_t state_synced_show(struct device *dev,
} }
static DEVICE_ATTR_RO(state_synced); static DEVICE_ATTR_RO(state_synced);
static int call_driver_probe(struct device *dev, struct device_driver *drv)
{
int ret = 0;
if (dev->bus->probe)
ret = dev->bus->probe(dev);
else if (drv->probe)
ret = drv->probe(dev);
switch (ret) {
case 0:
break;
case -EPROBE_DEFER:
/* Driver requested deferred probing */
dev_dbg(dev, "Driver %s requests probe deferral\n", drv->name);
break;
case -ENODEV:
case -ENXIO:
pr_debug("%s: probe of %s rejects match %d\n",
drv->name, dev_name(dev), ret);
break;
default:
/* driver matched but the probe failed */
pr_warn("%s: probe of %s failed with error %d\n",
drv->name, dev_name(dev), ret);
break;
}
return ret;
}
static int really_probe(struct device *dev, struct device_driver *drv) static int really_probe(struct device *dev, struct device_driver *drv)
{ {
int ret = -EPROBE_DEFER;
int local_trigger_count = atomic_read(&deferred_trigger_count);
bool test_remove = IS_ENABLED(CONFIG_DEBUG_TEST_DRIVER_REMOVE) && bool test_remove = IS_ENABLED(CONFIG_DEBUG_TEST_DRIVER_REMOVE) &&
!drv->suppress_bind_attrs; !drv->suppress_bind_attrs;
int ret;
if (defer_all_probes) { if (defer_all_probes) {
/* /*
...@@ -527,17 +551,13 @@ static int really_probe(struct device *dev, struct device_driver *drv) ...@@ -527,17 +551,13 @@ static int really_probe(struct device *dev, struct device_driver *drv)
* wait_for_device_probe() right after that to avoid any races. * wait_for_device_probe() right after that to avoid any races.
*/ */
dev_dbg(dev, "Driver %s force probe deferral\n", drv->name); dev_dbg(dev, "Driver %s force probe deferral\n", drv->name);
driver_deferred_probe_add(dev); return -EPROBE_DEFER;
return ret;
} }
ret = device_links_check_suppliers(dev); ret = device_links_check_suppliers(dev);
if (ret == -EPROBE_DEFER)
driver_deferred_probe_add_trigger(dev, local_trigger_count);
if (ret) if (ret)
return ret; return ret;
atomic_inc(&probe_count);
pr_debug("bus: '%s': %s: probing driver %s with device %s\n", pr_debug("bus: '%s': %s: probing driver %s with device %s\n",
drv->bus->name, __func__, drv->name, dev_name(dev)); drv->bus->name, __func__, drv->name, dev_name(dev));
if (!list_empty(&dev->devres_head)) { if (!list_empty(&dev->devres_head)) {
...@@ -572,13 +592,13 @@ static int really_probe(struct device *dev, struct device_driver *drv) ...@@ -572,13 +592,13 @@ static int really_probe(struct device *dev, struct device_driver *drv)
goto probe_failed; goto probe_failed;
} }
if (dev->bus->probe) { ret = call_driver_probe(dev, drv);
ret = dev->bus->probe(dev); if (ret) {
if (ret) /*
goto probe_failed; * Return probe errors as positive values so that the callers
} else if (drv->probe) { * can distinguish them from other errors.
ret = drv->probe(dev); */
if (ret) ret = -ret;
goto probe_failed; goto probe_failed;
} }
...@@ -621,7 +641,6 @@ static int really_probe(struct device *dev, struct device_driver *drv) ...@@ -621,7 +641,6 @@ static int really_probe(struct device *dev, struct device_driver *drv)
dev->pm_domain->sync(dev); dev->pm_domain->sync(dev);
driver_bound(dev); driver_bound(dev);
ret = 1;
pr_debug("bus: '%s': %s: bound device %s to driver %s\n", pr_debug("bus: '%s': %s: bound device %s to driver %s\n",
drv->bus->name, __func__, dev_name(dev), drv->name); drv->bus->name, __func__, dev_name(dev), drv->name);
goto done; goto done;
...@@ -650,31 +669,7 @@ static int really_probe(struct device *dev, struct device_driver *drv) ...@@ -650,31 +669,7 @@ static int really_probe(struct device *dev, struct device_driver *drv)
dev->pm_domain->dismiss(dev); dev->pm_domain->dismiss(dev);
pm_runtime_reinit(dev); pm_runtime_reinit(dev);
dev_pm_set_driver_flags(dev, 0); dev_pm_set_driver_flags(dev, 0);
switch (ret) {
case -EPROBE_DEFER:
/* Driver requested deferred probing */
dev_dbg(dev, "Driver %s requests probe deferral\n", drv->name);
driver_deferred_probe_add_trigger(dev, local_trigger_count);
break;
case -ENODEV:
case -ENXIO:
pr_debug("%s: probe of %s rejects match %d\n",
drv->name, dev_name(dev), ret);
break;
default:
/* driver matched but the probe failed */
pr_warn("%s: probe of %s failed with error %d\n",
drv->name, dev_name(dev), ret);
}
/*
* Ignore errors returned by ->probe so that the next driver can try
* its luck.
*/
ret = 0;
done: done:
atomic_dec(&probe_count);
wake_up_all(&probe_waitqueue);
return ret; return ret;
} }
...@@ -728,25 +723,14 @@ void wait_for_device_probe(void) ...@@ -728,25 +723,14 @@ void wait_for_device_probe(void)
} }
EXPORT_SYMBOL_GPL(wait_for_device_probe); EXPORT_SYMBOL_GPL(wait_for_device_probe);
/** static int __driver_probe_device(struct device_driver *drv, struct device *dev)
* driver_probe_device - attempt to bind device & driver together
* @drv: driver to bind a device to
* @dev: device to try to bind to the driver
*
* This function returns -ENODEV if the device is not registered,
* 1 if the device is bound successfully and 0 otherwise.
*
* This function must be called with @dev lock held. When called for a
* USB interface, @dev->parent lock must be held as well.
*
* If the device has a parent, runtime-resume the parent before driver probing.
*/
static int driver_probe_device(struct device_driver *drv, struct device *dev)
{ {
int ret = 0; int ret = 0;
if (!device_is_registered(dev)) if (dev->p->dead || !device_is_registered(dev))
return -ENODEV; return -ENODEV;
if (dev->driver)
return -EBUSY;
dev->can_match = true; dev->can_match = true;
pr_debug("bus: '%s': %s: matched device %s with driver %s\n", pr_debug("bus: '%s': %s: matched device %s with driver %s\n",
...@@ -770,6 +754,42 @@ static int driver_probe_device(struct device_driver *drv, struct device *dev) ...@@ -770,6 +754,42 @@ static int driver_probe_device(struct device_driver *drv, struct device *dev)
return ret; return ret;
} }
/**
* driver_probe_device - attempt to bind device & driver together
* @drv: driver to bind a device to
* @dev: device to try to bind to the driver
*
* This function returns -ENODEV if the device is not registered, -EBUSY if it
* already has a driver, 0 if the device is bound successfully and a positive
* (inverted) error code for failures from the ->probe method.
*
* This function must be called with @dev lock held. When called for a
* USB interface, @dev->parent lock must be held as well.
*
* If the device has a parent, runtime-resume the parent before driver probing.
*/
static int driver_probe_device(struct device_driver *drv, struct device *dev)
{
int trigger_count = atomic_read(&deferred_trigger_count);
int ret;
atomic_inc(&probe_count);
ret = __driver_probe_device(drv, dev);
if (ret == -EPROBE_DEFER || ret == EPROBE_DEFER) {
driver_deferred_probe_add(dev);
/*
* Did a trigger occur while probing? Need to re-trigger if yes
*/
if (trigger_count != atomic_read(&deferred_trigger_count) &&
!defer_all_probes)
driver_deferred_probe_trigger();
}
atomic_dec(&probe_count);
wake_up_all(&probe_waitqueue);
return ret;
}
static inline bool cmdline_requested_async_probing(const char *drv_name) static inline bool cmdline_requested_async_probing(const char *drv_name)
{ {
return parse_option_str(async_probe_drv_names, drv_name); return parse_option_str(async_probe_drv_names, drv_name);
...@@ -867,7 +887,14 @@ static int __device_attach_driver(struct device_driver *drv, void *_data) ...@@ -867,7 +887,14 @@ static int __device_attach_driver(struct device_driver *drv, void *_data)
if (data->check_async && async_allowed != data->want_async) if (data->check_async && async_allowed != data->want_async)
return 0; return 0;
return driver_probe_device(drv, dev); /*
* Ignore errors returned by ->probe so that the next driver can try
* its luck.
*/
ret = driver_probe_device(drv, dev);
if (ret < 0)
return ret;
return ret == 0;
} }
static void __device_attach_async_helper(void *_dev, async_cookie_t cookie) static void __device_attach_async_helper(void *_dev, async_cookie_t cookie)
...@@ -1023,43 +1050,34 @@ static void __device_driver_unlock(struct device *dev, struct device *parent) ...@@ -1023,43 +1050,34 @@ static void __device_driver_unlock(struct device *dev, struct device *parent)
* @dev: Device to attach it to * @dev: Device to attach it to
* *
* Manually attach driver to a device. Will acquire both @dev lock and * Manually attach driver to a device. Will acquire both @dev lock and
* @dev->parent lock if needed. * @dev->parent lock if needed. Returns 0 on success, -ERR on failure.
*/ */
int device_driver_attach(struct device_driver *drv, struct device *dev) int device_driver_attach(struct device_driver *drv, struct device *dev)
{ {
int ret = 0; int ret;
__device_driver_lock(dev, dev->parent); __device_driver_lock(dev, dev->parent);
ret = __driver_probe_device(drv, dev);
/*
* If device has been removed or someone has already successfully
* bound a driver before us just skip the driver probe call.
*/
if (!dev->p->dead && !dev->driver)
ret = driver_probe_device(drv, dev);
__device_driver_unlock(dev, dev->parent); __device_driver_unlock(dev, dev->parent);
/* also return probe errors as normal negative errnos */
if (ret > 0)
ret = -ret;
if (ret == -EPROBE_DEFER)
return -EAGAIN;
return ret; return ret;
} }
EXPORT_SYMBOL_GPL(device_driver_attach);
static void __driver_attach_async_helper(void *_dev, async_cookie_t cookie) static void __driver_attach_async_helper(void *_dev, async_cookie_t cookie)
{ {
struct device *dev = _dev; struct device *dev = _dev;
struct device_driver *drv; struct device_driver *drv;
int ret = 0; int ret;
__device_driver_lock(dev, dev->parent); __device_driver_lock(dev, dev->parent);
drv = dev->p->async_driver; drv = dev->p->async_driver;
/*
* If device has been removed or someone has already successfully
* bound a driver before us just skip the driver probe call.
*/
if (!dev->p->dead && !dev->driver)
ret = driver_probe_device(drv, dev); ret = driver_probe_device(drv, dev);
__device_driver_unlock(dev, dev->parent); __device_driver_unlock(dev, dev->parent);
dev_dbg(dev, "driver %s async attach completed: %d\n", drv->name, ret); dev_dbg(dev, "driver %s async attach completed: %d\n", drv->name, ret);
...@@ -1114,7 +1132,9 @@ static int __driver_attach(struct device *dev, void *data) ...@@ -1114,7 +1132,9 @@ static int __driver_attach(struct device *dev, void *data)
return 0; return 0;
} }
device_driver_attach(drv, dev); __device_driver_lock(dev, dev->parent);
driver_probe_device(drv, dev);
__device_driver_unlock(dev, dev->parent);
return 0; return 0;
} }
......
...@@ -125,7 +125,7 @@ config DRM_I915_GVT_KVMGT ...@@ -125,7 +125,7 @@ config DRM_I915_GVT_KVMGT
tristate "Enable KVM/VFIO support for Intel GVT-g" tristate "Enable KVM/VFIO support for Intel GVT-g"
depends on DRM_I915_GVT depends on DRM_I915_GVT
depends on KVM depends on KVM
depends on VFIO_MDEV && VFIO_MDEV_DEVICE depends on VFIO_MDEV
default n default n
help help
Choose this option if you want to enable KVMGT support for Choose this option if you want to enable KVMGT support for
......
...@@ -5038,7 +5038,7 @@ static void pci_dev_lock(struct pci_dev *dev) ...@@ -5038,7 +5038,7 @@ static void pci_dev_lock(struct pci_dev *dev)
} }
/* Return 1 on successful lock, 0 on contention */ /* Return 1 on successful lock, 0 on contention */
static int pci_dev_trylock(struct pci_dev *dev) int pci_dev_trylock(struct pci_dev *dev)
{ {
if (pci_cfg_access_trylock(dev)) { if (pci_cfg_access_trylock(dev)) {
if (device_trylock(&dev->dev)) if (device_trylock(&dev->dev))
...@@ -5048,12 +5048,14 @@ static int pci_dev_trylock(struct pci_dev *dev) ...@@ -5048,12 +5048,14 @@ static int pci_dev_trylock(struct pci_dev *dev)
return 0; return 0;
} }
EXPORT_SYMBOL_GPL(pci_dev_trylock);
static void pci_dev_unlock(struct pci_dev *dev) void pci_dev_unlock(struct pci_dev *dev)
{ {
device_unlock(&dev->dev); device_unlock(&dev->dev);
pci_cfg_access_unlock(dev); pci_cfg_access_unlock(dev);
} }
EXPORT_SYMBOL_GPL(pci_dev_unlock);
static void pci_dev_save_and_disable(struct pci_dev *dev) static void pci_dev_save_and_disable(struct pci_dev *dev)
{ {
......
...@@ -140,26 +140,18 @@ static int vfio_fsl_mc_open(struct vfio_device *core_vdev) ...@@ -140,26 +140,18 @@ static int vfio_fsl_mc_open(struct vfio_device *core_vdev)
{ {
struct vfio_fsl_mc_device *vdev = struct vfio_fsl_mc_device *vdev =
container_of(core_vdev, struct vfio_fsl_mc_device, vdev); container_of(core_vdev, struct vfio_fsl_mc_device, vdev);
int ret; int ret = 0;
if (!try_module_get(THIS_MODULE))
return -ENODEV;
mutex_lock(&vdev->reflck->lock); mutex_lock(&vdev->reflck->lock);
if (!vdev->refcnt) { if (!vdev->refcnt) {
ret = vfio_fsl_mc_regions_init(vdev); ret = vfio_fsl_mc_regions_init(vdev);
if (ret) if (ret)
goto err_reg_init; goto out;
} }
vdev->refcnt++; vdev->refcnt++;
out:
mutex_unlock(&vdev->reflck->lock); mutex_unlock(&vdev->reflck->lock);
return 0;
err_reg_init:
mutex_unlock(&vdev->reflck->lock);
module_put(THIS_MODULE);
return ret; return ret;
} }
...@@ -196,8 +188,6 @@ static void vfio_fsl_mc_release(struct vfio_device *core_vdev) ...@@ -196,8 +188,6 @@ static void vfio_fsl_mc_release(struct vfio_device *core_vdev)
} }
mutex_unlock(&vdev->reflck->lock); mutex_unlock(&vdev->reflck->lock);
module_put(THIS_MODULE);
} }
static long vfio_fsl_mc_ioctl(struct vfio_device *core_vdev, static long vfio_fsl_mc_ioctl(struct vfio_device *core_vdev,
......
...@@ -9,10 +9,3 @@ config VFIO_MDEV ...@@ -9,10 +9,3 @@ config VFIO_MDEV
See Documentation/driver-api/vfio-mediated-device.rst for more details. See Documentation/driver-api/vfio-mediated-device.rst for more details.
If you don't know what do here, say N. If you don't know what do here, say N.
config VFIO_MDEV_DEVICE
tristate "VFIO driver for Mediated devices"
depends on VFIO && VFIO_MDEV
default n
help
VFIO based driver for Mediated devices.
# SPDX-License-Identifier: GPL-2.0-only # SPDX-License-Identifier: GPL-2.0-only
mdev-y := mdev_core.o mdev_sysfs.o mdev_driver.o mdev-y := mdev_core.o mdev_sysfs.o mdev_driver.o vfio_mdev.o
obj-$(CONFIG_VFIO_MDEV) += mdev.o obj-$(CONFIG_VFIO_MDEV) += mdev.o
obj-$(CONFIG_VFIO_MDEV_DEVICE) += vfio_mdev.o
...@@ -94,9 +94,11 @@ static void mdev_device_remove_common(struct mdev_device *mdev) ...@@ -94,9 +94,11 @@ static void mdev_device_remove_common(struct mdev_device *mdev)
mdev_remove_sysfs_files(mdev); mdev_remove_sysfs_files(mdev);
device_del(&mdev->dev); device_del(&mdev->dev);
lockdep_assert_held(&parent->unreg_sem); lockdep_assert_held(&parent->unreg_sem);
if (parent->ops->remove) {
ret = parent->ops->remove(mdev); ret = parent->ops->remove(mdev);
if (ret) if (ret)
dev_err(&mdev->dev, "Remove failed: err=%d\n", ret); dev_err(&mdev->dev, "Remove failed: err=%d\n", ret);
}
/* Balances with device_initialize() */ /* Balances with device_initialize() */
put_device(&mdev->dev); put_device(&mdev->dev);
...@@ -127,7 +129,9 @@ int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops) ...@@ -127,7 +129,9 @@ int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops)
char *envp[] = { env_string, NULL }; char *envp[] = { env_string, NULL };
/* check for mandatory ops */ /* check for mandatory ops */
if (!ops || !ops->create || !ops->remove || !ops->supported_type_groups) if (!ops || !ops->supported_type_groups)
return -EINVAL;
if (!ops->device_driver && (!ops->create || !ops->remove))
return -EINVAL; return -EINVAL;
dev = get_device(dev); dev = get_device(dev);
...@@ -256,6 +260,7 @@ int mdev_device_create(struct mdev_type *type, const guid_t *uuid) ...@@ -256,6 +260,7 @@ int mdev_device_create(struct mdev_type *type, const guid_t *uuid)
int ret; int ret;
struct mdev_device *mdev, *tmp; struct mdev_device *mdev, *tmp;
struct mdev_parent *parent = type->parent; struct mdev_parent *parent = type->parent;
struct mdev_driver *drv = parent->ops->device_driver;
mutex_lock(&mdev_list_lock); mutex_lock(&mdev_list_lock);
...@@ -296,14 +301,22 @@ int mdev_device_create(struct mdev_type *type, const guid_t *uuid) ...@@ -296,14 +301,22 @@ int mdev_device_create(struct mdev_type *type, const guid_t *uuid)
goto out_put_device; goto out_put_device;
} }
if (parent->ops->create) {
ret = parent->ops->create(mdev); ret = parent->ops->create(mdev);
if (ret) if (ret)
goto out_unlock; goto out_unlock;
}
ret = device_add(&mdev->dev); ret = device_add(&mdev->dev);
if (ret) if (ret)
goto out_remove; goto out_remove;
if (!drv)
drv = &vfio_mdev_driver;
ret = device_driver_attach(&drv->driver, &mdev->dev);
if (ret)
goto out_del;
ret = mdev_create_sysfs_files(mdev); ret = mdev_create_sysfs_files(mdev);
if (ret) if (ret)
goto out_del; goto out_del;
...@@ -317,6 +330,7 @@ int mdev_device_create(struct mdev_type *type, const guid_t *uuid) ...@@ -317,6 +330,7 @@ int mdev_device_create(struct mdev_type *type, const guid_t *uuid)
out_del: out_del:
device_del(&mdev->dev); device_del(&mdev->dev);
out_remove: out_remove:
if (parent->ops->remove)
parent->ops->remove(mdev); parent->ops->remove(mdev);
out_unlock: out_unlock:
up_read(&parent->unreg_sem); up_read(&parent->unreg_sem);
...@@ -360,11 +374,24 @@ int mdev_device_remove(struct mdev_device *mdev) ...@@ -360,11 +374,24 @@ int mdev_device_remove(struct mdev_device *mdev)
static int __init mdev_init(void) static int __init mdev_init(void)
{ {
return mdev_bus_register(); int rc;
rc = mdev_bus_register();
if (rc)
return rc;
rc = mdev_register_driver(&vfio_mdev_driver);
if (rc)
goto err_bus;
return 0;
err_bus:
mdev_bus_unregister();
return rc;
} }
static void __exit mdev_exit(void) static void __exit mdev_exit(void)
{ {
mdev_unregister_driver(&vfio_mdev_driver);
if (mdev_bus_compat_class) if (mdev_bus_compat_class)
class_compat_unregister(mdev_bus_compat_class); class_compat_unregister(mdev_bus_compat_class);
...@@ -378,4 +405,3 @@ MODULE_VERSION(DRIVER_VERSION); ...@@ -378,4 +405,3 @@ MODULE_VERSION(DRIVER_VERSION);
MODULE_LICENSE("GPL v2"); MODULE_LICENSE("GPL v2");
MODULE_AUTHOR(DRIVER_AUTHOR); MODULE_AUTHOR(DRIVER_AUTHOR);
MODULE_DESCRIPTION(DRIVER_DESC); MODULE_DESCRIPTION(DRIVER_DESC);
MODULE_SOFTDEP("post: vfio_mdev");
...@@ -71,10 +71,20 @@ static int mdev_remove(struct device *dev) ...@@ -71,10 +71,20 @@ static int mdev_remove(struct device *dev)
return 0; return 0;
} }
static int mdev_match(struct device *dev, struct device_driver *drv)
{
/*
* No drivers automatically match. Drivers are only bound by explicit
* device_driver_attach()
*/
return 0;
}
struct bus_type mdev_bus_type = { struct bus_type mdev_bus_type = {
.name = "mdev", .name = "mdev",
.probe = mdev_probe, .probe = mdev_probe,
.remove = mdev_remove, .remove = mdev_remove,
.match = mdev_match,
}; };
EXPORT_SYMBOL_GPL(mdev_bus_type); EXPORT_SYMBOL_GPL(mdev_bus_type);
......
...@@ -37,6 +37,8 @@ struct mdev_type { ...@@ -37,6 +37,8 @@ struct mdev_type {
#define to_mdev_type(_kobj) \ #define to_mdev_type(_kobj) \
container_of(_kobj, struct mdev_type, kobj) container_of(_kobj, struct mdev_type, kobj)
extern struct mdev_driver vfio_mdev_driver;
int parent_create_sysfs_files(struct mdev_parent *parent); int parent_create_sysfs_files(struct mdev_parent *parent);
void parent_remove_sysfs_files(struct mdev_parent *parent); void parent_remove_sysfs_files(struct mdev_parent *parent);
......
...@@ -17,28 +17,15 @@ ...@@ -17,28 +17,15 @@
#include "mdev_private.h" #include "mdev_private.h"
#define DRIVER_VERSION "0.1"
#define DRIVER_AUTHOR "NVIDIA Corporation"
#define DRIVER_DESC "VFIO based driver for Mediated device"
static int vfio_mdev_open(struct vfio_device *core_vdev) static int vfio_mdev_open(struct vfio_device *core_vdev)
{ {
struct mdev_device *mdev = to_mdev_device(core_vdev->dev); struct mdev_device *mdev = to_mdev_device(core_vdev->dev);
struct mdev_parent *parent = mdev->type->parent; struct mdev_parent *parent = mdev->type->parent;
int ret;
if (unlikely(!parent->ops->open)) if (unlikely(!parent->ops->open))
return -EINVAL; return -EINVAL;
if (!try_module_get(THIS_MODULE)) return parent->ops->open(mdev);
return -ENODEV;
ret = parent->ops->open(mdev);
if (ret)
module_put(THIS_MODULE);
return ret;
} }
static void vfio_mdev_release(struct vfio_device *core_vdev) static void vfio_mdev_release(struct vfio_device *core_vdev)
...@@ -48,8 +35,6 @@ static void vfio_mdev_release(struct vfio_device *core_vdev) ...@@ -48,8 +35,6 @@ static void vfio_mdev_release(struct vfio_device *core_vdev)
if (likely(parent->ops->release)) if (likely(parent->ops->release))
parent->ops->release(mdev); parent->ops->release(mdev);
module_put(THIS_MODULE);
} }
static long vfio_mdev_unlocked_ioctl(struct vfio_device *core_vdev, static long vfio_mdev_unlocked_ioctl(struct vfio_device *core_vdev,
...@@ -151,7 +136,7 @@ static void vfio_mdev_remove(struct mdev_device *mdev) ...@@ -151,7 +136,7 @@ static void vfio_mdev_remove(struct mdev_device *mdev)
kfree(vdev); kfree(vdev);
} }
static struct mdev_driver vfio_mdev_driver = { struct mdev_driver vfio_mdev_driver = {
.driver = { .driver = {
.name = "vfio_mdev", .name = "vfio_mdev",
.owner = THIS_MODULE, .owner = THIS_MODULE,
...@@ -160,21 +145,3 @@ static struct mdev_driver vfio_mdev_driver = { ...@@ -160,21 +145,3 @@ static struct mdev_driver vfio_mdev_driver = {
.probe = vfio_mdev_probe, .probe = vfio_mdev_probe,
.remove = vfio_mdev_remove, .remove = vfio_mdev_remove,
}; };
static int __init vfio_mdev_init(void)
{
return mdev_register_driver(&vfio_mdev_driver);
}
static void __exit vfio_mdev_exit(void)
{
mdev_unregister_driver(&vfio_mdev_driver);
}
module_init(vfio_mdev_init)
module_exit(vfio_mdev_exit)
MODULE_VERSION(DRIVER_VERSION);
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR(DRIVER_AUTHOR);
MODULE_DESCRIPTION(DRIVER_DESC);
...@@ -477,13 +477,10 @@ static void vfio_pci_disable(struct vfio_pci_device *vdev) ...@@ -477,13 +477,10 @@ static void vfio_pci_disable(struct vfio_pci_device *vdev)
* We can not use the "try" reset interface here, which will * We can not use the "try" reset interface here, which will
* overwrite the previously restored configuration information. * overwrite the previously restored configuration information.
*/ */
if (vdev->reset_works && pci_cfg_access_trylock(pdev)) { if (vdev->reset_works && pci_dev_trylock(pdev)) {
if (device_trylock(&pdev->dev)) {
if (!__pci_reset_function_locked(pdev)) if (!__pci_reset_function_locked(pdev))
vdev->needs_reset = false; vdev->needs_reset = false;
device_unlock(&pdev->dev); pci_dev_unlock(pdev);
}
pci_cfg_access_unlock(pdev);
} }
pci_restore_state(pdev); pci_restore_state(pdev);
...@@ -558,8 +555,6 @@ static void vfio_pci_release(struct vfio_device *core_vdev) ...@@ -558,8 +555,6 @@ static void vfio_pci_release(struct vfio_device *core_vdev)
} }
mutex_unlock(&vdev->reflck->lock); mutex_unlock(&vdev->reflck->lock);
module_put(THIS_MODULE);
} }
static int vfio_pci_open(struct vfio_device *core_vdev) static int vfio_pci_open(struct vfio_device *core_vdev)
...@@ -568,9 +563,6 @@ static int vfio_pci_open(struct vfio_device *core_vdev) ...@@ -568,9 +563,6 @@ static int vfio_pci_open(struct vfio_device *core_vdev)
container_of(core_vdev, struct vfio_pci_device, vdev); container_of(core_vdev, struct vfio_pci_device, vdev);
int ret = 0; int ret = 0;
if (!try_module_get(THIS_MODULE))
return -ENODEV;
mutex_lock(&vdev->reflck->lock); mutex_lock(&vdev->reflck->lock);
if (!vdev->refcnt) { if (!vdev->refcnt) {
...@@ -584,8 +576,6 @@ static int vfio_pci_open(struct vfio_device *core_vdev) ...@@ -584,8 +576,6 @@ static int vfio_pci_open(struct vfio_device *core_vdev)
vdev->refcnt++; vdev->refcnt++;
error: error:
mutex_unlock(&vdev->reflck->lock); mutex_unlock(&vdev->reflck->lock);
if (ret)
module_put(THIS_MODULE);
return ret; return ret;
} }
...@@ -1594,6 +1584,7 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf) ...@@ -1594,6 +1584,7 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf)
{ {
struct vm_area_struct *vma = vmf->vma; struct vm_area_struct *vma = vmf->vma;
struct vfio_pci_device *vdev = vma->vm_private_data; struct vfio_pci_device *vdev = vma->vm_private_data;
struct vfio_pci_mmap_vma *mmap_vma;
vm_fault_t ret = VM_FAULT_NOPAGE; vm_fault_t ret = VM_FAULT_NOPAGE;
mutex_lock(&vdev->vma_lock); mutex_lock(&vdev->vma_lock);
...@@ -1601,24 +1592,36 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf) ...@@ -1601,24 +1592,36 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf)
if (!__vfio_pci_memory_enabled(vdev)) { if (!__vfio_pci_memory_enabled(vdev)) {
ret = VM_FAULT_SIGBUS; ret = VM_FAULT_SIGBUS;
mutex_unlock(&vdev->vma_lock);
goto up_out; goto up_out;
} }
if (__vfio_pci_add_vma(vdev, vma)) { /*
ret = VM_FAULT_OOM; * We populate the whole vma on fault, so we need to test whether
mutex_unlock(&vdev->vma_lock); * the vma has already been mapped, such as for concurrent faults
* to the same vma. io_remap_pfn_range() will trigger a BUG_ON if
* we ask it to fill the same range again.
*/
list_for_each_entry(mmap_vma, &vdev->vma_list, vma_next) {
if (mmap_vma->vma == vma)
goto up_out; goto up_out;
} }
mutex_unlock(&vdev->vma_lock);
if (io_remap_pfn_range(vma, vma->vm_start, vma->vm_pgoff, if (io_remap_pfn_range(vma, vma->vm_start, vma->vm_pgoff,
vma->vm_end - vma->vm_start, vma->vm_page_prot)) vma->vm_end - vma->vm_start,
vma->vm_page_prot)) {
ret = VM_FAULT_SIGBUS; ret = VM_FAULT_SIGBUS;
zap_vma_ptes(vma, vma->vm_start, vma->vm_end - vma->vm_start);
goto up_out;
}
if (__vfio_pci_add_vma(vdev, vma)) {
ret = VM_FAULT_OOM;
zap_vma_ptes(vma, vma->vm_start, vma->vm_end - vma->vm_start);
}
up_out: up_out:
up_read(&vdev->memory_lock); up_read(&vdev->memory_lock);
mutex_unlock(&vdev->vma_lock);
return ret; return ret;
} }
......
...@@ -59,7 +59,6 @@ static int vfio_amba_probe(struct amba_device *adev, const struct amba_id *id) ...@@ -59,7 +59,6 @@ static int vfio_amba_probe(struct amba_device *adev, const struct amba_id *id)
vdev->flags = VFIO_DEVICE_FLAGS_AMBA; vdev->flags = VFIO_DEVICE_FLAGS_AMBA;
vdev->get_resource = get_amba_resource; vdev->get_resource = get_amba_resource;
vdev->get_irq = get_amba_irq; vdev->get_irq = get_amba_irq;
vdev->parent_module = THIS_MODULE;
vdev->reset_required = false; vdev->reset_required = false;
ret = vfio_platform_probe_common(vdev, &adev->dev); ret = vfio_platform_probe_common(vdev, &adev->dev);
......
...@@ -50,7 +50,6 @@ static int vfio_platform_probe(struct platform_device *pdev) ...@@ -50,7 +50,6 @@ static int vfio_platform_probe(struct platform_device *pdev)
vdev->flags = VFIO_DEVICE_FLAGS_PLATFORM; vdev->flags = VFIO_DEVICE_FLAGS_PLATFORM;
vdev->get_resource = get_platform_resource; vdev->get_resource = get_platform_resource;
vdev->get_irq = get_platform_irq; vdev->get_irq = get_platform_irq;
vdev->parent_module = THIS_MODULE;
vdev->reset_required = reset_required; vdev->reset_required = reset_required;
ret = vfio_platform_probe_common(vdev, &pdev->dev); ret = vfio_platform_probe_common(vdev, &pdev->dev);
......
...@@ -241,8 +241,6 @@ static void vfio_platform_release(struct vfio_device *core_vdev) ...@@ -241,8 +241,6 @@ static void vfio_platform_release(struct vfio_device *core_vdev)
} }
mutex_unlock(&driver_lock); mutex_unlock(&driver_lock);
module_put(vdev->parent_module);
} }
static int vfio_platform_open(struct vfio_device *core_vdev) static int vfio_platform_open(struct vfio_device *core_vdev)
...@@ -251,9 +249,6 @@ static int vfio_platform_open(struct vfio_device *core_vdev) ...@@ -251,9 +249,6 @@ static int vfio_platform_open(struct vfio_device *core_vdev)
container_of(core_vdev, struct vfio_platform_device, vdev); container_of(core_vdev, struct vfio_platform_device, vdev);
int ret; int ret;
if (!try_module_get(vdev->parent_module))
return -ENODEV;
mutex_lock(&driver_lock); mutex_lock(&driver_lock);
if (!vdev->refcnt) { if (!vdev->refcnt) {
...@@ -291,7 +286,6 @@ static int vfio_platform_open(struct vfio_device *core_vdev) ...@@ -291,7 +286,6 @@ static int vfio_platform_open(struct vfio_device *core_vdev)
vfio_platform_regions_cleanup(vdev); vfio_platform_regions_cleanup(vdev);
err_reg: err_reg:
mutex_unlock(&driver_lock); mutex_unlock(&driver_lock);
module_put(vdev->parent_module);
return ret; return ret;
} }
......
...@@ -50,7 +50,6 @@ struct vfio_platform_device { ...@@ -50,7 +50,6 @@ struct vfio_platform_device {
u32 num_irqs; u32 num_irqs;
int refcnt; int refcnt;
struct mutex igate; struct mutex igate;
struct module *parent_module;
const char *compat; const char *compat;
const char *acpihid; const char *acpihid;
struct module *reset_module; struct module *reset_module;
......
...@@ -1369,8 +1369,14 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf) ...@@ -1369,8 +1369,14 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
if (IS_ERR(device)) if (IS_ERR(device))
return PTR_ERR(device); return PTR_ERR(device);
if (!try_module_get(device->dev->driver->owner)) {
vfio_device_put(device);
return -ENODEV;
}
ret = device->ops->open(device); ret = device->ops->open(device);
if (ret) { if (ret) {
module_put(device->dev->driver->owner);
vfio_device_put(device); vfio_device_put(device);
return ret; return ret;
} }
...@@ -1382,6 +1388,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf) ...@@ -1382,6 +1388,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
ret = get_unused_fd_flags(O_CLOEXEC); ret = get_unused_fd_flags(O_CLOEXEC);
if (ret < 0) { if (ret < 0) {
device->ops->release(device); device->ops->release(device);
module_put(device->dev->driver->owner);
vfio_device_put(device); vfio_device_put(device);
return ret; return ret;
} }
...@@ -1392,6 +1399,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf) ...@@ -1392,6 +1399,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
put_unused_fd(ret); put_unused_fd(ret);
ret = PTR_ERR(filep); ret = PTR_ERR(filep);
device->ops->release(device); device->ops->release(device);
module_put(device->dev->driver->owner);
vfio_device_put(device); vfio_device_put(device);
return ret; return ret;
} }
...@@ -1550,6 +1558,8 @@ static int vfio_device_fops_release(struct inode *inode, struct file *filep) ...@@ -1550,6 +1558,8 @@ static int vfio_device_fops_release(struct inode *inode, struct file *filep)
device->ops->release(device); device->ops->release(device);
module_put(device->dev->driver->owner);
vfio_group_try_dissolve_container(device->group); vfio_group_try_dissolve_container(device->group);
vfio_device_put(device); vfio_device_put(device);
......
...@@ -110,7 +110,7 @@ struct vfio_batch { ...@@ -110,7 +110,7 @@ struct vfio_batch {
int offset; /* of next entry in pages */ int offset; /* of next entry in pages */
}; };
struct vfio_group { struct vfio_iommu_group {
struct iommu_group *iommu_group; struct iommu_group *iommu_group;
struct list_head next; struct list_head next;
bool mdev_group; /* An mdev group */ bool mdev_group; /* An mdev group */
...@@ -160,7 +160,8 @@ struct vfio_regions { ...@@ -160,7 +160,8 @@ struct vfio_regions {
static int put_pfn(unsigned long pfn, int prot); static int put_pfn(unsigned long pfn, int prot);
static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu, static struct vfio_iommu_group*
vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
struct iommu_group *iommu_group); struct iommu_group *iommu_group);
/* /*
...@@ -836,7 +837,7 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data, ...@@ -836,7 +837,7 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
unsigned long *phys_pfn) unsigned long *phys_pfn)
{ {
struct vfio_iommu *iommu = iommu_data; struct vfio_iommu *iommu = iommu_data;
struct vfio_group *group; struct vfio_iommu_group *group;
int i, j, ret; int i, j, ret;
unsigned long remote_vaddr; unsigned long remote_vaddr;
struct vfio_dma *dma; struct vfio_dma *dma;
...@@ -1875,10 +1876,10 @@ static void vfio_test_domain_fgsp(struct vfio_domain *domain) ...@@ -1875,10 +1876,10 @@ static void vfio_test_domain_fgsp(struct vfio_domain *domain)
__free_pages(pages, order); __free_pages(pages, order);
} }
static struct vfio_group *find_iommu_group(struct vfio_domain *domain, static struct vfio_iommu_group *find_iommu_group(struct vfio_domain *domain,
struct iommu_group *iommu_group) struct iommu_group *iommu_group)
{ {
struct vfio_group *g; struct vfio_iommu_group *g;
list_for_each_entry(g, &domain->group_list, next) { list_for_each_entry(g, &domain->group_list, next) {
if (g->iommu_group == iommu_group) if (g->iommu_group == iommu_group)
...@@ -1888,11 +1889,12 @@ static struct vfio_group *find_iommu_group(struct vfio_domain *domain, ...@@ -1888,11 +1889,12 @@ static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
return NULL; return NULL;
} }
static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu, static struct vfio_iommu_group*
vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
struct iommu_group *iommu_group) struct iommu_group *iommu_group)
{ {
struct vfio_domain *domain; struct vfio_domain *domain;
struct vfio_group *group = NULL; struct vfio_iommu_group *group = NULL;
list_for_each_entry(domain, &iommu->domain_list, next) { list_for_each_entry(domain, &iommu->domain_list, next) {
group = find_iommu_group(domain, iommu_group); group = find_iommu_group(domain, iommu_group);
...@@ -1967,7 +1969,7 @@ static int vfio_mdev_detach_domain(struct device *dev, void *data) ...@@ -1967,7 +1969,7 @@ static int vfio_mdev_detach_domain(struct device *dev, void *data)
} }
static int vfio_iommu_attach_group(struct vfio_domain *domain, static int vfio_iommu_attach_group(struct vfio_domain *domain,
struct vfio_group *group) struct vfio_iommu_group *group)
{ {
if (group->mdev_group) if (group->mdev_group)
return iommu_group_for_each_dev(group->iommu_group, return iommu_group_for_each_dev(group->iommu_group,
...@@ -1978,7 +1980,7 @@ static int vfio_iommu_attach_group(struct vfio_domain *domain, ...@@ -1978,7 +1980,7 @@ static int vfio_iommu_attach_group(struct vfio_domain *domain,
} }
static void vfio_iommu_detach_group(struct vfio_domain *domain, static void vfio_iommu_detach_group(struct vfio_domain *domain,
struct vfio_group *group) struct vfio_iommu_group *group)
{ {
if (group->mdev_group) if (group->mdev_group)
iommu_group_for_each_dev(group->iommu_group, domain->domain, iommu_group_for_each_dev(group->iommu_group, domain->domain,
...@@ -2242,7 +2244,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, ...@@ -2242,7 +2244,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
struct iommu_group *iommu_group) struct iommu_group *iommu_group)
{ {
struct vfio_iommu *iommu = iommu_data; struct vfio_iommu *iommu = iommu_data;
struct vfio_group *group; struct vfio_iommu_group *group;
struct vfio_domain *domain, *d; struct vfio_domain *domain, *d;
struct bus_type *bus = NULL; struct bus_type *bus = NULL;
int ret; int ret;
...@@ -2518,7 +2520,7 @@ static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu, ...@@ -2518,7 +2520,7 @@ static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
struct list_head *iova_copy) struct list_head *iova_copy)
{ {
struct vfio_domain *d; struct vfio_domain *d;
struct vfio_group *g; struct vfio_iommu_group *g;
struct vfio_iova *node; struct vfio_iova *node;
dma_addr_t start, end; dma_addr_t start, end;
LIST_HEAD(resv_regions); LIST_HEAD(resv_regions);
...@@ -2560,7 +2562,7 @@ static void vfio_iommu_type1_detach_group(void *iommu_data, ...@@ -2560,7 +2562,7 @@ static void vfio_iommu_type1_detach_group(void *iommu_data,
{ {
struct vfio_iommu *iommu = iommu_data; struct vfio_iommu *iommu = iommu_data;
struct vfio_domain *domain; struct vfio_domain *domain;
struct vfio_group *group; struct vfio_iommu_group *group;
bool update_dirty_scope = false; bool update_dirty_scope = false;
LIST_HEAD(iova_copy); LIST_HEAD(iova_copy);
...@@ -2681,7 +2683,7 @@ static void *vfio_iommu_type1_open(unsigned long arg) ...@@ -2681,7 +2683,7 @@ static void *vfio_iommu_type1_open(unsigned long arg)
static void vfio_release_domain(struct vfio_domain *domain, bool external) static void vfio_release_domain(struct vfio_domain *domain, bool external)
{ {
struct vfio_group *group, *group_tmp; struct vfio_iommu_group *group, *group_tmp;
list_for_each_entry_safe(group, group_tmp, list_for_each_entry_safe(group, group_tmp,
&domain->group_list, next) { &domain->group_list, next) {
......
...@@ -846,6 +846,8 @@ static inline void *dev_get_platdata(const struct device *dev) ...@@ -846,6 +846,8 @@ static inline void *dev_get_platdata(const struct device *dev)
* Manual binding of a device to driver. See drivers/base/bus.c * Manual binding of a device to driver. See drivers/base/bus.c
* for information on use. * for information on use.
*/ */
int __must_check device_driver_attach(struct device_driver *drv,
struct device *dev);
int __must_check device_bind_driver(struct device *dev); int __must_check device_bind_driver(struct device *dev);
void device_release_driver(struct device *dev); void device_release_driver(struct device *dev);
int __must_check device_attach(struct device *dev); int __must_check device_attach(struct device *dev);
......
...@@ -55,6 +55,7 @@ struct device *mtype_get_parent_dev(struct mdev_type *mtype); ...@@ -55,6 +55,7 @@ struct device *mtype_get_parent_dev(struct mdev_type *mtype);
* register the device to mdev module. * register the device to mdev module.
* *
* @owner: The module owner. * @owner: The module owner.
* @device_driver: Which device driver to probe() on newly created devices
* @dev_attr_groups: Attributes of the parent device. * @dev_attr_groups: Attributes of the parent device.
* @mdev_attr_groups: Attributes of the mediated device. * @mdev_attr_groups: Attributes of the mediated device.
* @supported_type_groups: Attributes to define supported types. It is mandatory * @supported_type_groups: Attributes to define supported types. It is mandatory
...@@ -103,6 +104,7 @@ struct device *mtype_get_parent_dev(struct mdev_type *mtype); ...@@ -103,6 +104,7 @@ struct device *mtype_get_parent_dev(struct mdev_type *mtype);
**/ **/
struct mdev_parent_ops { struct mdev_parent_ops {
struct module *owner; struct module *owner;
struct mdev_driver *device_driver;
const struct attribute_group **dev_attr_groups; const struct attribute_group **dev_attr_groups;
const struct attribute_group **mdev_attr_groups; const struct attribute_group **mdev_attr_groups;
struct attribute_group **supported_type_groups; struct attribute_group **supported_type_groups;
......
...@@ -1624,6 +1624,9 @@ void pci_cfg_access_lock(struct pci_dev *dev); ...@@ -1624,6 +1624,9 @@ void pci_cfg_access_lock(struct pci_dev *dev);
bool pci_cfg_access_trylock(struct pci_dev *dev); bool pci_cfg_access_trylock(struct pci_dev *dev);
void pci_cfg_access_unlock(struct pci_dev *dev); void pci_cfg_access_unlock(struct pci_dev *dev);
int pci_dev_trylock(struct pci_dev *dev);
void pci_dev_unlock(struct pci_dev *dev);
/* /*
* PCI domain support. Sometimes called PCI segment (eg by ACPI), * PCI domain support. Sometimes called PCI segment (eg by ACPI),
* a PCI domain is defined to be a set of PCI buses which share * a PCI domain is defined to be a set of PCI buses which share
......
...@@ -154,14 +154,14 @@ config SAMPLE_UHID ...@@ -154,14 +154,14 @@ config SAMPLE_UHID
config SAMPLE_VFIO_MDEV_MTTY config SAMPLE_VFIO_MDEV_MTTY
tristate "Build VFIO mtty example mediated device sample code -- loadable modules only" tristate "Build VFIO mtty example mediated device sample code -- loadable modules only"
depends on VFIO_MDEV_DEVICE && m depends on VFIO_MDEV && m
help help
Build a virtual tty sample driver for use as a VFIO Build a virtual tty sample driver for use as a VFIO
mediated device mediated device
config SAMPLE_VFIO_MDEV_MDPY config SAMPLE_VFIO_MDEV_MDPY
tristate "Build VFIO mdpy example mediated device sample code -- loadable modules only" tristate "Build VFIO mdpy example mediated device sample code -- loadable modules only"
depends on VFIO_MDEV_DEVICE && m depends on VFIO_MDEV && m
help help
Build a virtual display sample driver for use as a VFIO Build a virtual display sample driver for use as a VFIO
mediated device. It is a simple framebuffer and supports mediated device. It is a simple framebuffer and supports
...@@ -178,7 +178,7 @@ config SAMPLE_VFIO_MDEV_MDPY_FB ...@@ -178,7 +178,7 @@ config SAMPLE_VFIO_MDEV_MDPY_FB
config SAMPLE_VFIO_MDEV_MBOCHS config SAMPLE_VFIO_MDEV_MBOCHS
tristate "Build VFIO mdpy example mediated device sample code -- loadable modules only" tristate "Build VFIO mdpy example mediated device sample code -- loadable modules only"
depends on VFIO_MDEV_DEVICE && m depends on VFIO_MDEV && m
select DMA_SHARED_BUFFER select DMA_SHARED_BUFFER
help help
Build a virtual display sample driver for use as a VFIO Build a virtual display sample driver for use as a VFIO
......
...@@ -130,6 +130,7 @@ static struct class *mbochs_class; ...@@ -130,6 +130,7 @@ static struct class *mbochs_class;
static struct cdev mbochs_cdev; static struct cdev mbochs_cdev;
static struct device mbochs_dev; static struct device mbochs_dev;
static int mbochs_used_mbytes; static int mbochs_used_mbytes;
static const struct vfio_device_ops mbochs_dev_ops;
struct vfio_region_info_ext { struct vfio_region_info_ext {
struct vfio_region_info base; struct vfio_region_info base;
...@@ -160,6 +161,7 @@ struct mbochs_dmabuf { ...@@ -160,6 +161,7 @@ struct mbochs_dmabuf {
/* State of each mdev device */ /* State of each mdev device */
struct mdev_state { struct mdev_state {
struct vfio_device vdev;
u8 *vconfig; u8 *vconfig;
u64 bar_mask[3]; u64 bar_mask[3];
u32 memory_bar_mask; u32 memory_bar_mask;
...@@ -425,11 +427,9 @@ static void handle_edid_blob(struct mdev_state *mdev_state, u16 offset, ...@@ -425,11 +427,9 @@ static void handle_edid_blob(struct mdev_state *mdev_state, u16 offset,
memcpy(buf, mdev_state->edid_blob + offset, count); memcpy(buf, mdev_state->edid_blob + offset, count);
} }
static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, static ssize_t mdev_access(struct mdev_state *mdev_state, char *buf,
loff_t pos, bool is_write) size_t count, loff_t pos, bool is_write)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
struct device *dev = mdev_dev(mdev);
struct page *pg; struct page *pg;
loff_t poff; loff_t poff;
char *map; char *map;
...@@ -478,7 +478,7 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, ...@@ -478,7 +478,7 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count,
put_page(pg); put_page(pg);
} else { } else {
dev_dbg(dev, "%s: %s @0x%llx (unhandled)\n", dev_dbg(mdev_state->vdev.dev, "%s: %s @0x%llx (unhandled)\n",
__func__, is_write ? "WR" : "RD", pos); __func__, is_write ? "WR" : "RD", pos);
ret = -1; ret = -1;
goto accessfailed; goto accessfailed;
...@@ -493,9 +493,8 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, ...@@ -493,9 +493,8 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count,
return ret; return ret;
} }
static int mbochs_reset(struct mdev_device *mdev) static int mbochs_reset(struct mdev_state *mdev_state)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
u32 size64k = mdev_state->memsize / (64 * 1024); u32 size64k = mdev_state->memsize / (64 * 1024);
int i; int i;
...@@ -506,12 +505,13 @@ static int mbochs_reset(struct mdev_device *mdev) ...@@ -506,12 +505,13 @@ static int mbochs_reset(struct mdev_device *mdev)
return 0; return 0;
} }
static int mbochs_create(struct mdev_device *mdev) static int mbochs_probe(struct mdev_device *mdev)
{ {
const struct mbochs_type *type = const struct mbochs_type *type =
&mbochs_types[mdev_get_type_group_id(mdev)]; &mbochs_types[mdev_get_type_group_id(mdev)];
struct device *dev = mdev_dev(mdev); struct device *dev = mdev_dev(mdev);
struct mdev_state *mdev_state; struct mdev_state *mdev_state;
int ret = -ENOMEM;
if (type->mbytes + mbochs_used_mbytes > max_mbytes) if (type->mbytes + mbochs_used_mbytes > max_mbytes)
return -ENOMEM; return -ENOMEM;
...@@ -519,6 +519,7 @@ static int mbochs_create(struct mdev_device *mdev) ...@@ -519,6 +519,7 @@ static int mbochs_create(struct mdev_device *mdev)
mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL); mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL);
if (mdev_state == NULL) if (mdev_state == NULL)
return -ENOMEM; return -ENOMEM;
vfio_init_group_dev(&mdev_state->vdev, &mdev->dev, &mbochs_dev_ops);
mdev_state->vconfig = kzalloc(MBOCHS_CONFIG_SPACE_SIZE, GFP_KERNEL); mdev_state->vconfig = kzalloc(MBOCHS_CONFIG_SPACE_SIZE, GFP_KERNEL);
if (mdev_state->vconfig == NULL) if (mdev_state->vconfig == NULL)
...@@ -537,7 +538,6 @@ static int mbochs_create(struct mdev_device *mdev) ...@@ -537,7 +538,6 @@ static int mbochs_create(struct mdev_device *mdev)
mutex_init(&mdev_state->ops_lock); mutex_init(&mdev_state->ops_lock);
mdev_state->mdev = mdev; mdev_state->mdev = mdev;
mdev_set_drvdata(mdev, mdev_state);
INIT_LIST_HEAD(&mdev_state->dmabufs); INIT_LIST_HEAD(&mdev_state->dmabufs);
mdev_state->next_id = 1; mdev_state->next_id = 1;
...@@ -547,32 +547,38 @@ static int mbochs_create(struct mdev_device *mdev) ...@@ -547,32 +547,38 @@ static int mbochs_create(struct mdev_device *mdev)
mdev_state->edid_regs.edid_offset = MBOCHS_EDID_BLOB_OFFSET; mdev_state->edid_regs.edid_offset = MBOCHS_EDID_BLOB_OFFSET;
mdev_state->edid_regs.edid_max_size = sizeof(mdev_state->edid_blob); mdev_state->edid_regs.edid_max_size = sizeof(mdev_state->edid_blob);
mbochs_create_config_space(mdev_state); mbochs_create_config_space(mdev_state);
mbochs_reset(mdev); mbochs_reset(mdev_state);
mbochs_used_mbytes += type->mbytes; mbochs_used_mbytes += type->mbytes;
ret = vfio_register_group_dev(&mdev_state->vdev);
if (ret)
goto err_mem;
dev_set_drvdata(&mdev->dev, mdev_state);
return 0; return 0;
err_mem: err_mem:
kfree(mdev_state->vconfig); kfree(mdev_state->vconfig);
kfree(mdev_state); kfree(mdev_state);
return -ENOMEM; return ret;
} }
static int mbochs_remove(struct mdev_device *mdev) static void mbochs_remove(struct mdev_device *mdev)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev); struct mdev_state *mdev_state = dev_get_drvdata(&mdev->dev);
mbochs_used_mbytes -= mdev_state->type->mbytes; mbochs_used_mbytes -= mdev_state->type->mbytes;
mdev_set_drvdata(mdev, NULL); vfio_unregister_group_dev(&mdev_state->vdev);
kfree(mdev_state->pages); kfree(mdev_state->pages);
kfree(mdev_state->vconfig); kfree(mdev_state->vconfig);
kfree(mdev_state); kfree(mdev_state);
return 0;
} }
static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf, static ssize_t mbochs_read(struct vfio_device *vdev, char __user *buf,
size_t count, loff_t *ppos) size_t count, loff_t *ppos)
{ {
struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
unsigned int done = 0; unsigned int done = 0;
int ret; int ret;
...@@ -582,7 +588,7 @@ static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf, ...@@ -582,7 +588,7 @@ static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf,
if (count >= 4 && !(*ppos % 4)) { if (count >= 4 && !(*ppos % 4)) {
u32 val; u32 val;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -594,7 +600,7 @@ static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf, ...@@ -594,7 +600,7 @@ static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf,
} else if (count >= 2 && !(*ppos % 2)) { } else if (count >= 2 && !(*ppos % 2)) {
u16 val; u16 val;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -606,7 +612,7 @@ static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf, ...@@ -606,7 +612,7 @@ static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf,
} else { } else {
u8 val; u8 val;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -629,9 +635,11 @@ static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf, ...@@ -629,9 +635,11 @@ static ssize_t mbochs_read(struct mdev_device *mdev, char __user *buf,
return -EFAULT; return -EFAULT;
} }
static ssize_t mbochs_write(struct mdev_device *mdev, const char __user *buf, static ssize_t mbochs_write(struct vfio_device *vdev, const char __user *buf,
size_t count, loff_t *ppos) size_t count, loff_t *ppos)
{ {
struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
unsigned int done = 0; unsigned int done = 0;
int ret; int ret;
...@@ -644,7 +652,7 @@ static ssize_t mbochs_write(struct mdev_device *mdev, const char __user *buf, ...@@ -644,7 +652,7 @@ static ssize_t mbochs_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -656,7 +664,7 @@ static ssize_t mbochs_write(struct mdev_device *mdev, const char __user *buf, ...@@ -656,7 +664,7 @@ static ssize_t mbochs_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -668,7 +676,7 @@ static ssize_t mbochs_write(struct mdev_device *mdev, const char __user *buf, ...@@ -668,7 +676,7 @@ static ssize_t mbochs_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -754,9 +762,10 @@ static const struct vm_operations_struct mbochs_region_vm_ops = { ...@@ -754,9 +762,10 @@ static const struct vm_operations_struct mbochs_region_vm_ops = {
.fault = mbochs_region_vm_fault, .fault = mbochs_region_vm_fault,
}; };
static int mbochs_mmap(struct mdev_device *mdev, struct vm_area_struct *vma) static int mbochs_mmap(struct vfio_device *vdev, struct vm_area_struct *vma)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev); struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
if (vma->vm_pgoff != MBOCHS_MEMORY_BAR_OFFSET >> PAGE_SHIFT) if (vma->vm_pgoff != MBOCHS_MEMORY_BAR_OFFSET >> PAGE_SHIFT)
return -EINVAL; return -EINVAL;
...@@ -963,7 +972,7 @@ mbochs_dmabuf_find_by_id(struct mdev_state *mdev_state, u32 id) ...@@ -963,7 +972,7 @@ mbochs_dmabuf_find_by_id(struct mdev_state *mdev_state, u32 id)
static int mbochs_dmabuf_export(struct mbochs_dmabuf *dmabuf) static int mbochs_dmabuf_export(struct mbochs_dmabuf *dmabuf)
{ {
struct mdev_state *mdev_state = dmabuf->mdev_state; struct mdev_state *mdev_state = dmabuf->mdev_state;
struct device *dev = mdev_dev(mdev_state->mdev); struct device *dev = mdev_state->vdev.dev;
DEFINE_DMA_BUF_EXPORT_INFO(exp_info); DEFINE_DMA_BUF_EXPORT_INFO(exp_info);
struct dma_buf *buf; struct dma_buf *buf;
...@@ -991,15 +1000,10 @@ static int mbochs_dmabuf_export(struct mbochs_dmabuf *dmabuf) ...@@ -991,15 +1000,10 @@ static int mbochs_dmabuf_export(struct mbochs_dmabuf *dmabuf)
return 0; return 0;
} }
static int mbochs_get_region_info(struct mdev_device *mdev, static int mbochs_get_region_info(struct mdev_state *mdev_state,
struct vfio_region_info_ext *ext) struct vfio_region_info_ext *ext)
{ {
struct vfio_region_info *region_info = &ext->base; struct vfio_region_info *region_info = &ext->base;
struct mdev_state *mdev_state;
mdev_state = mdev_get_drvdata(mdev);
if (!mdev_state)
return -EINVAL;
if (region_info->index >= MBOCHS_NUM_REGIONS) if (region_info->index >= MBOCHS_NUM_REGIONS)
return -EINVAL; return -EINVAL;
...@@ -1047,15 +1051,13 @@ static int mbochs_get_region_info(struct mdev_device *mdev, ...@@ -1047,15 +1051,13 @@ static int mbochs_get_region_info(struct mdev_device *mdev,
return 0; return 0;
} }
static int mbochs_get_irq_info(struct mdev_device *mdev, static int mbochs_get_irq_info(struct vfio_irq_info *irq_info)
struct vfio_irq_info *irq_info)
{ {
irq_info->count = 0; irq_info->count = 0;
return 0; return 0;
} }
static int mbochs_get_device_info(struct mdev_device *mdev, static int mbochs_get_device_info(struct vfio_device_info *dev_info)
struct vfio_device_info *dev_info)
{ {
dev_info->flags = VFIO_DEVICE_FLAGS_PCI; dev_info->flags = VFIO_DEVICE_FLAGS_PCI;
dev_info->num_regions = MBOCHS_NUM_REGIONS; dev_info->num_regions = MBOCHS_NUM_REGIONS;
...@@ -1063,11 +1065,9 @@ static int mbochs_get_device_info(struct mdev_device *mdev, ...@@ -1063,11 +1065,9 @@ static int mbochs_get_device_info(struct mdev_device *mdev,
return 0; return 0;
} }
static int mbochs_query_gfx_plane(struct mdev_device *mdev, static int mbochs_query_gfx_plane(struct mdev_state *mdev_state,
struct vfio_device_gfx_plane_info *plane) struct vfio_device_gfx_plane_info *plane)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
struct device *dev = mdev_dev(mdev);
struct mbochs_dmabuf *dmabuf; struct mbochs_dmabuf *dmabuf;
struct mbochs_mode mode; struct mbochs_mode mode;
int ret; int ret;
...@@ -1121,18 +1121,16 @@ static int mbochs_query_gfx_plane(struct mdev_device *mdev, ...@@ -1121,18 +1121,16 @@ static int mbochs_query_gfx_plane(struct mdev_device *mdev,
done: done:
if (plane->drm_plane_type == DRM_PLANE_TYPE_PRIMARY && if (plane->drm_plane_type == DRM_PLANE_TYPE_PRIMARY &&
mdev_state->active_id != plane->dmabuf_id) { mdev_state->active_id != plane->dmabuf_id) {
dev_dbg(dev, "%s: primary: %d => %d\n", __func__, dev_dbg(mdev_state->vdev.dev, "%s: primary: %d => %d\n",
mdev_state->active_id, plane->dmabuf_id); __func__, mdev_state->active_id, plane->dmabuf_id);
mdev_state->active_id = plane->dmabuf_id; mdev_state->active_id = plane->dmabuf_id;
} }
mutex_unlock(&mdev_state->ops_lock); mutex_unlock(&mdev_state->ops_lock);
return 0; return 0;
} }
static int mbochs_get_gfx_dmabuf(struct mdev_device *mdev, static int mbochs_get_gfx_dmabuf(struct mdev_state *mdev_state, u32 id)
u32 id)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
struct mbochs_dmabuf *dmabuf; struct mbochs_dmabuf *dmabuf;
mutex_lock(&mdev_state->ops_lock); mutex_lock(&mdev_state->ops_lock);
...@@ -1154,9 +1152,11 @@ static int mbochs_get_gfx_dmabuf(struct mdev_device *mdev, ...@@ -1154,9 +1152,11 @@ static int mbochs_get_gfx_dmabuf(struct mdev_device *mdev,
return dma_buf_fd(dmabuf->buf, 0); return dma_buf_fd(dmabuf->buf, 0);
} }
static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd, static long mbochs_ioctl(struct vfio_device *vdev, unsigned int cmd,
unsigned long arg) unsigned long arg)
{ {
struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
int ret = 0; int ret = 0;
unsigned long minsz, outsz; unsigned long minsz, outsz;
...@@ -1173,7 +1173,7 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1173,7 +1173,7 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (info.argsz < minsz) if (info.argsz < minsz)
return -EINVAL; return -EINVAL;
ret = mbochs_get_device_info(mdev, &info); ret = mbochs_get_device_info(&info);
if (ret) if (ret)
return ret; return ret;
...@@ -1197,7 +1197,7 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1197,7 +1197,7 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (outsz > sizeof(info)) if (outsz > sizeof(info))
return -EINVAL; return -EINVAL;
ret = mbochs_get_region_info(mdev, &info); ret = mbochs_get_region_info(mdev_state, &info);
if (ret) if (ret)
return ret; return ret;
...@@ -1220,7 +1220,7 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1220,7 +1220,7 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd,
(info.index >= VFIO_PCI_NUM_IRQS)) (info.index >= VFIO_PCI_NUM_IRQS))
return -EINVAL; return -EINVAL;
ret = mbochs_get_irq_info(mdev, &info); ret = mbochs_get_irq_info(&info);
if (ret) if (ret)
return ret; return ret;
...@@ -1243,7 +1243,7 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1243,7 +1243,7 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (plane.argsz < minsz) if (plane.argsz < minsz)
return -EINVAL; return -EINVAL;
ret = mbochs_query_gfx_plane(mdev, &plane); ret = mbochs_query_gfx_plane(mdev_state, &plane);
if (ret) if (ret)
return ret; return ret;
...@@ -1260,19 +1260,19 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1260,19 +1260,19 @@ static long mbochs_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (get_user(dmabuf_id, (__u32 __user *)arg)) if (get_user(dmabuf_id, (__u32 __user *)arg))
return -EFAULT; return -EFAULT;
return mbochs_get_gfx_dmabuf(mdev, dmabuf_id); return mbochs_get_gfx_dmabuf(mdev_state, dmabuf_id);
} }
case VFIO_DEVICE_SET_IRQS: case VFIO_DEVICE_SET_IRQS:
return -EINVAL; return -EINVAL;
case VFIO_DEVICE_RESET: case VFIO_DEVICE_RESET:
return mbochs_reset(mdev); return mbochs_reset(mdev_state);
} }
return -ENOTTY; return -ENOTTY;
} }
static int mbochs_open(struct mdev_device *mdev) static int mbochs_open(struct vfio_device *vdev)
{ {
if (!try_module_get(THIS_MODULE)) if (!try_module_get(THIS_MODULE))
return -ENODEV; return -ENODEV;
...@@ -1280,9 +1280,10 @@ static int mbochs_open(struct mdev_device *mdev) ...@@ -1280,9 +1280,10 @@ static int mbochs_open(struct mdev_device *mdev)
return 0; return 0;
} }
static void mbochs_close(struct mdev_device *mdev) static void mbochs_close(struct vfio_device *vdev)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev); struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
struct mbochs_dmabuf *dmabuf, *tmp; struct mbochs_dmabuf *dmabuf, *tmp;
mutex_lock(&mdev_state->ops_lock); mutex_lock(&mdev_state->ops_lock);
...@@ -1306,8 +1307,7 @@ static ssize_t ...@@ -1306,8 +1307,7 @@ static ssize_t
memory_show(struct device *dev, struct device_attribute *attr, memory_show(struct device *dev, struct device_attribute *attr,
char *buf) char *buf)
{ {
struct mdev_device *mdev = mdev_from_dev(dev); struct mdev_state *mdev_state = dev_get_drvdata(dev);
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
return sprintf(buf, "%d MB\n", mdev_state->type->mbytes); return sprintf(buf, "%d MB\n", mdev_state->type->mbytes);
} }
...@@ -1398,12 +1398,7 @@ static struct attribute_group *mdev_type_groups[] = { ...@@ -1398,12 +1398,7 @@ static struct attribute_group *mdev_type_groups[] = {
NULL, NULL,
}; };
static const struct mdev_parent_ops mdev_fops = { static const struct vfio_device_ops mbochs_dev_ops = {
.owner = THIS_MODULE,
.mdev_attr_groups = mdev_dev_groups,
.supported_type_groups = mdev_type_groups,
.create = mbochs_create,
.remove = mbochs_remove,
.open = mbochs_open, .open = mbochs_open,
.release = mbochs_close, .release = mbochs_close,
.read = mbochs_read, .read = mbochs_read,
...@@ -1412,6 +1407,23 @@ static const struct mdev_parent_ops mdev_fops = { ...@@ -1412,6 +1407,23 @@ static const struct mdev_parent_ops mdev_fops = {
.mmap = mbochs_mmap, .mmap = mbochs_mmap,
}; };
static struct mdev_driver mbochs_driver = {
.driver = {
.name = "mbochs",
.owner = THIS_MODULE,
.mod_name = KBUILD_MODNAME,
.dev_groups = mdev_dev_groups,
},
.probe = mbochs_probe,
.remove = mbochs_remove,
};
static const struct mdev_parent_ops mdev_fops = {
.owner = THIS_MODULE,
.device_driver = &mbochs_driver,
.supported_type_groups = mdev_type_groups,
};
static const struct file_operations vd_fops = { static const struct file_operations vd_fops = {
.owner = THIS_MODULE, .owner = THIS_MODULE,
}; };
...@@ -1434,11 +1446,15 @@ static int __init mbochs_dev_init(void) ...@@ -1434,11 +1446,15 @@ static int __init mbochs_dev_init(void)
cdev_add(&mbochs_cdev, mbochs_devt, MINORMASK + 1); cdev_add(&mbochs_cdev, mbochs_devt, MINORMASK + 1);
pr_info("%s: major %d\n", __func__, MAJOR(mbochs_devt)); pr_info("%s: major %d\n", __func__, MAJOR(mbochs_devt));
ret = mdev_register_driver(&mbochs_driver);
if (ret)
goto err_cdev;
mbochs_class = class_create(THIS_MODULE, MBOCHS_CLASS_NAME); mbochs_class = class_create(THIS_MODULE, MBOCHS_CLASS_NAME);
if (IS_ERR(mbochs_class)) { if (IS_ERR(mbochs_class)) {
pr_err("Error: failed to register mbochs_dev class\n"); pr_err("Error: failed to register mbochs_dev class\n");
ret = PTR_ERR(mbochs_class); ret = PTR_ERR(mbochs_class);
goto failed1; goto err_driver;
} }
mbochs_dev.class = mbochs_class; mbochs_dev.class = mbochs_class;
mbochs_dev.release = mbochs_device_release; mbochs_dev.release = mbochs_device_release;
...@@ -1446,19 +1462,21 @@ static int __init mbochs_dev_init(void) ...@@ -1446,19 +1462,21 @@ static int __init mbochs_dev_init(void)
ret = device_register(&mbochs_dev); ret = device_register(&mbochs_dev);
if (ret) if (ret)
goto failed2; goto err_class;
ret = mdev_register_device(&mbochs_dev, &mdev_fops); ret = mdev_register_device(&mbochs_dev, &mdev_fops);
if (ret) if (ret)
goto failed3; goto err_device;
return 0; return 0;
failed3: err_device:
device_unregister(&mbochs_dev); device_unregister(&mbochs_dev);
failed2: err_class:
class_destroy(mbochs_class); class_destroy(mbochs_class);
failed1: err_driver:
mdev_unregister_driver(&mbochs_driver);
err_cdev:
cdev_del(&mbochs_cdev); cdev_del(&mbochs_cdev);
unregister_chrdev_region(mbochs_devt, MINORMASK + 1); unregister_chrdev_region(mbochs_devt, MINORMASK + 1);
return ret; return ret;
...@@ -1470,6 +1488,7 @@ static void __exit mbochs_dev_exit(void) ...@@ -1470,6 +1488,7 @@ static void __exit mbochs_dev_exit(void)
mdev_unregister_device(&mbochs_dev); mdev_unregister_device(&mbochs_dev);
device_unregister(&mbochs_dev); device_unregister(&mbochs_dev);
mdev_unregister_driver(&mbochs_driver);
cdev_del(&mbochs_cdev); cdev_del(&mbochs_cdev);
unregister_chrdev_region(mbochs_devt, MINORMASK + 1); unregister_chrdev_region(mbochs_devt, MINORMASK + 1);
class_destroy(mbochs_class); class_destroy(mbochs_class);
......
...@@ -85,9 +85,11 @@ static struct class *mdpy_class; ...@@ -85,9 +85,11 @@ static struct class *mdpy_class;
static struct cdev mdpy_cdev; static struct cdev mdpy_cdev;
static struct device mdpy_dev; static struct device mdpy_dev;
static u32 mdpy_count; static u32 mdpy_count;
static const struct vfio_device_ops mdpy_dev_ops;
/* State of each mdev device */ /* State of each mdev device */
struct mdev_state { struct mdev_state {
struct vfio_device vdev;
u8 *vconfig; u8 *vconfig;
u32 bar_mask; u32 bar_mask;
struct mutex ops_lock; struct mutex ops_lock;
...@@ -162,11 +164,9 @@ static void handle_pci_cfg_write(struct mdev_state *mdev_state, u16 offset, ...@@ -162,11 +164,9 @@ static void handle_pci_cfg_write(struct mdev_state *mdev_state, u16 offset,
} }
} }
static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, static ssize_t mdev_access(struct mdev_state *mdev_state, char *buf,
loff_t pos, bool is_write) size_t count, loff_t pos, bool is_write)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
struct device *dev = mdev_dev(mdev);
int ret = 0; int ret = 0;
mutex_lock(&mdev_state->ops_lock); mutex_lock(&mdev_state->ops_lock);
...@@ -187,8 +187,9 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, ...@@ -187,8 +187,9 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count,
memcpy(buf, mdev_state->memblk, count); memcpy(buf, mdev_state->memblk, count);
} else { } else {
dev_info(dev, "%s: %s @0x%llx (unhandled)\n", dev_info(mdev_state->vdev.dev,
__func__, is_write ? "WR" : "RD", pos); "%s: %s @0x%llx (unhandled)\n", __func__,
is_write ? "WR" : "RD", pos);
ret = -1; ret = -1;
goto accessfailed; goto accessfailed;
} }
...@@ -202,9 +203,8 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, ...@@ -202,9 +203,8 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count,
return ret; return ret;
} }
static int mdpy_reset(struct mdev_device *mdev) static int mdpy_reset(struct mdev_state *mdev_state)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
u32 stride, i; u32 stride, i;
/* initialize with gray gradient */ /* initialize with gray gradient */
...@@ -216,13 +216,14 @@ static int mdpy_reset(struct mdev_device *mdev) ...@@ -216,13 +216,14 @@ static int mdpy_reset(struct mdev_device *mdev)
return 0; return 0;
} }
static int mdpy_create(struct mdev_device *mdev) static int mdpy_probe(struct mdev_device *mdev)
{ {
const struct mdpy_type *type = const struct mdpy_type *type =
&mdpy_types[mdev_get_type_group_id(mdev)]; &mdpy_types[mdev_get_type_group_id(mdev)];
struct device *dev = mdev_dev(mdev); struct device *dev = mdev_dev(mdev);
struct mdev_state *mdev_state; struct mdev_state *mdev_state;
u32 fbsize; u32 fbsize;
int ret;
if (mdpy_count >= max_devices) if (mdpy_count >= max_devices)
return -ENOMEM; return -ENOMEM;
...@@ -230,6 +231,7 @@ static int mdpy_create(struct mdev_device *mdev) ...@@ -230,6 +231,7 @@ static int mdpy_create(struct mdev_device *mdev)
mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL); mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL);
if (mdev_state == NULL) if (mdev_state == NULL)
return -ENOMEM; return -ENOMEM;
vfio_init_group_dev(&mdev_state->vdev, &mdev->dev, &mdpy_dev_ops);
mdev_state->vconfig = kzalloc(MDPY_CONFIG_SPACE_SIZE, GFP_KERNEL); mdev_state->vconfig = kzalloc(MDPY_CONFIG_SPACE_SIZE, GFP_KERNEL);
if (mdev_state->vconfig == NULL) { if (mdev_state->vconfig == NULL) {
...@@ -250,36 +252,42 @@ static int mdpy_create(struct mdev_device *mdev) ...@@ -250,36 +252,42 @@ static int mdpy_create(struct mdev_device *mdev)
mutex_init(&mdev_state->ops_lock); mutex_init(&mdev_state->ops_lock);
mdev_state->mdev = mdev; mdev_state->mdev = mdev;
mdev_set_drvdata(mdev, mdev_state);
mdev_state->type = type; mdev_state->type = type;
mdev_state->memsize = fbsize; mdev_state->memsize = fbsize;
mdpy_create_config_space(mdev_state); mdpy_create_config_space(mdev_state);
mdpy_reset(mdev); mdpy_reset(mdev_state);
mdpy_count++; mdpy_count++;
ret = vfio_register_group_dev(&mdev_state->vdev);
if (ret) {
kfree(mdev_state->vconfig);
kfree(mdev_state);
return ret;
}
dev_set_drvdata(&mdev->dev, mdev_state);
return 0; return 0;
} }
static int mdpy_remove(struct mdev_device *mdev) static void mdpy_remove(struct mdev_device *mdev)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev); struct mdev_state *mdev_state = dev_get_drvdata(&mdev->dev);
struct device *dev = mdev_dev(mdev);
dev_info(dev, "%s\n", __func__); dev_info(&mdev->dev, "%s\n", __func__);
mdev_set_drvdata(mdev, NULL); vfio_unregister_group_dev(&mdev_state->vdev);
vfree(mdev_state->memblk); vfree(mdev_state->memblk);
kfree(mdev_state->vconfig); kfree(mdev_state->vconfig);
kfree(mdev_state); kfree(mdev_state);
mdpy_count--; mdpy_count--;
return 0;
} }
static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf, static ssize_t mdpy_read(struct vfio_device *vdev, char __user *buf,
size_t count, loff_t *ppos) size_t count, loff_t *ppos)
{ {
struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
unsigned int done = 0; unsigned int done = 0;
int ret; int ret;
...@@ -289,7 +297,7 @@ static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf, ...@@ -289,7 +297,7 @@ static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf,
if (count >= 4 && !(*ppos % 4)) { if (count >= 4 && !(*ppos % 4)) {
u32 val; u32 val;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -301,7 +309,7 @@ static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf, ...@@ -301,7 +309,7 @@ static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf,
} else if (count >= 2 && !(*ppos % 2)) { } else if (count >= 2 && !(*ppos % 2)) {
u16 val; u16 val;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -313,7 +321,7 @@ static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf, ...@@ -313,7 +321,7 @@ static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf,
} else { } else {
u8 val; u8 val;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -336,9 +344,11 @@ static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf, ...@@ -336,9 +344,11 @@ static ssize_t mdpy_read(struct mdev_device *mdev, char __user *buf,
return -EFAULT; return -EFAULT;
} }
static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf, static ssize_t mdpy_write(struct vfio_device *vdev, const char __user *buf,
size_t count, loff_t *ppos) size_t count, loff_t *ppos)
{ {
struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
unsigned int done = 0; unsigned int done = 0;
int ret; int ret;
...@@ -351,7 +361,7 @@ static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf, ...@@ -351,7 +361,7 @@ static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -363,7 +373,7 @@ static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf, ...@@ -363,7 +373,7 @@ static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -375,7 +385,7 @@ static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf, ...@@ -375,7 +385,7 @@ static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (char *)&val, sizeof(val), ret = mdev_access(mdev_state, (char *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -393,9 +403,10 @@ static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf, ...@@ -393,9 +403,10 @@ static ssize_t mdpy_write(struct mdev_device *mdev, const char __user *buf,
return -EFAULT; return -EFAULT;
} }
static int mdpy_mmap(struct mdev_device *mdev, struct vm_area_struct *vma) static int mdpy_mmap(struct vfio_device *vdev, struct vm_area_struct *vma)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev); struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
if (vma->vm_pgoff != MDPY_MEMORY_BAR_OFFSET >> PAGE_SHIFT) if (vma->vm_pgoff != MDPY_MEMORY_BAR_OFFSET >> PAGE_SHIFT)
return -EINVAL; return -EINVAL;
...@@ -409,16 +420,10 @@ static int mdpy_mmap(struct mdev_device *mdev, struct vm_area_struct *vma) ...@@ -409,16 +420,10 @@ static int mdpy_mmap(struct mdev_device *mdev, struct vm_area_struct *vma)
return remap_vmalloc_range(vma, mdev_state->memblk, 0); return remap_vmalloc_range(vma, mdev_state->memblk, 0);
} }
static int mdpy_get_region_info(struct mdev_device *mdev, static int mdpy_get_region_info(struct mdev_state *mdev_state,
struct vfio_region_info *region_info, struct vfio_region_info *region_info,
u16 *cap_type_id, void **cap_type) u16 *cap_type_id, void **cap_type)
{ {
struct mdev_state *mdev_state;
mdev_state = mdev_get_drvdata(mdev);
if (!mdev_state)
return -EINVAL;
if (region_info->index >= VFIO_PCI_NUM_REGIONS && if (region_info->index >= VFIO_PCI_NUM_REGIONS &&
region_info->index != MDPY_DISPLAY_REGION) region_info->index != MDPY_DISPLAY_REGION)
return -EINVAL; return -EINVAL;
...@@ -447,15 +452,13 @@ static int mdpy_get_region_info(struct mdev_device *mdev, ...@@ -447,15 +452,13 @@ static int mdpy_get_region_info(struct mdev_device *mdev,
return 0; return 0;
} }
static int mdpy_get_irq_info(struct mdev_device *mdev, static int mdpy_get_irq_info(struct vfio_irq_info *irq_info)
struct vfio_irq_info *irq_info)
{ {
irq_info->count = 0; irq_info->count = 0;
return 0; return 0;
} }
static int mdpy_get_device_info(struct mdev_device *mdev, static int mdpy_get_device_info(struct vfio_device_info *dev_info)
struct vfio_device_info *dev_info)
{ {
dev_info->flags = VFIO_DEVICE_FLAGS_PCI; dev_info->flags = VFIO_DEVICE_FLAGS_PCI;
dev_info->num_regions = VFIO_PCI_NUM_REGIONS; dev_info->num_regions = VFIO_PCI_NUM_REGIONS;
...@@ -463,11 +466,9 @@ static int mdpy_get_device_info(struct mdev_device *mdev, ...@@ -463,11 +466,9 @@ static int mdpy_get_device_info(struct mdev_device *mdev,
return 0; return 0;
} }
static int mdpy_query_gfx_plane(struct mdev_device *mdev, static int mdpy_query_gfx_plane(struct mdev_state *mdev_state,
struct vfio_device_gfx_plane_info *plane) struct vfio_device_gfx_plane_info *plane)
{ {
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
if (plane->flags & VFIO_GFX_PLANE_TYPE_PROBE) { if (plane->flags & VFIO_GFX_PLANE_TYPE_PROBE) {
if (plane->flags == (VFIO_GFX_PLANE_TYPE_PROBE | if (plane->flags == (VFIO_GFX_PLANE_TYPE_PROBE |
VFIO_GFX_PLANE_TYPE_REGION)) VFIO_GFX_PLANE_TYPE_REGION))
...@@ -496,14 +497,13 @@ static int mdpy_query_gfx_plane(struct mdev_device *mdev, ...@@ -496,14 +497,13 @@ static int mdpy_query_gfx_plane(struct mdev_device *mdev,
return 0; return 0;
} }
static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd, static long mdpy_ioctl(struct vfio_device *vdev, unsigned int cmd,
unsigned long arg) unsigned long arg)
{ {
int ret = 0; int ret = 0;
unsigned long minsz; unsigned long minsz;
struct mdev_state *mdev_state; struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
mdev_state = mdev_get_drvdata(mdev);
switch (cmd) { switch (cmd) {
case VFIO_DEVICE_GET_INFO: case VFIO_DEVICE_GET_INFO:
...@@ -518,7 +518,7 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -518,7 +518,7 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (info.argsz < minsz) if (info.argsz < minsz)
return -EINVAL; return -EINVAL;
ret = mdpy_get_device_info(mdev, &info); ret = mdpy_get_device_info(&info);
if (ret) if (ret)
return ret; return ret;
...@@ -543,7 +543,7 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -543,7 +543,7 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (info.argsz < minsz) if (info.argsz < minsz)
return -EINVAL; return -EINVAL;
ret = mdpy_get_region_info(mdev, &info, &cap_type_id, ret = mdpy_get_region_info(mdev_state, &info, &cap_type_id,
&cap_type); &cap_type);
if (ret) if (ret)
return ret; return ret;
...@@ -567,7 +567,7 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -567,7 +567,7 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd,
(info.index >= mdev_state->dev_info.num_irqs)) (info.index >= mdev_state->dev_info.num_irqs))
return -EINVAL; return -EINVAL;
ret = mdpy_get_irq_info(mdev, &info); ret = mdpy_get_irq_info(&info);
if (ret) if (ret)
return ret; return ret;
...@@ -590,7 +590,7 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -590,7 +590,7 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (plane.argsz < minsz) if (plane.argsz < minsz)
return -EINVAL; return -EINVAL;
ret = mdpy_query_gfx_plane(mdev, &plane); ret = mdpy_query_gfx_plane(mdev_state, &plane);
if (ret) if (ret)
return ret; return ret;
...@@ -604,12 +604,12 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -604,12 +604,12 @@ static long mdpy_ioctl(struct mdev_device *mdev, unsigned int cmd,
return -EINVAL; return -EINVAL;
case VFIO_DEVICE_RESET: case VFIO_DEVICE_RESET:
return mdpy_reset(mdev); return mdpy_reset(mdev_state);
} }
return -ENOTTY; return -ENOTTY;
} }
static int mdpy_open(struct mdev_device *mdev) static int mdpy_open(struct vfio_device *vdev)
{ {
if (!try_module_get(THIS_MODULE)) if (!try_module_get(THIS_MODULE))
return -ENODEV; return -ENODEV;
...@@ -617,7 +617,7 @@ static int mdpy_open(struct mdev_device *mdev) ...@@ -617,7 +617,7 @@ static int mdpy_open(struct mdev_device *mdev)
return 0; return 0;
} }
static void mdpy_close(struct mdev_device *mdev) static void mdpy_close(struct vfio_device *vdev)
{ {
module_put(THIS_MODULE); module_put(THIS_MODULE);
} }
...@@ -626,8 +626,7 @@ static ssize_t ...@@ -626,8 +626,7 @@ static ssize_t
resolution_show(struct device *dev, struct device_attribute *attr, resolution_show(struct device *dev, struct device_attribute *attr,
char *buf) char *buf)
{ {
struct mdev_device *mdev = mdev_from_dev(dev); struct mdev_state *mdev_state = dev_get_drvdata(dev);
struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
return sprintf(buf, "%dx%d\n", return sprintf(buf, "%dx%d\n",
mdev_state->type->width, mdev_state->type->width,
...@@ -716,12 +715,7 @@ static struct attribute_group *mdev_type_groups[] = { ...@@ -716,12 +715,7 @@ static struct attribute_group *mdev_type_groups[] = {
NULL, NULL,
}; };
static const struct mdev_parent_ops mdev_fops = { static const struct vfio_device_ops mdpy_dev_ops = {
.owner = THIS_MODULE,
.mdev_attr_groups = mdev_dev_groups,
.supported_type_groups = mdev_type_groups,
.create = mdpy_create,
.remove = mdpy_remove,
.open = mdpy_open, .open = mdpy_open,
.release = mdpy_close, .release = mdpy_close,
.read = mdpy_read, .read = mdpy_read,
...@@ -730,6 +724,23 @@ static const struct mdev_parent_ops mdev_fops = { ...@@ -730,6 +724,23 @@ static const struct mdev_parent_ops mdev_fops = {
.mmap = mdpy_mmap, .mmap = mdpy_mmap,
}; };
static struct mdev_driver mdpy_driver = {
.driver = {
.name = "mdpy",
.owner = THIS_MODULE,
.mod_name = KBUILD_MODNAME,
.dev_groups = mdev_dev_groups,
},
.probe = mdpy_probe,
.remove = mdpy_remove,
};
static const struct mdev_parent_ops mdev_fops = {
.owner = THIS_MODULE,
.device_driver = &mdpy_driver,
.supported_type_groups = mdev_type_groups,
};
static const struct file_operations vd_fops = { static const struct file_operations vd_fops = {
.owner = THIS_MODULE, .owner = THIS_MODULE,
}; };
...@@ -752,11 +763,15 @@ static int __init mdpy_dev_init(void) ...@@ -752,11 +763,15 @@ static int __init mdpy_dev_init(void)
cdev_add(&mdpy_cdev, mdpy_devt, MINORMASK + 1); cdev_add(&mdpy_cdev, mdpy_devt, MINORMASK + 1);
pr_info("%s: major %d\n", __func__, MAJOR(mdpy_devt)); pr_info("%s: major %d\n", __func__, MAJOR(mdpy_devt));
ret = mdev_register_driver(&mdpy_driver);
if (ret)
goto err_cdev;
mdpy_class = class_create(THIS_MODULE, MDPY_CLASS_NAME); mdpy_class = class_create(THIS_MODULE, MDPY_CLASS_NAME);
if (IS_ERR(mdpy_class)) { if (IS_ERR(mdpy_class)) {
pr_err("Error: failed to register mdpy_dev class\n"); pr_err("Error: failed to register mdpy_dev class\n");
ret = PTR_ERR(mdpy_class); ret = PTR_ERR(mdpy_class);
goto failed1; goto err_driver;
} }
mdpy_dev.class = mdpy_class; mdpy_dev.class = mdpy_class;
mdpy_dev.release = mdpy_device_release; mdpy_dev.release = mdpy_device_release;
...@@ -764,19 +779,21 @@ static int __init mdpy_dev_init(void) ...@@ -764,19 +779,21 @@ static int __init mdpy_dev_init(void)
ret = device_register(&mdpy_dev); ret = device_register(&mdpy_dev);
if (ret) if (ret)
goto failed2; goto err_class;
ret = mdev_register_device(&mdpy_dev, &mdev_fops); ret = mdev_register_device(&mdpy_dev, &mdev_fops);
if (ret) if (ret)
goto failed3; goto err_device;
return 0; return 0;
failed3: err_device:
device_unregister(&mdpy_dev); device_unregister(&mdpy_dev);
failed2: err_class:
class_destroy(mdpy_class); class_destroy(mdpy_class);
failed1: err_driver:
mdev_unregister_driver(&mdpy_driver);
err_cdev:
cdev_del(&mdpy_cdev); cdev_del(&mdpy_cdev);
unregister_chrdev_region(mdpy_devt, MINORMASK + 1); unregister_chrdev_region(mdpy_devt, MINORMASK + 1);
return ret; return ret;
...@@ -788,6 +805,7 @@ static void __exit mdpy_dev_exit(void) ...@@ -788,6 +805,7 @@ static void __exit mdpy_dev_exit(void)
mdev_unregister_device(&mdpy_dev); mdev_unregister_device(&mdpy_dev);
device_unregister(&mdpy_dev); device_unregister(&mdpy_dev);
mdev_unregister_driver(&mdpy_driver);
cdev_del(&mdpy_cdev); cdev_del(&mdpy_cdev);
unregister_chrdev_region(mdpy_devt, MINORMASK + 1); unregister_chrdev_region(mdpy_devt, MINORMASK + 1);
class_destroy(mdpy_class); class_destroy(mdpy_class);
......
...@@ -127,6 +127,7 @@ struct serial_port { ...@@ -127,6 +127,7 @@ struct serial_port {
/* State of each mdev device */ /* State of each mdev device */
struct mdev_state { struct mdev_state {
struct vfio_device vdev;
int irq_fd; int irq_fd;
struct eventfd_ctx *intx_evtfd; struct eventfd_ctx *intx_evtfd;
struct eventfd_ctx *msi_evtfd; struct eventfd_ctx *msi_evtfd;
...@@ -143,13 +144,14 @@ struct mdev_state { ...@@ -143,13 +144,14 @@ struct mdev_state {
int nr_ports; int nr_ports;
}; };
static struct mutex mdev_list_lock; static atomic_t mdev_avail_ports = ATOMIC_INIT(MAX_MTTYS);
static struct list_head mdev_devices_list;
static const struct file_operations vd_fops = { static const struct file_operations vd_fops = {
.owner = THIS_MODULE, .owner = THIS_MODULE,
}; };
static const struct vfio_device_ops mtty_dev_ops;
/* function prototypes */ /* function prototypes */
static int mtty_trigger_interrupt(struct mdev_state *mdev_state); static int mtty_trigger_interrupt(struct mdev_state *mdev_state);
...@@ -631,22 +633,15 @@ static void mdev_read_base(struct mdev_state *mdev_state) ...@@ -631,22 +633,15 @@ static void mdev_read_base(struct mdev_state *mdev_state)
} }
} }
static ssize_t mdev_access(struct mdev_device *mdev, u8 *buf, size_t count, static ssize_t mdev_access(struct mdev_state *mdev_state, u8 *buf, size_t count,
loff_t pos, bool is_write) loff_t pos, bool is_write)
{ {
struct mdev_state *mdev_state;
unsigned int index; unsigned int index;
loff_t offset; loff_t offset;
int ret = 0; int ret = 0;
if (!mdev || !buf) if (!buf)
return -EINVAL;
mdev_state = mdev_get_drvdata(mdev);
if (!mdev_state) {
pr_err("%s mdev_state not found\n", __func__);
return -EINVAL; return -EINVAL;
}
mutex_lock(&mdev_state->ops_lock); mutex_lock(&mdev_state->ops_lock);
...@@ -708,14 +703,26 @@ static ssize_t mdev_access(struct mdev_device *mdev, u8 *buf, size_t count, ...@@ -708,14 +703,26 @@ static ssize_t mdev_access(struct mdev_device *mdev, u8 *buf, size_t count,
return ret; return ret;
} }
static int mtty_create(struct mdev_device *mdev) static int mtty_probe(struct mdev_device *mdev)
{ {
struct mdev_state *mdev_state; struct mdev_state *mdev_state;
int nr_ports = mdev_get_type_group_id(mdev) + 1; int nr_ports = mdev_get_type_group_id(mdev) + 1;
int avail_ports = atomic_read(&mdev_avail_ports);
int ret;
do {
if (avail_ports < nr_ports)
return -ENOSPC;
} while (!atomic_try_cmpxchg(&mdev_avail_ports,
&avail_ports, avail_ports - nr_ports));
mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL); mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL);
if (mdev_state == NULL) if (mdev_state == NULL) {
atomic_add(nr_ports, &mdev_avail_ports);
return -ENOMEM; return -ENOMEM;
}
vfio_init_group_dev(&mdev_state->vdev, &mdev->dev, &mtty_dev_ops);
mdev_state->nr_ports = nr_ports; mdev_state->nr_ports = nr_ports;
mdev_state->irq_index = -1; mdev_state->irq_index = -1;
...@@ -726,63 +733,50 @@ static int mtty_create(struct mdev_device *mdev) ...@@ -726,63 +733,50 @@ static int mtty_create(struct mdev_device *mdev)
if (mdev_state->vconfig == NULL) { if (mdev_state->vconfig == NULL) {
kfree(mdev_state); kfree(mdev_state);
atomic_add(nr_ports, &mdev_avail_ports);
return -ENOMEM; return -ENOMEM;
} }
mutex_init(&mdev_state->ops_lock); mutex_init(&mdev_state->ops_lock);
mdev_state->mdev = mdev; mdev_state->mdev = mdev;
mdev_set_drvdata(mdev, mdev_state);
mtty_create_config_space(mdev_state); mtty_create_config_space(mdev_state);
mutex_lock(&mdev_list_lock); ret = vfio_register_group_dev(&mdev_state->vdev);
list_add(&mdev_state->next, &mdev_devices_list); if (ret) {
mutex_unlock(&mdev_list_lock); kfree(mdev_state);
atomic_add(nr_ports, &mdev_avail_ports);
return ret;
}
dev_set_drvdata(&mdev->dev, mdev_state);
return 0; return 0;
} }
static int mtty_remove(struct mdev_device *mdev) static void mtty_remove(struct mdev_device *mdev)
{ {
struct mdev_state *mds, *tmp_mds; struct mdev_state *mdev_state = dev_get_drvdata(&mdev->dev);
struct mdev_state *mdev_state = mdev_get_drvdata(mdev); int nr_ports = mdev_state->nr_ports;
int ret = -EINVAL;
vfio_unregister_group_dev(&mdev_state->vdev);
mutex_lock(&mdev_list_lock);
list_for_each_entry_safe(mds, tmp_mds, &mdev_devices_list, next) {
if (mdev_state == mds) {
list_del(&mdev_state->next);
mdev_set_drvdata(mdev, NULL);
kfree(mdev_state->vconfig); kfree(mdev_state->vconfig);
kfree(mdev_state); kfree(mdev_state);
ret = 0; atomic_add(nr_ports, &mdev_avail_ports);
break;
}
}
mutex_unlock(&mdev_list_lock);
return ret;
} }
static int mtty_reset(struct mdev_device *mdev) static int mtty_reset(struct mdev_state *mdev_state)
{ {
struct mdev_state *mdev_state;
if (!mdev)
return -EINVAL;
mdev_state = mdev_get_drvdata(mdev);
if (!mdev_state)
return -EINVAL;
pr_info("%s: called\n", __func__); pr_info("%s: called\n", __func__);
return 0; return 0;
} }
static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf, static ssize_t mtty_read(struct vfio_device *vdev, char __user *buf,
size_t count, loff_t *ppos) size_t count, loff_t *ppos)
{ {
struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
unsigned int done = 0; unsigned int done = 0;
int ret; int ret;
...@@ -792,7 +786,7 @@ static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf, ...@@ -792,7 +786,7 @@ static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf,
if (count >= 4 && !(*ppos % 4)) { if (count >= 4 && !(*ppos % 4)) {
u32 val; u32 val;
ret = mdev_access(mdev, (u8 *)&val, sizeof(val), ret = mdev_access(mdev_state, (u8 *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -804,7 +798,7 @@ static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf, ...@@ -804,7 +798,7 @@ static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf,
} else if (count >= 2 && !(*ppos % 2)) { } else if (count >= 2 && !(*ppos % 2)) {
u16 val; u16 val;
ret = mdev_access(mdev, (u8 *)&val, sizeof(val), ret = mdev_access(mdev_state, (u8 *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -816,7 +810,7 @@ static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf, ...@@ -816,7 +810,7 @@ static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf,
} else { } else {
u8 val; u8 val;
ret = mdev_access(mdev, (u8 *)&val, sizeof(val), ret = mdev_access(mdev_state, (u8 *)&val, sizeof(val),
*ppos, false); *ppos, false);
if (ret <= 0) if (ret <= 0)
goto read_err; goto read_err;
...@@ -839,9 +833,11 @@ static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf, ...@@ -839,9 +833,11 @@ static ssize_t mtty_read(struct mdev_device *mdev, char __user *buf,
return -EFAULT; return -EFAULT;
} }
static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf, static ssize_t mtty_write(struct vfio_device *vdev, const char __user *buf,
size_t count, loff_t *ppos) size_t count, loff_t *ppos)
{ {
struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
unsigned int done = 0; unsigned int done = 0;
int ret; int ret;
...@@ -854,7 +850,7 @@ static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf, ...@@ -854,7 +850,7 @@ static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (u8 *)&val, sizeof(val), ret = mdev_access(mdev_state, (u8 *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -866,7 +862,7 @@ static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf, ...@@ -866,7 +862,7 @@ static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (u8 *)&val, sizeof(val), ret = mdev_access(mdev_state, (u8 *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -878,7 +874,7 @@ static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf, ...@@ -878,7 +874,7 @@ static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf,
if (copy_from_user(&val, buf, sizeof(val))) if (copy_from_user(&val, buf, sizeof(val)))
goto write_err; goto write_err;
ret = mdev_access(mdev, (u8 *)&val, sizeof(val), ret = mdev_access(mdev_state, (u8 *)&val, sizeof(val),
*ppos, true); *ppos, true);
if (ret <= 0) if (ret <= 0)
goto write_err; goto write_err;
...@@ -896,19 +892,11 @@ static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf, ...@@ -896,19 +892,11 @@ static ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf,
return -EFAULT; return -EFAULT;
} }
static int mtty_set_irqs(struct mdev_device *mdev, uint32_t flags, static int mtty_set_irqs(struct mdev_state *mdev_state, uint32_t flags,
unsigned int index, unsigned int start, unsigned int index, unsigned int start,
unsigned int count, void *data) unsigned int count, void *data)
{ {
int ret = 0; int ret = 0;
struct mdev_state *mdev_state;
if (!mdev)
return -EINVAL;
mdev_state = mdev_get_drvdata(mdev);
if (!mdev_state)
return -EINVAL;
mutex_lock(&mdev_state->ops_lock); mutex_lock(&mdev_state->ops_lock);
switch (index) { switch (index) {
...@@ -1024,21 +1012,13 @@ static int mtty_trigger_interrupt(struct mdev_state *mdev_state) ...@@ -1024,21 +1012,13 @@ static int mtty_trigger_interrupt(struct mdev_state *mdev_state)
return ret; return ret;
} }
static int mtty_get_region_info(struct mdev_device *mdev, static int mtty_get_region_info(struct mdev_state *mdev_state,
struct vfio_region_info *region_info, struct vfio_region_info *region_info,
u16 *cap_type_id, void **cap_type) u16 *cap_type_id, void **cap_type)
{ {
unsigned int size = 0; unsigned int size = 0;
struct mdev_state *mdev_state;
u32 bar_index; u32 bar_index;
if (!mdev)
return -EINVAL;
mdev_state = mdev_get_drvdata(mdev);
if (!mdev_state)
return -EINVAL;
bar_index = region_info->index; bar_index = region_info->index;
if (bar_index >= VFIO_PCI_NUM_REGIONS) if (bar_index >= VFIO_PCI_NUM_REGIONS)
return -EINVAL; return -EINVAL;
...@@ -1073,8 +1053,7 @@ static int mtty_get_region_info(struct mdev_device *mdev, ...@@ -1073,8 +1053,7 @@ static int mtty_get_region_info(struct mdev_device *mdev,
return 0; return 0;
} }
static int mtty_get_irq_info(struct mdev_device *mdev, static int mtty_get_irq_info(struct vfio_irq_info *irq_info)
struct vfio_irq_info *irq_info)
{ {
switch (irq_info->index) { switch (irq_info->index) {
case VFIO_PCI_INTX_IRQ_INDEX: case VFIO_PCI_INTX_IRQ_INDEX:
...@@ -1098,8 +1077,7 @@ static int mtty_get_irq_info(struct mdev_device *mdev, ...@@ -1098,8 +1077,7 @@ static int mtty_get_irq_info(struct mdev_device *mdev,
return 0; return 0;
} }
static int mtty_get_device_info(struct mdev_device *mdev, static int mtty_get_device_info(struct vfio_device_info *dev_info)
struct vfio_device_info *dev_info)
{ {
dev_info->flags = VFIO_DEVICE_FLAGS_PCI; dev_info->flags = VFIO_DEVICE_FLAGS_PCI;
dev_info->num_regions = VFIO_PCI_NUM_REGIONS; dev_info->num_regions = VFIO_PCI_NUM_REGIONS;
...@@ -1108,19 +1086,13 @@ static int mtty_get_device_info(struct mdev_device *mdev, ...@@ -1108,19 +1086,13 @@ static int mtty_get_device_info(struct mdev_device *mdev,
return 0; return 0;
} }
static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd, static long mtty_ioctl(struct vfio_device *vdev, unsigned int cmd,
unsigned long arg) unsigned long arg)
{ {
struct mdev_state *mdev_state =
container_of(vdev, struct mdev_state, vdev);
int ret = 0; int ret = 0;
unsigned long minsz; unsigned long minsz;
struct mdev_state *mdev_state;
if (!mdev)
return -EINVAL;
mdev_state = mdev_get_drvdata(mdev);
if (!mdev_state)
return -ENODEV;
switch (cmd) { switch (cmd) {
case VFIO_DEVICE_GET_INFO: case VFIO_DEVICE_GET_INFO:
...@@ -1135,7 +1107,7 @@ static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1135,7 +1107,7 @@ static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (info.argsz < minsz) if (info.argsz < minsz)
return -EINVAL; return -EINVAL;
ret = mtty_get_device_info(mdev, &info); ret = mtty_get_device_info(&info);
if (ret) if (ret)
return ret; return ret;
...@@ -1160,7 +1132,7 @@ static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1160,7 +1132,7 @@ static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd,
if (info.argsz < minsz) if (info.argsz < minsz)
return -EINVAL; return -EINVAL;
ret = mtty_get_region_info(mdev, &info, &cap_type_id, ret = mtty_get_region_info(mdev_state, &info, &cap_type_id,
&cap_type); &cap_type);
if (ret) if (ret)
return ret; return ret;
...@@ -1184,7 +1156,7 @@ static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1184,7 +1156,7 @@ static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd,
(info.index >= mdev_state->dev_info.num_irqs)) (info.index >= mdev_state->dev_info.num_irqs))
return -EINVAL; return -EINVAL;
ret = mtty_get_irq_info(mdev, &info); ret = mtty_get_irq_info(&info);
if (ret) if (ret)
return ret; return ret;
...@@ -1218,25 +1190,25 @@ static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd, ...@@ -1218,25 +1190,25 @@ static long mtty_ioctl(struct mdev_device *mdev, unsigned int cmd,
return PTR_ERR(data); return PTR_ERR(data);
} }
ret = mtty_set_irqs(mdev, hdr.flags, hdr.index, hdr.start, ret = mtty_set_irqs(mdev_state, hdr.flags, hdr.index, hdr.start,
hdr.count, data); hdr.count, data);
kfree(ptr); kfree(ptr);
return ret; return ret;
} }
case VFIO_DEVICE_RESET: case VFIO_DEVICE_RESET:
return mtty_reset(mdev); return mtty_reset(mdev_state);
} }
return -ENOTTY; return -ENOTTY;
} }
static int mtty_open(struct mdev_device *mdev) static int mtty_open(struct vfio_device *vdev)
{ {
pr_info("%s\n", __func__); pr_info("%s\n", __func__);
return 0; return 0;
} }
static void mtty_close(struct mdev_device *mdev) static void mtty_close(struct vfio_device *mdev)
{ {
pr_info("%s\n", __func__); pr_info("%s\n", __func__);
} }
...@@ -1308,14 +1280,9 @@ static ssize_t available_instances_show(struct mdev_type *mtype, ...@@ -1308,14 +1280,9 @@ static ssize_t available_instances_show(struct mdev_type *mtype,
struct mdev_type_attribute *attr, struct mdev_type_attribute *attr,
char *buf) char *buf)
{ {
struct mdev_state *mds;
unsigned int ports = mtype_get_type_group_id(mtype) + 1; unsigned int ports = mtype_get_type_group_id(mtype) + 1;
int used = 0;
list_for_each_entry(mds, &mdev_devices_list, next)
used += mds->nr_ports;
return sprintf(buf, "%d\n", (MAX_MTTYS - used)/ports); return sprintf(buf, "%d\n", atomic_read(&mdev_avail_ports) / ports);
} }
static MDEV_TYPE_ATTR_RO(available_instances); static MDEV_TYPE_ATTR_RO(available_instances);
...@@ -1351,13 +1318,8 @@ static struct attribute_group *mdev_type_groups[] = { ...@@ -1351,13 +1318,8 @@ static struct attribute_group *mdev_type_groups[] = {
NULL, NULL,
}; };
static const struct mdev_parent_ops mdev_fops = { static const struct vfio_device_ops mtty_dev_ops = {
.owner = THIS_MODULE, .name = "vfio-mtty",
.dev_attr_groups = mtty_dev_groups,
.mdev_attr_groups = mdev_dev_groups,
.supported_type_groups = mdev_type_groups,
.create = mtty_create,
.remove = mtty_remove,
.open = mtty_open, .open = mtty_open,
.release = mtty_close, .release = mtty_close,
.read = mtty_read, .read = mtty_read,
...@@ -1365,6 +1327,24 @@ static const struct mdev_parent_ops mdev_fops = { ...@@ -1365,6 +1327,24 @@ static const struct mdev_parent_ops mdev_fops = {
.ioctl = mtty_ioctl, .ioctl = mtty_ioctl,
}; };
static struct mdev_driver mtty_driver = {
.driver = {
.name = "mtty",
.owner = THIS_MODULE,
.mod_name = KBUILD_MODNAME,
.dev_groups = mdev_dev_groups,
},
.probe = mtty_probe,
.remove = mtty_remove,
};
static const struct mdev_parent_ops mdev_fops = {
.owner = THIS_MODULE,
.device_driver = &mtty_driver,
.dev_attr_groups = mtty_dev_groups,
.supported_type_groups = mdev_type_groups,
};
static void mtty_device_release(struct device *dev) static void mtty_device_release(struct device *dev)
{ {
dev_dbg(dev, "mtty: released\n"); dev_dbg(dev, "mtty: released\n");
...@@ -1393,12 +1373,16 @@ static int __init mtty_dev_init(void) ...@@ -1393,12 +1373,16 @@ static int __init mtty_dev_init(void)
pr_info("major_number:%d\n", MAJOR(mtty_dev.vd_devt)); pr_info("major_number:%d\n", MAJOR(mtty_dev.vd_devt));
ret = mdev_register_driver(&mtty_driver);
if (ret)
goto err_cdev;
mtty_dev.vd_class = class_create(THIS_MODULE, MTTY_CLASS_NAME); mtty_dev.vd_class = class_create(THIS_MODULE, MTTY_CLASS_NAME);
if (IS_ERR(mtty_dev.vd_class)) { if (IS_ERR(mtty_dev.vd_class)) {
pr_err("Error: failed to register mtty_dev class\n"); pr_err("Error: failed to register mtty_dev class\n");
ret = PTR_ERR(mtty_dev.vd_class); ret = PTR_ERR(mtty_dev.vd_class);
goto failed1; goto err_driver;
} }
mtty_dev.dev.class = mtty_dev.vd_class; mtty_dev.dev.class = mtty_dev.vd_class;
...@@ -1407,28 +1391,22 @@ static int __init mtty_dev_init(void) ...@@ -1407,28 +1391,22 @@ static int __init mtty_dev_init(void)
ret = device_register(&mtty_dev.dev); ret = device_register(&mtty_dev.dev);
if (ret) if (ret)
goto failed2; goto err_class;
ret = mdev_register_device(&mtty_dev.dev, &mdev_fops); ret = mdev_register_device(&mtty_dev.dev, &mdev_fops);
if (ret) if (ret)
goto failed3; goto err_device;
return 0;
mutex_init(&mdev_list_lock);
INIT_LIST_HEAD(&mdev_devices_list);
goto all_done;
failed3:
err_device:
device_unregister(&mtty_dev.dev); device_unregister(&mtty_dev.dev);
failed2: err_class:
class_destroy(mtty_dev.vd_class); class_destroy(mtty_dev.vd_class);
err_driver:
failed1: mdev_unregister_driver(&mtty_driver);
err_cdev:
cdev_del(&mtty_dev.vd_cdev); cdev_del(&mtty_dev.vd_cdev);
unregister_chrdev_region(mtty_dev.vd_devt, MINORMASK + 1); unregister_chrdev_region(mtty_dev.vd_devt, MINORMASK + 1);
all_done:
return ret; return ret;
} }
...@@ -1439,6 +1417,7 @@ static void __exit mtty_dev_exit(void) ...@@ -1439,6 +1417,7 @@ static void __exit mtty_dev_exit(void)
device_unregister(&mtty_dev.dev); device_unregister(&mtty_dev.dev);
idr_destroy(&mtty_dev.vd_idr); idr_destroy(&mtty_dev.vd_idr);
mdev_unregister_driver(&mtty_driver);
cdev_del(&mtty_dev.vd_cdev); cdev_del(&mtty_dev.vd_cdev);
unregister_chrdev_region(mtty_dev.vd_devt, MINORMASK + 1); unregister_chrdev_region(mtty_dev.vd_devt, MINORMASK + 1);
class_destroy(mtty_dev.vd_class); class_destroy(mtty_dev.vd_class);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment