diff --git a/drivers/xen/pvcalls-front.c b/drivers/xen/pvcalls-front.c index 326395d09e31..8d4a43e6aa46 100644 --- a/drivers/xen/pvcalls-front.c +++ b/drivers/xen/pvcalls-front.c @@ -59,6 +59,18 @@ struct sock_mapping { bool active_socket; struct list_head list; struct socket *sock; + union { + struct { + int irq; + grant_ref_t ref; + struct pvcalls_data_intf *ring; + struct pvcalls_data data; + struct mutex in_mutex; + struct mutex out_mutex; + + wait_queue_head_t inflight_conn_req; + } active; + }; }; static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) @@ -121,6 +133,18 @@ static void pvcalls_front_free_map(struct pvcalls_bedata *bedata, { } +static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map) +{ + struct sock_mapping *map = sock_map; + + if (map == NULL) + return IRQ_HANDLED; + + wake_up_interruptible(&map->active.inflight_conn_req); + + return IRQ_HANDLED; +} + int pvcalls_front_socket(struct socket *sock) { struct pvcalls_bedata *bedata; @@ -196,6 +220,132 @@ int pvcalls_front_socket(struct socket *sock) return ret; } +static int create_active(struct sock_mapping *map, int *evtchn) +{ + void *bytes; + int ret = -ENOMEM, irq = -1, i; + + *evtchn = -1; + init_waitqueue_head(&map->active.inflight_conn_req); + + map->active.ring = (struct pvcalls_data_intf *) + __get_free_page(GFP_KERNEL | __GFP_ZERO); + if (map->active.ring == NULL) + goto out_error; + map->active.ring->ring_order = PVCALLS_RING_ORDER; + bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, + PVCALLS_RING_ORDER); + if (bytes == NULL) + goto out_error; + for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) + map->active.ring->ref[i] = gnttab_grant_foreign_access( + pvcalls_front_dev->otherend_id, + pfn_to_gfn(virt_to_pfn(bytes) + i), 0); + + map->active.ref = gnttab_grant_foreign_access( + pvcalls_front_dev->otherend_id, + pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0); + + map->active.data.in = bytes; + map->active.data.out = bytes + + XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); + + ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn); + if (ret) + goto out_error; + irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler, + 0, "pvcalls-frontend", map); + if (irq < 0) { + ret = irq; + goto out_error; + } + + map->active.irq = irq; + map->active_socket = true; + mutex_init(&map->active.in_mutex); + mutex_init(&map->active.out_mutex); + + return 0; + +out_error: + if (irq >= 0) + unbind_from_irqhandler(irq, map); + else if (*evtchn >= 0) + xenbus_free_evtchn(pvcalls_front_dev, *evtchn); + kfree(map->active.data.in); + kfree(map->active.ring); + return ret; +} + +int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, + int addr_len, int flags) +{ + struct pvcalls_bedata *bedata; + struct sock_mapping *map = NULL; + struct xen_pvcalls_request *req; + int notify, req_id, ret, evtchn; + + if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) + return -EOPNOTSUPP; + + pvcalls_enter(); + if (!pvcalls_front_dev) { + pvcalls_exit(); + return -ENOTCONN; + } + + bedata = dev_get_drvdata(&pvcalls_front_dev->dev); + + map = (struct sock_mapping *)sock->sk->sk_send_head; + if (!map) { + pvcalls_exit(); + return -ENOTSOCK; + } + + spin_lock(&bedata->socket_lock); + ret = get_request(bedata, &req_id); + if (ret < 0) { + spin_unlock(&bedata->socket_lock); + pvcalls_exit(); + return ret; + } + ret = create_active(map, &evtchn); + if (ret < 0) { + spin_unlock(&bedata->socket_lock); + pvcalls_exit(); + return ret; + } + + req = RING_GET_REQUEST(&bedata->ring, req_id); + req->req_id = req_id; + req->cmd = PVCALLS_CONNECT; + req->u.connect.id = (uintptr_t)map; + req->u.connect.len = addr_len; + req->u.connect.flags = flags; + req->u.connect.ref = map->active.ref; + req->u.connect.evtchn = evtchn; + memcpy(req->u.connect.addr, addr, sizeof(*addr)); + + map->sock = sock; + + bedata->ring.req_prod_pvt++; + RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); + spin_unlock(&bedata->socket_lock); + + if (notify) + notify_remote_via_irq(bedata->irq); + + wait_event(bedata->inflight_req, + READ_ONCE(bedata->rsp[req_id].req_id) == req_id); + + /* read req_id, then the content */ + smp_rmb(); + ret = bedata->rsp[req_id].ret; + bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; + pvcalls_exit(); + return ret; +} + static const struct xenbus_device_id pvcalls_front_ids[] = { { "pvcalls" }, { "" } @@ -212,6 +362,14 @@ static int pvcalls_front_remove(struct xenbus_device *dev) if (bedata->irq >= 0) unbind_from_irqhandler(bedata->irq, dev); + list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { + map->sock->sk->sk_send_head = NULL; + if (map->active_socket) { + map->active.ring->in_error = -EBADF; + wake_up_interruptible(&map->active.inflight_conn_req); + } + } + smp_mb(); while (atomic_read(&pvcalls_refcount) > 0) cpu_relax(); diff --git a/drivers/xen/pvcalls-front.h b/drivers/xen/pvcalls-front.h index b7dabedf5ccb..63b0417c31d3 100644 --- a/drivers/xen/pvcalls-front.h +++ b/drivers/xen/pvcalls-front.h @@ -4,5 +4,7 @@ #include int pvcalls_front_socket(struct socket *sock); +int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, + int addr_len, int flags); #endif