Skip to content

Commit 6c6bced

Browse files
authored
v2.7.3
1 parent feefd48 commit 6c6bced

File tree

4 files changed

+163
-5
lines changed

4 files changed

+163
-5
lines changed

src/bark_module.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import threading
2+
import queue
3+
from transformers import AutoProcessor, BarkModel
4+
import torch
5+
import numpy as np
6+
import re
7+
import time
8+
import pyaudio
9+
import gc
10+
11+
class BarkAudio:
12+
def __init__(self):
13+
self.processor = AutoProcessor.from_pretrained("suno/bark-small")
14+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
15+
self.model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16).to(self.device)
16+
self.model = self.model.to_bettertransformer()
17+
# self.model.enable_cpu_offload()
18+
19+
self.sentence_queue = queue.Queue()
20+
self.processing_queue = queue.Queue()
21+
self.start_time = None
22+
23+
def play_audio_thread(self):
24+
while True:
25+
queue_item = self.sentence_queue.get()
26+
if queue_item is None:
27+
break
28+
29+
audio_array, sampling_rate, sentence_num = queue_item
30+
elapsed_time = time.time() - self.start_time
31+
print(f"({elapsed_time:.2f} seconds) Playing sentence #{sentence_num}")
32+
33+
p = pyaudio.PyAudio()
34+
stream = p.open(format=pyaudio.paInt16, channels=1, rate=sampling_rate, output=True)
35+
stream.write(audio_array.tobytes())
36+
stream.stop_stream()
37+
stream.close()
38+
p.terminate()
39+
40+
self.stop()
41+
42+
def process_text_thread(self):
43+
sentence_count = 1
44+
while True:
45+
text_prompt = self.processing_queue.get()
46+
if text_prompt is None:
47+
break
48+
49+
sentences = re.split(r'[.!?;]+', text_prompt)
50+
51+
for sentence in sentences:
52+
if sentence.strip():
53+
elapsed_time = time.time() - self.start_time
54+
print(f"({elapsed_time:.2f} seconds) Processing sentence #{sentence_count}")
55+
voice_preset = "v2/en_speaker_6"
56+
57+
inputs = self.processor(text=sentence, voice_preset=voice_preset, return_tensors="pt")
58+
59+
with torch.no_grad():
60+
speech_output = self.model.generate(**inputs.to(self.device), do_sample=True)
61+
62+
audio_array = speech_output[0].cpu().numpy()
63+
audio_array = np.int16(audio_array / np.max(np.abs(audio_array)) * 32767)
64+
sampling_rate = self.model.generation_config.sample_rate
65+
66+
self.sentence_queue.put((audio_array, sampling_rate, sentence_count))
67+
sentence_count += 1
68+
69+
def run(self):
70+
with open('chat_history.txt', 'r', encoding='utf-8') as file:
71+
llm_response = file.read()
72+
self.processing_queue.put(llm_response)
73+
74+
self.start_time = time.time()
75+
76+
processing_thread = threading.Thread(target=self.process_text_thread)
77+
playback_thread = threading.Thread(target=self.play_audio_thread)
78+
processing_thread.start()
79+
playback_thread.start()
80+
81+
processing_thread.join()
82+
playback_thread.join()
83+
84+
def stop(self):
85+
self.sentence_queue.put(None)
86+
87+
self.release_resources()
88+
89+
def release_resources(self):
90+
del self.model
91+
del self.processor
92+
if torch.cuda.is_available():
93+
torch.cuda.empty_cache()
94+
gc.collect()
95+
96+
if __name__ == "__main__":
97+
bark_audio = BarkAudio()
98+
bark_audio.run()
99+
100+
'''
101+
INSTRUCTIONS:
102+
103+
(1) Bark consists of 4 models but only one is used at any given moment. You can uncomment "self.model.enable_cpu_offload()"
104+
to put 3 models into RAM and only the one being used into VRAM. This saves VRAM at a significant speed cost.
105+
106+
(2) Delete ", torch_dtype=torch.float16" verbatim to run the model in float32 instead of float16. You must leave ".to(device)".
107+
108+
(3) You can comment out "model = model.to_bettertransformer()" to NOT use "better transformer," which is a library from Huggingface.
109+
Only do this if Better Transformers isn't compatible with your system, but it should be, and it provides a 5-20% speedup.
110+
111+
(4) Finally, to use the Bark full-size model remove "-small" on the two lines above; for example, it should read "suno/bark" instead.
112+
113+
*** You can experiment with any combination items (1)-(4) above to get the VRAM/speed/quality you want. For example, using the
114+
full-size Bark models but only at float16...or using the "-small" models but at full float32. ***
115+
'''
116+
117+
'''
118+
INSTRUCTIONS:
119+
120+
Go here for examples of different voices:
121+
122+
https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
123+
'''

