Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 20 additions & 98 deletions stationapi/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use csv::{ReaderBuilder, StringRecord};
use sqlx::{Connection, PgConnection};
use stationapi::config::fetch_database_url;
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::io::{Cursor, Read as _};
use std::path::Path;
use std::{env, fs};
Expand Down Expand Up @@ -105,50 +105,35 @@ pub async fn import_csv() -> Result<(), Box<dyn std::error::Error>> {
}
};

let table_columns = fetch_table_columns(&mut conn, table_name).await?;
let (column_names, column_indexes) = build_insert_columns(&headers, &table_columns);

let skipped_columns: Vec<&str> = headers
.iter()
.filter(|header| !header.starts_with('#') && !table_columns.contains(header.as_str()))
.map(String::as_str)
.collect();
if !skipped_columns.is_empty() {
warn!(
"Skipping unknown CSV columns for table {}: {}",
table_name,
skipped_columns.join(", ")
);
}

// Skip empty CSV files to avoid generating invalid INSERT statements
if csv_data.is_empty() {
tracing::warn!("Skipping empty CSV file: {}", file_name);
continue;
}

if column_names.is_empty() {
tracing::warn!(
"Skipping CSV file {} because no importable columns matched table {}",
file_name,
table_name
);
continue;
}

let mut sql_lines_inner = Vec::new();
sql_lines_inner.push(format!(
"INSERT INTO public.{table_name} ({}) VALUES ",
column_names.join(",")
));
sql_lines_inner.push(format!("INSERT INTO public.{table_name} VALUES "));

for (idx, data) in csv_data.iter().enumerate() {
let cols: Vec<_> = column_indexes
let cols: Vec<_> = data
.iter()
.map(|col_idx| match data.get(*col_idx).unwrap_or("") {
"" => "NULL".to_string(),
"DEFAULT" => "DEFAULT".to_string(),
col => format!("'{}'", escape_sql_string(col)),
.enumerate()
.filter_map(|(col_idx, col)| {
if headers
.get(col_idx)
.unwrap_or(&String::new())
.starts_with('#')
{
return None;
}

if col.is_empty() {
Some("NULL".to_string())
} else if col == "DEFAULT" {
Some("DEFAULT".to_string())
} else {
Some(format!("'{}'", escape_sql_string(col)))
}
})
Comment thread
TinyKitten marked this conversation as resolved.
.collect();

Expand All @@ -173,69 +158,6 @@ pub async fn import_csv() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

async fn fetch_table_columns(
conn: &mut PgConnection,
table_name: &str,
) -> Result<HashSet<String>, Box<dyn std::error::Error>> {
let rows = sqlx::query_scalar::<_, String>(
r#"
SELECT column_name
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
"#,
)
.bind(table_name)
.fetch_all(conn)
.await?;

Ok(rows.into_iter().collect())
}

fn build_insert_columns(
headers: &[String],
table_columns: &HashSet<String>,
) -> (Vec<String>, Vec<usize>) {
let mut column_names = Vec::new();
let mut column_indexes = Vec::new();

for (idx, header) in headers.iter().enumerate() {
if header.starts_with('#') || !table_columns.contains(header.as_str()) {
continue;
}

column_names.push(header.clone());
column_indexes.push(idx);
}

(column_names, column_indexes)
}

#[cfg(test)]
mod import_tests {
use super::build_insert_columns;
use std::collections::HashSet;

#[test]
fn build_insert_columns_skips_comment_and_unknown_columns() {
let headers = vec![
"id".to_string(),
"station_cd".to_string(),
"#メモ".to_string(),
"unknown_col".to_string(),
"type_cd".to_string(),
];
let table_columns: HashSet<String> = ["id", "station_cd", "type_cd"]
.into_iter()
.map(str::to_string)
.collect();

let (column_names, column_indexes) = build_insert_columns(&headers, &table_columns);

assert_eq!(column_names, vec!["id", "station_cd", "type_cd"]);
assert_eq!(column_indexes, vec![0, 1, 4]);
}
}

/// Represents a translation entry from translations.txt
#[derive(Debug, Clone, Default)]
struct Translation {
Expand Down
Loading