Skip to content

Commit 0368fa2

Browse files
committed
compute FID only on selected samples, following the exp settings in Aug-PE
1 parent 13b8782 commit 0368fa2

File tree

6 files changed

+24
-6
lines changed

6 files changed

+24
-6
lines changed

example/text/openreview_huggingface/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pe.callback import SaveTextToCSV
2626
from pe.logger import CSVPrint
2727
from pe.logger import LogPrint
28+
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME
2829

2930
import pandas as pd
3031
import os
@@ -57,7 +58,9 @@
5758
)
5859

5960
save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
60-
compute_fid = ComputeFID(priv_data=data, embedding=embedding)
61+
compute_fid = ComputeFID(
62+
priv_data=data, embedding=embedding, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1}
63+
)
6164
save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))
6265

6366
csv_print = CSVPrint(output_folder=exp_folder)

example/text/openreview_openai/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pe.callback import SaveTextToCSV
4949
from pe.logger import CSVPrint
5050
from pe.logger import LogPrint
51+
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME
5152

5253
import pandas as pd
5354
import os
@@ -87,7 +88,9 @@
8788
)
8889

8990
save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
90-
compute_fid = ComputeFID(priv_data=data, embedding=embedding)
91+
compute_fid = ComputeFID(
92+
priv_data=data, embedding=embedding, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1}
93+
)
9194
save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))
9295

9396
csv_print = CSVPrint(output_folder=exp_folder)

example/text/pubmed_huggingface/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pe.callback import SaveTextToCSV
2626
from pe.logger import CSVPrint
2727
from pe.logger import LogPrint
28+
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME
2829

2930
import pandas as pd
3031
import os
@@ -57,7 +58,9 @@
5758
)
5859

5960
save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
60-
compute_fid = ComputeFID(priv_data=data, embedding=embedding)
61+
compute_fid = ComputeFID(
62+
priv_data=data, embedding=embedding, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1}
63+
)
6164
save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))
6265

6366
csv_print = CSVPrint(output_folder=exp_folder)

example/text/pubmed_openai/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pe.callback import SaveTextToCSV
4949
from pe.logger import CSVPrint
5050
from pe.logger import LogPrint
51+
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME
5152

5253
import pandas as pd
5354
import os
@@ -87,7 +88,9 @@
8788
)
8889

8990
save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
90-
compute_fid = ComputeFID(priv_data=data, embedding=embedding)
91+
compute_fid = ComputeFID(
92+
priv_data=data, embedding=embedding, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1}
93+
)
9194
save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))
9295

9396
csv_print = CSVPrint(output_folder=exp_folder)

example/text/yelp_huggingface/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pe.callback import SaveTextToCSV
2626
from pe.logger import CSVPrint
2727
from pe.logger import LogPrint
28+
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME
2829

2930
import pandas as pd
3031
import os
@@ -57,7 +58,9 @@
5758
)
5859

5960
save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
60-
compute_fid = ComputeFID(priv_data=data, embedding=embedding)
61+
compute_fid = ComputeFID(
62+
priv_data=data, embedding=embedding, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1}
63+
)
6164
save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))
6265

6366
csv_print = CSVPrint(output_folder=exp_folder)

example/text/yelp_openai/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pe.callback import SaveTextToCSV
4949
from pe.logger import CSVPrint
5050
from pe.logger import LogPrint
51+
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME
5152

5253
import pandas as pd
5354
import os
@@ -87,7 +88,9 @@
8788
)
8889

8990
save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
90-
compute_fid = ComputeFID(priv_data=data, embedding=embedding)
91+
compute_fid = ComputeFID(
92+
priv_data=data, embedding=embedding, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1}
93+
)
9194
save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))
9295

9396
csv_print = CSVPrint(output_folder=exp_folder)

0 commit comments

Comments
 (0)