Skip to content

Commit 0f9dff2

Browse files
committed
fix: obtain list of gpus available based in user
1 parent 526a8cf commit 0f9dff2

File tree

5 files changed

+24
-22
lines changed

5 files changed

+24
-22
lines changed

nebula/controller.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ async def get_least_memory_gpu():
113113
}
114114

115115

116-
@app.get("/available_gpu")
116+
@app.get("/available_gpus/")
117117
async def get_available_gpu():
118-
available_gpu_index = None
118+
available_gpus = []
119119

120120
if importlib.util.find_spec("pynvml") is not None:
121121
try:
@@ -130,16 +130,15 @@ async def get_available_gpu():
130130
memory_used_percent = (memory_info.used / memory_info.total) * 100
131131

132132
# Obtain available GPUs
133-
if memory_used_percent < 5 and available_gpu_index is None:
134-
available_gpu_index = i
135-
133+
if memory_used_percent < 5:
134+
available_gpus.append(i)
135+
136+
return {
137+
"available_gpus": available_gpus,
138+
}
136139
except Exception: # noqa: S110
137140
pass
138141

139-
return {
140-
"available_gpu_index": available_gpu_index,
141-
}
142-
143142

144143
class NebulaEventHandler(PatternMatchingEventHandler):
145144
"""

nebula/core/training/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ def create_logger(self):
173173
def create_trainer(self):
174174
# Create a new trainer and logger for each round
175175
self.create_logger()
176-
num_gpus = torch.cuda.device_count()
176+
num_gpus = len(self.config.participant["device_args"]["gpu_id"])
177177
if self.config.participant["device_args"]["accelerator"] == "gpu" and num_gpus > 0:
178178
# Use all available GPUs
179-
if self.config.participant["device_args"]["gpu_id"] == -1:
179+
if num_gpus > 1:
180180
gpu_index = self.config.participant["device_args"]["idx"] % num_gpus
181181
# Use the selected GPU
182182
else:

nebula/frontend/app.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ async def get_host_resources():
442442
return None
443443

444444

445-
async def get_available_gpu():
446-
url = f"http://{settings.controller_host}:{settings.controller_port}/available_gpu"
445+
async def get_available_gpus():
446+
url = f"http://{settings.controller_host}:{settings.controller_port}/available_gpus"
447447
async with aiohttp.ClientSession() as session:
448448
async with session.get(url) as response:
449449
if response.status == 200:
@@ -1229,15 +1229,18 @@ async def node_stopped(scenario_name: str, request: Request):
12291229

12301230
async def assign_available_gpu(scenario_data, role):
12311231
if scenario_data["accelerator"] == "cpu":
1232-
scenario_data["gpu_id"] = None
1232+
scenario_data["gpu_id"] = []
12331233
else:
1234+
available_gpus = await get_available_gpus()
1235+
12341236
if role == "user":
1235-
gpu = await get_available_gpu()
1236-
scenario_data["gpu_id"] = gpu.get("available_gpu_index")
1237+
json_available_gpus = available_gpus.pop()
1238+
scenario_data["gpu_id"] = json_available_gpus
12371239
elif role == "admin":
1238-
scenario_data["gpu_id"] = -1
1240+
json_available_gpus = available_gpus
1241+
scenario_data["gpu_id"] = json_available_gpus
12391242
else:
1240-
scenario_data["gpu_id"] = None
1243+
scenario_data["gpu_id"] = []
12411244

12421245
return scenario_data
12431246

@@ -1266,7 +1269,7 @@ async def run_scenario(scenario_data, role, user):
12661269
dataset=scenario_data["dataset"],
12671270
rounds=scenario_data["rounds"],
12681271
role=role,
1269-
gpu_id=scenario_data["gpu_id"]
1272+
gpu_id=json.dumps(scenario_data["gpu_id"])
12701273
)
12711274

12721275
# Run the actual scenario

nebula/frontend/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def initialize_databases():
110110
rounds TEXT,
111111
role TEXT,
112112
username TEXT,
113-
gpu_id INTEGER
113+
gpu_id TEXT
114114
);
115115
"""
116116
)
@@ -127,7 +127,7 @@ async def initialize_databases():
127127
"rounds": "TEXT",
128128
"role": "TEXT",
129129
"username": "TEXT",
130-
"gpu_id" : "INTEGER",
130+
"gpu_id" : "TEXT",
131131
}
132132
await ensure_columns(conn, "scenarios", desired_columns)
133133

nebula/scenarios.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
logginglevel (str): Logging level.
9191
report_status_data_queue (bool): Indicator to report information about the nodes of the scenario
9292
accelerator (str): Accelerator used.
93-
gpu_id (int) : Id of the used gpu
93+
gpu_id (list) : Id list of the used gpu
9494
network_subnet (str): Network subnet.
9595
network_gateway (str): Network gateway.
9696
epochs (int): Number of epochs.

0 commit comments

Comments
 (0)