From 6fe44ae4afadb26ce018e70fdae9a8e3ad73081a Mon Sep 17 00:00:00 2001 From: AigizK Date: Tue, 17 Sep 2024 18:30:30 +0500 Subject: [PATCH 1/2] transcript with timestamp --- examples/ctc_with_timestamp.py | 141 +++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 examples/ctc_with_timestamp.py diff --git a/examples/ctc_with_timestamp.py b/examples/ctc_with_timestamp.py new file mode 100644 index 0000000..63d42a6 --- /dev/null +++ b/examples/ctc_with_timestamp.py @@ -0,0 +1,141 @@ +import argparse + +import torch +import torchaudio +from nemo.collections.asr.models import EncDecCTCModel +from nemo.collections.asr.modules.audio_preprocessing import ( + AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor, +) +from nemo.collections.asr.parts.preprocessing.features import ( + FilterbankFeaturesTA as NeMoFilterbankFeaturesTA, +) + + +class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA): + def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs): + if "window_size" in kwargs: + del kwargs["window_size"] + if "window_stride" in kwargs: + del kwargs["window_stride"] + + super().__init__(**kwargs) + + self._mel_spec_extractor = torchaudio.transforms.MelSpectrogram( + sample_rate=self._sample_rate, + win_length=self.win_length, + hop_length=self.hop_length, + n_mels=kwargs["nfilt"], + window_fn=self.torch_windows[kwargs["window"]], + mel_scale=mel_scale, + norm=kwargs["mel_norm"], + n_fft=kwargs["n_fft"], + f_max=kwargs.get("highfreq", None), + f_min=kwargs.get("lowfreq", 0), + wkwargs=wkwargs, + ) + + +class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor): + def __init__(self, mel_scale: str = "htk", **kwargs): + super().__init__(**kwargs) + kwargs["nfilt"] = kwargs["features"] + del kwargs["features"] + self.featurizer = ( + FilterbankFeaturesTA( # Deprecated arguments; kept for config compatibility + mel_scale=mel_scale, + **kwargs, + ) + ) + +def get_timestamps(logprobs, blank_id, stride, sample_rate): + hypotheses, word_timestamps = [], [] + timestamp_dict = {} + last_char = None + current_word = '' + word_start_frame = 0 + + # Алфавит из конфигурации + alphabet = [' ', 'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', 'ф', 'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я'] + + for frame, logprob in enumerate(logprobs[0]): + char = logprob.argmax().item() + + if char != blank_id: + if char != last_char: + if current_word and char == 0: # Пробел + end_time = frame * stride / sample_rate + timestamp_dict[current_word] = (word_start_frame * stride / sample_rate, end_time) + word_timestamps.append((current_word, word_start_frame * stride / sample_rate, end_time)) + hypotheses.append(current_word) + current_word = '' + word_start_frame = frame + else: + current_word += alphabet[char] + last_char = char + + if current_word: + end_time = len(logprobs[0]) * stride / sample_rate + timestamp_dict[current_word] = (word_start_frame * stride / sample_rate, end_time) + word_timestamps.append((current_word, word_start_frame * stride / sample_rate, end_time)) + hypotheses.append(current_word) + + return ' '.join(hypotheses), word_timestamps + +def parse_with_timestamp(audio_path: str, model: EncDecCTCModel): + device = model.device + + # Загрузка аудио + audio, sample_rate = torchaudio.load(audio_path) + audio = audio.to(device) + + # Убедимся, что аудио имеет правильную форму (batch, time) + if audio.dim() == 1: + audio = audio.unsqueeze(0) + elif audio.dim() > 2: + audio = audio.squeeze() + if audio.shape[0] > 1: + audio = audio.mean(dim=0, keepdim=True) + + # Получаем длину аудио + audio_length = torch.tensor([audio.shape[1]], device=device) + + # Получаем логарифмические вероятности + with torch.no_grad(): + log_probs, encoded_len, greedy_predictions = model( + input_signal=audio, input_signal_length=audio_length + ) + + # Получаем stride + stride = model.cfg.preprocessor['n_window_stride'] + + blank_id = len(model.decoder.vocabulary) + transcription, timestamps = get_timestamps(log_probs.cpu().numpy(), blank_id, stride, sample_rate) + + # print(f"Transcription: {transcription}") + # print("Word timestamps:") + # for word, start, end in timestamps: + # print(f" {word}: {start:.2f}s - {end:.2f}s") + + return transcription, timestamps + +device = "cuda" if torch.cuda.is_available() else "cpu" +print(device) +model = EncDecCTCModel.from_config_file("./ctc_model_config.yaml") +ckpt = torch.load("./ctc_model_weights.ckpt", map_location="cpu") +model.load_state_dict(ckpt, strict=False) +model.eval() +model = model.to(device) + +transcription, timestamps = parse_with_timestamp("example.wav", model) +print(f"transcription: {transcription}") +print(f"timestamps: {timestamps}") + + +if __name__ == "__main__": + args = _parse_args() + main( + model_config=args.model_config, + model_weights=args.model_weights, + device=args.device, + audio_path=args.audio_path, + ) From 6aab055d8ce74e3db3e723f2ac4edef989206ecd Mon Sep 17 00:00:00 2001 From: AigizK Date: Tue, 17 Sep 2024 18:33:47 +0500 Subject: [PATCH 2/2] clean --- examples/ctc_with_timestamp.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/ctc_with_timestamp.py b/examples/ctc_with_timestamp.py index 63d42a6..0b0765e 100644 --- a/examples/ctc_with_timestamp.py +++ b/examples/ctc_with_timestamp.py @@ -129,13 +129,3 @@ def parse_with_timestamp(audio_path: str, model: EncDecCTCModel): transcription, timestamps = parse_with_timestamp("example.wav", model) print(f"transcription: {transcription}") print(f"timestamps: {timestamps}") - - -if __name__ == "__main__": - args = _parse_args() - main( - model_config=args.model_config, - model_weights=args.model_weights, - device=args.device, - audio_path=args.audio_path, - )