diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c index f84c52b6eafc..10daefd7f938 100644 --- a/net/mpls/af_mpls.c +++ b/net/mpls/af_mpls.c @@ -24,7 +24,10 @@ #include #include "internal.h" -#define MAX_NEW_LABELS 2 +/* put a reasonable limit on the number of labels + * we will accept from userspace + */ +#define MAX_NEW_LABELS 30 /* max memory we will use for mpls_route */ #define MAX_MPLS_ROUTE_MEM 4096 @@ -698,9 +701,6 @@ static int mpls_nh_build_from_cfg(struct mpls_route_config *cfg, return -ENOMEM; err = -EINVAL; - /* Ensure only a supported number of labels are present */ - if (cfg->rc_output_labels > MAX_NEW_LABELS) - goto errout; nh->nh_labels = cfg->rc_output_labels; for (i = 0; i < nh->nh_labels; i++) @@ -725,7 +725,7 @@ errout: static int mpls_nh_build(struct net *net, struct mpls_route *rt, struct mpls_nh *nh, int oif, struct nlattr *via, - struct nlattr *newdst) + struct nlattr *newdst, u8 max_labels) { int err = -ENOMEM; @@ -733,7 +733,7 @@ static int mpls_nh_build(struct net *net, struct mpls_route *rt, goto errout; if (newdst) { - err = nla_get_labels(newdst, MAX_NEW_LABELS, + err = nla_get_labels(newdst, max_labels, &nh->nh_labels, nh->nh_label); if (err) goto errout; @@ -759,21 +759,19 @@ errout: } static u8 mpls_count_nexthops(struct rtnexthop *rtnh, int len, - u8 cfg_via_alen, u8 *max_via_alen) + u8 cfg_via_alen, u8 *max_via_alen, + u8 *max_labels) { int remaining = len; u8 nhs = 0; - if (!rtnh) { - *max_via_alen = cfg_via_alen; - return 1; - } - *max_via_alen = 0; + *max_labels = 0; while (rtnh_ok(rtnh, remaining)) { struct nlattr *nla, *attrs = rtnh_attrs(rtnh); int attrlen; + u8 n_labels = 0; attrlen = rtnh_attrlen(rtnh); nla = nla_find(attrs, attrlen, RTA_VIA); @@ -787,6 +785,13 @@ static u8 mpls_count_nexthops(struct rtnexthop *rtnh, int len, via_alen); } + nla = nla_find(attrs, attrlen, RTA_NEWDST); + if (nla && + nla_get_labels(nla, MAX_NEW_LABELS, &n_labels, NULL) != 0) + return 0; + + *max_labels = max_t(u8, *max_labels, n_labels); + /* number of nexthops is tracked by a u8. * Check for overflow. */ @@ -802,7 +807,7 @@ static u8 mpls_count_nexthops(struct rtnexthop *rtnh, int len, } static int mpls_nh_build_multi(struct mpls_route_config *cfg, - struct mpls_route *rt) + struct mpls_route *rt, u8 max_labels) { struct rtnexthop *rtnh = cfg->rc_mp; struct nlattr *nla_via, *nla_newdst; @@ -835,7 +840,8 @@ static int mpls_nh_build_multi(struct mpls_route_config *cfg, } err = mpls_nh_build(cfg->rc_nlinfo.nl_net, rt, nh, - rtnh->rtnh_ifindex, nla_via, nla_newdst); + rtnh->rtnh_ifindex, nla_via, nla_newdst, + max_labels); if (err) goto errout; @@ -862,6 +868,7 @@ static int mpls_route_add(struct mpls_route_config *cfg) int err = -EINVAL; u8 max_via_alen; unsigned index; + u8 max_labels; u8 nhs; index = cfg->rc_label; @@ -900,13 +907,21 @@ static int mpls_route_add(struct mpls_route_config *cfg) goto errout; err = -EINVAL; - nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len, - cfg->rc_via_alen, &max_via_alen); + if (cfg->rc_mp) { + nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len, + cfg->rc_via_alen, &max_via_alen, + &max_labels); + } else { + max_via_alen = cfg->rc_via_alen; + max_labels = cfg->rc_output_labels; + nhs = 1; + } + if (nhs == 0) goto errout; err = -ENOMEM; - rt = mpls_rt_alloc(nhs, max_via_alen, MAX_NEW_LABELS); + rt = mpls_rt_alloc(nhs, max_via_alen, max_labels); if (IS_ERR(rt)) { err = PTR_ERR(rt); goto errout; @@ -917,7 +932,7 @@ static int mpls_route_add(struct mpls_route_config *cfg) rt->rt_ttl_propagate = cfg->rc_ttl_propagate; if (cfg->rc_mp) - err = mpls_nh_build_multi(cfg, rt); + err = mpls_nh_build_multi(cfg, rt, max_labels); else err = mpls_nh_build_from_cfg(cfg, rt); if (err) @@ -1531,16 +1546,18 @@ int nla_put_labels(struct sk_buff *skb, int attrtype, EXPORT_SYMBOL_GPL(nla_put_labels); int nla_get_labels(const struct nlattr *nla, - u32 max_labels, u8 *labels, u32 label[]) + u8 max_labels, u8 *labels, u32 label[]) { unsigned len = nla_len(nla); - unsigned nla_labels; struct mpls_shim_hdr *nla_label; + u8 nla_labels; bool bos; int i; - /* len needs to be an even multiple of 4 (the label size) */ - if (len & 3) + /* len needs to be an even multiple of 4 (the label size). Number + * of labels is a u8 so check for overflow. + */ + if (len & 3 || len / 4 > 255) return -EINVAL; /* Limit the number of new labels allowed */ @@ -1548,6 +1565,10 @@ int nla_get_labels(const struct nlattr *nla, if (nla_labels > max_labels) return -EINVAL; + /* when label == NULL, caller wants number of labels */ + if (!label) + goto out; + nla_label = nla_data(nla); bos = true; for (i = nla_labels - 1; i >= 0; i--, bos = false) { @@ -1571,6 +1592,7 @@ int nla_get_labels(const struct nlattr *nla, label[i] = dec.label; } +out: *labels = nla_labels; return 0; } @@ -1632,7 +1654,6 @@ static int rtm_to_route_config(struct sk_buff *skb, struct nlmsghdr *nlh, err = -EINVAL; rtm = nlmsg_data(nlh); - memset(cfg, 0, sizeof(*cfg)); if (rtm->rtm_family != AF_MPLS) goto errout; @@ -1731,27 +1752,43 @@ errout: static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh) { - struct mpls_route_config cfg; + struct mpls_route_config *cfg; int err; - err = rtm_to_route_config(skb, nlh, &cfg); - if (err < 0) - return err; + cfg = kzalloc(sizeof(*cfg), GFP_KERNEL); + if (!cfg) + return -ENOMEM; - return mpls_route_del(&cfg); + err = rtm_to_route_config(skb, nlh, cfg); + if (err < 0) + goto out; + + err = mpls_route_del(cfg); +out: + kfree(cfg); + + return err; } static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh) { - struct mpls_route_config cfg; + struct mpls_route_config *cfg; int err; - err = rtm_to_route_config(skb, nlh, &cfg); - if (err < 0) - return err; + cfg = kzalloc(sizeof(*cfg), GFP_KERNEL); + if (!cfg) + return -ENOMEM; - return mpls_route_add(&cfg); + err = rtm_to_route_config(skb, nlh, cfg); + if (err < 0) + goto out; + + err = mpls_route_add(cfg); +out: + kfree(cfg); + + return err; } static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event, @@ -1980,7 +2017,7 @@ static int resize_platform_label_table(struct net *net, size_t limit) /* In case the predefined labels need to be populated */ if (limit > MPLS_LABEL_IPV4NULL) { struct net_device *lo = net->loopback_dev; - rt0 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS); + rt0 = mpls_rt_alloc(1, lo->addr_len, 0); if (IS_ERR(rt0)) goto nort0; RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo); @@ -1994,7 +2031,7 @@ static int resize_platform_label_table(struct net *net, size_t limit) } if (limit > MPLS_LABEL_IPV6NULL) { struct net_device *lo = net->loopback_dev; - rt2 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS); + rt2 = mpls_rt_alloc(1, lo->addr_len, 0); if (IS_ERR(rt2)) goto nort2; RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo); diff --git a/net/mpls/internal.h b/net/mpls/internal.h index cc324c022049..c5d2f5bc37ec 100644 --- a/net/mpls/internal.h +++ b/net/mpls/internal.h @@ -197,7 +197,7 @@ static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev) int nla_put_labels(struct sk_buff *skb, int attrtype, u8 labels, const u32 label[]); -int nla_get_labels(const struct nlattr *nla, u32 max_labels, u8 *labels, +int nla_get_labels(const struct nlattr *nla, u8 max_labels, u8 *labels, u32 label[]); int nla_get_via(const struct nlattr *nla, u8 *via_alen, u8 *via_table, u8 via[]);