Skip to content

Commit 4634590

Browse files
authored
Merge pull request #184 from ENSTA-U2IS-AI/dev
🔨 Rework FrostImages dataset & Rename a plotting function
2 parents f0b5dea + f22f039 commit 4634590

File tree

8 files changed

+28
-65
lines changed

8 files changed

+28
-65
lines changed

auto_tutorial_source/Bayesian_Methods/tutorial_bayesian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
#
120120
# Now that the model is trained, let's test it on MNIST.
121121
# Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble
122-
# and to the batch. As for TorchUncertainty 0.5.1, the ensemble dimension is merged with the batch dimension
122+
# and to the batch. As for TorchUncertainty 0.5.2, the ensemble dimension is merged with the batch dimension
123123
# in this order (num_estimator x batch, classes).
124124

125125
import matplotlib.pyplot as plt

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent"
2121
)
2222
author = "Adrien Lafage and Olivier Laurent"
23-
release = "0.5.1"
23+
release = "0.5.2"
2424

2525
# -- General configuration ---------------------------------------------------
2626
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "torch_uncertainty"
7-
version = "0.5.1"
7+
version = "0.5.2"
88
authors = [
99
{ name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" },
1010
{ name = "Adrien Lafage", email = "adrienlafage@outlook.com" },
@@ -41,7 +41,7 @@ dependencies = [
4141

4242
[project.optional-dependencies]
4343
experiments = ["tensorboard", "huggingface-hub>=0.31", "safetensors"]
44-
image = ["kornia", "h5py", "opencv-python"]
44+
image = ["kornia", "h5py", "opencv-python", "torch-uncertainty-assets"]
4545
tabular = ["pandas"]
4646
dev = [
4747
"torch_uncertainty[experiments,image]",

torch_uncertainty/datasets/frost.py

Lines changed: 18 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,43 @@
1-
import logging
21
from collections.abc import Callable
2+
from importlib import util
3+
from importlib.abc import Traversable
4+
from importlib.resources import files
35
from pathlib import Path
46
from typing import Any
57

68
from PIL import Image
79
from torchvision.datasets import VisionDataset
8-
from torchvision.datasets.utils import (
9-
check_integrity,
10-
download_and_extract_archive,
11-
)
1210

11+
FROST_ASSETS_MOD = "torch_uncertainty_assets.frost"
12+
tu_assets_installed = util.find_spec("torch_uncertainty_assets")
1313

14-
def pil_loader(path: Path) -> Image.Image:
15-
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
14+
15+
def pil_loader(path: Path | Traversable) -> Image.Image:
1616
with path.open("rb") as f:
1717
img = Image.open(f)
1818
return img.convert("RGB")
1919

2020

21-
class FrostImages(VisionDataset): # TODO: Use ImageFolder
22-
url = "https://zenodo.org/records/10438904/files/frost.zip"
23-
zip_md5 = "d82f29f620d43a68e71e34b28f7c35cb"
24-
filename = "frost.zip"
25-
samples = [
26-
"frost1.png",
27-
"frost2.png",
28-
"frost3.jpg",
29-
"frost4.jpg",
30-
"frost5.jpg",
31-
]
32-
21+
class FrostImages(VisionDataset):
3322
def __init__(
3423
self,
35-
root: str | Path,
36-
transform: Callable[..., Any] | None,
24+
transform: Callable[..., Any] | None = None,
3725
target_transform: Callable[..., Any] | None = None,
38-
download: bool = False,
3926
) -> None:
40-
self.root = Path(root)
41-
42-
if download:
43-
self.download()
44-
45-
if not self._check_integrity():
46-
raise RuntimeError(
47-
"Dataset not found or corrupted. You can use download=True to download it."
27+
if not tu_assets_installed: # coverage: ignore
28+
raise ImportError(
29+
"The torch-uncertainty-assets library is not installed. Please install"
30+
"torch_uncertainty with the image option:"
31+
"""pip install -U "torch_uncertainty[image]"."""
4832
)
49-
5033
super().__init__(
51-
self.root / "frost",
34+
FROST_ASSETS_MOD,
5235
transform=transform,
5336
target_transform=target_transform,
5437
)
5538
self.loader = pil_loader
56-
57-
def _check_integrity(self) -> bool:
58-
fpath = self.root / self.filename
59-
return check_integrity(
60-
fpath,
61-
self.zip_md5,
62-
)
63-
64-
def download(self) -> None:
65-
if self._check_integrity():
66-
logging.info("Files already downloaded and verified")
67-
return
68-
69-
download_and_extract_archive(
70-
self.url,
71-
download_root=self.root,
72-
filename=self.filename,
73-
md5=self.zip_md5,
74-
)
75-
logging.info("Downloaded %s to %s.", self.filename, self.root)
39+
sample_path = files(FROST_ASSETS_MOD)
40+
self.samples = [sample_path.joinpath(f"frost{i}.jpg") for i in range(1, 6)]
7641

7742
def __getitem__(self, index: int) -> Any:
7843
"""Get the samples of the dataset.
@@ -83,8 +48,7 @@ def __getitem__(self, index: int) -> Any:
8348
Returns:
8449
tuple: (sample, target) where target is class_index of the target class.
8550
"""
86-
path = self.root / self.samples[index]
87-
sample = self.loader(path)
51+
sample = self.loader(self.samples[index])
8852
if self.transform is not None:
8953
sample = self.transform(sample)
9054
return sample

torch_uncertainty/routines/segmentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838
from torch_uncertainty.post_processing import PostProcessing
3939
from torch_uncertainty.utils import csv_writer
40-
from torch_uncertainty.utils.plotting import show
40+
from torch_uncertainty.utils.plotting import show_segmentation_predictions
4141

4242

4343
class SegmentationRoutine(LightningModule):
@@ -421,7 +421,7 @@ def log_segmentation_plots(self) -> None:
421421

422422
self.logger.experiment.add_figure(
423423
f"Segmentation results/{i}",
424-
show(pred_mask, gt_mask),
424+
show_segmentation_predictions(pred_mask, gt_mask),
425425
)
426426

427427
def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:

torch_uncertainty/transforms/corruption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def __init__(self, severity: int, seed: int | None = None) -> None:
584584
super().__init__(severity)
585585
self.rng = np.random.default_rng(seed)
586586
self.mix = [(1, 0.4), (0.8, 0.6), (0.7, 0.7), (0.65, 0.7), (0.6, 0.75)][severity - 1]
587-
self.frost_ds = FrostImages("./data", download=True, transform=ToTensor())
587+
self.frost_ds = FrostImages(transform=ToTensor())
588588

589589
def forward(self, img: Tensor) -> Tensor:
590590
if self.severity == 0:

torch_uncertainty/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
from .evaluation_loop import TUEvaluationLoop
66
from .hub import load_hf
77
from .misc import csv_writer
8-
from .plotting import plot_hist, show
8+
from .plotting import plot_hist, show_segmentation_predictions
99
from .trainer import TUTrainer
1010
from .transforms import interpolation_modes_from_str

torch_uncertainty/utils/plotting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from torch import Tensor
88

99

10-
def show(prediction: Tensor, target: Tensor) -> Figure:
10+
def show_segmentation_predictions(prediction: Tensor, target: Tensor) -> Figure:
1111
imgs = [prediction, target]
12-
fig, axs = plt.subplots(ncols=len(imgs), figsize=(12, 6))
12+
fig, axs = plt.subplots(ncols=2, figsize=(12, 6), dpi=300)
1313
for i, img in enumerate(imgs):
1414
img = img.detach()
1515
img = F.to_pil_image(img)
@@ -18,7 +18,6 @@ def show(prediction: Tensor, target: Tensor) -> Figure:
1818

1919
axs[0].set(title="Prediction")
2020
axs[1].set(title="Ground Truth")
21-
2221
return fig
2322

2423

0 commit comments

Comments
 (0)