Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion atompack-py/python/atompack/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,18 @@ def __len__(self) -> int:
self._ensure_open()
return self._total_length

def __getitem__(self, index: int) -> Molecule:
def __getitem__(self, index: int | slice) -> Molecule | list[Molecule]:
if isinstance(index, slice):
self._ensure_open()
start, stop, step = index.indices(self._total_length)
return self.get_molecules(list(range(start, stop, step)))
return self.get_molecule(index)

def __iter__(self):
self._ensure_open()
for index in range(self._total_length):
yield self.get_molecule(index)

def get_molecule(self, index: int) -> Molecule:
db_index, local_index = self._locate(index)
return self._databases[db_index][local_index]
Expand Down
8 changes: 7 additions & 1 deletion atompack-py/python/atompack/hub.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from __future__ import annotations

from pathlib import Path
from types import TracebackType
from typing import Any, Sequence
from typing import Any, Iterator, Sequence, overload

from . import Molecule

Expand All @@ -26,6 +26,7 @@ class AtompackReader:
def __len__(self) -> int:
"""Return the total number of molecules across all opened files."""
...
@overload
def __getitem__(self, index: int) -> Molecule:
"""
Fetch one molecule by index.
Expand All @@ -34,6 +35,8 @@ class AtompackReader:
dataset, not within a single shard.
"""
...
@overload
def __getitem__(self, index: slice) -> list[Molecule]: ...
def get_molecule(self, index: int) -> Molecule:
"""
Fetch one molecule by global index across the underlying shard set.
Expand Down Expand Up @@ -75,6 +78,9 @@ class AtompackReader:
def close(self) -> None:
"""Close the underlying databases and invalidate the reader."""
...
def __iter__(self) -> Iterator[Molecule]:
"""Iterate over molecules in logical reader order."""
...

def download(
repo_id: str,
Expand Down
15 changes: 15 additions & 0 deletions atompack-py/tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,21 @@ def test_open_path_directory_flattens_lexicographically(tmp_path: Path) -> None:
assert [reader[i].energy for i in range(len(reader))] == pytest.approx([-1.0, -2.0, -3.0])


def test_reader_supports_iteration_and_slices(tmp_path: Path) -> None:
shard_dir = tmp_path / "shards"
shard_dir.mkdir()
_make_db(shard_dir / "a.atp", [-1.0, -2.0])
_make_db(shard_dir / "b.atp", [-3.0, -4.0])

reader = atompack.hub.open_path(shard_dir)

assert [molecule.energy for molecule in reader] == pytest.approx([-1.0, -2.0, -3.0, -4.0])
assert [molecule.energy for molecule in reader[1:4:2]] == pytest.approx([-2.0, -4.0])
assert [molecule.energy for molecule in reader[::-1]] == pytest.approx(
[-4.0, -3.0, -2.0, -1.0]
)


def test_open_path_context_manager_closes_reader(tmp_path: Path) -> None:
source = tmp_path / "single.atp"
_make_db(source, [-1.0])
Expand Down
3 changes: 3 additions & 0 deletions atompack-py/tests/test_stub_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def test_public_stub_exposes_flat_batch_reader() -> None:


def test_hub_stub_has_public_docstrings() -> None:
reader_methods = _class_method_names(HUB_STUB, "AtompackReader")
assert {"__getitem__", "__iter__"} <= reader_methods

reader_doc = _class_docstring(HUB_STUB, "AtompackReader") or ""
assert "lexicographically ordered shard set" in reader_doc

Expand Down
Loading