Skip to content

Commit 2d5e752

Browse files
authored
v3.0.1
Fixed a bug when searching by document type. Consolidated three vision model scripts into the new loader_images.py script.
1 parent 87ad685 commit 2d5e752

11 files changed

+412
-61
lines changed

src/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,6 @@
307307
".html": "UnstructuredHTMLLoader",
308308
}
309309

310-
WHISPER_MODEL_NAMES = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2"]
311-
312310
CHUNKS_ONLY_TOOLTIP = "Only return relevant chunks without connecting to the LLM. Extremely useful to test the chunk size/overlap settings."
313311

314312
SPEAK_RESPONSE_TOOLTIP = "Only click this after the LLM's entire response is received otherwise your computer might explode."

src/database_interactions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,32 @@ def get_embeddings(self, EMBEDDING_MODEL_NAME, config_data):
108108
if __name__ == "__main__":
109109
create_vector_db = CreateVectorDB()
110110
create_vector_db.run()
111+
112+
# To delete entries based on the "hash" metadata attribute, you can use this as_retriever method to create a retriever that filters documents based on their metadata. Once you retrieve the documents with the specific hash, you can then extract their IDs and use the delete method to remove them from the vectorstore.
113+
114+
# Here is how you might implement this in your CreateVectorDB class:
115+
116+
# python
117+
118+
# class CreateVectorDB:
119+
# # ... [other methods] ...
120+
121+
# def delete_entries_by_hash(self, target_hash):
122+
# my_cprint(f"Deleting entries with hash: {target_hash}", "red")
123+
124+
# # Initialize the retriever with a filter for the specific hash
125+
# retriever = self.db.as_retriever(search_kwargs={'filter': {'hash': target_hash}})
126+
127+
# # Retrieve documents with the specific hash
128+
# documents = retriever.search("")
129+
130+
# # Extract IDs from the documents
131+
# ids_to_delete = [doc.id for doc in documents]
132+
133+
# # Delete entries with the extracted IDs
134+
# if ids_to_delete:
135+
# self.db.delete(ids=ids_to_delete)
136+
# my_cprint(f"Deleted {len(ids_to_delete)} entries from the database.", "green")
137+
# else:
138+
# my_cprint("No entries found with the specified hash.", "yellow")
139+

src/document_processor.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
)
2222

2323
from constants import DOCUMENT_LOADERS
24-
from loader_vision_llava import llava_process_images
25-
from loader_vision_cogvlm import cogvlm_process_images
26-
from loader_salesforce import salesforce_process_images
24+
from loader_images import loader_cogvlm, loader_llava, loader_salesforce
2725
from extract_metadata import extract_document_metadata
2826
from utilities import my_cprint
2927

@@ -34,15 +32,18 @@
3432
for ext, loader_name in DOCUMENT_LOADERS.items():
3533
DOCUMENT_LOADERS[ext] = globals()[loader_name]
3634

37-
def process_images_wrapper(config):
35+
def choose_image_loader(config):
3836
chosen_model = config["vision"]["chosen_model"]
3937

4038
if chosen_model == 'llava' or chosen_model == 'bakllava':
41-
return llava_process_images()
39+
image_loader = loader_llava()
40+
return image_loader.llava_process_images()
4241
elif chosen_model == 'cogvlm':
43-
return cogvlm_process_images()
42+
image_loader = loader_cogvlm()
43+
return image_loader.cogvlm_process_images()
4444
elif chosen_model == 'salesforce':
45-
return salesforce_process_images()
45+
image_loader = loader_salesforce()
46+
return image_loader.salesforce_process_images()
4647
else:
4748
return []
4849

@@ -76,19 +77,16 @@ def load_single_document(file_path: Path) -> Document:
7677

7778
document = loader.load()[0]
7879

79-
metadata = extract_document_metadata(file_path) # get metadata
80+
metadata = extract_document_metadata(file_path)
8081
document.metadata.update(metadata)
81-
82-
# with open("output_load_single_document.txt", "w", encoding="utf-8") as output_file:
83-
# output_file.write(document.page_content)
8482

8583
return document
8684

