Commit 72904d7b authored by David Howells's avatar David Howells

rxrpc, afs: Allow afs to pin rxrpc_peer objects

Change rxrpc's API such that:

 (1) A new function, rxrpc_kernel_lookup_peer(), is provided to look up an
     rxrpc_peer record for a remote address and a corresponding function,
     rxrpc_kernel_put_peer(), is provided to dispose of it again.

 (2) When setting up a call, the rxrpc_peer object used during a call is
     now passed in rather than being set up by rxrpc_connect_call().  For
     afs, this meenat passing it to rxrpc_kernel_begin_call() rather than
     the full address (the service ID then has to be passed in as a
     separate parameter).

 (3) A new function, rxrpc_kernel_remote_addr(), is added so that afs can
     get a pointer to the transport address for display purposed, and
     another, rxrpc_kernel_remote_srx(), to gain a pointer to the full
     rxrpc address.

 (4) The function to retrieve the RTT from a call, rxrpc_kernel_get_srtt(),
     is then altered to take a peer.  This now returns the RTT or -1 if
     there are insufficient samples.

 (5) Rename rxrpc_kernel_get_peer() to rxrpc_kernel_call_get_peer().

 (6) Provide a new function, rxrpc_kernel_get_peer(), to get a ref on a
     peer the caller already has.

