Skip to content

Commit cdeb37f

Browse files
committed
fixed out of index error in gpt2 run
1 parent 85aed60 commit cdeb37f

File tree

6 files changed

+29
-12
lines changed

6 files changed

+29
-12
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,7 @@ cython_debug/
189189

190190
# Weights & Biases
191191
wandb/
192-
wandb-debug.log
192+
wandb-debug.log
193+
194+
# Ignore training prompts (for now!)
195+
examples/GPT2_v_GPT2/prompts_reddit_train.json
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

examples/first_trial_GPT2/firstTrial.py renamed to examples/GPT2_v_GPT2/ast_basic_1.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,22 @@
66
corpora below of initial prompts.
77
"""
88

9+
# requirements: transformers tokenizers
10+
# requirements: ..
11+
912
import torch
1013
import json
1114
from torch.optim import AdamW
1215
from transformers import GPT2LMHeadModel, AutoTokenizer
16+
1317
from astra_rl import ASTProblem, ASTEnvironment, DPO, DetoxifyModerator, Harness
1418

1519
# MODEL_NAME = "sshleifer/tiny-gpt2" # Runs fast on cpu only
1620
MODEL_NAME = "gpt2"
1721

1822

1923
class ExampleDetoxifyProblem(ASTProblem):
20-
def __init__(self, device="cpu"):
24+
def __init__(self, device="cuda"):
2125
# TASK: initialize and pass to superclass
2226
# your choice of moderator
2327
super().__init__(DetoxifyModerator())
@@ -54,8 +58,19 @@ def parameters(self):
5458
# you don't have to implement these for the API, but you should probably
5559
# do something like this unless your attacker and defense is very different
5660
def __rollout(self, model, prompt):
61+
### TODO: remove this when find bug
62+
for p in prompt:
63+
assert isinstance(p, str), f"Bad prompt: {p}"
64+
assert len(p) > 0, "Empty prompt detected"
65+
66+
# we truncate the prompt to 1024 tokens to avoid a PyTorch CUDA device-side indexing error (conversation contexts can get too long in the multiturn setting)
5767
tokenized_prompt = self.tokenizer(
58-
prompt, padding=True, return_tensors="pt", padding_side="left"
68+
prompt,
69+
padding=True,
70+
return_tensors="pt",
71+
padding_side="left",
72+
truncation=True,
73+
max_length=1024,
5974
).to(self.device)
6075
output = model.generate(
6176
**tokenized_prompt,
@@ -86,7 +101,7 @@ def __get_logprobs(self, model, context, continuation):
86101
for i, j in zip(context.input_ids, continuation.input_ids)
87102
]
88103

89-
# combine context + continuation; compute how much to pad
104+
# combine context + continuation; compute how much to pad -- bug
90105
combined = [i + j for i, j in zip(context.input_ids, continuation.input_ids)]
91106
max_length = max(len(i) for i in combined)
92107

@@ -100,10 +115,13 @@ def __get_logprobs(self, model, context, continuation):
100115
[True] * len(i) + [False] * (max_length - len(i)) for i in combined_mask
101116
]
102117

103-
# move things to torch and cuda
104-
combined = torch.tensor(combined).to(self.device)
105-
attention_mask = torch.tensor(attention_mask).to(self.device)
106-
combined_mask = torch.tensor(combined_mask).to(self.device)
118+
# move things to torch and cuda (make sure indicies <= 1024 for GPT2... this is model specific!)
119+
# TODO: show how to make this capping flexible to the model to help future users
120+
combined = torch.tensor(combined).to(self.device)[
121+
:, -1024:
122+
] # cap length to 1024
123+
attention_mask = torch.tensor(attention_mask).to(self.device)[:, -1024:]
124+
combined_mask = torch.tensor(combined_mask).to(self.device)[:, -1024:]
107125

108126
# run inference
109127
logits = (
@@ -122,10 +140,6 @@ def __get_logprobs(self, model, context, continuation):
122140

123141
def main() -> None:
124142
# prompts to use to seed initial stage
125-
# PROMPTS = [
126-
# "Tell us about your day! My day was",
127-
# "The Federal Reserve said last Tuesday that",
128-
# ]
129143
with open("prompts_reddit_train.json") as f:
130144
PROMPTS = json.load(f)
131145

0 commit comments

Comments
 (0)