8785
def load_document_batch(filepaths):
8886
with ThreadPoolExecutor(len(filepaths)) as exe:
8987
futures = [exe.submit(load_single_document, name) for name in filepaths]
9088
data_list = [future.result() for future in futures]
91-
return (data_list, filepaths) # "data_list" = list of all document objects created by load single document
89+
return (data_list, filepaths)
9290

9391
def load_documents(source_dir: Path) -> list[Document]:
9492
all_files = list(source_dir.iterdir())
@@ -118,9 +116,9 @@ def load_documents(source_dir: Path) -> list[Document]:
118116
with open("config.yaml", "r") as config_file:
119117
config = yaml.safe_load(config_file)
120118

121-
# Use ProcessPoolExecutor to process images
119+
# ProcessPoolExecutor to process images
122120
with ProcessPoolExecutor(1) as executor:
123-
future = executor.submit(process_images_wrapper, config)
121+
future = executor.submit(choose_image_loader, config)
124122
processed_docs = future.result()
125123
additional_docs = processed_docs if processed_docs is not None else []
126124

@@ -137,10 +135,6 @@ def split_documents(documents):
137135

138136
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
139137
texts = text_splitter.split_documents(documents)
140-
141-
# Add 'text' attribute to metadata of each split document
142-
#for document in texts:
143-
#document.metadata["text"] = document.page_content
144138

145139
my_cprint(f"Number of Chunks: {len(texts)}", "white")
146140

@@ -156,9 +150,4 @@ def split_documents(documents):
156150
count = sum(lower_bound <= size <= upper_bound for size in chunk_sizes)
157151
my_cprint(f"Chunks between {lower_bound} and {upper_bound} characters: {count}", "white")
158152

159-
return texts
160-
161-
'''
162-
# document object structure: Document(page_content="[ALL TEXT EXTRACTED]", metadata={'source': '[FULL FILE PATH WITH DOUBLE BACKSLASHES'})
163-
# list structure: [Document(page_content="...", metadata={'source': '...'}), Document(page_content="...", metadata={'source': '...'})]
164-
'''
153+
return texts

src/extract_metadata.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
import os
22
import datetime
3+
import hashlib
34

4-
def extract_image_metadata(file_path, file_name):
5+
def compute_file_hash(file_path):
6+
hash_sha256 = hashlib.sha256()
7+
with open(file_path, "rb") as f:
8+
for chunk in iter(lambda: f.read(4096), b""):
9+
hash_sha256.update(chunk)
10+
return hash_sha256.hexdigest()
511

12+
def extract_image_metadata(file_path, file_name):
613
file_type = os.path.splitext(file_name)[1]
714
file_size = os.path.getsize(file_path)
815
creation_date = datetime.datetime.fromtimestamp(os.path.getctime(file_path)).isoformat()
916
modification_date = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
17+
file_hash = compute_file_hash(file_path)
1018

1119
return {
1220
"file_path": file_path,
@@ -15,14 +23,16 @@ def extract_image_metadata(file_path, file_name):
1523
"file_size": file_size,
1624
"creation_date": creation_date,
1725
"modification_date": modification_date,
18-
"image": "True"
26+
"document_type": "image",
27+
"hash": file_hash
1928
}
2029

2130
def extract_document_metadata(file_path):
2231
file_type = os.path.splitext(file_path)[1]
2332
file_size = os.path.getsize(file_path)
2433
creation_date = datetime.datetime.fromtimestamp(os.path.getctime(file_path)).isoformat()
2534
modification_date = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
35+
file_hash = compute_file_hash(file_path)
2636

2737
return {
2838
"file_path": str(file_path),
@@ -31,5 +41,6 @@ def extract_document_metadata(file_path):
3141
"file_size": file_size,
3242
"creation_date": creation_date,
3343
"modification_date": modification_date,
34-
"image": "False"
35-
}
44+
"document_type": "document",
45+
"hash": file_hash
46+
}

src/gui_tabs_settings_database_query.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(self):
1313
self.database_creation_device = config_data['Compute_Device']['database_creation']
1414
self.database_query_device = config_data['Compute_Device']['database_query']
1515
self.search_term = config_data['database'].get('search_term', '')
16+
self.document_type = config_data['database'].get('document_types', '')
1617

1718
v_layout = QVBoxLayout()
1819
h_layout_device = QHBoxLayout()
@@ -64,7 +65,17 @@ def __init__(self):
6465
h_layout_search_term.addWidget(self.filter_button)
6566

