File tree Expand file tree Collapse file tree 4 files changed +15
-6
lines changed Expand file tree Collapse file tree 4 files changed +15
-6
lines changed Original file line number Diff line number Diff line change 20
20
f"{ datetime .now ().year !s} , Adrien Lafage and Olivier Laurent"
21
21
)
22
22
author = "Adrien Lafage and Olivier Laurent"
23
- release = "0.5.2"
23
+ release = "0.5.2.post0 "
24
24
25
25
# -- General configuration ---------------------------------------------------
26
26
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Original file line number Diff line number Diff line change @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
4
4
5
5
[project ]
6
6
name = " torch_uncertainty"
7
- version = " 0.5.2"
7
+ version = " 0.5.2.post0 "
8
8
authors = [
9
9
{ name = " ENSTA U2IS" , email = " olivier.laurent@ensta-paris.fr" },
10
10
{ name = " Adrien Lafage" , email = " adrienlafage@outlook.com" },
Original file line number Diff line number Diff line change 1
1
import pytest
2
2
import torch
3
3
from einops import repeat
4
+ from torch import nn
4
5
5
6
from torch_uncertainty .layers .functional .packed import (
6
7
packed_in_projection_packed ,
@@ -420,13 +421,15 @@ class TestPackedLayerNorm:
420
421
"""Testing the PackedGroupNorm layer class."""
421
422
422
423
def test_one_estimator_forward (self , batched_qkv : torch .Tensor ) -> None :
424
+ layer_norm = nn .LayerNorm (6 )
423
425
packed_layer_norm = PackedLayerNorm (
424
426
embed_dim = 6 ,
425
427
num_estimators = 1 ,
426
428
alpha = 1 ,
427
429
)
428
- out = packed_layer_norm (batched_qkv )
429
- assert out .shape == torch .Size ([2 , 3 , 6 ])
430
+ pe_out = packed_layer_norm (batched_qkv )
431
+ layer_norm_out = layer_norm (batched_qkv )
432
+ assert torch .allclose (pe_out , layer_norm_out )
430
433
431
434
432
435
class TestPackedMultiheadAttention :
Original file line number Diff line number Diff line change @@ -704,6 +704,9 @@ def __init__(
704
704
Shape:
705
705
- Input: :math:`(N, *)` where :math:`*` means any number of additional dimensions.
706
706
- Output: :math:`(N, *)` (same shape as input)
707
+
708
+ Warnings:
709
+ This layer is only suitable to replace ``nn.LayerNorm`` when only the last dimension is normalized.
707
710
"""
708
711
super ().__init__ (
709
712
num_groups = num_estimators ,
@@ -715,15 +718,18 @@ def __init__(
715
718
)
716
719
717
720
def forward (self , inputs : Tensor ) -> Tensor :
718
- x = rearrange (inputs , "b ... h -> b h ..." )
721
+ shapes = {f"d{ i } " : size for i , size in enumerate (inputs .shape [1 :- 1 ])}
722
+ shape_str = " " .join (shapes .keys ())
723
+
724
+ x = rearrange (inputs , "b ... h -> (b ...) h" )
719
725
x = F .group_norm (
720
726
x ,
721
727
self .num_groups ,
722
728
self .weight ,
723
729
self .bias ,
724
730
self .eps ,
725
731
)
726
- return rearrange (x , "b h ... -> b ... h" )
732
+ return rearrange (x , f"(b { shape_str } ) h -> b { shape_str } h", ** shapes )
727
733
728
734
729
735
class PackedMultiheadAttention (nn .Module ):
You can’t perform that action at this time.
0 commit comments