diff --git a/type4py/data_loaders.py b/type4py/data_loaders.py index 2bb28cb..85d1521 100644 --- a/type4py/data_loaders.py +++ b/type4py/data_loaders.py @@ -1,9 +1,10 @@ from type4py import logger, MIN_DATA_POINTS -from typing import Tuple +from typing import Tuple, Dict from os.path import join from collections import Counter from time import time from torch.utils.data import TensorDataset, DataLoader +from tqdm import tqdm import torch import numpy as np @@ -407,9 +408,13 @@ def __init__(self, *in_sequences: torch.Tensor, labels: torch.Tensor, dataset_na self.labels = labels self.dataset_name = dataset_name self.train_mode = train_mode + self.precomputed_triplets: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {} self.get_item_func = self.get_item_train if self.train_mode else self.get_item_test + if self.train_mode: + self._precompute_triplets() + def get_item_train(self, index: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """ @@ -419,16 +424,11 @@ def get_item_train(self, index: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], - The third tuple is different (data, label) from the given index """ - # Find a similar datapoint randomly - mask = self.labels == self.labels[index] - mask[index] = False # Making sure that the similar pair is NOT the same as the given index - mask = mask.nonzero() - a = mask[torch.randint(high=len(mask), size=(1,))][0] - - # Find a different datapoint randomly - mask = self.labels != self.labels[index] - mask = mask.nonzero() - b = mask[torch.randint(high=len(mask), size=(1,))][0] + label = self.labels[index].item() + # A similar datapoint + a = self.precomputed_triplets[label][0] + # A different datapoint + b = self.precomputed_triplets[label][1] return (self.data[index], self.labels[index]), (self.data[a.item()], self.labels[a.item()]), \ (self.data[b.item()], self.labels[b.item()]) @@ -442,3 +442,24 @@ def __getitem__(self, index: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], def __len__(self) -> int: return len(self.data) + + def _precompute_triplets(self): + """ + This method pre-computes triplets for training the model. It speeds up the creation of training batches quite significantly. + However, each training example has ONLY one randomly-selected positive and negative pair. + Previously, each training example could have a different negative/positive pair in every epoch. + """ + for i, l in enumerate(tqdm(self.labels, total=len(self.labels), desc="Pre-computing triplets")): + if l.item() not in self.precomputed_triplets: + # Positive example + p = self.labels == l + n = (p).byte() ^ 1 + p[i] = False + p = p.nonzero() + p = p[torch.randint(high=len(p), size=(1,))][0] + + # Negative example + n = n.nonzero() + n = n[torch.randint(high=len(n), size=(1,))][0] + + self.precomputed_triplets[l.item()] = (p, n)