Commit adff6c65 authored by Florian Westphal's avatar Florian Westphal Committed by Pablo Neira Ayuso

netfilter: connlabels: change nf_connlabels_get bit arg to 'highest used'

nf_connlabel_set() takes the bit number that we would like to set.
nf_connlabels_get() however took the number of bits that we want to
support.

So e.g. nf_connlabels_get(32) support bits 0 to 31, but not 32.
This changes nf_connlabels_get() to take the highest bit that we want
to set.

Callers then don't have to cope with a potential integer wrap
when using nf_connlabels_get(bit + 1) anymore.

Current callers are fine, this change is only to make folloup
nft ct label set support simpler.
Signed-off-by: default avatarFlorian Westphal <fw@strlen.de>
Signed-off-by: default avatarPablo Neira Ayuso <pablo@netfilter.org>
parent 5a8145f7
...@@ -53,11 +53,11 @@ int nf_connlabels_replace(struct nf_conn *ct, ...@@ -53,11 +53,11 @@ int nf_connlabels_replace(struct nf_conn *ct,
#ifdef CONFIG_NF_CONNTRACK_LABELS #ifdef CONFIG_NF_CONNTRACK_LABELS
int nf_conntrack_labels_init(void); int nf_conntrack_labels_init(void);
void nf_conntrack_labels_fini(void); void nf_conntrack_labels_fini(void);
int nf_connlabels_get(struct net *net, unsigned int n_bits); int nf_connlabels_get(struct net *net, unsigned int bit);
void nf_connlabels_put(struct net *net); void nf_connlabels_put(struct net *net);
#else #else
static inline int nf_conntrack_labels_init(void) { return 0; } static inline int nf_conntrack_labels_init(void) { return 0; }
static inline void nf_conntrack_labels_fini(void) {} static inline void nf_conntrack_labels_fini(void) {}
static inline int nf_connlabels_get(struct net *net, unsigned int n_bits) { return 0; } static inline int nf_connlabels_get(struct net *net, unsigned int bit) { return 0; }
static inline void nf_connlabels_put(struct net *net) {} static inline void nf_connlabels_put(struct net *net) {}
#endif #endif
...@@ -78,15 +78,14 @@ int nf_connlabels_replace(struct nf_conn *ct, ...@@ -78,15 +78,14 @@ int nf_connlabels_replace(struct nf_conn *ct,
} }
EXPORT_SYMBOL_GPL(nf_connlabels_replace); EXPORT_SYMBOL_GPL(nf_connlabels_replace);
int nf_connlabels_get(struct net *net, unsigned int n_bits) int nf_connlabels_get(struct net *net, unsigned int bits)
{ {
size_t words; size_t words;
if (n_bits > (NF_CT_LABELS_MAX_SIZE * BITS_PER_BYTE)) words = BIT_WORD(bits) + 1;
if (words > NF_CT_LABELS_MAX_SIZE / sizeof(long))
return -ERANGE; return -ERANGE;
words = BITS_TO_LONGS(n_bits);
spin_lock(&nf_connlabels_lock); spin_lock(&nf_connlabels_lock);
net->ct.labels_used++; net->ct.labels_used++;
if (words > net->ct.label_words) if (words > net->ct.label_words)
...@@ -115,6 +114,8 @@ static struct nf_ct_ext_type labels_extend __read_mostly = { ...@@ -115,6 +114,8 @@ static struct nf_ct_ext_type labels_extend __read_mostly = {
int nf_conntrack_labels_init(void) int nf_conntrack_labels_init(void)
{ {
BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE / sizeof(long) >= U8_MAX);
spin_lock_init(&nf_connlabels_lock); spin_lock_init(&nf_connlabels_lock);
return nf_ct_extend_register(&labels_extend); return nf_ct_extend_register(&labels_extend);
} }
......
...@@ -484,6 +484,8 @@ static struct nft_expr_type nft_ct_type __read_mostly = { ...@@ -484,6 +484,8 @@ static struct nft_expr_type nft_ct_type __read_mostly = {
static int __init nft_ct_module_init(void) static int __init nft_ct_module_init(void)
{ {
BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE);
return nft_register_expr(&nft_ct_type); return nft_register_expr(&nft_ct_type);
} }
......
...@@ -65,7 +65,7 @@ static int connlabel_mt_check(const struct xt_mtchk_param *par) ...@@ -65,7 +65,7 @@ static int connlabel_mt_check(const struct xt_mtchk_param *par)
return ret; return ret;
} }
ret = nf_connlabels_get(par->net, info->bit + 1); ret = nf_connlabels_get(par->net, info->bit);
if (ret < 0) if (ret < 0)
nf_ct_l3proto_module_put(par->family); nf_ct_l3proto_module_put(par->family);
return ret; return ret;
......
...@@ -1344,7 +1344,7 @@ void ovs_ct_init(struct net *net) ...@@ -1344,7 +1344,7 @@ void ovs_ct_init(struct net *net)
unsigned int n_bits = sizeof(struct ovs_key_ct_labels) * BITS_PER_BYTE; unsigned int n_bits = sizeof(struct ovs_key_ct_labels) * BITS_PER_BYTE;
struct ovs_net *ovs_net = net_generic(net, ovs_net_id); struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
if (nf_connlabels_get(net, n_bits)) { if (nf_connlabels_get(net, n_bits - 1)) {
ovs_net->xt_label = false; ovs_net->xt_label = false;
OVS_NLERR(true, "Failed to set connlabel length"); OVS_NLERR(true, "Failed to set connlabel length");
} else { } else {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment