Skip to content

Commit

Permalink
xfrm: Add support for per cpu xfrm state handling.
Browse files Browse the repository at this point in the history
Currently all flows for a certain SA must be processed by the same
cpu to avoid packet reordering and lock contention of the xfrm
state lock.

To get rid of this limitation, the IETF standardized per cpu SAs
in RFC 9611. This patch implements the xfrm part of it.

We add the cpu as a lookup key for xfrm states and a config option
to generate acquire messages for each cpu.

With that, we can have on each cpu a SA with identical traffic selector
so that flows can be processed in parallel on all cpus.

Signed-off-by: Steffen Klassert <[email protected]>
Tested-by: Antony Antony <[email protected]>
Tested-by: Tobias Brunner <[email protected]>
  • Loading branch information
klassert committed Oct 29, 2024
1 parent ab101c5 commit 1ddf991
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 22 deletions.
5 changes: 3 additions & 2 deletions include/net/xfrm.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ struct xfrm_state {
refcount_t refcnt;
spinlock_t lock;

u32 pcpu_num;
struct xfrm_id id;
struct xfrm_selector sel;
struct xfrm_mark mark;
Expand Down Expand Up @@ -1684,7 +1685,7 @@ struct xfrmk_spdinfo {
u32 spdhmcnt;
};

struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num);
int xfrm_state_delete(struct xfrm_state *x);
int xfrm_state_flush(struct net *net, u8 proto, bool task_valid, bool sync);
int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_valid);
Expand Down Expand Up @@ -1796,7 +1797,7 @@ int verify_spi_info(u8 proto, u32 min, u32 max, struct netlink_ext_ack *extack);
int xfrm_alloc_spi(struct xfrm_state *x, u32 minspi, u32 maxspi,
struct netlink_ext_ack *extack);
struct xfrm_state *xfrm_find_acq(struct net *net, const struct xfrm_mark *mark,
u8 mode, u32 reqid, u32 if_id, u8 proto,
u8 mode, u32 reqid, u32 if_id, u32 pcpu_num, u8 proto,
const xfrm_address_t *daddr,
const xfrm_address_t *saddr, int create,
unsigned short family);
Expand Down
2 changes: 2 additions & 0 deletions include/uapi/linux/xfrm.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ enum xfrm_attr_type_t {
XFRMA_MTIMER_THRESH, /* __u32 in seconds for input SA */
XFRMA_SA_DIR, /* __u8 */
XFRMA_NAT_KEEPALIVE_INTERVAL, /* __u32 in seconds for NAT keepalive */
XFRMA_SA_PCPU, /* __u32 */
__XFRMA_MAX

#define XFRMA_OUTPUT_MARK XFRMA_SET_MARK /* Compatibility */
Expand Down Expand Up @@ -437,6 +438,7 @@ struct xfrm_userpolicy_info {
#define XFRM_POLICY_LOCALOK 1 /* Allow user to override global policy */
/* Automatically expand selector to include matching ICMP payloads. */
#define XFRM_POLICY_ICMP 2
#define XFRM_POLICY_CPU_ACQUIRE 4
__u8 share;
};

Expand Down
7 changes: 4 additions & 3 deletions net/key/af_key.c
Original file line number Diff line number Diff line change
Expand Up @@ -1354,15 +1354,16 @@ static int pfkey_getspi(struct sock *sk, struct sk_buff *skb, const struct sadb_
}

if (hdr->sadb_msg_seq) {
x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq);
x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq, UINT_MAX);
if (x && !xfrm_addr_equal(&x->id.daddr, xdaddr, family)) {
xfrm_state_put(x);
x = NULL;
}
}

if (!x)
x = xfrm_find_acq(net, &dummy_mark, mode, reqid, 0, proto, xdaddr, xsaddr, 1, family);
x = xfrm_find_acq(net, &dummy_mark, mode, reqid, 0, UINT_MAX,
proto, xdaddr, xsaddr, 1, family);

if (x == NULL)
return -ENOENT;
Expand Down Expand Up @@ -1417,7 +1418,7 @@ static int pfkey_acquire(struct sock *sk, struct sk_buff *skb, const struct sadb
if (hdr->sadb_msg_seq == 0 || hdr->sadb_msg_errno == 0)
return 0;

x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq);
x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq, UINT_MAX);
if (x == NULL)
return 0;

Expand Down
6 changes: 4 additions & 2 deletions net/xfrm/xfrm_compat.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ static const struct nla_policy compat_policy[XFRMA_MAX+1] = {
[XFRMA_MTIMER_THRESH] = { .type = NLA_U32 },
[XFRMA_SA_DIR] = NLA_POLICY_RANGE(NLA_U8, XFRM_SA_DIR_IN, XFRM_SA_DIR_OUT),
[XFRMA_NAT_KEEPALIVE_INTERVAL] = { .type = NLA_U32 },
[XFRMA_SA_PCPU] = { .type = NLA_U32 },
};

