Skip to content

Commit 6c670c8

Browse files
authored
v3.0.3
1 parent e2d4f64 commit 6c670c8

11 files changed

+252
-281
lines changed

src/bark_module.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Import necessary libraries including tqdm
12
import warnings
23
import threading
34
import queue
@@ -10,6 +11,7 @@
1011
import yaml
1112
from termcolor import cprint
1213
import platform
14+
from tqdm import tqdm # Importing tqdm for the progress bar
1315

1416
warnings.filterwarnings("ignore", message="torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
1517

@@ -37,7 +39,6 @@ def load_config(self):
3739

3840
def initialize_model_and_processor(self):
3941
os_name = platform.system().lower()
40-
4142
# set compute device
4243
if torch.cuda.is_available():
4344
if torch.version.hip and os_name == 'linux':
@@ -125,8 +126,8 @@ def process_text_thread(self):
125126
break
126127

127128
sentences = re.split(r'[.!?;]+', text_prompt)
128-
129-
for sentence in sentences:
129+
# Adding tqdm progress bar
130+
for sentence in tqdm(sentences, desc="Processing Sentences"):
130131
if sentence.strip():
131132
voice_preset = self.config['speaker']
132133
inputs = self.processor(text=sentence, voice_preset=voice_preset, return_tensors="pt")
@@ -179,4 +180,4 @@ def release_resources(self):
179180

180181
if __name__ == "__main__":
181182
bark_audio = BarkAudio()
182-
bark_audio.run()
183+
bark_audio.run()

src/check_gpu.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,74 @@
11
import sys
22
from PySide6.QtWidgets import QApplication, QMessageBox
3-
import torch
43

5-
def display_info():
6-
app = QApplication(sys.argv)
7-
info_message = ""
4+
try:
5+
import torch
6+
except ImportError:
7+
def display_info():
8+
app = QApplication(sys.argv)
9+
msg_box = QMessageBox(QMessageBox.Information, "PyTorch Not Installed", "PyTorch is not installed on this system.")
10+
msg_box.exec()
811

9-
if torch.cuda.is_available():
10-
info_message += "CUDA is available!\n"
11-
info_message += "CUDA version: {}\n\n".format(torch.version.cuda)
12-
else:
13-
info_message += "CUDA is not available.\n\n"
12+
else:
13+
def check_bitsandbytes():
14+
try:
15+
import bitsandbytes as bnb
16+
p = torch.nn.Parameter(torch.rand(10, 10).cuda())
17+
a = torch.rand(10, 10).cuda()
1418

15-
if torch.backends.mps.is_available():
16-
info_message += "Metal/MPS is available!\n\n"
17-
else:
18-
info_message += "Metal/MPS is not available.\n\n"
19+
p1 = p.data.sum().item()
1920

20-
info_message += "If you want to check the version of Metal and MPS on your macOS device, you can go to \"About This Mac\" -> \"System Report\" -> \"Graphics/Displays\" and look for information related to Metal and MPS.\n\n"
21+
adam = bnb.optim.Adam([p])
2122

22-
if torch.version.hip is not None:
23-
info_message += "ROCm is available!\n"
24-
info_message += "ROCm version: {}\n".format(torch.version.hip)
25-
else:
26-
info_message += "ROCm is not available.\n"
23+
out = a * p
24+
loss = out.sum()
25+
loss.backward()
26+
adam.step()
2727

28-
msg_box = QMessageBox(QMessageBox.Information, "GPU Acceleration Available?", info_message)
29-
msg_box.exec()
28+
p2 = p.data.sum().item()
29+
30+
assert p1 != p2
31+
return "SUCCESS!\nInstallation of bitsandbytes was successful!"
32+
except ImportError:
33+
return "bitsandbytes is not installed."
34+
except AssertionError:
35+
return "bitsandbytes is installed, but the installation seems incorrect."
36+
except Exception as e:
37+
return f"An error occurred while checking bitsandbytes: {e}"
38+
39+
def display_info():
40+
app = QApplication(sys.argv)
41+
info_message = ""
42+
43+
if torch.cuda.is_available():
44+
info_message += "CUDA is available!\n"
45+
info_message += "CUDA version: {}\n\n".format(torch.version.cuda)
46+
else:
47+
info_message += "CUDA is not available.\n\n"
48+
49+
if torch.backends.mps.is_available():
50+
info_message += "Metal/MPS is available!\n\n"
51+
else:
52+
info_message += "Metal/MPS is not available.\n\n"
53+
if not torch.backends.mps.is_built():
54+
info_message += "MPS not available because the current PyTorch install was not built with MPS enabled.\n\n"
55+
else:
56+
info_message += "MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.\n\n"
57+
58+
info_message += "If you want to check the version of Metal and MPS on your macOS device, you can go to \"About This Mac\" -> \"System Report\" -> \"Graphics/Displays\" and look for information related to Metal and MPS.\n\n"
59+
60+
if torch.version.hip is not None:
61+
info_message += "ROCm is available!\n"
62+
info_message += "ROCm version: {}\n".format(torch.version.hip)
63+
else:
64+
info_message += "ROCm is not available.\n"
65+
66+
# Check for bitsandbytes
67+
bitsandbytes_message = check_bitsandbytes()
68+
info_message += "\n" + bitsandbytes_message
69+
70+
msg_box = QMessageBox(QMessageBox.Information, "GPU Acceleration and Library Check", info_message)
71+
msg_box.exec()
3072

3173
if __name__ == "__main__":
3274
display_info()

src/choose_documents.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
import os
33
from pathlib import Path
44
from PySide6.QtWidgets import QApplication, QFileDialog, QDialog, QVBoxLayout, QTextEdit, QPushButton, QHBoxLayout
5-
import sys
5+
import platform
66

77
def choose_documents_directory():
88
allowed_extensions = ['.pdf', '.docx', '.epub', '.txt', '.enex', '.eml', '.msg', '.csv', '.xls', '.xlsx', '.rtf', '.odt',
99
'.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tif', '.tiff', '.html', '.htm', '.md']
1010
current_dir = Path(__file__).parent.resolve()
11-
docs_folder = current_dir / "Docs_for_DB"
12-
images_folder = current_dir / "Images_for_DB"
1311
file_dialog = QFileDialog()
1412
file_dialog.setFileMode(QFileDialog.ExistingFiles)
1513
file_paths, _ = file_dialog.getOpenFileNames(None, "Choose Documents and Images for Database", str(current_dir))
@@ -21,12 +19,18 @@ def choose_documents_directory():
2119
for file_path in file_paths:
2220
extension = Path(file_path).suffix.lower()
2321
if extension in allowed_extensions:
22+
# Determine target folder without creating it
2423
if extension in ['.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tif', '.tiff']:
25-
target_folder = images_folder
24+
target_folder = current_dir / "Images_for_DB"
2625
else:
27-
target_folder = docs_folder
28-
target_folder.mkdir(parents=True, exist_ok=True)
26+
target_folder = current_dir / "Docs_for_DB"
27+
28+
# Check and unlink existing symlink if necessary
2929
symlink_target = target_folder / Path(file_path).name
30+
if symlink_target.exists():
31+
symlink_target.unlink()
32+
33+
# Create new symlink
3034
symlink_target.symlink_to(file_path)
3135
else:
3236
incompatible_files.append(Path(file_path).name)
@@ -62,14 +66,12 @@ def see_documents_directory():
6266
current_dir = Path(__file__).parent.resolve()
6367
docs_folder = current_dir / "Docs_for_DB"
6468

65-
docs_folder.mkdir(parents=True, exist_ok=True)
66-
67-
# Cross-platform directory opening
68-
if os.name == 'nt': # Windows
69+
os_name = platform.system()
70+
if os_name == 'Windows':
6971
subprocess.Popen(['explorer', str(docs_folder)])
70-
elif sys.platform == 'darwin': # macOS
72+
elif os_name == 'Darwin':
7173
subprocess.Popen(['open', str(docs_folder)])
72-
elif sys.platform.startswith('linux'): # Linux
74+
elif os_name == 'Linux':
7375
subprocess.Popen(['xdg-open', str(docs_folder)])
7476

7577
if __name__ == '__main__':

src/create_database.py

Lines changed: 35 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,24 @@
1-
import gc
2-
from mailbox import Message
3-
import os
41
import shutil
5-
from pathlib import Path
6-
from typing import Self
7-
8-
import torch
92
import yaml
10-
from chromadb.config import Settings
3+
import gc
114
from langchain.docstore.document import Document
12-
from langchain.embeddings import (
13-
HuggingFaceBgeEmbeddings,
14-
HuggingFaceEmbeddings,
15-
HuggingFaceInstructEmbeddings,
16-
)
5+
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
176
from langchain.vectorstores import Chroma
18-
from termcolor import cprint
19-
7+
from chromadb.config import Settings
208
from document_processor import load_documents, split_documents
9+
import torch
2110
from utilities import validate_symbolic_links
11+
from termcolor import cprint
12+
from pathlib import Path
13+
import os
2214

2315
ENABLE_PRINT = True
2416

25-
2617
def my_cprint(*args, **kwargs):
2718
if ENABLE_PRINT:
2819
modified_message = f"create_database.py: {args[0]}"
2920
cprint(modified_message, *args[1:], **kwargs)
3021

31-
3222
ROOT_DIRECTORY = Path(__file__).resolve().parent
3323
SOURCE_DIRECTORY = ROOT_DIRECTORY / "Docs_for_DB"
3424
PERSIST_DIRECTORY = ROOT_DIRECTORY / "Vector_DB"
@@ -37,93 +27,84 @@ def my_cprint(*args, **kwargs):
3727
CHROMA_SETTINGS = Settings(
3828
chroma_db_impl="duckdb+parquet",
3929
persist_directory=str(PERSIST_DIRECTORY),
40-
anonymized_telemetry=False,
30+
anonymized_telemetry=False
4131
)
4232

43-
4433
def main():
45-
with open(ROOT_DIRECTORY / "config.yaml", "r") as stream:
34+
35+
with open(ROOT_DIRECTORY / "config.yaml", 'r') as stream:
4636
config_data = yaml.safe_load(stream)
4737

4838
EMBEDDING_MODEL_NAME = config_data.get("EMBEDDING_MODEL_NAME")
4939

5040
my_cprint(f"Loading documents.", "white")
51-
documents = load_documents(
52-
SOURCE_DIRECTORY
53-
) # invoke document_processor.py; returns a list of document objects
54-
if documents == None or len(documents) == 0:
55-
cprint(f"No documents to load.")
41+
documents = load_documents(SOURCE_DIRECTORY) # invoke document_processor.py; returns a list of document objects
42+
if documents is None or len(documents) == 0:
43+
cprint("No documents to load.", "red")
5644
return
57-
5845
my_cprint(f"Successfully loaded documents.", "white")
59-
texts = split_documents(
60-
documents
61-
) # invoke document_processor.py again; returns a list of split document objects
62-
46+
47+
texts = split_documents(documents) # invoke document_processor.py again; returns a list of split document objects
48+
6349
embeddings = get_embeddings(EMBEDDING_MODEL_NAME, config_data)
6450
my_cprint("Embedding model loaded.", "green")
51+
6552
if PERSIST_DIRECTORY.exists():
6653
shutil.rmtree(PERSIST_DIRECTORY)
6754
PERSIST_DIRECTORY.mkdir(parents=True, exist_ok=True)
6855

6956
my_cprint("Creating database.", "white")
70-
57+
7158
db = Chroma.from_documents(
72-
texts,
73-
embeddings,
74-
persist_directory=str(PERSIST_DIRECTORY),
59+
texts, embeddings,
60+
persist_directory=str(PERSIST_DIRECTORY),
7561
client_settings=CHROMA_SETTINGS,
7662
)
77-
63+
7864
my_cprint("Persisting database.", "white")
7965
db.persist()
8066
my_cprint("Database persisted.", "white")
81-
67+
68+
del embeddings.client
69+
del embeddings
8270
torch.cuda.empty_cache()
8371
gc.collect()
8472
my_cprint("Embedding model removed from memory.", "red")
8573

86-
8774
def get_embeddings(EMBEDDING_MODEL_NAME, config_data, normalize_embeddings=False):
8875
my_cprint("Creating embeddings.", "white")
89-
90-
compute_device = config_data["Compute_Device"]["database_creation"]
91-
76+
77+
compute_device = config_data['Compute_Device']['database_creation']
78+
9279
if "instructor" in EMBEDDING_MODEL_NAME:
93-
embed_instruction = config_data["embedding-models"]["instructor"].get(
94-
"embed_instruction"
95-
)
96-
query_instruction = config_data["embedding-models"]["instructor"].get(
97-
"query_instruction"
98-
)
80+
embed_instruction = config_data['embedding-models']['instructor'].get('embed_instruction')
81+
query_instruction = config_data['embedding-models']['instructor'].get('query_instruction')
9982

10083
return HuggingFaceInstructEmbeddings(
10184
model_name=EMBEDDING_MODEL_NAME,
10285
model_kwargs={"device": compute_device},
10386
encode_kwargs={"normalize_embeddings": normalize_embeddings},
10487
embed_instruction=embed_instruction,
105-
query_instruction=query_instruction,
88+
query_instruction=query_instruction
10689
)
10790

10891
elif "bge" in EMBEDDING_MODEL_NAME:
109-
query_instruction = config_data["embedding-models"]["bge"].get(
110-
"query_instruction"
111-
)
92+
query_instruction = config_data['embedding-models']['bge'].get('query_instruction')
11293

11394
return HuggingFaceBgeEmbeddings(
11495
model_name=EMBEDDING_MODEL_NAME,
11596
model_kwargs={"device": compute_device},
11697
query_instruction=query_instruction,
117-
encode_kwargs={"normalize_embeddings": normalize_embeddings},
98+
encode_kwargs={"normalize_embeddings": normalize_embeddings}
11899
)
119-
100+
120101
else:
102+
121103
return HuggingFaceEmbeddings(
122104
model_name=EMBEDDING_MODEL_NAME,
123105
model_kwargs={"device": compute_device},
124-
encode_kwargs={"normalize_embeddings": normalize_embeddings},
106+
encode_kwargs={"normalize_embeddings": normalize_embeddings}
125107
)
126108

127-
128109
if __name__ == "__main__":
129110
main()

0 commit comments

Comments
 (0)