diff --git a/sense-voice/csrc/main.cc b/sense-voice/csrc/main.cc index c9c2e82..1df62ce 100644 --- a/sense-voice/csrc/main.cc +++ b/sense-voice/csrc/main.cc @@ -73,6 +73,7 @@ struct sense_voice_params { bool flash_attn = false; bool use_itn = false; bool use_prefix = false; + bool use_repl = false; std::string language = "auto"; std::string prompt; @@ -173,7 +174,8 @@ static void sense_voice_print_usage(int /*argc*/, char ** argv, const sense_voic fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -itn, --use-itn [%-7s] use itn\n", params.use_itn ? "true" : "false"); - fprintf(stderr, " -prefix, --use-prefix [%-7s] use itn\n", params.use_itn ? "true" : "false"); + fprintf(stderr, " -prefix, --use-prefix [%-7s] use prefix\n", params.use_itn ? "true" : "false"); + fprintf(stderr, " -repl, --use-repl [%-7s] use REPL mode\n", params.use_repl ? "true" : "false"); fprintf(stderr, "\n"); } @@ -265,6 +267,7 @@ static bool sense_voice_params_parse(int argc, char ** argv, sense_voice_params else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } else if (arg == "-itn" || arg == "--use-itn") { params.use_itn = true; } else if (arg == "-prefix" || arg == "--use-prefix") { params.use_prefix = true; } + else if (arg == "-repl" || arg == "--use-repl") { params.use_repl = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); sense_voice_print_usage(argc, argv, params); @@ -415,6 +418,187 @@ void sense_voice_free(struct sense_voice_context * ctx) { } } +bool process_audio_file(struct sense_voice_context * ctx, sense_voice_params & params, char ** argv, std::string fname_inp) { + std::vector pcmf32; // mono-channel F32 PCM + + // Since the `load_wav_file` function will output into the stdout, + // we need to check the fname first to prevent the error message printed, + // which will interfere the program that receives the output from this program. + if (fname_inp.empty()) { + fprintf(stderr, "error: no input file\n"); + return false; + } + + int sample_rate; + if (!::load_wav_file(fname_inp.c_str(), &sample_rate, pcmf32)) { + fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); + return false; + } + + if (!params.no_prints) { + // print system information + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads*params.n_processors, std::thread::hardware_concurrency(), sense_voice_print_system_info()); + + // print some info about the processing + fprintf(stderr, "\n"); + fprintf(stderr, "%s: processing audio (%d samples, %.5f sec) , %d threads, %d processors, lang = %s...\n", + __func__, int(pcmf32.size()), float(pcmf32.size())/sample_rate, + params.n_threads, params.n_processors, + params.language.c_str()); + ctx->state->duration = float(pcmf32.size())/sample_rate; + fprintf(stderr, "\n"); + } + sense_voice_full_params wparams = sense_voice_full_default_params(SENSE_VOICE_SAMPLING_GREEDY); + + { + wparams.strategy = (params.beam_size > 1 ) ? SENSE_VOICE_SAMPLING_BEAM_SEARCH : SENSE_VOICE_SAMPLING_GREEDY; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + wparams.debug_mode = params.debug_mode; + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + wparams.no_timestamps = params.no_timestamps; + + + int n_pad = 0; + std::vector chunk(CONTEXT_SIZE + SENSE_VOICE_VAD_CHUNK_PAD_SIZE, 0); + + // run vad and asr + int offset = offset = CHUNK_SIZE - CONTEXT_SIZE; + + auto & sched = ctx->state->sched_vad.sched; + +// ggml_backend_sched_set_eval_callback(sched, ctx->params.cb_eval, &ctx->params.cb_eval_user_data); + + // var for vad + bool triggered = false; + int32_t temp_end = 0; + int32_t prev_end = 0, next_start = 0; + int32_t current_speech_start = 0, current_speech_end = 0; + int32_t min_speech_samples = sample_rate * params.min_speech_duration_ms / 1000; + int32_t speech_pad_samples = sample_rate * params.speech_pad_ms / 1000; + int32_t max_speech_samples = sample_rate * params.max_speech_duration_ms / 1000 - CHUNK_SIZE - 2 * speech_pad_samples; + int32_t min_silence_samples = sample_rate * params.min_silence_duration_ms / 1000; + int32_t min_silence_samples_at_max_speech = sample_rate * 98 / 1000; + std::vector speech_segment; + for (int i = 0; i < pcmf32.size(); i += CHUNK_SIZE){ + + n_pad = CHUNK_SIZE <= pcmf32.size() - i ? 0 : CHUNK_SIZE + i - pcmf32.size(); + + for (int j = i + offset; j < i + CHUNK_SIZE; j++) { + if (j > 0 && j < i + CONTEXT_SIZE - n_pad && j < pcmf32.size()){ + chunk[j - i - offset] = pcmf32[j] / 32768; + } else{ + //pad chunk when first chunk in left or data not enough in right + chunk[j - i - offset] = 0; + } + + } + // implements reflection pad + for (int j = CONTEXT_SIZE; j < chunk.size(); j++) { + chunk[j] = chunk[2 * CONTEXT_SIZE - j - 2]; + } + + + { + if (triggered && i - current_speech_start > max_speech_samples) { + if (prev_end) { + current_speech_end = prev_end; + + // find an endpoint in speech + speech_segment.clear(); + speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); + printf("[%.2f-%.2f] ", current_speech_start / (sample_rate * 1.0), current_speech_end / (sample_rate * 1.0)); + if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return false; + } + sense_voice_print_output(ctx, params.use_prefix, params.use_itn, false); + current_speech_end = current_speech_start = 0; + if (next_start < prev_end) { + triggered = false; + } else { + current_speech_start = next_start; + } + prev_end = 0; + } + } + + float speech_prob = 0; + silero_vad_encode_internal(*ctx, *ctx->state, chunk, params.n_threads, speech_prob); + if (speech_prob >= params.threshold) { + if (temp_end) temp_end = 0; + if (next_start < prev_end) next_start = i; + } + + if (speech_prob >= params.threshold && !triggered) { + triggered = true; + current_speech_start = i; + continue; + } + + if (speech_prob < params.neg_threshold && triggered) { + if (temp_end == 0) { + temp_end = i; + } + + if (i - temp_end > min_silence_samples_at_max_speech) { + prev_end = temp_end; + } else { + continue; + } + + // TODO min_silence_samples -> max_silence_samples + if (i - prev_end < min_silence_samples) { + continue; + } else { + current_speech_end = prev_end; + if (current_speech_end - current_speech_start > min_speech_samples) { + // find an endpoint in speech + speech_segment.clear(); + speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); + printf("[%.2f-%.2f] ", current_speech_start / (sample_rate * 1.0), current_speech_end / (sample_rate * 1.0)); + if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return false; + } + sense_voice_print_output(ctx, params.use_prefix, params.use_itn, false); + current_speech_end = current_speech_start = 0; + } + prev_end = next_start = 0; + triggered = false; + continue; + } + } + } + } + // last segment speech + if (triggered && pcmf32.size() - 1 - current_speech_start > min_speech_samples) { + if (temp_end) { + current_speech_end = temp_end; + } else { + current_speech_end = pcmf32.size() - 1; + } + speech_segment.clear(); + speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); + printf("[%.2f-%.2f] ", current_speech_start / (sample_rate * 1.0), current_speech_end / (sample_rate * 1.0)); + if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return false; + } + sense_voice_print_output(ctx, true, params.use_prefix, false); + } + } + return true; +} + int main(int argc, char ** argv) { sense_voice_params params; @@ -423,20 +607,22 @@ int main(int argc, char ** argv) { return 1; } - // remove non-existent files - for (auto it = params.fname_inp.begin(); it != params.fname_inp.end();) { - const auto fname_inp = it->c_str(); + if (!params.use_repl) { + // remove non-existent files for non-repl mode + for (auto it = params.fname_inp.begin(); it != params.fname_inp.end();) { + const auto fname_inp = it->c_str(); + + if (*it != "-" && !is_file_exist(fname_inp)) { + fprintf(stderr, "error: input file not found '%s'\n", fname_inp); + it = params.fname_inp.erase(it); + continue; + } - if (*it != "-" && !is_file_exist(fname_inp)) { - fprintf(stderr, "error: input file not found '%s'\n", fname_inp); - it = params.fname_inp.erase(it); - continue; + it++; } - - it++; } - if (params.fname_inp.empty()) { + if (params.fname_inp.empty() && !params.use_repl) { fprintf(stderr, "error: no input files specified\n"); sense_voice_print_usage(argc, argv, params); return 2; @@ -470,208 +656,61 @@ int main(int argc, char ** argv) { ctx->language_id = sense_voice_lang_id(params.language.c_str()); - for (int f = 0; f < (int) params.fname_inp.size(); ++f) { - const auto fname_inp = params.fname_inp[f]; - const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; + // === init vad state start === + ctx->state->vad_ctx = ggml_init({VAD_LSTM_STATE_MEMORY_SIZE, nullptr, true}); + ctx->state->vad_lstm_context = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM); + ctx->state->vad_lstm_hidden_state = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM); + + ctx->state->vad_lstm_context_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0], + ggml_nbytes(ctx->state->vad_lstm_context) + + ggml_backend_get_alignment(ctx->state->backends[0])); + ctx->state->vad_lstm_hidden_state_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0], + ggml_nbytes(ctx->state->vad_lstm_hidden_state) + + ggml_backend_get_alignment(ctx->state->backends[0])); + auto context_alloc = ggml_tallocr_new(ctx->state->vad_lstm_context_buffer); + ggml_tallocr_alloc(&context_alloc, ctx->state->vad_lstm_context); + + auto state_alloc = ggml_tallocr_new(ctx->state->vad_lstm_hidden_state_buffer); + ggml_tallocr_alloc(&state_alloc, ctx->state->vad_lstm_hidden_state); + + ggml_set_zero(ctx->state->vad_lstm_context); + ggml_set_zero(ctx->state->vad_lstm_hidden_state); + // === init vad state end === + + if (params.use_repl) { + printf("[__INIT__]\n"); + fflush(stdout); + } - std::vector pcmf32; // mono-channel F32 PCM + if (params.use_repl) { + for (std::string line; line != "exit"; std::getline(std::cin, line)) { + if (line == "") continue; + if (!process_audio_file(ctx, params, argv, line)) continue; + printf("[__DONE__]\n"); + fflush(stdout); - int sample_rate; - if (!::load_wav_file(fname_inp.c_str(), &sample_rate, pcmf32)) { - fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); - continue; - } + SENSE_VOICE_LOG_INFO("\n%s: decoder audio use %f s, rtf is %f. \n\n", + __func__, + (ctx->state->t_encode_us + ctx->state->t_decode_us) / 1e6, + (ctx->state->t_encode_us + ctx->state->t_decode_us) / (1e6 * ctx->state->duration)); - if (!params.no_prints) { - // print system information - fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads*params.n_processors, std::thread::hardware_concurrency(), sense_voice_print_system_info()); - - // print some info about the processing - fprintf(stderr, "\n"); - fprintf(stderr, "%s: processing audio (%d samples, %.5f sec) , %d threads, %d processors, lang = %s...\n", - __func__, int(pcmf32.size()), float(pcmf32.size())/sample_rate, - params.n_threads, params.n_processors, - params.language.c_str()); - ctx->state->duration = float(pcmf32.size())/sample_rate; - fprintf(stderr, "\n"); } - sense_voice_full_params wparams = sense_voice_full_default_params(SENSE_VOICE_SAMPLING_GREEDY); - - { - wparams.strategy = (params.beam_size > 1 ) ? SENSE_VOICE_SAMPLING_BEAM_SEARCH : SENSE_VOICE_SAMPLING_GREEDY; - wparams.print_progress = params.print_progress; - wparams.print_timestamps = !params.no_timestamps; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - wparams.debug_mode = params.debug_mode; - wparams.greedy.best_of = params.best_of; - wparams.beam_search.beam_size = params.beam_size; - wparams.no_timestamps = params.no_timestamps; - - - int n_pad = 0; - std::vector chunk(CONTEXT_SIZE + SENSE_VOICE_VAD_CHUNK_PAD_SIZE, 0); - - // run vad and asr - - { - // init state - ctx->state->vad_ctx = ggml_init({VAD_LSTM_STATE_MEMORY_SIZE, nullptr, true}); - ctx->state->vad_lstm_context = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM); - ctx->state->vad_lstm_hidden_state = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM); - - ctx->state->vad_lstm_context_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0], - ggml_nbytes(ctx->state->vad_lstm_context) - + ggml_backend_get_alignment(ctx->state->backends[0])); - ctx->state->vad_lstm_hidden_state_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0], - ggml_nbytes(ctx->state->vad_lstm_hidden_state) - + ggml_backend_get_alignment(ctx->state->backends[0])); - auto context_alloc = ggml_tallocr_new(ctx->state->vad_lstm_context_buffer); - ggml_tallocr_alloc(&context_alloc, ctx->state->vad_lstm_context); - - auto state_alloc = ggml_tallocr_new(ctx->state->vad_lstm_hidden_state_buffer); - ggml_tallocr_alloc(&state_alloc, ctx->state->vad_lstm_hidden_state); - - ggml_set_zero(ctx->state->vad_lstm_context); - ggml_set_zero(ctx->state->vad_lstm_hidden_state); - } - - int offset = offset = CHUNK_SIZE - CONTEXT_SIZE; - - auto & sched = ctx->state->sched_vad.sched; - -// ggml_backend_sched_set_eval_callback(sched, ctx->params.cb_eval, &ctx->params.cb_eval_user_data); - - // var for vad - bool triggered = false; - int32_t temp_end = 0; - int32_t prev_end = 0, next_start = 0; - int32_t current_speech_start = 0, current_speech_end = 0; - int32_t min_speech_samples = sample_rate * params.min_speech_duration_ms / 1000; - int32_t speech_pad_samples = sample_rate * params.speech_pad_ms / 1000; - int32_t max_speech_samples = sample_rate * params.max_speech_duration_ms / 1000 - CHUNK_SIZE - 2 * speech_pad_samples; - int32_t min_silence_samples = sample_rate * params.min_silence_duration_ms / 1000; - int32_t min_silence_samples_at_max_speech = sample_rate * 98 / 1000; - std::vector speech_segment; - for (int i = 0; i < pcmf32.size(); i += CHUNK_SIZE){ - - n_pad = CHUNK_SIZE <= pcmf32.size() - i ? 0 : CHUNK_SIZE + i - pcmf32.size(); - - for (int j = i + offset; j < i + CHUNK_SIZE; j++) { - if (j > 0 && j < i + CONTEXT_SIZE - n_pad && j < pcmf32.size()){ - chunk[j - i - offset] = pcmf32[j] / 32768; - } else{ - //pad chunk when first chunk in left or data not enough in right - chunk[j - i - offset] = 0; - } - - } - // implements reflection pad - for (int j = CONTEXT_SIZE; j < chunk.size(); j++) { - chunk[j] = chunk[2 * CONTEXT_SIZE - j - 2]; - } - - - { - if (triggered && i - current_speech_start > max_speech_samples) { - if (prev_end) { - current_speech_end = prev_end; - - // find an endpoint in speech - speech_segment.clear(); - speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); - printf("[%.2f-%.2f] ", current_speech_start / (sample_rate * 1.0), current_speech_end / (sample_rate * 1.0)); - if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { - fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 10; - } - sense_voice_print_output(ctx, params.use_prefix, params.use_itn, false); - current_speech_end = current_speech_start = 0; - if (next_start < prev_end) { - triggered = false; - } else { - current_speech_start = next_start; - } - prev_end = 0; - } - } - - float speech_prob = 0; - silero_vad_encode_internal(*ctx, *ctx->state, chunk, params.n_threads, speech_prob); - if (speech_prob >= params.threshold) { - if (temp_end) temp_end = 0; - if (next_start < prev_end) next_start = i; - } - - if (speech_prob >= params.threshold && !triggered) { - triggered = true; - current_speech_start = i; - continue; - } - - if (speech_prob < params.neg_threshold && triggered) { - if (temp_end == 0) { - temp_end = i; - } - - if (i - temp_end > min_silence_samples_at_max_speech) { - prev_end = temp_end; - } else { - continue; - } - - // TODO min_silence_samples -> max_silence_samples - if (i - prev_end < min_silence_samples) { - continue; - } else { - current_speech_end = prev_end; - if (current_speech_end - current_speech_start > min_speech_samples) { - // find an endpoint in speech - speech_segment.clear(); - speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); - printf("[%.2f-%.2f] ", current_speech_start / (sample_rate * 1.0), current_speech_end / (sample_rate * 1.0)); - if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { - fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 10; - } - sense_voice_print_output(ctx, params.use_prefix, params.use_itn, false); - current_speech_end = current_speech_start = 0; - } - prev_end = next_start = 0; - triggered = false; - continue; - } - } - } - } - // last segment speech - if (triggered && pcmf32.size() - 1 - current_speech_start > min_speech_samples) { - if (temp_end) { - current_speech_end = temp_end; - } else { - current_speech_end = pcmf32.size() - 1; - } - speech_segment.clear(); - speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); - printf("[%.2f-%.2f] ", current_speech_start / (sample_rate * 1.0), current_speech_end / (sample_rate * 1.0)); - if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { - fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 10; - } - sense_voice_print_output(ctx, true, params.use_prefix, false); - } + } else { + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { + const auto fname_inp = params.fname_inp[f]; + const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; + + if (fname_inp == "") continue; + if (!process_audio_file(ctx, params, argv, fname_inp)) continue; + + SENSE_VOICE_LOG_INFO("\n%s: decoder audio use %f s, rtf is %f. \n\n", + __func__, + (ctx->state->t_encode_us + ctx->state->t_decode_us) / 1e6, + (ctx->state->t_encode_us + ctx->state->t_decode_us) / (1e6 * ctx->state->duration)); + } - SENSE_VOICE_LOG_INFO("\n%s: decoder audio use %f s, rtf is %f. \n\n", - __func__, - (ctx->state->t_encode_us + ctx->state->t_decode_us) / 1e6, - (ctx->state->t_encode_us + ctx->state->t_decode_us) / (1e6 * ctx->state->duration)); - } + sense_voice_free(ctx); return 0; }