Skip to content
4 changes: 4 additions & 0 deletions configs/detection/edffnet/edffnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@
milestones=[8, 11],
gamma=0.1)
]

optim_wrapper = dict(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001))
76 changes: 76 additions & 0 deletions configs/edit/_base_/datasets/cityscape_enhancement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# dataset settings
dataset_type = 'CityscapeFoggyImageDataset'
data_root = 'data/Datasets/'

file_client_args = dict(backend='disk')

train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadGTImageFromFile', file_client_args=file_client_args),
dict(
type='TransBroadcaster',
src_key='img',
dst_key='gt_img',
transforms=[
dict(type='Resize', scale=(512, 512), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
]),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadGTImageFromFile', file_client_args=file_client_args),
dict(
type='TransBroadcaster',
src_key='img',
dst_key='gt_img',
transforms=[dict(type='Resize', scale=(512, 512), keep_ratio=True)]),
dict(
type='PackInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
metainfo=dict(
dataset_type='cityscape_enhancement', task_name='enhancement'),
ann_file='cityscape_foggy/train/train.txt',
data_prefix=dict(
img='cityscape_foggy/train/', gt_img='cityscape/train/'),
search_key='img',
img_suffix=dict(img='png', gt_img='png'),
file_client_args=file_client_args,
pipeline=train_pipeline,
split_str='_foggy'))
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
metainfo=dict(
dataset_type='cityscape_enhancement', task_name='enhancement'),
ann_file='cityscape_foggy/test/test.txt',
data_prefix=dict(
img='cityscape_foggy/test/', gt_img='cityscape/test/'),
search_key='img',
img_suffix=dict(img='png', gt_img='png'),
file_client_args=file_client_args,
pipeline=test_pipeline,
split_str='_foggy'))
test_dataloader = val_dataloader

val_evaluator = [
dict(type='MSE', gt_key='img', pred_key='pred_img'),
]
test_evaluator = val_evaluator
39 changes: 39 additions & 0 deletions configs/edit/aodnet/aodnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_base_ = [
'../_base_/datasets/cityscape_enhancement.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='lqit.BaseEditModel',
data_preprocessor=dict(
type='lqit.EditDataPreprocessor',
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
bgr_to_rgb=True,
pad_size_divisor=32,
gt_name='img'),
generator=dict(
_scope_='lqit',
type='AODNetGenerator',
model=dict(type='AODNet'),
pixel_loss=dict(type='MSELoss', loss_weight=1.0)))

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1)
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=False,
begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=10,
by_epoch=True,
milestones=[6, 9],
gamma=0.5)
]

optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='Adam', lr=0.0001, momentum=0.9, weight_decay=0.0001))
3 changes: 2 additions & 1 deletion lqit/edit/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .basic_image_dataset import BasicImageDataset
from .cityscape_foggy_dataset import CityscapeFoggyImageDataset

__all__ = ['BasicImageDataset']
__all__ = ['BasicImageDataset', 'CityscapeFoggyImageDataset']
95 changes: 95 additions & 0 deletions lqit/edit/datasets/cityscape_foggy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Modified from https://github.com/open-mmlab/mmediting/tree/1.x/
import os.path as osp
from typing import Callable, List, Optional, Union

from lqit.registry import DATASETS
from .basic_image_dataset import BasicImageDataset


@DATASETS.register_module()
class CityscapeFoggyImageDataset(BasicImageDataset):
"""CityscapeFoggyImageDataset for pixel-level vision tasks that have
aligned gts.

Args:
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for data. Defaults to
dict(img='').
mapping_table (dict): Mapping table for data.
Defaults to dict().
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
search_key (str): The key used for searching the folder to get
data_list. Defaults to 'gt'.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
img_suffix (str or dict[str]): Image suffix that we are interested in.
Defaults to jpg.
recursive (bool): If set to True, recursively scan the
directory. Defaults to False.
split_str (str): split image name to gt image name.
Defaults to '_foggy'.
"""

