@@ -1329,7 +1329,10 @@ <h3 id="astra_rl.algorithms.ppo.PPO" class="doc doc-heading">
1329
1329
< span class ="normal "> 122</ span >
1330
1330
< span class ="normal "> 123</ span >
1331
1331
< span class ="normal "> 124</ span >
1332
- < span class ="normal "> 125</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> PPO</ span > < span class ="p "> (</ span >
1332
+ < span class ="normal "> 125</ span >
1333
+ < span class ="normal "> 126</ span >
1334
+ < span class ="normal "> 127</ span >
1335
+ < span class ="normal "> 128</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> PPO</ span > < span class ="p "> (</ span >
1333
1336
< span class ="n "> Algorithm</ span > < span class ="p "> [</ span > < span class ="n "> StateT</ span > < span class ="p "> ,</ span > < span class ="n "> ActionT</ span > < span class ="p "> ,</ span > < span class ="n "> PPOStep</ span > < span class ="p "> [</ span > < span class ="n "> StateT</ span > < span class ="p "> ,</ span > < span class ="n "> ActionT</ span > < span class ="p "> ],</ span > < span class ="n "> PPOBatch</ span > < span class ="p "> [</ span > < span class ="n "> StateT</ span > < span class ="p "> ,</ span > < span class ="n "> ActionT</ span > < span class ="p "> ]],</ span >
1334
1337
< span class ="n "> ABC</ span > < span class ="p "> ,</ span >
1335
1338
< span class ="p "> ):</ span >
@@ -1399,7 +1402,10 @@ <h3 id="astra_rl.algorithms.ppo.PPO" class="doc doc-heading">
1399
1402
< span class ="n "> A</ span > < span class ="o "> =</ span > < span class ="n "> Q</ span > < span class ="o "> -</ span > < span class ="n "> values</ span >
1400
1403
1401
1404
< span class ="c1 "> # normalize advantages</ span >
1402
- < span class ="n "> A</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> A</ span > < span class ="o "> -</ span > < span class ="n "> A</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ())</ span > < span class ="o "> /</ span > < span class ="p "> (</ span > < span class ="n "> A</ span > < span class ="o "> .</ span > < span class ="n "> std</ span > < span class ="p "> ()</ span > < span class ="o "> +</ span > < span class ="mf "> 1e-8</ span > < span class ="p "> )</ span >
1405
+ < span class ="k "> if</ span > < span class ="n "> A</ span > < span class ="o "> .</ span > < span class ="n "> size</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span > < span class ="o "> ==</ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
1406
+ < span class ="n "> A</ span > < span class ="o "> =</ span > < span class ="p "> ((</ span > < span class ="n "> A</ span > < span class ="o "> -</ span > < span class ="n "> A</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ())</ span > < span class ="o "> /</ span > < span class ="p "> (</ span > < span class ="n "> A</ span > < span class ="o "> .</ span > < span class ="n "> std</ span > < span class ="p "> ()</ span > < span class ="o "> +</ span > < span class ="mf "> 1e-8</ span > < span class ="p "> ))</ span > < span class ="o "> .</ span > < span class ="n "> squeeze</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span >
1407
+ < span class ="k "> else</ span > < span class ="p "> :</ span >
1408
+ < span class ="n "> A</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> A</ span > < span class ="o "> -</ span > < span class ="n "> A</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ())</ span > < span class ="o "> /</ span > < span class ="p "> (</ span > < span class ="n "> A</ span > < span class ="o "> .</ span > < span class ="n "> std</ span > < span class ="p "> ()</ span > < span class ="o "> +</ span > < span class ="mf "> 1e-8</ span > < span class ="p "> )</ span >
1403
1409
< span class ="c1 "> # compute ratio, should be 1 at the first iteration</ span >
1404
1410
< span class ="n "> ratio</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> exp</ span > < span class ="p "> ((</ span > < span class ="n "> logprobs_attacker</ span > < span class ="o "> -</ span > < span class ="n "> logprobs_baseline</ span > < span class ="o "> .</ span > < span class ="n "> detach</ span > < span class ="p "> ()))</ span >
1405
1411
0 commit comments