3
3
4
4
import math
5
5
import numbers
6
- from typing import List , Tuple , Union
6
+ from typing import List , Optional , Tuple , Union
7
7
8
8
import torch
9
9
from torch .fft import irfft , rfft
@@ -510,7 +510,9 @@ def crps_empirical(pred, truth):
510
510
return (pred - truth ).abs ().mean (0 ) - (diff * weight ).sum (0 ) / num_samples ** 2
511
511
512
512
513
- def energy_score_empirical (pred : torch .Tensor , truth : torch .Tensor ) -> torch .Tensor :
513
+ def energy_score_empirical (
514
+ pred : torch .Tensor , truth : torch .Tensor , pred_batch_size : Optional [int ] = None
515
+ ) -> torch .Tensor :
514
516
r"""
515
517
Computes negative Energy Score ES* (see equation 22 in [1]) between a
516
518
set of multivariate samples ``pred`` and a true data vector ``truth``. Running time
@@ -538,6 +540,8 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten
538
540
The leftmost dim is that of the multivariate sample.
539
541
:param torch.Tensor truth: A tensor of true observations with same shape as ``pred`` except
540
542
for the second leftmost dim which can have any value or be omitted.
543
+ :param int pred_batch_size: If specified the predictions will be batched before calculation
544
+ according to the specified batch size in order to reduce memory consumption.
541
545
:return: A tensor of shape ``truth.shape``.
542
546
:rtype: torch.Tensor
543
547
"""
@@ -552,10 +556,44 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten
552
556
"Actual shapes: {} versus {}" .format (pred .shape , truth .shape )
553
557
)
554
558
555
- retval = (
556
- torch .cdist (pred , truth ).mean (dim = - 2 )
557
- - 0.5 * torch .cdist (pred , pred ).mean (dim = [- 1 , - 2 ])[..., None ]
558
- )
559
+ if pred_batch_size is None :
560
+ retval = (
561
+ torch .cdist (pred , truth ).mean (dim = - 2 )
562
+ - 0.5 * torch .cdist (pred , pred ).mean (dim = [- 1 , - 2 ])[..., None ]
563
+ )
564
+ else :
565
+ # Divide predictions into batches
566
+ pred_len = pred .shape [- 2 ]
567
+ pred_batches = []
568
+ while pred .numel () > 0 :
569
+ pred_batches .append (pred [..., :pred_batch_size , :])
570
+ pred = pred [..., pred_batch_size :, :]
571
+ # Calculate predictions distance to truth
572
+ retval = (
573
+ torch .stack (
574
+ [
575
+ torch .cdist (pred_batch , truth ).sum (dim = - 2 )
576
+ for pred_batch in pred_batches
577
+ ],
578
+ dim = 0 ,
579
+ ).sum (dim = 0 )
580
+ / pred_len
581
+ )
582
+ # Calculate predictions self distance
583
+ for aux_pred_batch in pred_batches :
584
+ retval = (
585
+ retval
586
+ - 0.5
587
+ * torch .stack (
588
+ [
589
+ torch .cdist (pred_batch , aux_pred_batch ).sum (dim = [- 1 , - 2 ])
590
+ for pred_batch in pred_batches
591
+ ],
592
+ dim = 0 ,
593
+ ).sum (dim = 0 )[..., None ]
594
+ / pred_len
595
+ / pred_len
596
+ )
559
597
560
598
if remove_leftmost_dim :
561
599
retval = retval [..., 0 ]
0 commit comments