src/gui.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import yaml
99
import sys
10+
import threading # Import threading
1011
from initialize import main as initialize_system
1112
from metrics_bar import MetricsBar
1213
from download_model import download_embedding_model
@@ -19,6 +20,9 @@
1920
import voice_recorder_module
2021
from utilities import list_theme_files, make_theme_changer, load_stylesheet
2122

23+
# Import BarkAudio from bark_module
24+
from bark_module import BarkAudio
25+
2226
class DocQA_GUI(QWidget):
2327
def __init__(self):
2428
super().__init__()
@@ -93,9 +97,15 @@ def init_ui(self):
9397
self.submit_button.clicked.connect(self.on_submit_button_clicked)
9498
right_vbox.addWidget(self.submit_button)
9599

100+
# Test Embeddings checkbox and Bark button
101+
checkbox_button_hbox = QHBoxLayout()
96102
self.test_embeddings_checkbox = QCheckBox("Test Embeddings")
97103
self.test_embeddings_checkbox.stateChanged.connect(self.on_test_embeddings_changed)
98-
right_vbox.addWidget(self.test_embeddings_checkbox)
104+
checkbox_button_hbox.addWidget(self.test_embeddings_checkbox)
105+
bark_button = QPushButton("Bark")
106+
bark_button.clicked.connect(self.on_bark_button_clicked) # Connect to the new handler
107+
checkbox_button_hbox.addWidget(bark_button)
108+
right_vbox.addLayout(checkbox_button_hbox)
99109

100110
# Create and add button row
101111
button_row_widget = self.create_button_row(self.on_submit_button_clicked)
@@ -196,6 +206,15 @@ def stop_recording():
196206

197207
return row_widget
198208

209+
# Handler for the Bark button click
210+
def on_bark_button_clicked(self):
211+
threading.Thread(target=self.run_bark_module).start()
212+
213+
# Method to instantiate and run BarkAudio
214+
def run_bark_module(self):
215+
bark_audio = BarkAudio() # Instantiate BarkAudio when the button is clicked
216+
bark_audio.run()
217+
199218
if __name__ == '__main__':
200219
app = QApplication(sys.argv)
201220
app.setStyle(QStyleFactory.create('fusion'))

src/requirements.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,8 @@ PyAudio==0.2.14
1818
faster-whisper==0.10.0
1919
termcolor==2.3.0
2020
pypandoc==1.12
21-
pydub==0.25.1
22-
PyYAML==6.0.1
21+
PyYAML==6.0.1
22+
transformers==4.36.0
23+
accelerate==0.25.0
24+
optimum==1.15.0
25+
pydub==0.25.1

src/server_connector.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,25 @@ def ask_local_chatgpt(query, persist_directory=PERSIST_DIRECTORY, client_setting
173173

174174
response_json = connect_to_local_chatgpt(augmented_query)
175175

176+
full_response = []
177+
176178
for chunk_message in response_json:
179+
if full_response and isinstance(full_response[-1], str):
180+
full_response[-1] += chunk_message
181+
else:
182+
full_response.append(chunk_message)
183+
177184
yield chunk_message
178185

186+
# Save the full response to chat_history.txt
187+
with open('chat_history.txt', 'w', encoding='utf-8') as file:
188+
for message in full_response:
189+
file.write(message)
190+
179191
yield "\n\n"
180192

193+
# LLM's response complete
194+
# format and append citations
181195
citations = format_metadata_as_citations(metadata_list)
182196

183197
unique_citations = []
@@ -195,7 +209,6 @@ def ask_local_chatgpt(query, persist_directory=PERSIST_DIRECTORY, client_setting
195209

196210
return {"answer": response_json, "sources": relevant_contexts}
197211

198-
199212
if __name__ == "__main__":
200213
user_input = "Your query here"
201-
interact_with_chat(user_input)
214+
interact_with_chat(user_input)

0 commit comments

Comments
 (0)