@@ -2830,7 +2830,9 @@ <h3 id="astra_rl.core.problem.ValueFunctionProblem" class="doc doc-heading">
2830
2830
< span class ="normal "> 289</ span >
2831
2831
< span class ="normal "> 290</ span >
2832
2832
< span class ="normal "> 291</ span >
2833
- < span class ="normal "> 292</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> ValueFunctionProblem</ span > < span class ="p "> (</ span > < span class ="n "> Problem</ span > < span class ="p "> [</ span > < span class ="n "> StateT</ span > < span class ="p "> ,</ span > < span class ="n "> ActionT</ span > < span class ="p "> ],</ span > < span class ="n "> ABC</ span > < span class ="p "> ):</ span >
2833
+ < span class ="normal "> 292</ span >
2834
+ < span class ="normal "> 293</ span >
2835
+ < span class ="normal "> 294</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> ValueFunctionProblem</ span > < span class ="p "> (</ span > < span class ="n "> Problem</ span > < span class ="p "> [</ span > < span class ="n "> StateT</ span > < span class ="p "> ,</ span > < span class ="n "> ActionT</ span > < span class ="p "> ],</ span > < span class ="n "> ABC</ span > < span class ="p "> ):</ span >
2834
2836
< span class ="w "> </ span > < span class ="sd "> """Extends `Problem` to be able to return sequence values with a value head.</ span >
2835
2837
2836
2838
< span class ="sd "> Note:</ span >
@@ -2861,7 +2863,9 @@ <h3 id="astra_rl.core.problem.ValueFunctionProblem" class="doc doc-heading">
2861
2863
< span class ="sd "> Returns:</ span >
2862
2864
< span class ="sd "> torch.Tensor[batch_size, max_continuation_length]: The per-token values of</ span >
2863
2865
< span class ="sd "> the given squence by the sequence predictor. Do not include the value of the input</ span >
2864
- < span class ="sd "> prefixes.</ span >
2866
+ < span class ="sd "> prefixes. If you are predicting on the whole input, you should be slicing on</ span >
2867
+ < span class ="sd "> `[:, :-1]`, meaning you should *not* return the value of the last token, whose</ span >
2868
+ < span class ="sd "> input is eos/context length limit.</ span >
2865
2869
< span class ="sd "> """</ span >
2866
2870
2867
2871
< span class ="k "> pass</ span >
@@ -2970,7 +2974,27 @@ <h4 id="astra_rl.core.problem.ValueFunctionProblem.value" class="doc doc-heading
2970
2974
</ td >
2971
2975
< td >
2972
2976
< div class ="doc-md-description ">
2973
- < p > prefixes.</ p >
2977
+ < p > prefixes. If you are predicting on the whole input, you should be slicing on</ p >
2978
+ </ div >
2979
+ </ td >
2980
+ </ tr >
2981
+ < tr class ="doc-section-item ">
2982
+ < td >
2983
+ < code > < span title ="torch.Tensor "> Tensor</ span > </ code >
2984
+ </ td >
2985
+ < td >
2986
+ < div class ="doc-md-description ">
2987
+ < p > < code > [:, :-1]</ code > , meaning you should < em > not</ em > return the value of the last token, whose</ p >
2988
+ </ div >
2989
+ </ td >
2990
+ </ tr >
2991
+ < tr class ="doc-section-item ">
2992
+ < td >
2993
+ < code > < span title ="torch.Tensor "> Tensor</ span > </ code >
2994
+ </ td >
2995
+ < td >
2996
+ < div class ="doc-md-description ">
2997
+ < p > input is eos/context length limit.</ p >
2974
2998
</ div >
2975
2999
</ td >
2976
3000
</ tr >
@@ -2999,7 +3023,9 @@ <h4 id="astra_rl.core.problem.ValueFunctionProblem.value" class="doc doc-heading
2999
3023
< span class ="normal "> 289</ span >
3000
3024
< span class ="normal "> 290</ span >
3001
3025
< span class ="normal "> 291</ span >
3002
- < span class ="normal "> 292</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="nd "> @abstractmethod</ span >
3026
+ < span class ="normal "> 292</ span >
3027
+ < span class ="normal "> 293</ span >
3028
+ < span class ="normal "> 294</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="nd "> @abstractmethod</ span >
3003
3029
< span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> value</ span > < span class ="p "> (</ span >
3004
3030
< span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> context</ span > < span class ="p "> :</ span > < span class ="n "> Sequence</ span > < span class ="p "> [</ span > < span class ="n "> StateT</ span > < span class ="p "> ],</ span > < span class ="n "> continuation</ span > < span class ="p "> :</ span > < span class ="n "> Sequence</ span > < span class ="p "> [</ span > < span class ="n "> ActionT</ span > < span class ="p "> ]</ span >
3005
3031
< span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> :</ span >
@@ -3015,7 +3041,9 @@ <h4 id="astra_rl.core.problem.ValueFunctionProblem.value" class="doc doc-heading
3015
3041
< span class ="sd "> Returns:</ span >
3016
3042
< span class ="sd "> torch.Tensor[batch_size, max_continuation_length]: The per-token values of</ span >
3017
3043
< span class ="sd "> the given squence by the sequence predictor. Do not include the value of the input</ span >
3018
- < span class ="sd "> prefixes.</ span >
3044
+ < span class ="sd "> prefixes. If you are predicting on the whole input, you should be slicing on</ span >
3045
+ < span class ="sd "> `[:, :-1]`, meaning you should *not* return the value of the last token, whose</ span >
3046
+ < span class ="sd "> input is eos/context length limit.</ span >
3019
3047
< span class ="sd "> """</ span >
3020
3048
3021
3049
< span class ="k "> pass</ span >
0 commit comments