From 8306436378856c418c31cd5f3539a2e6bb9211d8 Mon Sep 17 00:00:00 2001 From: Yu Chen Date: Thu, 5 Mar 2026 23:09:21 -0800 Subject: [PATCH 1/6] feat: add SQL dialect support with Spark SQL dialect Add SqlDialect enum and SparkDialect implementation using DataFusion's unparser Dialect trait. The to_sql() method now accepts an optional dialect parameter to generate dialect-specific SQL from Cypher queries. Supported dialects: Default, Spark, PostgreSQL, MySQL, SQLite. Spark SQL differences: backtick quoting, STRING type, EXTRACT for date parts, LENGTH instead of CHARACTER_LENGTH, required subquery aliases. Python API updated: to_sql(datasets, dialect="spark") Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/lance-graph-python/src/graph.rs | 29 +- crates/lance-graph/src/lib.rs | 2 + crates/lance-graph/src/query.rs | 29 +- crates/lance-graph/src/spark_dialect.rs | 218 +++++++++++++ crates/lance-graph/tests/test_to_spark_sql.rs | 293 ++++++++++++++++++ crates/lance-graph/tests/test_to_sql.rs | 24 +- 6 files changed, 571 insertions(+), 24 deletions(-) create mode 100644 crates/lance-graph/src/spark_dialect.rs create mode 100644 crates/lance-graph/tests/test_to_spark_sql.rs diff --git a/crates/lance-graph-python/src/graph.rs b/crates/lance-graph-python/src/graph.rs index 223c6aac..127de651 100644 --- a/crates/lance-graph-python/src/graph.rs +++ b/crates/lance-graph-python/src/graph.rs @@ -494,6 +494,9 @@ impl CypherQuery { /// ---------- /// datasets : dict /// Dictionary mapping table names to Lance datasets + /// dialect : str, optional + /// SQL dialect to use. One of "default", "spark", "postgresql", "mysql", "sqlite". + /// Defaults to "default" (generic DataFusion SQL). /// /// Returns /// ------- @@ -504,7 +507,29 @@ impl CypherQuery { /// ------ /// RuntimeError /// If SQL generation fails - fn to_sql(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult { + /// ValueError + /// If an invalid dialect is specified + #[pyo3(signature = (datasets, dialect=None))] + fn to_sql( + &self, + py: Python, + datasets: &Bound<'_, PyDict>, + dialect: Option<&str>, + ) -> PyResult { + let sql_dialect = match dialect { + None | Some("default") => None, + Some("spark") => Some(lance_graph::SqlDialect::Spark), + Some("postgresql") | Some("postgres") => Some(lance_graph::SqlDialect::PostgreSql), + Some("mysql") => Some(lance_graph::SqlDialect::MySql), + Some("sqlite") => Some(lance_graph::SqlDialect::Sqlite), + Some(other) => { + return Err(PyValueError::new_err(format!( + "Unknown SQL dialect: '{}'. Valid options: 'default', 'spark', 'postgresql', 'mysql', 'sqlite'", + other + ))); + } + }; + // Convert datasets to Arrow RecordBatch map let arrow_datasets = python_datasets_to_batches(datasets)?; @@ -513,7 +538,7 @@ impl CypherQuery { // Execute via runtime let sql = RT - .block_on(Some(py), inner_query.to_sql(arrow_datasets))? + .block_on(Some(py), inner_query.to_sql(arrow_datasets, sql_dialect))? .map_err(graph_error_to_pyerr)?; Ok(sql) diff --git a/crates/lance-graph/src/lib.rs b/crates/lance-graph/src/lib.rs index 8f93fddb..9aed43ff 100644 --- a/crates/lance-graph/src/lib.rs +++ b/crates/lance-graph/src/lib.rs @@ -47,6 +47,7 @@ pub mod parameter_substitution; pub mod parser; pub mod query; pub mod semantic; +pub mod spark_dialect; pub mod sql_catalog; pub mod sql_query; pub mod table_readers; @@ -68,6 +69,7 @@ pub use lance_graph_catalog::{ pub use lance_graph_catalog::{UnityCatalogConfig, UnityCatalogProvider}; pub use lance_vector_search::VectorSearch; pub use query::{CypherQuery, ExecutionStrategy}; +pub use spark_dialect::{SqlDialect, SparkDialect}; pub use sql_query::SqlQuery; #[cfg(feature = "delta")] pub use table_readers::DeltaTableReader; diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index 9625f273..ffd6abd8 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -280,10 +280,10 @@ impl CypherQuery { self.explain_internal(Arc::new(catalog), ctx).await } - /// Convert the Cypher query to a DataFusion SQL string + /// Convert the Cypher query to a SQL string in the specified dialect. /// /// This method generates a SQL string that corresponds to the DataFusion logical plan - /// derived from the Cypher query. It uses the `datafusion-sql` unparser. + /// derived from the Cypher query, using the specified SQL dialect for unparsing. /// /// **WARNING**: This method is experimental and the generated SQL dialect may change. /// @@ -293,16 +293,20 @@ impl CypherQuery { /// /// # Arguments /// * `datasets` - HashMap of table name to RecordBatch (nodes and relationships) + /// * `dialect` - The SQL dialect to use for generating the output SQL. + /// Defaults to `SqlDialect::Default` (generic DataFusion SQL). + /// Use `SqlDialect::Spark` for Spark SQL, `SqlDialect::PostgreSql`, etc. /// /// # Returns - /// A SQL string representing the query + /// A SQL string representing the query in the specified dialect pub async fn to_sql( &self, datasets: HashMap, + dialect: Option, ) -> Result { - use datafusion_sql::unparser::plan_to_sql; use std::sync::Arc; + let dialect = dialect.unwrap_or_default(); let _config = self.require_config()?; // Build catalog and context from datasets using the helper @@ -323,11 +327,16 @@ impl CypherQuery { location: snafu::Location::new(file!(), line!(), column!()), })?; - // Unparse to SQL - let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError { - message: format!("Failed to unparse plan to SQL: {}", e), - location: snafu::Location::new(file!(), line!(), column!()), - })?; + // Unparse to SQL using the specified dialect + let dialect_unparser = dialect.unparser(); + let unparser = dialect_unparser.as_unparser(); + let sql_ast = + unparser + .plan_to_sql(&optimized_plan) + .map_err(|e| GraphError::PlanError { + message: format!("Failed to unparse plan to SQL: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; Ok(sql_ast.to_string()) } @@ -1852,7 +1861,7 @@ mod tests { .unwrap() .with_config(cfg); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); println!("Generated SQL: {}", sql); assert!(sql.contains("SELECT")); diff --git a/crates/lance-graph/src/spark_dialect.rs b/crates/lance-graph/src/spark_dialect.rs new file mode 100644 index 00000000..be7bb60e --- /dev/null +++ b/crates/lance-graph/src/spark_dialect.rs @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! SQL dialect support for the DataFusion unparser. +//! +//! This module provides a [`SqlDialect`] enum for selecting which SQL dialect +//! to use when unparsing DataFusion logical plans to SQL strings, and includes +//! a [`SparkDialect`] implementation for Spark SQL. +//! +//! Key Spark SQL differences from standard SQL: +//! - Backtick (`` ` ``) identifier quoting +//! - `EXTRACT(field FROM expr)` for date field extraction +//! - `STRING` type for casting (not `VARCHAR`) +//! - `BIGINT`/`INT` for integer types +//! - `TIMESTAMP` for all timestamp types (no timezone info in cast) +//! - `LENGTH()` instead of `CHARACTER_LENGTH()` +//! - Subqueries in FROM require aliases + +use std::sync::Arc; + +use arrow::datatypes::TimeUnit; +use datafusion_common::Result; +use datafusion_expr::Expr; +use datafusion_sql::unparser::dialect::{ + CharacterLengthStyle, DateFieldExtractStyle, DefaultDialect, Dialect, IntervalStyle, + MySqlDialect, PostgreSqlDialect, SqliteDialect, +}; +use datafusion_sql::unparser::Unparser; +use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo}; + +/// SQL dialect to use when generating SQL from Cypher queries. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SqlDialect { + /// Generic SQL (DataFusion default dialect) + #[default] + Default, + /// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.) + Spark, + /// PostgreSQL dialect + PostgreSql, + /// MySQL dialect + MySql, + /// SQLite dialect + Sqlite, +} + +impl SqlDialect { + /// Create a DataFusion `Unparser` configured for this dialect. + pub fn unparser(&self) -> DialectUnparser { + match self { + SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}), + SqlDialect::Spark => DialectUnparser::Spark(SparkDialect), + SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}), + SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}), + SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}), + } + } +} + +/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference. +pub enum DialectUnparser { + Default(DefaultDialect), + Spark(SparkDialect), + PostgreSql(PostgreSqlDialect), + MySql(MySqlDialect), + Sqlite(SqliteDialect), +} + +impl DialectUnparser { + pub fn as_unparser(&self) -> Unparser<'_> { + match self { + DialectUnparser::Default(d) => Unparser::new(d), + DialectUnparser::Spark(d) => Unparser::new(d), + DialectUnparser::PostgreSql(d) => Unparser::new(d), + DialectUnparser::MySql(d) => Unparser::new(d), + DialectUnparser::Sqlite(d) => Unparser::new(d), + } + } +} + +/// A Spark SQL dialect for unparsing DataFusion logical plans to Spark-compatible SQL. +pub struct SparkDialect; + +impl Dialect for SparkDialect { + fn identifier_quote_style(&self, _identifier: &str) -> Option { + Some('`') + } + + fn supports_nulls_first_in_sort(&self) -> bool { + true + } + + fn use_timestamp_for_date64(&self) -> bool { + true + } + + fn interval_style(&self) -> IntervalStyle { + IntervalStyle::SQLStandard + } + + fn float64_ast_dtype(&self) -> ast::DataType { + ast::DataType::Double(ast::ExactNumberInfo::None) + } + + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Custom( + ObjectName::from(vec![Ident::new("STRING")]), + vec![], + ) + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Custom( + ObjectName::from(vec![Ident::new("STRING")]), + vec![], + ) + } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::Extract + } + + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::Length + } + + fn int64_cast_dtype(&self) -> ast::DataType { + ast::DataType::BigInt(None) + } + + fn int32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Int(None) + } + + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + _tz: &Option>, + ) -> ast::DataType { + ast::DataType::Timestamp(None, TimezoneInfo::None) + } + + fn date32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Date + } + + fn supports_column_alias_in_table_alias(&self) -> bool { + true + } + + fn requires_derived_table_alias(&self) -> bool { + true + } + + fn full_qualified_col(&self) -> bool { + false + } + + fn unnest_as_table_factor(&self) -> bool { + false + } + + fn scalar_function_to_sql_overrides( + &self, + _unparser: &Unparser, + _func_name: &str, + _args: &[Expr], + ) -> Result> { + // character_length -> length is handled by CharacterLengthStyle::Length + // Additional Spark-specific function mappings can be added here as needed + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_spark_dialect_identifier_quoting() { + let dialect = SparkDialect; + assert_eq!(dialect.identifier_quote_style("table_name"), Some('`')); + assert_eq!(dialect.identifier_quote_style("column"), Some('`')); + } + + #[test] + fn test_spark_dialect_type_mappings() { + let dialect = SparkDialect; + assert!(matches!(dialect.utf8_cast_dtype(), ast::DataType::Custom(..))); + assert!(matches!(dialect.int64_cast_dtype(), ast::DataType::BigInt(None))); + assert!(matches!(dialect.int32_cast_dtype(), ast::DataType::Int(None))); + assert!(matches!(dialect.date32_cast_dtype(), ast::DataType::Date)); + } + + #[test] + fn test_spark_dialect_requires_derived_table_alias() { + let dialect = SparkDialect; + assert!(dialect.requires_derived_table_alias()); + } + + #[test] + fn test_spark_dialect_extract_style() { + let dialect = SparkDialect; + assert!(matches!( + dialect.date_field_extract_style(), + DateFieldExtractStyle::Extract + )); + } + + #[test] + fn test_spark_dialect_character_length_style() { + let dialect = SparkDialect; + assert!(matches!( + dialect.character_length_style(), + CharacterLengthStyle::Length + )); + } +} diff --git a/crates/lance-graph/tests/test_to_spark_sql.rs b/crates/lance-graph/tests/test_to_spark_sql.rs new file mode 100644 index 00000000..ee2cac20 --- /dev/null +++ b/crates/lance-graph/tests/test_to_spark_sql.rs @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Integration tests for the to_sql API with SqlDialect::Spark +//! +//! These tests verify that Cypher queries can be correctly converted to Spark SQL strings. + +use arrow::array::{Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use lance_graph::{CypherQuery, GraphConfig, SqlDialect}; +use std::collections::HashMap; +use std::sync::Arc; + +fn create_person_table() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("person_id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + Field::new("city", DataType::Utf8, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])), + Arc::new(Int32Array::from(vec![28, 34, 29, 42])), + Arc::new(StringArray::from(vec![ + "New York", + "San Francisco", + "New York", + "Chicago", + ])), + ], + ) + .unwrap() +} + +fn create_company_table() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("company_id", DataType::Int32, false), + Field::new("company_name", DataType::Utf8, false), + Field::new("industry", DataType::Utf8, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![101, 102, 103])), + Arc::new(StringArray::from(vec!["TechCorp", "DataInc", "CloudSoft"])), + Arc::new(StringArray::from(vec!["Technology", "Analytics", "Cloud"])), + ], + ) + .unwrap() +} + +fn create_works_for_table() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("person_id", DataType::Int32, false), + Field::new("company_id", DataType::Int32, false), + Field::new("position", DataType::Utf8, false), + Field::new("salary", DataType::Int32, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(Int32Array::from(vec![101, 101, 102, 103])), + Arc::new(StringArray::from(vec![ + "Engineer", "Designer", "Manager", "Director", + ])), + Arc::new(Int32Array::from(vec![120000, 95000, 130000, 180000])), + ], + ) + .unwrap() +} + +#[tokio::test] +async fn test_to_sql_spark_simple_node_scan() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name") + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::Spark)) + .await + .unwrap(); + + assert!(sql.contains('`'), "Spark SQL should use backtick quoting"); + assert!( + sql.to_uppercase().contains("SELECT"), + "SQL should contain SELECT" + ); + assert!(sql.contains("name"), "SQL should reference name column"); + + println!("Generated Spark SQL:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_spark_with_filter() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age") + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::Spark)) + .await + .unwrap(); + + assert!(sql.contains('`'), "Spark SQL should use backtick quoting"); + assert!(sql.contains("30"), "SQL should contain filter value"); + assert!(sql.contains("age"), "SQL should reference age column"); + + println!("Generated Spark SQL with filter:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_spark_with_relationship() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .with_node_label("Company", "company_id") + .with_relationship("WORKS_FOR", "person_id", "company_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + datasets.insert("Company".to_string(), create_company_table()); + datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); + + let query = CypherQuery::new( + "MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN p.name, c.company_name", + ) + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::Spark)) + .await + .unwrap(); + + let sql_upper = sql.to_uppercase(); + assert!(sql.contains('`'), "Spark SQL should use backtick quoting"); + assert!(sql_upper.contains("JOIN"), "SQL should contain JOIN"); + + println!("Generated Spark SQL with relationship:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_spark_complex_query() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .with_node_label("Company", "company_id") + .with_relationship("WORKS_FOR", "person_id", "company_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + datasets.insert("Company".to_string(), create_company_table()); + datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); + + let query = CypherQuery::new( + "MATCH (p:Person)-[w:WORKS_FOR]->(c:Company) \ + WHERE p.age > 30 AND c.industry = 'Technology' \ + RETURN p.name, c.company_name, w.position \ + ORDER BY p.age DESC \ + LIMIT 5", + ) + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::Spark)) + .await + .unwrap(); + + assert!(sql.contains('`'), "Spark SQL should use backtick quoting"); + assert!( + sql.contains("ORDER BY") || sql.contains("order by"), + "SQL should contain ORDER BY" + ); + assert!( + sql.contains("LIMIT") || sql.contains("limit"), + "SQL should contain LIMIT" + ); + + println!("Generated complex Spark SQL:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_default_dialect_no_backticks() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age") + .unwrap() + .with_config(config); + + // Default dialect (None) should not use backticks + let sql = query.to_sql(datasets, None).await.unwrap(); + + assert!( + sql.to_uppercase().contains("SELECT"), + "SQL should contain SELECT" + ); + + println!("Generated default SQL:\n{}", sql); +} + +#[tokio::test] +async fn test_spark_sql_differs_from_default_sql() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets1 = HashMap::new(); + datasets1.insert("Person".to_string(), create_person_table()); + + let mut datasets2 = HashMap::new(); + datasets2.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age") + .unwrap() + .with_config(config); + + let default_sql = query.to_sql(datasets1, None).await.unwrap(); + let spark_sql = query + .to_sql(datasets2, Some(SqlDialect::Spark)) + .await + .unwrap(); + + // Spark SQL should use backtick quoting while default SQL may not + assert!( + spark_sql.contains('`'), + "Spark SQL should use backtick quoting" + ); + + println!("Default SQL:\n{}", default_sql); + println!("\nSpark SQL:\n{}", spark_sql); +} + +#[tokio::test] +async fn test_to_sql_postgresql_dialect() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name") + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::PostgreSql)) + .await + .unwrap(); + + // PostgreSQL uses double-quote identifier quoting + assert!( + sql.contains('"'), + "PostgreSQL SQL should use double-quote quoting" + ); + + println!("Generated PostgreSQL SQL:\n{}", sql); +} diff --git a/crates/lance-graph/tests/test_to_sql.rs b/crates/lance-graph/tests/test_to_sql.rs index 1ba8a8d6..9f39d4c7 100644 --- a/crates/lance-graph/tests/test_to_sql.rs +++ b/crates/lance-graph/tests/test_to_sql.rs @@ -97,7 +97,7 @@ async fn test_to_sql_simple_node_scan() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains expected elements assert!( @@ -129,7 +129,7 @@ async fn test_to_sql_with_filter() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains filter condition assert!(sql.contains("SELECT"), "SQL should contain SELECT"); @@ -157,7 +157,7 @@ async fn test_to_sql_with_multiple_properties() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify all columns are present assert!(sql.contains("name"), "SQL should contain name"); @@ -187,7 +187,7 @@ async fn test_to_sql_with_relationship() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains join let sql_upper = sql.to_uppercase(); @@ -223,7 +223,7 @@ async fn test_to_sql_with_relationship_filter() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains filter on relationship property assert!(sql.contains("salary"), "SQL should reference salary"); @@ -246,7 +246,7 @@ async fn test_to_sql_with_order_by() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains ORDER BY assert!( @@ -272,7 +272,7 @@ async fn test_to_sql_with_limit() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains LIMIT assert!( @@ -298,7 +298,7 @@ async fn test_to_sql_with_distinct() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL is generated successfully // Note: DISTINCT might be optimized away by DataFusion's optimizer in some cases @@ -323,7 +323,7 @@ async fn test_to_sql_with_alias() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains aliases assert!( @@ -358,7 +358,7 @@ async fn test_to_sql_complex_query() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify complex query elements assert!(sql.contains("SELECT"), "SQL should contain SELECT"); @@ -390,7 +390,7 @@ async fn test_to_sql_missing_config() { let query = CypherQuery::new("MATCH (p:Person) RETURN p.name").unwrap(); // Note: No config set - let result = query.to_sql(datasets).await; + let result = query.to_sql(datasets, None).await; // Should fail without config assert!(result.is_err(), "to_sql should fail without config"); @@ -413,7 +413,7 @@ async fn test_to_sql_empty_datasets() { .unwrap() .with_config(config); - let result = query.to_sql(datasets).await; + let result = query.to_sql(datasets, None).await; // Should fail with empty datasets assert!(result.is_err(), "to_sql should fail with empty datasets"); From 1bc5623d421335d68db20aac667dfcffa0f59dd2 Mon Sep 17 00:00:00 2001 From: Yu Chen Date: Mon, 23 Mar 2026 15:41:35 -0700 Subject: [PATCH 2/6] refactor: address PR review comments for SQL dialect support - Use CustomDialectBuilder for Spark dialect instead of manual Dialect impl - Move SqlDialect enum from spark_dialect.rs to query.rs as general-purpose type - Expose SqlDialect as a Python enum instead of error-prone string parameter Co-authored-by: Isaac --- crates/lance-graph-python/src/graph.rs | 55 ++++--- crates/lance-graph/src/lib.rs | 3 +- crates/lance-graph/src/query.rs | 18 ++- crates/lance-graph/src/spark_dialect.rs | 185 +++++++----------------- python/python/lance_graph/__init__.py | 2 + 5 files changed, 107 insertions(+), 156 deletions(-) diff --git a/crates/lance-graph-python/src/graph.rs b/crates/lance-graph-python/src/graph.rs index 127de651..03a3c8d8 100644 --- a/crates/lance-graph-python/src/graph.rs +++ b/crates/lance-graph-python/src/graph.rs @@ -27,7 +27,8 @@ use lance_graph::{ ast::{DistanceMetric as RustDistanceMetric, GraphPattern, ReadingClause}, CypherQuery as RustCypherQuery, ExecutionStrategy as RustExecutionStrategy, GraphConfig as RustGraphConfig, GraphError as RustGraphError, InMemoryCatalog, - SqlQuery as RustSqlQuery, VectorSearch as RustVectorSearch, + SqlDialect as RustSqlDialect, SqlQuery as RustSqlQuery, + VectorSearch as RustVectorSearch, }; use pyo3::{ exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError}, @@ -59,6 +60,34 @@ impl From for RustExecutionStrategy { } } +/// SQL dialect for generating SQL from Cypher queries +#[pyclass(name = "SqlDialect", module = "lance.graph")] +#[derive(Clone, Copy)] +pub enum SqlDialect { + /// Generic SQL (DataFusion default dialect) + Default, + /// Spark SQL dialect + Spark, + /// PostgreSQL dialect + PostgreSql, + /// MySQL dialect + MySql, + /// SQLite dialect + Sqlite, +} + +impl From for RustSqlDialect { + fn from(dialect: SqlDialect) -> Self { + match dialect { + SqlDialect::Default => RustSqlDialect::Default, + SqlDialect::Spark => RustSqlDialect::Spark, + SqlDialect::PostgreSql => RustSqlDialect::PostgreSql, + SqlDialect::MySql => RustSqlDialect::MySql, + SqlDialect::Sqlite => RustSqlDialect::Sqlite, + } + } +} + /// Distance metric for vector similarity search #[pyclass(name = "DistanceMetric", module = "lance.graph")] #[derive(Clone, Copy)] @@ -494,9 +523,8 @@ impl CypherQuery { /// ---------- /// datasets : dict /// Dictionary mapping table names to Lance datasets - /// dialect : str, optional - /// SQL dialect to use. One of "default", "spark", "postgresql", "mysql", "sqlite". - /// Defaults to "default" (generic DataFusion SQL). + /// dialect : SqlDialect, optional + /// SQL dialect to use. Defaults to SqlDialect.Default (generic DataFusion SQL). /// /// Returns /// ------- @@ -507,28 +535,14 @@ impl CypherQuery { /// ------ /// RuntimeError /// If SQL generation fails - /// ValueError - /// If an invalid dialect is specified #[pyo3(signature = (datasets, dialect=None))] fn to_sql( &self, py: Python, datasets: &Bound<'_, PyDict>, - dialect: Option<&str>, + dialect: Option, ) -> PyResult { - let sql_dialect = match dialect { - None | Some("default") => None, - Some("spark") => Some(lance_graph::SqlDialect::Spark), - Some("postgresql") | Some("postgres") => Some(lance_graph::SqlDialect::PostgreSql), - Some("mysql") => Some(lance_graph::SqlDialect::MySql), - Some("sqlite") => Some(lance_graph::SqlDialect::Sqlite), - Some(other) => { - return Err(PyValueError::new_err(format!( - "Unknown SQL dialect: '{}'. Valid options: 'default', 'spark', 'postgresql', 'mysql', 'sqlite'", - other - ))); - } - }; + let sql_dialect = dialect.map(|d| d.into()); // Convert datasets to Arrow RecordBatch map let arrow_datasets = python_datasets_to_batches(datasets)?; @@ -1570,6 +1584,7 @@ pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) -> let graph_module = PyModule::new(py, "graph")?; graph_module.add_class::()?; + graph_module.add_class::()?; graph_module.add_class::()?; graph_module.add_class::()?; graph_module.add_class::()?; diff --git a/crates/lance-graph/src/lib.rs b/crates/lance-graph/src/lib.rs index 9aed43ff..387033dd 100644 --- a/crates/lance-graph/src/lib.rs +++ b/crates/lance-graph/src/lib.rs @@ -68,8 +68,7 @@ pub use lance_graph_catalog::{ #[cfg(feature = "unity-catalog")] pub use lance_graph_catalog::{UnityCatalogConfig, UnityCatalogProvider}; pub use lance_vector_search::VectorSearch; -pub use query::{CypherQuery, ExecutionStrategy}; -pub use spark_dialect::{SqlDialect, SparkDialect}; +pub use query::{CypherQuery, ExecutionStrategy, SqlDialect}; pub use sql_query::SqlQuery; #[cfg(feature = "delta")] pub use table_readers::DeltaTableReader; diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index ffd6abd8..b6a6fe9a 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -16,6 +16,22 @@ use lance_namespace::models::DescribeTableRequest; use std::collections::{HashMap, HashSet}; use std::sync::Arc; +/// SQL dialect to use when generating SQL from Cypher queries. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SqlDialect { + /// Generic SQL (DataFusion default dialect) + #[default] + Default, + /// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.) + Spark, + /// PostgreSQL dialect + PostgreSql, + /// MySQL dialect + MySql, + /// SQLite dialect + Sqlite, +} + /// Normalize an Arrow schema to have lowercase field names. /// /// This ensures that column names in the dataset match the normalized @@ -302,7 +318,7 @@ impl CypherQuery { pub async fn to_sql( &self, datasets: HashMap, - dialect: Option, + dialect: Option, ) -> Result { use std::sync::Arc; diff --git a/crates/lance-graph/src/spark_dialect.rs b/crates/lance-graph/src/spark_dialect.rs index be7bb60e..3d594843 100644 --- a/crates/lance-graph/src/spark_dialect.rs +++ b/crates/lance-graph/src/spark_dialect.rs @@ -3,9 +3,9 @@ //! SQL dialect support for the DataFusion unparser. //! -//! This module provides a [`SqlDialect`] enum for selecting which SQL dialect -//! to use when unparsing DataFusion logical plans to SQL strings, and includes -//! a [`SparkDialect`] implementation for Spark SQL. +//! This module provides a Spark SQL dialect built using DataFusion's +//! [`CustomDialectBuilder`], and a helper to build an [`Unparser`] for any +//! supported [`SqlDialect`]. //! //! Key Spark SQL differences from standard SQL: //! - Backtick (`` ` ``) identifier quoting @@ -16,51 +16,50 @@ //! - `LENGTH()` instead of `CHARACTER_LENGTH()` //! - Subqueries in FROM require aliases -use std::sync::Arc; - -use arrow::datatypes::TimeUnit; -use datafusion_common::Result; -use datafusion_expr::Expr; use datafusion_sql::unparser::dialect::{ - CharacterLengthStyle, DateFieldExtractStyle, DefaultDialect, Dialect, IntervalStyle, - MySqlDialect, PostgreSqlDialect, SqliteDialect, + CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, + DefaultDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect, }; use datafusion_sql::unparser::Unparser; use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo}; -/// SQL dialect to use when generating SQL from Cypher queries. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum SqlDialect { - /// Generic SQL (DataFusion default dialect) - #[default] - Default, - /// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.) - Spark, - /// PostgreSQL dialect - PostgreSql, - /// MySQL dialect - MySql, - /// SQLite dialect - Sqlite, -} +use crate::query::SqlDialect; -impl SqlDialect { - /// Create a DataFusion `Unparser` configured for this dialect. - pub fn unparser(&self) -> DialectUnparser { - match self { - SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}), - SqlDialect::Spark => DialectUnparser::Spark(SparkDialect), - SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}), - SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}), - SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}), - } - } +/// Build a Spark SQL dialect using DataFusion's `CustomDialectBuilder`. +pub fn build_spark_dialect() -> CustomDialect { + CustomDialectBuilder::new() + .with_identifier_quote_style('`') + .with_supports_nulls_first_in_sort(true) + .with_use_timestamp_for_date64(true) + .with_utf8_cast_dtype(ast::DataType::Custom( + ObjectName::from(vec![Ident::new("STRING")]), + vec![], + )) + .with_large_utf8_cast_dtype(ast::DataType::Custom( + ObjectName::from(vec![Ident::new("STRING")]), + vec![], + )) + .with_date_field_extract_style(DateFieldExtractStyle::Extract) + .with_character_length_style(CharacterLengthStyle::Length) + .with_int64_cast_dtype(ast::DataType::BigInt(None)) + .with_int32_cast_dtype(ast::DataType::Int(None)) + .with_timestamp_cast_dtype( + ast::DataType::Timestamp(None, TimezoneInfo::None), + ast::DataType::Timestamp(None, TimezoneInfo::None), + ) + .with_date32_cast_dtype(ast::DataType::Date) + .with_supports_column_alias_in_table_alias(true) + .with_requires_derived_table_alias(true) + .with_full_qualified_col(false) + .with_unnest_as_table_factor(false) + .with_float64_ast_dtype(ast::DataType::Double(ast::ExactNumberInfo::None)) + .build() } /// Wrapper to hold the concrete dialect type and provide an `Unparser` reference. pub enum DialectUnparser { Default(DefaultDialect), - Spark(SparkDialect), + Spark(CustomDialect), PostgreSql(PostgreSqlDialect), MySql(MySqlDialect), Sqlite(SqliteDialect), @@ -78,114 +77,34 @@ impl DialectUnparser { } } -/// A Spark SQL dialect for unparsing DataFusion logical plans to Spark-compatible SQL. -pub struct SparkDialect; - -impl Dialect for SparkDialect { - fn identifier_quote_style(&self, _identifier: &str) -> Option { - Some('`') - } - - fn supports_nulls_first_in_sort(&self) -> bool { - true - } - - fn use_timestamp_for_date64(&self) -> bool { - true - } - - fn interval_style(&self) -> IntervalStyle { - IntervalStyle::SQLStandard - } - - fn float64_ast_dtype(&self) -> ast::DataType { - ast::DataType::Double(ast::ExactNumberInfo::None) - } - - fn utf8_cast_dtype(&self) -> ast::DataType { - ast::DataType::Custom( - ObjectName::from(vec![Ident::new("STRING")]), - vec![], - ) - } - - fn large_utf8_cast_dtype(&self) -> ast::DataType { - ast::DataType::Custom( - ObjectName::from(vec![Ident::new("STRING")]), - vec![], - ) - } - - fn date_field_extract_style(&self) -> DateFieldExtractStyle { - DateFieldExtractStyle::Extract - } - - fn character_length_style(&self) -> CharacterLengthStyle { - CharacterLengthStyle::Length - } - - fn int64_cast_dtype(&self) -> ast::DataType { - ast::DataType::BigInt(None) - } - - fn int32_cast_dtype(&self) -> ast::DataType { - ast::DataType::Int(None) - } - - fn timestamp_cast_dtype( - &self, - _time_unit: &TimeUnit, - _tz: &Option>, - ) -> ast::DataType { - ast::DataType::Timestamp(None, TimezoneInfo::None) - } - - fn date32_cast_dtype(&self) -> ast::DataType { - ast::DataType::Date - } - - fn supports_column_alias_in_table_alias(&self) -> bool { - true - } - - fn requires_derived_table_alias(&self) -> bool { - true - } - - fn full_qualified_col(&self) -> bool { - false - } - - fn unnest_as_table_factor(&self) -> bool { - false - } - - fn scalar_function_to_sql_overrides( - &self, - _unparser: &Unparser, - _func_name: &str, - _args: &[Expr], - ) -> Result> { - // character_length -> length is handled by CharacterLengthStyle::Length - // Additional Spark-specific function mappings can be added here as needed - Ok(None) +impl SqlDialect { + /// Create a `DialectUnparser` configured for this dialect. + pub fn unparser(&self) -> DialectUnparser { + match self { + SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}), + SqlDialect::Spark => DialectUnparser::Spark(build_spark_dialect()), + SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}), + SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}), + SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}), + } } } #[cfg(test)] mod tests { use super::*; + use datafusion_sql::unparser::dialect::Dialect; #[test] fn test_spark_dialect_identifier_quoting() { - let dialect = SparkDialect; + let dialect = build_spark_dialect(); assert_eq!(dialect.identifier_quote_style("table_name"), Some('`')); assert_eq!(dialect.identifier_quote_style("column"), Some('`')); } #[test] fn test_spark_dialect_type_mappings() { - let dialect = SparkDialect; + let dialect = build_spark_dialect(); assert!(matches!(dialect.utf8_cast_dtype(), ast::DataType::Custom(..))); assert!(matches!(dialect.int64_cast_dtype(), ast::DataType::BigInt(None))); assert!(matches!(dialect.int32_cast_dtype(), ast::DataType::Int(None))); @@ -194,13 +113,13 @@ mod tests { #[test] fn test_spark_dialect_requires_derived_table_alias() { - let dialect = SparkDialect; + let dialect = build_spark_dialect(); assert!(dialect.requires_derived_table_alias()); } #[test] fn test_spark_dialect_extract_style() { - let dialect = SparkDialect; + let dialect = build_spark_dialect(); assert!(matches!( dialect.date_field_extract_style(), DateFieldExtractStyle::Extract @@ -209,7 +128,7 @@ mod tests { #[test] fn test_spark_dialect_character_length_style() { - let dialect = SparkDialect; + let dialect = build_spark_dialect(); assert!(matches!( dialect.character_length_style(), CharacterLengthStyle::Length diff --git a/python/python/lance_graph/__init__.py b/python/python/lance_graph/__init__.py index 1d463c7d..26731960 100644 --- a/python/python/lance_graph/__init__.py +++ b/python/python/lance_graph/__init__.py @@ -83,6 +83,7 @@ def _load_dev_build() -> ModuleType: SqlQuery = _bindings.graph.SqlQuery SqlEngine = _bindings.graph.SqlEngine ExecutionStrategy = _bindings.graph.ExecutionStrategy +SqlDialect = _bindings.graph.SqlDialect VectorSearch = _bindings.graph.VectorSearch DistanceMetric = _bindings.graph.DistanceMetric @@ -101,6 +102,7 @@ def _load_dev_build() -> ModuleType: "SqlQuery", "SqlEngine", "ExecutionStrategy", + "SqlDialect", "VectorSearch", "DistanceMetric", "DirNamespace", From 6eaffe7379fdc0ca46e7e0d67e9beea3c2b31ef9 Mon Sep 17 00:00:00 2001 From: Yu Chen Date: Mon, 23 Mar 2026 22:08:24 -0700 Subject: [PATCH 3/6] refactor: move DialectUnparser to query.rs and box large enum variant Move DialectUnparser and SqlDialect::unparser() from spark_dialect.rs to query.rs to co-locate with the SqlDialect enum. Box the Spark CustomDialect variant to fix clippy::large_enum_variant lint. Co-authored-by: Isaac --- crates/lance-graph/src/query.rs | 39 ++++++++++++++++++++++ crates/lance-graph/src/spark_dialect.rs | 43 ++----------------------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index b6a6fe9a..e62b7167 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -9,8 +9,13 @@ use crate::config::GraphConfig; use crate::error::{GraphError, Result}; use crate::logical_plan::LogicalPlanner; use crate::parser::parse_cypher_query; +use crate::spark_dialect::build_spark_dialect; use arrow_array::RecordBatch; use arrow_schema::{Field, Schema, SchemaRef}; +use datafusion_sql::unparser::dialect::{ + CustomDialect, DefaultDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect, +}; +use datafusion_sql::unparser::Unparser; use lance_graph_catalog::DirNamespace; use lance_namespace::models::DescribeTableRequest; use std::collections::{HashMap, HashSet}; @@ -32,6 +37,40 @@ pub enum SqlDialect { Sqlite, } +/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference. +pub enum DialectUnparser { + Default(DefaultDialect), + Spark(Box), + PostgreSql(PostgreSqlDialect), + MySql(MySqlDialect), + Sqlite(SqliteDialect), +} + +impl DialectUnparser { + pub fn as_unparser(&self) -> Unparser<'_> { + match self { + DialectUnparser::Default(d) => Unparser::new(d), + DialectUnparser::Spark(d) => Unparser::new(d.as_ref()), + DialectUnparser::PostgreSql(d) => Unparser::new(d), + DialectUnparser::MySql(d) => Unparser::new(d), + DialectUnparser::Sqlite(d) => Unparser::new(d), + } + } +} + +impl SqlDialect { + /// Create a `DialectUnparser` configured for this dialect. + pub fn unparser(&self) -> DialectUnparser { + match self { + SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}), + SqlDialect::Spark => DialectUnparser::Spark(Box::new(build_spark_dialect())), + SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}), + SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}), + SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}), + } + } +} + /// Normalize an Arrow schema to have lowercase field names. /// /// This ensures that column names in the dataset match the normalized diff --git a/crates/lance-graph/src/spark_dialect.rs b/crates/lance-graph/src/spark_dialect.rs index 3d594843..3e5505e0 100644 --- a/crates/lance-graph/src/spark_dialect.rs +++ b/crates/lance-graph/src/spark_dialect.rs @@ -1,11 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -//! SQL dialect support for the DataFusion unparser. +//! Spark SQL dialect for the DataFusion unparser. //! //! This module provides a Spark SQL dialect built using DataFusion's -//! [`CustomDialectBuilder`], and a helper to build an [`Unparser`] for any -//! supported [`SqlDialect`]. +//! [`CustomDialectBuilder`]. //! //! Key Spark SQL differences from standard SQL: //! - Backtick (`` ` ``) identifier quoting @@ -18,13 +17,9 @@ use datafusion_sql::unparser::dialect::{ CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, - DefaultDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect, }; -use datafusion_sql::unparser::Unparser; use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo}; -use crate::query::SqlDialect; - /// Build a Spark SQL dialect using DataFusion's `CustomDialectBuilder`. pub fn build_spark_dialect() -> CustomDialect { CustomDialectBuilder::new() @@ -56,40 +51,6 @@ pub fn build_spark_dialect() -> CustomDialect { .build() } -/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference. -pub enum DialectUnparser { - Default(DefaultDialect), - Spark(CustomDialect), - PostgreSql(PostgreSqlDialect), - MySql(MySqlDialect), - Sqlite(SqliteDialect), -} - -impl DialectUnparser { - pub fn as_unparser(&self) -> Unparser<'_> { - match self { - DialectUnparser::Default(d) => Unparser::new(d), - DialectUnparser::Spark(d) => Unparser::new(d), - DialectUnparser::PostgreSql(d) => Unparser::new(d), - DialectUnparser::MySql(d) => Unparser::new(d), - DialectUnparser::Sqlite(d) => Unparser::new(d), - } - } -} - -impl SqlDialect { - /// Create a `DialectUnparser` configured for this dialect. - pub fn unparser(&self) -> DialectUnparser { - match self { - SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}), - SqlDialect::Spark => DialectUnparser::Spark(build_spark_dialect()), - SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}), - SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}), - SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}), - } - } -} - #[cfg(test)] mod tests { use super::*; From d4500212697d8604e420cb7784f362b3e8db35b9 Mon Sep 17 00:00:00 2001 From: Yu Chen Date: Mon, 23 Mar 2026 22:57:41 -0700 Subject: [PATCH 4/6] feat: add lance-graph-cli crate with lgraph binary for Cypher-to-SQL translation Adds a CLI tool that translates Cypher queries into SQL for various dialects (default, spark, postgresql, mysql, sqlite) using a JSON config file that describes the graph schema. Co-authored-by: Isaac --- Cargo.lock | 96 +++++++++++++ Cargo.toml | 1 + crates/lance-graph-cli/Cargo.toml | 23 ++++ crates/lance-graph-cli/src/main.rs | 208 +++++++++++++++++++++++++++++ 4 files changed, 328 insertions(+) create mode 100644 crates/lance-graph-cli/Cargo.toml create mode 100644 crates/lance-graph-cli/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 29add1d9..d4215458 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,12 +78,56 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.60.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.60.2", +] + [[package]] name = "anyhow" version = "1.0.100" @@ -1218,6 +1262,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2134bb3ea021b78629caa971416385309e0131b351b25e01dc16fb54e1b5fae" dependencies = [ "clap_builder", + "clap_derive", ] [[package]] @@ -1226,8 +1271,22 @@ version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2ba64afa3c0a6df7fa517765e31314e983f51dda798ffba27b988194fb65dc9" dependencies = [ + "anstream", "anstyle", "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfd7eae0b0f1a6e63d4b13c9c478de77c2eb546fba158ad50b4203dc24b9f9c" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -1245,6 +1304,12 @@ dependencies = [ "cc", ] +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + [[package]] name = "comfy-table" version = "7.1.2" @@ -3666,6 +3731,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itertools" version = "0.10.5" @@ -4135,6 +4206,19 @@ dependencies = [ "wiremock", ] +[[package]] +name = "lance-graph-cli" +version = "0.5.3" +dependencies = [ + "arrow-array", + "arrow-schema", + "clap", + "lance-graph", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "lance-graph-python" version = "0.5.3" @@ -4966,6 +5050,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "oneshot" version = "0.1.11" @@ -7404,6 +7494,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.18.1" diff --git a/Cargo.toml b/Cargo.toml index c66043bd..b1c87588 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "crates/lance-graph", "crates/lance-graph-catalog", + "crates/lance-graph-cli", "crates/lance-graph-python", "crates/lance-graph-benches", ] diff --git a/crates/lance-graph-cli/Cargo.toml b/crates/lance-graph-cli/Cargo.toml new file mode 100644 index 00000000..65ec3f4f --- /dev/null +++ b/crates/lance-graph-cli/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "lance-graph-cli" +version = "0.5.3" +edition = "2021" +license = "Apache-2.0" +authors = ["Lance Devs "] +repository = "https://github.com/lancedb/lance-graph" +description = "CLI tool for translating Cypher queries to SQL" +keywords = ["lance", "graph", "cypher", "sql", "cli"] +categories = ["command-line-utilities", "database"] + +[[bin]] +name = "lgraph" +path = "src/main.rs" + +[dependencies] +arrow-array = "56.2" +arrow-schema = "56.2" +clap = { version = "4", features = ["derive"] } +lance-graph = { path = "../lance-graph", version = "0.5.3", default-features = false } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tokio = { version = "1.37", features = ["macros", "rt-multi-thread"] } diff --git a/crates/lance-graph-cli/src/main.rs b/crates/lance-graph-cli/src/main.rs new file mode 100644 index 00000000..3ded568e --- /dev/null +++ b/crates/lance-graph-cli/src/main.rs @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! `lgraph` — CLI tool for translating Cypher queries to SQL. +//! +//! # Usage +//! +//! ```sh +//! # Translate a Cypher query using a graph config file +//! lgraph -c graph.json "MATCH (p:Person) WHERE p.age > 30 RETURN p.name" +//! +//! # Specify a dialect +//! lgraph -c graph.json -d spark "MATCH (p:Person) RETURN p.name" +//! +//! # Read query from stdin +//! echo "MATCH (p:Person) RETURN p.name" | lgraph -c graph.json +//! ``` + +use std::collections::HashMap; +use std::io::{self, Read}; +use std::path::PathBuf; +use std::sync::Arc; + +use arrow_array::RecordBatch; +use arrow_schema::{DataType, Field, Schema}; +use clap::Parser; +use lance_graph::{CypherQuery, GraphConfig, SqlDialect}; +use serde::Deserialize; + +/// Translate Cypher queries to SQL for various dialects. +#[derive(Parser)] +#[command(name = "lgraph", version, about)] +struct Cli { + /// Cypher query string. If omitted, reads from stdin. + query: Option, + + /// Path to graph config JSON file describing the schema. + #[arg(short, long)] + config: PathBuf, + + /// SQL dialect for the output. + #[arg(short, long, value_enum, default_value_t = DialectArg::Default)] + dialect: DialectArg, +} + +#[derive(Clone, Copy, clap::ValueEnum)] +enum DialectArg { + Default, + Spark, + #[value(alias = "postgres")] + Postgresql, + Mysql, + Sqlite, +} + +impl From for SqlDialect { + fn from(d: DialectArg) -> Self { + match d { + DialectArg::Default => SqlDialect::Default, + DialectArg::Spark => SqlDialect::Spark, + DialectArg::Postgresql => SqlDialect::PostgreSql, + DialectArg::Mysql => SqlDialect::MySql, + DialectArg::Sqlite => SqlDialect::Sqlite, + } + } +} + +// ── Config file schema ────────────────────────────────────────────── + +/// Top-level graph config file format. +/// +/// Example `graph.json`: +/// ```json +/// { +/// "nodes": { +/// "Person": { +/// "id_field": "person_id", +/// "fields": { "name": "Utf8", "age": "Int32" } +/// } +/// }, +/// "relationships": { +/// "KNOWS": { +/// "source_field": "src_id", +/// "target_field": "dst_id", +/// "fields": { "since": "Int32" } +/// } +/// } +/// } +/// ``` +#[derive(Deserialize)] +struct ConfigFile { + nodes: HashMap, + #[serde(default)] + relationships: HashMap, +} + +#[derive(Deserialize)] +struct NodeDef { + id_field: String, + #[serde(default)] + fields: HashMap, +} + +#[derive(Deserialize)] +struct RelDef { + source_field: String, + target_field: String, + #[serde(default)] + fields: HashMap, +} + +/// Parse a type name string into an Arrow DataType. +fn parse_data_type(s: &str) -> DataType { + match s.to_lowercase().as_str() { + "bool" | "boolean" => DataType::Boolean, + "int8" => DataType::Int8, + "int16" => DataType::Int16, + "int32" | "int" => DataType::Int32, + "int64" | "bigint" | "long" => DataType::Int64, + "uint8" => DataType::UInt8, + "uint16" => DataType::UInt16, + "uint32" => DataType::UInt32, + "uint64" => DataType::UInt64, + "float16" | "half" => DataType::Float16, + "float32" | "float" => DataType::Float32, + "float64" | "double" => DataType::Float64, + "date32" | "date" => DataType::Date32, + "date64" => DataType::Date64, + _ => DataType::Utf8, // default to string + } +} + +/// Build a [`GraphConfig`] and empty-schema datasets from the config file. +fn build_from_config( + cfg: &ConfigFile, +) -> Result<(GraphConfig, HashMap), Box> { + let mut builder = GraphConfig::builder(); + let mut datasets: HashMap = HashMap::new(); + + for (label, node_def) in &cfg.nodes { + builder = builder.with_node_label(label, &node_def.id_field); + + // Build schema: id field + declared fields + let mut fields = vec![Field::new(&node_def.id_field, DataType::Utf8, true)]; + for (fname, ftype) in &node_def.fields { + fields.push(Field::new(fname, parse_data_type(ftype), true)); + } + let schema = Arc::new(Schema::new(fields)); + let batch = RecordBatch::new_empty(schema); + datasets.insert(label.clone(), batch); + } + + for (rel_type, rel_def) in &cfg.relationships { + builder = + builder.with_relationship(rel_type, &rel_def.source_field, &rel_def.target_field); + + // Build schema: source + target fields + declared fields + let mut fields = vec![ + Field::new(&rel_def.source_field, DataType::Utf8, true), + Field::new(&rel_def.target_field, DataType::Utf8, true), + ]; + for (fname, ftype) in &rel_def.fields { + fields.push(Field::new(fname, parse_data_type(ftype), true)); + } + let schema = Arc::new(Schema::new(fields)); + let batch = RecordBatch::new_empty(schema); + datasets.insert(rel_type.clone(), batch); + } + + let config = builder.build()?; + Ok((config, datasets)) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let cli = Cli::parse(); + + // Read query from argument or stdin + let query_str = match cli.query { + Some(q) => q, + None => { + let mut buf = String::new(); + io::stdin().read_to_string(&mut buf)?; + buf.trim().to_string() + } + }; + + if query_str.is_empty() { + eprintln!("Error: no Cypher query provided"); + std::process::exit(1); + } + + // Load config + let config_text = std::fs::read_to_string(&cli.config) + .map_err(|e| format!("Failed to read config file {:?}: {}", cli.config, e))?; + let config_file: ConfigFile = serde_json::from_str(&config_text) + .map_err(|e| format!("Failed to parse config file: {}", e))?; + + let (graph_config, datasets) = build_from_config(&config_file)?; + + // Build CypherQuery, translate to SQL + let cypher = CypherQuery::new(&query_str)?.with_config(graph_config); + let dialect: SqlDialect = cli.dialect.into(); + let sql = cypher.to_sql(datasets, Some(dialect)).await?; + + println!("{sql}"); + Ok(()) +} From d1bd741aa0d1034911864ae31e2779fc6cee142c Mon Sep 17 00:00:00 2001 From: Yu Chen Date: Mon, 23 Mar 2026 22:59:29 -0700 Subject: [PATCH 5/6] Revert "feat: add lance-graph-cli crate with lgraph binary for Cypher-to-SQL translation" This reverts commit d4500212697d8604e420cb7784f362b3e8db35b9. --- Cargo.lock | 96 ------------- Cargo.toml | 1 - crates/lance-graph-cli/Cargo.toml | 23 ---- crates/lance-graph-cli/src/main.rs | 208 ----------------------------- 4 files changed, 328 deletions(-) delete mode 100644 crates/lance-graph-cli/Cargo.toml delete mode 100644 crates/lance-graph-cli/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index d4215458..29add1d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,56 +78,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" -[[package]] -name = "anstream" -version = "0.6.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" -dependencies = [ - "anstyle", - "anstyle-parse", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "is_terminal_polyfill", - "utf8parse", -] - [[package]] name = "anstyle" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" -[[package]] -name = "anstyle-parse" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" -dependencies = [ - "utf8parse", -] - -[[package]] -name = "anstyle-query" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" -dependencies = [ - "windows-sys 0.60.2", -] - -[[package]] -name = "anstyle-wincon" -version = "3.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" -dependencies = [ - "anstyle", - "once_cell_polyfill", - "windows-sys 0.60.2", -] - [[package]] name = "anyhow" version = "1.0.100" @@ -1262,7 +1218,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2134bb3ea021b78629caa971416385309e0131b351b25e01dc16fb54e1b5fae" dependencies = [ "clap_builder", - "clap_derive", ] [[package]] @@ -1271,22 +1226,8 @@ version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2ba64afa3c0a6df7fa517765e31314e983f51dda798ffba27b988194fb65dc9" dependencies = [ - "anstream", "anstyle", "clap_lex", - "strsim", -] - -[[package]] -name = "clap_derive" -version = "4.5.47" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbfd7eae0b0f1a6e63d4b13c9c478de77c2eb546fba158ad50b4203dc24b9f9c" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn 2.0.106", ] [[package]] @@ -1304,12 +1245,6 @@ dependencies = [ "cc", ] -[[package]] -name = "colorchoice" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" - [[package]] name = "comfy-table" version = "7.1.2" @@ -3731,12 +3666,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "is_terminal_polyfill" -version = "1.70.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" - [[package]] name = "itertools" version = "0.10.5" @@ -4206,19 +4135,6 @@ dependencies = [ "wiremock", ] -[[package]] -name = "lance-graph-cli" -version = "0.5.3" -dependencies = [ - "arrow-array", - "arrow-schema", - "clap", - "lance-graph", - "serde", - "serde_json", - "tokio", -] - [[package]] name = "lance-graph-python" version = "0.5.3" @@ -5050,12 +4966,6 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" -[[package]] -name = "once_cell_polyfill" -version = "1.70.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" - [[package]] name = "oneshot" version = "0.1.11" @@ -7494,12 +7404,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" -[[package]] -name = "utf8parse" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" - [[package]] name = "uuid" version = "1.18.1" diff --git a/Cargo.toml b/Cargo.toml index b1c87588..c66043bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ members = [ "crates/lance-graph", "crates/lance-graph-catalog", - "crates/lance-graph-cli", "crates/lance-graph-python", "crates/lance-graph-benches", ] diff --git a/crates/lance-graph-cli/Cargo.toml b/crates/lance-graph-cli/Cargo.toml deleted file mode 100644 index 65ec3f4f..00000000 --- a/crates/lance-graph-cli/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -[package] -name = "lance-graph-cli" -version = "0.5.3" -edition = "2021" -license = "Apache-2.0" -authors = ["Lance Devs "] -repository = "https://github.com/lancedb/lance-graph" -description = "CLI tool for translating Cypher queries to SQL" -keywords = ["lance", "graph", "cypher", "sql", "cli"] -categories = ["command-line-utilities", "database"] - -[[bin]] -name = "lgraph" -path = "src/main.rs" - -[dependencies] -arrow-array = "56.2" -arrow-schema = "56.2" -clap = { version = "4", features = ["derive"] } -lance-graph = { path = "../lance-graph", version = "0.5.3", default-features = false } -serde = { version = "1", features = ["derive"] } -serde_json = "1" -tokio = { version = "1.37", features = ["macros", "rt-multi-thread"] } diff --git a/crates/lance-graph-cli/src/main.rs b/crates/lance-graph-cli/src/main.rs deleted file mode 100644 index 3ded568e..00000000 --- a/crates/lance-graph-cli/src/main.rs +++ /dev/null @@ -1,208 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! `lgraph` — CLI tool for translating Cypher queries to SQL. -//! -//! # Usage -//! -//! ```sh -//! # Translate a Cypher query using a graph config file -//! lgraph -c graph.json "MATCH (p:Person) WHERE p.age > 30 RETURN p.name" -//! -//! # Specify a dialect -//! lgraph -c graph.json -d spark "MATCH (p:Person) RETURN p.name" -//! -//! # Read query from stdin -//! echo "MATCH (p:Person) RETURN p.name" | lgraph -c graph.json -//! ``` - -use std::collections::HashMap; -use std::io::{self, Read}; -use std::path::PathBuf; -use std::sync::Arc; - -use arrow_array::RecordBatch; -use arrow_schema::{DataType, Field, Schema}; -use clap::Parser; -use lance_graph::{CypherQuery, GraphConfig, SqlDialect}; -use serde::Deserialize; - -/// Translate Cypher queries to SQL for various dialects. -#[derive(Parser)] -#[command(name = "lgraph", version, about)] -struct Cli { - /// Cypher query string. If omitted, reads from stdin. - query: Option, - - /// Path to graph config JSON file describing the schema. - #[arg(short, long)] - config: PathBuf, - - /// SQL dialect for the output. - #[arg(short, long, value_enum, default_value_t = DialectArg::Default)] - dialect: DialectArg, -} - -#[derive(Clone, Copy, clap::ValueEnum)] -enum DialectArg { - Default, - Spark, - #[value(alias = "postgres")] - Postgresql, - Mysql, - Sqlite, -} - -impl From for SqlDialect { - fn from(d: DialectArg) -> Self { - match d { - DialectArg::Default => SqlDialect::Default, - DialectArg::Spark => SqlDialect::Spark, - DialectArg::Postgresql => SqlDialect::PostgreSql, - DialectArg::Mysql => SqlDialect::MySql, - DialectArg::Sqlite => SqlDialect::Sqlite, - } - } -} - -// ── Config file schema ────────────────────────────────────────────── - -/// Top-level graph config file format. -/// -/// Example `graph.json`: -/// ```json -/// { -/// "nodes": { -/// "Person": { -/// "id_field": "person_id", -/// "fields": { "name": "Utf8", "age": "Int32" } -/// } -/// }, -/// "relationships": { -/// "KNOWS": { -/// "source_field": "src_id", -/// "target_field": "dst_id", -/// "fields": { "since": "Int32" } -/// } -/// } -/// } -/// ``` -#[derive(Deserialize)] -struct ConfigFile { - nodes: HashMap, - #[serde(default)] - relationships: HashMap, -} - -#[derive(Deserialize)] -struct NodeDef { - id_field: String, - #[serde(default)] - fields: HashMap, -} - -#[derive(Deserialize)] -struct RelDef { - source_field: String, - target_field: String, - #[serde(default)] - fields: HashMap, -} - -/// Parse a type name string into an Arrow DataType. -fn parse_data_type(s: &str) -> DataType { - match s.to_lowercase().as_str() { - "bool" | "boolean" => DataType::Boolean, - "int8" => DataType::Int8, - "int16" => DataType::Int16, - "int32" | "int" => DataType::Int32, - "int64" | "bigint" | "long" => DataType::Int64, - "uint8" => DataType::UInt8, - "uint16" => DataType::UInt16, - "uint32" => DataType::UInt32, - "uint64" => DataType::UInt64, - "float16" | "half" => DataType::Float16, - "float32" | "float" => DataType::Float32, - "float64" | "double" => DataType::Float64, - "date32" | "date" => DataType::Date32, - "date64" => DataType::Date64, - _ => DataType::Utf8, // default to string - } -} - -/// Build a [`GraphConfig`] and empty-schema datasets from the config file. -fn build_from_config( - cfg: &ConfigFile, -) -> Result<(GraphConfig, HashMap), Box> { - let mut builder = GraphConfig::builder(); - let mut datasets: HashMap = HashMap::new(); - - for (label, node_def) in &cfg.nodes { - builder = builder.with_node_label(label, &node_def.id_field); - - // Build schema: id field + declared fields - let mut fields = vec![Field::new(&node_def.id_field, DataType::Utf8, true)]; - for (fname, ftype) in &node_def.fields { - fields.push(Field::new(fname, parse_data_type(ftype), true)); - } - let schema = Arc::new(Schema::new(fields)); - let batch = RecordBatch::new_empty(schema); - datasets.insert(label.clone(), batch); - } - - for (rel_type, rel_def) in &cfg.relationships { - builder = - builder.with_relationship(rel_type, &rel_def.source_field, &rel_def.target_field); - - // Build schema: source + target fields + declared fields - let mut fields = vec![ - Field::new(&rel_def.source_field, DataType::Utf8, true), - Field::new(&rel_def.target_field, DataType::Utf8, true), - ]; - for (fname, ftype) in &rel_def.fields { - fields.push(Field::new(fname, parse_data_type(ftype), true)); - } - let schema = Arc::new(Schema::new(fields)); - let batch = RecordBatch::new_empty(schema); - datasets.insert(rel_type.clone(), batch); - } - - let config = builder.build()?; - Ok((config, datasets)) -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - let cli = Cli::parse(); - - // Read query from argument or stdin - let query_str = match cli.query { - Some(q) => q, - None => { - let mut buf = String::new(); - io::stdin().read_to_string(&mut buf)?; - buf.trim().to_string() - } - }; - - if query_str.is_empty() { - eprintln!("Error: no Cypher query provided"); - std::process::exit(1); - } - - // Load config - let config_text = std::fs::read_to_string(&cli.config) - .map_err(|e| format!("Failed to read config file {:?}: {}", cli.config, e))?; - let config_file: ConfigFile = serde_json::from_str(&config_text) - .map_err(|e| format!("Failed to parse config file: {}", e))?; - - let (graph_config, datasets) = build_from_config(&config_file)?; - - // Build CypherQuery, translate to SQL - let cypher = CypherQuery::new(&query_str)?.with_config(graph_config); - let dialect: SqlDialect = cli.dialect.into(); - let sql = cypher.to_sql(datasets, Some(dialect)).await?; - - println!("{sql}"); - Ok(()) -} From 0d1a864331e0ed4bbd7ce221877b77a0479e40e7 Mon Sep 17 00:00:00 2001 From: Yu Chen Date: Mon, 23 Mar 2026 23:22:50 -0700 Subject: [PATCH 6/6] style: apply rustfmt formatting Co-authored-by: Isaac --- crates/lance-graph/src/query.rs | 13 ++++++------- crates/lance-graph/src/spark_dialect.rs | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index e62b7167..e63e2a6a 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -385,13 +385,12 @@ impl CypherQuery { // Unparse to SQL using the specified dialect let dialect_unparser = dialect.unparser(); let unparser = dialect_unparser.as_unparser(); - let sql_ast = - unparser - .plan_to_sql(&optimized_plan) - .map_err(|e| GraphError::PlanError { - message: format!("Failed to unparse plan to SQL: {}", e), - location: snafu::Location::new(file!(), line!(), column!()), - })?; + let sql_ast = unparser + .plan_to_sql(&optimized_plan) + .map_err(|e| GraphError::PlanError { + message: format!("Failed to unparse plan to SQL: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; Ok(sql_ast.to_string()) } diff --git a/crates/lance-graph/src/spark_dialect.rs b/crates/lance-graph/src/spark_dialect.rs index 3e5505e0..9457ead8 100644 --- a/crates/lance-graph/src/spark_dialect.rs +++ b/crates/lance-graph/src/spark_dialect.rs @@ -15,10 +15,10 @@ //! - `LENGTH()` instead of `CHARACTER_LENGTH()` //! - Subqueries in FROM require aliases +use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo}; use datafusion_sql::unparser::dialect::{ CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, }; -use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo}; /// Build a Spark SQL dialect using DataFusion's `CustomDialectBuilder`. pub fn build_spark_dialect() -> CustomDialect { @@ -66,9 +66,18 @@ mod tests { #[test] fn test_spark_dialect_type_mappings() { let dialect = build_spark_dialect(); - assert!(matches!(dialect.utf8_cast_dtype(), ast::DataType::Custom(..))); - assert!(matches!(dialect.int64_cast_dtype(), ast::DataType::BigInt(None))); - assert!(matches!(dialect.int32_cast_dtype(), ast::DataType::Int(None))); + assert!(matches!( + dialect.utf8_cast_dtype(), + ast::DataType::Custom(..) + )); + assert!(matches!( + dialect.int64_cast_dtype(), + ast::DataType::BigInt(None) + )); + assert!(matches!( + dialect.int32_cast_dtype(), + ast::DataType::Int(None) + )); assert!(matches!(dialect.date32_cast_dtype(), ast::DataType::Date)); }