Skip to content

Commit ddb07ad

Browse files
committed
global progress for graphdit and lstm
1 parent 485903f commit ddb07ad

File tree

2 files changed

+42
-20
lines changed

2 files changed

+42
-20
lines changed

torch_molecule/generator/graph_dit/modeling_graph_dit.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,26 +300,30 @@ def fit(
300300

301301
self.fitting_loss = []
302302
self.fitting_epoch = 0
303+
304+
# Calculate total steps for progress tracking
305+
total_steps = self.epochs * len(train_loader)
306+
307+
# Initialize global progress bar
308+
global_pbar = tqdm(total=total_steps, desc="Training Progress", disable=not self.verbose)
309+
303310
for epoch in range(self.epochs):
304-
train_losses = self._train_epoch(train_loader, optimizer, epoch)
311+
train_losses = self._train_epoch(train_loader, optimizer, epoch, global_pbar)
305312
self.fitting_loss.append(np.mean(train_losses).item())
306313
if scheduler:
307314
scheduler.step(np.mean(train_losses).item())
308315

316+
global_pbar.close()
309317
self.fitting_epoch = epoch
310318
self.is_fitted_ = True
311319
return self
312320

313-
def _train_epoch(self, train_loader, optimizer, epoch):
321+
def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None):
314322
self.model.train()
315323
losses = []
316-
iterator = (
317-
tqdm(train_loader, desc="Training", leave=False)
318-
if self.verbose
319-
else train_loader
320-
)
324+
# Remove the local tqdm iterator since we're using global progress bar
321325
active_index = self.dataset_info["active_index"]
322-
for step, batched_data in enumerate(iterator):
326+
for step, batched_data in enumerate(train_loader):
323327
batched_data = batched_data.to(self.device)
324328
optimizer.zero_grad()
325329

@@ -338,8 +342,16 @@ def _train_epoch(self, train_loader, optimizer, epoch):
338342
optimizer.step()
339343
losses.append(loss.item())
340344

341-
if self.verbose:
342-
iterator.set_postfix({"Epoch": epoch, "Loss": f"{loss.item():.4f}", "Loss_X": f"{loss_X.item():.4f}", "Loss_E": f"{loss_E.item():.4f}"})
345+
# Update global progress bar
346+
if global_pbar is not None:
347+
global_pbar.set_postfix({
348+
"Epoch": f"{epoch+1}/{self.epochs}",
349+
"Step": f"{step+1}/{len(train_loader)}",
350+
"Loss": f"{loss.item():.4f}",
351+
"Loss_X": f"{loss_X.item():.4f}",
352+
"Loss_E": f"{loss_E.item():.4f}"
353+
})
354+
global_pbar.update(1)
343355

344356
return losses
345357

torch_molecule/generator/lstm/modeling_lstm.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,25 +178,29 @@ def fit(
178178
self.fitting_loss = []
179179
self.fitting_epoch = 0
180180
criterion = torch.nn.CrossEntropyLoss()
181+
182+
# Calculate total steps for progress tracking
183+
total_steps = self.epochs * len(train_loader)
184+
185+
# Initialize global progress bar
186+
global_pbar = tqdm(total=total_steps, desc="Training Progress", disable=not self.verbose)
187+
181188
for epoch in range(self.epochs):
182-
train_losses = self._train_epoch(train_loader, optimizer, epoch, criterion)
189+
train_losses = self._train_epoch(train_loader, optimizer, epoch, criterion, global_pbar)
183190
self.fitting_loss.append(np.mean(train_losses).item())
184191
if scheduler:
185192
scheduler.step(np.mean(train_losses).item())
186193

194+
global_pbar.close()
187195
self.fitting_epoch = epoch
188196
self.is_fitted_ = True
189197
return self
190198

191-
def _train_epoch(self, train_loader, optimizer, epoch, criterion):
199+
def _train_epoch(self, train_loader, optimizer, epoch, criterion, global_pbar=None):
192200
self.model.train()
193201
losses = []
194-
iterator = (
195-
tqdm(train_loader, desc="Training", leave=False)
196-
if self.verbose
197-
else train_loader
198-
)
199-
for step, batched_data in enumerate(iterator):
202+
203+
for step, batched_data in enumerate(train_loader):
200204
for i in range(len(batched_data)):
201205
batched_data[i] = batched_data[i].to(self.device)
202206
optimizer.zero_grad()
@@ -208,8 +212,14 @@ def _train_epoch(self, train_loader, optimizer, epoch, criterion):
208212
optimizer.step()
209213
losses.append(loss.item())
210214

211-
if self.verbose:
212-
iterator.set_postfix({"Epoch": epoch, "Loss": f"{loss.item():.4f}"})
215+
# Update global progress bar
216+
if global_pbar is not None:
217+
global_pbar.set_postfix({
218+
"Epoch": f"{epoch+1}/{self.epochs}",
219+
"Step": f"{step+1}/{len(train_loader)}",
220+
"Loss": f"{loss.item():.4f}"
221+
})
222+
global_pbar.update(1)
213223

214224
return losses
215225

0 commit comments

Comments
 (0)