You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# MODEL_NAME = "sshleifer/tiny-gpt2" # Runs fast on cpu only
16
20
MODEL_NAME="gpt2"
17
21
18
22
19
23
classExampleDetoxifyProblem(ASTProblem):
20
-
def__init__(self, device="cpu"):
24
+
def__init__(self, device="cuda"):
21
25
# TASK: initialize and pass to superclass
22
26
# your choice of moderator
23
27
super().__init__(DetoxifyModerator())
@@ -54,8 +58,19 @@ def parameters(self):
54
58
# you don't have to implement these for the API, but you should probably
55
59
# do something like this unless your attacker and defense is very different
56
60
def__rollout(self, model, prompt):
61
+
### TODO: remove this when find bug
62
+
forpinprompt:
63
+
assertisinstance(p, str), f"Bad prompt: {p}"
64
+
assertlen(p) >0, "Empty prompt detected"
65
+
66
+
# we truncate the prompt to 1024 tokens to avoid a PyTorch CUDA device-side indexing error (conversation contexts can get too long in the multiturn setting)
0 commit comments