Skip to content

Commit 7935f98

Browse files
Merge pull request #7 from iwknow/frank-dev
Fix types and enable multiple text/image inputs
2 parents ab561c1 + def83c6 commit 7935f98

File tree

11 files changed

+128
-71
lines changed

11 files changed

+128
-71
lines changed

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
5+
# data folder contains downloadable public datasets
6+
data/
7+
8+
# Editor
9+
.vscode/

lanistr/dataset/amazon/amazon_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ def load_multimodal_data(args: omegaconf.DictConfig) -> pd.DataFrame:
4242
A pandas DataFrame containing the loaded data.
4343
"""
4444
if args.task == "pretrain":
45-
path = os.path.join(args.data_dir, f"{args.category}_total.json.gz")
45+
path = os.path.join(args.data_dir, f"{args.category}.json.gz")
4646
data = read_gzip(path)
4747
else:
4848
path_to_clean_data = os.path.join(
49-
args.data_dir, f"{args.category}_total.csv"
49+
args.data_dir, f"{args.category}.csv"
5050
)
5151
data = pd.read_csv(path_to_clean_data)
5252
data = data.reset_index(drop=True)

lanistr/dataset/amazon/download_images.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import gzip
1818
import json
1919
import os
20-
from typing import List, Optional
20+
from typing import Dict, List, Optional
2121

2222
import omegaconf
2323
import pandas as pd
@@ -60,7 +60,7 @@ def load_and_clean_meta_data(
6060
return metadata
6161

6262

63-
def read_gzip(name: str, args: omegaconf.DictConfig) -> List[dict[str, str]]:
63+
def read_gzip(name: str, args: omegaconf.DictConfig) -> List[Dict[str, str]]:
6464
"""Reads a gzipped file and returns a list of JSON objects.
6565
6666
Args:
@@ -103,8 +103,8 @@ def load_data(args: omegaconf.DictConfig) -> pd.DataFrame:
103103
def get_reviews(
104104
row: pd.Series,
105105
index: int,
106-
nan_indices_summary: list[int],
107-
nan_indices_review_text: list[int],
106+
nan_indices_summary: List[int],
107+
nan_indices_review_text: List[int],
108108
) -> Optional[str]:
109109
"""Extracts and cleans the review text from a row of data.
110110
@@ -135,7 +135,7 @@ def get_reviews(
135135

136136

137137
def get_review_votes(
138-
row: pd.Series, index: int, nan_indices_votes: list[int]
138+
row: pd.Series, index: int, nan_indices_votes: List[int]
139139
) -> int:
140140
"""Extracts and cleans the review vote from a row of data.
141141
@@ -181,7 +181,7 @@ def get_product_brands(
181181

182182

183183
def download_and_save_image(
184-
image_data_dir: str, urls: list[str], index: int
184+
image_data_dir: str, urls: List[str], index: int
185185
) -> Optional[str]:
186186
"""Downloads and saves the image from a URL.
187187
@@ -236,7 +236,7 @@ def get_product_prices(
236236

237237

238238
def get_review_names(
239-
row: pd.Series, index: int, nan_indices_reviewer_names: list[int]
239+
row: pd.Series, index: int, nan_indices_reviewer_names: List[int]
240240
) -> Optional[str]:
241241
"""Extracts and cleans the reviewer name from a row of data.
242242
@@ -389,7 +389,7 @@ def main():
389389
# row = data.iloc[index]
390390

391391
meta_row = meta_data.loc[meta_data['asin'] == row['asin'].item()]
392-
meta_data_exists = True if meta_row else False
392+
meta_data_exists = False if meta_row.empty else True
393393

394394
amazon_image_exists = False
395395
user_image_exists = False
@@ -443,8 +443,8 @@ def main():
443443

444444
categorical_cols = ['reviewerID', 'verified', 'asin', 'year']
445445
numerical_cols = ['vote', 'unixReviewTime']
446-
image_col = ['ImageFileName']
447-
text_col = ['Review']
446+
image_cols = ['ImageFileName']
447+
text_cols = ['Review']
448448
label_col = ['labels']
449449

450450
d = pd.DataFrame()
@@ -474,8 +474,8 @@ def main():
474474
if item
475475
not in categorical_cols
476476
+ numerical_cols
477-
+ image_col
478-
+ text_col
477+
+ image_cols
478+
+ text_cols
479479
+ label_col
480480
]
481481

lanistr/dataset/amazon/load_data.py

Lines changed: 70 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def load_amazon(
5656
)
5757
)
5858
feature_names = categorical_cols + numerical_cols
59+
image_names = ['ImageFileName']
60+
text_names = ['Review']
5961
train_data, test_data, valid_data = get_train_and_test_splits(
6062
args, amazon_data
6163
)
@@ -65,6 +67,8 @@ def load_amazon(
6567
'cat_idxs': cat_idxs,
6668
'cat_dims': cat_dims,
6769
'feature_names': feature_names,
70+
'image_names': image_names,
71+
'text_names': text_names,
6872
}
6973

