Skip to content

LSD update: Generate Transformer Embeddings #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12.3
6 changes: 3 additions & 3 deletions lsd/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class AutoencoderMultiverse(Multiverse):
module: str = "lsd.generate.autoencoders"
data_choices: str = "lsd/design/ae.yaml"
model_choices: str = "lsd/design/ae.yaml"
implementation_choices: str = "/design/ae.yaml"
implementation_choices: str = "lsd/design/ae.yaml"


@dataclass
Expand All @@ -96,7 +96,7 @@ class DimReductionMultiverse(Multiverse):
module: str = "lsd.generate.dim_reductions"
model_choices: str = "lsd/design/dr.yaml"
data_choices: str = "lsd/design/dr.yaml"
implementation_choices: str = "/design/dr.yaml"
implementation_choices: str = "lsd/design/dr.yaml"


@dataclass
Expand All @@ -113,4 +113,4 @@ class TransformerMultiverse(Multiverse):
module: str = "lsd.generate.transformers"
model_choices: str = "lsd/design/tf.yaml"
data_choices: str = "lsd/design/tf.yaml"
implementation_choices: str = "/design/tf.yaml"
implementation_choices: str = "lsd/design/tf.yaml"
6 changes: 3 additions & 3 deletions lsd/generate/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .configs import Mistral, Ada, MiniLM, arXiv, CNN
from .tf import Transformer, Custom
from .configs import Mistral, Ada, MiniLM, arXiv, CNN, Tokenizer
from .tf import Transformer

__all__ = [
"Mistral",
Expand All @@ -8,5 +8,5 @@
"arXiv",
"CNN",
"Transformer",
"Custom",
"Tokenizer",
]
91 changes: 62 additions & 29 deletions lsd/generate/transformers/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Protocol
from typing import Protocol
from dataclasses import dataclass


Expand All @@ -8,43 +8,50 @@


@dataclass
class Transformer(Protocol):
pass
class PretrainedLanguageModel(Protocol):
module: str = "lsd.generate.transformers.models.pretrained"


@dataclass
class PretrainedLanguageModel(Transformer):
pass
class Ada(PretrainedLanguageModel):
module: str = "lsd.generate.transformers.models.huggingface"
name: str = "Ada"
version: str = "v1"


@dataclass
class ADA(PretrainedLanguageModel):
pass
class Mistral(PretrainedLanguageModel):
module: str = "lsd.generate.transformers.models.huggingface"
name: str = "Mistral"
version: str = "v1"


@dataclass
class MISTRAL(PretrainedLanguageModel):
pass
class DistilRoberta(PretrainedLanguageModel):
module: str = "lsd.generate.transformers.models.huggingface"
name: str = "distilroberta-base"
version: str = "v1"


@dataclass
class DISTILROBERTA(PretrainedLanguageModel):
pass


@dataclass
class MINILM(PretrainedLanguageModel):
pass
class MiniLM(PretrainedLanguageModel):
module: str = "lsd.generate.transformers.models.sbert"
name: str = "sentence-transformers/all-MiniLM-L6-v2"
version: str = "v1"


@dataclass
class MPNET(PretrainedLanguageModel):
pass
module: str = "lsd.generate.transformers.models.sbert"
name: str = "sentence-transformers/all-mpnet-base-v2"
version: str = "v1"


@dataclass
class QA_DISTILBERT(PretrainedLanguageModel):
pass
class QA_DistilBert(PretrainedLanguageModel):
module: str = "lsd.generate.transformers.models.huggingface"
name: str = "distilbert-base-cased-distilled-squad"
version: str = "v1"


# ╭──────────────────────────────────────────────────────────╮
Expand All @@ -53,28 +60,42 @@ class QA_DISTILBERT(PretrainedLanguageModel):


@dataclass
class Embedding(Protocol):
pass
class HuggingFaceData(Protocol):
name: str
version: str
split: str = "train"
host: str = "huggingface"


@dataclass
class arXiv(Embedding):
pass
class arXiv(HuggingFaceData):
name: str = "arxiv"
version: str = "1.0.0"


@dataclass
class BBC(Embedding):
pass
class BBC(HuggingFaceData):
name: str = "bbc"
version: str = "1.0.0"


@dataclass
class CNN(Embedding):
pass
class CNN(HuggingFaceData):
name: str = "cnn_daily_mail"
version: str = "3.0.0"


@dataclass
class Patents(Embedding):
pass
class Patents(HuggingFaceData):
name: str = "patents"
version: str = "1.0.0"


@dataclass
class LocalData(Protocol):
name: str
path: str
host: str = "local"


# ╭──────────────────────────────────────────────────────────╮
Expand All @@ -84,4 +105,16 @@ class Patents(Embedding):

