From 995b16eef8dc8d151e7dfeeddf869dd7f4008c5a Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 11 Mar 2026 07:34:06 +0100 Subject: [PATCH] feat: add ai chat panel --- src-tauri/Cargo.lock | 138 +++++- src-tauri/Cargo.toml | 1 + src-tauri/src/ai.rs | 158 +++++++ src-tauri/src/dbs/mod.rs | 1 + src-tauri/src/dbs/project.rs | 15 +- src-tauri/src/dbs/query.rs | 15 +- src-tauri/src/dbs/settings.rs | 91 ++++ src-tauri/src/dbs/workspace.rs | 15 +- src-tauri/src/drivers/pgsql.rs | 68 +-- src-tauri/src/main.rs | 45 +- src/App.tsx | 16 + src/components/ai-chat-panel.tsx | 610 ++++++++++++++++++++++++++ src/components/ai-settings-dialog.tsx | 201 +++++++++ src/components/command-palette.tsx | 7 +- src/components/connection-modal.tsx | 69 ++- src/components/tab-bar.tsx | 12 +- src/lib/ai-service.ts | 94 ++++ src/lib/schema-context.ts | 107 +++++ src/lib/sql-classify.ts | 28 ++ src/stores/settings-store.ts | 139 ++++++ src/stores/ui-store.ts | 18 + src/tauri.ts | 33 ++ 22 files changed, 1793 insertions(+), 88 deletions(-) create mode 100644 src-tauri/src/ai.rs create mode 100644 src-tauri/src/dbs/settings.rs create mode 100644 src/components/ai-chat-panel.tsx create mode 100644 src/components/ai-settings-dialog.tsx create mode 100644 src/lib/ai-service.ts create mode 100644 src/lib/schema-context.ts create mode 100644 src/lib/sql-classify.ts create mode 100644 src/stores/settings-store.ts diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index cfd8924..9e93d94 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -469,7 +469,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash", + "rustc-hash 1.1.0", "shlex", "syn 2.0.106", "which", @@ -2050,9 +2050,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasi 0.14.5+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -2552,6 +2554,7 @@ dependencies = [ "tokio", "tokio-rustls 0.26.4", "tower-service", + "webpki-roots 1.0.6", ] [[package]] @@ -3298,6 +3301,12 @@ version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "mac" version = "0.1.1" @@ -4705,6 +4714,61 @@ dependencies = [ "memchr", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases 0.2.1", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash 2.1.1", + "rustls 0.23.37", + "socket2 0.5.10", + "thiserror 2.0.16", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" +dependencies = [ + "bytes", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash 2.1.1", + "rustls 0.23.37", + "rustls-pki-types", + "slab", + "thiserror 2.0.16", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases 0.2.1", + "libc", + "once_cell", + "socket2 0.5.10", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.40" @@ -4946,6 +5010,44 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cadadef317c2f20755a64d7fdc48f9e7178ee6b0e1f7fce33fa60f1d68a276e6" +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-rustls 0.27.7", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls 0.23.37", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tokio-rustls 0.26.4", + "tower 0.5.2", + "tower-http 0.6.8", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots 1.0.6", +] + [[package]] name = "reqwest" version = "0.13.2" @@ -5093,6 +5195,7 @@ dependencies = [ "portable-pty", "postgres-native-tls", "rayon", + "reqwest 0.12.28", "russh", "serde", "serde_json", @@ -5212,6 +5315,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -5315,6 +5424,7 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ + "web-time", "zeroize", ] @@ -5645,6 +5755,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_with" version = "3.14.0" @@ -6306,7 +6428,7 @@ dependencies = [ "percent-encoding", "plist", "raw-window-handle", - "reqwest", + "reqwest 0.13.2", "serde", "serde_json", "serde_repr", @@ -6516,7 +6638,7 @@ dependencies = [ "minisign-verify", "osakit", "percent-encoding", - "reqwest", + "reqwest 0.13.2", "rustls 0.23.37", "semver", "serde", @@ -7573,6 +7695,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webkit2gtk" version = "2.0.2" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 455032e..99fc1e8 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -43,3 +43,4 @@ sysinfo = "0.38.3" csv = "1.3" futures-util = "0.3" russh = "0.57" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } diff --git a/src-tauri/src/ai.rs b/src-tauri/src/ai.rs new file mode 100644 index 0000000..36797ee --- /dev/null +++ b/src-tauri/src/ai.rs @@ -0,0 +1,158 @@ +use serde::Serialize; +use tauri::Result; + +use crate::common::enums::AppError; + +#[derive(Serialize, Clone)] +pub struct AIModelInfo { + pub id: String, + pub label: String, +} + +#[tauri::command(rename_all = "snake_case")] +pub async fn ai_fetch_claude_models(api_key: &str) -> Result> { + let client = reqwest::Client::new(); + let resp = client + .get("https://api.anthropic.com/v1/models?limit=100") + .header("x-api-key", api_key) + .header("anthropic-version", "2023-06-01") + .send() + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(AppError::DatabaseError(format!( + "Claude API error ({}): {}", + status, body + )) + .into()); + } + + let data: serde_json::Value = resp + .json() + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut models = Vec::new(); + if let Some(arr) = data.get("data").and_then(|d| d.as_array()) { + for item in arr { + let id = item + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + if id.is_empty() { + continue; + } + let label = item + .get("display_name") + .and_then(|v| v.as_str()) + .unwrap_or(&id) + .to_string(); + models.push(AIModelInfo { id, label }); + } + } + + models.sort_by(|a, b| a.label.cmp(&b.label)); + Ok(models) +} + +#[tauri::command(rename_all = "snake_case")] +pub async fn ai_fetch_openai_models(api_key: &str) -> Result> { + let client = reqwest::Client::new(); + let resp = client + .get("https://api.openai.com/v1/models") + .header("Authorization", format!("Bearer {}", api_key)) + .send() + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(AppError::DatabaseError(format!( + "OpenAI API error ({}): {}", + status, body + )) + .into()); + } + + let data: serde_json::Value = resp + .json() + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + // Exclude non-chat and legacy model families + let exclude_prefixes = [ + "dall-e", + "tts-", + "whisper", + "text-embedding", + "text-moderation", + "davinci", + "babbage", + "curie", + "ada", + "codex-", + "omni-moderation", + "gpt-3.5", + "gpt-4-", + "ft:", + ]; + + // Skip dated snapshot variants like gpt-4o-2024-08-06 or gpt-4o-20240806 + let has_date_suffix = |id: &str| -> bool { + let parts: Vec<&str> = id.split('-').collect(); + if parts.len() >= 2 { + let last = parts[parts.len() - 1]; + if last.len() == 8 && last.chars().all(|c| c.is_ascii_digit()) { + return true; + } + if parts.len() >= 4 { + let y = parts[parts.len() - 3]; + let m = parts[parts.len() - 2]; + let d = parts[parts.len() - 1]; + if y.len() == 4 + && m.len() == 2 + && d.len() == 2 + && y.chars().all(|c| c.is_ascii_digit()) + && m.chars().all(|c| c.is_ascii_digit()) + && d.chars().all(|c| c.is_ascii_digit()) + { + return true; + } + } + } + false + }; + + let mut models = Vec::new(); + if let Some(arr) = data.get("data").and_then(|d| d.as_array()) { + for item in arr { + let id = item + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + if id.is_empty() { + continue; + } + let lower = id.to_lowercase(); + if exclude_prefixes.iter().any(|p| lower.starts_with(p)) { + continue; + } + if has_date_suffix(&id) { + continue; + } + models.push(AIModelInfo { + label: id.clone(), + id, + }); + } + } + + models.sort_by(|a, b| a.label.cmp(&b.label)); + Ok(models) +} diff --git a/src-tauri/src/dbs/mod.rs b/src-tauri/src/dbs/mod.rs index 8f815bb..24096ad 100644 --- a/src-tauri/src/dbs/mod.rs +++ b/src-tauri/src/dbs/mod.rs @@ -1,3 +1,4 @@ pub mod project; pub mod query; +pub mod settings; pub mod workspace; diff --git a/src-tauri/src/dbs/project.rs b/src-tauri/src/dbs/project.rs index bd480f1..de78ab9 100644 --- a/src-tauri/src/dbs/project.rs +++ b/src-tauri/src/dbs/project.rs @@ -6,10 +6,7 @@ use tauri::{Result, State}; #[tauri::command(rename_all = "snake_case")] pub async fn project_db_select(app_state: State<'_, AppState>) -> Result { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; let mut rows = conn .query( @@ -83,10 +80,7 @@ pub async fn project_db_insert( project_details: Vec, app_state: State<'_, AppState>, ) -> Result<()> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; let driver = project_details.first().cloned().unwrap_or_default(); let username = project_details.get(1).cloned().unwrap_or_default(); @@ -118,10 +112,7 @@ pub async fn project_db_insert( #[tauri::command(rename_all = "snake_case")] pub async fn project_db_delete(project_id: &str, app_state: State<'_, AppState>) -> Result<()> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; conn.execute( "DELETE FROM projects WHERE id = ?1", diff --git a/src-tauri/src/dbs/query.rs b/src-tauri/src/dbs/query.rs index 8e0261e..45cccad 100644 --- a/src-tauri/src/dbs/query.rs +++ b/src-tauri/src/dbs/query.rs @@ -5,10 +5,7 @@ use tauri::{AppHandle, Manager, Result, State}; #[tauri::command(rename_all = "snake_case")] pub async fn query_db_select(app_state: State<'_, AppState>) -> Result> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; let mut rows = conn .query("SELECT id, sql FROM queries ORDER BY id", ()) @@ -35,10 +32,7 @@ pub async fn query_db_select(app_state: State<'_, AppState>) -> Result Result<()> { let app_state = app.state::(); - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; conn.execute( "INSERT OR REPLACE INTO queries (id, sql) VALUES (?1, ?2)", @@ -52,10 +46,7 @@ pub async fn query_db_insert(query_id: &str, sql: &str, app: AppHandle) -> Resul #[tauri::command(rename_all = "snake_case")] pub async fn query_db_delete(query_id: &str, app_state: State<'_, AppState>) -> Result<()> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; conn.execute( "DELETE FROM queries WHERE id = ?1", diff --git a/src-tauri/src/dbs/settings.rs b/src-tauri/src/dbs/settings.rs new file mode 100644 index 0000000..05770bb --- /dev/null +++ b/src-tauri/src/dbs/settings.rs @@ -0,0 +1,91 @@ +use std::collections::HashMap; + +use crate::common::enums::AppError; +use crate::AppState; +use tauri::{Result, State}; + +#[tauri::command(rename_all = "snake_case")] +pub async fn settings_get_all( + app_state: State<'_, AppState>, +) -> Result> { + let conn = app_state.local_conn.lock().await; + + let mut rows = conn + .query("SELECT key, value FROM settings", ()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut map = HashMap::new(); + while let Some(row) = rows + .next() + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + { + let key: String = row + .get(0) + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let value: String = row + .get(1) + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + map.insert(key, value); + } + Ok(map) +} + +#[tauri::command(rename_all = "snake_case")] +pub async fn settings_get(key: &str, app_state: State<'_, AppState>) -> Result> { + let conn = app_state.local_conn.lock().await; + + let mut rows = conn + .query( + "SELECT value FROM settings WHERE key = ?1", + libsql::params![key], + ) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if let Some(row) = rows + .next() + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + { + let value: String = row + .get(0) + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(Some(value)) + } else { + Ok(None) + } +} + +#[tauri::command(rename_all = "snake_case")] +pub async fn settings_set( + key: &str, + value: &str, + app_state: State<'_, AppState>, +) -> Result<()> { + let conn = app_state.local_conn.lock().await; + + conn.execute( + "INSERT OR REPLACE INTO settings (key, value) VALUES (?1, ?2)", + libsql::params![key, value], + ) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) +} + +#[tauri::command(rename_all = "snake_case")] +pub async fn settings_delete(key: &str, app_state: State<'_, AppState>) -> Result<()> { + let conn = app_state.local_conn.lock().await; + + conn.execute( + "DELETE FROM settings WHERE key = ?1", + libsql::params![key], + ) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) +} diff --git a/src-tauri/src/dbs/workspace.rs b/src-tauri/src/dbs/workspace.rs index a0f0ceb..c10be84 100644 --- a/src-tauri/src/dbs/workspace.rs +++ b/src-tauri/src/dbs/workspace.rs @@ -4,10 +4,7 @@ use tauri::{Result, State}; #[tauri::command(rename_all = "snake_case")] pub async fn workspace_save(name: &str, tabs: &str, app_state: State<'_, AppState>) -> Result<()> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; conn.execute( "INSERT OR REPLACE INTO workspaces (name, tabs) VALUES (?1, ?2)", @@ -21,10 +18,7 @@ pub async fn workspace_save(name: &str, tabs: &str, app_state: State<'_, AppStat #[tauri::command(rename_all = "snake_case")] pub async fn workspace_load_all(app_state: State<'_, AppState>) -> Result> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; let mut rows = conn .query("SELECT name, tabs FROM workspaces ORDER BY name", ()) @@ -50,10 +44,7 @@ pub async fn workspace_load_all(app_state: State<'_, AppState>) -> Result) -> Result<()> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; conn.execute( "DELETE FROM workspaces WHERE name = ?1", diff --git a/src-tauri/src/drivers/pgsql.rs b/src-tauri/src/drivers/pgsql.rs index 6d6c0e1..aa97570 100644 --- a/src-tauri/src/drivers/pgsql.rs +++ b/src-tauri/src/drivers/pgsql.rs @@ -129,10 +129,7 @@ async fn snapshot_upsert_metadata( page_size: usize, col_count: usize, ) -> std::result::Result<(), AppError> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; conn.execute( "INSERT OR REPLACE INTO virtual_query_snapshots ( @@ -165,10 +162,7 @@ async fn snapshot_store_page( return Ok(()); } - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; for attempt in 0..SNAPSHOT_PAGE_WRITE_RETRIES { match conn @@ -208,10 +202,7 @@ async fn snapshot_load_page( query_id: &str, page_index: usize, ) -> std::result::Result, AppError> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; let mut rows = conn .query( @@ -242,10 +233,7 @@ async fn snapshot_load_metadata( app_state: &AppState, query_id: &str, ) -> std::result::Result, AppError> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; let mut rows = conn .query( @@ -296,10 +284,7 @@ async fn snapshot_cleanup_query( app_state: &AppState, query_id: &str, ) -> std::result::Result<(), AppError> { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; conn.execute( "DELETE FROM virtual_query_pages WHERE query_id = ?1", @@ -377,6 +362,39 @@ async fn restore_virtual_from_snapshot( Ok(true) } +#[tauri::command(rename_all = "snake_case")] +pub async fn pgsql_test_connection( + key: [&str; 6], +) -> Result { + let user = key[0]; + let password = key[1]; + let database = key[2]; + let host = key[3]; + let port: u16 = key[4].parse().unwrap_or(5432); + let use_ssl = key[5] == "true"; + + let mut cfg = Config::new(); + cfg.user(user) + .password(password) + .dbname(database) + .host(host) + .port(port); + + let pool = create_pg_pool(&cfg, use_ssl, 1)?; + let client = pool + .get() + .await + .map_err(|e| AppError::ConnectionFailed(full_error_chain(&e)))?; + + let row = client + .query_one("SELECT version()", &[]) + .await + .map_err(|e| AppError::ConnectionFailed(e.to_string()))?; + + let version: String = row.get(0); + Ok(version) +} + #[tauri::command(rename_all = "snake_case")] pub async fn pgsql_connector( project_id: &str, @@ -402,10 +420,7 @@ pub async fn pgsql_connector( key[5] == "true", ), None => { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; let mut rows = conn .query( "SELECT username, password, database, host, port, ssl FROM projects WHERE id = ?1", @@ -1090,10 +1105,7 @@ pub async fn pgsql_listen_start(project_id: &str, channel: &str, app: AppHandle) // Get connection config from local db let (cfg, use_ssl) = { - let conn = app_state - .local_db - .connect() - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let conn = app_state.local_conn.lock().await; let mut rows = conn .query( "SELECT username, password, database, host, port, ssl FROM projects WHERE id = ?1", diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 1dbbe5b..59e42a1 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -1,6 +1,7 @@ // Prevents additional console window on Windows in release, DO NOT REMOVE!! #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] +mod ai; mod common; mod dbs; mod drivers; @@ -23,7 +24,7 @@ pub struct AppState { pub meta_clients: Arc>>>, pub cancel_tokens: Arc>>, pub client_ssl: Arc>>, - pub local_db: libsql::Database, + pub local_conn: Arc>, pub resource_monitor: Arc>, pub virtual_cache: Arc>, pub notify_handles: Arc>>>, @@ -57,24 +58,25 @@ fn main() { let app_handle = app.handle().clone(); tauri::async_runtime::block_on(async move { - let db_path = if cfg!(debug_assertions) { - LOCAL_DB_NAME.to_string() - } else { - let app_dir = app_handle - .path() - .app_data_dir() - .expect("Failed to resolve app data directory"); - std::fs::create_dir_all(&app_dir).ok(); - app_dir.join(LOCAL_DB_NAME).to_string_lossy().to_string() - }; + let app_dir = app_handle + .path() + .app_data_dir() + .expect("Failed to resolve app data directory"); + std::fs::create_dir_all(&app_dir).ok(); + let db_path = app_dir.join(LOCAL_DB_NAME).to_string_lossy().to_string(); let db = libsql::Builder::new_local(&db_path) .build() .await .expect("Failed to open local database"); - // Create tables + // Enable WAL mode for concurrent access let conn = db.connect().expect("Failed to create connection"); + conn.execute("PRAGMA journal_mode=WAL", ()) + .await + .ok(); + + // Create tables conn.execute( "CREATE TABLE IF NOT EXISTS projects ( id TEXT PRIMARY KEY, @@ -111,6 +113,16 @@ fn main() { .await .expect("Failed to create workspaces table"); + conn.execute( + "CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL DEFAULT '' + )", + (), + ) + .await + .expect("Failed to create settings table"); + conn.execute( "CREATE TABLE IF NOT EXISTS virtual_query_snapshots ( query_id TEXT PRIMARY KEY, @@ -173,7 +185,7 @@ fn main() { meta_clients: Arc::new(Mutex::new(BTreeMap::new())), cancel_tokens: Arc::new(Mutex::new(BTreeMap::new())), client_ssl: Arc::new(Mutex::new(BTreeMap::new())), - local_db: db, + local_conn: Arc::new(Mutex::new(conn)), resource_monitor: Arc::new(Mutex::new(utils::ResourceMonitor::new())), virtual_cache: Arc::new(Mutex::new(BTreeMap::new())), notify_handles: Arc::new(Mutex::new(BTreeMap::new())), @@ -259,6 +271,11 @@ fn main() { dbs::workspace::workspace_save, dbs::workspace::workspace_load_all, dbs::workspace::workspace_delete, + dbs::settings::settings_get_all, + dbs::settings::settings_get, + dbs::settings::settings_set, + dbs::settings::settings_delete, + drivers::pgsql::pgsql_test_connection, drivers::pgsql::pgsql_connector, drivers::pgsql::pgsql_load_databases, drivers::pgsql::pgsql_load_tablespaces, @@ -316,6 +333,8 @@ fn main() { terminal::terminal_kill, utils::compute_diff, utils::system_resource_usage, + ai::ai_fetch_claude_models, + ai::ai_fetch_openai_models, ]) .run(tauri::generate_context!()) .expect("error while running tauri application"); diff --git a/src/App.tsx b/src/App.tsx index 52e1087..356512e 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -14,6 +14,7 @@ import { SchemaDiffPanel } from "@/components/schema-diff-panel"; import { ExtensionsPanel } from "@/components/extensions-panel"; import { EnumsPanel } from "@/components/enums-panel"; import { PgSettingsPanel } from "@/components/pg-settings-panel"; +import { AIChatPanel } from "@/components/ai-chat-panel"; import { TabBar } from "@/components/tab-bar"; import { TopBar } from "@/components/top-bar"; import { EditorToolbar } from "@/components/editor-toolbar"; @@ -65,6 +66,9 @@ export default function App() { const setConnectionModalOpen = useUIStore((s) => s.setConnectionModalOpen); const setSidebarWidth = useUIStore((s) => s.setSidebarWidth); const setEditorHeight = useUIStore((s) => s.setEditorHeight); + const aiPanelOpen = useUIStore((s) => s.aiPanelOpen); + const aiPanelWidth = useUIStore((s) => s.aiPanelWidth); + const setAIPanelWidth = useUIStore((s) => s.setAIPanelWidth); const loadProjects = useProjectStore((s) => s.loadProjects); const projects = useProjectStore((s) => s.projects); @@ -547,6 +551,18 @@ export default function App() { )} + + {aiPanelOpen && ( + <> + setAIPanelWidth(-d)} /> +
+ +
+ + )} diff --git a/src/components/ai-chat-panel.tsx b/src/components/ai-chat-panel.tsx new file mode 100644 index 0000000..f2b3923 --- /dev/null +++ b/src/components/ai-chat-panel.tsx @@ -0,0 +1,610 @@ +import { useState, useRef, useEffect, useCallback } from "react"; +import { + Settings, + Send, + Copy, + Play, + Square, + Sparkles, + X, + Database, + FileCode, + Trash2, +} from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { toast } from "sonner"; +import { useSettingsStore, getModelsForProvider } from "@/stores/settings-store"; +import { useProjectStore } from "@/stores/project-store"; +import { useTabStore } from "@/stores/tab-store"; +import { DriverFactory } from "@/lib/database-driver"; +import { aiChat, type ChatMessage } from "@/lib/ai-service"; +import { classifySQL } from "@/lib/sql-classify"; +import { buildSchemaContext } from "@/lib/schema-context"; +import { useUIStore } from "@/stores/ui-store"; +import { AISettingsDialog } from "@/components/ai-settings-dialog"; +import { ProjectConnectionStatus as PCS } from "@/types"; + +interface DisplayMessage { + role: "user" | "assistant"; + content: string; + sqlBlocks: SQLBlock[]; +} + +interface SQLBlock { + sql: string; + type: "select" | "write"; +} + +function extractSQLBlocks(text: string): SQLBlock[] { + const blocks: SQLBlock[] = []; + const regex = /```sql\s*\n([\s\S]*?)```/gi; + let match; + while ((match = regex.exec(text)) !== null) { + const sql = match[1].trim(); + if (sql) { + blocks.push({ sql, type: classifySQL(sql) }); + } + } + return blocks; +} + +/** Parse markdown-ish content into segments for rendering */ +function parseContent(content: string) { + const segments: Array< + | { type: "text"; value: string } + | { type: "sql"; value: string } + | { type: "code"; value: string; lang?: string } + | { type: "inline-code"; value: string } + > = []; + + // Match fenced code blocks (sql and other) + const codeBlockRegex = /```(\w*)\s*\n([\s\S]*?)```/g; + let lastIndex = 0; + let match; + + while ((match = codeBlockRegex.exec(content)) !== null) { + if (match.index > lastIndex) { + segments.push({ type: "text", value: content.slice(lastIndex, match.index) }); + } + const lang = match[1].toLowerCase(); + const code = match[2].trim(); + if (lang === "sql") { + segments.push({ type: "sql", value: code }); + } else { + segments.push({ type: "code", value: code, lang: lang || undefined }); + } + lastIndex = match.index + match[0].length; + } + + if (lastIndex < content.length) { + segments.push({ type: "text", value: content.slice(lastIndex) }); + } + + return segments; +} + +/** Render a text segment with inline formatting (bold, inline code, etc.) */ +function renderTextSegment(text: string, key: number) { + const parts: React.ReactNode[] = []; + // Process inline code, bold, and line breaks + const inlineRegex = /(`[^`]+`|\*\*[^*]+\*\*|\n)/g; + let lastIdx = 0; + let m; + let subKey = 0; + + while ((m = inlineRegex.exec(text)) !== null) { + if (m.index > lastIdx) { + parts.push({text.slice(lastIdx, m.index)}); + } + const token = m[0]; + if (token === "\n") { + parts.push(
); + } else if (token.startsWith("`") && token.endsWith("`")) { + parts.push( + + {token.slice(1, -1)} + , + ); + } else if (token.startsWith("**") && token.endsWith("**")) { + parts.push( + + {token.slice(2, -2)} + , + ); + } + lastIdx = m.index + token.length; + } + + if (lastIdx < text.length) { + parts.push({text.slice(lastIdx)}); + } + + return {parts}; +} + +function SQLCodeBlock({ + sql, + type, + onCopy, + onRun, +}: { + sql: string; + type: "select" | "write"; + onCopy: () => void; + onRun: () => void; +}) { + return ( +
+ {/* SQL header bar */} +
+ + sql + +
+ + {type === "select" ? ( + + ) : ( + + )} +
+
+ {/* SQL code */} +
+        
+          {sql}
+        
+      
+
+ ); +} + +export function AIChatPanel() { + const [messages, setMessages] = useState([]); + const [input, setInput] = useState(""); + const [loading, setLoading] = useState(false); + const [settingsOpen, setSettingsOpen] = useState(false); + const [selectedProjectId, setSelectedProjectId] = useState(); + const abortRef = useRef(null); + const messagesEndRef = useRef(null); + const inputRef = useRef(null); + + const aiProvider = useSettingsStore((s) => s.aiProvider); + const aiModel = useSettingsStore((s) => s.aiModel); + const claudeApiKey = useSettingsStore((s) => s.claudeApiKey); + const openaiApiKey = useSettingsStore((s) => s.openaiApiKey); + const loaded = useSettingsStore((s) => s.loaded); + + const projects = useProjectStore((s) => s.projects); + const connectionStatus = useProjectStore((s) => s.status); + const connectProject = useProjectStore((s) => s.connect); + + useEffect(() => { + if (!loaded) void useSettingsStore.getState().load(); + }, [loaded]); + + useEffect(() => { + messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); + }, [messages]); + + const apiKey = aiProvider === "claude" ? claudeApiKey : openaiApiKey; + const models = getModelsForProvider(aiProvider, useSettingsStore.getState()); + const modelLabel = models.find((m) => m.id === aiModel)?.label ?? aiModel; + + const connectedProjectId = + selectedProjectId && connectionStatus[selectedProjectId] === PCS.Connected + ? selectedProjectId + : undefined; + + // Auto-load full schema metadata when a connected project is selected + useEffect(() => { + if (!connectedProjectId) return; + + const loadAll = async () => { + const store = useProjectStore.getState(); + let projectSchemas = store.schemas[connectedProjectId] ?? []; + + if (projectSchemas.length === 0) { + await store.loadSchemas(connectedProjectId); + projectSchemas = useProjectStore.getState().schemas[connectedProjectId] ?? []; + } + + for (const schema of projectSchemas) { + const tableKey = `${connectedProjectId}::${schema}`; + if (!useProjectStore.getState().tables[tableKey]) { + await store.loadTables(connectedProjectId, schema); + } + const tables = useProjectStore.getState().tables[tableKey] ?? []; + + // Batch load column details for all tables that haven't been loaded + const toLoad = tables.filter((t) => { + const colKey = `${connectedProjectId}::${schema}::${t.name}`; + return !useProjectStore.getState().columnDetails[colKey]; + }); + if (toLoad.length > 0) { + await Promise.allSettled( + toLoad.map((t) => store.loadColumnDetails(connectedProjectId, schema, t.name)), + ); + } + } + }; + void loadAll(); + }, [connectedProjectId]); + + const handleSend = useCallback(async () => { + const text = input.trim(); + if (!text || loading) return; + + if (!apiKey) { + toast.error("No API key configured", { + description: "Open AI Settings to add your API key", + }); + setSettingsOpen(true); + return; + } + + const userMsg: DisplayMessage = { role: "user", content: text, sqlBlocks: [] }; + setMessages((prev) => [...prev, userMsg]); + setInput(""); + setLoading(true); + + const chatHistory: ChatMessage[] = [ + ...messages.map((m) => ({ role: m.role, content: m.content })), + { role: "user" as const, content: text }, + ]; + + const schemaCtx = connectedProjectId ? buildSchemaContext(connectedProjectId) : ""; + const dbInfo = connectedProjectId + ? (() => { + const d = projects[connectedProjectId]; + return d ? `Connected to: ${d.database} (${d.driver}) on ${d.host}:${d.port}` : ""; + })() + : ""; + + const systemPrompt = [ + `You are an expert PostgreSQL assistant embedded in a database GUI called RSQL. +You have direct access to the database schema provided below. Use it to write precise, correct SQL. + +Rules: +- NEVER ask the user for table names, column names, or types — you already have the full schema. +- ALWAYS use the exact table and column names from the schema. Respect schema prefixes (e.g. public.users). +- When providing SQL, wrap it in \`\`\`sql code blocks. +- Be concise. Give the SQL directly, with brief explanation only when needed. +- For SELECT queries: write them ready to run. +- For write operations (INSERT, UPDATE, DELETE, ALTER, DROP, CREATE): provide the SQL but note it will be opened in editor for review. +- If the user's request is ambiguous, make a reasonable assumption based on the schema and note your assumption briefly. +- Use modern PostgreSQL syntax (CTEs, window functions, JSONB operators, etc.) when appropriate. +- When the user asks to "show", "list", "find", or "get" something, always respond with a SELECT query.`, + dbInfo, + schemaCtx, + ] + .filter(Boolean) + .join("\n\n"); + + const controller = new AbortController(); + abortRef.current = controller; + + try { + const response = await aiChat({ + provider: aiProvider, + model: aiModel, + apiKey, + messages: chatHistory, + systemPrompt, + signal: controller.signal, + }); + + const sqlBlocks = extractSQLBlocks(response); + const assistantMsg: DisplayMessage = { + role: "assistant", + content: response, + sqlBlocks, + }; + setMessages((prev) => [...prev, assistantMsg]); + } catch (err: unknown) { + if (err instanceof Error && err.name === "AbortError") return; + const msg = err instanceof Error ? err.message : String(err); + toast.error("AI request failed", { description: msg }); + const errorMsg: DisplayMessage = { + role: "assistant", + content: `Error: ${msg}`, + sqlBlocks: [], + }; + setMessages((prev) => [...prev, errorMsg]); + } finally { + setLoading(false); + abortRef.current = null; + } + }, [input, loading, apiKey, messages, connectedProjectId, projects, aiProvider, aiModel]); + + const handleStop = useCallback(() => { + abortRef.current?.abort(); + setLoading(false); + }, []); + + const handleCopySQL = useCallback((sql: string) => { + void navigator.clipboard.writeText(sql); + toast.success("SQL copied to clipboard"); + }, []); + + const handleRunSQL = useCallback( + async (sql: string) => { + if (!connectedProjectId) { + toast.error("No active database connection"); + return; + } + + const type = classifySQL(sql); + if (type === "write") { + useTabStore.getState().openTab(connectedProjectId, sql); + toast.info("Query opened in new tab for review"); + return; + } + + const d = projects[connectedProjectId]; + if (!d) return; + + try { + const driver = DriverFactory.getDriver(d.driver); + const [cols, rows, time] = await driver.runQuery(connectedProjectId, sql); + const store = useTabStore.getState(); + store.openTab(connectedProjectId, sql); + const newIdx = store.tabs.length - 1; + store.updateResult(newIdx, { columns: cols, rows, time }); + toast.success(`${rows.length} rows in ${time.toFixed(1)}ms`); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : String(err); + toast.error("Query failed", { description: msg }); + } + }, + [connectedProjectId, projects], + ); + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent) => { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + void handleSend(); + } + }, + [handleSend], + ); + + const handleClearChat = useCallback(() => { + setMessages([]); + }, []); + + // Build rendered segments for assistant messages + const renderAssistantContent = (msg: DisplayMessage) => { + const segments = parseContent(msg.content); + let sqlIdx = 0; + + return segments.map((seg, i) => { + if (seg.type === "text") { + return renderTextSegment(seg.value, i); + } + if (seg.type === "sql") { + const block = msg.sqlBlocks[sqlIdx]; + sqlIdx++; + return ( + handleCopySQL(seg.value)} + onRun={() => void handleRunSQL(seg.value)} + /> + ); + } + if (seg.type === "code") { + return ( +
+            
+              {seg.value}
+            
+          
+ ); + } + return null; + }); + }; + + return ( +
+ {/* Header */} +
+
+ + + {modelLabel} + +
+
+ {messages.length > 0 && ( + + )} + + +
+
+ + {/* DB selector */} +
+ + + {connectedProjectId && ( +
+ )} +
+ + {/* Messages */} +
+ {messages.length === 0 && ( +
+
+ +
+
+

SQL Assistant

+

+ Write queries, explain SQL, or explore your schema. + {!connectedProjectId && " Connect a database for schema context."} +

+
+
+ )} + +
+ {messages.map((msg, i) => + msg.role === "user" ? ( + /* User message */ +
+
+

{msg.content}

+
+
+ ) : ( + /* Assistant message */ +
+
+ {renderAssistantContent(msg)} +
+
+ ), + )} + + {loading && ( +
+
+ + + +
+ Thinking... +
+ )} + +
+
+
+ + {/* Input */} +
+
+