diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c index d214aecdbeef..5ec15bb43e83 100644 --- a/net/sunrpc/auth_gss/auth_gss.c +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -76,6 +76,7 @@ struct gss_pipe { struct rpc_pipe *pipe; struct rpc_clnt *clnt; const char *name; + struct kref kref; }; struct gss_auth { @@ -832,7 +833,6 @@ static struct gss_pipe *gss_pipe_alloc(struct rpc_clnt *clnt, const char *name, const struct rpc_pipe_ops *upcall_ops) { - struct net *net = rpc_net_ns(clnt); struct gss_pipe *p; int err = -ENOMEM; @@ -846,19 +846,71 @@ static struct gss_pipe *gss_pipe_alloc(struct rpc_clnt *clnt, } p->name = name; p->clnt = clnt; + kref_init(&p->kref); rpc_init_pipe_dir_object(&p->pdo, &gss_pipe_dir_object_ops, p); - err = rpc_add_pipe_dir_object(net, &clnt->cl_pipedir_objects, &p->pdo); - if (!err) - return p; - rpc_destroy_pipe_data(p->pipe); + return p; err_free_gss_pipe: kfree(p); err: return ERR_PTR(err); } +struct gss_alloc_pdo { + struct rpc_clnt *clnt; + const char *name; + const struct rpc_pipe_ops *upcall_ops; +}; + +static int gss_pipe_match_pdo(struct rpc_pipe_dir_object *pdo, void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + if (pdo->pdo_ops != &gss_pipe_dir_object_ops) + return 0; + gss_pipe = container_of(pdo, struct gss_pipe, pdo); + if (strcmp(gss_pipe->name, args->name) != 0) + return 0; + if (!kref_get_unless_zero(&gss_pipe->kref)) + return 0; + return 1; +} + +static struct rpc_pipe_dir_object *gss_pipe_alloc_pdo(void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + gss_pipe = gss_pipe_alloc(args->clnt, args->name, args->upcall_ops); + if (!IS_ERR(gss_pipe)) + return &gss_pipe->pdo; + return NULL; +} + +static struct gss_pipe *gss_pipe_get(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) +{ + struct net *net = rpc_net_ns(clnt); + struct rpc_pipe_dir_object *pdo; + struct gss_alloc_pdo args = { + .clnt = clnt, + .name = name, + .upcall_ops = upcall_ops, + }; + + pdo = rpc_find_or_alloc_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + gss_pipe_match_pdo, + gss_pipe_alloc_pdo, + &args); + if (pdo != NULL) + return container_of(pdo, struct gss_pipe, pdo); + return ERR_PTR(-ENOMEM); +} + static void __gss_pipe_free(struct gss_pipe *p) { struct rpc_clnt *clnt = p->clnt; @@ -871,10 +923,17 @@ static void __gss_pipe_free(struct gss_pipe *p) kfree(p); } +static void __gss_pipe_release(struct kref *kref) +{ + struct gss_pipe *p = container_of(kref, struct gss_pipe, kref); + + __gss_pipe_free(p); +} + static void gss_pipe_free(struct gss_pipe *p) { if (p != NULL) - __gss_pipe_free(p); + kref_put(&p->kref, __gss_pipe_release); } /* @@ -930,14 +989,14 @@ gss_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) * that we supported only the old pipe. So we instead create * the new pipe first. */ - gss_pipe = gss_pipe_alloc(clnt, "gssd", &gss_upcall_ops_v1); + gss_pipe = gss_pipe_get(clnt, "gssd", &gss_upcall_ops_v1); if (IS_ERR(gss_pipe)) { err = PTR_ERR(gss_pipe); goto err_destroy_credcache; } gss_auth->gss_pipe[1] = gss_pipe; - gss_pipe = gss_pipe_alloc(clnt, gss_auth->mech->gm_name, + gss_pipe = gss_pipe_get(clnt, gss_auth->mech->gm_name, &gss_upcall_ops_v0); if (IS_ERR(gss_pipe)) { err = PTR_ERR(gss_pipe); @@ -947,7 +1006,7 @@ gss_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) return auth; err_destroy_pipe_1: - __gss_pipe_free(gss_auth->gss_pipe[1]); + gss_pipe_free(gss_auth->gss_pipe[1]); err_destroy_credcache: rpcauth_destroy_credcache(auth); err_put_mech: