diff --git a/drivers/infiniband/core/sa_query.c b/drivers/infiniband/core/sa_query.c
index cf474ec27070..78ea8157d622 100644
--- a/drivers/infiniband/core/sa_query.c
+++ b/drivers/infiniband/core/sa_query.c
@@ -361,7 +361,7 @@ static void update_sm_ah(struct work_struct *work)
 {
 	struct ib_sa_port *port =
 		container_of(work, struct ib_sa_port, update_task);
-	struct ib_sa_sm_ah *new_ah, *old_ah;
+	struct ib_sa_sm_ah *new_ah;
 	struct ib_port_attr port_attr;
 	struct ib_ah_attr   ah_attr;
 
@@ -397,12 +397,9 @@ static void update_sm_ah(struct work_struct *work)
 	}
 
 	spin_lock_irq(&port->ah_lock);
-	old_ah = port->sm_ah;
 	port->sm_ah = new_ah;
 	spin_unlock_irq(&port->ah_lock);
 
-	if (old_ah)
-		kref_put(&old_ah->ref, free_sm_ah);
 }
 
 static void ib_sa_event(struct ib_event_handler *handler, struct ib_event *event)
@@ -413,8 +410,17 @@ static void ib_sa_event(struct ib_event_handler *handler, struct ib_event *event
 	    event->event == IB_EVENT_PKEY_CHANGE ||
 	    event->event == IB_EVENT_SM_CHANGE   ||
 	    event->event == IB_EVENT_CLIENT_REREGISTER) {
-		struct ib_sa_device *sa_dev;
-		sa_dev = container_of(handler, typeof(*sa_dev), event_handler);
+		unsigned long flags;
+		struct ib_sa_device *sa_dev =
+			container_of(handler, typeof(*sa_dev), event_handler);
+		struct ib_sa_port *port =
+			&sa_dev->port[event->element.port_num - sa_dev->start_port];
+
+		spin_lock_irqsave(&port->ah_lock, flags);
+		if (port->sm_ah)
+			kref_put(&port->sm_ah->ref, free_sm_ah);
+		port->sm_ah = NULL;
+		spin_unlock_irqrestore(&port->ah_lock, flags);
 
 		schedule_work(&sa_dev->port[event->element.port_num -
 					    sa_dev->start_port].update_task);
@@ -519,6 +525,10 @@ static int alloc_mad(struct ib_sa_query *query, gfp_t gfp_mask)
 	unsigned long flags;
 
 	spin_lock_irqsave(&query->port->ah_lock, flags);
+	if (!query->port->sm_ah) {
+		spin_unlock_irqrestore(&query->port->ah_lock, flags);
+		return -EAGAIN;
+	}
 	kref_get(&query->port->sm_ah->ref);
 	query->sm_ah = query->port->sm_ah;
 	spin_unlock_irqrestore(&query->port->ah_lock, flags);