Skip to content

Commit 49ec375

Browse files
authored
Version 1.2 files
Added support for Metal/MPS and AMD GPU acceleration (via ROCm). Provided detailed installation instructions. Added "check.gpu.py" to test whether GPU-acceleration is installed correctly.
1 parent 5a496ec commit 49ec375

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

check_gpu.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
3+
if torch.cuda.is_available():
4+
print("CUDA is available!")
5+
print("CUDA version:", torch.version.cuda)
6+
else:
7+
print("CUDA is not available.")
8+
9+
print()
10+
11+
if torch.backends.mps.is_available():
12+
print("Metal/MPS is available!")
13+
else:
14+
print("Metal/MPS is not available.")
15+
16+
print("If you want to check the version of Metal and MPS on your macOS device, you can go to \"About This Mac\" -> \"System Report\" -> \"Graphics/Displays\" and look for information related to Metal and MPS.")
17+
18+
print()
19+
20+
if torch.version.hip is not None:
21+
print("ROCm is available!")
22+
print("ROCm version:", torch.version.hip)
23+
else:
24+
print("ROCm is not available.")

gui_logic.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def buttons(self):
3232
class DocQA_Logic:
3333
def __init__(self, gui: DocQA_GUI):
3434
self.gui = gui
35-
self.embed_model_name = "" # Store the selected embedding model name
35+
self.embed_model_name = ""
3636

37-
# Connect the buttons to their respective actions
37+
# Connecting the GUI buttons to their logic
3838
self.gui.download_embedding_model_button.config(command=self.download_embedding_model)
3939
self.gui.select_embedding_model_button.config(command=self.select_embedding_model_directory)
4040
self.gui.choose_documents_button.config(command=self.choose_documents)
@@ -48,27 +48,27 @@ def download_embedding_model(self):
4848

4949
# Opening the dialog window
5050
dialog = DownloadModelDialog(self.gui.root)
51-
selected_model = dialog.model_var.get() # this gets the selected model's name
51+
selected_model = dialog.model_var.get()
5252

5353
if selected_model:
5454
# Construct the URL for the Hugging Face model repository
5555
model_url = f"https://huggingface.co/{selected_model}"
5656

57-
# Define the target directory for the download
57+
# Define the directory to download the model to
5858
target_directory = os.path.join("Embedding_Models", selected_model.replace("/", "--"))
5959

60-
# Clone the repository using the subprocess module
60+
# Clone the repository to the directory
6161
subprocess.run(["git", "clone", model_url, target_directory])
6262

6363
def select_embedding_model_directory(self):
6464
initial_dir = 'Embedding_Models' if os.path.exists('Embedding_Models') else os.path.expanduser("~")
6565
chosen_directory = filedialog.askdirectory(initialdir=initial_dir, title="Select Embedding Model Directory")
6666

67-
# Store the chosen directory locally
67+
# Choose the model directory to use
6868
if chosen_directory:
6969
self.embedding_model_directory = chosen_directory
7070

71-
# Also update the global variable in server_connector.py
71+
# Update the global variable in server_connector.py
7272
server_connector.EMBEDDING_MODEL_NAME = chosen_directory
7373

7474
# Optionally, you can print or display a confirmation to the user
@@ -85,13 +85,12 @@ def choose_documents(self):
8585

8686
for file_path in file_paths:
8787
shutil.copy(file_path, docs_folder)
88-
# Add any additional logic to handle the selected files
8988

9089
def create_chromadb(self):
9190
current_dir = os.path.dirname(os.path.realpath(__file__))
9291
vector_db_folder = os.path.join(current_dir, "Vector_DB")
9392

94-
# Check if the "Vector_DB" folder exists, and create it if not
93+
# Create the "Vector_DB" folder if it doesn't exist
9594
if not os.path.exists(vector_db_folder):
9695
os.mkdir(vector_db_folder)
9796

ingest_improved.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,15 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc
7777
return documents, [] # We're only processing PDFs now, no more split based on extensions
7878

7979
def main():
80-
device_type = "cuda" # Change to 'cpu' if needed
80+
# Determine the appropriate compute device
81+
if torch.cuda.is_available() and torch.version.cuda:
82+
device_type = "cuda"
83+
elif torch.backends.mps.is_available():
84+
device_type = "mps"
85+
elif torch.cuda.is_available() and torch.version.hip:
86+
device_type = "cuda"
87+
else:
88+
device_type = "cpu"
8189

8290
logging.info(f"Loading documents from {SOURCE_DIRECTORY}")
8391
documents = load_documents(SOURCE_DIRECTORY)
@@ -92,9 +100,13 @@ def main():
92100
embeddings = HuggingFaceInstructEmbeddings(
93101
model_name=EMBEDDING_MODEL_NAME,
94102
model_kwargs={"device": device_type},
103+
query_instruction="Represent the legal treatise for retrieval."
95104
)
96105
else:
97-
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
106+
embeddings = HuggingFaceEmbeddings(
107+
model_name=EMBEDDING_MODEL_NAME,
108+
model_kwargs={"device": device_type},
109+
)
98110

99111
# Delete contents of the PERSIST_DIRECTORY before creating the vector database
100112
if os.path.exists(PERSIST_DIRECTORY):

0 commit comments

Comments
 (0)