diff --git a/README.md b/README.md index 0f1798a..d54c1a7 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,7 @@ Notes: - `pipeline.greeting` plays when a session starts. `pipeline.greeting_outgoing` is used for outbound SIP calls when present. - `pipeline.debug = true` emits timing events over the DataChannel. - `stt.provider = "openai"` uses Whisper-style final transcription instead of streaming partials. +- `test.turn_endpoint = true` enables a dev-only `POST /test-turn` text regression harness. Keep this disabled on public deployments. - `llm.provider = "ollama"` uses a local Ollama instance instead of OpenAI. Make sure Ollama is running and the specified model is pulled (e.g., `ollama pull llama3.2`). - `stt.provider = "vibevoice"` and `tts.provider = "vibevoice"` use local VibeVoice models. Start the Python servers first (see [Local VibeVoice Setup](#local-vibevoice-setup)). - `rag.provider` enables built-in RAG. When set, the server embeds each user utterance and retrieves the top-k most relevant chunks from your vector store before calling the LLM — all in a single LLM pass with no tool-call overhead. diff --git a/config.toml.example b/config.toml.example index ffb38a7..f2cb231 100644 --- a/config.toml.example +++ b/config.toml.example @@ -16,6 +16,9 @@ greeting = "" # Spoken when a session connects greeting_outgoing = "" # Spoken for outbound SIP calls; falls back to greeting debug = false # Emit timing events over the DataChannel +[test] +turn_endpoint = false # Dev-only: enable POST /test-turn text regression harness. Do not expose publicly. + # Provider selection [stt] @@ -73,4 +76,4 @@ embedding_model = "text-embedding-3-small" # optional, this is the default [supabase] url = "https://xxx.supabase.co" api_key = "your-service-role-key" -function = "match_documents" # optional, this is the default \ No newline at end of file +function = "match_documents" # optional, this is the default diff --git a/internal/config/config.go b/internal/config/config.go index cb7fcc8..d8a44bd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,6 +14,7 @@ type Config struct { Server ServerConfig `toml:"server"` Plugins PluginsConfig `toml:"plugins"` Pipeline PipelineConfig `toml:"pipeline"` + Test TestConfig `toml:"test"` STT STTConfig `toml:"stt"` LLM LLMConfig `toml:"llm"` TTS TTSConfig `toml:"tts"` @@ -39,6 +40,10 @@ type PipelineConfig struct { Debug bool `toml:"debug"` // Emit per-turn timing events over the DataChannel } +type TestConfig struct { + TurnEndpoint bool `toml:"turn_endpoint"` // Enable dev-only POST /test-turn text harness. Disabled by default. +} + type ServerConfig struct { Port string `toml:"port"` PublicIP string `toml:"public_ip"` // Public IP for ICE candidates when behind NAT (e.g., EC2). Leave empty for local/direct connections. diff --git a/internal/testturn/handler.go b/internal/testturn/handler.go new file mode 100644 index 0000000..2e2f6b1 --- /dev/null +++ b/internal/testturn/handler.go @@ -0,0 +1,244 @@ +package testturn + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + "time" + + "github.com/streamcoreai/server/internal/config" + "github.com/streamcoreai/server/internal/llm" + "github.com/streamcoreai/server/internal/plugin" + "github.com/streamcoreai/server/internal/rag" +) + +const maxRequestBytes = 1 << 20 +const unsupportedVisionTool = "vision.analyze" + +type Message struct { + Role string `json:"role"` + Text string `json:"text"` + At string `json:"at,omitempty"` +} + +type TurnRequest struct { + Text string `json:"text,omitempty"` + CustomerText string `json:"customerText,omitempty"` + Messages []Message `json:"messages,omitempty"` +} + +type Event struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Stage string `json:"stage,omitempty"` + Ms int64 `json:"ms,omitempty"` +} + +type TurnResponse struct { + Spoken string `json:"spoken"` + Events []Event `json:"events,omitempty"` + LatencyMs int64 `json:"latencyMs"` +} + +type turnRunner func(context.Context, TurnRequest) (TurnResponse, error) + +// NewHandler returns the disabled-by-default text turn harness used by local +// regression tools. It bypasses WebRTC, STT, and TTS, but reuses the configured +// LLM, plugins, skills, and optional RAG context. +func NewHandler(cfg *config.Config, pluginMgr *plugin.Manager, ragClient rag.Client) http.HandlerFunc { + return newHTTPHandler(newAgent(cfg, pluginMgr, ragClient).run) +} + +func newHTTPHandler(run turnRunner) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) + return + } + + var req TurnRequest + decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxRequestBytes)) + if err := decoder.Decode(&req); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + writeJSON(w, http.StatusRequestEntityTooLarge, map[string]string{"error": "request body too large"}) + return + } + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON request"}) + return + } + + if req.latestCustomerText() == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "text or customerText is required"}) + return + } + + resp, err := run(r.Context(), req) + if err != nil { + log.Printf("[test-turn] error: %v", err) + writeJSON(w, http.StatusBadGateway, map[string]string{"error": "test turn failed"}) + return + } + + writeJSON(w, http.StatusOK, resp) + } +} + +type agent struct { + cfg *config.Config + pluginMgr *plugin.Manager + ragClient rag.Client +} + +func newAgent(cfg *config.Config, pluginMgr *plugin.Manager, ragClient rag.Client) *agent { + return &agent{cfg: cfg, pluginMgr: pluginMgr, ragClient: ragClient} +} + +func (a *agent) run(ctx context.Context, req TurnRequest) (TurnResponse, error) { + started := time.Now() + + client, err := llm.NewClient(a.cfg) + if err != nil { + return TurnResponse{}, err + } + a.configureClient(client) + + input := req.prompt() + if a.ragClient != nil { + chunks, err := a.ragClient.Search(ctx, req.latestCustomerText(), 0) + if err != nil { + log.Printf("[test-turn] RAG search error: %v", err) + } else if len(chunks) > 0 { + input = fmt.Sprintf("[Context:\n%s]\n\n%s", strings.Join(chunks, "\n---\n"), input) + } + } + + events := make([]Event, 0, 8) + spoken, err := client.Chat(ctx, input, func(chunk string) { + events = append(events, Event{Type: "response", Text: chunk}) + }, nil) + if err != nil { + return TurnResponse{}, err + } + spoken = strings.TrimSpace(spoken) + if spoken == "" { + return TurnResponse{}, fmt.Errorf("LLM returned an empty response") + } + + latency := time.Since(started).Milliseconds() + events = append(events, Event{Type: "timing", Stage: "llm_complete", Ms: latency}) + return TurnResponse{ + Spoken: spoken, + Events: events, + LatencyMs: latency, + }, nil +} + +func (a *agent) configureClient(client llm.Client) { + if a.pluginMgr == nil { + return + } + + tools := a.pluginMgr.Tools() + if len(tools) > 0 { + defs := make([]llm.ToolDefinition, 0, len(tools)) + for _, tool := range tools { + if tool.Name() == unsupportedVisionTool { + log.Printf("[test-turn] skipping unsupported pipeline-dependent tool: %s", tool.Name()) + continue + } + defs = append(defs, llm.ToolDefinition{ + Name: tool.Name(), + Description: tool.Description(), + Parameters: tool.Parameters(), + }) + } + if len(defs) > 0 { + client.SetTools(defs) + client.SetToolHandler(func(callCtx context.Context, call llm.ToolCall) (string, error) { + tool, ok := a.pluginMgr.GetTool(call.Name) + if !ok { + return "", fmt.Errorf("unknown tool: %s", call.Name) + } + return tool.Execute(call.Arguments) + }) + } + } + + if skillsPrompt := a.pluginMgr.SkillsPrompt(); skillsPrompt != "" { + client.AppendSystemPrompt(skillsPrompt) + } +} + +func (req TurnRequest) latestCustomerText() string { + if text := strings.TrimSpace(req.CustomerText); text != "" { + return text + } + if text := strings.TrimSpace(req.Text); text != "" { + return text + } + for i := len(req.Messages) - 1; i >= 0; i-- { + if strings.EqualFold(req.Messages[i].Role, "customer") || strings.EqualFold(req.Messages[i].Role, "user") { + return strings.TrimSpace(req.Messages[i].Text) + } + } + return "" +} + +func (req TurnRequest) prompt() string { + messages := req.normalizedMessages() + if len(messages) == 0 { + return req.latestCustomerText() + } + + var b strings.Builder + b.WriteString("Conversation transcript:\n") + for _, msg := range messages { + switch strings.ToLower(strings.TrimSpace(msg.Role)) { + case "assistant": + b.WriteString("Assistant: ") + default: + b.WriteString("User: ") + } + b.WriteString(strings.TrimSpace(msg.Text)) + b.WriteString("\n") + } + b.WriteString("\nRespond to the latest user turn. Keep the reply concise and natural for voice.") + return b.String() +} + +func (req TurnRequest) normalizedMessages() []Message { + messages := make([]Message, 0, len(req.Messages)+1) + for _, msg := range req.Messages { + text := strings.TrimSpace(msg.Text) + if text == "" { + continue + } + role := strings.TrimSpace(msg.Role) + if role == "" { + role = "user" + } + messages = append(messages, Message{Role: role, Text: text, At: msg.At}) + } + + latest := req.latestCustomerText() + if latest == "" { + return messages + } + if len(messages) == 0 || strings.TrimSpace(messages[len(messages)-1].Text) != latest { + messages = append(messages, Message{Role: "user", Text: latest}) + } + return messages +} + +func writeJSON(w http.ResponseWriter, status int, body interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(body); err != nil { + log.Printf("[test-turn] write response error: %v", err) + } +} diff --git a/internal/testturn/handler_test.go b/internal/testturn/handler_test.go new file mode 100644 index 0000000..dbe6a0b --- /dev/null +++ b/internal/testturn/handler_test.go @@ -0,0 +1,144 @@ +package testturn + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/streamcoreai/server/internal/llm" + "github.com/streamcoreai/server/internal/plugin" +) + +func TestHandlerReturnsSpokenResponse(t *testing.T) { + handler := newHTTPHandler(func(ctx context.Context, req TurnRequest) (TurnResponse, error) { + if got := req.latestCustomerText(); got != "what does StreamCoreAI do?" { + t.Fatalf("latestCustomerText() = %q", got) + } + if prompt := req.prompt(); !strings.Contains(prompt, "Conversation transcript:") { + t.Fatalf("prompt() missing transcript context: %q", prompt) + } + return TurnResponse{Spoken: "StreamCoreAI runs realtime voice agents.", LatencyMs: 12}, nil + }) + + req := httptest.NewRequest(http.MethodPost, "/test-turn", strings.NewReader(`{ + "suiteName": "Voice Agent TestOps", + "customerText": "what does StreamCoreAI do?", + "merchant": {"name": "StreamCoreAI demo"}, + "messages": [ + {"role": "customer", "text": "what does StreamCoreAI do?"} + ] + }`)) + rec := httptest.NewRecorder() + + handler(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, body = %s", rec.Code, rec.Body.String()) + } + + var body TurnResponse + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode response: %v", err) + } + if body.Spoken != "StreamCoreAI runs realtime voice agents." { + t.Fatalf("spoken = %q", body.Spoken) + } +} + +func TestHandlerRejectsMissingText(t *testing.T) { + handler := newHTTPHandler(func(ctx context.Context, req TurnRequest) (TurnResponse, error) { + t.Fatal("runner should not be called") + return TurnResponse{}, nil + }) + + req := httptest.NewRequest(http.MethodPost, "/test-turn", strings.NewReader(`{"messages":[]}`)) + rec := httptest.NewRecorder() + + handler(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, body = %s", rec.Code, rec.Body.String()) + } +} + +func TestHandlerMapsRunnerErrors(t *testing.T) { + handler := newHTTPHandler(func(ctx context.Context, req TurnRequest) (TurnResponse, error) { + return TurnResponse{}, errors.New("provider unavailable") + }) + + req := httptest.NewRequest(http.MethodPost, "/test-turn", strings.NewReader(`{"text":"hello"}`)) + rec := httptest.NewRecorder() + + handler(rec, req) + + if rec.Code != http.StatusBadGateway { + t.Fatalf("status = %d, body = %s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "test turn failed") { + t.Fatalf("body missing generic runner error: %s", rec.Body.String()) + } + if strings.Contains(rec.Body.String(), "provider unavailable") { + t.Fatalf("body leaked runner error: %s", rec.Body.String()) + } +} + +func TestHandlerRejectsTooLargeBody(t *testing.T) { + handler := newHTTPHandler(func(ctx context.Context, req TurnRequest) (TurnResponse, error) { + t.Fatal("runner should not be called") + return TurnResponse{}, nil + }) + + body := `{"text":"` + strings.Repeat("a", maxRequestBytes) + `"}` + req := httptest.NewRequest(http.MethodPost, "/test-turn", strings.NewReader(body)) + rec := httptest.NewRecorder() + + handler(rec, req) + + if rec.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("status = %d, body = %s", rec.Code, rec.Body.String()) + } +} + +func TestAgentConfigureClientSkipsVisionTool(t *testing.T) { + pluginMgr := plugin.NewManager("") + pluginMgr.RegisterNative(fakeTool{name: unsupportedVisionTool}) + pluginMgr.RegisterNative(fakeTool{name: "math.calculate"}) + + client := &fakeLLMClient{} + newAgent(nil, pluginMgr, nil).configureClient(client) + + if len(client.tools) != 1 { + t.Fatalf("configured tools = %d, want 1", len(client.tools)) + } + if client.tools[0].Name != "math.calculate" { + t.Fatalf("configured tool = %q", client.tools[0].Name) + } +} + +type fakeTool struct { + name string +} + +func (t fakeTool) Name() string { return t.name } +func (t fakeTool) Description() string { return "test tool" } +func (t fakeTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) } +func (t fakeTool) Execute(json.RawMessage) (string, error) { return "ok", nil } +func (t fakeTool) ConfirmationRequired() bool { return false } +func (t fakeTool) ThinkingSound() bool { return false } + +type fakeLLMClient struct { + tools []llm.ToolDefinition +} + +func (c *fakeLLMClient) Chat(context.Context, string, func(string), func(string)) (string, error) { + return "", nil +} +func (c *fakeLLMClient) SetTools(tools []llm.ToolDefinition) { c.tools = tools } +func (c *fakeLLMClient) SetToolHandler(func(context.Context, llm.ToolCall) (string, error)) { +} +func (c *fakeLLMClient) AppendSystemPrompt(string) {} +func (c *fakeLLMClient) Reset() {} diff --git a/main.go b/main.go index 8cf7a27..d2aafb0 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "github.com/streamcoreai/server/internal/rag" "github.com/streamcoreai/server/internal/session" "github.com/streamcoreai/server/internal/signaling" + "github.com/streamcoreai/server/internal/testturn" turnserver "github.com/streamcoreai/server/internal/turn" ) @@ -69,6 +70,16 @@ func main() { w.Write([]byte("ok")) }) + if cfg.Test.TurnEndpoint { + log.Println("Dev test-turn endpoint enabled at POST /test-turn") + testTurnHandler := testturn.NewHandler(cfg, pluginMgr, ragClient) + if cfg.Server.JWTSecret != "" { + log.Println("JWT authentication enabled for /test-turn") + testTurnHandler = jwtMiddleware(cfg.Server.JWTSecret, testTurnHandler) + } + mux.HandleFunc("/test-turn", testTurnHandler) + } + if cfg.Server.JWTSecret != "" { mux.HandleFunc("/token", tokenHandler(cfg.Server.JWTSecret, cfg.Server.APIKey)) }