Skip to content

Commit 6c14b9d

Browse files
authored
Merge pull request #507 from transformerlab/add/audio-stt
Add/audio stt
2 parents 93826e0 + 00b5c37 commit 6c14b9d

File tree

4 files changed

+232
-76
lines changed

4 files changed

+232
-76
lines changed

transformerlab/fastchat_openai_api.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class ChatCompletionRequest(BaseModel):
9393
tools: Optional[List[Dict[str, Any]]] = None
9494

9595

96-
class AudioRequest(BaseModel):
96+
class AudioGenerationRequest(BaseModel):
9797
experiment_id: str
9898
model: str
9999
adaptor: Optional[str] = ""
@@ -106,6 +106,14 @@ class AudioRequest(BaseModel):
106106
voice: Optional[str] = None
107107
audio_path: Optional[str] = None
108108

109+
class AudioTranscriptionsRequest(BaseModel):
110+
experiment_id: str
111+
model: str
112+
adaptor: Optional[str] = ""
113+
audio_path: str
114+
# format: str
115+
# output_path: str note: probably we set this by ourself
116+
109117

110118
class VisualizationRequest(PydanticBaseModel):
111119
model: str
@@ -506,7 +514,7 @@ async def show_available_models():
506514

507515

508516
@router.post("/v1/audio/speech", tags=["audio"])
509-
async def create_audio_tts(request: AudioRequest):
517+
async def create_audio_tts(request: AudioGenerationRequest):
510518
error_check_ret = await check_model(request)
511519
if error_check_ret is not None:
512520
if isinstance(error_check_ret, JSONResponse):
@@ -532,6 +540,8 @@ async def create_audio_tts(request: AudioRequest):
532540
"top_p": request.top_p,
533541
"audio_path": request.audio_path,
534542
}
543+
gen_params["task"] = "tts"
544+
535545

536546
# Add voice parameter if provided
537547
if request.voice:
@@ -564,6 +574,34 @@ async def upload_audio_reference(experimentId: str, audio: UploadFile = File(...
564574

565575
return JSONResponse({"audioPath": file_path})
566576

577+
@router.post("/v1/audio/transcriptions", tags=["audio"])
578+
async def create_text_stt(request: AudioTranscriptionsRequest):
579+
error_check_ret = await check_model(request)
580+
if error_check_ret is not None:
581+
if isinstance(error_check_ret, JSONResponse):
582+
return error_check_ret
583+
elif isinstance(error_check_ret, dict) and "model_name" in error_check_ret.keys():
584+
request.model = error_check_ret["model_name"]
585+
586+
exp_obj = Experiment.get(request.experiment_id)
587+
experiment_dir = exp_obj.get_dir()
588+
transcription_dir = os.path.join(experiment_dir, "transcriptions")
589+
os.makedirs(transcription_dir, exist_ok=True)
590+
591+
gen_params = {
592+
"model": request.model,
593+
"audio_path": request.audio_path,
594+
"output_path": transcription_dir,
595+
#"format": request.format,
596+
}
597+
gen_params["task"] = "stt"
598+
try:
599+
content = await generate_completion(gen_params)
600+
return content
601+
except Exception as e:
602+
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
603+
604+
567605

568606
@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)], tags=["chat"])
569607
async def create_openapi_chat_completion(request: ChatCompletionRequest):

