Skip to content

Commit 677555d

Browse files
committed
fixed out of index error in gpt2 run, init tokenizer diff
1 parent cdeb37f commit 677555d

File tree

1 file changed

+87
-15
lines changed

1 file changed

+87
-15
lines changed

examples/ast_basic.py

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111

1212
import torch
1313
import json
14+
import os
1415
from torch.optim import AdamW
1516
from transformers import GPT2LMHeadModel, AutoTokenizer
16-
1717
from astra_rl import ASTProblem, ASTEnvironment, DPO, DetoxifyModerator, Harness
18+
from astra_rl.logging import logger
1819

1920
# MODEL_NAME = "sshleifer/tiny-gpt2" # Runs fast on cpu only
2021
MODEL_NAME = "gpt2"
@@ -32,6 +33,21 @@ def __init__(self, device="cuda"):
3233

3334
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
3435
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
36+
self.tokenizer.padding_side = "left"
37+
self.tokenizer.truncation_side = "left"
38+
39+
self.attacker.config.pad_token_id = self.tokenizer.eos_token_id
40+
self.target.config.pad_token_id = self.tokenizer.eos_token_id
41+
42+
# model’s usable max sequence length (GPT-2: 1024)
43+
self.max_ctx = int(
44+
getattr(
45+
self.attacker.config,
46+
"n_positions",
47+
getattr(self.attacker.config, "max_position_embeddings", 1024),
48+
)
49+
)
50+
print(f"Using model {MODEL_NAME} with max context length {self.max_ctx}")
3551

3652
# TASK: you have to implement these for our API
3753
def get_target_logprobs(self, context, continuation):
@@ -54,28 +70,32 @@ def rollout_prompt_with_target(self, prompt):
5470
def parameters(self):
5571
return self.attacker.parameters()
5672

57-
# two helper methods to make the implementatinos above easy
73+
# two helper methods to make the implementations above easy
5874
# you don't have to implement these for the API, but you should probably
5975
# do something like this unless your attacker and defense is very different
6076
def __rollout(self, model, prompt):
61-
### TODO: remove this when find bug
62-
for p in prompt:
63-
assert isinstance(p, str), f"Bad prompt: {p}"
64-
assert len(p) > 0, "Empty prompt detected"
65-
66-
# we truncate the prompt to 1024 tokens to avoid a PyTorch CUDA device-side indexing error (conversation contexts can get too long in the multiturn setting)
77+
gen_length = 32
78+
max_context_len = self.max_ctx - gen_length
79+
# we truncate the prompt to 1024 - 32 tokens to avoid a PyTorch CUDA device-side indexing error (conversation contexts can get too long in the multiturn setting)
6780
tokenized_prompt = self.tokenizer(
6881
prompt,
6982
padding=True,
7083
return_tensors="pt",
71-
padding_side="left",
7284
truncation=True,
73-
max_length=1024,
85+
max_length=max_context_len,
86+
add_special_tokens=False, # I added this, is it okay?
7487
).to(self.device)
88+
89+
# print statments to find bug
90+
ids = tokenized_prompt["input_ids"]
91+
seq_len = ids.shape[1]
92+
# print("ROLL seq_len:", seq_len, "max_new:", 32, "total_if_generated:", seq_len + 32)
93+
assert seq_len + 32 <= getattr(model.config, "n_positions", 1024)
94+
7595
output = model.generate(
7696
**tokenized_prompt,
7797
pad_token_id=self.tokenizer.eos_token_id,
78-
max_new_tokens=32,
98+
max_new_tokens=gen_length,
7999
do_sample=True,
80100
top_p=0.9,
81101
top_k=50,
@@ -87,11 +107,14 @@ def __rollout(self, model, prompt):
87107
self.tokenizer.batch_decode(output, skip_special_tokens=True), prompt
88108
)
89109
]
110+
90111
return continuation
91112

92113
def __get_logprobs(self, model, context, continuation):
93114
# tokenize both context and continuation
115+
# make sure context is not too long (context + continuation should be <= 1024 / max seq len for GPT2)
94116
context = self.tokenizer(context)
117+
# continuation should be only 32 tokens long
95118
continuation = self.tokenizer(continuation)
96119

