Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vulkan] Prioritize discrete GPUs as device_id=0. #8588

Merged
merged 1 commit into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/runtime/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,27 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance,
device_name = properties.properties.deviceName;
driver_version = properties.properties.driverVersion;

switch (properties.properties.deviceType) {
case VK_PHYSICAL_DEVICE_TYPE_OTHER:
device_type = "other";
break;
case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
device_type = "integrated";
break;
case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
device_type = "discrete";
break;
case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
device_type = "virtual";
break;
case VK_PHYSICAL_DEVICE_TYPE_CPU:
device_type = "cpu";
break;
default:
LOG(FATAL) << "Unknown vulkan device type: " << properties.properties.deviceType;
break;
}

// By default, use the maximum API version that the driver allows,
// so that any supported features can be used by TVM shaders.
// However, if we can query the conformance version, then limit to
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ struct VulkanDeviceProperties {
uint32_t max_storage_buffer_range{1 << 27};
uint32_t max_per_stage_descriptor_storage_buffer{4};
uint32_t max_shared_memory_per_block{16384};
std::string device_name{"unknown device name"};
std::string device_type{"unknown_device_type"};
std::string device_name{"unknown_device_name"};
uint32_t driver_version{0};
uint32_t vulkan_api_version{VK_API_VERSION_1_0};
uint32_t max_spirv_version{0x10000};
Expand Down
26 changes: 24 additions & 2 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
devices_.push_back(std::move(device));
}
}

// Move discrete GPUs to the start of the list, so the default
// device_id=0 preferentially uses a discrete GPU.
auto preference = [](const VulkanDevice& device) {
const std::string& type = device.device_properties.device_type;
if (type == "discrete") {
return 0;
} else if (type == "integrated") {
return 1;
} else if (type == "virtual") {
return 2;
} else if (type == "cpu") {
return 3;
} else {
return 4;
}
};

std::stable_sort(devices_.begin(), devices_.end(),
[&preference](const VulkanDevice& a, const VulkanDevice& b) {
return preference(a) < preference(b);
});
}

VulkanDeviceAPI::~VulkanDeviceAPI() {}
Expand Down Expand Up @@ -214,8 +236,8 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property,
if (property == "max_shared_memory_per_block") {
*rv = int64_t(prop.max_shared_memory_per_block);
}
if (property == ":string device_name") {
*rv = prop.device_name;
if (property == "device_name") {
*rv = String(prop.device_name);
}
if (property == "driver_version") {
*rv = int64_t(prop.driver_version);
Expand Down
3 changes: 2 additions & 1 deletion src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ Map<String, ObjectRef> UpdateVulkanAttrs(Map<String, ObjectRef> attrs) {
"driver_version",
"vulkan_api_version",
"max_spirv_version"};
std::vector<const char*> str_opts = {"device_name"};
std::vector<const char*> str_opts = {"device_name", "device_type"};

for (auto& key : bool_opts) {
if (!attrs.count(key)) {
Expand Down Expand Up @@ -387,6 +387,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Integer>("max_per_stage_descriptor_storage_buffer")
.add_attr_option<Integer>("max_shared_memory_per_block")
// Other device properties
.add_attr_option<String>("device_type")
.add_attr_option<String>("device_name")
.add_attr_option<Integer>("driver_version")
.add_attr_option<Integer>("vulkan_api_version")
Expand Down