Skip to content

Commit f43bb8d

Browse files
author
ThibaultGROUEIX
committed
fix memory leak
1 parent 6f4b030 commit f43bb8d

File tree

7 files changed

+16
-101
lines changed

7 files changed

+16
-101
lines changed

extension/dist_chamfer.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,68 +10,45 @@
1010

1111
# Chamfer's distance module @thibaultgroueix
1212
# GPU tensors only
13-
# remember to call chamferFunction.clean() after loss.backward() to avoid memory leak
14-
1513
class chamferFunction(Function):
16-
def __enter__(self):
17-
""" """
18-
19-
def __exit__(self, exc_type, exc_value, traceback):
20-
del self.xyz1
21-
del self.xyz2
22-
del self.idx1
23-
del self.idx2
24-
del self.dist1
25-
del self.dist2
26-
27-
def clean(self):
28-
# print('Destructor called, vehicle deleted.')
29-
del self.xyz1
30-
del self.xyz2
31-
del self.idx1
32-
del self.idx2
33-
del self.dist1
34-
del self.dist2
35-
36-
def forward(self, xyz1, xyz2):
14+
@staticmethod
15+
def forward(ctx, xyz1, xyz2):
3716
batchsize, n, _ = xyz1.size()
3817
_, m, _ = xyz2.size()
39-
self.xyz1 = xyz1
40-
self.xyz2 = xyz2
18+
4119
dist1 = torch.zeros(batchsize, n)
4220
dist2 = torch.zeros(batchsize, m)
4321

44-
self.idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
45-
self.idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
22+
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
23+
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
4624

4725
dist1 = dist1.cuda()
4826
dist2 = dist2.cuda()
49-
self.idx1 = self.idx1.cuda()
50-
self.idx2 = self.idx2.cuda()
51-
chamfer.forward(xyz1, xyz2, dist1, dist2, self.idx1, self.idx2)
27+
idx1 = idx1.cuda()
28+
idx2 = idx2.cuda()
5229

53-
self.dist1 = dist1
54-
self.dist2 = dist2
30+
chamfer.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
31+
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
5532
return dist1, dist2
5633

57-
def backward(self, graddist1, graddist2):
34+
@staticmethod
35+
def backward(ctx, graddist1, graddist2):
36+
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
5837
graddist1 = graddist1.contiguous()
5938
graddist2 = graddist2.contiguous()
6039

61-
gradxyz1 = torch.zeros(self.xyz1.size())
62-
gradxyz2 = torch.zeros(self.xyz2.size())
40+
gradxyz1 = torch.zeros(xyz1.size())
41+
gradxyz2 = torch.zeros(xyz2.size())
6342

6443
gradxyz1 = gradxyz1.cuda()
6544
gradxyz2 = gradxyz2.cuda()
66-
chamfer.backward(self.xyz1, self.xyz2, gradxyz1, gradxyz2, graddist1, graddist2, self.idx1, self.idx2)
45+
chamfer.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
6746
return gradxyz1, gradxyz2
6847

6948
class chamferDist(nn.Module):
7049
def __init__(self):
7150
super(chamferDist, self).__init__()
72-
self.cham = chamferFunction()
7351

7452
def forward(self, input1, input2):
75-
self.cham = chamferFunction()
76-
return self.cham(input1, input2)
53+
return chamferFunction.apply(input1, input2)
7754

