Skip to content
6 changes: 5 additions & 1 deletion configs/detection/edffnet/edffnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
src_key='img',
dst_key='gt_edge',
transforms=[
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='Resize', scale=(256, 256), keep_ratio=True),
dict(type='RandomFlip', prob=0.5)
]),
dict(type='lqit.PackInputs', )
Expand All @@ -53,3 +53,7 @@
milestones=[8, 11],
gamma=0.1)
]

optim_wrapper = dict(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001))
75 changes: 75 additions & 0 deletions configs/edit/_base_/datasets/cityscape_enhancement_with_anno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# dataset settings
dataset_type = 'mmdet.CityscapesDataset'
data_root = 'data/Datasets/'

file_client_args = dict(backend='disk')

train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadGTImageFromFile', file_client_args=file_client_args),
dict(
type='TransBroadcaster',
src_key='img',
dst_key='gt_img',
transforms=[
dict(type='Resize', scale=(512, 512), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
]),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadGTImageFromFile', file_client_args=file_client_args),
dict(
type='TransBroadcaster',
src_key='img',
dst_key='gt_img',
transforms=[dict(type='Resize', scale=(512, 512), keep_ratio=True)]),
dict(
type='PackInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='lqit.DatasetWithClearImageWrapper',
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='cityscape_foggy/annotations_json/instancesonly_filtered_gtFine_train.json',
data_prefix=dict(img='cityscape_foggy/train/', gt_img_path='cityscape/train/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline),
suffix='png'
))

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='lqit.DatasetWithClearImageWrapper',
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
indices=100,
ann_file='cityscape_foggy/annotations_json/instancesonly_filtered_gtFine_test.json',
data_prefix=dict(img='cityscape_foggy/test/', gt_img_path='cityscape/test/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=test_pipeline),
suffix='png'
))

test_dataloader = val_dataloader

val_evaluator = [
dict(type='MSE', gt_key='img', pred_key='pred_img'),
]
test_evaluator = val_evaluator
76 changes: 76 additions & 0 deletions configs/edit/_base_/datasets/cityscape_enhancement_with_txt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# dataset settings
dataset_type = 'CityscapeFoggyImageDataset'
data_root = 'data/Datasets/'

file_client_args = dict(backend='disk')

train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadGTImageFromFile', file_client_args=file_client_args),
dict(
type='TransBroadcaster',
src_key='img',
dst_key='gt_img',
transforms=[
dict(type='Resize', scale=(512, 512), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
]),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadGTImageFromFile', file_client_args=file_client_args),
dict(
type='TransBroadcaster',
src_key='img',
dst_key='gt_img',
transforms=[dict(type='Resize', scale=(512, 512), keep_ratio=True)]),
dict(
type='PackInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
metainfo=dict(
dataset_type='cityscape_enhancement', task_name='enhancement'),
ann_file='cityscape_foggy/train/train.txt',
data_prefix=dict(img='cityscape_foggy/train/', gt_img='cityscape/train/'),
search_key='img',
img_suffix=dict(img='png', gt_img='png'),
file_client_args=file_client_args,
pipeline=train_pipeline,
split_str='_foggy'
))
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
metainfo=dict(
dataset_type='cityscape_enhancement', task_name='enhancement'),
ann_file='cityscape_foggy/test/test.txt',
data_prefix=dict(img='cityscape_foggy/test/', gt_img='cityscape/test/'),
search_key='img',
img_suffix=dict(img='png', gt_img='png'),
file_client_args=file_client_args,
pipeline=test_pipeline,
split_str='_foggy'
))
test_dataloader = val_dataloader

val_evaluator = [
dict(type='MSE', gt_key='img', pred_key='pred_img'),
]
test_evaluator = val_evaluator
39 changes: 39 additions & 0 deletions configs/edit/aodnet/aodnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_base_ = [
'../_base_/datasets/cityscape_enhancement_with_txt.py',
'../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
model = dict(
type='lqit.BaseEditModel',
data_preprocessor=dict(
type='lqit.EditDataPreprocessor',
mean=[0.0, 0.0, 0.0],
std=[255.0, 255.0, 255.0],
bgr_to_rgb=True,
pad_size_divisor=32,
gt_name='img'),
generator=dict(
_scope_='lqit',
type='AODNetGenerator',
model=dict(type='AODNet'),
pixel_loss=dict(type='MSELoss', loss_weight=1.0)
))


train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1)
param_scheduler = [
dict(
type='LinearLR', start_factor=0.0001, by_epoch=False, begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=10,
by_epoch=True,
milestones=[6, 9],
gamma=0.1)
]

optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001))
4 changes: 2 additions & 2 deletions lqit/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .data_preprocessor import * # noqa: F401,F403
from .dataset_wrappers import DatasetWithGTImageWrapper
from .dataset_wrappers import DatasetWithGTImageWrapper, DatasetWithClearImageWrapper
from .structures import * # noqa: F401,F403
from .transforms import * # noqa: F401,F403

__all__ = ['DatasetWithGTImageWrapper']
__all__ = ['DatasetWithGTImageWrapper', 'DatasetWithClearImageWrapper']
114 changes: 114 additions & 0 deletions lqit/common/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,117 @@ def parse_gt_img_info(self, data_info: dict) -> Union[dict, List[dict]]:
f'.{self.suffix}'
data_info['gt_img_path'] = osp.join(gt_img_root, img_name)
return data_info



@DATASETS.register_module()
class DatasetWithClearImageWrapper:
"""Dataset wrapper for image dehazing task. Add `gt_image_path` simultaneously.

Args:
dataset (BaseDataset or dict): The dataset
suffix (str): gt_image suffix. Defaults to 'jpg'.
lazy_init (bool, optional): whether to load annotation during
instantiation. Defaults to False
"""

def __init__(self,
dataset: Union[BaseDataset, dict],
suffix: str = 'jpg',
lazy_init: bool = False) -> None:
self.suffix = suffix
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
self._metainfo = self.dataset.metainfo

self._fully_initialized = False
if not lazy_init:
self.full_init()

@property
def metainfo(self) -> dict:
"""Get the meta information of the repeated dataset.

Returns:
dict: The meta information of repeated dataset.
"""
return self._metainfo

def full_init(self):
self.dataset.full_init()

def get_data_info(self, idx: int) -> dict:
return self.dataset.get_data_info(idx)

def prepare_data(self, idx) -> Any:
"""Get data processed by ``self.pipeline``.

Args:
idx (int): The index of ``data_info``.

Returns:
Any: Depends on ``self.pipeline``.
"""
data_info = self.get_data_info(idx)
data_info = self.parse_gt_img_info(data_info)
return self.dataset.pipeline(data_info)

def __getitem__(self, idx):
if not self.dataset._fully_initialized:
warnings.warn(
'Please call `full_init()` method manually to accelerate '
'the speed.')
self.dataset.full_init()

if self.dataset.test_mode:
data = self.prepare_data(idx)
if data is None:
raise Exception('Test time pipline should not get `None` '
'data_sample')
return data

for _ in range(self.dataset.max_refetch + 1):
data = self.prepare_data(idx)
# Broken images or random augmentations may cause the returned data
# to be None
if data is None:
idx = self.dataset._rand_another()
continue
return data

raise Exception(f'Cannot find valid image after {self.max_refetch}! '
'Please check your image path and pipeline')

def __len__(self):
return len(self.dataset)

def parse_gt_img_info(self, data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.

Args:
raw_data_info (dict): Raw data information load from ``ann_file``

Returns:
Union[dict, List[dict]]: Parsed annotation.
"""

gt_img_root = self.dataset.data_prefix.get('gt_img_path', None)

if gt_img_root is None:
warnings.warn(
'Cannot get gt_img_root, please set `gt_img_path` in '
'`dataset.data_prefix`')
data_info['gt_img_path'] = data_info['img_path']
else:
img_name = \
f"{osp.split(data_info['img_path'])[0].split('/')[-1]}" + '/'\
f"{osp.split(data_info['img_path'])[-1].split('_foggy_')[0]}" \
f'.{self.suffix}'
data_info['gt_img_path'] = osp.join(gt_img_root, img_name)
return data_info
3 changes: 2 additions & 1 deletion lqit/edit/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .basic_image_dataset import BasicImageDataset
from .cityscape_foggy_dataset import CityscapeFoggyImageDataset

__all__ = ['BasicImageDataset']
__all__ = ['BasicImageDataset', 'CityscapeFoggyImageDataset']
Loading