Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions crates/lance-graph-python/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use lance_graph::{
ast::{DistanceMetric as RustDistanceMetric, GraphPattern, ReadingClause},
CypherQuery as RustCypherQuery, ExecutionStrategy as RustExecutionStrategy,
GraphConfig as RustGraphConfig, GraphError as RustGraphError, InMemoryCatalog,
SqlQuery as RustSqlQuery, VectorSearch as RustVectorSearch,
SqlDialect as RustSqlDialect, SqlQuery as RustSqlQuery,
VectorSearch as RustVectorSearch,
};
use pyo3::{
exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError},
Expand Down Expand Up @@ -59,6 +60,34 @@ impl From<ExecutionStrategy> for RustExecutionStrategy {
}
}

/// SQL dialect for generating SQL from Cypher queries
#[pyclass(name = "SqlDialect", module = "lance.graph")]
#[derive(Clone, Copy)]
pub enum SqlDialect {
/// Generic SQL (DataFusion default dialect)
Default,
/// Spark SQL dialect
Spark,
/// PostgreSQL dialect
PostgreSql,
/// MySQL dialect
MySql,
/// SQLite dialect
Sqlite,
}

impl From<SqlDialect> for RustSqlDialect {
fn from(dialect: SqlDialect) -> Self {
match dialect {
SqlDialect::Default => RustSqlDialect::Default,
SqlDialect::Spark => RustSqlDialect::Spark,
SqlDialect::PostgreSql => RustSqlDialect::PostgreSql,
SqlDialect::MySql => RustSqlDialect::MySql,
SqlDialect::Sqlite => RustSqlDialect::Sqlite,
}
}
}

/// Distance metric for vector similarity search
#[pyclass(name = "DistanceMetric", module = "lance.graph")]
#[derive(Clone, Copy)]
Expand Down Expand Up @@ -494,6 +523,8 @@ impl CypherQuery {
/// ----------
/// datasets : dict
/// Dictionary mapping table names to Lance datasets
/// dialect : SqlDialect, optional
/// SQL dialect to use. Defaults to SqlDialect.Default (generic DataFusion SQL).
///
/// Returns
/// -------
Expand All @@ -504,7 +535,15 @@ impl CypherQuery {
/// ------
/// RuntimeError
/// If SQL generation fails
fn to_sql(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
#[pyo3(signature = (datasets, dialect=None))]
fn to_sql(
&self,
py: Python,
datasets: &Bound<'_, PyDict>,
dialect: Option<SqlDialect>,
) -> PyResult<String> {
let sql_dialect = dialect.map(|d| d.into());

// Convert datasets to Arrow RecordBatch map
let arrow_datasets = python_datasets_to_batches(datasets)?;

Expand All @@ -513,7 +552,7 @@ impl CypherQuery {

// Execute via runtime
let sql = RT
.block_on(Some(py), inner_query.to_sql(arrow_datasets))?
.block_on(Some(py), inner_query.to_sql(arrow_datasets, sql_dialect))?
.map_err(graph_error_to_pyerr)?;

Ok(sql)
Expand Down Expand Up @@ -1545,6 +1584,7 @@ pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) ->
let graph_module = PyModule::new(py, "graph")?;

graph_module.add_class::<ExecutionStrategy>()?;
graph_module.add_class::<SqlDialect>()?;
graph_module.add_class::<DistanceMetric>()?;
graph_module.add_class::<GraphConfig>()?;
graph_module.add_class::<GraphConfigBuilder>()?;
Expand Down
3 changes: 2 additions & 1 deletion crates/lance-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub mod parameter_substitution;
pub mod parser;
pub mod query;
pub mod semantic;
pub mod spark_dialect;
pub mod sql_catalog;
pub mod sql_query;
pub mod table_readers;
Expand All @@ -67,7 +68,7 @@ pub use lance_graph_catalog::{
#[cfg(feature = "unity-catalog")]
pub use lance_graph_catalog::{UnityCatalogConfig, UnityCatalogProvider};
pub use lance_vector_search::VectorSearch;
pub use query::{CypherQuery, ExecutionStrategy};
pub use query::{CypherQuery, ExecutionStrategy, SqlDialect};
pub use sql_query::SqlQuery;
#[cfg(feature = "delta")]
pub use table_readers::DeltaTableReader;
Expand Down
83 changes: 73 additions & 10 deletions crates/lance-graph/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,68 @@ use crate::config::GraphConfig;
use crate::error::{GraphError, Result};
use crate::logical_plan::LogicalPlanner;
use crate::parser::parse_cypher_query;
use crate::spark_dialect::build_spark_dialect;
use arrow_array::RecordBatch;
use arrow_schema::{Field, Schema, SchemaRef};
use datafusion_sql::unparser::dialect::{
CustomDialect, DefaultDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect,
};
use datafusion_sql::unparser::Unparser;
use lance_graph_catalog::DirNamespace;
use lance_namespace::models::DescribeTableRequest;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

/// SQL dialect to use when generating SQL from Cypher queries.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SqlDialect {
/// Generic SQL (DataFusion default dialect)
#[default]
Default,
/// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.)
Spark,
/// PostgreSQL dialect
PostgreSql,
/// MySQL dialect
MySql,
/// SQLite dialect
Sqlite,
}

/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference.
pub enum DialectUnparser {
Default(DefaultDialect),
Spark(Box<CustomDialect>),
PostgreSql(PostgreSqlDialect),
MySql(MySqlDialect),
Sqlite(SqliteDialect),
}

impl DialectUnparser {
pub fn as_unparser(&self) -> Unparser<'_> {
match self {
DialectUnparser::Default(d) => Unparser::new(d),
DialectUnparser::Spark(d) => Unparser::new(d.as_ref()),
DialectUnparser::PostgreSql(d) => Unparser::new(d),
DialectUnparser::MySql(d) => Unparser::new(d),
DialectUnparser::Sqlite(d) => Unparser::new(d),
}
}
}

impl SqlDialect {
/// Create a `DialectUnparser` configured for this dialect.
pub fn unparser(&self) -> DialectUnparser {
match self {
SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}),
SqlDialect::Spark => DialectUnparser::Spark(Box::new(build_spark_dialect())),
SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}),
SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}),
SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}),
}
}
}

