diff --git a/crates/lance-graph-python/src/graph.rs b/crates/lance-graph-python/src/graph.rs index 223c6aac..03a3c8d8 100644 --- a/crates/lance-graph-python/src/graph.rs +++ b/crates/lance-graph-python/src/graph.rs @@ -27,7 +27,8 @@ use lance_graph::{ ast::{DistanceMetric as RustDistanceMetric, GraphPattern, ReadingClause}, CypherQuery as RustCypherQuery, ExecutionStrategy as RustExecutionStrategy, GraphConfig as RustGraphConfig, GraphError as RustGraphError, InMemoryCatalog, - SqlQuery as RustSqlQuery, VectorSearch as RustVectorSearch, + SqlDialect as RustSqlDialect, SqlQuery as RustSqlQuery, + VectorSearch as RustVectorSearch, }; use pyo3::{ exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError}, @@ -59,6 +60,34 @@ impl From for RustExecutionStrategy { } } +/// SQL dialect for generating SQL from Cypher queries +#[pyclass(name = "SqlDialect", module = "lance.graph")] +#[derive(Clone, Copy)] +pub enum SqlDialect { + /// Generic SQL (DataFusion default dialect) + Default, + /// Spark SQL dialect + Spark, + /// PostgreSQL dialect + PostgreSql, + /// MySQL dialect + MySql, + /// SQLite dialect + Sqlite, +} + +impl From for RustSqlDialect { + fn from(dialect: SqlDialect) -> Self { + match dialect { + SqlDialect::Default => RustSqlDialect::Default, + SqlDialect::Spark => RustSqlDialect::Spark, + SqlDialect::PostgreSql => RustSqlDialect::PostgreSql, + SqlDialect::MySql => RustSqlDialect::MySql, + SqlDialect::Sqlite => RustSqlDialect::Sqlite, + } + } +} + /// Distance metric for vector similarity search #[pyclass(name = "DistanceMetric", module = "lance.graph")] #[derive(Clone, Copy)] @@ -494,6 +523,8 @@ impl CypherQuery { /// ---------- /// datasets : dict /// Dictionary mapping table names to Lance datasets + /// dialect : SqlDialect, optional + /// SQL dialect to use. Defaults to SqlDialect.Default (generic DataFusion SQL). /// /// Returns /// ------- @@ -504,7 +535,15 @@ impl CypherQuery { /// ------ /// RuntimeError /// If SQL generation fails - fn to_sql(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult { + #[pyo3(signature = (datasets, dialect=None))] + fn to_sql( + &self, + py: Python, + datasets: &Bound<'_, PyDict>, + dialect: Option, + ) -> PyResult { + let sql_dialect = dialect.map(|d| d.into()); + // Convert datasets to Arrow RecordBatch map let arrow_datasets = python_datasets_to_batches(datasets)?; @@ -513,7 +552,7 @@ impl CypherQuery { // Execute via runtime let sql = RT - .block_on(Some(py), inner_query.to_sql(arrow_datasets))? + .block_on(Some(py), inner_query.to_sql(arrow_datasets, sql_dialect))? .map_err(graph_error_to_pyerr)?; Ok(sql) @@ -1545,6 +1584,7 @@ pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) -> let graph_module = PyModule::new(py, "graph")?; graph_module.add_class::()?; + graph_module.add_class::()?; graph_module.add_class::()?; graph_module.add_class::()?; graph_module.add_class::()?; diff --git a/crates/lance-graph/src/lib.rs b/crates/lance-graph/src/lib.rs index 8f93fddb..387033dd 100644 --- a/crates/lance-graph/src/lib.rs +++ b/crates/lance-graph/src/lib.rs @@ -47,6 +47,7 @@ pub mod parameter_substitution; pub mod parser; pub mod query; pub mod semantic; +pub mod spark_dialect; pub mod sql_catalog; pub mod sql_query; pub mod table_readers; @@ -67,7 +68,7 @@ pub use lance_graph_catalog::{ #[cfg(feature = "unity-catalog")] pub use lance_graph_catalog::{UnityCatalogConfig, UnityCatalogProvider}; pub use lance_vector_search::VectorSearch; -pub use query::{CypherQuery, ExecutionStrategy}; +pub use query::{CypherQuery, ExecutionStrategy, SqlDialect}; pub use sql_query::SqlQuery; #[cfg(feature = "delta")] pub use table_readers::DeltaTableReader; diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index 9625f273..e63e2a6a 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -9,13 +9,68 @@ use crate::config::GraphConfig; use crate::error::{GraphError, Result}; use crate::logical_plan::LogicalPlanner; use crate::parser::parse_cypher_query; +use crate::spark_dialect::build_spark_dialect; use arrow_array::RecordBatch; use arrow_schema::{Field, Schema, SchemaRef}; +use datafusion_sql::unparser::dialect::{ + CustomDialect, DefaultDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect, +}; +use datafusion_sql::unparser::Unparser; use lance_graph_catalog::DirNamespace; use lance_namespace::models::DescribeTableRequest; use std::collections::{HashMap, HashSet}; use std::sync::Arc; +/// SQL dialect to use when generating SQL from Cypher queries. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SqlDialect { + /// Generic SQL (DataFusion default dialect) + #[default] + Default, + /// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.) + Spark, + /// PostgreSQL dialect + PostgreSql, + /// MySQL dialect + MySql, + /// SQLite dialect + Sqlite, +} + +/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference. +pub enum DialectUnparser { + Default(DefaultDialect), + Spark(Box), + PostgreSql(PostgreSqlDialect), + MySql(MySqlDialect), + Sqlite(SqliteDialect), +} + +impl DialectUnparser { + pub fn as_unparser(&self) -> Unparser<'_> { + match self { + DialectUnparser::Default(d) => Unparser::new(d), + DialectUnparser::Spark(d) => Unparser::new(d.as_ref()), + DialectUnparser::PostgreSql(d) => Unparser::new(d), + DialectUnparser::MySql(d) => Unparser::new(d), + DialectUnparser::Sqlite(d) => Unparser::new(d), + } + } +} + +impl SqlDialect { + /// Create a `DialectUnparser` configured for this dialect. + pub fn unparser(&self) -> DialectUnparser { + match self { + SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}), + SqlDialect::Spark => DialectUnparser::Spark(Box::new(build_spark_dialect())), + SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}), + SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}), + SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}), + } + } +} + /// Normalize an Arrow schema to have lowercase field names. /// /// This ensures that column names in the dataset match the normalized @@ -280,10 +335,10 @@ impl CypherQuery { self.explain_internal(Arc::new(catalog), ctx).await } - /// Convert the Cypher query to a DataFusion SQL string + /// Convert the Cypher query to a SQL string in the specified dialect. /// /// This method generates a SQL string that corresponds to the DataFusion logical plan - /// derived from the Cypher query. It uses the `datafusion-sql` unparser. + /// derived from the Cypher query, using the specified SQL dialect for unparsing. /// /// **WARNING**: This method is experimental and the generated SQL dialect may change. /// @@ -293,16 +348,20 @@ impl CypherQuery { /// /// # Arguments /// * `datasets` - HashMap of table name to RecordBatch (nodes and relationships) + /// * `dialect` - The SQL dialect to use for generating the output SQL. + /// Defaults to `SqlDialect::Default` (generic DataFusion SQL). + /// Use `SqlDialect::Spark` for Spark SQL, `SqlDialect::PostgreSql`, etc. /// /// # Returns - /// A SQL string representing the query + /// A SQL string representing the query in the specified dialect pub async fn to_sql( &self, datasets: HashMap, + dialect: Option, ) -> Result { - use datafusion_sql::unparser::plan_to_sql; use std::sync::Arc; + let dialect = dialect.unwrap_or_default(); let _config = self.require_config()?; // Build catalog and context from datasets using the helper @@ -323,11 +382,15 @@ impl CypherQuery { location: snafu::Location::new(file!(), line!(), column!()), })?; - // Unparse to SQL - let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError { - message: format!("Failed to unparse plan to SQL: {}", e), - location: snafu::Location::new(file!(), line!(), column!()), - })?; + // Unparse to SQL using the specified dialect + let dialect_unparser = dialect.unparser(); + let unparser = dialect_unparser.as_unparser(); + let sql_ast = unparser + .plan_to_sql(&optimized_plan) + .map_err(|e| GraphError::PlanError { + message: format!("Failed to unparse plan to SQL: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; Ok(sql_ast.to_string()) } @@ -1852,7 +1915,7 @@ mod tests { .unwrap() .with_config(cfg); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); println!("Generated SQL: {}", sql); assert!(sql.contains("SELECT")); diff --git a/crates/lance-graph/src/spark_dialect.rs b/crates/lance-graph/src/spark_dialect.rs new file mode 100644 index 00000000..9457ead8 --- /dev/null +++ b/crates/lance-graph/src/spark_dialect.rs @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Spark SQL dialect for the DataFusion unparser. +//! +//! This module provides a Spark SQL dialect built using DataFusion's +//! [`CustomDialectBuilder`]. +//! +//! Key Spark SQL differences from standard SQL: +//! - Backtick (`` ` ``) identifier quoting +//! - `EXTRACT(field FROM expr)` for date field extraction +//! - `STRING` type for casting (not `VARCHAR`) +//! - `BIGINT`/`INT` for integer types +//! - `TIMESTAMP` for all timestamp types (no timezone info in cast) +//! - `LENGTH()` instead of `CHARACTER_LENGTH()` +//! - Subqueries in FROM require aliases + +use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo}; +use datafusion_sql::unparser::dialect::{ + CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, +}; + +/// Build a Spark SQL dialect using DataFusion's `CustomDialectBuilder`. +pub fn build_spark_dialect() -> CustomDialect { + CustomDialectBuilder::new() + .with_identifier_quote_style('`') + .with_supports_nulls_first_in_sort(true) + .with_use_timestamp_for_date64(true) + .with_utf8_cast_dtype(ast::DataType::Custom( + ObjectName::from(vec![Ident::new("STRING")]), + vec![], + )) + .with_large_utf8_cast_dtype(ast::DataType::Custom( + ObjectName::from(vec![Ident::new("STRING")]), + vec![], + )) + .with_date_field_extract_style(DateFieldExtractStyle::Extract) + .with_character_length_style(CharacterLengthStyle::Length) + .with_int64_cast_dtype(ast::DataType::BigInt(None)) + .with_int32_cast_dtype(ast::DataType::Int(None)) + .with_timestamp_cast_dtype( + ast::DataType::Timestamp(None, TimezoneInfo::None), + ast::DataType::Timestamp(None, TimezoneInfo::None), + ) + .with_date32_cast_dtype(ast::DataType::Date) + .with_supports_column_alias_in_table_alias(true) + .with_requires_derived_table_alias(true) + .with_full_qualified_col(false) + .with_unnest_as_table_factor(false) + .with_float64_ast_dtype(ast::DataType::Double(ast::ExactNumberInfo::None)) + .build() +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_sql::unparser::dialect::Dialect; + + #[test] + fn test_spark_dialect_identifier_quoting() { + let dialect = build_spark_dialect(); + assert_eq!(dialect.identifier_quote_style("table_name"), Some('`')); + assert_eq!(dialect.identifier_quote_style("column"), Some('`')); + } + + #[test] + fn test_spark_dialect_type_mappings() { + let dialect = build_spark_dialect(); + assert!(matches!( + dialect.utf8_cast_dtype(), + ast::DataType::Custom(..) + )); + assert!(matches!( + dialect.int64_cast_dtype(), + ast::DataType::BigInt(None) + )); + assert!(matches!( + dialect.int32_cast_dtype(), + ast::DataType::Int(None) + )); + assert!(matches!(dialect.date32_cast_dtype(), ast::DataType::Date)); + } + + #[test] + fn test_spark_dialect_requires_derived_table_alias() { + let dialect = build_spark_dialect(); + assert!(dialect.requires_derived_table_alias()); + } + + #[test] + fn test_spark_dialect_extract_style() { + let dialect = build_spark_dialect(); + assert!(matches!( + dialect.date_field_extract_style(), + DateFieldExtractStyle::Extract + )); + } + + #[test] + fn test_spark_dialect_character_length_style() { + let dialect = build_spark_dialect(); + assert!(matches!( + dialect.character_length_style(), + CharacterLengthStyle::Length + )); + } +} diff --git a/crates/lance-graph/tests/test_to_spark_sql.rs b/crates/lance-graph/tests/test_to_spark_sql.rs new file mode 100644 index 00000000..ee2cac20 --- /dev/null +++ b/crates/lance-graph/tests/test_to_spark_sql.rs @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Integration tests for the to_sql API with SqlDialect::Spark +//! +//! These tests verify that Cypher queries can be correctly converted to Spark SQL strings. + +use arrow::array::{Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use lance_graph::{CypherQuery, GraphConfig, SqlDialect}; +use std::collections::HashMap; +use std::sync::Arc; + +fn create_person_table() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("person_id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + Field::new("city", DataType::Utf8, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])), + Arc::new(Int32Array::from(vec![28, 34, 29, 42])), + Arc::new(StringArray::from(vec![ + "New York", + "San Francisco", + "New York", + "Chicago", + ])), + ], + ) + .unwrap() +} + +fn create_company_table() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("company_id", DataType::Int32, false), + Field::new("company_name", DataType::Utf8, false), + Field::new("industry", DataType::Utf8, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![101, 102, 103])), + Arc::new(StringArray::from(vec!["TechCorp", "DataInc", "CloudSoft"])), + Arc::new(StringArray::from(vec!["Technology", "Analytics", "Cloud"])), + ], + ) + .unwrap() +} + +fn create_works_for_table() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("person_id", DataType::Int32, false), + Field::new("company_id", DataType::Int32, false), + Field::new("position", DataType::Utf8, false), + Field::new("salary", DataType::Int32, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(Int32Array::from(vec![101, 101, 102, 103])), + Arc::new(StringArray::from(vec![ + "Engineer", "Designer", "Manager", "Director", + ])), + Arc::new(Int32Array::from(vec![120000, 95000, 130000, 180000])), + ], + ) + .unwrap() +} + +#[tokio::test] +async fn test_to_sql_spark_simple_node_scan() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name") + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::Spark)) + .await + .unwrap(); + + assert!(sql.contains('`'), "Spark SQL should use backtick quoting"); + assert!( + sql.to_uppercase().contains("SELECT"), + "SQL should contain SELECT" + ); + assert!(sql.contains("name"), "SQL should reference name column"); + + println!("Generated Spark SQL:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_spark_with_filter() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age") + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::Spark)) + .await + .unwrap(); + + assert!(sql.contains('`'), "Spark SQL should use backtick quoting"); + assert!(sql.contains("30"), "SQL should contain filter value"); + assert!(sql.contains("age"), "SQL should reference age column"); + + println!("Generated Spark SQL with filter:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_spark_with_relationship() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .with_node_label("Company", "company_id") + .with_relationship("WORKS_FOR", "person_id", "company_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + datasets.insert("Company".to_string(), create_company_table()); + datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); + + let query = CypherQuery::new( + "MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN p.name, c.company_name", + ) + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::Spark)) + .await + .unwrap(); + + let sql_upper = sql.to_uppercase(); + assert!(sql.contains('`'), "Spark SQL should use backtick quoting"); + assert!(sql_upper.contains("JOIN"), "SQL should contain JOIN"); + + println!("Generated Spark SQL with relationship:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_spark_complex_query() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .with_node_label("Company", "company_id") + .with_relationship("WORKS_FOR", "person_id", "company_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + datasets.insert("Company".to_string(), create_company_table()); + datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); + + let query = CypherQuery::new( + "MATCH (p:Person)-[w:WORKS_FOR]->(c:Company) \ + WHERE p.age > 30 AND c.industry = 'Technology' \ + RETURN p.name, c.company_name, w.position \ + ORDER BY p.age DESC \ + LIMIT 5", + ) + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::Spark)) + .await + .unwrap(); + + assert!(sql.contains('`'), "Spark SQL should use backtick quoting"); + assert!( + sql.contains("ORDER BY") || sql.contains("order by"), + "SQL should contain ORDER BY" + ); + assert!( + sql.contains("LIMIT") || sql.contains("limit"), + "SQL should contain LIMIT" + ); + + println!("Generated complex Spark SQL:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_default_dialect_no_backticks() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age") + .unwrap() + .with_config(config); + + // Default dialect (None) should not use backticks + let sql = query.to_sql(datasets, None).await.unwrap(); + + assert!( + sql.to_uppercase().contains("SELECT"), + "SQL should contain SELECT" + ); + + println!("Generated default SQL:\n{}", sql); +} + +#[tokio::test] +async fn test_spark_sql_differs_from_default_sql() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets1 = HashMap::new(); + datasets1.insert("Person".to_string(), create_person_table()); + + let mut datasets2 = HashMap::new(); + datasets2.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age") + .unwrap() + .with_config(config); + + let default_sql = query.to_sql(datasets1, None).await.unwrap(); + let spark_sql = query + .to_sql(datasets2, Some(SqlDialect::Spark)) + .await + .unwrap(); + + // Spark SQL should use backtick quoting while default SQL may not + assert!( + spark_sql.contains('`'), + "Spark SQL should use backtick quoting" + ); + + println!("Default SQL:\n{}", default_sql); + println!("\nSpark SQL:\n{}", spark_sql); +} + +#[tokio::test] +async fn test_to_sql_postgresql_dialect() { + let config = GraphConfig::builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name") + .unwrap() + .with_config(config); + + let sql = query + .to_sql(datasets, Some(SqlDialect::PostgreSql)) + .await + .unwrap(); + + // PostgreSQL uses double-quote identifier quoting + assert!( + sql.contains('"'), + "PostgreSQL SQL should use double-quote quoting" + ); + + println!("Generated PostgreSQL SQL:\n{}", sql); +} diff --git a/crates/lance-graph/tests/test_to_sql.rs b/crates/lance-graph/tests/test_to_sql.rs index 1ba8a8d6..9f39d4c7 100644 --- a/crates/lance-graph/tests/test_to_sql.rs +++ b/crates/lance-graph/tests/test_to_sql.rs @@ -97,7 +97,7 @@ async fn test_to_sql_simple_node_scan() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains expected elements assert!( @@ -129,7 +129,7 @@ async fn test_to_sql_with_filter() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains filter condition assert!(sql.contains("SELECT"), "SQL should contain SELECT"); @@ -157,7 +157,7 @@ async fn test_to_sql_with_multiple_properties() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify all columns are present assert!(sql.contains("name"), "SQL should contain name"); @@ -187,7 +187,7 @@ async fn test_to_sql_with_relationship() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains join let sql_upper = sql.to_uppercase(); @@ -223,7 +223,7 @@ async fn test_to_sql_with_relationship_filter() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains filter on relationship property assert!(sql.contains("salary"), "SQL should reference salary"); @@ -246,7 +246,7 @@ async fn test_to_sql_with_order_by() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains ORDER BY assert!( @@ -272,7 +272,7 @@ async fn test_to_sql_with_limit() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains LIMIT assert!( @@ -298,7 +298,7 @@ async fn test_to_sql_with_distinct() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL is generated successfully // Note: DISTINCT might be optimized away by DataFusion's optimizer in some cases @@ -323,7 +323,7 @@ async fn test_to_sql_with_alias() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify SQL contains aliases assert!( @@ -358,7 +358,7 @@ async fn test_to_sql_complex_query() { .unwrap() .with_config(config); - let sql = query.to_sql(datasets).await.unwrap(); + let sql = query.to_sql(datasets, None).await.unwrap(); // Verify complex query elements assert!(sql.contains("SELECT"), "SQL should contain SELECT"); @@ -390,7 +390,7 @@ async fn test_to_sql_missing_config() { let query = CypherQuery::new("MATCH (p:Person) RETURN p.name").unwrap(); // Note: No config set - let result = query.to_sql(datasets).await; + let result = query.to_sql(datasets, None).await; // Should fail without config assert!(result.is_err(), "to_sql should fail without config"); @@ -413,7 +413,7 @@ async fn test_to_sql_empty_datasets() { .unwrap() .with_config(config); - let result = query.to_sql(datasets).await; + let result = query.to_sql(datasets, None).await; // Should fail with empty datasets assert!(result.is_err(), "to_sql should fail with empty datasets"); diff --git a/python/python/lance_graph/__init__.py b/python/python/lance_graph/__init__.py index 1d463c7d..26731960 100644 --- a/python/python/lance_graph/__init__.py +++ b/python/python/lance_graph/__init__.py @@ -83,6 +83,7 @@ def _load_dev_build() -> ModuleType: SqlQuery = _bindings.graph.SqlQuery SqlEngine = _bindings.graph.SqlEngine ExecutionStrategy = _bindings.graph.ExecutionStrategy +SqlDialect = _bindings.graph.SqlDialect VectorSearch = _bindings.graph.VectorSearch DistanceMetric = _bindings.graph.DistanceMetric @@ -101,6 +102,7 @@ def _load_dev_build() -> ModuleType: "SqlQuery", "SqlEngine", "ExecutionStrategy", + "SqlDialect", "VectorSearch", "DistanceMetric", "DirNamespace",