From df12a6295092ef7e346646a76b8fea3b5f2a3c67 Mon Sep 17 00:00:00 2001 From: Liang Hong Date: Thu, 19 Feb 2026 05:55:52 +0000 Subject: [PATCH 1/2] Add merged structure IO and mmCIF parity - add Python split compression and merged fragment database reads - expose source fragment indices for merged entries - support format-selectable decompression in Python and CLI (pdb|mmcif|cif) - add shared mmCIF atom writer in C++ output path - harden tar/db output path handling with parent-directory checks - expand tests and docs using existing multichain fixture --- README.md | 28 ++- foldcomp/foldcomp.cxx | 485 +++++++++++++++++++++++++++++++++++++--- src/atom_coordinate.cpp | 136 +++++++++++ src/atom_coordinate.h | 6 + src/main.cpp | 153 ++++++++++--- test/test_foldcomp.py | 115 ++++++++++ 6 files changed, 851 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index 200064a..279df01 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Foldcomp compresses protein structures with torsion angles effectively. It compr Foldcomp efficient compressed format stores protein structures requiring only 13 bytes per residue, which reduces the required storage space by an order of magnitude compared to saving 3D coordinates directly. We achieve this reduction by encoding the torsion angles of the backbone as well as the side-chain angles in a compact binary file format (FCZ). -> Foldcomp currently only supports compression of single chain PDB files +> Foldcomp compression core is single-chain per FCZ chunk. Multi-chain or discontinuous inputs are split into multiple FCZ entries.

@@ -57,7 +57,7 @@ foldcomp compress [] foldcomp compress [-t number] [] # Decompression -foldcomp decompress [] +foldcomp decompress [] foldcomp decompress [-t number] [] # Decompressing a subset of Foldcomp database @@ -94,6 +94,7 @@ foldcomp rmsd --fasta, --amino-acid extract amino acid sequence (only for extraction mode) --no-merge do not merge output files (only for extraction mode) --use-title use TITLE as the output file name (only for extraction mode) + --output-format output format for decompression: pdb|mmcif|cif [default=pdb] --time measure time for compression/decompression ``` @@ -137,7 +138,8 @@ with open("test/compressed.fcz", "rb") as fcz: fcz_binary = fcz.read() # Decompress - (name, pdb) = foldcomp.decompress(fcz_binary) # pdb_out[0]: file name, pdb_out[1]: pdb binary string + (name, pdb) = foldcomp.decompress(fcz_binary) # tuple[str, str] where second is decompressed structure text (PDB) + (name, cif) = foldcomp.decompress(fcz_binary, format="mmcif") # mmCIF text output # Save to a pdb file with open(name, "w") as pdb_file: @@ -163,8 +165,27 @@ with foldcomp.open("test/example_db", ids=ids) as db: # save entries as seperate pdb files with open(name + ".pdb", "w") as pdb_file: pdb_file.write(pdb) + +# 03. Multi-chain/discontinuous handling in Python +with open("test/multichain.pdb", "r") as f: + pdb_in = f.read() + +# Split into chain/fragments (CLI-compatible naming) and compress each chunk +chunks = foldcomp.compress("multichain.pdb", pdb_in, split=True) # list[(chunk_name, fcz_bytes)] + +# If a Foldcomp DB stores these chunks, you can reconstruct one whole structure text per lookup group +with foldcomp.open("test/example_db", merge_fragments=True, format="mmcif") as db: + name, mmcif_text = db[0] +``` + +### CLI format example +```sh +# Write mmCIF directly during decompression +foldcomp decompress --output-format mmcif input.fcz output.cif ``` +For tar/database outputs, parent directories in the output path are created automatically. + ## Subsetting Databases If you are dealing with millions of entries, we recommend using `createsubdb` command of [mmseqs2](https://mmseqs.com) to subset databases. @@ -182,4 +203,3 @@ Please note that the IDs in afdb_uniprot_v4 are in the format `AF-A0A5S3Y9Q7-F1- - diff --git a/foldcomp/foldcomp.cxx b/foldcomp/foldcomp.cxx index 8e37ffe..c008ae9 100644 --- a/foldcomp/foldcomp.cxx +++ b/foldcomp/foldcomp.cxx @@ -3,29 +3,46 @@ #include #include +#include #include +#include #include #include #include #include // IWYU pragma: keep +#include #include "atom_coordinate.h" #include "foldcomp.h" #include "database_reader.h" +#include "utility.h" static PyObject *FoldcompError; +enum class OutputFormat { + PDB, + MMCIF +}; + typedef struct { PyObject_HEAD std::vector* user_indices; + std::vector* merged_names; + std::vector>* merged_source_indices; + bool merge_fragments; bool decompress; + OutputFormat output_format; void* memory_handle; } FoldcompDatabaseObject; -int decompress(const char* input, size_t input_size, bool use_alt_order, std::ostream& oss, std::string& name); +int decompressToAtoms(const char* input, size_t input_size, bool use_alt_order, std::vector& atomCoordinates, std::string& name); +int decompress(const char* input, size_t input_size, bool use_alt_order, OutputFormat format, std::ostream& oss, std::string& name); +int parsePDB(const std::string& pdb_input, std::vector& atomCoordinates); static PyObject* FoldcompDatabase_close(PyObject* self); static PyObject* FoldcompDatabase_enter(PyObject* self); static PyObject* FoldcompDatabase_exit(PyObject* self, PyObject* args); +static PyObject* FoldcompDatabase_source_indices(PyObject* self, PyObject* args); +void FoldcompDatabase_release_resources(FoldcompDatabaseObject* db); PyObject* vectorToList_Int64(const std::vector& data); #pragma GCC diagnostic push @@ -34,6 +51,7 @@ PyObject* vectorToList_Int64(const std::vector& data); #pragma GCC diagnostic ignored "-Wcast-function-type" static PyMethodDef FoldcompDatabase_methods[] = { {"close", (PyCFunction)FoldcompDatabase_close, METH_NOARGS, "Close the database."}, + {"source_indices", (PyCFunction)FoldcompDatabase_source_indices, METH_VARARGS, "Get source fragment indices for a merged entry."}, {"__enter__", (PyCFunction)FoldcompDatabase_enter, METH_NOARGS, "Enter the runtime context related to this object."}, {"__exit__", (PyCFunction)FoldcompDatabase_exit, METH_VARARGS, "Exit the runtime context related to this object."}, {NULL, NULL, 0, NULL} /* Sentinel */ @@ -43,19 +61,138 @@ static PyMethodDef FoldcompDatabase_methods[] = { // FoldcompDatabase_sq_length static Py_ssize_t FoldcompDatabase_sq_length(PyObject* self) { FoldcompDatabaseObject* db = (FoldcompDatabaseObject*)self; + if (db->merge_fragments && db->merged_names != NULL) { + return (Py_ssize_t)db->merged_names->size(); + } if (db->user_indices != NULL) { return db->user_indices->size(); } return (Py_ssize_t)reader_get_size(db->memory_handle); } +static void appendPDBFragment(std::ostringstream& out, const std::string& pdb, bool include_title) { + std::istringstream iss(pdb); + std::string line; + while (std::getline(iss, line)) { + if (!include_title && stringStartsWith("TITLE", line)) { + continue; + } + out << line << "\n"; + } +} + +static std::string toLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return (char)std::tolower(c); + }); + return value; +} + +static bool parseOutputFormat(const char* format, OutputFormat& output_format) { + if (format == NULL) { + output_format = OutputFormat::PDB; + return true; + } + std::string format_lower = toLower(std::string(format)); + if (format_lower == "pdb") { + output_format = OutputFormat::PDB; + return true; + } + if (format_lower == "mmcif" || format_lower == "cif") { + output_format = OutputFormat::MMCIF; + return true; + } + return false; +} + +static void writeAtomsByFormat(std::vector& atomCoordinates, const std::string& name, OutputFormat output_format, std::ostream& oss) { + if (output_format == OutputFormat::MMCIF) { + writeAtomCoordinatesToMMCIF(atomCoordinates, name, oss); + } else { + writeAtomCoordinatesToPDB(atomCoordinates, name, oss); + } +} + // FoldcompDatabase_sq_item static PyObject* FoldcompDatabase_sq_item(PyObject* self, Py_ssize_t index) { FoldcompDatabaseObject* db = (FoldcompDatabaseObject*)self; - const char* data; - size_t length; - int64_t id; + if (db->merge_fragments) { + if (!db->decompress) { + PyErr_SetString(PyExc_TypeError, "merge_fragments requires decompress=True"); + return NULL; + } + if (db->merged_names == NULL || db->merged_source_indices == NULL) { + PyErr_SetString(PyExc_RuntimeError, "merged database state is not initialized"); + return NULL; + } + if (index < 0 || index >= (Py_ssize_t)db->merged_names->size()) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return NULL; + } + + const std::string& merged_name = db->merged_names->at(index); + const std::vector& source_ids = db->merged_source_indices->at(index); + if (db->output_format == OutputFormat::PDB) { + std::ostringstream merged_oss; + bool include_title = true; + for (int64_t source_id : source_ids) { + const char* data = reader_get_data(db->memory_handle, source_id); + size_t length = std::max(reader_get_length(db->memory_handle, source_id), (int64_t)1) - (int64_t)1; + if (data == NULL) { + PyErr_SetString(FoldcompError, "Failed to read source fragment from database."); + return NULL; + } + std::ostringstream fragment_oss; + std::string fragment_name; + int err = decompress(data, length, false, OutputFormat::PDB, fragment_oss, fragment_name); + if (err != 0) { + std::string err_msg = "Error decompressing: " + fragment_name; + PyErr_SetString(FoldcompError, err_msg.c_str()); + return NULL; + } + std::string fragment_pdb = fragment_oss.str(); + appendPDBFragment(merged_oss, fragment_pdb, include_title); + include_title = false; + } + std::string merged_pdb = merged_oss.str(); + PyObject* pdb = PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, merged_pdb.c_str(), merged_pdb.size()); + PyObject* result = Py_BuildValue("(s,O)", merged_name.c_str(), pdb); + Py_DECREF(pdb); + return result; + } + + std::vector merged_atoms; + for (int64_t source_id : source_ids) { + const char* data = reader_get_data(db->memory_handle, source_id); + size_t length = std::max(reader_get_length(db->memory_handle, source_id), (int64_t)1) - (int64_t)1; + if (data == NULL) { + PyErr_SetString(FoldcompError, "Failed to read source fragment from database."); + return NULL; + } + std::vector fragment_atoms; + std::string fragment_name; + int err = decompressToAtoms(data, length, false, fragment_atoms, fragment_name); + if (err != 0) { + std::string err_msg = "Error decompressing: " + fragment_name; + PyErr_SetString(FoldcompError, err_msg.c_str()); + return NULL; + } + merged_atoms.insert(merged_atoms.end(), fragment_atoms.begin(), fragment_atoms.end()); + } + + std::ostringstream merged_oss; + writeAtomsByFormat(merged_atoms, merged_name, db->output_format, merged_oss); + std::string merged_text = merged_oss.str(); + PyObject* structure = PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, merged_text.c_str(), merged_text.size()); + PyObject* result = Py_BuildValue("(s,O)", merged_name.c_str(), structure); + Py_DECREF(structure); + return result; + } + + const char* data = NULL; + size_t length = 0; + int64_t id = -1; if (db->user_indices != NULL) { if (index >= (Py_ssize_t)db->user_indices->size()) { PyErr_SetString(PyExc_IndexError, "index out of range"); @@ -75,13 +212,14 @@ static PyObject* FoldcompDatabase_sq_item(PyObject* self, Py_ssize_t index) { if (db->decompress) { std::ostringstream oss; std::string name; - int err = decompress(data, length, false, oss, name); + int err = decompress(data, length, false, db->output_format, oss, name); if (err != 0) { std::string err_msg = "Error decompressing: " + name; PyErr_SetString(FoldcompError, err_msg.c_str()); return NULL; } - PyObject* pdb = PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, oss.str().c_str(), oss.str().size()); + std::string structure_text = oss.str(); + PyObject* pdb = PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, structure_text.c_str(), structure_text.size()); PyObject* result = Py_BuildValue("(s,O)", name.c_str(), pdb); Py_DECREF(pdb); return result; @@ -166,10 +304,7 @@ static PyObject* FoldcompDatabase_close(PyObject* self) { return NULL; } FoldcompDatabaseObject* db = (FoldcompDatabaseObject*)self; - if (db->memory_handle != NULL) { - free_reader(db->memory_handle); - db->memory_handle = NULL; - } + FoldcompDatabase_release_resources(db); Py_RETURN_NONE; } @@ -184,6 +319,49 @@ static PyObject *FoldcompDatabase_exit(PyObject *self, PyObject* /* args */) { return FoldcompDatabase_close(self); } +static PyObject* FoldcompDatabase_source_indices(PyObject* self, PyObject* args) { + if (!PyObject_TypeCheck(self, &FoldcompDatabaseType)) { + PyErr_SetString(PyExc_TypeError, "Expected FoldcompDatabase object."); + return NULL; + } + FoldcompDatabaseObject* db = (FoldcompDatabaseObject*)self; + if (!db->merge_fragments || db->merged_source_indices == NULL) { + PyErr_SetString(PyExc_TypeError, "source_indices is available only when merge_fragments=True."); + return NULL; + } + Py_ssize_t index; + if (!PyArg_ParseTuple(args, "n", &index)) { + return NULL; + } + if (index < 0 || index >= (Py_ssize_t)db->merged_source_indices->size()) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return NULL; + } + return vectorToList_Int64(db->merged_source_indices->at(index)); +} + +void FoldcompDatabase_release_resources(FoldcompDatabaseObject* db) { + if (db == NULL) { + return; + } + if (db->memory_handle != NULL) { + free_reader(db->memory_handle); + db->memory_handle = NULL; + } + if (db->user_indices != NULL) { + delete db->user_indices; + db->user_indices = NULL; + } + if (db->merged_names != NULL) { + delete db->merged_names; + db->merged_names = NULL; + } + if (db->merged_source_indices != NULL) { + delete db->merged_source_indices; + db->merged_source_indices = NULL; + } +} + // https://stackoverflow.com/questions/1448467/initializing-a-c-stdistringstream-from-an-in-memory-buffer/1449527 struct OneShotReadBuf : public std::streambuf { @@ -194,48 +372,66 @@ struct OneShotReadBuf : public std::streambuf }; // Decompress -int decompress(const char* input, size_t input_size, bool use_alt_order, std::ostream& oss, std::string& name) { +int decompressToAtoms(const char* input, size_t input_size, bool use_alt_order, std::vector& atomCoordinates, std::string& name) { OneShotReadBuf buf((char*)input, input_size); std::istream istr(&buf); + std::ios_base::iostate cout_state = std::cout.rdstate(); std::cout.setstate(std::ios_base::failbit); Foldcomp compRes; int flag = compRes.read(istr); if (flag != 0) { + std::cout.clear(cout_state); return 1; } - std::vector atomCoordinates; compRes.useAltAtomOrder = use_alt_order; flag = compRes.decompress(atomCoordinates); if (flag != 0) { + std::cout.clear(cout_state); return 1; } - // Write decompressed data to file - writeAtomCoordinatesToPDB(atomCoordinates, compRes.strTitle, oss); - std::cout.clear(); + std::cout.clear(cout_state); name = compRes.strTitle; return 0; } + +int decompress(const char* input, size_t input_size, bool use_alt_order, OutputFormat format, std::ostream& oss, std::string& name) { + std::vector atomCoordinates; + int flag = decompressToAtoms(input, input_size, use_alt_order, atomCoordinates, name); + if (flag != 0) { + return flag; + } + writeAtomsByFormat(atomCoordinates, name, format, oss); + return 0; +} // Python binding for decompress -static PyObject *foldcomp_decompress(PyObject* /* self */, PyObject *args) { +static PyObject *foldcomp_decompress(PyObject* /* self */, PyObject *args, PyObject* kwargs) { // Unpack a string from the arguments const char *strArg; Py_ssize_t strSize; - if (!PyArg_ParseTuple(args, "y#", &strArg, &strSize)) { + const char* format = "pdb"; + static const char *kwlist[] = {"input", "format", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y#|$s", const_cast(kwlist), &strArg, &strSize, &format)) { + return NULL; + } + OutputFormat output_format; + if (!parseOutputFormat(format, output_format)) { + PyErr_SetString(PyExc_ValueError, "format must be one of: 'pdb', 'mmcif', 'cif'"); return NULL; } std::ostringstream oss; std::string name; - int err = decompress(strArg, strSize, false, oss, name); + int err = decompress(strArg, strSize, false, output_format, oss, name); if (err != 0) { PyErr_SetString(FoldcompError, "Error decompressing."); return NULL; } - return Py_BuildValue("(s,O)", name.c_str(), PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, oss.str().c_str(), oss.str().size())); + std::string structure_text = oss.str(); + return Py_BuildValue("(s,O)", name.c_str(), PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, structure_text.c_str(), structure_text.size())); } std::string trim(const std::string& str, const std::string& whitespace = " \t") { @@ -249,10 +445,8 @@ std::string trim(const std::string& str, const std::string& whitespace = " \t") return str.substr(strBegin, strRange); } -// Compress -int compress(const std::string& name, const std::string& pdb_input, std::ostream& oss, int anchor_residue_threshold) { - std::vector atomCoordinates; - // parse ATOM lines from PDB file into atomCoordinates +int parsePDB(const std::string& pdb_input, std::vector& atomCoordinates) { + atomCoordinates.clear(); std::istringstream iss(pdb_input); std::string line; std::string chain = ""; @@ -279,6 +473,16 @@ int compress(const std::string& name, const std::string& pdb_input, std::ostream if (atomCoordinates.size() == 0) { return 1; // FLAG 1: no ATOM lines } + return 0; +} + +// Compress +int compress(const std::string& name, const std::string& pdb_input, std::ostream& oss, int anchor_residue_threshold) { + std::vector atomCoordinates; + int parse_flag = parsePDB(pdb_input, atomCoordinates); + if (parse_flag != 0) { + return parse_flag; + } removeAlternativePosition(atomCoordinates); @@ -296,8 +500,9 @@ static PyObject *foldcomp_compress(PyObject* /* self */, PyObject *args, PyObjec const char* name; const char* pdb_input; PyObject* anchor_residue_threshold = NULL; - static const char *kwlist[] = {"name", "pdb_content", "anchor_residue_threshold", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ss|$O", const_cast(kwlist), &name, &pdb_input, &anchor_residue_threshold)) { + PyObject* split = NULL; + static const char *kwlist[] = {"name", "pdb_content", "anchor_residue_threshold", "split", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ss|$OO", const_cast(kwlist), &name, &pdb_input, &anchor_residue_threshold, &split)) { return NULL; } @@ -305,11 +510,131 @@ static PyObject *foldcomp_compress(PyObject* /* self */, PyObject *args, PyObjec PyErr_SetString(PyExc_TypeError, "anchor_residue_threshold must be an integer"); return NULL; } + if (split != NULL && !PyBool_Check(split)) { + PyErr_SetString(PyExc_TypeError, "split must be a boolean"); + return NULL; + } int threshold = DEFAULT_ANCHOR_THRESHOLD; if (anchor_residue_threshold != NULL) { threshold = PyLong_AsLong(anchor_residue_threshold); } + bool split_flag = split != NULL && PyObject_IsTrue(split); + + if (split_flag) { + std::string pdb_content(pdb_input); + std::vector> chains; + std::istringstream chain_stream(pdb_content); + std::string line; + std::string current_chain = ""; + std::string current_chain_chunk = ""; + while (std::getline(chain_stream, line)) { + if (!stringStartsWith("ATOM", line)) { + continue; + } + std::string chain = line.substr(21, 1); + if (current_chain.empty()) { + current_chain = chain; + } else if (chain != current_chain) { + if (!current_chain_chunk.empty()) { + chains.emplace_back(current_chain, current_chain_chunk); + } + current_chain = chain; + current_chain_chunk.clear(); + } + current_chain_chunk += line + "\n"; + } + if (!current_chain_chunk.empty()) { + chains.emplace_back(current_chain, current_chain_chunk); + } + if (chains.size() == 0) { + PyErr_SetString(FoldcompError, "No ATOM lines found"); + return NULL; + } + + std::string base = baseName(name); + std::pair output_parts = getFileParts(base); + std::string output_base = output_parts.first; + + PyObject* outputs = PyList_New(0); + if (outputs == NULL) { + return NULL; + } + + for (size_t i = 0; i < chains.size(); i++) { + std::vector fragments; + std::istringstream fragment_stream(chains[i].second); + std::string fragment; + int prev_n_res_idx = 0; + bool has_prev_n = false; + while (std::getline(fragment_stream, line)) { + if (!stringStartsWith("ATOM", line)) { + continue; + } + std::string atom = trim(line.substr(12, 4)); + if (atom == "N") { + int curr_res_idx = 0; + bool parsed = true; + try { + curr_res_idx = std::stoi(line.substr(22, 4)); + } catch (...) { + parsed = false; + } + if (parsed && has_prev_n && curr_res_idx - prev_n_res_idx > 1) { + if (!fragment.empty()) { + fragments.push_back(fragment); + fragment.clear(); + } + } + if (parsed) { + prev_n_res_idx = curr_res_idx; + has_prev_n = true; + } + } + fragment += line + "\n"; + } + if (!fragment.empty()) { + fragments.push_back(fragment); + } + + for (size_t j = 0; j < fragments.size(); j++) { + std::ostringstream oss; + int flag = compress(name, fragments[j], oss, threshold); + if (flag != 0) { + continue; + } + + std::string chunk_name = output_base; + if (chains.size() > 1) { + chunk_name += chains[i].first; + } + if (fragments.size() > 1) { + chunk_name += "_" + std::to_string(j); + } + chunk_name += ".fcz"; + + std::string payload = oss.str(); + PyObject* py_payload = PyBytes_FromStringAndSize(payload.c_str(), payload.size()); + if (py_payload == NULL) { + Py_DECREF(outputs); + return NULL; + } + PyObject* py_item = Py_BuildValue("(s,O)", chunk_name.c_str(), py_payload); + Py_DECREF(py_payload); + if (py_item == NULL) { + Py_DECREF(outputs); + return NULL; + } + if (PyList_Append(outputs, py_item) != 0) { + Py_DECREF(py_item); + Py_DECREF(outputs); + return NULL; + } + Py_DECREF(py_item); + } + } + return outputs; + } std::ostringstream oss; int flag = compress(name, pdb_input, oss, threshold); @@ -324,7 +649,8 @@ static PyObject *foldcomp_compress(PyObject* /* self */, PyObject *args, PyObjec return NULL; } - return PyBytes_FromStringAndSize(oss.str().c_str(), oss.str().length()); + std::string compressed_bytes = oss.str(); + return PyBytes_FromStringAndSize(compressed_bytes.c_str(), compressed_bytes.length()); } @@ -335,9 +661,11 @@ static PyObject *foldcomp_open(PyObject* /* self */, PyObject* args, PyObject* k PyObject* user_ids = NULL; PyObject* decompress = NULL; PyObject* err_on_missing = NULL; // Raise an error if the file is missing. Default: False + PyObject* merge_fragments = NULL; + const char* format = "pdb"; - static const char *kwlist[] = {"path", "ids", "decompress", "err_on_missing", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&|$OOO", const_cast(kwlist), PyUnicode_FSConverter, &path, &user_ids, &decompress, &err_on_missing)) { + static const char *kwlist[] = {"path", "ids", "decompress", "err_on_missing", "merge_fragments", "format", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&|$OOOOs", const_cast(kwlist), PyUnicode_FSConverter, &path, &user_ids, &decompress, &err_on_missing, &merge_fragments, &format)) { return NULL; } if (path == NULL) { @@ -368,6 +696,17 @@ static PyObject *foldcomp_open(PyObject* /* self */, PyObject* args, PyObject* k PyErr_SetString(PyExc_TypeError, "err_on_missing must be a boolean"); return NULL; } + if (merge_fragments != NULL && !PyBool_Check(merge_fragments)) { + Py_DECREF(path); + PyErr_SetString(PyExc_TypeError, "merge_fragments must be a boolean"); + return NULL; + } + OutputFormat output_format; + if (!parseOutputFormat(format, output_format)) { + Py_DECREF(path); + PyErr_SetString(PyExc_ValueError, "format must be one of: 'pdb', 'mmcif', 'cif'"); + return NULL; + } std::string dbname(pathCStr); std::string index = dbname + ".index"; @@ -379,9 +718,19 @@ static PyObject *foldcomp_open(PyObject* /* self */, PyObject* args, PyObject* k PyErr_SetString(PyExc_MemoryError, "Could not allocate memory for FoldcompDatabaseObject"); return NULL; } + obj->memory_handle = NULL; + obj->user_indices = NULL; + obj->merged_names = NULL; + obj->merged_source_indices = NULL; + obj->merge_fragments = false; + obj->decompress = true; + obj->output_format = output_format; int mode = DB_READER_USE_DATA; - if (user_ids != NULL && PySequence_Length(user_ids) > 0) { + bool merge_fragments_flag = merge_fragments != NULL && PyObject_IsTrue(merge_fragments); + if (merge_fragments_flag) { + mode |= DB_READER_USE_LOOKUP_REVERSE; + } else if (user_ids != NULL && PySequence_Length(user_ids) > 0) { mode |= DB_READER_USE_LOOKUP; } @@ -396,16 +745,85 @@ static PyObject *foldcomp_open(PyObject* /* self */, PyObject* args, PyObject* k } else { err_on_missing_flag = PyObject_IsTrue(err_on_missing); } + obj->merge_fragments = merge_fragments_flag; + if (obj->merge_fragments && !obj->decompress) { + FoldcompDatabase_release_resources(obj); + Py_DECREF(obj); + PyErr_SetString(PyExc_TypeError, "merge_fragments requires decompress=True"); + return NULL; + } obj->memory_handle = make_reader(dbname.c_str(), index.c_str(), mode); + if (obj->memory_handle == NULL) { + FoldcompDatabase_release_resources(obj); + Py_DECREF(obj); + PyErr_SetString(PyExc_RuntimeError, "Failed to open Foldcomp database."); + return NULL; + } - obj->user_indices = NULL; - if (user_ids != NULL && PySequence_Length(user_ids) > 0) { + if (obj->merge_fragments) { + obj->merged_names = new std::vector(); + obj->merged_source_indices = new std::vector>(); + std::unordered_map group_idx_by_name; + int64_t size = reader_get_size(obj->memory_handle); + for (int64_t id = 0; id < size; id++) { + uint32_t key = reader_get_key(obj->memory_handle, id); + const char* lookup_name = reader_lookup_name_alloc(obj->memory_handle, key); + std::string group_name; + if (lookup_name != NULL && lookup_name[0] != '\0') { + group_name = lookup_name; + free((void*)lookup_name); + } else { + group_name = std::to_string(key); + } + + auto it = group_idx_by_name.find(group_name); + size_t group_idx; + if (it == group_idx_by_name.end()) { + group_idx = obj->merged_names->size(); + group_idx_by_name[group_name] = group_idx; + obj->merged_names->push_back(group_name); + obj->merged_source_indices->emplace_back(); + } else { + group_idx = it->second; + } + obj->merged_source_indices->at(group_idx).push_back(id); + } + + if (user_ids != NULL && PySequence_Length(user_ids) > 0) { + std::vector filtered_names; + std::vector> filtered_sources; + size_t id_count = (size_t)PySequence_Length(user_ids); + for (Py_ssize_t i = 0; i < (Py_ssize_t)id_count; i++) { + PyObject* item = PySequence_GetItem(user_ids, i); + const char* data = PyUnicode_AsUTF8(item); + Py_DECREF(item); + auto it = group_idx_by_name.find(data); + if (it == group_idx_by_name.end()) { + std::string err_msg = "Skipping entry "; + err_msg += data; + err_msg += " which is not in the database."; + if (err_on_missing_flag) { + FoldcompDatabase_release_resources(obj); + Py_DECREF(obj); + PyErr_SetString(PyExc_KeyError, err_msg.c_str()); + return NULL; + } else { + std::cerr << err_msg << std::endl; + continue; + } + } + filtered_names.push_back(data); + filtered_sources.push_back(obj->merged_source_indices->at(it->second)); + } + *(obj->merged_names) = filtered_names; + *(obj->merged_source_indices) = filtered_sources; + } + } else if (user_ids != NULL && PySequence_Length(user_ids) > 0) { size_t id_count = (size_t)PySequence_Length(user_ids); // Reserve memory for the user indices obj->user_indices = new std::vector(); obj->user_indices->reserve(id_count); - // user_indices.reserve(id_count); for (Py_ssize_t i = 0; i < (Py_ssize_t)id_count; i++) { // Iterate over all entries in the database and store ids in a vector of int64_t PyObject* item = PySequence_GetItem(user_ids, i); @@ -419,6 +837,7 @@ static PyObject *foldcomp_open(PyObject* /* self */, PyObject* args, PyObject* k err_msg += data; err_msg += " which is not in the database."; if (err_on_missing_flag) { + FoldcompDatabase_release_resources(obj); Py_DECREF(obj); PyErr_SetString(PyExc_KeyError, err_msg.c_str()); return NULL; @@ -701,7 +1120,7 @@ static PyObject* foldcomp_get_data(PyObject* /* self */, PyObject* args, PyObjec #pragma GCC diagnostic ignored "-Wcast-function-type" static PyMethodDef foldcomp_methods[] = { // {"compress", foldcomp_compress, METH_VARARGS, "Compress a PDB file."}, - {"decompress", foldcomp_decompress, METH_VARARGS, "Decompress FCZ content to PDB."}, + {"decompress", (PyCFunction)foldcomp_decompress, METH_VARARGS | METH_KEYWORDS, "Decompress FCZ content to PDB or mmCIF."}, {"compress", (PyCFunction)foldcomp_compress, METH_VARARGS | METH_KEYWORDS, "Compress PDB content to FCZ."}, {"open", (PyCFunction)foldcomp_open, METH_VARARGS | METH_KEYWORDS, "Open a Foldcomp database."}, {"get_data", (PyCFunction)foldcomp_get_data, METH_VARARGS | METH_KEYWORDS, "Get data from FCZ or PDB content."}, diff --git a/src/atom_coordinate.cpp b/src/atom_coordinate.cpp index 4217e29..fbda883 100644 --- a/src/atom_coordinate.cpp +++ b/src/atom_coordinate.cpp @@ -13,6 +13,8 @@ */ #include "atom_coordinate.h" +#include +#include #include #include #include @@ -217,6 +219,53 @@ void fast_ftoa(float n, char* s) { // } } +static std::string sanitizeMMCIFBlockName(const std::string& title) { + std::string block = title; + if (block.empty()) { + block = "foldcomp"; + } + std::transform(block.begin(), block.end(), block.begin(), [](unsigned char c) { + if (std::isalnum(c) || c == '_' || c == '-') { + return (char)c; + } + return '_'; + }); + return block.empty() ? "foldcomp" : block; +} + +static void writeMMCIFValue(std::ostream& out, const std::string& value) { + if (value.empty()) { + out << "."; + return; + } + bool needs_quote = false; + for (char c : value) { + if (std::isspace((unsigned char)c) || c == '\'' || c == '"' || c == ';' || c == '#') { + needs_quote = true; + break; + } + } + if (!needs_quote) { + out << value; + return; + } + std::string escaped = value; + std::replace(escaped.begin(), escaped.end(), '\n', '_'); + std::replace(escaped.begin(), escaped.end(), '\r', '_'); + std::replace(escaped.begin(), escaped.end(), '\'', '_'); + std::replace(escaped.begin(), escaped.end(), '"', '_'); + out << "'" << escaped << "'"; +} + +static char inferElementSymbol(const std::string& atom_name) { + for (char c : atom_name) { + if (std::isalpha((unsigned char)c)) { + return (char)std::toupper((unsigned char)c); + } + } + return '?'; +} + void writeAtomCoordinatesToPDB( std::vector& atoms, std::string title, std::ostream& pdb_stream ) { @@ -301,6 +350,93 @@ int writeAtomCoordinatesToPDBFile( return 0; } +void writeAtomCoordinatesToMMCIF( + const std::vector& atoms, const std::string& title, std::ostream& mmcif_stream +) { + static const char* MMCIF_ATOM_SITE_HEADER = + "#\n" + "loop_\n" + "_atom_site.group_PDB\n" + "_atom_site.id\n" + "_atom_site.type_symbol\n" + "_atom_site.label_atom_id\n" + "_atom_site.label_alt_id\n" + "_atom_site.label_comp_id\n" + "_atom_site.label_asym_id\n" + "_atom_site.label_entity_id\n" + "_atom_site.label_seq_id\n" + "_atom_site.pdbx_PDB_ins_code\n" + "_atom_site.Cartn_x\n" + "_atom_site.Cartn_y\n" + "_atom_site.Cartn_z\n" + "_atom_site.occupancy\n" + "_atom_site.B_iso_or_equiv\n" + "_atom_site.pdbx_formal_charge\n" + "_atom_site.auth_seq_id\n" + "_atom_site.auth_comp_id\n" + "_atom_site.auth_asym_id\n" + "_atom_site.auth_atom_id\n" + "_atom_site.pdbx_PDB_model_num\n"; + + mmcif_stream << "data_" << sanitizeMMCIFBlockName(title) << "\n"; + mmcif_stream << MMCIF_ATOM_SITE_HEADER; + + char xbuf[16]; + char ybuf[16]; + char zbuf[16]; + char occbuf[16]; + char bbuf[16]; + for (const AtomCoordinate& atom : atoms) { + const std::string chain = atom.chain.empty() ? "." : atom.chain; + const float occupancy = atom.occupancy > 0.0f ? atom.occupancy : 1.0f; + const char element = inferElementSymbol(atom.atom); + fast_ftoa<1000, 3>(atom.coordinate.x, xbuf); + fast_ftoa<1000, 3>(atom.coordinate.y, ybuf); + fast_ftoa<1000, 3>(atom.coordinate.z, zbuf); + fast_ftoa<100, 2>(occupancy, occbuf); + fast_ftoa<100, 2>(atom.tempFactor, bbuf); + + mmcif_stream << "ATOM " + << atom.atom_index << " " + << element << " "; + writeMMCIFValue(mmcif_stream, atom.atom); + mmcif_stream << " . "; + writeMMCIFValue(mmcif_stream, atom.residue); + mmcif_stream << " "; + writeMMCIFValue(mmcif_stream, chain); + mmcif_stream << " " + << "1 " + << atom.residue_index << " " + << "? " + << xbuf << " " + << ybuf << " " + << zbuf << " " + << occbuf << " " + << bbuf << " " + << "? " + << atom.residue_index << " " + ; + writeMMCIFValue(mmcif_stream, atom.residue); + mmcif_stream << " "; + writeMMCIFValue(mmcif_stream, chain); + mmcif_stream << " "; + writeMMCIFValue(mmcif_stream, atom.atom); + mmcif_stream << " 1\n"; + } + mmcif_stream << "#\n"; +} + +int writeAtomCoordinatesToMMCIFFile( + const std::vector& atoms, const std::string& title, const std::string& mmcif_path +) { + std::ofstream mmcif_file(mmcif_path); + if (!mmcif_file) { + return 1; + } + writeAtomCoordinatesToMMCIF(atoms, title, mmcif_file); + return 0; +} + std::vector< std::vector > splitAtomByResidue( const tcb::span& atomCoordinates ) { diff --git a/src/atom_coordinate.h b/src/atom_coordinate.h index 3b06efb..d7eb551 100644 --- a/src/atom_coordinate.h +++ b/src/atom_coordinate.h @@ -85,6 +85,12 @@ void writeAtomCoordinatesToPDB( int writeAtomCoordinatesToPDBFile( std::vector& atoms, std::string title, std::string pdb_path ); +void writeAtomCoordinatesToMMCIF( + const std::vector& atoms, const std::string& title, std::ostream& mmcif_stream +); +int writeAtomCoordinatesToMMCIFFile( + const std::vector& atoms, const std::string& title, const std::string& mmcif_path +); std::vector> splitAtomByResidue( const tcb::span& atomCoordinates diff --git a/src/main.cpp b/src/main.cpp index 7c7896b..06bbd3d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -29,10 +29,12 @@ #include "input_processor.h" // Standard libraries +#include +#include #include +#include #include // IWYU pragma: keep #ifdef _WIN32 -#include #include "windows/getopt.h" #include "windows/dirent.h" #else @@ -62,10 +64,37 @@ static int overwrite = 0; // version #define FOLDCOMP_VERSION "1.0.0" +static bool ensureDirectoryExists(const std::string& dir_path) { + if (dir_path.empty()) { + return true; + } + std::error_code ec; + std::filesystem::create_directories(std::filesystem::path(dir_path), ec); + if (ec) { + std::cerr << "[Error] Could not create directory: " << dir_path << std::endl; + return false; + } + return true; +} + +static bool ensureParentDirectoryExists(const std::string& file_path) { + std::filesystem::path parent = std::filesystem::path(file_path).parent_path(); + if (parent.empty()) { + return true; + } + std::error_code ec; + std::filesystem::create_directories(parent, ec); + if (ec) { + std::cerr << "[Error] Could not create parent directory: " << parent.string() << std::endl; + return false; + } + return true; +} + int print_usage(void) { std::cout << "Usage: foldcomp compress []" << std::endl; std::cout << " foldcomp compress [-t number] []" << std::endl; - std::cout << " foldcomp decompress []" << std::endl; + std::cout << " foldcomp decompress []" << std::endl; std::cout << " foldcomp decompress [-t number] []" << std::endl; std::cout << " foldcomp extract [--plddt|--amino-acid] []" << std::endl; std::cout << " foldcomp extract [--plddt|--amino-acid] [-t number] []" << std::endl; @@ -92,6 +121,7 @@ int print_usage(void) { std::cout << " --fasta, --amino-acid extract amino acid sequence (only for extraction mode)" << std::endl; std::cout << " --no-merge do not merge output files (only for extraction mode)" << std::endl; std::cout << " --use-title use TITLE as the output file name (only for extraction mode)" << std::endl; + std::cout << " --output-format output format for decompression: pdb|mmcif|cif [default=pdb]" << std::endl; std::cout << " --time measure time for compression/decompression" << std::endl; std::cout << " --use-cache use cached index for database input [default=false]" << std::endl; return 0; @@ -154,6 +184,7 @@ int main(int argc, char* const *argv) { int check_before_decompression = 0; int id_mode = 1; int use_cache = 0; + std::string output_format = "pdb"; std::string user_id_file = ""; std::vector user_names; std::vector user_ids; @@ -190,12 +221,13 @@ int main(int argc, char* const *argv) { {"id-list", required_argument, 0, 'l'}, {"id-mode", required_argument, 0, 'm'}, {"plddt-digits", required_argument, 0, 'p'}, + {"output-format", required_argument, 0, 'x'}, {"use-cache", no_argument, &use_cache, 1 }, {0, 0, 0, 0 } }; // Parse command line options with getopt_long - int flag = getopt_long(argc, argv, "hadzrfyvt:b:l:p:", long_options, &option_index); + int flag = getopt_long(argc, argv, "hadzrfyvt:b:l:p:x:", long_options, &option_index); while (flag != -1) { switch (flag) { case 'h': @@ -237,6 +269,20 @@ int main(int argc, char* const *argv) { case 'p': ext_plddt_digits = atoi(optarg); break; + case 'x': { + output_format = std::string(optarg); + std::transform(output_format.begin(), output_format.end(), output_format.begin(), [](unsigned char c) { + return (char)std::tolower(c); + }); + if (output_format == "mmcif") { + output_format = "cif"; + } + if (output_format != "pdb" && output_format != "cif") { + std::cerr << "[Error] Invalid output format. Please use pdb, mmcif, or cif." << std::endl; + return print_usage(); + } + break; + } case 'v': return print_version(); case '?': @@ -244,7 +290,7 @@ int main(int argc, char* const *argv) { default: break; } - flag = getopt_long(argc, argv, "hadzrfyt:b:l:p:", long_options, &option_index); + flag = getopt_long(argc, argv, "hadzrfyt:b:l:p:x:", long_options, &option_index); } // Parse non-option arguments @@ -273,7 +319,7 @@ int main(int argc, char* const *argv) { } else if (strcmp(argv[optind], "decompress") == 0) { mode = DECOMPRESS; mayHaveOutput = true; - outputSuffix = "pdb"; + outputSuffix = output_format.c_str(); } else if (strcmp(argv[optind], "extract") == 0) { mode = EXTRACT; mayHaveOutput = true; @@ -372,20 +418,28 @@ int main(int argc, char* const *argv) { rmsd(input, output); } else if (mode == COMPRESS) { // output variants - void* handle; + void* handle = NULL; mtar_t tar_out; if (save_as_tar) { - mtar_open(&tar_out, output.c_str(), "w"); + if (!ensureParentDirectoryExists(output)) { + return EXIT_FAILURE; + } + if (mtar_open(&tar_out, output.c_str(), "w") != MTAR_ESUCCESS) { + std::cerr << "[Error] Could not open tar output: " << output << std::endl; + return EXIT_FAILURE; + } } else if (db_output) { + if (!ensureParentDirectoryExists(output)) { + return EXIT_FAILURE; + } handle = make_writer(output.c_str(), (output + ".index").c_str()); + if (handle == NULL) { + std::cerr << "[Error] Could not open database output: " << output << std::endl; + return EXIT_FAILURE; + } } else if (!isSingleFileInput) { - struct stat st; - if (stat(output.c_str(), &st) == -1) { -#ifdef _WIN32 - _mkdir(output.c_str()); -#else - mkdir(output.c_str(), 0755); -#endif + if (!ensureDirectoryExists(output)) { + return EXIT_FAILURE; } } @@ -544,20 +598,28 @@ int main(int argc, char* const *argv) { mtar_close(&tar_out); } } else if (mode == DECOMPRESS) { - void* handle; + void* handle = NULL; mtar_t tar_out; if (save_as_tar) { - mtar_open(&tar_out, output.c_str(), "w"); + if (!ensureParentDirectoryExists(output)) { + return EXIT_FAILURE; + } + if (mtar_open(&tar_out, output.c_str(), "w") != MTAR_ESUCCESS) { + std::cerr << "[Error] Could not open tar output: " << output << std::endl; + return EXIT_FAILURE; + } } else if (db_output) { + if (!ensureParentDirectoryExists(output)) { + return EXIT_FAILURE; + } handle = make_writer(output.c_str(), (output + ".index").c_str()); + if (handle == NULL) { + std::cerr << "[Error] Could not open database output: " << output << std::endl; + return EXIT_FAILURE; + } } else if (!isSingleFileInput) { - struct stat st; - if (stat(output.c_str(), &st) == -1) { -#ifdef _WIN32 - _mkdir(output.c_str()); -#else - mkdir(output.c_str(), 0755); -#endif + if (!ensureDirectoryExists(output)) { + return EXIT_FAILURE; } } @@ -576,6 +638,19 @@ int main(int argc, char* const *argv) { } unsigned int key = 0; + auto writeDecompressedStructure = [&](std::vector& atomCoordinates, const std::string& title, std::ostream& out) { + if (output_format == "pdb") { + writeAtomCoordinatesToPDB(atomCoordinates, title, out); + } else { + writeAtomCoordinatesToMMCIF(atomCoordinates, title, out); + } + }; + auto writeDecompressedStructureFile = [&](std::vector& atomCoordinates, const std::string& title, const std::string& outputFile) -> int { + if (output_format == "pdb") { + return writeAtomCoordinatesToPDBFile(atomCoordinates, title, outputFile); + } + return writeAtomCoordinatesToMMCIFFile(atomCoordinates, title, outputFile); + }; for (size_t i = 0; i < inputs.size() + 1; i++) { const std::string& input = (i == inputs.size()) ? "" : inputs[i]; Processor* processor; @@ -655,7 +730,7 @@ int main(int argc, char* const *argv) { if (db_output) { std::ostringstream oss; - writeAtomCoordinatesToPDB(atomCoordinates, compRes.strTitle, oss); + writeDecompressedStructure(atomCoordinates, compRes.strTitle, oss); oss << '\0'; std::string os = oss.str(); #pragma omp critical @@ -665,7 +740,7 @@ int main(int argc, char* const *argv) { } } else if (save_as_tar) { std::ostringstream oss; - writeAtomCoordinatesToPDB(atomCoordinates, compRes.strTitle, oss); + writeDecompressedStructure(atomCoordinates, compRes.strTitle, oss); #pragma omp critical { std::string os = oss.str(); @@ -679,7 +754,7 @@ int main(int argc, char* const *argv) { std::cerr << "[Error] Output file already exists: " << baseName(outputFile) << std::endl; return false; } - flag = writeAtomCoordinatesToPDBFile(atomCoordinates, compRes.strTitle, outputFile); + flag = writeDecompressedStructureFile(atomCoordinates, compRes.strTitle, outputFile); if (flag != 0) { std::cerr << "[Error] Writing decompressed data to file: " << output << std::endl; return false; @@ -697,20 +772,28 @@ int main(int argc, char* const *argv) { mtar_close(&tar_out); } } else if (mode == EXTRACT) { - void* handle; + void* handle = NULL; mtar_t tar_out; if (save_as_tar) { - mtar_open(&tar_out, output.c_str(), "w"); + if (!ensureParentDirectoryExists(output)) { + return EXIT_FAILURE; + } + if (mtar_open(&tar_out, output.c_str(), "w") != MTAR_ESUCCESS) { + std::cerr << "[Error] Could not open tar output: " << output << std::endl; + return EXIT_FAILURE; + } } else if (db_output) { + if (!ensureParentDirectoryExists(output)) { + return EXIT_FAILURE; + } handle = make_writer(output.c_str(), (output + ".index").c_str()); + if (handle == NULL) { + std::cerr << "[Error] Could not open database output: " << output << std::endl; + return EXIT_FAILURE; + } } else { - struct stat st; - if (stat(output.c_str(), &st) == -1 && !ext_merge) { -#ifdef _WIN32 - _mkdir(output.c_str()); -#else - mkdir(output.c_str(), 0755); -#endif + if (!ext_merge && !ensureDirectoryExists(output)) { + return EXIT_FAILURE; } } diff --git a/test/test_foldcomp.py b/test/test_foldcomp.py index 21bd578..539d325 100644 --- a/test/test_foldcomp.py +++ b/test/test_foldcomp.py @@ -1,4 +1,6 @@ import foldcomp +import pytest +import shutil from pathlib import Path @@ -8,6 +10,24 @@ def test_decompress(pytestconfig): print(foldcomp.decompress(foldcomp.compress("test", data))) +def test_decompress_mmcif(pytestconfig): + with open(pytestconfig.rootpath.joinpath("test/test.pdb"), "rb") as f: + data = f.read().decode("utf-8") + fcz = foldcomp.compress("test", data) + name, mmcif = foldcomp.decompress(fcz, format="mmcif") + assert isinstance(name, str) + assert isinstance(mmcif, str) + assert mmcif.startswith("data_") + assert "_atom_site.Cartn_x" in mmcif + + +def test_decompress_invalid_format(pytestconfig): + with open(pytestconfig.rootpath.joinpath("test/test_af.fcz"), "rb") as f: + fcz = f.read() + with pytest.raises(ValueError, match="format must be one of"): + foldcomp.decompress(fcz, format="xyz") + + def test_open_db_all(pytestconfig): path = Path(pytestconfig.rootpath.joinpath("test/example_db")) with foldcomp.open(path) as db: @@ -25,3 +45,98 @@ def test_open_db_ids(pytestconfig): def test_open_db_str(pytestconfig): with foldcomp.open(str(pytestconfig.rootpath.joinpath("test/example_db"))) as db: pass + + +def test_compress_multichain_requires_split(pytestconfig): + with open(pytestconfig.rootpath.joinpath("test/multichain.pdb"), "rb") as f: + data = f.read().decode("utf-8") + with pytest.raises(foldcomp.error, match="Multiple chains found"): + foldcomp.compress("multichain.pdb", data) + + +def test_compress_split_multichain(pytestconfig): + with open(pytestconfig.rootpath.joinpath("test/multichain.pdb"), "rb") as f: + data = f.read().decode("utf-8") + chunks = foldcomp.compress("multichain.pdb", data, split=True) + assert isinstance(chunks, list) + assert len(chunks) > 1 + first_name, first_fcz = chunks[0] + assert first_name.endswith(".fcz") + assert isinstance(first_fcz, bytes) + name, pdb = foldcomp.decompress(first_fcz) + assert isinstance(name, str) + assert isinstance(pdb, str) + assert "ATOM" in pdb + + +def _copy_example_db(rootpath, tmp_path): + source = rootpath.joinpath("test/example_db") + target = tmp_path.joinpath("example_db_copy") + for suffix in ["", ".dbtype", ".index", ".lookup", ".source"]: + src = Path(str(source) + suffix) + if src.exists(): + shutil.copy(src, Path(str(target) + suffix)) + return target + + +def test_open_merge_fragments_lookup_grouping(pytestconfig, tmp_path): + dbpath = _copy_example_db(pytestconfig.rootpath, tmp_path) + lookup_path = Path(str(dbpath) + ".lookup") + lines = lookup_path.read_text().splitlines() + first_cols = lines[0].split("\t") + second_cols = lines[1].split("\t") + first_cols[1] = "merged_entry" + second_cols[1] = "merged_entry" + lines[0] = "\t".join(first_cols) + lines[1] = "\t".join(second_cols) + lookup_path.write_text("\n".join(lines) + "\n") + + with foldcomp.open(str(dbpath)) as raw_db: + raw_len = len(raw_db) + + with foldcomp.open(str(dbpath), merge_fragments=True) as merged_db: + assert len(merged_db) == raw_len - 1 + found = False + for i in range(len(merged_db)): + name, pdb = merged_db[i] + if name == "merged_entry": + assert isinstance(pdb, str) + assert "ATOM" in pdb + source = merged_db.source_indices(i) + assert len(source) == 2 + found = True + break + assert found + + with foldcomp.open(str(dbpath), ids=["merged_entry"], merge_fragments=True) as merged_filtered: + assert len(merged_filtered) == 1 + name, pdb = merged_filtered[0] + assert name == "merged_entry" + assert "ATOM" in pdb + assert len(merged_filtered.source_indices(0)) == 2 + + with foldcomp.open(str(dbpath), ids=["merged_entry"], merge_fragments=True, format="mmcif") as merged_mmcif: + assert len(merged_mmcif) == 1 + name, mmcif = merged_mmcif[0] + assert name == "merged_entry" + assert isinstance(mmcif, str) + assert mmcif.startswith("data_") + assert mmcif.count("data_") == 1 + assert "_atom_site.Cartn_x" in mmcif + + +def test_open_merge_fragments_requires_decompress(pytestconfig, tmp_path): + dbpath = _copy_example_db(pytestconfig.rootpath, tmp_path) + with pytest.raises(TypeError, match="merge_fragments requires decompress=True"): + foldcomp.open(str(dbpath), merge_fragments=True, decompress=False) + + +def test_source_indices_requires_merge(pytestconfig): + with foldcomp.open(str(pytestconfig.rootpath.joinpath("test/example_db"))) as db: + with pytest.raises(TypeError, match="source_indices is available only"): + db.source_indices(0) + + +def test_open_invalid_format(pytestconfig): + with pytest.raises(ValueError, match="format must be one of"): + foldcomp.open(str(pytestconfig.rootpath.joinpath("test/example_db")), format="xyz") From cfddb60c5f13814ebd7b4737f3aea77ca6549d57 Mon Sep 17 00:00:00 2001 From: Liang Hong Date: Mon, 23 Feb 2026 07:05:59 +0000 Subject: [PATCH 2/2] Format test file with black --- test/test_foldcomp.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/test_foldcomp.py b/test/test_foldcomp.py index 539d325..2e15100 100644 --- a/test/test_foldcomp.py +++ b/test/test_foldcomp.py @@ -108,14 +108,18 @@ def test_open_merge_fragments_lookup_grouping(pytestconfig, tmp_path): break assert found - with foldcomp.open(str(dbpath), ids=["merged_entry"], merge_fragments=True) as merged_filtered: + with foldcomp.open( + str(dbpath), ids=["merged_entry"], merge_fragments=True + ) as merged_filtered: assert len(merged_filtered) == 1 name, pdb = merged_filtered[0] assert name == "merged_entry" assert "ATOM" in pdb assert len(merged_filtered.source_indices(0)) == 2 - with foldcomp.open(str(dbpath), ids=["merged_entry"], merge_fragments=True, format="mmcif") as merged_mmcif: + with foldcomp.open( + str(dbpath), ids=["merged_entry"], merge_fragments=True, format="mmcif" + ) as merged_mmcif: assert len(merged_mmcif) == 1 name, mmcif = merged_mmcif[0] assert name == "merged_entry" @@ -139,4 +143,6 @@ def test_source_indices_requires_merge(pytestconfig): def test_open_invalid_format(pytestconfig): with pytest.raises(ValueError, match="format must be one of"): - foldcomp.open(str(pytestconfig.rootpath.joinpath("test/example_db")), format="xyz") + foldcomp.open( + str(pytestconfig.rootpath.joinpath("test/example_db")), format="xyz" + )