97120
# create a mask such that the context is masked out
@@ -101,7 +124,7 @@ def __get_logprobs(self, model, context, continuation):
101124
for i, j in zip(context.input_ids, continuation.input_ids)
102125
]
103126

104-
# combine context + continuation; compute how much to pad -- bug
127+
# combine context + continuation; compute how much to pad
105128
combined = [i + j for i, j in zip(context.input_ids, continuation.input_ids)]
106129
max_length = max(len(i) for i in combined)
107130

@@ -118,10 +141,16 @@ def __get_logprobs(self, model, context, continuation):
118141
# move things to torch and cuda (make sure indicies <= 1024 for GPT2... this is model specific!)
119142
# TODO: show how to make this capping flexible to the model to help future users
120143
combined = torch.tensor(combined).to(self.device)[
121-
:, -1024:
144+
:, -self.max_ctx :
122145
] # cap length to 1024
123-
attention_mask = torch.tensor(attention_mask).to(self.device)[:, -1024:]
124-
combined_mask = torch.tensor(combined_mask).to(self.device)[:, -1024:]
146+
attention_mask = torch.tensor(attention_mask).to(self.device)[
147+
:, -self.max_ctx :
148+
]
149+
combined_mask = torch.tensor(combined_mask).to(self.device)[:, -self.max_ctx :]
150+
151+
# print statements to find bug
152+
# print("LOGPROBS seq_len:", combined.shape[1])
153+
assert combined.shape[1] <= getattr(model.config, "n_positions", 1024)
125154

126155
# run inference
127156
logits = (
@@ -138,10 +167,52 @@ def __get_logprobs(self, model, context, continuation):
138167
return logprobs
139168

140169

170+
# the following two functions will be implemented in the trainer class. This example
171+
# does not use a trainer so we implement it here
172+
def save(eval_env, step, tag="step"):
173+
if tag == "best":
174+
out = os.path.join("checkpoints", "best") # single fixed path
175+
else:
176+
out = os.path.join("checkpoints", f"{tag}-{step}")
177+
178+
# Save attacker/target in HF format
179+
os.makedirs(out, exist_ok=True)
180+
eval_env.problem.attacker.save_pretrained(out)
181+
eval_env.problem.tokenizer.save_pretrained(out)
182+
183+
184+
def eval_epoch(env, dev_prompts, best_score, step, tag="step"):
185+
print(f"EVALUATING after training step {step}...")
186+
rewards = []
187+
188+
for indx, i in enumerate(dev_prompts):
189+
if indx % 30 == 0:
190+
print(f"EVAULATED {indx}/{len(dev_prompts)} steps...")
191+
# perform a sigle eval rollout per dev prompt and collect a list of rewards
192+
# FIND WAY to extract rewards from the rollout
193+
rollout = env.eval_rollout(i)
194+
final_rollout_reward = env.final_reward(rollout)
195+
196+
rewards += [final_rollout_reward]
197+
198+
print(f"EVAULATED {indx}/{len(dev_prompts)} steps...")
199+
dev_score = sum(rewards) / len(rewards)
200+
201+
if dev_score > best_score:
202+
logger.info(f"NEW BEST! {round(dev_score, 3)}")
203+
logger.info({"training/dev_score": dev_score}, step=step)
204+
save(env, step, "best")
205+
206+
141207
def main() -> None:
208+
best_score = -float("inf") # best score so far, used to save the best model
142209
# prompts to use to seed initial stage
210+
# read in training prompts
143211
with open("prompts_reddit_train.json") as f:
144212
PROMPTS = json.load(f)
213+
# read in dev set of prompts
214+
with open("prompts_reddit_dev.json") as f:
215+
dev_prompts = json.load(f)
145216

146217
DEVICE = "cuda" # cuda/cpu/mps
147218

@@ -180,6 +251,7 @@ def main() -> None:
180251
# TODO: Do we want to add other things here to logging?
181252
step_logs["step"] = step
182253
harness.log_current_step(step_logs)
254+
eval_epoch(env, dev_prompts, best_score, step, "best")
183255

184256

185257
if __name__ == "__main__":

0 commit comments

Comments
 (0)