diff --git a/tests/vhost-user-test.c b/tests/vhost-user-test.c index ffdd398f86..a39846e6fd 100644 --- a/tests/vhost-user-test.c +++ b/tests/vhost-user-test.c @@ -130,6 +130,13 @@ static VhostUserMsg m __attribute__ ((unused)); #define VHOST_USER_VERSION (0x1) /*****************************************************************************/ +enum { + TEST_FLAGS_OK, + TEST_FLAGS_DISCONNECT, + TEST_FLAGS_BAD, + TEST_FLAGS_END, +}; + typedef struct TestServer { gchar *socket_path; gchar *mig_path; @@ -143,6 +150,7 @@ typedef struct TestServer { int log_fd; uint64_t rings; bool test_fail; + int test_flags; int queues; } TestServer; @@ -292,6 +300,10 @@ static void chr_read(void *opaque, const uint8_t *buf, int size) if (s->queues > 1) { msg.payload.u64 |= 0x1ULL << VIRTIO_NET_F_MQ; } + if (s->test_flags >= TEST_FLAGS_BAD) { + msg.payload.u64 = 0; + s->test_flags = TEST_FLAGS_END; + } p = (uint8_t *) &msg; qemu_chr_fe_write_all(chr, p, VHOST_USER_HDR_SIZE + msg.size); break; @@ -299,6 +311,10 @@ static void chr_read(void *opaque, const uint8_t *buf, int size) case VHOST_USER_SET_FEATURES: g_assert_cmpint(msg.payload.u64 & (0x1ULL << VHOST_USER_F_PROTOCOL_FEATURES), !=, 0ULL); + if (s->test_flags == TEST_FLAGS_DISCONNECT) { + qemu_chr_disconnect(chr); + s->test_flags = TEST_FLAGS_BAD; + } break; case VHOST_USER_GET_PROTOCOL_FEATURES: @@ -424,6 +440,16 @@ static TestServer *test_server_new(const gchar *name) return server; } +static void chr_event(void *opaque, int event) +{ + TestServer *s = opaque; + + if (s->test_flags == TEST_FLAGS_END && + event == CHR_EVENT_CLOSED) { + s->test_flags = TEST_FLAGS_OK; + } +} + static void test_server_create_chr(TestServer *server, const gchar *opt) { gchar *chr_path; @@ -432,7 +458,8 @@ static void test_server_create_chr(TestServer *server, const gchar *opt) server->chr = qemu_chr_new(server->chr_name, chr_path, NULL); g_free(chr_path); - qemu_chr_add_handlers(server->chr, chr_can_read, chr_read, NULL, server); + qemu_chr_add_handlers(server->chr, chr_can_read, chr_read, + chr_event, server); } static void test_server_listen(TestServer *server) @@ -774,6 +801,34 @@ static void test_connect_fail(void) g_free(path); } +static void test_flags_mismatch_subprocess(void) +{ + TestServer *s = test_server_new("flags-mismatch"); + char *cmd; + + s->test_flags = TEST_FLAGS_DISCONNECT; + g_thread_new("connect", connect_thread, s); + cmd = GET_QEMU_CMDE(s, 2, ",server", ""); + qtest_start(cmd); + g_free(cmd); + + init_virtio_dev(s); + wait_for_fds(s); + wait_for_rings_started(s, 2); + + qtest_end(); + test_server_free(s); +} + +static void test_flags_mismatch(void) +{ + gchar *path = g_strdup_printf("/%s/vhost-user/flags-mismatch/subprocess", + qtest_get_arch()); + g_test_trap_subprocess(path, 0, 0); + g_test_trap_assert_passed(); + g_free(path); +} + #endif static QVirtioPCIDevice *virtio_net_pci_init(QPCIBus *bus, int slot) @@ -908,6 +963,9 @@ int main(int argc, char **argv) qtest_add_func("/vhost-user/connect-fail/subprocess", test_connect_fail_subprocess); qtest_add_func("/vhost-user/connect-fail", test_connect_fail); + qtest_add_func("/vhost-user/flags-mismatch/subprocess", + test_flags_mismatch_subprocess); + qtest_add_func("/vhost-user/flags-mismatch", test_flags_mismatch); #endif ret = g_test_run();