Skip to content

Commit 2433fdc

Browse files
authored
fixes
1 parent 4f9e36e commit 2433fdc

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

src/database_interactions.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from utilities import my_cprint, get_model_native_precision, get_appropriate_dtype, supports_flash_attention
3030
from constants import VECTOR_MODELS
3131

32-
logging.basicConfig(level=logging.CRITICAL, force=True)
32+
logging.basicConfig(level=logging.INFO, force=True)
3333
# logging.basicConfig(level=logging.DEBUG, force=True)
3434
logger = logging.getLogger(__name__)
3535

@@ -51,6 +51,7 @@ def prepare_encode_kwargs(self):
5151
def create(self):
5252
prepared_kwargs = self.prepare_kwargs()
5353
prepared_encode_kwargs = self.prepare_encode_kwargs()
54+
5455
return HuggingFaceEmbeddings(
5556
model_name=self.model_name,
5657
show_progress=not self.is_query,
@@ -143,12 +144,42 @@ def prepare_kwargs(self):
143144
return stella_kwargs
144145

145146

147+
# class AlibabaEmbedding(BaseEmbeddingModel):
148+
# def prepare_kwargs(self):
149+
# ali_kwargs = deepcopy(self.model_kwargs)
150+
# compute_device = ali_kwargs.get("device", "").lower()
151+
# is_cuda = compute_device == "cuda"
152+
# use_xformers = is_cuda and supports_flash_attention()
153+
# ali_kwargs["tokenizer_kwargs"] = {
154+
# "padding": "longest",
155+
# "truncation": True,
156+
# "max_length": 8192
157+
# }
158+
# ali_kwargs["config_kwargs"] = {
159+
# "use_memory_efficient_attention": use_xformers,
160+
# "unpad_inputs": use_xformers,
161+
# "attn_implementation": "eager" if use_xformers else "sdpa"
162+
# }
163+
# return ali_kwargs
164+
165+
# def prepare_encode_kwargs(self):
166+
# encode_kwargs = super().prepare_encode_kwargs()
167+
# encode_kwargs.update({
168+
# "padding": True,
169+
# "truncation": True,
170+
# "max_length": 8192
171+
# })
172+
# return encode_kwargs
173+
174+
146175
class AlibabaEmbedding(BaseEmbeddingModel):
147176
def prepare_kwargs(self):
148177
ali_kwargs = deepcopy(self.model_kwargs)
178+
149179
compute_device = ali_kwargs.get("device", "").lower()
150180
is_cuda = compute_device == "cuda"
151181
use_xformers = is_cuda and supports_flash_attention()
182+
152183
ali_kwargs["tokenizer_kwargs"] = {
153184
"padding": "longest",
154185
"truncation": True,
@@ -171,6 +202,7 @@ def prepare_encode_kwargs(self):
171202
return encode_kwargs
172203

173204

205+
174206
def create_vector_db_in_process(database_name):
175207
create_vector_db = CreateVectorDB(database_name=database_name)
176208
create_vector_db.run()

0 commit comments

Comments
 (0)