File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -136,6 +136,12 @@ def get_args():
136
136
required = True ,
137
137
help = "Path to yaml config file" ,
138
138
)
139
+ parser .add_argument (
140
+ "--class_id" ,
141
+ type = int ,
142
+ default = None ,
143
+ help = "Number up to 1000 that corresponds to a class" ,
144
+ )
139
145
140
146
return parser .parse_args ()
141
147
@@ -166,7 +172,11 @@ def main():
166
172
model .load_state_dict (torch .load (args .checkpoint_path , map_location = "cpu" ))
167
173
model = model .eval ().to (device )
168
174
169
- y = torch .ones (args .batch_size , dtype = torch .int ).to (device ) * 3
175
+ y = (
176
+ torch .ones (args .batch_size , dtype = torch .int ).to (device ) * args .class_id
177
+ if args .class_id is not None
178
+ else None
179
+ )
170
180
autoencoder = (
171
181
get_autoencoder (config ["autoencoder" ]["autoencoder_checkpoint_path" ])
172
182
if "autoencoder" in config
Original file line number Diff line number Diff line change @@ -9,4 +9,5 @@ python sampler.py \
9
9
--seed 1 \
10
10
--config_path $config_path \
11
11
--checkpoint_path $checkpoint_path \
12
- --output_folder $output_folder
12
+ --output_folder $output_folder \
13
+ --class_id 3
You can’t perform that action at this time.
0 commit comments