Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
from cosyvoice.utils.file_utils import load_wav
import datetime
import torchaudio
import uuid


class CosyVoice:
Expand Down Expand Up @@ -190,7 +194,8 @@ def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk

class CosyVoice3(CosyVoice2):

def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1,
speaker_info_dir=None, graph_mode=False):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
Expand All @@ -212,7 +217,8 @@ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_c
if not _has_accelerator and (load_trt is True or fp16 is True):
load_trt, fp16 = False, False
logging.warning('no cuda/npu device, set load_trt/fp16 to False')
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16,
model_dir=model_dir, graph_mode=graph_mode)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
Expand All @@ -227,6 +233,94 @@ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_c
self.fp16)
del configs

# ����speaker_infoĿ¼
self.speaker_info_dir = speaker_info_dir or os.path.join(model_dir, 'speaker_info')
if not os.path.exists(self.speaker_info_dir):
assert False
logging.warning(f'speaker_info_dir {self.speaker_info_dir} does not exist')

# Ԥ��������˵������Ϣ
self.promote_wave_info = {}
self._preload_all_wave_info()

def _preload_all_wave_info(self):
"""Ԥ��������˵������Ϣ����ͳһ��JSON�ļ���"""
if not os.path.exists(self.speaker_info_dir):
return

import json

# ����speaker_info.json�ļ�
speaker_info_file = os.path.join(self.speaker_info_dir, 'speaker_info.json')
if not os.path.exists(speaker_info_file):
logging.warning(f'speaker_info.json not found in {self.speaker_info_dir}')
return

try:
with open(speaker_info_file, 'r', encoding='utf-8') as f:
all_speakers = json.load(f)

# ����ÿ��˵���˵���Ƶ
for spk_id, wave_info in all_speakers.items():
try:
# ������Ƶ�ļ�
prompt_wav_path = os.path.join(self.speaker_info_dir, wave_info['prompt_wav'])
if os.path.exists(prompt_wav_path):
prompt_wav = load_wav(wave_info['prompt_wav'], 16000)
wave_info['prompt_wav'] = prompt_wav
self.promote_wave_info[spk_id] = wave_info
logging.info(f"Loaded speaker {spk_id} from {prompt_wav_path}, promote={wave_info['prompt_text']}")
else:
logging.error(f'Audio file not found for speaker {spk_id}: {prompt_wav_path}')
except Exception as e:
logging.error(f'Failed to load speaker {spk_id}: {e}')

logging.info(f'Preloaded {len(self.promote_wave_info)} speakers from {speaker_info_file}')

except Exception as e:
logging.error(f'Failed to load {speaker_info_file}: {e}')

def inference_zero_shot_by_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
"""ʹ��Ԥ�����˵����IDִ��zero_shot����"""
# �ӻ����л�ȡ˵������Ϣ
if spk_id not in self.promote_wave_info:
raise ValueError(f'Speaker ID {spk_id} not found. Available IDs: {list(self.promote_wave_info.keys())}')
wave_info = self.promote_wave_info[spk_id]

# 0.save audio
#audio_uuid = str(uuid.uuid1())
# ���������Ϣ
if spk_id not in self.frontend.spk2info:
self.add_zero_shot_spk(wave_info['prompt_text'], wave_info['prompt_wav'], spk_id)

# ʹ��zero_shot�ӿڽ�������
# 1.save audio
#res = torch.tensor([])
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
print(f'thc: inference_zero_shot_by_id: {i}')
model_input = self.frontend.frontend_zero_shot(i, wave_info['prompt_text'], wave_info['prompt_wav'], self.sample_rate, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
# 2.save audio
#res = torch.cat((res, model_output['tts_speech']), dim=1)
yield model_output
start_time = time.time()
# 3.save audio
#now_time = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
#torchaudio.save(f'./server-audio/{now_time}-{audio_uuid}.wav', res, self.sample_rate)

def reload_wave_info(self):
"""���¼�������˵������Ϣ"""
self.promote_wave_info.clear()
self._preload_all_wave_info()

def get_available_spk_ids(self):
"""��ȡ���п��õ�˵����ID"""
return list(self.promote_wave_info.keys())


def AutoModel(**kwargs):
if not os.path.exists(kwargs['model_dir']):
Expand Down
18 changes: 15 additions & 3 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,13 @@ def load_vllm(self, model_dir):
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
from vllm import EngineArgs, LLMEngine
engine_args = EngineArgs(model=model_dir,
dtype='float16', #new added on 2026/2/10
skip_tokenizer_init=True,
enable_prompt_embeds=True,
gpu_memory_utilization=0.2)
gpu_memory_utilization=0.7, #adjust from 0.2 to 0.9, then to 0.5
additional_config={"torchair_graph_config":{"enabled":True}}, #new added on 2026/2/10
compilation_config={'cudagraph_capture_sizes':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,512,1024,2048],"cudagraph_mode": "FULL"},
enable_prefix_caching=False)
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
self.llm.lock = threading.Lock()
del self.llm.llm.model.model.layers
Expand Down Expand Up @@ -369,8 +373,14 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
token_offset = 0
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
while True:
time.sleep(0.1)
time.sleep(0.05)
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
if token_offset == 0: #���װ���Ϊ 20token
#this_token_hop_len = 20
this_token_hop_len = 10
#this_token_hop_len = 9
#print("len(self.tts_speech_token_dict[this_uuid]): ", len(self.tts_speech_token_dict[this_uuid]))
#print("self.flow.pre_lookahead_len: ", self.flow.pre_lookahead_len)
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
Expand Down Expand Up @@ -421,7 +431,9 @@ def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
fp16: bool = False,
model_dir: str = None,
graph_mode: bool = False):
self.device = _get_device()
self.llm = llm
self.flow = flow
Expand Down
32 changes: 25 additions & 7 deletions cosyvoice/flow/DiT/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import torchaudio

