From 9fca69b41044b9a85b03286c3850ea6e53ef13c4 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 10 Feb 2024 22:14:52 +0100 Subject: [PATCH] Add check for VK_KHR_portability_enumeration for MoltenVK support --- ggml-vulkan.cpp | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 33b8a90..37123ac 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -1100,23 +1100,44 @@ static void ggml_vk_instance_init() { #endif vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; - const std::vector layers = { + + const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); +#ifdef __APPLE__ + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + portability_enumeration_ext = true; + break; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + + std::vector layers = { #ifdef GGML_VULKAN_VALIDATE "VK_LAYER_KHRONOS_validation", #endif }; - const std::vector extensions = { + std::vector extensions = { #ifdef GGML_VULKAN_VALIDATE "VK_EXT_validation_features", -#endif -#ifdef __APPLE__ - "VK_KHR_portability_enumeration", #endif }; - vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags(), &app_info, layers, extensions); #ifdef __APPLE__ - instance_create_info.flags = vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; + if (portability_enumeration_ext) { + extensions.push_back("VK_KHR_portability_enumeration"); + } #endif + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); +#ifdef __APPLE__ + if (portability_enumeration_ext) { + instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; + } +#endif + #ifdef GGML_VULKAN_VALIDATE const std::vector features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; @@ -1175,12 +1196,12 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { vk_instance.devices[idx] = std::make_shared(); ctx->device = vk_instance.devices[idx]; ctx->device.lock()->physical_device = devices[dev_num]; - std::vector ext_props = ctx->device.lock()->physical_device.enumerateDeviceExtensionProperties(); + const std::vector ext_props = ctx->device.lock()->physical_device.enumerateDeviceExtensionProperties(); bool maintenance4_support = false; // Check if maintenance4 is supported - for (auto properties : ext_props) { + for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { maintenance4_support = true; } @@ -1211,7 +1232,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { bool fp16_storage = false; bool fp16_compute = false; - for (auto properties : ext_props) { + for (const auto& properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { fp16_storage = true; } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {