Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions atompack-py/python/atompack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,35 @@
`Database.open(path, mmap=False)` if you want to append molecules.
"""

from typing import Any, Iterator

from . import hub
from ._atompack_rs import PyAtom as Atom
from ._atompack_rs import PyAtomDatabase as Database
from ._atompack_rs import PyMolecule as Molecule
from .ase_bridge import add_ase_batch, from_ase, to_ase, to_ase_batch


def _database_iter_batches(
database: Database,
batch_size: int,
*,
flat: bool = False,
drop_last: bool = False,
) -> Iterator[list[Molecule] | dict[str, Any]]:
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer")

getter = database.get_molecules_flat if flat else database.get_molecules
for start in range(0, len(database), batch_size):
stop = min(start + batch_size, len(database))
if drop_last and stop - start < batch_size:
break
yield getter(list(range(start, stop)))


Database.iter_batches = _database_iter_batches

__version__ = "0.2.1"
__all__ = [
"Atom",
Expand Down
16 changes: 15 additions & 1 deletion atompack-py/python/atompack/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Type stubs for atompack"""

from typing import Any, Sequence, overload
from typing import Any, Iterator, Sequence, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -569,6 +569,20 @@ class Database:
``properties`` and ``atom_properties`` dictionaries when present.
"""
...
def iter_batches(
self,
batch_size: int,
*,
flat: bool = False,
drop_last: bool = False,
) -> Iterator[list[Molecule] | dict[str, Any]]:
"""
Yield contiguous batches from the database.

Set ``flat=True`` to yield ``get_molecules_flat`` payloads instead of
materialized Molecule objects.
"""
...
def to_ase_batch(
self,
indices: list[int] | None = None,
Expand Down
31 changes: 31 additions & 0 deletions atompack-py/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,37 @@ def test_database_add_arrays_batch_promotes_to_float64_geometry_when_needed(
assert flat["forces"].dtype == np.float64


def test_database_iter_batches_supports_object_and_flat_batches(tmp_path: Path) -> None:
path = tmp_path / "iter_batches.atp"
db = atompack.Database(str(path))
db.add_molecules(
[
_make_molecule(-1.0),
_make_molecule(-2.0),
_make_molecule(-3.0),
_make_molecule(-4.0),
_make_molecule(-5.0),
]
)
db.flush()

reopened = atompack.Database.open(str(path))
object_batches = list(reopened.iter_batches(2))
assert [[m.energy for m in batch] for batch in object_batches] == [
[-1.0, -2.0],
[-3.0, -4.0],
[-5.0],
]

flat_batches = list(reopened.iter_batches(2, flat=True, drop_last=True))
assert len(flat_batches) == 2
np.testing.assert_allclose(flat_batches[0]["energy"], np.array([-1.0, -2.0]))
np.testing.assert_allclose(flat_batches[1]["energy"], np.array([-3.0, -4.0]))

with pytest.raises(ValueError, match="positive"):
list(reopened.iter_batches(0))


@pytest.mark.parametrize("mmap", [False, True])
@pytest.mark.parametrize("compression", ["none", "lz4", "zstd"])
def test_database_single_item_reads_are_view_compatible(
Expand Down
2 changes: 1 addition & 1 deletion atompack-py/tests/test_stub_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_private_stub_tracks_low_level_surface() -> None:

def test_public_stub_exposes_flat_batch_reader() -> None:
database_methods = _class_method_names(PUBLIC_STUB, "Database")
assert "get_molecules_flat" in database_methods
assert {"get_molecules_flat", "iter_batches"} <= database_methods


def test_hub_stub_has_public_docstrings() -> None:
Expand Down
Loading