1
- import gc
2
- from mailbox import Message
3
- import os
4
1
import shutil
5
- from pathlib import Path
6
- from typing import Self
7
-
8
- import torch
9
2
import yaml
10
- from chromadb . config import Settings
3
+ import gc
11
4
from langchain .docstore .document import Document
12
- from langchain .embeddings import (
13
- HuggingFaceBgeEmbeddings ,
14
- HuggingFaceEmbeddings ,
15
- HuggingFaceInstructEmbeddings ,
16
- )
5
+ from langchain .embeddings import HuggingFaceInstructEmbeddings , HuggingFaceEmbeddings , HuggingFaceBgeEmbeddings
17
6
from langchain .vectorstores import Chroma
18
- from termcolor import cprint
19
-
7
+ from chromadb .config import Settings
20
8
from document_processor import load_documents , split_documents
9
+ import torch
21
10
from utilities import validate_symbolic_links
11
+ from termcolor import cprint
12
+ from pathlib import Path
13
+ import os
22
14
23
15
ENABLE_PRINT = True
24
16
25
-
26
17
def my_cprint (* args , ** kwargs ):
27
18
if ENABLE_PRINT :
28
19
modified_message = f"create_database.py: { args [0 ]} "
29
20
cprint (modified_message , * args [1 :], ** kwargs )
30
21
31
-
32
22
ROOT_DIRECTORY = Path (__file__ ).resolve ().parent
33
23
SOURCE_DIRECTORY = ROOT_DIRECTORY / "Docs_for_DB"
34
24
PERSIST_DIRECTORY = ROOT_DIRECTORY / "Vector_DB"
@@ -37,93 +27,84 @@ def my_cprint(*args, **kwargs):
37
27
CHROMA_SETTINGS = Settings (
38
28
chroma_db_impl = "duckdb+parquet" ,
39
29
persist_directory = str (PERSIST_DIRECTORY ),
40
- anonymized_telemetry = False ,
30
+ anonymized_telemetry = False
41
31
)
42
32
43
-
44
33
def main ():
45
- with open (ROOT_DIRECTORY / "config.yaml" , "r" ) as stream :
34
+
35
+ with open (ROOT_DIRECTORY / "config.yaml" , 'r' ) as stream :
46
36
config_data = yaml .safe_load (stream )
47
37
48
38
EMBEDDING_MODEL_NAME = config_data .get ("EMBEDDING_MODEL_NAME" )
49
39
50
40
my_cprint (f"Loading documents." , "white" )
51
- documents = load_documents (
52
- SOURCE_DIRECTORY
53
- ) # invoke document_processor.py; returns a list of document objects
54
- if documents == None or len (documents ) == 0 :
55
- cprint (f"No documents to load." )
41
+ documents = load_documents (SOURCE_DIRECTORY ) # invoke document_processor.py; returns a list of document objects
42
+ if documents is None or len (documents ) == 0 :
43
+ cprint ("No documents to load." , "red" )
56
44
return
57
-
58
45
my_cprint (f"Successfully loaded documents." , "white" )
59
- texts = split_documents (
60
- documents
61
- ) # invoke document_processor.py again; returns a list of split document objects
62
-
46
+
47
+ texts = split_documents (documents ) # invoke document_processor.py again; returns a list of split document objects
48
+
63
49
embeddings = get_embeddings (EMBEDDING_MODEL_NAME , config_data )
64
50
my_cprint ("Embedding model loaded." , "green" )
51
+
65
52
if PERSIST_DIRECTORY .exists ():
66
53
shutil .rmtree (PERSIST_DIRECTORY )
67
54
PERSIST_DIRECTORY .mkdir (parents = True , exist_ok = True )
68
55
69
56
my_cprint ("Creating database." , "white" )
70
-
57
+
71
58
db = Chroma .from_documents (
72
- texts ,
73
- embeddings ,
74
- persist_directory = str (PERSIST_DIRECTORY ),
59
+ texts , embeddings ,
60
+ persist_directory = str (PERSIST_DIRECTORY ),
75
61
client_settings = CHROMA_SETTINGS ,
76
62
)
77
-
63
+
78
64
my_cprint ("Persisting database." , "white" )
79
65
db .persist ()
80
66
my_cprint ("Database persisted." , "white" )
81
-
67
+
68
+ del embeddings .client
69
+ del embeddings
82
70
torch .cuda .empty_cache ()
83
71
gc .collect ()
84
72
my_cprint ("Embedding model removed from memory." , "red" )
85
73
86
-
87
74
def get_embeddings (EMBEDDING_MODEL_NAME , config_data , normalize_embeddings = False ):
88
75
my_cprint ("Creating embeddings." , "white" )
89
-
90
- compute_device = config_data [" Compute_Device" ][ " database_creation" ]
91
-
76
+
77
+ compute_device = config_data [' Compute_Device' ][ ' database_creation' ]
78
+
92
79
if "instructor" in EMBEDDING_MODEL_NAME :
93
- embed_instruction = config_data ["embedding-models" ]["instructor" ].get (
94
- "embed_instruction"
95
- )
96
- query_instruction = config_data ["embedding-models" ]["instructor" ].get (
97
- "query_instruction"
98
- )
80
+ embed_instruction = config_data ['embedding-models' ]['instructor' ].get ('embed_instruction' )
81
+ query_instruction = config_data ['embedding-models' ]['instructor' ].get ('query_instruction' )
99
82
100
83
return HuggingFaceInstructEmbeddings (
101
84
model_name = EMBEDDING_MODEL_NAME ,
102
85
model_kwargs = {"device" : compute_device },
103
86
encode_kwargs = {"normalize_embeddings" : normalize_embeddings },
104
87
embed_instruction = embed_instruction ,
105
- query_instruction = query_instruction ,
88
+ query_instruction = query_instruction
106
89
)
107
90
108
91
elif "bge" in EMBEDDING_MODEL_NAME :
109
- query_instruction = config_data ["embedding-models" ]["bge" ].get (
110
- "query_instruction"
111
- )
92
+ query_instruction = config_data ['embedding-models' ]['bge' ].get ('query_instruction' )
112
93
113
94
return HuggingFaceBgeEmbeddings (
114
95
model_name = EMBEDDING_MODEL_NAME ,
115
96
model_kwargs = {"device" : compute_device },
116
97
query_instruction = query_instruction ,
117
- encode_kwargs = {"normalize_embeddings" : normalize_embeddings },
98
+ encode_kwargs = {"normalize_embeddings" : normalize_embeddings }
118
99
)
119
-
100
+
120
101
else :
102
+
121
103
return HuggingFaceEmbeddings (
122
104
model_name = EMBEDDING_MODEL_NAME ,
123
105
model_kwargs = {"device" : compute_device },
124
- encode_kwargs = {"normalize_embeddings" : normalize_embeddings },
106
+ encode_kwargs = {"normalize_embeddings" : normalize_embeddings }
125
107
)
126
108
127
-
128
109
if __name__ == "__main__" :
129
110
main ()
0 commit comments