11
11
12
12
import torch
13
13
import json
14
+ import os
14
15
from torch .optim import AdamW
15
16
from transformers import GPT2LMHeadModel , AutoTokenizer
16
-
17
17
from astra_rl import ASTProblem , ASTEnvironment , DPO , DetoxifyModerator , Harness
18
+ from astra_rl .logging import logger
18
19
19
20
# MODEL_NAME = "sshleifer/tiny-gpt2" # Runs fast on cpu only
20
21
MODEL_NAME = "gpt2"
@@ -32,6 +33,21 @@ def __init__(self, device="cuda"):
32
33
33
34
self .tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
34
35
self .tokenizer .pad_token_id = self .tokenizer .eos_token_id
36
+ self .tokenizer .padding_side = "left"
37
+ self .tokenizer .truncation_side = "left"
38
+
39
+ self .attacker .config .pad_token_id = self .tokenizer .eos_token_id
40
+ self .target .config .pad_token_id = self .tokenizer .eos_token_id
41
+
42
+ # model’s usable max sequence length (GPT-2: 1024)
43
+ self .max_ctx = int (
44
+ getattr (
45
+ self .attacker .config ,
46
+ "n_positions" ,
47
+ getattr (self .attacker .config , "max_position_embeddings" , 1024 ),
48
+ )
49
+ )
50
+ print (f"Using model { MODEL_NAME } with max context length { self .max_ctx } " )
35
51
36
52
# TASK: you have to implement these for our API
37
53
def get_target_logprobs (self , context , continuation ):
@@ -54,28 +70,32 @@ def rollout_prompt_with_target(self, prompt):
54
70
def parameters (self ):
55
71
return self .attacker .parameters ()
56
72
57
- # two helper methods to make the implementatinos above easy
73
+ # two helper methods to make the implementations above easy
58
74
# you don't have to implement these for the API, but you should probably
59
75
# do something like this unless your attacker and defense is very different
60
76
def __rollout (self , model , prompt ):
61
- ### TODO: remove this when find bug
62
- for p in prompt :
63
- assert isinstance (p , str ), f"Bad prompt: { p } "
64
- assert len (p ) > 0 , "Empty prompt detected"
65
-
66
- # we truncate the prompt to 1024 tokens to avoid a PyTorch CUDA device-side indexing error (conversation contexts can get too long in the multiturn setting)
77
+ gen_length = 32
78
+ max_context_len = self .max_ctx - gen_length
79
+ # we truncate the prompt to 1024 - 32 tokens to avoid a PyTorch CUDA device-side indexing error (conversation contexts can get too long in the multiturn setting)
67
80
tokenized_prompt = self .tokenizer (
68
81
prompt ,
69
82
padding = True ,
70
83
return_tensors = "pt" ,
71
- padding_side = "left" ,
72
84
truncation = True ,
73
- max_length = 1024 ,
85
+ max_length = max_context_len ,
86
+ add_special_tokens = False , # I added this, is it okay?
74
87
).to (self .device )
88
+
89
+ # print statments to find bug
90
+ ids = tokenized_prompt ["input_ids" ]
91
+ seq_len = ids .shape [1 ]
92
+ # print("ROLL seq_len:", seq_len, "max_new:", 32, "total_if_generated:", seq_len + 32)
93
+ assert seq_len + 32 <= getattr (model .config , "n_positions" , 1024 )
94
+
75
95
output = model .generate (
76
96
** tokenized_prompt ,
77
97
pad_token_id = self .tokenizer .eos_token_id ,
78
- max_new_tokens = 32 ,
98
+ max_new_tokens = gen_length ,
79
99
do_sample = True ,
80
100
top_p = 0.9 ,
81
101
top_k = 50 ,
@@ -87,11 +107,14 @@ def __rollout(self, model, prompt):
87
107
self .tokenizer .batch_decode (output , skip_special_tokens = True ), prompt
88
108
)
89
109
]
110
+
90
111
return continuation
91
112
92
113
def __get_logprobs (self , model , context , continuation ):
93
114
# tokenize both context and continuation
115
+ # make sure context is not too long (context + continuation should be <= 1024 / max seq len for GPT2)
94
116
context = self .tokenizer (context )
117
+ # continuation should be only 32 tokens long
95
118
continuation = self .tokenizer (continuation )
96
119
97
120
# create a mask such that the context is masked out
@@ -101,7 +124,7 @@ def __get_logprobs(self, model, context, continuation):
101
124
for i , j in zip (context .input_ids , continuation .input_ids )
102
125
]
103
126
104
- # combine context + continuation; compute how much to pad -- bug
127
+ # combine context + continuation; compute how much to pad
105
128
combined = [i + j for i , j in zip (context .input_ids , continuation .input_ids )]
106
129
max_length = max (len (i ) for i in combined )
107
130
@@ -118,10 +141,16 @@ def __get_logprobs(self, model, context, continuation):
118
141
# move things to torch and cuda (make sure indicies <= 1024 for GPT2... this is model specific!)
119
142
# TODO: show how to make this capping flexible to the model to help future users
120
143
combined = torch .tensor (combined ).to (self .device )[
121
- :, - 1024 :
144
+ :, - self . max_ctx :
122
145
] # cap length to 1024
123
- attention_mask = torch .tensor (attention_mask ).to (self .device )[:, - 1024 :]
124
- combined_mask = torch .tensor (combined_mask ).to (self .device )[:, - 1024 :]
146
+ attention_mask = torch .tensor (attention_mask ).to (self .device )[
147
+ :, - self .max_ctx :
148
+ ]
149
+ combined_mask = torch .tensor (combined_mask ).to (self .device )[:, - self .max_ctx :]
150
+
151
+ # print statements to find bug
152
+ # print("LOGPROBS seq_len:", combined.shape[1])
153
+ assert combined .shape [1 ] <= getattr (model .config , "n_positions" , 1024 )
125
154
126
155
# run inference
127
156
logits = (
@@ -138,10 +167,52 @@ def __get_logprobs(self, model, context, continuation):
138
167
return logprobs
139
168
140
169
170
+ # the following two functions will be implemented in the trainer class. This example
171
+ # does not use a trainer so we implement it here
172
+ def save (eval_env , step , tag = "step" ):
173
+ if tag == "best" :
174
+ out = os .path .join ("checkpoints" , "best" ) # single fixed path
175
+ else :
176
+ out = os .path .join ("checkpoints" , f"{ tag } -{ step } " )
177
+
178
+ # Save attacker/target in HF format
179
+ os .makedirs (out , exist_ok = True )
180
+ eval_env .problem .attacker .save_pretrained (out )
181
+ eval_env .problem .tokenizer .save_pretrained (out )
182
+
183
+
184
+ def eval_epoch (env , dev_prompts , best_score , step , tag = "step" ):
185
+ print (f"EVALUATING after training step { step } ..." )
186
+ rewards = []
187
+
188
+ for indx , i in enumerate (dev_prompts ):
189
+ if indx % 30 == 0 :
190
+ print (f"EVAULATED { indx } /{ len (dev_prompts )} steps..." )
191
+ # perform a sigle eval rollout per dev prompt and collect a list of rewards
192
+ # FIND WAY to extract rewards from the rollout
193
+ rollout = env .eval_rollout (i )
194
+ final_rollout_reward = env .final_reward (rollout )
195
+
196
+ rewards += [final_rollout_reward ]
197
+
198
+ print (f"EVAULATED { indx } /{ len (dev_prompts )} steps..." )
199
+ dev_score = sum (rewards ) / len (rewards )
200
+
201
+ if dev_score > best_score :
202
+ logger .info (f"NEW BEST! { round (dev_score , 3 )} " )
203
+ logger .info ({"training/dev_score" : dev_score }, step = step )
204
+ save (env , step , "best" )
205
+
206
+
141
207
def main () -> None :
208
+ best_score = - float ("inf" ) # best score so far, used to save the best model
142
209
# prompts to use to seed initial stage
210
+ # read in training prompts
143
211
with open ("prompts_reddit_train.json" ) as f :
144
212
PROMPTS = json .load (f )
213
+ # read in dev set of prompts
214
+ with open ("prompts_reddit_dev.json" ) as f :
215
+ dev_prompts = json .load (f )
145
216
146
217
DEVICE = "cuda" # cuda/cpu/mps
147
218
@@ -180,6 +251,7 @@ def main() -> None:
180
251
# TODO: Do we want to add other things here to logging?
181
252
step_logs ["step" ] = step
182
253
harness .log_current_step (step_logs )
254
+ eval_epoch (env , dev_prompts , best_score , step , "best" )
183
255
184
256
185
257
if __name__ == "__main__" :
0 commit comments