Skip to content

Commit 9174cfe

Browse files
authored
Support changing the function used to calculate the euclidean distance in energy_score_empirical. (#3431)
* Support changing the function used to calculate the euclidean distance in energy_score_empirical. * Specify cdist when testing JIT compilation of energy_score_empirical. * Drop support for JIT compilation of energy_support_empirical (test was ineffective anyhow). * Remove unused import.
1 parent 428fba7 commit 9174cfe

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

pyro/ops/stats.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import math
55
import numbers
6-
from typing import List, Optional, Tuple, Union
6+
from typing import Callable, List, Optional, Tuple, Union
77

88
import torch
99
from torch.fft import irfft, rfft
@@ -511,7 +511,10 @@ def crps_empirical(pred, truth):
511511

512512

513513
def energy_score_empirical(
514-
pred: torch.Tensor, truth: torch.Tensor, pred_batch_size: Optional[int] = None
514+
pred: torch.Tensor,
515+
truth: torch.Tensor,
516+
pred_batch_size: Optional[int] = None,
517+
cdist: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.cdist,
515518
) -> torch.Tensor:
516519
r"""
517520
Computes negative Energy Score ES* (see equation 22 in [1]) between a
@@ -542,6 +545,10 @@ def energy_score_empirical(
542545
for the second leftmost dim which can have any value or be omitted.
543546
:param int pred_batch_size: If specified the predictions will be batched before calculation
544547
according to the specified batch size in order to reduce memory consumption.
548+
:param callable cdist: Function for calculating an euclidean distance (see
549+
https://github.com/pytorch/pytorch/issues/42479 for why you might need to change this in order to
550+
balance speed versus accuracy). Default is :any:`torch.cdist`.
551+
545552
:return: A tensor of shape ``truth.shape``.
546553
:rtype: torch.Tensor
547554
"""
@@ -558,8 +565,8 @@ def energy_score_empirical(
558565

559566
if pred_batch_size is None:
560567
retval = (
561-
torch.cdist(pred, truth).mean(dim=-2)
562-
- 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None]
568+
cdist(pred, truth).mean(dim=-2)
569+
- 0.5 * cdist(pred, pred).mean(dim=[-1, -2])[..., None]
563570
)
564571
else:
565572
# Divide predictions into batches
@@ -571,10 +578,7 @@ def energy_score_empirical(
571578
# Calculate predictions distance to truth
572579
retval = (
573580
torch.stack(
574-
[
575-
torch.cdist(pred_batch, truth).sum(dim=-2)
576-
for pred_batch in pred_batches
577-
],
581+
[cdist(pred_batch, truth).sum(dim=-2) for pred_batch in pred_batches],
578582
dim=0,
579583
).sum(dim=0)
580584
/ pred_len
@@ -586,7 +590,7 @@ def energy_score_empirical(
586590
- 0.5
587591
* torch.stack(
588592
[
589-
torch.cdist(pred_batch, aux_pred_batch).sum(dim=[-1, -2])
593+
cdist(pred_batch, aux_pred_batch).sum(dim=[-1, -2])
590594
for pred_batch in pred_batches
591595
],
592596
dim=0,

tests/ops/test_stats.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,3 @@ def test_energy_score_empirical_batched_calculation(
373373
actual = energy_score_empirical(pred, truth, pred_batch_size=pred_batch_size)
374374
# Check accuracy
375375
assert_close(actual, expected)
376-
377-
378-
def test_jit_compilation():
379-
# Test that functions can be JIT compiled
380-
torch.jit.script(energy_score_empirical)

0 commit comments

Comments
 (0)