Skip to content

Commit f6b6b10

Browse files
authored
Merge pull request #186 from ENSTA-U2IS-AI/dev
🐛 Fix `PackedLayerNorm`
2 parents 4634590 + 3afca55 commit f6b6b10

File tree

4 files changed

+15
-6
lines changed

4 files changed

+15
-6
lines changed

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent"
2121
)
2222
author = "Adrien Lafage and Olivier Laurent"
23-
release = "0.5.2"
23+
release = "0.5.2.post0"
2424

2525
# -- General configuration ---------------------------------------------------
2626
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "torch_uncertainty"
7-
version = "0.5.2"
7+
version = "0.5.2.post0"
88
authors = [
99
{ name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" },
1010
{ name = "Adrien Lafage", email = "adrienlafage@outlook.com" },

tests/layers/test_packed.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import torch
33
from einops import repeat
4+
from torch import nn
45

56
from torch_uncertainty.layers.functional.packed import (
67
packed_in_projection_packed,
@@ -420,13 +421,15 @@ class TestPackedLayerNorm:
420421
"""Testing the PackedGroupNorm layer class."""
421422

422423
def test_one_estimator_forward(self, batched_qkv: torch.Tensor) -> None:
424+
layer_norm = nn.LayerNorm(6)
423425
packed_layer_norm = PackedLayerNorm(
424426
embed_dim=6,
425427
num_estimators=1,
426428
alpha=1,
427429
)
428-
out = packed_layer_norm(batched_qkv)
429-
assert out.shape == torch.Size([2, 3, 6])
430+
pe_out = packed_layer_norm(batched_qkv)
431+
layer_norm_out = layer_norm(batched_qkv)
432+
assert torch.allclose(pe_out, layer_norm_out)
430433

431434

432435
class TestPackedMultiheadAttention:

torch_uncertainty/layers/packed.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,9 @@ def __init__(
704704
Shape:
705705
- Input: :math:`(N, *)` where :math:`*` means any number of additional dimensions.
706706
- Output: :math:`(N, *)` (same shape as input)
707+
708+
Warnings:
709+
This layer is only suitable to replace ``nn.LayerNorm`` when only the last dimension is normalized.
707710
"""
708711
super().__init__(
709712
num_groups=num_estimators,
@@ -715,15 +718,18 @@ def __init__(
715718
)
716719

717720
def forward(self, inputs: Tensor) -> Tensor:
718-
x = rearrange(inputs, "b ... h -> b h ...")
721+
shapes = {f"d{i}": size for i, size in enumerate(inputs.shape[1:-1])}
722+
shape_str = " ".join(shapes.keys())
723+
724+
x = rearrange(inputs, "b ... h -> (b ...) h")
719725
x = F.group_norm(
720726
x,
721727
self.num_groups,
722728
self.weight,
723729
self.bias,
724730
self.eps,
725731
)
726-
return rearrange(x, "b h ... -> b ... h")
732+
return rearrange(x, f"(b {shape_str}) h -> b {shape_str} h", **shapes)
727733

728734

729735
class PackedMultiheadAttention(nn.Module):

0 commit comments

Comments
 (0)