7074
dataframes = {
@@ -103,6 +107,8 @@ def create_multimodal_dataset_from_dataframes(
103107
tokenizer=tokenizer,
104108
transform=train_transform,
105109
feature_names=dataframes['tabular_data_information']['feature_names'],
110+
image_names=dataframes['tabular_data_information']['image_names'],
111+
text_names=dataframes['tabular_data_information']['text_names'],
106112
text=args.text,
107113
image=args.image,
108114
tab=args.tab,
@@ -113,6 +119,8 @@ def create_multimodal_dataset_from_dataframes(
113119
tokenizer=tokenizer,
114120
transform=test_transform,
115121
feature_names=dataframes['tabular_data_information']['feature_names'],
122+
image_names=dataframes['tabular_data_information']['image_names'],
123+
text_names=dataframes['tabular_data_information']['text_names'],
116124
text=args.text,
117125
image=args.image,
118126
tab=args.tab,
@@ -123,6 +131,8 @@ def create_multimodal_dataset_from_dataframes(
123131
tokenizer=tokenizer,
124132
transform=train_transform,
125133
feature_names=dataframes['tabular_data_information']['feature_names'],
134+
image_names=dataframes['tabular_data_information']['image_names'],
135+
text_names=dataframes['tabular_data_information']['text_names'],
126136
text=args.text,
127137
image=args.image,
128138
tab=args.tab,
@@ -146,6 +156,8 @@ def __init__(
146156
tokenizer: transformers.BertTokenizer,
147157
transform: torchvision.transforms.Compose,
148158
feature_names: List[str],
159+
image_names: List[str],
160+
text_names: List[str],
149161
text: bool,
150162
image: bool,
151163
tab: bool,
@@ -157,7 +169,9 @@ def __init__(
157169
df: The dataframe to use for the dataset.
158170
tokenizer: The tokenizer to use for the text.
159171
transform: The transform to use for the images.
160-
feature_names: The names of the features to use.
172+
feature_names: The names of the features columns.
173+
image_names: The names of the image columns.
174+
text_names: The names of the text columns.
161175
text: Whether to use text.
162176
image: Whether to use images.
163177
tab: Whether to use tabular data.
@@ -171,7 +185,10 @@ def __init__(
171185
self.features = self.df[feature_names].values
172186

173187
if text:
174-
self.reviews = df['Review'].values
188+
self.texts = df[text_names].values
189+
190+
if image:
191+
self.images = df[image_names].values
175192

176193
self.mask_generator = MaskGenerator(
177194
input_size=args.image_size,
@@ -199,48 +216,37 @@ def __getitem__(self, index: int):
199216

200217
# text
201218
if self.text:
202-
review = self.reviews[index]
203-
204-
try:
205-
item = self.tokenizer.encode_plus(
206-
review,
207-
max_length=self.args.max_token_length,
208-
truncation=True,
209-
add_special_tokens=True,
210-
return_token_type_ids=False,
211-
padding='max_length',
212-
return_attention_mask=True,
213-
return_tensors='pt',
214-
)
215-
except Exception as e: # pylint: disable=broad-exception-caught
216-
print(e)
217-
item = self.tokenizer.encode_plus(
218-
'',
219-
max_length=self.args.max_token_length,
220-
truncation=True,
221-
add_special_tokens=True,
222-
return_token_type_ids=False,
223-
padding='max_length',
224-
return_attention_mask=True,
225-
return_tensors='pt',
226-
)
219+
input_ids_list = []
220+
attention_mask_list = []
221+
for text in self.texts[index]:
222+
encode_result = self.encode_text(text)
223+
input_ids_list.append(encode_result['input_ids'])
224+
attention_mask_list.append(encode_result['attention_mask'])
225+
# input_ids has shape (text_num, token_length)
226+
item['input_ids'] = torch.cat(input_ids_list)
227+
# attention_mask has shape (text_num, token_length)
228+
item['attention_mask'] = torch.cat(attention_mask_list)
227229

228230
# image
229231
if self.image:
230-
image_filename = row['ImageFileName']
231-
if isinstance(image_filename, str):
232-
image_path = os.path.join(self.args.image_data_dir, image_filename)
233-
img = Image.open(image_path).convert('RGB')
234-
img = self.transform(img)
235-
item['pixel_values'] = img
236-
item['bool_masked_pos'] = self.mask_generator()
237-
else:
238-
239-
item['pixel_values'] = torch.zeros(
232+
pixel_values = []
233+
bool_masked_positions = []
234+
for image_data in self.images[index]:
235+
if isinstance(image_data, str):
236+
image_path = os.path.join(self.args.image_data_dir, image_data)
237+
img = Image.open(image_path).convert('RGB')
238+
img = self.transform(img)
239+
pixel_values.append(img)
240+
else:
241+
pixel_values.append(torch.zeros(
240242
size=(3, self.args.image_size, self.args.image_size),
241243
dtype=torch.float,
242-
)
243-
item['bool_masked_pos'] = self.mask_generator()
244+
))
245+
bool_masked_positions.append(self.mask_generator())
246+
# pixel_values has shape (image_num, channel, width, height)
247+
item['pixel_values'] = torch.stack(pixel_values)
248+
# bool_masked_positions has shape (image_num, model_patch_size**2)
249+
item['bool_masked_positions'] = torch.stack(bool_masked_positions)
244250

245251
# tabular
246252
if self.tab:
@@ -261,3 +267,28 @@ def __len__(self) -> int:
261267
The length of the dataset.
262268
"""
263269
return len(self.df)
270+
271+
def encode_text(self, text: str):
272+
try:
273+
return self.tokenizer.encode_plus(
274+
text,
275+
max_length=self.args.max_token_length,
276+
truncation=True,
277+
add_special_tokens=True,
278+
return_token_type_ids=False,
279+
padding='max_length',
280+
return_attention_mask=True,
281+
return_tensors='pt',
282+
)
283+
except Exception as e: # pylint: disable=broad-exception-caught
284+
print(e)
285+
return self.tokenizer.encode_plus(
286+
'',
287+
max_length=self.args.max_token_length,
288+
truncation=True,
289+
add_special_tokens=True,
290+
return_token_type_ids=False,
291+
padding='max_length',
292+
return_attention_mask=True,
293+
return_tensors='pt',
294+
)

lanistr/model/modeling_lanistr.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -382,29 +382,46 @@ def forward(self, batch: Mapping[str, torch.Tensor]) -> BaseModelOutput:
382382
embeds = []
383383
##================================= Text =================================##
384384
if self.args.text:
385-
batch['input_ids'] = batch['input_ids'].squeeze(1)
386-
batch['attention_mask'] = batch['attention_mask'].squeeze(1)
387-
388-
# forwarding regular inputs:
389-
outputs = self.text_encoder(
390-
input_ids=batch['input_ids'],
391-
attention_mask=batch['attention_mask'],
385+
# batch['input_ids'] has shape (batch_size text_num, id_length), e.g. [4, 2, 512].
386+
batch_size = batch['input_ids'].shape[0]
387+
text_num = batch['input_ids'].shape[1]
388+
text_contents = batch['input_ids'].flatten(start_dim=0, end_dim=1)
389+
attention_mask = batch['attention_mask'].flatten(start_dim=0, end_dim=1)
390+
391+
text_encoding = self.text_encoder(
392+
input_ids=text_contents,
393+
attention_mask=attention_mask,
392394
)
393-
last_hidden_state = outputs.last_hidden_state
395+
last_hidden_state = text_encoding.last_hidden_state
394396
text_embeddings = self.text_proj(
395397
last_hidden_state[:, self.target_token_idx, :]
396398
)
399+
text_embeddings = text_embeddings.reshape(tuple([batch_size, text_num] + list(text_embeddings.shape)[1:]))
400+
401+
# Average the embeddings for all the text inputs.
402+
text_embeddings = text_embeddings.mean(dim=1, keepdim=True)
397403

404+
# TODO(Reviewer): the internal code doesn't have normalization. Do we need this? Is the dimension correct? text_embeddings has shape (batch_size, dim1, dim2)
398405
text_embeddings = F.normalize(text_embeddings, dim=1)
399-
embeds.append(text_embeddings.unsqueeze(dim=1))
406+
embeds.append(text_embeddings)
400407

401408
##================================== Image ===============================##
402409
if self.args.image:
410+
# batch['pixel_values'] has shape (batch_size, image_num, channel, width, height), e.g. [4, 2, 3, 224, 224].
411+
batch_size = batch['pixel_values'].shape[0]
412+
image_num = batch['pixel_values'].shape[1]
413+
images = batch['pixel_values'].flatten(start_dim=0, end_dim=1)
403414

404-
image_features = self.image_encoder(
405-
pixel_values=batch['pixel_values'], bool_masked_pos=None
415+
image_encodings = self.image_encoder(
416+
pixel_values=images, bool_masked_pos=None
406417
)
407-
image_embeddings = self.image_proj(image_features.last_hidden_state)
418+
image_embeddings = self.image_proj(image_encodings.last_hidden_state)
419+
image_embeddings = image_embeddings.reshape(
420+
tuple([batch_size, image_num] + list(image_embeddings.shape)[1:])
421+
)
422+
image_embeddings = image_embeddings.mean(dim=1)
423+
424+
# TODO(Reviewer): the internal code doesn't have normalization. Do we need this? Is the dimension correct? image_embeddings has shape (batch_size, dim1, dim2)
408425
image_embeddings = F.normalize(image_embeddings, dim=1)
409426
embeds.append(image_embeddings)
410427

lanistr/scripts/download_amazon.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ wget https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2/metaFiles2/meta_
3737
cd ../../../
3838

3939
# the following take nearly 30 minutes each.
40-
python datasets/amazon/download_images.py --category All_Beauty
41-
python datasets/amazon/download_images.py --category AMAZON_FASHION
40+
python dataset/amazon/download_images.py --category All_Beauty
41+
python dataset/amazon/download_images.py --category AMAZON_FASHION
4242

4343
# this will take many hours but it goes by fast because there are not too many images
44-
python datasets/amazon/download_images.py --category Office_Products
44+
python dataset/amazon/download_images.py --category Office_Products
File renamed without changes.

0 commit comments

Comments
 (0)