Skip to content

Commit f6a0e46

Browse files
Added s2s files
1 parent 90c38d4 commit f6a0e46

File tree

12 files changed

+488
-10
lines changed

12 files changed

+488
-10
lines changed

TTS/parler_handler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ class ParlerTTSHandler(BaseHandler):
3636
def setup(
3737
self,
3838
should_listen,
39-
model_name="ylacombe/parler-tts-mini-jenny-30H",
39+
# model_name="ylacombe/parler-tts-mini-jenny-30H",
40+
model_name="ylacombe/parler_tts_mini_v0.1",
41+
# model_name="parler-tts/parler_tts_mini_v0.1",
4042
device="cuda",
4143
torch_dtype="float16",
4244
compile_mode=None,
@@ -82,7 +84,9 @@ def setup(
8284

8385
self.viseme_flag = viseme_flag
8486
if self.viseme_flag:
85-
self.speech_to_visemes = SpeechToVisemes()
87+
self.speech_to_visemes = SpeechToVisemes(
88+
device=self.device
89+
)
8690

8791
self.warmup()
8892

@@ -100,13 +104,15 @@ def prepare_model_inputs(
100104
self.description, return_tensors="pt"
101105
)
102106
input_ids = tokenized_description.input_ids.to(self.device)
103-
attention_mask = tokenized_description.attention_mask.to(self.device)
107+
# attention_mask = tokenized_description.attention_mask.to(self.device)
108+
attention_mask = None
104109

105110
tokenized_prompt = self.prompt_tokenizer(
106111
prompt, return_tensors="pt", **pad_args_prompt
107112
)
108113
prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
109-
prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
114+
# prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
115+
prompt_attention_mask = None
110116

111117
gen_kwargs = {
112118
"input_ids": input_ids,

arguments_classes/parler_tts_arguments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
@dataclass
55
class ParlerTTSHandlerArguments:
66
tts_model_name: str = field(
7-
default="ylacombe/parler-tts-mini-jenny-30H",
7+
# default="ylacombe/parler-tts-mini-jenny-30H",
8+
default="ylacombe/parler_tts_mini_v0.1",
89
metadata={
910
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
1011
},

config.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"device": "cpu",
3+
"stt": "whisper",
4+
"stt_model_name": "openai/whisper-tiny"
5+
}

listen_and_dont_play.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import socket
2+
import threading
3+
from queue import Queue
4+
from dataclasses import dataclass, field
5+
import soundfile as sf
6+
import numpy as np
7+
import struct
8+
import pickle
9+
10+
@dataclass
11+
class ListenAndPlayArguments:
12+
send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
13+
recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
14+
list_play_chunk_size: int = field(
15+
default=512,
16+
metadata={"help": "The size of data chunks (in bytes). Default is 512."},
17+
)
18+
host: str = field(
19+
default="localhost",
20+
metadata={
21+
"help": "The hostname or IP address for listening and playing. Default is 'localhost'."
22+
},
23+
)
24+
send_port: int = field(
25+
default=12345,
26+
metadata={"help": "The network port for sending data. Default is 12345."},
27+
)
28+
recv_port: int = field(
29+
default=12346,
30+
metadata={"help": "The network port for receiving data. Default is 12346."},
31+
)
32+
input_audio_file: str = field(
33+
default="sample_audio.wav",
34+
metadata={"help": "Path to the audio file to use as input."},
35+
)
36+
37+
38+
def listen_and_dont_play(
39+
send_rate=16000,
40+
recv_rate=16000,
41+
list_play_chunk_size=512,
42+
host="localhost",
43+
send_port=12345,
44+
recv_port=12346,
45+
input_audio_file="sample_audio.wav",
46+
):
47+
send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
48+
send_socket.connect((host, send_port))
49+
50+
recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
51+
recv_socket.connect((host, recv_port))
52+
53+
print(f"Simulating recording and streaming using {input_audio_file}...")
54+
55+
stop_event = threading.Event()
56+
recv_queue = Queue()
57+
send_queue = Queue()
58+
59+
def load_audio_chunks(file_path, chunk_size, sample_rate, append_silence_secs=5):
60+
"""Load audio file, append silence, and yield chunks of audio data."""
61+
# Read audio file
62+
audio_data, audio_sample_rate = sf.read(file_path, dtype='int16')
63+
if audio_sample_rate != sample_rate:
64+
raise ValueError(f"Expected sample rate of {sample_rate}, but got {audio_sample_rate}")
65+
66+
# Calculate and append 5 seconds of silence
67+
silence = np.zeros(int(sample_rate * append_silence_secs), dtype='int16')
68+
combined_audio = np.concatenate([audio_data, silence])
69+
70+
# Break audio into chunks
71+
for i in range(0, len(combined_audio), chunk_size):
72+
yield combined_audio[i:i + chunk_size].tobytes()
73+
74+
def send(stop_event, send_queue):
75+
for chunk in load_audio_chunks(input_audio_file, list_play_chunk_size, send_rate):
76+
if stop_event.is_set():
77+
break
78+
send_queue.put(chunk)
79+
80+
send_queue.put(b"END")
81+
82+
def recv(stop_event, recv_queue):
83+
def receive_full_chunk(conn, chunk_size):
84+
data = b""
85+
while len(data) < chunk_size:
86+
packet = conn.recv(chunk_size - len(data))
87+
if not packet:
88+
return None # Connection has been closed
89+
data += packet
90+
return data
91+
92+
while not stop_event.is_set():
93+
# Step 1: Receive the first 4 bytes to get the packet length
94+
length_data = receive_full_chunk(recv_socket, 4)
95+
if not length_data:
96+
continue # Handle disconnection or data not available
97+
98+
# Step 2: Unpack the length (4 bytes)
99+
packet_length = struct.unpack('!I', length_data)[0]
100+
101+
# Step 3: Receive the full packet based on the length
102+
serialized_packet = receive_full_chunk(recv_socket, packet_length)
103+
if serialized_packet:
104+
# Step 4: Deserialize the packet using pickle
105+
packet = pickle.loads(serialized_packet)
106+
# Step 5: Extract the packet contents (text, visemes, audio)
107+
if 'text' in packet:
108+
print(f"Transcribed Text: {packet['text']}")
109+
if 'visemes' in packet:
110+
print(f"Visemes: {packet['visemes']}")
111+
# We're no longer playing audio, but you could process it if needed
112+
if 'audio' in packet:
113+
recv_queue.put(packet['audio'])
114+
115+
try:
116+
send_thread = threading.Thread(target=send, args=(stop_event, send_queue))
117+
send_thread.start()
118+
recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue))
119+
recv_thread.start()
120+
121+
input("Press Enter to stop...")
122+
123+
except KeyboardInterrupt:
124+
print("Finished streaming.")
125+
126+
finally:
127+
stop_event.set()
128+
recv_thread.join()
129+
send_thread.join()
130+
send_socket.close()
131+
recv_socket.close()
132+
print("Connection closed.")
133+
134+
135+
if __name__ == "__main__":
136+
parser = HfArgumentParser((ListenAndPlayArguments,))
137+
(listen_and_play_kwargs,) = parser.parse_args_into_dataclasses()
138+
listen_and_dont_play(**vars(listen_and_play_kwargs))

listen_and_play.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ def receive_full_chunk(conn, chunk_size):
9595
packet = pickle.loads(serialized_packet)
9696
# Step 5: Extract the packet contents
9797
if 'text' in packet:
98-
pass
99-
# print(packet['text'])
98+
# pass
99+
print(packet['text'])
100100
if 'visemes' in packet:
101-
pass
102-
# print(packet['visemes'])
101+
# pass
102+
print(packet['visemes'])
103103

104104
# Step 6: Put the packet audio data into the queue for sending
105105
recv_queue.put(packet['audio'].tobytes())

resampler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import librosa
2+
import soundfile as sf
3+
4+
def resample_audio(input_audio_file, target_sample_rate=16000):
5+
# Load the audio file with librosa
6+
audio_data, original_sample_rate = librosa.load(input_audio_file, sr=None)
7+
8+
# Resample the audio to the target sample rate
9+
if original_sample_rate != target_sample_rate:
10+
audio_data = librosa.resample(audio_data, orig_sr=original_sample_rate, target_sr=target_sample_rate)
11+
12+
# Save the resampled audio
13+
resampled_audio_file = "sample_audio.wav"
14+
sf.write(resampled_audio_file, audio_data, target_sample_rate)
15+
16+
return resampled_audio_file
17+
18+
# Usage
19+
input_audio_file = "unsampled_audio.wav"
20+
resampled_audio_file = resample_audio(input_audio_file)

run.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import subprocess
2+
import sys
3+
4+
def run_pipeline():
5+
try:
6+
# Run the command and stream the output in real-time
7+
process = subprocess.Popen(['python', 's2s_pipeline.py', 'config.json'],
8+
stdout=subprocess.PIPE,
9+
stderr=subprocess.STDOUT, # Combine stdout and stderr
10+
text=True,
11+
bufsize=1) # Line-buffered output
12+
13+
# Stream the stdout as it comes
14+
for line in process.stdout:
15+
sys.stdout.write(line) # Write directly to sys.stdout for real-time output
16+
sys.stdout.flush() # Ensure each line is printed immediately
17+
18+
process.wait() # Wait for the process to complete
19+
20+
except subprocess.CalledProcessError as e:
21+
print(f"Error occurred: {e.output}")
22+
23+
if __name__ == "__main__":
24+
run_pipeline()

s2s_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def main():
119119
# 1. Handle logger
120120
global logger
121121
logging.basicConfig(
122-
level=module_kwargs.log_level.upper(),
122+
# level=module_kwargs.log_level.upper(),
123+
level=logging.DEBUG,
123124
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
124125
)
125126
logger = logging.getLogger(__name__)

simplified_pipeline.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import queue
2+
import threading
3+
import logging
4+
import numpy as np
5+
import torch
6+
import torchaudio
7+
from VAD.vad_handler import VADHandler
8+
from STT.whisper_stt_handler import WhisperSTTHandler # Import your Whisper handler
9+
10+
# Configure logger
11+
logging.basicConfig(level=logging.DEBUG)
12+
logger = logging.getLogger(__name__)
13+
14+
# Function to run the local pipeline with VAD and STT
15+
def run_local_pipeline():
16+
# Initialize events and queues
17+
stop_event = threading.Event()
18+
should_listen = threading.Event()
19+
recv_audio_chunks_queue = queue.Queue()
20+
spoken_prompt_queue = queue.Queue() # Queue for audio with detected speech
21+
text_prompt_queue = queue.Queue() # Queue for converted text (output of STT)
22+
23+
# Initialize the VAD handler
24+
vad = VADHandler(
25+
stop_event,
26+
queue_in=recv_audio_chunks_queue,
27+
queue_out=spoken_prompt_queue,
28+
setup_args=(should_listen,),
29+
setup_kwargs={
30+
'thresh': 0.3,
31+
'sample_rate': 16000,
32+
'audio_enhancement': False # Set to True if you want enhancement
33+
}
34+
)
35+
36+
print("Setup VAD")
37+
38+
# Initialize the Whisper STT handler
39+
stt = WhisperSTTHandler(
40+
stop_event,
41+
queue_in=spoken_prompt_queue, # Speech detected audio chunks go here
42+
queue_out=text_prompt_queue, # The output text from STT goes here
43+
setup_kwargs={
44+
'device': 'cpu', # Assuming you're using CPU. Change to 'cuda' for GPU.
45+
'model_name': 'openai/whisper-tiny', # Whisper model being used
46+
'language': 'en', # Set to English to avoid language detection
47+
'compile_mode': None,
48+
}
49+
)
50+
51+
52+
print("Setup STT")
53+
54+
# Simulate receiving audio chunks and processing them with VAD and STT
55+
try:
56+
print("Running simplified pipeline locally with VAD and STT...")
57+
58+
# Simulate processing 5 chunks of audio (each 1 second long, silence)
59+
for _ in range(5):
60+
dummy_audio_chunk = b'\x00' * 16000 # Simulate a 1-second silent audio chunk
61+
vad.process(dummy_audio_chunk)
62+
63+
# After VAD, check if any speech was detected, if so, send to STT
64+
while not spoken_prompt_queue.empty():
65+
audio_chunk = spoken_prompt_queue.get() # Get the detected speech audio
66+
stt.process(audio_chunk) # Convert the audio to text
67+
68+
# Retrieve the transcribed text from the STT
69+
if not text_prompt_queue.empty():
70+
transcribed_text = text_prompt_queue.get()
71+
print(f"STT Result: {transcribed_text}")
72+
73+
print("Pipeline completed.")
74+
except KeyboardInterrupt:
75+
print("Pipeline stopped.")
76+
77+
# Run the local pipeline
78+
if __name__ == "__main__":
79+
run_local_pipeline()

0 commit comments

Comments
 (0)