diff --git a/src/commands/agent.rs b/src/commands/agent.rs new file mode 100644 index 000000000..367b26717 --- /dev/null +++ b/src/commands/agent.rs @@ -0,0 +1,375 @@ +use std::io::Write; + +use is_terminal::IsTerminal; +use rand::Rng; +use serde::Serialize; + +const THINKING_MESSAGES: &[&str] = &[ + "Chugging along...", + "Full steam ahead...", + "Leaving the station...", + "Building up steam...", + "Coupling the cars...", + "Switching tracks...", + "Rolling down the line...", + "Stoking the engine...", + "Pulling into the yard...", + "All aboard...", +]; + +use crate::{ + controllers::{ + chat::{ChatEvent, ChatRequest, build_chat_client, get_chat_url, stream_chat}, + environment::get_matched_environment, + project::get_project, + service::get_or_prompt_service, + }, + interact_or, + util::progress::{create_spinner, fail_spinner, success_spinner}, +}; + +use super::*; + +/// Interact with the Railway Agent +#[derive(Parser)] +#[clap( + about = "Interact with the Railway Agent", + after_help = "Examples:\n\n\ + railway agent # Interactive mode\n\ + railway agent -p \"what's the status of my deployment?\" # Single prompt\n\ + railway agent -p \"why is my service crashing?\" --json # JSON output" +)] +pub struct Args { + /// Send a single prompt (omit for interactive mode) + #[clap(short, long, value_name = "MESSAGE")] + prompt: Option, + + /// Output in JSON format + #[clap(long)] + json: bool, + + /// Continue an existing chat thread + #[clap(long, value_name = "ID")] + thread_id: Option, + + /// Service to scope the chat to (name or ID) + #[clap(short, long)] + service: Option, + + /// Environment to use (defaults to linked environment) + #[clap(short, long)] + environment: Option, +} + +#[derive(Default, Serialize)] +#[serde(rename_all = "camelCase")] +struct JsonResponse { + #[serde(skip_serializing_if = "Option::is_none")] + thread_id: Option, + response: String, + tool_calls: Vec, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct JsonToolCall { + tool_name: String, + args: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + is_error: bool, +} + +pub async fn command(args: Args) -> Result<()> { + let configs = Configs::new()?; + let linked_project = configs.get_linked_project().await?; + let project_id = linked_project.project.clone(); + + let client = GQLClient::new_authorized(&configs)?; + let project = get_project(&client, &configs, project_id.clone()).await?; + + let environment_id = match args.environment.clone() { + Some(env) => get_matched_environment(&project, env)?.id, + None => linked_project.environment_id()?.to_string(), + }; + + let service_id = get_or_prompt_service(Some(linked_project), project, args.service).await?; + + let chat_client = build_chat_client(&configs)?; + let url = get_chat_url(&configs); + let is_tty = std::io::stdout().is_terminal(); + + if let Some(message) = args.prompt { + run_single_shot( + &chat_client, + &url, + &ChatRequest { + project_id, + environment_id, + message, + thread_id: args.thread_id, + service_id, + }, + args.json, + is_tty, + ) + .await + } else { + run_repl( + &chat_client, + &url, + &project_id, + &environment_id, + service_id.as_deref(), + args.thread_id, + args.json, + is_tty, + ) + .await + } +} + +async fn run_single_shot( + client: &reqwest::Client, + url: &str, + request: &ChatRequest, + json: bool, + is_tty: bool, +) -> Result<()> { + if json { + let mut response = JsonResponse::default(); + + stream_chat(client, url, request, |event| { + accumulate_json_event(event, &mut response); + }) + .await?; + + println!("{}", serde_json::to_string_pretty(&response).unwrap()); + Ok(()) + } else { + let mut spinner: Option = None; + let mut has_printed_text = false; + + // Show a thinking spinner while waiting for the first event + if is_tty { + let msg = THINKING_MESSAGES[rand::thread_rng().gen_range(0..THINKING_MESSAGES.len())]; + println!(); + spinner = Some(create_spinner(msg.dimmed().to_string())); + } + + stream_chat(client, url, request, |event| { + handle_event_human(event, &mut spinner, &mut has_printed_text, is_tty); + }) + .await + } +} + +async fn run_repl( + client: &reqwest::Client, + url: &str, + project_id: &str, + environment_id: &str, + service_id: Option<&str>, + initial_thread_id: Option, + json: bool, + is_tty: bool, +) -> Result<()> { + interact_or!( + "Interactive mode requires a terminal. Use `railway -p \"your message\"` for non-interactive use." + ); + + println!( + "{}", + "Railway Agent (type 'exit' or Ctrl+C to quit)".dimmed() + ); + println!(); + + let mut thread_id = initial_thread_id; + + loop { + let input = inquire::Text::new("You:") + .with_render_config(Configs::get_render_config()) + .prompt(); + + let message = match input { + Ok(msg) + if msg.trim().eq_ignore_ascii_case("exit") + || msg.trim().eq_ignore_ascii_case("quit") => + { + break; + } + Ok(msg) if msg.trim().is_empty() => continue, + Ok(msg) => msg, + Err(inquire::InquireError::OperationInterrupted) => break, + Err(e) => return Err(e.into()), + }; + + let request = ChatRequest { + project_id: project_id.to_string(), + environment_id: environment_id.to_string(), + message, + thread_id: thread_id.clone(), + service_id: service_id.map(|s| s.to_string()), + }; + + if json { + let mut response = JsonResponse::default(); + + stream_chat(client, url, &request, |event| { + if let ChatEvent::Metadata { + thread_id: ref tid, .. + } = event + { + thread_id = Some(tid.clone()); + } + accumulate_json_event(event, &mut response); + }) + .await?; + + println!("{}", serde_json::to_string_pretty(&response).unwrap()); + } else { + let mut spinner: Option = None; + let mut has_printed_text = false; + + if is_tty { + let msg = + THINKING_MESSAGES[rand::thread_rng().gen_range(0..THINKING_MESSAGES.len())]; + println!(); + spinner = Some(create_spinner(msg.dimmed().to_string())); + } + + stream_chat(client, url, &request, |event| { + if let ChatEvent::Metadata { + thread_id: ref tid, .. + } = event + { + thread_id = Some(tid.clone()); + } + handle_event_human(event, &mut spinner, &mut has_printed_text, is_tty); + }) + .await?; + } + + println!(); + } + + Ok(()) +} + +fn handle_event_human( + event: ChatEvent, + spinner: &mut Option, + has_printed_text: &mut bool, + is_tty: bool, +) { + match event { + ChatEvent::Chunk { text } => { + if let Some(s) = spinner.take() { + s.finish_and_clear(); + } + if !*has_printed_text { + println!(); + print!("{} ", "Railway Agent:".purple().bold()); + *has_printed_text = true; + } + print!("{}", text); + let _ = std::io::stdout().flush(); + } + ChatEvent::ToolCallReady { tool_name, .. } => { + if is_tty { + if let Some(s) = spinner.take() { + s.finish_and_clear(); + } + *has_printed_text = false; + println!(); + *spinner = Some(create_spinner(format!( + "{} {}", + "╰─".dimmed(), + format!(" Agent Tool: {tool_name} ") + .truecolor(255, 255, 255) + .on_truecolor(68, 68, 68) + ))); + } + } + ChatEvent::ToolExecutionComplete { is_error, .. } => { + if let Some(s) = spinner { + if is_error { + fail_spinner( + s, + format!( + "{}", + " Tool failed " + .truecolor(255, 255, 255) + .on_truecolor(68, 68, 68) + ), + ); + } else { + success_spinner( + s, + format!( + "{}", + " Done ".truecolor(255, 255, 255).on_truecolor(68, 68, 68) + ), + ); + } + } + *spinner = None; + } + ChatEvent::Error { message } => { + if let Some(s) = spinner.take() { + s.finish_and_clear(); + } + eprintln!("{}: {}", "Error".red().bold(), message); + } + ChatEvent::Aborted { reason } => { + if let Some(s) = spinner.take() { + s.finish_and_clear(); + } + let msg = reason.unwrap_or_else(|| "Request was aborted".to_string()); + eprintln!("{}: {}", "Aborted".yellow().bold(), msg); + } + ChatEvent::WorkflowCompleted { .. } => { + println!(); + } + ChatEvent::Metadata { .. } => { + // Thread ID captured by caller; no output + } + } +} + +fn accumulate_json_event(event: ChatEvent, response: &mut JsonResponse) { + match event { + ChatEvent::Metadata { thread_id, .. } => { + response.thread_id = Some(thread_id); + } + ChatEvent::Chunk { text } => { + response.response.push_str(&text); + } + ChatEvent::ToolCallReady { + tool_name, args, .. + } => { + response.tool_calls.push(JsonToolCall { + tool_name, + args, + result: None, + is_error: false, + }); + } + ChatEvent::ToolExecutionComplete { + result, is_error, .. + } => { + if let Some(last) = response.tool_calls.last_mut() { + last.result = Some(result); + last.is_error = is_error; + } + } + ChatEvent::Error { message } => { + response.response.push_str(&format!("\nError: {message}")); + } + ChatEvent::Aborted { reason } => { + let msg = reason.unwrap_or_else(|| "Request was aborted".to_string()); + response.response.push_str(&format!("\nAborted: {msg}")); + } + ChatEvent::WorkflowCompleted { .. } => {} + } +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 957aaba5c..af21b6db6 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -48,4 +48,5 @@ pub mod variable; pub mod volume; pub mod whoami; +pub mod agent; pub mod check_updates; diff --git a/src/controllers/chat.rs b/src/controllers/chat.rs new file mode 100644 index 000000000..6c221a918 --- /dev/null +++ b/src/controllers/chat.rs @@ -0,0 +1,200 @@ +use std::time::Duration; + +use anyhow::{Result, bail}; +use reqwest::{ + Client, + header::{HeaderMap, HeaderValue}, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + client::auth_failure_error, commands::Environment, config::Configs, consts, + errors::RailwayError, +}; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ChatRequest { + pub project_id: String, + pub environment_id: String, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub thread_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub service_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChatEvent { + Metadata { + #[serde(rename = "threadId")] + thread_id: String, + #[serde(rename = "streamId")] + stream_id: String, + }, + Chunk { + text: String, + }, + ToolCallReady { + #[serde(rename = "toolCallId")] + tool_call_id: String, + #[serde(rename = "toolName")] + tool_name: String, + args: serde_json::Value, + }, + ToolExecutionComplete { + #[serde(rename = "toolCallId")] + tool_call_id: String, + result: serde_json::Value, + #[serde(rename = "isError")] + is_error: bool, + }, + Error { + message: String, + }, + Aborted { + #[serde(default)] + reason: Option, + }, + WorkflowCompleted { + #[serde(rename = "completedAt")] + completed_at: String, + }, +} + +pub fn get_chat_url(configs: &Configs) -> String { + format!("https://backboard.{}/api/v1/chat", configs.get_host()) +} + +/// Build an HTTP client for the chat API. +/// +/// The chat endpoint requires user OAuth tokens — project access tokens +/// (`RAILWAY_TOKEN`) are not supported. We skip project tokens and only +/// use the user's OAuth bearer token. +pub fn build_chat_client(configs: &Configs) -> Result { + let mut headers = HeaderMap::new(); + if let Some(token) = configs.get_railway_auth_token() { + headers.insert( + "authorization", + HeaderValue::from_str(&format!("Bearer {token}"))?, + ); + } else { + return Err(RailwayError::Unauthorized); + } + headers.insert( + "x-source", + HeaderValue::from_static(consts::get_user_agent()), + ); + let client = Client::builder() + .danger_accept_invalid_certs(matches!(Configs::get_environment_id(), Environment::Dev)) + .user_agent(consts::get_user_agent()) + .default_headers(headers) + .connect_timeout(Duration::from_secs(30)) + // No overall timeout — SSE streams are long-lived + .build() + .unwrap(); + Ok(client) +} + +pub async fn stream_chat( + client: &Client, + url: &str, + request: &ChatRequest, + mut on_event: impl FnMut(ChatEvent), +) -> Result<()> { + let mut response = client + .post(url) + .header("Accept", "text/event-stream") + .header("Content-Type", "application/json") + .json(request) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + match status.as_u16() { + 401 | 403 => return Err(auth_failure_error().into()), + 429 => return Err(RailwayError::Ratelimited.into()), + _ => { + let body = response.text().await.unwrap_or_default(); + bail!("Chat request failed ({}): {}", status, body); + } + } + } + + let mut buffer = String::new(); + let mut current_event_type = String::new(); + let mut current_data = String::new(); + + while let Some(chunk) = response.chunk().await? { + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim_end_matches('\r').to_string(); + buffer = buffer[line_end + 1..].to_string(); + + if line.is_empty() { + // Empty line signals end of SSE event + if !current_data.is_empty() { + if let Some(event) = parse_sse_event(¤t_event_type, ¤t_data) { + on_event(event); + } + current_event_type.clear(); + current_data.clear(); + } + } else if let Some(value) = line.strip_prefix("event: ") { + current_event_type = value.to_string(); + } else if let Some(value) = line.strip_prefix("data: ") { + if !current_data.is_empty() { + current_data.push('\n'); + } + current_data.push_str(value); + } + // Ignore comments (lines starting with :) and unknown fields + } + } + + Ok(()) +} + +fn parse_sse_event(event_type: &str, data: &str) -> Option { + match event_type { + "metadata" => { + serde_json::from_str(data) + .ok() + .map(|v: serde_json::Value| ChatEvent::Metadata { + thread_id: v["threadId"].as_str().unwrap_or_default().to_string(), + stream_id: v["streamId"].as_str().unwrap_or_default().to_string(), + }) + } + "chunk" => serde_json::from_str(data) + .ok() + .map(|v: serde_json::Value| ChatEvent::Chunk { + text: v["text"].as_str().unwrap_or_default().to_string(), + }), + "tool_call_ready" => serde_json::from_str(data).ok(), + "tool_execution_complete" => serde_json::from_str(data).ok(), + "error" => serde_json::from_str(data) + .ok() + .map(|v: serde_json::Value| ChatEvent::Error { + message: v["error"] + .as_str() + .or_else(|| v["message"].as_str()) + .unwrap_or("Unknown error") + .to_string(), + }), + "aborted" => { + serde_json::from_str(data) + .ok() + .map(|v: serde_json::Value| ChatEvent::Aborted { + reason: v["reason"].as_str().map(|s| s.to_string()), + }) + } + "workflow_completed" => serde_json::from_str(data).ok(), + // Ignore events we don't need to surface: started, tool_call_streaming_start, + // tool_call_delta, tool_execution_start, tool_output_delta, step_finish, + // completed, subagent_start, subagent_complete + _ => None, + } +} diff --git a/src/controllers/mod.rs b/src/controllers/mod.rs index 0320dbe1a..9d939cf8c 100644 --- a/src/controllers/mod.rs +++ b/src/controllers/mod.rs @@ -1,3 +1,4 @@ +pub mod chat; pub mod config; pub mod database; pub mod deployment; diff --git a/src/main.rs b/src/main.rs index 2076e8bf1..62e769f04 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,6 +29,7 @@ mod telemetry; // Specify the modules you want to include in the commands_enum! macro commands!( add, + agent, autoupdate, bucket, completion,