Note AES PMDs enablement in changelog
[deb_dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70         [VHOST_USER_CRYPTO_CREATE_SESS] = "VHOST_USER_CRYPTO_CREATE_SESS",
71         [VHOST_USER_CRYPTO_CLOSE_SESS] = "VHOST_USER_CRYPTO_CLOSE_SESS",
72 };
73
74 static uint64_t
75 get_blk_size(int fd)
76 {
77         struct stat stat;
78         int ret;
79
80         ret = fstat(fd, &stat);
81         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
82 }
83
84 static void
85 free_mem_region(struct virtio_net *dev)
86 {
87         uint32_t i;
88         struct rte_vhost_mem_region *reg;
89
90         if (!dev || !dev->mem)
91                 return;
92
93         for (i = 0; i < dev->mem->nregions; i++) {
94                 reg = &dev->mem->regions[i];
95                 if (reg->host_user_addr) {
96                         munmap(reg->mmap_addr, reg->mmap_size);
97                         close(reg->fd);
98                 }
99         }
100 }
101
102 void
103 vhost_backend_cleanup(struct virtio_net *dev)
104 {
105         if (dev->mem) {
106                 free_mem_region(dev);
107                 rte_free(dev->mem);
108                 dev->mem = NULL;
109         }
110
111         free(dev->guest_pages);
112         dev->guest_pages = NULL;
113
114         if (dev->log_addr) {
115                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
116                 dev->log_addr = 0;
117         }
118
119         if (dev->slave_req_fd >= 0) {
120                 close(dev->slave_req_fd);
121                 dev->slave_req_fd = -1;
122         }
123 }
124
125 /*
126  * This function just returns success at the moment unless
127  * the device hasn't been initialised.
128  */
129 static int
130 vhost_user_set_owner(void)
131 {
132         return 0;
133 }
134
135 static int
136 vhost_user_reset_owner(struct virtio_net *dev)
137 {
138         struct rte_vdpa_device *vdpa_dev;
139         int did = -1;
140
141         if (dev->flags & VIRTIO_DEV_RUNNING) {
142                 did = dev->vdpa_dev_id;
143                 vdpa_dev = rte_vdpa_get_device(did);
144                 if (vdpa_dev && vdpa_dev->ops->dev_close)
145                         vdpa_dev->ops->dev_close(dev->vid);
146                 dev->flags &= ~VIRTIO_DEV_RUNNING;
147                 dev->notify_ops->destroy_device(dev->vid);
148         }
149
150         cleanup_device(dev, 0);
151         reset_device(dev);
152         return 0;
153 }
154
155 /*
156  * The features that we support are requested.
157  */
158 static uint64_t
159 vhost_user_get_features(struct virtio_net *dev)
160 {
161         uint64_t features = 0;
162
163         rte_vhost_driver_get_features(dev->ifname, &features);
164         return features;
165 }
166
167 /*
168  * The queue number that we support are requested.
169  */
170 static uint32_t
171 vhost_user_get_queue_num(struct virtio_net *dev)
172 {
173         uint32_t queue_num = 0;
174
175         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
176         return queue_num;
177 }
178
179 /*
180  * We receive the negotiated features supported by us and the virtio device.
181  */
182 static int
183 vhost_user_set_features(struct virtio_net *dev, uint64_t features)
184 {
185         uint64_t vhost_features = 0;
186         struct rte_vdpa_device *vdpa_dev;
187         int did = -1;
188
189         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
190         if (features & ~vhost_features) {
191                 RTE_LOG(ERR, VHOST_CONFIG,
192                         "(%d) received invalid negotiated features.\n",
193                         dev->vid);
194                 return -1;
195         }
196
197         if (dev->flags & VIRTIO_DEV_RUNNING) {
198                 if (dev->features == features)
199                         return 0;
200
201                 /*
202                  * Error out if master tries to change features while device is
203                  * in running state. The exception being VHOST_F_LOG_ALL, which
204                  * is enabled when the live-migration starts.
205                  */
206                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
207                         RTE_LOG(ERR, VHOST_CONFIG,
208                                 "(%d) features changed while device is running.\n",
209                                 dev->vid);
210                         return -1;
211                 }
212
213                 if (dev->notify_ops->features_changed)
214                         dev->notify_ops->features_changed(dev->vid, features);
215         }
216
217         dev->features = features;
218         if (dev->features &
219                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
220                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
221         } else {
222                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
223         }
224         VHOST_LOG_DEBUG(VHOST_CONFIG,
225                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
226                 dev->vid,
227                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
228                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
229
230         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
231             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
232                 /*
233                  * Remove all but first queue pair if MQ hasn't been
234                  * negotiated. This is safe because the device is not
235                  * running at this stage.
236                  */
237                 while (dev->nr_vring > 2) {
238                         struct vhost_virtqueue *vq;
239
240                         vq = dev->virtqueue[--dev->nr_vring];
241                         if (!vq)
242                                 continue;
243
244                         dev->virtqueue[dev->nr_vring] = NULL;
245                         cleanup_vq(vq, 1);
246                         free_vq(vq);
247                 }
248         }
249
250         did = dev->vdpa_dev_id;
251         vdpa_dev = rte_vdpa_get_device(did);
252         if (vdpa_dev && vdpa_dev->ops->set_features)
253                 vdpa_dev->ops->set_features(dev->vid);
254
255         return 0;
256 }
257
258 /*
259  * The virtio device sends us the size of the descriptor ring.
260  */
261 static int
262 vhost_user_set_vring_num(struct virtio_net *dev,
263                          VhostUserMsg *msg)
264 {
265         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
266
267         vq->size = msg->payload.state.num;
268
269         /* VIRTIO 1.0, 2.4 Virtqueues says:
270          *
271          *   Queue Size value is always a power of 2. The maximum Queue Size
272          *   value is 32768.
273          */
274         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
275                 RTE_LOG(ERR, VHOST_CONFIG,
276                         "invalid virtqueue size %u\n", vq->size);
277                 return -1;
278         }
279
280         if (dev->dequeue_zero_copy) {
281                 vq->nr_zmbuf = 0;
282                 vq->last_zmbuf_idx = 0;
283                 vq->zmbuf_size = vq->size;
284                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
285                                          sizeof(struct zcopy_mbuf), 0);
286                 if (vq->zmbufs == NULL) {
287                         RTE_LOG(WARNING, VHOST_CONFIG,
288                                 "failed to allocate mem for zero copy; "
289                                 "zero copy is force disabled\n");
290                         dev->dequeue_zero_copy = 0;
291                 }
292                 TAILQ_INIT(&vq->zmbuf_list);
293         }
294
295         vq->shadow_used_ring = rte_malloc(NULL,
296                                 vq->size * sizeof(struct vring_used_elem),
297                                 RTE_CACHE_LINE_SIZE);
298         if (!vq->shadow_used_ring) {
299                 RTE_LOG(ERR, VHOST_CONFIG,
300                         "failed to allocate memory for shadow used ring.\n");
301                 return -1;
302         }
303
304         vq->batch_copy_elems = rte_malloc(NULL,
305                                 vq->size * sizeof(struct batch_copy_elem),
306                                 RTE_CACHE_LINE_SIZE);
307         if (!vq->batch_copy_elems) {
308                 RTE_LOG(ERR, VHOST_CONFIG,
309                         "failed to allocate memory for batching copy.\n");
310                 return -1;
311         }
312
313         return 0;
314 }
315
316 /*
317  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
318  * same numa node as the memory of vring descriptor.
319  */
320 #ifdef RTE_LIBRTE_VHOST_NUMA
321 static struct virtio_net*
322 numa_realloc(struct virtio_net *dev, int index)
323 {
324         int oldnode, newnode;
325         struct virtio_net *old_dev;
326         struct vhost_virtqueue *old_vq, *vq;
327         struct zcopy_mbuf *new_zmbuf;
328         struct vring_used_elem *new_shadow_used_ring;
329         struct batch_copy_elem *new_batch_copy_elems;
330         int ret;
331
332         old_dev = dev;
333         vq = old_vq = dev->virtqueue[index];
334
335         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
336                             MPOL_F_NODE | MPOL_F_ADDR);
337
338         /* check if we need to reallocate vq */
339         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
340                              MPOL_F_NODE | MPOL_F_ADDR);
341         if (ret) {
342                 RTE_LOG(ERR, VHOST_CONFIG,
343                         "Unable to get vq numa information.\n");
344                 return dev;
345         }
346         if (oldnode != newnode) {
347                 RTE_LOG(INFO, VHOST_CONFIG,
348                         "reallocate vq from %d to %d node\n", oldnode, newnode);
349                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
350                 if (!vq)
351                         return dev;
352
353                 memcpy(vq, old_vq, sizeof(*vq));
354                 TAILQ_INIT(&vq->zmbuf_list);
355
356                 new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
357                         sizeof(struct zcopy_mbuf), 0, newnode);
358                 if (new_zmbuf) {
359                         rte_free(vq->zmbufs);
360                         vq->zmbufs = new_zmbuf;
361                 }
362
363                 new_shadow_used_ring = rte_malloc_socket(NULL,
364                         vq->size * sizeof(struct vring_used_elem),
365                         RTE_CACHE_LINE_SIZE,
366                         newnode);
367                 if (new_shadow_used_ring) {
368                         rte_free(vq->shadow_used_ring);
369                         vq->shadow_used_ring = new_shadow_used_ring;
370                 }
371
372                 new_batch_copy_elems = rte_malloc_socket(NULL,
373                         vq->size * sizeof(struct batch_copy_elem),
374                         RTE_CACHE_LINE_SIZE,
375                         newnode);
376                 if (new_batch_copy_elems) {
377                         rte_free(vq->batch_copy_elems);
378                         vq->batch_copy_elems = new_batch_copy_elems;
379                 }
380
381                 rte_free(old_vq);
382         }
383
384         /* check if we need to reallocate dev */
385         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
386                             MPOL_F_NODE | MPOL_F_ADDR);
387         if (ret) {
388                 RTE_LOG(ERR, VHOST_CONFIG,
389                         "Unable to get dev numa information.\n");
390                 goto out;
391         }
392         if (oldnode != newnode) {
393                 RTE_LOG(INFO, VHOST_CONFIG,
394                         "reallocate dev from %d to %d node\n",
395                         oldnode, newnode);
396                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
397                 if (!dev) {
398                         dev = old_dev;
399                         goto out;
400                 }
401
402                 memcpy(dev, old_dev, sizeof(*dev));
403                 rte_free(old_dev);
404         }
405
406 out:
407         dev->virtqueue[index] = vq;
408         vhost_devices[dev->vid] = dev;
409
410         if (old_vq != vq)
411                 vhost_user_iotlb_init(dev, index);
412
413         return dev;
414 }
415 #else
416 static struct virtio_net*
417 numa_realloc(struct virtio_net *dev, int index __rte_unused)
418 {
419         return dev;
420 }
421 #endif
422
423 /* Converts QEMU virtual address to Vhost virtual address. */
424 static uint64_t
425 qva_to_vva(struct virtio_net *dev, uint64_t qva, uint64_t *len)
426 {
427         struct rte_vhost_mem_region *r;
428         uint32_t i;
429
430         /* Find the region where the address lives. */
431         for (i = 0; i < dev->mem->nregions; i++) {
432                 r = &dev->mem->regions[i];
433
434                 if (qva >= r->guest_user_addr &&
435                     qva <  r->guest_user_addr + r->size) {
436
437                         if (unlikely(*len > r->guest_user_addr + r->size - qva))
438                                 *len = r->guest_user_addr + r->size - qva;
439
440                         return qva - r->guest_user_addr +
441                                r->host_user_addr;
442                 }
443         }
444         *len = 0;
445
446         return 0;
447 }
448
449
450 /*
451  * Converts ring address to Vhost virtual address.
452  * If IOMMU is enabled, the ring address is a guest IO virtual address,
453  * else it is a QEMU virtual address.
454  */
455 static uint64_t
456 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
457                 uint64_t ra, uint64_t *size)
458 {
459         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
460                 uint64_t vva;
461
462                 vva = vhost_user_iotlb_cache_find(vq, ra,
463                                         size, VHOST_ACCESS_RW);
464                 if (!vva)
465                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
466
467                 return vva;
468         }
469
470         return qva_to_vva(dev, ra, size);
471 }
472
473 static struct virtio_net *
474 translate_ring_addresses(struct virtio_net *dev, int vq_index)
475 {
476         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
477         struct vhost_vring_addr *addr = &vq->ring_addrs;
478         uint64_t len;
479
480         /* The addresses are converted from QEMU virtual to Vhost virtual. */
481         if (vq->desc && vq->avail && vq->used)
482                 return dev;
483
484         len = sizeof(struct vring_desc) * vq->size;
485         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
486                         vq, addr->desc_user_addr, &len);
487         if (vq->desc == 0 || len != sizeof(struct vring_desc) * vq->size) {
488                 RTE_LOG(DEBUG, VHOST_CONFIG,
489                         "(%d) failed to map desc ring.\n",
490                         dev->vid);
491                 return dev;
492         }
493
494         dev = numa_realloc(dev, vq_index);
495         vq = dev->virtqueue[vq_index];
496         addr = &vq->ring_addrs;
497
498         len = sizeof(struct vring_avail) + sizeof(uint16_t) * vq->size;
499         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
500                         vq, addr->avail_user_addr, &len);
501         if (vq->avail == 0 ||
502                         len != sizeof(struct vring_avail) +
503                         sizeof(uint16_t) * vq->size) {
504                 RTE_LOG(DEBUG, VHOST_CONFIG,
505                         "(%d) failed to map avail ring.\n",
506                         dev->vid);
507                 return dev;
508         }
509
510         len = sizeof(struct vring_used) +
511                 sizeof(struct vring_used_elem) * vq->size;
512         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
513                         vq, addr->used_user_addr, &len);
514         if (vq->used == 0 || len != sizeof(struct vring_used) +
515                         sizeof(struct vring_used_elem) * vq->size) {
516                 RTE_LOG(DEBUG, VHOST_CONFIG,
517                         "(%d) failed to map used ring.\n",
518                         dev->vid);
519                 return dev;
520         }
521
522         if (vq->last_used_idx != vq->used->idx) {
523                 RTE_LOG(WARNING, VHOST_CONFIG,
524                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
525                         "some packets maybe resent for Tx and dropped for Rx\n",
526                         vq->last_used_idx, vq->used->idx);
527                 vq->last_used_idx  = vq->used->idx;
528                 vq->last_avail_idx = vq->used->idx;
529         }
530
531         vq->log_guest_addr = addr->log_guest_addr;
532
533         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
534                         dev->vid, vq->desc);
535         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
536                         dev->vid, vq->avail);
537         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
538                         dev->vid, vq->used);
539         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
540                         dev->vid, vq->log_guest_addr);
541
542         return dev;
543 }
544
545 /*
546  * The virtio device sends us the desc, used and avail ring addresses.
547  * This function then converts these to our address space.
548  */
549 static int
550 vhost_user_set_vring_addr(struct virtio_net **pdev, VhostUserMsg *msg)
551 {
552         struct vhost_virtqueue *vq;
553         struct vhost_vring_addr *addr = &msg->payload.addr;
554         struct virtio_net *dev = *pdev;
555
556         if (dev->mem == NULL)
557                 return -1;
558
559         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
560         vq = dev->virtqueue[msg->payload.addr.index];
561
562         /*
563          * Rings addresses should not be interpreted as long as the ring is not
564          * started and enabled
565          */
566         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
567
568         vring_invalidate(dev, vq);
569
570         if (vq->enabled && (dev->features &
571                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
572                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
573                 if (!dev)
574                         return -1;
575
576                 *pdev = dev;
577         }
578
579         return 0;
580 }
581
582 /*
583  * The virtio device sends us the available ring last used index.
584  */
585 static int
586 vhost_user_set_vring_base(struct virtio_net *dev,
587                           VhostUserMsg *msg)
588 {
589         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
590                         msg->payload.state.num;
591         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
592                         msg->payload.state.num;
593
594         return 0;
595 }
596
597 static int
598 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
599                    uint64_t host_phys_addr, uint64_t size)
600 {
601         struct guest_page *page, *last_page;
602
603         if (dev->nr_guest_pages == dev->max_guest_pages) {
604                 dev->max_guest_pages *= 2;
605                 dev->guest_pages = realloc(dev->guest_pages,
606                                         dev->max_guest_pages * sizeof(*page));
607                 if (!dev->guest_pages) {
608                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
609                         return -1;
610                 }
611         }
612
613         if (dev->nr_guest_pages > 0) {
614                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
615                 /* merge if the two pages are continuous */
616                 if (host_phys_addr == last_page->host_phys_addr +
617                                       last_page->size) {
618                         last_page->size += size;
619                         return 0;
620                 }
621         }
622
623         page = &dev->guest_pages[dev->nr_guest_pages++];
624         page->guest_phys_addr = guest_phys_addr;
625         page->host_phys_addr  = host_phys_addr;
626         page->size = size;
627
628         return 0;
629 }
630
631 static int
632 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
633                 uint64_t page_size)
634 {
635         uint64_t reg_size = reg->size;
636         uint64_t host_user_addr  = reg->host_user_addr;
637         uint64_t guest_phys_addr = reg->guest_phys_addr;
638         uint64_t host_phys_addr;
639         uint64_t size;
640
641         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
642         size = page_size - (guest_phys_addr & (page_size - 1));
643         size = RTE_MIN(size, reg_size);
644
645         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
646                 return -1;
647
648         host_user_addr  += size;
649         guest_phys_addr += size;
650         reg_size -= size;
651
652         while (reg_size > 0) {
653                 size = RTE_MIN(reg_size, page_size);
654                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
655                                                   host_user_addr);
656                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
657                                 size) < 0)
658                         return -1;
659
660                 host_user_addr  += size;
661                 guest_phys_addr += size;
662                 reg_size -= size;
663         }
664
665         return 0;
666 }
667
668 #ifdef RTE_LIBRTE_VHOST_DEBUG
669 /* TODO: enable it only in debug mode? */
670 static void
671 dump_guest_pages(struct virtio_net *dev)
672 {
673         uint32_t i;
674         struct guest_page *page;
675
676         for (i = 0; i < dev->nr_guest_pages; i++) {
677                 page = &dev->guest_pages[i];
678
679                 RTE_LOG(INFO, VHOST_CONFIG,
680                         "guest physical page region %u\n"
681                         "\t guest_phys_addr: %" PRIx64 "\n"
682                         "\t host_phys_addr : %" PRIx64 "\n"
683                         "\t size           : %" PRIx64 "\n",
684                         i,
685                         page->guest_phys_addr,
686                         page->host_phys_addr,
687                         page->size);
688         }
689 }
690 #else
691 #define dump_guest_pages(dev)
692 #endif
693
694 static bool
695 vhost_memory_changed(struct VhostUserMemory *new,
696                      struct rte_vhost_memory *old)
697 {
698         uint32_t i;
699
700         if (new->nregions != old->nregions)
701                 return true;
702
703         for (i = 0; i < new->nregions; ++i) {
704                 VhostUserMemoryRegion *new_r = &new->regions[i];
705                 struct rte_vhost_mem_region *old_r = &old->regions[i];
706
707                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
708                         return true;
709                 if (new_r->memory_size != old_r->size)
710                         return true;
711                 if (new_r->userspace_addr != old_r->guest_user_addr)
712                         return true;
713         }
714
715         return false;
716 }
717
718 static int
719 vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
720 {
721         struct virtio_net *dev = *pdev;
722         struct VhostUserMemory memory = pmsg->payload.memory;
723         struct rte_vhost_mem_region *reg;
724         void *mmap_addr;
725         uint64_t mmap_size;
726         uint64_t mmap_offset;
727         uint64_t alignment;
728         uint32_t i;
729         int populate;
730         int fd;
731
732         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
733                 RTE_LOG(ERR, VHOST_CONFIG,
734                         "too many memory regions (%u)\n", memory.nregions);
735                 return -1;
736         }
737
738         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
739                 RTE_LOG(INFO, VHOST_CONFIG,
740                         "(%d) memory regions not changed\n", dev->vid);
741
742                 for (i = 0; i < memory.nregions; i++)
743                         close(pmsg->fds[i]);
744
745                 return 0;
746         }
747
748         if (dev->mem) {
749                 free_mem_region(dev);
750                 rte_free(dev->mem);
751                 dev->mem = NULL;
752         }
753
754         dev->nr_guest_pages = 0;
755         if (!dev->guest_pages) {
756                 dev->max_guest_pages = 8;
757                 dev->guest_pages = malloc(dev->max_guest_pages *
758                                                 sizeof(struct guest_page));
759                 if (dev->guest_pages == NULL) {
760                         RTE_LOG(ERR, VHOST_CONFIG,
761                                 "(%d) failed to allocate memory "
762                                 "for dev->guest_pages\n",
763                                 dev->vid);
764                         return -1;
765                 }
766         }
767
768         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
769                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
770         if (dev->mem == NULL) {
771                 RTE_LOG(ERR, VHOST_CONFIG,
772                         "(%d) failed to allocate memory for dev->mem\n",
773                         dev->vid);
774                 return -1;
775         }
776         dev->mem->nregions = memory.nregions;
777
778         for (i = 0; i < memory.nregions; i++) {
779                 fd  = pmsg->fds[i];
780                 reg = &dev->mem->regions[i];
781
782                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
783                 reg->guest_user_addr = memory.regions[i].userspace_addr;
784                 reg->size            = memory.regions[i].memory_size;
785                 reg->fd              = fd;
786
787                 mmap_offset = memory.regions[i].mmap_offset;
788
789                 /* Check for memory_size + mmap_offset overflow */
790                 if (mmap_offset >= -reg->size) {
791                         RTE_LOG(ERR, VHOST_CONFIG,
792                                 "mmap_offset (%#"PRIx64") and memory_size "
793                                 "(%#"PRIx64") overflow\n",
794                                 mmap_offset, reg->size);
795                         goto err_mmap;
796                 }
797
798                 mmap_size = reg->size + mmap_offset;
799
800                 /* mmap() without flag of MAP_ANONYMOUS, should be called
801                  * with length argument aligned with hugepagesz at older
802                  * longterm version Linux, like 2.6.32 and 3.2.72, or
803                  * mmap() will fail with EINVAL.
804                  *
805                  * to avoid failure, make sure in caller to keep length
806                  * aligned.
807                  */
808                 alignment = get_blk_size(fd);
809                 if (alignment == (uint64_t)-1) {
810                         RTE_LOG(ERR, VHOST_CONFIG,
811                                 "couldn't get hugepage size through fstat\n");
812                         goto err_mmap;
813                 }
814                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
815
816                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
817                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
818                                  MAP_SHARED | populate, fd, 0);
819
820                 if (mmap_addr == MAP_FAILED) {
821                         RTE_LOG(ERR, VHOST_CONFIG,
822                                 "mmap region %u failed.\n", i);
823                         goto err_mmap;
824                 }
825
826                 reg->mmap_addr = mmap_addr;
827                 reg->mmap_size = mmap_size;
828                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
829                                       mmap_offset;
830
831                 if (dev->dequeue_zero_copy)
832                         if (add_guest_pages(dev, reg, alignment) < 0) {
833                                 RTE_LOG(ERR, VHOST_CONFIG,
834                                         "adding guest pages to region %u failed.\n",
835                                         i);
836                                 goto err_mmap;
837                         }
838
839                 RTE_LOG(INFO, VHOST_CONFIG,
840                         "guest memory region %u, size: 0x%" PRIx64 "\n"
841                         "\t guest physical addr: 0x%" PRIx64 "\n"
842                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
843                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
844                         "\t mmap addr : 0x%" PRIx64 "\n"
845                         "\t mmap size : 0x%" PRIx64 "\n"
846                         "\t mmap align: 0x%" PRIx64 "\n"
847                         "\t mmap off  : 0x%" PRIx64 "\n",
848                         i, reg->size,
849                         reg->guest_phys_addr,
850                         reg->guest_user_addr,
851                         reg->host_user_addr,
852                         (uint64_t)(uintptr_t)mmap_addr,
853                         mmap_size,
854                         alignment,
855                         mmap_offset);
856         }
857
858         for (i = 0; i < dev->nr_vring; i++) {
859                 struct vhost_virtqueue *vq = dev->virtqueue[i];
860
861                 if (vq->desc || vq->avail || vq->used) {
862                         /*
863                          * If the memory table got updated, the ring addresses
864                          * need to be translated again as virtual addresses have
865                          * changed.
866                          */
867                         vring_invalidate(dev, vq);
868
869                         dev = translate_ring_addresses(dev, i);
870                         if (!dev)
871                                 return -1;
872
873                         *pdev = dev;
874                 }
875         }
876
877         dump_guest_pages(dev);
878
879         return 0;
880
881 err_mmap:
882         free_mem_region(dev);
883         rte_free(dev->mem);
884         dev->mem = NULL;
885         return -1;
886 }
887
888 static int
889 vq_is_ready(struct vhost_virtqueue *vq)
890 {
891         return vq && vq->desc && vq->avail && vq->used &&
892                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
893                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
894 }
895
896 static int
897 virtio_is_ready(struct virtio_net *dev)
898 {
899         struct vhost_virtqueue *vq;
900         uint32_t i;
901
902         if (dev->nr_vring == 0)
903                 return 0;
904
905         for (i = 0; i < dev->nr_vring; i++) {
906                 vq = dev->virtqueue[i];
907
908                 if (!vq_is_ready(vq))
909                         return 0;
910         }
911
912         RTE_LOG(INFO, VHOST_CONFIG,
913                 "virtio is now ready for processing.\n");
914         return 1;
915 }
916
917 static void
918 vhost_user_set_vring_call(struct virtio_net *dev, struct VhostUserMsg *pmsg)
919 {
920         struct vhost_vring_file file;
921         struct vhost_virtqueue *vq;
922
923         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
924         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
925                 file.fd = VIRTIO_INVALID_EVENTFD;
926         else
927                 file.fd = pmsg->fds[0];
928         RTE_LOG(INFO, VHOST_CONFIG,
929                 "vring call idx:%d file:%d\n", file.index, file.fd);
930
931         vq = dev->virtqueue[file.index];
932         if (vq->callfd >= 0)
933                 close(vq->callfd);
934
935         vq->callfd = file.fd;
936 }
937
938 static void
939 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
940 {
941         struct vhost_vring_file file;
942         struct vhost_virtqueue *vq;
943         struct virtio_net *dev = *pdev;
944
945         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
946         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
947                 file.fd = VIRTIO_INVALID_EVENTFD;
948         else
949                 file.fd = pmsg->fds[0];
950         RTE_LOG(INFO, VHOST_CONFIG,
951                 "vring kick idx:%d file:%d\n", file.index, file.fd);
952
953         /* Interpret ring addresses only when ring is started. */
954         dev = translate_ring_addresses(dev, file.index);
955         if (!dev)
956                 return;
957
958         *pdev = dev;
959
960         vq = dev->virtqueue[file.index];
961
962         /*
963          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
964          * the ring starts already enabled. Otherwise, it is enabled via
965          * the SET_VRING_ENABLE message.
966          */
967         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
968                 vq->enabled = 1;
969
970         if (vq->kickfd >= 0)
971                 close(vq->kickfd);
972         vq->kickfd = file.fd;
973 }
974
975 static void
976 free_zmbufs(struct vhost_virtqueue *vq)
977 {
978         struct zcopy_mbuf *zmbuf, *next;
979
980         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
981              zmbuf != NULL; zmbuf = next) {
982                 next = TAILQ_NEXT(zmbuf, next);
983
984                 rte_pktmbuf_free(zmbuf->mbuf);
985                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
986         }
987
988         rte_free(vq->zmbufs);
989 }
990
991 /*
992  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
993  */
994 static int
995 vhost_user_get_vring_base(struct virtio_net *dev,
996                           VhostUserMsg *msg)
997 {
998         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
999         struct rte_vdpa_device *vdpa_dev;
1000         int did = -1;
1001
1002         /* We have to stop the queue (virtio) if it is running. */
1003         if (dev->flags & VIRTIO_DEV_RUNNING) {
1004                 did = dev->vdpa_dev_id;
1005                 vdpa_dev = rte_vdpa_get_device(did);
1006                 if (vdpa_dev && vdpa_dev->ops->dev_close)
1007                         vdpa_dev->ops->dev_close(dev->vid);
1008                 dev->flags &= ~VIRTIO_DEV_RUNNING;
1009                 dev->notify_ops->destroy_device(dev->vid);
1010         }
1011
1012         dev->flags &= ~VIRTIO_DEV_READY;
1013         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
1014
1015         /* Here we are safe to get the last avail index */
1016         msg->payload.state.num = vq->last_avail_idx;
1017
1018         RTE_LOG(INFO, VHOST_CONFIG,
1019                 "vring base idx:%d file:%d\n", msg->payload.state.index,
1020                 msg->payload.state.num);
1021         /*
1022          * Based on current qemu vhost-user implementation, this message is
1023          * sent and only sent in vhost_vring_stop.
1024          * TODO: cleanup the vring, it isn't usable since here.
1025          */
1026         if (vq->kickfd >= 0)
1027                 close(vq->kickfd);
1028
1029         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
1030
1031         if (vq->callfd >= 0)
1032                 close(vq->callfd);
1033
1034         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1035
1036         if (dev->dequeue_zero_copy)
1037                 free_zmbufs(vq);
1038         rte_free(vq->shadow_used_ring);
1039         vq->shadow_used_ring = NULL;
1040
1041         rte_free(vq->batch_copy_elems);
1042         vq->batch_copy_elems = NULL;
1043
1044         return 0;
1045 }
1046
1047 /*
1048  * when virtio queues are ready to work, qemu will send us to
1049  * enable the virtio queue pair.
1050  */
1051 static int
1052 vhost_user_set_vring_enable(struct virtio_net *dev,
1053                             VhostUserMsg *msg)
1054 {
1055         int enable = (int)msg->payload.state.num;
1056         int index = (int)msg->payload.state.index;
1057         struct rte_vdpa_device *vdpa_dev;
1058         int did = -1;
1059
1060         RTE_LOG(INFO, VHOST_CONFIG,
1061                 "set queue enable: %d to qp idx: %d\n",
1062                 enable, index);
1063
1064         did = dev->vdpa_dev_id;
1065         vdpa_dev = rte_vdpa_get_device(did);
1066         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1067                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1068
1069         if (dev->notify_ops->vring_state_changed)
1070                 dev->notify_ops->vring_state_changed(dev->vid,
1071                                 index, enable);
1072
1073         dev->virtqueue[index]->enabled = enable;
1074
1075         return 0;
1076 }
1077
1078 static void
1079 vhost_user_get_protocol_features(struct virtio_net *dev,
1080                                  struct VhostUserMsg *msg)
1081 {
1082         uint64_t features, protocol_features;
1083
1084         rte_vhost_driver_get_features(dev->ifname, &features);
1085         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1086
1087         /*
1088          * REPLY_ACK protocol feature is only mandatory for now
1089          * for IOMMU feature. If IOMMU is explicitly disabled by the
1090          * application, disable also REPLY_ACK feature for older buggy
1091          * Qemu versions (from v2.7.0 to v2.9.0).
1092          */
1093         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1094                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1095
1096         msg->payload.u64 = protocol_features;
1097         msg->size = sizeof(msg->payload.u64);
1098 }
1099
1100 static void
1101 vhost_user_set_protocol_features(struct virtio_net *dev,
1102                                  uint64_t protocol_features)
1103 {
1104         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES)
1105                 return;
1106
1107         dev->protocol_features = protocol_features;
1108 }
1109
1110 static int
1111 vhost_user_set_log_base(struct virtio_net *dev, struct VhostUserMsg *msg)
1112 {
1113         int fd = msg->fds[0];
1114         uint64_t size, off;
1115         void *addr;
1116
1117         if (fd < 0) {
1118                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1119                 return -1;
1120         }
1121
1122         if (msg->size != sizeof(VhostUserLog)) {
1123                 RTE_LOG(ERR, VHOST_CONFIG,
1124                         "invalid log base msg size: %"PRId32" != %d\n",
1125                         msg->size, (int)sizeof(VhostUserLog));
1126                 return -1;
1127         }
1128
1129         size = msg->payload.log.mmap_size;
1130         off  = msg->payload.log.mmap_offset;
1131
1132         /* Don't allow mmap_offset to point outside the mmap region */
1133         if (off > size) {
1134                 RTE_LOG(ERR, VHOST_CONFIG,
1135                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1136                         off, size);
1137                 return -1;
1138         }
1139
1140         RTE_LOG(INFO, VHOST_CONFIG,
1141                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1142                 size, off);
1143
1144         /*
1145          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1146          * fail when offset is not page size aligned.
1147          */
1148         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1149         close(fd);
1150         if (addr == MAP_FAILED) {
1151                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1152                 return -1;
1153         }
1154
1155         /*
1156          * Free previously mapped log memory on occasionally
1157          * multiple VHOST_USER_SET_LOG_BASE.
1158          */
1159         if (dev->log_addr) {
1160                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1161         }
1162         dev->log_addr = (uint64_t)(uintptr_t)addr;
1163         dev->log_base = dev->log_addr + off;
1164         dev->log_size = size;
1165
1166         return 0;
1167 }
1168
1169 /*
1170  * An rarp packet is constructed and broadcasted to notify switches about
1171  * the new location of the migrated VM, so that packets from outside will
1172  * not be lost after migration.
1173  *
1174  * However, we don't actually "send" a rarp packet here, instead, we set
1175  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1176  */
1177 static int
1178 vhost_user_send_rarp(struct virtio_net *dev, struct VhostUserMsg *msg)
1179 {
1180         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1181         struct rte_vdpa_device *vdpa_dev;
1182         int did = -1;
1183
1184         RTE_LOG(DEBUG, VHOST_CONFIG,
1185                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1186                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1187         memcpy(dev->mac.addr_bytes, mac, 6);
1188
1189         /*
1190          * Set the flag to inject a RARP broadcast packet at
1191          * rte_vhost_dequeue_burst().
1192          *
1193          * rte_smp_wmb() is for making sure the mac is copied
1194          * before the flag is set.
1195          */
1196         rte_smp_wmb();
1197         rte_atomic16_set(&dev->broadcast_rarp, 1);
1198         did = dev->vdpa_dev_id;
1199         vdpa_dev = rte_vdpa_get_device(did);
1200         if (vdpa_dev && vdpa_dev->ops->migration_done)
1201                 vdpa_dev->ops->migration_done(dev->vid);
1202
1203         return 0;
1204 }
1205
1206 static int
1207 vhost_user_net_set_mtu(struct virtio_net *dev, struct VhostUserMsg *msg)
1208 {
1209         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1210                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1211                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1212                                 msg->payload.u64);
1213
1214                 return -1;
1215         }
1216
1217         dev->mtu = msg->payload.u64;
1218
1219         return 0;
1220 }
1221
1222 static int
1223 vhost_user_set_req_fd(struct virtio_net *dev, struct VhostUserMsg *msg)
1224 {
1225         int fd = msg->fds[0];
1226
1227         if (fd < 0) {
1228                 RTE_LOG(ERR, VHOST_CONFIG,
1229                                 "Invalid file descriptor for slave channel (%d)\n",
1230                                 fd);
1231                 return -1;
1232         }
1233
1234         dev->slave_req_fd = fd;
1235
1236         return 0;
1237 }
1238
1239 static int
1240 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1241 {
1242         struct vhost_vring_addr *ra;
1243         uint64_t start, end;
1244
1245         start = imsg->iova;
1246         end = start + imsg->size;
1247
1248         ra = &vq->ring_addrs;
1249         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1250                 return 1;
1251         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1252                 return 1;
1253         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1254                 return 1;
1255
1256         return 0;
1257 }
1258
1259 static int
1260 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1261                                 struct vhost_iotlb_msg *imsg)
1262 {
1263         uint64_t istart, iend, vstart, vend;
1264
1265         istart = imsg->iova;
1266         iend = istart + imsg->size - 1;
1267
1268         vstart = (uintptr_t)vq->desc;
1269         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1270         if (vstart <= iend && istart <= vend)
1271                 return 1;
1272
1273         vstart = (uintptr_t)vq->avail;
1274         vend = vstart + sizeof(struct vring_avail);
1275         vend += sizeof(uint16_t) * vq->size - 1;
1276         if (vstart <= iend && istart <= vend)
1277                 return 1;
1278
1279         vstart = (uintptr_t)vq->used;
1280         vend = vstart + sizeof(struct vring_used);
1281         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1282         if (vstart <= iend && istart <= vend)
1283                 return 1;
1284
1285         return 0;
1286 }
1287
1288 static int
1289 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1290 {
1291         struct virtio_net *dev = *pdev;
1292         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1293         uint16_t i;
1294         uint64_t vva, len;
1295
1296         switch (imsg->type) {
1297         case VHOST_IOTLB_UPDATE:
1298                 len = imsg->size;
1299                 vva = qva_to_vva(dev, imsg->uaddr, &len);
1300                 if (!vva)
1301                         return -1;
1302
1303                 for (i = 0; i < dev->nr_vring; i++) {
1304                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1305
1306                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1307                                         len, imsg->perm);
1308
1309                         if (is_vring_iotlb_update(vq, imsg))
1310                                 *pdev = dev = translate_ring_addresses(dev, i);
1311                 }
1312                 break;
1313         case VHOST_IOTLB_INVALIDATE:
1314                 for (i = 0; i < dev->nr_vring; i++) {
1315                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1316
1317                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1318                                         imsg->size);
1319
1320                         if (is_vring_iotlb_invalidate(vq, imsg))
1321                                 vring_invalidate(dev, vq);
1322                 }
1323                 break;
1324         default:
1325                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1326                                 imsg->type);
1327                 return -1;
1328         }
1329
1330         return 0;
1331 }
1332
1333 /* return bytes# of read on success or negative val on failure. */
1334 static int
1335 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1336 {
1337         int ret;
1338
1339         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1340                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1341         if (ret <= 0)
1342                 return ret;
1343
1344         if (msg && msg->size) {
1345                 if (msg->size > sizeof(msg->payload)) {
1346                         RTE_LOG(ERR, VHOST_CONFIG,
1347                                 "invalid msg size: %d\n", msg->size);
1348                         return -1;
1349                 }
1350                 ret = read(sockfd, &msg->payload, msg->size);
1351                 if (ret <= 0)
1352                         return ret;
1353                 if (ret != (int)msg->size) {
1354                         RTE_LOG(ERR, VHOST_CONFIG,
1355                                 "read control message failed\n");
1356                         return -1;
1357                 }
1358         }
1359
1360         return ret;
1361 }
1362
1363 static int
1364 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1365 {
1366         if (!msg)
1367                 return 0;
1368
1369         return send_fd_message(sockfd, (char *)msg,
1370                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1371 }
1372
1373 static int
1374 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1375 {
1376         if (!msg)
1377                 return 0;
1378
1379         msg->flags &= ~VHOST_USER_VERSION_MASK;
1380         msg->flags &= ~VHOST_USER_NEED_REPLY;
1381         msg->flags |= VHOST_USER_VERSION;
1382         msg->flags |= VHOST_USER_REPLY_MASK;
1383
1384         return send_vhost_message(sockfd, msg, NULL, 0);
1385 }
1386
1387 /*
1388  * Allocate a queue pair if it hasn't been allocated yet
1389  */
1390 static int
1391 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev, VhostUserMsg *msg)
1392 {
1393         uint16_t vring_idx;
1394
1395         switch (msg->request.master) {
1396         case VHOST_USER_SET_VRING_KICK:
1397         case VHOST_USER_SET_VRING_CALL:
1398         case VHOST_USER_SET_VRING_ERR:
1399                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1400                 break;
1401         case VHOST_USER_SET_VRING_NUM:
1402         case VHOST_USER_SET_VRING_BASE:
1403         case VHOST_USER_SET_VRING_ENABLE:
1404                 vring_idx = msg->payload.state.index;
1405                 break;
1406         case VHOST_USER_SET_VRING_ADDR:
1407                 vring_idx = msg->payload.addr.index;
1408                 break;
1409         default:
1410                 return 0;
1411         }
1412
1413         if (vring_idx >= VHOST_MAX_VRING) {
1414                 RTE_LOG(ERR, VHOST_CONFIG,
1415                         "invalid vring index: %u\n", vring_idx);
1416                 return -1;
1417         }
1418
1419         if (dev->virtqueue[vring_idx])
1420                 return 0;
1421
1422         return alloc_vring_queue(dev, vring_idx);
1423 }
1424
1425 static void
1426 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1427 {
1428         unsigned int i = 0;
1429         unsigned int vq_num = 0;
1430
1431         while (vq_num < dev->nr_vring) {
1432                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1433
1434                 if (vq) {
1435                         rte_spinlock_lock(&vq->access_lock);
1436                         vq_num++;
1437                 }
1438                 i++;
1439         }
1440 }
1441
1442 static void
1443 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1444 {
1445         unsigned int i = 0;
1446         unsigned int vq_num = 0;
1447
1448         while (vq_num < dev->nr_vring) {
1449                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1450
1451                 if (vq) {
1452                         rte_spinlock_unlock(&vq->access_lock);
1453                         vq_num++;
1454                 }
1455                 i++;
1456         }
1457 }
1458
1459 int
1460 vhost_user_msg_handler(int vid, int fd)
1461 {
1462         struct virtio_net *dev;
1463         struct VhostUserMsg msg;
1464         struct rte_vdpa_device *vdpa_dev;
1465         int did = -1;
1466         int ret;
1467         int unlock_required = 0;
1468         uint32_t skip_master = 0;
1469
1470         dev = get_device(vid);
1471         if (dev == NULL)
1472                 return -1;
1473
1474         if (!dev->notify_ops) {
1475                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1476                 if (!dev->notify_ops) {
1477                         RTE_LOG(ERR, VHOST_CONFIG,
1478                                 "failed to get callback ops for driver %s\n",
1479                                 dev->ifname);
1480                         return -1;
1481                 }
1482         }
1483
1484         ret = read_vhost_message(fd, &msg);
1485         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1486                 if (ret < 0)
1487                         RTE_LOG(ERR, VHOST_CONFIG,
1488                                 "vhost read message failed\n");
1489                 else if (ret == 0)
1490                         RTE_LOG(INFO, VHOST_CONFIG,
1491                                 "vhost peer closed\n");
1492                 else
1493                         RTE_LOG(ERR, VHOST_CONFIG,
1494                                 "vhost read incorrect message\n");
1495
1496                 return -1;
1497         }
1498
1499         ret = 0;
1500         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1501                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1502                         vhost_message_str[msg.request.master]);
1503         else
1504                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1505                         vhost_message_str[msg.request.master]);
1506
1507         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1508         if (ret < 0) {
1509                 RTE_LOG(ERR, VHOST_CONFIG,
1510                         "failed to alloc queue\n");
1511                 return -1;
1512         }
1513
1514         /*
1515          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1516          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1517          * and device is destroyed. destroy_device waits for queues to be
1518          * inactive, so it is safe. Otherwise taking the access_lock
1519          * would cause a dead lock.
1520          */
1521         switch (msg.request.master) {
1522         case VHOST_USER_SET_FEATURES:
1523         case VHOST_USER_SET_PROTOCOL_FEATURES:
1524         case VHOST_USER_SET_OWNER:
1525         case VHOST_USER_SET_MEM_TABLE:
1526         case VHOST_USER_SET_LOG_BASE:
1527         case VHOST_USER_SET_LOG_FD:
1528         case VHOST_USER_SET_VRING_NUM:
1529         case VHOST_USER_SET_VRING_ADDR:
1530         case VHOST_USER_SET_VRING_BASE:
1531         case VHOST_USER_SET_VRING_KICK:
1532         case VHOST_USER_SET_VRING_CALL:
1533         case VHOST_USER_SET_VRING_ERR:
1534         case VHOST_USER_SET_VRING_ENABLE:
1535         case VHOST_USER_SEND_RARP:
1536         case VHOST_USER_NET_SET_MTU:
1537         case VHOST_USER_SET_SLAVE_REQ_FD:
1538                 vhost_user_lock_all_queue_pairs(dev);
1539                 unlock_required = 1;
1540                 break;
1541         default:
1542                 break;
1543
1544         }
1545
1546         if (dev->extern_ops.pre_msg_handle) {
1547                 uint32_t need_reply;
1548
1549                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1550                                 (void *)&msg, &need_reply, &skip_master);
1551                 if (ret < 0)
1552                         goto skip_to_reply;
1553
1554                 if (need_reply)
1555                         send_vhost_reply(fd, &msg);
1556
1557                 if (skip_master)
1558                         goto skip_to_post_handle;
1559         }
1560
1561         switch (msg.request.master) {
1562         case VHOST_USER_GET_FEATURES:
1563                 msg.payload.u64 = vhost_user_get_features(dev);
1564                 msg.size = sizeof(msg.payload.u64);
1565                 send_vhost_reply(fd, &msg);
1566                 break;
1567         case VHOST_USER_SET_FEATURES:
1568                 ret = vhost_user_set_features(dev, msg.payload.u64);
1569                 if (ret)
1570                         return -1;
1571                 break;
1572
1573         case VHOST_USER_GET_PROTOCOL_FEATURES:
1574                 vhost_user_get_protocol_features(dev, &msg);
1575                 send_vhost_reply(fd, &msg);
1576                 break;
1577         case VHOST_USER_SET_PROTOCOL_FEATURES:
1578                 vhost_user_set_protocol_features(dev, msg.payload.u64);
1579                 break;
1580
1581         case VHOST_USER_SET_OWNER:
1582                 vhost_user_set_owner();
1583                 break;
1584         case VHOST_USER_RESET_OWNER:
1585                 vhost_user_reset_owner(dev);
1586                 break;
1587
1588         case VHOST_USER_SET_MEM_TABLE:
1589                 ret = vhost_user_set_mem_table(&dev, &msg);
1590                 break;
1591
1592         case VHOST_USER_SET_LOG_BASE:
1593                 vhost_user_set_log_base(dev, &msg);
1594
1595                 /* it needs a reply */
1596                 msg.size = sizeof(msg.payload.u64);
1597                 send_vhost_reply(fd, &msg);
1598                 break;
1599         case VHOST_USER_SET_LOG_FD:
1600                 close(msg.fds[0]);
1601                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1602                 break;
1603
1604         case VHOST_USER_SET_VRING_NUM:
1605                 vhost_user_set_vring_num(dev, &msg);
1606                 break;
1607         case VHOST_USER_SET_VRING_ADDR:
1608                 vhost_user_set_vring_addr(&dev, &msg);
1609                 break;
1610         case VHOST_USER_SET_VRING_BASE:
1611                 vhost_user_set_vring_base(dev, &msg);
1612                 break;
1613
1614         case VHOST_USER_GET_VRING_BASE:
1615                 vhost_user_get_vring_base(dev, &msg);
1616                 msg.size = sizeof(msg.payload.state);
1617                 send_vhost_reply(fd, &msg);
1618                 break;
1619
1620         case VHOST_USER_SET_VRING_KICK:
1621                 vhost_user_set_vring_kick(&dev, &msg);
1622                 break;
1623         case VHOST_USER_SET_VRING_CALL:
1624                 vhost_user_set_vring_call(dev, &msg);
1625                 break;
1626
1627         case VHOST_USER_SET_VRING_ERR:
1628                 if (!(msg.payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1629                         close(msg.fds[0]);
1630                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1631                 break;
1632
1633         case VHOST_USER_GET_QUEUE_NUM:
1634                 msg.payload.u64 = (uint64_t)vhost_user_get_queue_num(dev);
1635                 msg.size = sizeof(msg.payload.u64);
1636                 send_vhost_reply(fd, &msg);
1637                 break;
1638
1639         case VHOST_USER_SET_VRING_ENABLE:
1640                 vhost_user_set_vring_enable(dev, &msg);
1641                 break;
1642         case VHOST_USER_SEND_RARP:
1643                 vhost_user_send_rarp(dev, &msg);
1644                 break;
1645
1646         case VHOST_USER_NET_SET_MTU:
1647                 ret = vhost_user_net_set_mtu(dev, &msg);
1648                 break;
1649
1650         case VHOST_USER_SET_SLAVE_REQ_FD:
1651                 ret = vhost_user_set_req_fd(dev, &msg);
1652                 break;
1653
1654         case VHOST_USER_IOTLB_MSG:
1655                 ret = vhost_user_iotlb_msg(&dev, &msg);
1656                 break;
1657
1658         default:
1659                 ret = -1;
1660                 break;
1661         }
1662
1663 skip_to_post_handle:
1664         if (dev->extern_ops.post_msg_handle) {
1665                 uint32_t need_reply;
1666
1667                 ret = (*dev->extern_ops.post_msg_handle)(
1668                                 dev->vid, (void *)&msg, &need_reply);
1669                 if (ret < 0)
1670                         goto skip_to_reply;
1671
1672                 if (need_reply)
1673                         send_vhost_reply(fd, &msg);
1674         }
1675
1676 skip_to_reply:
1677         if (unlock_required)
1678                 vhost_user_unlock_all_queue_pairs(dev);
1679
1680         if (msg.flags & VHOST_USER_NEED_REPLY) {
1681                 msg.payload.u64 = !!ret;
1682                 msg.size = sizeof(msg.payload.u64);
1683                 send_vhost_reply(fd, &msg);
1684         }
1685
1686         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1687                 dev->flags |= VIRTIO_DEV_READY;
1688
1689                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1690                         if (dev->dequeue_zero_copy) {
1691                                 RTE_LOG(INFO, VHOST_CONFIG,
1692                                                 "dequeue zero copy is enabled\n");
1693                         }
1694
1695                         if (dev->notify_ops->new_device(dev->vid) == 0)
1696                                 dev->flags |= VIRTIO_DEV_RUNNING;
1697                 }
1698         }
1699
1700         did = dev->vdpa_dev_id;
1701         vdpa_dev = rte_vdpa_get_device(did);
1702         if (vdpa_dev && virtio_is_ready(dev) &&
1703                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1704                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1705                 if (vdpa_dev->ops->dev_conf)
1706                         vdpa_dev->ops->dev_conf(dev->vid);
1707                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1708         }
1709
1710         return 0;
1711 }
1712
1713 int
1714 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1715 {
1716         int ret;
1717         struct VhostUserMsg msg = {
1718                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1719                 .flags = VHOST_USER_VERSION,
1720                 .size = sizeof(msg.payload.iotlb),
1721                 .payload.iotlb = {
1722                         .iova = iova,
1723                         .perm = perm,
1724                         .type = VHOST_IOTLB_MISS,
1725                 },
1726         };
1727
1728         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1729         if (ret < 0) {
1730                 RTE_LOG(ERR, VHOST_CONFIG,
1731                                 "Failed to send IOTLB miss message (%d)\n",
1732                                 ret);
1733                 return ret;
1734         }
1735
1736         return 0;
1737 }