@dataclass
class Implementation(Protocol):
version: str


@dataclass
class Sampler(Implementation):
pass


@dataclass
class Tokenizer(Implementation):
name: str = "Tokenizer"
version: str = "v1"
aggregation: str = "mean"
9 changes: 9 additions & 0 deletions lsd/generate/transformers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .huggingface import HuggingFaceModel
from .sbert import SentenceTransformerModel
from .pretrained import BasePretrainedModel

__all__ = [
"BasePretrainedModel",
"HuggingFaceModel",
"SentenceTransformerModel"
]
143 changes: 143 additions & 0 deletions lsd/generate/transformers/models/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from transformers import AutoModel, AutoTokenizer
import torch
from typing import Union, List

from lsd.generate.transformers.models.pretrained import BasePretrainedModel


class HuggingFaceModel(BasePretrainedModel):
"""
Implementation of a Pretrained Model using HuggingFace Transformers.

This class handles loading transformer models from HuggingFace Hub,
processing text, and generating embeddings using the transformer's
hidden states.

Parameters
----------
config : dict
Configuration dictionary containing model-related parameters.
Expected keys:
- 'name': Model name/identifier on HuggingFace Hub
- 'version': Version of the model (optional)

Attributes
----------
config : dict
Stores the configuration parameters provided during instantiation.
model : AutoModel
The loaded HuggingFace transformer model.
tokenizer : AutoTokenizer
The corresponding tokenizer for the model.

Methods
-------
load_model()
Loads the HuggingFace transformer model and tokenizer.
process_text(text: Union[str, List[str]])
Tokenizes and processes the input text.
embed(text: Union[str, List[str]])
Generates embeddings for the input text using the transformer.
"""

def __init__(self, config):
super().__init__(config)
self.model = None
self.tokenizer = None
self.load_model()

def load_model(self):
"""
Loads the HuggingFace transformer model and tokenizer.

The model is loaded from the `name` key in the config dictionary.
Defaults to 'distilbert-base-uncased' if no name is provided.
"""
model_name = self.config.get("name", "distilbert-base-uncased")

try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)

# Set model to evaluation mode
self.model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load model '{model_name}': {e}")

def process_text(self, text: Union[str, List[str]]):
"""
Processes the text by tokenizing it using the HuggingFace tokenizer.

Parameters
----------
text : Union[str, List[str]]
The input text to be processed. Can be a single string or list of strings.

Returns
-------
dict
The tokenized version of the input text with attention masks.
"""
if isinstance(text, str):
text = [text]

return self.tokenizer(
text,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
)

def embed(self, text: Union[str, List[str]]):
"""
Generates embeddings for the input text using the loaded HuggingFace model.

Parameters
----------
text : Union[str, List[str]]
The input text to be embedded. Can be a single string or list of strings.

Returns
-------
torch.Tensor
The generated embeddings for the input text. Uses mean pooling
of the last hidden states.
"""
# Process text
inputs = self.process_text(text)

# Generate embeddings
with torch.no_grad():
outputs = self.model(**inputs)

# Get last hidden states
last_hidden_states = outputs.last_hidden_state

# Apply mean pooling with attention mask
attention_mask = inputs['attention_mask']

# Expand attention mask to match hidden states dimensions
attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()

# Apply mask and compute mean
masked_embeddings = last_hidden_states * attention_mask_expanded
summed_embeddings = torch.sum(masked_embeddings, dim=1)
summed_mask = torch.clamp(attention_mask_expanded.sum(1), min=1e-9)

# Mean pooling
embeddings = summed_embeddings / summed_mask

return embeddings


def initialize():
"""
Initializes and returns an instance of the HuggingFaceModel.

Returns
-------
HuggingFaceModel
An instance of the HuggingFaceModel class.
"""
return HuggingFaceModel
49 changes: 49 additions & 0 deletions lsd/generate/transformers/models/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from abc import ABC, abstractmethod


class BasePretrainedModel(ABC):
"""
Abstract Base Class for Pretrained Transformer Models.

Child classes must implement the following methods:
- load_model
- load_tokenizer
- embed


Parameters
----------
config : ConfigType
Configuration parameters for setting up the model. This can be a
dictionary or an `ConfigType` object.

Attributes
----------
config : ConfigType
Stores the configuration parameters provided during instantiation.

Methods
-------
load_model()
Abstract method for loading the pretrained model.
load_tokenizer()
Abstract method for loading the tokenizer.
embed(text: str)
Abstract method for embedding the input text.
"""

def __init__(self, config):
self.config = config

@abstractmethod
def load_model(self):
pass

@abstractmethod
def process_text(self):
# Tokenize etc
pass

@abstractmethod
def embed(self, text):
pass
Loading