Skip to content

Commit ed28c03

Browse files
authored
v3.4.3
1 parent 22b3fe4 commit ed28c03

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/utilities.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ def load_config(config_file):
1414
return yaml.safe_load(file)
1515

1616
def is_nvidia_gpu_available(config):
17-
return config["Compute_Device"]["gpu_brand"].upper() == "NVIDIA"
17+
gpu_brand = config.get("Compute_Device", {}).get("gpu_brand")
18+
19+
if isinstance(gpu_brand, str):
20+
normalized_gpu_brand = gpu_brand.strip().lower()
21+
return normalized_gpu_brand == "nvidia"
22+
return False
1823

1924
config = load_config('config.yaml')
2025

0 commit comments

Comments
 (0)