Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
attrs==21.2.0
bcrypt==3.2.0
black==22.12.0
black==25.11.0; python_version >= "3.9"
cached-property==1.5.2
certifi==2024.7.4
cffi>=1.17.0,<2.0.0
cffi>=1.17.0
charset-normalizer==2.0.1
cryptography==41.0.0
cryptography>=46.0.7
distro==1.5.0
docker>=5.0.0
grpcio>=1.42.0
idna==3.2
importlib-metadata==4.6.1
iniconfig==1.1.1
jsonschema==3.2.0
packaging==21.0
packaging>=21.0
paramiko==2.10.1
pluggy>=1.0.0
protobuf>=3.13.0,<7.0.0
Expand All @@ -25,13 +25,13 @@ pyrsistent==0.18.0
pytest>=7.0.0,<8.0.0
pytest-asyncio==0.21.0
pytest-docker>=3.0.0
python-dotenv==0.18.0
python-dotenv>=1.0.1
PyYAML==5.3.1
pyjwt==2.0.0
requests==2.31.0
pyjwt>=2.9.0
requests>=2.32.4
texttable==1.6.4
toml==0.10.2
typing-extensions==4.12.2
typing-extensions>=4.13.2
urllib3==1.26.6
websocket-client==0.59.0
zipp==3.19.1
Expand Down
80 changes: 70 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,41 @@
import os
import subprocess

import docker
import pytest
import ydb
from ydb import issues


def _docker_client():
"""Build a Docker SDK client that works with non-default sockets.

`docker.from_env()` only honors `DOCKER_HOST` and a couple of fixed
paths, so it fails on Colima / OrbStack / Docker Desktop on macOS where
the socket lives elsewhere. Fall back to whatever the CLI's active
context says — that's a one-shot subprocess call before any driver is
running, so the gRPC fork race we worry about elsewhere doesn't apply.
"""
try:
return docker.from_env()
except docker.errors.DockerException:
pass
try:
host = (
subprocess.check_output(
["docker", "context", "inspect", "--format", "{{.Endpoints.docker.Host}}"],
stderr=subprocess.DEVNULL,
)
.decode()
.strip()
)
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
raise RuntimeError(
"Could not locate the Docker daemon socket. " "Set DOCKER_HOST or make sure `docker context` is configured."
) from exc
return docker.DockerClient(base_url=host)


