Commit 3c27a36f authored by Marcelo Diop-Gonzalez's avatar Marcelo Diop-Gonzalez Committed by Greg Kroah-Hartman

staging: vc04_services: use kref + RCU to reference count services

Currently reference counts are implemented by locking service_spinlock
and then incrementing the service's ->ref_count field, calling
kfree() when the last reference has been dropped. But at the same
time, there's code in multiple places that dereferences pointers
to services without having a reference, so there could be a race there.

It should be possible to avoid taking any lock in unlock_service()
or service_release() because we are setting a single array element
to NULL, and on service creation, a mutex is locked before looking
for a NULL spot to put the new service in.

Using a struct kref and RCU-delaying the freeing of services fixes
this race condition while still making it possible to skip
grabbing a reference in many places. Also it avoids the need to
acquire a single spinlock when e.g. taking a reference on
state->services[i] when somebody else is in the middle of taking
a reference on state->services[j].
Signed-off-by: default avatarMarcelo Diop-Gonzalez <marcgonzalez@google.com>
Link: https://lore.kernel.org/r/3bf6f1ec6ace64d7072025505e165b8dd18b25ca.1581532523.git.marcgonzalez@google.comSigned-off-by: default avatarGreg Kroah-Hartman <gregkh@linuxfoundation.org>
parent 0e35fa61
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <linux/platform_device.h> #include <linux/platform_device.h>
#include <linux/compat.h> #include <linux/compat.h>
#include <linux/dma-mapping.h> #include <linux/dma-mapping.h>
#include <linux/rcupdate.h>
#include <soc/bcm2835/raspberrypi-firmware.h> #include <soc/bcm2835/raspberrypi-firmware.h>
#include "vchiq_core.h" #include "vchiq_core.h"
...@@ -2096,10 +2097,12 @@ int vchiq_dump_platform_instances(void *dump_context) ...@@ -2096,10 +2097,12 @@ int vchiq_dump_platform_instances(void *dump_context)
/* There is no list of instances, so instead scan all services, /* There is no list of instances, so instead scan all services,
marking those that have been dumped. */ marking those that have been dumped. */
rcu_read_lock();
for (i = 0; i < state->unused_service; i++) { for (i = 0; i < state->unused_service; i++) {
struct vchiq_service *service = state->services[i]; struct vchiq_service *service;
struct vchiq_instance *instance; struct vchiq_instance *instance;
service = rcu_dereference(state->services[i]);
if (!service || service->base.callback != service_callback) if (!service || service->base.callback != service_callback)
continue; continue;
...@@ -2107,18 +2110,26 @@ int vchiq_dump_platform_instances(void *dump_context) ...@@ -2107,18 +2110,26 @@ int vchiq_dump_platform_instances(void *dump_context)
if (instance) if (instance)
instance->mark = 0; instance->mark = 0;
} }
rcu_read_unlock();
for (i = 0; i < state->unused_service; i++) { for (i = 0; i < state->unused_service; i++) {
struct vchiq_service *service = state->services[i]; struct vchiq_service *service;
struct vchiq_instance *instance; struct vchiq_instance *instance;
int err; int err;
if (!service || service->base.callback != service_callback) rcu_read_lock();
service = rcu_dereference(state->services[i]);
if (!service || service->base.callback != service_callback) {
rcu_read_unlock();
continue; continue;
}
instance = service->instance; instance = service->instance;
if (!instance || instance->mark) if (!instance || instance->mark) {
rcu_read_unlock();
continue; continue;
}
rcu_read_unlock();
len = snprintf(buf, sizeof(buf), len = snprintf(buf, sizeof(buf),
"Instance %pK: pid %d,%s completions %d/%d", "Instance %pK: pid %d,%s completions %d/%d",
...@@ -2128,7 +2139,6 @@ int vchiq_dump_platform_instances(void *dump_context) ...@@ -2128,7 +2139,6 @@ int vchiq_dump_platform_instances(void *dump_context)
instance->completion_insert - instance->completion_insert -
instance->completion_remove, instance->completion_remove,
MAX_COMPLETIONS); MAX_COMPLETIONS);
err = vchiq_dump(dump_context, buf, len + 1); err = vchiq_dump(dump_context, buf, len + 1);
if (err) if (err)
return err; return err;
...@@ -2585,8 +2595,10 @@ vchiq_dump_service_use_state(struct vchiq_state *state) ...@@ -2585,8 +2595,10 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
if (active_services > MAX_SERVICES) if (active_services > MAX_SERVICES)
only_nonzero = 1; only_nonzero = 1;
rcu_read_lock();
for (i = 0; i < active_services; i++) { for (i = 0; i < active_services; i++) {
struct vchiq_service *service_ptr = state->services[i]; struct vchiq_service *service_ptr =
rcu_dereference(state->services[i]);
if (!service_ptr) if (!service_ptr)
continue; continue;
...@@ -2604,6 +2616,7 @@ vchiq_dump_service_use_state(struct vchiq_state *state) ...@@ -2604,6 +2616,7 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
if (found >= MAX_SERVICES) if (found >= MAX_SERVICES)
break; break;
} }
rcu_read_unlock();
read_unlock_bh(&arm_state->susp_res_lock); read_unlock_bh(&arm_state->susp_res_lock);
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <linux/mutex.h> #include <linux/mutex.h>
#include <linux/completion.h> #include <linux/completion.h>
#include <linux/kthread.h> #include <linux/kthread.h>
#include <linux/kref.h>
#include <linux/rcupdate.h>
#include <linux/wait.h> #include <linux/wait.h>
#include "vchiq_cfg.h" #include "vchiq_cfg.h"
...@@ -251,7 +253,8 @@ struct vchiq_slot_info { ...@@ -251,7 +253,8 @@ struct vchiq_slot_info {
struct vchiq_service { struct vchiq_service {
struct vchiq_service_base base; struct vchiq_service_base base;
unsigned int handle; unsigned int handle;
unsigned int ref_count; struct kref ref_count;
struct rcu_head rcu;
int srvstate; int srvstate;
vchiq_userdata_term userdata_term; vchiq_userdata_term userdata_term;
unsigned int localport; unsigned int localport;
...@@ -464,7 +467,7 @@ struct vchiq_state { ...@@ -464,7 +467,7 @@ struct vchiq_state {
int error_count; int error_count;
} stats; } stats;
struct vchiq_service *services[VCHIQ_MAX_SERVICES]; struct vchiq_service __rcu *services[VCHIQ_MAX_SERVICES];
struct vchiq_service_quota service_quotas[VCHIQ_MAX_SERVICES]; struct vchiq_service_quota service_quotas[VCHIQ_MAX_SERVICES];
struct vchiq_slot_info slot_info[VCHIQ_MAX_SLOTS]; struct vchiq_slot_info slot_info[VCHIQ_MAX_SLOTS];
...@@ -545,12 +548,13 @@ request_poll(struct vchiq_state *state, struct vchiq_service *service, ...@@ -545,12 +548,13 @@ request_poll(struct vchiq_state *state, struct vchiq_service *service,
static inline struct vchiq_service * static inline struct vchiq_service *
handle_to_service(unsigned int handle) handle_to_service(unsigned int handle)
{ {
int idx = handle & (VCHIQ_MAX_SERVICES - 1);
struct vchiq_state *state = vchiq_states[(handle / VCHIQ_MAX_SERVICES) & struct vchiq_state *state = vchiq_states[(handle / VCHIQ_MAX_SERVICES) &
(VCHIQ_MAX_STATES - 1)]; (VCHIQ_MAX_STATES - 1)];
if (!state) if (!state)
return NULL; return NULL;
return rcu_dereference(state->services[idx]);
return state->services[handle & (VCHIQ_MAX_SERVICES - 1)];
} }
extern struct vchiq_service * extern struct vchiq_service *
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment