Commit a659f91a authored by David S. Miller's avatar David S. Miller

Merge branch 'crypto_async'

Tadeusz Struk says:

====================
Add support for async socket operations

After the iocb parameter has been removed from sendmsg() and recvmsg() ops
the socket layer, and the network stack no longer support async operations.
This patch set adds support for asynchronous operations on sockets back.

Changes in v3:
* As sugested by Al Viro instead of adding new functions aio_sendmsg
  and aio_recvmsg, added a ptr to iocb into the kernel-side msghdr structure.
  This way no change to aio.c is required.

Changes in v2:
* removed redundant total_size param from aio_sendmsg and aio_recvmsg functions
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 8f2ddaac a596999b
...@@ -358,8 +358,8 @@ int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len) ...@@ -358,8 +358,8 @@ int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len)
npages = (off + n + PAGE_SIZE - 1) >> PAGE_SHIFT; npages = (off + n + PAGE_SIZE - 1) >> PAGE_SHIFT;
if (WARN_ON(npages == 0)) if (WARN_ON(npages == 0))
return -EINVAL; return -EINVAL;
/* Add one extra for linking */
sg_init_table(sgl->sg, npages); sg_init_table(sgl->sg, npages + 1);
for (i = 0, len = n; i < npages; i++) { for (i = 0, len = n; i < npages; i++) {
int plen = min_t(int, len, PAGE_SIZE - off); int plen = min_t(int, len, PAGE_SIZE - off);
...@@ -369,18 +369,26 @@ int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len) ...@@ -369,18 +369,26 @@ int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len)
off = 0; off = 0;
len -= plen; len -= plen;
} }
sg_mark_end(sgl->sg + npages - 1);
sgl->npages = npages;
return n; return n;
} }
EXPORT_SYMBOL_GPL(af_alg_make_sg); EXPORT_SYMBOL_GPL(af_alg_make_sg);
void af_alg_link_sg(struct af_alg_sgl *sgl_prev, struct af_alg_sgl *sgl_new)
{
sg_unmark_end(sgl_prev->sg + sgl_prev->npages - 1);
sg_chain(sgl_prev->sg, sgl_prev->npages + 1, sgl_new->sg);
}
EXPORT_SYMBOL(af_alg_link_sg);
void af_alg_free_sg(struct af_alg_sgl *sgl) void af_alg_free_sg(struct af_alg_sgl *sgl)
{ {
int i; int i;
i = 0; for (i = 0; i < sgl->npages; i++)
do {
put_page(sgl->pages[i]); put_page(sgl->pages[i]);
} while (!sg_is_last(sgl->sg + (i++)));
} }
EXPORT_SYMBOL_GPL(af_alg_free_sg); EXPORT_SYMBOL_GPL(af_alg_free_sg);
......
...@@ -39,6 +39,7 @@ struct skcipher_ctx { ...@@ -39,6 +39,7 @@ struct skcipher_ctx {
struct af_alg_completion completion; struct af_alg_completion completion;
atomic_t inflight;
unsigned used; unsigned used;
unsigned int len; unsigned int len;
...@@ -49,9 +50,65 @@ struct skcipher_ctx { ...@@ -49,9 +50,65 @@ struct skcipher_ctx {
struct ablkcipher_request req; struct ablkcipher_request req;
}; };
struct skcipher_async_rsgl {
struct af_alg_sgl sgl;
struct list_head list;
};
struct skcipher_async_req {
struct kiocb *iocb;
struct skcipher_async_rsgl first_sgl;
struct list_head list;
struct scatterlist *tsg;
char iv[];
};
#define GET_SREQ(areq, ctx) (struct skcipher_async_req *)((char *)areq + \
crypto_ablkcipher_reqsize(crypto_ablkcipher_reqtfm(&ctx->req)))
#define GET_REQ_SIZE(ctx) \
crypto_ablkcipher_reqsize(crypto_ablkcipher_reqtfm(&ctx->req))
#define GET_IV_SIZE(ctx) \
crypto_ablkcipher_ivsize(crypto_ablkcipher_reqtfm(&ctx->req))
#define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \ #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
sizeof(struct scatterlist) - 1) sizeof(struct scatterlist) - 1)
static void skcipher_free_async_sgls(struct skcipher_async_req *sreq)
{
struct skcipher_async_rsgl *rsgl, *tmp;
struct scatterlist *sgl;
struct scatterlist *sg;
int i, n;
list_for_each_entry_safe(rsgl, tmp, &sreq->list, list) {
af_alg_free_sg(&rsgl->sgl);
if (rsgl != &sreq->first_sgl)
kfree(rsgl);
}
sgl = sreq->tsg;
n = sg_nents(sgl);
for_each_sg(sgl, sg, n, i)
put_page(sg_page(sg));
kfree(sreq->tsg);
}
static void skcipher_async_cb(struct crypto_async_request *req, int err)
{
struct sock *sk = req->data;
struct alg_sock *ask = alg_sk(sk);
struct skcipher_ctx *ctx = ask->private;
struct skcipher_async_req *sreq = GET_SREQ(req, ctx);
struct kiocb *iocb = sreq->iocb;
atomic_dec(&ctx->inflight);
skcipher_free_async_sgls(sreq);
kfree(req);
aio_complete(iocb, err, err);
}
static inline int skcipher_sndbuf(struct sock *sk) static inline int skcipher_sndbuf(struct sock *sk)
{ {
struct alg_sock *ask = alg_sk(sk); struct alg_sock *ask = alg_sk(sk);
...@@ -96,7 +153,7 @@ static int skcipher_alloc_sgl(struct sock *sk) ...@@ -96,7 +153,7 @@ static int skcipher_alloc_sgl(struct sock *sk)
return 0; return 0;
} }
static void skcipher_pull_sgl(struct sock *sk, int used) static void skcipher_pull_sgl(struct sock *sk, int used, int put)
{ {
struct alg_sock *ask = alg_sk(sk); struct alg_sock *ask = alg_sk(sk);
struct skcipher_ctx *ctx = ask->private; struct skcipher_ctx *ctx = ask->private;
...@@ -123,7 +180,7 @@ static void skcipher_pull_sgl(struct sock *sk, int used) ...@@ -123,7 +180,7 @@ static void skcipher_pull_sgl(struct sock *sk, int used)
if (sg[i].length) if (sg[i].length)
return; return;
if (put)
put_page(sg_page(sg + i)); put_page(sg_page(sg + i));
sg_assign_page(sg + i, NULL); sg_assign_page(sg + i, NULL);
} }
...@@ -143,7 +200,7 @@ static void skcipher_free_sgl(struct sock *sk) ...@@ -143,7 +200,7 @@ static void skcipher_free_sgl(struct sock *sk)
struct alg_sock *ask = alg_sk(sk); struct alg_sock *ask = alg_sk(sk);
struct skcipher_ctx *ctx = ask->private; struct skcipher_ctx *ctx = ask->private;
skcipher_pull_sgl(sk, ctx->used); skcipher_pull_sgl(sk, ctx->used, 1);
} }
static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags) static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
...@@ -424,8 +481,149 @@ static ssize_t skcipher_sendpage(struct socket *sock, struct page *page, ...@@ -424,8 +481,149 @@ static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
return err ?: size; return err ?: size;
} }
static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg, static int skcipher_all_sg_nents(struct skcipher_ctx *ctx)
size_t ignored, int flags) {
struct skcipher_sg_list *sgl;
struct scatterlist *sg;
int nents = 0;
list_for_each_entry(sgl, &ctx->tsgl, list) {
sg = sgl->sg;
while (!sg->length)
sg++;
nents += sg_nents(sg);
}
return nents;
}
static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
int flags)
{
struct sock *sk = sock->sk;
struct alg_sock *ask = alg_sk(sk);
struct skcipher_ctx *ctx = ask->private;
struct skcipher_sg_list *sgl;
struct scatterlist *sg;
struct skcipher_async_req *sreq;
struct ablkcipher_request *req;
struct skcipher_async_rsgl *last_rsgl = NULL;
unsigned int len = 0, tx_nents = skcipher_all_sg_nents(ctx);
unsigned int reqlen = sizeof(struct skcipher_async_req) +
GET_REQ_SIZE(ctx) + GET_IV_SIZE(ctx);
int i = 0;
int err = -ENOMEM;
lock_sock(sk);
req = kmalloc(reqlen, GFP_KERNEL);
if (unlikely(!req))
goto unlock;
sreq = GET_SREQ(req, ctx);
sreq->iocb = msg->msg_iocb;
memset(&sreq->first_sgl, '\0', sizeof(struct skcipher_async_rsgl));
INIT_LIST_HEAD(&sreq->list);
sreq->tsg = kcalloc(tx_nents, sizeof(*sg), GFP_KERNEL);
if (unlikely(!sreq->tsg)) {
kfree(req);
goto unlock;
}
sg_init_table(sreq->tsg, tx_nents);
memcpy(sreq->iv, ctx->iv, GET_IV_SIZE(ctx));
ablkcipher_request_set_tfm(req, crypto_ablkcipher_reqtfm(&ctx->req));
ablkcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
skcipher_async_cb, sk);
while (iov_iter_count(&msg->msg_iter)) {
struct skcipher_async_rsgl *rsgl;
unsigned long used;
if (!ctx->used) {
err = skcipher_wait_for_data(sk, flags);
if (err)
goto free;
}
sgl = list_first_entry(&ctx->tsgl,
struct skcipher_sg_list, list);
sg = sgl->sg;
while (!sg->length)
sg++;
used = min_t(unsigned long, ctx->used,
iov_iter_count(&msg->msg_iter));
used = min_t(unsigned long, used, sg->length);
if (i == tx_nents) {
struct scatterlist *tmp;
int x;
/* Ran out of tx slots in async request
* need to expand */
tmp = kcalloc(tx_nents * 2, sizeof(*tmp),
GFP_KERNEL);
if (!tmp)
goto free;
sg_init_table(tmp, tx_nents * 2);
for (x = 0; x < tx_nents; x++)
sg_set_page(&tmp[x], sg_page(&sreq->tsg[x]),
sreq->tsg[x].length,
sreq->tsg[x].offset);
kfree(sreq->tsg);
sreq->tsg = tmp;
tx_nents *= 2;
}
/* Need to take over the tx sgl from ctx
* to the asynch req - these sgls will be freed later */
sg_set_page(sreq->tsg + i++, sg_page(sg), sg->length,
sg->offset);
if (list_empty(&sreq->list)) {
rsgl = &sreq->first_sgl;
list_add_tail(&rsgl->list, &sreq->list);
} else {
rsgl = kzalloc(sizeof(*rsgl), GFP_KERNEL);
if (!rsgl) {
err = -ENOMEM;
goto free;
}
list_add_tail(&rsgl->list, &sreq->list);
}
used = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, used);
err = used;
if (used < 0)
goto free;
if (last_rsgl)
af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
last_rsgl = rsgl;
len += used;
skcipher_pull_sgl(sk, used, 0);
iov_iter_advance(&msg->msg_iter, used);
}
ablkcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
len, sreq->iv);
err = ctx->enc ? crypto_ablkcipher_encrypt(req) :
crypto_ablkcipher_decrypt(req);
if (err == -EINPROGRESS) {
atomic_inc(&ctx->inflight);
err = -EIOCBQUEUED;
goto unlock;
}
free:
skcipher_free_async_sgls(sreq);
kfree(req);
unlock:
skcipher_wmem_wakeup(sk);
release_sock(sk);
return err;
}
static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
int flags)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct alg_sock *ask = alg_sk(sk); struct alg_sock *ask = alg_sk(sk);
...@@ -484,7 +682,7 @@ static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg, ...@@ -484,7 +682,7 @@ static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
goto unlock; goto unlock;
copied += used; copied += used;
skcipher_pull_sgl(sk, used); skcipher_pull_sgl(sk, used, 1);
iov_iter_advance(&msg->msg_iter, used); iov_iter_advance(&msg->msg_iter, used);
} }
...@@ -497,6 +695,13 @@ static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg, ...@@ -497,6 +695,13 @@ static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
return copied ?: err; return copied ?: err;
} }
static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
size_t ignored, int flags)
{
return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
skcipher_recvmsg_async(sock, msg, flags) :
skcipher_recvmsg_sync(sock, msg, flags);
}
static unsigned int skcipher_poll(struct file *file, struct socket *sock, static unsigned int skcipher_poll(struct file *file, struct socket *sock,
poll_table *wait) poll_table *wait)
...@@ -555,12 +760,25 @@ static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen) ...@@ -555,12 +760,25 @@ static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
return crypto_ablkcipher_setkey(private, key, keylen); return crypto_ablkcipher_setkey(private, key, keylen);
} }
static void skcipher_wait(struct sock *sk)
{
struct alg_sock *ask = alg_sk(sk);
struct skcipher_ctx *ctx = ask->private;
int ctr = 0;
while (atomic_read(&ctx->inflight) && ctr++ < 100)
msleep(100);
}
static void skcipher_sock_destruct(struct sock *sk) static void skcipher_sock_destruct(struct sock *sk)
{ {
struct alg_sock *ask = alg_sk(sk); struct alg_sock *ask = alg_sk(sk);
struct skcipher_ctx *ctx = ask->private; struct skcipher_ctx *ctx = ask->private;
struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(&ctx->req); struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(&ctx->req);
if (atomic_read(&ctx->inflight))
skcipher_wait(sk);
skcipher_free_sgl(sk); skcipher_free_sgl(sk);
sock_kzfree_s(sk, ctx->iv, crypto_ablkcipher_ivsize(tfm)); sock_kzfree_s(sk, ctx->iv, crypto_ablkcipher_ivsize(tfm));
sock_kfree_s(sk, ctx, ctx->len); sock_kfree_s(sk, ctx, ctx->len);
...@@ -592,6 +810,7 @@ static int skcipher_accept_parent(void *private, struct sock *sk) ...@@ -592,6 +810,7 @@ static int skcipher_accept_parent(void *private, struct sock *sk)
ctx->more = 0; ctx->more = 0;
ctx->merge = 0; ctx->merge = 0;
ctx->enc = 0; ctx->enc = 0;
atomic_set(&ctx->inflight, 0);
af_alg_init_completion(&ctx->completion); af_alg_init_completion(&ctx->completion);
ask->private = ctx; ask->private = ctx;
......
...@@ -58,8 +58,9 @@ struct af_alg_type { ...@@ -58,8 +58,9 @@ struct af_alg_type {
}; };
struct af_alg_sgl { struct af_alg_sgl {
struct scatterlist sg[ALG_MAX_PAGES]; struct scatterlist sg[ALG_MAX_PAGES + 1];
struct page *pages[ALG_MAX_PAGES]; struct page *pages[ALG_MAX_PAGES];
unsigned int npages;
}; };
int af_alg_register_type(const struct af_alg_type *type); int af_alg_register_type(const struct af_alg_type *type);
...@@ -70,6 +71,7 @@ int af_alg_accept(struct sock *sk, struct socket *newsock); ...@@ -70,6 +71,7 @@ int af_alg_accept(struct sock *sk, struct socket *newsock);
int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len); int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len);
void af_alg_free_sg(struct af_alg_sgl *sgl); void af_alg_free_sg(struct af_alg_sgl *sgl);
void af_alg_link_sg(struct af_alg_sgl *sgl_prev, struct af_alg_sgl *sgl_new);
int af_alg_cmsg_send(struct msghdr *msg, struct af_alg_control *con); int af_alg_cmsg_send(struct msghdr *msg, struct af_alg_control *con);
......
...@@ -51,6 +51,7 @@ struct msghdr { ...@@ -51,6 +51,7 @@ struct msghdr {
void *msg_control; /* ancillary data */ void *msg_control; /* ancillary data */
__kernel_size_t msg_controllen; /* ancillary data buffer length */ __kernel_size_t msg_controllen; /* ancillary data buffer length */
unsigned int msg_flags; /* flags on received message */ unsigned int msg_flags; /* flags on received message */
struct kiocb *msg_iocb; /* ptr to iocb for async requests */
}; };
struct user_msghdr { struct user_msghdr {
......
...@@ -79,6 +79,8 @@ ssize_t get_compat_msghdr(struct msghdr *kmsg, ...@@ -79,6 +79,8 @@ ssize_t get_compat_msghdr(struct msghdr *kmsg,
if (nr_segs > UIO_MAXIOV) if (nr_segs > UIO_MAXIOV)
return -EMSGSIZE; return -EMSGSIZE;
kmsg->msg_iocb = NULL;
err = compat_rw_copy_check_uvector(save_addr ? READ : WRITE, err = compat_rw_copy_check_uvector(save_addr ? READ : WRITE,
compat_ptr(uiov), nr_segs, compat_ptr(uiov), nr_segs,
UIO_FASTIOV, *iov, iov); UIO_FASTIOV, *iov, iov);
......
...@@ -798,7 +798,8 @@ static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to) ...@@ -798,7 +798,8 @@ static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to)
{ {
struct file *file = iocb->ki_filp; struct file *file = iocb->ki_filp;
struct socket *sock = file->private_data; struct socket *sock = file->private_data;
struct msghdr msg = {.msg_iter = *to}; struct msghdr msg = {.msg_iter = *to,
.msg_iocb = iocb};
ssize_t res; ssize_t res;
if (file->f_flags & O_NONBLOCK) if (file->f_flags & O_NONBLOCK)
...@@ -819,7 +820,8 @@ static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from) ...@@ -819,7 +820,8 @@ static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from)
{ {
struct file *file = iocb->ki_filp; struct file *file = iocb->ki_filp;
struct socket *sock = file->private_data; struct socket *sock = file->private_data;
struct msghdr msg = {.msg_iter = *from}; struct msghdr msg = {.msg_iter = *from,
.msg_iocb = iocb};
ssize_t res; ssize_t res;
if (iocb->ki_pos != 0) if (iocb->ki_pos != 0)
...@@ -1894,6 +1896,8 @@ static ssize_t copy_msghdr_from_user(struct msghdr *kmsg, ...@@ -1894,6 +1896,8 @@ static ssize_t copy_msghdr_from_user(struct msghdr *kmsg,
if (nr_segs > UIO_MAXIOV) if (nr_segs > UIO_MAXIOV)
return -EMSGSIZE; return -EMSGSIZE;
kmsg->msg_iocb = NULL;
err = rw_copy_check_uvector(save_addr ? READ : WRITE, err = rw_copy_check_uvector(save_addr ? READ : WRITE,
uiov, nr_segs, uiov, nr_segs,
UIO_FASTIOV, *iov, iov); UIO_FASTIOV, *iov, iov);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment