Skip to content

Commit 04c371f

Browse files
BenZickelBen Zickel
andauthored
Add batched calculation option to energy_score_empirical in order to reduce memory consumption (#3402)
* Add batched calculation option to energy_score_empirical in order to reduce memory consumption. * Replace native Python sum with torch stack(...).sum(). --------- Co-authored-by: Ben Zickel <ben@vesttoo.com>
1 parent 0d3243a commit 04c371f

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

pyro/ops/stats.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import math
55
import numbers
6-
from typing import List, Tuple, Union
6+
from typing import List, Optional, Tuple, Union
77

88
import torch
99
from torch.fft import irfft, rfft
@@ -510,7 +510,9 @@ def crps_empirical(pred, truth):
510510
return (pred - truth).abs().mean(0) - (diff * weight).sum(0) / num_samples**2
511511

512512

513-
def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
513+
def energy_score_empirical(
514+
pred: torch.Tensor, truth: torch.Tensor, pred_batch_size: Optional[int] = None
515+
) -> torch.Tensor:
514516
r"""
515517
Computes negative Energy Score ES* (see equation 22 in [1]) between a
516518
set of multivariate samples ``pred`` and a true data vector ``truth``. Running time
@@ -538,6 +540,8 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten
538540
The leftmost dim is that of the multivariate sample.
539541
:param torch.Tensor truth: A tensor of true observations with same shape as ``pred`` except
540542
for the second leftmost dim which can have any value or be omitted.
543+
:param int pred_batch_size: If specified the predictions will be batched before calculation
544+
according to the specified batch size in order to reduce memory consumption.
541545
:return: A tensor of shape ``truth.shape``.
542546
:rtype: torch.Tensor
543547
"""
@@ -552,10 +556,44 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten
552556
"Actual shapes: {} versus {}".format(pred.shape, truth.shape)
553557
)
554558

555-
retval = (
556-
torch.cdist(pred, truth).mean(dim=-2)
557-
- 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None]
558-
)
559+
if pred_batch_size is None:
560+
retval = (
561+
torch.cdist(pred, truth).mean(dim=-2)
562+
- 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None]
563+
)
564+
else:
565+
# Divide predictions into batches
566+
pred_len = pred.shape[-2]
567+
pred_batches = []
568+
while pred.numel() > 0:
569+
pred_batches.append(pred[..., :pred_batch_size, :])
570+
pred = pred[..., pred_batch_size:, :]
571+
# Calculate predictions distance to truth
572+
retval = (
573+
torch.stack(
574+
[
575+
torch.cdist(pred_batch, truth).sum(dim=-2)
576+
for pred_batch in pred_batches
577+
],
578+
dim=0,
579+
).sum(dim=0)
580+
/ pred_len
581+
)
582+
# Calculate predictions self distance
583+
for aux_pred_batch in pred_batches:
584+
retval = (
585+
retval
586+
- 0.5
587+
* torch.stack(
588+
[
589+
torch.cdist(pred_batch, aux_pred_batch).sum(dim=[-1, -2])
590+
for pred_batch in pred_batches
591+
],
592+
dim=0,
593+
).sum(dim=0)[..., None]
594+
/ pred_len
595+
/ pred_len
596+
)
559597

560598
if remove_leftmost_dim:
561599
retval = retval[..., 0]

tests/ops/test_stats.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,26 @@ def test_multivariate_energy_score(sample_dim, num_samples=10000):
355355
rtol=0.02,
356356
)
357357
assert energy_score * 1.02 < energy_score_uncorrelated
358+
359+
360+
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)])
361+
@pytest.mark.parametrize("sample_dim", [30, 100])
362+
@pytest.mark.parametrize(
363+
"num_samples, pred_batch_size", [(100, 10), (100, 30), (100, 100), (100, 200)]
364+
)
365+
def test_energy_score_empirical_batched_calculation(
366+
batch_shape, sample_dim, num_samples, pred_batch_size
367+
):
368+
# Generate data
369+
truth = torch.randn(batch_shape + (sample_dim,))
370+
pred = torch.randn(batch_shape + (num_samples, sample_dim))
371+
# Do batched and regular calculation
372+
expected = energy_score_empirical(pred, truth)
373+
actual = energy_score_empirical(pred, truth, pred_batch_size=pred_batch_size)
374+
# Check accuracy
375+
assert_close(actual, expected)
376+
377+
378+
def test_jit_compilation():
379+
# Test that functions can be JIT compiled
380+
torch.jit.script(energy_score_empirical)

0 commit comments

Comments
 (0)