Skip to content

Commit 036ca71

Browse files
committed
Add class_id argument in sampler.py
1 parent c061cb5 commit 036ca71

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

sampler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ def get_args():
136136
required=True,
137137
help="Path to yaml config file",
138138
)
139+
parser.add_argument(
140+
"--class_id",
141+
type=int,
142+
default=None,
143+
help="Number up to 1000 that corresponds to a class",
144+
)
139145

140146
return parser.parse_args()
141147

@@ -166,7 +172,11 @@ def main():
166172
model.load_state_dict(torch.load(args.checkpoint_path, map_location="cpu"))
167173
model = model.eval().to(device)
168174

169-
y = torch.ones(args.batch_size, dtype=torch.int).to(device) * 3
175+
y = (
176+
torch.ones(args.batch_size, dtype=torch.int).to(device) * args.class_id
177+
if args.class_id is not None
178+
else None
179+
)
170180
autoencoder = (
171181
get_autoencoder(config["autoencoder"]["autoencoder_checkpoint_path"])
172182
if "autoencoder" in config

scripts/sample.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ python sampler.py \
99
--seed 1 \
1010
--config_path $config_path \
1111
--checkpoint_path $checkpoint_path \
12-
--output_folder $output_folder
12+
--output_folder $output_folder \
13+
--class_id 3

0 commit comments

Comments
 (0)