def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(img=''),
mapping_table: dict = dict(),
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
search_key: Optional[str] = None,
file_client_args: dict = dict(backend='disk'),
img_suffix: Union[str, dict] = 'jpg',
recursive: bool = False,
split_str: str = '_foggy',
**kwards) -> None:

self.split_str = split_str

super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
mapping_table=mapping_table,
pipeline=pipeline,
test_mode=test_mode,
search_key=search_key,
file_client_args=file_client_args,
img_suffix=img_suffix,
recursive=recursive,
**kwards)

def load_data_list(self) -> List[dict]:
"""Load data list from folder or annotation file.

Returns:
list[dict]: A list of annotation.
"""
img_ids = self._get_img_list()

data_list = []
# deal with img and gt img path
for img_id in img_ids:
data = dict(key=img_id)
data['img_id'] = img_id
for key in self.data_prefix:
img_id = self.mapping_table[key].format(img_id)
# The gt img name and img name do not match.
# one gt img corresponds to three imgs
if key == 'gt_img':
img_id = img_id.split(self.split_str)[0]

path = osp.join(self.data_prefix[key],
f'{img_id}.{self.img_suffix[key]}')
data[f'{key}_path'] = path
data_list.append(data)
return data_list
1 change: 1 addition & 0 deletions lqit/edit/models/editors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .aodnet import * # noqa: F401,F403
from .unet import * # noqa: F401,F403
from .zero_dce import * # noqa: F401,F403
4 changes: 4 additions & 0 deletions lqit/edit/models/editors/aodnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .aodnet import AODNet
from .aodnet_generator import AODNetGenerator

__all__ = ['AODNet', 'AODNetGenerator']
40 changes: 40 additions & 0 deletions lqit/edit/models/editors/aodnet/aodnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from lqit.registry import MODELS


@MODELS.register_module()
class AODNet(nn.Module):
"""AOD-Net: All-in-One Dehazing Network.
https://ieeexplore.ieee.org/document/8237773"""

def __init__(self):
super(AODNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1)
self.conv2 = nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(
in_channels=6, out_channels=3, kernel_size=5, padding=2)
self.conv4 = nn.Conv2d(
in_channels=6, out_channels=3, kernel_size=7, padding=3)
self.conv5 = nn.Conv2d(
in_channels=12, out_channels=3, kernel_size=3, padding=1)
self.b = 1

def forward(self, x):
x1 = F.relu(self.conv1(x))
x2 = F.relu(self.conv2(x1))
cat1 = torch.cat((x1, x2), 1)
x3 = F.relu(self.conv3(cat1))
cat2 = torch.cat((x2, x3), 1)
x4 = F.relu(self.conv4(cat2))
cat3 = torch.cat((x1, x2, x3, x4), 1)
k = F.relu(self.conv5(cat3))

if k.size() != x.size():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert k.size() == x.size()

raise Exception('haze image are different size!')

output = k * x - k + self.b
return F.relu(output)
27 changes: 27 additions & 0 deletions lqit/edit/models/editors/aodnet/aodnet_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import List

from lqit.edit.models.base_models import BaseGenerator
from lqit.edit.structures import BatchPixelData
from lqit.registry import MODELS
from lqit.utils.typing import ConfigType, OptMultiConfig


@MODELS.register_module()
class AODNetGenerator(BaseGenerator):

def __init__(self,
model: ConfigType,
pixel_loss: ConfigType = dict(
type='MSELoss', loss_weight=1.0),
init_cfg: OptMultiConfig = None) -> None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**kwarg is needed

super().__init__(model=model, pixel_loss=pixel_loss, init_cfg=init_cfg)

def loss(self, loss_input: BatchPixelData, batch_img_metas: List[dict]):
"""Calculate the loss based on the outputs of generator."""
batch_outputs = loss_input.output
batch_gt = loss_input.gt

pixel_loss = self.pixel_loss(batch_outputs, batch_gt)

losses = dict(pixel_loss=pixel_loss)
return losses