def pytest_addoption(parser):
"""Add custom command line options for pytest-docker compatibility."""
parser.addoption(
Expand All @@ -29,24 +60,53 @@ def docker_cleanup():


class DockerProject:
"""Compatibility wrapper for pytest-docker-compose docker_project fixture."""

def __init__(self, docker_compose, docker_services, endpoint):
"""Compatibility wrapper for pytest-docker-compose docker_project fixture.

The `stop()` method talks to the Docker daemon via the SDK (unix socket)
instead of forking a `docker compose kill` subprocess. Forking from a
python process that has active gRPC threads crashes the child with
SIGABRT on Python 3.9 (long-standing gRPC fork-handler issue), and
`stop()` is the worst offender because the driver is in the middle of
busy traffic when it's called.

`start()` still uses `docker compose up -d --force-recreate` because
recreating a YDB container after SIGKILL needs the full compose config
(in-memory PDisks lose state, plain `container.start()` can't recover
them). By the time `start()` runs, the driver's connections are already
broken and its background threads are sleeping in retry backoff, so
forking a subprocess is safe.
"""

def __init__(self, project_name, docker_compose, docker_services, endpoint):
self._project_name = project_name
self._docker_compose = docker_compose
self._docker_services = docker_services
self._endpoint = endpoint
self._docker = _docker_client()
self._stopped = False

def _ydb_container(self):
containers = self._docker.containers.list(
all=True,
filters={
"label": [
f"com.docker.compose.project={self._project_name}",
"com.docker.compose.service=ydb",
]
},
)
if not containers:
raise RuntimeError(f"YDB container for compose project '{self._project_name}' not found")
return containers[0]

def stop(self):
"""Stop all containers (marks as stopped, actual restart happens in start())."""
"""Instantly kill the YDB container (simulates network failure)."""
self._stopped = True
# Use 'kill' for instant stop (simulates network failure better than graceful stop)
self._docker_compose.execute("kill")
self._ydb_container().kill()

def start(self):
"""Restart containers and wait for YDB to be ready."""
"""Restart containers and wait until YDB is responsive."""
if self._stopped:
# After kill, we need to recreate the container to restore YDB properly
self._docker_compose.execute("up -d --force-recreate")
self._stopped = False
else:
Expand All @@ -60,9 +120,9 @@ def start(self):


@pytest.fixture(scope="module")
def docker_project(docker_services, endpoint):
def docker_project(docker_compose_project_name, docker_services, endpoint):
"""Compatibility fixture providing stop/start methods like pytest-docker-compose."""
return DockerProject(docker_services._docker_compose, docker_services, endpoint)
return DockerProject(docker_compose_project_name, docker_services._docker_compose, docker_services, endpoint)


def is_ydb_responsive(endpoint):
Expand Down
6 changes: 3 additions & 3 deletions tests/oauth2_token_exchange/test_token_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def test_oauth2_token_exchange_credentials_file():
},
},
http_request_is_expected=False,
get_token_error_text_part="Could not deserialize key data.",
get_token_error_text_part="Oauth2TokenExchangeCredentials:",
),
DataForConfigTest(
cfg={
Expand All @@ -640,7 +640,7 @@ def test_oauth2_token_exchange_credentials_file():
},
},
http_request_is_expected=False,
get_token_error_text_part="Could not deserialize key data.",
get_token_error_text_part="Oauth2TokenExchangeCredentials:",
),
DataForConfigTest(
cfg={
Expand All @@ -651,7 +651,7 @@ def test_oauth2_token_exchange_credentials_file():
},
},
http_request_is_expected=False,
get_token_error_text_part="sign() missing 1 required positional argument",
get_token_error_text_part="Oauth2TokenExchangeCredentials:",
),
DataForConfigTest(
cfg_file="~/unknown-file.cfg",
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ commands =
show-source = true
builtins = _
max-line-length = 160
ignore=E203,W503
ignore=E203,E704,W503
exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,ydb/public/api/protos/*,docs/*,ydb/public/api/grpc/*,persqueue/*,client/*,dbapi/*,ydb/default_pem.py,*docs/conf.py
per-file-ignores = ydb/table.py:F401

Expand Down
26 changes: 9 additions & 17 deletions ydb/_grpc/grpcwrapper/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,36 +54,31 @@ class IFromProto(abc.ABC, Generic[ProtoT, ResultT]):

@staticmethod
@abc.abstractmethod
def from_proto(msg: ProtoT) -> ResultT:
...
def from_proto(msg: ProtoT) -> ResultT: ...


class IFromProtoWithProtoType(IFromProto[ProtoT, ResultT]):
"""Extended interface that also knows which proto message type to create."""

@staticmethod
@abc.abstractmethod
def empty_proto_message() -> ProtoT:
...
def empty_proto_message() -> ProtoT: ...


class IToProto(abc.ABC):
@abc.abstractmethod
def to_proto(self) -> Message:
...
def to_proto(self) -> Message: ...


class IFromPublic(abc.ABC):
@staticmethod
@abc.abstractmethod
def from_public(o: typing.Any) -> typing.Any:
...
def from_public(o: typing.Any) -> typing.Any: ...


class IToPublic(abc.ABC):
@abc.abstractmethod
def to_public(self) -> typing.Any:
...
def to_public(self) -> typing.Any: ...


class UnknownGrpcMessageError(issues.Error):
Expand Down Expand Up @@ -148,16 +143,13 @@ async def __anext__(self):

class IGrpcWrapperAsyncIO(abc.ABC):
@abc.abstractmethod
async def receive(self, timeout: Optional[int] = None) -> Any:
...
async def receive(self, timeout: Optional[int] = None) -> Any: ...

@abc.abstractmethod
def write(self, wrap_message: IToProto):
...
def write(self, wrap_message: IToProto): ...

@abc.abstractmethod
def close(self):
...
def close(self): ...


# SupportedDriverType imported from ydb._typing
Expand Down Expand Up @@ -295,7 +287,7 @@ def from_proto(
msg: Union[
ydb_topic_pb2.StreamReadMessage.FromServer,
ydb_topic_pb2.StreamWriteMessage.FromServer,
]
],
) -> "ServerStatus":
return ServerStatus(msg.status, msg.issues)

Expand Down
2 changes: 1 addition & 1 deletion ydb/_grpc/grpcwrapper/ydb_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def to_public(self) -> List[ydb_topic_public_types.PublicCodec]:

@staticmethod
def from_public(
codecs: Optional[List[Union[ydb_topic_public_types.PublicCodec, int]]]
codecs: Optional[List[Union[ydb_topic_public_types.PublicCodec, int]]],
) -> Optional["SupportedCodecs"]:
if codecs is None:
return None
Expand Down
6 changes: 2 additions & 4 deletions ydb/_topic_reader/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@

class ICommittable(abc.ABC):
@abc.abstractmethod
def _commit_get_partition_session(self) -> PartitionSession:
...
def _commit_get_partition_session(self) -> PartitionSession: ...

@abc.abstractmethod
def _commit_get_offsets_range(self) -> OffsetsRange:
...
def _commit_get_offsets_range(self) -> OffsetsRange: ...


class ISessionAlive(abc.ABC):
Expand Down
5 changes: 3 additions & 2 deletions ydb/_topic_writer/topic_writer_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,15 +756,16 @@ async def test_buffered_messages_on_reconnect_sent_as_single_batch(
]
await reconnector.write_with_ack_future(messages)

sent = await asyncio.wait_for(stream_writer.from_client.get(), 1)
sent = await asyncio.wait_for(stream_writer.from_client.get(), 5)
assert len(sent) == 3

# ack first message, then trigger retriable error
stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1))
stream_writer.from_server.put_nowait(issues.Overloaded("test"))

second_writer = get_stream_writer()
resent = await asyncio.wait_for(second_writer.from_client.get(), 1)
# backoff after Overloaded can sleep up to 1s, so allow generous timeout
resent = await asyncio.wait_for(second_writer.from_client.get(), 5)

# msg2 and msg3 must arrive as a single batch, not two separate sends
assert resent == [InternalMessage(messages[1]), InternalMessage(messages[2])]
Expand Down
6 changes: 2 additions & 4 deletions ydb/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@ class _StatusResponseProtocol(Protocol):
"""Protocol for objects that have status and issues attributes."""

@property
def status(self) -> Union[StatusCode, int]:
...
def status(self) -> Union[StatusCode, int]: ...

@property
def issues(self) -> Iterable[Any]:
...
def issues(self) -> Iterable[Any]: ...


_TRANSPORT_STATUSES_FIRST = 401000
Expand Down
24 changes: 8 additions & 16 deletions ydb/query/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,12 @@ def _on_execute_stream_error(self, e: Exception) -> None:
@overload
def _create_call(
self: "BaseQuerySession[SyncDriver]", settings: Optional[BaseRequestSettings] = None
) -> "BaseQuerySession[SyncDriver]":
...
) -> "BaseQuerySession[SyncDriver]": ...

@overload
def _create_call(
self: "BaseQuerySession[AsyncDriver]", settings: Optional[BaseRequestSettings] = None
) -> Awaitable["BaseQuerySession[AsyncDriver]"]:
...
) -> Awaitable["BaseQuerySession[AsyncDriver]"]: ...

def _create_call(
self, settings: Optional[BaseRequestSettings] = None
Expand All @@ -170,14 +168,12 @@ def _create_call(
@overload
def _delete_call(
self: "BaseQuerySession[SyncDriver]", settings: Optional[BaseRequestSettings] = None
) -> "BaseQuerySession[SyncDriver]":
...
) -> "BaseQuerySession[SyncDriver]": ...

@overload
def _delete_call(
self: "BaseQuerySession[AsyncDriver]", settings: Optional[BaseRequestSettings] = None
) -> Awaitable["BaseQuerySession[AsyncDriver]"]:
...
) -> Awaitable["BaseQuerySession[AsyncDriver]"]: ...

def _delete_call(
self, settings: Optional[BaseRequestSettings] = None
Expand All @@ -197,14 +193,12 @@ def _delete_call(
@overload
def _attach_call(
self: "BaseQuerySession[SyncDriver]",
) -> GrpcStreamCall[_apis.ydb_query.SessionState]:
...
) -> GrpcStreamCall[_apis.ydb_query.SessionState]: ...

@overload
def _attach_call(
self: "BaseQuerySession[AsyncDriver]",
) -> Awaitable[GrpcStreamCall[_apis.ydb_query.SessionState]]:
...
) -> Awaitable[GrpcStreamCall[_apis.ydb_query.SessionState]]: ...

def _attach_call(
self,
Expand Down Expand Up @@ -233,8 +227,7 @@ def _execute_call(
arrow_format_settings: Optional[base.ArrowFormatSettings] = None,
concurrent_result_sets: bool = False,
settings: Optional[BaseRequestSettings] = None,
) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]:
...
) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: ...

@overload
def _execute_call(
Expand All @@ -250,8 +243,7 @@ def _execute_call(
arrow_format_settings: Optional[base.ArrowFormatSettings] = None,
concurrent_result_sets: bool = False,
settings: Optional[BaseRequestSettings] = None,
) -> Awaitable[Iterable[_apis.ydb_query.ExecuteQueryResponsePart]]:
...
) -> Awaitable[Iterable[_apis.ydb_query.ExecuteQueryResponsePart]]: ...

def _execute_call(
self,
Expand Down
Loading
Loading