Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,22 @@ for utt in utterances:
transcription, (start, end) = utt["transcription"], utt["boundaries"]
print(f"[{gigaam.format_time(start)} - {gigaam.format_time(end)}]: {transcription}")

# Multichannel transcription (diarization)
# Supports stereo/multichannel files or a list of separate files
# Results are automatically sorted by time and interleaved between channels
stereo_file = "conversation_stereo.wav" # or list: ["channel_0.wav", "channel_1.wav"]
results = model.transcribe_multichannel(
stereo_file,
batch_size=4, # batch size for processing segments
pause_threshold=2.0, # pause threshold for grouping segments (seconds)
strict_limit_duration=30.0 # maximum segment duration for model (seconds)
)
for seg in results:
channel = seg["channel"] # channel number (0, 1, ...)
transcription = seg["transcription"]
start, end = seg["boundaries"]
print(f"[{start:.2f}s - {end:.2f}s] Channel {channel}: {transcription}")

# Emotion recognition
model = gigaam.load_model("emo")
emotion2prob = model.get_probs(audio_path)
Expand Down
16 changes: 16 additions & 0 deletions README_ru.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ for utt in utterances:
transcription, (start, end) = utt["transcription"], utt["boundaries"]
print(f"[{gigaam.format_time(start)} - {gigaam.format_time(end)}]: {transcription}")

# Мультиканальная транскрибация (диаризация)
# Поддерживает стерео/многоканальные файлы или список отдельных файлов
# Результаты автоматически сортируются по времени и чередуются между каналами
stereo_file = "conversation_stereo.wav" # или список: ["channel_0.wav", "channel_1.wav"]
results = model.transcribe_multichannel(
stereo_file,
batch_size=4, # размер батча для обработки сегментов
pause_threshold=2.0, # порог паузы для группировки сегментов (секунды)
strict_limit_duration=30.0 # максимальная длительность сегмента для модели (секунды)
)
for seg in results:
channel = seg["channel"] # номер канала (0, 1, ...)
transcription = seg["transcription"]
start, end = seg["boundaries"]
print(f"[{start:.2f}s - {end:.2f}s] Канал {channel}: {transcription}")

# Распознавание эмоций
model = gigaam.load_model("emo")
emotion2prob = model.get_probs(audio_path)
Expand Down
3 changes: 2 additions & 1 deletion gigaam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from tqdm import tqdm

from .model import GigaAM, GigaAMASR, GigaAMEmo
from .preprocess import load_audio
from .preprocess import load_audio, load_multichannel_audio
from .utils import format_time

__all__ = [
"GigaAM",
"GigaAMASR",
"GigaAMEmo",
"load_audio",
"load_multichannel_audio",
"format_time",
"load_model",
]
Expand Down
93 changes: 92 additions & 1 deletion gigaam/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import hydra
import omegaconf
Expand Down Expand Up @@ -169,6 +169,97 @@ def transcribe_longform(
)
return transcribed_segments

@torch.inference_mode()
def transcribe_multichannel(
self,
audio_input: Union[str, List[str]],
batch_size: int = 4,
**kwargs
) -> List[Dict[str, Union[int, str, Tuple[float, float]]]]:
"""
Transcribes multichannel audio with synchronized diarization.

Supports:
- Single stereo/multichannel file (str)
- Multiple separate audio files (List[str])

Handles overlapping speech by cutting segments when other channel starts.
Maintains decoder state between segments for better quality.

Parameters:
-----------
audio_input : Union[str, List[str]]
Either a single multichannel audio file or list of separate files
batch_size : int
Batch size for processing segments (default: 4)
**kwargs
Additional arguments passed to segment_multichannel_audio

Returns:
--------
List of dicts with keys: 'channel', 'transcription', 'boundaries' (start, end)
"""
from .vad_utils import segment_multichannel_audio

# Segment audio with diarization
segments = segment_multichannel_audio(
audio_input, SAMPLE_RATE, device=self._device, **kwargs
)

if not segments:
return []

transcribed_segments = []

# Process all segments together in batches, regardless of channel
# Channel information is only used in the final output

# Process all segments in batches - no state preservation needed
for batch_start in range(0, len(segments), batch_size):
batch_segments = segments[batch_start:batch_start + batch_size]

# Prepare batch - audio is already on GPU from segmentation
batch_audio = []
batch_lengths = []

