|
16 | 16 | from fastapi.responses import JSONResponse |
17 | 17 |
|
18 | 18 | from fastchat.serve.model_worker import logger |
19 | | -from transformerlab.plugin import WORKSPACE_DIR |
20 | 19 |
|
21 | 20 | from mlx_audio.tts.generate import generate_audio |
| 21 | +from mlx_audio.stt.generate import generate |
22 | 22 | from datetime import datetime |
23 | 23 |
|
| 24 | +from lab.dirs import get_experiments_dir, get_workspace_dir |
| 25 | +from werkzeug.utils import secure_filename |
| 26 | + |
24 | 27 | worker_id = str(uuid.uuid4())[:8] |
25 | 28 |
|
26 | 29 | from fastchat.serve.base_model_worker import BaseModelWorker # noqa |
@@ -70,78 +73,134 @@ def __init__( |
70 | 73 | async def generate(self, params): |
71 | 74 | self.call_ct += 1 |
72 | 75 |
|
73 | | - text = params.get("text", "") |
74 | | - model = params.get("model", None) |
75 | | - speed = params.get("speed", 1.0) |
76 | | - # file_prefix = params.get("file_prefix", "audio") |
77 | | - audio_format = params.get("audio_format", "wav") |
78 | | - sample_rate = params.get("sample_rate", 24000) |
79 | | - temperature = params.get("temperature", 0.0) |
80 | | - top_p = params.get("top_p", 1.0) |
81 | | - stream = params.get("stream", False) |
82 | | - voice = params.get("voice", None) |
83 | | - lang_code = params.get("lang_code", None) |
84 | | - |
85 | | - audio_dir = params.get("audio_dir", None) |
86 | | - if not audio_dir: |
87 | | - audio_dir = os.path.join(WORKSPACE_DIR, "audio") |
88 | | - os.makedirs(name=audio_dir, exist_ok=True) |
89 | | - |
90 | | - # Generate a UUID for this file name: |
91 | | - file_prefix = str(uuid.uuid4()) |
92 | | - |
93 | | - try: |
94 | | - kwargs = { |
95 | | - "text": text, |
96 | | - "model_path": model, |
97 | | - "speed": speed, |
98 | | - "file_prefix": os.path.join(audio_dir, file_prefix), |
99 | | - "sample_rate": sample_rate, |
100 | | - "join_audio": True, |
101 | | - "verbose": True, |
102 | | - "temperature": temperature, |
103 | | - "top_p": top_p, |
104 | | - "stream": stream, |
105 | | - "voice": voice, |
106 | | - } |
107 | | - if lang_code: |
108 | | - kwargs["lang_code"] = lang_code |
109 | | - |
110 | | - generate_audio(**kwargs) |
111 | | - |
112 | | - # Also save the parameters and metadata used to generate the audio |
113 | | - metadata = { |
114 | | - "type": "audio", |
115 | | - "text": text, |
116 | | - "voice": voice, |
117 | | - "filename": f"{file_prefix}.{audio_format}", |
118 | | - "model": model, |
119 | | - "speed": speed, |
120 | | - "audio_format": audio_format, |
121 | | - "sample_rate": sample_rate, |
122 | | - "temperature": temperature, |
123 | | - "top_p": top_p, |
124 | | - "date": datetime.now().isoformat(), # Store the real date and time |
125 | | - } |
126 | | - |
127 | | - metadata_file = os.path.join(audio_dir, f"{file_prefix}.json") |
128 | | - with open(metadata_file, "w") as f: |
129 | | - json.dump(metadata, f) |
130 | | - |
131 | | - logger.info(f"Audio successfully generated: {audio_dir}/{file_prefix}.{audio_format}") |
132 | | - |
133 | | - return { |
134 | | - "status": "success", |
135 | | - "message": f"{audio_dir}/{file_prefix}.{audio_format}", |
136 | | - } |
137 | | - except Exception: |
138 | | - logger.error(f"Error generating audio: {audio_dir}/{file_prefix}.{audio_format}") |
| 76 | + task = params.get("task") |
| 77 | + if task == "tts": |
| 78 | + |
| 79 | + text = params.get("text", "") |
| 80 | + model = params.get("model", None) |
| 81 | + speed = params.get("speed", 1.0) |
| 82 | + file_prefix = secure_filename(params.get("file_prefix", "audio")) |
| 83 | + audio_format = params.get("audio_format", "wav") |
| 84 | + sample_rate = params.get("sample_rate", 24000) |
| 85 | + temperature = params.get("temperature", 0.0) |
| 86 | + top_p = params.get("top_p", 1.0) |
| 87 | + stream = params.get("stream", False) |
| 88 | + voice = params.get("voice", None) |
| 89 | + lang_code = params.get("lang_code", None) |
| 90 | + stream = params.get("stream", False) |
| 91 | + |
| 92 | + experiment_dir = get_experiments_dir() |
| 93 | + audio_dir_name = secure_filename(params.get("audio_dir", "audio")) |
| 94 | + audio_dir = os.path.join(experiment_dir, audio_dir_name) |
| 95 | + os.makedirs(name=audio_dir, exist_ok=True) |
| 96 | + |
| 97 | + try: |
| 98 | + kwargs = { |
| 99 | + "text": text, |
| 100 | + "model_path": model, |
| 101 | + "speed": speed, |
| 102 | + "file_prefix": os.path.join(audio_dir, file_prefix), |
| 103 | + "sample_rate": sample_rate, |
| 104 | + "join_audio": True, |
| 105 | + "verbose": True, |
| 106 | + "temperature": temperature, |
| 107 | + "top_p": top_p, |
| 108 | + "stream": stream, |
| 109 | + "voice": voice, |
| 110 | + } |
| 111 | + if lang_code: |
| 112 | + kwargs["lang_code"] = lang_code |
| 113 | + |
| 114 | + generate_audio(**kwargs) |
| 115 | + |
| 116 | + # Also save the parameters and metadata used to generate the audio |
| 117 | + metadata = { |
| 118 | + "type": "audio", |
| 119 | + "text": text, |
| 120 | + "voice": voice, |
| 121 | + "filename": f"{file_prefix}.{audio_format}", |
| 122 | + "model": model, |
| 123 | + "speed": speed, |
| 124 | + "audio_format": audio_format, |
| 125 | + "sample_rate": sample_rate, |
| 126 | + "temperature": temperature, |
| 127 | + "top_p": top_p, |
| 128 | + "date": datetime.now().isoformat(), # Store the real date and time |
| 129 | + } |
| 130 | + |
| 131 | + metadata_file = os.path.join(audio_dir, f"{file_prefix}.json") |
| 132 | + with open(metadata_file, "w") as f: |
| 133 | + json.dump(metadata, f) |
| 134 | + |
| 135 | + logger.info(f"Audio successfully generated: {audio_dir}/{file_prefix}.{audio_format}") |
| 136 | + |
| 137 | + return { |
| 138 | + "status": "success", |
| 139 | + "message": f"{audio_dir}/{file_prefix}.{audio_format}", |
| 140 | + } |
| 141 | + except Exception: |
| 142 | + logger.error(f"Error generating audio: {audio_dir}/{file_prefix}.{audio_format}") |
| 143 | + return { |
| 144 | + "status": "error", |
| 145 | + "message": f"Error generating audio: {audio_dir}/{file_prefix}.{audio_format}", |
| 146 | + } |
| 147 | + |
| 148 | + elif task == "stt": |
| 149 | + audio_path = params.get("audio_path", "") |
| 150 | + model = params.get("model", None) |
| 151 | + format = params.get("format", "txt") |
| 152 | + output_path_name = secure_filename(params.get("output_path", "transcriptions")) |
| 153 | + transcriptions_dir = os.path.join(get_workspace_dir(), output_path_name) |
| 154 | + os.makedirs(name=transcriptions_dir, exist_ok=True) |
| 155 | + |
| 156 | + # Generate a UUID for this file name: |
| 157 | + file_prefix = str(uuid.uuid4()) |
| 158 | + |
| 159 | + try: |
| 160 | + generate( |
| 161 | + audio_path=audio_path, |
| 162 | + model_path=model, |
| 163 | + format=format, |
| 164 | + output_path=os.path.join(transcriptions_dir, file_prefix), |
| 165 | + verbose=True, # Set to False to disable print messages |
| 166 | + ) |
| 167 | + |
| 168 | + # Also save the parameters and metadata used to generate the audio |
| 169 | + metadata = { |
| 170 | + "type": "text", |
| 171 | + "audio_folder": "uploaded_audio", |
| 172 | + "audio_path": audio_path.split("/").pop(), |
| 173 | + "filename": f"{file_prefix}.{format}", |
| 174 | + "model": model, |
| 175 | + "text_format": format, |
| 176 | + "date": datetime.now().isoformat(), # Store the real date and time |
| 177 | + } |
| 178 | + metadata_file = os.path.join(transcriptions_dir, f"{file_prefix}.json") |
| 179 | + with open(metadata_file, "w") as f: |
| 180 | + json.dump(metadata, f) |
| 181 | + |
| 182 | + logger.info(f"Transcription successfully generated: {transcriptions_dir}/{file_prefix}.{format}") |
| 183 | + |
| 184 | + return { |
| 185 | + "status": "success", |
| 186 | + "message": f"{transcriptions_dir}/{file_prefix}.{format}", |
| 187 | + } |
| 188 | + except Exception: |
| 189 | + logger.error(f"Error generating transcription: {transcriptions_dir}/{file_prefix}.{format}") |
| 190 | + return { |
| 191 | + "status": "error", |
| 192 | + "message": f"Error generating transcription: {transcriptions_dir}/{file_prefix}.{format}", |
| 193 | + } |
| 194 | + |
| 195 | + else: |
| 196 | + logger.error(f"Unknown task type: {task}") |
139 | 197 | return { |
140 | 198 | "status": "error", |
141 | | - "message": f"Error generating audio: {audio_dir}/{file_prefix}.{audio_format}", |
| 199 | + "message": f"Unknown task type: {task}", |
142 | 200 | } |
143 | 201 |
|
144 | 202 |
|
| 203 | + |
145 | 204 | def release_worker_semaphore(): |
146 | 205 | worker.semaphore.release() |
147 | 206 |
|
|
0 commit comments