transformerlab/plugins/mlx_audio_server/index.json

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
{
22
"name": "Apple Audio MLX Server",
33
"uniqueId": "mlx_audio_server",
4-
"description": "A text-to-speech (TTS) library built on Apple's MLX framework, providing efficient speech synthesis on Apple Silicon.",
4+
"description": "A text-to-speech (TTS), speech-to-text(STT) library built on Apple's MLX framework, providing efficient speech synthesis on Apple Silicon.",
55
"plugin-format": "python",
66
"type": "loader",
7-
"version": "0.0.7",
8-
"supports": ["Text-to-Speech", "Audio"],
9-
"model_architectures": ["MLXTextToSpeech", "StyleTTS2"],
7+
"version": "0.1.0",
8+
"supports": [
9+
"Text-to-Speech",
10+
"Audio",
11+
"Speech-to-Text"
12+
],
13+
"model_architectures": [
14+
"MLXTextToSpeech",
15+
"StyleTTS2",
16+
"MLXSpeechToText"
17+
],
1018
"supported_hardware_architectures": ["mlx"],
1119
"files": ["main.py", "setup.sh"],
1220
"setup-script": "setup.sh"

transformerlab/plugins/mlx_audio_server/main.py

Lines changed: 127 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
from fastapi.responses import JSONResponse
1717

1818
from fastchat.serve.model_worker import logger
19-
from transformerlab.plugin import WORKSPACE_DIR
2019

2120
from mlx_audio.tts.generate import generate_audio
21+
from mlx_audio.stt.generate import generate
2222
from datetime import datetime
2323

24+
from lab.dirs import get_experiments_dir, get_workspace_dir
25+
from werkzeug.utils import secure_filename
26+
2427
worker_id = str(uuid.uuid4())[:8]
2528

2629
from fastchat.serve.base_model_worker import BaseModelWorker # noqa
@@ -70,78 +73,134 @@ def __init__(
7073
async def generate(self, params):
7174
self.call_ct += 1
7275

73-
text = params.get("text", "")
74-
model = params.get("model", None)
75-
speed = params.get("speed", 1.0)
76-
# file_prefix = params.get("file_prefix", "audio")
77-
audio_format = params.get("audio_format", "wav")
78-
sample_rate = params.get("sample_rate", 24000)
79-
temperature = params.get("temperature", 0.0)
80-
top_p = params.get("top_p", 1.0)
81-
stream = params.get("stream", False)
82-
voice = params.get("voice", None)
83-
lang_code = params.get("lang_code", None)
84-
85-
audio_dir = params.get("audio_dir", None)
86-
if not audio_dir:
87-
audio_dir = os.path.join(WORKSPACE_DIR, "audio")
88-
os.makedirs(name=audio_dir, exist_ok=True)
89-
90-
# Generate a UUID for this file name:
91-
file_prefix = str(uuid.uuid4())
92-
93-
try:
94-
kwargs = {
95-
"text": text,
96-
"model_path": model,
97-
"speed": speed,
98-
"file_prefix": os.path.join(audio_dir, file_prefix),
99-
"sample_rate": sample_rate,
100-
"join_audio": True,
101-
"verbose": True,
102-
"temperature": temperature,
103-
"top_p": top_p,
104-
"stream": stream,
105-
"voice": voice,
106-
}
107-
if lang_code:
108-
kwargs["lang_code"] = lang_code
109-
110-
generate_audio(**kwargs)
111-
112-
# Also save the parameters and metadata used to generate the audio
113-
metadata = {
114-
"type": "audio",
115-
"text": text,
116-
"voice": voice,
117-
"filename": f"{file_prefix}.{audio_format}",
118-
"model": model,
119-
"speed": speed,
120-
"audio_format": audio_format,
121-
"sample_rate": sample_rate,
122-
"temperature": temperature,
123-
"top_p": top_p,
124-
"date": datetime.now().isoformat(), # Store the real date and time
125-
}
126-
127-
metadata_file = os.path.join(audio_dir, f"{file_prefix}.json")
128-
with open(metadata_file, "w") as f:
129-
json.dump(metadata, f)
130-
131-
logger.info(f"Audio successfully generated: {audio_dir}/{file_prefix}.{audio_format}")
132-
133-
return {
134-
"status": "success",
135-
"message": f"{audio_dir}/{file_prefix}.{audio_format}",
136-
}
137-
except Exception:
138-
logger.error(f"Error generating audio: {audio_dir}/{file_prefix}.{audio_format}")
76+
task = params.get("task")
77+
if task == "tts":
78+
79+
text = params.get("text", "")
80+
model = params.get("model", None)
81+
speed = params.get("speed", 1.0)
82+
file_prefix = secure_filename(params.get("file_prefix", "audio"))
83+
audio_format = params.get("audio_format", "wav")
84+
sample_rate = params.get("sample_rate", 24000)
85+
temperature = params.get("temperature", 0.0)
86+
top_p = params.get("top_p", 1.0)
87+
stream = params.get("stream", False)
88+
voice = params.get("voice", None)
89+
lang_code = params.get("lang_code", None)
90+
stream = params.get("stream", False)
91+
92+
experiment_dir = get_experiments_dir()
93+
audio_dir_name = secure_filename(params.get("audio_dir", "audio"))
94+
audio_dir = os.path.join(experiment_dir, audio_dir_name)
95+
os.makedirs(name=audio_dir, exist_ok=True)
96+
97+
try:
98+
kwargs = {
99+
"text": text,
100+
"model_path": model,
101+
"speed": speed,
102+
"file_prefix": os.path.join(audio_dir, file_prefix),
103+
"sample_rate": sample_rate,
104+
"join_audio": True,
105+
"verbose": True,
106+
"temperature": temperature,
107+
"top_p": top_p,
108+
"stream": stream,
109+
"voice": voice,
110+
}
111+
if lang_code:
112+
kwargs["lang_code"] = lang_code
113+
114+
generate_audio(**kwargs)
115+
116+
# Also save the parameters and metadata used to generate the audio
117+
metadata = {
118+
"type": "audio",
119+
"text": text,
120+
"voice": voice,
121+
"filename": f"{file_prefix}.{audio_format}",
122+
"model": model,
123+
"speed": speed,
124+
"audio_format": audio_format,
125+
"sample_rate": sample_rate,
126+
"temperature": temperature,
127+
"top_p": top_p,
128+
"date": datetime.now().isoformat(), # Store the real date and time
129+
}
130+
131+
metadata_file = os.path.join(audio_dir, f"{file_prefix}.json")
132+
with open(metadata_file, "w") as f:
133+
json.dump(metadata, f)
134+
135+
logger.info(f"Audio successfully generated: {audio_dir}/{file_prefix}.{audio_format}")
136+
137+
return {
138+
"status": "success",
139+
"message": f"{audio_dir}/{file_prefix}.{audio_format}",
140+
}
141+
except Exception:
142+
logger.error(f"Error generating audio: {audio_dir}/{file_prefix}.{audio_format}")
143+
return {
144+
"status": "error",
145+
"message": f"Error generating audio: {audio_dir}/{file_prefix}.{audio_format}",
146+
}
147+
148+
elif task == "stt":
149+
audio_path = params.get("audio_path", "")
150+
model = params.get("model", None)
151+
format = params.get("format", "txt")
152+
output_path_name = secure_filename(params.get("output_path", "transcriptions"))
153+
transcriptions_dir = os.path.join(get_workspace_dir(), output_path_name)
154+
os.makedirs(name=transcriptions_dir, exist_ok=True)
155+
156+
# Generate a UUID for this file name:
157+
file_prefix = str(uuid.uuid4())
158+
159+
try:
160+
generate(
161+
audio_path=audio_path,
162+
model_path=model,
163+
format=format,
164+
output_path=os.path.join(transcriptions_dir, file_prefix),
165+
verbose=True, # Set to False to disable print messages
166+
)
167+
168+
# Also save the parameters and metadata used to generate the audio
169+
metadata = {
170+
"type": "text",
171+
"audio_folder": "uploaded_audio",
172+
"audio_path": audio_path.split("/").pop(),
173+
"filename": f"{file_prefix}.{format}",
174+
"model": model,
175+
"text_format": format,
176+
"date": datetime.now().isoformat(), # Store the real date and time
177+
}
178+
metadata_file = os.path.join(transcriptions_dir, f"{file_prefix}.json")
179+
with open(metadata_file, "w") as f:
180+
json.dump(metadata, f)
181+
182+
logger.info(f"Transcription successfully generated: {transcriptions_dir}/{file_prefix}.{format}")
183+
184+
return {
185+
"status": "success",
186+
"message": f"{transcriptions_dir}/{file_prefix}.{format}",
187+
}
188+
except Exception:
189+
logger.error(f"Error generating transcription: {transcriptions_dir}/{file_prefix}.{format}")
190+
return {
191+
"status": "error",
192+
"message": f"Error generating transcription: {transcriptions_dir}/{file_prefix}.{format}",
193+
}
194+
195+
else:
196+
logger.error(f"Unknown task type: {task}")
139197
return {
140198
"status": "error",
141-
"message": f"Error generating audio: {audio_dir}/{file_prefix}.{audio_format}",
199+
"message": f"Unknown task type: {task}",
142200
}
143201

144202

203+
145204
def release_worker_semaphore():
146205
worker.semaphore.release()
147206

transformerlab/routers/experiment/conversations.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,13 @@ async def list_audio(experimentId: str):
130130

131131

132132
@router.get(path="/download_audio")
133-
async def download_audio(experimentId: str, filename: str):
133+
async def download_audio(experimentId: str, filename: str, audioFolder: str = "audio"):
134134
exp_obj = Experiment.get(experimentId)
135135
experiment_dir = exp_obj.get_dir()
136-
audio_dir = os.path.join(experiment_dir, "audio")
136+
137+
# Use the provided audioFolder parameter, defaulting to "audio"
138+
audioFolder = secure_filename(audioFolder)
139+
audio_dir = os.path.join(experiment_dir, audioFolder)
137140

138141
# now download the audio file
139142
filename = secure_filename(filename)
@@ -193,3 +196,51 @@ async def delete_audio(experimentId: str, id: str):
193196
os.remove(audio_path)
194197

195198
return {"message": f"Audio file {id} deleted from experiment {experimentId}"}
199+
200+
@router.get("/list_transcription")
201+
async def list_transcription(experimentId: str):
202+
# Get experiment object and directory
203+
exp_obj = Experiment.get(experimentId)
204+
experiment_dir = exp_obj.get_dir()
205+
transcription_dir = os.path.join(experiment_dir, "transcriptions")
206+
os.makedirs(transcription_dir, exist_ok=True)
207+
208+
# List all .json files in the transcription directory
209+
transcription_files_metadata = []
210+
for filename in os.listdir(transcription_dir):
211+
if filename.endswith(".json"):
212+
file_path = os.path.join(transcription_dir, filename)
213+
with open(file_path, "r") as f:
214+
try:
215+
data = json.load(f)
216+
# Add the file modification time for sorting
217+
data["id"] = filename[:-5] # Remove .json from the filename
218+
data["file_date"] = os.path.getmtime(file_path)
219+
transcription_files_metadata.append(data)
220+
except Exception:
221+
continue
222+
transcription_files_metadata.sort(key=lambda x: x["file_date"], reverse=True)
223+
return transcription_files_metadata
224+
225+
@router.get("/download_transcription")
226+
async def download_transcription(experimentId: str, filename: str):
227+
exp_obj = Experiment.get(experimentId)
228+
experiment_dir = exp_obj.get_dir()
229+
text_dir = os.path.join(experiment_dir, "transcriptions")
230+
filename = secure_filename(filename)
231+
file_path = os.path.join(text_dir, filename)
232+
if not os.path.exists(file_path):
233+
return {"message": f"Text file {filename} does not exist in experiment {experimentId}"}
234+
return FileResponse(path=file_path, filename=filename, media_type="text/plain")
235+
236+
@router.delete("/delete_transcription")
237+
async def delete_transcription(experimentId: str, id: str):
238+
exp_obj = Experiment.get(experimentId)
239+
experiment_dir = exp_obj.get_dir()
240+
text_dir = os.path.join(experiment_dir, "transcriptions")
241+
id = secure_filename(id)
242+
text_path = os.path.join(text_dir, id + ".json")
243+
if not os.path.exists(text_path):
244+
return {"message": f"Text file {id} does not exist in experiment {experimentId}"}
245+
os.remove(text_path)
246+
return {"message": f"Text file {id} deleted from experiment {experimentId}"}

0 commit comments

Comments
 (0)