diff --git a/atompack-py/python/atompack/__init__.py b/atompack-py/python/atompack/__init__.py index 34b97bb..3e83c7c 100644 --- a/atompack-py/python/atompack/__init__.py +++ b/atompack-py/python/atompack/__init__.py @@ -37,12 +37,35 @@ `Database.open(path, mmap=False)` if you want to append molecules. """ +from typing import Any, Iterator + from . import hub from ._atompack_rs import PyAtom as Atom from ._atompack_rs import PyAtomDatabase as Database from ._atompack_rs import PyMolecule as Molecule from .ase_bridge import add_ase_batch, from_ase, to_ase, to_ase_batch + +def _database_iter_batches( + database: Database, + batch_size: int, + *, + flat: bool = False, + drop_last: bool = False, +) -> Iterator[list[Molecule] | dict[str, Any]]: + if batch_size <= 0: + raise ValueError("batch_size must be a positive integer") + + getter = database.get_molecules_flat if flat else database.get_molecules + for start in range(0, len(database), batch_size): + stop = min(start + batch_size, len(database)) + if drop_last and stop - start < batch_size: + break + yield getter(list(range(start, stop))) + + +Database.iter_batches = _database_iter_batches + __version__ = "0.2.1" __all__ = [ "Atom", diff --git a/atompack-py/python/atompack/__init__.pyi b/atompack-py/python/atompack/__init__.pyi index a28c97e..f04243d 100644 --- a/atompack-py/python/atompack/__init__.pyi +++ b/atompack-py/python/atompack/__init__.pyi @@ -1,6 +1,6 @@ """Type stubs for atompack""" -from typing import Any, Sequence, overload +from typing import Any, Iterator, Sequence, overload import numpy as np import numpy.typing as npt @@ -569,6 +569,20 @@ class Database: ``properties`` and ``atom_properties`` dictionaries when present. """ ... + def iter_batches( + self, + batch_size: int, + *, + flat: bool = False, + drop_last: bool = False, + ) -> Iterator[list[Molecule] | dict[str, Any]]: + """ + Yield contiguous batches from the database. + + Set ``flat=True`` to yield ``get_molecules_flat`` payloads instead of + materialized Molecule objects. + """ + ... def to_ase_batch( self, indices: list[int] | None = None, diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index 4d83dfd..d022035 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -324,6 +324,37 @@ def test_database_add_arrays_batch_promotes_to_float64_geometry_when_needed( assert flat["forces"].dtype == np.float64 +def test_database_iter_batches_supports_object_and_flat_batches(tmp_path: Path) -> None: + path = tmp_path / "iter_batches.atp" + db = atompack.Database(str(path)) + db.add_molecules( + [ + _make_molecule(-1.0), + _make_molecule(-2.0), + _make_molecule(-3.0), + _make_molecule(-4.0), + _make_molecule(-5.0), + ] + ) + db.flush() + + reopened = atompack.Database.open(str(path)) + object_batches = list(reopened.iter_batches(2)) + assert [[m.energy for m in batch] for batch in object_batches] == [ + [-1.0, -2.0], + [-3.0, -4.0], + [-5.0], + ] + + flat_batches = list(reopened.iter_batches(2, flat=True, drop_last=True)) + assert len(flat_batches) == 2 + np.testing.assert_allclose(flat_batches[0]["energy"], np.array([-1.0, -2.0])) + np.testing.assert_allclose(flat_batches[1]["energy"], np.array([-3.0, -4.0])) + + with pytest.raises(ValueError, match="positive"): + list(reopened.iter_batches(0)) + + @pytest.mark.parametrize("mmap", [False, True]) @pytest.mark.parametrize("compression", ["none", "lz4", "zstd"]) def test_database_single_item_reads_are_view_compatible( diff --git a/atompack-py/tests/test_stub_surface.py b/atompack-py/tests/test_stub_surface.py index b883fcc..6cf572e 100644 --- a/atompack-py/tests/test_stub_surface.py +++ b/atompack-py/tests/test_stub_surface.py @@ -72,7 +72,7 @@ def test_private_stub_tracks_low_level_surface() -> None: def test_public_stub_exposes_flat_batch_reader() -> None: database_methods = _class_method_names(PUBLIC_STUB, "Database") - assert "get_molecules_flat" in database_methods + assert {"get_molecules_flat", "iter_batches"} <= database_methods def test_hub_stub_has_public_docstrings() -> None: