Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api-docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -4778,6 +4778,10 @@
"description": "OpenAI settings",
"type": "string"
},
"api_url": {
"description": "Custom transcription API base URL (OpenAI adapter only)",
"type": "string"
},
"attention_context_left": {
"description": "NVIDIA Parakeet-specific parameters for long-form audio",
"type": "integer"
Expand Down Expand Up @@ -4930,6 +4934,10 @@
"threads": {
"type": "integer"
},
"timeout_minutes": {
"description": "HTTP request timeout in minutes (OpenAI adapter with custom base URL)",
"type": "integer"
},
"vad_method": {
"description": "VAD (Voice Activity Detection) settings",
"type": "string"
Expand Down
6 changes: 6 additions & 0 deletions api-docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ definitions:
api_key:
description: OpenAI settings
type: string
api_url:
description: Custom transcription API base URL (OpenAI adapter only)
type: string
attention_context_left:
description: NVIDIA Parakeet-specific parameters for long-form audio
type: integer
Expand Down Expand Up @@ -747,6 +750,9 @@ definitions:
type: number
threads:
type: integer
timeout_minutes:
description: HTTP request timeout in minutes (OpenAI adapter with custom base URL)
type: integer
vad_method:
description: VAD (Voice Activity Detection) settings
type: string
Expand Down
4 changes: 3 additions & 1 deletion internal/models/transcription.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ type WhisperXParams struct {
CallbackURL *string `json:"callback_url,omitempty" gorm:"type:text"`

// OpenAI settings
APIKey *string `json:"api_key,omitempty" gorm:"type:text"`
APIKey *string `json:"api_key,omitempty" gorm:"type:text"`
APIURL *string `json:"api_url,omitempty" gorm:"type:text"`
TimeoutMinutes *int `json:"timeout_minutes,omitempty" gorm:"type:int"`

// Voxtral settings
MaxNewTokens *int `json:"max_new_tokens,omitempty" gorm:"type:int"`
Expand Down
12 changes: 11 additions & 1 deletion internal/transcription/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ err := adapter.ValidateParameters(params)
| `whisperx` | `whisper` | 90+ languages | Timestamps, Diarization, Translation |
| `parakeet` | `nvidia_parakeet` | English only | Timestamps, Long-form, High Quality |
| `canary` | `nvidia_canary` | 12 languages | Timestamps, Translation, Multilingual |
| `openai_whisper` | `openai` | 57 languages | Timestamps, Diarization, Translation, Custom Endpoint |

### Diarization Models

Expand Down Expand Up @@ -221,9 +222,18 @@ params := map[string]interface{}{
// NVIDIA Canary with translation
params := map[string]interface{}{
"source_lang": "es",
"target_lang": "en",
"target_lang": "en",
"task": "translate",
}

// OpenAI with custom self-hosted endpoint
params := map[string]interface{}{
"base_url": "http://localhost:8000/v1",
"model": "Systran/faster-whisper-large-v3",
"timeout_minutes": 30,
"diarize": true,
"diarize_model": "pyannote",
}
```

## Testing
Expand Down
50 changes: 38 additions & 12 deletions internal/transcription/adapters/openai_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Features: map[string]bool{
"timestamps": true, // Verbose JSON response includes segments
"word_level": false, // Not supported by standard API yet (unless using verbose_json with timestamp_granularities which is beta)
"diarization": false, // Not supported by OpenAI API
"diarization": true, // Post-processing via pyannote/sortformer pipeline
"translation": true,
"language_detection": true,
"vad": true, // Implicit
Expand All @@ -59,13 +59,19 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Description: "OpenAI API Key (overrides system default)",
Group: "authentication",
},
{
Name: "base_url",
Type: "string",
Required: false,
Description: "Custom transcription API base URL (overrides server default)",
Group: "authentication",
},
{
Name: "model",
Type: "string",
Required: false,
Default: "whisper-1",
Options: []string{"whisper-1"},
Description: "ID of the model to use",
Description: "Model name (e.g. whisper-1, or any model exposed by a custom endpoint)",
Group: "basic",
},
{
Expand All @@ -92,6 +98,15 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Description: "Sampling temperature",
Group: "quality",
},
{
Name: "timeout_minutes",
Type: "int",
Required: false,
Default: nil,
Min: &[]float64{1}[0],
Description: "HTTP request timeout in minutes (increase for large files on self-hosted endpoints)",
Group: "advanced",
},
}

baseAdapter := NewBaseAdapter("openai_whisper", "", capabilities, schema)
Expand Down Expand Up @@ -153,7 +168,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
apiKey = key
}

if apiKey == "" {
const officialURL = "https://api.openai.com/v1/audio/transcriptions"
endpointURL := officialURL
if url := a.GetStringParameter(params, "base_url"); url != "" {
endpointURL = strings.TrimRight(url, "/") + "/audio/transcriptions"
}
isOfficialEndpoint := endpointURL == officialURL

if apiKey == "" && isOfficialEndpoint {
writeLog("Error: OpenAI API key is required but not provided")
return nil, fmt.Errorf("OpenAI API key is required but not provided")
}
Expand Down Expand Up @@ -188,7 +210,7 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
writeLog("Model: %s", model)
_ = writer.WriteField("model", model)

if strings.HasPrefix(model, "gpt-4o") {
if isOfficialEndpoint && strings.HasPrefix(model, "gpt-4o") {
if strings.Contains(model, "diarize") {
_ = writer.WriteField("response_format", "diarized_json")
} else {
Expand All @@ -197,7 +219,6 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
// gpt-4o models don't support timestamp_granularities with these formats
} else {
_ = writer.WriteField("response_format", "verbose_json")
// timestamp_granularities is only supported for whisper-1
if model == "whisper-1" {
_ = writer.WriteField("timestamp_granularities[]", "word") // Request word timestamps
_ = writer.WriteField("timestamp_granularities[]", "segment") // Request segment timestamps
Expand All @@ -224,8 +245,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
}

// Create request
writeLog("Sending request to OpenAI API...")
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/audio/transcriptions", body)
writeLog("Sending request to %s...", endpointURL)
req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, body)
if err != nil {
writeLog("Error: Failed to create request: %v", err)
return nil, fmt.Errorf("failed to create request: %w", err)
Expand All @@ -235,9 +256,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
req.Header.Set("Authorization", "Bearer "+apiKey)

// Execute request
client := &http.Client{
Timeout: 10 * time.Minute, // Generous timeout for large files
timeout := 10 * time.Minute
if !isOfficialEndpoint {
timeout = 30 * time.Minute // Default for self-hosted endpoints
}
if t := a.GetIntParameter(params, "timeout_minutes"); t > 0 {
timeout = time.Duration(t) * time.Minute
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
writeLog("Error: Request failed: %v", err)
Expand All @@ -247,8 +273,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn

if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
writeLog("Error: OpenAI API error (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(respBody))
writeLog("Error: transcription API error (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("transcription API error (status %d): %s", resp.StatusCode, string(respBody))
}

writeLog("Response received. Parsing...")
Expand Down
147 changes: 145 additions & 2 deletions internal/transcription/adapters/sortformer_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -64,6 +65,16 @@ func NewSortformerAdapter(envPath string) *SortformerAdapter {
Description: "Maximum number of speakers (optimized for 4)",
Group: "basic",
},
{
Name: "min_speakers",
Type: "int",
Required: false,
Default: 1,
Min: &[]float64{1}[0],
Max: &[]float64{8}[0],
Description: "Minimum number of speakers",
Group: "basic",
},
{
Name: "batch_size",
Type: "int",
Expand Down Expand Up @@ -423,10 +434,22 @@ func (s *SortformerAdapter) buildSortformerArgs(input interfaces.AudioInput, par
func (s *SortformerAdapter) parseResult(tempDir string, input interfaces.AudioInput, params map[string]interface{}) (*interfaces.DiarizationResult, error) {
outputFormat := s.GetStringParameter(params, "output_format")

var (
result *interfaces.DiarizationResult
err error
)

if outputFormat == OutputFormatJSON {
return s.parseJSONResult(tempDir)
result, err = s.parseJSONResult(tempDir)
} else {
result, err = s.parseRTTMResult(tempDir, input)
}

if err != nil {
return nil, err
}
return s.parseRTTMResult(tempDir, input)

return s.enforceSpeakerLimit(result, params), nil
}

// parseJSONResult parses JSON format output
Expand Down Expand Up @@ -538,6 +561,126 @@ func (s *SortformerAdapter) parseRTTMResult(tempDir string, input interfaces.Aud
return result, nil
}

type sortformerSpeakerDuration struct {
speaker string
duration float64
}

func (s *SortformerAdapter) enforceSpeakerLimit(result *interfaces.DiarizationResult, params map[string]interface{}) *interfaces.DiarizationResult {
maxSpeakers := s.GetIntParameter(params, "max_speakers")
if result == nil || maxSpeakers <= 0 {
return result
}

speakerDurations := make(map[string]float64)
for _, segment := range result.Segments {
duration := segment.End - segment.Start
if duration < 0 {
duration = 0
}
speakerDurations[segment.Speaker] += duration
}

originalSpeakerCount := len(speakerDurations)
if originalSpeakerCount <= maxSpeakers {
result.SpeakerCount = originalSpeakerCount
result.Speakers = sortedSpeakerList(speakerDurations)
return result
}

rankedSpeakers := make([]sortformerSpeakerDuration, 0, len(speakerDurations))
for speaker, duration := range speakerDurations {
rankedSpeakers = append(rankedSpeakers, sortformerSpeakerDuration{
speaker: speaker,
duration: duration,
})
}

sort.Slice(rankedSpeakers, func(i, j int) bool {
if rankedSpeakers[i].duration == rankedSpeakers[j].duration {
return rankedSpeakers[i].speaker < rankedSpeakers[j].speaker
}
return rankedSpeakers[i].duration > rankedSpeakers[j].duration
})

keptSpeakers := make(map[string]bool, maxSpeakers)
fallbackSpeaker := rankedSpeakers[0].speaker
for i := 0; i < maxSpeakers && i < len(rankedSpeakers); i++ {
keptSpeakers[rankedSpeakers[i].speaker] = true
}

originalSegments := append([]interfaces.DiarizationSegment(nil), result.Segments...)
for i := range result.Segments {
if keptSpeakers[result.Segments[i].Speaker] {
continue
}

result.Segments[i].Speaker = nearestKeptSpeaker(originalSegments, i, keptSpeakers, fallbackSpeaker)
}

rebuildSpeakerSummary(result)

logger.Warn("Sortformer returned more speakers than requested; remapped extra speaker labels",
"requested_max_speakers", maxSpeakers,
"original_speakers", originalSpeakerCount,
"final_speakers", result.SpeakerCount)

return result
}

func nearestKeptSpeaker(segments []interfaces.DiarizationSegment, targetIndex int, keptSpeakers map[string]bool, fallbackSpeaker string) string {
target := segments[targetIndex]
bestSpeaker := fallbackSpeaker
bestDistance := -1.0

for i, segment := range segments {
if i == targetIndex || !keptSpeakers[segment.Speaker] {
continue
}

distance := segmentDistance(target, segment)
if bestDistance < 0 || distance < bestDistance || (distance == bestDistance && segment.Speaker < bestSpeaker) {
bestSpeaker = segment.Speaker
bestDistance = distance
}
}

return bestSpeaker
}

func segmentDistance(a, b interfaces.DiarizationSegment) float64 {
if b.End <= a.Start {
return a.Start - b.End
}
if a.End <= b.Start {
return b.Start - a.End
}
return 0
}

func rebuildSpeakerSummary(result *interfaces.DiarizationResult) {
speakers := make(map[string]float64)
for _, segment := range result.Segments {
duration := segment.End - segment.Start
if duration < 0 {
duration = 0
}
speakers[segment.Speaker] += duration
}

result.SpeakerCount = len(speakers)
result.Speakers = sortedSpeakerList(speakers)
}

func sortedSpeakerList(speakers map[string]float64) []string {
speakerList := make([]string, 0, len(speakers))
for speaker := range speakers {
speakerList = append(speakerList, speaker)
}
sort.Strings(speakerList)
return speakerList
}

// GetEstimatedProcessingTime provides Sortformer-specific time estimation
func (s *SortformerAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration {
// Sortformer is typically very fast, often faster than real-time
Expand Down
Loading