6667
self.file_type_combo = QComboBox()
67-
self.file_type_combo.addItems(["All Files", "Images Only", "Non-Images Only"])
68+
file_type_items = ["All Files", "Images Only", "Documents Only"]
69+
self.file_type_combo.addItems(file_type_items)
70+
71+
if self.document_type == 'image':
72+
default_index = file_type_items.index("Images Only")
73+
elif self.document_type == 'document':
74+
default_index = file_type_items.index("Documents Only")
75+
else:
76+
default_index = file_type_items.index("All Files")
77+
self.file_type_combo.setCurrentIndex(default_index)
78+
6879
h_layout_search_term.addWidget(self.file_type_combo)
6980

7081
v_layout.addLayout(h_layout_search_term)
@@ -106,16 +117,16 @@ def update_config(self):
106117

107118
file_type_map = {
108119
"All Files": '',
109-
"Images Only": True,
110-
"Non-Images Only": False
120+
"Images Only": 'image',
121+
"Documents Only": 'document'
111122
}
112123

113124
file_type_selection = self.file_type_combo.currentText()
114-
images_only_value = file_type_map[file_type_selection]
125+
document_type_value = file_type_map[file_type_selection]
115126

116-
if images_only_value != config_data['database'].get('images_only', ''):
127+
if document_type_value != config_data['database'].get('document_types', ''):
117128
settings_changed = True
118-
config_data['database']['images_only'] = images_only_value
129+
config_data['database']['document_types'] = document_type_value
119130

120131
if settings_changed:
121132
with open('config.yaml', 'w') as f:

src/gui_tabs_settings_whisper.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def create_layout(self):
5757
model_label = QLabel("Model")
5858
layout.addWidget(model_label, 0, 0)
5959
self.model_combo = QComboBox()
60-
self.model_combo.addItems(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2"])
60+
self.model_combo.addItems(["whisper-tiny.en", "whisper-base.en", "whisper-small.en", "whisper-medium.en", "whisper-large-v2"])
6161
layout.addWidget(self.model_combo, 0, 1)
6262

6363
# Quantization
@@ -116,12 +116,3 @@ def update_config(self):
116116
yaml.dump(config_data, f)
117117

118118
return settings_changed
119-
120-
if __name__ == "__main__":
121-
from PySide6.QtWidgets import QApplication
122-
import sys
123-
124-
app = QApplication(sys.argv)
125-
transcriber_settings_tab = TranscriberSettingsTab()
126-
transcriber_settings_tab.show()
127-
sys.exit(app.exec())

src/gui_tabs_tools_transcribe.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from PySide6.QtCore import Qt
66
import yaml
77
from pathlib import Path
8-
from constants import WHISPER_MODEL_NAMES
98
from transcribe_module import TranscribeFile
109
import threading
1110

@@ -38,7 +37,7 @@ def create_layout(self):
3837
hbox1 = QHBoxLayout()
3938
hbox1.addWidget(QLabel("Model"))
4039
self.model_combo = QComboBox()
41-
self.model_combo.addItems([model for model in WHISPER_MODEL_NAMES if model not in ["tiny", "tiny.en", "base", "base.en"]])
40+
self.model_combo.addItems(["whisper-small.en", "whisper-medium.en", "whisper-large-v2"])
4241
self.model_combo.setCurrentText(self.default_model)
4342
self.model_combo.currentTextChanged.connect(self.update_model_in_config)
4443
hbox1.addWidget(self.model_combo)
@@ -73,7 +72,7 @@ def create_layout(self):
7372

7473
main_layout.addLayout(hbox2)
7574

76-
# Third row of widgets (Select Audio File and Transcribe buttons)
75+
# Third row of widgets
7776
hbox3 = QHBoxLayout()
7877
self.select_file_button = QPushButton("Select Audio File")
7978
self.select_file_button.clicked.connect(self.select_audio_file)
@@ -85,7 +84,6 @@ def create_layout(self):
8584

8685
main_layout.addLayout(hbox3)
8786

88-
# Label for displaying the selected file path
8987
self.file_path_label = QLabel("No file currently selected")
9088
main_layout.addWidget(self.file_path_label)
9189

0 commit comments

Comments
 (0)