Use atomic_t for ucounts reference counting

The current implementation of the ucounts reference counter requires the
use of spin_lock. We're going to use get_ucounts() in more performance
critical areas like a handling of RLIMIT_SIGPENDING.

Now we need to use spin_lock only if we want to change the hashtable.

v10:
* Always try to put ucounts in case we cannot increase ucounts->count.
  This will allow to cover the case when all consumers will return
  ucounts at once.

v9:
* Use a negative value to check that the ucounts->count is close to
  overflow.

Signed-off-by: Alexey Gladkov <legion@kernel.org>
Link: https://lkml.kernel.org/r/94d1dbecab060a6b116b0a2d1accd8ca1bbb4f5f.1619094428.git.legion@kernel.org
Signed-off-by: Eric W. Biederman <ebiederm@xmission.com>
This commit is contained in:
Alexey Gladkov 2021-04-22 14:27:10 +02:00 committed by Eric W. Biederman
parent 905ae01c4a
commit b6c3365289
2 changed files with 21 additions and 36 deletions

View File

@ -95,7 +95,7 @@ struct ucounts {
struct hlist_node node; struct hlist_node node;
struct user_namespace *ns; struct user_namespace *ns;
kuid_t uid; kuid_t uid;
int count; atomic_t count;
atomic_long_t ucount[UCOUNT_COUNTS]; atomic_long_t ucount[UCOUNT_COUNTS];
}; };
@ -107,7 +107,7 @@ void retire_userns_sysctls(struct user_namespace *ns);
struct ucounts *inc_ucount(struct user_namespace *ns, kuid_t uid, enum ucount_type type); struct ucounts *inc_ucount(struct user_namespace *ns, kuid_t uid, enum ucount_type type);
void dec_ucount(struct ucounts *ucounts, enum ucount_type type); void dec_ucount(struct ucounts *ucounts, enum ucount_type type);
struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid); struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid);
struct ucounts *get_ucounts(struct ucounts *ucounts); struct ucounts * __must_check get_ucounts(struct ucounts *ucounts);
void put_ucounts(struct ucounts *ucounts); void put_ucounts(struct ucounts *ucounts);
#ifdef CONFIG_USER_NS #ifdef CONFIG_USER_NS

View File

@ -11,7 +11,7 @@
struct ucounts init_ucounts = { struct ucounts init_ucounts = {
.ns = &init_user_ns, .ns = &init_user_ns,
.uid = GLOBAL_ROOT_UID, .uid = GLOBAL_ROOT_UID,
.count = 1, .count = ATOMIC_INIT(1),
}; };
#define UCOUNTS_HASHTABLE_BITS 10 #define UCOUNTS_HASHTABLE_BITS 10
@ -139,6 +139,15 @@ static void hlist_add_ucounts(struct ucounts *ucounts)
spin_unlock_irq(&ucounts_lock); spin_unlock_irq(&ucounts_lock);
} }
struct ucounts *get_ucounts(struct ucounts *ucounts)
{
if (ucounts && atomic_add_negative(1, &ucounts->count)) {
put_ucounts(ucounts);
ucounts = NULL;
}
return ucounts;
}
struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid) struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid)
{ {
struct hlist_head *hashent = ucounts_hashentry(ns, uid); struct hlist_head *hashent = ucounts_hashentry(ns, uid);
@ -155,7 +164,7 @@ struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid)
new->ns = ns; new->ns = ns;
new->uid = uid; new->uid = uid;
new->count = 0; atomic_set(&new->count, 1);
spin_lock_irq(&ucounts_lock); spin_lock_irq(&ucounts_lock);
ucounts = find_ucounts(ns, uid, hashent); ucounts = find_ucounts(ns, uid, hashent);
@ -163,33 +172,12 @@ struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid)
kfree(new); kfree(new);
} else { } else {
hlist_add_head(&new->node, hashent); hlist_add_head(&new->node, hashent);
ucounts = new; spin_unlock_irq(&ucounts_lock);
return new;
} }
} }
if (ucounts->count == INT_MAX)
ucounts = NULL;
else
ucounts->count += 1;
spin_unlock_irq(&ucounts_lock); spin_unlock_irq(&ucounts_lock);
return ucounts; ucounts = get_ucounts(ucounts);
}
struct ucounts *get_ucounts(struct ucounts *ucounts)
{
unsigned long flags;
if (!ucounts)
return NULL;
spin_lock_irqsave(&ucounts_lock, flags);
if (ucounts->count == INT_MAX) {
WARN_ONCE(1, "ucounts: counter has reached its maximum value");
ucounts = NULL;
} else {
ucounts->count += 1;
}
spin_unlock_irqrestore(&ucounts_lock, flags);
return ucounts; return ucounts;
} }
@ -197,15 +185,12 @@ void put_ucounts(struct ucounts *ucounts)
{ {
unsigned long flags; unsigned long flags;
spin_lock_irqsave(&ucounts_lock, flags); if (atomic_dec_and_test(&ucounts->count)) {
ucounts->count -= 1; spin_lock_irqsave(&ucounts_lock, flags);
if (!ucounts->count)
hlist_del_init(&ucounts->node); hlist_del_init(&ucounts->node);
else spin_unlock_irqrestore(&ucounts_lock, flags);
ucounts = NULL; kfree(ucounts);
spin_unlock_irqrestore(&ucounts_lock, flags); }
kfree(ucounts);
} }
static inline bool atomic_long_inc_below(atomic_long_t *v, int u) static inline bool atomic_long_inc_below(atomic_long_t *v, int u)