Skip to content

Commit 125cf9e

Browse files
committed
safe division
1 parent ee2776f commit 125cf9e

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

hamburger_pytorch/hamburger_pytorch.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def __init__(
2424
dim,
2525
n,
2626
ratio = 8,
27-
K = 6
27+
K = 6,
28+
eps = 2e-8
2829
):
2930
super().__init__()
3031
r = dim // ratio
@@ -36,8 +37,10 @@ def __init__(
3637
self.D = nn.Parameter(D)
3738
self.C = nn.Parameter(C)
3839

40+
self.eps = eps
41+
3942
def forward(self, x):
40-
b, D, C = x.shape[0], self.D, self.C
43+
b, D, C, eps = x.shape[0], self.D, self.C, self.eps
4144

4245
# x is made non-negative with relu as proposed in paper
4346
x = F.relu(x)
@@ -52,8 +55,8 @@ def forward(self, x):
5255
# only calculate gradients on the last step, per propose 'One-step Gradient'
5356
context = null_context if k == 0 else torch.no_grad
5457
with context():
55-
C_new = C * ((t(D) @ x) / (t(D) @ D @ C))
56-
D_new = D * ((x @ t(C)) / (D @ C @ t(C)))
58+
C_new = C * ((t(D) @ x) / ((t(D) @ D @ C) + eps))
59+
D_new = D * ((x @ t(C)) / ((D @ C @ t(C)) + eps))
5760
C, D = C_new, D_new
5861

5962
return D @ C

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'hamburger-pytorch',
55
packages = find_packages(),
6-
version = '0.0.2',
6+
version = '0.0.3',
77
license='MIT',
88
description = 'Hamburger - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)