diff --git a/stationapi/src/import.rs b/stationapi/src/import.rs index 4c3ff026..1b991878 100644 --- a/stationapi/src/import.rs +++ b/stationapi/src/import.rs @@ -3,7 +3,7 @@ use csv::{ReaderBuilder, StringRecord}; use sqlx::{Connection, PgConnection}; use stationapi::config::fetch_database_url; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::io::{Cursor, Read as _}; use std::path::Path; use std::{env, fs}; @@ -105,50 +105,35 @@ pub async fn import_csv() -> Result<(), Box> { } }; - let table_columns = fetch_table_columns(&mut conn, table_name).await?; - let (column_names, column_indexes) = build_insert_columns(&headers, &table_columns); - - let skipped_columns: Vec<&str> = headers - .iter() - .filter(|header| !header.starts_with('#') && !table_columns.contains(header.as_str())) - .map(String::as_str) - .collect(); - if !skipped_columns.is_empty() { - warn!( - "Skipping unknown CSV columns for table {}: {}", - table_name, - skipped_columns.join(", ") - ); - } - // Skip empty CSV files to avoid generating invalid INSERT statements if csv_data.is_empty() { tracing::warn!("Skipping empty CSV file: {}", file_name); continue; } - if column_names.is_empty() { - tracing::warn!( - "Skipping CSV file {} because no importable columns matched table {}", - file_name, - table_name - ); - continue; - } - let mut sql_lines_inner = Vec::new(); - sql_lines_inner.push(format!( - "INSERT INTO public.{table_name} ({}) VALUES ", - column_names.join(",") - )); + sql_lines_inner.push(format!("INSERT INTO public.{table_name} VALUES ")); for (idx, data) in csv_data.iter().enumerate() { - let cols: Vec<_> = column_indexes + let cols: Vec<_> = data .iter() - .map(|col_idx| match data.get(*col_idx).unwrap_or("") { - "" => "NULL".to_string(), - "DEFAULT" => "DEFAULT".to_string(), - col => format!("'{}'", escape_sql_string(col)), + .enumerate() + .filter_map(|(col_idx, col)| { + if headers + .get(col_idx) + .unwrap_or(&String::new()) + .starts_with('#') + { + return None; + } + + if col.is_empty() { + Some("NULL".to_string()) + } else if col == "DEFAULT" { + Some("DEFAULT".to_string()) + } else { + Some(format!("'{}'", escape_sql_string(col))) + } }) .collect(); @@ -173,69 +158,6 @@ pub async fn import_csv() -> Result<(), Box> { Ok(()) } -async fn fetch_table_columns( - conn: &mut PgConnection, - table_name: &str, -) -> Result, Box> { - let rows = sqlx::query_scalar::<_, String>( - r#" - SELECT column_name - FROM information_schema.columns - WHERE table_schema = 'public' AND table_name = $1 - "#, - ) - .bind(table_name) - .fetch_all(conn) - .await?; - - Ok(rows.into_iter().collect()) -} - -fn build_insert_columns( - headers: &[String], - table_columns: &HashSet, -) -> (Vec, Vec) { - let mut column_names = Vec::new(); - let mut column_indexes = Vec::new(); - - for (idx, header) in headers.iter().enumerate() { - if header.starts_with('#') || !table_columns.contains(header.as_str()) { - continue; - } - - column_names.push(header.clone()); - column_indexes.push(idx); - } - - (column_names, column_indexes) -} - -#[cfg(test)] -mod import_tests { - use super::build_insert_columns; - use std::collections::HashSet; - - #[test] - fn build_insert_columns_skips_comment_and_unknown_columns() { - let headers = vec![ - "id".to_string(), - "station_cd".to_string(), - "#メモ".to_string(), - "unknown_col".to_string(), - "type_cd".to_string(), - ]; - let table_columns: HashSet = ["id", "station_cd", "type_cd"] - .into_iter() - .map(str::to_string) - .collect(); - - let (column_names, column_indexes) = build_insert_columns(&headers, &table_columns); - - assert_eq!(column_names, vec!["id", "station_cd", "type_cd"]); - assert_eq!(column_indexes, vec![0, 1, 4]); - } -} - /// Represents a translation entry from translations.txt #[derive(Debug, Clone, Default)] struct Translation {