Skip to content

Commit 2481b19

Browse files
minor changes in the workflow
1 parent 6af215a commit 2481b19

File tree

1 file changed

+2
-9
lines changed

1 file changed

+2
-9
lines changed

nebula/core/engine.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,6 @@ def __init__(
142142

143143
self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm)
144144

145-
self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True)
146-
147145
self._addon_manager = AddonManager(self, self.config)
148146

149147
@property
@@ -189,9 +187,6 @@ def get_federation_ready_lock(self):
189187
def get_federation_setup_lock(self):
190188
return self.federation_setup_lock
191189

192-
def get_trainning_in_progress_lock(self):
193-
return self.trainning_in_progress_lock
194-
195190
def get_round_lock(self):
196191
return self.round_lock
197192

@@ -712,10 +707,8 @@ def __init__(
712707

713708
async def _extended_learning_cycle(self):
714709
# Define the functionality of the aggregator node
715-
await self.trainer.test()
716-
await self.trainning_in_progress_lock.acquire_async()
717710
await self.trainer.train()
718-
await self.trainning_in_progress_lock.release_async()
711+
await self.trainer.test()
719712

720713
self_update_event = UpdateReceivedEvent(
721714
self.trainer.get_model_parameters(), self.trainer.get_model_weight(), self.addr, self.round
@@ -777,8 +770,8 @@ async def _extended_learning_cycle(self):
777770
# Define the functionality of the trainer node
778771
logging.info("Waiting global update | Assign _waiting_global_update = True")
779772

780-
await self.trainer.test()
781773
await self.trainer.train()
774+
await self.trainer.test()
782775

783776
self_update_event = UpdateReceivedEvent(
784777
self.trainer.get_model_parameters(), self.trainer.get_model_weight(), self.addr, self.round, local=True

0 commit comments

Comments
 (0)