This allows the afs filesystem to pin the rxrpc_peer records that it is
using, allowing faster lookups and pointer comparisons rather than
comparing sockaddr_rxrpc contents.  It also makes it easier to get hold of
the RTT.  The following changes are made to afs:

 (1) The addr_list struct's addrs[] elements now hold a peer struct pointer
     and a service ID rather than a sockaddr_rxrpc.

 (2) When displaying the transport address, rxrpc_kernel_remote_addr() is
     used.

 (3) The port arg is removed from afs_alloc_addrlist() since it's always
     overridden.

 (4) afs_merge_fs_addr4() and afs_merge_fs_addr6() do peer lookup and may
     now return an error that must be handled.

 (5) afs_find_server() now takes a peer pointer to specify the address.

 (6) afs_find_server(), afs_compare_fs_alists() and afs_merge_fs_addr[46]{}
     now do peer pointer comparison rather than address comparison.
Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
cc: Marc Dionne <marc.dionne@auristor.com>
cc: linux-afs@lists.infradead.org
parent 07f3502b
...@@ -13,26 +13,33 @@ ...@@ -13,26 +13,33 @@
#include "internal.h" #include "internal.h"
#include "afs_fs.h" #include "afs_fs.h"
static void afs_free_addrlist(struct rcu_head *rcu)
{
struct afs_addr_list *alist = container_of(rcu, struct afs_addr_list, rcu);
unsigned int i;
for (i = 0; i < alist->nr_addrs; i++)
rxrpc_kernel_put_peer(alist->addrs[i].peer);
}
/* /*
* Release an address list. * Release an address list.
*/ */
void afs_put_addrlist(struct afs_addr_list *alist) void afs_put_addrlist(struct afs_addr_list *alist)
{ {
if (alist && refcount_dec_and_test(&alist->usage)) if (alist && refcount_dec_and_test(&alist->usage))
kfree_rcu(alist, rcu); call_rcu(&alist->rcu, afs_free_addrlist);
} }
/* /*
* Allocate an address list. * Allocate an address list.
*/ */
struct afs_addr_list *afs_alloc_addrlist(unsigned int nr, struct afs_addr_list *afs_alloc_addrlist(unsigned int nr, u16 service_id)
unsigned short service,
unsigned short port)
{ {
struct afs_addr_list *alist; struct afs_addr_list *alist;
unsigned int i; unsigned int i;
_enter("%u,%u,%u", nr, service, port); _enter("%u,%u", nr, service_id);
if (nr > AFS_MAX_ADDRESSES) if (nr > AFS_MAX_ADDRESSES)
nr = AFS_MAX_ADDRESSES; nr = AFS_MAX_ADDRESSES;
...@@ -44,16 +51,8 @@ struct afs_addr_list *afs_alloc_addrlist(unsigned int nr, ...@@ -44,16 +51,8 @@ struct afs_addr_list *afs_alloc_addrlist(unsigned int nr,
refcount_set(&alist->usage, 1); refcount_set(&alist->usage, 1);
alist->max_addrs = nr; alist->max_addrs = nr;
for (i = 0; i < nr; i++) { for (i = 0; i < nr; i++)
struct sockaddr_rxrpc *srx = &alist->addrs[i].srx; alist->addrs[i].service_id = service_id;
srx->srx_family = AF_RXRPC;
srx->srx_service = service;
srx->transport_type = SOCK_DGRAM;
srx->transport_len = sizeof(srx->transport.sin6);
srx->transport.sin6.sin6_family = AF_INET6;
srx->transport.sin6.sin6_port = htons(port);
}
return alist; return alist;
} }
...@@ -126,7 +125,7 @@ struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net, ...@@ -126,7 +125,7 @@ struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net,
if (!vllist->servers[0].server) if (!vllist->servers[0].server)
goto error_vl; goto error_vl;
alist = afs_alloc_addrlist(nr, service, AFS_VL_PORT); alist = afs_alloc_addrlist(nr, service);
if (!alist) if (!alist)
goto error; goto error;
...@@ -197,9 +196,11 @@ struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net, ...@@ -197,9 +196,11 @@ struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net,
} }
if (family == AF_INET) if (family == AF_INET)
afs_merge_fs_addr4(alist, x[0], xport); ret = afs_merge_fs_addr4(net, alist, x[0], xport);
else else
afs_merge_fs_addr6(alist, x, xport); ret = afs_merge_fs_addr6(net, alist, x, xport);
if (ret < 0)
goto error;
} while (p < end); } while (p < end);
...@@ -271,25 +272,33 @@ struct afs_vlserver_list *afs_dns_query(struct afs_cell *cell, time64_t *_expiry ...@@ -271,25 +272,33 @@ struct afs_vlserver_list *afs_dns_query(struct afs_cell *cell, time64_t *_expiry
/* /*
* Merge an IPv4 entry into a fileserver address list. * Merge an IPv4 entry into a fileserver address list.
*/ */
void afs_merge_fs_addr4(struct afs_addr_list *alist, __be32 xdr, u16 port) int afs_merge_fs_addr4(struct afs_net *net, struct afs_addr_list *alist,
__be32 xdr, u16 port)
{ {
struct sockaddr_rxrpc *srx; struct sockaddr_rxrpc srx;
u32 addr = ntohl(xdr); struct rxrpc_peer *peer;
int i; int i;
if (alist->nr_addrs >= alist->max_addrs) if (alist->nr_addrs >= alist->max_addrs)
return; return 0;
for (i = 0; i < alist->nr_ipv4; i++) { srx.srx_family = AF_RXRPC;
struct sockaddr_in *a = &alist->addrs[i].srx.transport.sin; srx.transport_type = SOCK_DGRAM;
u32 a_addr = ntohl(a->sin_addr.s_addr); srx.transport_len = sizeof(srx.transport.sin);
u16 a_port = ntohs(a->sin_port); srx.transport.sin.sin_family = AF_INET;
srx.transport.sin.sin_port = htons(port);
srx.transport.sin.sin_addr.s_addr = xdr;
if (addr == a_addr && port == a_port) peer = rxrpc_kernel_lookup_peer(net->socket, &srx, GFP_KERNEL);
return; if (!peer)
if (addr == a_addr && port < a_port) return -ENOMEM;
break;
if (addr < a_addr) for (i = 0; i < alist->nr_ipv4; i++) {
if (peer == alist->addrs[i].peer) {
rxrpc_kernel_put_peer(peer);
return 0;
}
if (peer <= alist->addrs[i].peer)
break; break;
} }
...@@ -298,38 +307,42 @@ void afs_merge_fs_addr4(struct afs_addr_list *alist, __be32 xdr, u16 port) ...@@ -298,38 +307,42 @@ void afs_merge_fs_addr4(struct afs_addr_list *alist, __be32 xdr, u16 port)
alist->addrs + i, alist->addrs + i,
sizeof(alist->addrs[0]) * (alist->nr_addrs - i)); sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
srx = &alist->addrs[i].srx; alist->addrs[i].peer = peer;
srx->srx_family = AF_RXRPC;
srx->transport_type = SOCK_DGRAM;
srx->transport_len = sizeof(srx->transport.sin);
srx->transport.sin.sin_family = AF_INET;
srx->transport.sin.sin_port = htons(port);
srx->transport.sin.sin_addr.s_addr = xdr;
alist->nr_ipv4++; alist->nr_ipv4++;
alist->nr_addrs++; alist->nr_addrs++;
return 0;
} }
/* /*
* Merge an IPv6 entry into a fileserver address list. * Merge an IPv6 entry into a fileserver address list.
*/ */
void afs_merge_fs_addr6(struct afs_addr_list *alist, __be32 *xdr, u16 port) int afs_merge_fs_addr6(struct afs_net *net, struct afs_addr_list *alist,
__be32 *xdr, u16 port)
{ {
struct sockaddr_rxrpc *srx; struct sockaddr_rxrpc srx;
int i, diff; struct rxrpc_peer *peer;
int i;
if (alist->nr_addrs >= alist->max_addrs) if (alist->nr_addrs >= alist->max_addrs)
return; return 0;
for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) { srx.srx_family = AF_RXRPC;
struct sockaddr_in6 *a = &alist->addrs[i].srx.transport.sin6; srx.transport_type = SOCK_DGRAM;
u16 a_port = ntohs(a->sin6_port); srx.transport_len = sizeof(srx.transport.sin6);
srx.transport.sin6.sin6_family = AF_INET6;
srx.transport.sin6.sin6_port = htons(port);
memcpy(&srx.transport.sin6.sin6_addr, xdr, 16);
diff = memcmp(xdr, &a->sin6_addr, 16); peer = rxrpc_kernel_lookup_peer(net->socket, &srx, GFP_KERNEL);
if (diff == 0 && port == a_port) if (!peer)
return; return -ENOMEM;
if (diff == 0 && port < a_port)
break; for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) {
if (diff < 0) if (peer == alist->addrs[i].peer) {
rxrpc_kernel_put_peer(peer);
return 0;
}
if (peer <= alist->addrs[i].peer)
break; break;
} }
...@@ -337,15 +350,9 @@ void afs_merge_fs_addr6(struct afs_addr_list *alist, __be32 *xdr, u16 port) ...@@ -337,15 +350,9 @@ void afs_merge_fs_addr6(struct afs_addr_list *alist, __be32 *xdr, u16 port)
memmove(alist->addrs + i + 1, memmove(alist->addrs + i + 1,
alist->addrs + i, alist->addrs + i,
sizeof(alist->addrs[0]) * (alist->nr_addrs - i)); sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
alist->addrs[i].peer = peer;
srx = &alist->addrs[i].srx;
srx->srx_family = AF_RXRPC;
srx->transport_type = SOCK_DGRAM;
srx->transport_len = sizeof(srx->transport.sin6);
srx->transport.sin6.sin6_family = AF_INET6;
srx->transport.sin6.sin6_port = htons(port);
memcpy(&srx->transport.sin6.sin6_addr, xdr, 16);
alist->nr_addrs++; alist->nr_addrs++;
return 0;
} }
/* /*
......
...@@ -146,10 +146,11 @@ static int afs_find_cm_server_by_peer(struct afs_call *call) ...@@ -146,10 +146,11 @@ static int afs_find_cm_server_by_peer(struct afs_call *call)
{ {
struct sockaddr_rxrpc srx; struct sockaddr_rxrpc srx;
struct afs_server *server; struct afs_server *server;
struct rxrpc_peer *peer;
rxrpc_kernel_get_peer(call->net->socket, call->rxcall, &srx); peer = rxrpc_kernel_get_call_peer(call->net->socket, call->rxcall);
server = afs_find_server(call->net, &srx); server = afs_find_server(call->net, peer);
if (!server) { if (!server) {
trace_afs_cm_no_server(call, &srx); trace_afs_cm_no_server(call, &srx);
return 0; return 0;
......
...@@ -101,6 +101,7 @@ static void afs_fs_probe_not_done(struct afs_net *net, ...@@ -101,6 +101,7 @@ static void afs_fs_probe_not_done(struct afs_net *net,
void afs_fileserver_probe_result(struct afs_call *call) void afs_fileserver_probe_result(struct afs_call *call)
{ {
struct afs_addr_list *alist = call->alist; struct afs_addr_list *alist = call->alist;
struct afs_address *addr = &alist->addrs[call->addr_ix];
struct afs_server *server = call->server; struct afs_server *server = call->server;
unsigned int index = call->addr_ix; unsigned int index = call->addr_ix;
unsigned int rtt_us = 0, cap0; unsigned int rtt_us = 0, cap0;
...@@ -153,12 +154,12 @@ void afs_fileserver_probe_result(struct afs_call *call) ...@@ -153,12 +154,12 @@ void afs_fileserver_probe_result(struct afs_call *call)
if (call->service_id == YFS_FS_SERVICE) { if (call->service_id == YFS_FS_SERVICE) {
server->probe.is_yfs = true; server->probe.is_yfs = true;
set_bit(AFS_SERVER_FL_IS_YFS, &server->flags); set_bit(AFS_SERVER_FL_IS_YFS, &server->flags);
alist->addrs[index].srx.srx_service = call->service_id; addr->service_id = call->service_id;
} else { } else {
server->probe.not_yfs = true; server->probe.not_yfs = true;
if (!server->probe.is_yfs) { if (!server->probe.is_yfs) {
clear_bit(AFS_SERVER_FL_IS_YFS, &server->flags); clear_bit(AFS_SERVER_FL_IS_YFS, &server->flags);
alist->addrs[index].srx.srx_service = call->service_id; addr->service_id = call->service_id;
} }
cap0 = ntohl(call->tmp); cap0 = ntohl(call->tmp);
if (cap0 & AFS3_VICED_CAPABILITY_64BITFILES) if (cap0 & AFS3_VICED_CAPABILITY_64BITFILES)
...@@ -167,7 +168,7 @@ void afs_fileserver_probe_result(struct afs_call *call) ...@@ -167,7 +168,7 @@ void afs_fileserver_probe_result(struct afs_call *call)
clear_bit(AFS_SERVER_FL_HAS_FS64, &server->flags); clear_bit(AFS_SERVER_FL_HAS_FS64, &server->flags);
} }
rxrpc_kernel_get_srtt(call->net->socket, call->rxcall, &rtt_us); rtt_us = rxrpc_kernel_get_srtt(addr->peer);
if (rtt_us < server->probe.rtt) { if (rtt_us < server->probe.rtt) {
server->probe.rtt = rtt_us; server->probe.rtt = rtt_us;
server->rtt = rtt_us; server->rtt = rtt_us;
...@@ -181,8 +182,8 @@ void afs_fileserver_probe_result(struct afs_call *call) ...@@ -181,8 +182,8 @@ void afs_fileserver_probe_result(struct afs_call *call)
out: out:
spin_unlock(&server->probe_lock); spin_unlock(&server->probe_lock);
_debug("probe %pU [%u] %pISpc rtt=%u ret=%d", _debug("probe %pU [%u] %pISpc rtt=%d ret=%d",
&server->uuid, index, &alist->addrs[index].srx.transport, &server->uuid, index, rxrpc_kernel_remote_addr(alist->addrs[index].peer),
rtt_us, ret); rtt_us, ret);
return afs_done_one_fs_probe(call->net, server); return afs_done_one_fs_probe(call->net, server);
......
...@@ -72,6 +72,11 @@ enum afs_call_state { ...@@ -72,6 +72,11 @@ enum afs_call_state {
AFS_CALL_COMPLETE, /* Completed or failed */ AFS_CALL_COMPLETE, /* Completed or failed */
}; };
struct afs_address {
struct rxrpc_peer *peer;
u16 service_id;
};
/* /*
* List of server addresses. * List of server addresses.
*/ */
...@@ -87,9 +92,7 @@ struct afs_addr_list { ...@@ -87,9 +92,7 @@ struct afs_addr_list {
enum dns_lookup_status status:8; enum dns_lookup_status status:8;
unsigned long failed; /* Mask of addrs that failed locally/ICMP */ unsigned long failed; /* Mask of addrs that failed locally/ICMP */
unsigned long responded; /* Mask of addrs that responded */ unsigned long responded; /* Mask of addrs that responded */
struct { struct afs_address addrs[] __counted_by(max_addrs);
struct sockaddr_rxrpc srx;
} addrs[] __counted_by(max_addrs);
#define AFS_MAX_ADDRESSES ((unsigned int)(sizeof(unsigned long) * 8)) #define AFS_MAX_ADDRESSES ((unsigned int)(sizeof(unsigned long) * 8))
}; };
...@@ -420,7 +423,7 @@ struct afs_vlserver { ...@@ -420,7 +423,7 @@ struct afs_vlserver {
atomic_t probe_outstanding; atomic_t probe_outstanding;
spinlock_t probe_lock; spinlock_t probe_lock;
struct { struct {
unsigned int rtt; /* RTT in uS */ unsigned int rtt; /* Best RTT in uS (or UINT_MAX) */
u32 abort_code; u32 abort_code;
short error; short error;
unsigned short flags; unsigned short flags;
...@@ -537,7 +540,7 @@ struct afs_server { ...@@ -537,7 +540,7 @@ struct afs_server {
atomic_t probe_outstanding; atomic_t probe_outstanding;
spinlock_t probe_lock; spinlock_t probe_lock;
struct { struct {
unsigned int rtt; /* RTT in uS */ unsigned int rtt; /* Best RTT in uS (or UINT_MAX) */
u32 abort_code; u32 abort_code;
short error; short error;
bool responded:1; bool responded:1;
...@@ -964,9 +967,7 @@ static inline struct afs_addr_list *afs_get_addrlist(struct afs_addr_list *alist ...@@ -964,9 +967,7 @@ static inline struct afs_addr_list *afs_get_addrlist(struct afs_addr_list *alist
refcount_inc(&alist->usage); refcount_inc(&alist->usage);
return alist; return alist;
} }
extern struct afs_addr_list *afs_alloc_addrlist(unsigned int, extern struct afs_addr_list *afs_alloc_addrlist(unsigned int nr, u16 service_id);
unsigned short,
unsigned short);
extern void afs_put_addrlist(struct afs_addr_list *); extern void afs_put_addrlist(struct afs_addr_list *);
extern struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *, extern struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *,
const char *, size_t, char, const char *, size_t, char,
...@@ -977,8 +978,10 @@ extern struct afs_vlserver_list *afs_dns_query(struct afs_cell *, time64_t *); ...@@ -977,8 +978,10 @@ extern struct afs_vlserver_list *afs_dns_query(struct afs_cell *, time64_t *);
extern bool afs_iterate_addresses(struct afs_addr_cursor *); extern bool afs_iterate_addresses(struct afs_addr_cursor *);
extern int afs_end_cursor(struct afs_addr_cursor *); extern int afs_end_cursor(struct afs_addr_cursor *);
extern void afs_merge_fs_addr4(struct afs_addr_list *, __be32, u16); extern int afs_merge_fs_addr4(struct afs_net *net, struct afs_addr_list *addr,
extern void afs_merge_fs_addr6(struct afs_addr_list *, __be32 *, u16); __be32 xdr, u16 port);
extern int afs_merge_fs_addr6(struct afs_net *net, struct afs_addr_list *addr,
__be32 *xdr, u16 port);
/* /*
* callback.c * callback.c
...@@ -1405,8 +1408,7 @@ extern void __exit afs_clean_up_permit_cache(void); ...@@ -1405,8 +1408,7 @@ extern void __exit afs_clean_up_permit_cache(void);
*/ */
extern spinlock_t afs_server_peer_lock; extern spinlock_t afs_server_peer_lock;
extern struct afs_server *afs_find_server(struct afs_net *, extern struct afs_server *afs_find_server(struct afs_net *, const struct rxrpc_peer *);
const struct sockaddr_rxrpc *);
extern struct afs_server *afs_find_server_by_uuid(struct afs_net *, const uuid_t *); extern struct afs_server *afs_find_server_by_uuid(struct afs_net *, const uuid_t *);
extern struct afs_server *afs_lookup_server(struct afs_cell *, struct key *, const uuid_t *, u32); extern struct afs_server *afs_lookup_server(struct afs_cell *, struct key *, const uuid_t *, u32);
extern struct afs_server *afs_get_server(struct afs_server *, enum afs_server_trace); extern struct afs_server *afs_get_server(struct afs_server *, enum afs_server_trace);
......
...@@ -307,7 +307,7 @@ static int afs_proc_cell_vlservers_show(struct seq_file *m, void *v) ...@@ -307,7 +307,7 @@ static int afs_proc_cell_vlservers_show(struct seq_file *m, void *v)
for (i = 0; i < alist->nr_addrs; i++) for (i = 0; i < alist->nr_addrs; i++)
seq_printf(m, " %c %pISpc\n", seq_printf(m, " %c %pISpc\n",
alist->preferred == i ? '>' : '-', alist->preferred == i ? '>' : '-',
&alist->addrs[i].srx.transport); rxrpc_kernel_remote_addr(alist->addrs[i].peer));
} }
seq_printf(m, " info: fl=%lx rtt=%d\n", vlserver->flags, vlserver->rtt); seq_printf(m, " info: fl=%lx rtt=%d\n", vlserver->flags, vlserver->rtt);
seq_printf(m, " probe: fl=%x e=%d ac=%d out=%d\n", seq_printf(m, " probe: fl=%x e=%d ac=%d out=%d\n",
...@@ -398,9 +398,10 @@ static int afs_proc_servers_show(struct seq_file *m, void *v) ...@@ -398,9 +398,10 @@ static int afs_proc_servers_show(struct seq_file *m, void *v)
seq_printf(m, " - ALIST v=%u rsp=%lx f=%lx\n", seq_printf(m, " - ALIST v=%u rsp=%lx f=%lx\n",
alist->version, alist->responded, alist->failed); alist->version, alist->responded, alist->failed);
for (i = 0; i < alist->nr_addrs; i++) for (i = 0; i < alist->nr_addrs; i++)
seq_printf(m, " [%x] %pISpc%s\n", seq_printf(m, " [%x] %pISpc%s rtt=%d\n",
i, &alist->addrs[i].srx.transport, i, rxrpc_kernel_remote_addr(alist->addrs[i].peer),
alist->preferred == i ? "*" : ""); alist->preferred == i ? "*" : "",
rxrpc_kernel_get_srtt(alist->addrs[i].peer));
return 0; return 0;
} }
......
...@@ -113,7 +113,7 @@ bool afs_select_fileserver(struct afs_operation *op) ...@@ -113,7 +113,7 @@ bool afs_select_fileserver(struct afs_operation *op)
struct afs_server *server; struct afs_server *server;
struct afs_vnode *vnode = op->file[0].vnode; struct afs_vnode *vnode = op->file[0].vnode;
struct afs_error e; struct afs_error e;
u32 rtt; unsigned int rtt;
int error = op->ac.error, i; int error = op->ac.error, i;
_enter("%lx[%d],%lx[%d],%d,%d", _enter("%lx[%d],%lx[%d],%d,%d",
...@@ -420,7 +420,7 @@ bool afs_select_fileserver(struct afs_operation *op) ...@@ -420,7 +420,7 @@ bool afs_select_fileserver(struct afs_operation *op)
} }
op->index = -1; op->index = -1;
rtt = U32_MAX; rtt = UINT_MAX;
for (i = 0; i < op->server_list->nr_servers; i++) { for (i = 0; i < op->server_list->nr_servers; i++) {
struct afs_server *s = op->server_list->servers[i].server; struct afs_server *s = op->server_list->servers[i].server;
...@@ -488,7 +488,7 @@ bool afs_select_fileserver(struct afs_operation *op) ...@@ -488,7 +488,7 @@ bool afs_select_fileserver(struct afs_operation *op)
_debug("address [%u] %u/%u %pISp", _debug("address [%u] %u/%u %pISp",
op->index, op->ac.index, op->ac.alist->nr_addrs, op->index, op->ac.index, op->ac.alist->nr_addrs,
&op->ac.alist->addrs[op->ac.index].srx.transport); rxrpc_kernel_remote_addr(op->ac.alist->addrs[op->ac.index].peer));
_leave(" = t"); _leave(" = t");
return true; return true;
......
...@@ -296,7 +296,8 @@ static void afs_notify_end_request_tx(struct sock *sock, ...@@ -296,7 +296,8 @@ static void afs_notify_end_request_tx(struct sock *sock,
*/ */
void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp) void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp)
{ {
struct sockaddr_rxrpc *srx = &ac->alist->addrs[ac->index].srx; struct afs_address *addr = &ac->alist->addrs[ac->index];
struct rxrpc_peer *peer = addr->peer;
struct rxrpc_call *rxcall; struct rxrpc_call *rxcall;
struct msghdr msg; struct msghdr msg;
struct kvec iov[1]; struct kvec iov[1];
...@@ -304,7 +305,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp) ...@@ -304,7 +305,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp)
s64 tx_total_len; s64 tx_total_len;
int ret; int ret;
_enter(",{%pISp},", &srx->transport); _enter(",{%pISp},", rxrpc_kernel_remote_addr(addr->peer));
ASSERT(call->type != NULL); ASSERT(call->type != NULL);
ASSERT(call->type->name != NULL); ASSERT(call->type->name != NULL);
...@@ -333,7 +334,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp) ...@@ -333,7 +334,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp)
} }
/* create a call */ /* create a call */
rxcall = rxrpc_kernel_begin_call(call->net->socket, srx, call->key, rxcall = rxrpc_kernel_begin_call(call->net->socket, peer, call->key,
(unsigned long)call, (unsigned long)call,
tx_total_len, tx_total_len,
call->max_lifespan, call->max_lifespan,
...@@ -341,6 +342,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp) ...@@ -341,6 +342,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp)
(call->async ? (call->async ?
afs_wake_up_async_call : afs_wake_up_async_call :
afs_wake_up_call_waiter), afs_wake_up_call_waiter),
addr->service_id,
call->upgrade, call->upgrade,
(call->intr ? RXRPC_PREINTERRUPTIBLE : (call->intr ? RXRPC_PREINTERRUPTIBLE :
RXRPC_UNINTERRUPTIBLE), RXRPC_UNINTERRUPTIBLE),
...@@ -461,7 +463,7 @@ static void afs_log_error(struct afs_call *call, s32 remote_abort) ...@@ -461,7 +463,7 @@ static void afs_log_error(struct afs_call *call, s32 remote_abort)
max = m + 1; max = m + 1;
pr_notice("kAFS: Peer reported %s failure on %s [%pISp]\n", pr_notice("kAFS: Peer reported %s failure on %s [%pISp]\n",
msg, call->type->name, msg, call->type->name,
&call->alist->addrs[call->addr_ix].srx.transport); rxrpc_kernel_remote_addr(call->alist->addrs[call->addr_ix].peer));
} }
} }
......
...@@ -21,13 +21,12 @@ static void __afs_put_server(struct afs_net *, struct afs_server *); ...@@ -21,13 +21,12 @@ static void __afs_put_server(struct afs_net *, struct afs_server *);
/* /*
* Find a server by one of its addresses. * Find a server by one of its addresses.
*/ */
struct afs_server *afs_find_server(struct afs_net *net, struct afs_server *afs_find_server(struct afs_net *net, const struct rxrpc_peer *peer)
const struct sockaddr_rxrpc *srx)
{ {
const struct afs_addr_list *alist; const struct afs_addr_list *alist;
struct afs_server *server = NULL; struct afs_server *server = NULL;
unsigned int i; unsigned int i;
int seq = 1, diff; int seq = 1;
rcu_read_lock(); rcu_read_lock();
...@@ -38,38 +37,12 @@ struct afs_server *afs_find_server(struct afs_net *net, ...@@ -38,38 +37,12 @@ struct afs_server *afs_find_server(struct afs_net *net,
seq++; /* 2 on the 1st/lockless path, otherwise odd */ seq++; /* 2 on the 1st/lockless path, otherwise odd */
read_seqbegin_or_lock(&net->fs_addr_lock, &seq); read_seqbegin_or_lock(&net->fs_addr_lock, &seq);
if (srx->transport.family == AF_INET6) {
const struct sockaddr_in6 *a = &srx->transport.sin6, *b;
hlist_for_each_entry_rcu(server, &net->fs_addresses6, addr6_link) { hlist_for_each_entry_rcu(server, &net->fs_addresses6, addr6_link) {
alist = rcu_dereference(server->addresses); alist = rcu_dereference(server->addresses);
for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) { for (i = 0; i < alist->nr_addrs; i++)
b = &alist->addrs[i].srx.transport.sin6; if (alist->addrs[i].peer == peer)
diff = ((u16 __force)a->sin6_port -
(u16 __force)b->sin6_port);
if (diff == 0)
diff = memcmp(&a->sin6_addr,
&b->sin6_addr,
sizeof(struct in6_addr));
if (diff == 0)
goto found; goto found;
} }
}
} else {
const struct sockaddr_in *a = &srx->transport.sin, *b;
hlist_for_each_entry_rcu(server, &net->fs_addresses4, addr4_link) {
alist = rcu_dereference(server->addresses);
for (i = 0; i < alist->nr_ipv4; i++) {
b = &alist->addrs[i].srx.transport.sin;
diff = ((u16 __force)a->sin_port -
(u16 __force)b->sin_port);
if (diff == 0)
diff = ((u32 __force)a->sin_addr.s_addr -
(u32 __force)b->sin_addr.s_addr);
if (diff == 0)
goto found;
}
}
}
server = NULL; server = NULL;
continue; continue;
......
...@@ -32,55 +32,6 @@ static struct afs_volume *afs_sample_volume(struct afs_cell *cell, struct key *k ...@@ -32,55 +32,6 @@ static struct afs_volume *afs_sample_volume(struct afs_cell *cell, struct key *k
return volume; return volume;
} }
/*
* Compare two addresses.
*/
static int afs_compare_addrs(const struct sockaddr_rxrpc *srx_a,
const struct sockaddr_rxrpc *srx_b)
{
short port_a, port_b;
int addr_a, addr_b, diff;
diff = (short)srx_a->transport_type - (short)srx_b->transport_type;
if (diff)
goto out;
switch (srx_a->transport_type) {
case AF_INET: {
const struct sockaddr_in *a = &srx_a->transport.sin;
const struct sockaddr_in *b = &srx_b->transport.sin;
addr_a = ntohl(a->sin_addr.s_addr);
addr_b = ntohl(b->sin_addr.s_addr);
diff = addr_a - addr_b;
if (diff == 0) {
port_a = ntohs(a->sin_port);
port_b = ntohs(b->sin_port);
diff = port_a - port_b;
}
break;
}
case AF_INET6: {
const struct sockaddr_in6 *a = &srx_a->transport.sin6;
const struct sockaddr_in6 *b = &srx_b->transport.sin6;
diff = memcmp(&a->sin6_addr, &b->sin6_addr, 16);
if (diff == 0) {
port_a = ntohs(a->sin6_port);
port_b = ntohs(b->sin6_port);
diff = port_a - port_b;
}
break;
}
default:
WARN_ON(1);
diff = 1;
}
out:
return diff;
}
/* /*
* Compare the address lists of a pair of fileservers. * Compare the address lists of a pair of fileservers.
*/ */
...@@ -94,9 +45,9 @@ static int afs_compare_fs_alists(const struct afs_server *server_a, ...@@ -94,9 +45,9 @@ static int afs_compare_fs_alists(const struct afs_server *server_a,
lb = rcu_dereference(server_b->addresses); lb = rcu_dereference(server_b->addresses);
while (a < la->nr_addrs && b < lb->nr_addrs) { while (a < la->nr_addrs && b < lb->nr_addrs) {
const struct sockaddr_rxrpc *srx_a = &la->addrs[a].srx; unsigned long pa = (unsigned long)la->addrs[a].peer;
const struct sockaddr_rxrpc *srx_b = &lb->addrs[b].srx; unsigned long pb = (unsigned long)lb->addrs[b].peer;
int diff = afs_compare_addrs(srx_a, srx_b); long diff = pa - pb;
if (diff < 0) { if (diff < 0) {
a++; a++;
......
...@@ -83,14 +83,15 @@ static u16 afs_extract_le16(const u8 **_b) ...@@ -83,14 +83,15 @@ static u16 afs_extract_le16(const u8 **_b)
/* /*
* Build a VL server address list from a DNS queried server list. * Build a VL server address list from a DNS queried server list.
*/ */
static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end, static struct afs_addr_list *afs_extract_vl_addrs(struct afs_net *net,
const u8 **_b, const u8 *end,
u8 nr_addrs, u16 port) u8 nr_addrs, u16 port)
{ {
struct afs_addr_list *alist; struct afs_addr_list *alist;
const u8 *b = *_b; const u8 *b = *_b;
int ret = -EINVAL; int ret = -EINVAL;
alist = afs_alloc_addrlist(nr_addrs, VL_SERVICE, port); alist = afs_alloc_addrlist(nr_addrs, VL_SERVICE);
if (!alist) if (!alist)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
if (nr_addrs == 0) if (nr_addrs == 0)
...@@ -109,7 +110,9 @@ static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end, ...@@ -109,7 +110,9 @@ static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end,
goto error; goto error;
} }
memcpy(x, b, 4); memcpy(x, b, 4);
afs_merge_fs_addr4(alist, x[0], port); ret = afs_merge_fs_addr4(net, alist, x[0], port);
if (ret < 0)
goto error;
b += 4; b += 4;
break; break;
...@@ -119,7 +122,9 @@ static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end, ...@@ -119,7 +122,9 @@ static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end,
goto error; goto error;
} }
memcpy(x, b, 16); memcpy(x, b, 16);
afs_merge_fs_addr6(alist, x, port); ret = afs_merge_fs_addr6(net, alist, x, port);
if (ret < 0)
goto error;
b += 16; b += 16;
break; break;
...@@ -247,7 +252,7 @@ struct afs_vlserver_list *afs_extract_vlserver_list(struct afs_cell *cell, ...@@ -247,7 +252,7 @@ struct afs_vlserver_list *afs_extract_vlserver_list(struct afs_cell *cell,
/* Extract the addresses - note that we can't skip this as we /* Extract the addresses - note that we can't skip this as we
* have to advance the payload pointer. * have to advance the payload pointer.
*/ */
addrs = afs_extract_vl_addrs(&b, end, bs.nr_addrs, bs.port); addrs = afs_extract_vl_addrs(cell->net, &b, end, bs.nr_addrs, bs.port);
if (IS_ERR(addrs)) { if (IS_ERR(addrs)) {
ret = PTR_ERR(addrs); ret = PTR_ERR(addrs);
goto error_2; goto error_2;
......
...@@ -48,6 +48,7 @@ void afs_vlserver_probe_result(struct afs_call *call) ...@@ -48,6 +48,7 @@ void afs_vlserver_probe_result(struct afs_call *call)
{ {
struct afs_addr_list *alist = call->alist; struct afs_addr_list *alist = call->alist;
struct afs_vlserver *server = call->vlserver; struct afs_vlserver *server = call->vlserver;
struct afs_address *addr = &alist->addrs[call->addr_ix];
unsigned int server_index = call->server_index; unsigned int server_index = call->server_index;
unsigned int rtt_us = 0; unsigned int rtt_us = 0;
unsigned int index = call->addr_ix; unsigned int index = call->addr_ix;
...@@ -106,16 +107,16 @@ void afs_vlserver_probe_result(struct afs_call *call) ...@@ -106,16 +107,16 @@ void afs_vlserver_probe_result(struct afs_call *call)
if (call->service_id == YFS_VL_SERVICE) { if (call->service_id == YFS_VL_SERVICE) {
server->probe.flags |= AFS_VLSERVER_PROBE_IS_YFS; server->probe.flags |= AFS_VLSERVER_PROBE_IS_YFS;
set_bit(AFS_VLSERVER_FL_IS_YFS, &server->flags); set_bit(AFS_VLSERVER_FL_IS_YFS, &server->flags);
alist->addrs[index].srx.srx_service = call->service_id; addr->service_id = call->service_id;
} else { } else {
server->probe.flags |= AFS_VLSERVER_PROBE_NOT_YFS; server->probe.flags |= AFS_VLSERVER_PROBE_NOT_YFS;
if (!(server->probe.flags & AFS_VLSERVER_PROBE_IS_YFS)) { if (!(server->probe.flags & AFS_VLSERVER_PROBE_IS_YFS)) {
clear_bit(AFS_VLSERVER_FL_IS_YFS, &server->flags); clear_bit(AFS_VLSERVER_FL_IS_YFS, &server->flags);
alist->addrs[index].srx.srx_service = call->service_id; addr->service_id = call->service_id;
} }
} }
rxrpc_kernel_get_srtt(call->net->socket, call->rxcall, &rtt_us); rtt_us = rxrpc_kernel_get_srtt(addr->peer);
if (rtt_us < server->probe.rtt) { if (rtt_us < server->probe.rtt) {
server->probe.rtt = rtt_us; server->probe.rtt = rtt_us;
server->rtt = rtt_us; server->rtt = rtt_us;
...@@ -130,8 +131,9 @@ void afs_vlserver_probe_result(struct afs_call *call) ...@@ -130,8 +131,9 @@ void afs_vlserver_probe_result(struct afs_call *call)
out: out:
spin_unlock(&server->probe_lock); spin_unlock(&server->probe_lock);
_debug("probe [%u][%u] %pISpc rtt=%u ret=%d", _debug("probe [%u][%u] %pISpc rtt=%d ret=%d",
server_index, index, &alist->addrs[index].srx.transport, rtt_us, ret); server_index, index, rxrpc_kernel_remote_addr(addr->peer),
rtt_us, ret);
afs_done_one_vl_probe(server, have_result); afs_done_one_vl_probe(server, have_result);
} }
......
...@@ -92,7 +92,7 @@ bool afs_select_vlserver(struct afs_vl_cursor *vc) ...@@ -92,7 +92,7 @@ bool afs_select_vlserver(struct afs_vl_cursor *vc)
struct afs_addr_list *alist; struct afs_addr_list *alist;
struct afs_vlserver *vlserver; struct afs_vlserver *vlserver;
struct afs_error e; struct afs_error e;
u32 rtt; unsigned int rtt;
int error = vc->ac.error, i; int error = vc->ac.error, i;
_enter("%lx[%d],%lx[%d],%d,%d", _enter("%lx[%d],%lx[%d],%d,%d",
...@@ -194,7 +194,7 @@ bool afs_select_vlserver(struct afs_vl_cursor *vc) ...@@ -194,7 +194,7 @@ bool afs_select_vlserver(struct afs_vl_cursor *vc)
goto selected_server; goto selected_server;
vc->index = -1; vc->index = -1;
rtt = U32_MAX; rtt = UINT_MAX;
for (i = 0; i < vc->server_list->nr_servers; i++) { for (i = 0; i < vc->server_list->nr_servers; i++) {
struct afs_vlserver *s = vc->server_list->servers[i].server; struct afs_vlserver *s = vc->server_list->servers[i].server;
...@@ -249,7 +249,7 @@ bool afs_select_vlserver(struct afs_vl_cursor *vc) ...@@ -249,7 +249,7 @@ bool afs_select_vlserver(struct afs_vl_cursor *vc)
_debug("VL address %d/%d", vc->ac.index, vc->ac.alist->nr_addrs); _debug("VL address %d/%d", vc->ac.index, vc->ac.alist->nr_addrs);
_leave(" = t %pISpc", &vc->ac.alist->addrs[vc->ac.index].srx.transport); _leave(" = t %pISpc", rxrpc_kernel_remote_addr(vc->ac.alist->addrs[vc->ac.index].peer));
return true; return true;
next_server: next_server:
......
...@@ -208,7 +208,7 @@ static int afs_deliver_vl_get_addrs_u(struct afs_call *call) ...@@ -208,7 +208,7 @@ static int afs_deliver_vl_get_addrs_u(struct afs_call *call)
count = ntohl(*bp); count = ntohl(*bp);
nentries = min(nentries, count); nentries = min(nentries, count);
alist = afs_alloc_addrlist(nentries, FS_SERVICE, AFS_FS_PORT); alist = afs_alloc_addrlist(nentries, FS_SERVICE);
if (!alist) if (!alist)
return -ENOMEM; return -ENOMEM;
alist->version = uniquifier; alist->version = uniquifier;
...@@ -230,9 +230,13 @@ static int afs_deliver_vl_get_addrs_u(struct afs_call *call) ...@@ -230,9 +230,13 @@ static int afs_deliver_vl_get_addrs_u(struct afs_call *call)
alist = call->ret_alist; alist = call->ret_alist;
bp = call->buffer; bp = call->buffer;
count = min(call->count, 4U); count = min(call->count, 4U);
for (i = 0; i < count; i++) for (i = 0; i < count; i++) {
if (alist->nr_addrs < call->count2) if (alist->nr_addrs < call->count2) {
afs_merge_fs_addr4(alist, *bp++, AFS_FS_PORT); ret = afs_merge_fs_addr4(call->net, alist, *bp++, AFS_FS_PORT);
if (ret < 0)
return ret;
}
}
call->count -= count; call->count -= count;
if (call->count > 0) if (call->count > 0)
...@@ -450,7 +454,7 @@ static int afs_deliver_yfsvl_get_endpoints(struct afs_call *call) ...@@ -450,7 +454,7 @@ static int afs_deliver_yfsvl_get_endpoints(struct afs_call *call)
if (call->count > YFS_MAXENDPOINTS) if (call->count > YFS_MAXENDPOINTS)
return afs_protocol_error(call, afs_eproto_yvl_fsendpt_num); return afs_protocol_error(call, afs_eproto_yvl_fsendpt_num);
alist = afs_alloc_addrlist(call->count, FS_SERVICE, AFS_FS_PORT); alist = afs_alloc_addrlist(call->count, FS_SERVICE);
if (!alist) if (!alist)
return -ENOMEM; return -ENOMEM;
alist->version = uniquifier; alist->version = uniquifier;
...@@ -488,14 +492,18 @@ static int afs_deliver_yfsvl_get_endpoints(struct afs_call *call) ...@@ -488,14 +492,18 @@ static int afs_deliver_yfsvl_get_endpoints(struct afs_call *call)
if (ntohl(bp[0]) != sizeof(__be32) * 2) if (ntohl(bp[0]) != sizeof(__be32) * 2)
return afs_protocol_error( return afs_protocol_error(
call, afs_eproto_yvl_fsendpt4_len); call, afs_eproto_yvl_fsendpt4_len);
afs_merge_fs_addr4(alist, bp[1], ntohl(bp[2])); ret = afs_merge_fs_addr4(call->net, alist, bp[1], ntohl(bp[2]));
if (ret < 0)
return ret;
bp += 3; bp += 3;
break; break;
case YFS_ENDPOINT_IPV6: case YFS_ENDPOINT_IPV6:
if (ntohl(bp[0]) != sizeof(__be32) * 5) if (ntohl(bp[0]) != sizeof(__be32) * 5)
return afs_protocol_error( return afs_protocol_error(
call, afs_eproto_yvl_fsendpt6_len); call, afs_eproto_yvl_fsendpt6_len);
afs_merge_fs_addr6(alist, bp + 1, ntohl(bp[5])); ret = afs_merge_fs_addr6(call->net, alist, bp + 1, ntohl(bp[5]));
if (ret < 0)
return ret;
bp += 6; bp += 6;
break; break;
default: default:
......
...@@ -15,6 +15,7 @@ struct key; ...@@ -15,6 +15,7 @@ struct key;
struct sock; struct sock;
struct socket; struct socket;
struct rxrpc_call; struct rxrpc_call;
struct rxrpc_peer;
enum rxrpc_abort_reason; enum rxrpc_abort_reason;
enum rxrpc_interruptibility { enum rxrpc_interruptibility {
...@@ -41,13 +42,14 @@ void rxrpc_kernel_new_call_notification(struct socket *, ...@@ -41,13 +42,14 @@ void rxrpc_kernel_new_call_notification(struct socket *,
rxrpc_notify_new_call_t, rxrpc_notify_new_call_t,
rxrpc_discard_new_call_t); rxrpc_discard_new_call_t);
struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock, struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
struct sockaddr_rxrpc *srx, struct rxrpc_peer *peer,
struct key *key, struct key *key,
unsigned long user_call_ID, unsigned long user_call_ID,
s64 tx_total_len, s64 tx_total_len,
u32 hard_timeout, u32 hard_timeout,
gfp_t gfp, gfp_t gfp,
rxrpc_notify_rx_t notify_rx, rxrpc_notify_rx_t notify_rx,
u16 service_id,
bool upgrade, bool upgrade,
enum rxrpc_interruptibility interruptibility, enum rxrpc_interruptibility interruptibility,
unsigned int debug_id); unsigned int debug_id);
...@@ -60,9 +62,14 @@ bool rxrpc_kernel_abort_call(struct socket *, struct rxrpc_call *, ...@@ -60,9 +62,14 @@ bool rxrpc_kernel_abort_call(struct socket *, struct rxrpc_call *,
u32, int, enum rxrpc_abort_reason); u32, int, enum rxrpc_abort_reason);
void rxrpc_kernel_shutdown_call(struct socket *sock, struct rxrpc_call *call); void rxrpc_kernel_shutdown_call(struct socket *sock, struct rxrpc_call *call);
void rxrpc_kernel_put_call(struct socket *sock, struct rxrpc_call *call); void rxrpc_kernel_put_call(struct socket *sock, struct rxrpc_call *call);
void rxrpc_kernel_get_peer(struct socket *, struct rxrpc_call *, struct rxrpc_peer *rxrpc_kernel_lookup_peer(struct socket *sock,
struct sockaddr_rxrpc *); struct sockaddr_rxrpc *srx, gfp_t gfp);
bool rxrpc_kernel_get_srtt(struct socket *, struct rxrpc_call *, u32 *); void rxrpc_kernel_put_peer(struct rxrpc_peer *peer);
struct rxrpc_peer *rxrpc_kernel_get_peer(struct rxrpc_peer *peer);
struct rxrpc_peer *rxrpc_kernel_get_call_peer(struct socket *sock, struct rxrpc_call *call);
const struct sockaddr_rxrpc *rxrpc_kernel_remote_srx(const struct rxrpc_peer *peer);
const struct sockaddr *rxrpc_kernel_remote_addr(const struct rxrpc_peer *peer);
unsigned int rxrpc_kernel_get_srtt(const struct rxrpc_peer *);
int rxrpc_kernel_charge_accept(struct socket *, rxrpc_notify_rx_t, int rxrpc_kernel_charge_accept(struct socket *, rxrpc_notify_rx_t,
rxrpc_user_attach_call_t, unsigned long, gfp_t, rxrpc_user_attach_call_t, unsigned long, gfp_t,
unsigned int); unsigned int);
......
...@@ -178,7 +178,9 @@ ...@@ -178,7 +178,9 @@
#define rxrpc_peer_traces \ #define rxrpc_peer_traces \
EM(rxrpc_peer_free, "FREE ") \ EM(rxrpc_peer_free, "FREE ") \
EM(rxrpc_peer_get_accept, "GET accept ") \ EM(rxrpc_peer_get_accept, "GET accept ") \
EM(rxrpc_peer_get_application, "GET app ") \
EM(rxrpc_peer_get_bundle, "GET bundle ") \ EM(rxrpc_peer_get_bundle, "GET bundle ") \
EM(rxrpc_peer_get_call, "GET call ") \
EM(rxrpc_peer_get_client_conn, "GET cln-conn") \ EM(rxrpc_peer_get_client_conn, "GET cln-conn") \
EM(rxrpc_peer_get_input, "GET input ") \ EM(rxrpc_peer_get_input, "GET input ") \
EM(rxrpc_peer_get_input_error, "GET inpt-err") \ EM(rxrpc_peer_get_input_error, "GET inpt-err") \
...@@ -187,6 +189,7 @@ ...@@ -187,6 +189,7 @@
EM(rxrpc_peer_get_service_conn, "GET srv-conn") \ EM(rxrpc_peer_get_service_conn, "GET srv-conn") \
EM(rxrpc_peer_new_client, "NEW client ") \ EM(rxrpc_peer_new_client, "NEW client ") \
EM(rxrpc_peer_new_prealloc, "NEW prealloc") \ EM(rxrpc_peer_new_prealloc, "NEW prealloc") \
EM(rxrpc_peer_put_application, "PUT app ") \
EM(rxrpc_peer_put_bundle, "PUT bundle ") \ EM(rxrpc_peer_put_bundle, "PUT bundle ") \
EM(rxrpc_peer_put_call, "PUT call ") \ EM(rxrpc_peer_put_call, "PUT call ") \
EM(rxrpc_peer_put_conn, "PUT conn ") \ EM(rxrpc_peer_put_conn, "PUT conn ") \
......
...@@ -258,16 +258,62 @@ static int rxrpc_listen(struct socket *sock, int backlog) ...@@ -258,16 +258,62 @@ static int rxrpc_listen(struct socket *sock, int backlog)
return ret; return ret;
} }
/**
* rxrpc_kernel_lookup_peer - Obtain remote transport endpoint for an address
* @sock: The socket through which it will be accessed
* @srx: The network address
* @gfp: Allocation flags
*
* Lookup or create a remote transport endpoint record for the specified
* address and return it with a ref held.
*/
struct rxrpc_peer *rxrpc_kernel_lookup_peer(struct socket *sock,
struct sockaddr_rxrpc *srx, gfp_t gfp)
{
struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
int ret;
ret = rxrpc_validate_address(rx, srx, sizeof(*srx));
if (ret < 0)
return ERR_PTR(ret);
return rxrpc_lookup_peer(rx->local, srx, gfp);
}
EXPORT_SYMBOL(rxrpc_kernel_lookup_peer);
/**
* rxrpc_kernel_get_peer - Get a reference on a peer
* @peer: The peer to get a reference on.
*
* Get a record for the remote peer in a call.
*/
struct rxrpc_peer *rxrpc_kernel_get_peer(struct rxrpc_peer *peer)
{
return peer ? rxrpc_get_peer(peer, rxrpc_peer_get_application) : NULL;
}
EXPORT_SYMBOL(rxrpc_kernel_get_peer);
/**
* rxrpc_kernel_put_peer - Allow a kernel app to drop a peer reference
* @peer: The peer to drop a ref on
*/
void rxrpc_kernel_put_peer(struct rxrpc_peer *peer)
{
rxrpc_put_peer(peer, rxrpc_peer_put_application);
}
EXPORT_SYMBOL(rxrpc_kernel_put_peer);
/** /**
* rxrpc_kernel_begin_call - Allow a kernel service to begin a call * rxrpc_kernel_begin_call - Allow a kernel service to begin a call
* @sock: The socket on which to make the call * @sock: The socket on which to make the call
* @srx: The address of the peer to contact * @peer: The peer to contact
* @key: The security context to use (defaults to socket setting) * @key: The security context to use (defaults to socket setting)
* @user_call_ID: The ID to use * @user_call_ID: The ID to use
* @tx_total_len: Total length of data to transmit during the call (or -1) * @tx_total_len: Total length of data to transmit during the call (or -1)
* @hard_timeout: The maximum lifespan of the call in sec * @hard_timeout: The maximum lifespan of the call in sec
* @gfp: The allocation constraints * @gfp: The allocation constraints
* @notify_rx: Where to send notifications instead of socket queue * @notify_rx: Where to send notifications instead of socket queue
* @service_id: The ID of the service to contact
* @upgrade: Request service upgrade for call * @upgrade: Request service upgrade for call
* @interruptibility: The call is interruptible, or can be canceled. * @interruptibility: The call is interruptible, or can be canceled.
* @debug_id: The debug ID for tracing to be assigned to the call * @debug_id: The debug ID for tracing to be assigned to the call
...@@ -280,13 +326,14 @@ static int rxrpc_listen(struct socket *sock, int backlog) ...@@ -280,13 +326,14 @@ static int rxrpc_listen(struct socket *sock, int backlog)
* supplying @srx and @key. * supplying @srx and @key.
*/ */
struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock, struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
struct sockaddr_rxrpc *srx, struct rxrpc_peer *peer,
struct key *key, struct key *key,
unsigned long user_call_ID, unsigned long user_call_ID,
s64 tx_total_len, s64 tx_total_len,
u32 hard_timeout, u32 hard_timeout,
gfp_t gfp, gfp_t gfp,
rxrpc_notify_rx_t notify_rx, rxrpc_notify_rx_t notify_rx,
u16 service_id,
bool upgrade, bool upgrade,
enum rxrpc_interruptibility interruptibility, enum rxrpc_interruptibility interruptibility,
unsigned int debug_id) unsigned int debug_id)
...@@ -295,13 +342,11 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock, ...@@ -295,13 +342,11 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
struct rxrpc_call_params p; struct rxrpc_call_params p;
struct rxrpc_call *call; struct rxrpc_call *call;
struct rxrpc_sock *rx = rxrpc_sk(sock->sk); struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
int ret;
_enter(",,%x,%lx", key_serial(key), user_call_ID); _enter(",,%x,%lx", key_serial(key), user_call_ID);
ret = rxrpc_validate_address(rx, srx, sizeof(*srx)); if (WARN_ON_ONCE(peer->local != rx->local))
if (ret < 0) return ERR_PTR(-EIO);
return ERR_PTR(ret);
lock_sock(&rx->sk); lock_sock(&rx->sk);
...@@ -319,12 +364,13 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock, ...@@ -319,12 +364,13 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
memset(&cp, 0, sizeof(cp)); memset(&cp, 0, sizeof(cp));
cp.local = rx->local; cp.local = rx->local;
cp.peer = peer;
cp.key = key; cp.key = key;
cp.security_level = rx->min_sec_level; cp.security_level = rx->min_sec_level;
cp.exclusive = false; cp.exclusive = false;
cp.upgrade = upgrade; cp.upgrade = upgrade;
cp.service_id = srx->srx_service; cp.service_id = service_id;
call = rxrpc_new_client_call(rx, &cp, srx, &p, gfp, debug_id); call = rxrpc_new_client_call(rx, &cp, &p, gfp, debug_id);
/* The socket has been unlocked. */ /* The socket has been unlocked. */
if (!IS_ERR(call)) { if (!IS_ERR(call)) {
call->notify_rx = notify_rx; call->notify_rx = notify_rx;
......
...@@ -364,6 +364,7 @@ struct rxrpc_conn_proto { ...@@ -364,6 +364,7 @@ struct rxrpc_conn_proto {
struct rxrpc_conn_parameters { struct rxrpc_conn_parameters {
struct rxrpc_local *local; /* Representation of local endpoint */ struct rxrpc_local *local; /* Representation of local endpoint */
struct rxrpc_peer *peer; /* Representation of remote endpoint */
struct key *key; /* Security details */ struct key *key; /* Security details */
bool exclusive; /* T if conn is exclusive */ bool exclusive; /* T if conn is exclusive */
bool upgrade; /* T if service ID can be upgraded */ bool upgrade; /* T if service ID can be upgraded */
...@@ -867,7 +868,6 @@ struct rxrpc_call *rxrpc_find_call_by_user_ID(struct rxrpc_sock *, unsigned long ...@@ -867,7 +868,6 @@ struct rxrpc_call *rxrpc_find_call_by_user_ID(struct rxrpc_sock *, unsigned long
struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *, gfp_t, unsigned int); struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *, gfp_t, unsigned int);
struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *, struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *,
struct rxrpc_conn_parameters *, struct rxrpc_conn_parameters *,
struct sockaddr_rxrpc *,
struct rxrpc_call_params *, gfp_t, struct rxrpc_call_params *, gfp_t,
unsigned int); unsigned int);
void rxrpc_start_call_timer(struct rxrpc_call *call); void rxrpc_start_call_timer(struct rxrpc_call *call);
......
...@@ -193,7 +193,6 @@ struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *rx, gfp_t gfp, ...@@ -193,7 +193,6 @@ struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *rx, gfp_t gfp,
* Allocate a new client call. * Allocate a new client call.
*/ */
static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx, static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx,
struct sockaddr_rxrpc *srx,
struct rxrpc_conn_parameters *cp, struct rxrpc_conn_parameters *cp,
struct rxrpc_call_params *p, struct rxrpc_call_params *p,
gfp_t gfp, gfp_t gfp,
...@@ -211,10 +210,12 @@ static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx, ...@@ -211,10 +210,12 @@ static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx,
now = ktime_get_real(); now = ktime_get_real();
call->acks_latest_ts = now; call->acks_latest_ts = now;
call->cong_tstamp = now; call->cong_tstamp = now;
call->dest_srx = *srx; call->dest_srx = cp->peer->srx;
call->dest_srx.srx_service = cp->service_id;
call->interruptibility = p->interruptibility; call->interruptibility = p->interruptibility;
call->tx_total_len = p->tx_total_len; call->tx_total_len = p->tx_total_len;
call->key = key_get(cp->key); call->key = key_get(cp->key);
call->peer = rxrpc_get_peer(cp->peer, rxrpc_peer_get_call);
call->local = rxrpc_get_local(cp->local, rxrpc_local_get_call); call->local = rxrpc_get_local(cp->local, rxrpc_local_get_call);
call->security_level = cp->security_level; call->security_level = cp->security_level;
if (p->kernel) if (p->kernel)
...@@ -306,10 +307,6 @@ static int rxrpc_connect_call(struct rxrpc_call *call, gfp_t gfp) ...@@ -306,10 +307,6 @@ static int rxrpc_connect_call(struct rxrpc_call *call, gfp_t gfp)
_enter("{%d,%lx},", call->debug_id, call->user_call_ID); _enter("{%d,%lx},", call->debug_id, call->user_call_ID);
call->peer = rxrpc_lookup_peer(local, &call->dest_srx, gfp);
if (!call->peer)
goto error;
ret = rxrpc_look_up_bundle(call, gfp); ret = rxrpc_look_up_bundle(call, gfp);
if (ret < 0) if (ret < 0)
goto error; goto error;
...@@ -334,7 +331,6 @@ static int rxrpc_connect_call(struct rxrpc_call *call, gfp_t gfp) ...@@ -334,7 +331,6 @@ static int rxrpc_connect_call(struct rxrpc_call *call, gfp_t gfp)
*/ */
struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx,
struct rxrpc_conn_parameters *cp, struct rxrpc_conn_parameters *cp,
struct sockaddr_rxrpc *srx,
struct rxrpc_call_params *p, struct rxrpc_call_params *p,
gfp_t gfp, gfp_t gfp,
unsigned int debug_id) unsigned int debug_id)
...@@ -349,13 +345,18 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, ...@@ -349,13 +345,18 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx,
_enter("%p,%lx", rx, p->user_call_ID); _enter("%p,%lx", rx, p->user_call_ID);
if (WARN_ON_ONCE(!cp->peer)) {
release_sock(&rx->sk);
return ERR_PTR(-EIO);
}
limiter = rxrpc_get_call_slot(p, gfp); limiter = rxrpc_get_call_slot(p, gfp);
if (!limiter) { if (!limiter) {
release_sock(&rx->sk); release_sock(&rx->sk);
return ERR_PTR(-ERESTARTSYS); return ERR_PTR(-ERESTARTSYS);
} }
call = rxrpc_alloc_client_call(rx, srx, cp, p, gfp, debug_id); call = rxrpc_alloc_client_call(rx, cp, p, gfp, debug_id);
if (IS_ERR(call)) { if (IS_ERR(call)) {
release_sock(&rx->sk); release_sock(&rx->sk);
up(limiter); up(limiter);
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include <net/ip6_route.h> #include <net/ip6_route.h>
#include "ar-internal.h" #include "ar-internal.h"
static const struct sockaddr_rxrpc rxrpc_null_addr;
/* /*
* Hash a peer key. * Hash a peer key.
*/ */
...@@ -457,39 +459,53 @@ void rxrpc_destroy_all_peers(struct rxrpc_net *rxnet) ...@@ -457,39 +459,53 @@ void rxrpc_destroy_all_peers(struct rxrpc_net *rxnet)
} }
/** /**
* rxrpc_kernel_get_peer - Get the peer address of a call * rxrpc_kernel_get_call_peer - Get the peer address of a call
* @sock: The socket on which the call is in progress. * @sock: The socket on which the call is in progress.
* @call: The call to query * @call: The call to query
* @_srx: Where to place the result
* *
* Get the address of the remote peer in a call. * Get a record for the remote peer in a call.
*/ */
void rxrpc_kernel_get_peer(struct socket *sock, struct rxrpc_call *call, struct rxrpc_peer *rxrpc_kernel_get_call_peer(struct socket *sock, struct rxrpc_call *call)
struct sockaddr_rxrpc *_srx)
{ {
*_srx = call->peer->srx; return call->peer;
} }
EXPORT_SYMBOL(rxrpc_kernel_get_peer); EXPORT_SYMBOL(rxrpc_kernel_get_call_peer);
/** /**
* rxrpc_kernel_get_srtt - Get a call's peer smoothed RTT * rxrpc_kernel_get_srtt - Get a call's peer smoothed RTT
* @sock: The socket on which the call is in progress. * @peer: The peer to query
* @call: The call to query
* @_srtt: Where to store the SRTT value.
* *
* Get the call's peer smoothed RTT in uS. * Get the call's peer smoothed RTT in uS or UINT_MAX if we have no samples.
*/ */
bool rxrpc_kernel_get_srtt(struct socket *sock, struct rxrpc_call *call, unsigned int rxrpc_kernel_get_srtt(const struct rxrpc_peer *peer)
u32 *_srtt)
{ {
struct rxrpc_peer *peer = call->peer; return peer->rtt_count > 0 ? peer->srtt_us >> 3 : UINT_MAX;
}
EXPORT_SYMBOL(rxrpc_kernel_get_srtt);
if (peer->rtt_count == 0) { /**
*_srtt = 1000000; /* 1S */ * rxrpc_kernel_remote_srx - Get the address of a peer
return false; * @peer: The peer to query
} *
* Get a pointer to the address from a peer record. The caller is responsible
* for making sure that the address is not deallocated.
*/
const struct sockaddr_rxrpc *rxrpc_kernel_remote_srx(const struct rxrpc_peer *peer)
{
return peer ? &peer->srx : &rxrpc_null_addr;
}
EXPORT_SYMBOL(rxrpc_kernel_remote_srx);
*_srtt = call->peer->srtt_us >> 3; /**
return true; * rxrpc_kernel_remote_addr - Get the peer transport address of a call
* @peer: The peer to query
*
* Get a pointer to the transport address from a peer record. The caller is
* responsible for making sure that the address is not deallocated.
*/
const struct sockaddr *rxrpc_kernel_remote_addr(const struct rxrpc_peer *peer)
{
return (const struct sockaddr *)
(peer ? &peer->srx.transport : &rxrpc_null_addr.transport);
} }
EXPORT_SYMBOL(rxrpc_kernel_get_srtt); EXPORT_SYMBOL(rxrpc_kernel_remote_addr);
...@@ -572,6 +572,7 @@ rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, ...@@ -572,6 +572,7 @@ rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg,
__acquires(&call->user_mutex) __acquires(&call->user_mutex)
{ {
struct rxrpc_conn_parameters cp; struct rxrpc_conn_parameters cp;
struct rxrpc_peer *peer;
struct rxrpc_call *call; struct rxrpc_call *call;
struct key *key; struct key *key;
...@@ -584,21 +585,29 @@ rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, ...@@ -584,21 +585,29 @@ rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg,
return ERR_PTR(-EDESTADDRREQ); return ERR_PTR(-EDESTADDRREQ);
} }
peer = rxrpc_lookup_peer(rx->local, srx, GFP_KERNEL);
if (!peer) {
release_sock(&rx->sk);
return ERR_PTR(-ENOMEM);
}
key = rx->key; key = rx->key;
if (key && !rx->key->payload.data[0]) if (key && !rx->key->payload.data[0])
key = NULL; key = NULL;
memset(&cp, 0, sizeof(cp)); memset(&cp, 0, sizeof(cp));
cp.local = rx->local; cp.local = rx->local;
cp.peer = peer;
cp.key = rx->key; cp.key = rx->key;
cp.security_level = rx->min_sec_level; cp.security_level = rx->min_sec_level;
cp.exclusive = rx->exclusive | p->exclusive; cp.exclusive = rx->exclusive | p->exclusive;
cp.upgrade = p->upgrade; cp.upgrade = p->upgrade;
cp.service_id = srx->srx_service; cp.service_id = srx->srx_service;
call = rxrpc_new_client_call(rx, &cp, srx, &p->call, GFP_KERNEL, call = rxrpc_new_client_call(rx, &cp, &p->call, GFP_KERNEL,
atomic_inc_return(&rxrpc_debug_id)); atomic_inc_return(&rxrpc_debug_id));
/* The socket is now unlocked */ /* The socket is now unlocked */
rxrpc_put_peer(peer, rxrpc_peer_put_application);
_leave(" = %p\n", call); _leave(" = %p\n", call);
return call; return call;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment