diff --git a/net/sched/cls_flower.c b/net/sched/cls_flower.c index e98313cd710a..92478bb122d3 100644 --- a/net/sched/cls_flower.c +++ b/net/sched/cls_flower.c @@ -1304,11 +1304,14 @@ static struct fl_flow_mask *fl_create_new_mask(struct cls_fl_head *head, INIT_LIST_HEAD_RCU(&newmask->filters); refcount_set(&newmask->refcnt, 1); - err = rhashtable_insert_fast(&head->ht, &newmask->ht_node, - mask_ht_params); + err = rhashtable_replace_fast(&head->ht, &mask->ht_node, + &newmask->ht_node, mask_ht_params); if (err) goto errout_destroy; + /* Wait until any potential concurrent users of mask are finished */ + synchronize_rcu(); + list_add_tail_rcu(&newmask->list, &head->masks); return newmask; @@ -1330,19 +1333,36 @@ static int fl_check_assign_mask(struct cls_fl_head *head, int ret = 0; rcu_read_lock(); - fnew->mask = rhashtable_lookup_fast(&head->ht, mask, mask_ht_params); + + /* Insert mask as temporary node to prevent concurrent creation of mask + * with same key. Any concurrent lookups with same key will return + * -EAGAIN because mask's refcnt is zero. It is safe to insert + * stack-allocated 'mask' to masks hash table because we call + * synchronize_rcu() before returning from this function (either in case + * of error or after replacing it with heap-allocated mask in + * fl_create_new_mask()). + */ + fnew->mask = rhashtable_lookup_get_insert_fast(&head->ht, + &mask->ht_node, + mask_ht_params); if (!fnew->mask) { rcu_read_unlock(); - if (fold) - return -EINVAL; + if (fold) { + ret = -EINVAL; + goto errout_cleanup; + } newmask = fl_create_new_mask(head, mask); - if (IS_ERR(newmask)) - return PTR_ERR(newmask); + if (IS_ERR(newmask)) { + ret = PTR_ERR(newmask); + goto errout_cleanup; + } fnew->mask = newmask; return 0; + } else if (IS_ERR(fnew->mask)) { + ret = PTR_ERR(fnew->mask); } else if (fold && fold->mask != fnew->mask) { ret = -EINVAL; } else if (!refcount_inc_not_zero(&fnew->mask->refcnt)) { @@ -1351,6 +1371,13 @@ static int fl_check_assign_mask(struct cls_fl_head *head, } rcu_read_unlock(); return ret; + +errout_cleanup: + rhashtable_remove_fast(&head->ht, &mask->ht_node, + mask_ht_params); + /* Wait until any potential concurrent users of mask are finished */ + synchronize_rcu(); + return ret; } static int fl_set_parms(struct net *net, struct tcf_proto *tp,