for seg in batch_segments:
audio = seg["audio"]
# Ensure correct dtype (device should already be correct)
if audio.dtype != self._dtype:
audio = audio.to(self._dtype)
# Ensure audio is 1D: (samples,)
if audio.dim() > 1:
audio = audio.squeeze()
batch_audio.append(audio)
batch_lengths.append(len(audio))

# Pad and batch - more efficient: create tensor and fill in one pass
max_len = max(batch_lengths)
batched_audio = torch.zeros(
len(batch_audio), max(max_len, 320), dtype=self._dtype, device=self._device
)
for i, audio in enumerate(batch_audio):
batched_audio[i, :len(audio)] = audio

# Format: (batch, samples) - same as transcribe_longform uses
batched_lengths = torch.tensor(batch_lengths, device=self._device, dtype=torch.long)

# Forward pass
encoded, encoded_len = self.forward(batched_audio, batched_lengths)

# Decode
batch_results = self.decoding.decode(self.head, encoded, encoded_len)

# Store transcribed segments
for idx, seg in enumerate(batch_segments):
transcribed_segments.append({
"channel": seg["channel"],
"transcription": batch_results[idx],
"boundaries": seg["boundaries"],
})

return transcribed_segments


class GigaAMEmo(GigaAM):
"""
Expand Down
101 changes: 100 additions & 1 deletion gigaam/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from subprocess import CalledProcessError, run
from typing import Tuple
from typing import List, Tuple, Union

import torch
import torchaudio
Expand Down Expand Up @@ -40,6 +40,105 @@ def load_audio(audio_path: str, sample_rate: int = SAMPLE_RATE) -> Tensor:
return torch.frombuffer(audio, dtype=torch.int16).float() / 32768.0


def load_multichannel_audio(
audio_input: Union[str, List[str]],
sample_rate: int = SAMPLE_RATE
) -> Tuple[List[Tensor], int]:
"""
Load multichannel audio from either:
- A single stereo/multichannel file (str)
- Multiple separate audio files (List[str])

Returns:
Tuple of (list of channel tensors, max_length)
"""
if isinstance(audio_input, str):
# Try to load with torchaudio first (more reliable for multichannel)
try:
import torchaudio
waveform, file_sr = torchaudio.load(audio_input)

# Resample if needed
if file_sr != sample_rate:
resampler = torchaudio.transforms.Resample(file_sr, sample_rate)
waveform = resampler(waveform)

# Convert to list of channel tensors
num_channels = waveform.shape[0]
channels = [waveform[i] for i in range(num_channels)]

max_length = max(len(ch) for ch in channels)
return channels, max_length
except Exception:
# Fallback to ffmpeg approach
pass

# Fallback: Load multichannel file with ffmpeg
cmd = [
"ffmpeg",
"-nostdin",
"-threads",
"0",
"-i",
audio_input,
"-f",
"s16le",
"-acodec",
"pcm_s16le",
"-ar",
str(sample_rate),
"-",
]
try:
audio_bytes = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as exc:
raise RuntimeError(f"Failed to load audio from {audio_input}") from exc

# Try to determine number of channels from file metadata
# Default to stereo (2 channels) for common cases
num_channels = 2 # Default assumption

# Try ffprobe if available
cmd_probe = [
"ffprobe",
"-v", "error",
"-show_entries", "stream=channels",
"-of", "default=noprint_wrappers=1:nokey=1",
audio_input
]
try:
result = run(cmd_probe, capture_output=True, check=True)
num_channels = int(result.stdout.strip().split()[0])
except (CalledProcessError, ValueError, IndexError):
# If ffprobe fails, try to infer from data size
# This is a heuristic - may not always work
pass

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
audio_data = torch.frombuffer(audio_bytes, dtype=torch.int16).float() / 32768.0

# Reshape to channels
if num_channels > 1 and len(audio_data) % num_channels == 0:
audio_data = audio_data.view(-1, num_channels).transpose(0, 1)
channels = [audio_data[i] for i in range(num_channels)]
else:
# Single channel or couldn't determine
channels = [audio_data]

max_length = max(len(ch) for ch in channels)
return channels, max_length

else:
# Load multiple separate files
channels = []
for path in audio_input:
channels.append(load_audio(path, sample_rate))

max_length = max(len(ch) for ch in channels)
return channels, max_length


class SpecScaler(nn.Module):
"""
Module that applies logarithmic scaling to spectrogram values.
Expand Down
Loading