Skip to content

Commit 27245d7

Browse files
clean code and fix minor issues
1 parent 2481b19 commit 27245d7

File tree

9 files changed

+31
-727
lines changed

9 files changed

+31
-727
lines changed

nebula/addons/topologymanager.py

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -250,23 +250,6 @@ def generate_ring_topology(self, increase_convergence=False):
250250
"""
251251
self.__ring_topology(increase_convergence=increase_convergence)
252252

253-
def generate_custom_topology(self, topology):
254-
"""
255-
Sets the network topology to a custom topology provided by the user.
256-
257-
This method allows for the manual configuration of the network topology by directly assigning
258-
the `topology` argument to the internal `self.topology` attribute.
259-
260-
Args:
261-
topology (numpy.ndarray): A 2D array representing the custom network topology.
262-
The array should have dimensions (n_nodes, n_nodes) where `n_nodes`
263-
is the number of nodes in the network.
264-
265-
Returns:
266-
None: The method modifies the internal `self.topology` to the provided custom topology.
267-
"""
268-
self.topology = topology
269-
270253
def generate_random_topology(self, probability):
271254
"""
272255
Generates a random topology using Erdos-Renyi model with given probability.
@@ -281,62 +264,6 @@ def generate_random_topology(self, probability):
281264
self.topology = nx.to_numpy_array(random_graph, dtype=np.float32)
282265
np.fill_diagonal(self.topology, 0) # No self-loops
283266

284-
def get_matrix_adjacency_from_neighbors(self, neighbors):
285-
"""
286-
Generates an adjacency matrix from a list of neighbors.
287-
288-
This method constructs an adjacency matrix for the network based on the provided list of neighbors
289-
for each node. A 1 in the matrix at position (i, j) indicates that node i is a neighbor of node j,
290-
while a 0 indicates no connection.
291-
292-
Args:
293-
neighbors (list of lists): A list of lists where each sublist contains the indices of the neighbors
294-
for the corresponding node. The length of the outer list should be equal
295-
to the number of nodes in the network (`self.n_nodes`).
296-
297-
Returns:
298-
numpy.ndarray: A 2D adjacency matrix of shape (n_nodes, n_nodes), where n_nodes is the total number
299-
of nodes in the network. The matrix contains 1s where there is a connection and 0s
300-
where there is no connection.
301-
"""
302-
matrix_adjacency = np.zeros((self.n_nodes, self.n_nodes), dtype=np.float32)
303-
for i in range(self.n_nodes):
304-
for j in range(self.n_nodes):
305-
if i in neighbors[j]:
306-
matrix_adjacency[i, j] = 1
307-
return matrix_adjacency
308-
309-
def get_topology(self):
310-
"""
311-
Returns the network topology.
312-
313-
This method retrieves the current topology of the network. The behavior of the method depends on whether
314-
the network is symmetric or asymmetric. For both cases in this implementation, it simply returns the
315-
`self.topology`.
316-
317-
Returns:
318-
numpy.ndarray: The current topology of the network as a 2D numpy array. The topology represents the
319-
connectivity between nodes, where a value of 1 indicates a connection and 0 indicates
320-
no connection between the nodes.
321-
"""
322-
if self.b_symmetric:
323-
return self.topology
324-
else:
325-
return self.topology
326-
327-
def get_nodes(self):
328-
"""
329-
Returns the nodes in the network.
330-
331-
This method retrieves the current list of nodes in the network. Each node is represented by an array of
332-
three values (such as coordinates or identifiers) in the `self.nodes` attribute.
333-
334-
Returns:
335-
numpy.ndarray: A 2D numpy array representing the nodes in the network. Each row represents a node,
336-
and the columns may represent different properties (e.g., position, identifier, etc.).
337-
"""
338-
return self.nodes
339-
340267
@staticmethod
341268
def get_coordinates(random_geo=True):
342269
"""
@@ -393,20 +320,6 @@ def update_nodes(self, config_participants):
393320
"""
394321
self.nodes = config_participants
395322

