Skip to content

Commit b3da0f1

Browse files
committed
support multi turn questions
1 parent fbaf4f7 commit b3da0f1

File tree

5 files changed

+165
-14
lines changed

5 files changed

+165
-14
lines changed

llmc/compression/token_reduction/random.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from types import MethodType
44

55
import torch
6+
from loguru import logger
67

78
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
89

@@ -62,22 +63,37 @@ def input_hook(module, input_args, pruning_paras):
6263
@prefill_wrapper
6364
def random_pruning_hook(module, args, kwargs, pruning_paras):
6465

66+
logger.info(' ========random_pruning_hook======== ')
67+
6568
rate = pruning_paras['rate']
6669
image_token_start_index = pruning_paras['image_token_start_index']
6770
image_token_length = pruning_paras['image_token_length']
6871

6972
hidden_states = args[0]
7073
causal_mask = kwargs['attention_mask']
7174

75+
logger.info(f'before hidden_states : {hidden_states.shape}')
76+
7277
device = hidden_states.device
73-
vision_indexes = torch.arange(
74-
image_token_start_index,
75-
image_token_start_index + image_token_length,
76-
device=device,
77-
)
78-
num_keep = round(image_token_length * (1 - rate))
79-
rand_idx = torch.randperm(image_token_length, device=device)[:num_keep]
80-
vision_indexes = vision_indexes[rand_idx]
78+
79+
if self.model.first_turn_question:
80+
logger.info(' -----first_turn_question-----')
81+
vision_indexes = torch.arange(
82+
image_token_start_index,
83+
image_token_start_index + image_token_length,
84+
device=device,
85+
)
86+
num_keep = round(image_token_length * (1 - rate))
87+
rand_idx = torch.randperm(image_token_length, device=device)[:num_keep]
88+
vision_indexes = vision_indexes[rand_idx]
89+
90+
# save vision_indexes to module
91+
module.register_buffer('vision_indexes', vision_indexes)
92+
else:
93+
logger.info(' -----not first_turn_question-----')
94+
# load vision_indexes from module (prompt cache)
95+
vision_indexes = module.vision_indexes
96+
8197
# keep index
8298
keep_indexs = torch.cat(
8399
(
@@ -115,6 +131,7 @@ def random_pruning_hook(module, args, kwargs, pruning_paras):
115131
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
116132
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)
117133

134+
logger.info(f'after hidden_states : {hidden_states.shape}')
118135
return (hidden_states,), kwargs
119136

120137
if self.model.__class__.__name__ == 'LlavaHf':

llmc/eval/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .eval_acc import AccuracyEval
22
from .eval_code import HumanEval
33
from .eval_custom_generate import CustomGenerate
4+
from .eval_custom_generate_just_infer import CustomGenerateJustInfer
45
from .eval_ppl import DecodePerplexityEval, PerplexityEval
56
from .eval_token_consist import TokenConsistencyEval
67
from .eval_video_generate import VideoGenerateEval
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import glob
2+
import json
3+
import os
4+
5+
import torch
6+
from human_eval.data import stream_jsonl, write_jsonl
7+
from human_eval.evaluation import evaluate_functional_correctness
8+
from loguru import logger
9+
from tqdm import tqdm
10+
11+
from .eval_base import BaseEval
12+
13+
14+
class CustomGenerateJustInfer:
15+
def __init__(self, model, config):
16+
self.model = model
17+
self.config = config
18+
self.eval_cfg = config.eval
19+
20+
@torch.no_grad()
21+
def eval(self, model, eval_pos=None):
22+
logger.info('start inference')
23+
24+
with open(os.path.join(self.eval_cfg.path, 'samples.json'), 'r') as f:
25+
questions_list = json.load(f)
26+
27+
custom_samples_ans = self.model.eval_custom_samples_just_infer(
28+
questions_list,
29+
self.eval_cfg
30+
)
31+
32+
with open(os.path.join('custom_samples_ans.json'), 'w') as f:
33+
json.dump(custom_samples_ans, f, indent=4)
34+
35+
torch.cuda.empty_cache()
36+
return 'custom gen done.'

llmc/eval/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from loguru import logger
55

6-
from llmc.eval import (AccuracyEval, CustomGenerate, DecodePerplexityEval,
7-
HumanEval, PerplexityEval, TokenConsistencyEval,
8-
VideoGenerateEval, VQAEval)
6+
from llmc.eval import (AccuracyEval, CustomGenerate, CustomGenerateJustInfer,
7+
DecodePerplexityEval, HumanEval, PerplexityEval,
8+
TokenConsistencyEval, VideoGenerateEval, VQAEval)
99
from llmc.utils import deploy_all_modality
1010

1111

@@ -57,6 +57,8 @@ def get_eval_list(model, config):
5757
eval_class = HumanEval(model, config_for_eval)
5858
elif config_tmp.eval.type == 'generate_only':
5959
eval_class = CustomGenerate(model, config_for_eval)
60+
elif config_tmp.eval.type == 'just_infer':
61+
eval_class = CustomGenerateJustInfer(model, config_for_eval)
6062
elif config_tmp.eval.type == 'token_acc':
6163
eval_class = TokenConsistencyEval(model, config_for_eval)
6264
elif config_tmp.eval.type == 'ppl':

llmc/models/llava.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import types
23
from datetime import timedelta
34
from typing import Optional, Union
@@ -9,6 +10,7 @@
910
from lmms_eval.models.llava import Llava as LLaVA
1011
from loguru import logger
1112
from packaging import version
13+
from PIL import Image
1214
from transformers import AutoConfig, AutoTokenizer
1315

1416
from llmc.utils.registry_factory import MODEL_REGISTRY
@@ -17,8 +19,11 @@
1719

1820
try:
1921
from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
20-
DEFAULT_IMAGE_PATCH_TOKEN, IMAGE_TOKEN_INDEX)
21-
from llava.mm_utils import get_model_name_from_path
22+
DEFAULT_IMAGE_PATCH_TOKEN,
23+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
24+
from llava.conversation import SeparatorStyle, conv_templates
25+
from llava.mm_utils import (get_model_name_from_path, process_images,
26+
tokenizer_image_token)
2227
from llava.model.builder import load_pretrained_model
2328
from llava.model.language_model.llava_llama import LlavaConfig
2429
except Exception as e:
@@ -45,7 +50,7 @@ def build_model(self):
4550
self.vlm_model_config.use_cache = True
4651
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
4752

