Skip to content

Commit 496ceb9

Browse files
author
committed
Deployed b07be60 with MkDocs version: 1.6.1
1 parent a9a253d commit 496ceb9

File tree

9 files changed

+1936
-59
lines changed

9 files changed

+1936
-59
lines changed

api/algorithms/dpo.html

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,16 +1478,16 @@ <h3 id="astra_rl.algorithms.dpo.DPO" class="doc doc-heading">
14781478
<span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span>
14791479
<span class="n">attacker_logprobs_win</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">problem</span><span class="o">.</span><span class="n">_get_attacker_logprobs_and_validate</span><span class="p">(</span>
14801480
<span class="n">batch</span><span class="o">.</span><span class="n">prefixes</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">suffix_pos</span>
1481-
<span class="p">)</span>
1481+
<span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Sum per-token logprobs to get sequence logprobs</span>
14821482
<span class="n">attacker_logprobs_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">problem</span><span class="o">.</span><span class="n">_get_attacker_logprobs_and_validate</span><span class="p">(</span>
14831483
<span class="n">batch</span><span class="o">.</span><span class="n">prefixes</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">suffix_neg</span>
1484-
<span class="p">)</span>
1484+
<span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Sum per-token logprobs to get sequence logprobs</span>
14851485
<span class="n">baseline_logprobs_win</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">problem</span><span class="o">.</span><span class="n">_get_baseline_logprobs_and_validate</span><span class="p">(</span>
14861486
<span class="n">batch</span><span class="o">.</span><span class="n">prefixes</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">suffix_pos</span>
1487-
<span class="p">)</span>
1487+
<span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Sum per-token logprobs to get sequence logprobs</span>
14881488
<span class="n">baseline_logprobs_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">problem</span><span class="o">.</span><span class="n">_get_baseline_logprobs_and_validate</span><span class="p">(</span>
14891489
<span class="n">batch</span><span class="o">.</span><span class="n">prefixes</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">suffix_neg</span>
1490-
<span class="p">)</span>
1490+
<span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Sum per-token logprobs to get sequence logprobs</span>
14911491

14921492
<span class="c1"># https://github.com/eric-mitchell/direct-preference-optimization/blob/ \</span>
14931493
<span class="c1"># f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87</span>

0 commit comments

Comments
 (0)