@@ -66,7 +66,137 @@ def example_reading_spec(self, label_key=None):
66
66
return data_fields , data_items_to_decoders
67
67
68
68
69
- # French street names dataset.
69
+ @registry .register_problem ("image_celeba_tune" )
70
+ class ImageCeleba (ImageProblem ):
71
+ """CelebA dataset, aligned and cropped images."""
72
+ IMG_DATA = ("img_align_celeba.zip" ,
73
+ "https://drive.google.com/uc?export=download&"
74
+ "id=0B7EVK8r0v71pZjFTYXZWM3FlRnM" )
75
+ LANDMARKS_DATA = ("celeba_landmarks_align" ,
76
+ "https://drive.google.com/uc?export=download&"
77
+ "id=0B7EVK8r0v71pd0FJY3Blby1HUTQ" )
78
+ ATTR_DATA = ("celeba_attr" , "https://drive.google.com/uc?export=download&"
79
+ "id=0B7EVK8r0v71pblRyaVFSWGxPY0U" )
80
+
81
+ LANDMARK_HEADINGS = ("lefteye_x lefteye_y righteye_x righteye_y "
82
+ "nose_x nose_y leftmouth_x leftmouth_y rightmouth_x "
83
+ "rightmouth_y" ).split ()
84
+ ATTR_HEADINGS = (
85
+ "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs "
86
+ "Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair "
87
+ "Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair "
88
+ "Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache "
89
+ "Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline "
90
+ "Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings "
91
+ "Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young"
92
+ ).split ()
93
+
94
+ def preprocess_examples (self , examples , unused_mode , unused_hparams ):
95
+
96
+ def resize (img , size ):
97
+ return tf .to_int64 (
98
+ tf .image .resize_images (img , [size , size ], tf .image .ResizeMethod .AREA ))
99
+
100
+ inputs = examples ["inputs" ]
101
+ # Remove boundaries in CelebA images. Remove 40 pixels each side
102
+ # vertically and 20 pixels each side horizontally.
103
+ inputs = tf .image .crop_to_bounding_box (inputs , 40 , 20 , 218 - 80 , 178 - 40 )
104
+ examples ["inputs" ] = resize (inputs , 8 )
105
+ examples ["targets" ] = resize (inputs , 32 )
106
+ return examples
107
+
108
+ def hparams (self , defaults , model_hparams ):
109
+ p = defaults
110
+ p .input_modality = {"inputs" : ("image:identity_no_pad" , None )}
111
+ p .target_modality = ("image:identity_no_pad" , None )
112
+ p .batch_size_multiplier = 256
113
+ p .max_expected_batch_size_per_shard = 4
114
+ p .input_space_id = 1
115
+ p .target_space_id = 1
116
+
117
+ def generator (self , tmp_dir , how_many , start_from = 0 ):
118
+ """Image generator for CELEBA dataset.
119
+
120
+ Args:
121
+ tmp_dir: path to temporary storage directory.
122
+ how_many: how many images and labels to generate.
123
+ start_from: from which image to start.
124
+
125
+ Yields:
126
+ A dictionary representing the images with the following fields:
127
+ * image/encoded: the string encoding the image as JPEG,
128
+ * image/format: the string "jpeg" representing image format,
129
+ """
130
+ out_paths = []
131
+ for fname , url in [self .IMG_DATA , self .LANDMARKS_DATA , self .ATTR_DATA ]:
132
+ path = generator_utils .maybe_download_from_drive (tmp_dir , fname , url )
133
+ out_paths .append (path )
134
+
135
+ img_path , landmarks_path , attr_path = out_paths # pylint: disable=unbalanced-tuple-unpacking
136
+ unzipped_folder = img_path [:- 4 ]
137
+ if not tf .gfile .Exists (unzipped_folder ):
138
+ zipfile .ZipFile (img_path , "r" ).extractall (tmp_dir )
139
+
140
+ with tf .gfile .Open (landmarks_path ) as f :
141
+ landmarks_raw = f .read ()
142
+
143
+ with tf .gfile .Open (attr_path ) as f :
144
+ attr_raw = f .read ()
145
+
146
+ def process_landmarks (raw_data ):
147
+ landmarks = {}
148
+ lines = raw_data .split ("\n " )
149
+ headings = lines [1 ].strip ().split ()
150
+ for line in lines [2 :- 1 ]:
151
+ values = line .strip ().split ()
152
+ img_name = values [0 ]
153
+ landmark_values = [int (v ) for v in values [1 :]]
154
+ landmarks [img_name ] = landmark_values
155
+ return landmarks , headings
156
+
157
+ def process_attrs (raw_data ):
158
+ attrs = {}
159
+ lines = raw_data .split ("\n " )
160
+ headings = lines [1 ].strip ().split ()
161
+ for line in lines [2 :- 1 ]:
162
+ values = line .strip ().split ()
163
+ img_name = values [0 ]
164
+ attr_values = [int (v ) for v in values [1 :]]
165
+ attrs [img_name ] = attr_values
166
+ return attrs , headings
167
+
168
+ img_landmarks , _ = process_landmarks (landmarks_raw )
169
+ img_attrs , _ = process_attrs (attr_raw )
170
+
171
+ image_files = tf .gfile .Glob (unzipped_folder + "/*.jpg" )
172
+ for filename in image_files [start_from :start_from + how_many ]:
173
+ img_name = os .path .basename (filename )
174
+ landmarks = img_landmarks [img_name ]
175
+ attrs = img_attrs [img_name ]
176
+
177
+ with tf .gfile .Open (filename , "r" ) as f :
178
+ encoded_image_data = f .read ()
179
+ yield {
180
+ "image/encoded" : [encoded_image_data ],
181
+ "image/format" : ["jpeg" ],
182
+ "attributes" : attrs ,
183
+ "landmarks" : landmarks ,
184
+ }
185
+
186
+ @property
187
+ def train_shards (self ):
188
+ return 100
189
+
190
+ @property
191
+ def dev_shards (self ):
192
+ return 10
193
+
194
+ def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
195
+ generator_utils .generate_dataset_and_shuffle (
196
+ self .generator (tmp_dir , 162770 ), # train
197
+ self .training_filepaths (data_dir , self .train_shards , shuffled = False ),
198
+ self .generator (tmp_dir , 19867 , 162770 ), # dev
199
+ self .dev_filepaths (data_dir , self .dev_shards , shuffled = False ))
70
200
71
201
72
202
@registry .register_problem
@@ -199,7 +329,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
199
329
"instructions at https://github.com/tensorflow/models/blob/master"
200
330
"/inception/README.md#getting-started" )
201
331
202
- def preprocess_examples (self , examples , mode ):
332
+ def preprocess_examples (self , examples , mode , _ ):
203
333
return imagenet_preprocess_examples (examples , mode )
204
334
205
335
@@ -638,7 +768,7 @@ def train_shards(self):
638
768
def dev_shards (self ):
639
769
return 10
640
770
641
- def preprocess_examples (self , examples , mode ):
771
+ def preprocess_examples (self , examples , mode , _ ):
642
772
return imagenet_preprocess_examples (examples , mode )
643
773
644
774
def generator (self , data_dir , tmp_dir , is_training ):
@@ -700,41 +830,3 @@ class ImageMsCocoTokens32k(ImageMsCocoTokens8k):
700
830
@property
701
831
def targeted_vocab_size (self ):
702
832
return 2 ** 15 # 32768
703
-
704
-
705
- # URL and filename for CELEBA data.
706
- _CELEBA_NAME = "img_align_celeba"
707
- _CELEBA_URL = "https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM"
708
-
709
-
710
- def _get_celeba (directory ):
711
- """Download and extract CELEBA to directory unless it is there."""
712
- # path = os.path.join(directory, _CELEBA_NAME)
713
- path = generator_utils .maybe_download_from_drive (directory , _CELEBA_NAME ,
714
- _CELEBA_URL )
715
- if not tf .gfile .Exists (path ):
716
- zipfile .ZipFile (path + ".zip" , "r" ).extractall (directory )
717
-
718
-
719
- def celeba_generator (tmp_dir , how_many , start_from = 0 ):
720
- """Image generator for CELEBA dataset.
721
-
722
- Args:
723
- tmp_dir: path to temporary storage directory.
724
- how_many: how many images and labels to generate.
725
- start_from: from which image to start.
726
-
727
- Yields:
728
- A dictionary representing the images with the following fields:
729
- * image/encoded: the string encoding the image as JPEG,
730
- * image/format: the string "jpeg" representing image format,
731
- """
732
- _get_celeba (tmp_dir )
733
- image_files = tf .gfile .Glob (os .path .join (tmp_dir , _CELEBA_NAME ) + "/*.jpg" )
734
- for filename in image_files [start_from :start_from + how_many ]:
735
- with tf .gfile .Open (filename , "r" ) as f :
736
- encoded_image_data = f .read ()
737
- yield {
738
- "image/encoded" : [encoded_image_data ],
739
- "image/format" : ["jpeg" ],
740
- }
0 commit comments