diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e73ab43401..7afbdb37c4 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -106,7 +106,6 @@ from pyiceberg.table.update import ( AddPartitionSpecUpdate, AddSchemaUpdate, - AddSnapshotUpdate, AddSortOrderUpdate, AssertCreate, AssertRefSnapshotId, @@ -127,7 +126,6 @@ ) from pyiceberg.table.update.schema import UpdateSchema from pyiceberg.table.update.snapshot import ( - _SnapshotProducer, ManageSnapshots, UpdateSnapshot, _FastAppendFiles, @@ -161,6 +159,7 @@ from duckdb import DuckDBPyConnection from pyiceberg.catalog import Catalog + from pyiceberg.table.update import UpdateTableMetadata ALWAYS_TRUE = AlwaysTrue() DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" @@ -262,7 +261,7 @@ class Transaction: _autocommit: bool _updates: Tuple[TableUpdate, ...] _requirements: Tuple[TableRequirement, ...] - _snapshot_operations: Tuple[_SnapshotProducer, ...] + _snapshot_operations: Tuple[UpdateTableMetadata, ...] def __init__(self, table: Table, autocommit: bool = False): """Open a transaction to stage and commit changes to a table. @@ -504,8 +503,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) for data_file in data_files: append_files.append_data_file(data_file) - self._snapshot_operations += (append_files,) - def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: """ Shorthand for overwriting existing partitions with a PyArrow table. @@ -561,8 +558,6 @@ def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[st for data_file in data_files: append_files.append_data_file(data_file) - self._snapshot_operations += (append_files,) - def overwrite( self, df: pa.Table, @@ -620,8 +615,6 @@ def overwrite( for data_file in data_files: append_files.append_data_file(data_file) - self._snapshot_operations += (append_files,) - def delete( self, delete_filter: Union[str, BooleanExpression], @@ -716,8 +709,6 @@ def delete( if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed: warnings.warn("Delete operation did not match any records") - self._snapshot_operations += (delete_snapshot,) - def add_files( self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True ) -> None: @@ -754,8 +745,6 @@ def add_files( for data_file in data_files: append_snapshot.append_data_file(data_file) - self._snapshot_operations += (append_snapshot,) - def update_spec(self) -> UpdateSpec: """Create a new UpdateSpec to update the partitioning of the table. diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 1523ef1084..f96da72919 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -65,10 +65,10 @@ COMMIT_NUM_RETRIES_DEFAULT = 4 COMMIT_MIN_RETRY_WAIT_MS = "commit.retry.min-wait-ms" -COMMIT_MIN_RETRY_WAIT_MS_DEFAULT = 100 +COMMIT_MIN_RETRY_WAIT_MS_DEFAULT = 1000 # 1 second COMMIT_MAX_RETRY_WAIT_MS = "commit.retry.max-wait-ms" -COMMIT_MAX_RETRY_WAIT_MS_DEFAULT = 60 * 1000 # 1 minute +COMMIT_MAX_RETRY_WAIT_MS_DEFAULT = 5000 # 5 seconds INITIAL_SEQUENCE_NUMBER = 0 diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 1574d0f80e..1d1b48f0c5 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -104,7 +104,8 @@ def _before_commit_inner(state: RetryCallState) -> None: def commit_inner() -> None: self._transaction._apply(*self._commit()) - return commit_inner() + commit_inner() + self._transaction._snapshot_operations += (self,) def _cleanup_commit_failure(self) -> None: """Prepare the snapshot producer to commit against the latest version of the table after it has been updated.""" diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 69bbab527e..ea9b3d8238 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -16,6 +16,7 @@ # under the License. # pylint:disable=redefined-outer-name import json +from unittest.mock import Mock import uuid from copy import copy from typing import Any, Dict @@ -43,6 +44,7 @@ from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import ( + ALWAYS_TRUE, CommitTableRequest, StaticTable, Table, @@ -94,6 +96,7 @@ BucketTransform, IdentityTransform, ) +from pyiceberg.typedef import Record from pyiceberg.types import ( BinaryType, BooleanType, @@ -1378,3 +1381,91 @@ def test_remove_statistics_update(table_v2_with_statistics: Table) -> None: table_v2_with_statistics.metadata, (RemoveStatisticsUpdate(snapshot_id=123456789),), ) + + +def test_transaction_commit_retry(table_v1: Table, mocker: Mock) -> None: + import pyarrow as pa + + mock_data_file = DataFile( + content=DataFileContent.DATA, + file_path="s3://some-path/some-file.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=131327, + file_size_in_bytes=220669226, + column_sizes={1: 220661854}, + value_counts={1: 131327}, + null_value_counts={1: 0}, + nan_value_counts={}, + lower_bounds={1: b"aaaaaaaaaaaaaaaa"}, + upper_bounds={1: b"zzzzzzzzzzzzzzzz"}, + key_metadata=b"\xde\xad\xbe\xef", + split_offsets=[4, 133697593], + equality_ids=[], + sort_order_id=4, + ) + + call_count = 0 + captured_args = [] + + def mock_do_commit(*args, **kwargs): + """Capture arguments to `Transaction._do_commit` and invoke an initial retry.""" + + nonlocal call_count + captured_args.append((args, kwargs)) + call_count += 1 + if call_count == 1: + raise CommitFailedException("Test") + return None + + # Patch out IO of data, manifests, and metadata + mocker.patch("pyiceberg.io.pyarrow.write_file", return_value=[mock_data_file]) + mocker.patch("pyiceberg.table.update.snapshot.write_manifest") + mocker.patch("pyiceberg.table.update.snapshot.write_manifest_list") + mocker.patch("pyiceberg.catalog.noop.NoopCatalog.load_table", return_value=table_v1) + mocker.patch("pyiceberg.table.Table._do_commit", side_effect=mock_do_commit) + + schema = pa.schema( + [ + pa.field("x", pa.int64(), nullable=False), + pa.field("y", pa.int64(), nullable=False), + pa.field("z", pa.int64(), nullable=False), + ] + ) + + trx = table_v1.transaction() + with pytest.warns(UserWarning): + trx.delete(ALWAYS_TRUE) + trx.append(pa.Table.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6], "z": [7, 8, 9]}, schema=schema)) + trx.commit_transaction() + + # Verify that _do_commit was called twice (first failed, second succeeded) + assert call_count == 2, f"Expected 2 calls to _do_commit, got {call_count}" + + # Inspect the arguments passed to both commit attempts + _, first_call_kwargs = captured_args[0] + _, second_call_kwargs = captured_args[1] + + # Extract updates and requirements from both calls + first_updates = first_call_kwargs.get("updates", ()) + first_requirements = first_call_kwargs.get("requirements", ()) + second_updates = second_call_kwargs.get("updates", ()) + second_requirements = second_call_kwargs.get("requirements", ()) + + # Assert retry has same number of updates and requirements as first attempt + assert len(first_updates) == len(second_updates), f"Updates count mismatch: {len(first_updates)} vs {len(second_updates)}" + assert len(first_requirements) == len(second_requirements), ( + f"Requirements count mismatch: {len(first_requirements)} vs {len(second_requirements)}" + ) + + # Assert retry has same types of updates as first attempt + first_update_types = [type(update).__name__ for update in first_updates] + second_update_types = [type(update).__name__ for update in second_updates] + assert first_update_types == second_update_types, f"Update types mismatch: {first_update_types} vs {second_update_types}" + + # Assert retry has same types of requirements as first attempt + first_requirement_types = [type(req).__name__ for req in first_requirements] + second_requirement_types = [type(req).__name__ for req in second_requirements] + assert first_requirement_types == second_requirement_types, ( + f"Requirement types mismatch: {first_requirement_types} vs {second_requirement_types}" + )