1
+ import os
1
2
import types
2
3
from datetime import timedelta
3
4
from typing import Optional , Union
9
10
from lmms_eval .models .llava import Llava as LLaVA
10
11
from loguru import logger
11
12
from packaging import version
13
+ from PIL import Image
12
14
from transformers import AutoConfig , AutoTokenizer
13
15
14
16
from llmc .utils .registry_factory import MODEL_REGISTRY
17
19
18
20
try :
19
21
from llava .constants import (DEFAULT_IM_END_TOKEN , DEFAULT_IM_START_TOKEN ,
20
- DEFAULT_IMAGE_PATCH_TOKEN , IMAGE_TOKEN_INDEX )
21
- from llava .mm_utils import get_model_name_from_path
22
+ DEFAULT_IMAGE_PATCH_TOKEN ,
23
+ DEFAULT_IMAGE_TOKEN , IMAGE_TOKEN_INDEX )
24
+ from llava .conversation import SeparatorStyle , conv_templates
25
+ from llava .mm_utils import (get_model_name_from_path , process_images ,
26
+ tokenizer_image_token )
22
27
from llava .model .builder import load_pretrained_model
23
28
from llava .model .language_model .llava_llama import LlavaConfig
24
29
except Exception as e :
@@ -45,7 +50,7 @@ def build_model(self):
45
50
self .vlm_model_config .use_cache = True
46
51
logger .info (f'self.vlm_model_config : { self .vlm_model_config } ' )
47
52
48
- self .tokenizer , self .vlm_model , image_processor , context_len = load_pretrained_model (
53
+ self .tokenizer , self .vlm_model , self . image_processor , context_len = load_pretrained_model (
49
54
self .model_path ,
50
55
None ,
51
56
get_model_name_from_path (self .model_path ),
@@ -137,6 +142,96 @@ def get_subsets_in_block(self, block):
137
142
else :
138
143
raise Exception (f'Llava do not support { self .get_modality ()} modality.' )
139
144
145
+ def eval_custom_samples_just_infer (
146
+ self ,
147
+ img_qas ,
148
+ eval_cfg
149
+ ): # noqa
150
+
151
+ custom_samples_ans = img_qas .copy ()
152
+
153
+ self .vlm_model .cuda ()
154
+
155
+ def load_image (image_file ):
156
+ image = Image .open (image_file ).convert ('RGB' )
157
+ return image
158
+
159
+ def load_images (image_files ):
160
+ out = []
161
+ for image_file in image_files :
162
+ image = load_image (image_file )
163
+ out .append (image )
164
+ return out
165
+
166
+ self .first_turn_question = True
167
+
168
+ for data_idx , questions in enumerate (img_qas ):
169
+ self .first_turn_question = True
170
+
171
+ custom_samples_ans [data_idx ]['answer' ] = []
172
+
173
+ image_files = questions ['image' ]
174
+ image_files = [os .path .join (eval_cfg .path , 'images' , image_file ) for image_file in image_files ] # noqa
175
+ images = load_images (image_files )
176
+ image_sizes = [x .size for x in images ]
177
+ images_tensor = process_images (
178
+ images ,
179
+ self .image_processor ,
180
+ self .vlm_model .config
181
+ ).to (self .vlm_model .device , dtype = torch .float16 )
182
+
183
+ input_ids_old = None
184
+
185
+ for question_idx , question in enumerate (questions ['question' ]):
186
+
187
+ conv_mode = 'llava_v1'
188
+ conv = conv_templates [conv_mode ].copy ()
189
+ if question_idx > 0 :
190
+ conv .system = ''
191
+ qs = question
192
+ self .first_turn_question = False
193
+ else :
194
+ qs = DEFAULT_IMAGE_TOKEN + '\n ' + question
195
+ conv .append_message (conv .roles [0 ], qs )
196
+ conv .append_message (conv .roles [1 ], None )
197
+ prompt = conv .get_prompt ()
198
+
199
+ input_ids = tokenizer_image_token (prompt , self .tokenizer , IMAGE_TOKEN_INDEX , return_tensors = 'pt' ).unsqueeze (0 ).cuda () # noqa
200
+ # print(f"input_ids 1: {input_ids}, {input_ids.shape}")
201
+ if input_ids_old is not None :
202
+ input_ids = torch .cat ((input_ids_old , input_ids ), dim = 1 )
203
+ # print(f"input_ids 2: {input_ids}, {input_ids.shape}")
204
+
205
+ with torch .inference_mode ():
206
+ output_ids = self .vlm_model .generate (
207
+ input_ids ,
208
+ attention_mask = input_ids .new_ones (input_ids .shape , dtype = torch .bool ),
209
+ images = images_tensor ,
210
+ image_sizes = image_sizes ,
211
+ do_sample = False ,
212
+ top_p = None ,
213
+ num_beams = 1 ,
214
+ max_new_tokens = eval_cfg .max_new_tokens ,
215
+ use_cache = True ,
216
+ )
217
+
218
+ # print(f"output_ids: {output_ids}, {output_ids.shape}")
219
+
220
+ outputs = self .tokenizer .batch_decode (output_ids , skip_special_tokens = True )
221
+
222
+ print ('--------------------------------' )
223
+ print (f'data_idx: { data_idx } ' )
224
+ print (f'question_idx: { question_idx } ' )
225
+ print (f'question: { question } ' )
226
+ print (f'outputs: { outputs } ' )
227
+ print ('--------------------------------' )
228
+
229
+ custom_samples_ans [data_idx ]['answer' ].append (outputs [0 ])
230
+
231
+ input_ids_old = torch .cat ((input_ids , output_ids ), dim = 1 )
232
+
233
+ return custom_samples_ans
234
+
140
235
141
236
if version .parse (torch .__version__ ) >= version .parse ('2.1.2' ):
142
237
best_fit_attn_implementation = 'sdpa'
0 commit comments