diff --git a/CMakeLists.txt b/CMakeLists.txt index eabff576cc6b..3c43b95bd93f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,6 +111,8 @@ option(LLAMA_BUILD_SERVER "llama: build server example" option(LLAMA_BUILD_APP "llama: build the unified binary" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_UI "llama: build the embedded Web UI for server" ON) option(LLAMA_USE_PREBUILT_UI "llama: use prebuilt UI from HF Bucket when available (requires LLAMA_BUILD_UI=ON)" ON) +option(LLAMA_BUILD_SPEECH "llama: build speech synthesis server support and tools" OFF) +option(LLAMA_BUILD_Q3TTS "llama: build SpaceMIT Qwen3-TTS speech backend; deprecated alias for speech" OFF) option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT}) option(LLAMA_TESTS_INSTALL "llama: install tests" ON) @@ -120,6 +122,14 @@ option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) option(LLAMA_SERVER_SMT_VISION "llama: enable SpacemiT multimodal SMT extensions" OFF) +if (LLAMA_BUILD_Q3TTS) + set(LLAMA_BUILD_SPEECH ON CACHE BOOL "llama: build speech synthesis server support and tools" FORCE) +endif() + +if (LLAMA_BUILD_SPEECH) + add_compile_definitions(LLAMA_SERVER_SPEECH=1) +endif() + if (LLAMA_SERVER_SMT_VISION) add_compile_definitions(LLAMA_SERVER_SMT_VISION=1) endif() diff --git a/common/arg.cpp b/common/arg.cpp index d7c481b8f579..b512fc828d95 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2214,7 +2214,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.mmproj_use_gpu = value; } ).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD")); -#if defined(LLAMA_SERVER_SMT_VISION) +#if defined(LLAMA_SERVER_SMT_VISION) || defined(LLAMA_SERVER_SPEECH) add_opt(common_arg( {"--media-backend", "--vision-backend"}, "{auto|mtmd|smt}", string_format("multimodal backend selection (default: %s)", params.media_backend.c_str()), @@ -3847,7 +3847,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.vocoder.speaker_file = value; } - ).set_examples({LLAMA_EXAMPLE_TTS})); + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); // // diffusion params diff --git a/common/common.h b/common/common.h index 4381b13b0c77..521914d6490e 100644 --- a/common/common.h +++ b/common/common.h @@ -571,7 +571,7 @@ struct common_params { struct common_params_model mmproj; bool mmproj_use_gpu = true; // use GPU for multimodal model bool no_mmproj = false; // explicitly disable multimodal model -#if defined(LLAMA_SERVER_SMT_VISION) +#if defined(LLAMA_SERVER_SMT_VISION) || defined(LLAMA_SERVER_SPEECH) std::string media_backend = "auto"; // multimodal backend: auto|mtmd|smt std::string smt_config_dir; // SMT config dir (config.json + ONNX) #endif diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4fb4b14c7956..39e016b7017e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -174,6 +174,7 @@ option(GGML_LSX "ggml: enable lsx" ON) option(GGML_RVV "ggml: enable rvv" ON) option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) +option(GGML_RV_ZBA "ggml: enable riscv zba" ON) option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON) option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 6dfdf1269a94..9febe5136dbd 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -83,6 +83,14 @@ float ggml_table_f32_f16[1 << 16]; // precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h) float ggml_table_f32_e8m0_half[1 << 8]; +static bool ggml_cpu_fuse_swiglu_down_q8_enabled(void) { + static int enabled = -1; + if (enabled < 0) { + enabled = getenv("GGML_CPU_FUSE_SWIGLU_DOWN_Q8") != NULL; + } + return enabled != 0; +} + #if defined(__ARM_ARCH) struct ggml_arm_arch_features_type { int sve_cnt; @@ -1156,6 +1164,102 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, // ggml_compute_forward_mul_mat +static bool ggml_cpu_can_fuse_swiglu_down_q8_mul_mat( + const struct ggml_tensor * dst, + const enum ggml_type vec_dot_type) { + if (!ggml_cpu_fuse_swiglu_down_q8_enabled()) { + return false; + } + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (dst->op != GGML_OP_MUL_MAT || src0 == NULL || src1 == NULL) { + return false; + } + if (src0->type != GGML_TYPE_Q4_0 || vec_dot_type != GGML_TYPE_Q8_0) { + return false; + } + if (src1->op != GGML_OP_GLU || ggml_get_glu_op(src1) != GGML_GLU_OP_SWIGLU) { + return false; + } + if (src1->type != GGML_TYPE_F32 || src1->src[0] == NULL || src1->src[1] == NULL) { + return false; + } + + const struct ggml_tensor * gate = src1->src[0]; + const struct ggml_tensor * up = src1->src[1]; + + if (gate->type != GGML_TYPE_F32 || up->type != GGML_TYPE_F32) { + return false; + } + if (gate->ne[0] != src1->ne[0] || up->ne[0] != src1->ne[0]) { + return false; + } + if (src0->ne[0] != src1->ne[0]) { + return false; + } + if (ggml_nrows(src1) != 1) { + return false; + } + if (src1->nb[0] != sizeof(float) || gate->nb[0] != sizeof(float) || up->nb[0] != sizeof(float)) { + return false; + } + if (!ggml_is_contiguous_1(src1) || !ggml_is_contiguous_1(gate) || !ggml_is_contiguous_1(up)) { + return false; + } + + return true; +} + +static void ggml_cpu_quantize_swiglu_to_q8( + const struct ggml_compute_params * params, + const struct ggml_tensor * src1, + char * wdata, + const enum ggml_type vec_dot_type, + ggml_from_float_t from_float, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13) { + const int ith = params->ith; + const int nth = params->nth; + + const struct ggml_tensor * gate = src1->src[0]; + const struct ggml_tensor * up = src1->src[1]; + + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1 * ne11; + const size_t nbw3 = nbw2 * ne12; + + const size_t bs = ggml_blck_size(vec_dot_type); + const int64_t ne10_block_start = (ith * ne10 / bs) / nth; + const int64_t ne10_block_end = ((ith + 1) * ne10 / bs) / nth; + const int64_t offset = ne10_block_start * (int64_t) bs; + const int64_t len = (ne10_block_end - ne10_block_start) * (int64_t) bs; + + if (len <= 0) { + return; + } + + float * tmp = (float *) alloca((size_t) len * sizeof(float)); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const int64_t row = i11 + i12 * ne11 + i13 * ne12 * ne11; + const float * gate_row = (const float *) ((const char *) gate->data + row * gate->nb[1]); + const float * up_row = (const float *) ((const char *) up->data + row * up->nb[1]); + + ggml_vec_swiglu_f32((int) len, tmp, gate_row + offset, up_row + offset); + from_float(tmp, + (void *) (wdata + i13 * nbw3 + i12 * nbw2 + i11 * nbw1 + ne10_block_start * ggml_type_size(vec_dot_type)), + len); + } + } + } +} + static void ggml_compute_forward_mul_mat_one_chunk( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -1325,30 +1429,34 @@ UseGgmlGemm1:; assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); + if (ggml_cpu_can_fuse_swiglu_down_q8_mul_mat(dst, vec_dot_type)) { + ggml_cpu_quantize_swiglu_to_q8(params, src1, wdata, vec_dot_type, from_float, ne10, ne11, ne12, ne13); + } else { #if 0 - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { - from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } } } - } #else - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - size_t bs = ggml_blck_size(vec_dot_type); - int64_t ne10_block_start = (ith * ne10/bs) / nth; - int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth; - from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), - (ne10_block_end - ne10_block_start) * bs); + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + size_t bs = ggml_blck_size(vec_dot_type); + int64_t ne10_block_start = (ith * ne10/bs) / nth; + int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth; + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), + (ne10_block_end - ne10_block_start) * bs); + } } } - } #endif + } } if (ith == 0) { @@ -2982,8 +3090,11 @@ struct ggml_cplan ggml_graph_plan( } +#define GGML_CPU_FUSE_SKIP_CURRENT (-1) + // Try to fuse the current node with subsequent nodes for better performance. -// Returns the number of nodes skipped by fusion (>=1), or 0 if no fusion was applied. +// Returns the number of nodes skipped by fusion (>=1), GGML_CPU_FUSE_SKIP_CURRENT, +// or 0 if no fusion was applied. static bool ggml_cpu_disable_fusion = false; // initialized once in ggml_cpu_init(), read-only afterwards static int ggml_cpu_try_fuse_ops( @@ -2998,6 +3109,17 @@ static int ggml_cpu_try_fuse_ops( struct ggml_tensor * node = cgraph->nodes[node_n]; + if (node->op == GGML_OP_GLU && node_n + 1 < cgraph->n_nodes) { + struct ggml_tensor * mul_mat = cgraph->nodes[node_n + 1]; + if (mul_mat->op == GGML_OP_MUL_MAT && + mul_mat->src[0] != NULL && + mul_mat->src[1] == node && + ggml_node_has_n_uses(cgraph, node_n, 1) && + ggml_cpu_can_fuse_swiglu_down_q8_mul_mat(mul_mat, type_traits_cpu[mul_mat->src[0]->type].vec_dot_type)) { + return GGML_CPU_FUSE_SKIP_CURRENT; + } + } + if (node->op == GGML_OP_RMS_NORM) { // RMS_NORM + MUL fusion const enum ggml_op fuse_ops[] = { GGML_OP_RMS_NORM, GGML_OP_MUL }; @@ -3065,7 +3187,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_profile_log_op_begin(node, state->ith, params.nth); const int n_fused = ggml_cpu_try_fuse_ops(cgraph, node_n, ¶ms, cplan); - if (n_fused > 0) { + if (n_fused == GGML_CPU_FUSE_SKIP_CURRENT) { + // The following MUL_MAT will consume this GLU node's inputs directly. + } else if (n_fused > 0) { node_n += n_fused; } else { ggml_compute_forward(¶ms, node); diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 50e137b50af9..062e42adb341 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -27,6 +27,8 @@ #include #include #include // for GGML_ASSERT +#include +#include #include #include // clang-format off @@ -72,6 +74,64 @@ extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); namespace ggml::cpu::riscv64_spacemit { +namespace { + +bool mm_env_flag_enabled(const char *name) { + const char *value = std::getenv(name); + if (value == nullptr || value[0] == '\0') { + return false; + } + return !(std::strcmp(value, "0") == 0 || std::strcmp(value, "false") == 0 || std::strcmp(value, "FALSE") == 0 || + std::strcmp(value, "off") == 0 || std::strcmp(value, "OFF") == 0); +} + +bool mm_q4_hp_m1_n64_enabled() { + static const bool enabled = mm_env_flag_enabled("SPACEMIT_Q4_HP_M1_N64"); + return enabled; +} + +bool mm_fuse_swiglu_down_q8_enabled() { + static const bool enabled = std::getenv("GGML_CPU_FUSE_SWIGLU_DOWN_Q8") != nullptr; + return enabled; +} + +bool mm_can_fuse_swiglu_down_q8(const ggml_tensor * src1, int64_t gemm_m, int64_t gemm_k) { + if (!mm_fuse_swiglu_down_q8_enabled() || gemm_m != 1) { + return false; + } + if (src1 == nullptr || src1->op != GGML_OP_GLU || ggml_get_glu_op(src1) != GGML_GLU_OP_SWIGLU) { + return false; + } + if (src1->type != GGML_TYPE_F32 || src1->src[0] == nullptr || src1->src[1] == nullptr) { + return false; + } + const ggml_tensor * gate = src1->src[0]; + const ggml_tensor * up = src1->src[1]; + if (gate->type != GGML_TYPE_F32 || up->type != GGML_TYPE_F32) { + return false; + } + if (src1->ne[0] != gemm_k || gate->ne[0] != gemm_k || up->ne[0] != gemm_k) { + return false; + } + if (gemm_k % 32 != 0) { + return false; + } + if (src1->nb[0] != sizeof(float) || gate->nb[0] != sizeof(float) || up->nb[0] != sizeof(float)) { + return false; + } + return ggml_is_contiguous_1(src1) && ggml_is_contiguous_1(gate) && ggml_is_contiguous_1(up); +} + +void mm_swiglu_slice(const ggml_tensor * src1, int64_t offset, int64_t len, float * tmp) { + const ggml_tensor * gate = src1->src[0]; + const ggml_tensor * up = src1->src[1]; + const float * gate_row = reinterpret_cast(gate->data); + const float * up_row = reinterpret_cast(up->data); + ggml_vec_swiglu_f32(static_cast(len), tmp, gate_row + offset, up_row + offset); +} + +} // namespace + struct TLSContext { int cpu_id{ -1 }; cpu_set_t cpuset; @@ -349,6 +409,7 @@ template class tensor_ const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size(); const int64_t per_mb_rows_wsize = row_align * row_stride_a; const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b; + const bool fuse_swiglu_down = mm_can_fuse_swiglu_down_q8(src1, gemm_m, gemm_k); const int64_t barrier_idx = static_cast(ith / 2); @@ -361,8 +422,16 @@ template class tensor_ int a_blk_start = ith * task_per_thread; int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks); if (a_blk_start < a_blk_end) { - quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len, - quant_a_buffer + a_blk_start * block_stride_a); + const int64_t offset = a_blk_start * (int64_t) a_blk_len; + const int64_t len = (a_blk_end - a_blk_start) * (int64_t) a_blk_len; + if (fuse_swiglu_down) { + float tmp[4096]; + GGML_ASSERT(len <= (int64_t) (sizeof(tmp) / sizeof(tmp[0]))); + mm_swiglu_slice(src1, offset, len, tmp); + quantize_a_row_i8(a_blk_len, tmp, len, quant_a_buffer + a_blk_start * block_stride_a); + } else { + quantize_a_row_i8(a_blk_len, feature + offset, len, quant_a_buffer + a_blk_start * block_stride_a); + } } } else { int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth); @@ -387,7 +456,6 @@ template class tensor_ } } } - ggml_barrier(params->threadpool); const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16; @@ -442,52 +510,88 @@ template class tensor_ int64_t ni = ith * NB_COLS; int64_t nb_real = std::min(gemm_n - ni, NB_COLS); - if (ith % 2 == 0 && nb_real > 0) { - spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, - nb_real * row_stride_b); - if (a_row != quant_a_buffer) { - spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + if (gemm_m == 1 && a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + + int64_t m1_tile_cols = NB_COLS; + if (gemm_m == 1 && NB_COLS == 32) { + if (mm_q4_hp_m1_n64_enabled() && std::is_same_v && INTER_SIZE == 256 && + gemm_n >= 4096) { + m1_tile_cols = 2 * NB_COLS; + } else if constexpr (std::is_same_v && INTER_SIZE == 32) { + if (gemm_n >= 4096) { + m1_tile_cols = 2 * NB_COLS; + } + } } - } + ni = ith * m1_tile_cols; + for (; ni < gemm_n; ni += m1_tile_cols * nth) { + nb_real = std::min(gemm_n - ni, m1_tile_cols); + uint8_t * b_row = reinterpret_cast(w_data) + ni * row_stride_b; + uint8_t * b_row_zp = block_type_has_zp() ? b_row : nullptr; - spine_barrier_wait(cur_barrier); + int64_t rows_remaining = gemm_m; + float * c_blk = output + ni; + auto * a_row_cur = a_row; - if (ith % 2 != 0 && nb_real > 0) { - if (a_row != quant_a_buffer) { - spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_row, b_row_zp, c_blk, rows_remaining, + nb_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_cur += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + } + } else { + if (ith % 2 == 0 && nb_real > 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } } - spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, - nb_real * row_stride_b); - } - for (; ni < gemm_n; ni += NB_COLS * nth) { - int64_t rows_remaining = gemm_m; - float * c_blk = output + ni; - auto * a_row_cur = a_row; + spine_barrier_wait(cur_barrier); - if (ith % 2 != 0) { - spine_barrier_wait(cur_barrier); + if (ith % 2 != 0 && nb_real > 0) { + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); } - while (rows_remaining > 0) { - auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining, - nb_real, b_k_blks, gemm_n); + for (; ni < gemm_n; ni += NB_COLS * nth) { + int64_t rows_remaining = gemm_m; + float * c_blk = output + ni; + auto * a_row_cur = a_row; - c_blk += rows_handled * gemm_n; - a_row_cur += rows_handled * row_stride_a; + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } - rows_remaining -= rows_handled; - } + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining, + nb_real, b_k_blks, gemm_n); - if (ith % 2 == 0) { - spine_barrier_wait(cur_barrier); - } + c_blk += rows_handled * gemm_n; + a_row_cur += rows_handled * row_stride_a; - const int64_t next_ni = ni + NB_COLS * nth; - if (next_ni < gemm_n) { - nb_real = std::min(gemm_n - next_ni, NB_COLS); - spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + next_ni * row_stride_b, - nb_real * row_stride_b); + rows_remaining -= rows_handled; + } + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS * nth; + if (next_ni < gemm_n) { + nb_real = std::min(gemm_n - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + next_ni * row_stride_b, + nb_real * row_stride_b); + } } } } else { @@ -1278,6 +1382,7 @@ static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack { #if defined(RISCV64_SPACEMIT_IME2) if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + std::getenv("SPACEMIT_Q4_0_FORCE_32X32") == nullptr && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0; } diff --git a/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp index 0c7a036a92af..3d9988e52629 100644 --- a/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp @@ -6,6 +6,8 @@ #include #include +#include +#include #include #if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) @@ -32,6 +34,24 @@ namespace spacemit_kernels { namespace ime2 { +namespace { + +bool env_flag_enabled(const char *name) { + const char *value = std::getenv(name); + if (value == nullptr || value[0] == '\0') { + return false; + } + return !(std::strcmp(value, "0") == 0 || std::strcmp(value, "false") == 0 || std::strcmp(value, "FALSE") == 0 || + std::strcmp(value, "off") == 0 || std::strcmp(value, "OFF") == 0); +} + +bool q4_hp_m1_n64_enabled() { + static const bool enabled = env_flag_enabled("SPACEMIT_Q4_HP_M1_N64"); + return enabled; +} + +} // namespace + template void gemm_kernel_i8i2k_mrow_ref(size_t blk_len, const uint8_t * quant_a_ptr, @@ -2904,7 +2924,132 @@ void gemm_kernel_i8i4_hp_m1(size_t blk_len, const size_t b_tile_stride = k_blks * b_superblk_stride; if (quant_b_zp == NULL) { - for (size_t ni = 0; ni < count_n; ni += 32) { + size_t ni = 0; + if (q4_hp_m1_n64_enabled()) { + for (; ni + 2 * NB_COLS <= count_n; ni += 2 * NB_COLS) { + uint8_t * b_data0 = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * b_data1 = b_data0 + b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c0 = c_ptr + ni; + float * dst_c1 = c_ptr + ni + NB_COLS; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + "li t0, 0x4c00 \n\t" + "fmv.h.x fa0, t0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP64%=: \n\t" + "li t5, 8 \n\t" + "addi t6, %[A], 288 \n\t" + "flh ft1, (t6) \n\t" + "addi t6, %[A], 272 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v17, v17, v17 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v19, v19 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v21, v21, v21 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v23, v23, v23 \n\t" + + "INNER_BLK_LOOP64%=: \n\t" + ".rept 4 \n\t" + "flh fa1, (t6) \n\t" + "addi t6, t6, 2 \n\t" + "flh ft0, (%[A]) \n\t" + "addi %[A], %[A], 2 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A]) \n\t" + "addi %[A], %[A], 32 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" + "vnpack4.vv v3, v28, v28, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (%[B0]) \n\t" + "addi %[B0], %[B0], 64 \n\t" + "vl4r.v v4, (%[B0]) \n\t" + "addi %[B0], %[B0], 512 \n\t" + "vfmul.vf v8, v8, ft0 \n\t" + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" + "vfwmacc.vf v30, ft1, v10 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v16, v3, v4, v0, 4, i4 \n\t" + "vmadotsu.hp v17, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v18, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v19, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v16, v2, v4, v0, 0, i4 \n\t" + "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (%[B1]) \n\t" + "addi %[B1], %[B1], 64 \n\t" + "vl4r.v v4, (%[B1]) \n\t" + "addi %[B1], %[B1], 512 \n\t" + "vfmul.vf v8, v8, ft0 \n\t" + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" + "vfwmacc.vf v31, ft1, v10 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "addi t5, t5, -1 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v20, v3, v4, v0, 4, i4 \n\t" + "vmadotsu.hp v21, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v22, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v23, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v20, v2, v4, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v22, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v23, v2, v7, v0, 3, i4 \n\t" + ".endr \n\t" + + "bgtz t5, INNER_BLK_LOOP64%= \n\t" + + "vpack.vv v8, v16, v17, 1 \n\t" + "vpack.vv v12, v18, v19, 1 \n\t" + "vpack.vv v24, v8, v12, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v30, ft1, v24 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v8, v20, v21, 1 \n\t" + "vpack.vv v12, v22, v23, 1 \n\t" + "vpack.vv v24, v8, v12, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v31, ft1, v24 \n\t" + + "addi t4, t4, -1 \n\t" + "addi %[A], t6, 2 \n\t" + "bgtz t4, BLK_LOOP64%= \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v30, (%[DST0]) \n\t" + "vse32.v v31, (%[DST1]) \n\t" + : [A] "+r"(a_data), [B0] "+r"(b_data0), [B1] "+r"(b_data1) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", + "ft0", "ft1"); + } + } + + for (; ni < count_n; ni += 32) { uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; int8_t * a_data = (int8_t *) quant_a_ptr; float * dst_c = c_ptr + ni; @@ -2925,12 +3070,13 @@ void gemm_kernel_i8i4_hp_m1(size_t blk_len, // init the acc fp16 "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v18, v18, v18 \n\t" "vxor.vv v16, v18, v18 \n\t" "vxor.vv v17, v18, v18 \n\t" - "vxor.vv v18, v18, v18 \n\t" "vxor.vv v19, v18, v18 \n\t" "INNER_BLK_LOOP%=: \n\t" + ".rept 4 \n\t" // load a sum and scale "flh fa1, (t6) \n\t" "addi t6, t6, 2 \n\t" @@ -2941,6 +3087,12 @@ void gemm_kernel_i8i4_hp_m1(size_t blk_len, "vle8.v v3, (%[A]) \n\t" // 1x32@i8 "addi %[A], %[A], 32 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v3, v28, v28, 3 \n\t" // hi4 of A + // load scale B and B "vsetvli t0, x0, e16, mf2 \n\t" "vle16.v v8, (%[B]) \n\t" // b_scale fp16 @@ -2954,11 +3106,8 @@ void gemm_kernel_i8i4_hp_m1(size_t blk_len, "vsetvli t0, x0, e8, m1 \n\t" "vpack.vv v0, v8, v9, 3 \n\t" - "vsrl.vi v28, v3, 4 \n\t" - "vsetvli t0, x0, e16, m1 \n\t" - "vnpack4.vv v2, v3, v3, 3 \n\t" // lo4 of A - "vnpack4.vv v3, v28, v28, 3 \n\t" // hi4 of A + "addi t5, t5, -1 \n\t" // i4 * i4 vmadot "vsetvli t0, x0, e16, m1 \n\t" @@ -2970,8 +3119,8 @@ void gemm_kernel_i8i4_hp_m1(size_t blk_len, "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + ".endr \n\t" - "addi t5, t5, -1 \n\t" "bgtz t5, INNER_BLK_LOOP%= \n\t" "vpack.vv v8, v16, v17, 1 \n\t" @@ -3004,6 +3153,176 @@ void gemm_kernel_i8i4_hp_m1(size_t blk_len, } } +void gemm_kernel_i8i4_hp_m2(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + GGML_ASSERT(count_m >= 2); + + const size_t a_row_stride = q8_hp_blk_size(blk_len, true, true) * k_blks; + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr; + int8_t * a_data1 = (int8_t *) quant_a_ptr + a_row_stride; + float * dst_c0 = c_ptr + ni; + float * dst_c1 = c_ptr + ldc + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t0, 0x4c00 \n\t" // 16 in fp16 + "fmv.h.x fa0, t0 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + "li t5, 4 \n\t" + + // row0 block scale and a_sum pointer + "addi t6, %[A0], 288 \n\t" + "flh ft1, (t6) \n\t" + "addi t6, %[A0], 272 \n\t" + + // row1 block scale and a_sum pointer + "addi s2, %[A1], 288 \n\t" + "flh ft2, (s2) \n\t" + "addi s2, %[A1], 272 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v17, v17, v17 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v19, v19 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v21, v21, v21 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v23, v23, v23 \n\t" + + "INNER_BLK_LOOP%=: \n\t" + ".rept 2 \n\t" + // load shared B scale and payload once for two rows + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v11, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (%[B]) \n\t" + "addi %[B], %[B], 512 \n\t" + + // row0: same arithmetic order as m1 + "flh fa1, (t6) \n\t" + "addi t6, t6, 2 \n\t" + "flh ft0, (%[A0]) \n\t" + "addi %[A0], %[A0], 2 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v8, v11, ft0 \n\t" + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" + "vfwmacc.vf v30, ft1, v10 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" + "vnpack4.vv v3, v28, v28, 3 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v16, v3, v4, v0, 4, i4 \n\t" + "vmadotsu.hp v17, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v18, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v19, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v16, v2, v4, v0, 0, i4 \n\t" + "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + + // row1: same arithmetic order as m1 + "flh fa2, (s2) \n\t" + "addi s2, s2, 2 \n\t" + "flh ft3, (%[A1]) \n\t" + "addi %[A1], %[A1], 2 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v8, v11, ft3 \n\t" + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa2 \n\t" + "vfwmacc.vf v31, ft2, v10 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" + "vnpack4.vv v3, v28, v28, 3 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v20, v3, v4, v0, 4, i4 \n\t" + "vmadotsu.hp v21, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v22, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v23, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v20, v2, v4, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v22, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v23, v2, v7, v0, 3, i4 \n\t" + + ".endr \n\t" + "addi t5, t5, -1 \n\t" + "bgtz t5, INNER_BLK_LOOP%= \n\t" + + "vpack.vv v8, v16, v17, 1 \n\t" + "vpack.vv v12, v18, v19, 1 \n\t" + "vpack.vv v24, v8, v12, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v30, ft1, v24 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v8, v20, v21, 1 \n\t" + "vpack.vv v12, v22, v23, 1 \n\t" + "vpack.vv v24, v8, v12, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v31, ft2, v24 \n\t" + + "addi t4, t4, -1 \n\t" + "addi %[A0], t6, 2 \n\t" + "addi %[A1], s2, 2 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v30, (%[DST0]) \n\t" + "vse32.v v31, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft0", + "ft1", "ft2", "ft3"); + } + } else { + GGML_ABORT("gemm_kernel_i8i4_hp_m2 with quant_b_zp is not supported yet"); + } +} + void gemm_kernel_i8i4_m4(size_t blk_len, const uint8_t * quant_a_ptr, const uint8_t * quant_b_data, @@ -4776,9 +5095,9 @@ void gemm_kernel_i8i8_m1(size_t blk_len, const uint8_t * quant_b_zp, float * c_ptr, size_t count_m, - size_t count_n, - size_t k_blks, - size_t ldc) { + size_t count_n, + size_t k_blks, + size_t ldc) { for (size_t n = 0; n < count_n; n += 32) { size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // @@ -5625,6 +5944,14 @@ size_t gemm_kernel_i8i4_hp(size_t blk_len, gemm_kernel_i8i4_hp_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); #endif return 4; + } else if (count_m >= 2) { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m2(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 2; } else { #if 0 gemm_kernel_i8i4_hp_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.cpp b/ggml/src/ggml-cpu/spacemit/ime_env.cpp index a13ba391da2f..87a7793c9a49 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_env.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime_env.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -176,13 +177,13 @@ spine_env_info::spine_env_info() { } char * spine_perfer_core_arch_str = getenv("SPACEMIT_PERFER_CORE_ARCH"); - if (spine_perfer_core_arch_str != nullptr && spine_perfer_core_arch_str != "") { + if (spine_perfer_core_arch_str != nullptr && spine_perfer_core_arch_str[0] != '\0') { perfer_core_arch_id = spine_core_arch_id{ hex_string_to_u16(spine_perfer_core_arch_str) }; } char * spine_perfer_core_id_str = getenv("SPACEMIT_PERFER_CORE_ID"); std::vector perfer_core_id_vec; - if (spine_perfer_core_id_str != nullptr && spine_perfer_core_id_str != "") { + if (spine_perfer_core_id_str != nullptr && spine_perfer_core_id_str[0] != '\0') { std::string perfer_core_id_str(spine_perfer_core_id_str); size_t start = 0; size_t end = 0; @@ -296,8 +297,7 @@ spine_env_info::spine_env_info() { if (init_barrier != nullptr) { init_barrier_is_shared_mem = true; } else { - GGML_LOG_WARN("CPU_RISCV64_SPACEMIT: failed to allocate init_barrier from shared mem, falling back to heap\n", - __func__); + GGML_LOG_WARN("CPU_RISCV64_SPACEMIT: failed to allocate init_barrier from shared mem, falling back to heap\n"); init_barrier = new spine_barrier_t[spine_init_barrier_count]; } diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 0ca1d20a4a3b..83c04b3fbccb 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -362,6 +362,11 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_UP, "blk.%d.ffn_gate_up" }, + { LLM_TENSOR_Q3TTS_CODEC_EMBD, "q3tts.codec_embedding" }, + { LLM_TENSOR_Q3TTS_TALKER_HEAD, "q3tts.talker_head_f16" }, + { LLM_TENSOR_Q3TTS_CP_EMBD, "q3tts.cp_embedding.%d" }, + { LLM_TENSOR_Q3TTS_CP_HEAD_F16, "q3tts.cp_head_f16.%d" }, { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, @@ -591,6 +596,11 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_Q3TTS_CODEC_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + {LLM_TENSOR_Q3TTS_TALKER_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_Q3TTS_CP_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + {LLM_TENSOR_Q3TTS_CP_HEAD_F16, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index a7a21b2ef606..05d3d56e7a1c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -385,6 +385,11 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_UP, + LLM_TENSOR_Q3TTS_CODEC_EMBD, + LLM_TENSOR_Q3TTS_TALKER_HEAD, + LLM_TENSOR_Q3TTS_CP_EMBD, + LLM_TENSOR_Q3TTS_CP_HEAD_F16, LLM_TENSOR_FFN_ACT, LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility LLM_TENSOR_FFN_GATE_EXP, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7985315b68e5..4252537a2fdd 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -15,10 +15,25 @@ #include #include +#include #include #include #include +static uint32_t llama_ctx_pad() { + const char * env = std::getenv("LLAMA_CTX_PAD"); + if (env == nullptr || env[0] == '\0') { + return 256; + } + + char * end = nullptr; + const long value = std::strtol(env, &end, 10); + if (end == env || value < 1) { + return 256; + } + return (uint32_t) value; +} + // // llama_context // @@ -44,6 +59,8 @@ llama_context::llama_context( t_start_us = model.t_start_us; t_load_us = model.t_load_us; + const char * graph_cache_2way_env = std::getenv("LLAMA_GRAPH_REUSE_2WAY"); + graph_cache_2way = graph_cache_2way_env && std::atoi(graph_cache_2way_env) != 0; const auto & hparams = model.hparams; @@ -201,13 +218,14 @@ llama_context::llama_context( } // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); + const uint32_t n_ctx_pad = llama_ctx_pad(); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, n_ctx_pad); if (cparams.kv_unified) { cparams.n_ctx_seq = cparams.n_ctx; } else { cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; - cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256); + cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, n_ctx_pad); if (cparams.n_ctx_seq == 0) { throw std::runtime_error("n_ctx_seq == 0"); @@ -433,7 +451,13 @@ void llama_context::sched_reserve() { LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); gf_res_prev.reset(new llm_graph_result(max_nodes)); + if (graph_cache_2way) { + gf_res_prev_alt.reset(new llm_graph_result(max_nodes)); + } else { + gf_res_prev_alt.reset(); + } gf_res_reserve.reset(new llm_graph_result(max_nodes)); + gf_res_sched = nullptr; sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); @@ -764,6 +788,10 @@ bool llama_context::memory_update(bool optimize) { // TODO: change the mctx->apply() to return information if a graph reserve is needed // reset the graph result only if the memory module did reset the scheduler gf_res_prev->reset(); + if (gf_res_prev_alt) { + gf_res_prev_alt->reset(); + } + gf_res_sched = nullptr; if (!mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); @@ -1064,6 +1092,8 @@ void llama_context::attach_threadpool( this->threadpool = threadpool; this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; + this->graph_compute_threadpool = nullptr; + this->graph_compute_n_threads = -1; } void llama_context::detach_threadpool() { @@ -1071,6 +1101,8 @@ void llama_context::detach_threadpool() { this->threadpool = nullptr; this->threadpool_batch = nullptr; + this->graph_compute_threadpool = nullptr; + this->graph_compute_n_threads = -1; } void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { @@ -1078,6 +1110,7 @@ void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { cparams.n_threads = n_threads; cparams.n_threads_batch = n_threads_batch; + graph_compute_n_threads = -1; } void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { @@ -1266,9 +1299,24 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + auto gparams = graph_params(res, ubatch, mctx, gtype); + bool graph_reused = false; if (!graph_reuse_disable && res->can_reuse(gparams)) { + graph_reused = true; + } else if (!graph_reuse_disable && graph_cache_2way && gf_res_prev_alt) { + auto * alt_res = gf_res_prev_alt.get(); + const auto alt_gparams = graph_params(alt_res, ubatch, mctx, gtype); + if (alt_res->can_reuse(alt_gparams)) { + std::swap(gf_res_prev, gf_res_prev_alt); + res = gf_res_prev.get(); + gf = res->get_gf(); + gparams = graph_params(res, ubatch, mctx, gtype); + graph_reused = true; + } + } + + if (graph_reused) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); // with pipeline parallelism, the previous graph_compute_async may still be running @@ -1279,10 +1327,31 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } n_reused++; + + if (res != gf_res_sched) { + ggml_backend_sched_reset(sched.get()); + gf_res_sched = nullptr; + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate cached graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + + gf_res_sched = res; + } } else { + if (graph_cache_2way && gf_res_prev_alt) { + std::swap(gf_res_prev, gf_res_prev_alt); + res = gf_res_prev.get(); + gparams = graph_params(res, ubatch, mctx, gtype); + } + res->reset(); ggml_backend_sched_reset(sched.get()); + gf_res_sched = nullptr; ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); //const auto t_start_us = ggml_time_us(); @@ -1302,6 +1371,8 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll ret = GGML_STATUS_ALLOC_FAILED; return nullptr; } + + gf_res_sched = res; } // set the input data for the input tensors @@ -1651,7 +1722,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const int64_t n_embd = hparams.n_embd_inp(); // when computing embeddings, all tokens are output - const bool output_all = cparams.embeddings; + const bool output_all = cparams.embeddings && std::getenv("LLAMA_EMBEDDINGS_OUTPUT_ONLY") == nullptr; const bool has_samplers = !sampling.samplers.empty(); const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; @@ -2243,6 +2314,10 @@ ggml_cgraph * llama_context::graph_reserve( // when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that gf_res_prev->reset(); + if (gf_res_prev_alt) { + gf_res_prev_alt->reset(); + } + gf_res_sched = nullptr; // store the n_outputs as it is, and restore it afterwards // TODO: not sure if needed, might simplify in the future by removing this @@ -2318,17 +2393,21 @@ ggml_status llama_context::graph_compute( int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; - if (backend_cpu != nullptr) { + if (backend_cpu != nullptr && graph_compute_threadpool != tp) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); if (set_threadpool_fn) { set_threadpool_fn(backend_cpu, tp); } + graph_compute_threadpool = tp; } // set the number of threads for all the backends - for (const auto & set_n_threads_fn : set_n_threads_fns) { - set_n_threads_fn.second(set_n_threads_fn.first, n_threads); + if (graph_compute_n_threads != n_threads) { + for (const auto & set_n_threads_fn : set_n_threads_fns) { + set_n_threads_fn.second(set_n_threads_fn.first, n_threads); + } + graph_compute_n_threads = n_threads; } auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); diff --git a/src/llama-context.h b/src/llama-context.h index d03f681d4a13..52486c20515f 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -336,6 +336,8 @@ struct llama_context { ggml_threadpool_t threadpool = nullptr; ggml_threadpool_t threadpool_batch = nullptr; + ggml_threadpool_t graph_compute_threadpool = nullptr; + int32_t graph_compute_n_threads = -1; ggml_abort_callback abort_callback = nullptr; void * abort_callback_data = nullptr; @@ -348,7 +350,9 @@ struct llama_context { std::vector backend_buf_exp_size; // expected buffer sizes llm_graph_result_ptr gf_res_prev; + llm_graph_result_ptr gf_res_prev_alt; llm_graph_result_ptr gf_res_reserve; + llm_graph_result * gf_res_sched = nullptr; // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; @@ -360,6 +364,7 @@ struct llama_context { // env: LLAMA_GRAPH_REUSE_DISABLE bool graph_reuse_disable = false; + bool graph_cache_2way = false; // perf mutable int64_t t_start_us = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b8056cac3496..6fe12f026ce6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1588,7 +1588,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } ggml_tensor * llama_model_base::create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) { - const buft_list_t * buft_list_layer = tn.bid == -1 ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; + const buft_list_t * buft_list_layer = tn.bid == -1 || (flags & TENSOR_SKIP) ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; return ml.create_tensor( hparams, &pimpl->cpu_buft_list, pimpl->dev_input.buft_list, pimpl->dev_output.buft_list, buft_list_layer, tn, ne, flags); diff --git a/src/llama-model.h b/src/llama-model.h index 743feb970d99..289c85099bf9 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -284,6 +284,7 @@ struct llama_layer { struct ggml_tensor * ffn_gate = nullptr; // w1 struct ggml_tensor * ffn_down = nullptr; // w2 struct ggml_tensor * ffn_up = nullptr; // w3 + struct ggml_tensor * ffn_gate_up = nullptr; struct ggml_tensor * ffn_gate_enc = nullptr; struct ggml_tensor * ffn_down_enc = nullptr; struct ggml_tensor * ffn_up_enc = nullptr; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index d2955a846237..56665529d411 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -294,6 +294,10 @@ static bool tensor_allows_quantization(const llama_model_quantize_params * param const std::string name = ggml_get_name(tensor); + if (name.rfind("q3tts.", 0) == 0) { + return false; + } + // This used to be a regex, but has an extreme cost to compile times. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? diff --git a/src/models/qwen3.cpp b/src/models/qwen3.cpp index 41b97fed9564..46844746e3a2 100644 --- a/src/models/qwen3.cpp +++ b/src/models/qwen3.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include + void llama_model_qwen3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { @@ -27,6 +29,13 @@ void llama_model_qwen3::load_arch_tensors(llama_model_loader &) { // output rerank head cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_Q3TTS_CODEC_EMBD, "weight"), {n_embd, 3072}, TENSOR_NOT_REQUIRED | TENSOR_SKIP); + create_tensor(tn(LLM_TENSOR_Q3TTS_TALKER_HEAD, "weight"), {n_embd, 3072}, TENSOR_NOT_REQUIRED | TENSOR_SKIP); + for (int i = 0; i < 15; ++i) { + create_tensor(tn(LLM_TENSOR_Q3TTS_CP_EMBD, "weight", i), {n_embd, 2048}, TENSOR_NOT_REQUIRED | TENSOR_SKIP); + create_tensor(tn(LLM_TENSOR_Q3TTS_CP_HEAD_F16, "weight", i), {n_embd, 2048}, TENSOR_NOT_REQUIRED | TENSOR_SKIP); + } + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -39,9 +48,10 @@ void llama_model_qwen3::load_arch_tensors(llama_model_loader &) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_up = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP, "weight", i), {n_embd, 2*n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, layer.ffn_gate_up ? TENSOR_NOT_REQUIRED : 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, layer.ffn_gate_up ? TENSOR_NOT_REQUIRED : 0); } } @@ -121,12 +131,29 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, - model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, - model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); + if (model.layers[il].ffn_gate_up) { + ggml_tensor * gate_up = build_lora_mm(model.layers[il].ffn_gate_up, cur); + cb(gate_up, "ffn_gate_up", il); + + const int64_t n_ff = gate_up->ne[0] / 2; + ggml_tensor * gate = ggml_view_2d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->nb[1], 0); + cb(gate, "ffn_gate", il); + ggml_tensor * up = ggml_view_2d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->nb[1], n_ff * ggml_element_size(gate_up)); + cb(up, "ffn_up", il); + + cur = ggml_swiglu_split(ctx0, gate, up); + cb(cur, "ffn_swiglu", il); + + cur = build_lora_mm(model.layers[il].ffn_down, cur, model.layers[il].ffn_down_s); + cb(cur, "ffn_down", il); + } else { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + } cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); @@ -146,11 +173,14 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur, model.output_s); + const bool embed_only = cparams.embeddings && std::getenv("LLAMA_QWEN3_EMBED_ONLY") != nullptr; + if (!embed_only) { + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); - cb(cur, "result_output", -1); - res->t_logits = cur; + cb(cur, "result_output", -1); + res->t_logits = cur; + } ggml_build_forward_expand(gf, cur); } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index a60d3dab469a..cc8f524e67d2 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -29,6 +29,9 @@ else() add_subdirectory(tokenize) add_subdirectory(parser) add_subdirectory(tts) + if (LLAMA_BUILD_SPEECH) + add_subdirectory(speech) + endif() add_subdirectory(mtmd) if (GGML_RPC) add_subdirectory(rpc) diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 7b11a82e9595..94084ebe59ff 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -27,6 +27,16 @@ target_include_directories(${TARGET} PRIVATE ../mtmd) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PUBLIC llama-common mtmd ${CMAKE_THREAD_LIBS_INIT}) +if(LLAMA_BUILD_SPEECH) + target_sources(${TARGET} PRIVATE + server-speech-backend.h + server-speech.cpp + server-speech.h + server-speech-qwen3-tts.cpp + server-speech-qwen3-tts.h + ) +endif() + if(LLAMA_SERVER_SMT_VISION) if(WIN32) message(FATAL_ERROR "LLAMA_SERVER_SMT_VISION is not supported on Windows because the SMT audio backend depends on POSIX dlopen APIs") diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index eec80696ad3f..33848e3951b0 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -14,6 +14,9 @@ #include "server-http.h" #include "server-queue.h" #include "server-smt-vision.h" +#if defined(LLAMA_SERVER_SPEECH) +# include "server-speech.h" +#endif #include "server-task.h" #include "speculative.h" @@ -22,8 +25,13 @@ #include #include #include +#include #include +#include +#include #include +#include + // fix problem with std::min and std::max #if defined(_WIN32) @@ -38,6 +46,7 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; + static uint32_t server_n_outputs_max(const common_params & params) { const uint32_t n_batch = params.n_batch; @@ -71,7 +80,7 @@ enum server_state { enum server_vision_backend_mode { SERVER_VISION_BACKEND_NONE, SERVER_VISION_BACKEND_MTMD, -#if defined(LLAMA_SERVER_SMT_VISION) +#if defined(LLAMA_SERVER_SMT_VISION) || defined(LLAMA_SERVER_SPEECH) SERVER_VISION_BACKEND_SMT, #endif }; @@ -715,7 +724,7 @@ struct server_context_impl { common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; - bool model_less_reconstruction = false; + bool model_less_backend = false; common_speculative_ptr spec; @@ -756,7 +765,7 @@ struct server_context_impl { switch (vision_backend) { case SERVER_VISION_BACKEND_MTMD: return "mtmd"; -#if defined(LLAMA_SERVER_SMT_VISION) +#if defined(LLAMA_SERVER_SMT_VISION) || defined(LLAMA_SERVER_SPEECH) case SERVER_VISION_BACKEND_SMT: return "smt"; #endif @@ -835,7 +844,7 @@ struct server_context_impl { } vision_backend = SERVER_VISION_BACKEND_SMT; - model_less_reconstruction = true; + model_less_backend = true; chat_params = { /* use_jinja */ params_base.use_jinja, /* prefill_assistant */ params_base.prefill_assistant, @@ -871,6 +880,47 @@ struct server_context_impl { } #endif +#if defined(LLAMA_SERVER_SPEECH) + if (server_speech_config_matches(params_base)) { + vision_backend = SERVER_VISION_BACKEND_SMT; + model_less_backend = true; + chat_params = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* chat_template_kwargs */ params_base.default_template_kwargs, + /* tmpls */ nullptr, + /* allow_image */ false, + /* allow_audio */ false, + /* image_bin_only */ false, + /* media_backend */ vision_backend_name(), + /* enable_thinking */ false, + /* reasoning_budget */ params_base.sampling.reasoning_budget_tokens, + /* reasoning_budget_msg */ params_base.sampling.reasoning_budget_message, + /* media_path */ params_base.media_path, + /* force_pure_content */ params_base.force_pure_content_parser + }; + + if (!params_base.model_alias.empty()) { + model_name = *params_base.model_alias.begin(); + } else if (!params_base.model.name.empty()) { + model_name = params_base.model.name; + } else { + model_name = "qwen3-tts"; + } + model_aliases = params_base.model_alias; + model_tags = params_base.model_tags; + + SRV_INF("loaded speech backend '%s' (smt), '%s'\n", + server_speech_backend_name(params_base).c_str(), params_base.smt_config_dir.c_str()); + params = params_base; + if (!is_resume) { + return init(); + } + return true; + } +#endif + std::string & mmproj_path = params_base.mmproj.path; bool has_mmproj = !mmproj_path.empty(); mtmd_context_params mparams = mtmd_context_params_default(); @@ -1311,7 +1361,7 @@ struct server_context_impl { // unlike load_model(), this is only called once during initialization bool init() { - if (!model_less_reconstruction) { + if (!model_less_backend) { GGML_ASSERT(ctx_tgt != nullptr); GGML_ASSERT(model_tgt != nullptr); } @@ -1355,7 +1405,7 @@ struct server_context_impl { } } - if (model_less_reconstruction) { + if (model_less_backend) { return true; } @@ -4222,6 +4272,10 @@ void server_routes::init_routes() { // IMPORTANT: all lambda functions must start with create_response() // this is to ensure that the server_res_generator can handle sleeping case correctly +#if defined(LLAMA_SERVER_SPEECH) + auto speech_service = std::make_shared(params); +#endif + this->get_health = [this](const server_http_req &) { // error and loading states are handled by middleware auto res = create_response(true); @@ -4654,6 +4708,49 @@ void server_routes::init_routes() { TASK_RESPONSE_TYPE_OAI_ASR); }; + this->post_speech_oai = [this +#if defined(LLAMA_SERVER_SPEECH) + , speech_service +#endif + ](const server_http_req & req) { + auto res = create_response(true); +#if defined(LLAMA_SERVER_SPEECH) + if (!server_speech_config_matches(params)) { + res->error(format_error_response("The current server is not configured for a speech backend.", + ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + try { + const json body = json::parse(req.body); + const auto result = speech_service->synthesize(body); + const double gen_rtf = result.audio_seconds > 0.0 ? result.wall_seconds / result.audio_seconds : 0.0; + res->status = 200; + res->content_type = "audio/wav"; + res->data = result.wav; + res->headers["X-Speech-Backend"] = result.backend; + res->headers["X-Speech-Segments"] = std::to_string(result.segments); + res->headers["X-Speech-Audio-Seconds"] = string_format("%.3f", result.audio_seconds); + res->headers["X-Speech-Wall-Seconds"] = string_format("%.3f", result.wall_seconds); + res->headers["X-Speech-Gen-RTF"] = string_format("%.2f", gen_rtf); + return res; + } catch (const json::exception & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; + } catch (const std::invalid_argument & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; + } catch (const std::exception & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_SERVER)); + return res; + } +#else + res->error(format_error_response("This llama-server was not built with speech synthesis support.", + ERROR_TYPE_NOT_SUPPORTED)); + return res; +#endif + }; + this->post_anthropic_messages = [this](const server_http_req & req) { auto res = create_response(); std::vector files; diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 3083b733c5e0..53672066b1f6 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -115,6 +115,7 @@ struct server_routes { server_http_context::handler_t post_control; server_http_context::handler_t post_responses_oai; server_http_context::handler_t post_transcriptions_oai; + server_http_context::handler_t post_speech_oai; server_http_context::handler_t post_anthropic_messages; server_http_context::handler_t post_anthropic_count_tokens; server_http_context::handler_t post_apply_template; diff --git a/tools/server/server-speech-backend.h b/tools/server/server-speech-backend.h new file mode 100644 index 000000000000..1cdd5edc3413 --- /dev/null +++ b/tools/server/server-speech-backend.h @@ -0,0 +1,13 @@ +#pragma once + +#include "server-speech.h" + +#include + +class server_speech_backend { + public: + virtual ~server_speech_backend() = default; + + virtual const char * name() const = 0; + virtual server_speech_result synthesize(const nlohmann::ordered_json & body) = 0; +}; diff --git a/tools/server/server-speech-qwen3-tts.cpp b/tools/server/server-speech-qwen3-tts.cpp new file mode 100644 index 000000000000..4c458e4fe526 --- /dev/null +++ b/tools/server/server-speech-qwen3-tts.cpp @@ -0,0 +1,673 @@ +#include "server-speech-qwen3-tts.h" + +#include "log.h" +#include "server-speech-backend.h" +#include "server-common.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using json = nlohmann::ordered_json; + +static bool speech_path_exists(const std::string & path) { + std::error_code ec; + return !path.empty() && std::filesystem::exists(path, ec); +} + +static bool qwen3_tts_name_matches(const std::string & value) { + return value.find("qwen3-tts") != std::string::npos || + value.find("Qwen3-TTS") != std::string::npos || + value.find("q3tts") != std::string::npos; +} + +static bool qwen3_tts_config_matches(const common_params & params) { + if (params.smt_config_dir.empty()) { + return false; + } + if (params.media_backend != "smt" && params.media_backend != "auto") { + return false; + } + if (qwen3_tts_name_matches(params.model.path) || qwen3_tts_name_matches(params.model.name)) { + return true; + } + for (const auto & alias : params.model_alias) { + if (qwen3_tts_name_matches(alias)) { + return true; + } + } + + const std::filesystem::path dir(params.smt_config_dir); + return speech_path_exists((dir / "onnx" / "codec_decoder_t25.q.onnx").string()) || + speech_path_exists((dir / "gguf" / "qwen3-tts-0.6b-talker-qkv-gateup-q8_0-side.gguf").string()); +} + +static std::string speech_current_exe_dir() { + char path[4096]; + const ssize_t n = readlink("/proc/self/exe", path, sizeof(path) - 1); + if (n <= 0) { + return {}; + } + path[n] = '\0'; + return std::filesystem::path(path).parent_path().string(); +} + +static int speech_env_int(const char * name, int fallback) { + const char * value = std::getenv(name); + if (value == nullptr || value[0] == '\0') { + return fallback; + } + return std::atoi(value); +} + +static std::string speech_env_str(const char * name, const std::string & fallback) { + const char * value = std::getenv(name); + if (value == nullptr || value[0] == '\0') { + return fallback; + } + return value; +} + +static std::string qwen3_tts_find_runner() { + if (const char * env = std::getenv("Q3TTS_RUN_BIN"); env != nullptr && env[0] != '\0') { + return env; + } + const std::string exe_dir = speech_current_exe_dir(); + if (!exe_dir.empty()) { + const std::string colocated = (std::filesystem::path(exe_dir) / "q3tts-run").string(); + if (speech_path_exists(colocated)) { + return colocated; + } + } + return "q3tts-run"; +} + +static std::string qwen3_tts_default_ref_bin(const std::string & model_dir) { + const std::filesystem::path dir(model_dir); + const std::vector candidates = { + dir / "refs" / "default.spk.bin", + dir / "refs" / "default.prompt.bin", + dir / "refs" / "warm_female.spk.bin", + dir / "refs" / "warm_female_full_prompt.spk.bin", + }; + for (const auto & path : candidates) { + if (speech_path_exists(path.string())) { + return path.string(); + } + } + return {}; +} + +static bool speech_write_all_fd(int fd, const std::string & data) { + const char * ptr = data.data(); + size_t left = data.size(); + while (left > 0) { + const ssize_t n = write(fd, ptr, left); + if (n < 0) { + if (errno == EINTR) { + continue; + } + return false; + } + ptr += n; + left -= static_cast(n); + } + return true; +} + +static std::vector speech_read_file(const std::string & path) { + std::ifstream in(path, std::ios::binary); + if (!in) { + throw std::runtime_error("failed to open wav segment: " + path); + } + const std::vector bytes((std::istreambuf_iterator(in)), std::istreambuf_iterator()); + return std::vector(bytes.begin(), bytes.end()); +} + +static uint16_t speech_u16le(const uint8_t * p) { + return static_cast(p[0]) | (static_cast(p[1]) << 8); +} + +static uint32_t speech_u32le(const uint8_t * p) { + return static_cast(p[0]) | + (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16) | + (static_cast(p[3]) << 24); +} + +static void speech_put_u16le(std::string & out, uint16_t v) { + out.push_back(static_cast(v & 0xff)); + out.push_back(static_cast((v >> 8) & 0xff)); +} + +static void speech_put_u32le(std::string & out, uint32_t v) { + out.push_back(static_cast(v & 0xff)); + out.push_back(static_cast((v >> 8) & 0xff)); + out.push_back(static_cast((v >> 16) & 0xff)); + out.push_back(static_cast((v >> 24) & 0xff)); +} + +static std::vector speech_extract_pcm16_mono24k(const std::string & path) { + const std::vector wav = speech_read_file(path); + if (wav.size() < 44 || std::memcmp(wav.data(), "RIFF", 4) != 0 || std::memcmp(wav.data() + 8, "WAVE", 4) != 0) { + throw std::runtime_error("unsupported wav segment header: " + path); + } + + bool fmt_ok = false; + size_t data_offset = 0; + size_t data_size = 0; + for (size_t pos = 12; pos + 8 <= wav.size();) { + const uint8_t * chunk = wav.data() + pos; + const uint32_t size = speech_u32le(chunk + 4); + const size_t payload = pos + 8; + if (payload + size > wav.size()) { + break; + } + if (std::memcmp(chunk, "fmt ", 4) == 0 && size >= 16) { + const uint16_t format = speech_u16le(wav.data() + payload); + const uint16_t channels = speech_u16le(wav.data() + payload + 2); + const uint32_t sample_rate = speech_u32le(wav.data() + payload + 4); + const uint16_t bits = speech_u16le(wav.data() + payload + 14); + fmt_ok = format == 1 && channels == 1 && sample_rate == 24000 && bits == 16; + } else if (std::memcmp(chunk, "data", 4) == 0) { + data_offset = payload; + data_size = size; + } + pos = payload + size + (size & 1u); + } + if (!fmt_ok || data_offset == 0) { + throw std::runtime_error("unsupported wav segment format: " + path); + } + return std::vector(wav.begin() + static_cast(data_offset), + wav.begin() + static_cast(data_offset + data_size)); +} + +static std::string speech_make_wav(const std::vector & pcm) { + std::string out; + out.reserve(44 + pcm.size()); + out.append("RIFF", 4); + speech_put_u32le(out, static_cast(36 + pcm.size())); + out.append("WAVE", 4); + out.append("fmt ", 4); + speech_put_u32le(out, 16); + speech_put_u16le(out, 1); + speech_put_u16le(out, 1); + speech_put_u32le(out, 24000); + speech_put_u32le(out, 24000 * 2); + speech_put_u16le(out, 2); + speech_put_u16le(out, 16); + out.append("data", 4); + speech_put_u32le(out, static_cast(pcm.size())); + out.append(reinterpret_cast(pcm.data()), pcm.size()); + return out; +} + +static std::string speech_merge_pcm16_mono24k_wavs(const std::vector & paths, int pause_ms, double & audio_seconds) { + std::vector pcm; + const size_t pause_bytes = static_cast(std::max(0, pause_ms)) * 24 * 2; + for (size_t i = 0; i < paths.size(); ++i) { + if (i > 0) { + pcm.insert(pcm.end(), pause_bytes, 0); + } + auto segment = speech_extract_pcm16_mono24k(paths[i]); + pcm.insert(pcm.end(), segment.begin(), segment.end()); + } + audio_seconds = static_cast(pcm.size()) / (24000.0 * 2.0); + return speech_make_wav(pcm); +} + +static void speech_replace_all(std::string & text, const std::string & from, const std::string & to) { + if (from.empty()) { + return; + } + size_t pos = 0; + while ((pos = text.find(from, pos)) != std::string::npos) { + text.replace(pos, from.size(), to); + pos += to.size(); + } +} + +static std::vector> speech_hotwords_from_json(const json & body) { + std::vector> hotwords; + const json * raw = nullptr; + if (body.contains("hotwords")) { + raw = &body.at("hotwords"); + } else if (body.contains("lexicon")) { + raw = &body.at("lexicon"); + } + if (raw == nullptr) { + return hotwords; + } + if (raw->is_object()) { + for (const auto & item : raw->items()) { + if (item.value().is_string() && !item.key().empty()) { + hotwords.emplace_back(item.key(), item.value().get()); + } + } + } else if (raw->is_array()) { + for (const auto & item : *raw) { + if (!item.is_object()) { + continue; + } + std::string from; + std::string to; + for (const char * key : {"word", "from", "text"}) { + if (item.contains(key) && item.at(key).is_string()) { + from = item.at(key).get(); + break; + } + } + for (const char * key : {"phoneme", "to", "replacement"}) { + if (item.contains(key) && item.at(key).is_string()) { + to = item.at(key).get(); + break; + } + } + if (!from.empty() && !to.empty()) { + hotwords.emplace_back(std::move(from), std::move(to)); + } + } + } + std::sort(hotwords.begin(), hotwords.end(), [](const auto & a, const auto & b) { + return a.first.size() > b.first.size(); + }); + return hotwords; +} + +class qwen3_tts_backend : public server_speech_backend { + public: + explicit qwen3_tts_backend(const common_params & params); + ~qwen3_tts_backend() override; + + const char * name() const override { + return server_speech_qwen3_tts_name(); + } + server_speech_result synthesize(const json & body) override; + + private: + struct SegmentResult { + std::string wav_path; + std::string skip_reason; + std::string skip_text; + }; + + bool enabled = false; + std::string model_dir; + std::string ref_file; + int frames; + int pause_ms; + int ready_timeout_sec; + int request_timeout_sec; + + std::mutex start_mutex; + std::mutex request_mutex; + std::mutex state_mutex; + std::condition_variable state_cv; + std::deque> request_ranges; + std::map segment_results; + std::thread reader; + pid_t child_pid = -1; + int child_stdin = -1; + bool ready = false; + bool child_closed = false; + + server_speech_result synthesize_text(std::string text); + void ensure_started(); + void start_process(); + void read_loop(int fd); + void mark_closed(); + void stop(); +}; + +qwen3_tts_backend::qwen3_tts_backend(const common_params & params) : + enabled(qwen3_tts_config_matches(params)), + model_dir(params.smt_config_dir), + ref_file(params.vocoder.speaker_file), + frames(speech_env_int("Q3TTS_SERVICE_FRAMES", 160)), + pause_ms(speech_env_int("Q3TTS_SERVICE_PAUSE_MS", 200)), + ready_timeout_sec(speech_env_int("Q3TTS_SERVICE_READY_TIMEOUT", 60)), + request_timeout_sec(speech_env_int("Q3TTS_SERVICE_REQUEST_TIMEOUT", 180)) { + if (ref_file.empty()) { + ref_file = qwen3_tts_default_ref_bin(model_dir); + } + if (enabled && speech_env_int("Q3TTS_SERVICE_PREWARM", 1) != 0) { + const std::string text = speech_env_str("Q3TTS_SERVICE_PREWARM_TEXT", "你好,这是预热。"); + if (!text.empty()) { + try { + (void) synthesize_text(text); + } catch (const std::exception & e) { + SRV_WRN("Qwen3-TTS prewarm failed: %s\n", e.what()); + } + } else { + try { + ensure_started(); + } catch (const std::exception & e) { + SRV_WRN("Qwen3-TTS runner startup failed: %s\n", e.what()); + } + } + } +} + +qwen3_tts_backend::~qwen3_tts_backend() { + stop(); +} + +server_speech_result qwen3_tts_backend::synthesize(const json & body) { + std::string text; + if (body.contains("input") && body.at("input").is_string()) { + text = body.at("input").get(); + } else if (body.contains("text") && body.at("text").is_string()) { + text = body.at("text").get(); + } + if (text.empty()) { + throw std::invalid_argument("\"input\" must be a non-empty string"); + } + + const std::string response_format = json_value(body, "response_format", std::string("wav")); + if (!response_format.empty() && response_format != "wav") { + throw std::invalid_argument("Qwen3-TTS speech currently supports response_format=wav"); + } + + for (const auto & hotword : speech_hotwords_from_json(body)) { + speech_replace_all(text, hotword.first, hotword.second); + } + return synthesize_text(std::move(text)); +} + +server_speech_result qwen3_tts_backend::synthesize_text(std::string text) { + if (!enabled) { + throw std::runtime_error("Qwen3-TTS speech backend is not enabled"); + } + const auto t0 = std::chrono::steady_clock::now(); + std::lock_guard req_lock(request_mutex); + ensure_started(); + + if (!speech_write_all_fd(child_stdin, text + "\n")) { + throw std::runtime_error("failed to write Qwen3-TTS request"); + } + + std::pair range; + { + std::unique_lock lock(state_mutex); + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(request_timeout_sec); + state_cv.wait_until(lock, deadline, [&]() { + return !request_ranges.empty() || child_closed; + }); + if (request_ranges.empty()) { + throw std::runtime_error(child_closed ? "Qwen3-TTS runner exited" : "Qwen3-TTS request timeout"); + } + range = request_ranges.front(); + request_ranges.pop_front(); + } + if (range.first <= 0 || range.second < range.first) { + throw std::runtime_error("empty text after Qwen3-TTS segmentation"); + } + + std::vector wav_paths; + { + std::unique_lock lock(state_mutex); + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(request_timeout_sec); + state_cv.wait_until(lock, deadline, [&]() { + if (child_closed) { + return true; + } + for (int i = range.first; i <= range.second; ++i) { + if (segment_results.find(i) == segment_results.end()) { + return false; + } + } + return true; + }); + for (int i = range.first; i <= range.second; ++i) { + auto it = segment_results.find(i); + if (it == segment_results.end()) { + throw std::runtime_error(child_closed ? "Qwen3-TTS runner exited" : "Qwen3-TTS segment timeout"); + } + if (!it->second.skip_reason.empty()) { + throw std::runtime_error("Qwen3-TTS skipped segment: " + it->second.skip_text); + } + wav_paths.push_back(it->second.wav_path); + segment_results.erase(it); + } + } + + double audio_seconds = 0.0; + std::string wav = speech_merge_pcm16_mono24k_wavs(wav_paths, pause_ms, audio_seconds); + for (const auto & path : wav_paths) { + std::remove(path.c_str()); + } + const auto t1 = std::chrono::steady_clock::now(); + return { + std::move(wav), + name(), + range.second - range.first + 1, + audio_seconds, + std::chrono::duration(t1 - t0).count(), + }; +} + +void qwen3_tts_backend::ensure_started() { + std::lock_guard lock(start_mutex); + if (child_pid > 0 && !child_closed) { + return; + } + start_process(); +} + +void qwen3_tts_backend::start_process() { + int in_pipe[2] = {-1, -1}; + int out_pipe[2] = {-1, -1}; + if (pipe(in_pipe) != 0 || pipe(out_pipe) != 0) { + throw std::runtime_error("pipe failed for Qwen3-TTS runner"); + } + + const pid_t pid = fork(); + if (pid < 0) { + close(in_pipe[0]); + close(in_pipe[1]); + close(out_pipe[0]); + close(out_pipe[1]); + throw std::runtime_error("fork failed for Qwen3-TTS runner"); + } + + if (pid == 0) { + dup2(in_pipe[0], STDIN_FILENO); + dup2(out_pipe[1], STDOUT_FILENO); + dup2(out_pipe[1], STDERR_FILENO); + close(in_pipe[0]); + close(in_pipe[1]); + close(out_pipe[0]); + close(out_pipe[1]); + + setenv("Q3TTS_MODEL_DIR", model_dir.c_str(), 1); + const std::string frames_str = std::to_string(frames); + setenv("Q3TTS_STDIN_MAX_FRAMES", frames_str.c_str(), 1); + + std::vector args = { + qwen3_tts_find_runner(), + "--stdin-segments", + "--no-clone-split", + "--frames", + frames_str, + "--wav", + "/tmp/qwen3_tts_server_merged.wav", + }; + if (!model_dir.empty()) { + args.push_back("--model-dir"); + args.push_back(model_dir); + } + if (!ref_file.empty()) { + const bool is_wav = ref_file.size() >= 4 && ref_file.substr(ref_file.size() - 4) == ".wav"; + args.push_back(is_wav ? "--ref-wav" : "--ref-bin"); + args.push_back(ref_file); + } + + std::vector argv; + argv.reserve(args.size() + 1); + for (auto & arg : args) { + argv.push_back(arg.data()); + } + argv.push_back(nullptr); + execv(argv[0], argv.data()); + _exit(127); + } + + close(in_pipe[0]); + close(out_pipe[1]); + child_pid = pid; + child_stdin = in_pipe[1]; + ready = false; + child_closed = false; + request_ranges.clear(); + segment_results.clear(); + reader = std::thread([this, fd = out_pipe[0]]() { read_loop(fd); }); + + std::unique_lock state_lock(state_mutex); + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(ready_timeout_sec); + state_cv.wait_until(state_lock, deadline, [&]() { + return ready || child_closed; + }); + if (!ready) { + state_lock.unlock(); + if (child_stdin >= 0) { + close(child_stdin); + child_stdin = -1; + } + if (child_pid > 0) { + int status = 0; + kill(child_pid, SIGTERM); + waitpid(child_pid, &status, 0); + child_pid = -1; + } + if (reader.joinable()) { + reader.join(); + } + mark_closed(); + throw std::runtime_error("Qwen3-TTS runner did not become ready"); + } +} + +void qwen3_tts_backend::read_loop(int fd) { + FILE * stream = fdopen(fd, "r"); + if (stream == nullptr) { + close(fd); + mark_closed(); + return; + } + + char * line = nullptr; + size_t cap = 0; + const std::regex segment_re("^stream_segment\\s+([0-9]+)\\b.*\\swav\\s+(\\S+)"); + const std::regex skip_re("^stream_segment_skip\\s+([0-9]+)\\s+reason\\s+(\\S+)\\s+text\\s+(.*)"); + const std::regex truncated_re("^stream_segment_truncated\\s+([0-9]+)\\s+frames\\s+([0-9]+)\\s+max\\s+([0-9]+)\\s+text\\s+(.*)"); + const std::regex request_re("^stream_request\\s+([0-9]+)\\s+([0-9]+)"); + while (getline(&line, &cap, stream) >= 0) { + std::string msg(line); + if (!msg.empty() && msg.back() == '\n') { + msg.pop_back(); + } + if (msg.find("talker_stdin_ready") != std::string::npos) { + std::lock_guard lock(state_mutex); + ready = true; + state_cv.notify_all(); + continue; + } + + std::smatch match; + if (std::regex_match(msg, match, request_re)) { + std::lock_guard lock(state_mutex); + request_ranges.emplace_back(std::stoi(match[1].str()), std::stoi(match[2].str())); + state_cv.notify_all(); + continue; + } + if (std::regex_match(msg, match, truncated_re)) { + std::lock_guard lock(state_mutex); + auto & result = segment_results[std::stoi(match[1].str())]; + result.skip_reason = "truncated"; + result.skip_text = match[4].str(); + state_cv.notify_all(); + continue; + } + if (std::regex_match(msg, match, segment_re)) { + std::lock_guard lock(state_mutex); + segment_results[std::stoi(match[1].str())].wav_path = match[2].str(); + state_cv.notify_all(); + continue; + } + if (std::regex_match(msg, match, skip_re)) { + std::lock_guard lock(state_mutex); + auto & result = segment_results[std::stoi(match[1].str())]; + result.skip_reason = match[2].str(); + result.skip_text = match[3].str(); + state_cv.notify_all(); + continue; + } + } + free(line); + fclose(stream); + mark_closed(); +} + +void qwen3_tts_backend::mark_closed() { + std::lock_guard lock(state_mutex); + child_closed = true; + state_cv.notify_all(); +} + +void qwen3_tts_backend::stop() { + { + std::lock_guard lock(start_mutex); + if (child_stdin >= 0) { + close(child_stdin); + child_stdin = -1; + } + if (child_pid > 0) { + int status = 0; + if (waitpid(child_pid, &status, WNOHANG) == 0) { + kill(child_pid, SIGTERM); + waitpid(child_pid, &status, 0); + } + child_pid = -1; + } + } + if (reader.joinable()) { + reader.join(); + } + mark_closed(); +} + +const char * server_speech_qwen3_tts_name() { + return "qwen3-tts"; +} + +bool server_speech_qwen3_tts_config_matches(const common_params & params) { + return qwen3_tts_config_matches(params); +} + +std::unique_ptr server_speech_qwen3_tts_create(const common_params & params) { + return std::make_unique(params); +} diff --git a/tools/server/server-speech-qwen3-tts.h b/tools/server/server-speech-qwen3-tts.h new file mode 100644 index 000000000000..aa74262feb6d --- /dev/null +++ b/tools/server/server-speech-qwen3-tts.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common.h" + +#include +#include + +class server_speech_backend; + +const char * server_speech_qwen3_tts_name(); +bool server_speech_qwen3_tts_config_matches(const common_params & params); +std::unique_ptr server_speech_qwen3_tts_create(const common_params & params); diff --git a/tools/server/server-speech.cpp b/tools/server/server-speech.cpp new file mode 100644 index 000000000000..dafa573dd923 --- /dev/null +++ b/tools/server/server-speech.cpp @@ -0,0 +1,62 @@ +#include "server-speech.h" + +#include "server-speech-backend.h" +#include "server-speech-qwen3-tts.h" + +#include + +#include +#include + +using json = nlohmann::ordered_json; + +static std::string server_speech_backend_name_for_config(const common_params & params) { + if (server_speech_qwen3_tts_config_matches(params)) { + return server_speech_qwen3_tts_name(); + } + return {}; +} + +static std::unique_ptr server_speech_create_backend(const common_params & params) { + if (server_speech_qwen3_tts_config_matches(params)) { + return server_speech_qwen3_tts_create(params); + } + return nullptr; +} + +bool server_speech_config_matches(const common_params & params) { + return !server_speech_backend_name_for_config(params).empty(); +} + +std::string server_speech_backend_name(const common_params & params) { + return server_speech_backend_name_for_config(params); +} + +struct server_speech_service::impl { + explicit impl(const common_params & params) : + backend(server_speech_create_backend(params)) { + } + + server_speech_result synthesize(const json & body) { + if (!backend) { + throw std::runtime_error("speech backend is not enabled"); + } + server_speech_result result = backend->synthesize(body); + if (result.backend.empty()) { + result.backend = backend->name(); + } + return result; + } + + std::unique_ptr backend; +}; + +server_speech_service::server_speech_service(const common_params & params) : + pimpl(new impl(params)) { +} + +server_speech_service::~server_speech_service() = default; + +server_speech_result server_speech_service::synthesize(const json & body) { + return pimpl->synthesize(body); +} diff --git a/tools/server/server-speech.h b/tools/server/server-speech.h new file mode 100644 index 000000000000..ce54631ec27b --- /dev/null +++ b/tools/server/server-speech.h @@ -0,0 +1,31 @@ +#pragma once + +#include "common.h" + +#include + +#include +#include + +struct server_speech_result { + std::string wav; + std::string backend; + int segments = 0; + double audio_seconds = 0.0; + double wall_seconds = 0.0; +}; + +bool server_speech_config_matches(const common_params & params); +std::string server_speech_backend_name(const common_params & params); + +class server_speech_service { + public: + explicit server_speech_service(const common_params & params); + ~server_speech_service(); + + server_speech_result synthesize(const nlohmann::ordered_json & body); + + private: + struct impl; + std::unique_ptr pimpl; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 71d9efaa49d7..32170d801970 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -162,6 +162,7 @@ int llama_server(int argc, char ** argv) { routes.post_control = models_routes->proxy_post; routes.post_responses_oai = models_routes->proxy_post; routes.post_transcriptions_oai = models_routes->proxy_post; + routes.post_speech_oai = models_routes->proxy_post; routes.post_anthropic_messages = models_routes->proxy_post; routes.post_anthropic_count_tokens = models_routes->proxy_post; routes.post_infill = models_routes->proxy_post; @@ -202,6 +203,8 @@ int llama_server(int argc, char ** argv) { ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai)); ctx_http.post("/v1/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai)); ctx_http.post("/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai)); + ctx_http.post("/v1/audio/speech", ex_wrapper(routes.post_speech_oai)); + ctx_http.post("/audio/speech", ex_wrapper(routes.post_speech_oai)); ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting ctx_http.post("/infill", ex_wrapper(routes.post_infill)); diff --git a/tools/speech/CMakeLists.txt b/tools/speech/CMakeLists.txt new file mode 100644 index 000000000000..1a5e24be35ee --- /dev/null +++ b/tools/speech/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(backends/qwen3_tts) + +add_custom_target(speech-install + DEPENDS q3tts-install +) diff --git a/tools/speech/README.md b/tools/speech/README.md new file mode 100644 index 000000000000..b022b0d78d14 --- /dev/null +++ b/tools/speech/README.md @@ -0,0 +1,18 @@ +# Speech Synthesis Backends + +This directory owns speech output backends for `llama-server` and the +OpenAI-compatible `/v1/audio/speech` endpoint. + +The server-facing layer is `tools/server/server-speech.*`. It should expose a +stable request/result boundary and keep HTTP endpoint behavior shared across +backends. + +Each backend lives under `tools/speech/backends//` and should keep +model-specific launchers, converters, kernels, and runtime code inside that +backend directory. Shared speech utilities should be added under `tools/speech` +instead of being copied into every backend. + +Current backends: + +- `qwen3_tts`: Qwen3-TTS runner, talker driver, codec runtime, ref-bin tooling, + and K3-specific runtime packaging. diff --git a/tools/speech/backends/qwen3_tts/CMakeLists.txt b/tools/speech/backends/qwen3_tts/CMakeLists.txt new file mode 100644 index 000000000000..03d1ed674d83 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/CMakeLists.txt @@ -0,0 +1,286 @@ +include(GNUInstallDirs) + +set(Q3TTS_TARGET_PREFIX llama-q3tts) +set(Q3TTS_SCRIPT_OUTPUT_DIR ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) +if(NOT Q3TTS_SCRIPT_OUTPUT_DIR) + set(Q3TTS_SCRIPT_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) +endif() + +set(Q3TTS_ONNXRUNTIME_ROOT "" CACHE PATH "Root of an ONNX Runtime installation used by Qwen3-TTS") +set(Q3TTS_SPACEMIT_EP_ROOT "" CACHE PATH "Root of a SpacemiT EP installation copied into q3tts runtime packages") +option(Q3TTS_COPY_RUNTIME_LIBS "Copy ONNX Runtime and SpacemiT EP shared libraries into q3tts-install prefix" ON) + +set(Q3TTS_ONNXRUNTIME_INCLUDE_HINTS) +set(Q3TTS_ONNXRUNTIME_LIB_HINTS) +foreach(_q3tts_ort_root IN ITEMS "${Q3TTS_ONNXRUNTIME_ROOT}" "$ENV{Q3TTS_ONNXRUNTIME_ROOT}") + if(_q3tts_ort_root) + list(APPEND Q3TTS_ONNXRUNTIME_INCLUDE_HINTS + ${_q3tts_ort_root}/include + ${_q3tts_ort_root}/include/onnxruntime + ${_q3tts_ort_root}/include/onnxruntime/core/session + ) + list(APPEND Q3TTS_ONNXRUNTIME_LIB_HINTS + ${_q3tts_ort_root}/lib + ${_q3tts_ort_root}/lib64 + ) + endif() +endforeach() + +find_path(Q3TTS_ONNXRUNTIME_INCLUDE_DIR + NAMES onnxruntime_cxx_api.h + PATHS ${Q3TTS_ONNXRUNTIME_INCLUDE_HINTS} + NO_DEFAULT_PATH +) +find_library(Q3TTS_ONNXRUNTIME_LIB + NAMES onnxruntime + PATHS ${Q3TTS_ONNXRUNTIME_LIB_HINTS} + NO_DEFAULT_PATH +) +if(NOT Q3TTS_ONNXRUNTIME_INCLUDE_DIR) + find_path(Q3TTS_ONNXRUNTIME_INCLUDE_DIR + NAMES onnxruntime_cxx_api.h + PATHS + /usr/include + /usr/include/onnxruntime + /usr/include/onnxruntime/core/session + /usr/local/include + /usr/local/include/onnxruntime + /usr/local/include/onnxruntime/core/session + ) +endif() +if(NOT Q3TTS_ONNXRUNTIME_LIB) + find_library(Q3TTS_ONNXRUNTIME_LIB + NAMES onnxruntime + PATHS + /usr/lib + /usr/local/lib + ) +endif() + +set(Q3TTS_SPACEMIT_EP_LIB_HINTS) +foreach(_q3tts_ep_root IN ITEMS "${Q3TTS_SPACEMIT_EP_ROOT}" "$ENV{Q3TTS_SPACEMIT_EP_ROOT}") + if(_q3tts_ep_root) + list(APPEND Q3TTS_SPACEMIT_EP_LIB_HINTS + ${_q3tts_ep_root}/lib + ${_q3tts_ep_root}/lib64 + ) + endif() +endforeach() +find_library(Q3TTS_SPACEMIT_EP_LIB + NAMES spacemit_ep + PATHS ${Q3TTS_SPACEMIT_EP_LIB_HINTS} + NO_DEFAULT_PATH +) +if(NOT Q3TTS_SPACEMIT_EP_LIB) + find_library(Q3TTS_SPACEMIT_EP_LIB + NAMES spacemit_ep + PATHS /usr/lib /usr/local/lib + ) +endif() + +if(NOT Q3TTS_ONNXRUNTIME_INCLUDE_DIR OR NOT Q3TTS_ONNXRUNTIME_LIB) + message(STATUS "Qwen3-TTS tools disabled; ONNX Runtime headers/library not found") + return() +endif() +message(STATUS "Qwen3-TTS ONNX Runtime include: ${Q3TTS_ONNXRUNTIME_INCLUDE_DIR}") +message(STATUS "Qwen3-TTS ONNX Runtime library: ${Q3TTS_ONNXRUNTIME_LIB}") +if(Q3TTS_SPACEMIT_EP_LIB) + message(STATUS "Qwen3-TTS SpacemiT EP library: ${Q3TTS_SPACEMIT_EP_LIB}") +else() + message(STATUS "Qwen3-TTS SpacemiT EP library not found; runtime will rely on system lookup if needed") +endif() + +set(Q3TTS_INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/include/qwen3_tts + ${Q3TTS_ONNXRUNTIME_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/include + ${CMAKE_SOURCE_DIR}/ggml/include +) + +add_library(${Q3TTS_TARGET_PREFIX}-runtime STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/src/qwen3_tts_runtime.cpp +) +target_include_directories(${Q3TTS_TARGET_PREFIX}-runtime + PUBLIC + ${Q3TTS_INCLUDE_DIRS} +) +target_link_libraries(${Q3TTS_TARGET_PREFIX}-runtime + PUBLIC + ${Q3TTS_ONNXRUNTIME_LIB} + ggml-base + PRIVATE + dl + pthread + m +) +target_compile_features(${Q3TTS_TARGET_PREFIX}-runtime PRIVATE cxx_std_17) + +find_path(Q3TTS_ALSA_INCLUDE_DIR + NAMES alsa/asoundlib.h + PATHS /usr/include /usr/local/include +) +find_library(Q3TTS_ALSA_LIB + NAMES asound + PATHS /usr/lib /usr/local/lib +) +if(Q3TTS_ALSA_INCLUDE_DIR AND Q3TTS_ALSA_LIB) + target_compile_definitions(${Q3TTS_TARGET_PREFIX}-runtime PUBLIC Q3TTS_ENABLE_SDK_AUDIO) + target_include_directories(${Q3TTS_TARGET_PREFIX}-runtime PUBLIC ${Q3TTS_ALSA_INCLUDE_DIR}) + target_link_libraries(${Q3TTS_TARGET_PREFIX}-runtime PUBLIC ${Q3TTS_ALSA_LIB}) +endif() + +add_executable(q3tts-runner ${CMAKE_CURRENT_SOURCE_DIR}/tools/q3tts_run_main.cpp) +target_link_libraries(q3tts-runner PRIVATE ${Q3TTS_TARGET_PREFIX}-runtime) +target_compile_features(q3tts-runner PRIVATE cxx_std_17) + +add_executable(q3tts-ref-to-bin ${CMAKE_CURRENT_SOURCE_DIR}/tools/q3tts_ref_to_bin.cpp) +target_include_directories(q3tts-ref-to-bin PRIVATE ${Q3TTS_INCLUDE_DIRS}) +target_link_libraries(q3tts-ref-to-bin PRIVATE + ${Q3TTS_ONNXRUNTIME_LIB} + ggml-base + pthread + m +) +target_compile_features(q3tts-ref-to-bin PRIVATE cxx_std_17) + +add_executable(talker_driver.headmain ${CMAKE_CURRENT_SOURCE_DIR}/src/talker_driver.c) +target_compile_definitions(talker_driver.headmain PRIVATE _GNU_SOURCE) +target_compile_options(talker_driver.headmain PRIVATE + $<$:-O2> + $<$:-fno-tree-vectorize> +) +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)") + target_compile_options(talker_driver.headmain PRIVATE + $<$:-march=rv64gcv_zfh_zvfh_zba_zicbop_zihintpause> + $<$:-mabi=lp64d> + ) +endif() +target_include_directories(talker_driver.headmain PRIVATE + ${Q3TTS_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR}/src/kernels +) +target_link_libraries(talker_driver.headmain PRIVATE + llama + ggml-cpu + ggml-base + ggml + pthread + m +) + +add_executable(q3tts-cp-kernel-bench ${CMAKE_CURRENT_SOURCE_DIR}/tools/q3tts_cp_kernel_bench.cpp) +target_include_directories(q3tts-cp-kernel-bench PRIVATE + ${CMAKE_SOURCE_DIR}/ggml/src + ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cpu +) +target_link_libraries(q3tts-cp-kernel-bench PRIVATE + ggml-cpu + ggml-base + ggml + pthread + m +) +target_compile_features(q3tts-cp-kernel-bench PRIVATE cxx_std_17) + +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/q3tts-run.in + ${Q3TTS_SCRIPT_OUTPUT_DIR}/q3tts-run @ONLY) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/talker_driver.qwen3tts-k3.in + ${Q3TTS_SCRIPT_OUTPUT_DIR}/talker_driver.qwen3tts-k3 @ONLY) + +execute_process(COMMAND chmod +x + ${Q3TTS_SCRIPT_OUTPUT_DIR}/q3tts-run + ${Q3TTS_SCRIPT_OUTPUT_DIR}/talker_driver.qwen3tts-k3 +) + +get_filename_component(Q3TTS_ONNXRUNTIME_LIB_DIR ${Q3TTS_ONNXRUNTIME_LIB} DIRECTORY) +set(Q3TTS_ONNXRUNTIME_RUNTIME_LIBS) +if(Q3TTS_COPY_RUNTIME_LIBS) + file(GLOB Q3TTS_ONNXRUNTIME_RUNTIME_LIBS + ${Q3TTS_ONNXRUNTIME_LIB_DIR}/libonnxruntime.so* + ) +endif() +set(Q3TTS_SPACEMIT_EP_RUNTIME_LIBS) +if(Q3TTS_COPY_RUNTIME_LIBS AND Q3TTS_SPACEMIT_EP_LIB) + get_filename_component(Q3TTS_SPACEMIT_EP_LIB_DIR ${Q3TTS_SPACEMIT_EP_LIB} DIRECTORY) + file(GLOB Q3TTS_SPACEMIT_EP_RUNTIME_LIBS + ${Q3TTS_SPACEMIT_EP_LIB_DIR}/libspacemit_ep.so* + ) +endif() +set(Q3TTS_RUNTIME_LIB_COPY_COMMANDS) +foreach(_q3tts_runtime_lib IN LISTS Q3TTS_ONNXRUNTIME_RUNTIME_LIBS Q3TTS_SPACEMIT_EP_RUNTIME_LIBS) + list(APPEND Q3TTS_RUNTIME_LIB_COPY_COMMANDS + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${_q3tts_runtime_lib} + ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/ + ) +endforeach() +set(Q3TTS_BUILD_RUNTIME_LIB_COPY_COMMANDS) +if(Q3TTS_COPY_RUNTIME_LIBS) + set(Q3TTS_BUILD_RUNTIME_LIB_DIR ${Q3TTS_SCRIPT_OUTPUT_DIR}/../${CMAKE_INSTALL_LIBDIR}) + list(APPEND Q3TTS_BUILD_RUNTIME_LIB_COPY_COMMANDS + COMMAND ${CMAKE_COMMAND} -E make_directory ${Q3TTS_BUILD_RUNTIME_LIB_DIR} + ) + foreach(_q3tts_runtime_lib IN LISTS Q3TTS_ONNXRUNTIME_RUNTIME_LIBS Q3TTS_SPACEMIT_EP_RUNTIME_LIBS) + list(APPEND Q3TTS_BUILD_RUNTIME_LIB_COPY_COMMANDS + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${_q3tts_runtime_lib} + ${Q3TTS_BUILD_RUNTIME_LIB_DIR}/ + ) + endforeach() +endif() +if(Q3TTS_BUILD_RUNTIME_LIB_COPY_COMMANDS) + add_custom_command(TARGET q3tts-runner POST_BUILD + ${Q3TTS_BUILD_RUNTIME_LIB_COPY_COMMANDS} + VERBATIM + ) +endif() + +set(Q3TTS_LLAMA_SERVER_COPY_COMMANDS) +set(Q3TTS_LLAMA_SERVER_DEPENDS) +if(TARGET llama-server) + list(APPEND Q3TTS_LLAMA_SERVER_COPY_COMMANDS + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}/llama-server + ) + list(APPEND Q3TTS_LLAMA_SERVER_DEPENDS llama-server) +endif() + +add_custom_target(q3tts-install + COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR} + COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/$ + ${Q3TTS_RUNTIME_LIB_COPY_COMMANDS} + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}/q3tts-runner + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}/q3tts-ref-to-bin + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}/talker_driver.headmain + ${Q3TTS_LLAMA_SERVER_COPY_COMMANDS} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${Q3TTS_SCRIPT_OUTPUT_DIR}/q3tts-run ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}/q3tts-run + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${Q3TTS_SCRIPT_OUTPUT_DIR}/talker_driver.qwen3tts-k3 ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}/talker_driver.qwen3tts-k3 + COMMAND chmod +x + ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}/q3tts-run + ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}/talker_driver.qwen3tts-k3 + DEPENDS q3tts-runner q3tts-ref-to-bin talker_driver.headmain ${Q3TTS_LLAMA_SERVER_DEPENDS} + VERBATIM +) + +if(LLAMA_TOOLS_INSTALL) + install(TARGETS q3tts-runner q3tts-ref-to-bin talker_driver.headmain RUNTIME) + install(PROGRAMS + ${Q3TTS_SCRIPT_OUTPUT_DIR}/q3tts-run + ${Q3TTS_SCRIPT_OUTPUT_DIR}/talker_driver.qwen3tts-k3 + TYPE BIN + ) +endif() diff --git a/tools/speech/backends/qwen3_tts/README.md b/tools/speech/backends/qwen3_tts/README.md new file mode 100644 index 000000000000..dc857fb20f48 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/README.md @@ -0,0 +1,61 @@ +# SpaceMIT Speech Qwen3-TTS Backend + +This directory builds the SpaceMIT Qwen3-TTS speech backend inside `llama.cpp`. +It contains the Qwen3-TTS runtime, the q3tts runner, the talker driver, and the +prompt reference converter. + +The public HTTP entrypoint is `llama-server` with the OpenAI-compatible +`/v1/audio/speech` endpoint. The TTS model zoo component should call that +endpoint; it should not embed this runtime code. + +## Build + +Qwen3-TTS depends on ONNX Runtime for codec decoding and optionally copies the +SpaceMIT EP shared library into the runtime prefix. Pass dependency roots +explicitly when building from source packages: + +```bash +cmake -B build-q3tts -S . \ + -DLLAMA_BUILD_SPEECH=ON \ + -DQ3TTS_ONNXRUNTIME_ROOT=/path/to/onnxruntime/install \ + -DQ3TTS_SPACEMIT_EP_ROOT=/path/to/spacemit-ep/install \ + -DGGML_CPU_RISCV64_SPACEMIT=ON + +cmake --build build-q3tts --parallel 8 --target q3tts-install +``` + +`q3tts-install` installs `llama-server`, `q3tts-ref-to-bin`, `q3tts-run`, +`q3tts-runner`, `talker_driver.headmain`, `talker_driver.qwen3tts-k3`, and the +shared runtime libraries needed by the launcher into `CMAKE_INSTALL_PREFIX`. +`q3tts-run` is an internal/diagnostic runner used by the server-side speech +backend. + +`LLAMA_BUILD_Q3TTS=ON` is kept as a compatibility alias for existing build +scripts, but new integrations should use `LLAMA_BUILD_SPEECH=ON`. + +## Run + +```bash +llama-server \ + -m ${HOME}/.cache/models/tts/qwen3-tts/gguf/qwen3-tts-0.6b-talker-qkv-gateup-q8_0-side.gguf \ + --media-backend smt \ + --smt-config-dir ${HOME}/.cache/models/tts/qwen3-tts \ + --host 127.0.0.1 \ + --port 8090 \ + --alias qwen3-tts \ + --tts-speaker-file ${HOME}/.cache/models/tts/qwen3-tts/refs/default.spk.bin \ + -t 4 \ + -c 128 +``` + +The server exposes an OpenAI-compatible `/v1/audio/speech` endpoint for the +model-zoo TTS backend. `--tts-speaker-file` accepts a `.spk.bin` reference bin +or a `.wav` reference audio file. If it is omitted, the backend looks for a +default reference under `${smt-config-dir}/refs`. + +The runtime defaults to the Q8 side-preserving talker GGUF because it is the +validated quality default. Q4 remains available as an explicit performance mode: + +```bash +Q3TTS_TALKER_GGUF=qwen3-tts-0.6b-talker-qkv-gateup-q4_0-side.gguf llama-server ... +``` diff --git a/tools/speech/backends/qwen3_tts/cmake/q3tts-run.in b/tools/speech/backends/qwen3_tts/cmake/q3tts-run.in new file mode 100644 index 000000000000..058c3d18279f --- /dev/null +++ b/tools/speech/backends/qwen3_tts/cmake/q3tts-run.in @@ -0,0 +1,63 @@ +#!/bin/sh +set -e + +root=$(CDPATH= cd -- "$(dirname -- "$0")/.." && pwd) +model=${Q3TTS_MODEL_DIR:-${HOME:-/tmp}/.cache/models/tts/qwen3-tts} +testdata=${Q3TTS_TESTDATA_DIR:-$root/examples/testdata} +work=${Q3TTS_RUN_DIR:-/tmp/q3tts-k3-run} +mkdir -p "$work" + +if [ -d "$model/onnx" ]; then + for name in codec_decoder_t25.q.onnx codec_decoder_t50.q.onnx codec_decoder_t75.q.onnx; do + [ -e "$model/onnx/$name" ] && ln -sfn "$model/onnx/$name" "$work/$name" + done +fi +if [ -d "$testdata" ]; then + [ -e "$testdata/e2e_long.npz" ] && ln -sfn "$testdata/e2e_long.npz" "$work/e2e_long.npz" + [ -e "$testdata/e2e_spk.npz" ] && ln -sfn "$testdata/e2e_spk.npz" "$work/e2e_spk.npz" +fi + +cd "$work" +if [ -z "${Q3TTS_SPACEMIT_EP_LIB+x}" ]; then + for ep in "$root/lib/libspacemit_ep.so" "$root/lib/libspacemit_ep.so.2" "$root/lib/libspacemit_ep.so.2.0.3"; do + if [ -e "$ep" ]; then + export Q3TTS_SPACEMIT_EP_LIB="$ep" + break + fi + done +fi +if [ -z "${Q3TTS_DISABLE_SWIGLU_DOWN_FUSION+x}" ]; then + export GGML_CPU_FUSE_SWIGLU_DOWN_Q8=1 +else + unset GGML_CPU_FUSE_SWIGLU_DOWN_Q8 +fi + +exec env \ + LD_LIBRARY_PATH="$root/lib:${LD_LIBRARY_PATH:-}" \ + Q3TTS_CODEC_BUCKETS="${Q3TTS_CODEC_BUCKETS:-50}" \ + Q3TTS_CODEC_FIRST_CHUNK="${Q3TTS_CODEC_FIRST_CHUNK:-50}" \ + Q3TTS_CODEC_CHUNK="${Q3TTS_CODEC_CHUNK:-50}" \ + Q3TTS_OUTPUT_ONLY_EMBEDDINGS="${Q3TTS_OUTPUT_ONLY_EMBEDDINGS:-1}" \ + Q3TTS_CP_EMBED_ONLY="${Q3TTS_CP_EMBED_ONLY:-1}" \ + Q3TTS_HEADS_BASE="${Q3TTS_HEADS_BASE:-2}" \ + Q3TTS_HEADS_MAIN_WORK="${Q3TTS_HEADS_MAIN_WORK:-1}" \ + Q3TTS_HEAD_ARGMAX="${Q3TTS_HEAD_ARGMAX:-1}" \ + Q3TTS_CP_META_CLEAR="${Q3TTS_CP_META_CLEAR:-1}" \ + Q3TTS_CP_PROFILE="${Q3TTS_CP_PROFILE:-0}" \ + Q3TTS_LLAMA_THREADPOOL_POLL="${Q3TTS_LLAMA_THREADPOOL_POLL:-100}" \ + Q3TTS_CP_ACTIVE_HEADS="${Q3TTS_CP_ACTIVE_HEADS:-15}" \ + Q3TTS_CP_FILL_CODE="${Q3TTS_CP_FILL_CODE:-0}" \ + Q3TTS_TALKER_CTX="${Q3TTS_TALKER_CTX:-128}" \ + Q3TTS_TALKER_BATCH="${Q3TTS_TALKER_BATCH:-128}" \ + Q3TTS_TALKER_UBATCH="${Q3TTS_TALKER_UBATCH:-16}" \ + Q3TTS_NOREF_CODEC_BUCKETS="${Q3TTS_NOREF_CODEC_BUCKETS:-25}" \ + Q3TTS_NOREF_CODEC_FIRST_CHUNK="${Q3TTS_NOREF_CODEC_FIRST_CHUNK:-25}" \ + Q3TTS_NOREF_CODEC_CHUNK="${Q3TTS_NOREF_CODEC_CHUNK:-25}" \ + Q3TTS_CODEC_THREADS="${Q3TTS_CODEC_THREADS:-4}" \ + Q3TTS_CODEC_AFFINITY="${Q3TTS_CODEC_AFFINITY:-14;15;8;9}" \ + Q3TTS_MODEL_DIR="$model" \ + SPACEMIT_Q4_HP_M1_N64="${SPACEMIT_Q4_HP_M1_N64:-1}" \ + SPACEMIT_PERFER_CORE_ID="${SPACEMIT_PERFER_CORE_ID:-10,11,12,13}" \ + TALKER_DRIVER="${TALKER_DRIVER:-$root/bin/talker_driver.qwen3tts-k3}" \ + TALKER_CPUSET="${TALKER_CPUSET:-4-7}" \ + "$root/bin/q3tts-runner" --npz e2e_long.npz --wav /tmp/q3tts.wav "$@" diff --git a/tools/speech/backends/qwen3_tts/cmake/talker_driver.qwen3tts-k3.in b/tools/speech/backends/qwen3_tts/cmake/talker_driver.qwen3tts-k3.in new file mode 100644 index 000000000000..ba4cd4895e30 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/cmake/talker_driver.qwen3tts-k3.in @@ -0,0 +1,23 @@ +#!/bin/sh +set -e + +dir=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +root=$(CDPATH= cd -- "$dir/.." && pwd) +model=${Q3TTS_MODEL_DIR:-${HOME:-/tmp}/.cache/models/tts/qwen3-tts} +: "${SPACEMIT_Q4_HP_M1_N64:=1}" +export SPACEMIT_Q4_HP_M1_N64 + +resolve_gguf() { + case "$1" in + /*|*/*) printf '%s\n' "$1" ;; + *) printf '%s\n' "$model/gguf/$1" ;; + esac +} + +if [ "$#" -ge 2 ]; then + exec "$dir/talker_driver.headmain" "$@" +fi + +talker=$(resolve_gguf "${Q3TTS_TALKER_GGUF:-qwen3-tts-0.6b-talker-qkv-gateup-q8_0-side.gguf}") +cp=$(resolve_gguf "${Q3TTS_CP_GGUF:-qwen3-tts-0.6b-cp-qkv-gateup-rawq4.gguf}") +exec "$dir/talker_driver.headmain" "$talker" "$cp" "$@" diff --git a/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_audio_sdk.h b/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_audio_sdk.h new file mode 100644 index 000000000000..8ec1f848a6fb --- /dev/null +++ b/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_audio_sdk.h @@ -0,0 +1,285 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef Q3TTS_ENABLE_SDK_AUDIO +#include +#endif + +namespace q3tts_audio { + +#ifdef Q3TTS_ENABLE_SDK_AUDIO + +using Clock = std::chrono::steady_clock; + +inline std::vector resample_pcm16_linear(const std::vector &input, + int input_rate, + int output_rate, + int channels) { + if (input.empty() || input_rate == output_rate) { + return input; + } + if (input_rate <= 0 || output_rate <= 0 || channels <= 0) { + throw std::runtime_error("invalid audio resample config"); + } + + const size_t in_frames = input.size() / static_cast(channels); + if (in_frames == 0) { + return {}; + } + const size_t out_frames = std::max( + 1, static_cast( + (static_cast(in_frames) * static_cast(output_rate) + + static_cast(input_rate) / 2) / + static_cast(input_rate))); + std::vector output(out_frames * static_cast(channels)); + const double scale = static_cast(input_rate) / static_cast(output_rate); + for (size_t out_frame = 0; out_frame < out_frames; ++out_frame) { + const double src = static_cast(out_frame) * scale; + const size_t i0 = std::min(static_cast(src), in_frames - 1); + const size_t i1 = std::min(i0 + 1, in_frames - 1); + const double frac = src - static_cast(i0); + for (int ch = 0; ch < channels; ++ch) { + const int16_t a = input[i0 * static_cast(channels) + static_cast(ch)]; + const int16_t b = input[i1 * static_cast(channels) + static_cast(ch)]; + const double v = static_cast(a) + (static_cast(b) - static_cast(a)) * frac; + output[out_frame * static_cast(channels) + static_cast(ch)] = + static_cast(std::max(-32768.0, std::min(32767.0, v))); + } + } + return output; +} + +inline std::vector convert_channels_pcm16(const std::vector &input, + int input_channels, + int output_channels) { + if (input.empty() || input_channels == output_channels) { + return input; + } + + const size_t frames = input.size() / static_cast(input_channels); + std::vector output(frames * static_cast(output_channels)); + for (size_t frame = 0; frame < frames; ++frame) { + const size_t in_base = frame * static_cast(input_channels); + const size_t out_base = frame * static_cast(output_channels); + if (output_channels == 1) { + int total = 0; + for (int ch = 0; ch < input_channels; ++ch) { + total += input[in_base + static_cast(ch)]; + } + output[out_base] = static_cast(total / input_channels); + } else if (input_channels == 1) { + for (int ch = 0; ch < output_channels; ++ch) { + output[out_base + static_cast(ch)] = input[in_base]; + } + } else { + for (int ch = 0; ch < output_channels; ++ch) { + const int in_ch = ch < input_channels ? ch : input_channels - 1; + output[out_base + static_cast(ch)] = + input[in_base + static_cast(in_ch)]; + } + } + } + return output; +} + +class SdkSegmentPlayer { +public: + SdkSegmentPlayer(int sample_rate, int channels, int device, int frames_per_buffer, + int tail_ms, int drain_ms, int segment_pause_ms) + : sample_rate_(sample_rate), channels_(channels), + tail_ms_(std::max(0, tail_ms)), drain_ms_(std::max(0, drain_ms)), + segment_pause_ms_(std::max(0, segment_pause_ms)) { + if (sample_rate_ <= 0 || channels_ <= 0) { + throw std::runtime_error("invalid SDK audio playback config"); + } + frames_per_buffer_ = frames_per_buffer > 0 ? frames_per_buffer : 1024; + const char *env_device = std::getenv("Q3TTS_ALSA_DEVICE"); + device_name_ = env_device && *env_device + ? env_device + : (device >= 0 ? "plughw:" + std::to_string(device) + ",0" : "default"); + + int err = snd_pcm_open(&pcm_, device_name_.c_str(), SND_PCM_STREAM_PLAYBACK, 0); + if (err < 0) { + throw std::runtime_error("ALSA open failed for " + device_name_ + ": " + snd_strerror(err)); + } + err = snd_pcm_set_params( + pcm_, SND_PCM_FORMAT_S16_LE, SND_PCM_ACCESS_RW_INTERLEAVED, + static_cast(channels_), static_cast(sample_rate_), + 1, 200000); + if (err < 0) { + snd_pcm_close(pcm_); + pcm_ = nullptr; + throw std::runtime_error("ALSA set params failed for " + device_name_ + ": " + snd_strerror(err)); + } + std::cerr << "[Q3TTSAudio] Opened ALSA " << device_name_ + << " " << sample_rate_ << "Hz ch " << channels_ + << " buffer " << frames_per_buffer_ << "\n"; + expected_end_ = Clock::now(); + worker_ = std::thread([this]() { run(); }); + } + + ~SdkSegmentPlayer() { + finish_no_throw(); + } + + SdkSegmentPlayer(const SdkSegmentPlayer &) = delete; + SdkSegmentPlayer &operator=(const SdkSegmentPlayer &) = delete; + + double enqueue_mono24k(std::vector samples) { + if (enqueued_ > 0 && segment_pause_ms_ > 0) { + const size_t pause_frames = + static_cast(24000) * static_cast(segment_pause_ms_) / 1000; + samples.insert(samples.end(), pause_frames, 0); + } + samples = resample_pcm16_linear(samples, 24000, sample_rate_, 1); + samples = convert_channels_pcm16(samples, 1, channels_); + const double audio_s = + static_cast(samples.size()) / static_cast(sample_rate_ * channels_); + + const auto now = Clock::now(); + double gap_s = 0.0; + if (enqueued_ > 0 && now > expected_end_) { + gap_s = std::chrono::duration(now - expected_end_).count(); + } + const auto start = now > expected_end_ ? now : expected_end_; + expected_end_ = start + + std::chrono::duration_cast(std::chrono::duration(audio_s)); + ++enqueued_; + + { + std::lock_guard lock(mu_); + queue_.push_back(std::move(samples)); + } + cv_.notify_one(); + return gap_s; + } + + void finish() { + { + std::lock_guard lock(mu_); + if (finished_) { + return; + } + finished_ = true; + if (enqueued_ > 0 && tail_ms_ > 0) { + const size_t tail_frames = + static_cast(sample_rate_) * static_cast(tail_ms_) / 1000; + if (tail_frames > 0) { + queue_.emplace_back(tail_frames * static_cast(channels_), 0); + const auto now = Clock::now(); + const auto start = now > expected_end_ ? now : expected_end_; + expected_end_ = start + std::chrono::duration_cast( + std::chrono::duration(static_cast(tail_ms_) / 1000.0)); + } + } + done_ = true; + } + cv_.notify_one(); + if (worker_.joinable()) { + worker_.join(); + } + if (enqueued_ > 0 && drain_ms_ > 0) { + const auto drain_until = expected_end_ + std::chrono::milliseconds(drain_ms_); + const auto now = Clock::now(); + if (now < drain_until) { + std::this_thread::sleep_until(drain_until); + } + } + if (pcm_) { + snd_pcm_drain(pcm_); + snd_pcm_close(pcm_); + pcm_ = nullptr; + } + if (!ok_) { + throw std::runtime_error("ALSA write failed"); + } + } + +private: + void finish_no_throw() { + try { + finish(); + } catch (...) { + } + } + + void run() { + while (true) { + std::vector samples; + { + std::unique_lock lock(mu_); + cv_.wait(lock, [&]() { return done_ || !queue_.empty(); }); + if (queue_.empty()) { + if (done_) { + break; + } + continue; + } + samples = std::move(queue_.front()); + queue_.pop_front(); + } + snd_pcm_sframes_t offset = 0; + snd_pcm_sframes_t remaining = + static_cast(samples.size() / static_cast(channels_)); + while (remaining > 0) { + const snd_pcm_sframes_t chunk = + std::min(remaining, static_cast(frames_per_buffer_)); + snd_pcm_sframes_t wrote = snd_pcm_writei( + pcm_, + samples.data() + static_cast(offset) * static_cast(channels_), + chunk); + if (wrote == -EPIPE) { + snd_pcm_prepare(pcm_); + continue; + } + if (wrote < 0) { + std::cerr << "[Q3TTSAudio] Write failed: " << snd_strerror(static_cast(wrote)) << "\n"; + ok_ = false; + break; + } + offset += wrote; + remaining -= wrote; + } + if (!ok_) { + break; + } + } + } + + int sample_rate_; + int channels_; + int tail_ms_; + int drain_ms_; + int segment_pause_ms_; + int frames_per_buffer_ = 1024; + std::string device_name_; + snd_pcm_t *pcm_ = nullptr; + std::thread worker_; + std::mutex mu_; + std::condition_variable cv_; + std::deque> queue_; + bool done_ = false; + bool finished_ = false; + bool ok_ = true; + size_t enqueued_ = 0; + Clock::time_point expected_end_; +}; + +#endif + +} // namespace q3tts_audio diff --git a/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_codec_ort.h b/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_codec_ort.h new file mode 100644 index 000000000000..1f22bb3f72cc --- /dev/null +++ b/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_codec_ort.h @@ -0,0 +1,200 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace q3tts_codec { + +inline bool file_exists(const std::string &path) { + std::ifstream f(path, std::ios::binary); + return f.good(); +} + +inline std::string path_join(const std::string &a, const std::string &b) { + if (a.empty()) { + return b; + } + if (a.back() == '/') { + return a + b; + } + return a + "/" + b; +} + +inline std::string first_existing(const std::vector &paths) { + for (const auto &p : paths) { + if (!p.empty() && file_exists(p)) { + return p; + } + } + throw std::runtime_error("missing file"); +} + +inline std::string codec_model_file(const std::string &model_dir, int bucket) { + const std::string name = "codec_decoder_t" + std::to_string(bucket) + ".q.onnx"; + return first_existing({ + path_join(path_join(model_dir, "onnx"), name), + path_join(model_dir, name), + name, + }); +} + +inline std::string first_existing_ep_lib() { + const char *env = std::getenv("Q3TTS_SPACEMIT_EP_LIB"); + const std::vector paths = { + env && *env ? std::string(env) : std::string(), + "/usr/lib/python3.14/dist-packages/spacemit_ort/libspacemit_ep.so.2.0.3", + "/usr/lib/python3.14/dist-packages/spacemit_ort/libspacemit_ep.so.2", + "/usr/lib/python3/dist-packages/spacemit_ort/libspacemit_ep.so.2.0.3", + "/usr/lib/python3/dist-packages/spacemit_ort/libspacemit_ep.so.2", + }; + return first_existing(paths); +} + +inline void init_spacemit_ep(Ort::SessionOptions &so, const std::string &ep_lib) { + using InitFn = OrtStatus *(ORT_API_CALL *)(OrtSessionOptions *, + const char *const *, + const char *const *, + size_t); + static void *handle = nullptr; + static InitFn init_fn = nullptr; + if (!handle) { + handle = dlopen(ep_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (!handle) { + throw std::runtime_error(std::string("dlopen spacemit ep failed: ") + dlerror()); + } + init_fn = reinterpret_cast(dlsym(handle, "OrtSessionOptionsSpaceMITEnvInit")); + if (!init_fn) { + throw std::runtime_error(std::string("dlsym OrtSessionOptionsSpaceMITEnvInit failed: ") + dlerror()); + } + } + Ort::ThrowOnError(init_fn(so, nullptr, nullptr, 0)); +} + +struct Decoder { + int bucket = 0; + std::unique_ptr session; + std::string input_name; + std::string output_name; +}; + +struct DecoderPoolConfig { + std::string model_dir = "."; + std::vector buckets; + int intra_threads = 3; + std::string ep_lib; + std::function on_bucket_warm; +}; + +class DecoderPool { +public: + DecoderPool(Ort::Env &env, DecoderPoolConfig config) { + if (config.buckets.empty()) { + throw std::runtime_error("empty codec bucket list"); + } + const std::string ep_lib = config.ep_lib.empty() ? first_existing_ep_lib() : config.ep_lib; + Ort::AllocatorWithDefaultOptions allocator; + for (int b : config.buckets) { + Ort::SessionOptions so; + so.SetIntraOpNumThreads(config.intra_threads); + so.AddConfigEntry("session.intra_op.allow_spinning", "0"); + init_spacemit_ep(so, ep_lib); + + Decoder d; + d.bucket = b; + const std::string model = codec_model_file(config.model_dir, b); + d.session = std::make_unique(env, model.c_str(), so); + auto in = d.session->GetInputNameAllocated(0, allocator); + auto out = d.session->GetOutputNameAllocated(0, allocator); + d.input_name = in.get(); + d.output_name = out.get(); + + std::vector> warm(static_cast(b)); + for (auto &x : warm) { + x.fill(0); + } + (void)decode_with(d, warm, 0); + decoders_.emplace(b, std::move(d)); + if (config.on_bucket_warm) { + config.on_bucket_warm(b); + } + } + } + + std::vector decode(int bucket, + const std::vector> &codes, + int ctx) { + auto it = decoders_.find(bucket); + if (it == decoders_.end()) { + throw std::runtime_error("codec bucket not initialized"); + } + return decode_with(it->second, codes, ctx); + } + + std::vector decode_chunks(const std::vector> &frames, + const std::vector &buckets, + int first_chunk, + int chunk, + int ctx_limit) { + std::vector wav; + int done = 0; + while (done < static_cast(frames.size())) { + const int next_chunk = done == 0 ? first_chunk : chunk; + const int n = std::min(static_cast(frames.size()), done + next_chunk); + const int new_count = n - done; + auto it = std::find_if(buckets.begin(), buckets.end(), [&](int b) { return b >= new_count; }); + if (it == buckets.end()) { + throw std::runtime_error("no codec bucket for chunk"); + } + const int b = *it; + const int ctx = std::min({done, ctx_limit, b - new_count}); + std::vector> codes(frames.begin() + (done - ctx), frames.begin() + n); + auto chunk_wav = decode(b, codes, ctx); + wav.insert(wav.end(), chunk_wav.begin(), chunk_wav.end()); + done = n; + } + return wav; + } + +private: + static std::vector decode_with(Decoder &dec, + const std::vector> &codes, + int ctx) { + std::vector input(static_cast(16 * dec.bucket), 0); + for (size_t t = 0; t < codes.size(); ++t) { + for (int c = 0; c < 16; ++c) { + input[static_cast(c * dec.bucket) + t] = codes[t][c]; + } + } + std::array shape {1, 16, dec.bucket}; + auto mem = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + auto tensor = Ort::Value::CreateTensor(mem, input.data(), input.size(), shape.data(), shape.size()); + const char *in_names[] = {dec.input_name.c_str()}; + const char *out_names[] = {dec.output_name.c_str()}; + auto outputs = dec.session->Run(Ort::RunOptions{nullptr}, in_names, &tensor, 1, out_names, 1); + float *out = outputs[0].GetTensorMutableData(); + const size_t total = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); + const size_t begin = static_cast(ctx) * 1920; + const size_t end = std::min(total, codes.size() * static_cast(1920)); + if (begin > end) { + throw std::runtime_error("bad codec slice"); + } + return std::vector(out + begin, out + end); + } + + std::unordered_map decoders_; +}; + +} // namespace q3tts_codec diff --git a/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_frontend.h b/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_frontend.h new file mode 100644 index 000000000000..6a1106c2d6c2 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/include/qwen3_tts/q3tts_frontend.h @@ -0,0 +1,1422 @@ +#pragma once + +#include "gguf.h" +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace q3tts_frontend { + +static constexpr int HID = 1024; +static constexpr int TEXT_IM_START = 151644; +static constexpr int TEXT_IM_END = 151645; +static constexpr int TTS_PAD = 151671; +static constexpr int TTS_BOS = 151672; +static constexpr int TTS_EOS = 151673; +static constexpr int CODEC_PAD = 2148; +static constexpr int CODEC_BOS = 2149; +static constexpr int CODEC_THINK = 2154; +static constexpr int CODEC_NOTHINK = 2155; +static constexpr int CODEC_THINK_BOS = 2156; +static constexpr int CODEC_THINK_EOS = 2157; +static constexpr int CODEC_LANGUAGE_ENGLISH = 2050; +static constexpr int CODEC_LANGUAGE_CHINESE = 2055; +static constexpr int CODE_GROUPS = 16; +static constexpr int CODE_VOCAB = 2048; +static constexpr int REF_CODEC_SAMPLES = 192000; +static constexpr int REF_CODEC_SAMPLES_PER_FRAME = 1920; + +struct FrontendInput { + std::vector prefill; + std::vector trailing; + std::vector pad; + int64_t n_prefill = 0; + int64_t n_trailing = 0; +}; + +struct FrontendConfig { + std::string model_dir = "."; + std::string text; + std::string ref_wav; + std::string ref_bin; + std::string language = "auto"; + std::string talker_gguf = "qwen3-tts-0.6b-talker-qkv-gateup-q8_0-side.gguf"; + std::string cp_gguf = "qwen3-tts-0.6b-cp-qkv-gateup-rawq4.gguf"; + int frontend_threads = 2; + bool full_prompt_non_streaming = false; +}; + +struct ReferencePrompt { + std::vector speaker; + std::string ref_text; + std::vector> ref_codes; + + bool full_prompt() const { + return !ref_text.empty() && !ref_codes.empty(); + } +}; + +inline bool file_exists(const std::string &path) { + std::ifstream f(path, std::ios::binary); + return f.good(); +} + +inline std::string path_join(const std::string &a, const std::string &b) { + if (a.empty() || a == ".") { + return b; + } + if (a.back() == '/') { + return a + b; + } + return a + "/" + b; +} + +inline std::string first_existing(const std::vector &paths) { + for (const auto &p : paths) { + if (!p.empty() && file_exists(p)) { + return p; + } + } + throw std::runtime_error("missing model asset: " + (paths.empty() ? std::string("") : paths.front())); +} + +inline std::string read_text_file(const std::string &path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("open failed: " + path); + } + f.seekg(0, std::ios::end); + const size_t n = static_cast(f.tellg()); + f.seekg(0, std::ios::beg); + std::string s(n, '\0'); + f.read(&s[0], n); + return s; +} + +inline std::vector read_bytes_file(const std::string &path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("open failed: " + path); + } + f.seekg(0, std::ios::end); + const size_t n = static_cast(f.tellg()); + f.seekg(0, std::ios::beg); + std::vector out(n); + f.read(reinterpret_cast(out.data()), n); + return out; +} + +inline void write_bytes_file(const std::string &path, const uint8_t *data, size_t n) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("write failed: " + path); + } + f.write(reinterpret_cast(data), n); +} + +inline void append_utf8(uint32_t cp, std::string &out) { + if (cp <= 0x7f) { + out.push_back(static_cast(cp)); + } else if (cp <= 0x7ff) { + out.push_back(static_cast(0xc0 | (cp >> 6))); + out.push_back(static_cast(0x80 | (cp & 0x3f))); + } else if (cp <= 0xffff) { + out.push_back(static_cast(0xe0 | (cp >> 12))); + out.push_back(static_cast(0x80 | ((cp >> 6) & 0x3f))); + out.push_back(static_cast(0x80 | (cp & 0x3f))); + } else { + out.push_back(static_cast(0xf0 | (cp >> 18))); + out.push_back(static_cast(0x80 | ((cp >> 12) & 0x3f))); + out.push_back(static_cast(0x80 | ((cp >> 6) & 0x3f))); + out.push_back(static_cast(0x80 | (cp & 0x3f))); + } +} + +inline uint32_t parse_hex4(const std::string &s, size_t pos) { + uint32_t v = 0; + for (size_t i = 0; i < 4; ++i) { + char c = s[pos + i]; + v <<= 4; + if (c >= '0' && c <= '9') { + v |= static_cast(c - '0'); + } else if (c >= 'a' && c <= 'f') { + v |= static_cast(c - 'a' + 10); + } else if (c >= 'A' && c <= 'F') { + v |= static_cast(c - 'A' + 10); + } else { + throw std::runtime_error("bad json unicode escape"); + } + } + return v; +} + +inline std::string parse_json_string(const std::string &s, size_t &i) { + if (i >= s.size() || s[i] != '"') { + throw std::runtime_error("expected json string"); + } + ++i; + std::string out; + while (i < s.size()) { + unsigned char c = static_cast(s[i++]); + if (c == '"') { + return out; + } + if (c != '\\') { + out.push_back(static_cast(c)); + continue; + } + if (i >= s.size()) { + throw std::runtime_error("bad json escape"); + } + char e = s[i++]; + switch (e) { + case '"': out.push_back('"'); break; + case '\\': out.push_back('\\'); break; + case '/': out.push_back('/'); break; + case 'b': out.push_back('\b'); break; + case 'f': out.push_back('\f'); break; + case 'n': out.push_back('\n'); break; + case 'r': out.push_back('\r'); break; + case 't': out.push_back('\t'); break; + case 'u': { + if (i + 4 > s.size()) { + throw std::runtime_error("short json unicode escape"); + } + uint32_t cp = parse_hex4(s, i); + i += 4; + if (cp >= 0xd800 && cp <= 0xdbff && i + 6 <= s.size() && s[i] == '\\' && s[i + 1] == 'u') { + i += 2; + uint32_t lo = parse_hex4(s, i); + i += 4; + cp = 0x10000 + ((cp - 0xd800) << 10) + (lo - 0xdc00); + } + append_utf8(cp, out); + break; + } + default: + throw std::runtime_error("unsupported json escape"); + } + } + throw std::runtime_error("unterminated json string"); +} + +inline void skip_json_ws(const std::string &s, size_t &i) { + while (i < s.size()) { + char c = s[i]; + if (c == ' ' || c == '\n' || c == '\r' || c == '\t') { + ++i; + } else { + break; + } + } +} + +inline std::vector make_byte_encoder() { + std::vector bs; + for (int b = 33; b <= 126; ++b) bs.push_back(b); + for (int b = 161; b <= 172; ++b) bs.push_back(b); + for (int b = 174; b <= 255; ++b) bs.push_back(b); + std::vector cs = bs; + int n = 0; + for (int b = 0; b < 256; ++b) { + if (std::find(bs.begin(), bs.end(), b) == bs.end()) { + bs.push_back(b); + cs.push_back(256 + n); + ++n; + } + } + std::vector enc(256); + for (size_t i = 0; i < bs.size(); ++i) { + append_utf8(static_cast(cs[i]), enc[static_cast(bs[i])]); + } + return enc; +} + +inline uint32_t decode_utf8_one(const std::string &s, size_t &i, std::string *bytes = nullptr) { + const size_t start = i; + unsigned char c = static_cast(s[i++]); + uint32_t cp = c; + if (c < 0x80) { + cp = c; + } else if ((c >> 5) == 0x6 && i < s.size()) { + const uint32_t b1 = static_cast(s[i++]) & 0x3f; + cp = ((c & 0x1f) << 6) | b1; + } else if ((c >> 4) == 0xe && i + 1 < s.size()) { + const uint32_t b1 = static_cast(s[i++]) & 0x3f; + const uint32_t b2 = static_cast(s[i++]) & 0x3f; + cp = ((c & 0x0f) << 12) | (b1 << 6) | b2; + } else if ((c >> 3) == 0x1e && i + 2 < s.size()) { + const uint32_t b1 = static_cast(s[i++]) & 0x3f; + const uint32_t b2 = static_cast(s[i++]) & 0x3f; + const uint32_t b3 = static_cast(s[i++]) & 0x3f; + cp = ((c & 0x07) << 18) | (b1 << 12) | (b2 << 6) | b3; + } + if (bytes) { + *bytes = s.substr(start, i - start); + } + return cp; +} + +inline bool is_space_cp(uint32_t cp) { + return cp == ' ' || cp == '\n' || cp == '\r' || cp == '\t' || cp == '\v' || cp == '\f'; +} + +inline bool is_crlf_cp(uint32_t cp) { + return cp == '\n' || cp == '\r'; +} + +inline bool is_digit_cp(uint32_t cp) { + return cp >= '0' && cp <= '9'; +} + +inline bool is_letter_cp(uint32_t cp) { + if ((cp >= 'a' && cp <= 'z') || (cp >= 'A' && cp <= 'Z')) { + return true; + } + if ((cp >= 0x4e00 && cp <= 0x9fff) || (cp >= 0x3400 && cp <= 0x4dbf) || + (cp >= 0x3040 && cp <= 0x30ff) || (cp >= 0xac00 && cp <= 0xd7af)) { + return true; + } + if (cp >= 0x80 && !is_space_cp(cp) && + !(cp >= 0x3000 && cp <= 0x303f) && + !(cp >= 0xff00 && cp <= 0xff65)) { + return true; + } + return false; +} + +enum class CharKind { + Space, + Digit, + Letter, + Punct, +}; + +inline CharKind char_kind(uint32_t cp) { + if (is_space_cp(cp)) return CharKind::Space; + if (is_digit_cp(cp)) return CharKind::Digit; + if (is_letter_cp(cp)) return CharKind::Letter; + return CharKind::Punct; +} + +inline char ascii_lower(char c) { + return (c >= 'A' && c <= 'Z') ? static_cast(c - 'A' + 'a') : c; +} + +inline std::string lower_ascii(std::string s) { + for (char &c : s) { + c = ascii_lower(c); + } + return s; +} + +inline int codec_language_id_for(const std::string &language) { + const std::string l = lower_ascii(language); + if (l.empty() || l == "auto") { + return 0; + } + if (l == "chinese" || l == "zh" || l == "zh-cn") { + return CODEC_LANGUAGE_CHINESE; + } + if (l == "english" || l == "en" || l == "en-us") { + return CODEC_LANGUAGE_ENGLISH; + } + throw std::runtime_error("unsupported --language: " + language); +} + +inline std::vector> utf8_chars(const std::string &s) { + std::vector> out; + size_t i = 0; + while (i < s.size()) { + std::string b; + uint32_t cp = decode_utf8_one(s, i, &b); + out.emplace_back(cp, std::move(b)); + } + return out; +} + +inline std::vector split_utf8_strings(const std::string &s) { + std::vector out; + size_t i = 0; + while (i < s.size()) { + std::string b; + (void)decode_utf8_one(s, i, &b); + out.push_back(std::move(b)); + } + return out; +} + +class BpeTokenizer { +public: + BpeTokenizer(const std::string &vocab_path, const std::string &merges_path) { + byte_encoder_ = make_byte_encoder(); + load_vocab(vocab_path); + load_merges(merges_path); + specials_["<|im_start|>"] = TEXT_IM_START; + specials_["<|im_end|>"] = TEXT_IM_END; + } + + std::vector encode(const std::string &text) const { + std::vector out; + size_t pos = 0; + while (pos < text.size()) { + std::string special; + int special_id = -1; + for (const auto &kv : specials_) { + const std::string &tok = kv.first; + if (text.compare(pos, tok.size(), tok) == 0) { + if (tok.size() > special.size()) { + special = tok; + special_id = kv.second; + } + } + } + if (special_id >= 0) { + out.push_back(special_id); + pos += special.size(); + continue; + } + + size_t next = text.size(); + for (const auto &kv : specials_) { + size_t p = text.find(kv.first, pos); + if (p != std::string::npos) { + next = std::min(next, p); + } + } + encode_plain(text.substr(pos, next - pos), out); + pos = next; + } + return out; + } + +private: + std::unordered_map vocab_; + std::unordered_map ranks_; + std::unordered_map specials_; + std::vector byte_encoder_; + + static std::string pair_key(const std::string &a, const std::string &b) { + return a + "\001" + b; + } + + void load_vocab(const std::string &path) { + const std::string s = read_text_file(path); + size_t i = 0; + skip_json_ws(s, i); + if (i >= s.size() || s[i] != '{') { + throw std::runtime_error("bad vocab json"); + } + ++i; + while (i < s.size()) { + skip_json_ws(s, i); + if (i < s.size() && s[i] == '}') { + break; + } + std::string key = parse_json_string(s, i); + skip_json_ws(s, i); + if (i >= s.size() || s[i++] != ':') { + throw std::runtime_error("bad vocab json colon"); + } + skip_json_ws(s, i); + bool neg = false; + if (s[i] == '-') { + neg = true; + ++i; + } + int64_t val = 0; + while (i < s.size() && s[i] >= '0' && s[i] <= '9') { + val = val * 10 + (s[i++] - '0'); + } + vocab_[key] = neg ? -val : val; + skip_json_ws(s, i); + if (i < s.size() && s[i] == ',') { + ++i; + } + } + } + + void load_merges(const std::string &path) { + const std::string s = read_text_file(path); + size_t line_begin = 0; + int rank = 0; + while (line_begin < s.size()) { + size_t line_end = s.find('\n', line_begin); + if (line_end == std::string::npos) { + line_end = s.size(); + } + std::string line = s.substr(line_begin, line_end - line_begin); + if (!line.empty() && line.back() == '\r') { + line.pop_back(); + } + if (!line.empty() && line[0] != '#') { + size_t sp = line.find(' '); + if (sp != std::string::npos && sp + 1 < line.size()) { + ranks_[pair_key(line.substr(0, sp), line.substr(sp + 1))] = rank++; + } + } + line_begin = line_end + 1; + } + } + + std::vector pre_tokenize(const std::string &text) const { + auto chars = utf8_chars(text); + std::vector toks; + size_t i = 0; + while (i < chars.size()) { + if (chars[i].first == '\'' && i + 1 < chars.size()) { + std::string suf; + size_t j = i + 1; + while (j < chars.size() && chars[j].first < 128 && is_letter_cp(chars[j].first) && + suf.size() < 2) { + suf.push_back(ascii_lower(chars[j].second[0])); + ++j; + } + const bool one = suf == "s" || suf == "t" || suf == "m" || suf == "d"; + const bool two = suf == "re" || suf == "ve" || suf == "ll"; + if ((one && j == i + 2) || (two && j == i + 3)) { + std::string tok; + for (size_t k = i; k < j; ++k) { + tok += chars[k].second; + } + toks.push_back(std::move(tok)); + i = j; + continue; + } + } + + size_t j = i; + std::string tok; + if (!is_crlf_cp(chars[j].first) && !is_letter_cp(chars[j].first) && !is_digit_cp(chars[j].first) && + j + 1 < chars.size() && is_letter_cp(chars[j + 1].first)) { + tok += chars[j].second; + ++j; + } + if (j < chars.size() && is_letter_cp(chars[j].first)) { + while (j < chars.size() && is_letter_cp(chars[j].first)) { + tok += chars[j].second; + ++j; + } + toks.push_back(std::move(tok)); + i = j; + continue; + } + + if (is_digit_cp(chars[i].first)) { + toks.push_back(chars[i].second); + ++i; + continue; + } + + j = i; + tok.clear(); + if (chars[j].first == ' ' && j + 1 < chars.size() && + !is_space_cp(chars[j + 1].first) && !is_letter_cp(chars[j + 1].first) && + !is_digit_cp(chars[j + 1].first)) { + tok += chars[j].second; + ++j; + } + if (j < chars.size() && !is_space_cp(chars[j].first) && + !is_letter_cp(chars[j].first) && !is_digit_cp(chars[j].first)) { + while (j < chars.size() && !is_space_cp(chars[j].first) && + !is_letter_cp(chars[j].first) && !is_digit_cp(chars[j].first)) { + tok += chars[j].second; + ++j; + } + while (j < chars.size() && is_crlf_cp(chars[j].first)) { + tok += chars[j].second; + ++j; + } + toks.push_back(std::move(tok)); + i = j; + continue; + } + + j = i; + tok.clear(); + while (j < chars.size() && is_space_cp(chars[j].first) && !is_crlf_cp(chars[j].first)) { + tok += chars[j].second; + ++j; + } + if (j < chars.size() && is_crlf_cp(chars[j].first)) { + while (j < chars.size() && is_crlf_cp(chars[j].first)) { + tok += chars[j].second; + ++j; + } + toks.push_back(std::move(tok)); + i = j; + continue; + } + + tok.clear(); + while (i < chars.size() && is_space_cp(chars[i].first)) { + tok += chars[i].second; + ++i; + } + if (tok.empty()) { + tok = chars[i].second; + ++i; + } + toks.push_back(std::move(tok)); + } + return toks; + } + + std::vector bpe(const std::string &token) const { + std::string encoded; + encoded.reserve(token.size() * 2); + for (unsigned char c : token) { + encoded += byte_encoder_[c]; + } + std::vector word = split_utf8_strings(encoded); + if (word.size() <= 1) { + return word; + } + while (true) { + int best_rank = std::numeric_limits::max(); + size_t best = std::numeric_limits::max(); + for (size_t i = 0; i + 1 < word.size(); ++i) { + auto it = ranks_.find(pair_key(word[i], word[i + 1])); + if (it != ranks_.end() && it->second < best_rank) { + best_rank = it->second; + best = i; + } + } + if (best == std::numeric_limits::max()) { + break; + } + std::vector merged; + merged.reserve(word.size() - 1); + for (size_t i = 0; i < word.size();) { + if (i + 1 < word.size() && i == best) { + merged.push_back(word[i] + word[i + 1]); + i += 2; + } else { + merged.push_back(word[i]); + ++i; + } + } + word.swap(merged); + } + return word; + } + + void encode_plain(const std::string &text, std::vector &out) const { + for (const auto &tok : pre_tokenize(text)) { + for (const auto &piece : bpe(tok)) { + auto it = vocab_.find(piece); + if (it == vocab_.end()) { + throw std::runtime_error("token not found in vocab"); + } + out.push_back(it->second); + } + } + } +}; + +inline std::vector load_gguf_tensor_f32(const std::string &path, const std::string &name, size_t count) { + gguf_init_params params = {}; + params.no_alloc = true; + gguf_context *ctx = gguf_init_from_file(path.c_str(), params); + if (!ctx) { + throw std::runtime_error("open gguf failed: " + path); + } + int64_t tid = gguf_find_tensor(ctx, name.c_str()); + if (tid < 0) { + gguf_free(ctx); + throw std::runtime_error("tensor not found in gguf: " + name); + } + if (gguf_get_tensor_type(ctx, tid) != GGML_TYPE_F32 || + gguf_get_tensor_size(ctx, tid) != count * sizeof(float)) { + gguf_free(ctx); + throw std::runtime_error("bad gguf tensor shape/type: " + name); + } + const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, tid); + std::vector out(count); + FILE *f = std::fopen(path.c_str(), "rb"); + if (!f) { + gguf_free(ctx); + throw std::runtime_error("open gguf data failed: " + path); + } + if (std::fseek(f, static_cast(offset), SEEK_SET) != 0 || + std::fread(out.data(), sizeof(float), count, f) != count) { + std::fclose(f); + gguf_free(ctx); + throw std::runtime_error("read gguf tensor failed: " + name); + } + std::fclose(f); + gguf_free(ctx); + return out; +} + +inline std::vector run_text_embed(Ort::Env &env, + const std::string &onnx_path, + const std::vector &ids, + int threads) { + Ort::SessionOptions so; + so.SetIntraOpNumThreads(std::max(1, threads)); + so.AddConfigEntry("session.intra_op.allow_spinning", "0"); + Ort::Session sess(env, onnx_path.c_str(), so); + Ort::AllocatorWithDefaultOptions allocator; + auto in = sess.GetInputNameAllocated(0, allocator); + auto out = sess.GetOutputNameAllocated(0, allocator); + std::array shape {1, static_cast(ids.size())}; + auto mem = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + auto tensor = Ort::Value::CreateTensor(mem, + const_cast(ids.data()), + ids.size(), + shape.data(), + shape.size()); + const char *in_names[] = {in.get()}; + const char *out_names[] = {out.get()}; + auto outputs = sess.Run(Ort::RunOptions{nullptr}, in_names, &tensor, 1, out_names, 1); + float *p = outputs[0].GetTensorMutableData(); + const size_t n = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); + if (n != ids.size() * static_cast(HID)) { + throw std::runtime_error("text_embed_proj output shape mismatch"); + } + return std::vector(p, p + n); +} + +inline std::vector run_text_embed_session(Ort::Session &sess, + const std::string &input_name, + const std::string &output_name, + const std::vector &ids) { + std::array shape {1, static_cast(ids.size())}; + auto mem = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + auto tensor = Ort::Value::CreateTensor(mem, + const_cast(ids.data()), + ids.size(), + shape.data(), + shape.size()); + const char *in_names[] = {input_name.c_str()}; + const char *out_names[] = {output_name.c_str()}; + auto outputs = sess.Run(Ort::RunOptions{nullptr}, in_names, &tensor, 1, out_names, 1); + float *p = outputs[0].GetTensorMutableData(); + const size_t n = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); + if (n != ids.size() * static_cast(HID)) { + throw std::runtime_error("text_embed_proj output shape mismatch"); + } + return std::vector(p, p + n); +} + +inline std::vector read_wav_mono_24k(const std::string &path) { + const auto d = read_bytes_file(path); + if (d.size() < 44 || std::memcmp(d.data(), "RIFF", 4) != 0 || std::memcmp(d.data() + 8, "WAVE", 4) != 0) { + throw std::runtime_error("expected RIFF/WAVE wav: " + path); + } + uint16_t audio_format = 0; + uint16_t channels = 0; + uint32_t sample_rate = 0; + uint16_t bits = 0; + size_t data_pos = 0; + size_t data_bytes = 0; + size_t p = 12; + while (p + 8 <= d.size()) { + const char *id = reinterpret_cast(d.data() + p); + uint32_t sz = static_cast(d[p + 4]) | + (static_cast(d[p + 5]) << 8) | + (static_cast(d[p + 6]) << 16) | + (static_cast(d[p + 7]) << 24); + p += 8; + if (p + sz > d.size()) { + throw std::runtime_error("bad wav chunk size"); + } + if (std::memcmp(id, "fmt ", 4) == 0 && sz >= 16) { + audio_format = static_cast(d[p] | (d[p + 1] << 8)); + channels = static_cast(d[p + 2] | (d[p + 3] << 8)); + sample_rate = static_cast(d[p + 4]) | + (static_cast(d[p + 5]) << 8) | + (static_cast(d[p + 6]) << 16) | + (static_cast(d[p + 7]) << 24); + bits = static_cast(d[p + 14] | (d[p + 15] << 8)); + } else if (std::memcmp(id, "data", 4) == 0) { + data_pos = p; + data_bytes = sz; + } + p += sz + (sz & 1u); + } + if (audio_format != 1 || channels < 1 || bits != 16 || data_pos == 0) { + throw std::runtime_error("ref wav must be PCM16"); + } + const size_t frames = data_bytes / (2 * channels); + std::vector audio(frames); + for (size_t i = 0; i < frames; ++i) { + float sum = 0.0f; + for (uint16_t ch = 0; ch < channels; ++ch) { + const size_t p16 = data_pos + (i * channels + ch) * 2; + int16_t s = static_cast(d[p16] | (d[p16 + 1] << 8)); + sum += static_cast(s) / 32768.0f; + } + audio[i] = sum / channels; + } + if (sample_rate == 24000) { + return audio; + } + if (sample_rate == 0) { + throw std::runtime_error("bad wav sample rate"); + } + const size_t out_frames = std::max(1, static_cast( + std::llround(static_cast(audio.size()) * 24000.0 / sample_rate))); + std::vector resampled(out_frames); + for (size_t i = 0; i < out_frames; ++i) { + const double src = static_cast(i) * sample_rate / 24000.0; + const size_t j = std::min(static_cast(src), audio.size() - 1); + const size_t j1 = std::min(j + 1, audio.size() - 1); + const float frac = static_cast(src - j); + resampled[i] = audio[j] * (1.0f - frac) + audio[j1] * frac; + } + return resampled; +} + +inline double hz_to_mel(double hz) { + const double f_sp = 200.0 / 3.0; + const double min_log_hz = 1000.0; + const double min_log_mel = min_log_hz / f_sp; + const double logstep = std::log(6.4) / 27.0; + if (hz < min_log_hz) { + return hz / f_sp; + } + return min_log_mel + std::log(hz / min_log_hz) / logstep; +} + +inline double mel_to_hz(double mel) { + const double f_sp = 200.0 / 3.0; + const double min_log_hz = 1000.0; + const double min_log_mel = min_log_hz / f_sp; + const double logstep = std::log(6.4) / 27.0; + if (mel < min_log_mel) { + return mel * f_sp; + } + return min_log_hz * std::exp(logstep * (mel - min_log_mel)); +} + +inline std::vector mel_filterbank() { + constexpr int n_fft = 1024; + constexpr int n_mels = 128; + constexpr int n_freq = n_fft / 2 + 1; + std::vector mel_pts(n_mels + 2); + const double mel_min = hz_to_mel(0.0); + const double mel_max = hz_to_mel(12000.0); + for (int i = 0; i < n_mels + 2; ++i) { + mel_pts[i] = mel_to_hz(mel_min + (mel_max - mel_min) * i / (n_mels + 1)); + } + std::vector fb(n_mels * n_freq, 0.0f); + for (int m = 0; m < n_mels; ++m) { + const double left = mel_pts[m]; + const double center = mel_pts[m + 1]; + const double right = mel_pts[m + 2]; + const double enorm = 2.0 / (right - left); + for (int f = 0; f < n_freq; ++f) { + const double hz = 12000.0 * f / (n_freq - 1); + double w = 0.0; + if (hz >= left && hz <= center) { + w = (hz - left) / (center - left); + } else if (hz >= center && hz <= right) { + w = (right - hz) / (right - center); + } + fb[static_cast(m * n_freq + f)] = static_cast(std::max(0.0, w) * enorm); + } + } + return fb; +} + +inline void fft1024(std::array, 1024> &a) { + constexpr int n = 1024; + for (int i = 1, j = 0; i < n; ++i) { + int bit = n >> 1; + for (; j & bit; bit >>= 1) { + j ^= bit; + } + j ^= bit; + if (i < j) { + std::swap(a[i], a[j]); + } + } + for (int len = 2; len <= n; len <<= 1) { + const float ang = -2.0f * static_cast(M_PI) / static_cast(len); + const std::complex wlen(std::cos(ang), std::sin(ang)); + for (int i = 0; i < n; i += len) { + std::complex w(1.0f, 0.0f); + for (int j = 0; j < len / 2; ++j) { + std::complex u = a[i + j]; + std::complex v = a[i + j + len / 2] * w; + a[i + j] = u + v; + a[i + j + len / 2] = u - v; + w *= wlen; + } + } + } +} + +inline std::vector wav_to_mel_128(const std::string &path) { + constexpr int n_fft = 1024; + constexpr int hop = 256; + constexpr int pad = 384; + constexpr int n_mels = 128; + constexpr int n_freq = 513; + auto audio = read_wav_mono_24k(path); + if (audio.size() <= static_cast(pad + 1)) { + throw std::runtime_error("ref wav is too short"); + } + const int n = static_cast(audio.size()); + std::vector padded(static_cast(n + 2 * pad)); + for (int i = 0; i < static_cast(padded.size()); ++i) { + int j = i - pad; + while (j < 0 || j >= n) { + if (j < 0) { + j = -j; + } else { + j = 2 * n - 2 - j; + } + } + padded[static_cast(i)] = audio[static_cast(j)]; + } + const int frames = 1 + (static_cast(padded.size()) - n_fft) / hop; + std::vector mel(static_cast(frames * n_mels), 0.0f); + std::array window {}; + for (int i = 0; i < n_fft; ++i) { + window[i] = 0.5f - 0.5f * std::cos(2.0f * static_cast(M_PI) * i / n_fft); + } + const auto fb = mel_filterbank(); + std::array, n_fft> buf {}; + std::array mag {}; + for (int t = 0; t < frames; ++t) { + const int off = t * hop; + for (int i = 0; i < n_fft; ++i) { + buf[i] = std::complex(padded[static_cast(off + i)] * window[i], 0.0f); + } + fft1024(buf); + for (int f = 0; f < n_freq; ++f) { + const float re = buf[f].real(); + const float im = buf[f].imag(); + mag[f] = std::sqrt(re * re + im * im + 1e-9f); + } + for (int m = 0; m < n_mels; ++m) { + double v = 0.0; + for (int f = 0; f < n_freq; ++f) { + v += static_cast(fb[static_cast(m * n_freq + f)]) * mag[f]; + } + mel[static_cast(t * n_mels + m)] = std::log(static_cast(std::max(v, 1e-5))); + } + } + return mel; +} + +inline std::vector run_speaker_encoder(Ort::Env &env, + const std::string &onnx_path, + const std::string &wav_path, + int threads) { + auto mel = wav_to_mel_128(wav_path); + const int64_t frames = static_cast(mel.size() / 128); + Ort::SessionOptions so; + so.SetIntraOpNumThreads(std::max(1, threads)); + so.AddConfigEntry("session.intra_op.allow_spinning", "0"); + Ort::Session sess(env, onnx_path.c_str(), so); + Ort::AllocatorWithDefaultOptions allocator; + auto in = sess.GetInputNameAllocated(0, allocator); + auto out = sess.GetOutputNameAllocated(0, allocator); + std::array shape {1, frames, 128}; + auto mem = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + auto tensor = Ort::Value::CreateTensor(mem, mel.data(), mel.size(), shape.data(), shape.size()); + const char *in_names[] = {in.get()}; + const char *out_names[] = {out.get()}; + auto outputs = sess.Run(Ort::RunOptions{nullptr}, in_names, &tensor, 1, out_names, 1); + float *p = outputs[0].GetTensorMutableData(); + const size_t n = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); + if (n != HID) { + throw std::runtime_error("speaker encoder output shape mismatch"); + } + return std::vector(p, p + n); +} + +inline std::vector> run_codec_encoder(Ort::Env &env, + const std::string &onnx_path, + const std::string &wav_path, + int threads) { + auto audio = read_wav_mono_24k(wav_path); + if (audio.empty()) { + throw std::runtime_error("ref wav is empty"); + } + const size_t valid_samples = std::min(audio.size(), static_cast(REF_CODEC_SAMPLES)); + const int valid_frames = std::max(1, static_cast( + (valid_samples + REF_CODEC_SAMPLES_PER_FRAME - 1) / REF_CODEC_SAMPLES_PER_FRAME)); + + std::vector input(REF_CODEC_SAMPLES, 0.0f); + std::copy_n(audio.data(), valid_samples, input.data()); + + Ort::SessionOptions so; + so.SetIntraOpNumThreads(std::max(1, threads)); + so.AddConfigEntry("session.intra_op.allow_spinning", "0"); + Ort::Session sess(env, onnx_path.c_str(), so); + Ort::AllocatorWithDefaultOptions allocator; + auto in = sess.GetInputNameAllocated(0, allocator); + auto out = sess.GetOutputNameAllocated(0, allocator); + std::array shape {1, REF_CODEC_SAMPLES}; + auto mem = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + auto tensor = Ort::Value::CreateTensor(mem, input.data(), input.size(), shape.data(), shape.size()); + const char *in_names[] = {in.get()}; + const char *out_names[] = {out.get()}; + auto outputs = sess.Run(Ort::RunOptions{nullptr}, in_names, &tensor, 1, out_names, 1); + int64_t *p = outputs[0].GetTensorMutableData(); + const size_t n = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); + if (n != 100UL * CODE_GROUPS) { + throw std::runtime_error("codec encoder output shape mismatch"); + } + const int frames = std::min(valid_frames, 100); + std::vector> codes(static_cast(frames)); + for (int t = 0; t < frames; ++t) { + for (int g = 0; g < CODE_GROUPS; ++g) { + const int64_t v = p[static_cast(t * CODE_GROUPS + g)]; + if (v < 0 || v >= CODE_VOCAB) { + throw std::runtime_error("codec encoder produced out-of-range code"); + } + codes[static_cast(t)][g] = static_cast(v); + } + } + return codes; +} + +inline uint32_t read_u32_le(const uint8_t *p) { + return static_cast(p[0]) | + (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16) | + (static_cast(p[3]) << 24); +} + +inline void append_u32_le(std::vector &out, uint32_t v) { + out.push_back(static_cast(v & 0xff)); + out.push_back(static_cast((v >> 8) & 0xff)); + out.push_back(static_cast((v >> 16) & 0xff)); + out.push_back(static_cast((v >> 24) & 0xff)); +} + +inline void append_raw(std::vector &out, const void *p, size_t n) { + const auto *b = reinterpret_cast(p); + out.insert(out.end(), b, b + n); +} + +inline std::vector read_speaker_bin(const std::string &path) { + auto bytes = read_bytes_file(path); + if (bytes.size() != static_cast(HID) * sizeof(float)) { + throw std::runtime_error("speaker bin must be raw float32[1024]: " + path); + } + std::vector out(HID); + std::memcpy(out.data(), bytes.data(), bytes.size()); + return out; +} + +inline void write_speaker_bin(const std::string &path, const std::vector &spk) { + if (spk.size() != HID) { + throw std::runtime_error("speaker bin expects float32[1024]"); + } + write_bytes_file(path, reinterpret_cast(spk.data()), spk.size() * sizeof(float)); +} + +inline ReferencePrompt read_reference_prompt_bin(const std::string &path) { + auto bytes = read_bytes_file(path); + if (bytes.size() == static_cast(HID) * sizeof(float)) { + ReferencePrompt prompt; + prompt.speaker.resize(HID); + std::memcpy(prompt.speaker.data(), bytes.data(), bytes.size()); + return prompt; + } + + static constexpr char kMagic[8] = {'Q', '3', 'T', 'P', 'R', 'M', 'P', 'T'}; + if (bytes.size() < 40 || std::memcmp(bytes.data(), kMagic, sizeof(kMagic)) != 0) { + throw std::runtime_error("bad reference prompt bin: " + path); + } + size_t p = sizeof(kMagic); + const uint32_t version = read_u32_le(bytes.data() + p); p += 4; + const uint32_t hid = read_u32_le(bytes.data() + p); p += 4; + const uint32_t groups = read_u32_le(bytes.data() + p); p += 4; + const uint32_t frames = read_u32_le(bytes.data() + p); p += 4; + const uint32_t text_bytes = read_u32_le(bytes.data() + p); p += 4; + p += 12; // reserved + if (version != 1 || hid != HID || groups != CODE_GROUPS || frames == 0 || text_bytes == 0) { + throw std::runtime_error("unsupported reference prompt bin: " + path); + } + const size_t need = p + static_cast(HID) * sizeof(float) + + static_cast(frames) * CODE_GROUPS * sizeof(int32_t) + + static_cast(text_bytes); + if (need != bytes.size()) { + throw std::runtime_error("reference prompt bin size mismatch: " + path); + } + + ReferencePrompt prompt; + prompt.speaker.resize(HID); + std::memcpy(prompt.speaker.data(), bytes.data() + p, static_cast(HID) * sizeof(float)); + p += static_cast(HID) * sizeof(float); + prompt.ref_codes.resize(frames); + std::memcpy(prompt.ref_codes.data(), bytes.data() + p, + static_cast(frames) * CODE_GROUPS * sizeof(int32_t)); + p += static_cast(frames) * CODE_GROUPS * sizeof(int32_t); + prompt.ref_text.assign(reinterpret_cast(bytes.data() + p), text_bytes); + return prompt; +} + +inline void write_reference_prompt_bin(const std::string &path, + const std::vector &spk, + const std::string &ref_text, + const std::vector> &ref_codes) { + if (spk.size() != HID) { + throw std::runtime_error("reference prompt expects speaker float32[1024]"); + } + if (ref_text.empty()) { + throw std::runtime_error("reference prompt expects non-empty ref_text"); + } + if (ref_codes.empty()) { + throw std::runtime_error("reference prompt expects non-empty ref_codes"); + } + + std::vector out; + static constexpr char kMagic[8] = {'Q', '3', 'T', 'P', 'R', 'M', 'P', 'T'}; + append_raw(out, kMagic, sizeof(kMagic)); + append_u32_le(out, 1); + append_u32_le(out, HID); + append_u32_le(out, CODE_GROUPS); + append_u32_le(out, static_cast(ref_codes.size())); + append_u32_le(out, static_cast(ref_text.size())); + append_u32_le(out, 0); + append_u32_le(out, 0); + append_u32_le(out, 0); + append_raw(out, spk.data(), spk.size() * sizeof(float)); + append_raw(out, ref_codes.data(), ref_codes.size() * CODE_GROUPS * sizeof(int32_t)); + append_raw(out, ref_text.data(), ref_text.size()); + write_bytes_file(path, out.data(), out.size()); +} + +inline std::string speaker_encoder_path(const std::string &model_dir) { + return first_existing({ + path_join(path_join(model_dir, "onnx"), "speaker_encoder.onnx"), + path_join(model_dir, "speaker_encoder.onnx"), + "speaker_encoder.onnx", + }); +} + +inline std::string codec_encoder_path(const std::string &model_dir) { + return first_existing({ + path_join(path_join(model_dir, "onnx"), "codec_encoder.onnx"), + path_join(model_dir, "codec_encoder.onnx"), + "codec_encoder.onnx", + }); +} + +inline void append_vec(std::vector &dst, const float *src) { + dst.insert(dst.end(), src, src + HID); +} + +inline void append_sum(std::vector &dst, const float *a, const float *b) { + const size_t off = dst.size(); + dst.resize(off + HID); + for (int i = 0; i < HID; ++i) { + dst[off + i] = a[i] + b[i]; + } +} + +inline std::pair tokenizer_paths(const std::string &model_dir) { + const std::string tokenizer_dir = path_join(model_dir, "tokenizer"); + const std::string vocab = first_existing({ + path_join(tokenizer_dir, "vocab.json"), + path_join(model_dir, "vocab.json"), + }); + const std::string merges = first_existing({ + path_join(tokenizer_dir, "merges.txt"), + path_join(model_dir, "merges.txt"), + }); + return {vocab, merges}; +} + +inline std::string text_embed_path(const std::string &model_dir) { + std::vector candidates = { + path_join(path_join(model_dir, "onnx"), "text_embed_proj.onnx"), + "text_embed_proj.onnx", + }; + if (std::getenv("Q3TTS_ALLOW_DYNQ_TEXT")) { + candidates.push_back(path_join(path_join(model_dir, "onnx"), "text_embed_proj.dynq.onnx")); + candidates.push_back("text_embed_proj.dynq.onnx"); + } + return first_existing(candidates); +} + +inline std::string talker_gguf_path(const std::string &model_dir, const std::string &talker_gguf) { + return first_existing({ + path_join(path_join(model_dir, "gguf"), talker_gguf), + path_join(model_dir, talker_gguf), + talker_gguf, + }); +} + +inline std::string cp_gguf_path(const std::string &model_dir, const std::string &cp_gguf) { + return first_existing({ + path_join(path_join(model_dir, "gguf"), cp_gguf), + path_join(model_dir, cp_gguf), + cp_gguf, + }); +} + +inline std::vector tokenize_prompt(const std::string &model_dir, const std::string &text) { + const auto paths = tokenizer_paths(model_dir); + const std::string &vocab = paths.first; + const std::string &merges = paths.second; + BpeTokenizer tok(vocab, merges); + return tok.encode("<|im_start|>assistant\n" + text + "<|im_end|>\n<|im_start|>assistant\n"); +} + +class FrontendRuntime { +public: + FrontendRuntime(Ort::Env &env, const FrontendConfig &cfg) + : env_(env), + model_dir_(cfg.model_dir), + language_(cfg.language), + talker_gguf_name_(cfg.talker_gguf), + cp_gguf_name_(cfg.cp_gguf), + frontend_threads_(std::max(1, cfg.frontend_threads)) { + (void)codec_language_id_for(language_); + + const auto tok_paths = tokenizer_paths(model_dir_); + tokenizer_ = std::make_unique(tok_paths.first, tok_paths.second); + + text_onnx_ = text_embed_path(model_dir_); + talker_gguf_ = talker_gguf_path(model_dir_, talker_gguf_name_); + cp_gguf_ = cp_gguf_path(model_dir_, cp_gguf_name_); + + Ort::SessionOptions so; + so.SetIntraOpNumThreads(frontend_threads_); + so.AddConfigEntry("session.intra_op.allow_spinning", "0"); + text_session_ = std::make_unique(env_, text_onnx_.c_str(), so); + + Ort::AllocatorWithDefaultOptions allocator; + auto in = text_session_->GetInputNameAllocated(0, allocator); + auto out = text_session_->GetOutputNameAllocated(0, allocator); + text_input_name_ = in.get(); + text_output_name_ = out.get(); + + tts_h_ = run_text_embed_session( + *text_session_, text_input_name_, text_output_name_, + std::vector{TTS_BOS, TTS_EOS, TTS_PAD}); + if (tts_h_.size() != 3UL * HID) { + throw std::runtime_error("tts embed output shape mismatch"); + } + codec_ = load_gguf_tensor_f32(talker_gguf_, "q3tts.codec_embedding.weight", 3072UL * HID); + } + + FrontendInput build(const FrontendConfig &cfg, std::vector *ids_out = nullptr) { + if (cfg.model_dir != model_dir_ || cfg.language != language_ || + cfg.talker_gguf != talker_gguf_name_ || cfg.cp_gguf != cp_gguf_name_) { + throw std::runtime_error("FrontendRuntime used with different model config"); + } + + std::vector ids = tokenizer_->encode( + "<|im_start|>assistant\n" + cfg.text + "<|im_end|>\n<|im_start|>assistant\n"); + if (ids.size() < 9) { + throw std::runtime_error("tokenized prompt is too short"); + } + if (ids_out) { + *ids_out = ids; + } + + auto text_h = run_text_embed_session(*text_session_, text_input_name_, text_output_name_, ids); + const float *bos_e = tts_h_.data(); + const float *eos_e = tts_h_.data() + HID; + const float *pad_e = tts_h_.data() + 2 * HID; + + auto ce = [&](int id) -> const float * { + if (id < 0 || id >= 3072) { + throw std::runtime_error("codec id out of range"); + } + return codec_.data() + static_cast(id) * HID; + }; + + std::vector ctrl; + auto add_ctrl = [&](const float *v) { + ctrl.insert(ctrl.end(), v, v + HID); + }; + const int language_id = codec_language_id_for(language_); + if (language_id == 0) { + add_ctrl(ce(CODEC_NOTHINK)); + add_ctrl(ce(CODEC_THINK_BOS)); + add_ctrl(ce(CODEC_THINK_EOS)); + } else { + add_ctrl(ce(CODEC_THINK)); + add_ctrl(ce(CODEC_THINK_BOS)); + add_ctrl(ce(language_id)); + add_ctrl(ce(CODEC_THINK_EOS)); + } + ReferencePrompt ref_prompt; + if (!cfg.ref_bin.empty() || !cfg.ref_wav.empty()) { + std::vector spk; + if (!cfg.ref_bin.empty()) { + ref_prompt = read_reference_prompt_bin(cfg.ref_bin); + spk = ref_prompt.speaker; + } else { + spk = run_speaker_encoder(env_, speaker_encoder_path(model_dir_), cfg.ref_wav, frontend_threads_); + } + add_ctrl(spk.data()); + } + add_ctrl(ce(CODEC_PAD)); + add_ctrl(ce(CODEC_BOS)); + const int ctrl_n = static_cast(ctrl.size() / HID); + + FrontendInput out; + out.pad.assign(pad_e, pad_e + HID); + + for (int i = 0; i < 3; ++i) { + append_vec(out.prefill, text_h.data() + static_cast(i) * HID); + } + + if (ref_prompt.full_prompt()) { + for (int i = 0; i < ctrl_n - 1; ++i) { + const float *base = (i == ctrl_n - 2) ? bos_e : pad_e; + append_sum(out.prefill, base, ctrl.data() + static_cast(i) * HID); + } + + std::vector ref_ids = tokenizer_->encode( + "<|im_start|>assistant\n" + ref_prompt.ref_text + "<|im_end|>\n"); + if (ref_ids.size() < 6 || ids.size() < 9) { + throw std::runtime_error("tokenized full prompt is too short"); + } + std::vector icl_ids; + for (size_t i = 3; i < ref_ids.size() - 2; ++i) { + icl_ids.push_back(ref_ids[i]); + } + for (size_t i = 3; i < ids.size() - 5; ++i) { + icl_ids.push_back(ids[i]); + } + if (icl_ids.empty()) { + throw std::runtime_error("empty ICL text prompt"); + } + auto icl_text_h = run_text_embed_session(*text_session_, text_input_name_, text_output_name_, icl_ids); + ensure_cp_embeddings(); + + std::vector codec_embed; + codec_embed.reserve((ref_prompt.ref_codes.size() + 1) * HID); + append_vec(codec_embed, ce(CODEC_BOS)); + for (const auto &frame : ref_prompt.ref_codes) { + const size_t off = codec_embed.size(); + codec_embed.resize(off + HID, 0.0f); + const int c0 = frame[0]; + if (c0 < 0 || c0 >= CODE_VOCAB) { + throw std::runtime_error("ref code out of range"); + } + const float *base = ce(c0); + for (int h = 0; h < HID; ++h) { + codec_embed[off + h] += base[h]; + } + for (int g = 1; g < CODE_GROUPS; ++g) { + const int c = frame[g]; + if (c < 0 || c >= CODE_VOCAB) { + throw std::runtime_error("ref code out of range"); + } + const float *emb = cp_embeddings_[static_cast(g - 1)].data() + + static_cast(c) * HID; + for (int h = 0; h < HID; ++h) { + codec_embed[off + h] += emb[h]; + } + } + } + + const size_t text_lens = icl_ids.size() + 1; + const size_t codec_lens = ref_prompt.ref_codes.size() + 1; + auto text_ptr = [&](size_t i) -> const float * { + return (i == icl_ids.size()) ? eos_e : (icl_text_h.data() + i * HID); + }; + if (cfg.full_prompt_non_streaming) { + const float *codec_pad = ce(CODEC_PAD); + for (size_t i = 0; i < text_lens; ++i) { + append_sum(out.prefill, text_ptr(i), codec_pad); + } + for (size_t i = 0; i < codec_lens; ++i) { + append_sum(out.prefill, pad_e, codec_embed.data() + i * HID); + } + append_vec(out.trailing, pad_e); + } else { + const size_t paired = std::min(text_lens, codec_lens); + for (size_t i = 0; i < paired; ++i) { + append_sum(out.prefill, text_ptr(i), codec_embed.data() + i * HID); + } + for (size_t i = paired; i < codec_lens; ++i) { + append_sum(out.prefill, pad_e, codec_embed.data() + i * HID); + } + if (text_lens > codec_lens) { + for (size_t i = codec_lens; i < text_lens; ++i) { + append_vec(out.trailing, text_ptr(i)); + } + } else { + append_vec(out.trailing, pad_e); + } + } + } else { + for (int i = 0; i < ctrl_n - 1; ++i) { + const float *base = (i == ctrl_n - 2) ? bos_e : pad_e; + append_sum(out.prefill, base, ctrl.data() + static_cast(i) * HID); + } + append_sum(out.prefill, text_h.data() + 3UL * HID, ctrl.data() + static_cast(ctrl_n - 1) * HID); + + if (ids.size() > 9) { + const size_t begin = 4; + const size_t end = ids.size() - 5; + for (size_t i = begin; i < end; ++i) { + append_vec(out.trailing, text_h.data() + i * HID); + } + } + append_vec(out.trailing, eos_e); + } + out.n_prefill = static_cast(out.prefill.size() / HID); + out.n_trailing = static_cast(out.trailing.size() / HID); + return out; + } + +private: + void ensure_cp_embeddings() { + if (!cp_embeddings_.empty()) { + return; + } + cp_embeddings_.reserve(CODE_GROUPS - 1); + for (int i = 0; i < CODE_GROUPS - 1; ++i) { + cp_embeddings_.push_back(load_gguf_tensor_f32( + cp_gguf_, + "q3tts.cp_embedding." + std::to_string(i) + ".weight", + static_cast(CODE_VOCAB) * HID)); + } + } + + Ort::Env &env_; + std::string model_dir_; + std::string language_; + std::string talker_gguf_name_; + std::string cp_gguf_name_; + int frontend_threads_ = 1; + std::unique_ptr tokenizer_; + std::string text_onnx_; + std::string talker_gguf_; + std::string cp_gguf_; + std::unique_ptr text_session_; + std::string text_input_name_; + std::string text_output_name_; + std::vector tts_h_; + std::vector codec_; + std::vector> cp_embeddings_; +}; + +inline FrontendInput build(Ort::Env &env, const FrontendConfig &cfg, std::vector *ids_out = nullptr) { + FrontendRuntime runtime(env, cfg); + return runtime.build(cfg, ids_out); +} + +} // namespace q3tts_frontend diff --git a/tools/speech/backends/qwen3_tts/include/qwen3_tts/qwen3_tts_runtime.h b/tools/speech/backends/qwen3_tts/include/qwen3_tts/qwen3_tts_runtime.h new file mode 100644 index 000000000000..84caea836c9b --- /dev/null +++ b/tools/speech/backends/qwen3_tts/include/qwen3_tts/qwen3_tts_runtime.h @@ -0,0 +1,7 @@ +#pragma once + +namespace qwen3_tts { + +int run_cli(int argc, char **argv); + +} // namespace qwen3_tts diff --git a/tools/speech/backends/qwen3_tts/src/kernels/heads_pool.h b/tools/speech/backends/qwen3_tts/src/kernels/heads_pool.h new file mode 100644 index 000000000000..d93fe20600f8 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/src/kernels/heads_pool.h @@ -0,0 +1,203 @@ +// X100 4-thread fp32 GEMV pool for the 15 cp lm-heads. +// (Production slice of cp_forward.h; in-process cp forward experiments live in ref/.) +#pragma once +#include +#include +#include +#include +#include + +#ifndef Q3TTS_HEADS_NTH +#define Q3TTS_HEADS_NTH 4 +#endif +#define NTH Q3TTS_HEADS_NTH +typedef struct { const uint16_t *wf, *xf; int fo, fi; float *y; int t0, t1; } Job; +static Job hd_job[NTH]; +static pthread_t hd_th[NTH]; +static int hd_core_base = 0; +static int hd_idle_yield = 0; +static int hd_main_work = 0; +static int hd_best_idx[NTH]; +static float hd_best_val[NTH]; +static volatile int hd_go[NTH], hd_done[NTH], hd_quit = 0; + +// w is fp16 (halves DDR traffic), x is f32 pre-narrowed by caller to fp16; widening MAC keeps f32 accum. +static inline float dot_f16(const uint16_t *w, const uint16_t *x, int n) { + size_t vl = __riscv_vsetvl_e16m4(128); + vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(0.f, vl); + for (int d = 0; d < n; d += (int)vl) { + acc = __riscv_vfwmacc_vv_f32m8(acc, __riscv_vle16_v_f16m4((const _Float16 *)(w + d), vl), + __riscv_vle16_v_f16m4((const _Float16 *)(x + d), vl), vl); + } + size_t vl32 = __riscv_vsetvl_e32m8(vl); + vfloat32m1_t r = __riscv_vfredusum_vs_f32m8_f32m1(acc, __riscv_vfmv_v_f_f32m1(0.f, 1), vl32); + return __riscv_vfmv_f_s_f32m1_f32(r); +} + +static void *hd_worker(void *a) { + long id = (long)a; + cpu_set_t cs; CPU_ZERO(&cs); CPU_SET(hd_core_base + id, &cs); + pthread_setaffinity_np(pthread_self(), sizeof(cs), &cs); + for (;;) { + while (!__atomic_load_n(&hd_go[id], __ATOMIC_ACQUIRE)) { + if (hd_quit) return NULL; + if (hd_idle_yield) { + sched_yield(); + } else { + __asm__ volatile("pause"); + } + } + __atomic_store_n(&hd_go[id], 0, __ATOMIC_RELAXED); + Job *j = &hd_job[id]; + if (j->y) { + for (int r = j->t0; r < j->t1; r++) { + __builtin_prefetch(j->wf + (long)(r + 1) * j->fi); + j->y[r] = dot_f16(j->wf + (long)r * j->fi, j->xf, j->fi); + } + } else { + float best = -1e30f; + int best_i = j->t0; + for (int r = j->t0; r < j->t1; r++) { + __builtin_prefetch(j->wf + (long)(r + 1) * j->fi); + float v = dot_f16(j->wf + (long)r * j->fi, j->xf, j->fi); + if (v > best) { + best = v; + best_i = r; + } + } + hd_best_val[id] = best; + hd_best_idx[id] = best_i; + } + __atomic_store_n(&hd_done[id], 1, __ATOMIC_RELEASE); + } +} +static void pools_init(void) { + const char *base = getenv("Q3TTS_HEADS_BASE"); + if (base) { + hd_core_base = atoi(base); + } + hd_idle_yield = getenv("Q3TTS_HEADS_IDLE_YIELD") != NULL; + hd_main_work = getenv("Q3TTS_HEADS_MAIN_WORK") != NULL; + for (long i = 0; i < NTH; i++) pthread_create(&hd_th[i], NULL, hd_worker, (void *)i); +} +static inline void prepare_xh(const float *x, uint16_t *xh, int i) { + size_t vl; + for (int d = 0; d < i; d += (int)vl) { + vl = __riscv_vsetvl_e32m4(i - d); + vfloat16m2_t h = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(x + d, vl), vl); + __riscv_vse16_v_f16m2((_Float16 *)(xh + d), h, vl); + } +} +static inline void mv_f16(const uint16_t *wf, int o, int i, const float *x, float *y) { + static uint16_t xh[4096]; + prepare_xh(x, xh, i); + const int parts = hd_main_work ? (NTH + 1) : NTH; + int per = (o + parts - 1) / parts; + for (int t = 0; t < NTH; t++) { + hd_job[t] = (Job){wf, xh, o, i, y, t * per, (t + 1) * per > o ? o : (t + 1) * per}; + __atomic_store_n(&hd_go[t], 1, __ATOMIC_RELEASE); + } + if (hd_main_work) { + int t0 = NTH * per; + int t1 = (NTH + 1) * per > o ? o : (NTH + 1) * per; + for (int r = t0; r < t1; r++) { + __builtin_prefetch(wf + (long)(r + 1) * i); + y[r] = dot_f16(wf + (long)r * i, xh, i); + } + } + for (int t = 0; t < NTH; t++) { + while (!__atomic_load_n(&hd_done[t], __ATOMIC_ACQUIRE)) { __asm__ volatile("pause"); } + __atomic_store_n(&hd_done[t], 0, __ATOMIC_RELAXED); + } +} +static inline int mv_f16_argmax(const uint16_t *wf, int o, int i, const float *x, float *best_out) { + static uint16_t xh[4096]; + prepare_xh(x, xh, i); + const int parts = hd_main_work ? (NTH + 1) : NTH; + int per = (o + parts - 1) / parts; + for (int t = 0; t < NTH; t++) { + hd_job[t] = (Job){wf, xh, o, i, NULL, t * per, (t + 1) * per > o ? o : (t + 1) * per}; + __atomic_store_n(&hd_go[t], 1, __ATOMIC_RELEASE); + } + int main_best_i = NTH * per; + float main_best = -1e30f; + if (hd_main_work) { + int t0 = NTH * per; + int t1 = (NTH + 1) * per > o ? o : (NTH + 1) * per; + for (int r = t0; r < t1; r++) { + __builtin_prefetch(wf + (long)(r + 1) * i); + float v = dot_f16(wf + (long)r * i, xh, i); + if (v > main_best) { + main_best = v; + main_best_i = r; + } + } + } + for (int t = 0; t < NTH; t++) { + while (!__atomic_load_n(&hd_done[t], __ATOMIC_ACQUIRE)) { __asm__ volatile("pause"); } + __atomic_store_n(&hd_done[t], 0, __ATOMIC_RELAXED); + } + int best_i = 0; + float best = -1e30f; + for (int t = 0; t < NTH; t++) { + if (hd_best_val[t] > best) { + best = hd_best_val[t]; + best_i = hd_best_idx[t]; + } + } + if (hd_main_work && main_best > best) { + best = main_best; + best_i = main_best_i; + } + if (best_out) { + *best_out = best; + } + return best_i; +} +static inline int mv_f16_argmax_eos(const uint16_t *wf, int o, int eos, int i, const float *x, float *best_out, float *eos_out) { + static uint16_t xh[4096]; + prepare_xh(x, xh, i); + const int parts = hd_main_work ? (NTH + 1) : NTH; + int per = (o + parts - 1) / parts; + for (int t = 0; t < NTH; t++) { + hd_job[t] = (Job){wf, xh, o, i, NULL, t * per, (t + 1) * per > o ? o : (t + 1) * per}; + __atomic_store_n(&hd_go[t], 1, __ATOMIC_RELEASE); + } + int main_best_i = NTH * per; + float main_best = -1e30f; + if (hd_main_work) { + int t0 = NTH * per; + int t1 = (NTH + 1) * per > o ? o : (NTH + 1) * per; + for (int r = t0; r < t1; r++) { + __builtin_prefetch(wf + (long)(r + 1) * i); + float v = dot_f16(wf + (long)r * i, xh, i); + if (v > main_best) { + main_best = v; + main_best_i = r; + } + } + } + for (int t = 0; t < NTH; t++) { + while (!__atomic_load_n(&hd_done[t], __ATOMIC_ACQUIRE)) { __asm__ volatile("pause"); } + __atomic_store_n(&hd_done[t], 0, __ATOMIC_RELAXED); + } + int best_i = 0; + float best = -1e30f; + for (int t = 0; t < NTH; t++) { + if (hd_best_val[t] > best) { + best = hd_best_val[t]; + best_i = hd_best_idx[t]; + } + } + if (hd_main_work && main_best > best) { + best = main_best; + best_i = main_best_i; + } + if (best_out) { + *best_out = best; + } + if (eos_out) { + *eos_out = dot_f16(wf + (long)eos * i, xh, i); + } + return best_i; +} diff --git a/tools/speech/backends/qwen3_tts/src/qwen3_tts_runtime.cpp b/tools/speech/backends/qwen3_tts/src/qwen3_tts_runtime.cpp new file mode 100644 index 000000000000..83fbf74ca74b --- /dev/null +++ b/tools/speech/backends/qwen3_tts/src/qwen3_tts_runtime.cpp @@ -0,0 +1,2326 @@ +#include + +#include "qwen3_tts_runtime.h" + +#include "q3tts_frontend.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "q3tts_audio_sdk.h" +#include "q3tts_codec_ort.h" + +namespace { + +using Clock = std::chrono::steady_clock; + +std::string env_str(const char *name, const std::string &fallback) { + const char *v = std::getenv(name); + return (v && *v) ? std::string(v) : fallback; +} + +int env_int(const char *name, int fallback) { + const char *v = std::getenv(name); + return (v && *v) ? std::atoi(v) : fallback; +} + +double env_double(const char *name, double fallback) { + const char *v = std::getenv(name); + if (!v || !*v) { + return fallback; + } + char *end = nullptr; + const double parsed = std::strtod(v, &end); + return end != v && std::isfinite(parsed) ? parsed : fallback; +} + +void set_default_env(const char *name, const char *value) { + if (!std::getenv(name)) { + setenv(name, value, 0); + } +} + +void set_env_override(const char *name, const std::string &value) { + setenv(name, value.c_str(), 1); +} + +void set_default_runtime_env(const std::string &talker_gguf) { + (void) talker_gguf; + if (!std::getenv("Q3TTS_DISABLE_SWIGLU_DOWN_FUSION")) { + set_default_env("GGML_CPU_FUSE_SWIGLU_DOWN_Q8", "1"); + } + set_default_env("LLAMA_CTX_PAD", "16"); + set_default_env("SPACEMIT_Q4_HP_M1_N64", "1"); + set_default_env("Q3TTS_CP_CTX", "16"); + set_default_env("Q3TTS_CP_CTX_MIN", "16"); +} + +void profile_event(bool enabled, Clock::time_point origin, const std::string &msg) { + if (!enabled) { + return; + } + const double ms = std::chrono::duration(Clock::now() - origin).count(); + const auto old_precision = std::cerr.precision(); + const auto old_flags = std::cerr.flags(); + std::cerr.setf(std::ios::fixed); + std::cerr.precision(2); + std::cerr << "[q3tts-profile +" << ms << "ms] " << msg << "\n"; + std::cerr.flags(old_flags); + std::cerr.precision(old_precision); +} + +bool exists(const std::string &path) { + std::ifstream f(path, std::ios::binary); + return f.good(); +} + +std::vector parse_int_list(const std::string &raw, const std::vector &fallback) { + if (raw.empty()) { + return fallback; + } + std::vector out; + std::string s = raw; + std::replace(s.begin(), s.end(), ';', ','); + std::stringstream ss(s); + std::string item; + while (std::getline(ss, item, ',')) { + if (!item.empty()) { + out.push_back(std::stoi(item)); + } + } + return out.empty() ? fallback : out; +} + +std::vector read_all(const std::string &path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("open failed: " + path); + } + f.seekg(0, std::ios::end); + const auto n = static_cast(f.tellg()); + f.seekg(0, std::ios::beg); + std::vector data(n); + f.read(reinterpret_cast(data.data()), data.size()); + return data; +} + +void write_all(const std::string &path, const uint8_t *data, size_t n) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("write failed: " + path); + } + f.write(reinterpret_cast(data), n); +} + +uint16_t le16(const std::vector &d, size_t p) { + return static_cast(d[p]) | (static_cast(d[p + 1]) << 8); +} + +uint32_t le32(const std::vector &d, size_t p) { + return static_cast(d[p]) | + (static_cast(d[p + 1]) << 8) | + (static_cast(d[p + 2]) << 16) | + (static_cast(d[p + 3]) << 24); +} + +uint64_t le64(const std::vector &d, size_t p) { + return static_cast(le32(d, p)) | + (static_cast(le32(d, p + 4)) << 32); +} + +std::unordered_map> load_npz_stored(const std::string &path) { + auto zip = read_all(path); + std::unordered_map> out; + size_t p = 0; + while (p + 30 <= zip.size()) { + const uint32_t sig = le32(zip, p); + if (sig != 0x04034b50u) { + break; + } + const uint16_t flags = le16(zip, p + 6); + const uint16_t method = le16(zip, p + 8); + uint64_t comp_size = le32(zip, p + 18); + uint64_t uncomp_size = le32(zip, p + 22); + const uint16_t name_len = le16(zip, p + 26); + const uint16_t extra_len = le16(zip, p + 28); + const size_t name_pos = p + 30; + const size_t data_pos = name_pos + name_len + extra_len; + if ((comp_size == 0xffffffffULL || uncomp_size == 0xffffffffULL) && name_pos + name_len + extra_len <= zip.size()) { + size_t ep = name_pos + name_len; + const size_t extra_end = ep + extra_len; + while (ep + 4 <= extra_end) { + const uint16_t header_id = le16(zip, ep); + const uint16_t data_size = le16(zip, ep + 2); + ep += 4; + if (ep + data_size > extra_end) { + throw std::runtime_error("bad zip64 extra field in " + path); + } + if (header_id == 0x0001) { + size_t zp = ep; + if (uncomp_size == 0xffffffffULL && zp + 8 <= ep + data_size) { + uncomp_size = le64(zip, zp); + zp += 8; + } + if (comp_size == 0xffffffffULL && zp + 8 <= ep + data_size) { + comp_size = le64(zip, zp); + } + break; + } + ep += data_size; + } + } + if (data_pos + comp_size > zip.size()) { + throw std::runtime_error("bad zip entry size in " + path); + } + std::string name(reinterpret_cast(&zip[name_pos]), name_len); + if (method != 0 || (flags & 0x08u)) { + throw std::runtime_error("npz entry is compressed or uses data descriptor: " + name); + } + out[name] = std::vector(zip.begin() + data_pos, zip.begin() + data_pos + uncomp_size); + p = data_pos + comp_size; + } + return out; +} + +struct NpyArray { + std::string descr; + std::vector shape; + const uint8_t *data = nullptr; + size_t bytes = 0; +}; + +std::string trim(std::string s) { + auto is_ws = [](unsigned char c) { return std::isspace(c); }; + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [&](char c) { return !is_ws(c); })); + s.erase(std::find_if(s.rbegin(), s.rend(), [&](char c) { return !is_ws(c); }).base(), s.end()); + return s; +} + +using Hotwords = std::vector>; + +void add_hotword_spec(Hotwords &hotwords, const std::string &spec, const char *source) { + const size_t eq = spec.find('='); + if (eq == std::string::npos || eq == 0) { + throw std::runtime_error(std::string("bad ") + source + " hotword, expected FROM=TO"); + } + hotwords.emplace_back(spec.substr(0, eq), spec.substr(eq + 1)); +} + +void add_hotword_specs(Hotwords &hotwords, const std::string &specs, const char *source) { + std::stringstream ss(specs); + std::string item; + while (std::getline(ss, item, ';')) { + item = trim(item); + if (!item.empty()) { + add_hotword_spec(hotwords, item, source); + } + } +} + +Hotwords env_hotwords() { + Hotwords hotwords; + if (const char *v = std::getenv("Q3TTS_HOTWORDS"); v && *v) { + add_hotword_specs(hotwords, v, "Q3TTS_HOTWORDS"); + } + return hotwords; +} + +std::string apply_hotwords(std::string text, const Hotwords &hotwords) { + for (const auto &hw : hotwords) { + size_t pos = 0; + while ((pos = text.find(hw.first, pos)) != std::string::npos) { + text.replace(pos, hw.first.size(), hw.second); + pos += hw.second.size(); + } + } + return text; +} + +NpyArray parse_npy(const std::vector &npy) { + if (npy.size() < 16 || std::memcmp(npy.data(), "\x93NUMPY", 6) != 0) { + throw std::runtime_error("bad npy magic"); + } + const uint8_t major = npy[6]; + size_t header_len = 0; + size_t header_pos = 0; + if (major == 1) { + header_len = le16(npy, 8); + header_pos = 10; + } else if (major == 2 || major == 3) { + header_len = le32(npy, 8); + header_pos = 12; + } else { + throw std::runtime_error("unsupported npy version"); + } + if (header_pos + header_len > npy.size()) { + throw std::runtime_error("bad npy header length"); + } + std::string header(reinterpret_cast(&npy[header_pos]), header_len); + + NpyArray arr; + auto descr_pos = header.find("'descr'"); + if (descr_pos == std::string::npos) { + descr_pos = header.find("\"descr\""); + } + auto q1 = header.find('\'', header.find(':', descr_pos)); + auto q2 = header.find('\'', q1 + 1); + arr.descr = header.substr(q1 + 1, q2 - q1 - 1); + + auto shape_pos = header.find("'shape'"); + if (shape_pos == std::string::npos) { + shape_pos = header.find("\"shape\""); + } + auto l = header.find('(', shape_pos); + auto r = header.find(')', l); + std::stringstream ss(header.substr(l + 1, r - l - 1)); + std::string item; + while (std::getline(ss, item, ',')) { + item = trim(item); + if (!item.empty()) { + arr.shape.push_back(std::stoll(item)); + } + } + arr.data = npy.data() + header_pos + header_len; + arr.bytes = npy.size() - header_pos - header_len; + return arr; +} + +int64_t first_dim(const NpyArray &arr, const std::string &name) { + if (arr.shape.empty()) { + throw std::runtime_error("missing first dimension for " + name); + } + return arr.shape[0]; +} + +std::string default_model_dir() { + const char *home = std::getenv("HOME"); + if (home && *home) { + return std::string(home) + "/.cache/models/tts/qwen3-tts"; + } + return "/tmp/.cache/models/tts/qwen3-tts"; +} + +struct Args { + std::string npz; + std::string text; + std::string ref_wav; + std::string ref_bin; + std::string model_dir = env_str("Q3TTS_MODEL_DIR", default_model_dir()); + std::string talker_gguf = env_str("Q3TTS_TALKER_GGUF", "qwen3-tts-0.6b-talker-qkv-gateup-q8_0-side.gguf"); + std::string cp_gguf = env_str("Q3TTS_CP_GGUF", "qwen3-tts-0.6b-cp-qkv-gateup-rawq4.gguf"); + std::string language = env_str("Q3TTS_LANGUAGE", "auto"); + std::string clone_leadin = env_str("Q3TTS_CLONE_LEADIN", ""); + int frames = 60; + std::string wav = "cpp_driver.wav"; + Hotwords hotwords = env_hotwords(); + int play_rate = env_int("Q3TTS_PLAY_RATE", 24000); + int play_channels = env_int("Q3TTS_PLAY_CHANNELS", 1); + int play_device = env_int("Q3TTS_PLAY_DEVICE", -1); + int play_buffer = env_int("Q3TTS_PLAY_BUFFER", 1024); + int play_tail_ms = env_int("Q3TTS_PLAY_TAIL_MS", 300); + int play_drain_ms = env_int("Q3TTS_PLAY_DRAIN_MS", 250); + int play_segment_pause_ms = env_int("Q3TTS_PLAY_SEGMENT_PAUSE_MS", 120); + bool dump_ids = false; + bool dump_segments = false; + bool frontend_only = false; + bool no_clone_split = false; + bool play_segments = false; + bool stdin_segments = false; + bool talker_gguf_cli = false; +}; + +Args parse_args(int argc, char **argv) { + Args a; + for (int i = 1; i < argc; ++i) { + std::string k = argv[i]; + auto need = [&](const char *name) -> std::string { + if (i + 1 >= argc) { + throw std::runtime_error(std::string("missing value for ") + name); + } + return argv[++i]; + }; + if (k == "--npz") { + a.npz = need("--npz"); + } else if (k == "--text") { + a.text = need("--text"); + } else if (k == "--ref-wav") { + a.ref_wav = need("--ref-wav"); + } else if (k == "--ref-bin") { + a.ref_bin = need("--ref-bin"); + } else if (k == "--model-dir") { + a.model_dir = need("--model-dir"); + } else if (k == "--talker-gguf") { + a.talker_gguf = need("--talker-gguf"); + a.talker_gguf_cli = true; + } else if (k == "--cp-gguf") { + a.cp_gguf = need("--cp-gguf"); + } else if (k == "--language") { + a.language = need("--language"); + } else if (k == "--dump-ids") { + a.dump_ids = true; + } else if (k == "--dump-segments") { + a.dump_segments = true; + } else if (k == "--frontend-only") { + a.frontend_only = true; + } else if (k == "--no-clone-split") { + a.no_clone_split = true; + } else if (k == "--play-segments") { + a.play_segments = true; + } else if (k == "--stdin-segments") { + a.stdin_segments = true; + } else if (k == "--clone-leadin") { + a.clone_leadin = need("--clone-leadin"); + } else if (k == "--play-rate") { + a.play_rate = std::stoi(need("--play-rate")); + } else if (k == "--play-channels") { + a.play_channels = std::stoi(need("--play-channels")); + } else if (k == "--play-device") { + a.play_device = std::stoi(need("--play-device")); + } else if (k == "--play-buffer") { + a.play_buffer = std::stoi(need("--play-buffer")); + } else if (k == "--play-tail-ms") { + a.play_tail_ms = std::stoi(need("--play-tail-ms")); + } else if (k == "--play-drain-ms") { + a.play_drain_ms = std::stoi(need("--play-drain-ms")); + } else if (k == "--play-segment-pause-ms") { + a.play_segment_pause_ms = std::stoi(need("--play-segment-pause-ms")); + } else if (k == "--hotword") { + add_hotword_spec(a.hotwords, need("--hotword"), "--hotword"); + } else if (k == "--hotwords") { + add_hotword_specs(a.hotwords, need("--hotwords"), "--hotwords"); + } else if (k == "--frames") { + a.frames = std::stoi(need("--frames")); + } else if (k == "--wav") { + a.wav = need("--wav"); + } else { + throw std::runtime_error("unknown arg: " + k); + } + } + return a; +} + +bool has_clone_reference(const Args &args) { + return !args.ref_bin.empty() || !args.ref_wav.empty(); +} + +bool has_full_reference_prompt(const std::string &ref_bin) { + if (ref_bin.empty()) { + return false; + } + try { + return q3tts_frontend::read_reference_prompt_bin(ref_bin).full_prompt(); + } catch (const std::exception &) { + return false; + } +} + +void maybe_select_full_prompt_talker(Args &args) { + if (!has_full_reference_prompt(args.ref_bin)) { + return; + } + if (args.talker_gguf_cli || std::getenv("Q3TTS_TALKER_GGUF")) { + return; + } + const std::string full_talker = + env_str("Q3TTS_FULL_PROMPT_TALKER_GGUF", "qwen3-tts-0.6b-talker-qkv-gateup-q8_0-side.gguf"); + if (full_talker.empty() || full_talker == "0") { + return; + } + try { + (void)q3tts_frontend::first_existing({ + q3tts_frontend::path_join(q3tts_frontend::path_join(args.model_dir, "gguf"), full_talker), + q3tts_frontend::path_join(args.model_dir, full_talker), + full_talker, + }); + args.talker_gguf = full_talker; + std::cout << "full_prompt_talker " << full_talker << "\n"; + } catch (const std::exception &) { + std::cerr << "warning: full prompt talker not found, keeping " << args.talker_gguf << "\n"; + } +} + +std::vector> +load_ref_decode_prefix(const std::string &ref_bin) { + if (ref_bin.empty()) { + return {}; + } + auto ref_prompt = q3tts_frontend::read_reference_prompt_bin(ref_bin); + if (!ref_prompt.full_prompt()) { + return {}; + } + return ref_prompt.ref_codes; +} + +size_t reference_audio_cut_samples(size_t ref_frames, size_t generated_frames, size_t wav_samples) { + const size_t total_frames = ref_frames + generated_frames; + if (ref_frames == 0 || total_frames == 0 || wav_samples == 0) { + return 0; + } + return static_cast( + (static_cast(ref_frames) * static_cast(wav_samples)) / + static_cast(total_frames)); +} + +std::vector decode_with_reference_prefix( + q3tts_codec::DecoderPool &codec, + const std::vector> &generated, + const std::vector> &ref_prefix, + const std::vector &buckets, + int first_chunk, + int chunk, + int ctx_limit) { + if (ref_prefix.empty()) { + return codec.decode_chunks(generated, buckets, first_chunk, chunk, ctx_limit); + } + + std::vector> all; + all.reserve(ref_prefix.size() + generated.size()); + all.insert(all.end(), ref_prefix.begin(), ref_prefix.end()); + all.insert(all.end(), generated.begin(), generated.end()); + + auto wav = codec.decode_chunks(all, buckets, first_chunk, chunk, ctx_limit); + const size_t ref_samples = reference_audio_cut_samples(ref_prefix.size(), generated.size(), wav.size()); + if (wav.size() <= ref_samples) { + wav.clear(); + } else { + wav.erase(wav.begin(), wav.begin() + static_cast(ref_samples)); + } + return wav; +} + +void apply_biquad_lowpass(std::vector &samples, double cutoff_hz) { + if (samples.size() < 3 || cutoff_hz <= 0.0) { + return; + } + + constexpr double sample_rate = 24000.0; + constexpr double pi = 3.14159265358979323846; + cutoff_hz = std::max(1000.0, std::min(cutoff_hz, sample_rate * 0.49)); + + const double q = 0.7071067811865476; + const double w0 = 2.0 * pi * cutoff_hz / sample_rate; + const double c = std::cos(w0); + const double alpha = std::sin(w0) / (2.0 * q); + const double a0 = 1.0 + alpha; + const double b0 = ((1.0 - c) * 0.5) / a0; + const double b1 = (1.0 - c) / a0; + const double b2 = ((1.0 - c) * 0.5) / a0; + const double a1 = (-2.0 * c) / a0; + const double a2 = (1.0 - alpha) / a0; + + auto pass = [&](std::vector &x) { + double x1 = x.front(); + double x2 = x.front(); + double y1 = x.front(); + double y2 = x.front(); + for (float &sample : x) { + const double x0 = sample; + const double y0 = b0 * x0 + b1 * x1 + b2 * x2 - a1 * y1 - a2 * y2; + x2 = x1; + x1 = x0; + y2 = y1; + y1 = y0; + sample = static_cast(std::max(-1.0, std::min(1.0, y0))); + } + }; + + pass(samples); + pass(samples); +} + +void apply_peak_normalize(std::vector &samples, double target_db, double max_gain_db) { + if (samples.empty()) { + return; + } + double peak = 0.0; + for (float x : samples) { + peak = std::max(peak, static_cast(std::abs(x))); + } + if (peak <= 0.0) { + return; + } + + const double target = std::pow(10.0, target_db / 20.0); + const double max_gain = std::pow(10.0, std::max(0.0, max_gain_db) / 20.0); + const double gain = std::min(target / peak, max_gain); + for (float &x : samples) { + x = static_cast(std::max(-1.0, std::min(1.0, static_cast(x) * gain))); + } +} + +void postprocess_audio_f32(std::vector &samples) { + if (env_int("Q3TTS_AUDIO_POSTPROCESS", 1) == 0 || samples.empty()) { + return; + } + apply_biquad_lowpass(samples, env_double("Q3TTS_AUDIO_LOWPASS_HZ", 9000.0)); + apply_peak_normalize( + samples, + env_double("Q3TTS_AUDIO_PEAK_DB", -3.0), + env_double("Q3TTS_AUDIO_MAX_GAIN_DB", 9.0)); +} + +std::vector f32_to_pcm16(std::vector samples, bool postprocess) { + if (postprocess) { + postprocess_audio_f32(samples); + } + + std::vector pcm; + pcm.reserve(samples.size()); + for (float x : samples) { + x = std::max(-1.0f, std::min(1.0f, x)); + pcm.push_back(static_cast(x * 32767.0f)); + } + return pcm; +} + +void postprocess_pcm16(std::vector &samples) { + if (env_int("Q3TTS_AUDIO_POSTPROCESS", 1) == 0 || samples.empty()) { + return; + } + + std::vector f32; + f32.reserve(samples.size()); + for (int16_t sample : samples) { + f32.push_back(static_cast(sample) / 32768.0f); + } + postprocess_audio_f32(f32); + samples = f32_to_pcm16(std::move(f32), false); +} + +void write_wav_i16_samples(const std::string &path, const std::vector &samples); + +void write_wav_i16(const std::string &path, const std::vector &samples) { + write_wav_i16_samples(path, f32_to_pcm16(samples, true)); +} + +void write_wav_i16_samples(const std::string &path, const std::vector &samples) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("write wav failed: " + path); + } + const uint32_t data_bytes = static_cast(samples.size() * sizeof(int16_t)); + const uint32_t riff_size = 36 + data_bytes; + auto w16 = [&](uint16_t v) { + char b[2] = {static_cast(v & 0xff), static_cast((v >> 8) & 0xff)}; + f.write(b, 2); + }; + auto w32 = [&](uint32_t v) { + char b[4] = {static_cast(v & 0xff), static_cast((v >> 8) & 0xff), + static_cast((v >> 16) & 0xff), static_cast((v >> 24) & 0xff)}; + f.write(b, 4); + }; + f.write("RIFF", 4); + w32(riff_size); + f.write("WAVEfmt ", 8); + w32(16); + w16(1); + w16(1); + w32(24000); + w32(48000); + w16(2); + w16(16); + f.write("data", 4); + w32(data_bytes); + for (int16_t s : samples) { + w16(static_cast(s)); + } +} + +std::vector read_wav_i16_mono_24k(const std::string &path) { + auto wav = read_all(path); + if (wav.size() < 44 || std::memcmp(wav.data(), "RIFF", 4) != 0 || + std::memcmp(wav.data() + 8, "WAVE", 4) != 0) { + throw std::runtime_error("bad wav: " + path); + } + + uint16_t audio_format = 0; + uint16_t channels = 0; + uint32_t sample_rate = 0; + uint16_t bits_per_sample = 0; + size_t data_pos = 0; + size_t data_bytes = 0; + for (size_t p = 12; p + 8 <= wav.size();) { + const std::string id(reinterpret_cast(&wav[p]), 4); + const uint32_t n = le32(wav, p + 4); + const size_t body = p + 8; + if (body + n > wav.size()) { + throw std::runtime_error("bad wav chunk size: " + path); + } + if (id == "fmt " && n >= 16) { + audio_format = le16(wav, body); + channels = le16(wav, body + 2); + sample_rate = le32(wav, body + 4); + bits_per_sample = le16(wav, body + 14); + } else if (id == "data") { + data_pos = body; + data_bytes = n; + } + p = body + n + (n & 1u); + } + if (audio_format != 1 || channels != 1 || sample_rate != 24000 || bits_per_sample != 16 || + data_pos == 0 || data_bytes == 0) { + throw std::runtime_error("expected PCM16 mono 24k wav: " + path); + } + std::vector samples(data_bytes / sizeof(int16_t)); + for (size_t i = 0; i < samples.size(); ++i) { + samples[i] = static_cast(le16(wav, data_pos + i * 2)); + } + return samples; +} + +size_t utf8_char_len(unsigned char c) { + if ((c & 0x80u) == 0) { + return 1; + } + if ((c & 0xe0u) == 0xc0u) { + return 2; + } + if ((c & 0xf0u) == 0xe0u) { + return 3; + } + if ((c & 0xf8u) == 0xf0u) { + return 4; + } + return 1; +} + +size_t utf8_len(const std::string &s) { + size_t n = 0; + for (size_t i = 0; i < s.size();) { + i += std::min(utf8_char_len(static_cast(s[i])), s.size() - i); + ++n; + } + return n; +} + +uint32_t utf8_codepoint(const std::string &s, size_t i, size_t n) { + const auto b0 = static_cast(s[i]); + if (n == 1) { + return b0; + } + if (n == 2 && i + 1 < s.size()) { + return (static_cast(b0 & 0x1fu) << 6) | + static_cast(static_cast(s[i + 1]) & 0x3fu); + } + if (n == 3 && i + 2 < s.size()) { + return (static_cast(b0 & 0x0fu) << 12) | + (static_cast(static_cast(s[i + 1]) & 0x3fu) << 6) | + static_cast(static_cast(s[i + 2]) & 0x3fu); + } + if (n == 4 && i + 3 < s.size()) { + return (static_cast(b0 & 0x07u) << 18) | + (static_cast(static_cast(s[i + 1]) & 0x3fu) << 12) | + (static_cast(static_cast(s[i + 2]) & 0x3fu) << 6) | + static_cast(static_cast(s[i + 3]) & 0x3fu); + } + return b0; +} + +std::string lower_ascii_copy(std::string s) { + for (char &c : s) { + if (c >= 'A' && c <= 'Z') { + c = static_cast(c - 'A' + 'a'); + } + } + return s; +} + +bool contains_ascii_letter(const std::string &s) { + for (unsigned char c : s) { + if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) { + return true; + } + } + return false; +} + +bool contains_cjk(const std::string &s) { + for (size_t i = 0; i < s.size();) { + const size_t n = std::min(utf8_char_len(static_cast(s[i])), s.size() - i); + const uint32_t cp = utf8_codepoint(s, i, n); + if ((cp >= 0x3400 && cp <= 0x9fff) || + (cp >= 0xf900 && cp <= 0xfaff) || + (cp >= 0x20000 && cp <= 0x2ebef)) { + return true; + } + i += n; + } + return false; +} + +bool in_set(const std::vector &items, const std::string &x) { + return std::find(items.begin(), items.end(), x) != items.end(); +} + +bool starts_with(const std::string &s, const std::string &prefix) { + return s.rfind(prefix, 0) == 0; +} + +bool ends_with(const std::string &s, const std::string &suffix) { + return s.size() >= suffix.size() && + s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +bool prefer_chinese_sentence_mark(const std::string &segment, const std::string &language) { + const std::string l = lower_ascii_copy(language); + if (l == "chinese" || l == "zh" || l == "zh-cn") { + return true; + } + if (l == "english" || l == "en" || l == "en-us") { + return false; + } + if (contains_cjk(segment)) { + return true; + } + if (contains_ascii_letter(segment)) { + return false; + } + return true; +} + +std::string sentence_mark_for(const std::string &segment, const std::string &language) { + return prefer_chinese_sentence_mark(segment, language) ? "\xe3\x80\x82" : "."; +} + +std::string normalize_segment_end(std::string segment, const std::string &language) { + const std::string mark = sentence_mark_for(segment, language); + const std::vector weak = {"\xef\xbc\x8c", ",", "\xe3\x80\x81", "\xef\xbc\x9a", ":"}; + for (const auto &punct : weak) { + if (ends_with(segment, punct)) { + segment.replace(segment.size() - punct.size(), punct.size(), mark); + break; + } + } + return segment; +} + +std::string ensure_sentence_end(std::string segment, const std::string &language) { + segment = normalize_segment_end(trim(segment), language); + if (segment.empty()) { + return segment; + } + const std::vector final = { + "\xe3\x80\x82", ".", "\xef\xbc\x81", "\xef\xbc\x9f", "!", "?", "\xef\xbc\x9b", ";", + "\xef\xbc\x8c", ",", "\xe3\x80\x81", "\xef\xbc\x9a", ":"}; + for (const auto &punct : final) { + if (ends_with(segment, punct)) { + return normalize_segment_end(segment, language); + } + } + segment += sentence_mark_for(segment, language); + return segment; +} + +std::vector split_on_punct(const std::string &text, + const std::vector &breaks, + size_t min_chars_before_break) { + std::vector out; + std::string cur; + for (size_t i = 0; i < text.size();) { + const size_t n = std::min(utf8_char_len(static_cast(text[i])), text.size() - i); + const std::string ch = text.substr(i, n); + cur += ch; + i += n; + if (in_set(breaks, ch) && utf8_len(cur) >= min_chars_before_break) { + std::string item = trim(cur); + if (!item.empty()) { + out.push_back(item); + } + cur.clear(); + } + } + std::string tail = trim(cur); + if (!tail.empty()) { + out.push_back(tail); + } + return out; +} + +std::vector split_by_max_chars(const std::string &text, size_t max_chars) { + std::vector out; + std::string cur; + size_t cur_chars = 0; + size_t last_soft_byte = std::string::npos; + const size_t min_soft_chars = std::min(8, max_chars); + size_t chars_at_last_soft = 0; + const auto soft = std::vector{" ", "\t", "\xef\xbc\x8c", ",", "\xe3\x80\x81", "\xef\xbc\x9a", ":"}; + + auto flush = [&]() { + std::string item = trim(cur); + if (!item.empty()) { + out.push_back(item); + } + cur.clear(); + cur_chars = 0; + last_soft_byte = std::string::npos; + chars_at_last_soft = 0; + }; + + for (size_t i = 0; i < text.size();) { + const size_t n = std::min(utf8_char_len(static_cast(text[i])), text.size() - i); + const std::string ch = text.substr(i, n); + if (cur_chars >= min_soft_chars && in_set(soft, ch)) { + last_soft_byte = cur.size() + n; + chars_at_last_soft = cur_chars + 1; + } + cur += ch; + ++cur_chars; + i += n; + if (cur_chars > max_chars) { + if (last_soft_byte != std::string::npos && chars_at_last_soft >= min_soft_chars) { + std::string head = trim(cur.substr(0, last_soft_byte)); + std::string tail = trim(cur.substr(last_soft_byte)); + if (!head.empty()) { + out.push_back(head); + } + cur = tail; + cur_chars = utf8_len(cur); + } else { + flush(); + } + last_soft_byte = std::string::npos; + chars_at_last_soft = 0; + } + } + flush(); + return out; +} + +std::vector split_clone_text(const std::string &text, + bool full_prompt_ref, + const std::string &language) { + const bool latin_text = !prefer_chinese_sentence_mark(text, language); + int hard_max = env_int("Q3TTS_CLONE_MAX_CHARS", full_prompt_ref ? (latin_text ? 96 : 16) : 0); + if (latin_text) { + hard_max = env_int("Q3TTS_CLONE_MAX_CHARS_LATIN", hard_max); + } + const size_t max_chars = hard_max > 0 ? static_cast(std::max(8, hard_max)) : 28UL; + const size_t weak_min_chars = + static_cast(std::max(1, env_int("Q3TTS_FULL_PROMPT_WEAK_MIN_CHARS", + full_prompt_ref ? (latin_text ? 24 : 6) : 12))); + const auto strong = std::vector{ + "\xe3\x80\x82", ".", "\xef\xbc\x81", "\xef\xbc\x9f", "!", "?", "\xef\xbc\x9b", ";"}; + const auto weak = std::vector{ + "\xef\xbc\x8c", ",", "\xe3\x80\x81", "\xef\xbc\x9a", ":"}; + std::vector first = split_on_punct(text, strong, 1); + std::vector out; + for (const auto &s : first) { + if (utf8_len(s) <= max_chars) { + out.push_back(s); + continue; + } + auto pieces = split_on_punct(s, weak, weak_min_chars); + if (pieces.size() <= 1) { + if (hard_max > 0) { + auto hard = split_by_max_chars(s, max_chars); + out.insert(out.end(), hard.begin(), hard.end()); + } else { + out.push_back(s); + } + } else { + for (const auto &piece : pieces) { + if (hard_max <= 0 || utf8_len(piece) <= max_chars) { + out.push_back(piece); + } else { + auto hard = split_by_max_chars(piece, max_chars); + out.insert(out.end(), hard.begin(), hard.end()); + } + } + } + } + if (out.empty()) { + std::string item = trim(text); + if (!item.empty()) { + out.push_back(item); + } + } + return out; +} + +std::vector split_stdin_text(const std::string &text, + bool full_prompt_ref, + const std::string &language) { + (void) full_prompt_ref; + if (env_int("Q3TTS_STDIN_SAFE_SPLIT", 1) == 0) { + std::string item = ensure_sentence_end(text, language); + return item.empty() ? std::vector{} : std::vector{item}; + } + const bool latin_text = !prefer_chinese_sentence_mark(text, language); + const int default_max_chars = latin_text ? 96 : 48; + const size_t max_chars = static_cast(std::max( + 8, env_int("Q3TTS_STDIN_MAX_CHARS_PER_SEGMENT", default_max_chars))); + const size_t weak_min_chars = static_cast( + std::max(1, env_int("Q3TTS_STDIN_WEAK_MIN_CHARS", 24))); + const auto strong = std::vector{ + "\xe3\x80\x82", ".", "\xef\xbc\x81", "\xef\xbc\x9f", "!", "?", "\xef\xbc\x9b", ";"}; + const auto weak = std::vector{ + "\xef\xbc\x8c", ",", "\xe3\x80\x81", "\xef\xbc\x9a", ":"}; + std::vector out; + for (const auto &sentence : split_on_punct(text, strong, 1)) { + std::vector pieces; + if (utf8_len(sentence) <= max_chars) { + pieces.push_back(sentence); + } else { + pieces = split_on_punct(sentence, weak, weak_min_chars); + if (pieces.size() <= 1) { + pieces = split_by_max_chars(sentence, max_chars); + } + } + for (const auto &piece : pieces) { + if (utf8_len(piece) <= max_chars) { + std::string item = ensure_sentence_end(piece, language); + if (!item.empty()) { + out.push_back(item); + } + } else { + for (const auto &hard : split_by_max_chars(piece, max_chars)) { + std::string item = ensure_sentence_end(hard, language); + if (!item.empty()) { + out.push_back(item); + } + } + } + } + } + if (out.empty()) { + std::string item = ensure_sentence_end(text, language); + if (!item.empty()) { + out.push_back(item); + } + } + return out; +} + +bool needs_clone_leadin(const std::string &segment, const std::string &leadin) { + if (leadin.empty()) { + return false; + } + if (env_int("Q3TTS_CLONE_LEADIN_ALWAYS", 1) != 0) { + return true; + } + const std::vector stable_prefixes = { + "\xe4\xbd\xa0\xe5\xa5\xbd", + "\xe6\x82\xa8\xe5\xa5\xbd", + "hello", + "Hello"}; + for (const auto &prefix : stable_prefixes) { + if (starts_with(segment, prefix)) { + return false; + } + } + return true; +} + +int clone_leadin_trim_frames(const std::string &leadin) { + const std::string default_leadin = "\xe4\xbd\xa0\xe5\xa5\xbd\xef\xbc\x8c"; + const int default_frames = (leadin == default_leadin) ? 7 : 0; + return std::max(0, env_int("Q3TTS_CLONE_LEADIN_TRIM_FRAMES", default_frames)); +} + +size_t clone_leadin_trim_samples(const std::string &leadin) { + const int explicit_samples = env_int("Q3TTS_CLONE_LEADIN_TRIM_SAMPLES", -1); + if (explicit_samples >= 0) { + return static_cast(explicit_samples); + } + return static_cast(clone_leadin_trim_frames(leadin)) * 1920UL; +} + +void trim_leading_samples(std::vector &samples, size_t n) { + if (n == 0) { + return; + } + if (n >= samples.size()) { + samples.clear(); + return; + } + samples.erase(samples.begin(), samples.begin() + static_cast(n)); +} + +std::vector prepare_clone_segments(const Args &args) { + const bool full_prompt_ref = has_full_reference_prompt(args.ref_bin); + std::vector segments = split_clone_text(args.text, full_prompt_ref, args.language); + for (auto &segment : segments) { + segment = normalize_segment_end(segment, args.language); + } + return segments; +} + +std::string shell_quote(const std::string &s) { + std::string out = "'"; + for (char c : s) { + if (c == '\'') { + out += "'\\''"; + } else { + out += c; + } + } + out += "'"; + return out; +} + +double rms_i16(const std::vector &samples, size_t begin, size_t end) { + double sum = 0.0; + for (size_t i = begin; i < end; ++i) { + const double x = static_cast(samples[i]) / 32768.0; + sum += x * x; + } + return std::sqrt(sum / static_cast(end - begin)); +} + +void trim_trailing_silence(std::vector &samples) { + const size_t block = 2400; + const size_t keep = 2400; + const double threshold = 0.003; + for (size_t end = samples.size(); end > 0;) { + const size_t begin = end > block ? end - block : 0; + if (rms_i16(samples, begin, end) > threshold) { + const size_t trimmed = std::min(samples.size(), end + keep); + samples.resize(trimmed); + return; + } + end = begin; + } +} + +double seconds_since(Clock::time_point a, Clock::time_point b); +std::string model_file(const std::string &model_dir, const std::string &subdir, const std::string &name); + +int run_clone_split(const Args &args, const char *argv0, const std::vector &segments) { + (void)argv0; + if (segments.empty()) { + throw std::runtime_error("empty text"); + } + + const auto t0 = Clock::now(); + const bool profile = env_int("Q3TTS_PROFILE", 0) != 0; + profile_event(profile, t0, "clone_split_start segments=" + std::to_string(segments.size())); + const std::string pid = std::to_string(static_cast(getpid())); + std::vector temps; + auto cleanup = [&]() { + for (const auto &p : temps) { + std::remove(p.c_str()); + } + }; + + try { + if (!std::getenv("Q3TTS_SKIP_TCM_CLEAR")) { + const int clear_rc = std::system("spacemit-tcm-smi -c >/dev/null 2>&1"); + (void)clear_rc; + } + profile_event(profile, t0, "tcm_clear_done"); + + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "q3tts_cpp_driver"); + std::string ref_bin = args.ref_bin; + if (ref_bin.empty() && !args.ref_wav.empty()) { + ref_bin = "/tmp/q3tts_split_" + pid + "_ref.spk.bin"; + temps.push_back(ref_bin); + auto spk = q3tts_frontend::run_speaker_encoder( + env, q3tts_frontend::speaker_encoder_path(args.model_dir), args.ref_wav, + env_int("Q3TTS_FRONTEND_THREADS", 2)); + q3tts_frontend::write_speaker_bin(ref_bin, spk); + } + profile_event(profile, t0, "ref_ready"); + + q3tts_frontend::FrontendConfig base_fc; + base_fc.model_dir = args.model_dir; + base_fc.ref_bin = ref_bin; + base_fc.language = args.language; + base_fc.talker_gguf = args.talker_gguf; + base_fc.cp_gguf = args.cp_gguf; + base_fc.frontend_threads = env_int("Q3TTS_FRONTEND_THREADS", 2); + base_fc.full_prompt_non_streaming = env_int("Q3TTS_FULL_PROMPT_NON_STREAMING", 0) != 0; + q3tts_frontend::FrontendRuntime frontend(env, base_fc); + profile_event(profile, t0, "frontend_runtime_ready"); + + struct SegmentJob { + std::string prefill; + std::string trailing; + std::string pad; + std::string codes; + std::string done; + std::string wav; + std::string text; + std::string synth_text; + size_t leadin_trim_samples = 0; + int64_t np = 0; + int64_t nt = 0; + }; + std::vector jobs; + jobs.reserve(segments.size()); + for (size_t i = 0; i < segments.size(); ++i) { + SegmentJob job; + const std::string base = "/tmp/q3tts_split_" + pid + "_" + std::to_string(i); + job.prefill = base + "_prefill.bin"; + job.trailing = base + "_trailing.bin"; + job.pad = base + "_pad.bin"; + job.codes = base + "_codes.bin"; + job.done = base + "_done"; + job.wav = base + ".wav"; + job.text = segments[i]; + if (needs_clone_leadin(job.text, args.clone_leadin)) { + job.synth_text = args.clone_leadin + job.text; + job.leadin_trim_samples = clone_leadin_trim_samples(args.clone_leadin); + } else { + job.synth_text = job.text; + } + temps.push_back(job.prefill); + temps.push_back(job.trailing); + temps.push_back(job.pad); + temps.push_back(job.codes); + temps.push_back(job.done); + temps.push_back(job.wav); + + q3tts_frontend::FrontendConfig fc = base_fc; + fc.text = job.synth_text; + auto input = frontend.build(fc); + write_all(job.prefill, + reinterpret_cast(input.prefill.data()), + input.prefill.size() * sizeof(float)); + write_all(job.trailing, + reinterpret_cast(input.trailing.data()), + input.trailing.size() * sizeof(float)); + write_all(job.pad, + reinterpret_cast(input.pad.data()), + input.pad.size() * sizeof(float)); + job.np = input.n_prefill; + job.nt = input.n_trailing; + jobs.push_back(std::move(job)); + profile_event(profile, t0, "frontend_segment_done i=" + std::to_string(i)); + } + + const std::string job_list = "/tmp/q3tts_split_" + pid + "_jobs.txt"; + temps.push_back(job_list); + { + std::ofstream jf(job_list); + if (!jf) { + throw std::runtime_error("write jobs failed: " + job_list); + } + for (const auto &job : jobs) { + jf << job.prefill << " " << job.np << " " + << job.trailing << " " << job.nt << " " + << job.pad << " " << args.frames << " " + << job.codes << " " << job.done << "\n"; + } + } + profile_event(profile, t0, "job_list_ready"); + + const int codec_threads = env_int("Q3TTS_CODEC_THREADS", 3); + if (const char *aff = std::getenv("Q3TTS_CODEC_AFFINITY"); aff && *aff) { + setenv("SPACEMIT_EP_INTRA_THREAD_AFFINITY", aff, 1); + if (!std::getenv("SPACEMIT_EP_INTRA_THREAD_NUM")) { + setenv("SPACEMIT_EP_INTRA_THREAD_NUM", std::to_string(codec_threads).c_str(), 1); + } + if (!std::getenv("SPACEMIT_EP_USE_GLOBAL_INTRA_THREAD")) { + setenv("SPACEMIT_EP_USE_GLOBAL_INTRA_THREAD", "0", 1); + } + } + + const std::vector buckets = + parse_int_list(env_str("Q3TTS_CODEC_BUCKETS", ""), std::vector{50}); + const int chunk = env_int("Q3TTS_CODEC_CHUNK", 50); + const int first_chunk = env_int("Q3TTS_CODEC_FIRST_CHUNK", chunk); + const int ctx_limit = env_int("Q3TTS_CODEC_CTX", 25); + +#ifndef Q3TTS_ENABLE_SDK_AUDIO + if (args.play_segments) { + throw std::runtime_error("--play-segments requires Q3TTS_ENABLE_SDK_AUDIO build"); + } +#else + std::unique_ptr realtime_player; + if (args.play_segments) { + realtime_player = std::make_unique( + args.play_rate, args.play_channels, args.play_device, args.play_buffer, + args.play_tail_ms, args.play_drain_ms, args.play_segment_pause_ms); + } +#endif + + const std::string driver = env_str("TALKER_DRIVER", "./talker_driver"); + const std::string talker_cpuset = env_str("TALKER_CPUSET", "4-7"); + const std::string talker_gguf = model_file(args.model_dir, "gguf", args.talker_gguf); + const std::string cp_gguf = model_file(args.model_dir, "gguf", args.cp_gguf); + std::stringstream cmd; + cmd << "taskset -c " << shell_quote(talker_cpuset) + << " " << shell_quote(driver) + << " " << shell_quote(talker_gguf) + << " " << shell_quote(cp_gguf) + << " --jobs " << shell_quote(job_list); + std::atomic talker_rc{-1}; + std::thread talker_thread([&]() { + talker_rc.store(std::system(cmd.str().c_str())); + }); + profile_event(profile, t0, "talker_launched"); + + q3tts_codec::DecoderPoolConfig codec_cfg; + codec_cfg.model_dir = args.model_dir; + codec_cfg.buckets = buckets; + codec_cfg.intra_threads = codec_threads; + codec_cfg.on_bucket_warm = [&](int b) { + profile_event(profile, t0, "codec_warm bucket=" + std::to_string(b)); + }; + q3tts_codec::DecoderPool codec(env, codec_cfg); + profile_event(profile, t0, "codec_ready"); + + const auto ref_decode_prefix = load_ref_decode_prefix(ref_bin); + + auto decode_codes = [&](const std::string &path, size_t *frame_count) -> std::vector { + auto data = read_all(path); + const size_t frame_bytes = sizeof(int32_t) * 16; + if (data.size() % frame_bytes != 0) { + throw std::runtime_error("bad code file size: " + path); + } + if (frame_count) { + *frame_count = data.size() / frame_bytes; + } + std::vector> frames(data.size() / frame_bytes); + if (!frames.empty()) { + std::memcpy(frames.data(), data.data(), data.size()); + } + return decode_with_reference_prefix(codec, frames, ref_decode_prefix, buckets, first_chunk, chunk, ctx_limit); + }; + + auto decode_job_samples = [&](const SegmentJob &job) -> std::vector { + size_t code_frames = 0; + auto wav_f32 = decode_codes(job.codes, &code_frames); + if (args.frames > 0 && code_frames >= static_cast(args.frames)) { + std::cout << "clone_segment_truncated" + << " frames " << code_frames + << " max " << args.frames + << " text " << job.text << std::endl; + if (env_int("Q3TTS_ALLOW_TRUNCATED", 0) == 0) { + throw std::runtime_error( + "clone split segment reached max frames without EOS: " + job.text); + } + } + std::vector samples = f32_to_pcm16(std::move(wav_f32), false); + trim_leading_samples(samples, job.leadin_trim_samples); + trim_trailing_silence(samples); + if (samples.empty() || rms_i16(samples, 0, samples.size()) <= 0.003) { + throw std::runtime_error("persistent clone split segment produced silent audio"); + } + postprocess_pcm16(samples); + return samples; + }; + + if (args.play_segments) { +#ifndef Q3TTS_ENABLE_SDK_AUDIO + throw std::runtime_error("--play-segments requires Q3TTS_ENABLE_SDK_AUDIO build"); +#else + try { + std::vector merged; + double total_gap = 0.0; + double max_gap = 0.0; + bool got_first_queue = false; + Clock::time_point t_first_queue = t0; + Clock::time_point t_generation_done = t0; + + for (size_t i = 0; i < jobs.size(); ++i) { + const auto wait0 = Clock::now(); + while (!exists(jobs[i].done)) { + if (talker_rc.load() != -1) { + throw std::runtime_error("talker stopped before segment marker"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + const double wait_s = seconds_since(wait0, Clock::now()); + + const auto decode0 = Clock::now(); + auto samples = decode_job_samples(jobs[i]); + write_wav_i16_samples(jobs[i].wav, samples); + const double decode_s = seconds_since(decode0, Clock::now()); + const double audio_s = static_cast(samples.size()) / 24000.0; + + const double gap_s = realtime_player->enqueue_mono24k(samples); + if (!got_first_queue) { + got_first_queue = true; + t_first_queue = Clock::now(); + } + total_gap += gap_s; + max_gap = std::max(max_gap, gap_s); + + if (!merged.empty()) { + merged.insert(merged.end(), 4800, 0); + } + merged.insert(merged.end(), samples.begin(), samples.end()); + + std::cout.setf(std::ios::fixed); + std::cout.precision(3); + std::cout << "segment " << (i + 1) + << " wait " << wait_s << "s" + << " decode " << decode_s << "s" + << " audio " << audio_s << "s" + << " gap " << gap_s << "s" + << " wav " << jobs[i].wav << std::endl; + } + t_generation_done = Clock::now(); + + if (talker_thread.joinable()) { + talker_thread.join(); + } + profile_event(profile, t0, "talker_done"); + if (talker_rc.load() != 0) { + throw std::runtime_error("persistent clone split talker failed"); + } + realtime_player->finish(); + + write_wav_i16_samples(args.wav, merged); + profile_event(profile, t0, "wav_written"); + cleanup(); + + const double wall = seconds_since(t0, Clock::now()); + const double audio = static_cast(merged.size()) / 24000.0; + const double gen_wall = seconds_since(t0, t_generation_done); + const double warm = got_first_queue ? seconds_since(t_first_queue, t_generation_done) : gen_wall; + std::cout.setf(std::ios::fixed); + std::cout.precision(2); + std::cout << "clone_split_realtime segments " << segments.size() + << " gen_wall " << gen_wall << "s" + << " wall " << wall << "s" + << " audio " << audio << "s" + << " genRTF " << (gen_wall / audio) + << " RTF " << (wall / audio) + << " warmRTF " << (warm / audio) + << " total_gap " << total_gap << "s" + << " max_gap " << max_gap << "s\n"; + std::cout << "wav " << args.wav << "\n"; + return 0; + } catch (...) { + if (talker_thread.joinable()) { + talker_thread.join(); + } + if (realtime_player) { + try { + realtime_player->finish(); + } catch (...) { + } + } + throw; + } +#endif + } + + talker_thread.join(); + profile_event(profile, t0, "talker_done"); + if (talker_rc.load() != 0) { + throw std::runtime_error("persistent clone split talker failed"); + } + + std::vector merged; + for (const auto &job : jobs) { + auto samples = decode_job_samples(job); + if (!merged.empty()) { + merged.insert(merged.end(), 4800, 0); + } + merged.insert(merged.end(), samples.begin(), samples.end()); + profile_event(profile, t0, "decode_segment_done codes=" + job.codes); + } + + write_wav_i16_samples(args.wav, merged); + profile_event(profile, t0, "wav_written"); + cleanup(); + + const double wall = seconds_since(t0, Clock::now()); + const double audio = static_cast(merged.size()) / 24000.0; + std::cout.setf(std::ios::fixed); + std::cout.precision(2); + std::cout << "clone_split_persistent segments " << segments.size() + << " cold_wall " << wall << "s" + << " audio " << audio << "s" + << " coldRTF " << (wall / audio) << "\n"; + std::cout << "wav " << args.wav << "\n"; + return 0; + } catch (...) { + cleanup(); + throw; + } +} + +double seconds_since(Clock::time_point a, Clock::time_point b) { + return std::chrono::duration(b - a).count(); +} + +std::string model_file(const std::string &model_dir, const std::string &subdir, const std::string &name) { + return q3tts_frontend::first_existing({ + q3tts_frontend::path_join(q3tts_frontend::path_join(model_dir, subdir), name), + q3tts_frontend::path_join(model_dir, name), + name, + }); +} + +int run_stdin_segments(Args args, const std::vector &input_lines = {}) { + std::signal(SIGPIPE, SIG_IGN); + + const auto t0 = Clock::now(); + const bool profile = env_int("Q3TTS_PROFILE", 0) != 0; + profile_event(profile, t0, "stdin_segments_start"); + const bool has_ref = !args.ref_bin.empty() || !args.ref_wav.empty(); + const std::string pid = std::to_string(static_cast(getpid())); + std::vector temps; + auto cleanup = [&]() { + for (const auto &p : temps) { + std::remove(p.c_str()); + } + }; + + struct StreamJob { + int index = 0; + std::string text; + std::string synth_text; + std::string prefill; + std::string trailing; + std::string pad; + std::string codes; + std::string done; + std::string wav; + int64_t np = 0; + int64_t nt = 0; + int max_frames = 0; + size_t leadin_trim_samples = 0; + }; + + try { + if (!std::getenv("Q3TTS_SKIP_TCM_CLEAR")) { + const int clear_rc = std::system("spacemit-tcm-smi -c >/dev/null 2>&1"); + (void)clear_rc; + } + profile_event(profile, t0, "tcm_clear_done"); + + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "q3tts_cpp_driver"); + std::string ref_bin = args.ref_bin; + if (ref_bin.empty() && !args.ref_wav.empty()) { + ref_bin = "/tmp/q3tts_stdin_" + pid + "_ref.spk.bin"; + temps.push_back(ref_bin); + auto spk = q3tts_frontend::run_speaker_encoder( + env, q3tts_frontend::speaker_encoder_path(args.model_dir), args.ref_wav, + env_int("Q3TTS_FRONTEND_THREADS", 2)); + q3tts_frontend::write_speaker_bin(ref_bin, spk); + } + profile_event(profile, t0, "ref_ready"); + + q3tts_frontend::FrontendConfig base_fc; + base_fc.model_dir = args.model_dir; + base_fc.ref_bin = ref_bin; + base_fc.language = args.language; + base_fc.talker_gguf = args.talker_gguf; + base_fc.cp_gguf = args.cp_gguf; + base_fc.frontend_threads = env_int("Q3TTS_FRONTEND_THREADS", 2); + base_fc.full_prompt_non_streaming = env_int("Q3TTS_FULL_PROMPT_NON_STREAMING", 0) != 0; + q3tts_frontend::FrontendRuntime frontend(env, base_fc); + profile_event(profile, t0, "frontend_runtime_ready"); + + const int codec_threads = env_int("Q3TTS_CODEC_THREADS", 3); + if (const char *aff = std::getenv("Q3TTS_CODEC_AFFINITY"); aff && *aff) { + setenv("SPACEMIT_EP_INTRA_THREAD_AFFINITY", aff, 1); + if (!std::getenv("SPACEMIT_EP_INTRA_THREAD_NUM")) { + setenv("SPACEMIT_EP_INTRA_THREAD_NUM", std::to_string(codec_threads).c_str(), 1); + } + if (!std::getenv("SPACEMIT_EP_USE_GLOBAL_INTRA_THREAD")) { + setenv("SPACEMIT_EP_USE_GLOBAL_INTRA_THREAD", "0", 1); + } + } + + const bool no_ref_text = !has_ref; + const std::vector buckets = no_ref_text + ? parse_int_list(env_str("Q3TTS_NOREF_CODEC_BUCKETS", ""), std::vector{25}) + : parse_int_list(env_str("Q3TTS_CODEC_BUCKETS", ""), std::vector{50}); + const int chunk = no_ref_text ? env_int("Q3TTS_NOREF_CODEC_CHUNK", 25) + : env_int("Q3TTS_CODEC_CHUNK", 50); + const int first_chunk = no_ref_text ? env_int("Q3TTS_NOREF_CODEC_FIRST_CHUNK", chunk) + : env_int("Q3TTS_CODEC_FIRST_CHUNK", chunk); + const int ctx_limit = env_int("Q3TTS_CODEC_CTX", 25); + +#ifndef Q3TTS_ENABLE_SDK_AUDIO + if (args.play_segments) { + throw std::runtime_error("--play-segments requires Q3TTS_ENABLE_SDK_AUDIO build"); + } +#else + std::unique_ptr realtime_player; + if (args.play_segments) { + realtime_player = std::make_unique( + args.play_rate, args.play_channels, args.play_device, args.play_buffer, + args.play_tail_ms, args.play_drain_ms, args.play_segment_pause_ms); + } +#endif + + q3tts_codec::DecoderPoolConfig codec_cfg; + codec_cfg.model_dir = args.model_dir; + codec_cfg.buckets = buckets; + codec_cfg.intra_threads = codec_threads; + codec_cfg.on_bucket_warm = [&](int b) { + profile_event(profile, t0, "codec_warm bucket=" + std::to_string(b)); + }; + q3tts_codec::DecoderPool codec(env, codec_cfg); + profile_event(profile, t0, "codec_ready"); + + const std::string driver = env_str("TALKER_DRIVER", "./talker_driver"); + const std::string talker_cpuset = env_str("TALKER_CPUSET", "4-7"); + const std::string talker_gguf = model_file(args.model_dir, "gguf", args.talker_gguf); + const std::string cp_gguf = model_file(args.model_dir, "gguf", args.cp_gguf); + const int stdin_max_prefill = env_int("Q3TTS_STDIN_MAX_PREFILL", 128); + const int stdin_max_frames = env_int("Q3TTS_STDIN_MAX_FRAMES", args.frames); + std::stringstream cmd; + cmd << "taskset -c " << shell_quote(talker_cpuset) + << " " << shell_quote(driver) + << " " << shell_quote(talker_gguf) + << " " << shell_quote(cp_gguf) + << " --jobs-stdin " << stdin_max_prefill << " " << stdin_max_frames; + profile_event(profile, t0, "popen_talker_stdin cmd=" + cmd.str()); + FILE *talker = popen(cmd.str().c_str(), "w"); + if (!talker) { + throw std::runtime_error("popen talker stdin failed"); + } + + std::queue decode_queue; + std::mutex mu; + std::condition_variable cv; + bool producer_done = false; + std::exception_ptr decoder_error = nullptr; + std::atomic talker_closed{false}; + std::atomic talker_rc{-1}; + std::vector merged; + double total_gap = 0.0; + double max_gap = 0.0; + bool got_first_queue = false; + bool got_first_input = false; + Clock::time_point t_first_queue = t0; + Clock::time_point t_first_input = t0; + Clock::time_point t_generation_done = t0; + int decoded_segments = 0; + int skipped_segments = 0; + int truncated_segments = 0; + int written_segments = 0; + const auto ref_decode_prefix = load_ref_decode_prefix(ref_bin); + + auto decode_codes = [&](const std::string &path, size_t *frame_count) -> std::vector { + auto data = read_all(path); + const size_t frame_bytes = sizeof(int32_t) * 16; + if (data.size() % frame_bytes != 0) { + throw std::runtime_error("bad code file size: " + path); + } + if (frame_count) { + *frame_count = data.size() / frame_bytes; + } + std::vector> frames(data.size() / frame_bytes); + if (!frames.empty()) { + std::memcpy(frames.data(), data.data(), data.size()); + } + return decode_with_reference_prefix(codec, frames, ref_decode_prefix, buckets, first_chunk, chunk, ctx_limit); + }; + + auto decode_codes_streaming = [&](const StreamJob &job, size_t *frame_count) -> std::vector { + const size_t frame_bytes = sizeof(int32_t) * 16; + std::vector> frames; + frames.insert(frames.end(), ref_decode_prefix.begin(), ref_decode_prefix.end()); + std::vector>> chunks; + size_t generated = 0; + int done = 0; + + auto append_available_frames = [&]() { + if (!exists(job.codes)) { + return; + } + auto data = read_all(job.codes); + const size_t available = data.size() / frame_bytes; + if (available <= generated) { + return; + } + const size_t old = generated; + generated = available; + frames.resize(ref_decode_prefix.size() + generated); + std::memcpy(frames.data() + ref_decode_prefix.size() + old, + data.data() + old * frame_bytes, + (generated - old) * frame_bytes); + }; + + auto submit = [&](int start, int n) -> int { + const int new_count = n - start; + auto it = std::find_if(buckets.begin(), buckets.end(), [&](int b) { return b >= new_count; }); + if (it == buckets.end()) { + throw std::runtime_error("no codec bucket for stdin chunk"); + } + const int b = *it; + const int ctx = std::min({start, ctx_limit, b - new_count}); + std::vector> codes(frames.begin() + (start - ctx), frames.begin() + n); + chunks.emplace_back(start, codec.decode(b, codes, ctx)); + return n; + }; + + while (true) { + append_available_frames(); + while (static_cast(frames.size()) - done >= (done == 0 ? first_chunk : chunk)) { + const int next = done == 0 ? first_chunk : chunk; + done = submit(done, done + next); + } + if (exists(job.done)) { + append_available_frames(); + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + if (static_cast(frames.size()) > done) { + submit(done, static_cast(frames.size())); + } + if (frame_count) { + *frame_count = generated; + } + + std::sort(chunks.begin(), chunks.end(), [](const auto &a, const auto &b) { + return a.first < b.first; + }); + size_t total_samples = 0; + for (const auto &chunk_wav : chunks) { + total_samples += chunk_wav.second.size(); + } + std::vector wav; + wav.reserve(total_samples); + for (auto &chunk_wav : chunks) { + wav.insert(wav.end(), chunk_wav.second.begin(), chunk_wav.second.end()); + } + + const size_t ref_decode_samples = reference_audio_cut_samples( + ref_decode_prefix.size(), generated, wav.size()); + if (ref_decode_samples > 0) { + if (wav.size() <= ref_decode_samples) { + wav.clear(); + } else { + wav.erase(wav.begin(), wav.begin() + static_cast(ref_decode_samples)); + } + } + return wav; + }; + + auto decode_job_samples = [&](const StreamJob &job) -> std::vector { + size_t code_frames = 0; + auto wav_f32 = env_int("Q3TTS_STDIN_STREAM_DECODE", 1) != 0 + ? decode_codes_streaming(job, &code_frames) + : decode_codes(job.codes, &code_frames); + if (job.max_frames > 0 && code_frames >= static_cast(job.max_frames)) { + ++truncated_segments; + std::cout << "stream_segment_truncated " << job.index + << " frames " << code_frames + << " max " << job.max_frames + << " text " << job.text << std::endl; + } + std::vector samples = f32_to_pcm16(std::move(wav_f32), false); + trim_leading_samples(samples, job.leadin_trim_samples); + trim_trailing_silence(samples); + if (samples.empty() || rms_i16(samples, 0, samples.size()) <= 0.003) { + std::cout << "stream_segment_skip " << job.index + << " reason silent_audio text " << job.text << std::endl; + return {}; + } + postprocess_pcm16(samples); + return samples; + }; + + auto cleanup_stream_job_inputs = [](const StreamJob &job) { + std::remove(job.prefill.c_str()); + std::remove(job.trailing.c_str()); + std::remove(job.pad.c_str()); + std::remove(job.codes.c_str()); + std::remove(job.done.c_str()); + }; + + std::thread decoder([&]() { + try { + while (true) { + StreamJob job; + { + std::unique_lock lk(mu); + cv.wait(lk, [&]() { return producer_done || !decode_queue.empty(); }); + if (decode_queue.empty()) { + return; + } + job = std::move(decode_queue.front()); + decode_queue.pop(); + } + + const auto decode0 = Clock::now(); + auto samples = decode_job_samples(job); + cleanup_stream_job_inputs(job); + if (samples.empty()) { + ++skipped_segments; + t_generation_done = Clock::now(); + continue; + } + write_wav_i16_samples(job.wav, samples); + const double decode_s = seconds_since(decode0, Clock::now()); + const double wait_s = 0.0; + const double audio_s = static_cast(samples.size()) / 24000.0; + + double gap_s = 0.0; +#ifdef Q3TTS_ENABLE_SDK_AUDIO + if (realtime_player) { + gap_s = realtime_player->enqueue_mono24k(samples); + if (!got_first_queue) { + got_first_queue = true; + t_first_queue = Clock::now(); + } + } +#endif + total_gap += gap_s; + max_gap = std::max(max_gap, gap_s); + + if (!merged.empty()) { + merged.insert(merged.end(), 4800, 0); + } + merged.insert(merged.end(), samples.begin(), samples.end()); + ++decoded_segments; + t_generation_done = Clock::now(); + + std::cout.setf(std::ios::fixed); + std::cout.precision(3); + std::cout << "stream_segment " << job.index + << " wait " << wait_s << "s" + << " decode " << decode_s << "s" + << " audio " << audio_s << "s" + << " gap " << gap_s << "s" + << " wav " << job.wav << std::endl; + } + } catch (...) { + decoder_error = std::current_exception(); + } + }); + + auto submit_line = [&](std::string line) -> std::pair { + line = trim(line); + if (line.empty()) { + return {0, 0}; + } + const int first_index = written_segments + 1; + if (!args.hotwords.empty()) { + line = apply_hotwords(line, args.hotwords); + } + Args line_args = args; + line_args.text = line; + std::vector segments; + if (args.no_clone_split) { + segments = split_stdin_text(line, has_full_reference_prompt(ref_bin), line_args.language); + } else { + segments = prepare_clone_segments(line_args); + } + for (const auto &segment : segments) { + StreamJob job; + job.index = ++written_segments; + job.text = segment; + if (has_ref && needs_clone_leadin(job.text, args.clone_leadin)) { + job.synth_text = args.clone_leadin + job.text; + job.leadin_trim_samples = clone_leadin_trim_samples(args.clone_leadin); + } else { + job.synth_text = job.text; + } + const std::string base = "/tmp/q3tts_stdin_" + pid + "_" + std::to_string(job.index); + job.prefill = base + "_prefill.bin"; + job.trailing = base + "_trailing.bin"; + job.pad = base + "_pad.bin"; + job.codes = base + "_codes.bin"; + job.done = base + "_done"; + job.wav = base + ".wav"; + job.max_frames = stdin_max_frames; + temps.push_back(job.prefill); + temps.push_back(job.trailing); + temps.push_back(job.pad); + temps.push_back(job.codes); + temps.push_back(job.done); + temps.push_back(job.wav); + + q3tts_frontend::FrontendConfig fc = base_fc; + fc.text = job.synth_text; + auto input = frontend.build(fc); + write_all(job.prefill, + reinterpret_cast(input.prefill.data()), + input.prefill.size() * sizeof(float)); + write_all(job.trailing, + reinterpret_cast(input.trailing.data()), + input.trailing.size() * sizeof(float)); + write_all(job.pad, + reinterpret_cast(input.pad.data()), + input.pad.size() * sizeof(float)); + job.np = input.n_prefill; + job.nt = input.n_trailing; + if (job.np > stdin_max_prefill) { + throw std::runtime_error("stdin segment prefill exceeds Q3TTS_STDIN_MAX_PREFILL"); + } + + if (std::fprintf(talker, "%s %lld %s %lld %s %d %s %s\n", + job.prefill.c_str(), static_cast(job.np), + job.trailing.c_str(), static_cast(job.nt), + job.pad.c_str(), job.max_frames, + job.codes.c_str(), job.done.c_str()) < 0 || + std::fflush(talker) != 0) { + throw std::runtime_error("write talker stdin job failed"); + } + if (!got_first_input) { + got_first_input = true; + t_first_input = Clock::now(); + } + { + std::lock_guard lk(mu); + decode_queue.push(job); + } + cv.notify_one(); + std::cout << "stream_text " << job.index << " " << segment << std::endl; + } + return {first_index, written_segments}; + }; + + if (!input_lines.empty()) { + for (const auto &line : input_lines) { + const auto range = submit_line(line); + std::cout << "stream_request " << range.first << " " << range.second << std::endl; + } + } else { + std::string line; + while (std::getline(std::cin, line)) { + const auto range = submit_line(line); + std::cout << "stream_request " << range.first << " " << range.second << std::endl; + } + } + + { + std::lock_guard lk(mu); + producer_done = true; + } + cv.notify_all(); + + const int rc = pclose(talker); + talker_rc.store(rc); + talker_closed.store(true); + cv.notify_all(); + + if (decoder.joinable()) { + decoder.join(); + } + if (decoder_error) { + std::rethrow_exception(decoder_error); + } + if (rc != 0) { + throw std::runtime_error("stdin talker exited with status " + std::to_string(rc)); + } + if (truncated_segments > 0 && env_int("Q3TTS_ALLOW_TRUNCATED", 0) == 0) { + throw std::runtime_error( + "stdin segment reached max frames without EOS; split text or increase --frames " + "(set Q3TTS_ALLOW_TRUNCATED=1 to keep partial audio)"); + } + +#ifdef Q3TTS_ENABLE_SDK_AUDIO + if (realtime_player) { + realtime_player->finish(); + } +#endif + write_wav_i16_samples(args.wav, merged); + cleanup(); + + const double wall = seconds_since(t0, Clock::now()); + const double input_wall = got_first_input ? seconds_since(t_first_input, t_generation_done) : wall; + const double audio = static_cast(merged.size()) / 24000.0; + const double warm = got_first_queue ? seconds_since(t_first_queue, t_generation_done) : input_wall; + std::cout.setf(std::ios::fixed); + std::cout.precision(2); + std::cout << "stdin_realtime segments " << decoded_segments + << " skipped " << skipped_segments + << " gen_wall " << input_wall << "s" + << " wall " << wall << "s" + << " audio " << audio << "s" + << " genRTF " << (audio > 0 ? input_wall / audio : 0.0) + << " RTF " << (audio > 0 ? wall / audio : 0.0) + << " warmRTF " << (audio > 0 ? warm / audio : 0.0) + << " total_gap " << total_gap << "s" + << " max_gap " << max_gap << "s\n"; + std::cout << "wav " << args.wav << "\n"; + return 0; + } catch (...) { + cleanup(); + throw; + } +} + +} // namespace + +namespace qwen3_tts { +int run_cli(int argc, char **argv) { + try { + Args args = parse_args(argc, argv); + if (!args.stdin_segments && !args.text.empty() && !args.hotwords.empty()) { + args.text = apply_hotwords(args.text, args.hotwords); + } + if (args.dump_segments) { + std::vector segments = prepare_clone_segments(args); + for (size_t i = 0; i < segments.size(); ++i) { + std::cout << (i + 1) << "\t" << utf8_len(segments[i]) << "\t" << segments[i] << "\n"; + } + return 0; + } + maybe_select_full_prompt_talker(args); + set_default_runtime_env(args.talker_gguf); + const bool has_ref = has_clone_reference(args); + if (has_ref) { + set_env_override( + "Q3TTS_TALKER_REPETITION_PENALTY", + env_str("Q3TTS_CLONE_TALKER_REPETITION_PENALTY", "1.15")); + } + if (args.stdin_segments) { + return run_stdin_segments(args); + } + if (has_ref && !args.text.empty() && !args.dump_ids && !args.frontend_only) { + std::vector segments = prepare_clone_segments(args); + if (segments.size() == 1) { + args.text = segments[0]; + args.no_clone_split = true; + } else if (!args.no_clone_split || env_int("Q3TTS_CLONE_UNSAFE_NOSPLIT", 0) == 0) { + if (args.no_clone_split) { + std::cerr << "clone_force_split segments " << segments.size() + << " text_chars " << utf8_len(args.text) << "\n"; + } + if (has_full_reference_prompt(args.ref_bin)) { + Args stream_args = args; + stream_args.no_clone_split = true; + return run_stdin_segments(stream_args, segments); + } + return run_clone_split(args, argv[0], segments); + } + } + const bool profile = env_int("Q3TTS_PROFILE", 0) != 0; + const auto t_profile0 = Clock::now(); + profile_event(profile, t_profile0, "driver_start"); + + if (!std::getenv("Q3TTS_SKIP_TCM_CLEAR")) { + const int clear_rc = std::system("spacemit-tcm-smi -c >/dev/null 2>&1"); + (void)clear_rc; + } + profile_event(profile, t_profile0, "tcm_clear_done"); + + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "q3tts_cpp_driver"); + + const std::string tmp_tag = std::to_string(static_cast(getpid())); + const std::string prefill_bin = "/tmp/q3tts_cpp_" + tmp_tag + "_prefill.bin"; + const std::string trailing_bin = "/tmp/q3tts_cpp_" + tmp_tag + "_trailing.bin"; + const std::string pad_bin = "/tmp/q3tts_cpp_" + tmp_tag + "_pad.bin"; + int64_t np = 0; + int64_t nt = 0; + if (!args.text.empty()) { + if (args.dump_ids) { + auto ids = q3tts_frontend::tokenize_prompt(args.model_dir, args.text); + for (size_t i = 0; i < ids.size(); ++i) { + if (i != 0) { + std::cout << " "; + } + std::cout << ids[i]; + } + std::cout << "\n"; + return 0; + } + q3tts_frontend::FrontendConfig fc; + fc.model_dir = args.model_dir; + fc.text = args.text; + fc.ref_wav = args.ref_wav; + fc.ref_bin = args.ref_bin; + fc.language = args.language; + fc.talker_gguf = args.talker_gguf; + fc.cp_gguf = args.cp_gguf; + fc.frontend_threads = env_int("Q3TTS_FRONTEND_THREADS", 2); + fc.full_prompt_non_streaming = env_int("Q3TTS_FULL_PROMPT_NON_STREAMING", 0) != 0; + auto input = q3tts_frontend::build(env, fc); + write_all(prefill_bin, + reinterpret_cast(input.prefill.data()), + input.prefill.size() * sizeof(float)); + write_all(trailing_bin, + reinterpret_cast(input.trailing.data()), + input.trailing.size() * sizeof(float)); + write_all(pad_bin, + reinterpret_cast(input.pad.data()), + input.pad.size() * sizeof(float)); + np = input.n_prefill; + nt = input.n_trailing; + profile_event(profile, t_profile0, + "frontend_materialized np=" + std::to_string(np) + " nt=" + std::to_string(nt)); + } else { + if (args.npz.empty()) { + args.npz = "e2e_spk.npz"; + } + auto npz = load_npz_stored(args.npz); + auto prefill = parse_npy(npz.at("prefill.npy")); + auto trailing = parse_npy(npz.at("trailing.npy")); + auto pad = parse_npy(npz.at("pad.npy")); + if (prefill.descr != " buckets = no_ref_text + ? parse_int_list(env_str("Q3TTS_NOREF_CODEC_BUCKETS", ""), std::vector{25}) + : parse_int_list(env_str("Q3TTS_CODEC_BUCKETS", ""), std::vector{50}); + const int chunk = no_ref_text ? env_int("Q3TTS_NOREF_CODEC_CHUNK", 25) + : env_int("Q3TTS_CODEC_CHUNK", 50); + const int first_chunk = no_ref_text ? env_int("Q3TTS_NOREF_CODEC_FIRST_CHUNK", chunk) + : env_int("Q3TTS_CODEC_FIRST_CHUNK", chunk); + const int ctx_limit = env_int("Q3TTS_CODEC_CTX", 25); + const int codec_nice = env_int("Q3TTS_CODEC_NICE", 0); + if (codec_nice > 0) { + if (setpriority(PRIO_PROCESS, 0, codec_nice) != 0) { + std::perror("setpriority codec"); + } + } + + q3tts_codec::DecoderPoolConfig codec_cfg; + codec_cfg.model_dir = args.model_dir; + codec_cfg.buckets = buckets; + codec_cfg.intra_threads = codec_threads; + codec_cfg.on_bucket_warm = [&](int b) { + profile_event(profile, t_profile0, "codec_warm bucket=" + std::to_string(b)); + }; + q3tts_codec::DecoderPool codec(env, codec_cfg); + if (codec_nice > 0 && setpriority(PRIO_PROCESS, 0, 0) != 0) { + std::perror("setpriority main"); + } + + const auto ref_decode_prefix = load_ref_decode_prefix(args.ref_bin); + + std::vector> frames; + frames.insert(frames.end(), ref_decode_prefix.begin(), ref_decode_prefix.end()); + std::vector>> chunks; + std::mutex mu; + std::condition_variable cv; + std::queue>, int, int, int>> jobs; + bool stop = false; + + std::thread worker([&]() { + while (true) { + std::tuple>, int, int, int> job; + { + std::unique_lock lk(mu); + cv.wait(lk, [&] { return stop || !jobs.empty(); }); + if (jobs.empty()) { + return; + } + job = std::move(jobs.front()); + jobs.pop(); + } + auto &[codes, off, ctx, bucket] = job; + const auto t_job0 = Clock::now(); + auto wav = codec.decode(bucket, codes, ctx); + const auto t_job1 = Clock::now(); + if (profile) { + std::stringstream msg; + msg << "codec_done off=" << off + << " input=" << codes.size() + << " new=" << (static_cast(codes.size()) - ctx) + << " ctx=" << ctx + << " bucket=" << bucket + << " run_ms=" << (seconds_since(t_job0, t_job1) * 1000.0); + profile_event(true, t_profile0, msg.str()); + } + { + std::lock_guard lk(mu); + chunks.emplace_back(off, std::move(wav)); + } + } + }); + + auto submit = [&](int done, int n) -> int { + const int new_count = n - done; + auto it = std::find_if(buckets.begin(), buckets.end(), [&](int b) { return b >= new_count; }); + if (it == buckets.end()) { + throw std::runtime_error("no codec bucket for chunk"); + } + const int b = *it; + const int ctx = std::min({done, ctx_limit, b - new_count}); + std::vector> codes(frames.begin() + (done - ctx), frames.begin() + n); + { + std::lock_guard lk(mu); + jobs.emplace(std::move(codes), done, ctx, b); + } + if (profile) { + std::stringstream msg; + msg << "codec_submit off=" << done + << " input=" << (new_count + ctx) + << " new=" << new_count + << " ctx=" << ctx + << " bucket=" << b; + profile_event(true, t_profile0, msg.str()); + } + cv.notify_one(); + return n; + }; + + const std::string driver = env_str("TALKER_DRIVER", "./talker_driver"); + const std::string talker_cpuset = env_str("TALKER_CPUSET", "4-7"); + const std::string talker_gguf = model_file(args.model_dir, "gguf", args.talker_gguf); + const std::string cp_gguf = model_file(args.model_dir, "gguf", args.cp_gguf); + std::stringstream cmd; + cmd << "taskset -c " << talker_cpuset << " " << driver + << " " << talker_gguf + << " " << cp_gguf + << " aux " << prefill_bin << " " << np + << " " << trailing_bin << " " << nt + << " " << pad_bin << " " << args.frames; + + profile_event(profile, t_profile0, "popen_talker cmd=" + cmd.str()); + FILE *pipe = popen(cmd.str().c_str(), "r"); + if (!pipe) { + throw std::runtime_error("popen talker failed"); + } + profile_event(profile, t_profile0, "popen_done"); + + auto t0 = Clock::now(); + Clock::time_point t_first; + bool got_first = false; + int done = 0; + int generated_frames = 0; + while (true) { + std::array f {}; + const size_t got = fread(f.data(), sizeof(int32_t), 16, pipe); + if (got < 16) { + break; + } + if (!got_first) { + t_first = Clock::now(); + got_first = true; + profile_event(profile, t_profile0, "first_frame"); + } + frames.push_back(f); + ++generated_frames; + const int next_chunk = done == 0 ? first_chunk : chunk; + if (static_cast(frames.size()) - done >= next_chunk) { + done = submit(done, done + next_chunk); + } + } + profile_event(profile, t_profile0, "talker_stdout_eof frames=" + std::to_string(generated_frames)); + if (static_cast(frames.size()) > done) { + submit(done, static_cast(frames.size())); + } + if (const char *dump_codes = std::getenv("Q3TTS_DUMP_CODES"); dump_codes && *dump_codes) { + const auto *dump_begin = frames.data() + static_cast(ref_decode_prefix.size()); + write_all(dump_codes, + reinterpret_cast(dump_begin), + static_cast(generated_frames) * sizeof(frames[0])); + } + const auto t_pclose0 = Clock::now(); + const int rc = pclose(pipe); + const auto t_pclose1 = Clock::now(); + if (profile) { + std::stringstream msg; + msg << "pclose_done rc=" << rc + << " wait_ms=" << (seconds_since(t_pclose0, t_pclose1) * 1000.0); + profile_event(true, t_profile0, msg.str()); + } + { + std::lock_guard lk(mu); + stop = true; + } + cv.notify_one(); + worker.join(); + profile_event(profile, t_profile0, "codec_worker_join chunks=" + std::to_string(chunks.size())); + if (rc != 0) { + std::remove(prefill_bin.c_str()); + std::remove(trailing_bin.c_str()); + std::remove(pad_bin.c_str()); + throw std::runtime_error("talker exited with status " + std::to_string(rc)); + } + + auto t1 = Clock::now(); + std::sort(chunks.begin(), chunks.end(), [](const auto &a, const auto &b) { + return a.first < b.first; + }); + std::vector wav; + size_t total_samples = 0; + for (const auto &c : chunks) { + total_samples += c.second.size(); + } + wav.reserve(total_samples); + for (auto &c : chunks) { + wav.insert(wav.end(), c.second.begin(), c.second.end()); + } + const size_t ref_decode_samples = reference_audio_cut_samples( + ref_decode_prefix.size(), static_cast(generated_frames), wav.size()); + if (ref_decode_samples > 0) { + if (wav.size() <= ref_decode_samples) { + wav.clear(); + } else { + wav.erase(wav.begin(), wav.begin() + static_cast(ref_decode_samples)); + } + } + const double wall = seconds_since(t0, t1); + const double audio = static_cast(wav.size()) / 24000.0; + const double warm = got_first ? seconds_since(t_first, t1) : wall; + if (args.frames > 0 && generated_frames >= args.frames && + env_int("Q3TTS_ALLOW_TRUNCATED", 0) == 0) { + std::remove(prefill_bin.c_str()); + std::remove(trailing_bin.c_str()); + std::remove(pad_bin.c_str()); + throw std::runtime_error( + "talker reached max frames without EOS; split text or increase --frames " + "(set Q3TTS_ALLOW_TRUNCATED=1 to keep partial audio)"); + } + if (profile) { + std::stringstream msg; + msg << "wav_assembled samples=" << wav.size() + << " chunks=" << chunks.size() + << " wall_ms=" << (wall * 1000.0) + << " warm_ms=" << (warm * 1000.0); + profile_event(true, t_profile0, msg.str()); + } + std::cout.setf(std::ios::fixed); + std::cout.precision(2); + std::cout << "frames " << generated_frames + << " E2E wall " << wall << "s" + << " audio " << audio << "s" + << " RTF " << (wall / audio) + << " warmRTF " << (warm / audio) << "\n"; + write_wav_i16(args.wav, wav); + std::remove(prefill_bin.c_str()); + std::remove(trailing_bin.c_str()); + std::remove(pad_bin.c_str()); + std::cout << "wav " << args.wav << "\n"; + } catch (const std::exception &e) { + std::cerr << "error: " << e.what() << "\n"; + return 1; + } + return 0; +} + +} // namespace qwen3_tts diff --git a/tools/speech/backends/qwen3_tts/src/talker_driver.c b/tools/speech/backends/qwen3_tts/src/talker_driver.c new file mode 100644 index 000000000000..9f11443ee420 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/src/talker_driver.c @@ -0,0 +1,729 @@ +// Qwen3-TTS frame loop in C: talker + code-predictor in one process (two llama contexts), +// cp lm-heads as in-process GEMV. Streams 16-code frames (int32) to stdout for the codec. +// usage: talker_driver TALKER.gguf CP.gguf _ PREFILL.bin NPREFILL TRAILING.bin NTRAIL PAD.bin MAXFRAMES +// talker_driver TALKER.gguf CP.gguf --jobs JOBS.txt +// talker_driver TALKER.gguf CP.gguf --jobs-stdin MAX_PREFILL MAXFRAMES +#include "llama.h" +#include "ggml-cpu.h" +#include "gguf.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#include "heads_pool.h" + +#define HID 1024 +#define NCB 16 +#define CPV 2048 +#define EOS 2150 + +static double now_s(void) { + struct timespec ts; clock_gettime(CLOCK_MONOTONIC, &ts); + return ts.tv_sec + ts.tv_nsec * 1e-9; +} + +static void log_to_stderr(enum ggml_log_level level, const char *text, void *user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); +} + +static int saved_stdout = -1; +static int sticky_qwen3_embed_only = 0; + +static void default_llama_env(void) { + if (getenv("LLAMA_GRAPH_REUSE_2WAY") == NULL) { + setenv("LLAMA_GRAPH_REUSE_2WAY", "1", 0); + } +} + +static void stdout_to_stderr(void) { + fflush(stdout); + saved_stdout = dup(STDOUT_FILENO); + if (saved_stdout >= 0) { + dup2(STDERR_FILENO, STDOUT_FILENO); + } +} + +static void restore_stdout(void) { + fflush(stdout); + if (saved_stdout >= 0) { + dup2(saved_stdout, STDOUT_FILENO); + close(saved_stdout); + saved_stdout = -1; + } +} + +static int cp_decode(struct llama_context *ctx, struct llama_batch batch) { + if (sticky_qwen3_embed_only) { + return llama_decode(ctx, batch); + } + if (getenv("Q3TTS_CP_EMBED_ONLY") == NULL) { + return llama_decode(ctx, batch); + } + setenv("LLAMA_QWEN3_EMBED_ONLY", "1", 1); + int ret = llama_decode(ctx, batch); + unsetenv("LLAMA_QWEN3_EMBED_ONLY"); + return ret; +} + +static void set_ctx_threads_if_changed(struct llama_context *ctx, int *cur_threads, int *cur_threads_batch, + int threads, int threads_batch) { + if (threads <= 0 || threads_batch <= 0) { + return; + } + if (*cur_threads == threads && *cur_threads_batch == threads_batch) { + return; + } + llama_set_n_threads(ctx, threads, threads_batch); + *cur_threads = threads; + *cur_threads_batch = threads_batch; +} + +static int talker_decode(struct llama_context *ctx, struct llama_batch batch, int embed_only) { + if (sticky_qwen3_embed_only) { + return llama_decode(ctx, batch); + } + if (!embed_only) { + return llama_decode(ctx, batch); + } + setenv("LLAMA_QWEN3_EMBED_ONLY", "1", 1); + int ret = llama_decode(ctx, batch); + unsetenv("LLAMA_QWEN3_EMBED_ONLY"); + return ret; +} + +static float *load_bin(const char *path, long n) { + FILE *f = fopen(path, "rb"); + if (!f) { fprintf(stderr, "open %s fail\n", path); exit(1); } + float *p = malloc(n * 4); + if (fread(p, 4, n, f) != (size_t)n) { fprintf(stderr, "read %s fail\n", path); exit(1); } + fclose(f); + return p; +} +static uint16_t *load_bin16(const char *path, long n) { + FILE *f = fopen(path, "rb"); + if (!f) { fprintf(stderr, "open %s fail\n", path); exit(1); } + uint16_t *p = malloc(n * 2); + if (fread(p, 2, n, f) != (size_t)n) { fprintf(stderr, "read %s fail\n", path); exit(1); } + fclose(f); + return p; +} + +static void *load_gguf_tensor(const char *path, const char *name, enum ggml_type type, size_t bytes) { + struct gguf_init_params params = { + .no_alloc = true, + .ctx = NULL, + }; + struct gguf_context *ctx = gguf_init_from_file(path, params); + if (!ctx) { fprintf(stderr, "open gguf %s fail\n", path); exit(1); } + + int64_t tid = gguf_find_tensor(ctx, name); + if (tid < 0) { fprintf(stderr, "tensor %s not found in %s\n", name, path); exit(1); } + if (gguf_get_tensor_type(ctx, tid) != type) { + fprintf(stderr, "tensor %s type mismatch in %s\n", name, path); + exit(1); + } + if (gguf_get_tensor_size(ctx, tid) != bytes) { + fprintf(stderr, "tensor %s size mismatch in %s\n", name, path); + exit(1); + } + + FILE *f = fopen(path, "rb"); + if (!f) { fprintf(stderr, "open %s fail\n", path); exit(1); } + if (fseek(f, (long)(gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, tid)), SEEK_SET) != 0) { + fprintf(stderr, "seek %s:%s fail\n", path, name); + exit(1); + } + void *p = malloc(bytes); + if (!p) { fprintf(stderr, "malloc %zu fail\n", bytes); exit(1); } + if (fread(p, 1, bytes, f) != bytes) { + fprintf(stderr, "read %s:%s fail\n", path, name); + exit(1); + } + fclose(f); + gguf_free(ctx); + return p; +} + +static int env_int(const char *name, int fallback) { + const char *v = getenv(name); + if (!v || !*v) { + return fallback; + } + return atoi(v); +} + +static float env_float(const char *name, float fallback) { + const char *v = getenv(name); + if (!v || !*v) { + return fallback; + } + return (float)atof(v); +} + +static int env_bool_default(const char *name, int fallback) { + const char *v = getenv(name); + if (!v || !*v) { + return fallback; + } + if (v[0] == '0' || v[0] == 'f' || v[0] == 'F' || v[0] == 'n' || v[0] == 'N') { + return 0; + } + return 1; +} + +static int max_i(int a, int b) { + return a > b ? a : b; +} + +static struct ggml_threadpool *create_llama_threadpool(int n_threads) { + if (env_int("Q3TTS_LLAMA_THREADPOOL", 1) == 0 || n_threads <= 1) { + return NULL; + } + + struct ggml_threadpool_params tp = ggml_threadpool_params_default(n_threads); + int poll = env_int("Q3TTS_LLAMA_THREADPOOL_POLL", -1); + if (poll >= 0) { + if (poll > 100) { + poll = 100; + } + tp.poll = (uint32_t)poll; + } + + struct ggml_threadpool *threadpool = ggml_threadpool_new(&tp); + if (!threadpool) { + fprintf(stderr, "llama_threadpool create failed threads %d\n", n_threads); + return NULL; + } + fprintf(stderr, "llama_threadpool threads %d poll %u\n", n_threads, tp.poll); + return threadpool; +} + +static struct llama_context *mk_ctx(struct llama_model *m, int n_ctx, int n_batch, int n_ubatch, int n_threads, int n_threads_batch, int n_seq_max) { + struct llama_context_params cp = llama_context_default_params(); + cp.n_ctx = n_ctx; cp.n_batch = n_batch; cp.n_ubatch = n_ubatch; + if (n_seq_max > 0) { cp.n_seq_max = n_seq_max; } + cp.embeddings = true; cp.pooling_type = LLAMA_POOLING_TYPE_NONE; +#ifdef Q3TTS_LLAMA_BOOL_FLASH_ATTN + cp.flash_attn = env_bool_default("Q3TTS_FLASH_ATTN", 1) != 0; +#else + cp.flash_attn_type = env_bool_default("Q3TTS_FLASH_ATTN", 1) != 0 ? + LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED; +#endif + if (n_threads > 0) { cp.n_threads = n_threads; } + if (n_threads_batch > 0) { cp.n_threads_batch = n_threads_batch; } + else if (n_threads > 0) { cp.n_threads_batch = n_threads; } + return llama_init_from_model(m, cp); +} + +struct job_spec { + char prefill[512]; + int n_prefill; + char trailing[512]; + int n_trail; + char pad[512]; + int max_frames; + char out[512]; + char done[512]; +}; + +static int read_job_line(FILE *f, struct job_spec *j) { + memset(j, 0, sizeof(*j)); + int got = fscanf(f, "%511s %d %511s %d %511s %d %511s %511s", + j->prefill, &j->n_prefill, j->trailing, &j->n_trail, + j->pad, &j->max_frames, j->out, j->done); + if (got == EOF) { + return 0; + } + if (got != 7 && got != 8) { + return -1; + } + return 1; +} + +static int read_jobs(const char *path, struct job_spec **jobs_out, int *n_jobs_out) { + FILE *f = fopen(path, "r"); + if (!f) { + fprintf(stderr, "open jobs %s fail\n", path); + return 1; + } + int cap = 8; + int n = 0; + struct job_spec *jobs = calloc((size_t)cap, sizeof(*jobs)); + if (!jobs) { + fclose(f); + return 1; + } + while (1) { + struct job_spec j; + int got = read_job_line(f, &j); + if (got == 0) { + break; + } + if (got < 0) { + fprintf(stderr, "bad jobs line in %s\n", path); + free(jobs); + fclose(f); + return 1; + } + if (n == cap) { + cap *= 2; + struct job_spec *next = realloc(jobs, (size_t)cap * sizeof(*jobs)); + if (!next) { + free(jobs); + fclose(f); + return 1; + } + jobs = next; + } + jobs[n++] = j; + } + fclose(f); + *jobs_out = jobs; + *n_jobs_out = n; + return n == 0; +} + +static int run_job(struct llama_context *ct, struct llama_context *cc, + struct llama_batch *bt, struct llama_batch *bc, + float *cemb, float **pemb, uint16_t **heads, uint16_t *talker_head, + const struct job_spec *job, FILE *out) { + float *prefill = load_bin(job->prefill, (long)job->n_prefill * HID); + float *trail = load_bin(job->trailing, (long)job->n_trail * HID); + float *pad = load_bin(job->pad, HID); + + llama_memory_clear(llama_get_memory(ct), true); + llama_memory_clear(llama_get_memory(cc), true); + + double t0 = now_s(); + memcpy(bt->embd, prefill, (size_t)job->n_prefill * HID * 4); + for (int i = 0; i < job->n_prefill; i++) { + bt->pos[i] = i; + bt->n_seq_id[i] = 1; + bt->seq_id[i][0] = 0; + bt->logits[i] = i == job->n_prefill - 1; + } + bt->n_tokens = job->n_prefill; + if (talker_decode(ct, *bt, talker_head != NULL)) { + return 1; + } + fprintf(stderr, "prefill %d tok %.3fs\n", job->n_prefill, now_s() - t0); + + double t_talker = 0, t_cp = 0, t_heads = 0; + double t_cp_clear = 0, t_cp_init_decode = 0, t_cp_step_decode = 0; + const int cp_seq_per_frame = getenv("Q3TTS_CP_SEQ_PER_FRAME") != NULL; + const int head_argmax = getenv("Q3TTS_HEAD_ARGMAX") != NULL; + const int cp_meta_clear = getenv("Q3TTS_CP_META_CLEAR") != NULL; + const int cp_seq_rm = getenv("Q3TTS_CP_SEQ_RM") != NULL; + const int cp_profile = env_int("Q3TTS_CP_PROFILE", 0) != 0; + int cp_active_heads = env_int("Q3TTS_CP_ACTIVE_HEADS", 15); + if (cp_active_heads < 1) { + cp_active_heads = 1; + } + if (cp_active_heads > 15) { + cp_active_heads = 15; + } + int cp_fill_code = env_int("Q3TTS_CP_FILL_CODE", 0); + const int cp_switch_threads = + getenv("Q3TTS_CP_INIT_THREADS") || getenv("Q3TTS_CP_INIT_THREADS_BATCH") || + getenv("Q3TTS_CP_STEP_THREADS") || getenv("Q3TTS_CP_STEP_THREADS_BATCH"); + int cp_cur_threads = llama_n_threads(cc); + int cp_cur_threads_batch = llama_n_threads_batch(cc); + const int cp_init_threads = env_int("Q3TTS_CP_INIT_THREADS", cp_cur_threads); + const int cp_init_threads_batch = env_int("Q3TTS_CP_INIT_THREADS_BATCH", cp_cur_threads_batch); + const int cp_step_threads = env_int("Q3TTS_CP_STEP_THREADS", cp_init_threads); + const int cp_step_threads_batch = env_int("Q3TTS_CP_STEP_THREADS_BATCH", cp_init_threads_batch); + int frames = 0; + int codes[NCB]; + unsigned char seen_c0[CPV]; + memset(seen_c0, 0, sizeof(seen_c0)); + const float repetition_penalty = env_float("Q3TTS_TALKER_REPETITION_PENALTY", 1.0f); + const int use_repetition_penalty = repetition_penalty > 1.0f; + for (int f = 0; f < job->max_frames; f++) { + for (int i = 0; i < NCB; i++) { + codes[i] = cp_fill_code; + } + const float *hid = llama_get_embeddings_ith(ct, bt->n_tokens - 1); + const float *lg = talker_head ? NULL : llama_get_logits_ith(ct, bt->n_tokens - 1); + int c0 = 0; + float best = -1e30f; + float eos_logit = -1e30f; + if (talker_head) { + if (use_repetition_penalty) { + static float talker_logits[EOS + 1]; + mv_f16(talker_head, EOS + 1, HID, hid, talker_logits); + for (int i = 0; i < CPV; i++) { + float v = talker_logits[i]; + if (seen_c0[i]) { + v = v < 0.0f ? v * repetition_penalty : v / repetition_penalty; + } + if (v > best) { + best = v; + c0 = i; + } + } + eos_logit = talker_logits[EOS]; + } else { + c0 = mv_f16_argmax_eos(talker_head, 2048, EOS, HID, hid, &best, &eos_logit); + } + } else { + for (int i = 0; i < 2048; i++) { + float v = lg[i]; + if (use_repetition_penalty && seen_c0[i]) { + v = v < 0.0f ? v * repetition_penalty : v / repetition_penalty; + } + if (v > best) { + best = v; + c0 = i; + } + } + eos_logit = lg[EOS]; + } + if (f >= 2 && eos_logit > best) { + fprintf(stderr, "EOS at %d\n", f); + break; + } + codes[0] = c0; + if (c0 >= 0 && c0 < CPV) { + seen_c0[c0] = 1; + } + + double tc = now_s(); + const int cp_seq = cp_seq_per_frame ? (f + 1) : 0; + if (!cp_seq_per_frame) { + double tclear = now_s(); + if (cp_seq_rm) { + llama_memory_seq_rm(llama_get_memory(cc), 0, -1, -1); + } else { + llama_memory_clear(llama_get_memory(cc), !cp_meta_clear); + } + t_cp_clear += now_s() - tclear; + } + memcpy(bc->embd, hid, HID * 4); + memcpy(bc->embd + HID, cemb + (long)c0 * HID, HID * 4); + for (int i = 0; i < 2; i++) { + bc->pos[i] = i; + bc->n_seq_id[i] = 1; + bc->seq_id[i][0] = cp_seq; + bc->logits[i] = i == 1; + } + bc->n_tokens = 2; + if (cp_switch_threads) { + set_ctx_threads_if_changed(cc, &cp_cur_threads, &cp_cur_threads_batch, + cp_init_threads, cp_init_threads_batch); + } + double td = now_s(); + if (cp_decode(cc, *bc)) { + return 1; + } + t_cp_init_decode += now_s() - td; + if (cp_switch_threads) { + set_ctx_threads_if_changed(cc, &cp_cur_threads, &cp_cur_threads_batch, + cp_step_threads, cp_step_threads_batch); + } + for (int s = 0; s < cp_active_heads; s++) { + const float *h = llama_get_embeddings_ith(cc, bc->n_tokens - 1); + double th = now_s(); + static float hl[CPV]; + int g = 0; + float gb = -1e30f; + if (head_argmax) { + g = mv_f16_argmax(heads[s], CPV, HID, h, &gb); + } else { + mv_f16(heads[s], CPV, HID, h, hl); + for (int v = 0; v < CPV; v++) { + if (hl[v] > gb) { + gb = hl[v]; + g = v; + } + } + } + t_heads += now_s() - th; + codes[s + 1] = g; + if (s == cp_active_heads - 1) { + break; + } + memcpy(bc->embd, pemb[s] + (long)g * HID, HID * 4); + bc->pos[0] = 2 + s; + bc->n_seq_id[0] = 1; + bc->seq_id[0][0] = cp_seq; + bc->logits[0] = 1; + bc->n_tokens = 1; + td = now_s(); + if (cp_decode(cc, *bc)) { + return 1; + } + t_cp_step_decode += now_s() - td; + } + if (cp_fill_code < 0 && cp_active_heads < 15) { + int fill = codes[cp_active_heads]; + for (int i = cp_active_heads + 1; i < NCB; i++) { + codes[i] = fill; + } + } + t_cp += now_s() - tc; + + fwrite(codes, 4, NCB, out); + fflush(out); + frames++; + + double tt = now_s(); + float *e = bt->embd; + const float *tr = f < job->n_trail ? trail + (long)f * HID : pad; + for (int d = 0; d < HID; d++) { + float acc = cemb[(long)codes[0] * HID + d] + tr[d]; + for (int i = 0; i < 15; i++) { + acc += pemb[i][(long)codes[i + 1] * HID + d]; + } + e[d] = acc; + } + bt->pos[0] = job->n_prefill + f; + bt->n_seq_id[0] = 1; + bt->seq_id[0][0] = 0; + bt->logits[0] = 1; + bt->n_tokens = 1; + if (talker_decode(ct, *bt, talker_head != NULL)) { + return 1; + } + t_talker += now_s() - tt; + } + + if (frames > 0) { + fprintf(stderr, "frames %d talker %.1fms/f cp %.1fms/f (heads %.1f) total %.2fs\n", + frames, t_talker / frames * 1000, t_cp / frames * 1000, + t_heads / frames * 1000, now_s() - t0); + if (cp_profile) { + fprintf(stderr, + "cp_detail clear %.3fms/f init_decode %.3fms/f step_decode %.3fms/f heads %.3fms/f\n", + t_cp_clear / frames * 1000, + t_cp_init_decode / frames * 1000, + t_cp_step_decode / frames * 1000, + t_heads / frames * 1000); + } + } else { + fprintf(stderr, "frames 0 total %.2fs\n", now_s() - t0); + } + + free(prefill); + free(trail); + free(pad); + return 0; +} + +int main(int argc, char **argv) { + if (argc < 4) { + fprintf(stderr, "args\n"); + return 1; + } + + struct job_spec *jobs = NULL; + int n_jobs = 0; + int jobs_mode = argc >= 5 && strcmp(argv[3], "--jobs") == 0; + int jobs_stdin_mode = argc >= 6 && strcmp(argv[3], "--jobs-stdin") == 0; + if (jobs_mode) { + if (read_jobs(argv[4], &jobs, &n_jobs)) { + return 1; + } + } else if (jobs_stdin_mode) { + jobs = NULL; + n_jobs = 0; + } else { + if (argc < 10) { + fprintf(stderr, "args\n"); + return 1; + } + jobs = calloc(1, sizeof(*jobs)); + if (!jobs) { + return 1; + } + n_jobs = 1; + snprintf(jobs[0].prefill, sizeof(jobs[0].prefill), "%s", argv[4]); + jobs[0].n_prefill = atoi(argv[5]); + snprintf(jobs[0].trailing, sizeof(jobs[0].trailing), "%s", argv[6]); + jobs[0].n_trail = atoi(argv[7]); + snprintf(jobs[0].pad, sizeof(jobs[0].pad), "%s", argv[8]); + jobs[0].max_frames = atoi(argv[9]); + jobs[0].out[0] = '\0'; + } + + int max_prefill = 0; + int max_frames = 0; + if (jobs_stdin_mode) { + max_prefill = atoi(argv[4]); + max_frames = atoi(argv[5]); + if (max_prefill <= 0 || max_frames <= 0) { + fprintf(stderr, "bad --jobs-stdin limits\n"); + return 1; + } + } else { + for (int i = 0; i < n_jobs; i++) { + if (jobs[i].n_prefill > max_prefill) { + max_prefill = jobs[i].n_prefill; + } + if (jobs[i].max_frames > max_frames) { + max_frames = jobs[i].max_frames; + } + } + } + + char p[512]; + float *cemb = load_gguf_tensor(argv[1], "q3tts.codec_embedding.weight", GGML_TYPE_F32, 3072UL * HID * sizeof(float)); + float *pemb[15]; + uint16_t *heads[15]; + for (int i = 0; i < 15; i++) { + snprintf(p, sizeof(p), "q3tts.cp_embedding.%d.weight", i); + pemb[i] = load_gguf_tensor(argv[2], p, GGML_TYPE_F32, (size_t)CPV * HID * sizeof(float)); + snprintf(p, sizeof(p), "q3tts.cp_head_f16.%d.weight", i); + heads[i] = load_gguf_tensor(argv[2], p, GGML_TYPE_F16, (size_t)CPV * HID * sizeof(uint16_t)); + } + uint16_t *talker_head = load_gguf_tensor(argv[1], "q3tts.talker_head_f16.weight", GGML_TYPE_F16, 3072UL * HID * sizeof(uint16_t)); + sticky_qwen3_embed_only = talker_head != NULL && getenv("Q3TTS_CP_EMBED_ONLY") != NULL; + if (sticky_qwen3_embed_only) { + setenv("LLAMA_QWEN3_EMBED_ONLY", "1", 1); + } + + if (getenv("Q3TTS_OUTPUT_ONLY_EMBEDDINGS") != NULL) { + setenv("LLAMA_EMBEDDINGS_OUTPUT_ONLY", "1", 1); + } + default_llama_env(); + stdout_to_stderr(); + llama_log_set(log_to_stderr, NULL); + llama_backend_init(); + struct llama_model_params mp = llama_model_default_params(); + struct llama_model *mt = llama_model_load_from_file(argv[1], mp); + if (!mt) { + return 1; + } + struct llama_model *mc = llama_model_load_from_file(argv[2], mp); + if (!mc) { + return 1; + } + + int talker_ctx = env_int("Q3TTS_TALKER_CTX", 4096); + int min_talker_ctx = max_prefill + max_frames + 1; + if (talker_ctx < min_talker_ctx) { + talker_ctx = min_talker_ctx; + } + const int cp_seq_per_frame = getenv("Q3TTS_CP_SEQ_PER_FRAME") != NULL; + int cp_ctx = env_int("Q3TTS_CP_CTX", cp_seq_per_frame ? (max_frames * 16 + 16) : 64); + int cp_ctx_min = env_int("Q3TTS_CP_CTX_MIN", 32); + if (cp_ctx_min < 16) { + cp_ctx_min = 16; + } + if (cp_ctx < cp_ctx_min) { + cp_ctx = cp_ctx_min; + } + int cp_seq_max = env_int("Q3TTS_CP_SEQ_MAX", cp_seq_per_frame ? (max_frames + 1) : 1); + int talker_batch = env_int("Q3TTS_TALKER_BATCH", talker_ctx); + if (talker_batch < max_prefill) { + talker_batch = max_prefill; + } + int talker_ubatch = env_int("Q3TTS_TALKER_UBATCH", 512); + if (talker_ubatch < 1) { + talker_ubatch = 1; + } + int cp_batch = env_int("Q3TTS_CP_BATCH", cp_ctx); + if (cp_batch < 2) { + cp_batch = 2; + } + int cp_ubatch = env_int("Q3TTS_CP_UBATCH", 512); + if (cp_ubatch < 1) { + cp_ubatch = 1; + } + int talker_threads = env_int("Q3TTS_TALKER_THREADS", 0); + int cp_threads = env_int("Q3TTS_CP_THREADS", 0); + int talker_threads_batch = env_int("Q3TTS_TALKER_THREADS_BATCH", talker_threads); + int cp_threads_batch = env_int("Q3TTS_CP_THREADS_BATCH", cp_threads); + int cp_step_threads = env_int("Q3TTS_CP_STEP_THREADS", cp_threads); + int cp_step_threads_batch = env_int("Q3TTS_CP_STEP_THREADS_BATCH", cp_threads_batch); + struct llama_context *ct = mk_ctx(mt, talker_ctx, talker_batch, talker_ubatch, talker_threads, + talker_threads_batch, 1); + struct llama_context *cc = mk_ctx(mc, cp_ctx, cp_batch, cp_ubatch, cp_threads, + cp_threads_batch, cp_seq_max); + int llama_threads = max_i(llama_n_threads(ct), llama_n_threads_batch(ct)); + llama_threads = max_i(llama_threads, llama_n_threads(cc)); + llama_threads = max_i(llama_threads, llama_n_threads_batch(cc)); + llama_threads = max_i(llama_threads, cp_step_threads); + llama_threads = max_i(llama_threads, cp_step_threads_batch); + struct ggml_threadpool *llama_threadpool = create_llama_threadpool(llama_threads); + if (llama_threadpool) { + llama_attach_threadpool(ct, llama_threadpool, llama_threadpool); + llama_attach_threadpool(cc, llama_threadpool, llama_threadpool); + } + struct llama_batch bt = llama_batch_init(max_prefill + max_frames + 1, HID, 1); + struct llama_batch bc = llama_batch_init(2, HID, 1); + pools_init(); + restore_stdout(); + if (jobs_stdin_mode) { + fprintf(stderr, "talker_stdin_ready\n"); + fflush(stderr); + } + + int exit_code = 0; + for (int i = 0; jobs_stdin_mode || i < n_jobs; i++) { + struct job_spec stdin_job; + struct job_spec *job = jobs_stdin_mode ? &stdin_job : &jobs[i]; + if (jobs_stdin_mode) { + int got = read_job_line(stdin, job); + if (got == 0) { + break; + } + if (got < 0) { + fprintf(stderr, "bad stdin jobs line\n"); + exit_code = 1; + break; + } + if (job->n_prefill > max_prefill || job->max_frames > max_frames) { + fprintf(stderr, "stdin job exceeds limits: prefill %d/%d frames %d/%d\n", + job->n_prefill, max_prefill, job->max_frames, max_frames); + exit_code = 1; + break; + } + } + FILE *out = stdout; + if (job->out[0]) { + out = fopen(job->out, "wb"); + if (!out) { + fprintf(stderr, "open output %s fail\n", job->out); + exit_code = 1; + break; + } + } + int rc = run_job(ct, cc, &bt, &bc, cemb, pemb, heads, talker_head, job, out); + if (job->out[0]) { + fclose(out); + } + if (rc) { + exit_code = rc; + break; + } + if (job->done[0]) { + FILE *done = fopen(job->done, "wb"); + if (!done) { + fprintf(stderr, "open done marker %s fail\n", job->done); + exit_code = 1; + break; + } + fclose(done); + } + } + + if (exit_code == 0 && getenv("Q3TTS_LLAMA_PERF_PRINT") != NULL) { + fprintf(stderr, "talker perf:\n"); + llama_perf_context_print(ct); + fprintf(stderr, "cp perf:\n"); + llama_perf_context_print(cc); + } + if (llama_threadpool) { + llama_detach_threadpool(ct); + llama_detach_threadpool(cc); + ggml_threadpool_free(llama_threadpool); + } + return exit_code; +} diff --git a/tools/speech/backends/qwen3_tts/tools/q3tts_cp_kernel_bench.cpp b/tools/speech/backends/qwen3_tts/tools/q3tts_cp_kernel_bench.cpp new file mode 100644 index 000000000000..ac5ec7801225 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/tools/q3tts_cp_kernel_bench.cpp @@ -0,0 +1,256 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "spacemit/ime_kernels.h" +#include "spacemit/rvv_kernels.h" + +extern "C" { +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n); +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n); +} + +namespace { + +struct block_q4_0x32_layout { + _Float16 d[32]; + uint8_t qs[16 * 32]; +}; + +static uint32_t xorshift32(uint32_t & state) { + state ^= state << 13; + state ^= state >> 17; + state ^= state << 5; + return state; +} + +static void fill_f32(std::vector & data) { + uint32_t state = 0x12345678u; + for (float & v : data) { + const int x = (int) (xorshift32(state) & 0x3ffu) - 512; + v = (float) x / 256.0f; + } +} + +static void fill_q4_repacked(std::vector & data, int64_t k_blks, int64_t n) { + constexpr int64_t k_subblks_per_superblk = 8; + const int64_t b_superblk_stride = (int64_t) sizeof(block_q4_0x32_layout) * k_subblks_per_superblk; + const int64_t b_tile_stride = k_blks * b_superblk_stride; + uint32_t state = 0x87654321u; + + for (int64_t ni = 0; ni < n; ni += 32) { + uint8_t * tile = data.data() + (ni / 32) * b_tile_stride; + for (int64_t kb = 0; kb < k_blks; ++kb) { + uint8_t * superblk = tile + kb * b_superblk_stride; + auto * blocks = reinterpret_cast(superblk); + for (int s = 0; s < k_subblks_per_superblk; ++s) { + for (int i = 0; i < 32; ++i) { + blocks[s].d[i] = (_Float16) 0.02f; + } + for (uint8_t & q : blocks[s].qs) { + q = (uint8_t) xorshift32(state); + } + } + } + } +} + +static double now_s() { + using clock = std::chrono::steady_clock; + return std::chrono::duration(clock::now().time_since_epoch()).count(); +} + +static float checksum(const std::vector & values) { + float sum = 0.0f; + for (float v : values) { + sum += v * 0.000001f; + } + return sum; +} + +struct shape_buffers { + static constexpr int64_t block_len = 256; + + int64_t k = 0; + int64_t n = 0; + int64_t k_blks = 0; + size_t a_stride = 0; + size_t b_superblk_stride = 0; + size_t b_tile_stride = 0; + std::vector src; + std::vector a; + std::vector b; + std::vector out; +}; + +static shape_buffers make_shape(int64_t k, int64_t n) { + shape_buffers shape; + shape.k = k; + shape.n = n; + shape.k_blks = spacemit_kernels::div_round_up(k, shape_buffers::block_len); + shape.a_stride = spacemit_kernels::q8_hp_blk_size(shape_buffers::block_len, true, true); + shape.b_superblk_stride = sizeof(block_q4_0x32_layout) * 8; + shape.b_tile_stride = (size_t) shape.k_blks * shape.b_superblk_stride; + + const size_t a_bytes = (size_t) shape.k_blks * shape.a_stride; + const size_t b_bytes = (size_t) spacemit_kernels::div_round_up(n, int64_t(32)) * shape.b_tile_stride; + + shape.src.resize((size_t) k); + shape.a.resize(a_bytes); + shape.b.resize(b_bytes); + shape.out.resize((size_t) n); + fill_f32(shape.src); + fill_q4_repacked(shape.b, shape.k_blks, n); + spacemit_kernels::rvv::quantize_a_row_i8_hp(shape_buffers::block_len, shape.src.data(), (size_t) k, + shape.a.data()); + return shape; +} + +static bool env_enabled(const char * name) { + const char * value = std::getenv(name); + if (value == nullptr || value[0] == '\0') { + return false; + } + return std::strcmp(value, "0") != 0 && std::strcmp(value, "false") != 0 && std::strcmp(value, "FALSE") != 0; +} + +static void run_shape_single(shape_buffers & shape, int iters) { + for (int i = 0; i < 50; ++i) { + spacemit_kernels::ime2::gemm_kernel_i8i4_hp(shape_buffers::block_len, shape.a.data(), shape.b.data(), nullptr, + shape.out.data(), 1, (size_t) shape.n, (size_t) shape.k_blks, + (size_t) shape.n); + } + + const double t0 = now_s(); + for (int i = 0; i < iters; ++i) { + spacemit_kernels::ime2::gemm_kernel_i8i4_hp(shape_buffers::block_len, shape.a.data(), shape.b.data(), nullptr, + shape.out.data(), 1, (size_t) shape.n, (size_t) shape.k_blks, + (size_t) shape.n); + } + const double t1 = now_s(); + const double us = (t1 - t0) * 1000000.0 / (double) iters; + std::printf("single m=1 k=%lld n=%lld k_blks=%lld iters=%d time_us=%.3f checksum=%g\n", (long long) shape.k, + (long long) shape.n, (long long) shape.k_blks, iters, us, checksum(shape.out)); +} + +struct dispatch_context { + shape_buffers * shape = nullptr; + pthread_barrier_t start; + pthread_barrier_t done; + int nth = 4; + int iters = 0; + int warmup = 50; + int64_t tile_cols = 32; +}; + +struct worker_context { + dispatch_context * dispatch = nullptr; + int ith = 0; +}; + +static void run_dispatch_tiles(dispatch_context * dispatch, int ith) { + shape_buffers & shape = *dispatch->shape; + for (int64_t ni = (int64_t) ith * dispatch->tile_cols; ni < shape.n; ni += dispatch->tile_cols * dispatch->nth) { + const int64_t nb_real = std::min(shape.n - ni, dispatch->tile_cols); + uint8_t * b_row = shape.b.data() + (size_t) (ni / 32) * shape.b_tile_stride; + float * c_blk = shape.out.data() + ni; + spacemit_kernels::ime2::gemm_kernel_i8i4_hp(shape_buffers::block_len, shape.a.data(), b_row, nullptr, c_blk, 1, + (size_t) nb_real, (size_t) shape.k_blks, (size_t) shape.n); + } +} + +static void * dispatch_worker(void * arg) { + worker_context * worker = static_cast(arg); + dispatch_context * dispatch = worker->dispatch; + + ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(worker->ith); + for (int i = 0; i < dispatch->warmup + dispatch->iters; ++i) { + pthread_barrier_wait(&dispatch->start); + run_dispatch_tiles(dispatch, worker->ith); + pthread_barrier_wait(&dispatch->done); + } + ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(worker->ith); + return nullptr; +} + +static void run_shape_dispatch(shape_buffers & shape, int iters) { + dispatch_context dispatch; + dispatch.shape = &shape; + dispatch.iters = iters; + dispatch.tile_cols = env_enabled("SPACEMIT_Q4_HP_M1_N64") ? 64 : 32; + + pthread_barrier_init(&dispatch.start, nullptr, (unsigned) dispatch.nth + 1); + pthread_barrier_init(&dispatch.done, nullptr, (unsigned) dispatch.nth + 1); + + std::vector threads((size_t) dispatch.nth); + std::vector workers((size_t) dispatch.nth); + for (int ith = 0; ith < dispatch.nth; ++ith) { + workers[(size_t) ith].dispatch = &dispatch; + workers[(size_t) ith].ith = ith; + pthread_create(&threads[(size_t) ith], nullptr, dispatch_worker, &workers[(size_t) ith]); + } + + for (int i = 0; i < dispatch.warmup; ++i) { + pthread_barrier_wait(&dispatch.start); + pthread_barrier_wait(&dispatch.done); + } + + const double t0 = now_s(); + for (int i = 0; i < iters; ++i) { + pthread_barrier_wait(&dispatch.start); + pthread_barrier_wait(&dispatch.done); + } + const double t1 = now_s(); + + for (pthread_t thread : threads) { + pthread_join(thread, nullptr); + } + + pthread_barrier_destroy(&dispatch.start); + pthread_barrier_destroy(&dispatch.done); + + const double us = (t1 - t0) * 1000000.0 / (double) iters; + std::printf("dispatch4 m=1 k=%lld n=%lld tile_cols=%lld iters=%d time_us=%.3f checksum=%g\n", + (long long) shape.k, (long long) shape.n, (long long) dispatch.tile_cols, iters, us, + checksum(shape.out)); +} + +static void run_shape(int64_t k, int64_t n, int iters, bool dispatch) { + shape_buffers shape = make_shape(k, n); + if (dispatch) { + run_shape_dispatch(shape, iters); + } else { + run_shape_single(shape, iters); + } +} + +} // namespace + +int main(int argc, char ** argv) { + int iters = 2000; + if (argc > 1) { + iters = std::atoi(argv[1]); + if (iters <= 0) { + iters = 2000; + } + } + const bool dispatch = argc > 2 && std::strcmp(argv[2], "dispatch") == 0; + + if (!dispatch) { + ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(0); + } + run_shape(1024, 1024, iters, dispatch); + run_shape(2048, 1024, iters, dispatch); + run_shape(3072, 1024, iters, dispatch); + run_shape(1024, 4096, iters, dispatch); + run_shape(1024, 6144, iters, dispatch); + run_shape(3072, 1024, iters, dispatch); + if (!dispatch) { + ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(0); + } + return 0; +} diff --git a/tools/speech/backends/qwen3_tts/tools/q3tts_ref_to_bin.cpp b/tools/speech/backends/qwen3_tts/tools/q3tts_ref_to_bin.cpp new file mode 100644 index 000000000000..9c06ce3757b4 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/tools/q3tts_ref_to_bin.cpp @@ -0,0 +1,116 @@ +#include + +#include "q3tts_frontend.h" + +#include +#include +#include +#include + +namespace { + +std::string env_str(const char *name, const std::string &fallback) { + const char *v = std::getenv(name); + return (v && *v) ? std::string(v) : fallback; +} + +int env_int(const char *name, int fallback) { + const char *v = std::getenv(name); + return (v && *v) ? std::atoi(v) : fallback; +} + +std::string default_model_dir() { + const char *home = std::getenv("HOME"); + if (home && *home) { + return std::string(home) + "/.cache/models/tts/qwen3-tts"; + } + return "/tmp/.cache/models/tts/qwen3-tts"; +} + +struct Args { + std::string ref_wav; + std::string ref_text; + std::string ref_text_file; + std::string out; + std::string model_dir = env_str("Q3TTS_MODEL_DIR", default_model_dir()); + int threads = env_int("Q3TTS_FRONTEND_THREADS", 2); +}; + +Args parse_args(int argc, char **argv) { + Args args; + for (int i = 1; i < argc; ++i) { + std::string k = argv[i]; + auto need = [&](const char *name) -> std::string { + if (i + 1 >= argc) { + throw std::runtime_error(std::string("missing value for ") + name); + } + return argv[++i]; + }; + if (k == "--ref-wav") { + args.ref_wav = need("--ref-wav"); + } else if (k == "--ref-text") { + args.ref_text = need("--ref-text"); + } else if (k == "--ref-text-file") { + args.ref_text_file = need("--ref-text-file"); + } else if (k == "--out") { + args.out = need("--out"); + } else if (k == "--model-dir") { + args.model_dir = need("--model-dir"); + } else if (k == "--threads") { + args.threads = std::stoi(need("--threads")); + } else { + throw std::runtime_error("unknown arg: " + k); + } + } + if (args.ref_wav.empty()) { + throw std::runtime_error("missing --ref-wav"); + } + if (args.out.empty()) { + throw std::runtime_error("missing --out"); + } + if (!args.ref_text_file.empty()) { + args.ref_text = q3tts_frontend::read_text_file(args.ref_text_file); + while (!args.ref_text.empty() && + (args.ref_text.back() == '\n' || args.ref_text.back() == '\r')) { + args.ref_text.pop_back(); + } + } + return args; +} + +} // namespace + +int main(int argc, char **argv) { + try { + Args args = parse_args(argc, argv); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "q3tts_ref_to_bin"); + + const auto ref_audio = q3tts_frontend::read_wav_mono_24k(args.ref_wav); + const double ref_seconds = static_cast(ref_audio.size()) / 24000.0; + if (ref_seconds < 4.0) { + std::cerr << "warning: reference audio is short (" << ref_seconds + << "s); use a clean 5-10s clip for stable voice cloning\n"; + } + + const std::string speaker_onnx = q3tts_frontend::speaker_encoder_path(args.model_dir); + auto spk = q3tts_frontend::run_speaker_encoder(env, speaker_onnx, args.ref_wav, args.threads); + if (args.ref_text.empty()) { + q3tts_frontend::write_speaker_bin(args.out, spk); + std::cout << "ref_bin " << args.out << " floats " << spk.size() + << " bytes " << (spk.size() * sizeof(float)) << "\n"; + } else { + const std::string codec_onnx = q3tts_frontend::codec_encoder_path(args.model_dir); + auto codes = q3tts_frontend::run_codec_encoder(env, codec_onnx, args.ref_wav, args.threads); + q3tts_frontend::write_reference_prompt_bin(args.out, spk, args.ref_text, codes); + std::cout << "ref_prompt_bin " << args.out + << " speaker_floats " << spk.size() + << " ref_frames " << codes.size() + << " ref_text_bytes " << args.ref_text.size() + << "\n"; + } + } catch (const std::exception &e) { + std::cerr << "error: " << e.what() << "\n"; + return 1; + } + return 0; +} diff --git a/tools/speech/backends/qwen3_tts/tools/q3tts_run_main.cpp b/tools/speech/backends/qwen3_tts/tools/q3tts_run_main.cpp new file mode 100644 index 000000000000..1e38a5a1e026 --- /dev/null +++ b/tools/speech/backends/qwen3_tts/tools/q3tts_run_main.cpp @@ -0,0 +1,5 @@ +#include "qwen3_tts_runtime.h" + +int main(int argc, char **argv) { + return qwen3_tts::run_cli(argc, argv); +}