Skip to content

Commit 2a9cd17

Browse files
committed
* 2022年4月23日
- 为了便于使用和配置,对大量代码进行了调整和修改,与之前版本相比,使用上也存在部分差异。 - 评估部分: - 支持多个方法的json文件同时使用评估。 - 更新了指标统计类,便于更灵活的指定不同的指标。 - 对一些医学二值分割的指标提供了支持。 - 绘图部分: - 支持多个曲线npy文件同时用于绘图。 - 将个性化配置尽可能独立出来,提供了独立的绘图配置文件。 - 重构了绘图类,便于使用yaml文件对matplotlib的默认设定进行覆盖。
1 parent 897612b commit 2a9cd17

14 files changed

+756
-154
lines changed

.gitignore

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# Big files
22
**/*.png
3+
**/*.pdf
4+
**/*.jpg
5+
**/*.bmp
36
**/*.zip
47
**/*.7z
58
**/*.rar
9+
**/*.tar*
610

711
# Byte-compiled / optimized / DLL files
812
__pycache__/
@@ -274,7 +278,7 @@ gen
274278
/output/
275279
/untracked/
276280
/configs/
277-
/*.py
278-
/*.sh
281+
# /*.py
282+
# /*.sh
279283
/results/rgb_sod.md
280284
/results/htmls/*.html

eval.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# -*- coding: utf-8 -*-
2+
import argparse
3+
import os
4+
import textwrap
5+
import warnings
6+
7+
from metrics import cal_sod_matrics
8+
from utils.generate_info import get_datasets_info, get_methods_info
9+
from utils.misc import make_dir
10+
from utils.recorders import METRIC_MAPPING
11+
12+
13+
def get_args():
14+
parser = argparse.ArgumentParser(
15+
description=textwrap.dedent(
16+
r"""
17+
INCLUDE:
18+
19+
- F-measure-Threshold Curve
20+
- Precision-Recall Curve
21+
- MAE
22+
- weighted F-measure
23+
- S-measure
24+
- max/average/adaptive F-measure
25+
- max/average/adaptive E-measure
26+
- max/average Precision
27+
- max/average Sensitivity
28+
- max/average Specificity
29+
- max/average F-measure
30+
- max/average Dice
31+
- max/average IoU
32+
33+
NOTE:
34+
35+
- Our method automatically calculates the intersection of `pre` and `gt`.
36+
- Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext`
37+
38+
EXAMPLES:
39+
40+
python eval_all.py \
41+
--dataset-json configs/datasets/json/rgbd_sod.json \
42+
--method-json configs/methods/json/rgbd_other_methods.json configs/methods/json/rgbd_our_method.json --metric-npy output/rgbd_metrics.npy \
43+
--curves-npy output/rgbd_curves.npy \
44+
--record-tex output/rgbd_results.txt
45+
"""
46+
),
47+
formatter_class=argparse.RawTextHelpFormatter,
48+
)
49+
parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.")
50+
parser.add_argument(
51+
"--method-json", required=True, nargs="+", type=str, help="Json file for methods."
52+
)
53+
parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.")
54+
parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.")
55+
parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.")
56+
parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.")
57+
parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.")
58+
parser.add_argument(
59+
"--include-methods",
60+
type=str,
61+
nargs="+",
62+
help="Names of only specific methods you want to evaluate.",
63+
)
64+
parser.add_argument(
65+
"--exclude-methods",
66+
type=str,
67+
nargs="+",
68+
help="Names of some specific methods you do not want to evaluate.",
69+
)
70+
parser.add_argument(
71+
"--include-datasets",
72+
type=str,
73+
nargs="+",
74+
help="Names of only specific datasets you want to evaluate.",
75+
)
76+
parser.add_argument(
77+
"--exclude-datasets",
78+
type=str,
79+
nargs="+",
80+
help="Names of some specific datasets you do not want to evaluate.",
81+
)
82+
parser.add_argument(
83+
"--num-workers",
84+
type=int,
85+
default=4,
86+
help="Number of workers for multi-threading or multi-processing. Default: 4",
87+
)
88+
parser.add_argument(
89+
"--num-bits",
90+
type=int,
91+
default=3,
92+
help="Number of decimal places for showing results. Default: 3",
93+
)
94+
parser.add_argument(
95+
"--metric-names",
96+
type=str,
97+
nargs="+",
98+
default=["mae", "fm", "em", "sm", "wfm"],
99+
choices=METRIC_MAPPING.keys(),
100+
help="Names of metrics",
101+
)
102+
args = parser.parse_args()
103+
104+
if args.metric_npy is not None:
105+
make_dir(os.path.dirname(args.metric_npy))
106+
if args.curves_npy is not None:
107+
make_dir(os.path.dirname(args.curves_npy))
108+
if args.record_txt is not None:
109+
make_dir(os.path.dirname(args.record_txt))
110+
if args.record_xlsx is not None:
111+
make_dir(os.path.dirname(args.record_xlsx))
112+
if args.to_overwrite and not args.record_txt:
113+
warnings.warn("--to-overwrite only works with a valid --record-txt")
114+
return args
115+
116+
117+
def main():
118+
args = get_args()
119+
120+
# 包含所有数据集信息的字典
121+
datasets_info = get_datasets_info(
122+
datastes_info_json=args.dataset_json,
123+
include_datasets=args.include_datasets,
124+
exclude_datasets=args.exclude_datasets,
125+
)
126+
# 包含所有待比较模型结果的信息的字典
127+
methods_info = get_methods_info(
128+
methods_info_jsons=args.method_json,
129+
for_drawing=True,
130+
include_methods=args.include_methods,
131+
exclude_methods=args.exclude_methods,
132+
)
133+
134+
# 确保多进程在windows上也可以正常使用
135+
cal_sod_matrics.cal_sod_matrics(
136+
sheet_name="Results",
137+
to_append=not args.to_overwrite,
138+
txt_path=args.record_txt,
139+
xlsx_path=args.record_xlsx,
140+
methods_info=methods_info,
141+
datasets_info=datasets_info,
142+
curves_npy_path=args.curves_npy,
143+
metrics_npy_path=args.metric_npy,
144+
num_bits=args.num_bits,
145+
num_workers=args.num_workers,
146+
use_mp=False,
147+
metric_names=args.metric_names,
148+
ncols_tqdm=119,
149+
)
150+
151+
152+
if __name__ == "__main__":
153+
main()

metrics/cal_sod_matrics.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from utils.misc import get_gt_pre_with_name, get_name_list, make_dir
1212
from utils.print_formatter import formatter_for_tabulate
13-
from utils.recorders import MetricExcelRecorder, MetricRecorder, TxtRecorder
13+
from utils.recorders import MetricExcelRecorder, MetricRecorder_V2, TxtRecorder
1414

1515

1616
class Recorder:
@@ -80,6 +80,8 @@ def cal_sod_matrics(
8080
num_bits: int = 3,
8181
num_workers: int = 2,
8282
use_mp: bool = False,
83+
metric_names: tuple = ("mae", "fm", "em", "sm", "wfm"),
84+
ncols_tqdm: int = 79,
8385
):
8486
"""
8587
Save the results of all models on different datasets in a `npy` file in the form of a
@@ -112,6 +114,8 @@ def cal_sod_matrics(
112114
:param num_bits: the number of bits used to format results
113115
:param num_workers: the number of workers of multiprocessing or multithreading
114116
:param use_mp: using multiprocessing or multithreading
117+
:param metric_names: names of metrics
118+
:param ncols_tqdm: number of columns for tqdm
115119
"""
116120
recorder = Recorder(
117121
txt_path=txt_path,
@@ -181,6 +185,8 @@ def cal_sod_matrics(
181185
desc=f"[{dataset_name}({len(gt_name_list)}):{method_name}({len(pre_name_list)})]",
182186
proc_idx=procs_idx,
183187
blocking=use_mp,
188+
metric_names=metric_names,
189+
ncols_tqdm=ncols_tqdm,
184190
),
185191
callback=partial(recorder.record, method_name=method_name),
186192
)
@@ -211,16 +217,18 @@ def evaluate_data(
211217
desc="",
212218
proc_idx=None,
213219
blocking=True,
220+
metric_names=None,
221+
ncols_tqdm=79,
214222
):
215-
metric_recoder = MetricRecorder()
223+
metric_recoder = MetricRecorder_V2(metric_names=metric_names)
216224
# https://github.com/tqdm/tqdm#parameters
217225
# https://github.com/tqdm/tqdm/blob/master/examples/parallel_bars.py
218226
tqdm_bar = tqdm(
219227
names,
220228
total=len(names),
221229
desc=desc,
222230
position=proc_idx,
223-
ncols=79,
231+
ncols=ncols_tqdm,
224232
lock_args=None if blocking else (False,),
225233
)
226234
for name in tqdm_bar:

0 commit comments

Comments
 (0)