48-
self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model(
53+
self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model(
4954
self.model_path,
5055
None,
5156
get_model_name_from_path(self.model_path),
@@ -137,6 +142,96 @@ def get_subsets_in_block(self, block):
137142
else:
138143
raise Exception(f'Llava do not support {self.get_modality()} modality.')
139144

145+
def eval_custom_samples_just_infer(
146+
self,
147+
img_qas,
148+
eval_cfg
149+
): # noqa
150+
151+
custom_samples_ans = img_qas.copy()
152+
153+
self.vlm_model.cuda()
154+
155+
def load_image(image_file):
156+
image = Image.open(image_file).convert('RGB')
157+
return image
158+
159+
def load_images(image_files):
160+
out = []
161+
for image_file in image_files:
162+
image = load_image(image_file)
163+
out.append(image)
164+
return out
165+
166+
self.first_turn_question = True
167+
168+
for data_idx, questions in enumerate(img_qas):
169+
self.first_turn_question = True
170+
171+
custom_samples_ans[data_idx]['answer'] = []
172+
173+
image_files = questions['image']
174+
image_files = [os.path.join(eval_cfg.path, 'images', image_file) for image_file in image_files] # noqa
175+
images = load_images(image_files)
176+
image_sizes = [x.size for x in images]
177+
images_tensor = process_images(
178+
images,
179+
self.image_processor,
180+
self.vlm_model.config
181+
).to(self.vlm_model.device, dtype=torch.float16)
182+
183+
input_ids_old = None
184+
185+
for question_idx, question in enumerate(questions['question']):
186+
187+
conv_mode = 'llava_v1'
188+
conv = conv_templates[conv_mode].copy()
189+
if question_idx > 0:
190+
conv.system = ''
191+
qs = question
192+
self.first_turn_question = False
193+
else:
194+
qs = DEFAULT_IMAGE_TOKEN + '\n' + question
195+
conv.append_message(conv.roles[0], qs)
196+
conv.append_message(conv.roles[1], None)
197+
prompt = conv.get_prompt()
198+
199+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() # noqa
200+
# print(f"input_ids 1: {input_ids}, {input_ids.shape}")
201+
if input_ids_old is not None:
202+
input_ids = torch.cat((input_ids_old, input_ids), dim=1)
203+
# print(f"input_ids 2: {input_ids}, {input_ids.shape}")
204+
205+
with torch.inference_mode():
206+
output_ids = self.vlm_model.generate(
207+
input_ids,
208+
attention_mask=input_ids.new_ones(input_ids.shape, dtype=torch.bool),
209+
images=images_tensor,
210+
image_sizes=image_sizes,
211+
do_sample=False,
212+
top_p=None,
213+
num_beams=1,
214+
max_new_tokens=eval_cfg.max_new_tokens,
215+
use_cache=True,
216+
)
217+
218+
# print(f"output_ids: {output_ids}, {output_ids.shape}")
219+
220+
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
221+
222+
print('--------------------------------')
223+
print(f'data_idx: {data_idx}')
224+
print(f'question_idx: {question_idx}')
225+
print(f'question: {question}')
226+
print(f'outputs: {outputs}')
227+
print('--------------------------------')
228+
229+
custom_samples_ans[data_idx]['answer'].append(outputs[0])
230+
231+
input_ids_old = torch.cat((input_ids, output_ids), dim=1)
232+
233+
return custom_samples_ans
234+
140235

141236
if version.parse(torch.__version__) >= version.parse('2.1.2'):
142237
best_fit_attn_implementation = 'sdpa'

0 commit comments

Comments
 (0)