static struct nlmsghdr *xfrm_nlmsg_put_compat(struct sk_buff *skb,
Expand Down Expand Up @@ -282,9 +283,10 @@ static int xfrm_xlate64_attr(struct sk_buff *dst, const struct nlattr *src)
case XFRMA_MTIMER_THRESH:
case XFRMA_SA_DIR:
case XFRMA_NAT_KEEPALIVE_INTERVAL:
case XFRMA_SA_PCPU:
return xfrm_nla_cpy(dst, src, nla_len(src));
default:
BUILD_BUG_ON(XFRMA_MAX != XFRMA_NAT_KEEPALIVE_INTERVAL);
BUILD_BUG_ON(XFRMA_MAX != XFRMA_SA_PCPU);
pr_warn_once("unsupported nla_type %d\n", src->nla_type);
return -EOPNOTSUPP;
}
Expand Down Expand Up @@ -439,7 +441,7 @@ static int xfrm_xlate32_attr(void *dst, const struct nlattr *nla,
int err;

if (type > XFRMA_MAX) {
BUILD_BUG_ON(XFRMA_MAX != XFRMA_NAT_KEEPALIVE_INTERVAL);
BUILD_BUG_ON(XFRMA_MAX != XFRMA_SA_PCPU);
NL_SET_ERR_MSG(extack, "Bad attribute");
return -EOPNOTSUPP;
}
Expand Down
58 changes: 47 additions & 11 deletions net/xfrm/xfrm_state.c
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ struct xfrm_state *xfrm_state_alloc(struct net *net)
x->lft.hard_packet_limit = XFRM_INF;
x->replay_maxage = 0;
x->replay_maxdiff = 0;
x->pcpu_num = UINT_MAX;
spin_lock_init(&x->lock);
}
return x;
Expand Down Expand Up @@ -1155,6 +1156,12 @@ static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
struct xfrm_state **best, int *acq_in_progress,
int *error)
{
/* We need the cpu id just as a lookup key,
* we don't require it to be stable.
*/
unsigned int pcpu_id = get_cpu();
put_cpu();

/* Resolution logic:
* 1. There is a valid state with matching selector. Done.
* 2. Valid state with inappropriate selector. Skip.
Expand All @@ -1174,13 +1181,18 @@ static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
&fl->u.__fl_common))
return;

if (x->pcpu_num != UINT_MAX && x->pcpu_num != pcpu_id)
return;

if (!*best ||
((*best)->pcpu_num == UINT_MAX && x->pcpu_num == pcpu_id) ||
(*best)->km.dying > x->km.dying ||
((*best)->km.dying == x->km.dying &&
(*best)->curlft.add_time < x->curlft.add_time))
*best = x;
} else if (x->km.state == XFRM_STATE_ACQ) {
*acq_in_progress = 1;
if (!*best || x->pcpu_num == pcpu_id)
*acq_in_progress = 1;
} else if (x->km.state == XFRM_STATE_ERROR ||
x->km.state == XFRM_STATE_EXPIRED) {
if ((!x->sel.family ||
Expand Down Expand Up @@ -1209,6 +1221,13 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
unsigned short encap_family = tmpl->encap_family;
unsigned int sequence;
struct km_event c;
unsigned int pcpu_id;

/* We need the cpu id just as a lookup key,
* we don't require it to be stable.
*/
pcpu_id = get_cpu();
put_cpu();

to_put = NULL;

Expand Down Expand Up @@ -1282,7 +1301,10 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
}

found:
x = best;
if (!(pol->flags & XFRM_POLICY_CPU_ACQUIRE) ||
(best && (best->pcpu_num == pcpu_id)))
x = best;