396-
def get_node(self, node_idx):
397-
"""
398-
Retrieves the node information based on the given index.
399-
400-
This method returns the details of a specific node from the `nodes` attribute using its index.
401-
402-
Parameters:
403-
node_idx (int): The index of the node to retrieve from the `nodes` list.
404-
405-
Returns:
406-
numpy.ndarray: A tuple or array containing the node's information at the given index.
407-
"""
408-
return self.nodes[node_idx]
409-
410323
def get_neighbors_string(self, node_idx):
411324
"""
412325
Retrieves the neighbors of a given node as a string representation.

nebula/controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ async def get_least_memory_gpu():
9090
gpu_with_least_memory_index = None
9191

9292
if importlib.util.find_spec("pynvml") is not None:
93+
max_memory_used_percent = 50
9394
try:
9495
import pynvml
9596

nebula/core/aggregation/aggregator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def __init__(self, config=None, engine=None):
4444
logging.info(f"[{self.__class__.__name__}] Starting Aggregator")
4545
self._federation_nodes = set()
4646
self._pending_models_to_aggregate = {}
47-
self._pending_models_to_aggregate_lock = Locker(name="pending_models_to_aggregate_lock", async_lock=True)
4847
self._aggregation_done_lock = Locker(name="aggregation_done_lock", async_lock=True)
4948
self._aggregation_waiting_skip = asyncio.Event()
5049

@@ -92,7 +91,7 @@ async def get_aggregation(self):
9291
await self.us.notify_if_all_updates_received()
9392
lock_task = asyncio.create_task(self._aggregation_done_lock.acquire_async(timeout=timeout))
9493
skip_task = asyncio.create_task(self._aggregation_waiting_skip.wait())
95-
done, pending = await asyncio.wait(
94+
done, _ = await asyncio.wait(
9695
[lock_task, skip_task],
9796
return_when=asyncio.FIRST_COMPLETED,
9897
)
@@ -131,13 +130,10 @@ async def get_aggregation(self):
131130
return aggregated_result
132131

133132
def print_model_size(self, model):
134-
total_params = 0
135133
total_memory = 0
136134

137135
for _, param in model.items():
138136
num_params = param.numel()
139-
total_params += num_params
140-
141137
memory_usage = param.element_size() * num_params
142138
total_memory += memory_usage
143139

nebula/core/aggregation/dualhistagg.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

nebula/core/aggregation/fedavgSVM.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

nebula/core/engine.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False)
376376
updt_nei_event = UpdateNeighborEvent(removed_neighbor_addr, remove)
377377
asyncio.create_task(EventManager.get_instance().publish_node_event(updt_nei_event))
378378

379-
async def broadcast_models_include(self, age: AggregationEvent):
379+
async def broadcast_models_include(self, aggregation_event: AggregationEvent):
380380
logging.info(f"🔄 Broadcasting MODELS_INCLUDED for round {self.get_round()}")
381381
message = self.cm.create_message(
382382
"federation", "federation_models_included", [str(arg) for arg in [self.get_round()]]
@@ -530,13 +530,6 @@ async def _waiting_model_updates(self):
530530
else:
531531
logging.error("Aggregation finished with no parameters")
532532

533-
def print_round_information(self):
534-
print_msg_box(
535-
msg=f"Round {self.round} of {self.total_rounds} started.",
536-
indent=2,
537-
title="Round information",
538-
)
539-
540533
def learning_cycle_finished(self):
541534
return not (self.round < self.total_rounds)
542535

@@ -707,8 +700,8 @@ def __init__(
707700

708701
async def _extended_learning_cycle(self):
709702
# Define the functionality of the aggregator node
710-
await self.trainer.train()
711703
await self.trainer.test()
704+
await self.trainer.train()
712705

713706
self_update_event = UpdateReceivedEvent(
714707
self.trainer.get_model_parameters(), self.trainer.get_model_weight(), self.addr, self.round
@@ -770,8 +763,8 @@ async def _extended_learning_cycle(self):
770763
# Define the functionality of the trainer node
771764
logging.info("Waiting global update | Assign _waiting_global_update = True")
772765

773-
await self.trainer.train()
774766
await self.trainer.test()
767+
await self.trainer.train()
775768

776769
self_update_event = UpdateReceivedEvent(
777770
self.trainer.get_model_parameters(), self.trainer.get_model_weight(), self.addr, self.round, local=True

0 commit comments

Comments
 (0)