from x_transformers.x_transformers import apply_rotary_pos_emb

import torch_npu
from cosyvoice.utils.file_utils import logging
import math

# raw wav to mel spec
class MelSpec(nn.Module):
Expand Down Expand Up @@ -84,17 +86,22 @@ def forward(self, x, scale=1000):


# convolutional position embedding

class Mish(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x * torch.tanh(torch.log(1 + torch.exp(x)))

class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
Mish(),
)

def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
Expand All @@ -119,11 +126,11 @@ def __init__(self, dim, kernel_size=31, groups=16):
self.kernel_size = kernel_size
self.conv1 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
Mish(),
)
self.conv2 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
Mish(),
)

def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
Expand Down Expand Up @@ -388,7 +395,18 @@ def __call__(
else:
attn_mask = None

x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
#x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
atten_mask_npu = torch.logical_not(attn_mask) # atten_mask��Ҫȡ��
head_num = query.shape[1]
x = torch_npu.npu_fusion_attention(
query, key, value, head_num, input_layout="BNSD",
pse=None,
atten_mask=atten_mask_npu,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=2147483647,
next_tockens=2147483647,
keep_prob=1
)[0]
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)

Expand Down
7 changes: 4 additions & 3 deletions cosyvoice/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path

import torchair

class MaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
Expand Down Expand Up @@ -376,7 +376,8 @@ def inference(self,
prompt_feat_len,
embedding,
streaming,
finalize):
finalize,
n_timesteps=None):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
Expand Down Expand Up @@ -406,7 +407,7 @@ def inference(self,
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
n_timesteps=10 if n_timesteps is None else n_timesteps,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
Expand Down
8 changes: 5 additions & 3 deletions cosyvoice/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,11 @@ def add_optional_chunk_mask(xs: torch.Tensor,
else:
chunk_masks = masks
assert chunk_masks.dtype == torch.bool
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
# if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
# print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
# chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
new_mask = (chunk_masks.sum(dim=-1, keepdim=True) == 0)
chunk_masks = torch.logical_or(chunk_masks, new_mask)
return chunk_masks


Expand Down
18 changes: 16 additions & 2 deletions runtime/python/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
import numpy as np
from cosyvoice.utils.file_utils import load_wav
import time


def main():
Expand All @@ -44,6 +45,8 @@ def main():
zero_shot_request.prompt_text = args.prompt_text
prompt_speech = load_wav(args.prompt_wav, 16000)
zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
# 关键:需要定义 proto 中 zeroshotRequest 有 stream 字段
zero_shot_request.stream = args.stream
request.zero_shot_request.CopyFrom(zero_shot_request)
elif args.mode == 'cross_lingual':
logging.info('send cross_lingual request')
Expand All @@ -52,6 +55,13 @@ def main():
prompt_speech = load_wav(args.prompt_wav, 16000)
cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
request.cross_lingual_request.CopyFrom(cross_lingual_request)
elif args.mode == 'zero_shot_by_id':
logging.info('send zero_shot_by_id request')
zero_shot_by_id_request = cosyvoice_pb2.zeroshotByIdRequest()
zero_shot_by_id_request.tts_text = args.tts_text
zero_shot_by_id_request.spk_id = args.spk_id
zero_shot_by_id_request.stream = args.stream
request.zero_shot_by_id_request.CopyFrom(zero_shot_by_id_request)
else:
logging.info('send instruct request')
instruct_request = cosyvoice_pb2.instructRequest()
Expand Down Expand Up @@ -80,7 +90,7 @@ def main():
default='50000')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct', 'zero_shot_by_id'],
help='request mode')
parser.add_argument('--tts_text',
type=str,
Expand All @@ -101,6 +111,10 @@ def main():
parser.add_argument('--tts_wav',
type=str,
default='demo.wav')
# 新增 stream 开关
parser.add_argument('--stream',
action='store_true',
help='whether to use streaming inference')
args = parser.parse_args()
prompt_sr, target_sr = 16000, 22050
prompt_sr, target_sr = 16000, 24000
main()
8 changes: 8 additions & 0 deletions runtime/python/grpc/cosyvoice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ message Request{
zeroshotRequest zero_shot_request = 2;
crosslingualRequest cross_lingual_request = 3;
instructRequest instruct_request = 4;
zeroshotByIdRequest zero_shot_by_id_request = 5;
}
}

Expand All @@ -25,6 +26,7 @@ message zeroshotRequest{
string tts_text = 1;
string prompt_text = 2;
bytes prompt_audio = 3;
bool stream = 4;
}

message crosslingualRequest{
Expand All @@ -38,6 +40,12 @@ message instructRequest{
string instruct_text = 3;
}

message zeroshotByIdRequest{
string tts_text = 1;
string spk_id = 2;
bool stream = 3;
}

message Response{
bytes tts_audio = 1;
}
Loading