]> git.hungrycats.org Git - linux/commitdiff
[PATCH] Cleanup for SunRPC auth code
authorTrond Myklebust <trond.myklebust@fys.uio.no>
Wed, 8 Jan 2003 01:59:44 +0000 (17:59 -0800)
committerTrond Myklebust <trond.myklebust@fys.uio.no>
Wed, 8 Jan 2003 01:59:44 +0000 (17:59 -0800)
Converts the RPC client auth code to use 'list_head' rather than a
custom pointer scheme.

Fixes a (relatively harmless) race which could cause several cred
entries to be created for the same user.

include/linux/sunrpc/auth.h
net/sunrpc/auth.c
net/sunrpc/sunrpc_syms.c

index 6106cf73da6ba6563ddbdd2aba5af7c8cba976d3..5e481026fc7e91091a1d13db20f2a90f6d7fc163 100644 (file)
@@ -23,7 +23,7 @@
  * Client user credentials
  */
 struct rpc_cred {
-       struct rpc_cred *       cr_next;        /* linked list */
+       struct list_head        cr_hash;        /* hash chain */
        struct rpc_auth *       cr_auth;
        struct rpc_credops *    cr_ops;
        unsigned long           cr_expire;      /* when to gc */
@@ -49,7 +49,7 @@ struct rpc_cred {
 #define RPC_CREDCACHE_NR       8
 #define RPC_CREDCACHE_MASK     (RPC_CREDCACHE_NR - 1)
 struct rpc_auth {
-       struct rpc_cred *       au_credcache[RPC_CREDCACHE_NR];
+       struct list_head        au_credcache[RPC_CREDCACHE_NR];
        unsigned long           au_expire;      /* cache expiry interval */
        unsigned long           au_nextgc;      /* next garbage collection */
        unsigned int            au_cslack;      /* call cred size estimate */
@@ -101,8 +101,6 @@ struct rpc_cred *   rpcauth_bindcred(struct rpc_task *);
 void                   rpcauth_holdcred(struct rpc_task *);
 void                   put_rpccred(struct rpc_cred *);
 void                   rpcauth_unbindcred(struct rpc_task *);
-int                    rpcauth_matchcred(struct rpc_auth *,
-                                         struct rpc_cred *, int);
 u32 *                  rpcauth_marshcred(struct rpc_task *, u32 *);
 u32 *                  rpcauth_checkverf(struct rpc_task *, u32 *);
 int                    rpcauth_refreshcred(struct rpc_task *);
@@ -110,8 +108,6 @@ void                        rpcauth_invalcred(struct rpc_task *);
 int                    rpcauth_uptodatecred(struct rpc_task *);
 void                   rpcauth_init_credcache(struct rpc_auth *);
 void                   rpcauth_free_credcache(struct rpc_auth *);
-void                   rpcauth_insert_credcache(struct rpc_auth *,
-                                               struct rpc_cred *);
 
 static inline
 struct rpc_cred *      get_rpccred(struct rpc_cred *cred)
index f2ad100286eabbd1cd1aafb9a3722eb78b91b3c5..a45ab766376dbdee69ef333233f3379d3cd18819 100644 (file)
@@ -75,7 +75,9 @@ static spinlock_t rpc_credcache_lock = SPIN_LOCK_UNLOCKED;
 void
 rpcauth_init_credcache(struct rpc_auth *auth)
 {
-       memset(auth->au_credcache, 0, sizeof(auth->au_credcache));
+       int i;
+       for (i = 0; i < RPC_CREDCACHE_NR; i++)
+               INIT_LIST_HEAD(&auth->au_credcache[i]);
        auth->au_nextgc = jiffies + (auth->au_expire >> 1);
 }
 
@@ -86,11 +88,10 @@ static inline void
 rpcauth_crdestroy(struct rpc_cred *cred)
 {
 #ifdef RPC_DEBUG
-       if (cred->cr_magic != RPCAUTH_CRED_MAGIC)
-               BUG();
+       BUG_ON(cred->cr_magic != RPCAUTH_CRED_MAGIC ||
+                       atomic_read(&cred->cr_count) ||
+                       !list_empty(&cred->cr_hash));
        cred->cr_magic = 0;
-       if (atomic_read(&cred->cr_count) || cred->cr_auth)
-               BUG();
 #endif
        cred->cr_ops->crdestroy(cred);
 }
@@ -99,12 +100,13 @@ rpcauth_crdestroy(struct rpc_cred *cred)
  * Destroy a list of credentials
  */
 static inline
-void rpcauth_destroy_credlist(struct rpc_cred *head)
+void rpcauth_destroy_credlist(struct list_head *head)
 {
        struct rpc_cred *cred;
 
-       while ((cred = head) != NULL) {
-               head = cred->cr_next;
+       while (!list_empty(head)) {
+               cred = list_entry(head->next, struct rpc_cred, cr_hash);
+               list_del_init(&cred->cr_hash);
                rpcauth_crdestroy(cred);
        }
 }
@@ -116,137 +118,117 @@ void rpcauth_destroy_credlist(struct rpc_cred *head)
 void
 rpcauth_free_credcache(struct rpc_auth *auth)
 {
-       struct rpc_cred **q, *cred, *free = NULL;
+       LIST_HEAD(free);
+       struct list_head *pos, *next;
+       struct rpc_cred *cred;
        int             i;
 
        spin_lock(&rpc_credcache_lock);
        for (i = 0; i < RPC_CREDCACHE_NR; i++) {
-               q = &auth->au_credcache[i];
-               while ((cred = *q) != NULL) {
-                       *q = cred->cr_next;
+               list_for_each_safe(pos, next, &auth->au_credcache[i]) {
+                       cred = list_entry(pos, struct rpc_cred, cr_hash);
                        cred->cr_auth = NULL;
-                       if (atomic_read(&cred->cr_count) == 0) {
-                               cred->cr_next = free;
-                               free = cred;
-                       } else
-                               cred->cr_next = NULL;
+                       list_del_init(&cred->cr_hash);
+                       if (atomic_read(&cred->cr_count) == 0)
+                               list_add(&cred->cr_hash, &free);
                }
        }
        spin_unlock(&rpc_credcache_lock);
-       rpcauth_destroy_credlist(free);
+       rpcauth_destroy_credlist(&free);
+}
+
+static inline int
+rpcauth_prune_expired(struct rpc_cred *cred, struct list_head *free)
+{
+       if (atomic_read(&cred->cr_count) != 0)
+              return 0;
+       if (time_before(jiffies, cred->cr_expire))
+               return 0;
+       cred->cr_auth = NULL;
+       list_del(&cred->cr_hash);
+       list_add(&cred->cr_hash, free);
+       return 1;
 }
 
 /*
  * Remove stale credentials. Avoid sleeping inside the loop.
  */
 static void
-rpcauth_gc_credcache(struct rpc_auth *auth)
+rpcauth_gc_credcache(struct rpc_auth *auth, struct list_head *free)
 {
-       struct rpc_cred **q, *cred, *free = NULL;
+       struct list_head *pos, *next;
+       struct rpc_cred *cred;
        int             i;
 
        dprintk("RPC: gc'ing RPC credentials for auth %p\n", auth);
-       spin_lock(&rpc_credcache_lock);
        for (i = 0; i < RPC_CREDCACHE_NR; i++) {
-               q = &auth->au_credcache[i];
-               while ((cred = *q) != NULL) {
-                       if (!atomic_read(&cred->cr_count) &&
-                           time_before(cred->cr_expire, jiffies)) {
-                               *q = cred->cr_next;
-                               cred->cr_auth = NULL;
-                               cred->cr_next = free;
-                               free = cred;
-                               continue;
-                       }
-                       q = &cred->cr_next;
+               list_for_each_safe(pos, next, &auth->au_credcache[i]) {
+                       cred = list_entry(pos, struct rpc_cred, cr_hash);
+                       rpcauth_prune_expired(cred, free);
                }
        }
-       spin_unlock(&rpc_credcache_lock);
-       rpcauth_destroy_credlist(free);
        auth->au_nextgc = jiffies + auth->au_expire;
 }
 
-/*
- * Insert credential into cache
- */
-void
-rpcauth_insert_credcache(struct rpc_auth *auth, struct rpc_cred *cred)
-{
-       int             nr;
-
-       nr = (cred->cr_uid & RPC_CREDCACHE_MASK);
-       spin_lock(&rpc_credcache_lock);
-       cred->cr_next = auth->au_credcache[nr];
-       auth->au_credcache[nr] = cred;
-       cred->cr_auth = auth;
-       get_rpccred(cred);
-       spin_unlock(&rpc_credcache_lock);
-}
-
 /*
  * Look up a process' credentials in the authentication cache
  */
 static struct rpc_cred *
 rpcauth_lookup_credcache(struct rpc_auth *auth, int taskflags)
 {
-       struct rpc_cred **q, *cred = NULL;
+       LIST_HEAD(free);
+       struct list_head *pos, *next;
+       struct rpc_cred *new = NULL,
+                       *cred = NULL;
        int             nr = 0;
 
        if (!(taskflags & RPC_TASK_ROOTCREDS))
                nr = current->uid & RPC_CREDCACHE_MASK;
-
-       if (time_before(auth->au_nextgc, jiffies))
-               rpcauth_gc_credcache(auth);
-
+retry:
        spin_lock(&rpc_credcache_lock);
-       q = &auth->au_credcache[nr];
-       while ((cred = *q) != NULL) {
-               if (!(cred->cr_flags & RPCAUTH_CRED_DEAD) &&
-                   cred->cr_ops->crmatch(cred, taskflags)) {
-                       *q = cred->cr_next;
+       if (time_before(auth->au_nextgc, jiffies))
+               rpcauth_gc_credcache(auth, &free);
+       list_for_each_safe(pos, next, &auth->au_credcache[nr]) {
+               struct rpc_cred *entry;
+               entry = list_entry(pos, struct rpc_cred, cr_hash);
+               if (entry->cr_flags & RPCAUTH_CRED_DEAD)
+                       continue;
+               if (rpcauth_prune_expired(entry, &free))
+                       continue;
+               if (entry->cr_ops->crmatch(entry, taskflags)) {
+                       list_del(&entry->cr_hash);
+                       cred = entry;
                        break;
                }
-               q = &cred->cr_next;
+       }
+       if (new) {
+               if (cred)
+                       list_add(&new->cr_hash, &free);
+               else
+                       cred = new;
+       }
+       if (cred) {
+               list_add(&cred->cr_hash, &auth->au_credcache[nr]);
+               cred->cr_auth = auth;
+               get_rpccred(cred);
        }
        spin_unlock(&rpc_credcache_lock);
 
+       rpcauth_destroy_credlist(&free);
+
        if (!cred) {
-               cred = auth->au_ops->crcreate(taskflags);
+               new = auth->au_ops->crcreate(taskflags);
+               if (new) {
 #ifdef RPC_DEBUG
-               if (cred)
-                       cred->cr_magic = RPCAUTH_CRED_MAGIC;
+                       new->cr_magic = RPCAUTH_CRED_MAGIC;
 #endif
+                       goto retry;
+               }
        }
 
-       if (cred)
-               rpcauth_insert_credcache(auth, cred);
-
        return (struct rpc_cred *) cred;
 }
 
-/*
- * Remove cred handle from cache
- */
-static void
-rpcauth_remove_credcache(struct rpc_cred *cred)
-{
-       struct rpc_auth *auth = cred->cr_auth;
-       struct rpc_cred **q, *cr;
-       int             nr;
-
-       nr = (cred->cr_uid & RPC_CREDCACHE_MASK);
-       q = &auth->au_credcache[nr];
-       while ((cr = *q) != NULL) {
-               if (cred == cr) {
-                       *q = cred->cr_next;
-                       cred->cr_next = NULL;
-                       cred->cr_auth = NULL;
-                       break;
-               }
-               q = &cred->cr_next;
-       }
-}
-
 struct rpc_cred *
 rpcauth_lookupcred(struct rpc_auth *auth, int taskflags)
 {
@@ -268,14 +250,6 @@ rpcauth_bindcred(struct rpc_task *task)
        return task->tk_msg.rpc_cred;
 }
 
-int
-rpcauth_matchcred(struct rpc_auth *auth, struct rpc_cred *cred, int taskflags)
-{
-       dprintk("RPC:     matching %s cred %d\n",
-               auth->au_ops->au_name, taskflags);
-       return cred->cr_ops->crmatch(cred, taskflags);
-}
-
 void
 rpcauth_holdcred(struct rpc_task *task)
 {
@@ -291,10 +265,10 @@ put_rpccred(struct rpc_cred *cred)
        if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
                return;
 
-       if (cred->cr_auth && cred->cr_flags & RPCAUTH_CRED_DEAD)
-               rpcauth_remove_credcache(cred);
+       if ((cred->cr_flags & RPCAUTH_CRED_DEAD) && !list_empty(&cred->cr_hash))
+               list_del_init(&cred->cr_hash);
 
-       if (!cred->cr_auth) {
+       if (list_empty(&cred->cr_hash)) {
                spin_unlock(&rpc_credcache_lock);
                rpcauth_crdestroy(cred);
                return;
index 2c06ada571bec24d48913df8e6999353fa937e7a..dc56b0ea37487709ed5b4f11aa2f7bcba1667e58 100644 (file)
@@ -60,12 +60,7 @@ EXPORT_SYMBOL(xprt_set_timeout);
 /* Client credential cache */
 EXPORT_SYMBOL(rpcauth_register);
 EXPORT_SYMBOL(rpcauth_unregister);
-EXPORT_SYMBOL(rpcauth_init_credcache);
-EXPORT_SYMBOL(rpcauth_free_credcache);
-EXPORT_SYMBOL(rpcauth_insert_credcache);
 EXPORT_SYMBOL(rpcauth_lookupcred);
-EXPORT_SYMBOL(rpcauth_bindcred);
-EXPORT_SYMBOL(rpcauth_matchcred);
 EXPORT_SYMBOL(put_rpccred);
 
 /* RPC server stuff */