Skip to content

Commit 441f49c

Browse files
authored
Merge pull request #53 from ai-forever/release-1.0.0
release 1.0.0
2 parents 1822b22 + 0bedc88 commit 441f49c

24 files changed

+1293
-438
lines changed

DPF/filters/images/aesthetic_improved_filter.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,22 @@ def get_improved_aesthetic_model(cache_folder: str) -> MLP:
7373

7474
class ImprovedAestheticFilter(ImageFilter):
7575
"""
76-
ImprovedAestheticFilter class
76+
Filter for improved aesthetic score calculating with LAION model. This repository is used:
77+
https://github.com/christophschuhmann/improved-aesthetic-predictor
78+
79+
Parameters
80+
----------
81+
weights_folder: str
82+
Path to the folder where the weights are located.
83+
If there are no weights, they will be downloaded automatically
84+
device: str = "cuda:0"
85+
Device to use
86+
workers: int = 16
87+
Number of processes to use for reading data and calculating flow scores
88+
batch_size: int = 64
89+
Batch size for model
90+
pbar: bool = True
91+
Whether to use a progress bar
7792
"""
7893

7994
def __init__(

DPF/filters/images/hash_filters.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,16 @@ def get_phash(pil_img: Image.Image, hash_size: int = 8, highfreq_factor: int = 4
3030

3131
class PHashFilter(ImageFilter):
3232
"""
33-
PHashFilter class
33+
Filter for calculating PHash (perceptual hash) for images
34+
35+
Parameters
36+
----------
37+
sim_hash_size: int = 8
38+
Hash size for PHash
39+
workers: int = 16
40+
Number of processes to use for reading data and calculating flow scores
41+
pbar: bool = True
42+
Whether to use a progress bar
3443
"""
3544

3645
def __init__(

DPF/filters/images/info_filter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,14 @@ def get_image_info(img_bytes: bytes, data: dict[str, Any], key_column: str) -> I
5050

5151
class ImageInfoFilter(ImageFilter):
5252
"""
53-
ImageInfoFilter class
53+
Filter for gathering basic info about images (width, height, number of channels)
54+
55+
Parameters
56+
----------
57+
workers: int = 16
58+
Number of parallel dataloader workers
59+
pbar: bool = True
60+
Whether to show progress bar
5461
"""
5562

5663
def __init__(self, workers: int = 16, pbar: bool = True, _pbar_position: int = 0):

DPF/filters/images/nsfw_filter.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626

2727

2828
def load_safety_model(clip_model: str, cache_folder: str, device: Union[str, torch.device]) -> Any:
29-
"""load the safety model"""
30-
3129
gpus = tf.config.list_physical_devices("GPU")
3230
if gpus:
3331
try:
@@ -38,7 +36,7 @@ def load_safety_model(clip_model: str, cache_folder: str, device: Union[str, tor
3836
pass
3937

4038
if clip_model == "ViT-L/14":
41-
model_dir = cache_folder + "/clip_autokeras_binary_nsfw"
39+
model_dir = os.path.join(cache_folder, "clip_autokeras_binary_nsfw")
4240
url_model = (
4341
"https://raw.githubusercontent.com/LAION-AI/"
4442
"CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip"
@@ -48,12 +46,13 @@ def load_safety_model(clip_model: str, cache_folder: str, device: Union[str, tor
4846

4947
if not os.path.exists(model_dir):
5048
os.makedirs(cache_folder, exist_ok=True)
51-
path_to_zip_file = cache_folder + "/clip_autokeras_binary_nsfw.zip"
49+
path_to_zip_file = os.path.join(cache_folder, "clip_autokeras_binary_nsfw.zip")
5250
urlretrieve(url_model, path_to_zip_file)
5351
with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
5452
zip_ref.extractall(cache_folder)
5553

5654
with tf.device(device):
55+
print(model_dir)
5756
loaded_model = load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS)
5857

5958
return loaded_model
@@ -72,7 +71,6 @@ class NSFWFilter(ImageFilter):
7271

7372
def __init__(
7473
self,
75-
clip_model: str,
7674
weights_folder: str,
7775
workers: int = 16,
7876
batch_size: int = 64,
@@ -81,7 +79,7 @@ def __init__(
8179
_pbar_position: int = 0
8280
):
8381
super().__init__(pbar, _pbar_position)
84-
82+
clip_model = "ViT-L/14"
8583
self.num_workers = workers
8684
self.batch_size = batch_size
8785
self.device = device

DPF/filters/images/ocr_filter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
self,
2424
weights_path: str,
2525
model_name: Optional[str] = None,
26+
text_box_col: str = "text_boxes",
2627
device: str = "cuda:0",
2728
workers: int = 16,
2829
pad: int = 5,
@@ -73,7 +74,7 @@ def __init__(
7374

7475
self.AlignCollate = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD)
7576
#
76-
self.text_box_col = "text_boxes"
77+
self.text_box_col = text_box_col
7778
self.ocr_col = f"OCR_{self.model_name}"
7879

7980
@property

DPF/filters/images/ocr_model/model.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,3 @@
1-
"""
2-
Copyright (c) 2019-present NAVER Corp.
3-
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
7-
8-
http://www.apache.org/licenses/LICENSE-2.0
9-
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
"""
16-
171
import torch.nn as nn
182

193
from .modules.feature_extraction import (
@@ -34,14 +18,12 @@ def __init__(self, opt):
3418
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
3519
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
3620

37-
""" Transformation """
3821
if opt.Transformation == 'TPS':
3922
self.Transformation = TPS_SpatialTransformerNetwork(
4023
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
4124
else:
4225
print('No Transformation module specified')
4326

44-
""" FeatureExtraction """
4527
if opt.FeatureExtraction == 'VGG':
4628
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
4729
elif opt.FeatureExtraction == 'RCNN':
@@ -53,7 +35,6 @@ def __init__(self, opt):
5335
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
5436
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
5537

56-
""" Sequence modeling"""
5738
if opt.SequenceModeling == 'BiLSTM':
5839
self.SequenceModeling = nn.Sequential(
5940
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
@@ -63,7 +44,6 @@ def __init__(self, opt):
6344
print('No SequenceModeling module specified')
6445
self.SequenceModeling_output = self.FeatureExtraction_output
6546

66-
""" Prediction """
6747
if opt.Prediction == 'CTC':
6848
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
6949
elif opt.Prediction == 'Attn':
@@ -72,22 +52,18 @@ def __init__(self, opt):
7252
raise Exception('Prediction is neither CTC or Attn')
7353

7454
def forward(self, input, text, is_train=True):
75-
""" Transformation stage """
7655
if not self.stages['Trans'] == "None":
7756
input = self.Transformation(input)
7857

79-
""" Feature extraction stage """
8058
visual_feature = self.FeatureExtraction(input)
8159
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
8260
visual_feature = visual_feature.squeeze(3)
8361

84-
""" Sequence modeling stage """
8562
if self.stages['Seq'] == 'BiLSTM':
8663
contextual_feature = self.SequenceModeling(visual_feature)
8764
else:
8865
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
8966

90-
""" Prediction stage """
9167
if self.stages['Pred'] == 'CTC':
9268
prediction = self.Prediction(contextual_feature.contiguous())
9369
else:

DPF/filters/texts/google_translate_filter.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,26 @@ def translate_batch(translator: BaseTranslator, batch: list[str], delimiter: str
4242

4343
class GoogleTranslateFilter(ColumnFilter):
4444
"""
45-
GoogleTranslateFilter class
45+
Filter for translating texts with google translate api
46+
47+
Parameters
48+
----------
49+
text_column_name: str = "text"
50+
Name of column with texts
51+
source_lang: str = "auto"
52+
Source language to translate from
53+
target_lang: str = "en"
54+
Language to translate to
55+
max_symbols_in_batch: int = 3000
56+
Maximum symbols in one request to API.
57+
timeout: float = 1
58+
Timeout between requests
59+
timeout_on_error: float = 3
60+
Timeout between requests if error occured
61+
num_retries_per_batch: int = 1
62+
Number of retries of errors occured
63+
pbar: bool = True
64+
Whether to use a progress bar
4665
"""
4766

4867
def __init__(

DPF/filters/texts/lang_filter.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77

88
class LangFilter(ColumnFilter):
99
"""
10-
LangFilter class
10+
Filter for text language detection
11+
12+
Parameters
13+
----------
14+
text_column_name: str = "text"
15+
Name of column with texts
16+
workers: int = 16
17+
Number of processes to use
18+
pbar: bool = True
19+
Whether to use a progress bar
1120
"""
1221

1322
def __init__(

0 commit comments

Comments
 (0)