diff --git a/configs/detection/edffnet/edffnet.py b/configs/detection/edffnet/edffnet.py index f42aa1c..f950973 100644 --- a/configs/detection/edffnet/edffnet.py +++ b/configs/detection/edffnet/edffnet.py @@ -2,13 +2,13 @@ model = dict( type='EDFFNet', - backbone=dict(norm_eval=True), + backbone=dict(norm_eval=False), neck=dict( type='DFFPN', in_channels=[256, 512, 1024, 2048], out_channels=256, start_level=1, - add_extra_convs='on_input', + add_extra_convs='on_output', shape_level=2, num_outs=5), enhance_head=dict( @@ -19,37 +19,21 @@ loss_enhance=dict(type='mmdet.L1Loss', loss_weight=0.7), gt_preprocessor=dict( type='lqit.GTPixelPreprocessor', - mean=[128], - std=[57.12], + mean=[123.675], + std=[58.395], pad_size_divisor=32, - element_name='edge'))) + element_name='edge')), +) # dataset settings train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), dict(type='lqit.GetEdgeGTFromImage', method='scharr'), - dict( - type='lqit.TransBroadcaster', - src_key='img', - dst_key='gt_edge', - transforms=[ - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5) - ]), - dict(type='lqit.PackInputs', ) + dict(type='lqit.PackInputs') ] train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) - -param_scheduler = [ - dict( - type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, - end=1000), - dict( - type='MultiStepLR', - begin=0, - end=12, - by_epoch=True, - milestones=[8, 11], - gamma=0.1) -] diff --git a/configs/edit/_base_/datasets/cityscape_enhancement.py b/configs/edit/_base_/datasets/cityscape_enhancement.py new file mode 100644 index 0000000..b623cf0 --- /dev/null +++ b/configs/edit/_base_/datasets/cityscape_enhancement.py @@ -0,0 +1,76 @@ +# dataset settings +dataset_type = 'CityscapeFoggyImageDataset' +data_root = 'data/Datasets/' + +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadGTImageFromFile', file_client_args=file_client_args), + dict( + type='TransBroadcaster', + src_key='img', + dst_key='gt_img', + transforms=[ + dict(type='Resize', scale=(512, 512), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + ]), + dict(type='PackInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadGTImageFromFile', file_client_args=file_client_args), + dict( + type='TransBroadcaster', + src_key='img', + dst_key='gt_img', + transforms=[dict(type='Resize', scale=(512, 512), keep_ratio=True)]), + dict( + type='PackInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + metainfo=dict( + dataset_type='cityscape_enhancement', task_name='enhancement'), + ann_file='cityscape_foggy/train/train.txt', + data_prefix=dict( + img='cityscape_foggy/train/', gt_img='cityscape/train/'), + search_key='img', + img_suffix=dict(img='png', gt_img='png'), + file_client_args=file_client_args, + pipeline=train_pipeline, + split_str='_foggy')) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + test_mode=True, + metainfo=dict( + dataset_type='cityscape_enhancement', task_name='enhancement'), + ann_file='cityscape_foggy/test/test.txt', + data_prefix=dict( + img='cityscape_foggy/test/', gt_img='cityscape/test/'), + search_key='img', + img_suffix=dict(img='png', gt_img='png'), + file_client_args=file_client_args, + pipeline=test_pipeline, + split_str='_foggy')) +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type='MSE', gt_key='img', pred_key='pred_img'), +] +test_evaluator = val_evaluator diff --git a/configs/edit/aodnet/aodnet.py b/configs/edit/aodnet/aodnet.py new file mode 100644 index 0000000..f2ddc76 --- /dev/null +++ b/configs/edit/aodnet/aodnet.py @@ -0,0 +1,39 @@ +_base_ = [ + '../_base_/datasets/cityscape_enhancement.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +model = dict( + type='lqit.BaseEditModel', + data_preprocessor=dict( + type='lqit.EditDataPreprocessor', + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + bgr_to_rgb=True, + pad_size_divisor=32, + gt_name='img'), + generator=dict( + _scope_='lqit', + type='AODNetGenerator', + model=dict(type='AODNet'), + pixel_loss=dict(type='MSELoss', loss_weight=1.0))) + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1) +param_scheduler = [ + dict( + type='LinearLR', + start_factor=0.0001, + by_epoch=False, + begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=10, + by_epoch=True, + milestones=[6, 9], + gamma=0.5) +] + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.0001, momentum=0.9, weight_decay=0.0001)) diff --git a/lqit/detection/datasets/__init__.py b/lqit/detection/datasets/__init__.py index 188b7e3..1a736c5 100644 --- a/lqit/detection/datasets/__init__.py +++ b/lqit/detection/datasets/__init__.py @@ -1,9 +1,4 @@ -from .class_names import * # noqa: F401,F403 from .rtts import RTTSCocoDataset from .urpc import URPCCocoDataset, URPCXMLDataset -from .xml_dataset import XMLDatasetWithMetaFile -__all__ = [ - 'XMLDatasetWithMetaFile', 'URPCCocoDataset', 'URPCXMLDataset', - 'RTTSCocoDataset' -] +__all__ = ['URPCCocoDataset', 'URPCXMLDataset', 'RTTSCocoDataset'] diff --git a/lqit/detection/models/detectors/edffnet.py b/lqit/detection/models/detectors/edffnet.py index cc7df81..161b985 100644 --- a/lqit/detection/models/detectors/edffnet.py +++ b/lqit/detection/models/detectors/edffnet.py @@ -3,11 +3,11 @@ from mmdet.registry import MODELS from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig -from .single_stage_enhance_head import SingleStageWithEnhanceHead +from .single_stage_enhance_head import SingleStageDetector @MODELS.register_module() -class EDFFNet(SingleStageWithEnhanceHead): +class EDFFNet(SingleStageDetector): def __init__(self, backbone: ConfigType, diff --git a/lqit/detection/models/necks/__init__.py b/lqit/detection/models/necks/__init__.py index d463b99..3ee293e 100644 --- a/lqit/detection/models/necks/__init__.py +++ b/lqit/detection/models/necks/__init__.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. from .dffpn import DFFPN __all__ = ['DFFPN'] diff --git a/lqit/edit/datasets/__init__.py b/lqit/edit/datasets/__init__.py index 156046e..f2878d7 100644 --- a/lqit/edit/datasets/__init__.py +++ b/lqit/edit/datasets/__init__.py @@ -1,3 +1,4 @@ from .basic_image_dataset import BasicImageDataset +from .cityscape_foggy_dataset import CityscapeFoggyImageDataset -__all__ = ['BasicImageDataset'] +__all__ = ['BasicImageDataset', 'CityscapeFoggyImageDataset'] diff --git a/lqit/edit/datasets/cityscape_foggy_dataset.py b/lqit/edit/datasets/cityscape_foggy_dataset.py new file mode 100644 index 0000000..1227cbf --- /dev/null +++ b/lqit/edit/datasets/cityscape_foggy_dataset.py @@ -0,0 +1,95 @@ +# Modified from https://github.com/open-mmlab/mmediting/tree/1.x/ +import os.path as osp +from typing import Callable, List, Optional, Union + +from lqit.registry import DATASETS +from .basic_image_dataset import BasicImageDataset + + +@DATASETS.register_module() +class CityscapeFoggyImageDataset(BasicImageDataset): + """CityscapeFoggyImageDataset for pixel-level vision tasks that have + aligned gts. + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for data. Defaults to + dict(img=''). + mapping_table (dict): Mapping table for data. + Defaults to dict(). + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + search_key (str): The key used for searching the folder to get + data_list. Defaults to 'gt'. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to dict(backend='disk'). + img_suffix (str or dict[str]): Image suffix that we are interested in. + Defaults to jpg. + recursive (bool): If set to True, recursively scan the + directory. Defaults to False. + split_str (str): split image name to gt image name. + Defaults to '_foggy'. + """ + + def __init__(self, + ann_file: str = '', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=''), + mapping_table: dict = dict(), + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + search_key: Optional[str] = None, + file_client_args: dict = dict(backend='disk'), + img_suffix: Union[str, dict] = 'jpg', + recursive: bool = False, + split_str: str = '_foggy', + **kwards) -> None: + + self.split_str = split_str + + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + mapping_table=mapping_table, + pipeline=pipeline, + test_mode=test_mode, + search_key=search_key, + file_client_args=file_client_args, + img_suffix=img_suffix, + recursive=recursive, + **kwards) + + def load_data_list(self) -> List[dict]: + """Load data list from folder or annotation file. + + Returns: + list[dict]: A list of annotation. + """ + img_ids = self._get_img_list() + + data_list = [] + # deal with img and gt img path + for img_id in img_ids: + data = dict(key=img_id) + data['img_id'] = img_id + for key in self.data_prefix: + img_id = self.mapping_table[key].format(img_id) + # The gt img name and img name do not match. + # one gt img corresponds to three imgs + if key == 'gt_img': + img_id = img_id.split(self.split_str)[0] + + path = osp.join(self.data_prefix[key], + f'{img_id}.{self.img_suffix[key]}') + data[f'{key}_path'] = path + data_list.append(data) + return data_list diff --git a/lqit/edit/models/editors/__init__.py b/lqit/edit/models/editors/__init__.py index dd58005..aa93705 100644 --- a/lqit/edit/models/editors/__init__.py +++ b/lqit/edit/models/editors/__init__.py @@ -1,2 +1,3 @@ +from .aodnet import * # noqa: F401,F403 from .unet import * # noqa: F401,F403 from .zero_dce import * # noqa: F401,F403 diff --git a/lqit/edit/models/editors/aodnet/__init__.py b/lqit/edit/models/editors/aodnet/__init__.py new file mode 100644 index 0000000..b5c2b65 --- /dev/null +++ b/lqit/edit/models/editors/aodnet/__init__.py @@ -0,0 +1,4 @@ +from .aodnet import AODNet +from .aodnet_generator import AODNetGenerator + +__all__ = ['AODNet', 'AODNetGenerator'] diff --git a/lqit/edit/models/editors/aodnet/aodnet.py b/lqit/edit/models/editors/aodnet/aodnet.py new file mode 100644 index 0000000..dc42378 --- /dev/null +++ b/lqit/edit/models/editors/aodnet/aodnet.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lqit.registry import MODELS + + +@MODELS.register_module() +class AODNet(nn.Module): + """AOD-Net: All-in-One Dehazing Network. + https://ieeexplore.ieee.org/document/8237773""" + + def __init__(self): + super(AODNet, self).__init__() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1) + self.conv2 = nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d( + in_channels=6, out_channels=3, kernel_size=5, padding=2) + self.conv4 = nn.Conv2d( + in_channels=6, out_channels=3, kernel_size=7, padding=3) + self.conv5 = nn.Conv2d( + in_channels=12, out_channels=3, kernel_size=3, padding=1) + self.b = 1 + + def forward(self, x): + x1 = F.relu(self.conv1(x)) + x2 = F.relu(self.conv2(x1)) + cat1 = torch.cat((x1, x2), 1) + x3 = F.relu(self.conv3(cat1)) + cat2 = torch.cat((x2, x3), 1) + x4 = F.relu(self.conv4(cat2)) + cat3 = torch.cat((x1, x2, x3, x4), 1) + k = F.relu(self.conv5(cat3)) + + assert k.size() == x.size(), 'haze image are different size' + + output = k * x - k + self.b + return F.relu(output) diff --git a/lqit/edit/models/editors/aodnet/aodnet_generator.py b/lqit/edit/models/editors/aodnet/aodnet_generator.py new file mode 100644 index 0000000..b106758 --- /dev/null +++ b/lqit/edit/models/editors/aodnet/aodnet_generator.py @@ -0,0 +1,28 @@ +from typing import List + +from lqit.edit.models.base_models import BaseGenerator +from lqit.edit.structures import BatchPixelData +from lqit.registry import MODELS +from lqit.utils.typing import ConfigType, OptMultiConfig + + +@MODELS.register_module() +class AODNetGenerator(BaseGenerator): + + def __init__(self, + model: ConfigType, + pixel_loss: ConfigType = dict( + type='MSELoss', loss_weight=1.0), + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + super().__init__(model=model, pixel_loss=pixel_loss, init_cfg=init_cfg) + + def loss(self, loss_input: BatchPixelData, batch_img_metas: List[dict]): + """Calculate the loss based on the outputs of generator.""" + batch_outputs = loss_input.output + batch_gt = loss_input.gt + + pixel_loss = self.pixel_loss(batch_outputs, batch_gt) + + losses = dict(pixel_loss=pixel_loss) + return losses diff --git a/lqit/edit/models/enhance_heads/edge_head.py b/lqit/edit/models/enhance_heads/edge_head.py new file mode 100644 index 0000000..866a5b9 --- /dev/null +++ b/lqit/edit/models/enhance_heads/edge_head.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from lqit.registry import MODELS +from .base_enhance_head import BaseEnhanceHead + + +@MODELS.register_module() +class EdgeHead(BaseEnhanceHead): + """[conv+GN+relu]*4+1*1conv.""" + + def __init__(self, + in_channels=256, + feat_channels=256, + num_convs=5, + conv_cfg=None, + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + act_cfg=dict(type='ReLU'), + gt_preprocessor=None, + loss_enhance=dict(type='mmdet.L1Loss', loss_weight=1.0), + init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)): + super().__init__( + loss_enhance=loss_enhance, + gt_preprocessor=gt_preprocessor, + init_cfg=init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.num_convs = num_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self._init_layers() + + def _init_layers(self): + assert self.num_convs > 0 + enhance_conv = [] + for i in range(self.num_convs): + in_channels = self.in_channels if i == 0 \ + else self.feat_channels + if i < (self.num_convs - 1): + enhance_conv.append( + ConvModule( + in_channels, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + else: + enhance_conv.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=1, + kernel_size=1, + stride=1, + padding=1)) + self.enhance_conv = nn.Sequential(*enhance_conv) + + def forward(self, x): + if len(x) > 1 and (isinstance(x, tuple) or isinstance(x, list)): + x = x[0] + outs = self.enhance_conv(x) + return outs + + def loss_by_feat(self, enhance_img, gt_imgs, img_metas): + reshape_gt_imgs = F.interpolate( + gt_imgs, size=enhance_img.shape[-2:], mode='bilinear') + enhance_loss = self.loss_enhance(enhance_img, reshape_gt_imgs) + return dict(loss_enhance=enhance_loss)