/// Normalize an Arrow schema to have lowercase field names.
///
/// This ensures that column names in the dataset match the normalized
Expand Down Expand Up @@ -280,10 +335,10 @@ impl CypherQuery {
self.explain_internal(Arc::new(catalog), ctx).await
}

/// Convert the Cypher query to a DataFusion SQL string
/// Convert the Cypher query to a SQL string in the specified dialect.
///
/// This method generates a SQL string that corresponds to the DataFusion logical plan
/// derived from the Cypher query. It uses the `datafusion-sql` unparser.
/// derived from the Cypher query, using the specified SQL dialect for unparsing.
///
/// **WARNING**: This method is experimental and the generated SQL dialect may change.
///
Expand All @@ -293,16 +348,20 @@ impl CypherQuery {
///
/// # Arguments
/// * `datasets` - HashMap of table name to RecordBatch (nodes and relationships)
/// * `dialect` - The SQL dialect to use for generating the output SQL.
/// Defaults to `SqlDialect::Default` (generic DataFusion SQL).
/// Use `SqlDialect::Spark` for Spark SQL, `SqlDialect::PostgreSql`, etc.
///
/// # Returns
/// A SQL string representing the query
/// A SQL string representing the query in the specified dialect
pub async fn to_sql(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
dialect: Option<SqlDialect>,
) -> Result<String> {
use datafusion_sql::unparser::plan_to_sql;
use std::sync::Arc;

let dialect = dialect.unwrap_or_default();
let _config = self.require_config()?;

// Build catalog and context from datasets using the helper
Expand All @@ -323,11 +382,15 @@ impl CypherQuery {
location: snafu::Location::new(file!(), line!(), column!()),
})?;

// Unparse to SQL
let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError {
message: format!("Failed to unparse plan to SQL: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
// Unparse to SQL using the specified dialect
let dialect_unparser = dialect.unparser();
let unparser = dialect_unparser.as_unparser();
let sql_ast = unparser
.plan_to_sql(&optimized_plan)
.map_err(|e| GraphError::PlanError {
message: format!("Failed to unparse plan to SQL: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;

Ok(sql_ast.to_string())
}
Expand Down Expand Up @@ -1852,7 +1915,7 @@ mod tests {
.unwrap()
.with_config(cfg);

let sql = query.to_sql(datasets).await.unwrap();
let sql = query.to_sql(datasets, None).await.unwrap();
println!("Generated SQL: {}", sql);

assert!(sql.contains("SELECT"));
Expand Down
107 changes: 107 additions & 0 deletions crates/lance-graph/src/spark_dialect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! Spark SQL dialect for the DataFusion unparser.
//!
//! This module provides a Spark SQL dialect built using DataFusion's
//! [`CustomDialectBuilder`].
//!
//! Key Spark SQL differences from standard SQL:
//! - Backtick (`` ` ``) identifier quoting
//! - `EXTRACT(field FROM expr)` for date field extraction
//! - `STRING` type for casting (not `VARCHAR`)
//! - `BIGINT`/`INT` for integer types
//! - `TIMESTAMP` for all timestamp types (no timezone info in cast)
//! - `LENGTH()` instead of `CHARACTER_LENGTH()`
//! - Subqueries in FROM require aliases

use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo};
use datafusion_sql::unparser::dialect::{
CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle,
};

/// Build a Spark SQL dialect using DataFusion's `CustomDialectBuilder`.
pub fn build_spark_dialect() -> CustomDialect {
CustomDialectBuilder::new()
.with_identifier_quote_style('`')
.with_supports_nulls_first_in_sort(true)
.with_use_timestamp_for_date64(true)
.with_utf8_cast_dtype(ast::DataType::Custom(
ObjectName::from(vec![Ident::new("STRING")]),
vec![],
))
.with_large_utf8_cast_dtype(ast::DataType::Custom(
ObjectName::from(vec![Ident::new("STRING")]),
vec![],
))
.with_date_field_extract_style(DateFieldExtractStyle::Extract)
.with_character_length_style(CharacterLengthStyle::Length)
.with_int64_cast_dtype(ast::DataType::BigInt(None))
.with_int32_cast_dtype(ast::DataType::Int(None))
.with_timestamp_cast_dtype(
ast::DataType::Timestamp(None, TimezoneInfo::None),
ast::DataType::Timestamp(None, TimezoneInfo::None),
)
.with_date32_cast_dtype(ast::DataType::Date)
.with_supports_column_alias_in_table_alias(true)
.with_requires_derived_table_alias(true)
.with_full_qualified_col(false)
.with_unnest_as_table_factor(false)
.with_float64_ast_dtype(ast::DataType::Double(ast::ExactNumberInfo::None))
.build()
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion_sql::unparser::dialect::Dialect;

#[test]
fn test_spark_dialect_identifier_quoting() {
let dialect = build_spark_dialect();
assert_eq!(dialect.identifier_quote_style("table_name"), Some('`'));
assert_eq!(dialect.identifier_quote_style("column"), Some('`'));
}

#[test]
fn test_spark_dialect_type_mappings() {
let dialect = build_spark_dialect();
assert!(matches!(
dialect.utf8_cast_dtype(),
ast::DataType::Custom(..)
));
assert!(matches!(
dialect.int64_cast_dtype(),
ast::DataType::BigInt(None)
));
assert!(matches!(
dialect.int32_cast_dtype(),
ast::DataType::Int(None)
));
assert!(matches!(dialect.date32_cast_dtype(), ast::DataType::Date));
}

#[test]
fn test_spark_dialect_requires_derived_table_alias() {
let dialect = build_spark_dialect();
assert!(dialect.requires_derived_table_alias());
}

#[test]
fn test_spark_dialect_extract_style() {
let dialect = build_spark_dialect();
assert!(matches!(
dialect.date_field_extract_style(),
DateFieldExtractStyle::Extract
));
}

#[test]
fn test_spark_dialect_character_length_style() {
let dialect = build_spark_dialect();
assert!(matches!(
dialect.character_length_style(),
CharacterLengthStyle::Length
));
}
}
Loading
Loading