@@ -48,6 +48,18 @@ def log(t, eps = 1e-5):
48
48
def entropy (prob ):
49
49
return (- prob * log (prob )).sum (dim = - 1 )
50
50
51
+ # cosine sim linear
52
+
53
+ class CosineSimLinear (Module ):
54
+ def __init__ (self , dim_in , dim_out , ** kwargs ):
55
+ super ().__init__ ()
56
+ self .weight = nn .Parameter (torch .randn (dim_in , dim_out ))
57
+
58
+ def forward (self , x ):
59
+ x = F .normalize (x , dim = - 1 )
60
+ w = F .normalize (self .weight , dim = 0 )
61
+ return x @ w
62
+
51
63
# class
52
64
53
65
class LFQ (Module ):
@@ -66,7 +78,8 @@ def __init__(
66
78
frac_per_sample_entropy = 1. , # make less than 1. to only use a random fraction of the probs for per sample entropy
67
79
use_code_agnostic_commit_loss = False ,
68
80
projection_has_bias = True ,
69
- soft_clamp_input_value = None
81
+ soft_clamp_input_value = None ,
82
+ cosine_sim_project_in = False
70
83
):
71
84
super ().__init__ ()
72
85
@@ -82,7 +95,9 @@ def __init__(
82
95
dim = default (dim , codebook_dims )
83
96
84
97
has_projections = dim != codebook_dims
85
- self .project_in = nn .Linear (dim , codebook_dims , bias = projection_has_bias ) if has_projections else nn .Identity ()
98
+
99
+ project_in_klass = CosineSimLinear if cosine_sim_project_in else nn .Linear
100
+ self .project_in = project_in_klass (dim , codebook_dims , bias = projection_has_bias ) if has_projections else nn .Identity ()
86
101
self .project_out = nn .Linear (codebook_dims , dim , bias = projection_has_bias ) if has_projections else nn .Identity ()
87
102
self .has_projections = has_projections
88
103
0 commit comments