if (!x && !error && !acquire_in_progress) {
if (tmpl->id.spi &&
(x0 = __xfrm_state_lookup_all(net, mark, daddr,
Expand Down Expand Up @@ -1314,6 +1336,8 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
xfrm_init_tempstate(x, fl, tmpl, daddr, saddr, family);
memcpy(&x->mark, &pol->mark, sizeof(x->mark));
x->if_id = if_id;
if ((pol->flags & XFRM_POLICY_CPU_ACQUIRE) && best)
x->pcpu_num = pcpu_id;

error = security_xfrm_state_alloc_acquire(x, pol->security, fl->flowi_secid);
if (error) {
Expand Down Expand Up @@ -1392,6 +1416,11 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
x = NULL;
error = -ESRCH;
}

/* Use the already installed 'fallback' while the CPU-specific
* SA acquire is handled*/
if (best)
x = best;
}
out:
if (x) {
Expand Down Expand Up @@ -1524,12 +1553,14 @@ static void __xfrm_state_bump_genids(struct xfrm_state *xnew)
unsigned int h;
u32 mark = xnew->mark.v & xnew->mark.m;
u32 if_id = xnew->if_id;
u32 cpu_id = xnew->pcpu_num;

h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family);
hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
if (x->props.family == family &&
x->props.reqid == reqid &&
x->if_id == if_id &&
x->pcpu_num == cpu_id &&
(mark & x->mark.m) == x->mark.v &&
xfrm_addr_equal(&x->id.daddr, &xnew->id.daddr, family) &&
xfrm_addr_equal(&x->props.saddr, &xnew->props.saddr, family))
Expand All @@ -1552,7 +1583,7 @@ EXPORT_SYMBOL(xfrm_state_insert);
static struct xfrm_state *__find_acq_core(struct net *net,
const struct xfrm_mark *m,
unsigned short family, u8 mode,
u32 reqid, u32 if_id, u8 proto,
u32 reqid, u32 if_id, u32 pcpu_num, u8 proto,
const xfrm_address_t *daddr,
const xfrm_address_t *saddr,
int create)
Expand All @@ -1569,6 +1600,7 @@ static struct xfrm_state *__find_acq_core(struct net *net,
x->id.spi != 0 ||
x->id.proto != proto ||
(mark & x->mark.m) != x->mark.v ||
x->pcpu_num != pcpu_num ||
!xfrm_addr_equal(&x->id.daddr, daddr, family) ||
!xfrm_addr_equal(&x->props.saddr, saddr, family))
continue;
Expand Down Expand Up @@ -1602,6 +1634,7 @@ static struct xfrm_state *__find_acq_core(struct net *net,
break;
}

x->pcpu_num = pcpu_num;
x->km.state = XFRM_STATE_ACQ;
x->id.proto = proto;
x->props.family = family;
Expand Down Expand Up @@ -1630,7 +1663,7 @@ static struct xfrm_state *__find_acq_core(struct net *net,
return x;
}

static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num);

int xfrm_state_add(struct xfrm_state *x)
{
Expand All @@ -1656,7 +1689,7 @@ int xfrm_state_add(struct xfrm_state *x)
}

if (use_spi && x->km.seq) {
x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq);
x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq, x->pcpu_num);
if (x1 && ((x1->id.proto != x->id.proto) ||
!xfrm_addr_equal(&x1->id.daddr, &x->id.daddr, family))) {
to_put = x1;
Expand All @@ -1666,7 +1699,7 @@ int xfrm_state_add(struct xfrm_state *x)

if (use_spi && !x1)
x1 = __find_acq_core(net, &x->mark, family, x->props.mode,
x->props.reqid, x->if_id, x->id.proto,
x->props.reqid, x->if_id, x->pcpu_num, x->id.proto,
&x->id.daddr, &x->props.saddr, 0);

__xfrm_state_bump_genids(x);
Expand Down Expand Up @@ -1791,6 +1824,7 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig,
x->props.flags = orig->props.flags;
x->props.extra_flags = orig->props.extra_flags;

x->pcpu_num = orig->pcpu_num;
x->if_id = orig->if_id;
x->tfcpad = orig->tfcpad;
x->replay_maxdiff = orig->replay_maxdiff;
Expand Down Expand Up @@ -2066,13 +2100,14 @@ EXPORT_SYMBOL(xfrm_state_lookup_byaddr);

struct xfrm_state *
xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 reqid,
u32 if_id, u8 proto, const xfrm_address_t *daddr,
u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr,
const xfrm_address_t *saddr, int create, unsigned short family)
{
struct xfrm_state *x;

spin_lock_bh(&net->xfrm.xfrm_state_lock);
x = __find_acq_core(net, mark, family, mode, reqid, if_id, proto, daddr, saddr, create);
x = __find_acq_core(net, mark, family, mode, reqid, if_id, pcpu_num,
proto, daddr, saddr, create);
spin_unlock_bh(&net->xfrm.xfrm_state_lock);

return x;
Expand Down Expand Up @@ -2207,14 +2242,15 @@ xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n,

/* Silly enough, but I'm lazy to build resolution list */

static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num)
{
unsigned int h = xfrm_seq_hash(net, seq);
struct xfrm_state *x;

hlist_for_each_entry_rcu(x, net->xfrm.state_byseq + h, byseq) {
if (x->km.seq == seq &&
(mark & x->mark.m) == x->mark.v &&
x->pcpu_num == pcpu_num &&
x->km.state == XFRM_STATE_ACQ) {
xfrm_state_hold(x);
return x;
Expand All @@ -2224,12 +2260,12 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s
return NULL;
}

struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num)
{
struct xfrm_state *x;

spin_lock_bh(&net->xfrm.xfrm_state_lock);
x = __xfrm_find_acq_byseq(net, mark, seq);
x = __xfrm_find_acq_byseq(net, mark, seq, pcpu_num);
spin_unlock_bh(&net->xfrm.xfrm_state_lock);
return x;
}
Expand Down
Loading

0 comments on commit 1ddf991

Please sign in to comment.