From ec27e15dd56b661ccbe9619fcf943a1fd178de89 Mon Sep 17 00:00:00 2001 From: Ford Date: Fri, 19 Dec 2025 12:17:39 -0800 Subject: [PATCH 1/6] Add base infrastructure for loader test generalization Foundational work to enable all loader tests to inherit common test patterns --- .../loaders/backends/test_postgresql.py | 345 ++++++++++++++++++ tests/integration/loaders/conftest.py | 165 +++++++++ tests/integration/loaders/test_base_loader.py | 142 +++++++ 3 files changed, 652 insertions(+) create mode 100644 tests/integration/loaders/backends/test_postgresql.py create mode 100644 tests/integration/loaders/conftest.py create mode 100644 tests/integration/loaders/test_base_loader.py diff --git a/tests/integration/loaders/backends/test_postgresql.py b/tests/integration/loaders/backends/test_postgresql.py new file mode 100644 index 0000000..f45e317 --- /dev/null +++ b/tests/integration/loaders/backends/test_postgresql.py @@ -0,0 +1,345 @@ +""" +PostgreSQL-specific loader integration tests. + +This module provides PostgreSQL-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +import time +from typing import Any, Dict, List, Optional + +import pytest + +try: + from src.amp.loaders.implementations.postgresql_loader import PostgreSQLLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class PostgreSQLTestConfig(LoaderTestConfig): + """PostgreSQL-specific test configuration""" + + loader_class = PostgreSQLLoader + config_fixture_name = 'postgresql_test_config' + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + + def get_row_count(self, loader: PostgreSQLLoader, table_name: str) -> int: + """Get row count from PostgreSQL table""" + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {table_name}') + return cur.fetchone()[0] + finally: + loader.pool.putconn(conn) + + def query_rows( + self, loader: PostgreSQLLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from PostgreSQL table""" + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + # Get column names first + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = %s + ORDER BY ordinal_position + """, + (table_name,), + ) + columns = [row[0] for row in cur.fetchall()] + + # Build query + query = f'SELECT * FROM {table_name}' + if where: + query += f' WHERE {where}' + if order_by: + query += f' ORDER BY {order_by}' + + cur.execute(query) + rows = cur.fetchall() + + # Convert to list of dicts + return [dict(zip(columns, row)) for row in rows] + finally: + loader.pool.putconn(conn) + + def cleanup_table(self, loader: PostgreSQLLoader, table_name: str) -> None: + """Drop PostgreSQL table""" + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'DROP TABLE IF EXISTS {table_name} CASCADE') + conn.commit() + finally: + loader.pool.putconn(conn) + + def get_column_names(self, loader: PostgreSQLLoader, table_name: str) -> List[str]: + """Get column names from PostgreSQL table""" + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = %s + ORDER BY ordinal_position + """, + (table_name,), + ) + return [row[0] for row in cur.fetchall()] + finally: + loader.pool.putconn(conn) + + +@pytest.mark.postgresql +class TestPostgreSQLCore(BaseLoaderTests): + """PostgreSQL core loader tests (inherited from base)""" + + config = PostgreSQLTestConfig() + + +@pytest.fixture +def cleanup_tables(postgresql_test_config): + """Cleanup test tables after tests""" + tables_to_clean = [] + + yield tables_to_clean + + # Cleanup + loader = PostgreSQLLoader(postgresql_test_config) + try: + loader.connect() + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + for table in tables_to_clean: + try: + cur.execute(f'DROP TABLE IF EXISTS {table} CASCADE') + conn.commit() + except Exception: + pass + finally: + loader.pool.putconn(conn) + loader.disconnect() + except Exception: + pass + + +@pytest.mark.postgresql +class TestPostgreSQLSpecific: + """PostgreSQL-specific tests that cannot be generalized""" + + def test_connection_pooling(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): + """Test PostgreSQL connection pooling behavior""" + from src.amp.loaders.base import LoadMode + + cleanup_tables.append(test_table_name) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + # Perform multiple operations to test pool reuse + for i in range(5): + subset = small_test_data.slice(i, 1) + mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND + + result = loader.load_table(subset, test_table_name, mode=mode) + assert result.success == True + + # Verify pool is managing connections properly + # Note: _used is a dict in ThreadedConnectionPool, not an int + assert len(loader.pool._used) <= loader.pool.maxconn + + def test_binary_data_handling(self, postgresql_test_config, test_table_name, cleanup_tables): + """Test binary data handling with INSERT fallback""" + import pyarrow as pa + + cleanup_tables.append(test_table_name) + + # Create data with binary columns + data = {'id': [1, 2, 3], 'binary_data': [b'hello', b'world', b'test'], 'text_data': ['a', 'b', 'c']} + table = pa.Table.from_pydict(data) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + result = loader.load_table(table, test_table_name) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify binary data was stored correctly + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT id, binary_data FROM {test_table_name} ORDER BY id') + rows = cur.fetchall() + assert rows[0][1].tobytes() == b'hello' + assert rows[1][1].tobytes() == b'world' + assert rows[2][1].tobytes() == b'test' + finally: + loader.pool.putconn(conn) + + def test_schema_retrieval(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): + """Test schema retrieval functionality""" + cleanup_tables.append(test_table_name) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + # Create table + result = loader.load_table(small_test_data, test_table_name) + assert result.success == True + + # Get schema + schema = loader.get_table_schema(test_table_name) + assert schema is not None + + # Filter out metadata columns added by PostgreSQL loader + non_meta_fields = [ + field for field in schema if not (field.name.startswith('_meta_') or field.name.startswith('_amp_')) + ] + + assert len(non_meta_fields) == len(small_test_data.schema) + + # Verify column names match (excluding metadata columns) + original_names = set(small_test_data.schema.names) + retrieved_names = set(field.name for field in non_meta_fields) + assert original_names == retrieved_names + + def test_performance_metrics(self, postgresql_test_config, medium_test_table, test_table_name, cleanup_tables): + """Test performance metrics in results""" + cleanup_tables.append(test_table_name) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + start_time = time.time() + result = loader.load_table(medium_test_table, test_table_name) + end_time = time.time() + + assert result.success == True + assert result.duration > 0 + assert result.duration <= (end_time - start_time) + assert result.rows_loaded == 10000 + + # Check metadata contains performance info + assert 'table_size_bytes' in result.metadata + assert result.metadata['table_size_bytes'] > 0 + + def test_null_value_handling_detailed( + self, postgresql_test_config, null_test_data, test_table_name, cleanup_tables + ): + """Test comprehensive null value handling across all PostgreSQL data types""" + cleanup_tables.append(test_table_name) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + result = loader.load_table(null_test_data, test_table_name) + assert result.success == True + assert result.rows_loaded == 10 + + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + # Check text field nulls (rows 3, 6, 9 have index 2, 5, 8) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE text_field IS NULL') + text_nulls = cur.fetchone()[0] + assert text_nulls == 3 + + # Check int field nulls (rows 2, 5, 8 have index 1, 4, 7) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE int_field IS NULL') + int_nulls = cur.fetchone()[0] + assert int_nulls == 3 + + # Check float field nulls (rows 3, 6, 9 have index 2, 5, 8) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE float_field IS NULL') + float_nulls = cur.fetchone()[0] + assert float_nulls == 3 + + # Check bool field nulls (rows 3, 6, 9 have index 2, 5, 8) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE bool_field IS NULL') + bool_nulls = cur.fetchone()[0] + assert bool_nulls == 3 + + # Check timestamp field nulls + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE timestamp_field IS NULL') + timestamp_nulls = cur.fetchone()[0] + assert timestamp_nulls == 4 + + # Check json field nulls + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE json_field IS NULL') + json_nulls = cur.fetchone()[0] + assert json_nulls == 3 + + # Verify non-null values are intact + cur.execute(f'SELECT text_field FROM {test_table_name} WHERE id = 1') + text_val = cur.fetchone()[0] + assert text_val in ['a', '"a"'] # Handle potential CSV quoting + + cur.execute(f'SELECT int_field FROM {test_table_name} WHERE id = 1') + int_val = cur.fetchone()[0] + assert int_val == 1 + + cur.execute(f'SELECT float_field FROM {test_table_name} WHERE id = 1') + float_val = cur.fetchone()[0] + assert abs(float_val - 1.1) < 0.01 + + cur.execute(f'SELECT bool_field FROM {test_table_name} WHERE id = 1') + bool_val = cur.fetchone()[0] + assert bool_val == True + finally: + loader.pool.putconn(conn) + + +@pytest.mark.postgresql +@pytest.mark.slow +class TestPostgreSQLPerformance: + """PostgreSQL performance tests""" + + def test_large_data_loading(self, postgresql_test_config, test_table_name, cleanup_tables): + """Test loading large datasets""" + import pyarrow as pa + from datetime import datetime + + cleanup_tables.append(test_table_name) + + # Create large dataset + large_data = { + 'id': list(range(50000)), + 'value': [i * 0.123 for i in range(50000)], + 'category': [f'category_{i % 100}' for i in range(50000)], + 'description': [f'This is a longer text description for row {i}' for i in range(50000)], + 'created_at': [datetime.now() for _ in range(50000)], + } + large_table = pa.Table.from_pydict(large_data) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + result = loader.load_table(large_table, test_table_name) + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 60 # Should complete within 60 seconds + + # Verify data integrity + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + count = cur.fetchone()[0] + assert count == 50000 + finally: + loader.pool.putconn(conn) diff --git a/tests/integration/loaders/conftest.py b/tests/integration/loaders/conftest.py new file mode 100644 index 0000000..ba7d3ea --- /dev/null +++ b/tests/integration/loaders/conftest.py @@ -0,0 +1,165 @@ +""" +Base test classes and configuration for loader integration tests. + +This module provides abstract base classes that enable test generalization +across different loader implementations while maintaining full test coverage. +""" + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional, Type + +import pytest + +from src.amp.loaders.base import DataLoader + + +@pytest.fixture +def test_table_name(): + """Generate unique table name for each test""" + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') + return f'test_table_{timestamp}' + + +class LoaderTestConfig(ABC): + """ + Configuration for a specific loader's tests. + + Each loader implementation must provide a concrete implementation + that defines loader-specific query methods and capabilities. + """ + + # Required: Specify the loader class being tested + loader_class: Type[DataLoader] = None + + # Required: Name of the pytest fixture that provides the loader config dict + config_fixture_name: str = None + + # Loader capability flags + supports_overwrite: bool = True + supports_streaming: bool = True + supports_multi_network: bool = True + supports_null_values: bool = True + + @abstractmethod + def get_row_count(self, loader: DataLoader, table_name: str) -> int: + """ + Get the number of rows in a table. + + Args: + loader: The loader instance + table_name: Name of the table + + Returns: + Number of rows in the table + """ + raise NotImplementedError + + @abstractmethod + def query_rows( + self, loader: DataLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Query rows from a table. + + Args: + loader: The loader instance + table_name: Name of the table + where: Optional WHERE clause (without 'WHERE' keyword) + order_by: Optional ORDER BY clause (without 'ORDER BY' keyword) + + Returns: + List of dictionaries representing rows + """ + raise NotImplementedError + + @abstractmethod + def cleanup_table(self, loader: DataLoader, table_name: str) -> None: + """ + Clean up/drop a table. + + Args: + loader: The loader instance + table_name: Name of the table to drop + """ + raise NotImplementedError + + @abstractmethod + def get_column_names(self, loader: DataLoader, table_name: str) -> List[str]: + """ + Get column names for a table. + + Args: + loader: The loader instance + table_name: Name of the table + + Returns: + List of column names + """ + raise NotImplementedError + + +class LoaderTestBase: + """ + Base class for all loader tests. + + Test classes should inherit from this and set the `config` class attribute + to a LoaderTestConfig instance. + + Example: + class TestPostgreSQLCore(BaseLoaderTests): + config = PostgreSQLTestConfig() + """ + + # Override this in subclasses + config: LoaderTestConfig = None + + @pytest.fixture + def loader(self, request): + """ + Create a loader instance from the config. + + This fixture dynamically retrieves the loader config from the + fixture name specified in LoaderTestConfig.config_fixture_name. + """ + if self.config is None: + raise ValueError('Test class must define a config attribute') + if self.config.loader_class is None: + raise ValueError('LoaderTestConfig must define loader_class') + if self.config.config_fixture_name is None: + raise ValueError('LoaderTestConfig must define config_fixture_name') + + # Get the loader config from the specified fixture + loader_config = request.getfixturevalue(self.config.config_fixture_name) + + # Create and return the loader instance + return self.config.loader_class(loader_config) + + @pytest.fixture + def cleanup_tables(self, request): + """ + Cleanup fixture that drops tables after tests. + + Tests should append table names to this list, and they will be + cleaned up automatically after the test completes. + """ + tables_to_clean = [] + + yield tables_to_clean + + # Cleanup after test + if tables_to_clean and self.config: + loader_config = request.getfixturevalue(self.config.config_fixture_name) + loader = self.config.loader_class(loader_config) + try: + loader.connect() + for table_name in tables_to_clean: + try: + self.config.cleanup_table(loader, table_name) + except Exception: + # Ignore cleanup errors + pass + loader.disconnect() + except Exception: + # Ignore connection errors during cleanup + pass diff --git a/tests/integration/loaders/test_base_loader.py b/tests/integration/loaders/test_base_loader.py new file mode 100644 index 0000000..0bb24ee --- /dev/null +++ b/tests/integration/loaders/test_base_loader.py @@ -0,0 +1,142 @@ +""" +Generalized core loader tests that work across all loader implementations. + +These tests use the LoaderTestConfig abstraction to run identical test logic +across different storage backends (PostgreSQL, Redis, Snowflake, etc.). +""" + +import pytest + +from src.amp.loaders.base import LoadMode +from tests.integration.loaders.conftest import LoaderTestBase + + +class BaseLoaderTests(LoaderTestBase): + """ + Base test class with core loader functionality tests. + + All loaders should inherit from this class and provide a LoaderTestConfig. + + Example: + class TestPostgreSQLCore(BaseLoaderTests): + config = PostgreSQLTestConfig() + """ + + @pytest.mark.integration + def test_connection(self, loader): + """Test basic connection and disconnection to the storage backend""" + # Test connection + loader.connect() + assert loader._is_connected == True + + # Test disconnection + loader.disconnect() + assert loader._is_connected == False + + @pytest.mark.integration + def test_context_manager(self, loader, small_test_data, test_table_name, cleanup_tables): + """Test context manager functionality""" + cleanup_tables.append(test_table_name) + + with loader: + assert loader._is_connected == True + + result = loader.load_table(small_test_data, test_table_name) + assert result.success == True + + # Should be disconnected after context + assert loader._is_connected == False + + @pytest.mark.integration + def test_batch_loading(self, loader, medium_test_table, test_table_name, cleanup_tables): + """Test batch loading functionality with sequential batches""" + cleanup_tables.append(test_table_name) + + with loader: + # Test loading individual batches + batches = medium_test_table.to_batches(max_chunksize=250) + + for i, batch in enumerate(batches): + mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND + result = loader.load_batch(batch, test_table_name, mode=mode) + + assert result.success == True + assert result.rows_loaded == batch.num_rows + assert result.metadata['batch_size'] == batch.num_rows + + # Verify all data was loaded + total_rows = self.config.get_row_count(loader, test_table_name) + assert total_rows == 10000 + + @pytest.mark.integration + def test_append_mode(self, loader, small_test_data, test_table_name, cleanup_tables): + """Test append mode functionality""" + cleanup_tables.append(test_table_name) + + with loader: + # Initial load + result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.APPEND) + assert result.success == True + assert result.rows_loaded == 5 + + # Append additional data + result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.APPEND) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify total rows + total_rows = self.config.get_row_count(loader, test_table_name) + assert total_rows == 10 # 5 + 5 + + @pytest.mark.integration + def test_overwrite_mode(self, loader, small_test_data, test_table_name, cleanup_tables): + """Test overwrite mode functionality""" + if not self.config.supports_overwrite: + pytest.skip('Loader does not support overwrite mode') + + cleanup_tables.append(test_table_name) + + with loader: + # Initial load + result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.OVERWRITE) + assert result.success == True + assert result.rows_loaded == 5 + + # Overwrite with different data + new_data = small_test_data.slice(0, 3) # First 3 rows + result = loader.load_table(new_data, test_table_name, mode=LoadMode.OVERWRITE) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify only new data remains + total_rows = self.config.get_row_count(loader, test_table_name) + assert total_rows == 3 + + @pytest.mark.integration + def test_null_handling(self, loader, null_test_data, test_table_name, cleanup_tables): + """Test null value handling across all data types""" + if not self.config.supports_null_values: + pytest.skip('Loader does not support null values') + + cleanup_tables.append(test_table_name) + + with loader: + result = loader.load_table(null_test_data, test_table_name) + assert result.success == True + assert result.rows_loaded == 10 + + # Verify data was loaded (basic sanity check) + # Specific null value verification is loader-specific and tested in backend tests + total_rows = self.config.get_row_count(loader, test_table_name) + assert total_rows == 10 + + @pytest.mark.integration + def test_error_handling(self, loader, small_test_data): + """Test error handling scenarios""" + with loader: + # Test loading to non-existent table without create_table + result = loader.load_table(small_test_data, 'non_existent_table', create_table=False) + + assert result.success == False + assert result.error is not None + assert result.rows_loaded == 0 From 42f0c07e4c6fbf2450c7feb43fb8145a11f35d07 Mon Sep 17 00:00:00 2001 From: Ford Date: Sat, 3 Jan 2026 10:39:39 -0800 Subject: [PATCH 2/6] Add streaming tests and migrate Redis/Snowflake loaders Adds generalized streaming test infrastructure and migrates Redis and Snowflake loader tests to use the shared base classes. --- .../loaders/backends/test_postgresql.py | 8 + .../loaders/backends/test_redis.py | 380 ++++++++++++++++++ .../loaders/backends/test_snowflake.py | 282 +++++++++++++ tests/integration/loaders/conftest.py | 15 + .../loaders/test_base_streaming.py | 348 ++++++++++++++++ 5 files changed, 1033 insertions(+) create mode 100644 tests/integration/loaders/backends/test_redis.py create mode 100644 tests/integration/loaders/backends/test_snowflake.py create mode 100644 tests/integration/loaders/test_base_streaming.py diff --git a/tests/integration/loaders/backends/test_postgresql.py b/tests/integration/loaders/backends/test_postgresql.py index f45e317..163afb6 100644 --- a/tests/integration/loaders/backends/test_postgresql.py +++ b/tests/integration/loaders/backends/test_postgresql.py @@ -14,6 +14,7 @@ from src.amp.loaders.implementations.postgresql_loader import PostgreSQLLoader from tests.integration.loaders.conftest import LoaderTestConfig from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests except ImportError: pytest.skip('amp modules not available', allow_module_level=True) @@ -109,6 +110,13 @@ class TestPostgreSQLCore(BaseLoaderTests): config = PostgreSQLTestConfig() +@pytest.mark.postgresql +class TestPostgreSQLStreaming(BaseStreamingTests): + """PostgreSQL streaming tests (inherited from base)""" + + config = PostgreSQLTestConfig() + + @pytest.fixture def cleanup_tables(postgresql_test_config): """Cleanup test tables after tests""" diff --git a/tests/integration/loaders/backends/test_redis.py b/tests/integration/loaders/backends/test_redis.py new file mode 100644 index 0000000..b25e7e2 --- /dev/null +++ b/tests/integration/loaders/backends/test_redis.py @@ -0,0 +1,380 @@ +""" +Redis-specific loader integration tests. + +This module provides Redis-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +import json +from typing import Any, Dict, List, Optional + +import pytest + +try: + import redis + + from src.amp.loaders.implementations.redis_loader import RedisLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class RedisTestConfig(LoaderTestConfig): + """Redis-specific test configuration""" + + loader_class = RedisLoader + config_fixture_name = 'redis_test_config' + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + + def get_row_count(self, loader: RedisLoader, table_name: str) -> int: + """Get row count from Redis based on data structure""" + data_structure = loader.config.get('data_structure', 'hash') + + if data_structure == 'hash': + # Count hash keys matching pattern + pattern = f'{table_name}:*' + count = 0 + for _ in loader.redis_client.scan_iter(match=pattern, count=1000): + count += 1 + return count + elif data_structure == 'string': + # Count string keys matching pattern + pattern = f'{table_name}:*' + count = 0 + for _ in loader.redis_client.scan_iter(match=pattern, count=1000): + count += 1 + return count + elif data_structure == 'stream': + # Get stream length + try: + return loader.redis_client.xlen(table_name) + except: + return 0 + else: + # For other structures, scan for keys + pattern = f'{table_name}:*' + count = 0 + for _ in loader.redis_client.scan_iter(match=pattern, count=1000): + count += 1 + return count + + def query_rows( + self, loader: RedisLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from Redis - limited functionality due to Redis architecture""" + # Redis doesn't support SQL-like queries, so this is a simplified implementation + data_structure = loader.config.get('data_structure', 'hash') + rows = [] + + if data_structure == 'hash': + pattern = f'{table_name}:*' + for key in loader.redis_client.scan_iter(match=pattern, count=100): + data = loader.redis_client.hgetall(key) + row = {k.decode(): v.decode() if isinstance(v, bytes) else v for k, v in data.items()} + rows.append(row) + elif data_structure == 'stream': + # Read from stream + messages = loader.redis_client.xrange(table_name, count=100) + for msg_id, data in messages: + row = {k.decode(): v.decode() if isinstance(v, bytes) else v for k, v in data.items()} + row['_stream_id'] = msg_id.decode() + rows.append(row) + + return rows[:100] # Limit to 100 rows + + def cleanup_table(self, loader: RedisLoader, table_name: str) -> None: + """Drop Redis keys for table""" + # Delete all keys matching the table pattern + patterns = [f'{table_name}:*', table_name] + + for pattern in patterns: + for key in loader.redis_client.scan_iter(match=pattern, count=1000): + loader.redis_client.delete(key) + + def get_column_names(self, loader: RedisLoader, table_name: str) -> List[str]: + """Get column names from Redis - return fields from first record""" + data_structure = loader.config.get('data_structure', 'hash') + + if data_structure == 'hash': + pattern = f'{table_name}:*' + for key in loader.redis_client.scan_iter(match=pattern, count=1): + data = loader.redis_client.hgetall(key) + return [k.decode() if isinstance(k, bytes) else k for k in data.keys()] + elif data_structure == 'stream': + messages = loader.redis_client.xrange(table_name, count=1) + if messages: + _, data = messages[0] + return [k.decode() if isinstance(k, bytes) else k for k in data.keys()] + + return [] + + +@pytest.fixture +def cleanup_redis(redis_test_config): + """Cleanup Redis data after tests""" + keys_to_clean = [] + patterns_to_clean = [] + + yield (keys_to_clean, patterns_to_clean) + + # Cleanup + try: + r = redis.Redis( + host=redis_test_config['host'], + port=redis_test_config['port'], + db=redis_test_config['db'], + password=redis_test_config['password'], + ) + + # Delete specific keys + for key in keys_to_clean: + r.delete(key) + + # Delete keys matching patterns + for pattern in patterns_to_clean: + for key in r.scan_iter(match=pattern, count=1000): + r.delete(key) + + r.close() + except Exception: + pass + + +@pytest.mark.redis +class TestRedisCore(BaseLoaderTests): + """Redis core loader tests (inherited from base)""" + + config = RedisTestConfig() + + +@pytest.mark.redis +class TestRedisStreaming(BaseStreamingTests): + """Redis streaming tests (inherited from base)""" + + config = RedisTestConfig() + + +@pytest.mark.redis +class TestRedisSpecific: + """Redis-specific tests that cannot be generalized""" + + def test_hash_storage(self, redis_test_config, small_test_data, cleanup_redis): + """Test Redis hash data structure storage""" + import pyarrow as pa + + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_hash' + patterns_to_clean.append(f'{table_name}:*') + + config = {**redis_test_config, 'data_structure': 'hash', 'key_field': 'id'} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify data is stored as hashes + pattern = f'{table_name}:*' + keys = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys) == 5 + + # Verify hash structure + first_key = keys[0] + data = loader.redis_client.hgetall(first_key) + assert len(data) > 0 + + def test_string_storage(self, redis_test_config, small_test_data, cleanup_redis): + """Test Redis string data structure storage""" + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_string' + patterns_to_clean.append(f'{table_name}:*') + + config = {**redis_test_config, 'data_structure': 'string', 'key_field': 'id'} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify data is stored as strings + pattern = f'{table_name}:*' + keys = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys) == 5 + + # Verify string structure (should be JSON) + first_key = keys[0] + data = loader.redis_client.get(first_key) + assert data is not None + # Should be JSON-encoded + json.loads(data) + + def test_stream_storage(self, redis_test_config, small_test_data, cleanup_redis): + """Test Redis stream data structure storage""" + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_stream' + keys_to_clean.append(table_name) + + config = {**redis_test_config, 'data_structure': 'stream'} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify data is in stream + stream_len = loader.redis_client.xlen(table_name) + assert stream_len == 5 + + def test_set_storage(self, redis_test_config, small_test_data, cleanup_redis): + """Test Redis set data structure storage""" + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_set' + patterns_to_clean.append(f'{table_name}:*') + + config = {**redis_test_config, 'data_structure': 'set', 'key_field': 'id'} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + def test_ttl_functionality(self, redis_test_config, small_test_data, cleanup_redis): + """Test TTL (time-to-live) functionality""" + import time + + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_ttl' + patterns_to_clean.append(f'{table_name}:*') + + config = {**redis_test_config, 'data_structure': 'hash', 'key_field': 'id', 'ttl': 2} # 2 second TTL + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + + # Verify data exists + pattern = f'{table_name}:*' + keys_before = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys_before) == 5 + + # Verify TTL is set + ttl = loader.redis_client.ttl(keys_before[0]) + assert ttl > 0 and ttl <= 2 + + # Wait for TTL to expire + time.sleep(3) + + # Verify data has expired + keys_after = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys_after) == 0 + + def test_key_pattern_generation(self, redis_test_config, cleanup_redis): + """Test custom key pattern generation""" + import pyarrow as pa + + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_key_pattern' + patterns_to_clean.append(f'{table_name}:*') + + data = {'user_id': ['user1', 'user2', 'user3'], 'score': [100, 200, 300], 'level': [1, 2, 3]} + test_data = pa.Table.from_pydict(data) + + config = { + **redis_test_config, + 'data_structure': 'hash', + 'key_field': 'user_id', + 'key_pattern': '{table}:user:{key_value}', + } + loader = RedisLoader(config) + + with loader: + result = loader.load_table(test_data, table_name) + assert result.success == True + + # Verify custom key pattern + pattern = f'{table_name}:user:*' + keys = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys) == 3 + + # Verify key format + key_str = keys[0].decode() if isinstance(keys[0], bytes) else keys[0] + assert key_str.startswith(f'{table_name}:user:') + + def test_data_structure_comparison(self, redis_test_config, comprehensive_test_data, cleanup_redis): + """Test performance comparison between different data structures""" + import time + + keys_to_clean, patterns_to_clean = cleanup_redis + + structures = ['hash', 'string', 'stream'] + results = {} + + for structure in structures: + table_name = f'test_perf_{structure}' + patterns_to_clean.append(f'{table_name}:*') + keys_to_clean.append(table_name) + + config = {**redis_test_config, 'data_structure': structure} + if structure in ['hash', 'string']: + config['key_field'] = 'id' + + loader = RedisLoader(config) + + with loader: + start_time = time.time() + result = loader.load_table(comprehensive_test_data, table_name) + duration = time.time() - start_time + + results[structure] = { + 'success': result.success, + 'duration': duration, + 'rows_loaded': result.rows_loaded, + } + + # Verify all structures work + for structure, data in results.items(): + assert data['success'] == True + assert data['rows_loaded'] == 1000 + + +@pytest.mark.redis +@pytest.mark.slow +class TestRedisPerformance: + """Redis performance tests""" + + def test_large_data_loading(self, redis_test_config, cleanup_redis): + """Test loading large datasets to Redis""" + import pyarrow as pa + + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_large' + patterns_to_clean.append(f'{table_name}:*') + + # Create large dataset + large_data = { + 'id': list(range(10000)), + 'value': [i * 0.123 for i in range(10000)], + 'category': [f'cat_{i % 100}' for i in range(10000)], + } + large_table = pa.Table.from_pydict(large_data) + + config = {**redis_test_config, 'data_structure': 'hash', 'key_field': 'id'} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(large_table, table_name) + + assert result.success == True + assert result.rows_loaded == 10000 + assert result.duration < 60 # Should complete within 60 seconds diff --git a/tests/integration/loaders/backends/test_snowflake.py b/tests/integration/loaders/backends/test_snowflake.py new file mode 100644 index 0000000..6f42e3f --- /dev/null +++ b/tests/integration/loaders/backends/test_snowflake.py @@ -0,0 +1,282 @@ +""" +Snowflake-specific loader integration tests. + +This module provides Snowflake-specific test configuration and tests that +inherit from the generalized base test classes. + +Note: Snowflake tests require valid Snowflake credentials and are typically +skipped in CI/CD. Run manually with: pytest -m snowflake +""" + +from typing import Any, Dict, List, Optional + +import pytest + +try: + from src.amp.loaders.implementations.snowflake_loader import SnowflakeLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class SnowflakeTestConfig(LoaderTestConfig): + """Snowflake-specific test configuration""" + + loader_class = SnowflakeLoader + config_fixture_name = 'snowflake_config' + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + + def get_row_count(self, loader: SnowflakeLoader, table_name: str) -> int: + """Get row count from Snowflake table""" + with loader.conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {table_name}') + return cur.fetchone()[0] + + def query_rows( + self, loader: SnowflakeLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from Snowflake table""" + query = f'SELECT * FROM {table_name}' + if where: + query += f' WHERE {where}' + if order_by: + query += f' ORDER BY {order_by}' + query += ' LIMIT 100' + + with loader.conn.cursor() as cur: + cur.execute(query) + columns = [col[0] for col in cur.description] + rows = cur.fetchall() + return [dict(zip(columns, row)) for row in rows] + + def cleanup_table(self, loader: SnowflakeLoader, table_name: str) -> None: + """Drop Snowflake table""" + with loader.conn.cursor() as cur: + cur.execute(f'DROP TABLE IF EXISTS {table_name}') + + def get_column_names(self, loader: SnowflakeLoader, table_name: str) -> List[str]: + """Get column names from Snowflake table""" + with loader.conn.cursor() as cur: + cur.execute(f'SELECT * FROM {table_name} LIMIT 0') + return [col[0] for col in cur.description] + + +@pytest.mark.snowflake +class TestSnowflakeCore(BaseLoaderTests): + """Snowflake core loader tests (inherited from base)""" + + config = SnowflakeTestConfig() + + +@pytest.mark.snowflake +class TestSnowflakeStreaming(BaseStreamingTests): + """Snowflake streaming tests (inherited from base)""" + + config = SnowflakeTestConfig() + + +@pytest.mark.snowflake +class TestSnowflakeSpecific: + """Snowflake-specific tests that cannot be generalized""" + + def test_stage_loading_method(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): + """Test Snowflake stage-based loading (Snowflake-specific optimization)""" + import pyarrow as pa + + cleanup_tables.append(test_table_name) + + # Configure for stage loading + config = {**snowflake_config, 'loading_method': 'stage'} + loader = SnowflakeLoader(config) + + with loader: + result = loader.load_table(small_test_table, test_table_name) + assert result.success == True + assert result.rows_loaded == 100 + + # Verify data loaded + with loader.conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + count = cur.fetchone()[0] + assert count == 100 + + def test_insert_loading_method(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): + """Test Snowflake INSERT-based loading""" + cleanup_tables.append(test_table_name) + + # Configure for INSERT loading + config = {**snowflake_config, 'loading_method': 'insert'} + loader = SnowflakeLoader(config) + + with loader: + result = loader.load_table(small_test_table, test_table_name) + assert result.success == True + assert result.rows_loaded == 100 + + def test_table_info_retrieval(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): + """Test Snowflake table information retrieval""" + cleanup_tables.append(test_table_name) + + loader = SnowflakeLoader(snowflake_config) + + with loader: + # Load data first + loader.load_table(small_test_table, test_table_name) + + # Get table info + info = loader.get_table_info(test_table_name) + + assert info is not None + assert 'row_count' in info + assert 'bytes' in info + assert 'clustering_key' in info + assert info['row_count'] == 100 + + def test_concurrent_batch_loading(self, snowflake_config, medium_test_table, test_table_name, cleanup_tables): + """Test concurrent batch loading to Snowflake""" + from concurrent.futures import ThreadPoolExecutor, as_completed + + from src.amp.loaders.base import LoadMode + + cleanup_tables.append(test_table_name) + + loader = SnowflakeLoader(snowflake_config) + + with loader: + # Split table into batches + batches = medium_test_table.to_batches(max_chunksize=2000) + + # Load batches concurrently + def load_batch_with_mode(batch_tuple): + i, batch = batch_tuple + mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND + return loader.load_batch(batch, test_table_name, mode=mode) + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(load_batch_with_mode, (i, batch)) for i, batch in enumerate(batches)] + + results = [] + for future in as_completed(futures): + result = future.result() + results.append(result) + + # Verify all batches succeeded + assert all(r.success for r in results) + + # Verify total row count + with loader.conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + count = cur.fetchone()[0] + assert count == 10000 + + def test_schema_special_characters(self, snowflake_config, test_table_name, cleanup_tables): + """Test handling of special characters in column names""" + import pyarrow as pa + + cleanup_tables.append(test_table_name) + + # Create data with special characters in column names + data = { + 'id': [1, 2, 3], + 'user name': ['Alice', 'Bob', 'Charlie'], # Space in name + 'email@address': ['a@ex.com', 'b@ex.com', 'c@ex.com'], # @ symbol + 'data-value': [100, 200, 300], # Hyphen + } + test_data = pa.Table.from_pydict(data) + + loader = SnowflakeLoader(snowflake_config) + + with loader: + result = loader.load_table(test_data, test_table_name) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify columns were properly escaped + with loader.conn.cursor() as cur: + cur.execute(f'SELECT * FROM {test_table_name} LIMIT 1') + columns = [col[0] for col in cur.description] + # Snowflake normalizes column names + assert len(columns) >= 4 + + def test_history_preservation_with_reorg(self, snowflake_config, test_table_name, cleanup_tables): + """Test Snowflake's history preservation during reorg (Time Travel feature)""" + import pyarrow as pa + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + cleanup_tables.append(test_table_name) + + config = {**snowflake_config, 'preserve_history': True} + loader = SnowflakeLoader(config) + + with loader: + # Load initial data + batch1 = pa.RecordBatch.from_pydict({'tx_hash': ['0x100', '0x101'], 'block_num': [100, 101]}) + + loader._create_table_from_schema(batch1.schema, test_table_name) + + response1 = ResponseBatch.data_batch( + data=batch1, + metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=101, hash='0xaaa')]), + ) + + results = list(loader.load_stream_continuous(iter([response1]), test_table_name)) + assert len(results) == 1 + assert results[0].success + + # Verify initial count + with loader.conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + initial_count = cur.fetchone()[0] + assert initial_count == 2 + + # Perform reorg (with history preservation, data is soft-deleted) + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=100, end=101)] + ) + reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) + assert len(reorg_results) == 1 + + # With history preservation, data may be marked deleted rather than physically removed + # Verify the reorg was processed (exact behavior depends on implementation) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + after_count = cur.fetchone()[0] + # Count should be 0 or data should have _deleted flag + assert after_count >= 0 # Flexible assertion for different implementations + + +@pytest.mark.snowflake +@pytest.mark.slow +class TestSnowflakePerformance: + """Snowflake performance tests""" + + def test_large_batch_loading(self, snowflake_config, performance_test_data, test_table_name, cleanup_tables): + """Test loading large batches to Snowflake""" + cleanup_tables.append(test_table_name) + + loader = SnowflakeLoader(snowflake_config) + + with loader: + result = loader.load_table(performance_test_data, test_table_name) + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 120 # Should complete within 2 minutes + + # Verify data integrity + with loader.conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + count = cur.fetchone()[0] + assert count == 50000 + + +# Note: Snowpipe Streaming tests (TestSnowpipeStreamingIntegration) are kept +# in the original test file as they are highly Snowflake-specific and involve +# complex channel management that doesn't generalize well. See: +# tests/integration/test_snowflake_loader.py::TestSnowpipeStreamingIntegration diff --git a/tests/integration/loaders/conftest.py b/tests/integration/loaders/conftest.py index ba7d3ea..7e078ad 100644 --- a/tests/integration/loaders/conftest.py +++ b/tests/integration/loaders/conftest.py @@ -121,6 +121,9 @@ def loader(self, request): This fixture dynamically retrieves the loader config from the fixture name specified in LoaderTestConfig.config_fixture_name. + + For streaming tests, state management is enabled by default to support + reorg operations and batch tracking. """ if self.config is None: raise ValueError('Test class must define a config attribute') @@ -132,6 +135,18 @@ def loader(self, request): # Get the loader config from the specified fixture loader_config = request.getfixturevalue(self.config.config_fixture_name) + # Enable state management for streaming tests (needed for reorg) + # Make a copy to avoid modifying the original fixture + if isinstance(loader_config, dict): + loader_config = loader_config.copy() + # Only enable state if not explicitly configured + if 'state' not in loader_config: + loader_config['state'] = { + 'enabled': True, + 'storage': 'memory', + 'store_batch_id': True + } + # Create and return the loader instance return self.config.loader_class(loader_config) diff --git a/tests/integration/loaders/test_base_streaming.py b/tests/integration/loaders/test_base_streaming.py new file mode 100644 index 0000000..76768f8 --- /dev/null +++ b/tests/integration/loaders/test_base_streaming.py @@ -0,0 +1,348 @@ +""" +Generalized streaming and reorg tests that work across all loader implementations. + +These tests use the LoaderTestConfig abstraction to run identical streaming test logic +across different storage backends (PostgreSQL, Redis, Snowflake, etc.). +""" + +import pyarrow as pa +import pytest + +from tests.integration.loaders.conftest import LoaderTestBase + + +class BaseStreamingTests(LoaderTestBase): + """ + Base test class with streaming and reorg functionality tests. + + Loaders that support streaming should inherit from this class and provide a LoaderTestConfig. + + Example: + class TestPostgreSQLStreaming(BaseStreamingTests): + config = PostgreSQLTestConfig() + """ + + @pytest.mark.integration + def test_streaming_metadata_columns(self, loader, test_table_name, cleanup_tables): + """Test that streaming data creates tables with metadata columns""" + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BlockRange + + # Create test data with metadata + data = { + 'block_number': [100, 101, 102], + 'transaction_hash': ['0xabc', '0xdef', '0x123'], + 'value': [1.0, 2.0, 3.0], + } + batch = pa.RecordBatch.from_pydict(data) + + # Create metadata with block ranges + block_ranges = [BlockRange(network='ethereum', start=100, end=102)] + + with loader: + # Add metadata columns (simulating what load_stream_continuous does) + batch_with_metadata = loader._add_metadata_columns(batch, block_ranges) + + # Load the batch + result = loader.load_batch(batch_with_metadata, test_table_name, create_table=True) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify metadata columns were created using backend-specific method + column_names = self.config.get_column_names(loader, test_table_name) + + # Should have original columns plus metadata columns + assert '_amp_batch_id' in column_names + + # Verify data was loaded + row_count = self.config.get_row_count(loader, test_table_name) + assert row_count == 3 + + @pytest.mark.integration + def test_reorg_deletion(self, loader, test_table_name, cleanup_tables): + """Test that _handle_reorg correctly deletes invalidated ranges""" + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + with loader: + # Create streaming batches with metadata + batch1 = pa.RecordBatch.from_pydict( + { + 'tx_hash': ['0x100', '0x101', '0x102'], + 'block_num': [100, 101, 102], + 'value': [10.0, 11.0, 12.0], + } + ) + batch2 = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x200', '0x201'], 'block_num': [103, 104], 'value': [12.0, 33.0]} + ) + batch3 = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x300', '0x301'], 'block_num': [105, 106], 'value': [7.0, 9.0]} + ) + batch4 = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x400', '0x401'], 'block_num': [107, 108], 'value': [6.0, 73.0]} + ) + + # Create table from first batch schema + loader._create_table_from_schema(batch1.schema, test_table_name) + + # Create response batches with hashes + response1 = ResponseBatch.data_batch( + data=batch1, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')], + ranges_complete=True, # Mark as complete so state tracking works + ), + ) + response2 = ResponseBatch.data_batch( + data=batch2, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')], + ranges_complete=True, + ), + ) + response3 = ResponseBatch.data_batch( + data=batch3, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')], + ranges_complete=True, + ), + ) + response4 = ResponseBatch.data_batch( + data=batch4, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=107, end=108, hash='0xddd')], + ranges_complete=True, + ), + ) + + # Load via streaming API (with connection_name for state tracking) + stream = [response1, response2, response3, response4] + results = list(loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_conn')) + assert len(results) == 4 + assert all(r.success for r in results) + + # Verify initial data count + initial_count = self.config.get_row_count(loader, test_table_name) + assert initial_count == 9 # 3 + 2 + 2 + 2 + + # Test reorg deletion - invalidate blocks 104-108 on ethereum + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=104, end=108)] + ) + reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn')) + assert len(reorg_results) == 1 + assert reorg_results[0].success + + # Should delete batch2, batch3 and batch4 leaving only the 3 rows from batch1 + after_reorg_count = self.config.get_row_count(loader, test_table_name) + assert after_reorg_count == 3 + + @pytest.mark.integration + def test_reorg_overlapping_ranges(self, loader, test_table_name, cleanup_tables): + """Test reorg deletion with overlapping block ranges""" + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + with loader: + # Load data with overlapping ranges that should be invalidated + batch = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x150', '0x175', '0x250'], 'block_num': [150, 175, 250], 'value': [15.0, 17.5, 25.0]} + ) + + # Create table from batch schema + loader._create_table_from_schema(batch.schema, test_table_name) + + response = ResponseBatch.data_batch( + data=batch, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')], + ranges_complete=True, + ), + ) + + # Load via streaming API (with connection_name for state tracking) + results = list(loader.load_stream_continuous(iter([response]), test_table_name, connection_name='test_conn')) + assert len(results) == 1 + assert results[0].success + + # Verify initial data + initial_count = self.config.get_row_count(loader, test_table_name) + assert initial_count == 3 + + # Test partial overlap invalidation (160-180) + # This should invalidate our range [150,175] because they overlap + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=160, end=180)] + ) + reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn')) + assert len(reorg_results) == 1 + assert reorg_results[0].success + + # All data should be deleted due to overlap + after_reorg_count = self.config.get_row_count(loader, test_table_name) + assert after_reorg_count == 0 + + @pytest.mark.integration + def test_reorg_multi_network(self, loader, test_table_name, cleanup_tables): + """Test that reorg only affects specified network""" + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + if not self.config.supports_multi_network: + pytest.skip('Loader does not support multi-network isolation') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + with loader: + # Load data from multiple networks with same block ranges + batch_eth = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x100_eth'], 'network_id': ['ethereum'], 'block_num': [100], 'value': [10.0]} + ) + batch_poly = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x100_poly'], 'network_id': ['polygon'], 'block_num': [100], 'value': [10.0]} + ) + + # Create table from batch schema + loader._create_table_from_schema(batch_eth.schema, test_table_name) + + response_eth = ResponseBatch.data_batch( + data=batch_eth, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')], + ranges_complete=True, + ), + ) + response_poly = ResponseBatch.data_batch( + data=batch_poly, + metadata=BatchMetadata( + ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')], + ranges_complete=True, + ), + ) + + # Load both batches via streaming API (with connection_name for state tracking) + stream = [response_eth, response_poly] + results = list(loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_conn')) + assert len(results) == 2 + assert all(r.success for r in results) + + # Verify both networks' data exists + initial_count = self.config.get_row_count(loader, test_table_name) + assert initial_count == 2 + + # Invalidate only ethereum network + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=100, end=100)] + ) + reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn')) + assert len(reorg_results) == 1 + assert reorg_results[0].success + + # Should only delete ethereum data, polygon should remain + after_reorg_count = self.config.get_row_count(loader, test_table_name) + assert after_reorg_count == 1 + + @pytest.mark.integration + def test_microbatch_deduplication(self, loader, test_table_name, cleanup_tables): + """ + Test that multiple RecordBatches within the same microbatch are all loaded, + and deduplication only happens at microbatch boundaries when ranges_complete=True. + + This test verifies the fix for the critical bug where we were marking batches + as processed after every RecordBatch instead of waiting for ranges_complete=True. + """ + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + with loader: + # Create table first from the schema + batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) + loader._create_table_from_schema(batch1_data.schema, test_table_name) + + # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange + # First RecordBatch of the microbatch (ranges_complete=False) + response1 = ResponseBatch.data_batch( + data=batch1_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=False, # Not the last batch in this microbatch + ), + ) + + # Second RecordBatch (same BlockRange, ranges_complete=False) + batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) + response2 = ResponseBatch.data_batch( + data=batch2_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=False, + ), + ) + + # Third RecordBatch (same BlockRange, ranges_complete=True) + batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) + response3 = ResponseBatch.data_batch( + data=batch3_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=True, # Last batch - safe to mark as processed + ), + ) + + # Process the microbatch stream + stream = [response1, response2, response3] + results = list( + loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') + ) + + # CRITICAL: All 3 RecordBatches should be loaded successfully + assert len(results) == 3, 'All RecordBatches within microbatch should be processed' + assert all(r.success for r in results), 'All batches should succeed' + assert results[0].rows_loaded == 2, 'First batch should load 2 rows' + assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' + assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' + + # Verify all 6 rows in table + total_count = self.config.get_row_count(loader, test_table_name) + assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' + + # Test duplicate detection - send the same microbatch again + duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) + duplicate_response = ResponseBatch.data_batch( + data=duplicate_batch, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=True, + ), + ) + + duplicate_results = list( + loader.load_stream_continuous( + iter([duplicate_response]), test_table_name, connection_name='test_connection' + ) + ) + + # Duplicate should be skipped + assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' + + # Verify count hasn't changed + final_count = self.config.get_row_count(loader, test_table_name) + assert final_count == 6, 'No additional rows should be added after duplicate' From 70f7cd2b8c48abee92fd9213fdff6c8cee7f33fc Mon Sep 17 00:00:00 2001 From: Ford Date: Wed, 7 Jan 2026 12:34:51 -0800 Subject: [PATCH 3/6] Add generalized tests for DeltaLake, Iceberg, and LMDB loaders Migrates the final three loader test suites to use the shared base test infrastructure --- .../loaders/backends/test_deltalake.py | 294 +++++++++++++++ .../loaders/backends/test_iceberg.py | 313 ++++++++++++++++ .../integration/loaders/backends/test_lmdb.py | 334 ++++++++++++++++++ 3 files changed, 941 insertions(+) create mode 100644 tests/integration/loaders/backends/test_deltalake.py create mode 100644 tests/integration/loaders/backends/test_iceberg.py create mode 100644 tests/integration/loaders/backends/test_lmdb.py diff --git a/tests/integration/loaders/backends/test_deltalake.py b/tests/integration/loaders/backends/test_deltalake.py new file mode 100644 index 0000000..f2374c4 --- /dev/null +++ b/tests/integration/loaders/backends/test_deltalake.py @@ -0,0 +1,294 @@ +""" +DeltaLake-specific loader integration tests. + +This module provides DeltaLake-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pytest + +try: + from deltalake import DeltaTable + + from src.amp.loaders.implementations.deltalake_loader import DeltaLakeLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class DeltaLakeTestConfig(LoaderTestConfig): + """DeltaLake-specific test configuration""" + + loader_class = DeltaLakeLoader + config_fixture_name = 'delta_basic_config' + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + + def get_row_count(self, loader: DeltaLakeLoader, table_name: str) -> int: + """Get row count from DeltaLake table""" + # DeltaLake uses the table_path as the identifier + table_path = loader.config.table_path + dt = DeltaTable(table_path) + df = dt.to_pyarrow_table() + return len(df) + + def query_rows( + self, loader: DeltaLakeLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from DeltaLake table""" + table_path = loader.config.table_path + dt = DeltaTable(table_path) + df = dt.to_pyarrow_table() + + # Convert to list of dicts (simple implementation, no filtering) + result = [] + for i in range(min(100, len(df))): + row = {col: df[col][i].as_py() for col in df.column_names} + result.append(row) + return result + + def cleanup_table(self, loader: DeltaLakeLoader, table_name: str) -> None: + """Delete DeltaLake table directory""" + import shutil + + table_path = loader.config.table_path + if Path(table_path).exists(): + shutil.rmtree(table_path, ignore_errors=True) + + def get_column_names(self, loader: DeltaLakeLoader, table_name: str) -> List[str]: + """Get column names from DeltaLake table""" + table_path = loader.config.table_path + dt = DeltaTable(table_path) + schema = dt.schema() + return [field.name for field in schema.fields] + + +@pytest.mark.delta_lake +class TestDeltaLakeCore(BaseLoaderTests): + """DeltaLake core loader tests (inherited from base)""" + + config = DeltaLakeTestConfig() + + +@pytest.mark.delta_lake +class TestDeltaLakeStreaming(BaseStreamingTests): + """DeltaLake streaming tests (inherited from base)""" + + config = DeltaLakeTestConfig() + + +@pytest.mark.delta_lake +class TestDeltaLakeSpecific: + """DeltaLake-specific tests that cannot be generalized""" + + def test_partitioning(self, delta_partitioned_config, small_test_data): + """Test DeltaLake partitioning functionality""" + import pyarrow as pa + + loader = DeltaLakeLoader(delta_partitioned_config) + + # Verify partitioning is configured + assert loader.partition_by == ['year', 'month', 'day'] + + with loader: + result = loader.load_table(small_test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 5 + + # Verify partitions were created + dt = DeltaTable(loader.config.table_path) + files = dt.file_uris() + # Partitioned tables create subdirectories + assert len(files) > 0 + + def test_optimization_operations(self, delta_basic_config, comprehensive_test_data): + """Test DeltaLake OPTIMIZE operations""" + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + # Load data + result = loader.load_table(comprehensive_test_data, 'test_table') + assert result.success == True + + # DeltaLake may auto-optimize if configured + dt = DeltaTable(loader.config.table_path) + version = dt.version() + assert version >= 0 + + # Verify table can be read after optimization + df = dt.to_pyarrow_table() + assert len(df) == 1000 + + def test_schema_evolution(self, delta_basic_config, small_test_data): + """Test DeltaLake schema evolution""" + import pyarrow as pa + + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + # Load initial data + result = loader.load_table(small_test_data, 'test_table') + assert result.success == True + + # Add new column to schema + extended_data = { + **{col: small_test_data[col].to_pylist() for col in small_test_data.column_names}, + 'new_column': [100, 200, 300, 400, 500], + } + extended_table = pa.Table.from_pydict(extended_data) + + # Load with new schema (if schema evolution enabled) + from src.amp.loaders.base import LoadMode + + result2 = loader.load_table(extended_table, 'test_table', mode=LoadMode.APPEND) + + # Result depends on merge_schema configuration + dt = DeltaTable(loader.config.table_path) + schema = dt.schema() + # New column may or may not be present depending on config + assert len(schema.fields) >= len(small_test_data.schema) + + def test_table_history(self, delta_basic_config, small_test_data): + """Test DeltaLake table version history""" + from src.amp.loaders.base import LoadMode + + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + # Load initial data + loader.load_table(small_test_data, 'test_table') + + # Append more data + loader.load_table(small_test_data, 'test_table', mode=LoadMode.APPEND) + + # Check version history + dt = DeltaTable(loader.config.table_path) + version = dt.version() + assert version >= 1 # At least 2 operations (create + append) + + # Verify history is accessible + history = dt.history() + assert len(history) >= 1 + + def test_metadata_completeness(self, delta_basic_config, comprehensive_test_data): + """Test DeltaLake metadata in load results""" + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + result = loader.load_table(comprehensive_test_data, 'test_table') + + assert result.success == True + assert 'delta_version' in result.metadata + assert 'files_added' in result.metadata + assert result.metadata['delta_version'] >= 0 + + def test_query_operations(self, delta_basic_config, comprehensive_test_data): + """Test querying DeltaLake tables""" + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + result = loader.load_table(comprehensive_test_data, 'test_table') + assert result.success == True + + # Query the table + dt = DeltaTable(loader.config.table_path) + df = dt.to_pyarrow_table() + + # Verify data integrity + assert len(df) == 1000 + assert 'id' in df.column_names + assert 'user_id' in df.column_names + + def test_file_size_calculation(self, delta_basic_config, comprehensive_test_data): + """Test file size calculation for DeltaLake tables""" + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + result = loader.load_table(comprehensive_test_data, 'test_table') + assert result.success == True + + # Get table size + dt = DeltaTable(loader.config.table_path) + files = dt.file_uris() + assert len(files) > 0 + + # Calculate total size + total_size = 0 + for file_uri in files: + # Remove file:// prefix if present + file_path = file_uri.replace('file://', '') + if Path(file_path).exists(): + total_size += Path(file_path).stat().st_size + + assert total_size > 0 + + def test_concurrent_operations_safety(self, delta_basic_config, small_test_data): + """Test that DeltaLake handles concurrent operations safely""" + from concurrent.futures import ThreadPoolExecutor, as_completed + + from src.amp.loaders.base import LoadMode + + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + # Load initial data + loader.load_table(small_test_data, 'test_table') + + # Try concurrent appends + def append_data(i): + return loader.load_table(small_test_data, 'test_table', mode=LoadMode.APPEND) + + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(append_data, i) for i in range(3)] + results = [future.result() for future in as_completed(futures)] + + # All operations should succeed + assert all(r.success for r in results) + + # Verify final row count + dt = DeltaTable(loader.config.table_path) + df = dt.to_pyarrow_table() + assert len(df) == 20 # 5 initial + 3 * 5 appends + + +@pytest.mark.delta_lake +@pytest.mark.slow +class TestDeltaLakePerformance: + """DeltaLake performance tests""" + + def test_large_data_loading(self, delta_basic_config): + """Test loading large datasets to DeltaLake""" + import pyarrow as pa + + # Create large dataset + large_data = { + 'id': list(range(50000)), + 'value': [i * 0.123 for i in range(50000)], + 'category': [f'cat_{i % 100}' for i in range(50000)], + 'year': [2024 if i < 40000 else 2023 for i in range(50000)], + 'month': [(i // 100) % 12 + 1 for i in range(50000)], + 'day': [(i // 10) % 28 + 1 for i in range(50000)], + } + large_table = pa.Table.from_pydict(large_data) + + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + result = loader.load_table(large_table, 'test_table') + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 60 # Should complete within 60 seconds + + # Verify data integrity + dt = DeltaTable(loader.config.table_path) + df = dt.to_pyarrow_table() + assert len(df) == 50000 diff --git a/tests/integration/loaders/backends/test_iceberg.py b/tests/integration/loaders/backends/test_iceberg.py new file mode 100644 index 0000000..543fa06 --- /dev/null +++ b/tests/integration/loaders/backends/test_iceberg.py @@ -0,0 +1,313 @@ +""" +Iceberg-specific loader integration tests. + +This module provides Iceberg-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +from typing import Any, Dict, List, Optional + +import pytest + +try: + from pyiceberg.catalog import load_catalog + from pyiceberg.schema import Schema + + from src.amp.loaders.implementations.iceberg_loader import IcebergLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class IcebergTestConfig(LoaderTestConfig): + """Iceberg-specific test configuration""" + + loader_class = IcebergLoader + config_fixture_name = 'iceberg_basic_config' + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + + def get_row_count(self, loader: IcebergLoader, table_name: str) -> int: + """Get row count from Iceberg table""" + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + return len(df) + + def query_rows( + self, loader: IcebergLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from Iceberg table""" + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + + # Convert to list of dicts (simple implementation, no filtering) + result = [] + for i in range(min(100, len(df))): + row = {col: df[col][i].as_py() for col in df.column_names} + result.append(row) + return result + + def cleanup_table(self, loader: IcebergLoader, table_name: str) -> None: + """Drop Iceberg table""" + try: + catalog = loader._catalog + catalog.drop_table((loader.config.namespace, table_name)) + except Exception: + pass # Table may not exist + + def get_column_names(self, loader: IcebergLoader, table_name: str) -> List[str]: + """Get column names from Iceberg table""" + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + schema = table.schema() + return [field.name for field in schema.fields] + + +@pytest.mark.iceberg +class TestIcebergCore(BaseLoaderTests): + """Iceberg core loader tests (inherited from base)""" + + config = IcebergTestConfig() + + +@pytest.mark.iceberg +class TestIcebergStreaming(BaseStreamingTests): + """Iceberg streaming tests (inherited from base)""" + + config = IcebergTestConfig() + + +@pytest.mark.iceberg +class TestIcebergSpecific: + """Iceberg-specific tests that cannot be generalized""" + + def test_catalog_initialization(self, iceberg_basic_config): + """Test Iceberg catalog initialization""" + loader = IcebergLoader(iceberg_basic_config) + + with loader: + assert loader._catalog is not None + assert loader.config.namespace is not None + + # Verify namespace exists or was created + namespaces = loader._catalog.list_namespaces() + assert any(ns == (loader.config.namespace,) for ns in namespaces) + + def test_partitioning(self, iceberg_basic_config, small_test_data): + """Test Iceberg partitioning (partition spec)""" + import pyarrow as pa + + # Create config with partitioning + config = {**iceberg_basic_config, 'partition_spec': [('year', 'identity'), ('month', 'identity')]} + loader = IcebergLoader(config) + + table_name = 'test_partitioned' + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify table was created with partition spec + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + spec = table.spec() + # Partition spec should have fields + assert len(spec.fields) > 0 + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) + + def test_schema_evolution(self, iceberg_basic_config, small_test_data): + """Test Iceberg schema evolution""" + import pyarrow as pa + + from src.amp.loaders.base import LoadMode + + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_schema_evolution' + + with loader: + # Load initial data + result = loader.load_table(small_test_data, table_name) + assert result.success == True + + # Add new column + extended_data = { + **{col: small_test_data[col].to_pylist() for col in small_test_data.column_names}, + 'new_column': [100, 200, 300, 400, 500], + } + extended_table = pa.Table.from_pydict(extended_data) + + # Load with new schema + result2 = loader.load_table(extended_table, table_name, mode=LoadMode.APPEND) + + # Schema evolution depends on config + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + schema = table.schema() + # New column may or may not be present + assert len(schema.fields) >= len(small_test_data.schema) + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) + + def test_timestamp_conversion(self, iceberg_basic_config): + """Test timestamp conversion for Iceberg""" + import pyarrow as pa + from datetime import datetime + + # Create data with timestamps + data = { + 'id': [1, 2, 3], + 'timestamp': [datetime(2024, 1, 1), datetime(2024, 1, 2), datetime(2024, 1, 3)], + 'value': [100, 200, 300], + } + test_data = pa.Table.from_pydict(data) + + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_timestamps' + + with loader: + result = loader.load_table(test_data, table_name) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify timestamps were converted correctly + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + + assert len(df) == 3 + assert 'timestamp' in df.column_names + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) + + def test_multiple_tables(self, iceberg_basic_config, small_test_data): + """Test managing multiple tables with same loader""" + from src.amp.loaders.base import LoadMode + + loader = IcebergLoader(iceberg_basic_config) + + with loader: + # Create first table + result1 = loader.load_table(small_test_data, 'table1') + assert result1.success == True + + # Create second table + result2 = loader.load_table(small_test_data, 'table2') + assert result2.success == True + + # Verify both exist + catalog = loader._catalog + tables = catalog.list_tables(loader.config.namespace) + table_names = [t[1] for t in tables] + + assert 'table1' in table_names + assert 'table2' in table_names + + # Cleanup + catalog.drop_table((loader.config.namespace, 'table1')) + catalog.drop_table((loader.config.namespace, 'table2')) + + def test_upsert_operations(self, iceberg_basic_config): + """Test Iceberg upsert operations (merge on read)""" + import pyarrow as pa + + from src.amp.loaders.base import LoadMode + + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_upsert' + + # Initial data + data1 = {'id': [1, 2, 3], 'value': [100, 200, 300]} + table1 = pa.Table.from_pydict(data1) + + # Updated data (overlapping IDs) + data2 = {'id': [2, 3, 4], 'value': [250, 350, 400]} + table2 = pa.Table.from_pydict(data2) + + with loader: + # Load initial + loader.load_table(table1, table_name) + + # Upsert (if supported) or append + try: + # Try upsert mode if supported + loader.load_table(table2, table_name, mode=LoadMode.APPEND) + + # Verify row count + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + + # With append, we should have 6 rows (3 + 3) + # With upsert, we should have 4 rows (deduplicated) + assert len(df) >= 3 + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) + except Exception: + # Upsert may not be supported + catalog = loader._catalog + catalog.drop_table((loader.config.namespace, table_name)) + + def test_metadata_completeness(self, iceberg_basic_config, comprehensive_test_data): + """Test Iceberg metadata in load results""" + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_metadata' + + with loader: + result = loader.load_table(comprehensive_test_data, table_name) + + assert result.success == True + assert 'snapshot_id' in result.metadata or 'files_written' in result.metadata + + # Cleanup + catalog = loader._catalog + catalog.drop_table((loader.config.namespace, table_name)) + + +@pytest.mark.iceberg +@pytest.mark.slow +class TestIcebergPerformance: + """Iceberg performance tests""" + + def test_large_data_loading(self, iceberg_basic_config): + """Test loading large datasets to Iceberg""" + import pyarrow as pa + + # Create large dataset + large_data = { + 'id': list(range(50000)), + 'value': [i * 0.123 for i in range(50000)], + 'category': [f'cat_{i % 100}' for i in range(50000)], + } + large_table = pa.Table.from_pydict(large_data) + + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_large' + + with loader: + result = loader.load_table(large_table, table_name) + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 60 # Should complete within 60 seconds + + # Verify data integrity + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + assert len(df) == 50000 + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) diff --git a/tests/integration/loaders/backends/test_lmdb.py b/tests/integration/loaders/backends/test_lmdb.py new file mode 100644 index 0000000..aac0000 --- /dev/null +++ b/tests/integration/loaders/backends/test_lmdb.py @@ -0,0 +1,334 @@ +""" +LMDB-specific loader integration tests. + +This module provides LMDB-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pytest + +try: + import lmdb + + from src.amp.loaders.implementations.lmdb_loader import LMDBLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class LMDBTestConfig(LoaderTestConfig): + """LMDB-specific test configuration""" + + loader_class = LMDBLoader + config_fixture_name = 'lmdb_config' # LMDB uses config fixture + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + + def get_row_count(self, loader: LMDBLoader, table_name: str) -> int: + """Get row count from LMDB database""" + count = 0 + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + for key, value in cursor: + count += 1 + return count + + def query_rows( + self, loader: LMDBLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from LMDB database""" + import json + + rows = [] + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + for i, (key, value) in enumerate(cursor): + if i >= 100: # Limit to 100 + break + try: + # Try to decode as JSON + row_data = json.loads(value.decode()) + row_data['_key'] = key.decode() + rows.append(row_data) + except: + # Fallback to raw value + rows.append({'_key': key.decode(), '_value': value.decode()}) + return rows + + def cleanup_table(self, loader: LMDBLoader, table_name: str) -> None: + """Clear LMDB database""" + # LMDB doesn't have tables - clear the entire database + with loader.env.begin(db=loader.db, write=True) as txn: + cursor = txn.cursor() + # Delete all keys + keys_to_delete = [key for key, _ in cursor] + for key in keys_to_delete: + txn.delete(key) + + def get_column_names(self, loader: LMDBLoader, table_name: str) -> List[str]: + """Get column names from LMDB database (from first record)""" + import json + + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + for key, value in cursor: + try: + row_data = json.loads(value.decode()) + return list(row_data.keys()) + except: + return ['_value'] # Fallback + return [] + + +@pytest.fixture +def lmdb_test_env(): + """Create and cleanup temporary directory for LMDB databases""" + import tempfile + import shutil + + temp_dir = tempfile.mkdtemp(prefix='lmdb_test_') + yield temp_dir + + # Cleanup + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.fixture +def lmdb_config(lmdb_test_env): + """Create LMDB config from test env directory""" + return { + 'db_path': str(Path(lmdb_test_env) / 'test.lmdb'), + 'map_size': 100 * 1024**2, # 100MB + 'transaction_size': 1000, + 'create_if_missing': True, + } + + +@pytest.fixture +def lmdb_perf_config(lmdb_test_env): + """LMDB configuration for performance testing""" + return { + 'db_path': str(Path(lmdb_test_env) / 'perf.lmdb'), + 'map_size': 500 * 1024**2, # 500MB for performance tests + 'transaction_size': 5000, + 'create_if_missing': True, + } + + +@pytest.mark.lmdb +class TestLMDBCore(BaseLoaderTests): + """LMDB core loader tests (inherited from base)""" + + # Note: LMDB config needs special handling + config = LMDBTestConfig() + + +@pytest.mark.lmdb +class TestLMDBStreaming(BaseStreamingTests): + """LMDB streaming tests (inherited from base)""" + + config = LMDBTestConfig() + + +@pytest.mark.lmdb +class TestLMDBSpecific: + """LMDB-specific tests that cannot be generalized""" + + def test_key_column_strategy(self, lmdb_config, small_test_data): + """Test LMDB key column strategy""" + config = {**lmdb_config, 'key_column': 'id'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(small_test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 5 + + # Verify keys are based on id column + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + keys = [key.decode() for key, _ in cursor] + + # Keys should be string representations of IDs + assert len(keys) == 5 + # Keys may be prefixed with table name + + def test_key_pattern_strategy(self, lmdb_config): + """Test custom key pattern generation""" + import pyarrow as pa + + data = {'tx_hash': ['0x100', '0x101', '0x102'], 'block': [100, 101, 102], 'value': [10.0, 11.0, 12.0]} + test_data = pa.Table.from_pydict(data) + + config = {**lmdb_config, 'key_column': 'tx_hash', 'key_pattern': 'tx:{key}'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(test_data, 'test_table') + assert result.success == True + + # Verify custom key pattern + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + keys = [key.decode() for key, _ in cursor] + + # Keys should follow pattern + assert len(keys) == 3 + # At least one key should contain the pattern + assert any('tx:' in key or '0x' in key for key in keys) + + def test_composite_key_strategy(self, lmdb_config): + """Test composite key generation""" + import pyarrow as pa + + data = {'network': ['eth', 'eth', 'poly'], 'block': [100, 101, 100], 'tx_index': [0, 0, 0]} + test_data = pa.Table.from_pydict(data) + + config = {**lmdb_config, 'key_columns': ['network', 'block', 'tx_index']} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 3 + + # Verify composite keys + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + keys = [key.decode() for key, _ in cursor] + assert len(keys) == 3 + + def test_named_database(self, lmdb_config): + """Test LMDB named databases (sub-databases)""" + import pyarrow as pa + + data = {'id': [1, 2, 3], 'value': [100, 200, 300]} + test_data = pa.Table.from_pydict(data) + + config = {**lmdb_config, 'database_name': 'my_table'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 3 + + # Verify data in named database + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 3 + + def test_transaction_batching(self, lmdb_test_env): + """Test transaction batching for large datasets""" + import pyarrow as pa + + # Create large dataset + large_data = {'id': list(range(5000)), 'value': [i * 10 for i in range(5000)]} + large_table = pa.Table.from_pydict(large_data) + + config = { + 'db_path': str(Path(lmdb_test_env) / 'batch_test.lmdb'), + 'map_size': 100 * 1024**2, + 'transaction_size': 500, # Batch every 500 rows + 'create_if_missing': True, + 'key_column': 'id', + } + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(large_table, 'test_table') + assert result.success == True + assert result.rows_loaded == 5000 + + # Verify all data was loaded + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 5000 + + def test_byte_key_handling(self, lmdb_config): + """Test handling of byte keys""" + import pyarrow as pa + + data = {'key': [b'key1', b'key2', b'key3'], 'value': [100, 200, 300]} + test_data = pa.Table.from_pydict(data) + + config = {**lmdb_config, 'key_column': 'key'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 3 + + def test_data_persistence(self, lmdb_config, small_test_data): + """Test that data persists after closing and reopening""" + from src.amp.loaders.base import LoadMode + + # Load data + loader1 = LMDBLoader(lmdb_config) + with loader1: + result = loader1.load_table(small_test_data, 'test_table') + assert result.success == True + + # Close and reopen + loader2 = LMDBLoader(lmdb_config) + with loader2: + # Data should still be there + with loader2.env.begin(db=loader2.db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 5 + + # Can append more + result2 = loader2.load_table(small_test_data, 'test_table', mode=LoadMode.APPEND) + assert result2.success == True + + # Now should have 10 + with loader2.env.begin(db=loader2.db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 10 + + +@pytest.mark.lmdb +@pytest.mark.slow +class TestLMDBPerformance: + """LMDB performance tests""" + + def test_large_data_loading(self, lmdb_perf_config): + """Test loading large datasets to LMDB""" + import pyarrow as pa + + # Create large dataset + large_data = { + 'id': list(range(50000)), + 'value': [i * 0.123 for i in range(50000)], + 'category': [f'cat_{i % 100}' for i in range(50000)], + } + large_table = pa.Table.from_pydict(large_data) + + config = {**lmdb_perf_config, 'key_column': 'id'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(large_table, 'test_table') + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 60 # Should complete within 60 seconds + + # Verify data integrity + with loader.env.begin(db=loader.db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 50000 From bd57ceedbd569dd2e65ad8e19b3c5da224d7dbaf Mon Sep 17 00:00:00 2001 From: Ford Date: Thu, 8 Jan 2026 16:41:31 -0800 Subject: [PATCH 4/6] Delete old loader test files --- tests/integration/test_deltalake_loader.py | 870 ------------- tests/integration/test_iceberg_loader.py | 745 ------------ tests/integration/test_lmdb_loader.py | 653 ---------- tests/integration/test_postgresql_loader.py | 838 ------------- tests/integration/test_redis_loader.py | 981 --------------- tests/integration/test_snowflake_loader.py | 1215 ------------------- 6 files changed, 5302 deletions(-) delete mode 100644 tests/integration/test_deltalake_loader.py delete mode 100644 tests/integration/test_iceberg_loader.py delete mode 100644 tests/integration/test_lmdb_loader.py delete mode 100644 tests/integration/test_postgresql_loader.py delete mode 100644 tests/integration/test_redis_loader.py delete mode 100644 tests/integration/test_snowflake_loader.py diff --git a/tests/integration/test_deltalake_loader.py b/tests/integration/test_deltalake_loader.py deleted file mode 100644 index c19c9a9..0000000 --- a/tests/integration/test_deltalake_loader.py +++ /dev/null @@ -1,870 +0,0 @@ -# tests/integration/test_deltalake_loader.py -""" -Integration tests for Delta Lake loader implementation. -These tests require actual Delta Lake functionality and local filesystem access. -""" - -import json -import shutil -import tempfile -from datetime import datetime, timedelta -from pathlib import Path - -import pyarrow as pa -import pytest - -from src.amp.loaders.base import LoadMode - -try: - from src.amp.loaders.implementations.deltalake_loader import DELTALAKE_AVAILABLE, DeltaLakeLoader - - # Skip all tests if deltalake is not available - if not DELTALAKE_AVAILABLE: - pytest.skip('Delta Lake not available', allow_module_level=True) - -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -@pytest.fixture(scope='session') -def delta_test_env(): - """Setup Delta Lake test environment for the session""" - temp_dir = tempfile.mkdtemp(prefix='delta_test_') - yield temp_dir - # Cleanup - shutil.rmtree(temp_dir, ignore_errors=True) - - -@pytest.fixture -def delta_basic_config(delta_test_env): - """Get basic Delta Lake configuration""" - return { - 'table_path': str(Path(delta_test_env) / 'basic_table'), - 'partition_by': ['year', 'month'], - 'optimize_after_write': True, - 'vacuum_after_write': False, - 'schema_evolution': True, - 'merge_schema': True, - 'storage_options': {}, - } - - -@pytest.fixture -def delta_partitioned_config(delta_test_env): - """Get partitioned Delta Lake configuration""" - return { - 'table_path': str(Path(delta_test_env) / 'partitioned_table'), - 'partition_by': ['year', 'month', 'day'], - 'optimize_after_write': True, - 'vacuum_after_write': True, - 'schema_evolution': True, - 'merge_schema': True, - 'storage_options': {}, - } - - -@pytest.fixture -def comprehensive_test_data(): - """Create comprehensive test data for Delta Lake testing""" - base_date = datetime(2024, 1, 1) - - data = { - 'id': list(range(1000)), - 'user_id': [f'user_{i % 100}' for i in range(1000)], - 'transaction_amount': [round((i * 12.34) % 1000, 2) for i in range(1000)], - 'category': [['electronics', 'clothing', 'books', 'food', 'travel'][i % 5] for i in range(1000)], - 'timestamp': [(base_date + timedelta(days=i // 50, hours=i % 24)).isoformat() for i in range(1000)], - 'year': [2024 if i < 800 else 2023 for i in range(1000)], - 'month': [(i // 80) % 12 + 1 for i in range(1000)], - 'day': [(i // 30) % 28 + 1 for i in range(1000)], - 'is_weekend': [i % 7 in [0, 6] for i in range(1000)], - 'metadata': [ - json.dumps( - { - 'session_id': f'session_{i}', - 'device': ['mobile', 'desktop', 'tablet'][i % 3], - 'location': ['US', 'UK', 'DE', 'FR', 'JP'][i % 5], - } - ) - for i in range(1000) - ], - 'score': [i * 0.123 for i in range(1000)], - 'active': [i % 2 == 0 for i in range(1000)], - } - - return pa.Table.from_pydict(data) - - -@pytest.fixture -def small_test_data(): - """Create small test data for quick tests""" - data = { - 'id': [1, 2, 3, 4, 5], - 'name': ['a', 'b', 'c', 'd', 'e'], - 'value': [10.1, 20.2, 30.3, 40.4, 50.5], - 'year': [2024, 2024, 2024, 2024, 2024], - 'month': [1, 1, 1, 1, 1], - 'day': [1, 2, 3, 4, 5], - 'active': [True, False, True, False, True], - } - - return pa.Table.from_pydict(data) - - -@pytest.mark.integration -@pytest.mark.delta_lake -class TestDeltaLakeLoaderIntegration: - """Integration tests for Delta Lake loader""" - - def test_loader_initialization(self, delta_basic_config): - """Test loader initialization and connection""" - loader = DeltaLakeLoader(delta_basic_config) - - # Test configuration - assert loader.config.table_path == delta_basic_config['table_path'] - assert loader.config.partition_by == ['year', 'month'] - assert loader.config.optimize_after_write == True - assert loader.storage_backend == 'Local' - - # Test connection - loader.connect() - assert loader._is_connected == True - - # Test disconnection - loader.disconnect() - assert loader._is_connected == False - - def test_basic_table_operations(self, delta_basic_config, comprehensive_test_data): - """Test basic table creation and data loading""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Test initial table creation - result = loader.load_table(comprehensive_test_data, 'test_transactions', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 1000 - assert result.metadata['write_mode'] == 'overwrite' - assert result.metadata['storage_backend'] == 'Local' - assert result.metadata['partition_columns'] == ['year', 'month'] - - # Verify table exists - assert loader._table_exists == True - assert loader._delta_table is not None - - # Test table statistics - stats = loader.get_table_stats() - assert 'version' in stats - assert stats['storage_backend'] == 'Local' - assert stats['partition_columns'] == ['year', 'month'] - - def test_append_mode(self, delta_basic_config, comprehensive_test_data): - """Test append mode functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Initial load - result = loader.load_table(comprehensive_test_data, 'test_append', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 1000 - - # Append additional data - additional_data = comprehensive_test_data.slice(0, 100) # First 100 rows - result = loader.load_table(additional_data, 'test_append', mode=LoadMode.APPEND) - - assert result.success == True - assert result.rows_loaded == 100 - assert result.metadata['write_mode'] == 'append' - - # Verify total data - final_query = loader.query_table() - assert final_query.num_rows == 1100 # 1000 + 100 - - def test_batch_loading(self, delta_basic_config, comprehensive_test_data): - """Test batch loading functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Test loading individual batches - batches = comprehensive_test_data.to_batches(max_chunksize=200) - - for i, batch in enumerate(batches): - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - result = loader.load_batch(batch, 'test_batches', mode=mode) - - assert result.success == True - assert result.rows_loaded == batch.num_rows - assert result.metadata['operation'] == 'load_batch' - assert result.metadata['batch_size'] == batch.num_rows - - # Verify all data was loaded - final_query = loader.query_table() - assert final_query.num_rows == 1000 - - def test_partitioning(self, delta_partitioned_config, small_test_data): - """Test table partitioning functionality""" - loader = DeltaLakeLoader(delta_partitioned_config) - - with loader: - # Load partitioned data - result = loader.load_table(small_test_data, 'test_partitioned', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.metadata['partition_columns'] == ['year', 'month', 'day'] - - # Verify partition structure exists - table_path = Path(delta_partitioned_config['table_path']) - assert table_path.exists() - - def test_schema_evolution(self, delta_basic_config, small_test_data): - """Test schema evolution functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load initial data - result = loader.load_table(small_test_data, 'test_schema_evolution', mode=LoadMode.OVERWRITE) - - assert result.success == True - initial_schema = loader.get_table_schema() - initial_columns = set(initial_schema.names) - - # Create data with additional columns - extended_data_dict = small_test_data.to_pydict() - extended_data_dict['new_column'] = list(range(len(extended_data_dict['id']))) - extended_data_dict['another_field'] = ['test_value'] * len(extended_data_dict['id']) - extended_table = pa.Table.from_pydict(extended_data_dict) - - # Load extended data (should add new columns) - result = loader.load_table(extended_table, 'test_schema_evolution', mode=LoadMode.APPEND) - - assert result.success == True - - # Verify schema has evolved - evolved_schema = loader.get_table_schema() - evolved_columns = set(evolved_schema.names) - - assert 'new_column' in evolved_columns - assert 'another_field' in evolved_columns - assert evolved_columns.issuperset(initial_columns) - - def test_optimization_operations(self, delta_basic_config, comprehensive_test_data): - """Test table optimization operations""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load data multiple times to create multiple files - for i in range(3): - subset = comprehensive_test_data.slice(i * 300, 300) - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - - result = loader.load_table(subset, 'test_optimization', mode=mode) - assert result.success == True - - optimize_result = loader.optimize_table() - - assert optimize_result['success'] == True - assert 'duration_seconds' in optimize_result - assert 'metrics' in optimize_result - - # Verify data integrity after optimization - final_data = loader.query_table() - assert final_data.num_rows == 900 # 3 * 300 - - def test_query_operations(self, delta_basic_config, comprehensive_test_data): - """Test table querying operations""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load data - result = loader.load_table(comprehensive_test_data, 'test_query', mode=LoadMode.OVERWRITE) - assert result.success == True - - # Test basic query - query_result = loader.query_table() - assert query_result.num_rows == 1000 - - # Test column selection - query_result = loader.query_table(columns=['id', 'user_id', 'transaction_amount']) - assert query_result.num_rows == 1000 - assert query_result.column_names == ['id', 'user_id', 'transaction_amount'] - - # Test limit - query_result = loader.query_table(limit=50) - assert query_result.num_rows == 50 - - # Test combined options - query_result = loader.query_table(columns=['id', 'category'], limit=10) - assert query_result.num_rows == 10 - assert query_result.column_names == ['id', 'category'] - - def test_error_handling(self, delta_temp_config): - """Test error handling scenarios""" - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Test loading invalid data (missing partition columns) - invalid_data = pa.table( - { - 'id': [1, 2, 3], - 'name': ['a', 'b', 'c'], - # Missing 'year' and 'month' partition columns - } - ) - - result = loader.load_table(invalid_data, 'test_errors', mode=LoadMode.OVERWRITE) - - # Should handle error gracefully - assert result.success == False - assert result.error is not None - assert result.rows_loaded == 0 - - def test_table_history(self, delta_basic_config, small_test_data): - """Test table history functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Create multiple versions - for i in range(3): - subset = small_test_data.slice(i, 1) - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - - result = loader.load_table(subset, 'test_history', mode=mode) - assert result.success == True - - # Get history - history = loader.get_table_history() - assert len(history) >= 3 - - # Verify history structure - for entry in history: - assert 'version' in entry - assert 'operation' in entry - assert 'timestamp' in entry - - def test_context_manager(self, delta_basic_config, small_test_data): - """Test context manager functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - # Test context manager - with loader: - assert loader._is_connected == True - - result = loader.load_table(small_test_data, 'test_context', mode=LoadMode.OVERWRITE) - assert result.success == True - - # Should be disconnected after context - assert loader._is_connected == False - - def test_metadata_completeness(self, delta_basic_config, comprehensive_test_data): - """Test metadata completeness in results""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_metadata', mode=LoadMode.OVERWRITE) - - assert result.success == True - - # Check required metadata fields - metadata = result.metadata - required_fields = [ - 'write_mode', - 'storage_backend', - 'partition_columns', - 'throughput_rows_per_sec', - 'table_version', - ] - - for field in required_fields: - assert field in metadata, f'Missing metadata field: {field}' - - # Verify metadata values - assert metadata['write_mode'] == 'overwrite' - assert metadata['storage_backend'] == 'Local' - assert metadata['partition_columns'] == ['year', 'month'] - assert metadata['throughput_rows_per_sec'] > 0 - - def test_null_value_handling(self, delta_basic_config, null_test_data): - """Test comprehensive null value handling across all data types""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - result = loader.load_table(null_test_data, 'test_nulls', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 10 - - query_result = loader.query_table() - assert query_result.num_rows == 10 - - df = query_result.to_pandas() - - text_nulls = df['text_field'].isna().sum() - assert text_nulls == 3 # Rows 3, 6, 9 should be NULL - - int_nulls = df['int_field'].isna().sum() - assert int_nulls == 3 # Rows 2, 5, 8 should be NULL - - float_nulls = df['float_field'].isna().sum() - assert float_nulls == 3 # Rows 3, 6, 9 should be NULL - - bool_nulls = df['bool_field'].isna().sum() - assert bool_nulls == 3 # Rows 3, 6, 9 should be NULL - - timestamp_nulls = df['timestamp_field'].isna().sum() - assert timestamp_nulls == 4 # Rows where i % 3 == 0 - - # Verify non-null values are intact - assert df.loc[df['id'] == 1, 'text_field'].iloc[0] == 'a' - assert df.loc[df['id'] == 1, 'int_field'].iloc[0] == 1 - assert abs(df.loc[df['id'] == 1, 'float_field'].iloc[0] - 1.1) < 0.01 - assert df.loc[df['id'] == 1, 'bool_field'].iloc[0] == True - - # Test schema evolution with null values - from datetime import datetime - - additional_data = pa.table( - { - 'id': [11, 12], - 'text_field': ['k', None], - 'int_field': [None, 12], - 'float_field': [11.1, None], - 'bool_field': [None, False], - 'timestamp_field': [datetime.now(), None], # At least one non-null to preserve type - 'json_field': [None, '{"test": "value"}'], - 'year': [2024, 2024], - 'month': [1, 1], - 'day': [11, 12], - 'new_nullable_field': [None, 'new_value'], # New field with nulls - } - ) - - result = loader.load_table(additional_data, 'test_nulls', mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 2 - - # Verify schema evolved and nulls handled in new column - final_query = loader.query_table() - assert final_query.num_rows == 12 - - final_df = final_query.to_pandas() - new_field_nulls = final_df['new_nullable_field'].isna().sum() - assert new_field_nulls == 11 # All original rows + 1 new null row - - def test_file_size_calculation_modern_api(self, delta_basic_config, comprehensive_test_data): - """Test file size calculation using modern get_add_file_sizes API""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_file_sizes', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 1000 - - table_info = loader._get_table_info() - - # Verify size calculation worked - assert 'size_bytes' in table_info - assert table_info['size_bytes'] > 0, 'File size should be greater than 0' - assert table_info['num_files'] > 0, 'Should have at least one file' - - # Verify metadata includes size information - assert 'total_size_bytes' in result.metadata - assert result.metadata['total_size_bytes'] > 0 - - -@pytest.mark.integration -@pytest.mark.delta_lake -@pytest.mark.slow -class TestDeltaLakeLoaderAdvanced: - """Advanced integration tests for Delta Lake loader""" - - def test_large_data_performance(self, delta_basic_config): - """Test performance with larger datasets""" - # Create larger test dataset - large_data = { - 'id': list(range(50000)), - 'value': [i * 0.123 for i in range(50000)], - 'category': [f'category_{i % 10}' for i in range(50000)], - 'year': [2024] * 50000, - 'month': [(i // 4000) % 12 + 1 for i in range(50000)], - 'timestamp': [datetime.now().isoformat() for _ in range(50000)], - } - - large_table = pa.Table.from_pydict(large_data) - - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load large dataset - result = loader.load_table(large_table, 'test_performance', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 50000 - - # Verify performance metrics - assert result.metadata['throughput_rows_per_sec'] > 100 # Should be reasonably fast - assert result.duration < 120 # Should complete within reasonable time - - def test_concurrent_operations_safety(self, delta_basic_config, small_test_data): - """Test that operations are handled safely (basic concurrency test)""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load initial data - result = loader.load_table(small_test_data, 'test_concurrent', mode=LoadMode.OVERWRITE) - assert result.success == True - - # Perform multiple operations in sequence (simulating concurrent-like scenario) - operations = [] - - # Append operations - for i in range(3): - subset = small_test_data.slice(i, 1) - result = loader.load_table(subset, 'test_concurrent', mode=LoadMode.APPEND) - operations.append(result) - - # Verify all operations succeeded - for result in operations: - assert result.success == True - - # Verify final data integrity - final_data = loader.query_table() - assert final_data.num_rows == 8 # 5 + 3 * 1 - - def test_handle_reorg_no_table(self, delta_basic_config): - """Test reorg handling when table doesn't exist""" - from src.amp.streaming.types import BlockRange - - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Call handle reorg on non-existent table - invalidation_ranges = [BlockRange(network='ethereum', start=100, end=200)] - - # Should not raise any errors - loader._handle_reorg(invalidation_ranges, 'test_reorg_empty', 'test_connection') - - def test_handle_reorg_no_metadata_column(self, delta_basic_config): - """Test reorg handling when table lacks metadata column""" - from src.amp.streaming.types import BlockRange - - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Create table without metadata column - data = pa.table( - { - 'id': [1, 2, 3], - 'block_num': [100, 150, 200], - 'value': [10.0, 20.0, 30.0], - 'year': [2024, 2024, 2024], - 'month': [1, 1, 1], - } - ) - loader.load_table(data, 'test_reorg_no_meta', mode=LoadMode.OVERWRITE) - - # Call handle reorg - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)] - - # Should log warning and not modify data - loader._handle_reorg(invalidation_ranges, 'test_reorg_no_meta', 'test_connection') - - # Verify data unchanged - remaining_data = loader.query_table() - assert remaining_data.num_rows == 3 - - def test_handle_reorg_single_network(self, delta_temp_config): - """Test reorg handling for single network data""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming batches with metadata - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'block_num': [105], 'year': [2024], 'month': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'block_num': [155], 'year': [2024], 'month': [1]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'block_num': [205], 'year': [2024], 'month': [1]}) - - # Create response batches with hashes - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_single')) - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify all data exists - initial_data = loader.query_table() - assert initial_data.num_rows == 3 - - # Reorg from block 155 - should delete rows 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=155, end=300)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_single')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - assert reorg_results[0].is_reorg - - # Verify only first row remains - remaining_data = loader.query_table() - assert remaining_data.num_rows == 1 - assert remaining_data['id'][0].as_py() == 1 - - def test_handle_reorg_multi_network(self, delta_temp_config): - """Test reorg handling preserves data from unaffected networks""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming batches from multiple networks - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'network': ['ethereum'], 'year': [2024], 'month': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'network': ['polygon'], 'year': [2024], 'month': [1]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'network': ['ethereum'], 'year': [2024], 'month': [1]}) - batch4 = pa.RecordBatch.from_pydict({'id': [4], 'network': ['polygon'], 'year': [2024], 'month': [1]}) - - # Create response batches with network-specific ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response4 = ResponseBatch.data_batch( - data=batch4, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3, response4] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_multi')) - assert len(results) == 4 - assert all(r.success for r in results) - - # Reorg only ethereum from block 150 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_multi')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Verify ethereum row 3 deleted, but polygon rows preserved - remaining_data = loader.query_table() - assert remaining_data.num_rows == 3 - remaining_ids = sorted([id.as_py() for id in remaining_data['id']]) - assert remaining_ids == [1, 2, 4] # Row 3 deleted - - def test_handle_reorg_overlapping_ranges(self, delta_temp_config): - """Test reorg with overlapping block ranges""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming batches with different ranges - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'year': [2024], 'month': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'year': [2024], 'month': [1]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'year': [2024], 'month': [1]}) - - # Batch 1: 90-110 (ends before reorg start of 150) - # Batch 2: 140-160 (overlaps with reorg) - # Batch 3: 170-190 (after reorg, but should be deleted as 170 >= 150) - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_overlap')) - assert len(results) == 3 - assert all(r.success for r in results) - - # Reorg from block 150 - should delete batches 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_overlap')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Only first row should remain (ends at 110 < 150) - remaining_data = loader.query_table() - assert remaining_data.num_rows == 1 - assert remaining_data['id'][0].as_py() == 1 - - def test_handle_reorg_version_history(self, delta_temp_config): - """Test that reorg creates proper version history in Delta Lake""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming batches - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'year': [2024], 'month': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'year': [2024], 'month': [1]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'year': [2024], 'month': [1]}) - - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=0, end=10, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=50, end=60, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_history')) - assert len(results) == 3 - - initial_version = loader._delta_table.version() - - # Perform reorg - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=50, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_history')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Check that version increased - final_version = loader._delta_table.version() - assert final_version > initial_version - - # Check history - history = loader.get_table_history(limit=5) - assert len(history) >= 2 - # Latest operation should be an overwrite (from reorg) - assert history[0]['operation'] == 'WRITE' - - def test_streaming_with_reorg(self, delta_temp_config): - """Test streaming data with reorg support""" - from src.amp.streaming.types import ( - BatchMetadata, - BlockRange, - ResponseBatch, - ) - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming data with metadata - data1 = pa.RecordBatch.from_pydict( - {'id': [1, 2], 'value': [100, 200], 'year': [2024, 2024], 'month': [1, 1]} - ) - - data2 = pa.RecordBatch.from_pydict( - {'id': [3, 4], 'value': [300, 400], 'year': [2024, 2024], 'month': [1, 1]} - ) - - # Create response batches using factory methods (with hashes for proper state management) - response1 = ResponseBatch.data_batch( - data=data1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - response2 = ResponseBatch.data_batch( - data=data2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Simulate reorg event using factory method - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - - # Process streaming data - stream = [response1, response2, reorg_response] - results = list(loader.load_stream_continuous(iter(stream), 'test_streaming_reorg')) - - # Verify results - assert len(results) == 3 - assert results[0].success - assert results[0].rows_loaded == 2 - assert results[1].success - assert results[1].rows_loaded == 2 - assert results[2].success - assert results[2].is_reorg - - # Verify reorg deleted the second batch - final_data = loader.query_table() - assert final_data.num_rows == 2 - remaining_ids = sorted([id.as_py() for id in final_data['id']]) - assert remaining_ids == [1, 2] # 3 and 4 deleted by reorg diff --git a/tests/integration/test_iceberg_loader.py b/tests/integration/test_iceberg_loader.py deleted file mode 100644 index cbbe4bf..0000000 --- a/tests/integration/test_iceberg_loader.py +++ /dev/null @@ -1,745 +0,0 @@ -# tests/integration/test_iceberg_loader.py -""" -Integration tests for Apache Iceberg loader implementation. -These tests require actual Iceberg functionality and catalog access. -""" - -import json -import tempfile -from datetime import datetime, timedelta - -import pyarrow as pa -import pytest - -from src.amp.loaders.base import LoadMode - -try: - from src.amp.loaders.implementations.iceberg_loader import ICEBERG_AVAILABLE, IcebergLoader - - # Skip all tests if iceberg is not available - if not ICEBERG_AVAILABLE: - pytest.skip('Apache Iceberg not available', allow_module_level=True) - -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -@pytest.fixture(scope='session') -def iceberg_test_env(): - """Setup Iceberg test environment for the session""" - temp_dir = tempfile.mkdtemp(prefix='iceberg_test_') - yield temp_dir - # Note: cleanup is handled by temp directory auto-cleanup - - -@pytest.fixture -def iceberg_basic_config(iceberg_test_env): - """Get basic Iceberg configuration with local file catalog""" - return { - 'catalog_config': { - 'type': 'sql', - 'uri': f'sqlite:///{iceberg_test_env}/catalog.db', - 'warehouse': f'file://{iceberg_test_env}/warehouse', - }, - 'namespace': 'test_data', - 'create_namespace': True, - 'create_table': True, - 'schema_evolution': True, - 'batch_size': 1000, - } - - -@pytest.fixture -def iceberg_partitioned_config(iceberg_test_env): - """Get partitioned Iceberg configuration""" - # Note: partition_spec should be created with actual PartitionSpec when needed - # For now, return config without partitioning since we need schema first - return { - 'catalog_config': { - 'type': 'sql', - 'uri': f'sqlite:///{iceberg_test_env}/catalog.db', - 'warehouse': f'file://{iceberg_test_env}/warehouse', - }, - 'namespace': 'partitioned_data', - 'create_namespace': True, - 'create_table': True, - 'partition_spec': None, # Will be set in test if needed - 'schema_evolution': True, - 'batch_size': 500, - } - - -@pytest.fixture -def iceberg_temp_config(iceberg_test_env): - """Get temporary Iceberg configuration with unique namespace""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - return { - 'catalog_config': { - 'type': 'sql', - 'uri': f'sqlite:///{iceberg_test_env}/catalog.db', - 'warehouse': f'file://{iceberg_test_env}/warehouse', - }, - 'namespace': f'temp_data_{timestamp}', - 'create_namespace': True, - 'create_table': True, - 'schema_evolution': True, - 'batch_size': 2000, - } - - -@pytest.fixture -def comprehensive_test_data(): - """Create comprehensive test data for Iceberg testing""" - base_date = datetime(2024, 1, 1) - - data = { - 'id': list(range(1000)), - 'user_id': [f'user_{i % 100}' for i in range(1000)], - 'transaction_amount': [round((i * 12.34) % 1000, 2) for i in range(1000)], - 'category': [['electronics', 'clothing', 'books', 'food', 'travel'][i % 5] for i in range(1000)], - 'timestamp': pa.array( - [(base_date + timedelta(days=i // 50, hours=i % 24)) for i in range(1000)], - type=pa.timestamp('ns', tz='UTC'), - ), - 'year': [2024 if i < 800 else 2023 for i in range(1000)], - 'month': [(i // 80) % 12 + 1 for i in range(1000)], - 'day': [(i // 30) % 28 + 1 for i in range(1000)], - 'is_weekend': [i % 7 in [0, 6] for i in range(1000)], - 'metadata': [ - json.dumps( - { - 'session_id': f'session_{i}', - 'device': ['mobile', 'desktop', 'tablet'][i % 3], - 'location': ['US', 'UK', 'DE', 'FR', 'JP'][i % 5], - } - ) - for i in range(1000) - ], - 'score': [i * 0.123 for i in range(1000)], - 'active': [i % 2 == 0 for i in range(1000)], - } - - return pa.Table.from_pydict(data) - - -@pytest.fixture -def small_test_data(): - """Create small test data for quick tests""" - data = { - 'id': [1, 2, 3, 4, 5], - 'name': ['a', 'b', 'c', 'd', 'e'], - 'value': [10.1, 20.2, 30.3, 40.4, 50.5], - 'timestamp': pa.array( - [ - datetime(2024, 1, 1, 10, 0, 0), - datetime(2024, 1, 1, 11, 0, 0), - datetime(2024, 1, 1, 12, 0, 0), - datetime(2024, 1, 1, 13, 0, 0), - datetime(2024, 1, 1, 14, 0, 0), - ], - type=pa.timestamp('ns', tz='UTC'), - ), - 'year': [2024, 2024, 2024, 2024, 2024], - 'month': [1, 1, 1, 1, 1], - 'day': [1, 2, 3, 4, 5], - 'active': [True, False, True, False, True], - } - - return pa.Table.from_pydict(data) - - -@pytest.mark.integration -@pytest.mark.iceberg -class TestIcebergLoaderIntegration: - """Integration tests for Iceberg loader""" - - def test_loader_initialization(self, iceberg_basic_config): - """Test loader initialization and connection""" - loader = IcebergLoader(iceberg_basic_config) - - assert loader.config.namespace == iceberg_basic_config['namespace'] - assert loader.config.create_namespace == True - assert loader.config.create_table == True - - loader.connect() - assert loader._is_connected == True - assert loader._catalog is not None - assert loader._namespace_exists == True - - loader.disconnect() - assert loader._is_connected == False - assert loader._catalog is None - - def test_basic_table_operations(self, iceberg_basic_config, comprehensive_test_data): - """Test basic table creation and data loading""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_transactions', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 1000 - assert result.loader_type == 'iceberg' - assert result.table_name == 'test_transactions' - assert 'operation' in result.metadata - assert 'rows_loaded' in result.metadata - assert result.metadata['namespace'] == iceberg_basic_config['namespace'] - - def test_append_mode(self, iceberg_basic_config, comprehensive_test_data): - """Test append mode functionality""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_append', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 1000 - - additional_data = comprehensive_test_data.slice(0, 100) - result = loader.load_table(additional_data, 'test_append', mode=LoadMode.APPEND) - - assert result.success == True - assert result.rows_loaded == 100 - assert result.metadata['operation'] == 'load_table' - - def test_batch_loading(self, iceberg_basic_config, comprehensive_test_data): - """Test batch loading functionality""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - batches = comprehensive_test_data.to_batches(max_chunksize=200) - - for i, batch in enumerate(batches): - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - result = loader.load_batch(batch, 'test_batches', mode=mode) - - assert result.success == True - assert result.rows_loaded == batch.num_rows - assert result.metadata['operation'] == 'load_batch' - assert result.metadata['batch_size'] == batch.num_rows - assert result.metadata['schema_fields'] == len(batch.schema) - - def test_partitioning(self, iceberg_partitioned_config, small_test_data): - """Test table partitioning functionality""" - loader = IcebergLoader(iceberg_partitioned_config) - - with loader: - # Load partitioned data - result = loader.load_table(small_test_data, 'test_partitioned', mode=LoadMode.OVERWRITE) - - assert result.success == True - # Note: Partitioning requires creating PartitionSpec objects now - assert result.metadata['namespace'] == iceberg_partitioned_config['namespace'] - - def test_timestamp_conversion(self, iceberg_basic_config): - """Test timestamp precision conversion (ns -> us)""" - # Create data with nanosecond timestamps - timestamp_data = pa.table( - { - 'id': [1, 2, 3], - 'event_time': pa.array( - [ - datetime(2024, 1, 1, 10, 0, 0, 123456), # microsecond precision - datetime(2024, 1, 1, 11, 0, 0, 654321), - datetime(2024, 1, 1, 12, 0, 0, 987654), - ], - type=pa.timestamp('ns', tz='UTC'), - ), # nanosecond type - 'name': ['event1', 'event2', 'event3'], - } - ) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Just verify that we can load nanosecond timestamps successfully - result = loader.load_table(timestamp_data, 'test_timestamps', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 3 - - # The conversion happens internally - we just care that it works - - def test_schema_evolution(self, iceberg_basic_config, small_test_data): - """Test schema evolution functionality""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - result = loader.load_table(small_test_data, 'test_schema_evolution', mode=LoadMode.OVERWRITE) - assert result.success == True - - extended_data_dict = small_test_data.to_pydict() - extended_data_dict['new_column'] = list(range(len(extended_data_dict['id']))) - extended_data_dict['another_field'] = ['test_value'] * len(extended_data_dict['id']) - extended_table = pa.Table.from_pydict(extended_data_dict) - - result = loader.load_table(extended_table, 'test_schema_evolution', mode=LoadMode.APPEND) - - # Schema evolution should work successfully - assert result.success == True - assert result.rows_loaded > 0 - - # Verify that the new columns were added to the schema - table_info = loader.get_table_info('test_schema_evolution') - assert table_info['exists'] == True - assert 'new_column' in table_info['columns'] - assert 'another_field' in table_info['columns'] - - def test_error_handling_invalid_catalog(self): - """Test error handling with invalid catalog configuration""" - invalid_config = { - 'catalog_config': {'type': 'invalid_catalog_type', 'uri': 'invalid://invalid'}, - 'namespace': 'test', - 'create_namespace': True, - 'create_table': True, - } - - loader = IcebergLoader(invalid_config) - - # Should raise an error on connect - with pytest.raises(ValueError): - loader.connect() - - def test_error_handling_invalid_namespace(self, iceberg_test_env): - """Test error handling when namespace creation fails""" - config = { - 'catalog_config': { - 'type': 'sql', - 'uri': f'sqlite:///{iceberg_test_env}/catalog.db', - 'warehouse': f'file://{iceberg_test_env}/warehouse', - }, - 'namespace': 'test_namespace', - 'create_namespace': False, # Don't create namespace - 'create_table': True, - } - - loader = IcebergLoader(config) - - # Should fail if namespace doesn't exist and create_namespace=False - from pyiceberg.exceptions import NoSuchNamespaceError - - with pytest.raises(NoSuchNamespaceError): - loader.connect() - - def test_load_mode_overwrite(self, iceberg_basic_config, small_test_data): - """Test overwrite mode functionality""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, 'test_overwrite', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 5 - - # Create different data - different_data = pa.table( - { - 'id': [10, 20], - 'name': ['x', 'y'], - 'value': [100.0, 200.0], - 'timestamp': pa.array( - [datetime(2024, 2, 1, 10, 0, 0), datetime(2024, 2, 1, 11, 0, 0)], - type=pa.timestamp('ns', tz='UTC'), - ), - 'year': [2024, 2024], - 'month': [2, 2], - 'day': [1, 1], - 'active': [True, True], - } - ) - - # Overwrite with different data - result = loader.load_table(different_data, 'test_overwrite', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 2 - - def test_context_manager(self, iceberg_basic_config, small_test_data): - """Test context manager functionality""" - loader = IcebergLoader(iceberg_basic_config) - - # Test context manager auto-connect/disconnect - assert not loader._is_connected - - with loader: - assert loader._is_connected == True - - result = loader.load_table(small_test_data, 'test_context', mode=LoadMode.OVERWRITE) - assert result.success == True - - # Should be disconnected after context exit - assert loader._is_connected == False - - def test_load_result_metadata(self, iceberg_basic_config, comprehensive_test_data): - """Test that LoadResult contains proper metadata""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_metadata', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.loader_type == 'iceberg' - assert result.table_name == 'test_metadata' - assert result.rows_loaded == 1000 - assert result.duration > 0 - - # Check metadata content - metadata = result.metadata - assert 'operation' in metadata - assert 'rows_loaded' in metadata - assert 'columns' in metadata - assert 'namespace' in metadata - assert metadata['namespace'] == iceberg_basic_config['namespace'] - assert metadata['rows_loaded'] == 1000 - assert metadata['columns'] == len(comprehensive_test_data.schema) - - -@pytest.mark.integration -@pytest.mark.iceberg -@pytest.mark.slow -class TestIcebergLoaderAdvanced: - """Advanced integration tests for Iceberg loader""" - - def test_large_data_performance(self, iceberg_basic_config): - """Test performance with larger datasets""" - # Create larger test dataset - large_data = { - 'id': list(range(10000)), - 'value': [i * 0.123 for i in range(10000)], - 'category': [f'category_{i % 10}' for i in range(10000)], - 'year': [2024] * 10000, - 'month': [(i // 800) % 12 + 1 for i in range(10000)], - 'timestamp': pa.array( - [datetime(2024, 1, 1) + timedelta(seconds=i) for i in range(10000)], type=pa.timestamp('ns', tz='UTC') - ), - } - - large_table = pa.Table.from_pydict(large_data) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Load large dataset - result = loader.load_table(large_table, 'test_performance', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 10000 - assert result.duration < 300 # Should complete within reasonable time - - def test_multiple_tables_same_loader(self, iceberg_basic_config, small_test_data): - """Test loading multiple tables with the same loader instance""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - table_names = ['table_1', 'table_2', 'table_3'] - - for table_name in table_names: - result = loader.load_table(small_test_data, table_name, mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.table_name == table_name - assert result.rows_loaded == 5 - - def test_batch_streaming(self, iceberg_basic_config, comprehensive_test_data): - """Test streaming batch operations""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Convert to batch iterator - batches = comprehensive_test_data.to_batches(max_chunksize=100) - batch_list = list(batches) - - # Load using load_stream method from base class - results = list(loader.load_stream(iter(batch_list), 'test_streaming')) - - # Verify all batches were processed - total_rows = sum(r.rows_loaded for r in results if r.success) - assert total_rows == 1000 - - # Verify all operations succeeded - for result in results: - assert result.success == True - assert result.loader_type == 'iceberg' - - def test_upsert_operations(self, iceberg_basic_config): - """Test UPSERT/MERGE operations with automatic matching""" - # Use basic config - no special configuration needed for upsert - upsert_config = iceberg_basic_config.copy() - - # Initial data - initial_data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie'], 'value': [100, 200, 300]} - initial_table = pa.Table.from_pydict(initial_data) - - loader = IcebergLoader(upsert_config) - - with loader: - # Load initial data - result1 = loader.load_table(initial_table, 'test_upsert', mode=LoadMode.APPEND) - assert result1.success == True - assert result1.rows_loaded == 3 - - # Upsert data (update existing + insert new) - upsert_data = { - 'id': [2, 3, 4], # 2,3 exist (update), 4 is new (insert) - 'name': ['Bob_Updated', 'Charlie_Updated', 'David'], - 'value': [250, 350, 400], - } - upsert_table = pa.Table.from_pydict(upsert_data) - - result2 = loader.load_table(upsert_table, 'test_upsert', mode=LoadMode.UPSERT) - assert result2.success == True - assert result2.rows_loaded == 3 - - def test_upsert_simple(self, iceberg_basic_config): - """Test simple UPSERT operations with default behavior""" - - test_data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie'], 'value': [100, 200, 300]} - test_table = pa.Table.from_pydict(test_data) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Simple upsert with default settings - result = loader.load_table(test_table, 'test_simple_upsert', mode=LoadMode.UPSERT) - assert result.success == True - assert result.rows_loaded == 3 - - def test_upsert_fallback_to_append(self, iceberg_basic_config): - """Test that UPSERT falls back to APPEND when upsert fails""" - - test_data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie'], 'value': [100, 200, 300]} - test_table = pa.Table.from_pydict(test_data) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Even if upsert fails, should fallback gracefully - result = loader.load_table(test_table, 'test_upsert_fallback', mode=LoadMode.UPSERT) - assert result.success == True - assert result.rows_loaded == 3 - - def test_handle_reorg_empty_table(self, iceberg_basic_config): - """Test reorg handling on empty table""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create table with one row first - initial_data = pa.table( - {'id': [999], 'block_num': [999], '_meta_block_ranges': ['[{"network": "test", "start": 1, "end": 2}]']} - ) - loader.load_table(initial_data, 'test_reorg_empty', mode=LoadMode.OVERWRITE) - - # Now overwrite with empty data to simulate empty table - empty_data = pa.table({'id': [], 'block_num': [], '_meta_block_ranges': []}) - loader.load_table(empty_data, 'test_reorg_empty', mode=LoadMode.OVERWRITE) - - # Call handle reorg on empty table - invalidation_ranges = [BlockRange(network='ethereum', start=100, end=200)] - - # Should not raise any errors - loader._handle_reorg(invalidation_ranges, 'test_reorg_empty', 'test_connection') - - # Verify table still exists - table_info = loader.get_table_info('test_reorg_empty') - assert table_info['exists'] == True - - def test_handle_reorg_no_metadata_column(self, iceberg_basic_config): - """Test reorg handling when table lacks metadata column""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create table without metadata column - data = pa.table({'id': [1, 2, 3], 'block_num': [100, 150, 200], 'value': [10.0, 20.0, 30.0]}) - loader.load_table(data, 'test_reorg_no_meta', mode=LoadMode.OVERWRITE) - - # Call handle reorg - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)] - - # Should log warning and not modify data - loader._handle_reorg(invalidation_ranges, 'test_reorg_no_meta', 'test_connection') - - # Verify data unchanged - table_info = loader.get_table_info('test_reorg_no_meta') - assert table_info['exists'] == True - - def test_handle_reorg_single_network(self, iceberg_basic_config): - """Test reorg handling for single network data""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create table with metadata - block_ranges = [ - [{'network': 'ethereum', 'start': 100, 'end': 110}], - [{'network': 'ethereum', 'start': 150, 'end': 160}], - [{'network': 'ethereum', 'start': 200, 'end': 210}], - ] - - data = pa.table( - { - 'id': [1, 2, 3], - 'block_num': [105, 155, 205], - '_meta_block_ranges': [json.dumps(ranges) for ranges in block_ranges], - } - ) - - # Load initial data - result = loader.load_table(data, 'test_reorg_single', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 3 - - # Reorg from block 155 - should delete rows 2 and 3 - invalidation_ranges = [BlockRange(network='ethereum', start=155, end=300)] - loader._handle_reorg(invalidation_ranges, 'test_reorg_single', 'test_connection') - - # Verify only first row remains - # Since we can't easily query Iceberg tables in tests, we'll verify through table info - table_info = loader.get_table_info('test_reorg_single') - assert table_info['exists'] == True - # The actual row count verification would require scanning the table - - def test_handle_reorg_multi_network(self, iceberg_basic_config): - """Test reorg handling preserves data from unaffected networks""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create table with data from multiple networks - block_ranges = [ - [{'network': 'ethereum', 'start': 100, 'end': 110}], - [{'network': 'polygon', 'start': 100, 'end': 110}], - [{'network': 'ethereum', 'start': 150, 'end': 160}], - [{'network': 'polygon', 'start': 150, 'end': 160}], - ] - - data = pa.table( - { - 'id': [1, 2, 3, 4], - 'network': ['ethereum', 'polygon', 'ethereum', 'polygon'], - '_meta_block_ranges': [json.dumps(r) for r in block_ranges], - } - ) - - # Load initial data - result = loader.load_table(data, 'test_reorg_multi', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 4 - - # Reorg only ethereum from block 150 - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=200)] - loader._handle_reorg(invalidation_ranges, 'test_reorg_multi', 'test_connection') - - # Verify ethereum row 3 deleted, but polygon rows preserved - table_info = loader.get_table_info('test_reorg_multi') - assert table_info['exists'] == True - - def test_handle_reorg_overlapping_ranges(self, iceberg_basic_config): - """Test reorg with overlapping block ranges""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create data with overlapping ranges - block_ranges = [ - [{'network': 'ethereum', 'start': 90, 'end': 110}], # Overlaps with reorg - [{'network': 'ethereum', 'start': 140, 'end': 160}], # Overlaps with reorg - [{'network': 'ethereum', 'start': 170, 'end': 190}], # After reorg - ] - - data = pa.table({'id': [1, 2, 3], '_meta_block_ranges': [json.dumps(ranges) for ranges in block_ranges]}) - - # Load initial data - result = loader.load_table(data, 'test_reorg_overlap', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 3 - - # Reorg from block 150 - should delete rows where end >= 150 - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=200)] - loader._handle_reorg(invalidation_ranges, 'test_reorg_overlap', 'test_connection') - - # Only first row should remain (ends at 110 < 150) - table_info = loader.get_table_info('test_reorg_overlap') - assert table_info['exists'] == True - - def test_handle_reorg_multiple_invalidations(self, iceberg_basic_config): - """Test handling multiple invalidation ranges""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create data from multiple networks - block_ranges = [ - [{'network': 'ethereum', 'start': 100, 'end': 110}], - [{'network': 'polygon', 'start': 200, 'end': 210}], - [{'network': 'arbitrum', 'start': 300, 'end': 310}], - [{'network': 'ethereum', 'start': 150, 'end': 160}], - [{'network': 'polygon', 'start': 250, 'end': 260}], - ] - - data = pa.table({'id': [1, 2, 3, 4, 5], '_meta_block_ranges': [json.dumps(r) for r in block_ranges]}) - - # Load initial data - result = loader.load_table(data, 'test_reorg_multiple', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 5 - - # Multiple reorgs - invalidation_ranges = [ - BlockRange(network='ethereum', start=150, end=200), # Affects row 4 - BlockRange(network='polygon', start=250, end=300), # Affects row 5 - ] - loader._handle_reorg(invalidation_ranges, 'test_reorg_multiple', 'test_connection') - - # Rows 1, 2, 3 should remain - table_info = loader.get_table_info('test_reorg_multiple') - assert table_info['exists'] == True - - def test_streaming_with_reorg(self, iceberg_basic_config): - """Test streaming data with reorg support""" - from src.amp.streaming.types import ( - BatchMetadata, - BlockRange, - ResponseBatch, - ) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create streaming data with metadata - data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - - data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - - # Create response batches using factory methods (with hashes for proper state management) - response1 = ResponseBatch.data_batch( - data=data1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]), - ) - - response2 = ResponseBatch.data_batch( - data=data2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')]), - ) - - # Simulate reorg event using factory method - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - - # Process streaming data - stream = [response1, response2, reorg_response] - results = list(loader.load_stream_continuous(iter(stream), 'test_streaming_reorg')) - - # Verify results - assert len(results) == 3 - assert results[0].success == True - assert results[0].rows_loaded == 2 - assert results[1].success == True - assert results[1].rows_loaded == 2 - assert results[2].success == True - assert results[2].is_reorg == True diff --git a/tests/integration/test_lmdb_loader.py b/tests/integration/test_lmdb_loader.py deleted file mode 100644 index 20e2f67..0000000 --- a/tests/integration/test_lmdb_loader.py +++ /dev/null @@ -1,653 +0,0 @@ -# tests/integration/test_lmdb_loader.py -""" -Integration tests for LMDB loader implementation. -These tests use a local LMDB instance. -""" - -import os -import shutil -import tempfile -import time -from datetime import datetime - -import pyarrow as pa -import pytest - -try: - from src.amp.loaders.base import LoadMode - from src.amp.loaders.implementations.lmdb_loader import LMDBLoader -except ImportError: - pytest.skip('LMD loader modules not available', allow_module_level=True) - - -@pytest.fixture -def lmdb_test_dir(): - """Create and cleanup temporary directory for LMDB databases""" - temp_dir = tempfile.mkdtemp(prefix='lmdb_test_') - yield temp_dir - - # Cleanup - shutil.rmtree(temp_dir, ignore_errors=True) - - -@pytest.fixture -def lmdb_config(lmdb_test_dir): - """LMDB configuration for testing""" - return { - 'db_path': os.path.join(lmdb_test_dir, 'test.lmdb'), - 'map_size': 100 * 1024**2, # 100MB for tests - 'transaction_size': 1000, - } - - -@pytest.fixture -def test_table_name(): - """Generate unique table name for each test""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') - return f'test_table_{timestamp}' - - -@pytest.fixture -def sample_test_data(): - """Create sample test data""" - data = { - 'id': list(range(100)), - 'name': [f'name_{i}' for i in range(100)], - 'value': [i * 1.5 for i in range(100)], - 'active': [i % 2 == 0 for i in range(100)], - } - return pa.Table.from_pydict(data) - - -@pytest.fixture -def blockchain_test_data(): - """Create blockchain-like test data""" - data = { - 'block_number': list(range(1000, 2000)), - 'block_hash': [f'0x{i:064x}' for i in range(1000, 2000)], - 'timestamp': [1600000000 + i * 12 for i in range(1000)], - 'transaction_count': [i % 100 + 1 for i in range(1000)], - 'gas_used': [21000 + i * 1000 for i in range(1000)], - } - return pa.Table.from_pydict(data) - - -@pytest.mark.integration -@pytest.mark.lmdb -class TestLMDBLoaderIntegration: - """Test LMDB loader implementation""" - - def test_connection(self, lmdb_config): - """Test basic connection to LMDB""" - loader = LMDBLoader(lmdb_config) - - # Should not be connected initially - assert not loader.is_connected - - # Connect - loader.connect() - assert loader.is_connected - assert loader.env is not None - - # Check that database was created - assert os.path.exists(lmdb_config['db_path']) - - # Disconnect - loader.disconnect() - assert not loader.is_connected - - def test_load_table_basic(self, lmdb_config, sample_test_data, test_table_name): - """Test basic table loading""" - loader = LMDBLoader(lmdb_config) - loader.connect() - - # Load data - result = loader.load_table(sample_test_data, test_table_name) - - assert result.success - assert result.rows_loaded == 100 - assert result.table_name == test_table_name - assert result.loader_type == 'lmdb' - - # Verify data was stored - with loader.env.begin() as txn: - cursor = txn.cursor() - count = sum(1 for _ in cursor) - assert count == 100 - - loader.disconnect() - - def test_key_column_strategy(self, lmdb_config, sample_test_data, test_table_name): - """Test key generation using specific column""" - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(sample_test_data, test_table_name) - assert result.success - - # Verify keys were generated correctly - with loader.env.begin() as txn: - # Check specific keys - for i in range(10): - key = str(i).encode('utf-8') - value = txn.get(key) - assert value is not None - - loader.disconnect() - - def test_key_pattern_strategy(self, lmdb_config, blockchain_test_data, test_table_name): - """Test pattern-based key generation""" - config = {**lmdb_config, 'key_pattern': '{table}:block:{block_number}'} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(blockchain_test_data, test_table_name) - assert result.success - - # Verify pattern-based keys - with loader.env.begin() as txn: - key = f'{test_table_name}:block:1000'.encode('utf-8') - value = txn.get(key) - assert value is not None - - loader.disconnect() - - def test_composite_key_strategy(self, lmdb_config, sample_test_data, test_table_name): - """Test composite key generation""" - config = {**lmdb_config, 'composite_key_columns': ['name', 'id']} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(sample_test_data, test_table_name) - assert result.success - - # Verify composite keys - with loader.env.begin() as txn: - key = 'name_0:0'.encode('utf-8') - value = txn.get(key) - assert value is not None - - loader.disconnect() - - def test_named_database(self, lmdb_config, sample_test_data): - """Test using named databases""" - config = {**lmdb_config, 'database_name': 'blocks', 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Load to named database - result = loader.load_table(sample_test_data, 'any_table') - assert result.success - - # Verify data is in named database - db = loader.env.open_db(b'blocks') - with loader.env.begin(db=db) as txn: - cursor = txn.cursor() - count = sum(1 for _ in cursor) - assert count == 100 - - loader.disconnect() - - def test_load_modes(self, lmdb_config, sample_test_data, test_table_name): - """Test different load modes""" - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Initial load - result1 = loader.load_table(sample_test_data, test_table_name) - assert result1.success - assert result1.rows_loaded == 100 - - # Append mode (should fail for duplicate keys) - result2 = loader.load_table(sample_test_data, test_table_name, mode=LoadMode.APPEND) - assert not result2.success # Should fail due to duplicate keys - assert result2.rows_loaded == 0 # No new rows added - assert 'Key already exists' in str(result2.error) - - # Overwrite mode - result3 = loader.load_table(sample_test_data, test_table_name, mode=LoadMode.OVERWRITE) - assert result3.success - assert result3.rows_loaded == 100 - - loader.disconnect() - - def test_transaction_batching(self, lmdb_test_dir, blockchain_test_data, test_table_name): - """Test transaction batching performance""" - # Small transactions - config1 = { - 'db_path': os.path.join(lmdb_test_dir, 'test1.lmdb'), - 'map_size': 100 * 1024**2, - 'transaction_size': 100, - 'key_column': 'block_number', - } - loader1 = LMDBLoader(config1) - loader1.connect() - - start = time.time() - result1 = loader1.load_table(blockchain_test_data, test_table_name) - time1 = time.time() - start - - loader1.disconnect() - - # Large transactions - use different database - config2 = { - 'db_path': os.path.join(lmdb_test_dir, 'test2.lmdb'), - 'map_size': 100 * 1024**2, - 'transaction_size': 1000, - 'key_column': 'block_number', - } - loader2 = LMDBLoader(config2) - loader2.connect() - - start = time.time() - result2 = loader2.load_table(blockchain_test_data, test_table_name) - time2 = time.time() - start - - loader2.disconnect() - - # Both should succeed - assert result1.success - assert result2.success - - # Larger transactions should generally be faster - # (though this might not always be true in small test datasets) - print(f'Small txn time: {time1:.3f}s, Large txn time: {time2:.3f}s') - - def test_byte_key_handling(self, lmdb_config): - """Test handling of byte array keys""" - # Create data with byte keys - data = {'key': [b'key1', b'key2', b'key3'], 'value': [1, 2, 3]} - table = pa.Table.from_pydict(data) - - config = {**lmdb_config, 'key_column': 'key'} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(table, 'byte_test') - assert result.success - assert result.rows_loaded == 3 - - # Verify byte keys work correctly - with loader.env.begin() as txn: - assert txn.get(b'key1') is not None - assert txn.get(b'key2') is not None - assert txn.get(b'key3') is not None - - loader.disconnect() - - def test_large_batch_loading(self, lmdb_config): - """Test loading large batches""" - # Create larger dataset - size = 50000 - data = { - 'id': list(range(size)), - 'data': ['x' * 100 for _ in range(size)], # 100 chars per row - } - table = pa.Table.from_pydict(data) - - config = { - **lmdb_config, - 'map_size': 500 * 1024**2, # 500MB - 'transaction_size': 5000, - 'key_column': 'id', - } - loader = LMDBLoader(config) - loader.connect() - - start = time.time() - result = loader.load_table(table, 'large_test') - duration = time.time() - start - - assert result.success - assert result.rows_loaded == size - - # Check performance metrics - throughput = result.metadata['throughput_rows_per_sec'] - print(f'Loaded {size} rows in {duration:.2f}s ({throughput:.0f} rows/sec)') - assert throughput > 10000 # Should handle at least 10k rows/sec - - loader.disconnect() - - def test_error_handling(self, lmdb_config, sample_test_data): - """Test error handling""" - # Test with invalid key column - config = {**lmdb_config, 'key_column': 'nonexistent_column'} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(sample_test_data, 'error_test') - assert not result.success - assert 'not found in data' in result.error - - loader.disconnect() - - def test_data_persistence(self, lmdb_config, sample_test_data, test_table_name): - """Test that data persists across connections""" - config = {**lmdb_config, 'key_column': 'id'} - - # First connection - write data - loader1 = LMDBLoader(config) - loader1.connect() - result = loader1.load_table(sample_test_data, test_table_name) - assert result.success - loader1.disconnect() - - # Second connection - verify data - loader2 = LMDBLoader(config) - loader2.connect() - - with loader2.env.begin() as txn: - # Check a few keys - for i in range(10): - key = str(i).encode('utf-8') - value = txn.get(key) - assert value is not None - - # Deserialize and verify it's valid Arrow data - import pyarrow as pa - - reader = pa.ipc.open_stream(value) - batch = reader.read_next_batch() - assert batch.num_rows == 1 - - loader2.disconnect() - - def test_handle_reorg_empty_db(self, lmdb_config): - """Test reorg handling on empty database""" - from src.amp.streaming.types import BlockRange - - loader = LMDBLoader(lmdb_config) - loader.connect() - - # Call handle reorg on empty database - invalidation_ranges = [BlockRange(network='ethereum', start=100, end=200)] - - # Should not raise any errors - loader._handle_reorg(invalidation_ranges, 'test_reorg_empty', 'test_connection') - - loader.disconnect() - - def test_handle_reorg_no_metadata(self, lmdb_config): - """Test reorg handling when data lacks metadata column""" - from src.amp.streaming.types import BlockRange - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create data without metadata column - data = pa.table({'id': [1, 2, 3], 'block_num': [100, 150, 200], 'value': [10.0, 20.0, 30.0]}) - loader.load_table(data, 'test_reorg_no_meta', mode=LoadMode.OVERWRITE) - - # Call handle reorg - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)] - - # Should not delete any data (no metadata to check) - loader._handle_reorg(invalidation_ranges, 'test_reorg_no_meta', 'test_connection') - - # Verify data still exists - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is not None - assert txn.get(b'3') is not None - - loader.disconnect() - - def test_handle_reorg_single_network(self, lmdb_config): - """Test reorg handling for single network data""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create streaming batches with metadata - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'block_num': [105]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'block_num': [155]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'block_num': [205]}) - - # Create response batches with hashes - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_single')) - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify all data exists - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is not None - assert txn.get(b'3') is not None - - # Reorg from block 155 - should delete rows 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=155, end=300)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_single')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - assert reorg_results[0].is_reorg - - # Verify only first row remains - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is None # Deleted - assert txn.get(b'3') is None # Deleted - - loader.disconnect() - - def test_handle_reorg_multi_network(self, lmdb_config): - """Test reorg handling preserves data from unaffected networks""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create streaming batches from multiple networks - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'network': ['ethereum']}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'network': ['polygon']}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'network': ['ethereum']}) - batch4 = pa.RecordBatch.from_pydict({'id': [4], 'network': ['polygon']}) - - # Create response batches with network-specific ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response4 = ResponseBatch.data_batch( - data=batch4, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3, response4] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_multi')) - assert len(results) == 4 - assert all(r.success for r in results) - - # Reorg only ethereum from block 150 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_multi')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Verify ethereum row 3 deleted, but polygon rows preserved - with loader.env.begin() as txn: - assert txn.get(b'1') is not None # ethereum block 100 - assert txn.get(b'2') is not None # polygon block 100 - assert txn.get(b'3') is None # ethereum block 150 (deleted) - assert txn.get(b'4') is not None # polygon block 150 - - loader.disconnect() - - def test_handle_reorg_overlapping_ranges(self, lmdb_config): - """Test reorg with overlapping block ranges""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create streaming batches with different ranges - batch1 = pa.RecordBatch.from_pydict({'id': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3]}) - - # Batch 1: 90-110 (ends before reorg start of 150) - # Batch 2: 140-160 (overlaps with reorg) - # Batch 3: 170-190 (after reorg, but should be deleted as 170 >= 150) - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_overlap')) - assert len(results) == 3 - assert all(r.success for r in results) - - # Reorg from block 150 - should delete batches 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_overlap')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Only first row should remain (ends at 110 < 150) - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is None # Deleted (end=160 >= 150) - assert txn.get(b'3') is None # Deleted (end=190 >= 150) - - loader.disconnect() - - def test_streaming_with_reorg(self, lmdb_config): - """Test streaming data with reorg support""" - from src.amp.streaming.types import ( - BatchMetadata, - BlockRange, - ResponseBatch, - ) - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create streaming data with metadata - data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - - data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - - # Create response batches using factory methods (with hashes for proper state management) - response1 = ResponseBatch.data_batch( - data=data1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - response2 = ResponseBatch.data_batch( - data=data2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Simulate reorg event using factory method - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - - # Process streaming data - stream = [response1, response2, reorg_response] - results = list(loader.load_stream_continuous(iter(stream), 'test_streaming_reorg')) - - # Verify results - assert len(results) == 3 - assert results[0].success - assert results[0].rows_loaded == 2 - assert results[1].success - assert results[1].rows_loaded == 2 - assert results[2].success - assert results[2].is_reorg - - # Verify reorg deleted the second batch - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is not None - assert txn.get(b'3') is None # Deleted by reorg - assert txn.get(b'4') is None # Deleted by reorg - - loader.disconnect() - - -if __name__ == '__main__': - pytest.main([__file__, '-v']) diff --git a/tests/integration/test_postgresql_loader.py b/tests/integration/test_postgresql_loader.py deleted file mode 100644 index 649868c..0000000 --- a/tests/integration/test_postgresql_loader.py +++ /dev/null @@ -1,838 +0,0 @@ -# tests/integration/test_postgresql_loader.py -""" -Integration tests for PostgreSQL loader implementation. -These tests require a running PostgreSQL instance. -""" - -import time -from datetime import datetime - -import pyarrow as pa -import pytest - -try: - from src.amp.loaders.base import LoadMode - from src.amp.loaders.implementations.postgresql_loader import PostgreSQLLoader -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -@pytest.fixture -def test_table_name(): - """Generate unique table name for each test""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') - return f'test_table_{timestamp}' - - -@pytest.fixture -def postgresql_type_test_data(): - """Create test data specifically for PostgreSQL data type testing""" - data = { - 'id': list(range(1000)), - 'text_field': [f'text_{i}' for i in range(1000)], - 'float_field': [i * 1.23 for i in range(1000)], - 'bool_field': [i % 2 == 0 for i in range(1000)], - } - return pa.Table.from_pydict(data) - - -@pytest.fixture -def cleanup_tables(postgresql_test_config): - """Cleanup test tables after tests""" - tables_to_clean = [] - - yield tables_to_clean - - # Cleanup - loader = PostgreSQLLoader(postgresql_test_config) - try: - loader.connect() - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - for table in tables_to_clean: - try: - cur.execute(f'DROP TABLE IF EXISTS {table} CASCADE') - conn.commit() - except Exception: - pass - finally: - loader.pool.putconn(conn) - loader.disconnect() - except Exception: - pass - - -@pytest.mark.integration -@pytest.mark.postgresql -class TestPostgreSQLLoaderIntegration: - """Integration tests for PostgreSQL loader""" - - def test_loader_connection(self, postgresql_test_config): - """Test basic connection to PostgreSQL""" - loader = PostgreSQLLoader(postgresql_test_config) - - # Test connection - loader.connect() - assert loader._is_connected == True - assert loader.pool is not None - - # Test disconnection - loader.disconnect() - assert loader._is_connected == False - assert loader.pool is None - - def test_context_manager(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test context manager functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - assert loader._is_connected == True - - result = loader.load_table(small_test_data, test_table_name) - assert result.success == True - - # Should be disconnected after context - assert loader._is_connected == False - - def test_basic_table_operations(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test basic table creation and data loading""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Test initial table creation - result = loader.load_table(small_test_data, test_table_name, create_table=True) - - assert result.success == True - assert result.rows_loaded == 5 - assert result.loader_type == 'postgresql' - assert result.table_name == test_table_name - assert 'columns' in result.metadata - assert result.metadata['columns'] == 7 - - def test_append_mode(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test append mode functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 5 - - # Append additional data - result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 5 - - # Verify total rows - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = cur.fetchone()[0] - assert count == 10 # 5 + 5 - finally: - loader.pool.putconn(conn) - - def test_overwrite_mode(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test overwrite mode functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 5 - - # Overwrite with different data - new_data = small_test_data.slice(0, 3) # First 3 rows - result = loader.load_table(new_data, test_table_name, mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 3 - - # Verify only new data remains - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = cur.fetchone()[0] - assert count == 3 - finally: - loader.pool.putconn(conn) - - def test_batch_loading(self, postgresql_test_config, medium_test_table, test_table_name, cleanup_tables): - """Test batch loading functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Test loading individual batches - batches = medium_test_table.to_batches(max_chunksize=250) - - for i, batch in enumerate(batches): - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - result = loader.load_batch(batch, test_table_name, mode=mode) - - assert result.success == True - assert result.rows_loaded == batch.num_rows - assert result.metadata['batch_size'] == batch.num_rows - - # Verify all data was loaded - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = cur.fetchone()[0] - assert count == 10000 - finally: - loader.pool.putconn(conn) - - def test_data_types(self, postgresql_test_config, postgresql_type_test_data, test_table_name, cleanup_tables): - """Test various data types are handled correctly""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - result = loader.load_table(postgresql_type_test_data, test_table_name) - assert result.success == True - assert result.rows_loaded == 1000 - - # Verify data integrity - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Check various data types - cur.execute(f'SELECT id, text_field, float_field, bool_field FROM {test_table_name} WHERE id = 10') - row = cur.fetchone() - assert row[0] == 10 - assert row[1] in ['text_10', '"text_10"'] # Handle potential CSV quoting - assert abs(row[2] - 12.3) < 0.01 # 10 * 1.23 = 12.3 - assert row[3] == True - finally: - loader.pool.putconn(conn) - - def test_null_value_handling(self, postgresql_test_config, null_test_data, test_table_name, cleanup_tables): - """Test comprehensive null value handling across all data types""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - result = loader.load_table(null_test_data, test_table_name) - assert result.success == True - assert result.rows_loaded == 10 - - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Check text field nulls (rows 3, 6, 9 have index 2, 5, 8) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE text_field IS NULL') - text_nulls = cur.fetchone()[0] - assert text_nulls == 3 - - # Check int field nulls (rows 2, 5, 8 have index 1, 4, 7) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE int_field IS NULL') - int_nulls = cur.fetchone()[0] - assert int_nulls == 3 - - # Check float field nulls (rows 3, 6, 9 have index 2, 5, 8) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE float_field IS NULL') - float_nulls = cur.fetchone()[0] - assert float_nulls == 3 - - # Check bool field nulls (rows 3, 6, 9 have index 2, 5, 8) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE bool_field IS NULL') - bool_nulls = cur.fetchone()[0] - assert bool_nulls == 3 - - # Check timestamp field nulls - # (rows where i % 3 == 0, which are ids 3, 6, 9, plus id 1 due to zero indexing) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE timestamp_field IS NULL') - timestamp_nulls = cur.fetchone()[0] - assert timestamp_nulls == 4 - - # Check json field nulls (rows where i % 4 == 0, which are ids 4, 8 due to zero indexing pattern) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE json_field IS NULL') - json_nulls = cur.fetchone()[0] - assert json_nulls == 3 - - # Verify non-null values are intact - cur.execute(f'SELECT text_field FROM {test_table_name} WHERE id = 1') - text_val = cur.fetchone()[0] - assert text_val in ['a', '"a"'] # Handle potential CSV quoting - - cur.execute(f'SELECT int_field FROM {test_table_name} WHERE id = 1') - int_val = cur.fetchone()[0] - assert int_val == 1 - - cur.execute(f'SELECT float_field FROM {test_table_name} WHERE id = 1') - float_val = cur.fetchone()[0] - assert abs(float_val - 1.1) < 0.01 - - cur.execute(f'SELECT bool_field FROM {test_table_name} WHERE id = 1') - bool_val = cur.fetchone()[0] - assert bool_val == True - finally: - loader.pool.putconn(conn) - - def test_binary_data_handling(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test binary data handling with INSERT fallback""" - cleanup_tables.append(test_table_name) - - # Create data with binary columns - data = {'id': [1, 2, 3], 'binary_data': [b'hello', b'world', b'test'], 'text_data': ['a', 'b', 'c']} - table = pa.Table.from_pydict(data) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - result = loader.load_table(table, test_table_name) - assert result.success == True - assert result.rows_loaded == 3 - - # Verify binary data was stored correctly - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT id, binary_data FROM {test_table_name} ORDER BY id') - rows = cur.fetchall() - assert rows[0][1].tobytes() == b'hello' - assert rows[1][1].tobytes() == b'world' - assert rows[2][1].tobytes() == b'test' - finally: - loader.pool.putconn(conn) - - def test_schema_retrieval(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test schema retrieval functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Create table - result = loader.load_table(small_test_data, test_table_name) - assert result.success == True - - # Get schema - schema = loader.get_table_schema(test_table_name) - assert schema is not None - - # Filter out metadata columns added by PostgreSQL loader - non_meta_fields = [ - field for field in schema if not (field.name.startswith('_meta_') or field.name.startswith('_amp_')) - ] - - assert len(non_meta_fields) == len(small_test_data.schema) - - # Verify column names match (excluding metadata columns) - original_names = set(small_test_data.schema.names) - retrieved_names = set(field.name for field in non_meta_fields) - assert original_names == retrieved_names - - def test_error_handling(self, postgresql_test_config, small_test_data): - """Test error handling scenarios""" - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Test loading to non-existent table without create_table - result = loader.load_table(small_test_data, 'non_existent_table', create_table=False) - - assert result.success == False - assert result.error is not None - assert result.rows_loaded == 0 - assert 'does not exist' in result.error - - def test_connection_pooling(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test connection pooling behavior""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Perform multiple operations to test pool reuse - for i in range(5): - subset = small_test_data.slice(i, 1) - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - - result = loader.load_table(subset, test_table_name, mode=mode) - assert result.success == True - - # Verify pool is managing connections properly - # Note: _used is a dict in ThreadedConnectionPool, not an int - assert len(loader.pool._used) <= loader.pool.maxconn - - def test_performance_metrics(self, postgresql_test_config, medium_test_table, test_table_name, cleanup_tables): - """Test performance metrics in results""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - start_time = time.time() - result = loader.load_table(medium_test_table, test_table_name) - end_time = time.time() - - assert result.success == True - assert result.duration > 0 - assert result.duration <= (end_time - start_time) - assert result.rows_loaded == 10000 - - # Check metadata contains performance info - assert 'table_size_bytes' in result.metadata - assert result.metadata['table_size_bytes'] > 0 - - -@pytest.mark.integration -@pytest.mark.postgresql -@pytest.mark.slow -class TestPostgreSQLLoaderPerformance: - """Performance tests for PostgreSQL loader""" - - def test_large_data_loading(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test loading large datasets""" - cleanup_tables.append(test_table_name) - - # Create large dataset - large_data = { - 'id': list(range(50000)), - 'value': [i * 0.123 for i in range(50000)], - 'category': [f'category_{i % 100}' for i in range(50000)], - 'description': [f'This is a longer text description for row {i}' for i in range(50000)], - 'created_at': [datetime.now() for _ in range(50000)], - } - large_table = pa.Table.from_pydict(large_data) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - result = loader.load_table(large_table, test_table_name) - - assert result.success == True - assert result.rows_loaded == 50000 - assert result.duration < 60 # Should complete within 60 seconds - - # Verify data integrity - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = cur.fetchone()[0] - assert count == 50000 - finally: - loader.pool.putconn(conn) - - -@pytest.mark.integration -@pytest.mark.postgresql -class TestPostgreSQLLoaderStreaming: - """Integration tests for PostgreSQL loader streaming functionality""" - - def test_streaming_metadata_columns(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test that streaming data creates tables with metadata columns""" - cleanup_tables.append(test_table_name) - - # Import streaming types - from src.amp.streaming.types import BlockRange - - # Create test data with metadata - data = { - 'block_number': [100, 101, 102], - 'transaction_hash': ['0xabc', '0xdef', '0x123'], - 'value': [1.0, 2.0, 3.0], - } - batch = pa.RecordBatch.from_pydict(data) - - # Create metadata with block ranges - block_ranges = [BlockRange(network='ethereum', start=100, end=102)] - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Add metadata columns (simulating what load_stream_continuous does) - batch_with_metadata = loader._add_metadata_columns(batch, block_ranges) - - # Load the batch - result = loader.load_batch(batch_with_metadata, test_table_name, create_table=True) - assert result.success == True - assert result.rows_loaded == 3 - - # Verify metadata columns were created in the table - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Check table schema includes metadata columns - cur.execute( - """ - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = %s - ORDER BY ordinal_position - """, - (test_table_name,), - ) - - columns = cur.fetchall() - column_names = [col[0] for col in columns] - - # Should have original columns plus metadata columns - assert '_amp_batch_id' in column_names - - # Verify metadata column types - column_types = {col[0]: col[1] for col in columns} - assert ( - 'text' in column_types['_amp_batch_id'].lower() - or 'varchar' in column_types['_amp_batch_id'].lower() - ) - - # Verify data was stored correctly - cur.execute(f'SELECT "_amp_batch_id" FROM {test_table_name} LIMIT 1') - meta_row = cur.fetchone() - - # _amp_batch_id contains a compact 16-char hex string (or multiple separated by |) - batch_id_str = meta_row[0] - assert batch_id_str is not None - assert isinstance(batch_id_str, str) - assert len(batch_id_str) >= 16 # At least one 16-char batch ID - - finally: - loader.pool.putconn(conn) - - def test_handle_reorg_deletion(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test that _handle_reorg correctly deletes invalidated ranges""" - cleanup_tables.append(test_table_name) - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Create streaming batches with metadata - batch1 = pa.RecordBatch.from_pydict( - { - 'tx_hash': ['0x100', '0x101', '0x102'], - 'block_num': [100, 101, 102], - 'value': [10.0, 11.0, 12.0], - } - ) - batch2 = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x200', '0x201'], 'block_num': [103, 104], 'value': [12.0, 33.0]} - ) - batch3 = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x300', '0x301'], 'block_num': [105, 106], 'value': [7.0, 9.0]} - ) - batch4 = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x400', '0x401'], 'block_num': [107, 108], 'value': [6.0, 73.0]} - ) - - # Create table from first batch schema - loader._create_table_from_schema(batch1.schema, test_table_name) - - # Create response batches with hashes - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response4 = ResponseBatch.data_batch( - data=batch4, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=107, end=108, hash='0xddd')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3, response4] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - assert len(results) == 4 - assert all(r.success for r in results) - - # Verify initial data count - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - initial_count = cur.fetchone()[0] - assert initial_count == 9 # 3 + 2 + 2 + 2 - - # Test reorg deletion - invalidate blocks 104-108 on ethereum - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=104, end=108)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Should delete batch2, batch3 and batch4 leaving only the 3 rows from batch1 - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - after_reorg_count = cur.fetchone()[0] - assert after_reorg_count == 3 - - finally: - loader.pool.putconn(conn) - - def test_reorg_with_overlapping_ranges(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test reorg deletion with overlapping block ranges""" - cleanup_tables.append(test_table_name) - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Load data with overlapping ranges that should be invalidated - batch = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x150', '0x175', '0x250'], 'block_num': [150, 175, 250], 'value': [15.0, 17.5, 25.0]} - ) - - # Create table from batch schema - loader._create_table_from_schema(batch.schema, test_table_name) - - response = ResponseBatch.data_batch( - data=batch, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - results = list(loader.load_stream_continuous(iter([response]), test_table_name)) - assert len(results) == 1 - assert results[0].success - - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Verify initial data - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - assert cur.fetchone()[0] == 3 - - # Test partial overlap invalidation (160-180) - # This should invalidate our range [150,175] because they overlap - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=160, end=180)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # All data should be deleted due to overlap - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - assert cur.fetchone()[0] == 0 - - finally: - loader.pool.putconn(conn) - - def test_reorg_preserves_different_networks(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test that reorg only affects specified network""" - cleanup_tables.append(test_table_name) - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Load data from multiple networks with same block ranges - batch_eth = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x100_eth'], 'network_id': ['ethereum'], 'block_num': [100], 'value': [10.0]} - ) - batch_poly = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x100_poly'], 'network_id': ['polygon'], 'block_num': [100], 'value': [10.0]} - ) - - # Create table from batch schema - loader._create_table_from_schema(batch_eth.schema, test_table_name) - - response_eth = ResponseBatch.data_batch( - data=batch_eth, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response_poly = ResponseBatch.data_batch( - data=batch_poly, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load both batches via streaming API - stream = [response_eth, response_poly] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - assert len(results) == 2 - assert all(r.success for r in results) - - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Verify both networks' data exists - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - assert cur.fetchone()[0] == 2 - - # Invalidate only ethereum network - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=100, end=100)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Should only delete ethereum data, polygon should remain - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - assert cur.fetchone()[0] == 1 - - finally: - loader.pool.putconn(conn) - - def test_microbatch_deduplication(self, postgresql_test_config, test_table_name, cleanup_tables): - """ - Test that multiple RecordBatches within the same microbatch are all loaded, - and deduplication only happens at microbatch boundaries when ranges_complete=True. - - This test verifies the fix for the critical bug where we were marking batches - as processed after every RecordBatch instead of waiting for ranges_complete=True. - """ - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - cleanup_tables.append(test_table_name) - - # Enable state management to test deduplication - config_with_state = { - **postgresql_test_config, - 'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True}, - } - loader = PostgreSQLLoader(config_with_state) - - with loader: - # Create table first from the schema - batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - loader._create_table_from_schema(batch1_data.schema, test_table_name) - - # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange - # This happens when the server sends large microbatches in smaller chunks - - # First RecordBatch of the microbatch (ranges_complete=False) - response1 = ResponseBatch.data_batch( - data=batch1_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], - ranges_complete=False, # Not the last batch in this microbatch - ), - ) - - # Second RecordBatch of the microbatch (ranges_complete=False) - batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - response2 = ResponseBatch.data_batch( - data=batch2_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! - ranges_complete=False, # Still not the last batch - ), - ) - - # Third RecordBatch of the microbatch (ranges_complete=True) - batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) - response3 = ResponseBatch.data_batch( - data=batch3_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! - ranges_complete=True, # Last batch in this microbatch - safe to mark as processed - ), - ) - - # Process the microbatch stream - stream = [response1, response2, response3] - results = list( - loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') - ) - - # CRITICAL: All 3 RecordBatches should be loaded successfully - # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates") - assert len(results) == 3, 'All RecordBatches within microbatch should be processed' - assert all(r.success for r in results), 'All batches should succeed' - assert results[0].rows_loaded == 2, 'First batch should load 2 rows' - assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' - assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' - - # Verify total rows in table (all batches loaded) - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - total_count = cur.fetchone()[0] - assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' - - # Verify the actual IDs are present - cur.execute(f'SELECT id FROM {test_table_name} ORDER BY id') - all_ids = [row[0] for row in cur.fetchall()] - assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present' - - finally: - loader.pool.putconn(conn) - - # Now test that re-sending the complete microbatch is properly deduplicated - # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch) - duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) - duplicate_response = ResponseBatch.data_batch( - data=duplicate_batch, - metadata=BatchMetadata( - ranges=[ - BlockRange(network='ethereum', start=100, end=110, hash='0xabc123') - ], # Same range as before! - ranges_complete=True, # Complete microbatch - ), - ) - - # Process duplicate microbatch - duplicate_results = list( - loader.load_stream_continuous( - iter([duplicate_response]), test_table_name, connection_name='test_connection' - ) - ) - - # The duplicate microbatch should be skipped (already processed) - assert len(duplicate_results) == 1 - assert duplicate_results[0].success is True - assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' - assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate' - - # Verify row count unchanged (duplicate was skipped) - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - final_count = cur.fetchone()[0] - assert final_count == 6, 'Row count should not increase after duplicate microbatch' - - finally: - loader.pool.putconn(conn) diff --git a/tests/integration/test_redis_loader.py b/tests/integration/test_redis_loader.py deleted file mode 100644 index ce23da7..0000000 --- a/tests/integration/test_redis_loader.py +++ /dev/null @@ -1,981 +0,0 @@ -# tests/integration/test_redis_loader.py -""" -Integration tests for Redis loader implementation. -These tests require a running Redis instance. -""" - -import json -import time -from datetime import datetime, timedelta - -import pyarrow as pa -import pytest - -try: - from src.amp.loaders.base import LoadMode - from src.amp.loaders.implementations.redis_loader import RedisLoader -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -@pytest.fixture -def small_test_data(): - """Create small test data for quick tests""" - data = { - 'id': [1, 2, 3, 4, 5], - 'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve'], - 'score': [100, 200, 150, 300, 250], - 'active': [True, False, True, False, True], - 'created_at': [datetime.now() - timedelta(days=i) for i in range(5)], - } - return pa.Table.from_pydict(data) - - -@pytest.fixture -def comprehensive_test_data(): - """Create comprehensive test data with various data types""" - base_date = datetime(2024, 1, 1) - - data = { - 'id': list(range(1000)), - 'user_id': [f'user_{i % 100}' for i in range(1000)], - 'score': [i * 10 for i in range(1000)], - 'text_field': [f'text_{i}' for i in range(1000)], - 'float_field': [i * 0.123 for i in range(1000)], - 'bool_field': [i % 2 == 0 for i in range(1000)], - 'timestamp': [(base_date + timedelta(days=i // 10, hours=i % 24)).timestamp() for i in range(1000)], - 'binary_field': [f'binary_{i}'.encode() for i in range(1000)], - 'json_data': [json.dumps({'index': i, 'value': f'val_{i}'}) for i in range(1000)], - 'nullable_field': [i if i % 10 != 0 else None for i in range(1000)], - } - - return pa.Table.from_pydict(data) - - -@pytest.fixture -def cleanup_redis(redis_test_config): - """Cleanup Redis data after tests""" - keys_to_clean = [] - patterns_to_clean = [] - - yield (keys_to_clean, patterns_to_clean) - - # Cleanup - try: - import redis - - r = redis.Redis( - host=redis_test_config['host'], - port=redis_test_config['port'], - db=redis_test_config['db'], - password=redis_test_config['password'], - ) - - # Delete specific keys - for key in keys_to_clean: - r.delete(key) - - # Delete keys matching patterns - for pattern in patterns_to_clean: - for key in r.scan_iter(match=pattern, count=1000): - r.delete(key) - - r.close() - except Exception: - pass - - -@pytest.mark.integration -@pytest.mark.redis -class TestRedisLoaderIntegration: - """Integration tests for Redis loader""" - - def test_loader_connection(self, redis_test_config): - """Test basic connection to Redis""" - loader = RedisLoader(redis_test_config) - - # Test connection - loader.connect() - assert loader._is_connected == True - assert loader.redis_client is not None - assert loader.connection_pool is not None - - # Test health check - health = loader.health_check() - assert health['healthy'] == True - assert 'redis_version' in health - - # Test disconnection - loader.disconnect() - assert loader._is_connected == False - assert loader.redis_client is None - - def test_context_manager(self, redis_test_config, small_test_data, cleanup_redis): - """Test context manager functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_context:*') - - loader = RedisLoader({**redis_test_config, 'data_structure': 'hash'}) - - with loader: - assert loader._is_connected == True - - result = loader.load_table(small_test_data, 'test_context') - assert result.success == True - - # Should be disconnected after context - assert loader._is_connected == False - - def test_hash_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test hash data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_hash:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_hash:{id}'} - loader = RedisLoader(config) - - with loader: - # Test loading - result = loader.load_table(small_test_data, 'test_hash') - - assert result.success == True - assert result.rows_loaded == 5 - assert result.metadata['data_structure'] == 'hash' - - # Verify data was stored - for i in range(5): - key = f'test_hash:{i + 1}' - assert loader.redis_client.exists(key) - - # Check specific fields - name = loader.redis_client.hget(key, 'name') - assert name.decode() == ['Alice', 'Bob', 'Charlie', 'David', 'Eve'][i] - - score = loader.redis_client.hget(key, 'score') - assert int(score.decode()) == [100, 200, 150, 300, 250][i] - - def test_string_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test string (JSON) data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_string:*') - - config = {**redis_test_config, 'data_structure': 'string', 'key_pattern': 'test_string:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_string') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify data was stored as JSON - for i in range(5): - key = f'test_string:{i + 1}' - assert loader.redis_client.exists(key) - - # Parse JSON data - json_data = json.loads(loader.redis_client.get(key)) - assert json_data['name'] == ['Alice', 'Bob', 'Charlie', 'David', 'Eve'][i] - assert json_data['score'] == [100, 200, 150, 300, 250][i] - - def test_stream_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test stream data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - keys_to_clean.append('test_stream:stream') - - config = {**redis_test_config, 'data_structure': 'stream'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_stream') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify stream was created - stream_key = 'test_stream:stream' - assert loader.redis_client.exists(stream_key) - - # Check stream length - info = loader.redis_client.xinfo_stream(stream_key) - assert info['length'] == 5 - - def test_set_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test set data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - keys_to_clean.append('test_set:set') - - config = {**redis_test_config, 'data_structure': 'set', 'unique_field': 'name'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_set') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify set was created - set_key = 'test_set:set' - assert loader.redis_client.exists(set_key) - assert loader.redis_client.scard(set_key) == 5 - - # Check members - members = loader.redis_client.smembers(set_key) - names = {m.decode() for m in members} - assert names == {'Alice', 'Bob', 'Charlie', 'David', 'Eve'} - - def test_sorted_set_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test sorted set data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - keys_to_clean.append('test_zset:zset') - - config = {**redis_test_config, 'data_structure': 'sorted_set', 'score_field': 'score'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_zset') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify sorted set was created - zset_key = 'test_zset:zset' - assert loader.redis_client.exists(zset_key) - assert loader.redis_client.zcard(zset_key) == 5 - - # Check score ordering - members_with_scores = loader.redis_client.zrange(zset_key, 0, -1, withscores=True) - scores = [score for _, score in members_with_scores] - assert scores == [100.0, 150.0, 200.0, 250.0, 300.0] # Should be sorted - - def test_list_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test list data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - keys_to_clean.append('test_list:list') - - config = {**redis_test_config, 'data_structure': 'list'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_list') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify list was created - list_key = 'test_list:list' - assert loader.redis_client.exists(list_key) - assert loader.redis_client.llen(list_key) == 5 - - def test_append_mode(self, redis_test_config, small_test_data, cleanup_redis): - """Test append mode functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_append:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_append:{id}'} - loader = RedisLoader(config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, 'test_append', mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 5 - - # Append more data (with different IDs) - # Convert to pydict, modify, and create new table - new_data_dict = small_test_data.to_pydict() - new_data_dict['id'] = [6, 7, 8, 9, 10] - new_table = pa.Table.from_pydict(new_data_dict) - - result = loader.load_table(new_table, 'test_append', mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 5 - - # Verify all keys exist - for i in range(1, 11): - key = f'test_append:{i}' - assert loader.redis_client.exists(key) - - def test_overwrite_mode(self, redis_test_config, small_test_data, cleanup_redis): - """Test overwrite mode functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_overwrite:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_overwrite:{id}'} - loader = RedisLoader(config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, 'test_overwrite', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 5 - - # Overwrite with less data - new_data = small_test_data.slice(0, 3) - result = loader.load_table(new_data, 'test_overwrite', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 3 - - # Verify old keys were deleted - assert not loader.redis_client.exists('test_overwrite:4') - assert not loader.redis_client.exists('test_overwrite:5') - - def test_batch_loading(self, redis_test_config, comprehensive_test_data, cleanup_redis): - """Test batch loading functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_batch:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_batch:{id}', 'batch_size': 250} - loader = RedisLoader(config) - - with loader: - # Test loading individual batches - batches = comprehensive_test_data.to_batches(max_chunksize=250) - total_rows = 0 - - for i, batch in enumerate(batches): - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - result = loader.load_batch(batch, 'test_batch', mode=mode) - - assert result.success == True - assert result.rows_loaded == batch.num_rows - total_rows += batch.num_rows - - assert total_rows == 1000 - - def test_ttl_functionality(self, redis_test_config, small_test_data, cleanup_redis): - """Test TTL (time-to-live) functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_ttl:*') - - config = { - **redis_test_config, - 'data_structure': 'hash', - 'key_pattern': 'test_ttl:{id}', - 'ttl': 2, # 2 seconds TTL - } - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_ttl') - assert result.success == True - - # Check keys exist - key = 'test_ttl:1' - assert loader.redis_client.exists(key) - - # Check TTL is set - ttl = loader.redis_client.ttl(key) - assert ttl > 0 and ttl <= 2 - - # Wait for expiration - time.sleep(3) - assert not loader.redis_client.exists(key) - - def test_null_value_handling(self, redis_test_config, null_test_data, cleanup_redis): - """Test comprehensive null value handling across all data types""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_nulls:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_nulls:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(null_test_data, 'test_nulls') - assert result.success == True - assert result.rows_loaded == 10 - - # Verify null values were handled correctly (skipped in hash fields) - for i in range(1, 11): - key = f'test_nulls:{i}' - assert loader.redis_client.exists(key) - - # Check that null fields are not stored (Redis hash skips null values) - fields = loader.redis_client.hkeys(key) - field_names = {f.decode() if isinstance(f, bytes) else f for f in fields} - - # Based on our null test data pattern, verify expected fields - # text_field pattern: ['a', 'b', None, 'd', 'e', None, 'g', 'h', None, 'j'] - # So ids 3, 6, 9 should have None (indices 2, 5, 8) - if i in [3, 6, 9]: # text_field is None for these IDs - assert 'text_field' not in field_names - else: - assert 'text_field' in field_names, f'Expected text_field for id={i}, but got fields: {field_names}' - - # int_field pattern: [1, None, 3, 4, None, 6, 7, None, 9, 10] - # So ids 2, 5, 8 should have None (indices 1, 4, 7) - if i in [2, 5, 8]: # int_field is None for these IDs - assert 'int_field' not in field_names - else: - assert 'int_field' in field_names - - # float_field pattern: [1.1, 2.2, None, 4.4, 5.5, None, 7.7, 8.8, None, 10.0] - # So ids 3, 6, 9 should have None (indices 2, 5, 8) - if i in [3, 6, 9]: # float_field is None for these IDs - assert 'float_field' not in field_names - else: - assert 'float_field' in field_names - - # Verify non-null values are intact - if 'text_field' in field_names: - text_val = loader.redis_client.hget(key, 'text_field') - expected_chars = ['a', 'b', None, 'd', 'e', None, 'g', 'h', None, 'j'] - expected_char = expected_chars[i - 1] # Convert id to index - assert text_val.decode() == expected_char - - if 'int_field' in field_names: - int_val = loader.redis_client.hget(key, 'int_field') - expected_ints = [1, None, 3, 4, None, 6, 7, None, 9, 10] - expected_int = expected_ints[i - 1] # Convert id to index - assert int(int_val.decode()) == expected_int - - def test_null_value_handling_string_structure(self, redis_test_config, null_test_data, cleanup_redis): - """Test null value handling with string (JSON) data structure""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_json_nulls:*') - - config = {**redis_test_config, 'data_structure': 'string', 'key_pattern': 'test_json_nulls:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(null_test_data, 'test_json_nulls') - assert result.success == True - assert result.rows_loaded == 10 - - for i in range(1, 11): - key = f'test_json_nulls:{i}' - assert loader.redis_client.exists(key) - - json_str = loader.redis_client.get(key) - json_data = json.loads(json_str) - - if i in [3, 6, 9]: - assert 'text_field' not in json_data - else: - expected_chars = ['a', 'b', None, 'd', 'e', None, 'g', 'h', None, 'j'] - expected_char = expected_chars[i - 1] - assert json_data['text_field'] == expected_char - - if i in [2, 5, 8]: - assert 'int_field' not in json_data - else: - expected_ints = [1, None, 3, 4, None, 6, 7, None, 9, 10] - expected_int = expected_ints[i - 1] - assert json_data['int_field'] == expected_int - - def test_binary_data_handling(self, redis_test_config, cleanup_redis): - """Test binary data handling""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_binary:*') - - # Create data with binary columns - data = {'id': [1, 2, 3], 'binary_data': [b'hello', b'world', b'\x00\x01\x02\x03'], 'text_data': ['a', 'b', 'c']} - table = pa.Table.from_pydict(data) - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_binary:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(table, 'test_binary') - assert result.success == True - assert result.rows_loaded == 3 - - # Verify binary data was stored correctly - assert loader.redis_client.hget('test_binary:1', 'binary_data') == b'hello' - assert loader.redis_client.hget('test_binary:2', 'binary_data') == b'world' - assert loader.redis_client.hget('test_binary:3', 'binary_data') == b'\x00\x01\x02\x03' - - def test_comprehensive_stats(self, redis_test_config, small_test_data, cleanup_redis): - """Test comprehensive statistics functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_stats:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_stats:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_stats') - assert result.success == True - - # Get stats - stats = loader.get_comprehensive_stats('test_stats') - assert 'table_name' in stats - assert stats['data_structure'] == 'hash' - assert stats['key_count'] == 5 - assert 'estimated_memory_bytes' in stats - assert 'estimated_memory_mb' in stats - - def test_error_handling(self, redis_test_config, small_test_data): - """Test error handling scenarios""" - # Test with invalid configuration - invalid_config = {**redis_test_config, 'host': 'invalid-host-that-does-not-exist', 'socket_connect_timeout': 1} - loader = RedisLoader(invalid_config) - - import redis - - with pytest.raises(redis.exceptions.ConnectionError): - loader.connect() - - def test_key_pattern_generation(self, redis_test_config, cleanup_redis): - """Test various key pattern generations""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('complex:*') - - # Create data with multiple fields for complex key - data = {'user_id': ['u1', 'u2', 'u3'], 'session_id': ['s1', 's2', 's3'], 'timestamp': [100, 200, 300]} - table = pa.Table.from_pydict(data) - - config = { - **redis_test_config, - 'data_structure': 'hash', - 'key_pattern': 'complex:{user_id}:{session_id}:{timestamp}', - } - loader = RedisLoader(config) - - with loader: - result = loader.load_table(table, 'complex') - assert result.success == True - - # Verify complex keys were created - assert loader.redis_client.exists('complex:u1:s1:100') - assert loader.redis_client.exists('complex:u2:s2:200') - assert loader.redis_client.exists('complex:u3:s3:300') - - def test_performance_metrics(self, redis_test_config, comprehensive_test_data, cleanup_redis): - """Test performance metrics in results""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_perf:*') - - config = { - **redis_test_config, - 'data_structure': 'hash', - 'key_pattern': 'test_perf:{id}', - 'batch_size': 100, - 'pipeline_size': 500, - } - loader = RedisLoader(config) - - with loader: - start_time = time.time() - result = loader.load_table(comprehensive_test_data, 'test_perf') - end_time = time.time() - - assert result.success == True - assert result.duration > 0 - assert result.duration <= (end_time - start_time) - - # Check performance info - ops_per_second is now a direct attribute - assert hasattr(result, 'ops_per_second') - assert result.ops_per_second > 0 - assert 'batches_processed' in result.metadata - assert 'avg_batch_size' in result.metadata - - -@pytest.mark.integration -@pytest.mark.redis -@pytest.mark.slow -class TestRedisLoaderPerformance: - """Performance tests for Redis loader""" - - def test_large_data_loading(self, redis_test_config, cleanup_redis): - """Test loading large datasets""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_large:*') - - # Create large dataset - large_data = { - 'id': list(range(10000)), - 'value': [i * 0.123 for i in range(10000)], - 'category': [f'category_{i % 100}' for i in range(10000)], - 'description': [f'This is a longer text description for row {i}' for i in range(10000)], - 'active': [i % 2 == 0 for i in range(10000)], - } - large_table = pa.Table.from_pydict(large_data) - - config = { - **redis_test_config, - 'data_structure': 'hash', - 'key_pattern': 'test_large:{id}', - 'batch_size': 1000, - 'pipeline_size': 1000, - } - loader = RedisLoader(config) - - with loader: - result = loader.load_table(large_table, 'test_large') - - assert result.success == True - assert result.rows_loaded == 10000 - assert result.duration < 30 # Should complete within 30 seconds - - # Verify performance metrics - assert result.ops_per_second > 100 # Should handle >100 ops/sec - - def test_data_structure_performance_comparison(self, redis_test_config, cleanup_redis): - """Compare performance across different data structures""" - keys_to_clean, patterns_to_clean = cleanup_redis - - # Create test data - data = {'id': list(range(1000)), 'score': list(range(1000)), 'data': [f'value_{i}' for i in range(1000)]} - table = pa.Table.from_pydict(data) - - structures = ['hash', 'string', 'set', 'sorted_set', 'list'] - results = {} - - for structure in structures: - patterns_to_clean.append(f'perf_{structure}:*') - keys_to_clean.append(f'perf_{structure}:{structure}') - - config = { - **redis_test_config, - 'data_structure': structure, - 'key_pattern': f'perf_{structure}:{{id}}', - 'score_field': 'score' if structure == 'sorted_set' else None, - } - loader = RedisLoader(config) - - with loader: - result = loader.load_table(table, f'perf_{structure}') - results[structure] = result.ops_per_second - - # All structures should perform reasonably well - for structure, ops_per_sec in results.items(): - assert ops_per_sec > 50, f'{structure} performance too low: {ops_per_sec} ops/sec' - - -@pytest.mark.integration -@pytest.mark.redis -class TestRedisLoaderStreaming: - """Integration tests for Redis loader streaming functionality""" - - def test_streaming_metadata_columns(self, redis_test_config, cleanup_redis): - """Test that streaming data stores batch ID metadata""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'streaming_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - # Import streaming types - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - # Create test data with metadata - data = { - 'id': [1, 2, 3], # Required for Redis key generation - 'block_number': [100, 101, 102], - 'transaction_hash': ['0xabc', '0xdef', '0x123'], - 'value': [1.0, 2.0, 3.0], - } - batch = pa.RecordBatch.from_pydict(data) - - # Create metadata with block ranges - block_ranges = [BlockRange(network='ethereum', start=100, end=102, hash='0xabc')] - - config = {**redis_test_config, 'data_structure': 'hash'} - loader = RedisLoader(config) - - with loader: - # Load via streaming API - response = ResponseBatch.data_batch(data=batch, metadata=BatchMetadata(ranges=block_ranges)) - results = list(loader.load_stream_continuous(iter([response]), table_name)) - assert len(results) == 1 - assert results[0].success == True - assert results[0].rows_loaded == 3 - - # Verify data was stored - primary_keys = [f'{table_name}:1', f'{table_name}:2', f'{table_name}:3'] - for key in primary_keys: - assert loader.redis_client.exists(key) - # Check that batch_id metadata was stored - batch_id_field = loader.redis_client.hget(key, '_amp_batch_id') - assert batch_id_field is not None - batch_id_str = batch_id_field.decode('utf-8') - assert isinstance(batch_id_str, str) - assert len(batch_id_str) >= 16 # At least one 16-char batch ID - - def test_handle_reorg_deletion(self, redis_test_config, cleanup_redis): - """Test that _handle_reorg correctly deletes invalidated ranges""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'reorg_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**redis_test_config, 'data_structure': 'hash'} - loader = RedisLoader(config) - - with loader: - # Create streaming batches with metadata - batch1 = pa.RecordBatch.from_pydict( - { - 'id': [1, 2, 3], # Required for Redis key generation - 'tx_hash': ['0x100', '0x101', '0x102'], - 'block_num': [100, 101, 102], - 'value': [10.0, 11.0, 12.0], - } - ) - batch2 = pa.RecordBatch.from_pydict( - {'id': [4, 5], 'tx_hash': ['0x200', '0x201'], 'block_num': [103, 104], 'value': [13.0, 14.0]} - ) - batch3 = pa.RecordBatch.from_pydict( - {'id': [6, 7], 'tx_hash': ['0x300', '0x301'], 'block_num': [105, 106], 'value': [15.0, 16.0]} - ) - - # Create response batches with hashes - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), table_name)) - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify initial data - initial_keys = [] - pattern = f'{table_name}:*' - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - initial_keys.append(key) - assert len(initial_keys) == 7 # 3 + 2 + 2 - - # Test reorg deletion - invalidate blocks 104-108 on ethereum - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=104, end=108)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Should delete batch2 and batch3, leaving only batch1 (3 keys) - remaining_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - remaining_keys.append(key) - assert len(remaining_keys) == 3 - - def test_reorg_with_overlapping_ranges(self, redis_test_config, cleanup_redis): - """Test reorg deletion with overlapping block ranges""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'overlap_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**redis_test_config, 'data_structure': 'hash'} - loader = RedisLoader(config) - - with loader: - # Load data with overlapping ranges that should be invalidated - batch = pa.RecordBatch.from_pydict( - { - 'id': [1, 2, 3], - 'tx_hash': ['0x150', '0x175', '0x250'], - 'block_num': [150, 175, 250], - 'value': [15.0, 17.5, 25.0], - } - ) - - response = ResponseBatch.data_batch( - data=batch, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - results = list(loader.load_stream_continuous(iter([response]), table_name)) - assert len(results) == 1 - assert results[0].success - - # Verify initial data - pattern = f'{table_name}:*' - initial_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - initial_keys.append(key) - assert len(initial_keys) == 3 - - # Test partial overlap invalidation (160-180) - # This should invalidate our range [150,175] because they overlap - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=160, end=180)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # All data should be deleted due to overlap - remaining_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - remaining_keys.append(key) - assert len(remaining_keys) == 0 - - def test_reorg_preserves_different_networks(self, redis_test_config, cleanup_redis): - """Test that reorg only affects specified network""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'multinetwork_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**redis_test_config, 'data_structure': 'hash'} - loader = RedisLoader(config) - - with loader: - # Load data from multiple networks with same block ranges - batch_eth = pa.RecordBatch.from_pydict( - { - 'id': [1], - 'tx_hash': ['0x100_eth'], - 'network_id': ['ethereum'], - 'block_num': [100], - 'value': [10.0], - } - ) - batch_poly = pa.RecordBatch.from_pydict( - { - 'id': [2], - 'tx_hash': ['0x100_poly'], - 'network_id': ['polygon'], - 'block_num': [100], - 'value': [10.0], - } - ) - - response_eth = ResponseBatch.data_batch( - data=batch_eth, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response_poly = ResponseBatch.data_batch( - data=batch_poly, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load both batches via streaming API - stream = [response_eth, response_poly] - results = list(loader.load_stream_continuous(iter(stream), table_name)) - assert len(results) == 2 - assert all(r.success for r in results) - - # Verify both networks' data exists - pattern = f'{table_name}:*' - initial_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - initial_keys.append(key) - assert len(initial_keys) == 2 - - # Invalidate only ethereum network - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=100, end=100)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Should only delete ethereum data, polygon should remain - remaining_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - remaining_keys.append(key) - assert len(remaining_keys) == 1 - - # Verify remaining data is from polygon (just check batch_id exists) - remaining_key = remaining_keys[0] - batch_id_field = loader.redis_client.hget(remaining_key, '_amp_batch_id') - assert batch_id_field is not None - # Batch ID is a compact string, not network-specific, so we just verify it exists - - def test_streaming_with_string_data_structure(self, redis_test_config, cleanup_redis): - """Test streaming support with string data structure""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'string_streaming_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**redis_test_config, 'data_structure': 'string'} - loader = RedisLoader(config) - - with loader: - # Create test data - data = { - 'id': [1, 2, 3], - 'transaction_hash': ['0xaaa', '0xbbb', '0xccc'], - 'value': [100.0, 200.0, 300.0], - } - batch = pa.RecordBatch.from_pydict(data) - block_ranges = [BlockRange(network='polygon', start=200, end=202, hash='0xabc')] - - # Load via streaming API - response = ResponseBatch.data_batch( - data=batch, - metadata=BatchMetadata( - ranges=block_ranges, - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - results = list(loader.load_stream_continuous(iter([response]), table_name)) - assert len(results) == 1 - assert results[0].success == True - assert results[0].rows_loaded == 3 - - # Verify data was stored as JSON strings - for _i, id_val in enumerate([1, 2, 3]): - key = f'{table_name}:{id_val}' - assert loader.redis_client.exists(key) - - # Get and parse JSON data - json_data = loader.redis_client.get(key) - parsed_data = json.loads(json_data.decode('utf-8')) - assert '_amp_batch_id' in parsed_data - batch_id_str = parsed_data['_amp_batch_id'] - assert isinstance(batch_id_str, str) - assert len(batch_id_str) >= 16 # At least one 16-char batch ID - - # Verify reorg handling works with string data structure - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='polygon', start=201, end=205)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # All data should be deleted since ranges overlap - pattern = f'{table_name}:*' - remaining_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - remaining_keys.append(key) - assert len(remaining_keys) == 0 diff --git a/tests/integration/test_snowflake_loader.py b/tests/integration/test_snowflake_loader.py deleted file mode 100644 index 50f96c7..0000000 --- a/tests/integration/test_snowflake_loader.py +++ /dev/null @@ -1,1215 +0,0 @@ -# tests/integration/test_snowflake_loader.py -""" -Integration tests for Snowflake loader implementation. -These tests require a running Snowflake instance with proper credentials. - -NOTE: Snowflake integration tests are currently disabled because they require an active snowflake account -To re-enable these tests: -1. Set up a valid Snowflake account with billing enabled -2. Configure the following environment variables: - - SNOWFLAKE_ACCOUNT - - SNOWFLAKE_USER - - SNOWFLAKE_PASSWORD - - SNOWFLAKE_WAREHOUSE - - SNOWFLAKE_DATABASE - - SNOWFLAKE_SCHEMA (optional, defaults to PUBLIC) -3. Remove the skip decorator below -""" - -import time -from datetime import datetime - -import pyarrow as pa -import pytest - -try: - from src.amp.loaders.base import LoadMode - from src.amp.loaders.implementations.snowflake_loader import SnowflakeLoader - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -def wait_for_snowpipe_data(loader, table_name, expected_count, max_wait=30, poll_interval=2): - """ - Wait for Snowpipe streaming data to become queryable. - - Snowpipe streaming has eventual consistency, so data may not be immediately - queryable after insertion. This helper polls until the expected row count is visible. - - Args: - loader: SnowflakeLoader instance with active connection - table_name: Name of the table to query - expected_count: Expected number of rows - max_wait: Maximum seconds to wait (default 30) - poll_interval: Seconds between poll attempts (default 2) - - Returns: - int: Actual row count found - - Raises: - AssertionError: If expected count not reached within max_wait seconds - """ - elapsed = 0 - while elapsed < max_wait: - loader.cursor.execute(f'SELECT COUNT(*) FROM {table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - if count == expected_count: - return count - time.sleep(poll_interval) - elapsed += poll_interval - - # Final check before giving up - loader.cursor.execute(f'SELECT COUNT(*) FROM {table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == expected_count, f'Expected {expected_count} rows after {max_wait}s, but found {count}' - return count - - -# Skip all Snowflake tests -pytestmark = pytest.mark.skip(reason='Requires active Snowflake account - see module docstring for details') - - -@pytest.fixture -def test_table_name(): - """Generate unique table name for each test""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') - return f'test_table_{timestamp}' - - -@pytest.fixture -def cleanup_tables(snowflake_config): - """Cleanup test tables after tests""" - tables_to_clean = [] - - yield tables_to_clean - - loader = SnowflakeLoader(snowflake_config) - try: - loader.connect() - for table in tables_to_clean: - try: - loader.cursor.execute(f'DROP TABLE IF EXISTS {table}') - loader.connection.commit() - except Exception: - pass - except Exception: - pass - finally: - if loader._is_connected: - loader.disconnect() - - -@pytest.mark.integration -@pytest.mark.snowflake -class TestSnowflakeLoaderIntegration: - """Integration tests for Snowflake loader""" - - def test_loader_connection(self, snowflake_config): - """Test basic connection to Snowflake""" - loader = SnowflakeLoader(snowflake_config) - - loader.connect() - assert loader._is_connected is True - assert loader.connection is not None - assert loader.cursor is not None - - loader.disconnect() - assert loader._is_connected is False - assert loader.connection is None - assert loader.cursor is None - - def test_basic_table_loading_via_stage(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test basic table loading using stage""" - cleanup_tables.append(test_table_name) - - config = {**snowflake_config, 'loading_method': 'stage'} - loader = SnowflakeLoader(config) - - with loader: - result = loader.load_table(small_test_table, test_table_name, create_table=True) - - assert result.success is True - assert result.rows_loaded == small_test_table.num_rows - assert result.table_name == test_table_name - assert result.loader_type == 'snowflake' - assert result.metadata['loading_method'] == 'stage' - - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == small_test_table.num_rows - - def test_basic_table_loading_via_insert(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test basic table loading using INSERT (Note: currently defaults to stage for performance)""" - cleanup_tables.append(test_table_name) - - # Use insert loading (Note: implementation may default to stage for small tables) - config = {**snowflake_config, 'loading_method': 'insert'} - loader = SnowflakeLoader(config) - - with loader: - result = loader.load_table(small_test_table, test_table_name, create_table=True) - - assert result.success is True - assert result.rows_loaded == small_test_table.num_rows - # Note: Implementation uses stage by default for performance - assert result.metadata['loading_method'] in ['insert', 'stage'] - - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == small_test_table.num_rows - - def test_batch_loading(self, snowflake_config, medium_test_table, test_table_name, cleanup_tables): - """Test loading data in batches""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Use smaller batch size to force multiple batches (medium_test_table has 10000 rows) - result = loader.load_table(medium_test_table, test_table_name, create_table=True, batch_size=5000) - - assert result.success is True - assert result.rows_loaded == medium_test_table.num_rows - # Implementation may optimize batching, so just check >= 1 - assert result.metadata.get('batches_processed', 1) >= 1 - - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == medium_test_table.num_rows - - def test_overwrite_mode(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test OVERWRITE mode is not supported""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result1 = loader.load_table(small_test_table, test_table_name, create_table=True) - assert result1.success is True - - # OVERWRITE mode should fail with error message - result2 = loader.load_table(small_test_table, test_table_name, mode=LoadMode.OVERWRITE) - assert result2.success is False - assert 'Unsupported mode LoadMode.OVERWRITE' in result2.error - - def test_append_mode(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test APPEND mode adds to existing data""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result1 = loader.load_table(small_test_table, test_table_name, create_table=True) - assert result1.success is True - - result2 = loader.load_table(small_test_table, test_table_name, mode=LoadMode.APPEND) - assert result2.success is True - - # Should have double the rows - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == small_test_table.num_rows * 2 - - def test_comprehensive_data_types(self, snowflake_config, comprehensive_test_data, test_table_name, cleanup_tables): - """Test various data types from comprehensive test data""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result = loader.load_table(comprehensive_test_data, test_table_name, create_table=True) - assert result.success is True - - loader.cursor.execute(f""" - SELECT - "id", - "user_id", - "transaction_amount", - "category", - "timestamp", - "is_weekend", - "score", - "active" - FROM {test_table_name} - WHERE "id" = 0 - """) - - row = loader.cursor.fetchone() - assert row['id'] == 0 - assert row['user_id'] == 'user_0' - assert abs(row['transaction_amount'] - 0.0) < 0.001 - assert row['category'] == 'electronics' - assert row['is_weekend'] is True # id=0 is weekend - assert abs(row['score'] - 0.0) < 0.001 - assert row['active'] is True - - def test_null_handling(self, snowflake_config, null_test_data, test_table_name, cleanup_tables): - """Test proper handling of NULL values""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result = loader.load_table(null_test_data, test_table_name, create_table=True) - assert result.success is True - - loader.cursor.execute(f""" - SELECT COUNT(*) as null_count - FROM {test_table_name} - WHERE "text_field" IS NULL - """) - - null_count = loader.cursor.fetchone()['NULL_COUNT'] - expected_nulls = sum(1 for val in null_test_data.column('text_field').to_pylist() if val is None) - assert null_count == expected_nulls - - loader.cursor.execute(f""" - SELECT - COUNT(CASE WHEN "int_field" IS NULL THEN 1 END) as int_nulls, - COUNT(CASE WHEN "float_field" IS NULL THEN 1 END) as float_nulls, - COUNT(CASE WHEN "bool_field" IS NULL THEN 1 END) as bool_nulls - FROM {test_table_name} - """) - - null_counts = loader.cursor.fetchone() - assert null_counts['INT_NULLS'] > 0 - assert null_counts['FLOAT_NULLS'] > 0 - assert null_counts['BOOL_NULLS'] > 0 - - def test_table_info(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test getting table information""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result = loader.load_table(small_test_table, test_table_name, create_table=True) - assert result.success is True - - info = loader.get_table_info(test_table_name) - - assert info is not None - assert info['table_name'] == test_table_name.upper() - assert info['schema'] == snowflake_config.get('schema', 'PUBLIC') - # Table should have original columns + _amp_batch_id metadata column - assert len(info['columns']) == len(small_test_table.schema) + 1 - - # Verify _amp_batch_id column exists - batch_id_col = next((col for col in info['columns'] if col['name'].lower() == '_amp_batch_id'), None) - assert batch_id_col is not None, 'Expected _amp_batch_id metadata column' - - # In Snowflake, quoted column names are case-sensitive but INFORMATION_SCHEMA may return them differently - # Let's find the ID column by looking for either case variant - id_col = None - for col in info['columns']: - if col['name'].upper() == 'ID' or col['name'] == 'id': - id_col = col - break - - assert id_col is not None, f'Could not find ID column in {[col["name"] for col in info["columns"]]}' - assert 'INT' in id_col['type'] or 'NUMBER' in id_col['type'] - - @pytest.mark.slow - def test_performance_batch_loading(self, snowflake_config, performance_test_data, test_table_name, cleanup_tables): - """Test performance with larger dataset""" - cleanup_tables.append(test_table_name) - - config = {**snowflake_config, 'loading_method': 'stage'} - loader = SnowflakeLoader(config) - - with loader: - start_time = time.time() - - result = loader.load_table(performance_test_data, test_table_name, create_table=True) - - duration = time.time() - start_time - - assert result.success is True - assert result.rows_loaded == performance_test_data.num_rows - - rows_per_second = result.rows_loaded / duration - mb_per_second = (performance_test_data.nbytes / 1024 / 1024) / duration - - print('\nPerformance metrics:') - print(f' Total rows: {result.rows_loaded:,}') - print(f' Duration: {duration:.2f}s') - print(f' Throughput: {rows_per_second:,.0f} rows/sec') - print(f' Data rate: {mb_per_second:.2f} MB/sec') - print(f' Batches: {result.metadata.get("batches_processed", "N/A")}') - - def test_error_handling_invalid_table(self, snowflake_config, small_test_table): - """Test error handling for invalid table operations""" - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Try to load without creating table - result = loader.load_table(small_test_table, 'non_existent_table_xyz', create_table=False) - - assert result.success is False - assert result.error is not None - - @pytest.mark.slow - def test_concurrent_batch_loading(self, snowflake_config, medium_test_table, test_table_name, cleanup_tables): - """Test loading multiple batches concurrently""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create table first - batch = medium_test_table.to_batches(max_chunksize=1)[0] - loader.load_batch(batch, test_table_name, create_table=True) - - # Load multiple batches - total_rows = 0 - for _i, batch in enumerate(medium_test_table.to_batches(max_chunksize=500)): - result = loader.load_batch(batch, test_table_name) - assert result.success is True - total_rows += result.rows_loaded - - assert total_rows == medium_test_table.num_rows - - # Verify all data loaded - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == medium_test_table.num_rows + 1 # +1 for initial batch - - # Removed test_stage_and_compression_options - compression parameter not supported in current config - - def test_schema_with_special_characters(self, snowflake_config, test_table_name, cleanup_tables): - """Test handling of column names with special characters""" - cleanup_tables.append(test_table_name) - - data = { - 'user-id': [1, 2, 3], - 'first name': ['Alice', 'Bob', 'Charlie'], - 'total$amount': [100.0, 200.0, 300.0], - '2024_data': ['a', 'b', 'c'], - } - special_table = pa.Table.from_pydict(data) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result = loader.load_table(special_table, test_table_name, create_table=True) - assert result.success is True - - # Verify row count - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 3 - - # Verify we can query columns with special characters - # Note: Snowflake typically converts column names to uppercase and may need quoting - loader.cursor.execute(f""" - SELECT - "user-id", - "first name", - "total$amount", - "2024_data" - FROM {test_table_name} - WHERE "user-id" = 1 - """) - - row = loader.cursor.fetchone() - - assert row['user-id'] == 1 - assert row['first name'] == 'Alice' - assert abs(row['total$amount'] - 100.0) < 0.001 - assert row['2024_data'] == 'a' - - def test_handle_reorg_no_metadata_column(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg handling when table lacks metadata column""" - from src.amp.streaming.types import BlockRange - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create table without metadata column - data = pa.table({'id': [1, 2, 3], 'block_num': [100, 150, 200], 'value': [10.0, 20.0, 30.0]}) - loader.load_table(data, test_table_name, create_table=True) - - # Call handle reorg - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)] - - # Should log warning and not modify data - loader._handle_reorg(invalidation_ranges, test_table_name, 'test_connection') - - # Verify data unchanged - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 3 - - def test_handle_reorg_single_network(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg handling for single network data""" - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create batches with proper metadata - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'block_num': [105]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'block_num': [155]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'block_num': [205]}) - - # Create streaming responses with block ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')]), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')]), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0xghi')]), - ) - - # Load data via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify all data loaded successfully - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify all data exists - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 3 - - # Trigger reorg from block 155 - should delete rows 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=155, end=300)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - - # Verify reorg processed - assert len(reorg_results) == 1 - assert reorg_results[0].is_reorg - - # Verify only first row remains - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 1 - - loader.cursor.execute(f'SELECT "id" FROM {test_table_name}') - remaining_id = loader.cursor.fetchone()['id'] - assert remaining_id == 1 - - def test_handle_reorg_multi_network(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg handling preserves data from unaffected networks""" - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create batches from multiple networks - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'network': ['ethereum']}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'network': ['polygon']}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'network': ['ethereum']}) - batch4 = pa.RecordBatch.from_pydict({'id': [4], 'network': ['polygon']}) - - # Create streaming responses with block ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xa')]), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xb')]), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xc')]), - ) - response4 = ResponseBatch.data_batch( - data=batch4, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xd')]), - ) - - # Load data via streaming API - stream = [response1, response2, response3, response4] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify all data loaded successfully - assert len(results) == 4 - assert all(r.success for r in results) - - # Trigger reorg for ethereum only from block 150 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - - # Verify reorg processed - assert len(reorg_results) == 1 - assert reorg_results[0].is_reorg - - # Verify ethereum row 3 deleted, but polygon rows preserved - loader.cursor.execute(f'SELECT "id" FROM {test_table_name} ORDER BY "id"') - remaining_ids = [row['id'] for row in loader.cursor.fetchall()] - assert remaining_ids == [1, 2, 4] # Row 3 deleted - - def test_handle_reorg_overlapping_ranges(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg with overlapping block ranges""" - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create batches with overlapping ranges - batch1 = pa.RecordBatch.from_pydict({'id': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3]}) - - # Create streaming responses with block ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xa')] - ), # Before reorg - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xb')] - ), # Overlaps - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xc')] - ), # Overlaps - ) - - # Load data via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify all data loaded successfully - assert len(results) == 3 - assert all(r.success for r in results) - - # Trigger reorg from block 150 - should delete rows where end >= 150 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - - # Verify reorg processed - assert len(reorg_results) == 1 - assert reorg_results[0].is_reorg - - # Only first row should remain (ends at 110 < 150) - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 1 - - loader.cursor.execute(f'SELECT "id" FROM {test_table_name}') - remaining_id = loader.cursor.fetchone()['id'] - assert remaining_id == 1 - - def test_handle_reorg_with_history_preservation(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg history preservation mode - rows are updated instead of deleted""" - - cleanup_tables.append(test_table_name) - cleanup_tables.append(f'{test_table_name}_current') - cleanup_tables.append(f'{test_table_name}_history') - - # Enable history preservation - config_with_history = {**snowflake_config, 'preserve_reorg_history': True} - loader = SnowflakeLoader(config_with_history) - - with loader: - # Create batches with proper metadata - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'block_num': [105]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'block_num': [155]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'block_num': [205]}) - - # Create streaming responses with block ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')]), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')]), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0xghi')]), - ) - - # Load data via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify all data loaded successfully - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify temporal columns exist and are set correctly - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE "_amp_is_current" = TRUE') - current_count = loader.cursor.fetchone()['COUNT(*)'] - assert current_count == 3 - - # Verify reorg columns exist - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE "_amp_reorg_batch_id" IS NULL') - not_reorged_count = loader.cursor.fetchone()['COUNT(*)'] - assert not_reorged_count == 3 # All current rows should have NULL reorg_batch_id - - # Verify views exist - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}_current') - view_count = loader.cursor.fetchone()['COUNT(*)'] - assert view_count == 3 - - # Trigger reorg from block 155 - should UPDATE rows 2 and 3, not delete them - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=155, end=300)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - - # Verify reorg processed - assert len(reorg_results) == 1 - assert reorg_results[0].is_reorg - - # Verify ALL 3 rows still exist in base table - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - total_count = loader.cursor.fetchone()['COUNT(*)'] - assert total_count == 3 - - # Verify only first row is current - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE "_amp_is_current" = TRUE') - current_count = loader.cursor.fetchone()['COUNT(*)'] - assert current_count == 1 - - # Verify _current view shows only active row - loader.cursor.execute(f'SELECT "id" FROM {test_table_name}_current') - current_ids = [row['id'] for row in loader.cursor.fetchall()] - assert current_ids == [1] - - # Verify _history view shows all rows - loader.cursor.execute(f'SELECT "id" FROM {test_table_name}_history ORDER BY "id"') - history_ids = [row['id'] for row in loader.cursor.fetchall()] - assert history_ids == [1, 2, 3] - - # Verify reorged rows have simplified reorg columns set correctly - loader.cursor.execute( - f'''SELECT "id", "_amp_is_current", "_amp_batch_id", "_amp_reorg_batch_id" - FROM {test_table_name} - WHERE "_amp_is_current" = FALSE - ORDER BY "id"''' - ) - reorged_rows = loader.cursor.fetchall() - assert len(reorged_rows) == 2 - assert reorged_rows[0]['id'] == 2 - assert reorged_rows[1]['id'] == 3 - # Verify reorg_batch_id is set (identifies which reorg event superseded these rows) - assert reorged_rows[0]['_amp_reorg_batch_id'] is not None - assert reorged_rows[1]['_amp_reorg_batch_id'] is not None - # Both rows superseded by same reorg event - assert reorged_rows[0]['_amp_reorg_batch_id'] == reorged_rows[1]['_amp_reorg_batch_id'] - - def test_parallel_streaming_with_stage(self, snowflake_config, test_table_name, cleanup_tables): - """Test parallel streaming using stage loading method""" - import threading - - cleanup_tables.append(test_table_name) - config = {**snowflake_config, 'loading_method': 'stage'} - loader = SnowflakeLoader(config) - - with loader: - # Create table first - initial_batch = pa.RecordBatch.from_pydict({'id': [1], 'partition': ['partition_0'], 'value': [100]}) - loader.load_batch(initial_batch, test_table_name, create_table=True) - - # Thread lock for serializing access to shared Snowflake connection - # (Snowflake connector is not thread-safe) - load_lock = threading.Lock() - - # Load multiple batches in parallel from different "streams" - def load_partition_data(partition_id: int, start_id: int): - """Simulate a stream partition loading data""" - for batch_num in range(3): - batch_start = start_id + (batch_num * 10) - batch = pa.RecordBatch.from_pydict( - { - 'id': list(range(batch_start, batch_start + 10)), - 'partition': [f'partition_{partition_id}'] * 10, - 'value': list(range(batch_start * 100, (batch_start + 10) * 100, 100)), - } - ) - # Use lock to ensure thread-safe access to shared connection - with load_lock: - result = loader.load_batch(batch, test_table_name, create_table=False) - assert result.success, f'Partition {partition_id} batch {batch_num} failed: {result.error}' - - # Launch 3 parallel "streams" (threads simulating parallel streaming) - threads = [] - for partition_id in range(3): - start_id = 100 + (partition_id * 100) - thread = threading.Thread(target=load_partition_data, args=(partition_id, start_id)) - threads.append(thread) - thread.start() - - # Wait for all streams to complete - for thread in threads: - thread.join() - - # Verify all data loaded correctly - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - # 1 initial + (3 partitions * 3 batches * 10 rows) = 91 rows - assert count == 91 - - # Verify each partition loaded correctly - for partition_id in range(3): - loader.cursor.execute( - f'SELECT COUNT(*) FROM {test_table_name} WHERE "partition" = \'partition_{partition_id}\'' - ) - partition_count = loader.cursor.fetchone()['COUNT(*)'] - # partition_0 has 31 rows (1 initial + 30 from thread), others have 30 - expected_count = 31 if partition_id == 0 else 30 - assert partition_count == expected_count - - def test_streaming_with_reorg(self, snowflake_config, test_table_name, cleanup_tables): - """Test streaming data with reorg support""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create streaming data with metadata - data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - - data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - - # Create response batches using factory methods (with hashes for proper state management) - response1 = ResponseBatch.data_batch( - data=data1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]), - ) - - response2 = ResponseBatch.data_batch( - data=data2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')]), - ) - - # Simulate reorg event using factory method - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - - # Process streaming data - stream = [response1, response2, reorg_response] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify results - assert len(results) == 3 - assert results[0].success - assert results[0].rows_loaded == 2 - assert results[1].success - assert results[1].rows_loaded == 2 - assert results[2].success - assert results[2].is_reorg - - # Verify reorg deleted the second batch - loader.cursor.execute(f'SELECT "id" FROM {test_table_name} ORDER BY "id"') - remaining_ids = [row['id'] for row in loader.cursor.fetchall()] - assert remaining_ids == [1, 2] # 3 and 4 deleted by reorg - - -@pytest.fixture -def snowflake_streaming_config(): - """ - Snowflake Snowpipe Streaming configuration from environment. - - Requires: - - SNOWFLAKE_ACCOUNT: Account identifier - - SNOWFLAKE_USER: Username - - SNOWFLAKE_WAREHOUSE: Warehouse name - - SNOWFLAKE_DATABASE: Database name - - SNOWFLAKE_PRIVATE_KEY: Private key in PEM format (as string) - - SNOWFLAKE_SCHEMA: Schema name (optional, defaults to PUBLIC) - - SNOWFLAKE_ROLE: Role (optional) - """ - import os - - config = { - 'account': os.getenv('SNOWFLAKE_ACCOUNT', 'test_account'), - 'user': os.getenv('SNOWFLAKE_USER', 'test_user'), - 'warehouse': os.getenv('SNOWFLAKE_WAREHOUSE', 'test_warehouse'), - 'database': os.getenv('SNOWFLAKE_DATABASE', 'test_database'), - 'schema': os.getenv('SNOWFLAKE_SCHEMA', 'PUBLIC'), - 'loading_method': 'snowpipe_streaming', - 'streaming_channel_prefix': 'test_amp', - 'streaming_max_retries': 3, - 'streaming_buffer_flush_interval': 1, - } - - # Private key is required for Snowpipe Streaming - if os.getenv('SNOWFLAKE_PRIVATE_KEY'): - config['private_key'] = os.getenv('SNOWFLAKE_PRIVATE_KEY') - else: - pytest.skip('Snowpipe Streaming requires SNOWFLAKE_PRIVATE_KEY environment variable') - - if os.getenv('SNOWFLAKE_ROLE'): - config['role'] = os.getenv('SNOWFLAKE_ROLE') - - return config - - -@pytest.mark.integration -@pytest.mark.snowflake -class TestSnowpipeStreamingIntegration: - """Integration tests for Snowpipe Streaming functionality""" - - def test_streaming_connection(self, snowflake_streaming_config): - """Test connection with Snowpipe Streaming enabled""" - loader = SnowflakeLoader(snowflake_streaming_config) - - loader.connect() - assert loader._is_connected is True - assert loader.connection is not None - # Streaming channels dict is initialized empty (channels created on first load) - assert hasattr(loader, 'streaming_channels') - - loader.disconnect() - assert loader._is_connected is False - - def test_basic_streaming_batch_load( - self, snowflake_streaming_config, small_test_table, test_table_name, cleanup_tables - ): - """Test basic batch loading via Snowpipe Streaming""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Load first batch - batch = small_test_table.to_batches(max_chunksize=50)[0] - result = loader.load_batch(batch, test_table_name, create_table=True) - - assert result.success is True - assert result.rows_loaded == batch.num_rows - assert result.table_name == test_table_name - assert result.metadata['loading_method'] == 'snowpipe_streaming' - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, batch.num_rows) - assert count == batch.num_rows - - def test_streaming_multiple_batches( - self, snowflake_streaming_config, medium_test_table, test_table_name, cleanup_tables - ): - """Test loading multiple batches via Snowpipe Streaming""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Load multiple batches - total_rows = 0 - for i, batch in enumerate(medium_test_table.to_batches(max_chunksize=1000)): - result = loader.load_batch(batch, test_table_name, create_table=(i == 0)) - assert result.success is True - total_rows += result.rows_loaded - - assert total_rows == medium_test_table.num_rows - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, medium_test_table.num_rows) - assert count == medium_test_table.num_rows - - def test_streaming_channel_management( - self, snowflake_streaming_config, small_test_table, test_table_name, cleanup_tables - ): - """Test that channels are created and reused properly""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Load batches with same channel suffix - batch = small_test_table.to_batches(max_chunksize=50)[0] - - result1 = loader.load_batch(batch, test_table_name, create_table=True, channel_suffix='partition_0') - assert result1.success is True - - result2 = loader.load_batch(batch, test_table_name, channel_suffix='partition_0') - assert result2.success is True - - # Verify channel was reused (check loader's channel cache) - channel_key = f'{test_table_name}:test_amp_{test_table_name}_partition_0' - assert channel_key in loader.streaming_channels - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, batch.num_rows * 2) - assert count == batch.num_rows * 2 - - def test_streaming_multiple_partitions( - self, snowflake_streaming_config, small_test_table, test_table_name, cleanup_tables - ): - """Test parallel streaming with multiple partition channels""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - batch = small_test_table.to_batches(max_chunksize=30)[0] - - # Load to different partitions - result1 = loader.load_batch(batch, test_table_name, create_table=True, channel_suffix='partition_0') - result2 = loader.load_batch(batch, test_table_name, channel_suffix='partition_1') - result3 = loader.load_batch(batch, test_table_name, channel_suffix='partition_2') - - assert result1.success and result2.success and result3.success - - # Verify multiple channels created - assert len(loader.streaming_channels) == 3 - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, batch.num_rows * 3) - assert count == batch.num_rows * 3 - - def test_streaming_data_types( - self, snowflake_streaming_config, comprehensive_test_data, test_table_name, cleanup_tables - ): - """Test Snowpipe Streaming with various data types""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - result = loader.load_table(comprehensive_test_data, test_table_name, create_table=True) - assert result.success is True - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, comprehensive_test_data.num_rows) - assert count == comprehensive_test_data.num_rows - - # Verify specific row - loader.cursor.execute(f'SELECT * FROM {test_table_name} WHERE "id" = 0') - row = loader.cursor.fetchone() - assert row['id'] == 0 - - def test_streaming_null_handling(self, snowflake_streaming_config, null_test_data, test_table_name, cleanup_tables): - """Test Snowpipe Streaming with NULL values""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - result = loader.load_table(null_test_data, test_table_name, create_table=True) - assert result.success is True - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - wait_for_snowpipe_data(loader, test_table_name, null_test_data.num_rows) - - # Verify NULL handling - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE "text_field" IS NULL') - null_count = loader.cursor.fetchone()['COUNT(*)'] - expected_nulls = sum(1 for val in null_test_data.column('text_field').to_pylist() if val is None) - assert null_count == expected_nulls - - def test_streaming_reorg_channel_closure(self, snowflake_streaming_config, test_table_name, cleanup_tables): - """Test that reorg properly closes streaming channels""" - import json - - from src.amp.streaming.types import BlockRange - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Load initial data with multiple channels - batch = pa.RecordBatch.from_pydict( - { - 'id': [1, 2, 3], - 'value': [100, 200, 300], - '_meta_block_ranges': [json.dumps([{'network': 'ethereum', 'start': 100, 'end': 110}])] * 3, - } - ) - - loader.load_batch(batch, test_table_name, create_table=True, channel_suffix='partition_0') - loader.load_batch(batch, test_table_name, channel_suffix='partition_1') - - # Verify channels exist - assert len(loader.streaming_channels) == 2 - - # Wait for data to be queryable - time.sleep(5) - - # Trigger reorg - invalidation_ranges = [BlockRange(network='ethereum', start=100, end=200)] - loader._handle_reorg(invalidation_ranges, test_table_name, 'test_connection') - - # Verify channels were closed - assert len(loader.streaming_channels) == 0 - - # Verify data was deleted - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 0 - - @pytest.mark.slow - def test_streaming_performance( - self, snowflake_streaming_config, performance_test_data, test_table_name, cleanup_tables - ): - """Test Snowpipe Streaming performance with larger dataset""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - start_time = time.time() - result = loader.load_table(performance_test_data, test_table_name, create_table=True) - duration = time.time() - start_time - - assert result.success is True - assert result.rows_loaded == performance_test_data.num_rows - - rows_per_second = result.rows_loaded / duration - - print('\nSnowpipe Streaming Performance:') - print(f' Total rows: {result.rows_loaded:,}') - print(f' Duration: {duration:.2f}s') - print(f' Throughput: {rows_per_second:,.0f} rows/sec') - print(f' Loading method: {result.metadata.get("loading_method")}') - - # Wait for Snowpipe streaming data to become queryable - # (eventual consistency, larger dataset may take longer) - count = wait_for_snowpipe_data(loader, test_table_name, performance_test_data.num_rows, max_wait=60) - assert count == performance_test_data.num_rows - - def test_streaming_error_handling(self, snowflake_streaming_config, test_table_name, cleanup_tables): - """Test error handling in Snowpipe Streaming""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Create table first - initial_data = pa.table({'id': [1, 2, 3], 'value': [100, 200, 300]}) - result = loader.load_table(initial_data, test_table_name, create_table=True) - assert result.success is True - - # Try to load data with extra column (Snowpipe streaming handles gracefully) - # Note: Snowpipe streaming accepts data with extra columns and silently ignores them - incompatible_data = pa.RecordBatch.from_pydict( - { - 'id': [4, 5], - 'different_column': ['a', 'b'], # Extra column not in table schema - } - ) - - result = loader.load_batch(incompatible_data, test_table_name) - # Snowpipe streaming handles this gracefully - it loads the matching columns - # and ignores columns that don't exist in the table - assert result.success is True - assert result.rows_loaded == 2 - - def test_microbatch_deduplication(self, snowflake_config, test_table_name, cleanup_tables): - """ - Test that multiple RecordBatches within the same microbatch are all loaded, - and deduplication only happens at microbatch boundaries when ranges_complete=True. - - This test verifies the fix for the critical bug where we were marking batches - as processed after every RecordBatch instead of waiting for ranges_complete=True. - """ - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - cleanup_tables.append(test_table_name) - - # Enable state management to test deduplication - config_with_state = { - **snowflake_config, - 'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True}, - } - loader = SnowflakeLoader(config_with_state) - - with loader: - # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange - # This happens when the server sends large microbatches in smaller chunks - - # First RecordBatch of the microbatch (ranges_complete=False) - batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - response1 = ResponseBatch.data_batch( - data=batch1_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], - ranges_complete=False, # Not the last batch in this microbatch - ), - ) - - # Second RecordBatch of the microbatch (ranges_complete=False) - batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - response2 = ResponseBatch.data_batch( - data=batch2_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! - ranges_complete=False, # Still not the last batch - ), - ) - - # Third RecordBatch of the microbatch (ranges_complete=True) - batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) - response3 = ResponseBatch.data_batch( - data=batch3_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! - ranges_complete=True, # Last batch in this microbatch - safe to mark as processed - ), - ) - - # Process the microbatch stream - stream = [response1, response2, response3] - results = list( - loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') - ) - - # CRITICAL: All 3 RecordBatches should be loaded successfully - # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates") - assert len(results) == 3, 'All RecordBatches within microbatch should be processed' - assert all(r.success for r in results), 'All batches should succeed' - assert results[0].rows_loaded == 2, 'First batch should load 2 rows' - assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' - assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' - - # Verify total rows in table (all batches loaded) - loader.cursor.execute(f'SELECT COUNT(*) as count FROM {test_table_name}') - total_count = loader.cursor.fetchone()['COUNT'] - assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' - - # Verify the actual IDs are present - loader.cursor.execute(f'SELECT "id" FROM {test_table_name} ORDER BY "id"') - all_ids = [row['id'] for row in loader.cursor.fetchall()] - assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present' - - # Now test that re-sending the complete microbatch is properly deduplicated - # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch) - duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) - duplicate_response = ResponseBatch.data_batch( - data=duplicate_batch, - metadata=BatchMetadata( - ranges=[ - BlockRange(network='ethereum', start=100, end=110, hash='0xabc123') - ], # Same range as before! - ranges_complete=True, # Complete microbatch - ), - ) - - # Process duplicate microbatch - duplicate_results = list( - loader.load_stream_continuous( - iter([duplicate_response]), test_table_name, connection_name='test_connection' - ) - ) - - # The duplicate microbatch should be skipped (already processed) - assert len(duplicate_results) == 1 - assert duplicate_results[0].success is True - assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' - assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate' - - # Verify row count unchanged (duplicate was skipped) - loader.cursor.execute(f'SELECT COUNT(*) as count FROM {test_table_name}') - final_count = loader.cursor.fetchone()['COUNT'] - assert final_count == 6, 'Row count should not increase after duplicate microbatch' From 6d44011bded6cc1fdc80eb7d1422ba0bea90e1b6 Mon Sep 17 00:00:00 2001 From: Ford Date: Thu, 8 Jan 2026 19:30:32 -0800 Subject: [PATCH 5/6] Update loader implementation docs to explain new shared test structure --- docs/implementing_data_loaders.md | 226 ++++++++++++++++++++++++------ 1 file changed, 181 insertions(+), 45 deletions(-) diff --git a/docs/implementing_data_loaders.md b/docs/implementing_data_loaders.md index 3699532..20ab942 100644 --- a/docs/implementing_data_loaders.md +++ b/docs/implementing_data_loaders.md @@ -305,71 +305,207 @@ def _get_table_metadata(self, table: pa.Table, duration: float, batch_count: int ## Testing -### Integration Test Structure +### Generalized Test Infrastructure -Create integration tests in `tests/integration/test_{system}_loader.py`: +The project uses a generalized test infrastructure that eliminates code duplication across loader tests. Instead of writing standalone tests for each loader, you inherit from shared base test classes. + +### Architecture + +``` +tests/integration/loaders/ +├── conftest.py # Base classes and fixtures +├── test_base_loader.py # 7 core tests (all loaders inherit) +├── test_base_streaming.py # 5 streaming tests (for loaders with reorg support) +└── backends/ + ├── test_postgresql.py # PostgreSQL-specific config + tests + ├── test_redis.py # Redis-specific config + tests + └── test_example.py # Your loader tests here +``` + +### Step 1: Create Configuration Fixture + +Add your loader's configuration fixture to `tests/conftest.py`: + +```python +@pytest.fixture(scope='session') +def example_test_config(request): + """Example loader configuration from testcontainer or environment""" + # Use testcontainers for CI, or fall back to environment variables + if TESTCONTAINERS_AVAILABLE and USE_TESTCONTAINERS: + # Set up testcontainer (if applicable) + example_container = request.getfixturevalue('example_container') + return { + 'host': example_container.get_container_host_ip(), + 'port': example_container.get_exposed_port(5432), + 'database': 'test_db', + 'user': 'test_user', + 'password': 'test_pass', + } + else: + # Fall back to environment variables + return { + 'host': os.getenv('EXAMPLE_HOST', 'localhost'), + 'port': int(os.getenv('EXAMPLE_PORT', '5432')), + 'database': os.getenv('EXAMPLE_DB', 'test_db'), + 'user': os.getenv('EXAMPLE_USER', 'test_user'), + 'password': os.getenv('EXAMPLE_PASSWORD', 'test_pass'), + } +``` + +### Step 2: Create Test Configuration Class + +Create `tests/integration/loaders/backends/test_example.py`: ```python -# tests/integration/test_example_loader.py +""" +Example loader integration tests using generalized test infrastructure. +""" +from typing import Any, Dict, List, Optional import pytest -import pyarrow as pa -from src.amp.loaders.base import LoadMode + from src.amp.loaders.implementations.example_loader import ExampleLoader +from tests.integration.loaders.conftest import LoaderTestConfig +from tests.integration.loaders.test_base_loader import BaseLoaderTests +from tests.integration.loaders.test_base_streaming import BaseStreamingTests + + +class ExampleTestConfig(LoaderTestConfig): + """Example-specific test configuration""" + + loader_class = ExampleLoader + config_fixture_name = 'example_test_config' + + # Declare loader capabilities + supports_overwrite = True + supports_streaming = True # Set to False if no streaming support + supports_multi_network = True # For blockchain loaders with reorg + supports_null_values = True + + def get_row_count(self, loader: ExampleLoader, table_name: str) -> int: + """Get row count from table""" + # Implement using your loader's API + return loader._connection.query(f"SELECT COUNT(*) FROM {table_name}")[0]['count'] + + def query_rows( + self, + loader: ExampleLoader, + table_name: str, + where: Optional[str] = None, + order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from table""" + query = f"SELECT * FROM {table_name}" + if where: + query += f" WHERE {where}" + if order_by: + query += f" ORDER BY {order_by}" + return loader._connection.query(query) + + def cleanup_table(self, loader: ExampleLoader, table_name: str) -> None: + """Drop table""" + loader._connection.execute(f"DROP TABLE IF EXISTS {table_name}") + + def get_column_names(self, loader: ExampleLoader, table_name: str) -> List[str]: + """Get column names from table""" + result = loader._connection.query( + f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}'" + ) + return [row['column_name'] for row in result] -@pytest.fixture -def example_config(): - return { - 'host': 'localhost', - 'port': 5432, - 'database': 'test_db', - 'user': 'test_user', - 'password': 'test_pass' - } -@pytest.fixture -def test_data(): - return pa.Table.from_pydict({ - 'id': [1, 2, 3], - 'name': ['a', 'b', 'c'], - 'value': [1.0, 2.0, 3.0] - }) +# Core tests - ALL loaders must inherit these +class TestExampleCore(BaseLoaderTests): + """Inherits 7 core tests: connection, context manager, batching, modes, null handling, errors""" + config = ExampleTestConfig() + +# Streaming tests - Only for loaders with streaming/reorg support +class TestExampleStreaming(BaseStreamingTests): + """Inherits 5 streaming tests: metadata columns, reorg deletion, overlapping ranges, multi-network, microbatch dedup""" + config = ExampleTestConfig() + + +# Loader-specific tests @pytest.mark.integration @pytest.mark.example -class TestExampleLoaderIntegration: - def test_connection(self, example_config): - loader = ExampleLoader(example_config) - - loader.connect() - assert loader.is_connected - - loader.disconnect() - assert not loader.is_connected - - def test_basic_loading(self, example_config, test_data): - loader = ExampleLoader(example_config) - +class TestExampleSpecific: + """Example-specific functionality tests""" + config = ExampleTestConfig() + + def test_custom_feature(self, loader, test_table_name, cleanup_tables): + """Test example-specific functionality""" + cleanup_tables.append(test_table_name) + with loader: - result = loader.load_table(test_data, 'test_table') - + # Test your loader's unique features + result = loader.some_custom_method(test_table_name) assert result.success - assert result.rows_loaded == 3 - assert result.metadata['operation'] == 'load_table' - assert result.metadata['batches_processed'] > 0 +``` + +### What You Get Automatically + +By inheriting from the base test classes, you automatically get: + +**From `BaseLoaderTests` (7 core tests):** +- `test_connection` - Connection establishment and disconnection +- `test_context_manager` - Context manager functionality +- `test_batch_loading` - Basic batch loading +- `test_append_mode` - Append mode operations +- `test_overwrite_mode` - Overwrite mode operations +- `test_null_handling` - Null value handling +- `test_error_handling` - Error scenarios + +**From `BaseStreamingTests` (5 streaming tests):** +- `test_streaming_metadata_columns` - Metadata column creation +- `test_reorg_deletion` - Blockchain reorganization handling +- `test_reorg_overlapping_ranges` - Overlapping range invalidation +- `test_reorg_multi_network` - Multi-network reorg isolation +- `test_microbatch_deduplication` - Microbatch duplicate detection + +### Required LoaderTestConfig Methods + +You must implement these four methods in your `LoaderTestConfig` subclass: + +```python +def get_row_count(self, loader, table_name: str) -> int: + """Return number of rows in table""" + +def query_rows(self, loader, table_name: str, where=None, order_by=None) -> List[Dict]: + """Query and return rows as list of dicts""" + +def cleanup_table(self, loader, table_name: str) -> None: + """Drop/delete the table""" + +def get_column_names(self, loader, table_name: str) -> List[str]: + """Return list of column names""" +``` + +### Capability Flags + +Set these flags in your `LoaderTestConfig` to control which tests run: + +```python +supports_overwrite = True # Can overwrite existing data +supports_streaming = True # Supports streaming with metadata +supports_multi_network = True # Supports multi-network isolation (blockchain loaders) +supports_null_values = True # Handles NULL values correctly ``` ### Running Tests ```bash -# Run all integration tests -make test-integration +# Run all tests for your loader +uv run pytest tests/integration/loaders/backends/test_example.py -v + +# Run only core tests +uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleCore -v -# Run specific loader tests -make test-example +# Run only streaming tests +uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleStreaming -v -# Run with environment variables -uv run --env-file .test.env pytest tests/integration/test_example_loader.py -v +# Run specific test +uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleCore::test_connection -v ``` ## Best Practices From f910b03e4c7c35cd2f2ba6af2e2101aedca56948 Mon Sep 17 00:00:00 2001 From: Ford Date: Fri, 9 Jan 2026 09:46:27 -0800 Subject: [PATCH 6/6] Fix linting/formatting issues in loader test files --- docs/implementing_data_loaders.md | 2 -- .../loaders/backends/test_deltalake.py | 3 +-- .../integration/loaders/backends/test_iceberg.py | 10 +++------- tests/integration/loaders/backends/test_lmdb.py | 12 +++++------- .../loaders/backends/test_postgresql.py | 5 +++-- tests/integration/loaders/backends/test_redis.py | 5 ++--- .../loaders/backends/test_snowflake.py | 3 +-- tests/integration/loaders/conftest.py | 6 +----- tests/integration/loaders/test_base_streaming.py | 16 ++++++++++++---- 9 files changed, 28 insertions(+), 34 deletions(-) diff --git a/docs/implementing_data_loaders.md b/docs/implementing_data_loaders.md index 20ab942..38f791e 100644 --- a/docs/implementing_data_loaders.md +++ b/docs/implementing_data_loaders.md @@ -781,5 +781,3 @@ class KeyValueLoader(DataLoader[KeyValueConfig]): 'database': self.config.database } ``` - -This documentation provides everything needed to implement new data loaders efficiently and consistently! \ No newline at end of file diff --git a/tests/integration/loaders/backends/test_deltalake.py b/tests/integration/loaders/backends/test_deltalake.py index f2374c4..cd6281b 100644 --- a/tests/integration/loaders/backends/test_deltalake.py +++ b/tests/integration/loaders/backends/test_deltalake.py @@ -91,7 +91,6 @@ class TestDeltaLakeSpecific: def test_partitioning(self, delta_partitioned_config, small_test_data): """Test DeltaLake partitioning functionality""" - import pyarrow as pa loader = DeltaLakeLoader(delta_partitioned_config) @@ -148,7 +147,7 @@ def test_schema_evolution(self, delta_basic_config, small_test_data): # Load with new schema (if schema evolution enabled) from src.amp.loaders.base import LoadMode - result2 = loader.load_table(extended_table, 'test_table', mode=LoadMode.APPEND) + loader.load_table(extended_table, 'test_table', mode=LoadMode.APPEND) # Result depends on merge_schema configuration dt = DeltaTable(loader.config.table_path) diff --git a/tests/integration/loaders/backends/test_iceberg.py b/tests/integration/loaders/backends/test_iceberg.py index 543fa06..64376a8 100644 --- a/tests/integration/loaders/backends/test_iceberg.py +++ b/tests/integration/loaders/backends/test_iceberg.py @@ -10,9 +10,6 @@ import pytest try: - from pyiceberg.catalog import load_catalog - from pyiceberg.schema import Schema - from src.amp.loaders.implementations.iceberg_loader import IcebergLoader from tests.integration.loaders.conftest import LoaderTestConfig from tests.integration.loaders.test_base_loader import BaseLoaderTests @@ -102,7 +99,6 @@ def test_catalog_initialization(self, iceberg_basic_config): def test_partitioning(self, iceberg_basic_config, small_test_data): """Test Iceberg partitioning (partition spec)""" - import pyarrow as pa # Create config with partitioning config = {**iceberg_basic_config, 'partition_spec': [('year', 'identity'), ('month', 'identity')]} @@ -147,7 +143,7 @@ def test_schema_evolution(self, iceberg_basic_config, small_test_data): extended_table = pa.Table.from_pydict(extended_data) # Load with new schema - result2 = loader.load_table(extended_table, table_name, mode=LoadMode.APPEND) + loader.load_table(extended_table, table_name, mode=LoadMode.APPEND) # Schema evolution depends on config catalog = loader._catalog @@ -161,9 +157,10 @@ def test_schema_evolution(self, iceberg_basic_config, small_test_data): def test_timestamp_conversion(self, iceberg_basic_config): """Test timestamp conversion for Iceberg""" - import pyarrow as pa from datetime import datetime + import pyarrow as pa + # Create data with timestamps data = { 'id': [1, 2, 3], @@ -193,7 +190,6 @@ def test_timestamp_conversion(self, iceberg_basic_config): def test_multiple_tables(self, iceberg_basic_config, small_test_data): """Test managing multiple tables with same loader""" - from src.amp.loaders.base import LoadMode loader = IcebergLoader(iceberg_basic_config) diff --git a/tests/integration/loaders/backends/test_lmdb.py b/tests/integration/loaders/backends/test_lmdb.py index aac0000..abf0d19 100644 --- a/tests/integration/loaders/backends/test_lmdb.py +++ b/tests/integration/loaders/backends/test_lmdb.py @@ -11,8 +11,6 @@ import pytest try: - import lmdb - from src.amp.loaders.implementations.lmdb_loader import LMDBLoader from tests.integration.loaders.conftest import LoaderTestConfig from tests.integration.loaders.test_base_loader import BaseLoaderTests @@ -37,7 +35,7 @@ def get_row_count(self, loader: LMDBLoader, table_name: str) -> int: count = 0 with loader.env.begin(db=loader.db) as txn: cursor = txn.cursor() - for key, value in cursor: + for _key, _value in cursor: count += 1 return count @@ -58,7 +56,7 @@ def query_rows( row_data = json.loads(value.decode()) row_data['_key'] = key.decode() rows.append(row_data) - except: + except Exception: # Fallback to raw value rows.append({'_key': key.decode(), '_value': value.decode()}) return rows @@ -79,11 +77,11 @@ def get_column_names(self, loader: LMDBLoader, table_name: str) -> List[str]: with loader.env.begin(db=loader.db) as txn: cursor = txn.cursor() - for key, value in cursor: + for _key, value in cursor: try: row_data = json.loads(value.decode()) return list(row_data.keys()) - except: + except Exception: return ['_value'] # Fallback return [] @@ -91,8 +89,8 @@ def get_column_names(self, loader: LMDBLoader, table_name: str) -> List[str]: @pytest.fixture def lmdb_test_env(): """Create and cleanup temporary directory for LMDB databases""" - import tempfile import shutil + import tempfile temp_dir = tempfile.mkdtemp(prefix='lmdb_test_') yield temp_dir diff --git a/tests/integration/loaders/backends/test_postgresql.py b/tests/integration/loaders/backends/test_postgresql.py index 163afb6..765223d 100644 --- a/tests/integration/loaders/backends/test_postgresql.py +++ b/tests/integration/loaders/backends/test_postgresql.py @@ -70,7 +70,7 @@ def query_rows( rows = cur.fetchall() # Convert to list of dicts - return [dict(zip(columns, row)) for row in rows] + return [dict(zip(columns, row, strict=False)) for row in rows] finally: loader.pool.putconn(conn) @@ -318,9 +318,10 @@ class TestPostgreSQLPerformance: def test_large_data_loading(self, postgresql_test_config, test_table_name, cleanup_tables): """Test loading large datasets""" - import pyarrow as pa from datetime import datetime + import pyarrow as pa + cleanup_tables.append(test_table_name) # Create large dataset diff --git a/tests/integration/loaders/backends/test_redis.py b/tests/integration/loaders/backends/test_redis.py index b25e7e2..0a52e0b 100644 --- a/tests/integration/loaders/backends/test_redis.py +++ b/tests/integration/loaders/backends/test_redis.py @@ -54,7 +54,7 @@ def get_row_count(self, loader: RedisLoader, table_name: str) -> int: # Get stream length try: return loader.redis_client.xlen(table_name) - except: + except Exception: return 0 else: # For other structures, scan for keys @@ -166,7 +166,6 @@ class TestRedisSpecific: def test_hash_storage(self, redis_test_config, small_test_data, cleanup_redis): """Test Redis hash data structure storage""" - import pyarrow as pa keys_to_clean, patterns_to_clean = cleanup_redis table_name = 'test_hash' @@ -343,7 +342,7 @@ def test_data_structure_comparison(self, redis_test_config, comprehensive_test_d } # Verify all structures work - for structure, data in results.items(): + for _structure, data in results.items(): assert data['success'] == True assert data['rows_loaded'] == 1000 diff --git a/tests/integration/loaders/backends/test_snowflake.py b/tests/integration/loaders/backends/test_snowflake.py index 6f42e3f..6ef02a8 100644 --- a/tests/integration/loaders/backends/test_snowflake.py +++ b/tests/integration/loaders/backends/test_snowflake.py @@ -53,7 +53,7 @@ def query_rows( cur.execute(query) columns = [col[0] for col in cur.description] rows = cur.fetchall() - return [dict(zip(columns, row)) for row in rows] + return [dict(zip(columns, row, strict=False)) for row in rows] def cleanup_table(self, loader: SnowflakeLoader, table_name: str) -> None: """Drop Snowflake table""" @@ -87,7 +87,6 @@ class TestSnowflakeSpecific: def test_stage_loading_method(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): """Test Snowflake stage-based loading (Snowflake-specific optimization)""" - import pyarrow as pa cleanup_tables.append(test_table_name) diff --git a/tests/integration/loaders/conftest.py b/tests/integration/loaders/conftest.py index 7e078ad..26ce3a1 100644 --- a/tests/integration/loaders/conftest.py +++ b/tests/integration/loaders/conftest.py @@ -141,11 +141,7 @@ def loader(self, request): loader_config = loader_config.copy() # Only enable state if not explicitly configured if 'state' not in loader_config: - loader_config['state'] = { - 'enabled': True, - 'storage': 'memory', - 'store_batch_id': True - } + loader_config['state'] = {'enabled': True, 'storage': 'memory', 'store_batch_id': True} # Create and return the loader instance return self.config.loader_class(loader_config) diff --git a/tests/integration/loaders/test_base_streaming.py b/tests/integration/loaders/test_base_streaming.py index 76768f8..2ceb4d8 100644 --- a/tests/integration/loaders/test_base_streaming.py +++ b/tests/integration/loaders/test_base_streaming.py @@ -138,7 +138,9 @@ def test_reorg_deletion(self, loader, test_table_name, cleanup_tables): reorg_response = ResponseBatch.reorg_batch( invalidation_ranges=[BlockRange(network='ethereum', start=104, end=108)] ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn')) + reorg_results = list( + loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn') + ) assert len(reorg_results) == 1 assert reorg_results[0].success @@ -174,7 +176,9 @@ def test_reorg_overlapping_ranges(self, loader, test_table_name, cleanup_tables) ) # Load via streaming API (with connection_name for state tracking) - results = list(loader.load_stream_continuous(iter([response]), test_table_name, connection_name='test_conn')) + results = list( + loader.load_stream_continuous(iter([response]), test_table_name, connection_name='test_conn') + ) assert len(results) == 1 assert results[0].success @@ -187,7 +191,9 @@ def test_reorg_overlapping_ranges(self, loader, test_table_name, cleanup_tables) reorg_response = ResponseBatch.reorg_batch( invalidation_ranges=[BlockRange(network='ethereum', start=160, end=180)] ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn')) + reorg_results = list( + loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn') + ) assert len(reorg_results) == 1 assert reorg_results[0].success @@ -248,7 +254,9 @@ def test_reorg_multi_network(self, loader, test_table_name, cleanup_tables): reorg_response = ResponseBatch.reorg_batch( invalidation_ranges=[BlockRange(network='ethereum', start=100, end=100)] ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn')) + reorg_results = list( + loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn') + ) assert len(reorg_results) == 1 assert reorg_results[0].success