7
7
"""
8
8
9
9
from math import log2 , ceil
10
+ from functools import partial
10
11
from collections import namedtuple
11
12
12
13
import torch
@@ -51,14 +52,20 @@ def entropy(prob):
51
52
# cosine sim linear
52
53
53
54
class CosineSimLinear (Module ):
54
- def __init__ (self , dim_in , dim_out , ** kwargs ):
55
+ def __init__ (
56
+ self ,
57
+ dim_in ,
58
+ dim_out ,
59
+ scale = 1.
60
+ ):
55
61
super ().__init__ ()
62
+ self .scale = scale
56
63
self .weight = nn .Parameter (torch .randn (dim_in , dim_out ))
57
64
58
65
def forward (self , x ):
59
66
x = F .normalize (x , dim = - 1 )
60
67
w = F .normalize (self .weight , dim = 0 )
61
- return x @ w
68
+ return ( x @ w ) * self . scale
62
69
63
70
# class
64
71
@@ -79,7 +86,8 @@ def __init__(
79
86
use_code_agnostic_commit_loss = False ,
80
87
projection_has_bias = True ,
81
88
soft_clamp_input_value = None ,
82
- cosine_sim_project_in = False
89
+ cosine_sim_project_in = False ,
90
+ cosine_sim_project_in_scale = None
83
91
):
84
92
super ().__init__ ()
85
93
@@ -96,8 +104,13 @@ def __init__(
96
104
97
105
has_projections = dim != codebook_dims
98
106
99
- project_in_klass = CosineSimLinear if cosine_sim_project_in else nn .Linear
100
- self .project_in = project_in_klass (dim , codebook_dims , bias = projection_has_bias ) if has_projections else nn .Identity ()
107
+ if cosine_sim_project_in :
108
+ cosine_sim_project_in = default (cosine_sim_project_in_scale , codebook_scale )
109
+ project_in_klass = partial (CosineSimLinear , scale = cosine_sim_project_in )
110
+ else :
111
+ project_in_klass = partial (nn .Linear , bias = projection_has_bias )
112
+
113
+ self .project_in = project_in_klass (dim , codebook_dims ) if has_projections else nn .Identity ()
101
114
self .project_out = nn .Linear (codebook_dims , dim , bias = projection_has_bias ) if has_projections else nn .Identity ()
102
115
self .has_projections = has_projections
103
116
0 commit comments