training/train_AE_AtlasNet.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,6 @@ def distChamfer(a,b):
153153
loss_net.backward()
154154
train_loss.update(loss_net.item())
155155
optimizer.step() #gradient update
156-
157-
# This is neccesary to avoid a memory leak issue
158-
if opt.accelerated_chamfer:
159-
distChamfer.cham.clean()
160-
del distChamfer.cham
161-
162156
# VIZUALIZE
163157
if i%50 <= 0:
164158
vis.scatter(X = points.transpose(2,1).contiguous()[0].data.cpu(),
@@ -200,10 +194,6 @@ def distChamfer(a,b):
200194
pointsReconstructed = network(points)
201195
dist1, dist2 = distChamfer(points.transpose(2,1).contiguous(), pointsReconstructed)
202196
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
203-
if opt.accelerated_chamfer:
204-
distChamfer.cham.clean()
205-
del distChamfer.cham
206-
207197
val_loss.update(loss_net.item())
208198
dataset_test.perCatValueMeter[cat[0]].update(loss_net.item())
209199
if i%200 ==0 :

training/train_AE_AtlasNet_sphere.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,6 @@ def distChamfer(a,b):
158158
loss_net.backward()
159159
train_loss.update(loss_net.item())
160160
optimizer.step() #gradient update
161-
# This is neccesary to avoid a memory leak issue
162-
if opt.accelerated_chamfer:
163-
distChamfer.cham.clean()
164-
del distChamfer.cham
165-
166161
# VIZUALIZE
167162
if i%50 <= 0:
168163
vis.scatter(X = points.transpose(2,1).contiguous()[0].data.cpu(),
@@ -203,10 +198,6 @@ def distChamfer(a,b):
203198
pointsReconstructed = network(points)
204199
dist1, dist2 = distChamfer(points.transpose(2,1).contiguous(), pointsReconstructed)
205200
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
206-
if opt.accelerated_chamfer:
207-
distChamfer.cham.clean()
208-
del distChamfer.cham
209-
210201
val_loss.update(loss_net.item())
211202
dataset_test.perCatValueMeter[cat[0]].update(loss_net.item())
212203
if i%200 ==0 :

training/train_AE_Baseline.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,6 @@ def distChamfer(a,b):
148148
loss_net.backward()
149149
train_loss.update(loss_net.item())
150150
optimizer.step() #gradient update
151-
152-
# This is neccesary to avoid a memory leak issue
153-
if opt.accelerated_chamfer:
154-
distChamfer.cham.clean()
155-
del distChamfer.cham
156-
157151
# VIZUALIZE
158152
if i%50 <= 0:
159153
vis.scatter(X = points.transpose(2,1).contiguous()[0].data.cpu(),
@@ -194,10 +188,6 @@ def distChamfer(a,b):
194188
pointsReconstructed = network(points)
195189
dist1, dist2 = distChamfer(points.transpose(2,1).contiguous(), pointsReconstructed)
196190
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
197-
if opt.accelerated_chamfer:
198-
distChamfer.cham.clean()
199-
del distChamfer.cham
200-
201191
val_loss.update(loss_net.item())
202192
dataset_test.perCatValueMeter[cat[0]].update(loss_net.item())
203193
if i%200 ==0 :

training/train_SVR_AtlasNet.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def distChamfer(a,b):
126126
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
127127
val_loss.update(loss_net.item())
128128
# This is neccesary to avoid a memory leak issue
129-
if opt.accelerated_chamfer:
130-
distChamfer.cham.clean()
131-
del distChamfer.cham
132129

133130
print("Previous decoder performances : ", val_loss.avg)
134131

@@ -211,9 +208,6 @@ def distChamfer(a,b):
211208

212209
optimizer.step()
213210
# This is neccesary to avoid a memory leak issue
214-
if opt.accelerated_chamfer:
215-
distChamfer.cham.clean()
216-
del distChamfer.cham
217211

218212
# VIZUALIZE
219213
if i%50 <= 0:
@@ -252,9 +246,6 @@ def distChamfer(a,b):
252246
dist1, dist2 = distChamfer(points, pointsReconstructed)
253247
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
254248
# This is neccesary to avoid a memory leak issue
255-
if opt.accelerated_chamfer:
256-
distChamfer.cham.clean()
257-
del distChamfer.cham
258249
val_view_loss.update(loss_net.item())
259250
#UPDATE CURVES
260251
val_view_curve.append(val_view_loss.avg)
@@ -274,9 +265,6 @@ def distChamfer(a,b):
274265
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
275266
val_loss.update(loss_net.item())
276267
# This is neccesary to avoid a memory leak issue
277-
if opt.accelerated_chamfer:
278-
distChamfer.cham.clean()
279-
del distChamfer.cham
280268
dataset_test.perCatValueMeter[cat[0]].update(loss_net.item())
281269
if i%25 ==0 :
282270
vis.image(img[0].data.cpu().contiguous(), win = 'INPUT IMAGE VAL', opts = dict( title = "INPUT IMAGE TRAIN"))

training/train_SVR_AtlasNet_sphere.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,6 @@ def distChamfer(a,b):
125125
dist1, dist2 = distChamfer(points.transpose(2,1).contiguous(), pointsReconstructed)
126126
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
127127
val_loss.update(loss_net.item())
128-
if opt.accelerated_chamfer:
129-
distChamfer.cham.clean()
130-
del distChamfer.cham
131128
print("Previous decoder performances : ", val_loss.avg)
132129

133130
#Create network
@@ -200,9 +197,6 @@ def distChamfer(a,b):
200197
train_loss.update(loss_net.item())
201198

202199
optimizer.step()
203-
if opt.accelerated_chamfer:
204-
distChamfer.cham.clean()
205-
del distChamfer.cham
206200

207201
# VIZUALIZE
208202
if i%50 <= 0:
@@ -243,9 +237,6 @@ def distChamfer(a,b):
243237
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
244238
val_loss.update(loss_net.item())
245239
dataset_test.perCatValueMeter[cat[0]].update(loss_net.item())
246-
if opt.accelerated_chamfer:
247-
distChamfer.cham.clean()
248-
del distChamfer.cham
249240

250241
if i%25 ==0 :
251242
vis.image(img[0].data.cpu().contiguous(), win = 'INPUT IMAGE VAL', opts = dict( title = "INPUT IMAGE TRAIN"))

training/train_SVR_Baseline.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,6 @@ def distChamfer(a,b):
125125
dist1, dist2 = distChamfer(points.transpose(2,1).contiguous(), pointsReconstructed)
126126
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
127127
val_loss.update(loss_net.item())
128-
if opt.accelerated_chamfer:
129-
distChamfer.cham.clean()
130-
del distChamfer.cham
131128
print("Previous decoder performances : ", val_loss.avg)
132129

133130
#Create network
@@ -200,9 +197,6 @@ def distChamfer(a,b):
200197
train_loss.update(loss_net.item())
201198

202199
optimizer.step()
203-
if opt.accelerated_chamfer:
204-
distChamfer.cham.clean()
205-
del distChamfer.cham
206200

207201
# VIZUALIZE
208202
if i%50 <= 0:
@@ -241,9 +235,6 @@ def distChamfer(a,b):
241235
dist1, dist2 = distChamfer(points, pointsReconstructed)
242236
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
243237
val_view_loss.update(loss_net.item())
244-
if opt.accelerated_chamfer:
245-
distChamfer.cham.clean()
246-
del distChamfer.cham
247238

248239

249240
#UPDATE CURVES
@@ -266,9 +257,6 @@ def distChamfer(a,b):
266257
dist1, dist2 = distChamfer(points, pointsReconstructed)
267258
loss_net = (torch.mean(dist1)) + (torch.mean(dist2))
268259
val_loss.update(loss_net.item())
269-
if opt.accelerated_chamfer:
270-
distChamfer.cham.clean()
271-
del distChamfer.cham
272260

273261
dataset_test.perCatValueMeter[cat[0]].update(loss_net.item())
274262
if i%25 ==0 :

0 commit comments

Comments
 (0)