From 27fc96a47a39f34cf423743569f437b263098f52 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 31 Mar 2026 22:27:33 +0300 Subject: [PATCH] Ability to truely use pyformat style --- README.md | 105 ++++++++++++++ tests/test_connections.py | 61 ++++++++ tests/test_convert_parameters.py | 229 ++++++++++++++++++++++++++++++ ydb_dbapi/connections.py | 78 +++++++++-- ydb_dbapi/cursors.py | 28 +++- ydb_dbapi/utils.py | 233 +++++++++++++++++++++++++++++++ 6 files changed, 719 insertions(+), 15 deletions(-) create mode 100644 tests/test_convert_parameters.py diff --git a/README.md b/README.md index 5666a2c..5824376 100644 --- a/README.md +++ b/README.md @@ -47,3 +47,108 @@ async with async_connection.cursor() as cursor: rows = await cursor.fetchmany(size=5) rows = await cursor.fetchall() ``` + +## Query parameters + +### Standard mode (`pyformat=True`) + +Pass `pyformat=True` to `connect()` to enable familiar Python DB-API +parameter syntax such as `%(name)s` and `%s`. This mode is opt-in: the +default connection mode (`pyformat=False`) uses YDB-style `$name` +placeholders instead, so `%(name)s` and `%s` do not work unless +`pyformat=True` is set explicitly. The driver will convert placeholders +and infer YDB types from Python values automatically. + +**Named parameters** — `%(name)s` with a `dict`: + +```python +connection = ydb_dbapi.connect( + host="localhost", port="2136", database="/local", + pyformat=True, +) + +with connection.cursor() as cursor: + cursor.execute( + "SELECT * FROM users WHERE id = %(id)s AND active = %(active)s", + {"id": 42, "active": True}, + ) +``` + +**Positional parameters** — `%s` with a `list` or `tuple`: + +```python +with connection.cursor() as cursor: + cursor.execute( + "INSERT INTO users (id, name, score) VALUES (%s, %s, %s)", + [1, "Alice", 9.8], + ) +``` + +Use `%%` to insert a literal `%` character in the query. + +The driver validates `pyformat` placeholders before executing the query: +- do not mix named `%(name)s` and positional `%s` placeholders in one query; +- named placeholders require a `dict` with bare keys like `{"id": 1}`; +- positional placeholders require a `list` or `tuple`; +- missing or extra parameters raise `ProgrammingError`; +- keys starting with `$` are not allowed in `pyformat=True` mode. + +**Automatic type mapping:** + +| Python type | YDB type | +|--------------------|-------------| +| `bool` | `Bool` | +| `int` | `Int64` | +| `float` | `Double` | +| `str` | `Utf8` | +| `bytes` | `String` | +| `datetime.datetime`| `Timestamp` | +| `datetime.date` | `Date` | +| `datetime.timedelta`| `Interval` | +| `decimal.Decimal` | `Decimal(22, 9)` | +| `None` | `NULL` (passed as-is) | + +**Explicit types with `ydb.TypedValue`:** + +When automatic inference is not suitable (e.g. you need `Int32` instead of +`Int64`, or `Json`), wrap the value in `ydb.TypedValue` — it will be passed +through unchanged: + +```python +import ydb + +with connection.cursor() as cursor: + cursor.execute( + "INSERT INTO events (id, payload) VALUES (%(id)s, %(payload)s)", + { + "id": ydb.TypedValue(99, ydb.PrimitiveType.Int32), + "payload": ydb.TypedValue('{"key": "value"}', ydb.PrimitiveType.Json), + }, + ) +``` + +If the driver cannot infer a YDB type for a Python value, it raises +`TypeError`. Use `ydb.TypedValue` for such values or when you need an explicit +YDB type. + +### Native YDB mode (default, deprecated) + +> **Deprecated.** Native YDB mode is the current default for backwards +> compatibility, but it will be removed in a future release. Migrate to +> `pyformat=True` at your earliest convenience. + +By default (`pyformat=False`) the driver passes the query and parameters +directly to the YDB SDK without any transformation. Use `$name` placeholders +in the query and supply a `dict` with `$`-prefixed keys: + +```python +connection = ydb_dbapi.connect( + host="localhost", port="2136", database="/local", +) + +with connection.cursor() as cursor: + cursor.execute( + "SELECT * FROM users WHERE id = $id", + {"$id": ydb.TypedValue(42, ydb.PrimitiveType.Int64)}, + ) +``` diff --git a/tests/test_connections.py b/tests/test_connections.py index 0eced67..d87f9ed 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -131,6 +131,47 @@ def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None: maybe_await(cur.close()) + def _test_cursor_pyformat_query( + self, + connection: dbapi.Connection, + ) -> None: + cur = connection.cursor() + + with suppress(dbapi.DatabaseError): + maybe_await(cur.execute_scheme("DROP TABLE test_pyformat")) + + maybe_await( + cur.execute_scheme( + """ + CREATE TABLE test_pyformat( + id Int64 NOT NULL, + text Utf8, + PRIMARY KEY (id) + ) + """ + ) + ) + + maybe_await( + cur.execute( + "UPSERT INTO test_pyformat(id, text) VALUES (%s, %s)", + [17, "seventeen"], + ) + ) + maybe_await( + cur.execute( + "SELECT text FROM test_pyformat WHERE id = %(id)s", + {"id": 17}, + ) + ) + + row = cur.fetchone() + assert row is not None + assert row[0] == "seventeen" + + maybe_await(cur.execute_scheme("DROP TABLE test_pyformat")) + maybe_await(cur.close()) + def _test_errors( self, connection: dbapi.Connection, @@ -431,6 +472,13 @@ def test_connection(self, connection: dbapi.Connection) -> None: def test_cursor_raw_query(self, connection: dbapi.Connection) -> None: self._test_cursor_raw_query(connection) + def test_cursor_pyformat_query(self, connection_kwargs: dict) -> None: + connection = dbapi.connect(**{**connection_kwargs, "pyformat": True}) + try: + self._test_cursor_pyformat_query(connection) + finally: + connection.close() + def test_errors(self, connection: dbapi.Connection) -> None: self._test_errors(connection) @@ -542,6 +590,19 @@ async def test_cursor_raw_query( ) -> None: await greenlet_spawn(self._test_cursor_raw_query, connection) + @pytest.mark.asyncio + async def test_cursor_pyformat_query( + self, + connection_kwargs: dict, + ) -> None: + connection = await dbapi.async_connect( + **{**connection_kwargs, "pyformat": True} + ) + try: + await greenlet_spawn(self._test_cursor_pyformat_query, connection) + finally: + await connection.close() + @pytest.mark.asyncio async def test_errors(self, connection: dbapi.AsyncConnection) -> None: await greenlet_spawn(self._test_errors, connection) diff --git a/tests/test_convert_parameters.py b/tests/test_convert_parameters.py new file mode 100644 index 0000000..c0322fe --- /dev/null +++ b/tests/test_convert_parameters.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import datetime +import decimal + +import pytest +import ydb +from ydb_dbapi.errors import ProgrammingError +from ydb_dbapi.utils import convert_query_parameters + + +class TestNamedStyle: + """%(name)s placeholders with a dict.""" + + def test_basic_query_transformation(self): + q, _ = convert_query_parameters( + "SELECT %(id)s FROM t WHERE name = %(name)s", + {"id": 1, "name": "alice"}, + ) + assert q == "SELECT $id FROM t WHERE name = $name" + + def test_keys_prefixed_with_dollar(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": 1}) + assert "$x" in p + assert "x" not in p + + def test_int(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": 42}) + assert p["$x"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64) + + def test_float(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": 3.14}) + assert p["$x"] == ydb.TypedValue(3.14, ydb.PrimitiveType.Double) + + def test_str(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": "hello"}) + assert p["$x"] == ydb.TypedValue("hello", ydb.PrimitiveType.Utf8) + + def test_bytes(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": b"data"}) + assert p["$x"] == ydb.TypedValue(b"data", ydb.PrimitiveType.String) + + def test_bool(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": True}) + assert p["$x"] == ydb.TypedValue(True, ydb.PrimitiveType.Bool) + + def test_bool_not_confused_with_int(self): + # bool is subclass of int — must map to Bool, not Int64 + _, p = convert_query_parameters("SELECT %(x)s", {"x": False}) + assert p["$x"].value_type == ydb.PrimitiveType.Bool + + def test_date(self): + d = datetime.date(2024, 1, 15) + _, p = convert_query_parameters("SELECT %(x)s", {"x": d}) + assert p["$x"] == ydb.TypedValue(d, ydb.PrimitiveType.Date) + + def test_datetime(self): + tz = datetime.timezone.utc + dt = datetime.datetime(2024, 1, 15, 12, 0, 0, tzinfo=tz) + _, p = convert_query_parameters("SELECT %(x)s", {"x": dt}) + assert p["$x"] == ydb.TypedValue(dt, ydb.PrimitiveType.Timestamp) + + def test_datetime_not_confused_with_date(self): + # datetime is subclass of date — must map to Timestamp, not Date + tz = datetime.timezone.utc + dt = datetime.datetime(2024, 6, 1, 0, 0, 0, tzinfo=tz) + _, p = convert_query_parameters("SELECT %(x)s", {"x": dt}) + assert p["$x"].value_type == ydb.PrimitiveType.Timestamp + + def test_timedelta(self): + td = datetime.timedelta(seconds=60) + _, p = convert_query_parameters("SELECT %(x)s", {"x": td}) + assert p["$x"] == ydb.TypedValue(td, ydb.PrimitiveType.Interval) + + def test_decimal(self): + d = decimal.Decimal("3.14") + _, p = convert_query_parameters("SELECT %(x)s", {"x": d}) + assert isinstance(p["$x"], ydb.TypedValue) + assert p["$x"].value == d + + def test_none_passed_as_is(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": None}) + assert p["$x"] is None + + def test_unknown_type_raises_type_error(self): + obj = object() + with pytest.raises(TypeError, match="Could not infer YDB type"): + convert_query_parameters("SELECT %(x)s", {"x": obj}) + + def test_unknown_type_error_does_not_leak_value_repr(self): + class SecretValue: + def __repr__(self) -> str: + return "secret-token" + + with pytest.raises(TypeError) as exc_info: + convert_query_parameters("SELECT %(x)s", {"x": SecretValue()}) + + assert "secret-token" not in str(exc_info.value) + assert "SecretValue" in str(exc_info.value) + + def test_multiple_params(self): + q, p = convert_query_parameters( + "INSERT INTO t VALUES (%(a)s, %(b)s, %(c)s)", + {"a": 1, "b": "hi", "c": True}, + ) + assert q == "INSERT INTO t VALUES ($a, $b, $c)" + assert "$a" in p + assert "$b" in p + assert "$c" in p + + def test_percent_percent_escape(self): + q, _ = convert_query_parameters( + "SELECT %% as pct, %(x)s", {"x": 1} + ) + assert q == "SELECT % as pct, $x" + + def test_empty_params(self): + q, p = convert_query_parameters("SELECT 1", {}) + assert q == "SELECT 1" + assert p == {} + + def test_named_requires_dict(self): + with pytest.raises( + ProgrammingError, + match="require parameters to be a dict", + ): + convert_query_parameters("SELECT %(x)s", [1]) + + def test_named_missing_key_raises(self): + with pytest.raises(ProgrammingError, match="missing keys: 'y'"): + convert_query_parameters("SELECT %(x)s, %(y)s", {"x": 1}) + + def test_named_extra_key_raises(self): + with pytest.raises(ProgrammingError, match="unexpected keys: 'y'"): + convert_query_parameters("SELECT %(x)s", {"x": 1, "y": 2}) + + def test_named_dollar_prefixed_key_raises(self): + with pytest.raises( + ProgrammingError, + match="must not start with '\\$'", + ): + convert_query_parameters("SELECT %(x)s", {"$x": 1}) + + def test_mixing_named_and_positional_raises(self): + with pytest.raises(ProgrammingError, match="Mixing named"): + convert_query_parameters("SELECT %(x)s, %s", {"x": 1}) + + def test_invalid_percent_sequence_raises(self): + with pytest.raises( + ProgrammingError, + match="Invalid pyformat placeholder", + ): + convert_query_parameters("SELECT %%%", []) + + +class TestCustomTypes: + """Pass-through for ydb.TypedValue (explicit type hint).""" + + def test_typed_value_passed_through(self): + tv = ydb.TypedValue(42, ydb.PrimitiveType.Int32) + _, p = convert_query_parameters("SELECT %(x)s", {"x": tv}) + assert p["$x"] is tv + + def test_typed_value_not_double_wrapped(self): + tv = ydb.TypedValue("hello", ydb.PrimitiveType.Utf8) + _, p = convert_query_parameters("SELECT %(x)s", {"x": tv}) + assert isinstance(p["$x"], ydb.TypedValue) + assert p["$x"].value_type == ydb.PrimitiveType.Utf8 + + def test_typed_value_positional(self): + tv = ydb.TypedValue(99, ydb.PrimitiveType.Int32) + _, p = convert_query_parameters("SELECT %s", [tv]) + assert p["$p1"] is tv + + def test_unknown_type_raises_type_error(self): + val = object() + with pytest.raises(TypeError, match="Could not infer YDB type"): + convert_query_parameters("SELECT %(x)s", {"x": val}) + + +class TestPositionalStyle: + """Positional %s placeholders with a list or tuple.""" + + def test_basic_list(self): + q, p = convert_query_parameters("SELECT %s", [42]) + assert q == "SELECT $p1" + assert p["$p1"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64) + + def test_basic_tuple(self): + q, p = convert_query_parameters("SELECT %s", (42,)) + assert q == "SELECT $p1" + assert p["$p1"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64) + + def test_multiple_params_numbered_sequentially(self): + q, p = convert_query_parameters( + "INSERT INTO t VALUES (%s, %s, %s)", [1, "hi", 3.14] + ) + assert q == "INSERT INTO t VALUES ($p1, $p2, $p3)" + assert p["$p1"] == ydb.TypedValue(1, ydb.PrimitiveType.Int64) + assert p["$p2"] == ydb.TypedValue("hi", ydb.PrimitiveType.Utf8) + assert p["$p3"] == ydb.TypedValue(3.14, ydb.PrimitiveType.Double) + + def test_none_passed_as_is(self): + _, p = convert_query_parameters("SELECT %s", [None]) + assert p["$p1"] is None + + def test_percent_percent_escape(self): + q, p = convert_query_parameters("SELECT %%, %s", [7]) + assert q == "SELECT %, $p1" + assert p["$p1"] == ydb.TypedValue(7, ydb.PrimitiveType.Int64) + + def test_empty_list(self): + q, p = convert_query_parameters("SELECT 1", []) + assert q == "SELECT 1" + assert p == {} + + def test_positional_requires_sequence(self): + with pytest.raises( + ProgrammingError, match="require parameters to be a list or tuple" + ): + convert_query_parameters("SELECT %s", {"x": 1}) + + def test_positional_count_mismatch_raises(self): + with pytest.raises(ProgrammingError, match="expected 2, got 1"): + convert_query_parameters("SELECT %s, %s", [1]) + + def test_positional_extra_args_raises(self): + with pytest.raises(ProgrammingError, match="expected 1, got 2"): + convert_query_parameters("SELECT %s", [1, 2]) diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index 1ee215b..cc7a117 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -73,7 +73,7 @@ class BaseConnection: def __init__( self, host: str = "", - port: str = "", + port: int | str = "", database: str = "", ydb_table_path_prefix: str = "", protocol: str | None = None, @@ -82,6 +82,7 @@ def __init__( root_certificates_path: str | None = None, root_certificates: str | None = None, driver_config_kwargs: dict | None = None, + pyformat: bool = False, **kwargs: Any, ) -> None: protocol = protocol if protocol else "grpc" @@ -89,6 +90,7 @@ def __init__( self.credentials = prepare_credentials(credentials) self.database = database self.table_path_prefix = ydb_table_path_prefix + self.pyformat = pyformat self.connection_kwargs: dict = kwargs @@ -207,15 +209,16 @@ class Connection(BaseConnection): def __init__( self, host: str = "", - port: str = "", + port: int | str = "", database: str = "", ydb_table_path_prefix: str = "", protocol: str | None = None, - credentials: ydb.Credentials | None = None, + credentials: ydb.Credentials | dict | str | None = None, ydb_session_pool: SessionPool | AsyncSessionPool | None = None, root_certificates_path: str | None = None, root_certificates: str | None = None, driver_config_kwargs: dict | None = None, + pyformat: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -229,6 +232,7 @@ def __init__( root_certificates_path=root_certificates_path, root_certificates=root_certificates, driver_config_kwargs=driver_config_kwargs, + pyformat=pyformat, **kwargs, ) self._current_cursor: Cursor | None = None @@ -242,6 +246,7 @@ def cursor(self) -> Cursor: table_path_prefix=self.table_path_prefix, request_settings=self.request_settings, retry_settings=self.retry_settings, + pyformat=self.pyformat, ) def wait_ready(self, timeout: int = 10) -> None: @@ -402,15 +407,16 @@ class AsyncConnection(BaseConnection): def __init__( self, host: str = "", - port: str = "", + port: int | str = "", database: str = "", ydb_table_path_prefix: str = "", protocol: str | None = None, - credentials: ydb.Credentials | None = None, + credentials: ydb.Credentials | dict | str | None = None, ydb_session_pool: SessionPool | AsyncSessionPool | None = None, root_certificates_path: str | None = None, root_certificates: str | None = None, driver_config_kwargs: dict | None = None, + pyformat: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -424,6 +430,7 @@ def __init__( root_certificates_path=root_certificates_path, root_certificates=root_certificates, driver_config_kwargs=driver_config_kwargs, + pyformat=pyformat, **kwargs, ) self._current_cursor: AsyncCursor | None = None @@ -437,6 +444,7 @@ def cursor(self) -> AsyncCursor: table_path_prefix=self.table_path_prefix, request_settings=self.request_settings, retry_settings=self.retry_settings, + pyformat=self.pyformat, ) async def wait_ready(self, timeout: int = 10) -> None: @@ -593,13 +601,65 @@ async def _invalidate_session(self) -> None: self._session = None -def connect(*args: Any, **kwargs: Any) -> Connection: - conn = Connection(*args, **kwargs) +def connect( + host: str = "", + port: int | str = "", + database: str = "", + ydb_table_path_prefix: str = "", + protocol: str | None = None, + credentials: ydb.Credentials | dict | str | None = None, + ydb_session_pool: SessionPool | AsyncSessionPool | None = None, + root_certificates_path: str | None = None, + root_certificates: str | None = None, + driver_config_kwargs: dict | None = None, + pyformat: bool = False, + **kwargs: Any, +) -> Connection: + conn = Connection( + host=host, + port=port, + database=database, + ydb_table_path_prefix=ydb_table_path_prefix, + protocol=protocol, + credentials=credentials, + ydb_session_pool=ydb_session_pool, + root_certificates_path=root_certificates_path, + root_certificates=root_certificates, + driver_config_kwargs=driver_config_kwargs, + pyformat=pyformat, + **kwargs, + ) conn.wait_ready() return conn -async def async_connect(*args: Any, **kwargs: Any) -> AsyncConnection: - conn = AsyncConnection(*args, **kwargs) +async def async_connect( + host: str = "", + port: int | str = "", + database: str = "", + ydb_table_path_prefix: str = "", + protocol: str | None = None, + credentials: ydb.Credentials | dict | str | None = None, + ydb_session_pool: SessionPool | AsyncSessionPool | None = None, + root_certificates_path: str | None = None, + root_certificates: str | None = None, + driver_config_kwargs: dict | None = None, + pyformat: bool = False, + **kwargs: Any, +) -> AsyncConnection: + conn = AsyncConnection( + host=host, + port=port, + database=database, + ydb_table_path_prefix=ydb_table_path_prefix, + protocol=protocol, + credentials=credentials, + ydb_session_pool=ydb_session_pool, + root_certificates_path=root_certificates_path, + root_certificates=root_certificates, + driver_config_kwargs=driver_config_kwargs, + pyformat=pyformat, + **kwargs, + ) await conn.wait_ready() return conn diff --git a/ydb_dbapi/cursors.py b/ydb_dbapi/cursors.py index 7060401..fa04936 100644 --- a/ydb_dbapi/cursors.py +++ b/ydb_dbapi/cursors.py @@ -19,6 +19,7 @@ from .errors import InterfaceError from .errors import ProgrammingError from .utils import CursorStatus +from .utils import convert_query_parameters from .utils import handle_ydb_errors from .utils import maybe_get_current_trace_id @@ -26,13 +27,17 @@ from .connections import AsyncConnection from .connections import Connection - ParametersType = dict[ - str, - Union[ - Any, - tuple[Any, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]], - ydb.TypedValue, + ParametersType = Union[ + dict[ + str, + Union[ + Any, + tuple[Any, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]], + ydb.TypedValue, + ], ], + list[Any], + tuple[Any, ...], ] @@ -202,6 +207,7 @@ def __init__( retry_settings: ydb.RetrySettings, tx_context: ydb.QueryTxContext | None = None, table_path_prefix: str = "", + pyformat: bool = False, ) -> None: super().__init__() self._connection = connection @@ -211,6 +217,7 @@ def __init__( self._retry_settings = retry_settings self._tx_context = tx_context self._table_path_prefix = table_path_prefix + self._pyformat = pyformat self._stream: Iterator | None = None def fetchone(self) -> tuple | None: @@ -328,6 +335,10 @@ def execute( self._raise_if_running() query = self._append_table_path_prefix(query) + + if self._pyformat and parameters is not None: + query, parameters = convert_query_parameters(query, parameters) + self._begin_query() if self._tx_context is not None: @@ -379,6 +390,7 @@ def __init__( retry_settings: ydb.RetrySettings, tx_context: ydb.aio.QueryTxContext | None = None, table_path_prefix: str = "", + pyformat: bool = False, ) -> None: super().__init__() self._connection = connection @@ -388,6 +400,7 @@ def __init__( self._retry_settings = retry_settings self._tx_context = tx_context self._table_path_prefix = table_path_prefix + self._pyformat = pyformat self._stream: AsyncIterator | None = None def fetchone(self) -> tuple | None: @@ -506,6 +519,9 @@ async def execute( query = self._append_table_path_prefix(query) + if self._pyformat and parameters is not None: + query, parameters = convert_query_parameters(query, parameters) + self._begin_query() if self._tx_context is not None: diff --git a/ydb_dbapi/utils.py b/ydb_dbapi/utils.py index 5e0e1f0..7cc62e8 100644 --- a/ydb_dbapi/utils.py +++ b/ydb_dbapi/utils.py @@ -1,8 +1,11 @@ from __future__ import annotations +import datetime +import decimal import functools import importlib.util import json +import re from enum import Enum from inspect import iscoroutinefunction from typing import Any @@ -156,3 +159,233 @@ def prepare_credentials( ) return ydb.AnonymousCredentials() + + +# Order matters: bool before int, datetime before date (subclass checks). +_PYTHON_TO_YDB_TYPE: list[tuple[type, Any]] = [ + (bool, ydb.PrimitiveType.Bool), + (int, ydb.PrimitiveType.Int64), + (float, ydb.PrimitiveType.Double), + (str, ydb.PrimitiveType.Utf8), + (bytes, ydb.PrimitiveType.String), + (datetime.datetime, ydb.PrimitiveType.Timestamp), + (datetime.date, ydb.PrimitiveType.Date), + (datetime.timedelta, ydb.PrimitiveType.Interval), + (decimal.Decimal, ydb.DecimalType(22, 9)), +] +_PYFORMAT_TOKEN_RE = re.compile(r"%%|%\((\w+)\)s|%s|%") + + +def _infer_ydb_type(value: Any) -> Any: + """Infer a YDB type from a Python value.""" + for python_type, ydb_type in _PYTHON_TO_YDB_TYPE: + if isinstance(value, python_type): + return ydb_type + return None + + +def _wrap_value(value: Any) -> Any: + """Wrap a Python value in ydb.TypedValue if a type can be inferred. + + ``ydb.TypedValue`` instances are returned as-is so callers can supply + an explicit type for values whose type cannot be inferred automatically. + """ + if isinstance(value, ydb.TypedValue): + return value + if value is None: + return value + ydb_type = _infer_ydb_type(value) + if ydb_type is not None: + return ydb.TypedValue(value, ydb_type) + message = ( + "Could not infer YDB type for parameter of type " + f"{type(value).__name__}. " + "Wrap it in ydb.TypedValue to pass an explicit type." + ) + raise TypeError(message) + + +def _extract_pyformat_placeholders(query: str) -> tuple[list[str], int]: + named_placeholders: list[str] = [] + positional_count = 0 + + for match in _PYFORMAT_TOKEN_RE.finditer(query): + placeholder = match.group(0) + if placeholder == "%": + raise ProgrammingError( + "Invalid pyformat placeholder in query. " + "Use %% for a literal percent, %(name)s for named parameters, " + "or %s for positional parameters." + ) + if placeholder == "%%": + continue + if placeholder == "%s": + positional_count += 1 + continue + named_placeholders.append(match.group(1)) + + if named_placeholders and positional_count: + raise ProgrammingError( + "Mixing named (%(name)s) and positional (%s) placeholders " + "is not supported." + ) + + return named_placeholders, positional_count + + +def _convert_pyformat_query(query: str) -> str: + positional_index = 0 + + def replace(match: re.Match) -> str: + nonlocal positional_index + placeholder = match.group(0) + if placeholder == "%%": + return "%" + if placeholder == "%s": + positional_index += 1 + return f"$p{positional_index}" + return f"${match.group(1)}" + + return _PYFORMAT_TOKEN_RE.sub(replace, query) + + +def _validate_positional_parameters( + parameters: dict | list | tuple, + positional_count: int, +) -> None: + if not isinstance(parameters, (list, tuple)): + raise ProgrammingError( + "Positional placeholders (%s) require parameters " + "to be a list or tuple." + ) + if len(parameters) == positional_count: + return + + message = ( + "Positional placeholder count does not match the number " + f"of parameters: expected {positional_count}, got {len(parameters)}." + ) + raise ProgrammingError(message) + + +def _validate_named_parameters( + parameters: dict | list | tuple, + named_placeholders: list[str], +) -> dict: + if not isinstance(parameters, dict): + raise ProgrammingError( + "Named placeholders (%(name)s) require parameters " + "to be a dict." + ) + + expected_keys = set(named_placeholders) + actual_keys = set(parameters) + invalid_prefixed_keys = [ + name for name in actual_keys + if isinstance(name, str) and name.startswith("$") + ] + if invalid_prefixed_keys: + invalid_key = invalid_prefixed_keys[0] + message = ( + "Mapping parameter names must not start with '$' when using " + "pyformat style; " + f"got key {invalid_key!r}. Use bare names instead." + ) + raise ProgrammingError(message) + + missing_keys = expected_keys - actual_keys + extra_keys = actual_keys - expected_keys + if not (missing_keys or extra_keys): + return parameters + + details: list[str] = [] + if missing_keys: + details.append( + "missing keys: " + + ", ".join(sorted(repr(key) for key in missing_keys)) + ) + if extra_keys: + details.append( + "unexpected keys: " + + ", ".join(sorted(repr(key) for key in extra_keys)) + ) + raise ProgrammingError( + "Named placeholders do not match the provided parameters: " + + "; ".join(details) + + "." + ) + + +def _validate_pyformat_parameters( + parameters: dict | list | tuple, + named_placeholders: list[str], + positional_count: int, +) -> dict | list | tuple: + if positional_count: + _validate_positional_parameters(parameters, positional_count) + return parameters + if named_placeholders: + return _validate_named_parameters(parameters, named_placeholders) + if parameters: + raise ProgrammingError( + "Query does not contain pyformat placeholders, " + "but parameters were provided." + ) + return parameters + + +def convert_query_parameters( + query: str, + parameters: dict | list | tuple, +) -> tuple[str, dict]: + """Convert pyformat-style query and parameters to YDB format. + + Supports two parameter styles: + + Named (``%(name)s``) with a mapping:: + + convert_query_parameters( + "SELECT %(id)s", {"id": 42} + ) + # -> ("SELECT $id", {"$id": TypedValue(42, Int64)}) + + Positional (``%s``) with a sequence:: + + convert_query_parameters( + "SELECT %s, %s", [42, "hi"] + ) + # -> ("SELECT $p1, $p2", {"$p1": TypedValue(42, Int64), + # "$p2": TypedValue("hi", Utf8)}) + + ``%%`` is converted to a literal ``%`` in both modes. + + Python-to-YDB type mapping: + bool -> Bool + int -> Int64 + float -> Double + str -> Utf8 + bytes -> String + datetime -> Timestamp + date -> Date + timedelta -> Interval + Decimal -> Decimal(22, 9) + None -> passed as-is (NULL) + """ + named_placeholders, positional_count = _extract_pyformat_placeholders( + query + ) + parameters = _validate_pyformat_parameters( + parameters, named_placeholders, positional_count + ) + converted_query = _convert_pyformat_query(query) + + converted_params: dict = {} + + if isinstance(parameters, (list, tuple)): + for i, value in enumerate(parameters, start=1): + converted_params[f"$p{i}"] = _wrap_value(value) + else: + for name, value in parameters.items(): + converted_params[f"${name}"] = _wrap_value(value) + + return converted_query, converted_params