Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions atompack-py/python/atompack/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ class Database:
Parameters
----------
index : int
Molecule index (0-based)
Molecule index (0-based). Negative indices are supported.

Returns
-------
Expand All @@ -553,7 +553,7 @@ class Database:
Parameters
----------
indices : list of int
Molecule indices (0-based)
Molecule indices (0-based). Negative indices are supported.

Returns
-------
Expand Down Expand Up @@ -612,7 +612,7 @@ class Database:
Parameters
----------
index : int
Molecule index (0-based)
Molecule index (0-based). Negative indices are supported.

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions atompack-py/python/atompack/_atompack_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ class PyAtomDatabase:
Parameters
----------
index : int
Molecule index (0-based)
Molecule index (0-based). Negative indices are supported.

Returns
-------
Expand All @@ -491,7 +491,7 @@ class PyAtomDatabase:
Parameters
----------
indices : sequence of int
Molecule indices (0-based)
Molecule indices (0-based). Negative indices are supported.

Returns
-------
Expand Down
47 changes: 35 additions & 12 deletions atompack-py/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,31 @@ impl PyAtomDatabase {
SoaMoleculeView::from_owned_bytes(decompressed, ctx)
}

fn normalize_index(&self, index: isize) -> PyResult<usize> {
let len = self.inner.len();
let normalized = if index < 0 {
(len as isize)
.checked_add(index)
.ok_or_else(|| PyIndexError::new_err("index underflow"))?
} else {
index
};
if normalized < 0 || normalized >= len as isize {
return Err(PyIndexError::new_err(format!(
"Index {} out of bounds for database of length {}",
index, len
)));
}
Ok(normalized as usize)
}

fn normalize_indices(&self, indices: Vec<isize>) -> PyResult<Vec<usize>> {
indices
.into_iter()
.map(|index| self.normalize_index(index))
.collect()
}

fn single_molecule_view(&self, py: Python<'_>, index: usize) -> PyResult<SoaMoleculeView> {
let compression = self.inner.compression();
let ctx = self.soa_context()?;
Expand Down Expand Up @@ -273,15 +298,11 @@ impl PyAtomDatabase {
}

/// Get a molecule by index as a lazy view-backed molecule.
fn get_molecule(&self, py: Python<'_>, index: usize) -> PyResult<PyMolecule> {
let len = self.inner.len();
if index >= len {
return Err(PyIndexError::new_err(format!(
"Index {} out of bounds for database of length {}",
index, len
)));
}
Ok(PyMolecule::from_view(self.single_molecule_view(py, index)?))
fn get_molecule(&self, py: Python<'_>, index: isize) -> PyResult<PyMolecule> {
let normalized = self.normalize_index(index)?;
Ok(PyMolecule::from_view(
self.single_molecule_view(py, normalized)?,
))
}

/// Get multiple molecules by indices (parallel batch reading)
Expand All @@ -294,10 +315,11 @@ impl PyAtomDatabase {
///
/// Returns:
/// - List of molecules
fn get_molecules(&self, py: Python<'_>, indices: Vec<usize>) -> PyResult<Vec<PyMolecule>> {
fn get_molecules(&self, py: Python<'_>, indices: Vec<isize>) -> PyResult<Vec<PyMolecule>> {
if indices.is_empty() {
return Ok(Vec::new());
}
let indices = self.normalize_indices(indices)?;
let views = self.molecule_views(py, indices)?;
Ok(views.into_iter().map(PyMolecule::from_view).collect())
}
Expand Down Expand Up @@ -327,8 +349,9 @@ impl PyAtomDatabase {
fn get_molecules_flat<'py>(
&self,
py: Python<'py>,
indices: Vec<usize>,
indices: Vec<isize>,
) -> PyResult<Bound<'py, PyDict>> {
let indices = self.normalize_indices(indices)?;
flat::get_molecules_flat_soa_impl(&self.inner, py, indices)
}

Expand All @@ -338,7 +361,7 @@ impl PyAtomDatabase {
}

/// Enable indexing: db[i]
fn __getitem__(&self, py: Python<'_>, index: usize) -> PyResult<PyMolecule> {
fn __getitem__(&self, py: Python<'_>, index: isize) -> PyResult<PyMolecule> {
self.get_molecule(py, index)
}

Expand Down
32 changes: 25 additions & 7 deletions atompack-py/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,28 @@ def test_database_add_arrays_batch_promotes_to_float64_geometry_when_needed(
assert flat["forces"].dtype == np.float64


def test_database_negative_indices_work_across_read_apis(tmp_path: Path) -> None:
path = tmp_path / "negative_indices.atp"
db = atompack.Database(str(path))
db.add_molecules([_make_molecule(-1.0), _make_molecule(-2.0), _make_molecule(-3.0)])
db.flush()

reopened = atompack.Database.open(str(path))

assert reopened[-1].energy == pytest.approx(-3.0)
assert reopened.get_molecule(-2).energy == pytest.approx(-2.0)
assert [m.energy for m in reopened.get_molecules([-1, 0, -3])] == pytest.approx(
[-3.0, -1.0, -1.0]
)
np.testing.assert_allclose(
reopened.get_molecules_flat([-1, -2])["energy"],
np.array([-3.0, -2.0], dtype=np.float64),
)

with pytest.raises(IndexError, match="out of bounds"):
reopened.get_molecule(-4)


@pytest.mark.parametrize("mmap", [False, True])
@pytest.mark.parametrize("compression", ["none", "lz4", "zstd"])
def test_database_single_item_reads_are_view_compatible(
Expand Down Expand Up @@ -735,19 +757,15 @@ def test_database_open_mmap_populate(tmp_path: Path) -> None:
assert db_r[0].energy == pytest.approx(-3.0)


def test_database_negative_indexing_raises_overflow_error(tmp_path: Path) -> None:
# Database does not support negative indexing today. PyO3 extracts the
# index argument as `usize`, so a negative integer raises OverflowError
# at the FFI boundary. If wraparound semantics are ever added, this
# test will fail loudly so the intent is explicit.
def test_database_negative_indexing_out_of_bounds_raises_index_error(tmp_path: Path) -> None:
path = tmp_path / "negidx.atp"
db = atompack.Database(str(path))
db.add_molecule(_make_molecule(-1.0))
db.flush()

db_r = atompack.Database.open(str(path))
with pytest.raises(OverflowError, match=r"negative"):
_ = db_r[-1]
with pytest.raises(IndexError, match=r"out of bounds"):
_ = db_r[-2]


def test_database_empty_molecule_roundtrip(tmp_path: Path) -> None:
Expand Down
Loading