diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 5b233393..1b5a27e1 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -10,7 +10,7 @@ This guide explains how to set up your environment, make changes, and submit the ## Getting Started -Before contributing, please review the [Developer Guide](https://codeentropy.readthedocs.io/en/latest/developer_guide.html). +Before contributing, please review the [Developer Guide](https://codeentropy.readthedocs.io/en/latest/developer_guide.html). It covers CodeEntropy’s architecture, setup instructions, and contribution workflow. If you’re new to the project, we also recommend: @@ -23,19 +23,19 @@ If you’re new to the project, we also recommend: When you’re ready to submit your work: -1. **Push your branch** to GitHub. -2. **Open a [pull request](https://help.github.com/articles/using-pull-requests/)** against the `main` branch. +1. **Push your branch** to GitHub. +2. **Open a [pull request](https://help.github.com/articles/using-pull-requests/)** against the `main` branch. 3. **Fill out the PR template**, including: - - A concise summary of what your PR does - - A list of all changes introduced - - Details on how these changes affect the repository (features, tests, documentation, etc.) + - A concise summary of what your PR does + - A list of all changes introduced + - Details on how these changes affect the repository (features, tests, documentation, etc.) 4. **Verify before submission**: - - All tests pass - - Pre-commit checks succeed - - Documentation is updated where applicable + - All tests pass + - Pre-commit checks succeed + - Documentation is updated where applicable 5. **Review process**: - - Your PR will be reviewed by the core development team. - - At least **one approval** is required before merging. + - Your PR will be reviewed by the core development team. + - At least **one approval** is required before merging. We aim to provide constructive feedback quickly and appreciate your patience during the review process. @@ -45,12 +45,12 @@ We aim to provide constructive feedback quickly and appreciate your patience dur Found a bug or have a feature request? -1. **Open a new issue** on GitHub. -2. Provide a **clear and descriptive title**. +1. **Open a new issue** on GitHub. +2. Provide a **clear and descriptive title**. 3. Include: - - Steps to reproduce the issue (if applicable) - - Expected vs. actual behavior - - Relevant logs, screenshots, or input files + - Steps to reproduce the issue (if applicable) + - Expected vs. actual behavior + - Relevant logs, screenshots, or input files Well-documented issues help us address problems faster and keep CodeEntropy stable and robust. diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index b7215b1f..4492a51e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -12,11 +12,11 @@ assignees: '' --- -## To Reproduce +## To Reproduce ### YAML configuration - ```yaml # Paste the YAML snippet here @@ -46,7 +46,7 @@ Remove unrelated fields to make it minimal. --> - Python Version: - Package list: - If using conda, run: `conda list > packages.txt` and paste the contents here. - + ``` bash # Paste packages.txt here ``` diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 41e7b7c8..6b7db578 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,16 +5,16 @@ ### Change 1 : - -- +- ### Change 2 : - -- +- ### Change 3 : - -- +- ## Impact - -- \ No newline at end of file +- diff --git a/.github/renovate.json b/.github/renovate.json index 02834d3c..371a631f 100644 --- a/.github/renovate.json +++ b/.github/renovate.json @@ -18,4 +18,4 @@ "automerge": false } ] -} \ No newline at end of file +} diff --git a/.github/workflows/daily.yaml b/.github/workflows/daily.yaml new file mode 100644 index 00000000..218e8345 --- /dev/null +++ b/.github/workflows/daily.yaml @@ -0,0 +1,38 @@ +name: CodeEntropy Daily + +on: + schedule: + - cron: '0 8 * * 1-5' + workflow_dispatch: + +concurrency: + group: daily-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit: + name: Unit (${{ matrix.os }}, ${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + os: [ubuntu-24.04, macos-15, windows-2025] + python-version: ["3.12", "3.13", "3.14"] + steps: + - name: Checkout + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install (testing) + run: | + python -m pip install --upgrade pip + python -m pip install -e .[testing] + + - name: Pytest (unit) • ${{ matrix.os }} • py${{ matrix.python-version }} + run: python -m pytest tests/unit -q diff --git a/.github/workflows/mdanalysis-compatibility-failure.md b/.github/workflows/mdanalysis-compatibility-failure.md deleted file mode 100644 index 5815a9c7..00000000 --- a/.github/workflows/mdanalysis-compatibility-failure.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -title: "CI Failure: MDAnalysis v{{ env.MDA_VERSION }} / Python {{ env.PYTHON_VERSION }}" -labels: - - "CI Failure" - - "MDAnalysis Compatibility" ---- - -Automated MDAnalysis Compatibility Test Failure -MDAnalysis version: {{ env.MDA_VERSION }} -Python version: {{ env.PYTHON_VERSION }} -Workflow Run: [Run #{{ env.RUN_NUMBER }}]({{ env.RUN_URL }}) diff --git a/.github/workflows/mdanalysis-compatibility.yaml b/.github/workflows/mdanalysis-compatibility.yaml deleted file mode 100644 index 8b671e3f..00000000 --- a/.github/workflows/mdanalysis-compatibility.yaml +++ /dev/null @@ -1,49 +0,0 @@ -name: MDAnalysis Compatibility - -on: - schedule: - - cron: '0 8 * * 1' # Weekly Monday checks - workflow_dispatch: - -jobs: - mdanalysis-compatibility: - name: MDAnalysis Compatibility Tests - runs-on: ${{ matrix.os }} - - strategy: - matrix: - os: [ubuntu-24.04, windows-2025, macos-15] - python-version: ["3.11", "3.12", "3.13", "3.14"] - mdanalysis-version: ["2.10.0"] - - steps: - - name: Checkout repo - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: ${{ matrix.python-version }} - - - name: Install dependencies with MDAnalysis ${{ matrix.mdanalysis-version }} - run: | - pip install --upgrade pip - pip install -e .[testing] - pip install "MDAnalysis==${{ matrix.mdanalysis-version }}" - - - name: Run compatibility tests - run: pytest --cov CodeEntropy --cov-report=term-missing --cov-append - - - name: Create Issue on Failure - if: failure() - uses: JasonEtco/create-an-issue@1b14a70e4d8dc185e5cc76d3bec9eab20257b2c5 # v2 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PYTHON_VERSION: ${{ matrix.python-version }} - MDA_VERSION: ${{ matrix.mdanalysis-version }} - RUN_NUMBER: ${{ github.run_number }} - RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} - with: - filename: .github/workflows/mdanalysis-compatibility-failure.md - update_existing: true - search_existing: open diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml new file mode 100644 index 00000000..a53493f4 --- /dev/null +++ b/.github/workflows/pr.yaml @@ -0,0 +1,166 @@ +name: CodeEntropy CI + +on: + pull_request: + +concurrency: + group: pr-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit: + name: Unit + runs-on: ${{ matrix.os }} + timeout-minutes: 25 + strategy: + fail-fast: false + matrix: + os: [ubuntu-24.04, macos-15, windows-2025] + python-version: ["3.12", "3.13", "3.14"] + steps: + - name: Checkout + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install testing dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e .[testing] + + - name: Pytest (unit) • ${{ matrix.os }}, ${{ matrix.python-version }} + run: python -m pytest tests/unit -q + + regression-quick: + name: Regression (quick) + needs: unit + runs-on: ubuntu-24.04 + timeout-minutes: 35 + steps: + - name: Checkout + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + + - name: Set up Python 3.14 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.14" + cache: pip + + - name: Cache testdata + uses: actions/cache@v4 + with: + path: .testdata + key: codeentropy-testdata-v1-${{ runner.os }}-py3.14 + + - name: Install (testing) + run: | + python -m pip install --upgrade pip + python -m pip install -e .[testing] + + - name: Pytest (regression quick) + run: python -m pytest tests/regression -q + + - name: Upload artifacts (failure) + if: failure() + uses: actions/upload-artifact@v4 + with: + name: quick-regression-failure + path: | + .testdata/** + tests/regression/**/.pytest_cache/** + /tmp/pytest-of-*/pytest-*/**/config.yaml + /tmp/pytest-of-*/pytest-*/**/codeentropy_stdout.txt + /tmp/pytest-of-*/pytest-*/**/codeentropy_stderr.txt + /tmp/pytest-of-*/pytest-*/**/codeentropy_output.json + + docs: + name: Docs + runs-on: ubuntu-24.04 + timeout-minutes: 25 + steps: + - name: Checkout + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + + - name: Set up Python 3.14 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.14" + cache: pip + + - name: Install (docs) + run: | + python -m pip install --upgrade pip + python -m pip install -e .[docs] + + - name: Build docs + run: | + cd docs + make + + pre-commit: + name: Pre-commit + runs-on: ubuntu-24.04 + timeout-minutes: 15 + steps: + - name: Checkout + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + + - name: Set up Python 3.14 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.14" + cache: pip + + - name: Install (pre-commit) + run: | + python -m pip install --upgrade pip + python -m pip install -e .[pre-commit] + + - name: Run pre-commit + shell: bash + run: | + pre-commit install + pre-commit run --all-files || { + git status --short + git diff + exit 1 + } + + coverage: + name: Coverage + needs: unit + runs-on: ubuntu-24.04 + timeout-minutes: 30 + steps: + - name: Checkout repo + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 + + - name: Set up Python 3.14 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 + with: + python-version: "3.14" + cache: pip + + - name: Install (testing) + run: | + python -m pip install --upgrade pip + python -m pip install -e .[testing] + + - name: Run unit test suite with coverage + run: | + pytest tests/unit \ + --cov CodeEntropy \ + --cov-report term-missing \ + --cov-report xml \ + -q + + - name: Upload to Coveralls + uses: coverallsapp/github-action@5cbfd81b66ca5d10c19b062c04de0199c215fb6e + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + file: coverage.xml + fail-on-error: false diff --git a/.github/workflows/project-ci.yaml b/.github/workflows/project-ci.yaml deleted file mode 100644 index ae8ab9a0..00000000 --- a/.github/workflows/project-ci.yaml +++ /dev/null @@ -1,85 +0,0 @@ -name: CodeEntropy CI - -on: - push: - branches: [main] - pull_request: - schedule: - - cron: '0 8 * * 1-5' - workflow_dispatch: - -jobs: - tests: - name: Run tests - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-24.04, windows-2025, macos-15] - python-version: ["3.12", "3.13", "3.14"] - steps: - - name: Checkout repo - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: ${{ matrix.python-version }} - - - name: Install CodeEntropy and its testing dependencies - run: pip install -e .[testing] - - - name: Run test suite - run: pytest --cov CodeEntropy --cov-report term-missing --cov-append . - - - name: Coveralls GitHub Action - uses: coverallsapp/github-action@5cbfd81b66ca5d10c19b062c04de0199c215fb6e # v2.3.7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - parallel: true - - docs: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-24.04, windows-2025, macos-15] - python-version: ["3.12", "3.13", "3.14"] - timeout-minutes: 15 - steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install python dependencies - run: | - pip install --upgrade pip - pip install -e .[docs] - - name: Build docs - run: cd docs && make - - pre-commit: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-24.04, windows-2025, macos-15] - python-version: ["3.12", "3.13", "3.14"] - timeout-minutes: 15 - steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install python dependencies - run: | - pip install --upgrade pip - pip install -e .[pre-commit] - - name: Run pre-commit - shell: bash - run: | - pre-commit install - pre-commit run --all-files || { - git status --short - git diff - exit 1 - } diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 7279ae57..05102594 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -146,4 +146,3 @@ jobs: env: FLIT_USERNAME: __token__ FLIT_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - diff --git a/.github/workflows/weekly-docs.yaml b/.github/workflows/weekly-docs.yaml new file mode 100644 index 00000000..6c38e5b3 --- /dev/null +++ b/.github/workflows/weekly-docs.yaml @@ -0,0 +1,46 @@ +name: CodeEntropy Weekly Docs + +on: + schedule: + - cron: '0 8 * * 1' + workflow_dispatch: + +concurrency: + group: weekly-docs-${{ github.ref }} + cancel-in-progress: true + +jobs: + docs: + name: Docs build (${{ matrix.os }}, python ${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-24.04, windows-2025, macos-15] + python-version: ["3.12", "3.13", "3.14"] + timeout-minutes: 30 + steps: + - name: Checkout repo + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install python dependencies + run: | + pip install --upgrade pip + pip install -e .[docs] + + - name: Build docs + run: cd docs && make + + - name: Upload docs artifacts on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: docs-${{ matrix.os }}-py${{ matrix.python-version }}-failure + path: | + docs/_build/** diff --git a/.github/workflows/weekly-regression.yaml b/.github/workflows/weekly-regression.yaml new file mode 100644 index 00000000..7c8fa5a3 --- /dev/null +++ b/.github/workflows/weekly-regression.yaml @@ -0,0 +1,52 @@ +name: CodeEntropy Weekly Regression + +on: + schedule: + - cron: '0 8 * * 1' # Weekly Monday checks + workflow_dispatch: + +concurrency: + group: weekly-regression-${{ github.ref }} + cancel-in-progress: true + +jobs: + regression: + name: Regression tests (including slow) + runs-on: ubuntu-24.04 + timeout-minutes: 180 + steps: + - name: Checkout repo + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + + - name: Set up Python 3.14 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.14" + cache: pip + + - name: Cache regression test data downloads + uses: actions/cache@v4 + with: + path: .testdata + key: codeentropy-testdata-${{ runner.os }}-py314 + + - name: Install CodeEntropy and its testing dependencies + run: | + pip install --upgrade pip + pip install -e .[testing] + + - name: Run regression test suite + run: pytest tests/regression -q --run-slow + + - name: Upload regression artifacts on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: regression-failure-artifacts + path: | + .testdata/** + tests/regression/**/.pytest_cache/** + /tmp/pytest-of-*/pytest-*/**/config.yaml + /tmp/pytest-of-*/pytest-*/**/codeentropy_stdout.txt + /tmp/pytest-of-*/pytest-*/**/codeentropy_stderr.txt + /tmp/pytest-of-*/pytest-*/**/codeentropy_output.json diff --git a/.gitignore b/.gitignore index a4773328..20fa8d8f 100644 --- a/.gitignore +++ b/.gitignore @@ -125,3 +125,8 @@ job* *.err *.com *.txt + +.testdata/ + +!tests/regression/baselines/ +!tests/regression/baselines/*.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6be5d3ec..ca087bfa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,28 +1,19 @@ repos: - - repo: https://github.com/psf/black - rev: 25.1.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.2 hooks: - - id: black + - id: ruff-check + args: [--fix] + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-added-large-files - - id: check-ast - - id: check-case-conflict - - id: check-executables-have-shebangs - id: check-merge-conflict - - id: check-toml - id: check-yaml - - - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-pyproject] - - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile=black"] \ No newline at end of file + - id: check-toml + - id: check-case-conflict + - id: check-ast + - id: end-of-file-fixer + - id: trailing-whitespace diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 729da91b..2b7c3c7f 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -125,4 +125,4 @@ enforcement ladder](https://github.com/mozilla/diversity). For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at -https://www.contributor-covenant.org/translations. \ No newline at end of file +https://www.contributor-covenant.org/translations. diff --git a/CodeEntropy/__main__.py b/CodeEntropy/__main__.py new file mode 100644 index 00000000..55b3fbe2 --- /dev/null +++ b/CodeEntropy/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from CodeEntropy.cli import main + +if __name__ == "__main__": + main() diff --git a/CodeEntropy/axes.py b/CodeEntropy/axes.py deleted file mode 100644 index d04b2fc0..00000000 --- a/CodeEntropy/axes.py +++ /dev/null @@ -1,511 +0,0 @@ -import logging - -import numpy as np -from MDAnalysis.lib.mdamath import make_whole - -logger = logging.getLogger(__name__) - - -class AxesManager: - """ - Manages the structural and dynamic levels involved in entropy calculations. This - includes selecting relevant levels, computing axes for translation and rotation, - and handling bead-based representations of molecular systems. Provides utility - methods to extract averaged positions, convert coordinates to spherical systems, - compute weighted forces and torques, and manipulate matrices used in entropy - analysis. - """ - - def __init__(self): - """ - Initializes the AxesManager with placeholders for level-related data, - including translational and rotational axes, number of beads, and a - general-purpose data container. - """ - self.data_container = None - self._levels = None - self._trans_axes = None - self._rot_axes = None - self._number_of_beads = None - - def get_residue_axes(self, data_container, index): - """ - The translational and rotational axes at the residue level. - - Args: - data_container (MDAnalysis.Universe): the molecule and trajectory data - index (int): residue index - - Returns: - trans_axes : translational axes (3,3) - rot_axes : rotational axes (3,3) - center: center of mass (3,) - moment_of_inertia: moment of inertia (3,) - """ - # TODO refine selection so that it will work for branched polymers - index_prev = index - 1 - index_next = index + 1 - atom_set = data_container.select_atoms( - f"(resindex {index_prev} or resindex {index_next}) " - f"and bonded resid {index}" - ) - residue = data_container.select_atoms(f"resindex {index}") - center = residue.atoms.center_of_mass(unwrap=True) - - if len(atom_set) == 0: - # No bonds to other residues - # Use a custom principal axes, from a MOI tensor - # that uses positions of heavy atoms only, but including masses - # of heavy atom + bonded hydrogens - UAs = residue.select_atoms("mass 2 to 999") - UA_masses = self.get_UA_masses(residue) - moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( - center, UAs.positions, UA_masses, data_container.dimensions[:3] - ) - rot_axes, moment_of_inertia = self.get_custom_principal_axes( - moment_of_inertia_tensor - ) - trans_axes = ( - rot_axes # set trans axes to same as rot axes as per Jon's code - ) - else: - # if bonded to other residues, use default axes and MOI - make_whole(data_container.atoms) - trans_axes = data_container.atoms.principal_axes() - rot_axes, moment_of_inertia = self.get_vanilla_axes(residue) - center = residue.center_of_mass(unwrap=True) - - return trans_axes, rot_axes, center, moment_of_inertia - - def get_UA_axes(self, data_container, index): - """ - The translational and rotational axes at the united-atom level. - - Args: - data_container (MDAnalysis.Universe): the molecule and trajectory data - index (int): residue index - - Returns: - trans_axes : translational axes (3,3) - rot_axes : rotational axes (3,3) - center: center of mass (3,) - moment_of_inertia: moment of inertia (3,) - """ - - index = int(index) # bead index - - # use the same customPI trans axes as the residue level - heavy_atoms = data_container.select_atoms("prop mass > 1.1") - if len(heavy_atoms) > 1: - UA_masses = self.get_UA_masses(data_container.atoms) - center = data_container.atoms.center_of_mass(unwrap=True) - moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( - center, heavy_atoms.positions, UA_masses, data_container.dimensions[:3] - ) - trans_axes, _moment_of_inertia = self.get_custom_principal_axes( - moment_of_inertia_tensor - ) - else: - # use standard PA for UA not bonded to anything else - make_whole(data_container.atoms) - trans_axes = data_container.atoms.principal_axes() - - # look for heavy atoms in residue of interest - heavy_atom_indices = [] - for atom in heavy_atoms: - heavy_atom_indices.append(atom.index) - # we find the nth heavy atom - # where n is the bead index - heavy_atom_index = heavy_atom_indices[index] - heavy_atom = data_container.select_atoms(f"index {heavy_atom_index}") - - center = heavy_atom.positions[0] - rot_axes, moment_of_inertia = self.get_bonded_axes( - data_container, heavy_atom[0], data_container.dimensions[:3] - ) - - logger.debug(f"Translational Axes: {trans_axes}") - logger.debug(f"Rotational Axes: {rot_axes}") - logger.debug(f"Center: {center}") - logger.debug(f"Moment of Inertia: {moment_of_inertia}") - - return trans_axes, rot_axes, center, moment_of_inertia - - def get_bonded_axes(self, system, atom, dimensions): - """ - For a given heavy atom, use its bonded atoms to get the axes - for rotating forces around. Few cases for choosing united atom axes, - which are dependent on the bonds to the atom: - - :: - - X -- H = bonded to zero or more light atom/s (case1) - - X -- R = bonded to one heavy atom (case2) - - R -- X -- H = bonded to one heavy and at least one light atom (case3) - - R1 -- X -- R2 = bonded to two heavy atoms (case4) - - R1 -- X -- R2 = bonded to more than two heavy atoms (case5) - | - R3 - - Note that axis2 is calculated by taking the cross product between axis1 and - the vector chosen for each case, dependent on bonding: - - - case1: if all the bonded atoms are hydrogens, use the principal axes. - - - case2: use XR vector as axis1, arbitrary axis2. - - - case3: use XR vector as axis1, vector XH to calculate axis2 - - - case4: use vector XR1 as axis1, and XR2 to calculate axis2 - - - case5: get the sum of all XR normalised vectors as axis1, then use vector - R1R2 to calculate axis2 - - axis3 is always the cross product of axis1 and axis2. - - Args: - system: mdanalysis instance of all atoms in current frame - atom: mdanalysis instance of a heavy atom - dimensions: dimensions of the simulation box (3,) - - Returns: - custom_axes: custom axes for the UA, (3,3) array - custom_moment_of_inertia - """ - # check atom is a heavy atom - if not atom.mass > 1.1: - return None - # set default values - custom_moment_of_inertia = None - custom_axes = None - - # find the heavy bonded atoms and light bonded atoms - heavy_bonded, light_bonded = self.find_bonded_atoms(atom.index, system) - UA = atom + light_bonded - UA_all = atom + heavy_bonded + light_bonded - - # now find which atoms to select to find the axes for rotating forces: - # case1 - if len(heavy_bonded) == 0: - custom_axes, custom_moment_of_inertia = self.get_vanilla_axes(UA_all) - # case2 - if len(heavy_bonded) == 1 and len(light_bonded) == 0: - custom_axes = self.get_custom_axes( - atom.position, [heavy_bonded[0].position], np.zeros(3), dimensions - ) - # case3 - if len(heavy_bonded) == 1 and len(light_bonded) >= 1: - custom_axes = self.get_custom_axes( - atom.position, - [heavy_bonded[0].position], - light_bonded[0].position, - dimensions, - ) - # case4, not used in Jon's 2019 paper code, use case5 instead - # case5 - if len(heavy_bonded) >= 2: - custom_axes = self.get_custom_axes( - atom.position, - heavy_bonded.positions, - heavy_bonded[1].position, - dimensions, - ) - - if custom_moment_of_inertia is None: - # find moment of inertia using custom axes and atom position as COM - custom_moment_of_inertia = self.get_custom_moment_of_inertia( - UA, custom_axes, atom.position, dimensions - ) - - # get the moment of inertia from the custom axes - if custom_axes is not None: - # flip axes to face correct way wrt COM - custom_axes = self.get_flipped_axes( - UA, custom_axes, atom.position, dimensions - ) - - return custom_axes, custom_moment_of_inertia - - def find_bonded_atoms(self, atom_idx: int, system): - """ - for a given atom, find its bonded heavy and H atoms - - Args: - atom_idx: atom index to find bonded heavy atom for - system: mdanalysis instance of all atoms in current frame - - Returns: - bonded_heavy_atoms: MDAnalysis instance of bonded heavy atoms - bonded_H_atoms: MDAnalysis instance of bonded hydrogen atoms - """ - bonded_atoms = system.select_atoms(f"bonded index {atom_idx}") - bonded_heavy_atoms = bonded_atoms.select_atoms("mass 2 to 999") - bonded_H_atoms = bonded_atoms.select_atoms("mass 1 to 1.1") - return bonded_heavy_atoms, bonded_H_atoms - - def get_vanilla_axes(self, molecule): - """ - Compute the principal axes and sorted moments of inertia for a molecule. - - This method computes the translationally invariant principal axes and - corresponding moments of inertia for a molecular selection using the - default MDAnalysis routines. The molecule is first made whole to ensure - correct handling of periodic boundary conditions. - - The moments of inertia are obtained by diagonalising the moment of inertia - tensor and are returned sorted from largest to smallest magnitude. - - Args: - molecule (MDAnalysis.core.groups.AtomGroup): - AtomGroup representing the molecule or bead for which the axes - and moments of inertia are to be computed. - - Returns: - Tuple[np.ndarray, np.ndarray]: - A tuple containing: - - - principal_axes (np.ndarray): - Array of shape ``(3, 3)`` whose rows correspond to the - principal axes of the molecule. - - moment_of_inertia (np.ndarray): - Array of shape ``(3,)`` containing the moments of inertia - sorted in descending order. - """ - moment_of_inertia = molecule.moment_of_inertia(unwrap=True) - make_whole(molecule.atoms) - principal_axes = molecule.principal_axes() - - eigenvalues, _eigenvectors = np.linalg.eig(moment_of_inertia) - - # Sort eigenvalues from largest to smallest magnitude - order = np.argsort(np.abs(eigenvalues))[::-1] - moment_of_inertia = eigenvalues[order] - - return principal_axes, moment_of_inertia - - def get_custom_axes( - self, a: np.ndarray, b_list: list, c: np.ndarray, dimensions: np.ndarray - ): - r""" - For atoms a, b_list and c, calculate the axis to rotate forces around: - - - axis1: use the normalised vector ab as axis1. If there is more than one bonded - heavy atom (HA), average over all the normalised vectors calculated from b_list - and use this as axis1). b_list contains all the bonded heavy atom - coordinates. - - - axis2: use the cross product of normalised vector ac and axis1 as axis2. - If there are more than two bonded heavy atoms, then use normalised vector - b[0]c to cross product with axis1, this gives the axis perpendicular - (represented by |_ symbol below) to axis1. - - - axis3: the cross product of axis1 and axis2, which is perpendicular to - axis1 and axis2. - - Args: - a: central united-atom coordinates (3,) - b_list: list of heavy bonded atom positions (3,N) - c: atom coordinates of either a second heavy atom or a hydrogen atom - if there are no other bonded heavy atoms in b_list (where N=1 in b_list) - (3,) - dimensions: dimensions of the simulation box (3,) - - :: - - a 1 = norm_ab - / \ 2 = |_ norm_ab and norm_ac (use bc if more than 2 HAs) - / \ 3 = |_ 1 and 2 - b c - - Returns: - custom_axes: (3,3) array of the axes used to rotate forces - """ - unscaled_axis1 = np.zeros(3) - # average of all heavy atom covalent bond vectors for axis1 - for b in b_list: - ab_vector = self.get_vector(a, b, dimensions) - unscaled_axis1 += ab_vector - if len(b_list) >= 2: - # use the first heavy bonded atom as atom a - ac_vector = self.get_vector(c, b_list[0], dimensions) - else: - ac_vector = self.get_vector(c, a, dimensions) - - unscaled_axis2 = np.cross(ac_vector, unscaled_axis1) - unscaled_axis3 = np.cross(unscaled_axis2, unscaled_axis1) - - unscaled_custom_axes = np.array( - (unscaled_axis1, unscaled_axis2, unscaled_axis3) - ) - mod = np.sqrt(np.sum(unscaled_custom_axes**2, axis=1)) - scaled_custom_axes = unscaled_custom_axes / mod[:, np.newaxis] - - return scaled_custom_axes - - def get_custom_moment_of_inertia( - self, - UA, - custom_rotation_axes: np.ndarray, - center_of_mass: np.ndarray, - dimensions: np.ndarray, - ): - """ - Get the moment of inertia (specifically used for the united atom level) - from a set of rotation axes and a given center of mass - (COM is usually the heavy atom position in a UA). - - Args: - UA: MDAnalysis instance of a united-atom - custom_rotation_axes: (3,3) arrray of rotation axes - center_of_mass: (3,) center of mass for collection of atoms N - - Returns: - custom_moment_of_inertia: (3,) array for moment of inertia - """ - translated_coords = self.get_vector(center_of_mass, UA.positions, dimensions) - custom_moment_of_inertia = np.zeros(3) - for coord, mass in zip(translated_coords, UA.masses): - axis_component = np.sum( - np.cross(custom_rotation_axes, coord) ** 2 * mass, axis=1 - ) - custom_moment_of_inertia += axis_component - - # Remove lowest MOI degree of freedom if UA only has a single bonded H - if len(UA) == 2: - order = custom_moment_of_inertia.argsort()[::-1] # decending order - custom_moment_of_inertia[order[-1]] = 0 - - return custom_moment_of_inertia - - def get_flipped_axes(self, UA, custom_axes, center_of_mass, dimensions): - """ - For a given set of custom axes, ensure the axes are pointing in the - correct direction wrt the heavy atom position and the chosen center - of mass. - - Args: - UA: MDAnalysis instance of a united-atom - custom_axes: (3,3) array of the rotation axes - center_of_mass: (3,) array for center of mass (usually HA position) - dimensions: (3,) array of system box dimensions. - """ - # sorting out PIaxes for MoI for UA fragment - - # get dot product of Paxis1 and CoM->atom1 vect - # will just be [0,0,0] - RRaxis = self.get_vector(UA[0].position, center_of_mass, dimensions) - - # flip each Paxis if its pointing out of UA - custom_axis = np.sum(custom_axes**2, axis=1) - custom_axes_flipped = custom_axes / custom_axis**0.5 - for i in range(3): - dotProd1 = np.dot(custom_axes_flipped[i], RRaxis) - custom_axes_flipped[i] = np.where( - dotProd1 < 0, -custom_axes_flipped[i], custom_axes_flipped[i] - ) - return custom_axes_flipped - - def get_vector(self, a: np.ndarray, b: np.ndarray, dimensions: np.ndarray): - """ - For vector of two coordinates over periodic boundary conditions (PBCs). - - Args: - a: (N,3) array of atom cooordinates - b: (3,) array of atom cooordinates - dimensions: (3,) array of system box dimensions. - - Returns: - delta_wrapped: (N,3) array of the vector - """ - delta = b - a - delta -= dimensions * np.round(delta / dimensions) - - return delta - - def get_moment_of_inertia_tensor( - self, - center_of_mass: np.ndarray, - positions: np.ndarray, - masses: list, - dimensions: np.array, - ) -> np.ndarray: - """ - Calculate a custom moment of inertia tensor. - E.g., for cases where the mass list will contain masses of UAs rather than - individual atoms and the postions will be those for the UAs only - (excluding the H atoms coordinates). - - Args: - center_of_mass: a (3,) array of the chosen center of mass - positions: a (N,3) array of point positions - masses: a (N,) list of point masses - - Returns: - moment_of_inertia_tensor: a (3,3) moment of inertia tensor - """ - r = self.get_vector(center_of_mass, positions, dimensions) - r2 = np.sum(r**2, axis=1) - moment_of_inertia_tensor = np.eye(3) * np.sum(masses * r2) - moment_of_inertia_tensor -= np.einsum("i,ij,ik->jk", masses, r, r) - - return moment_of_inertia_tensor - - def get_custom_principal_axes( - self, moment_of_inertia_tensor: np.ndarray - ) -> tuple[np.ndarray, np.ndarray]: - """ - Principal axes and centre of axes from the ordered eigenvalues - and eigenvectors of a moment of inertia tensor. This function allows for - a custom moment of inertia tensor to be used, which isn't possible with - the built-in MDAnalysis principal_axes() function. - - Args: - moment_of_inertia_tensor: a (3,3) array of a custom moment of - inertia tensor - - Returns: - principal_axes: a (3,3) array for the principal axes - moment_of_inertia: a (3,) array of the principal axes center - """ - eigenvalues, eigenvectors = np.linalg.eig(moment_of_inertia_tensor) - order = abs(eigenvalues).argsort()[::-1] # decending order - transposed = np.transpose(eigenvectors) # turn columns to rows - moment_of_inertia = eigenvalues[order] - principal_axes = transposed[order] - - # point z axis in correct direction, as per Jon's code - cross_xy = np.cross(principal_axes[0], principal_axes[1]) - dot_z = np.dot(cross_xy, principal_axes[2]) - if dot_z < 0: - principal_axes[2] *= -1 - - return principal_axes, moment_of_inertia - - def get_UA_masses(self, molecule) -> list[float]: - """ - For a given molecule, return a list of masses of UAs - (combination of the heavy atoms + bonded hydrogen atoms. This list is used to - get the moment of inertia tensor for molecules larger than one UA. - - Args: - molecule: mdanalysis instance of molecule - - Returns: - UA_masses: list of masses for each UA in a molecule - """ - UA_masses = [] - for atom in molecule: - if atom.mass > 1.1: - UA_mass = atom.mass - bonded_atoms = molecule.select_atoms(f"bonded index {atom.index}") - bonded_H_atoms = bonded_atoms.select_atoms("mass 1 to 1.1") - for H in bonded_H_atoms: - UA_mass += H.mass - UA_masses.append(UA_mass) - else: - continue - return UA_masses diff --git a/CodeEntropy/cli.py b/CodeEntropy/cli.py new file mode 100644 index 00000000..c30cc731 --- /dev/null +++ b/CodeEntropy/cli.py @@ -0,0 +1,35 @@ +"""Command-line entry point for CodeEntropy. + +This module provides the CLI entry point used to run the multiscale cell +correlation entropy workflow. + +The entry point is intentionally small and only responsible for: + 1) Creating a job folder. + 2) Constructing a CodeEntropyRunner. + 3) Executing the entropy workflow. + 4) Handling fatal errors with a non-zero exit code. +""" + +from __future__ import annotations + +import logging + +from CodeEntropy.config.runtime import CodeEntropyRunner + +logger = logging.getLogger(__name__) + + +def main() -> None: + """Run the entropy workflow. + + Raises: + SystemExit: Exits with status code 1 on any unhandled exception. + """ + folder = CodeEntropyRunner.create_job_folder() + + try: + run_manager = CodeEntropyRunner(folder=folder) + run_manager.run_entropy_workflow() + except Exception: + logger.exception("Fatal error during entropy calculation") + raise SystemExit(1) from None diff --git a/CodeEntropy/config/arg_config_manager.py b/CodeEntropy/config/arg_config_manager.py deleted file mode 100644 index be066423..00000000 --- a/CodeEntropy/config/arg_config_manager.py +++ /dev/null @@ -1,286 +0,0 @@ -import argparse -import glob -import logging -import os - -import yaml - -# Set up logger -logger = logging.getLogger(__name__) - -arg_map = { - "top_traj_file": { - "type": str, - "nargs": "+", - "help": "Path to structure/topology file followed by trajectory file", - }, - "force_file": { - "type": str, - "default": None, - "help": "Optional path to force file if forces are not in trajectory file", - }, - "file_format": { - "type": str, - "default": None, - "help": "String for file format as recognised by MDAnalysis", - }, - "kcal_force_units": { - "type": bool, - "default": False, - "help": "Set this to True if you have a separate force file with kcal units.", - }, - "selection_string": { - "type": str, - "help": "Selection string for CodeEntropy", - "default": "all", - }, - "start": { - "type": int, - "help": "Start analysing the trajectory from this frame index", - "default": 0, - }, - "end": { - "type": int, - "help": ( - "Stop analysing the trajectory at this frame index. This is " - "the frame index of the last frame to be included, so for example" - "if start=0 and end=500 there would be 501 frames analysed. The " - "default -1 will include the last frame." - ), - "default": -1, - }, - "step": { - "type": int, - "help": "Interval between two consecutive frames to be read index", - "default": 1, - }, - "bin_width": { - "type": int, - "help": "Bin width in degrees for making the histogram", - "default": 30, - }, - "temperature": { - "type": float, - "help": "Temperature for entropy calculation (K)", - "default": 298.0, - }, - "verbose": { - "action": "store_true", - "help": "Enable verbose output", - }, - "output_file": { - "type": str, - "help": ( - "Name of the output file to write results to (filename only). Defaults " - "to output_file.json" - ), - "default": "output_file.json", - }, - "force_partitioning": {"type": float, "help": "Force partitioning", "default": 0.5}, - "water_entropy": { - "type": bool, - "help": "If set to False, disables the calculation of water entropy", - "default": True, - }, - "grouping": { - "type": str, - "help": "How to group molecules for averaging", - "default": "molecules", - }, - "combined_forcetorque": { - "type": bool, - "help": """Use combined force-torque matrix for residue - level vibrational entropies""", - "default": True, - }, - "customised_axes": { - "type": bool, - "help": """Use bonded axes to rotate forces for UA - level vibrational entropies""", - "default": True, - }, -} - - -class ConfigManager: - def __init__(self): - self.arg_map = arg_map - - def load_config(self, file_path): - """Load YAML configuration file from the given directory.""" - yaml_files = glob.glob(os.path.join(file_path, "*.yaml")) - - if not yaml_files: - return {"run1": {}} - - try: - with open(yaml_files[0], "r") as file: - config = yaml.safe_load(file) - logger.info(f"Loaded configuration from: {yaml_files[0]}") - if config is None: - config = {"run1": {}} - except Exception as e: - logger.error(f"Failed to load config file: {e}") - config = {"run1": {}} - - return config - - def str2bool(self, value): - """ - Convert a string or boolean input into a boolean value. - - Accepts common string representations of boolean values such as: - - True values: "true", "t", "yes", "1" - - False values: "false", "f", "no", "0" - - If the input is already a boolean, it is returned as-is. - Raises: - argparse.ArgumentTypeError: If the input cannot be interpreted as a boolean. - - Args: - value (str or bool): The input value to convert. - - Returns: - bool: The corresponding boolean value. - """ - if isinstance(value, bool): - return value - value = value.lower() - if value in {"true", "t", "yes", "1"}: - return True - elif value in {"false", "f", "no", "0"}: - return False - else: - raise argparse.ArgumentTypeError("Boolean value expected (True/False).") - - def setup_argparse(self): - """Setup argument parsing dynamically based on arg_map.""" - parser = argparse.ArgumentParser( - description="CodeEntropy: Entropy calculation with MCC method." - ) - - for arg, properties in self.arg_map.items(): - help_text = properties.get("help", "") - default = properties.get("default", None) - - if properties.get("type") == bool: - parser.add_argument( - f"--{arg}", - type=self.str2bool, - default=default, - help=f"{help_text} (default: {default})", - ) - else: - kwargs = {k: v for k, v in properties.items() if k != "help"} - parser.add_argument(f"--{arg}", **kwargs, help=help_text) - - return parser - - def merge_configs(self, args, run_config): - """Merge CLI arguments with YAML configuration and adjust logging level.""" - if run_config is None: - run_config = {} - - if not isinstance(run_config, dict): - raise TypeError("run_config must be a dictionary or None.") - - # Convert argparse Namespace to dictionary - args_dict = vars(args) - - # Reconstruct parser and check which arguments were explicitly provided via CLI - parser = self.setup_argparse() - default_args = parser.parse_args([]) - default_dict = vars(default_args) - - cli_provided_args = { - key for key, value in args_dict.items() if value != default_dict.get(key) - } - - # Step 1: Apply YAML values if CLI didn't explicitly set the argument - for key, yaml_value in run_config.items(): - if yaml_value is not None and key not in cli_provided_args: - logger.debug(f"Using YAML value for {key}: {yaml_value}") - setattr(args, key, yaml_value) - - # Step 2: Ensure all arguments have at least their default values - for key, params in self.arg_map.items(): - if getattr(args, key, None) is None: - setattr(args, key, params.get("default")) - - # Step 3: Ensure CLI arguments always take precedence - for key in self.arg_map.keys(): - cli_value = args_dict.get(key) - if cli_value is not None: - run_config[key] = cli_value - - # Adjust logging level based on 'verbose' flag - if getattr(args, "verbose", False): - logger.setLevel(logging.DEBUG) - for handler in logger.handlers: - handler.setLevel(logging.DEBUG) - logger.debug("Verbose mode enabled. Logger set to DEBUG level.") - else: - logger.setLevel(logging.INFO) - for handler in logger.handlers: - handler.setLevel(logging.INFO) - - return args - - def input_parameters_validation(self, u, args): - """Check the validity of the user inputs against sensible values""" - - self._check_input_start(u, args) - self._check_input_end(u, args) - self._check_input_step(args) - self._check_input_bin_width(args) - self._check_input_temperature(args) - self._check_input_force_partitioning(args) - - def _check_input_start(self, u, args): - """Check that the input does not exceed the length of the trajectory.""" - if args.start > len(u.trajectory): - raise ValueError( - f"Invalid 'start' value: {args.start}. It exceeds the trajectory length" - " of {len(u.trajectory)}." - ) - - def _check_input_end(self, u, args): - """Check that the end index does not exceed the trajectory length.""" - if args.end > len(u.trajectory): - raise ValueError( - f"Invalid 'end' value: {args.end}. It exceeds the trajectory length of" - " {len(u.trajectory)}." - ) - - def _check_input_step(self, args): - """Check that the step value is non-negative.""" - if args.step < 0: - logger.warning( - f"Negative 'step' value provided: {args.step}. This may lead to" - " unexpected behavior." - ) - - def _check_input_bin_width(self, args): - """Check that the bin width is within the valid range [0, 360].""" - if args.bin_width < 0 or args.bin_width > 360: - raise ValueError( - f"Invalid 'bin_width': {args.bin_width}. It must be between 0 and 360" - " degrees." - ) - - def _check_input_temperature(self, args): - """Check that the temperature is non-negative.""" - if args.temperature < 0: - raise ValueError( - f"Invalid 'temperature': {args.temperature}. Temperature cannot be" - " below 0." - ) - - def _check_input_force_partitioning(self, args): - """Warn if force partitioning is not set to the default value.""" - default_value = arg_map["force_partitioning"]["default"] - if args.force_partitioning != default_value: - logger.warning( - f"'force_partitioning' is set to {args.force_partitioning}," - f" which differs from the default {default_value}." - ) diff --git a/CodeEntropy/config/argparse.py b/CodeEntropy/config/argparse.py new file mode 100644 index 00000000..ec737301 --- /dev/null +++ b/CodeEntropy/config/argparse.py @@ -0,0 +1,436 @@ +"""Configuration and CLI argument management for CodeEntropy. + +This module provides: + +1) A declarative argument specification (`ARG_SPECS`) used to build an + ``argparse.ArgumentParser``. +2) A `ConfigResolver` that: + - loads YAML configuration (if present), + - merges YAML values with CLI values (CLI wins), + - adjusts logging verbosity, + - validates a subset of runtime inputs against the trajectory. + +Notes: +- Boolean arguments are parsed via `str2bool` to support YAML/CLI interop and + common string forms like "true"/"false". +""" + +from __future__ import annotations + +import argparse +import glob +import logging +import os +from dataclasses import dataclass +from typing import Any, Dict, Optional, Set + +import yaml + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ArgSpec: + """Argument specification used to build an argparse parser. + + Attributes: + help: Help text shown in CLI usage. + default: Default value if not provided via CLI or YAML. + type: Python type for parsing (e.g., int, float, str, bool). If bool, + `ConfigResolver.str2bool` will be used. + action: Optional argparse action (e.g., "store_true"). + nargs: Optional nargs spec (e.g., "+"). + """ + + help: str + default: Any = None + type: Any = None + action: Optional[str] = None + nargs: Optional[str] = None + + +ARG_SPECS: Dict[str, ArgSpec] = { + "top_traj_file": ArgSpec( + type=str, + nargs="+", + help="Path to structure/topology file followed by trajectory file", + ), + "force_file": ArgSpec( + type=str, + default=None, + help="Optional path to force file if forces are not in trajectory file", + ), + "file_format": ArgSpec( + type=str, + default=None, + help="String for file format as recognised by MDAnalysis", + ), + "kcal_force_units": ArgSpec( + type=bool, + default=False, + help="Set this to True if you have a separate force file with kcal units.", + ), + "selection_string": ArgSpec( + type=str, + default="all", + help="Selection string for CodeEntropy", + ), + "start": ArgSpec( + type=int, + default=0, + help="Start analysing the trajectory from this frame index", + ), + "end": ArgSpec( + type=int, + default=-1, + help=( + "Stop analysing the trajectory at this frame index. This is " + "the frame index of the last frame to be included, so for example " + "if start=0 and end=500 there would be 501 frames analysed. The " + "default -1 will include the last frame." + ), + ), + "step": ArgSpec( + type=int, + default=1, + help="Interval between two consecutive frames to be read index", + ), + "bin_width": ArgSpec( + type=int, + default=30, + help="Bin width in degrees for making the histogram", + ), + "temperature": ArgSpec( + type=float, + default=298.0, + help="Temperature for entropy calculation (K)", + ), + "verbose": ArgSpec( + action="store_true", + help="Enable verbose output", + ), + "output_file": ArgSpec( + type=str, + default="output_file.json", + help=( + "Name of the output file to write results to (filename only). Defaults " + "to output_file.json" + ), + ), + "force_partitioning": ArgSpec( + type=float, + default=0.5, + help="Force partitioning", + ), + "water_entropy": ArgSpec( + type=bool, + default=True, + help="If set to False, disables the calculation of water entropy", + ), + "grouping": ArgSpec( + type=str, + default="molecules", + help="How to group molecules for averaging", + ), + "combined_forcetorque": ArgSpec( + type=bool, + default=True, + help="Use combined force-torque matrix for residue level vibrational entropies", + ), + "customised_axes": ArgSpec( + type=bool, + default=True, + help="Use bonded axes to rotate forces for UA level vibrational entropies", + ), +} + + +class ConfigResolver: + """Load, merge, and validate CodeEntropy configuration. + + This class provides a consistent interface for: + - YAML config discovery/loading + - CLI parser construction + - merging YAML values with CLI values (CLI wins) + - setting logging verbosity + - validating trajectory-related numeric parameters + """ + + def __init__(self, arg_specs: Optional[Dict[str, ArgSpec]] = None) -> None: + """Initialize the manager. + + Args: + arg_specs: Optional override for argument specs. If omitted, uses + `ARG_SPECS`. + """ + self._arg_specs = dict(arg_specs or ARG_SPECS) + + def load_config(self, directory_path: str) -> Dict[str, Any]: + """Load the first YAML config file found in a directory. + + The current behavior matches your existing workflow: + - searches for ``*.yaml`` in `directory_path`, + - loads the first match, + - returns ``{"run1": {}}`` if none found or file is empty/invalid. + + Args: + directory_path: Directory to search for YAML files. + + Returns: + A configuration dictionary. + """ + yaml_files = glob.glob(os.path.join(directory_path, "*.yaml")) + if not yaml_files: + return {"run1": {}} + + config_path = yaml_files[0] + try: + with open(config_path, "r", encoding="utf-8") as file: + config = yaml.safe_load(file) or {"run1": {}} + logger.info("Loaded configuration from: %s", config_path) + return config + except Exception as exc: + logger.error("Failed to load config file: %s", exc) + return {"run1": {}} + + @staticmethod + def str2bool(value: Any) -> bool: + """Convert a string or boolean input into a boolean. + + Accepts common string representations: + - True values: "true", "t", "yes", "1" + - False values: "false", "f", "no", "0" + + If the input is already a boolean, it is returned as-is. + + Args: + value: Input value to convert. + + Returns: + The corresponding boolean. + + Raises: + argparse.ArgumentTypeError: If the input cannot be interpreted as a boolean. + """ + if isinstance(value, bool): + return value + if not isinstance(value, str): + raise argparse.ArgumentTypeError("Boolean value expected (True/False).") + + lowered = value.lower() + if lowered in {"true", "t", "yes", "1"}: + return True + if lowered in {"false", "f", "no", "0"}: + return False + raise argparse.ArgumentTypeError("Boolean value expected (True/False).") + + def build_parser(self) -> argparse.ArgumentParser: + """Build an ArgumentParser from argument specs. + + Returns: + An argparse.ArgumentParser configured with all supported flags. + """ + parser = argparse.ArgumentParser( + description="CodeEntropy: Entropy calculation with MCC method." + ) + + for name, spec in self._arg_specs.items(): + arg_name = f"--{name}" + + if spec.action is not None: + parser.add_argument(arg_name, action=spec.action, help=spec.help) + continue + + if spec.type is bool: + parser.add_argument( + arg_name, + type=self.str2bool, + default=spec.default, + help=f"{spec.help} (default: {spec.default})", + ) + continue + + kwargs: Dict[str, Any] = {} + if spec.type is not None: + kwargs["type"] = spec.type + if spec.default is not None: + kwargs["default"] = spec.default + if spec.nargs is not None: + kwargs["nargs"] = spec.nargs + + parser.add_argument(arg_name, help=spec.help, **kwargs) + + return parser + + def resolve( + self, args: argparse.Namespace, run_config: Optional[Dict[str, Any]] + ) -> argparse.Namespace: + """Merge CLI arguments with YAML configuration and adjust logging level. + + Merge rule: + - CLI explicitly-provided values take precedence. + - YAML values fill in missing values. + - Defaults fill in anything still unset. + + Args: + args: Parsed CLI arguments. + run_config: Dict of YAML values for a specific run, or None. + + Returns: + The mutated argparse.Namespace with merged values. + + Raises: + TypeError: If `run_config` is not a dict or None. + """ + if run_config is None: + run_config = {} + if not isinstance(run_config, dict): + raise TypeError("run_config must be a dictionary or None.") + + args_dict = vars(args) + + parser = self.build_parser() + default_args = parser.parse_args([]) + default_dict = vars(default_args) + + cli_provided = self._detect_cli_overrides(args_dict, default_dict) + + self._apply_yaml_defaults(args, run_config, cli_provided) + self._ensure_defaults(args) + self._apply_logging_level(bool(getattr(args, "verbose", False))) + + return args + + @staticmethod + def _detect_cli_overrides( + args_dict: Dict[str, Any], default_dict: Dict[str, Any] + ) -> Set[str]: + """Detect which args were explicitly overridden in the CLI. + + Args: + args_dict: Parsed arg values. + default_dict: Parser defaults. + + Returns: + Set of argument names that differ from defaults. + """ + return {k for k, v in args_dict.items() if v != default_dict.get(k)} + + def _apply_yaml_defaults( + self, + args: argparse.Namespace, + run_config: Dict[str, Any], + cli_provided: Set[str], + ) -> None: + """Apply YAML values onto args for keys not provided by CLI. + + Args: + args: Parsed CLI arguments (mutated in-place). + run_config: YAML dict for this run. + cli_provided: Keys explicitly set via CLI. + """ + for key, yaml_value in run_config.items(): + if yaml_value is None or key in cli_provided: + continue + if key in self._arg_specs: + logger.debug("Using YAML value for %s: %s", key, yaml_value) + setattr(args, key, yaml_value) + + def _ensure_defaults(self, args: argparse.Namespace) -> None: + """Ensure all known args have defaults if still unset. + + Args: + args: Parsed arg namespace (mutated in-place). + """ + for key, spec in self._arg_specs.items(): + if getattr(args, key, None) is None: + setattr(args, key, spec.default) + + @staticmethod + def _apply_logging_level(verbose: bool) -> None: + """Adjust logging levels for this module's logger and its handlers. + + Args: + verbose: Whether to enable DEBUG logging. + """ + level = logging.DEBUG if verbose else logging.INFO + logger.setLevel(level) + for handler in logger.handlers: + handler.setLevel(level) + if verbose: + logger.debug("Verbose mode enabled. Logger set to DEBUG level.") + + def validate_inputs(self, u: Any, args: argparse.Namespace) -> None: + """Validate user inputs against sensible runtime constraints. + + Args: + u: MDAnalysis universe (or compatible) with a `trajectory`. + args: Parsed/merged arguments. + + Raises: + ValueError: If a parameter is invalid. + """ + self._check_input_start(u, args) + self._check_input_end(u, args) + self._check_input_step(args) + self._check_input_bin_width(args) + self._check_input_temperature(args) + self._check_input_force_partitioning(args) + + @staticmethod + def _check_input_start(u: Any, args: argparse.Namespace) -> None: + """Check that the start index does not exceed the trajectory length.""" + traj_len = len(u.trajectory) + if args.start > traj_len: + raise ValueError( + f"Invalid 'start' value: {args.start}. It exceeds the trajectory " + f"length of {traj_len}." + ) + + @staticmethod + def _check_input_end(u: Any, args: argparse.Namespace) -> None: + """Check that the end index does not exceed the trajectory length.""" + traj_len = len(u.trajectory) + if args.end > traj_len: + raise ValueError( + f"Invalid 'end' value: {args.end}. It exceeds the trajectory length of " + f"{traj_len}." + ) + + @staticmethod + def _check_input_step(args: argparse.Namespace) -> None: + """Warn if the step value is negative.""" + if args.step < 0: + logger.warning( + "Negative 'step' value provided: %s. This may lead to unexpected " + "behavior.", + args.step, + ) + + @staticmethod + def _check_input_bin_width(args: argparse.Namespace) -> None: + """Check that the bin width is within the valid range [0, 360].""" + if args.bin_width < 0 or args.bin_width > 360: + raise ValueError( + f"Invalid 'bin_width': {args.bin_width}. It must be between 0 and 360 " + f"degrees." + ) + + @staticmethod + def _check_input_temperature(args: argparse.Namespace) -> None: + """Check that the temperature is non-negative.""" + if args.temperature < 0: + raise ValueError( + f"Invalid 'temperature': {args.temperature}. Temperature cannot be " + f"below 0." + ) + + def _check_input_force_partitioning(self, args: argparse.Namespace) -> None: + """Warn if force partitioning is not set to the default value.""" + default_value = self._arg_specs["force_partitioning"].default + if args.force_partitioning != default_value: + logger.warning( + "'force_partitioning' is set to %s, which differs from the default %s.", + args.force_partitioning, + default_value, + ) diff --git a/CodeEntropy/config/data_logger.py b/CodeEntropy/config/data_logger.py deleted file mode 100644 index 4345e115..00000000 --- a/CodeEntropy/config/data_logger.py +++ /dev/null @@ -1,109 +0,0 @@ -import json -import logging -import re - -import numpy as np -from rich.console import Console -from rich.table import Table - -from CodeEntropy.config.logging_config import LoggingConfig - -# Set up logger -logger = logging.getLogger(__name__) -console = LoggingConfig.get_console() - - -class DataLogger: - def __init__(self, console=None): - self.console = console or Console() - self.molecule_data = [] - self.residue_data = [] - self.group_labels = {} - - def save_dataframes_as_json(self, molecule_df, residue_df, output_file): - """Save multiple DataFrames into a single JSON file with separate keys""" - data = { - "molecule_data": molecule_df.to_dict(orient="records"), - "residue_data": residue_df.to_dict(orient="records"), - } - - # Write JSON data to file - with open(output_file, "w") as out: - json.dump(data, out, indent=4) - - def clean_residue_name(self, resname): - """Ensures residue names are stripped and cleaned before being stored""" - return re.sub(r"[-–—]", "", str(resname)) - - def add_results_data(self, group_id, level, entropy_type, value): - """Add data for molecule-level entries""" - self.molecule_data.append((group_id, level, entropy_type, value)) - - def add_residue_data( - self, group_id, resname, level, entropy_type, frame_count, value - ): - """Add data for residue-level entries""" - resname = self.clean_residue_name(resname) - if isinstance(frame_count, np.ndarray): - frame_count = frame_count.tolist() - self.residue_data.append( - [group_id, resname, level, entropy_type, frame_count, value] - ) - - def add_group_label(self, group_id, label, residue_count=None, atom_count=None): - """Store a mapping from group ID to a descriptive label and metadata""" - self.group_labels[group_id] = { - "label": label, - "residue_count": residue_count, - "atom_count": atom_count, - } - - def log_tables(self): - """Display rich tables in terminal""" - - if self.molecule_data: - table = Table( - title="Molecule Entropy Results", show_lines=True, expand=True - ) - table.add_column("Group ID", justify="center", style="bold cyan") - table.add_column("Level", justify="center", style="magenta") - table.add_column("Type", justify="center", style="green") - table.add_column("Result (J/mol/K)", justify="center", style="yellow") - - for row in self.molecule_data: - table.add_row(*[str(cell) for cell in row]) - - console.print(table) - - if self.residue_data: - table = Table(title="Residue Entropy Results", show_lines=True, expand=True) - table.add_column("Group ID", justify="center", style="bold cyan") - table.add_column("Residue Name", justify="center", style="cyan") - table.add_column("Level", justify="center", style="magenta") - table.add_column("Type", justify="center", style="green") - table.add_column("Count", justify="center", style="green") - table.add_column("Result (J/mol/K)", justify="center", style="yellow") - - for row in self.residue_data: - table.add_row(*[str(cell) for cell in row]) - - console.print(table) - - if self.group_labels: - label_table = Table( - title="Group ID to Residue Label Mapping", show_lines=True, expand=True - ) - label_table.add_column("Group ID", justify="center", style="bold cyan") - label_table.add_column("Residue Label", justify="center", style="green") - label_table.add_column("Residue Count", justify="center", style="magenta") - label_table.add_column("Atom Count", justify="center", style="yellow") - - for group_id, info in self.group_labels.items(): - label_table.add_row( - str(group_id), - info["label"], - str(info.get("residue_count", "")), - str(info.get("atom_count", "")), - ) - - console.print(label_table) diff --git a/CodeEntropy/config/logging_config.py b/CodeEntropy/config/logging_config.py deleted file mode 100644 index aea5f893..00000000 --- a/CodeEntropy/config/logging_config.py +++ /dev/null @@ -1,164 +0,0 @@ -import logging -import os - -from rich.console import Console -from rich.logging import RichHandler - - -class ErrorFilter(logging.Filter): - """ - Logging filter that only allows records with level ERROR or higher. - - This ensures that the attached handler only processes error and critical logs, - filtering out all lower level messages such as DEBUG and INFO. - """ - - def filter(self, record): - return record.levelno >= logging.ERROR - - -class LoggingConfig: - """ - Configures logging with Rich console output and multiple file handlers. - Provides a single Rich Console instance that records all output for later export. - - Attributes: - _console (Console): Shared Rich Console instance with output recording enabled. - log_dir (str): Directory path to store log files. - level (int): Logging level (e.g., logging.INFO). - console (Console): The Rich Console instance used for output and logging. - handlers (dict): Dictionary of logging handlers for console and files. - """ - - _console = None # Shared Console with recording enabled - - @classmethod - def get_console(cls): - """ - Get or create a singleton Rich Console instance with recording enabled. - - Returns: - Console: Rich Console instance that prints to terminal and records output. - """ - if cls._console is None: - # Create console that records output for later export - cls._console = Console(record=True) - return cls._console - - def __init__(self, folder, level=logging.INFO): - """ - Initialize the logging configuration. - - Args: - folder (str): Base folder where 'logs' directory will be created. - level (int): Logging level (default: logging.INFO). - """ - self.log_dir = os.path.join(folder, "logs") - os.makedirs(self.log_dir, exist_ok=True) - self.level = level - - # Use the single recorded console instance - self.console = self.get_console() - - self._setup_handlers() - - def _setup_handlers(self): - paths = { - "main": os.path.join(self.log_dir, "program.log"), - "error": os.path.join(self.log_dir, "program.err"), - "command": os.path.join(self.log_dir, "program.com"), - "mdanalysis": os.path.join(self.log_dir, "mdanalysis.log"), - } - - formatter = logging.Formatter( - "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" - ) - - self.handlers = { - "rich": RichHandler( - console=self.console, - markup=True, - rich_tracebacks=True, - show_time=True, - show_level=True, - show_path=False, - ), - "main": logging.FileHandler(paths["main"]), - "error": logging.FileHandler(paths["error"]), - "command": logging.FileHandler(paths["command"]), - "mdanalysis": logging.FileHandler(paths["mdanalysis"]), - } - - self.handlers["rich"].setLevel(logging.INFO) - self.handlers["main"].setLevel(self.level) - self.handlers["error"].setLevel(logging.ERROR) - self.handlers["command"].setLevel(logging.INFO) - self.handlers["mdanalysis"].setLevel(self.level) - - for name, handler in self.handlers.items(): - if name != "rich": - handler.setFormatter(formatter) - - # Add filter to error handler to ensure only ERROR and above are logged - self.handlers["error"].addFilter(ErrorFilter()) - - def setup_logging(self): - """ - Configure the root logger and specific loggers with the prepared handlers. - - Returns: - logging.Logger: Logger instance for the current module (__name__). - """ - root = logging.getLogger() - root.setLevel(self.level) - root.addHandler(self.handlers["rich"]) - root.addHandler(self.handlers["main"]) - root.addHandler(self.handlers["error"]) - - logging.getLogger("commands").addHandler(self.handlers["command"]) - logging.getLogger("commands").setLevel(logging.INFO) - logging.getLogger("commands").propagate = False - - logging.getLogger("MDAnalysis").addHandler(self.handlers["mdanalysis"]) - logging.getLogger("MDAnalysis").setLevel(self.level) - logging.getLogger("MDAnalysis").propagate = False - - return logging.getLogger(__name__) - - def update_logging_level(self, log_level): - """ - Update the logging level for the root logger and specific sub-loggers. - - Args: - log_level (int): New logging level (e.g., logging.DEBUG, logging.WARNING). - """ - root_logger = logging.getLogger() - root_logger.setLevel(log_level) - for handler in root_logger.handlers: - if isinstance(handler, logging.FileHandler): - handler.setLevel(log_level) - else: - # Keep RichHandler at INFO or higher for nicer console output - handler.setLevel(logging.INFO) - - for logger_name in ["commands", "MDAnalysis"]: - logger = logging.getLogger(logger_name) - logger.setLevel(log_level) - for handler in logger.handlers: - if isinstance(handler, logging.FileHandler): - handler.setLevel(log_level) - else: - handler.setLevel(logging.INFO) - - def save_console_log(self, filename="program_output.txt"): - """ - Save all recorded console output to a text file. - - Args: - filename (str): Name of the file to write console output to. - Defaults to 'program_output.txt' in the logs directory. - """ - output_path = os.path.join(self.log_dir, filename) - os.makedirs(self.log_dir, exist_ok=True) - with open(output_path, "w", encoding="utf-8") as f: - f.write(self.console.export_text()) diff --git a/CodeEntropy/config/runtime.py b/CodeEntropy/config/runtime.py new file mode 100644 index 00000000..f807261e --- /dev/null +++ b/CodeEntropy/config/runtime.py @@ -0,0 +1,390 @@ +"""Run orchestration for CodeEntropy. + +This module provides the CodeEntropyRunner, which is responsible for: +- Creating a new job folder for each run +- Loading YAML configuration and merging it with CLI arguments +- Setting up logging and displaying a Rich splash screen +- Building the MDAnalysis Universe (including optional force merging) +- Wiring dependencies and executing the EntropyWorkflow workflow +- Providing physical-constants helpers used by entropy calculations + +Notes on design: +- CodeEntropyRunner focuses on orchestration and simple utilities only. +- Computational logic lives in EntropyWorkflow and the level/entropy DAG modules. +""" + +from __future__ import annotations + +import logging +import os +import pickle +from typing import Any, Dict, Optional + +import MDAnalysis as mda +import requests +import yaml +from art import text2art +from rich.align import Align +from rich.console import Group +from rich.padding import Padding +from rich.panel import Panel +from rich.rule import Rule +from rich.table import Table +from rich.text import Text + +from CodeEntropy.config.argparse import ConfigResolver +from CodeEntropy.core.logging import LoggingConfig +from CodeEntropy.entropy.workflow import EntropyWorkflow +from CodeEntropy.levels.dihedrals import ConformationStateBuilder +from CodeEntropy.levels.mda import UniverseOperations +from CodeEntropy.molecules.grouping import MoleculeGrouper +from CodeEntropy.results.reporter import ResultsReporter + +logger = logging.getLogger(__name__) +console = LoggingConfig.get_console() + + +class CodeEntropyRunner: + """Coordinate setup and execution of entropy analysis runs. + + Responsibilities: + - Bootstrapping: job folder, logging, splash screen + - Configuration: YAML loading + CLI parsing + merge and validation + - Universe creation: MDAnalysis Universe (optionally merging forces) + - Dependency wiring and execution: EntropyWorkflow + - Utilities used by downstream modules: constants and unit conversions + + Attributes: + folder: Working directory for the current job (e.g., job001). + """ + + _N_AVOGADRO = 6.0221415e23 + _DEF_TEMPER = 298 + + def __init__(self, folder: str) -> None: + """Initialize a CodeEntropyRunner for a given working folder. + + This sets up configuration helpers, data logging, and logging configuration. + It also defines physical constants used in entropy calculations. + + Args: + folder: Job folder path where logs and outputs will be written. + """ + self.folder = folder + self._config_manager = ConfigResolver() + self._reporter = ResultsReporter() + self._logging_config = LoggingConfig(folder) + + @property + def N_AVOGADRO(self) -> float: + """Return Avogadro's number used in entropy calculations.""" + return self._N_AVOGADRO + + @property + def DEF_TEMPER(self) -> float: + """Return the default temperature (K) used in the analysis.""" + return self._DEF_TEMPER + + @staticmethod + def create_job_folder() -> str: + """Create a new job folder (job###) in the current working directory. + + The method searches existing folders that start with "job" and picks the next + integer suffix. If none exist, it creates job001. + + Returns: + The full path to the newly created job folder. + """ + current_dir = os.getcwd() + existing_folders = [f for f in os.listdir(current_dir) if f.startswith("job")] + + job_numbers = [] + for folder in existing_folders: + try: + job_numbers.append(int(folder[3:])) + except ValueError: + continue + + next_job_number = 1 if not job_numbers else max(job_numbers) + 1 + new_job_folder = f"job{next_job_number:03d}" + new_folder_path = os.path.join(current_dir, new_job_folder) + os.makedirs(new_folder_path, exist_ok=True) + + return new_folder_path + + def load_citation_data(self) -> Optional[Dict[str, Any]]: + """Load CITATION.cff from GitHub. + + If the request fails (offline, blocked, etc.), returns None. + + Returns: + Parsed CITATION.cff content as a dict, or None if unavailable. + """ + url = ( + "https://raw.githubusercontent.com/CCPBioSim/" + "CodeEntropy/refs/heads/main/CITATION.cff" + ) + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + return yaml.safe_load(response.text) + except requests.exceptions.RequestException: + return None + + def show_splash(self) -> None: + """Render a Rich splash screen with optional citation metadata.""" + citation = self.load_citation_data() + + if citation: + ascii_title = text2art(citation.get("title", "CodeEntropy")) + ascii_render = Align.center(Text(ascii_title, style="bold white")) + + version = citation.get("version", "?") + release_date = citation.get("date-released", "?") + url = citation.get("url", citation.get("repository-code", "")) + + version_text = Align.center( + Text(f"Version {version} | Released {release_date}", style="green") + ) + url_text = Align.center(Text(url, style="blue underline")) + + abstract = citation.get("abstract", "No description available.") + description_title = Align.center( + Text("Description", style="bold magenta underline") + ) + description_body = Align.center( + Padding(Text(abstract, style="white", justify="left"), (0, 4)) + ) + + contributors_title = Align.center( + Text("Contributors", style="bold magenta underline") + ) + + author_table = Table( + show_header=True, header_style="bold yellow", box=None, pad_edge=False + ) + author_table.add_column("Name", style="bold", justify="center") + author_table.add_column("Affiliation", justify="center") + + for author in citation.get("authors", []): + name = ( + f"{author.get('given-names', '')} {author.get('family-names', '')}" + ).strip() + affiliation = author.get("affiliation", "") + author_table.add_row(name, affiliation) + + contributors_table = Align.center(Padding(author_table, (0, 4))) + + splash_content = Group( + ascii_render, + Rule(style="cyan"), + version_text, + url_text, + Text(), + description_title, + description_body, + Text(), + contributors_title, + contributors_table, + ) + else: + ascii_title = text2art("CodeEntropy") + ascii_render = Align.center(Text(ascii_title, style="bold white")) + splash_content = Group(ascii_render) + + splash_panel = Panel( + splash_content, + title="[bold bright_cyan]Welcome to CodeEntropy", + title_align="center", + border_style="bright_cyan", + padding=(1, 4), + expand=True, + ) + + console.print(splash_panel) + + def print_args_table(self, args: Any) -> None: + """Print a Rich table of the run configuration arguments. + + Args: + args: argparse Namespace or object with attributes for configuration. + """ + table = Table(title="Run Configuration", expand=True) + table.add_column("Argument", style="cyan", no_wrap=True) + table.add_column("Value", style="magenta") + + for arg in vars(args): + table.add_row(arg, str(getattr(args, arg))) + + console.print(table) + + def run_entropy_workflow(self) -> None: + """Run the end-to-end entropy workflow. + + This method: + - Sets up logging and prints the splash screen + - Loads YAML config from CWD and parses CLI args + - Merges args with YAML per-run config + - Builds the MDAnalysis Universe (with optional force merging) + - Validates user parameters + - Constructs dependencies and executes EntropyWorkflow + - Saves recorded console output to a log file + - Logs run arguments if an error occurs to aid debugging + + Raises: + RuntimeError: If the workflow fails for any reason. The original + exception is chained to preserve traceback information. + """ + args = None + try: + run_logger = self._logging_config.configure() + self.show_splash() + + current_directory = os.getcwd() + config = self._config_manager.load_config(current_directory) + + parser = self._config_manager.build_parser() + args, _ = parser.parse_known_args() + args.output_file = os.path.join(self.folder, args.output_file) + + for run_name, run_config in config.items(): + if not isinstance(run_config, dict): + run_logger.warning( + "Run configuration for %s is not a dictionary.", run_name + ) + continue + + args = self._config_manager.resolve(args, run_config) + + log_level = ( + logging.DEBUG if getattr(args, "verbose", False) else logging.INFO + ) + self._logging_config.set_level(log_level) + + command = " ".join(os.sys.argv) + logging.getLogger("commands").info(command) + + self._validate_required_args(args) + + self.print_args_table(args) + + universe_operations = UniverseOperations() + u = self._build_universe(args, universe_operations) + + self._config_manager.validate_inputs(u, args) + + group_molecules = MoleculeGrouper() + dihedral_analysis = ConformationStateBuilder( + universe_operations=universe_operations + ) + + entropy_manager = EntropyWorkflow( + run_manager=self, + args=args, + universe=u, + reporter=self._reporter, + group_molecules=group_molecules, + dihedral_analysis=dihedral_analysis, + universe_operations=universe_operations, + ) + entropy_manager.execute() + + self._logging_config.export_console() + + except Exception as exc: + if args is not None: + try: + logger.error("Run arguments at failure: %s", vars(args)) + except Exception: + logger.error("Run arguments at failure could not be serialized") + + raise RuntimeError("CodeEntropyRunner encountered an error") from exc + + @staticmethod + def _validate_required_args(args: Any) -> None: + """Validate presence of required arguments. + + Args: + args: argparse Namespace or similar. + + Raises: + ValueError: If required arguments are missing. + """ + if not getattr(args, "top_traj_file", None): + raise ValueError("Missing 'top_traj_file' argument.") + if not getattr(args, "selection_string", None): + raise ValueError("Missing 'selection_string' argument.") + + @staticmethod + def _build_universe( + args: Any, universe_operations: UniverseOperations + ) -> mda.Universe: + """Create an MDAnalysis Universe from args. + + Args: + args: Parsed arguments containing topology/trajectory and force settings. + universe_operations: UniverseOperations utility instance. + + Returns: + An MDAnalysis Universe ready for analysis. + """ + tprfile = args.top_traj_file[0] + trrfile = args.top_traj_file[1:] + forcefile = args.force_file + fileformat = args.file_format + kcal_units = args.kcal_force_units + + if forcefile is None: + logger.debug("Loading Universe with %s and %s", tprfile, trrfile) + return mda.Universe(tprfile, trrfile, format=fileformat) + + return universe_operations.merge_forces( + tprfile, trrfile, forcefile, fileformat, kcal_units + ) + + def write_universe(self, u: mda.Universe, name: str = "default") -> str: + """Write a universe to disk as a pickle. + + Parameters + ---------- + u : MDAnalyse.Universe + A Universe object will all topology, dihedrals,coordinates and force + information + name : str, Optional. default: 'default' + The name of file with sub file name .pkl + + Returns + ------- + name : str + filename of saved universe + """ + filename = f"{name}.pkl" + with open(filename, "wb") as f: + pickle.dump(u, f) + return name + + def read_universe(self, path: str) -> mda.Universe: + """Read a universe from disk (pickle). + + Parameters + ---------- + path : str + The path to file. + + Returns + ------- + u : MDAnalysis.Universe + A Universe object will all topology, dihedrals,coordinates and force + information. + """ + with open(path, "rb") as f: + return pickle.load(f) + + def change_lambda_units(self, arg_lambdas: Any) -> Any: + """Unit of lambdas : kJ2 mol-2 A-2 amu-1 + change units of lambda to J/s2""" + return arg_lambdas * 1e29 / self.N_AVOGADRO + + def get_KT2J(self, arg_temper: float) -> float: + """A temperature dependent KT to Joule conversion""" + return 4.11e-21 * arg_temper / self.DEF_TEMPER diff --git a/tests/data/__init__.py b/CodeEntropy/core/__init__.py similarity index 100% rename from tests/data/__init__.py rename to CodeEntropy/core/__init__.py diff --git a/CodeEntropy/core/logging.py b/CodeEntropy/core/logging.py new file mode 100644 index 00000000..4a0bfac4 --- /dev/null +++ b/CodeEntropy/core/logging.py @@ -0,0 +1,220 @@ +"""Logging configuration utilities for CodeEntropy. + +This module configures consistent logging across the project with: + +- Rich console output (with tracebacks) for human-readable terminal logs +- File handlers for main logs, error-only logs, command logs, and MDAnalysis logs +- A singleton Rich Console instance with recording enabled, so terminal output + can be exported to disk at the end of a run + +The design keeps responsibilities separated: +- ErrorFilter: filter logic only +- LoggingConfig: handler creation, logger wiring, and exporting recorded output +""" + +from __future__ import annotations + +import logging +import os +from typing import Dict, Optional + +from rich.console import Console +from rich.logging import RichHandler + + +class ErrorFilter(logging.Filter): + """Allow only ERROR and CRITICAL log records. + + This filter is intended for the error file handler so that the file contains + only high-severity records and does not include DEBUG/INFO/WARNING output. + """ + + def filter(self, record: logging.LogRecord) -> bool: + """Return True if the record should be logged. + + Args: + record: The log record being evaluated. + + Returns: + True if record.levelno >= logging.ERROR, otherwise False. + """ + return record.levelno >= logging.ERROR + + +class LoggingConfig: + """Configure project logging with Rich console output and file handlers. + + This class wires a set of handlers onto the root logger and a few named + loggers. It also provides a singleton Rich Console instance with recording + enabled so that all console output can be exported to a text file later. + + Attributes: + log_dir: Directory where log files are written. + level: Base logging level for the root logger and file handlers. + console: Shared Rich Console instance used by RichHandler. + handlers: Mapping of handler name to handler instance. + """ + + _console: Optional[Console] = None + + @classmethod + def get_console(cls) -> Console: + """Get or create the singleton Rich Console with recording enabled. + + Returns: + A Rich Console instance that prints to terminal and records output. + """ + if cls._console is None: + cls._console = Console(record=True) + return cls._console + + def __init__(self, folder: str, level: int = logging.INFO) -> None: + """Initialize logging configuration. + + Args: + folder: Base folder where the 'logs' directory will be created. + level: Logging level for the root logger and most file handlers. + """ + self.log_dir = os.path.join(folder, "logs") + os.makedirs(self.log_dir, exist_ok=True) + + self.level = level + self.console = self.get_console() + self.handlers: Dict[str, logging.Handler] = {} + + self._setup_handlers() + + def _setup_handlers(self) -> None: + """Create handlers and assign formatters/levels/filters.""" + paths = { + "main": os.path.join(self.log_dir, "program.log"), + "error": os.path.join(self.log_dir, "program.err"), + "command": os.path.join(self.log_dir, "program.com"), + "mdanalysis": os.path.join(self.log_dir, "mdanalysis.log"), + } + + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" + ) + + rich_handler = RichHandler( + console=self.console, + markup=True, + rich_tracebacks=True, + show_time=True, + show_level=True, + show_path=False, + ) + rich_handler.setLevel(logging.INFO) + + main_handler = logging.FileHandler(paths["main"]) + main_handler.setLevel(self.level) + main_handler.setFormatter(formatter) + + error_handler = logging.FileHandler(paths["error"]) + error_handler.setLevel(logging.ERROR) + error_handler.setFormatter(formatter) + error_handler.addFilter(ErrorFilter()) + + command_handler = logging.FileHandler(paths["command"]) + command_handler.setLevel(logging.INFO) + command_handler.setFormatter(formatter) + + mdanalysis_handler = logging.FileHandler(paths["mdanalysis"]) + mdanalysis_handler.setLevel(self.level) + mdanalysis_handler.setFormatter(formatter) + + self.handlers = { + "rich": rich_handler, + "main": main_handler, + "error": error_handler, + "command": command_handler, + "mdanalysis": mdanalysis_handler, + } + + def configure(self) -> logging.Logger: + """Attach configured handlers to the appropriate loggers. + + This method: + - Attaches rich/main/error handlers to the root logger + - Attaches command handler to the 'commands' logger (non-propagating) + - Attaches MDAnalysis handler to the 'MDAnalysis' logger (non-propagating) + + Returns: + A logger for the current module. + """ + root = logging.getLogger() + root.setLevel(self.level) + + self._add_handler_once(root, self.handlers["rich"]) + self._add_handler_once(root, self.handlers["main"]) + self._add_handler_once(root, self.handlers["error"]) + + commands_logger = logging.getLogger("commands") + commands_logger.setLevel(logging.INFO) + commands_logger.propagate = False + self._add_handler_once(commands_logger, self.handlers["command"]) + + mda_logger = logging.getLogger("MDAnalysis") + mda_logger.setLevel(self.level) + mda_logger.propagate = False + self._add_handler_once(mda_logger, self.handlers["mdanalysis"]) + + return logging.getLogger(__name__) + + @staticmethod + def _add_handler_once(logger_obj: logging.Logger, handler: logging.Handler) -> None: + """Attach a handler to a logger only if it isn't already attached. + + Args: + logger_obj: Logger to modify. + handler: Handler to attach. + """ + if handler not in logger_obj.handlers: + logger_obj.addHandler(handler) + + def set_level(self, log_level: int) -> None: + """Update logging levels for root and named loggers. + + Notes: + - FileHandlers are set to the new log_level. + - RichHandler is kept at INFO (or higher) for cleaner console output. + + Args: + log_level: New logging level (e.g., logging.DEBUG). + """ + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + + self._set_logger_handlers_level(root_logger, log_level) + + for logger_name in ("commands", "MDAnalysis"): + named_logger = logging.getLogger(logger_name) + named_logger.setLevel(log_level) + self._set_logger_handlers_level(named_logger, log_level) + + @staticmethod + def _set_logger_handlers_level(logger_obj: logging.Logger, log_level: int) -> None: + """Apply level rules to all handlers on a logger. + + Args: + logger_obj: Logger whose handlers should be updated. + log_level: Target logging level for file handlers. + """ + for handler in logger_obj.handlers: + if isinstance(handler, logging.FileHandler): + handler.setLevel(log_level) + else: + handler.setLevel(logging.INFO) + + def export_console(self, filename: str = "program_output.txt") -> None: + """Save recorded console output to a file. + + Args: + filename: Output filename inside the log directory. + """ + output_path = os.path.join(self.log_dir, filename) + os.makedirs(self.log_dir, exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + f.write(self.console.export_text()) diff --git a/CodeEntropy/dihedral_tools.py b/CodeEntropy/dihedral_tools.py deleted file mode 100644 index 1c5d24df..00000000 --- a/CodeEntropy/dihedral_tools.py +++ /dev/null @@ -1,392 +0,0 @@ -import logging - -import numpy as np -from MDAnalysis.analysis.dihedrals import Dihedral -from rich.progress import ( - BarColumn, - Progress, - SpinnerColumn, - TextColumn, - TimeElapsedColumn, -) - -logger = logging.getLogger(__name__) - - -class DihedralAnalysis: - """ - Functions for finding dihedral angles and analysing them to get the - states needed for the conformational entropy functions. - """ - - def __init__(self, universe_operations=None): - """ - Initialise with placeholders. - """ - self._universe_operations = universe_operations - self.data_container = None - self.states_ua = None - self.states_res = None - - def build_conformational_states( - self, - data_container, - levels, - groups, - start, - end, - step, - bin_width, - ): - """ - Build the conformational states descriptors based on dihedral angles - needed for the calculation of the conformational entropy. - """ - number_groups = len(groups) - states_ua = {} - states_res = [None] * number_groups - - total_items = sum( - len(levels[mol_id]) for mols in groups.values() for mol_id in mols - ) - - with Progress( - SpinnerColumn(), - TextColumn("[bold blue]{task.fields[title]}", justify="right"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.1f}%"), - TimeElapsedColumn(), - ) as progress: - - task = progress.add_task( - "[green]Building Conformational States...", - total=total_items, - title="Starting...", - ) - - for group_id in groups.keys(): - molecules = groups[group_id] - mol = self._universe_operations.get_molecule_container( - data_container, molecules[0] - ) - num_residues = len(mol.residues) - dihedrals_ua = [[] for _ in range(num_residues)] - peaks_ua = [{} for _ in range(num_residues)] - dihedrals_res = [] - peaks_res = {} - - # Identify dihedral AtomGroups - for level in levels[molecules[0]]: - if level == "united_atom": - for res_id in range(num_residues): - selection1 = mol.residues[res_id].atoms.indices[0] - selection2 = mol.residues[res_id].atoms.indices[-1] - res_container = self._universe_operations.new_U_select_atom( - mol, - f"index {selection1}:" f"{selection2}", - ) - heavy_res = self._universe_operations.new_U_select_atom( - res_container, "prop mass > 1.1" - ) - - dihedrals_ua[res_id] = self._get_dihedrals(heavy_res, level) - - elif level == "residue": - dihedrals_res = self._get_dihedrals(mol, level) - - # Identify peaks - for level in levels[molecules[0]]: - if level == "united_atom": - for res_id in range(num_residues): - if len(dihedrals_ua[res_id]) == 0: - # No dihedrals means no histogram or peaks - peaks_ua[res_id] = [] - else: - peaks_ua[res_id] = self._identify_peaks( - data_container, - molecules, - dihedrals_ua[res_id], - bin_width, - start, - end, - step, - ) - - elif level == "residue": - if len(dihedrals_res) == 0: - # No dihedrals means no histogram or peaks - peaks_res = [] - else: - peaks_res = self._identify_peaks( - data_container, - molecules, - dihedrals_res, - bin_width, - start, - end, - step, - ) - - # Assign states for each group - for level in levels[molecules[0]]: - if level == "united_atom": - for res_id in range(num_residues): - key = (group_id, res_id) - if len(dihedrals_ua[res_id]) == 0: - # No conformational states - states_ua[key] = [] - else: - states_ua[key] = self._assign_states( - data_container, - molecules, - dihedrals_ua[res_id], - peaks_ua[res_id], - start, - end, - step, - ) - - elif level == "residue": - if len(dihedrals_res) == 0: - # No conformational states - states_res[group_id] = [] - else: - states_res[group_id] = self._assign_states( - data_container, - molecules, - dihedrals_res, - peaks_res, - start, - end, - step, - ) - - progress.advance(task) - - return states_ua, states_res - - def _get_dihedrals(self, data_container, level): - """ - Define the set of dihedrals for use in the conformational entropy function. - If united atom level, the dihedrals are defined from the heavy atoms - (4 bonded atoms for 1 dihedral). - If residue level, use the bonds between residues to cast dihedrals. - Note: not using improper dihedrals only ones with 4 atoms/residues - in a linear arrangement. - - Args: - data_container (MDAnalysis.Universe): system information - level (str): level of the hierarchy (should be residue or polymer) - - Returns: - dihedrals (array): set of dihedrals - """ - # Start with empty array - dihedrals = [] - atom_groups = [] - - # if united atom level, read dihedrals from MDAnalysis universe - if level == "united_atom": - dihedrals = data_container.dihedrals - num_dihedrals = len(dihedrals) - for index in range(num_dihedrals): - atom_groups.append(dihedrals[index].atoms) - - # if residue level, looking for dihedrals involving residues - if level == "residue": - num_residues = len(data_container.residues) - logger.debug(f"Number Residues: {num_residues}") - if num_residues < 4: - logger.debug("no residue level dihedrals") - - else: - # find bonds between residues N-3:N-2 and N-1:N - for residue in range(4, num_residues + 1): - # Using MDAnalysis selection, - # assuming only one covalent bond between neighbouring residues - # TODO not written for branched polymers - atom_string = ( - "resindex " - + str(residue - 4) - + " and bonded resindex " - + str(residue - 3) - ) - atom1 = data_container.select_atoms(atom_string) - - atom_string = ( - "resindex " - + str(residue - 3) - + " and bonded resindex " - + str(residue - 4) - ) - atom2 = data_container.select_atoms(atom_string) - - atom_string = ( - "resindex " - + str(residue - 2) - + " and bonded resindex " - + str(residue - 1) - ) - atom3 = data_container.select_atoms(atom_string) - - atom_string = ( - "resindex " - + str(residue - 1) - + " and bonded resindex " - + str(residue - 2) - ) - atom4 = data_container.select_atoms(atom_string) - - atom_group = atom1 + atom2 + atom3 + atom4 - atom_groups.append(atom_group) - - logger.debug(f"Level: {level}, Dihedrals: {atom_groups}") - - return atom_groups - - def _identify_peaks( - self, - data_container, - molecules, - dihedrals, - bin_width, - start, - end, - step, - ): - """ - Build a histogram of the dihedral data and identify the peaks. - This is to give the information needed for the adaptive method - of identifying dihedral states. - """ - peak_values = [] * len(dihedrals) - for dihedral_index in range(len(dihedrals)): - phi = [] - # get the values of the angle for the dihedral - # loop over all molecules in the averaging group - # dihedral angle values have a range from -180 to 180 - for molecule in molecules: - mol = self._universe_operations.get_molecule_container( - data_container, molecule - ) - number_frames = len(mol.trajectory) - dihedral_results = Dihedral(dihedrals).run() - for timestep in range(number_frames): - value = dihedral_results.results.angles[timestep][dihedral_index] - - # We want postive values in range 0 to 360 to make - # the peak assignment. - # works using the fact that dihedrals have circular symetry - # (i.e. -15 degrees = +345 degrees) - if value < 0: - value += 360 - phi.append(value) - - # create a histogram using numpy - number_bins = int(360 / bin_width) - popul, bin_edges = np.histogram(a=phi, bins=number_bins, range=(0, 360)) - bin_value = [ - 0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(0, len(popul)) - ] - - # identify "convex turning-points" and populate a list of peaks - # peak : a bin whose neighboring bins have smaller population - # NOTE might have problems if the peak is wide with a flat or - # sawtooth top in which case check you have a sensible bin width - - peaks = [] - for bin_index in range(number_bins): - # if there is no dihedrals in a bin then it cannot be a peak - if popul[bin_index] == 0: - pass - # being careful of the last bin - # (dihedrals have circular symmetry, the histogram does not) - elif ( - bin_index == number_bins - 1 - ): # the -1 is because the index starts with 0 not 1 - if ( - popul[bin_index] >= popul[bin_index - 1] - and popul[bin_index] >= popul[0] - ): - peaks.append(bin_value[bin_index]) - else: - if ( - popul[bin_index] >= popul[bin_index - 1] - and popul[bin_index] >= popul[bin_index + 1] - ): - peaks.append(bin_value[bin_index]) - - peak_values.append(peaks) - - logger.debug(f"Dihedral: {dihedral_index}, Peak Values: {peak_values}") - - return peak_values - - def _assign_states( - self, - data_container, - molecules, - dihedrals, - peaks, - start, - end, - step, - ): - """ - Turn the dihedral values into conformations based on the peaks - from the histogram. - Then combine these to form states for each molecule. - """ - states = None - - # get the values of the angle for the dihedral - # dihedral angle values have a range from -180 to 180 - for molecule in molecules: - conformations = [] - mol = self._universe_operations.get_molecule_container( - data_container, molecule - ) - number_frames = len(mol.trajectory) - dihedral_results = Dihedral(dihedrals).run() - for dihedral_index in range(len(dihedrals)): - conformation = [] - for timestep in range(number_frames): - value = dihedral_results.results.angles[timestep][dihedral_index] - - # We want postive values in range 0 to 360 to make - # the peak assignment. - # works using the fact that dihedrals have circular symetry - # (i.e. -15 degrees = +345 degrees) - if value < 0: - value += 360 - - # Find the turning point/peak that the snapshot is closest to. - distances = [abs(value - peak) for peak in peaks[dihedral_index]] - conformation.append(np.argmin(distances)) - - logger.debug( - f"Dihedral: {dihedral_index} Conformations: {conformation}" - ) - conformations.append(conformation) - - # for all the dihedrals available concatenate the label of each - # dihedral into the state for that frame - mol_states = [ - state - for state in ( - "".join( - str(int(conformations[d][f])) for d in range(len(dihedrals)) - ) - for f in range(number_frames) - ) - if state - ] - - if states is None: - states = mol_states - else: - states.extend(mol_states) - - logger.debug(f"States: {states}") - - return states diff --git a/CodeEntropy/entropy.py b/CodeEntropy/entropy.py deleted file mode 100644 index 9b7016e2..00000000 --- a/CodeEntropy/entropy.py +++ /dev/null @@ -1,1152 +0,0 @@ -import logging -import math -from collections import defaultdict - -import numpy as np -import pandas as pd -import waterEntropy.recipes.interfacial_solvent as GetSolvent -from numpy import linalg as la -from rich.progress import ( - BarColumn, - Progress, - SpinnerColumn, - TextColumn, - TimeElapsedColumn, -) - -from CodeEntropy.config.logging_config import LoggingConfig - -logger = logging.getLogger(__name__) -console = LoggingConfig.get_console() - - -class EntropyManager: - """ - Manages entropy calculations at multiple molecular levels, based on a - molecular dynamics trajectory. - """ - - def __init__( - self, - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ): - """ - Initializes the EntropyManager with required components. - - Args: - run_manager: Manager for universe and selection operations. - args: Argument namespace containing user parameters. - universe: MDAnalysis universe representing the simulation system. - data_logger: Logger for storing and exporting entropy data. - level_manager: Provides level-specific data such as matrices and dihedrals. - group_molecules: includes the grouping functions for averaging over - molecules. - """ - self._run_manager = run_manager - self._args = args - self._universe = universe - self._data_logger = data_logger - self._level_manager = level_manager - self._group_molecules = group_molecules - self._dihedral_analysis = dihedral_analysis - self._universe_operations = universe_operations - self._GAS_CONST = 8.3144598484848 - - def execute(self): - """ - Run the full entropy computation workflow. - - This method orchestrates the entire entropy analysis pipeline, including: - - Handling water entropy if present. - - Initializing molecular structures and levels. - - Building force and torque covariance matrices. - - Computing vibrational and conformational entropies. - - Finalizing and logging results. - """ - # Set up initial information - start, end, step = self._get_trajectory_bounds() - number_frames = self._get_number_frames(start, end, step) - - console.print( - f"Analyzing a total of {number_frames} frames in this calculation." - ) - - ve = VibrationalEntropy( - self._run_manager, - self._args, - self._universe, - self._data_logger, - self._level_manager, - self._group_molecules, - self._dihedral_analysis, - self._universe_operations, - ) - ce = ConformationalEntropy( - self._run_manager, - self._args, - self._universe, - self._data_logger, - self._level_manager, - self._group_molecules, - self._dihedral_analysis, - self._universe_operations, - ) - - reduced_atom, number_molecules, levels, groups = self._initialize_molecules() - logger.debug(f"Universe 3: {reduced_atom}") - water_atoms = self._universe.select_atoms("water") - water_resids = set(res.resid for res in water_atoms.residues) - - water_groups = { - gid: g - for gid, g in groups.items() - if any( - res.resid in water_resids - for mol in [self._universe.atoms.fragments[i] for i in g] - for res in mol.residues - ) - } - nonwater_groups = { - gid: g for gid, g in groups.items() if gid not in water_groups - } - - if self._args.water_entropy and water_groups: - self._handle_water_entropy(start, end, step, water_groups) - else: - nonwater_groups.update(water_groups) - - force_matrices, torque_matrices, forcetorque_matrices, frame_counts = ( - self._level_manager.build_covariance_matrices( - self, - reduced_atom, - levels, - nonwater_groups, - start, - end, - step, - number_frames, - self._args.force_partitioning, - self._args.combined_forcetorque, - self._args.customised_axes, - ) - ) - - # Identify the conformational states from dihedral angles for the - # conformational entropy calculations - states_ua, states_res = self._dihedral_analysis.build_conformational_states( - reduced_atom, - levels, - nonwater_groups, - start, - end, - step, - self._args.bin_width, - ) - - # Complete the entropy calculations - self._compute_entropies( - reduced_atom, - levels, - nonwater_groups, - force_matrices, - torque_matrices, - forcetorque_matrices, - states_ua, - states_res, - frame_counts, - number_frames, - ve, - ce, - ) - - # Print the results in a nicely formated way - self._finalize_molecule_results() - self._data_logger.log_tables() - - def _handle_water_entropy(self, start, end, step, water_groups): - """ - Compute water entropy for each water group, log data, and update selection - string to exclude water from further analysis. - - Args: - start (int): Start frame index - end (int): End frame index - step (int): Step size - water_groups (dict): {group_id: [atom indices]} for water - """ - if not water_groups or not self._args.water_entropy: - return - - for group_id, atom_indices in water_groups.items(): - - self._calculate_water_entropy( - universe=self._universe, - start=start, - end=end, - step=step, - group_id=group_id, - ) - - self._args.selection_string = ( - self._args.selection_string + " and not water" - if self._args.selection_string != "all" - else "not water" - ) - - logger.debug(f"WaterEntropy: molecule_data: {self._data_logger.molecule_data}") - logger.debug(f"WaterEntropy: residue_data: {self._data_logger.residue_data}") - - def _initialize_molecules(self): - """ - Prepare the reduced universe and determine molecule-level configurations. - - Returns: - tuple: A tuple containing: - - reduced_atom (Universe): The reduced atom selection. - - number_molecules (int): Number of molecules in the system. - - levels (list): List of entropy levels per molecule. - - groups (dict): Groups for averaging over molecules. - """ - # Based on the selection string, create a new MDAnalysis universe - reduced_atom = self._get_reduced_universe() - - # Count the molecules and identify the length scale levels for each one - number_molecules, levels = self._level_manager.select_levels(reduced_atom) - - # Group the molecules for averaging - grouping = self._args.grouping - groups = self._group_molecules.grouping_molecules(reduced_atom, grouping) - - return reduced_atom, number_molecules, levels, groups - - def _compute_entropies( - self, - reduced_atom, - levels, - groups, - force_matrices, - torque_matrices, - forcetorque_matrices, - states_ua, - states_res, - frame_counts, - number_frames, - ve, - ce, - ): - """ - Compute vibrational and conformational entropies for all molecules and levels. - - This method iterates over each molecule and its associated entropy levels - (united_atom, residue, polymer), computing the corresponding entropy - contributions using force/torque matrices and dihedral conformations. - - For each level: - - "united_atom": Computes per-residue conformational states and entropy. - - "residue": Computes molecule-level conformational and vibrational entropy. - - "polymer": Computes only vibrational entropy. - - Parameters: - reduced_atom (Universe): The reduced atom selection from the trajectory. - levels (list): List of entropy levels per molecule. - groups (dict): Groups for averaging over molecules. - force_matrices (dict): Precomputed force covariance matrices. - torque_matrices (dict): Precomputed torque covariance matrices. - states_ua (dict): Dictionary to store united-atom conformational states. - states_res (list): List to store residue-level conformational states. - frames_count (dict): Dictionary to store the frame counts - number_frames (int): Total number of trajectory frames to process. - ve: Vibrational Entropy object - ce: Conformational Entropy object - """ - with Progress( - SpinnerColumn(), - TextColumn("[bold blue]{task.fields[title]}", justify="right"), - BarColumn(), - "[progress.percentage]{task.percentage:>3.1f}%", - TimeElapsedColumn(), - ) as progress: - - task = progress.add_task( - "[green]Calculating Entropy...", - total=len(groups), - title="Starting...", - ) - - for group_id in groups.keys(): - mol = self._universe_operations.get_molecule_container( - reduced_atom, groups[group_id][0] - ) - - residue_group = "_".join( - sorted(set(res.resname for res in mol.residues)) - ) - group_residue_count = len(groups[group_id]) - group_atom_count = 0 - for mol_id in groups[group_id]: - each_mol = self._universe_operations.get_molecule_container( - reduced_atom, mol_id - ) - group_atom_count += len(each_mol.atoms) - self._data_logger.add_group_label( - group_id, residue_group, group_residue_count, group_atom_count - ) - - resname = mol.atoms[0].resname - resid = mol.atoms[0].resid - segid = mol.atoms[0].segid - - mol_label = f"{resname}_{resid} (segid {segid})" - - for level in levels[groups[group_id][0]]: - progress.update( - task, - title=f"Calculating entropy values | " - f"Molecule: {mol_label} | " - f"Level: {level}", - ) - highest = level == levels[groups[group_id][0]][-1] - forcetorque_matrix = None - - if level == "united_atom": - self._process_united_atom_entropy( - group_id, - mol, - ve, - ce, - level, - force_matrices["ua"], - torque_matrices["ua"], - states_ua, - frame_counts["ua"], - highest, - number_frames, - ) - - elif level == "residue": - if highest: - forcetorque_matrix = forcetorque_matrices["res"][group_id] - self._process_vibrational_entropy( - group_id, - mol, - number_frames, - ve, - level, - force_matrices["res"][group_id], - torque_matrices["res"][group_id], - forcetorque_matrix, - highest, - ) - - self._process_conformational_entropy( - group_id, - mol, - ce, - level, - states_res, - number_frames, - ) - - elif level == "polymer": - if highest: - forcetorque_matrix = forcetorque_matrices["poly"][group_id] - self._process_vibrational_entropy( - group_id, - mol, - number_frames, - ve, - level, - force_matrices["poly"][group_id], - torque_matrices["poly"][group_id], - forcetorque_matrix, - highest, - ) - - progress.advance(task) - - def _get_trajectory_bounds(self): - """ - Returns the start, end, and step frame indices based on input arguments. - - Returns: - Tuple of (start, end, step) frame indices. - """ - start = self._args.start or 0 - end = len(self._universe.trajectory) if self._args.end == -1 else self._args.end - step = self._args.step or 1 - - return start, end, step - - def _get_number_frames(self, start, end, step): - """ - Calculates the total number of trajectory frames used in the calculation. - - Args: - start (int): Start frame index. - end (int): End frame index. If -1, it refers to the end of the trajectory. - step (int): Frame step size. - - Returns: - int: Total number of frames considered. - """ - return math.floor((end - start) / step) - - def _get_reduced_universe(self): - """ - Applies atom selection based on the user's input. - - Returns: - MDAnalysis.Universe: Selected subset of the system. - """ - # If selection string is "all" the universe does not change - if self._args.selection_string == "all": - return self._universe - - # Otherwise create a new (smaller) universe based on the selection - u = self._universe - selection_string = self._args.selection_string - reduced = self._universe_operations.new_U_select_atom(u, selection_string) - name = f"{len(reduced.trajectory)}_frame_dump_atom_selection" - self._run_manager.write_universe(reduced, name) - - return reduced - - def _process_united_atom_entropy( - self, - group_id, - mol_container, - ve, - ce, - level, - force_matrix, - torque_matrix, - states, - frame_counts, - highest, - number_frames, - ): - """ - Calculates translational, rotational, and conformational entropy at the - united-atom level. - - Args: - group_id (int): ID of the group. - mol_container (Universe): Universe for the selected molecule. - ve: VibrationalEntropy object. - ce: ConformationalEntropy object. - level (str): Granularity level (should be 'united_atom'). - start, end, step (int): Trajectory frame parameters. - n_frames (int): Number of trajectory frames. - frame_counts: Number of frames counted - highest (bool): Whether this is the highest level of resolution for - the molecule. - number_frames (int): The number of frames analysed. - """ - S_trans, S_rot, S_conf = 0, 0, 0 - - # The united atom entropy is calculated separately for each residue - # This is to allow residue by residue information - # and prevents the matrices from becoming too large - for residue_id, residue in enumerate(mol_container.residues): - - key = (group_id, residue_id) - - # Find the relevant force and torque matrices and tidy them up - # by removing rows and columns that are all zeros - f_matrix = force_matrix[key] - f_matrix = self._level_manager.filter_zero_rows_columns(f_matrix) - - t_matrix = torque_matrix[key] - t_matrix = self._level_manager.filter_zero_rows_columns(t_matrix) - - # Calculate the vibrational entropy - S_trans_res = ve.vibrational_entropy_calculation( - f_matrix, "force", self._args.temperature, highest - ) - S_rot_res = ve.vibrational_entropy_calculation( - t_matrix, "torque", self._args.temperature, highest - ) - - # Get the relevant conformational states - values = states[key] - # Check if there is information in the states array - contains_non_empty_states = ( - np.any(values) if isinstance(values, np.ndarray) else any(values) - ) - - # Calculate the conformational entropy - # If there are no conformational states (i.e. no dihedrals) - # then the conformational entropy is zero - S_conf_res = ( - ce.conformational_entropy_calculation(values) - if contains_non_empty_states - else 0 - ) - - # Add the data to the united atom level entropy - S_trans += S_trans_res - S_rot += S_rot_res - S_conf += S_conf_res - - # Print out the data for each residue - self._data_logger.add_residue_data( - group_id, - residue.resname, - level, - "Transvibrational", - frame_counts[key], - S_trans_res, - ) - self._data_logger.add_residue_data( - group_id, - residue.resname, - level, - "Rovibrational", - frame_counts[key], - S_rot_res, - ) - self._data_logger.add_residue_data( - group_id, - residue.resname, - level, - "Conformational", - frame_counts[key], - S_conf_res, - ) - - # Print the total united atom level data for the molecule group - self._data_logger.add_results_data(group_id, level, "Transvibrational", S_trans) - self._data_logger.add_results_data(group_id, level, "Rovibrational", S_rot) - self._data_logger.add_results_data(group_id, level, "Conformational", S_conf) - - residue_group = "_".join( - sorted(set(res.resname for res in mol_container.residues)) - ) - - logger.debug(f"residue_group {residue_group}") - - def _process_vibrational_entropy( - self, - group_id, - mol_container, - number_frames, - ve, - level, - force_matrix, - torque_matrix, - forcetorque_matrix, - highest, - ): - """ - Calculates vibrational entropy. - - Args: - group_id (int): Group ID. - ve: VibrationalEntropy object. - level (str): Current granularity level. - force_matrix : Force covariance matrix - torque_matrix : Torque covariance matrix - frame_count: - highest (bool): Flag indicating if this is the highest granularity - level. - """ - # Find the relevant force and torque matrices and tidy them up - # by removing rows and columns that are all zeros - - if forcetorque_matrix is not None: - forcetorque_matrix = self._level_manager.filter_zero_rows_columns( - forcetorque_matrix - ) - - S_FTtrans = ve.vibrational_entropy_calculation( - forcetorque_matrix, "forcetorqueTRANS", self._args.temperature, highest - ) - S_FTrot = ve.vibrational_entropy_calculation( - forcetorque_matrix, "forcetorqueROT", self._args.temperature, highest - ) - - self._data_logger.add_results_data( - group_id, level, "FTmat-Transvibrational", S_FTtrans - ) - self._data_logger.add_results_data( - group_id, level, "FTmat-Rovibrational", S_FTrot - ) - - else: - force_matrix = self._level_manager.filter_zero_rows_columns(force_matrix) - - torque_matrix = self._level_manager.filter_zero_rows_columns(torque_matrix) - - # Calculate the vibrational entropy - S_trans = ve.vibrational_entropy_calculation( - force_matrix, "force", self._args.temperature, highest - ) - S_rot = ve.vibrational_entropy_calculation( - torque_matrix, "torque", self._args.temperature, highest - ) - - # Print the vibrational entropy for the molecule group - self._data_logger.add_results_data( - group_id, level, "Transvibrational", S_trans - ) - self._data_logger.add_results_data(group_id, level, "Rovibrational", S_rot) - - residue_group = "_".join( - sorted(set(res.resname for res in mol_container.residues)) - ) - residue_count = len(mol_container.residues) - atom_count = len(mol_container.atoms) - self._data_logger.add_group_label( - group_id, residue_group, residue_count, atom_count - ) - - def _process_conformational_entropy( - self, group_id, mol_container, ce, level, states, number_frames - ): - """ - Computes conformational entropy at the residue level (whole-molecule dihedral - analysis). - - Args: - mol_id (int): ID of the molecule. - mol_container (Universe): Selected molecule's universe. - ce: ConformationalEntropy object. - level (str): Level name (should be 'residue'). - states (array): The conformational states. - number_frames (int): Number of frames used. - """ - # Get the relevant conformational states - # Check if there is information in the states array - group_states = states[group_id] if group_id < len(states) else None - - if group_states is not None: - contains_state_data = ( - group_states.any() - if isinstance(group_states, np.ndarray) - else any(group_states) - ) - else: - contains_state_data = False - - # Calculate the conformational entropy - # If there are no conformational states (i.e. no dihedrals) - # then the conformational entropy is zero - S_conf = ( - ce.conformational_entropy_calculation(group_states) - if contains_state_data - else 0 - ) - self._data_logger.add_results_data(group_id, level, "Conformational", S_conf) - - residue_group = "_".join( - sorted(set(res.resname for res in mol_container.residues)) - ) - residue_count = len(mol_container.residues) - atom_count = len(mol_container.atoms) - self._data_logger.add_group_label( - group_id, residue_group, residue_count, atom_count - ) - - def _finalize_molecule_results(self): - """ - Aggregates and logs total entropy and frame counts per molecule. - """ - entropy_by_molecule = defaultdict(float) - for ( - mol_id, - level, - entropy_type, - result, - ) in self._data_logger.molecule_data: - if level != "Group Total": - try: - entropy_by_molecule[mol_id] += float(result) - except ValueError: - logger.warning(f"Skipping invalid entry: {mol_id}, {result}") - - for mol_id in entropy_by_molecule.keys(): - total_entropy = entropy_by_molecule[mol_id] - - self._data_logger.molecule_data.append( - ( - mol_id, - "Group Total", - "Group Total Entropy", - total_entropy, - ) - ) - - self._data_logger.save_dataframes_as_json( - pd.DataFrame( - self._data_logger.molecule_data, - columns=[ - "Group ID", - "Level", - "Type", - "Result (J/mol/K)", - ], - ), - pd.DataFrame( - self._data_logger.residue_data, - columns=[ - "Group ID", - "Residue Name", - "Level", - "Type", - "Frame Count", - "Result (J/mol/K)", - ], - ), - self._args.output_file, - ) - - def _calculate_water_entropy(self, universe, start, end, step, group_id=None): - """ - Calculate and aggregate the entropy of water molecules in a simulation. - - This function computes orientational, translational, and rotational - entropy components for all water molecules, aggregates them per residue, - and maps all waters to a single group ID. It also logs the total results - and labels the water group in the data logger. - - Parameters - ---------- - universe : MDAnalysis.Universe - The simulation universe containing water molecules. - start : int - The starting frame for analysis. - end : int - The ending frame for analysis. - step : int - Frame interval for analysis. - group_id : int or str, optional - The group ID to which all water molecules will be assigned. - """ - Sorient_dict, covariances, vibrations, _, water_count = ( - GetSolvent.get_interfacial_water_orient_entropy( - universe, start, end, step, self._args.temperature, parallel=True - ) - ) - - self._calculate_water_orientational_entropy(Sorient_dict, group_id) - self._calculate_water_vibrational_translational_entropy( - vibrations, group_id, covariances - ) - self._calculate_water_vibrational_rotational_entropy( - vibrations, group_id, covariances - ) - - water_selection = universe.select_atoms("resname WAT") - actual_water_residues = len(water_selection.residues) - residue_names = { - resname - for res_dict in Sorient_dict.values() - for resname in res_dict.keys() - if resname.upper() in water_selection.residues.resnames - } - - residue_group = "_".join(sorted(residue_names)) if residue_names else "WAT" - self._data_logger.add_group_label( - group_id, residue_group, actual_water_residues, len(water_selection.atoms) - ) - - def _calculate_water_orientational_entropy(self, Sorient_dict, group_id): - """ - Aggregate orientational entropy for all water molecules into a single group. - - Parameters - ---------- - Sorient_dict : dict - Dictionary containing orientational entropy values per residue. - group_id : int or str - The group ID to which the water residues belong. - covariances : object - Covariance object. - """ - for resid, resname_dict in Sorient_dict.items(): - for resname, values in resname_dict.items(): - if isinstance(values, list) and len(values) == 2: - Sor, count = values - self._data_logger.add_residue_data( - group_id, resname, "Water", "Orientational", count, Sor - ) - - def _calculate_water_vibrational_translational_entropy( - self, vibrations, group_id, covariances - ): - """ - Aggregate translational vibrational entropy for all water molecules. - - Parameters - ---------- - vibrations : object - Object containing translational entropy data (vibrations.translational_S). - group_id : int or str - The group ID for the water residues. - covariances : object - Covariance object. - """ - - for (solute_id, _), entropy in vibrations.translational_S.items(): - if isinstance(entropy, (list, np.ndarray)): - entropy = float(np.sum(entropy)) - - count = covariances.counts.get((solute_id, "WAT"), 1) - resname = solute_id.rsplit("_", 1)[0] if "_" in solute_id else solute_id - self._data_logger.add_residue_data( - group_id, resname, "Water", "Transvibrational", count, entropy - ) - - def _calculate_water_vibrational_rotational_entropy( - self, vibrations, group_id, covariances - ): - """ - Aggregate rotational vibrational entropy for all water molecules. - - Parameters - ---------- - vibrations : object - Object containing rotational entropy data (vibrations.rotational_S). - group_id : int or str - The group ID for the water residues. - covariances : object - Covariance object. - """ - for (solute_id, _), entropy in vibrations.rotational_S.items(): - if isinstance(entropy, (list, np.ndarray)): - entropy = float(np.sum(entropy)) - - count = covariances.counts.get((solute_id, "WAT"), 1) - - resname = solute_id.rsplit("_", 1)[0] if "_" in solute_id else solute_id - self._data_logger.add_residue_data( - group_id, resname, "Water", "Rovibrational", count, entropy - ) - - -class VibrationalEntropy(EntropyManager): - """ - Performs vibrational entropy calculations using molecular trajectory data. - Extends the base EntropyManager with constants and logic specific to - vibrational modes and thermodynamic properties. - """ - - def __init__( - self, - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ): - """ - Initializes the VibrationalEntropy manager with all required components and - defines physical constants used in vibrational entropy calculations. - """ - super().__init__( - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ) - self._PLANCK_CONST = 6.62607004081818e-34 - - def frequency_calculation(self, lambdas, temp): - """ - Function to calculate an array of vibrational frequencies from the eigenvalues - of the covariance matrix. - - Calculated from eq. (3) in Higham, S.-Y. Chou, F. Gräter and R. H. Henchman, - Molecular Physics, 2018, 116, 1965–1976//eq. (3) in A. Chakravorty, J. Higham - and R. H. Henchman, J. Chem. Inf. Model., 2020, 60, 5540–5551 - - frequency=sqrt(λ/kT)/2π - - Args: - lambdas : array of floats - eigenvalues of the covariance matrix - temp: float - temperature - - Returns: - frequencies : array of floats - corresponding vibrational frequencies - """ - pi = np.pi - # get kT in Joules from given temperature - kT = self._run_manager.get_KT2J(temp) - logger.debug(f"Temperature: {temp}, kT: {kT}") - - lambdas = np.array(lambdas) # Ensure input is a NumPy array - logger.debug(f"Eigenvalues (lambdas): {lambdas}") - - # Filter out lambda values that are negative or imaginary numbers - # As these will produce supurious entropy results that can crash - # the calculation - lambdas = np.real_if_close(lambdas, tol=1000) - valid_mask = ( - np.isreal(lambdas) & (lambdas > 0) & (~np.isclose(lambdas, 0, atol=1e-07)) - ) - - # If any lambdas were removed by the filter, warn the user - # as this will suggest insufficient sampling in the simulation data - if len(lambdas) > np.count_nonzero(valid_mask): - logger.warning( - f"{len(lambdas) - np.count_nonzero(valid_mask)} " - f"invalid eigenvalues excluded (complex, non-positive, or near-zero)." - ) - - lambdas = lambdas[valid_mask].real - - # Compute frequencies safely - frequencies = 1 / (2 * pi) * np.sqrt(lambdas / kT) - logger.debug(f"Calculated frequencies: {frequencies}") - - return frequencies - - def vibrational_entropy_calculation(self, matrix, matrix_type, temp, highest_level): - """ - Function to calculate the vibrational entropy for each level calculated from - eq. (4) in J. Higham, S.-Y. Chou, F. Gräter and R. H. Henchman, Molecular - Physics, 2018, 116, 1965–1976 / eq. (2) in A. Chakravorty, J. Higham and - R. H. Henchman, J. Chem. Inf. Model., 2020, 60, 5540–5551. - - Args: - matrix : matrix - force/torque covariance matrix - matrix_type: string - temp: float - temperature - highest_level: bool - is this the highest level of the heirarchy - - Returns: - S_vib_total : float - transvibrational/rovibrational entropy - """ - # N beads at a level => 3N x 3N covariance matrix => 3N eigenvalues - # Get eigenvalues of the given matrix and change units to SI units - logger.debug(f"matrix_type: {matrix_type}") - lambdas = la.eigvals(matrix) - logger.debug(f"Eigenvalues (lambdas) before unit change: {lambdas}") - - lambdas = self._run_manager.change_lambda_units(lambdas) - logger.debug(f"Eigenvalues (lambdas) after unit change: {lambdas}") - - # Calculate frequencies from the eigenvalues - frequencies = self.frequency_calculation(lambdas, temp) - logger.debug(f"Calculated frequencies: {frequencies}") - - # Sort frequencies lowest to highest - frequencies = np.sort(frequencies) - logger.debug(f"Sorted frequencies: {frequencies}") - - kT = self._run_manager.get_KT2J(temp) - logger.debug(f"Temperature: {temp}, kT: {kT}") - exponent = self._PLANCK_CONST * frequencies / kT - logger.debug(f"Exponent values: {exponent}") - power_positive = np.power(np.e, exponent) - power_negative = np.power(np.e, -exponent) - logger.debug(f"Power positive values: {power_positive}") - logger.debug(f"Power negative values: {power_negative}") - S_components = exponent / (power_positive - 1) - np.log(1 - power_negative) - S_components = ( - S_components * self._GAS_CONST - ) # multiply by R - get entropy in J mol^{-1} K^{-1} - logger.debug(f"Entropy components: {S_components}") - # N beads at a level => 3N x 3N covariance matrix => 3N eigenvalues - if matrix_type == "force": # force covariance matrix - if ( - highest_level - ): # whole molecule level - we take all frequencies into account - S_vib_total = sum(S_components) - - # discard the 6 lowest frequencies to discard translation and rotation of - # the whole unit the overall translation and rotation of a unit is an - # internal motion of the level above - else: - S_vib_total = sum(S_components[6:]) - - elif matrix_type == "forcetorqueTRANS": # three lowest are translations - S_vib_total = sum(S_components[:3]) - elif matrix_type == "forcetorqueROT": # three highest are rotations - S_vib_total = sum(S_components[3:]) - - else: # torque covariance matrix - we always take all values into account - S_vib_total = sum(S_components) - - logger.debug(f"Total vibrational entropy: {S_vib_total}") - - return S_vib_total - - -class ConformationalEntropy(EntropyManager): - """ - Performs conformational entropy calculations based on molecular dynamics data. - Inherits from EntropyManager and includes constants specific to conformational - analysis using statistical mechanics principles. - """ - - def __init__( - self, - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ): - """ - Initializes the ConformationalEntropy manager with all required components and - sets the gas constant used in conformational entropy calculations. - """ - super().__init__( - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ) - - def conformational_entropy_calculation(self, states): - """ - Function to calculate conformational entropies using eq. (7) in Higham, - S.-Y. Chou, F. Gräter and R. H. Henchman, Molecular Physics, 2018, 116, - 1965–1976 / eq. (4) in A. Chakravorty, J. Higham and R. H. Henchman, - J. Chem. Inf. Model., 2020, 60, 5540–5551. - - Uses the adaptive enumeration method (AEM). - - Args: - states (array): Conformational states in the molecule - - Returns: - S_conf_total (float) : conformational entropy - """ - - S_conf_total = 0 - - # Count how many times each state occurs, then use the probability - # to get the entropy - # entropy = sum over states p*ln(p) - values, counts = np.unique(states, return_counts=True) - total_count = len(states) - for state in range(len(values)): - logger.debug(f"Unique states: {values}") - logger.debug(f"Counts: {counts}") - count = counts[state] - probability = count / total_count - entropy = probability * np.log(probability) - S_conf_total += entropy - - # multiply by gas constant to get the units J/mol/K - S_conf_total *= -1 * self._GAS_CONST - - logger.debug(f"Total conformational entropy: {S_conf_total}") - - return S_conf_total - - -class OrientationalEntropy(EntropyManager): - """ - Performs orientational entropy calculations using molecular dynamics data. - Inherits from EntropyManager and includes constants relevant to rotational - and orientational degrees of freedom. - """ - - def __init__( - self, - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ): - """ - Initializes the OrientationalEntropy manager with all required components and - sets the gas constant used in orientational entropy calculations. - """ - super().__init__( - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ) - - def orientational_entropy_calculation(self, neighbours_dict): - """ - Function to calculate orientational entropies from eq. (10) in J. Higham, - S.-Y. Chou, F. Gräter and R. H. Henchman, Molecular Physics, 2018, 116, - 3 1965–1976. Number of orientations, Ω, is calculated using eq. (8) in - J. Higham, S.-Y. Chou, F. Gräter and R. H. Henchman, Molecular Physics, - 2018, 116, 3 1965–1976. - - σ is assumed to be 1 for the molecules we're concerned with and hence, - max {1, (Nc^3*π)^(1/2)} will always be (Nc^3*π)^(1/2). - - TODO future release - function for determing symmetry and symmetry numbers - maybe? - - Input - ----- - neighbours_dict : dictionary - dictionary of neighbours for the molecule - - should contain the type of neighbour molecule and the number of neighbour - molecules of that species - - Returns - ------- - S_or_total : float - orientational entropy - """ - - # Replaced molecule with neighbour as this is what the for loop uses - S_or_total = 0 - for neighbour in neighbours_dict: # we are going through neighbours - if neighbour in ["H2O"]: # water molecules - call POSEIDON functions - pass # TODO temporary until function is written - else: - # the bound ligand is always going to be a neighbour - omega = np.sqrt((neighbours_dict[neighbour] ** 3) * math.pi) - logger.debug(f"Omega for neighbour {neighbour}: {omega}") - # orientational entropy arising from each neighbouring species - # - we know the species is going to be a neighbour - S_or_component = math.log(omega) - logger.debug( - f"S_or_component (log(omega)) for neighbour {neighbour}: " - f"{S_or_component}" - ) - S_or_component *= self._GAS_CONST - logger.debug( - f"S_or_component after multiplying by GAS_CONST for neighbour " - f"{neighbour}: {S_or_component}" - ) - S_or_total += S_or_component - logger.debug( - f"S_or_total after adding component for neighbour {neighbour}: " - f"{S_or_total}" - ) - # TODO for future releases - # implement a case for molecules with hydrogen bonds but to a lesser - # extent than water - - logger.debug(f"Final total orientational entropy: {S_or_total}") - - return S_or_total diff --git a/CodeEntropy/entropy/__init__.py b/CodeEntropy/entropy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CodeEntropy/entropy/configurational.py b/CodeEntropy/entropy/configurational.py new file mode 100644 index 00000000..be6c31cd --- /dev/null +++ b/CodeEntropy/entropy/configurational.py @@ -0,0 +1,264 @@ +"""Conformational entropy utilities. + +This module provides: + * Assigning discrete conformational states for a single dihedral time series. + * Computing conformational entropy from a sequence of state labels. + +The public surface area is intentionally small to keep responsibilities clear. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ConformationConfig: + """Configuration for assigning conformational states from a dihedral. + + Attributes: + bin_width: Histogram bin width in degrees for peak detection. + start: Inclusive start frame index for trajectory slicing. + end: Exclusive end frame index for trajectory slicing. + step: Stride for trajectory slicing (must be positive). + """ + + bin_width: int + start: int + end: int + step: int + + +class ConformationalEntropy: + """Assign dihedral conformational states and compute conformational entropy. + + This class contains two independent responsibilities: + 1) `assign_conformation`: Map a single dihedral angle time series to discrete + state labels by detecting histogram peaks and assigning the nearest peak. + 2) `conformational_entropy_calculation`: Compute Shannon entropy of the + state distribution (in J/mol/K). + + Notes: + `number_frames` is accepted by `conformational_entropy_calculation` for + compatibility with calling sites that track frame counts, but the entropy + is computed from the observed state counts (i.e., `len(states)`), which is + the correct normalization for the sampled distribution. + """ + + _GAS_CONST: float = 8.3144598484848 + + def __init__(self) -> None: + """Math-only engine. + + This class assigns conformational states and computes conformational entropy. + It does not depend on the workflow runner, universe, grouping, or reporting. + """ + pass + + def assign_conformation( + self, + data_container: Any, + dihedral: Any, + number_frames: int, + bin_width: int, + start: int, + end: int, + step: int, + ) -> np.ndarray: + """Assign discrete conformational states for a single dihedral. + + The dihedral angle time series is: + 1) Collected across the trajectory slice [start:end:step]. + 2) Converted to [0, 360) degrees. + 3) Histogrammed using `bin_width`. + 4) Peaks are identified as bins with locally maximal population. + 5) Each frame is assigned the index of the nearest peak. + + Args: + data_container: MDAnalysis Universe/AtomGroup with a trajectory. + dihedral: Object providing `value()` for the current frame dihedral. + number_frames: Provided for call-site compatibility; not used for sizing. + bin_width: Histogram bin width in degrees. + start: Inclusive start frame index. + end: Exclusive end frame index. + step: Stride for trajectory slicing. + + Returns: + Array of integer state labels of length equal to the trajectory slice. + Returns an empty array if the slice is empty. + + Raises: + ValueError: If `bin_width` or `step` are invalid. + """ + _ = number_frames + + config = ConformationConfig( + bin_width=int(bin_width), + start=int(start), + end=int(end), + step=int(step), + ) + self._validate_assignment_config(config) + + traj_slice = data_container.trajectory[config.start : config.end : config.step] + n_slice = len(traj_slice) + if n_slice <= 0: + return np.array([], dtype=int) + + phi = self._collect_dihedral_angles(traj_slice, dihedral) + peak_values = self._find_histogram_peaks(phi, config.bin_width) + + if peak_values.size == 0: + return np.zeros(n_slice, dtype=int) + + states = self._assign_nearest_peaks(phi, peak_values) + logger.debug("Final conformations: %s", states) + return states + + def conformational_entropy_calculation( + self, states: Any, number_frames: int + ) -> float: + """Compute conformational entropy for a sequence of state labels. + + Entropy is computed as: + S = -R * sum_i p_i * ln(p_i) + where p_i is the observed probability of state i in `states`. + + Args: + states: Sequence/array of discrete state labels. Empty/None yields 0.0. + number_frames: Frame count metadata. + + Returns: + Conformational entropy in J/mol/K. + """ + _ = number_frames + + arr = self._to_1d_array(states) + if arr is None or arr.size == 0: + return 0.0 + + values, counts = np.unique(arr, return_counts=True) + total_count = int(np.sum(counts)) + if total_count <= 0 or values.size <= 1: + return 0.0 + + probs = counts.astype(float) / float(total_count) + probs = probs[probs > 0.0] + + s_conf = -self._GAS_CONST * float(np.sum(probs * np.log(probs))) + logger.debug("Total conformational entropy: %s", s_conf) + return s_conf + + @staticmethod + def _validate_assignment_config(config: ConformationConfig) -> None: + """Validate conformation assignment configuration. + + Args: + config: Assignment configuration. + + Raises: + ValueError: If configuration values are invalid. + """ + if config.step <= 0: + raise ValueError("step must be a positive integer") + if config.bin_width <= 0 or config.bin_width > 360: + raise ValueError("bin_width must be in the range (0, 360]") + if 360 % config.bin_width != 0: + logger.warning( + "bin_width=%s does not evenly divide 360; histogram bins will be " + "uneven.", + config.bin_width, + ) + + @staticmethod + def _collect_dihedral_angles(traj_slice: Any, dihedral: Any) -> np.ndarray: + """Collect dihedral angles for each frame in the trajectory slice. + + Args: + traj_slice: Slice of a trajectory iterable where iterating advances frames. + dihedral: Object with `value()` returning the dihedral in degrees. + + Returns: + Array of dihedral values mapped into [0, 360). + """ + phi = np.zeros(len(traj_slice), dtype=float) + for i, _ts in enumerate(traj_slice): + value = float(dihedral.value()) + if value < 0.0: + value += 360.0 + phi[i] = value + return phi + + @staticmethod + def _find_histogram_peaks(phi: np.ndarray, bin_width: int) -> np.ndarray: + """Identify peak bin centers from a histogram of dihedral angles. + + A peak is defined as a bin whose population is greater than or equal to + its immediate neighbors (with circular handling at the final bin). + + Args: + phi: Dihedral angles in degrees, in [0, 360). + bin_width: Histogram bin width in degrees. + + Returns: + 1D array of peak bin center values (degrees). Empty if no peaks found. + """ + number_bins = int(360 / bin_width) + popul, bin_edges = np.histogram(phi, bins=number_bins, range=(0.0, 360.0)) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + + peaks: list[float] = [] + for idx in range(number_bins): + if popul[idx] == 0: + continue + + left = popul[idx - 1] if idx > 0 else popul[number_bins - 1] + right = popul[idx + 1] if idx < number_bins - 1 else popul[0] + + if popul[idx] >= left and popul[idx] >= right: + peaks.append(float(bin_centers[idx])) + + return np.asarray(peaks, dtype=float) + + @staticmethod + def _assign_nearest_peaks(phi: np.ndarray, peak_values: np.ndarray) -> np.ndarray: + """Assign each phi value to the index of its nearest peak. + + Args: + phi: Dihedral angles in degrees. + peak_values: Peak centers (degrees). + + Returns: + Integer state labels aligned with `phi`. + """ + distances = np.abs(phi[:, None] - peak_values[None, :]) + return np.argmin(distances, axis=1).astype(int) + + @staticmethod + def _to_1d_array(states: Any) -> Optional[np.ndarray]: + """Convert a state sequence into a 1D numpy array. + + Args: + states: Input sequence/array. + + Returns: + 1D numpy array, or None if input is not usable. + """ + if states is None: + return None + + if isinstance(states, np.ndarray): + arr = states.reshape(-1) + else: + try: + arr = np.asarray(list(states)).reshape(-1) + except TypeError: + return None + + return arr diff --git a/CodeEntropy/entropy/graph.py b/CodeEntropy/entropy/graph.py new file mode 100644 index 00000000..d243907d --- /dev/null +++ b/CodeEntropy/entropy/graph.py @@ -0,0 +1,141 @@ +"""Entropy graph orchestration. + +This module defines `EntropyGraph`, a small directed acyclic graph (DAG) that +executes entropy calculation nodes in dependency order. + +The graph is intentionally simple: + * Vibrational entropy + * Configurational entropy + * Aggregation of results + +The nodes themselves encapsulate the detailed calculations. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Dict + +import networkx as nx + +from CodeEntropy.entropy.nodes.aggregate import AggregateEntropyNode +from CodeEntropy.entropy.nodes.configurational import ConfigurationalEntropyNode +from CodeEntropy.entropy.nodes.vibrational import VibrationalEntropyNode + +logger = logging.getLogger(__name__) + + +SharedData = Dict[str, Any] + + +@dataclass(frozen=True) +class NodeSpec: + """Specification for a node within the entropy graph. + + Attributes: + name: Unique node name. + node: Node instance. Must implement `run(shared_data, **kwargs)`. + deps: Optional list of node names that must run before this node. + """ + + name: str + node: Any + deps: tuple[str, ...] = () + + +class EntropyGraph: + """Build and execute the entropy calculation DAG. + + The graph is built once via `build()` and executed via `execute()`. + + Examples: + graph = EntropyGraph().build() + results = graph.execute(shared_data) + """ + + def __init__(self) -> None: + """Initialize an empty entropy graph.""" + self._graph: nx.DiGraph = nx.DiGraph() + self._nodes: Dict[str, Any] = {} + + def build(self) -> "EntropyGraph": + """Populate the graph with the standard entropy workflow. + + Returns: + Self for fluent chaining. + """ + specs = ( + NodeSpec("vibrational_entropy", VibrationalEntropyNode()), + NodeSpec("configurational_entropy", ConfigurationalEntropyNode()), + NodeSpec( + "aggregate_entropy", + AggregateEntropyNode(), + deps=("vibrational_entropy", "configurational_entropy"), + ), + ) + + for spec in specs: + self._add_node(spec) + + return self + + def execute( + self, shared_data: SharedData, *, progress: object | None = None + ) -> Dict[str, Any]: + """Execute the entropy graph in topological order. + + Nodes are executed in dependency order (topological sort). Each node reads + from and may mutate `shared_data`. Dict-like outputs returned by nodes are + merged into a single results dictionary. + + This method intentionally does *not* create a progress bar/task for the + entropy graph itself because the graph is typically very fast. If a progress + sink is provided, it is forwarded to nodes that accept it. + + Args: + shared_data: Mutable shared data dictionary passed to each node. + progress: Optional progress sink (e.g., from ResultsReporter.progress()). + Forwarded to node `run()` methods that accept a `progress` keyword. + + Returns: + Dictionary containing merged dict outputs produced by nodes. On key + collision, later nodes overwrite earlier keys. + + Raises: + KeyError: If a node name is missing from the internal node registry. + """ + results: Dict[str, Any] = {} + + for node_name in nx.topological_sort(self._graph): + node = self._nodes[node_name] + + if progress is not None: + try: + out = node.run(shared_data, progress=progress) + except TypeError: + out = node.run(shared_data) + else: + out = node.run(shared_data) + + if isinstance(out, dict): + results.update(out) + return results + + def _add_node(self, spec: NodeSpec) -> None: + """Add a node and its dependencies to the graph. + + Args: + spec: Node specification. + + Raises: + ValueError: If a duplicate node name is added. + """ + if spec.name in self._nodes: + raise ValueError(f"Duplicate node name: {spec.name}") + + self._nodes[spec.name] = spec.node + self._graph.add_node(spec.name) + + for dep in spec.deps: + self._graph.add_edge(dep, spec.name) diff --git a/CodeEntropy/entropy/nodes/__init__.py b/CodeEntropy/entropy/nodes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CodeEntropy/entropy/nodes/aggregate.py b/CodeEntropy/entropy/nodes/aggregate.py new file mode 100644 index 00000000..6855ab12 --- /dev/null +++ b/CodeEntropy/entropy/nodes/aggregate.py @@ -0,0 +1,85 @@ +"""Aggregates entropy outputs produced by upstream DAG nodes.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Mapping, MutableMapping, Optional + +EntropyResults = Dict[str, Any] + + +@dataclass(frozen=True, slots=True) +class AggregateEntropyNode: + """Aggregate entropy results into a single shared output object. + + This node is intentionally small and single-purpose: + it gathers previously-computed entropy components from `shared_data` + and writes a canonical `shared_data["entropy_results"]` mapping. + + Attributes: + vibrational_key: Key in `shared_data` where vibrational entropy is stored. + configurational_key: Key in `shared_data` where configurational entropy is + stored. + output_key: Key in `shared_data` where the aggregated mapping is written. + """ + + vibrational_key: str = "vibrational_entropy" + configurational_key: str = "configurational_entropy" + output_key: str = "entropy_results" + + def run( + self, shared_data: MutableMapping[str, Any], **_: Any + ) -> Dict[str, EntropyResults]: + """Run the aggregation step. + + Args: + shared_data: Shared workflow state. Must contain (or may contain) keys + for vibrational and configurational entropy results. + + Returns: + A dict containing a single key, `"entropy_results"`, which maps to the + aggregated results dict. + + Side Effects: + Writes the aggregated results to `shared_data[self.output_key]`. + + Notes: + This node does not validate the shapes/types of upstream results. + Validation should live with the producer nodes (single responsibility). + """ + results = self._collect_entropy_results(shared_data) + shared_data[self.output_key] = results + return {self.output_key: results} + + def _collect_entropy_results( + self, shared_data: Mapping[str, Any] + ) -> EntropyResults: + """Collect entropy results from shared data. + + Args: + shared_data: Shared workflow state. + + Returns: + A mapping with keys `"vibrational_entropy"` and `"configurational_entropy"`. + """ + return { + "vibrational_entropy": self._get_optional( + shared_data, self.vibrational_key + ), + "configurational_entropy": self._get_optional( + shared_data, self.configurational_key + ), + } + + @staticmethod + def _get_optional(shared_data: Mapping[str, Any], key: str) -> Optional[Any]: + """Fetch an optional value from shared data. + + Args: + shared_data: Shared workflow state. + key: Key to fetch. + + Returns: + The value if present, otherwise None. + """ + return shared_data.get(key) diff --git a/CodeEntropy/entropy/nodes/configurational.py b/CodeEntropy/entropy/nodes/configurational.py new file mode 100644 index 00000000..a6a8c483 --- /dev/null +++ b/CodeEntropy/entropy/nodes/configurational.py @@ -0,0 +1,229 @@ +"""Node for computing configurational entropy from conformational states.""" + +from __future__ import annotations + +import logging +from typing import ( + Any, + Dict, + Iterable, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + Union, +) + +import numpy as np + +from CodeEntropy.entropy.configurational import ConformationalEntropy + +logger = logging.getLogger(__name__) + +GroupId = int +ResidueId = int +StateKey = Tuple[GroupId, ResidueId] +StateSequence = Union[Sequence[Any], np.ndarray] + + +class ConfigurationalEntropyNode: + """Compute configurational entropy using precomputed conformational states. + + This node reads conformational state assignments from ``shared_data`` and + computes entropy contributions at different structural levels. + + Results are written back into ``shared_data["configurational_entropy"]``. + """ + + def run(self, shared_data: MutableMapping[str, Any], **_: Any) -> Dict[str, Any]: + """Execute configurational entropy calculation. + + Args: + shared_data: Shared workflow state dictionary. + + Returns: + Dictionary containing configurational entropy results. + + Raises: + KeyError: If required keys are missing. + """ + n_frames = self._get_n_frames(shared_data) + groups = shared_data["groups"] + levels = shared_data["levels"] + universe = shared_data["reduced_universe"] + reporter = shared_data.get("reporter") + + states_ua, states_res = self._get_state_containers(shared_data) + ce = self._build_entropy_engine() + + fragments = universe.atoms.fragments + results: Dict[int, Dict[str, float]] = {} + + for group_id, mol_ids in groups.items(): + results[group_id] = {"ua": 0.0, "res": 0.0, "poly": 0.0} + if not mol_ids: + continue + + rep_mol_id = mol_ids[0] + rep_mol = fragments[rep_mol_id] + level_list = levels[rep_mol_id] + + if "united_atom" in level_list: + ua_total = self._compute_ua_entropy_for_group( + ce=ce, + group_id=group_id, + residues=rep_mol.residues, + states_ua=states_ua, + n_frames=n_frames, + reporter=reporter, + ) + results[group_id]["ua"] = ua_total + + if "residue" in level_list: + res_val = self._compute_residue_entropy_for_group( + ce=ce, + group_id=group_id, + states_res=states_res, + n_frames=n_frames, + ) + results[group_id]["res"] = res_val + + if reporter is not None: + reporter.add_results_data( + group_id, "residue", "Conformational", res_val + ) + + shared_data["configurational_entropy"] = results + + return {"configurational_entropy": results} + + def _build_entropy_engine(self) -> ConformationalEntropy: + """Create the entropy calculation engine.""" + return ConformationalEntropy() + + def _get_state_containers( + self, shared_data: Mapping[str, Any] + ) -> Tuple[ + Dict[StateKey, StateSequence], + Union[Dict[GroupId, StateSequence], Sequence[Optional[StateSequence]]], + ]: + """Retrieve conformational state containers. + + Args: + shared_data: Shared workflow state. + + Returns: + Tuple of united atom and residue state containers. + """ + conf_states = shared_data.get("conformational_states", {}) or {} + return conf_states.get("ua", {}) or {}, conf_states.get("res", {}) + + def _get_n_frames(self, shared_data: Mapping[str, Any]) -> int: + """Return the number of frames analysed. + + Args: + shared_data: Shared workflow state. + + Returns: + Number of frames. + + Raises: + KeyError: If frame count is missing. + """ + n_frames = shared_data.get("n_frames", shared_data.get("number_frames")) + if n_frames is None: + raise KeyError("shared_data must contain n_frames or number_frames") + return int(n_frames) + + def _compute_ua_entropy_for_group( + self, + *, + ce: ConformationalEntropy, + group_id: int, + residues: Iterable[Any], + states_ua: Mapping[StateKey, StateSequence], + n_frames: int, + reporter: Optional[Any], + ) -> float: + """Compute united atom entropy for a group. + + Args: + ce: Entropy calculator. + group_id: Group identifier. + residues: Residue iterable. + states_ua: Mapping of states. + n_frames: Frame count. + reporter: Optional logger. + + Returns: + Total entropy for united atom level. + """ + total = 0.0 + + for res_id, res in enumerate(residues): + states = states_ua.get((group_id, res_id)) + val = self._entropy_or_zero(ce, states, n_frames) + total += val + + if reporter is not None: + reporter.add_residue_data( + group_id=group_id, + resname=getattr(res, "resname", "UNK"), + level="united_atom", + entropy_type="Conformational", + frame_count=n_frames, + value=val, + ) + + if reporter is not None: + reporter.add_results_data(group_id, "united_atom", "Conformational", total) + + return total + + def _compute_residue_entropy_for_group( + self, + *, + ce: ConformationalEntropy, + group_id: int, + states_res: Union[Dict[int, StateSequence], Sequence[Optional[StateSequence]]], + n_frames: int, + ) -> float: + """Compute residue-level entropy for a group.""" + group_states = self._get_group_states(states_res, group_id) + return self._entropy_or_zero(ce, group_states, n_frames) + + def _entropy_or_zero( + self, + ce: ConformationalEntropy, + states: Optional[StateSequence], + n_frames: int, + ) -> float: + """Return entropy value or zero if no state data exists.""" + if not self._has_state_data(states): + return 0.0 + return float(ce.conformational_entropy_calculation(states, n_frames)) + + @staticmethod + def _get_group_states( + states_res: Union[Dict[int, StateSequence], Sequence[Optional[StateSequence]]], + group_id: int, + ) -> Optional[StateSequence]: + """Fetch group states from container.""" + if isinstance(states_res, dict): + return states_res.get(group_id) + if group_id < len(states_res): + return states_res[group_id] + return None + + @staticmethod + def _has_state_data(states: Optional[StateSequence]) -> bool: + """Check if state container has usable data.""" + if states is None: + return False + if isinstance(states, np.ndarray): + return bool(np.any(states)) + try: + return any(states) + except TypeError: + return bool(states) diff --git a/CodeEntropy/entropy/nodes/vibrational.py b/CodeEntropy/entropy/nodes/vibrational.py new file mode 100644 index 00000000..6191e14b --- /dev/null +++ b/CodeEntropy/entropy/nodes/vibrational.py @@ -0,0 +1,431 @@ +"""Node for computing vibrational entropy from covariance matrices.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple + +import numpy as np + +from CodeEntropy.entropy.vibrational import VibrationalEntropy +from CodeEntropy.levels.linalg import MatrixUtils + +logger = logging.getLogger(__name__) + + +GroupId = int +ResidueId = int +CovKey = Tuple[GroupId, ResidueId] + + +@dataclass(frozen=True) +class EntropyPair: + """Container for paired translational and rotational entropy values. + + Attributes: + trans: Translational vibrational entropy value. + rot: Rotational vibrational entropy value. + """ + + trans: float + rot: float + + +class VibrationalEntropyNode: + """Compute vibrational entropy from force/torque (and optional FT) covariances. + + This node reads covariance matrices from a shared data mapping, computes + translational and rotational vibrational entropy at requested hierarchy levels, + and stores results back into the shared data structure. + + The node supports: + - Force and torque covariance matrices ("force" / "torque") at residue/polymer + levels. + - United-atom per-residue covariances keyed by (group_id, residue_id). + - Optional combined force-torque covariance matrices ("forcetorque") for the + highest level when enabled via args.combined_forcetorque. + """ + + def __init__(self) -> None: + """Initialize the node with matrix utilities and numerical tolerances.""" + self._mat_ops = MatrixUtils() + self._zero_atol = 1e-8 + + def run(self, shared_data: MutableMapping[str, Any], **_: Any) -> Dict[str, Any]: + """Run vibrational entropy calculations and update the shared data mapping. + + Args: + shared_data: Mutable mapping containing inputs (covariances, groups, + levels, args, etc.) and where outputs will be written. + **_: Unused keyword arguments, accepted for framework compatibility. + + Returns: + A dict containing the computed vibrational entropy results under the + key "vibrational_entropy". + + Raises: + ValueError: If an unknown level is encountered in the level list for a + representative molecule. + """ + ve = self._build_entropy_engine(shared_data) + temp = shared_data["args"].temperature + + groups = shared_data["groups"] + levels = shared_data["levels"] + fragments = shared_data["reduced_universe"].atoms.fragments + + gid2i = self._get_group_id_to_index(shared_data) + + force_cov = shared_data["force_covariances"] + torque_cov = shared_data["torque_covariances"] + + combined = bool(getattr(shared_data["args"], "combined_forcetorque", False)) + ft_cov = shared_data.get("forcetorque_covariances") if combined else None + + ua_frame_counts = self._get_ua_frame_counts(shared_data) + reporter = shared_data.get("reporter") + + results: Dict[int, Dict[str, Dict[str, float]]] = {} + + for group_id, mol_ids in groups.items(): + results[group_id] = {} + if not mol_ids: + continue + + rep_mol_id = mol_ids[0] + rep_mol = fragments[rep_mol_id] + level_list = levels[rep_mol_id] + + for level in level_list: + highest = level == level_list[-1] + + if level == "united_atom": + pair = self._compute_united_atom_entropy( + ve=ve, + temp=temp, + group_id=group_id, + residues=rep_mol.residues, + force_ua=force_cov["ua"], + torque_ua=torque_cov["ua"], + ua_frame_counts=ua_frame_counts, + reporter=reporter, + n_frames_default=shared_data.get("n_frames", 0), + highest=highest, + ) + self._store_results(results, group_id, level, pair) + self._log_molecule_level_results( + reporter, group_id, level, pair, use_ft_labels=False + ) + continue + + if level in ("residue", "polymer"): + gi = gid2i[group_id] + + if combined and highest and ft_cov is not None: + ft_key = "res" if level == "residue" else "poly" + ftmat = self._get_indexed_matrix(ft_cov.get(ft_key, []), gi) + + pair = self._compute_ft_entropy(ve=ve, temp=temp, ftmat=ftmat) + self._store_results(results, group_id, level, pair) + self._log_molecule_level_results( + reporter, group_id, level, pair, use_ft_labels=True + ) + continue + + cov_key = "res" if level == "residue" else "poly" + fmat = self._get_indexed_matrix(force_cov.get(cov_key, []), gi) + tmat = self._get_indexed_matrix(torque_cov.get(cov_key, []), gi) + + pair = self._compute_force_torque_entropy( + ve=ve, + temp=temp, + fmat=fmat, + tmat=tmat, + highest=highest, + ) + self._store_results(results, group_id, level, pair) + self._log_molecule_level_results( + reporter, group_id, level, pair, use_ft_labels=False + ) + continue + + raise ValueError(f"Unknown level: {level}") + + shared_data["vibrational_entropy"] = results + return {"vibrational_entropy": results} + + def _build_entropy_engine( + self, shared_data: Mapping[str, Any] + ) -> VibrationalEntropy: + """Construct the vibrational entropy engine used for calculations. + + Args: + shared_data: Read-only mapping containing a "run_manager" entry. + + Returns: + A configured VibrationalEntropy instance. + """ + return VibrationalEntropy( + run_manager=shared_data["run_manager"], + ) + + def _get_group_id_to_index(self, shared_data: Mapping[str, Any]) -> Dict[int, int]: + """Return a mapping from group_id to contiguous index used by covariance lists. + + If a precomputed mapping is provided under "group_id_to_index", it is used. + Otherwise, the mapping is derived from the insertion order of "groups". + + Args: + shared_data: Read-only mapping containing "groups" and optionally + "group_id_to_index". + + Returns: + Dictionary mapping each group_id to an integer index. + """ + gid2i = shared_data.get("group_id_to_index") + if isinstance(gid2i, dict) and gid2i: + return gid2i + groups = shared_data["groups"] + return {gid: i for i, gid in enumerate(groups.keys())} + + def _get_ua_frame_counts(self, shared_data: Mapping[str, Any]) -> Dict[CovKey, int]: + """Extract per-(group,residue) frame counts for united-atom covariances. + + Args: + shared_data: Read-only mapping which may contain nested frame count data + under shared_data["frame_counts"]["ua"]. + + Returns: + A dict keyed by (group_id, residue_id) containing frame counts. Returns + an empty dict if not present or not well-formed. + """ + counts = shared_data.get("frame_counts", {}) + if isinstance(counts, dict): + ua_counts = counts.get("ua", {}) + if isinstance(ua_counts, dict): + return ua_counts + return {} + + def _compute_united_atom_entropy( + self, + *, + ve: VibrationalEntropy, + temp: float, + group_id: int, + residues: Any, + force_ua: Mapping[CovKey, Any], + torque_ua: Mapping[CovKey, Any], + ua_frame_counts: Mapping[CovKey, int], + reporter: Optional[Any], + n_frames_default: int, + highest: bool, + ) -> EntropyPair: + """Compute total united-atom vibrational entropy for a group's residues. + + Iterates over residues, looks up per-residue force and torque covariance + matrices keyed by (group_id, residue_index), computes entropy contributions, + accumulates totals, and optionally reports per-residue values. + + Args: + ve: VibrationalEntropy calculation engine. + temp: Temperature (K) for entropy calculation. + group_id: Identifier for the group being processed. + residues: Residue container/sequence for the representative molecule. + force_ua: Mapping from (group_id, residue_id) to force covariance matrix. + torque_ua: Mapping from (group_id, residue_id) to torque covariance matrix. + ua_frame_counts: Mapping from (group_id, residue_id) to frame counts. + reporter: Optional reporter object supporting add_residue_data calls. + n_frames_default: Fallback frame count if per-residue count missing. + highest: Whether this computation is at the highest requested level. + + Returns: + EntropyPair with summed translational and rotational entropy across residues + """ + s_trans_total = 0.0 + s_rot_total = 0.0 + + for res_id, res in enumerate(residues): + key = (group_id, res_id) + fmat = force_ua.get(key) + tmat = torque_ua.get(key) + + pair = self._compute_force_torque_entropy( + ve=ve, + temp=temp, + fmat=fmat, + tmat=tmat, + highest=highest, + ) + + s_trans_total += pair.trans + s_rot_total += pair.rot + + if reporter is not None: + frame_count = ua_frame_counts.get(key, int(n_frames_default or 0)) + reporter.add_residue_data( + group_id=group_id, + resname=getattr(res, "resname", "UNK"), + level="united_atom", + entropy_type="Transvibrational", + frame_count=frame_count, + value=pair.trans, + ) + reporter.add_residue_data( + group_id=group_id, + resname=getattr(res, "resname", "UNK"), + level="united_atom", + entropy_type="Rovibrational", + frame_count=frame_count, + value=pair.rot, + ) + + return EntropyPair(trans=float(s_trans_total), rot=float(s_rot_total)) + + def _compute_force_torque_entropy( + self, + *, + ve: VibrationalEntropy, + temp: float, + fmat: Any, + tmat: Any, + highest: bool, + ) -> EntropyPair: + """Compute vibrational entropy from separate force and torque covariances. + + Matrices are filtered to remove (near-)zero rows/columns before computation. + If either matrix is missing or becomes empty after filtering, returns zeros. + + Args: + ve: VibrationalEntropy calculation engine. + temp: Temperature (K) for entropy calculation. + fmat: Force covariance matrix (array-like) or None. + tmat: Torque covariance matrix (array-like) or None. + highest: Whether this computation is at the highest requested level. + + Returns: + EntropyPair containing translational entropy (from force covariance) and + rotational entropy (from torque covariance). + """ + if fmat is None or tmat is None: + return EntropyPair(trans=0.0, rot=0.0) + + f = self._mat_ops.filter_zero_rows_columns( + np.asarray(fmat), atol=self._zero_atol + ) + t = self._mat_ops.filter_zero_rows_columns( + np.asarray(tmat), atol=self._zero_atol + ) + + if f.size == 0 or t.size == 0: + return EntropyPair(trans=0.0, rot=0.0) + + s_trans = ve.vibrational_entropy_calculation( + f, "force", temp, highest_level=highest + ) + s_rot = ve.vibrational_entropy_calculation( + t, "torque", temp, highest_level=highest + ) + return EntropyPair(trans=float(s_trans), rot=float(s_rot)) + + def _compute_ft_entropy( + self, + *, + ve: VibrationalEntropy, + temp: float, + ftmat: Any, + ) -> EntropyPair: + """Compute vibrational entropy from a combined force-torque covariance matrix. + + The combined covariance matrix is filtered to remove (near-)zero rows/columns + before computation. If missing or empty after filtering, returns zeros. + + Args: + ve: VibrationalEntropy calculation engine. + temp: Temperature (K) for entropy calculation. + ftmat: Combined force-torque covariance matrix (array-like) or None. + + Returns: + EntropyPair containing translational and rotational entropy values derived + from the combined covariance matrix. + """ + if ftmat is None: + return EntropyPair(trans=0.0, rot=0.0) + + ft = self._mat_ops.filter_zero_rows_columns( + np.asarray(ftmat), atol=self._zero_atol + ) + if ft.size == 0: + return EntropyPair(trans=0.0, rot=0.0) + + s_trans = ve.vibrational_entropy_calculation( + ft, "forcetorqueTRANS", temp, highest_level=True + ) + s_rot = ve.vibrational_entropy_calculation( + ft, "forcetorqueROT", temp, highest_level=True + ) + return EntropyPair(trans=float(s_trans), rot=float(s_rot)) + + @staticmethod + def _store_results( + results: Dict[int, Dict[str, Dict[str, float]]], + group_id: int, + level: str, + pair: EntropyPair, + ) -> None: + """Store entropy results for a group/level into the results structure. + + Args: + results: Nested results dict indexed by group_id then level. + group_id: Group identifier to store under. + level: Hierarchy level name (e.g., "united_atom", "residue", "polymer"). + pair: EntropyPair containing translational and rotational values. + """ + results[group_id][level] = {"trans": pair.trans, "rot": pair.rot} + + @staticmethod + def _log_molecule_level_results( + reporter: Optional[Any], + group_id: int, + level: str, + pair: EntropyPair, + *, + use_ft_labels: bool, + ) -> None: + """Log molecule-level entropy results to the reporter, if available. + + Args: + reporter: Optional reporter object supporting add_results_data calls. + group_id: Group identifier being reported. + level: Hierarchy level name being reported. + pair: EntropyPair containing translational and rotational values. + use_ft_labels: Whether to use FT-specific labels for the entropy types. + """ + if reporter is None: + return + + if use_ft_labels: + reporter.add_results_data( + group_id, level, "FTmat-Transvibrational", pair.trans + ) + reporter.add_results_data(group_id, level, "FTmat-Rovibrational", pair.rot) + return + + reporter.add_results_data(group_id, level, "Transvibrational", pair.trans) + reporter.add_results_data(group_id, level, "Rovibrational", pair.rot) + + @staticmethod + def _get_indexed_matrix(mats: Any, index: int) -> Any: + """Safely retrieve mats[index] if mats is indexable and index is in range. + + Args: + mats: Indexable container of matrices (e.g., list/tuple) or other object. + index: Desired index. + + Returns: + The matrix at the given index if available; otherwise None. + """ + try: + return mats[index] if index < len(mats) else None + except TypeError: + return None diff --git a/CodeEntropy/entropy/orientational.py b/CodeEntropy/entropy/orientational.py new file mode 100644 index 00000000..bd4786e4 --- /dev/null +++ b/CodeEntropy/entropy/orientational.py @@ -0,0 +1,158 @@ +"""Orientational entropy calculations. + +This module defines `OrientationalEntropy`, which computes orientational entropy +from a neighbor-count mapping. + +The current implementation supports non-water neighbors. Water-specific behavior +can be implemented later behind an interface so the core calculation remains +stable and testable. +""" + +from __future__ import annotations + +import logging +import math +from dataclasses import dataclass +from typing import Any, Mapping + +import numpy as np + +logger = logging.getLogger(__name__) + +_GAS_CONST_J_PER_MOL_K = 8.3144598484848 + + +@dataclass(frozen=True) +class OrientationalEntropyResult: + """Result of an orientational entropy calculation. + + Attributes: + total: Total orientational entropy (J/mol/K). + """ + + total: float + + +class OrientationalEntropy: + """Compute orientational entropy from neighbor counts. + + This class is intentionally small and focused: it provides a single public + method that converts a mapping of neighbor species to neighbor counts into + an orientational entropy value. + + Notes: + The manager-like constructor signature is kept for compatibility with + the rest of the codebase, but the calculation itself does not depend on + those objects. + """ + + def __init__( + self, + run_manager: Any, + args: Any, + universe: Any, + reporter: Any, + group_molecules: Any, + gas_constant: float = _GAS_CONST_J_PER_MOL_K, + ) -> None: + """Initialize the orientational entropy calculator. + + Args: + run_manager: Run manager (currently unused by this class). + args: User arguments (currently unused by this class). + universe: MDAnalysis Universe (currently unused by this class). + reporter: Data logger (currently unused by this class). + group_molecules: Grouping helper (currently unused by this class). + gas_constant: Gas constant in J/(mol*K). + """ + self._run_manager = run_manager + self._args = args + self._universe = universe + self._reporter = reporter + self._group_molecules = group_molecules + self._gas_constant = float(gas_constant) + + def calculate(self, neighbours: Mapping[str, int]) -> OrientationalEntropyResult: + """Calculate orientational entropy from neighbor counts. + + For each neighbor species (except water), the number of orientations is + estimated as: + + Ω = sqrt(Nc^3 * π) + + and the entropy contribution is: + + S = R * ln(Ω) + + where Nc is the neighbor count and R is the gas constant. + + Args: + neighbours: Mapping of neighbor species name to count. + + Returns: + OrientationalEntropyResult containing the total entropy in J/mol/K. + """ + total = 0.0 + for species, count in neighbours.items(): + if self._is_water(species): + logger.debug( + "Skipping water species %s in orientational entropy.", species + ) + continue + + contribution = self._entropy_contribution(count) + logger.debug( + "Orientational entropy contribution for %s: %s", species, contribution + ) + total += contribution + + logger.debug("Final orientational entropy total: %s", total) + return OrientationalEntropyResult(total=float(total)) + + @staticmethod + def _is_water(species: str) -> bool: + """Return True if the species should be treated as water. + + Args: + species: Species identifier. + + Returns: + True if the species is considered water. + """ + return species in {"H2O", "WAT", "HOH"} + + def _entropy_contribution(self, neighbour_count: int) -> float: + """Compute the entropy contribution for a single neighbor count. + + Args: + neighbour_count: Number of neighbors (Nc). + + Returns: + Entropy contribution in J/mol/K. + + Raises: + ValueError: If neighbour_count is negative. + """ + if neighbour_count < 0: + raise ValueError(f"neighbour_count must be >= 0, got {neighbour_count}") + + if neighbour_count == 0: + return 0.0 + + omega = self._omega(neighbour_count) + if omega <= 0.0: + return 0.0 + + return self._gas_constant * math.log(omega) + + @staticmethod + def _omega(neighbour_count: int) -> float: + """Compute the number of orientations Ω. + + Args: + neighbour_count: Number of neighbors (Nc). + + Returns: + Ω (unitless). + """ + return float(np.sqrt((neighbour_count**3) * math.pi)) diff --git a/CodeEntropy/entropy/vibrational.py b/CodeEntropy/entropy/vibrational.py new file mode 100644 index 00000000..f1baf132 --- /dev/null +++ b/CodeEntropy/entropy/vibrational.py @@ -0,0 +1,252 @@ +"""Vibrational entropy calculations. + +This module provides `VibrationalEntropy`, which computes vibrational entropy +from force, torque, or combined force-torque covariance matrices. + +The implementation is intentionally split into small, single-purpose methods: +- Eigenvalue extraction + unit conversion +- Frequency calculation with robust filtering +- Entropy component computation +- Mode selection / summation rules based on matrix type +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Literal, Tuple + +import numpy as np +from numpy import linalg as la + +logger = logging.getLogger(__name__) + +MatrixType = Literal["force", "torque", "forcetorqueTRANS", "forcetorqueROT"] + + +@dataclass(frozen=True) +class VibrationalEntropyResult: + """Result of a vibrational entropy computation. + + Attributes: + total: Computed entropy value (J/mol/K) for the requested matrix type. + n_modes: Number of vibrational modes used (after filtering eigenvalues). + """ + + total: float + n_modes: int + + +class VibrationalEntropy: + """Compute vibrational entropy from covariance matrices. + + This class focuses only on vibrational entropy math and relies on `run_manager` + for unit conversions (eigenvalue unit conversion and kT conversion). + """ + + def __init__( + self, + run_manager: Any, + planck_const: float = 6.62607004081818e-34, + gas_const: float = 8.3144598484848, + ) -> None: + """Initialize the vibrational entropy calculator. + + Args: + run_manager: Provides thermodynamic conversions (e.g., kT in Joules) + and eigenvalue unit conversion. + planck_const: Planck constant (J*s). + gas_const: Gas constant (J/(mol*K)). + """ + self._run_manager = run_manager + self._planck_const = float(planck_const) + self._gas_const = float(gas_const) + + def vibrational_entropy_calculation( + self, + matrix: np.ndarray, + matrix_type: MatrixType, + temp: float, + highest_level: bool, + ) -> float: + """Compute vibrational entropy for the given covariance matrix. + + Supported matrix types: + - "force": 3N x 3N force covariance. + - "torque": 3N x 3N torque covariance. + - "forcetorqueTRANS": 6N x 6N combined covariance (translational part). + - "forcetorqueROT": 6N x 6N combined covariance (rotational part). + + Mode handling: + - Frequencies are computed from eigenvalues, filtered to valid values, + then sorted ascending. + - For "force": + - If highest_level, include all modes. + - Otherwise, drop the lowest 6 modes. + - For "torque": include all modes. + - For combined "forcetorque*": + - Split the sorted spectrum into two halves (first 3N, last 3N). + - If not highest_level, drop the lowest 6 modes only within the + translational half. + + Args: + matrix: Covariance matrix (shape depends on matrix_type). + matrix_type: Type of covariance matrix. + temp: Temperature in Kelvin. + highest_level: Whether this is the highest level in the hierarchy. + + Returns: + Vibrational entropy value in J/mol/K. + + Raises: + ValueError: If matrix_type is unknown. + """ + components = self._entropy_components(matrix, temp) + total = self._sum_components(components, matrix_type, highest_level) + return float(total) + + def _entropy_components(self, matrix: np.ndarray, temp: float) -> np.ndarray: + """Compute per-mode entropy components from a covariance matrix. + + Args: + matrix: Covariance matrix. + temp: Temperature in Kelvin. + + Returns: + Array of entropy components (J/mol/K) for each valid mode. + """ + lambdas = self._matrix_eigenvalues(matrix) + lambdas = self._convert_lambda_units(lambdas) + + freqs = self._frequencies_from_lambdas(lambdas, temp) + if freqs.size == 0: + return np.array([], dtype=float) + + freqs = np.sort(freqs) + return self._entropy_components_from_frequencies(freqs, temp) + + @staticmethod + def _matrix_eigenvalues(matrix: np.ndarray) -> np.ndarray: + """Compute eigenvalues of a matrix. + + Args: + matrix: Input matrix. + + Returns: + Eigenvalues as a NumPy array. + """ + matrix = np.asarray(matrix, dtype=float) + return la.eigvals(matrix) + + def _convert_lambda_units(self, lambdas: np.ndarray) -> np.ndarray: + """Convert eigenvalues into SI units using run_manager. + + Args: + lambdas: Eigenvalues. + + Returns: + Converted eigenvalues. + """ + return self._run_manager.change_lambda_units(lambdas) + + def _frequencies_from_lambdas(self, lambdas: np.ndarray, temp: float) -> np.ndarray: + """Convert eigenvalues to frequencies with robust filtering. + + Filters out eigenvalues that are complex, non-positive, or near-zero to + avoid invalid frequencies and unstable entropies. + + Args: + lambdas: Eigenvalues (post unit conversion). + temp: Temperature in Kelvin. + + Returns: + Frequencies in Hz. + """ + lambdas = np.asarray(lambdas) + lambdas = np.real_if_close(lambdas, tol=1000) + + valid_mask = ( + np.isreal(lambdas) & (lambdas > 0) & (~np.isclose(lambdas, 0, atol=1e-7)) + ) + + removed = int(len(lambdas) - np.count_nonzero(valid_mask)) + if removed: + logger.warning( + "%d invalid eigenvalues excluded (complex, non-positive, " + "or near-zero).", + removed, + ) + + lambdas = np.asarray(lambdas[valid_mask].real, dtype=float) + if lambdas.size == 0: + return np.array([], dtype=float) + + kT = float(self._run_manager.get_KT2J(temp)) + pi = float(np.pi) + return (1.0 / (2.0 * pi)) * np.sqrt(lambdas / kT) + + def _entropy_components_from_frequencies( + self, frequencies: np.ndarray, temp: float + ) -> np.ndarray: + """Compute per-mode entropy components from frequencies. + + Args: + frequencies: Frequencies (Hz), sorted ascending. + temp: Temperature in Kelvin. + + Returns: + Per-mode entropy components in J/mol/K. + """ + kT = float(self._run_manager.get_KT2J(temp)) + exponent = (self._planck_const * frequencies) / kT + + exp_pos = np.exp(exponent) + exp_neg = np.exp(-exponent) + + components = exponent / (exp_pos - 1.0) - np.log(1.0 - exp_neg) + return components * self._gas_const + + @staticmethod + def _split_halves(components: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Split a component array into two equal halves. + + Args: + components: Array with an even length. + + Returns: + Tuple of (first_half, second_half). If odd-length, returns + (components, empty). + + Notes: + For combined force-torque matrices (6N x 6N), the valid number of modes + should be 6N. After sorting, we split into two halves of size 3N. + """ + n = int(components.size) + if n % 2 != 0: + return components, np.array([], dtype=float) + half = n // 2 + return components[:half], components[half:] + + def _sum_components( + self, + components: np.ndarray, + matrix_type: MatrixType, + highest_level: bool, + ) -> float: + if components.size == 0: + return 0.0 + + if matrix_type == "force": + return float( + np.sum(components) if highest_level else np.sum(components[6:]) + ) + + if matrix_type == "torque": + return float(np.sum(components)) + + if matrix_type in ("forcetorqueTRANS", "forcetorqueROT"): + if matrix_type == "forcetorqueTRANS": + return float(np.sum(components[:3])) + return float(np.sum(components[3:])) + + raise ValueError(f"Unknown matrix_type: {matrix_type}") diff --git a/CodeEntropy/entropy/water.py b/CodeEntropy/entropy/water.py new file mode 100644 index 00000000..5be7d13c --- /dev/null +++ b/CodeEntropy/entropy/water.py @@ -0,0 +1,243 @@ +"""Water entropy aggregation. + +This module wraps the waterEntropy routines and maps their +outputs into the project `ResultsReporter` format. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Callable, Mapping, Optional, Tuple + +import numpy as np +import waterEntropy.recipes.interfacial_solvent as GetSolvent + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class WaterEntropyInput: + """Inputs for water entropy computation. + + Attributes: + universe: MDAnalysis Universe containing the system. + start: Start frame index (inclusive). + end: End frame index (exclusive, or -1 depending on caller convention). + step: Frame stride. + temperature: Temperature in Kelvin. + group_id: Group ID used for logging. + """ + + universe: Any + start: int + end: int + step: int + temperature: float + group_id: Optional[int] = None + + +class WaterEntropy: + """Compute and log water entropy contributions. + + This class calls the external `waterEntropy` routine to compute: + - orientational entropy per residue + - translational vibrational entropy + - rotational vibrational entropy + + Then it logs residue-level entries and adds a group label. + """ + + def __init__( + self, + args: Any, + reporter: Any, + solver: Callable[..., Tuple[dict, Any, Any, Any, Any]] = ( + GetSolvent.get_interfacial_water_orient_entropy + ), + ) -> None: + """Initialize the water entropy calculator. + + Args: + args: Argument namespace; must include `temperature`. + reporter: Logger used to record residue and group results. + solver: Callable compatible with + `get_interfacial_water_orient_entropy + (universe, start, end, step, temperature, parallel=True)`. + Dependency injection allows unit testing without the external package. + """ + self._args = args + self._reporter = reporter + self._solver = solver + + def calculate_and_log( + self, + universe: Any, + start: int, + end: int, + step: int, + group_id: Optional[int] = None, + ) -> None: + """Compute water entropy and write results to the data logger. + + Args: + universe: MDAnalysis Universe containing water. + start: Start frame index. + end: End frame index. + step: Frame stride. + group_id: Group ID to assign all water contributions to. + """ + inputs = WaterEntropyInput( + universe=universe, + start=start, + end=end, + step=step, + temperature=float(self._args.temperature), + group_id=group_id, + ) + self._calculate_and_log_from_inputs(inputs) + + def _calculate_and_log_from_inputs(self, inputs: WaterEntropyInput) -> None: + """Run the solver and log all returned entropy components.""" + Sorient_dict, covariances, vibrations, _unused, _water_count = self._run_solver( + inputs + ) + + self._log_orientational_entropy(Sorient_dict, inputs.group_id) + self._log_translational_entropy(vibrations, covariances, inputs.group_id) + self._log_rotational_entropy(vibrations, covariances, inputs.group_id) + self._log_group_label(inputs.universe, Sorient_dict, inputs.group_id) + + def _run_solver(self, inputs: WaterEntropyInput): + """Call the external solver. + + Args: + inputs: WaterEntropyInput. + + Returns: + Tuple of solver outputs. + """ + logger.info( + "[WaterEntropy] Computing water entropy (start=%s, end=%s, step=%s)", + inputs.start, + inputs.end, + inputs.step, + ) + return self._solver( + inputs.universe, + inputs.start, + inputs.end, + inputs.step, + inputs.temperature, + parallel=True, + ) + + def _log_orientational_entropy( + self, Sorient_dict: Mapping[Any, Mapping[str, Any]], group_id: Optional[int] + ) -> None: + """Log orientational entropy entries. + + Args: + Sorient_dict: Mapping of residue ids to {resname: [entropy, count]}. + group_id: Group ID to assign logs to. + """ + for _resid, resname_dict in Sorient_dict.items(): + for resname, values in resname_dict.items(): + if isinstance(values, list) and len(values) == 2: + entropy, count = values + self._reporter.add_residue_data( + group_id, resname, "Water", "Orientational", count, entropy + ) + + def _log_translational_entropy( + self, vibrations: Any, covariances: Any, group_id: Optional[int] + ) -> None: + """Log translational vibrational entropy entries. + + Args: + vibrations: Solver vibrations object with `translational_S`. + covariances: Solver covariances object with `counts`. + group_id: Group ID to assign logs to. + """ + translational = getattr(vibrations, "translational_S", {}) or {} + counts = getattr(covariances, "counts", {}) or {} + + for (solute_id, _), entropy in translational.items(): + value = ( + float(np.sum(entropy)) + if isinstance(entropy, (list, np.ndarray)) + else float(entropy) + ) + count = counts.get((solute_id, "WAT"), 1) + resname = self._solute_id_to_resname(solute_id) + self._reporter.add_residue_data( + group_id, resname, "Water", "Transvibrational", count, value + ) + + def _log_rotational_entropy( + self, vibrations: Any, covariances: Any, group_id: Optional[int] + ) -> None: + """Log rotational vibrational entropy entries. + + Args: + vibrations: Solver vibrations object with `rotational_S`. + covariances: Solver covariances object with `counts`. + group_id: Group ID to assign logs to. + """ + rotational = getattr(vibrations, "rotational_S", {}) or {} + counts = getattr(covariances, "counts", {}) or {} + + for (solute_id, _), entropy in rotational.items(): + value = ( + float(np.sum(entropy)) + if isinstance(entropy, (list, np.ndarray)) + else float(entropy) + ) + count = counts.get((solute_id, "WAT"), 1) + resname = self._solute_id_to_resname(solute_id) + self._reporter.add_residue_data( + group_id, resname, "Water", "Rovibrational", count, value + ) + + def _log_group_label( + self, + universe: Any, + Sorient_dict: Mapping[Any, Mapping[str, Any]], + group_id: Optional[int], + ) -> None: + """Log a group label summarizing the water entries. + + Args: + universe: MDAnalysis Universe. + Sorient_dict: Orientational entropy dict used to infer residue names. + group_id: Group ID. + """ + water_selection = universe.select_atoms("resname WAT") + actual_water_residues = len(water_selection.residues) + + water_resnames = set(water_selection.residues.resnames) + residue_names = { + resname + for res_dict in Sorient_dict.values() + for resname in res_dict.keys() + if str(resname).upper() in {str(r).upper() for r in water_resnames} + } + + residue_group = "_".join(sorted(residue_names)) if residue_names else "WAT" + self._reporter.add_group_label( + group_id, residue_group, actual_water_residues, len(water_selection.atoms) + ) + + @staticmethod + def _solute_id_to_resname(solute_id: str) -> str: + """Convert a solver solute_id to a residue-like name. + + Args: + solute_id: Identifier returned by the solver. + + Returns: + Residue name string. + """ + if "_" in str(solute_id): + return str(solute_id).rsplit("_", 1)[0] + return str(solute_id) diff --git a/CodeEntropy/entropy/workflow.py b/CodeEntropy/entropy/workflow.py new file mode 100644 index 00000000..c5a76652 --- /dev/null +++ b/CodeEntropy/entropy/workflow.py @@ -0,0 +1,359 @@ +"""Entropy manager orchestration. + +This module defines `EntropyWorkflow`, which coordinates the end-to-end entropy +workflow: + * Determine trajectory bounds and frame count. + * Build a reduced universe based on atom selection. + * Identify molecule groups and hierarchy levels. + * Optionally compute water entropy and adjust selection. + * Execute the level DAG (matrix/state preparation). + * Execute the entropy graph (entropy calculations and aggregation). + * Finalize and persist results. + +The manager intentionally delegates calculations to dedicated components. +""" + +from __future__ import annotations + +import logging +import math +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, Mapping, Tuple + +import pandas as pd + +from CodeEntropy.core.logging import LoggingConfig +from CodeEntropy.entropy.graph import EntropyGraph +from CodeEntropy.entropy.water import WaterEntropy +from CodeEntropy.levels.hierarchy import HierarchyBuilder +from CodeEntropy.levels.level_dag import LevelDAG + +logger = logging.getLogger(__name__) +console = LoggingConfig.get_console() + +SharedData = Dict[str, Any] + + +@dataclass(frozen=True) +class TrajectorySlice: + """Trajectory slicing parameters. + + Attributes: + start: Inclusive start frame index. + end: Exclusive end frame index (or a concrete index derived from args). + step: Step size between frames. + n_frames: Number of frames in the slice. + """ + + start: int + end: int + step: int + n_frames: int + + +class EntropyWorkflow: + """Coordinate entropy calculations across structural levels. + + This class is responsible for orchestration and IO-level concerns (selection, + grouping, running graphs, and finalizing results). Domain calculations live in + dedicated components (LevelDAG, EntropyGraph, WaterEntropy, etc.). + """ + + def __init__( + self, + run_manager: Any, + args: Any, + universe: Any, + reporter: Any, + group_molecules: Any, + dihedral_analysis: Any, + universe_operations: Any, + ) -> None: + """Initialize the entropy workflow manager. + + Args: + run_manager: Manager for universe IO and unit conversions. + args: Parsed CLI/user arguments. + universe: MDAnalysis Universe representing the simulation system. + reporter: Collector for per-molecule and per-residue outputs. + group_molecules: Component that groups molecules for averaging. + dihedral_analysis: Component used to compute conformational states. + (Stored for completeness; computation is typically triggered by nodes.) + universe_operations: Adapter providing common universe operations. + """ + self._run_manager = run_manager + self._args = args + self._universe = universe + self._reporter = reporter + self._group_molecules = group_molecules + self._dihedral_analysis = dihedral_analysis + self._universe_operations = universe_operations + + def execute(self) -> None: + """Run the full entropy workflow and emit results. + + This method orchestrates the complete pipeline, populates shared data, + and triggers the DAG/graph executions. Final results are logged and saved + via `ResultsReporter`. + """ + traj = self._build_trajectory_slice() + console.print( + f"Analyzing a total of {traj.n_frames} frames in this calculation." + ) + + reduced_universe = self._build_reduced_universe() + + levels = self._detect_levels(reduced_universe) + groups = self._group_molecules.grouping_molecules( + reduced_universe, self._args.grouping + ) + + nonwater_groups, water_groups = self._split_water_groups(groups) + + if self._args.water_entropy and water_groups: + self._compute_water_entropy(traj, water_groups) + else: + nonwater_groups.update(water_groups) + + shared_data = self._build_shared_data( + reduced_universe=reduced_universe, + levels=levels, + groups=nonwater_groups, + traj=traj, + ) + + with self._reporter.progress(transient=False) as p: + self._run_level_dag(shared_data, progress=p) + self._run_entropy_graph(shared_data, progress=p) + + self._finalize_molecule_results() + self._reporter.log_tables() + + def _build_shared_data( + self, + reduced_universe: Any, + levels: Any, + groups: Mapping[int, Any], + traj: TrajectorySlice, + ) -> SharedData: + """Build the shared_data dict used by nodes and graphs. + + Args: + reduced_universe: Universe after applying selection. + levels: Level definition per molecule id. + groups: Mapping of group id -> list of molecule ids. + traj: Trajectory slice parameters. + + Returns: + Shared data dictionary for DAG/graph execution. + """ + shared_data: SharedData = { + "entropy_manager": self, + "run_manager": self._run_manager, + "reporter": self._reporter, + "args": self._args, + "universe": self._universe, + "reduced_universe": reduced_universe, + "levels": levels, + "groups": dict(groups), + "start": traj.start, + "end": traj.end, + "step": traj.step, + "n_frames": traj.n_frames, + } + return shared_data + + def _run_level_dag( + self, shared_data: SharedData, *, progress: object | None = None + ) -> None: + """Execute the structural/level DAG. + + Args: + shared_data: Shared data dict that will be mutated by the DAG. + progress: Optional progress sink provided by ResultsReporter.progress(). + """ + LevelDAG(self._universe_operations).build().execute( + shared_data, progress=progress + ) + + def _run_entropy_graph( + self, shared_data: SharedData, *, progress: object | None = None + ) -> None: + """Execute the entropy calculation graph and merge results into shared_data. + + Args: + shared_data: Shared data dict that will be mutated by the graph. + progress: Optional progress sink provided by ResultsReporter.progress(). + """ + entropy_results = EntropyGraph().build().execute(shared_data, progress=progress) + shared_data.update(entropy_results) + + def _build_trajectory_slice(self) -> TrajectorySlice: + """Compute trajectory slicing parameters from args. + + Returns: + A TrajectorySlice describing the frames to analyze. + """ + start, end, step = self._get_trajectory_bounds() + n_frames = self._get_number_frames(start, end, step) + return TrajectorySlice(start=start, end=end, step=step, n_frames=n_frames) + + def _get_trajectory_bounds(self) -> Tuple[int, int, int]: + """Return start, end, and step frame indices from args. + + Returns: + Tuple of (start, end, step). + """ + start = self._args.start or 0 + end = len(self._universe.trajectory) if self._args.end == -1 else self._args.end + step = self._args.step or 1 + return start, end, step + + def _get_number_frames(self, start: int, end: int, step: int) -> int: + """Compute the number of frames in a trajectory slice. + + Args: + start: Inclusive start frame index. + end: Exclusive end frame index. + step: Step between frames. + + Returns: + Number of frames processed. + """ + return math.floor((end - start) / step) + + def _build_reduced_universe(self) -> Any: + """Apply atom selection and return the reduced universe. + + If `selection_string` is "all", the original universe is returned. + + Returns: + MDAnalysis Universe (original or reduced). + """ + selection = self._args.selection_string + if selection == "all": + return self._universe + + reduced = self._universe_operations.select_atoms(self._universe, selection) + name = f"{len(reduced.trajectory)}_frame_dump_atom_selection" + self._run_manager.write_universe(reduced, name) + return reduced + + def _detect_levels(self, reduced_universe: Any) -> Any: + """Detect hierarchy levels for each molecule in the reduced universe. + + Args: + reduced_universe: Reduced MDAnalysis Universe. + + Returns: + Levels structure as returned by `HierarchyBuilder.select_levels`. + """ + level_hierarchy = HierarchyBuilder() + _number_molecules, levels = level_hierarchy.select_levels(reduced_universe) + return levels + + def _split_water_groups( + self, groups: Mapping[int, Any] + ) -> Tuple[Dict[int, Any], Dict[int, Any]]: + """Partition molecule groups into water and non-water groups. + + Args: + groups: Mapping of group id -> molecule ids. + + Returns: + Tuple of (nonwater_groups, water_groups). + """ + water_atoms = self._universe.select_atoms("water") + water_resids = {res.resid for res in water_atoms.residues} + + water_groups = { + gid: mol_ids + for gid, mol_ids in groups.items() + if any( + res.resid in water_resids + for mol in [self._universe.atoms.fragments[i] for i in mol_ids] + for res in mol.residues + ) + } + nonwater_groups = { + gid: g for gid, g in groups.items() if gid not in water_groups + } + return nonwater_groups, water_groups + + def _compute_water_entropy( + self, traj: TrajectorySlice, water_groups: Mapping[int, Any] + ) -> None: + """Compute water entropy for each water group and adjust selection string. + + Args: + traj: Trajectory slice parameters. + water_groups: Mapping of group id -> molecule ids for waters. + """ + if not water_groups or not self._args.water_entropy: + return + + water_entropy = WaterEntropy(self._args) + + for group_id in water_groups.keys(): + water_entropy._calculate_water_entropy( + universe=self._universe, + start=traj.start, + end=traj.end, + step=traj.step, + group_id=group_id, + ) + + self._args.selection_string = ( + f"{self._args.selection_string} and not water" + if self._args.selection_string != "all" + else "not water" + ) + + logger.debug("WaterEntropy: molecule_data=%s", self._reporter.molecule_data) + logger.debug("WaterEntropy: residue_data=%s", self._reporter.residue_data) + + def _finalize_molecule_results(self) -> None: + """Aggregate group totals and persist results to JSON. + + Computes total entropy per group and appends "Group Total" rows to the + molecule results table, then writes molecule and residue tables to the + configured output file via the data logger. + """ + entropy_by_group = defaultdict(float) + + for group_id, level, _etype, result in self._reporter.molecule_data: + if level == "Group Total": + continue + try: + entropy_by_group[group_id] += float(result) + except (TypeError, ValueError): + logger.warning("Skipping invalid entry: %s, %s", group_id, result) + + for group_id, total in entropy_by_group.items(): + self._reporter.molecule_data.append( + (group_id, "Group Total", "Group Total Entropy", total) + ) + + molecule_df = pd.DataFrame( + self._reporter.molecule_data, + columns=["Group ID", "Level", "Type", "Result (J/mol/K)"], + ) + residue_df = pd.DataFrame( + self._reporter.residue_data, + columns=[ + "Group ID", + "Residue Name", + "Level", + "Type", + "Frame Count", + "Result (J/mol/K)", + ], + ) + self._reporter.save_dataframes_as_json( + molecule_df, + residue_df, + self._args.output_file, + args=self._args, + include_raw_tables=False, + ) diff --git a/CodeEntropy/group_molecules.py b/CodeEntropy/group_molecules.py deleted file mode 100644 index 417f293e..00000000 --- a/CodeEntropy/group_molecules.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging - -logger = logging.getLogger(__name__) - - -class GroupMolecules: - """ - Groups molecules for averaging. - """ - - def __init__(self): - """ - Initializes the class with relevant information. - - """ - self._molecule_groups = None - - def grouping_molecules(self, universe, grouping): - """ - Grouping molecules by desired level of detail. - - Args: - universe: MDAnalysis univers object for the system of interest. - grouping (str): how to group molecules for averaging - - Returns: - molecule_groups (dict): molecule indices for each group. - """ - - molecule_groups = {} - - if grouping == "each": - molecule_groups = self._by_none(universe) - - if grouping == "molecules": - molecule_groups = self._by_molecules(universe) - - number_groups = len(molecule_groups) - - logger.info(f"Number of molecule groups: {number_groups}") - logger.debug(f"Molecule groups are: {molecule_groups}") - - return molecule_groups - - def _by_none(self, universe): - """ - Don't group molecules. Every molecule is in its own group. - - Args: - universe: MDAnalysis universe - - Returns: - molecule_groups (dict): molecule indices for each group. - """ - - # fragments is MDAnalysis terminology for molecules - number_molecules = len(universe.atoms.fragments) - - molecule_groups = {} - - for molecule_i in range(number_molecules): - molecule_groups[molecule_i] = [molecule_i] - - return molecule_groups - - def _by_molecules(self, universe): - """ - Group molecules by chemical type. - Based on number of atoms and atom names. - - Args: - universe: MDAnalysis universe - - Returns: - molecule_groups (dict): molecule indices for each group. - """ - - # fragments is MDAnalysis terminology for molecules - number_molecules = len(universe.atoms.fragments) - fragments = universe.atoms.fragments - - molecule_groups = {} - - for molecule_i in range(number_molecules): - names_i = fragments[molecule_i].names - number_atoms_i = len(names_i) - - for molecule_j in range(number_molecules): - names_j = fragments[molecule_j].names - number_atoms_j = len(names_j) - - # If molecule_i has the same number of atoms and same - # atom names as molecule_j, then index i is added to group j - # The index of molecule_j is the group key, the keys are - # all integers, but may not be consecutive numbers. - if number_atoms_i == number_atoms_j and all( - i == j for i, j in zip(names_i, names_j) - ): - if molecule_j in molecule_groups.keys(): - molecule_groups[molecule_j].append(molecule_i) - else: - molecule_groups[molecule_j] = [] - molecule_groups[molecule_j].append(molecule_i) - break - - return molecule_groups diff --git a/CodeEntropy/levels.py b/CodeEntropy/levels.py deleted file mode 100644 index 14262459..00000000 --- a/CodeEntropy/levels.py +++ /dev/null @@ -1,932 +0,0 @@ -import logging - -import numpy as np -from MDAnalysis.lib.mdamath import make_whole -from rich.progress import ( - BarColumn, - Progress, - SpinnerColumn, - TextColumn, - TimeElapsedColumn, -) - -from CodeEntropy.axes import AxesManager - -logger = logging.getLogger(__name__) - - -class LevelManager: - """ - Manages the structural and dynamic levels involved in entropy calculations. This - includes selecting relevant levels, computing axes for translation and rotation, - and handling bead-based representations of molecular systems. Provides utility - methods to extract averaged positions, convert coordinates to spherical systems, - compute weighted forces and torques, and manipulate matrices used in entropy - analysis. - """ - - def __init__(self, universe_operations): - """ - Initializes the LevelManager with placeholders for level-related data, - including translational and rotational axes, number of beads, and a - general-purpose data container. - """ - self.data_container = None - self._levels = None - self._trans_axes = None - self._rot_axes = None - self._number_of_beads = None - self._universe_operations = universe_operations - - def select_levels(self, data_container): - """ - Function to read input system and identify the number of molecules and - the levels (i.e. united atom, residue and/or polymer) that should be used. - The level refers to the size of the bead (atom or collection of atoms) - that will be used in the entropy calculations. - - Args: - arg_DataContainer: MDAnalysis universe object containing the system of - interest - - Returns: - number_molecules (int): Number of molecules in the system. - levels (array): Strings describing the length scales for each molecule. - """ - - # fragments is MDAnalysis terminology for what chemists would call molecules - number_molecules = len(data_container.atoms.fragments) - logger.debug(f"The number of molecules is {number_molecules}.") - - fragments = data_container.atoms.fragments - levels = [[] for _ in range(number_molecules)] - - for molecule in range(number_molecules): - levels[molecule].append( - "united_atom" - ) # every molecule has at least one atom - - atoms_in_fragment = fragments[molecule].select_atoms("prop mass > 1.1") - number_residues = len(atoms_in_fragment.residues) - - if len(atoms_in_fragment) > 1: - levels[molecule].append("residue") - - if number_residues > 1: - levels[molecule].append("polymer") - - logger.debug(f"levels {levels}") - - return number_molecules, levels - - def get_matrices( - self, - data_container, - level, - highest_level, - force_matrix, - torque_matrix, - force_partitioning, - customised_axes, - ): - """ - Compute and accumulate force/torque covariance matrices for a given level. - - Parameters: - data_container (MDAnalysis.Universe): Data for a molecule or residue. - level (str): 'polymer', 'residue', or 'united_atom'. - highest_level (bool): Whether this is the top (largest bead size) level. - force_matrix, torque_matrix (np.ndarray or None): Accumulated matrices to add - to. - force_partitioning (float): Factor to adjust force contributions, - default is 0.5. - customised_axes (bool): Whether to use customised axes for rotating - forces. - - Returns: - force_matrix (np.ndarray): Accumulated force covariance matrix. - torque_matrix (np.ndarray): Accumulated torque covariance matrix. - """ - - # Make beads - list_of_beads = self.get_beads(data_container, level) - - # number of beads and frames in trajectory - number_beads = len(list_of_beads) - - # initialize force and torque arrays - weighted_forces = [None for _ in range(number_beads)] - weighted_torques = [None for _ in range(number_beads)] - - # Calculate forces/torques for each bead - for bead_index in range(number_beads): - bead = list_of_beads[bead_index] - - # Set up axes - # translation and rotation use different axes - # how the axes are defined depends on the level - axes_manager = AxesManager() - if level == "united_atom" and customised_axes: - trans_axes, rot_axes, center, moment_of_inertia = ( - axes_manager.get_UA_axes(data_container, bead_index) - ) - elif level == "residue" and customised_axes: - trans_axes, rot_axes, center, moment_of_inertia = ( - axes_manager.get_residue_axes(data_container, bead_index) - ) - else: - make_whole(data_container.atoms) - make_whole(bead) - trans_axes = data_container.atoms.principal_axes() - rot_axes = np.real(bead.principal_axes()) - eigenvalues, _ = np.linalg.eig(bead.moment_of_inertia(unwrap=True)) - moment_of_inertia = sorted(eigenvalues, reverse=True) - center = bead.center_of_mass(unwrap=True) - - # Sort out coordinates, forces, and torques for each atom in the bead - weighted_forces[bead_index] = self.get_weighted_forces( - data_container, - bead, - trans_axes, - highest_level, - force_partitioning, - ) - weighted_torques[bead_index] = self.get_weighted_torques( - bead, - rot_axes, - center, - force_partitioning, - moment_of_inertia, - axes_manager, - ) - - # Create covariance submatrices - force_submatrix = [ - [0 for _ in range(number_beads)] for _ in range(number_beads) - ] - torque_submatrix = [ - [0 for _ in range(number_beads)] for _ in range(number_beads) - ] - - for i in range(number_beads): - for j in range(i, number_beads): - f_sub = self.create_submatrix(weighted_forces[i], weighted_forces[j]) - t_sub = self.create_submatrix(weighted_torques[i], weighted_torques[j]) - force_submatrix[i][j] = f_sub - force_submatrix[j][i] = f_sub.T - torque_submatrix[i][j] = t_sub - torque_submatrix[j][i] = t_sub.T - - # Convert block matrices to full matrix - force_block = np.block( - [ - [force_submatrix[i][j] for j in range(number_beads)] - for i in range(number_beads) - ] - ) - torque_block = np.block( - [ - [torque_submatrix[i][j] for j in range(number_beads)] - for i in range(number_beads) - ] - ) - - # Enforce consistent shape before accumulation - if force_matrix is None: - force_matrix = np.zeros_like(force_block) - force_matrix = force_block # add first set of forces - elif force_matrix.shape != force_block.shape: - raise ValueError( - f"Inconsistent force matrix shape: existing " - f"{force_matrix.shape}, new {force_block.shape}" - ) - else: - force_matrix = force_block - - if torque_matrix is None: - torque_matrix = np.zeros_like(torque_block) - torque_matrix = torque_block # add first set of torques - elif torque_matrix.shape != torque_block.shape: - raise ValueError( - f"Inconsistent torque matrix shape: existing " - f"{torque_matrix.shape}, new {torque_block.shape}" - ) - else: - torque_matrix = torque_block - - return force_matrix, torque_matrix - - def get_combined_forcetorque_matrices( - self, - data_container, - level, - highest_level, - forcetorque_matrix, - force_partitioning, - customised_axes, - ): - """ - Compute and accumulate combined force/torque covariance matrices for - a given level. - - Parameters: - data_container (MDAnalysis.Universe): Data for a molecule or residue. - level (str): 'polymer', 'residue', or 'united_atom'. - highest_level (bool): Whether this is the top (largest bead size) level. - forcetorque_matrix (np.ndarray or None): Accumulated matrices to add - to. - force_partitioning (float): Factor to adjust force contributions, - default is 0.5. - customised_axes (bool): Whether to use customised axes for rotating - forces. - - Returns: - forcetorque_matrix (np.ndarray): Accumulated torque covariance matrix. - """ - - # Make beads - list_of_beads = self.get_beads(data_container, level) - - # number of beads and frames in trajectory - number_beads = len(list_of_beads) - - # initialize force and torque arrays - weighted_forces = [None for _ in range(number_beads)] - weighted_torques = [None for _ in range(number_beads)] - - # Create axes manager for custom axes - axes_manager = AxesManager() - - # Calculate forces/torques for each bead - for bead_index in range(number_beads): - bead = list_of_beads[bead_index] - # Set up axes - # translation and rotation use different axes - # how the axes are defined depends on the level - if level == "residue" and customised_axes: - trans_axes, rot_axes, center, moment_of_inertia = ( - axes_manager.get_residue_axes(data_container, bead_index) - ) - else: - # ensure molecule is whole for PA calcs - make_whole(data_container.atoms) - make_whole(bead) - trans_axes = data_container.atoms.principal_axes() - rot_axes, moment_of_inertia = axes_manager.get_vanilla_axes(bead) - center = bead.center_of_mass(unwrap=True) - - # Sort out coordinates, forces, and torques for each atom in the bead - weighted_forces[bead_index] = self.get_weighted_forces( - data_container, - bead, - trans_axes, - highest_level, - force_partitioning, - ) - weighted_torques[bead_index] = self.get_weighted_torques( - bead, - rot_axes, - center, - force_partitioning, - moment_of_inertia, - axes_manager, - ) - - # Create covariance submatrices - forcetorque_submatrix = [ - [0 for _ in range(number_beads)] for _ in range(number_beads) - ] - - for i in range(number_beads): - for j in range(i, number_beads): - ft_sub = self.create_FTsubmatrix( - np.concatenate((weighted_forces[i], weighted_torques[i])), - np.concatenate((weighted_forces[j], weighted_torques[j])), - ) - forcetorque_submatrix[i][j] = ft_sub - forcetorque_submatrix[j][i] = ft_sub.T - - # Convert block matrices to full matrix - forcetorque_block = np.block( - [ - [forcetorque_submatrix[i][j] for j in range(number_beads)] - for i in range(number_beads) - ] - ) - - # Enforce consistent shape before accumulation - if forcetorque_matrix is None: - forcetorque_matrix = np.zeros_like(forcetorque_block) - forcetorque_matrix = forcetorque_block # add first set of torques - elif forcetorque_matrix.shape != forcetorque_block.shape: - raise ValueError( - f"Inconsistent forcetorque matrix shape: existing " - f"{forcetorque_matrix.shape}, new {forcetorque_block.shape}" - ) - else: - forcetorque_matrix = forcetorque_block - - return forcetorque_matrix - - def get_beads(self, data_container, level): - """ - Function to define beads depending on the level in the hierarchy. - - Args: - data_container (MDAnalysis.Universe): the molecule data - level (str): the heirarchy level (polymer, residue, or united atom) - - Returns: - list_of_beads : the relevent beads - """ - - if level == "polymer": - list_of_beads = [] - atom_group = "all" - list_of_beads.append(data_container.select_atoms(atom_group)) - - if level == "residue": - list_of_beads = [] - num_residues = len(data_container.residues) - for residue in range(num_residues): - atom_group = "resindex " + str(residue) - list_of_beads.append(data_container.select_atoms(atom_group)) - - if level == "united_atom": - list_of_beads = [] - heavy_atoms = data_container.select_atoms("prop mass > 1.1") - if len(heavy_atoms) == 0: - # molecule without heavy atoms would be a hydrogen molecule - list_of_beads.append(data_container.select_atoms("all")) - else: - # Select one heavy atom and all light atoms bonded to it - for atom in heavy_atoms: - atom_group = ( - "index " - + str(atom.index) - + " or ((prop mass <= 1.1) and bonded index " - + str(atom.index) - + ")" - ) - list_of_beads.append(data_container.select_atoms(atom_group)) - - logger.debug(f"List of beads: {list_of_beads}") - - return list_of_beads - - def get_weighted_forces( - self, data_container, bead, trans_axes, highest_level, force_partitioning - ): - """ - Compute mass-weighted translational forces for a bead. - - The forces acting on all atoms belonging to the bead are first transformed - into the provided translational reference frame and summed. If this bead - corresponds to the highest level of a hierarchical coarse-graining scheme, - the total force is scaled by a force-partitioning factor to avoid double - counting forces from weakly correlated atoms. - - The resulting force vector is then normalized by the square root of the - bead's total mass. - - Parameters - ---------- - data_container : MDAnalysis.Universe - Container holding atomic positions and forces. - bead : object - Molecular subunit whose atoms contribute to the force. - trans_axes : np.ndarray - Transformation matrix defining the translational reference frame. - highest_level : bool - Whether this bead is the highest level in the length-scale hierarchy. - If True, force partitioning is applied. - force_partitioning : float - Scaling factor applied to forces to avoid over-counting correlated - contributions (typically 0.5). - - Returns - ------- - weighted_force : np.ndarray - Mass-weighted translational force acting on the bead. - - Raises - ------ - ValueError - If the bead mass is zero or negative. - """ - forces_trans = np.zeros((3,)) - - for atom in bead.atoms: - forces_local = np.matmul(trans_axes, data_container.atoms[atom.index].force) - forces_trans += forces_local - - if highest_level: - forces_trans = force_partitioning * forces_trans - - mass = bead.total_mass() - - if mass <= 0: - raise ValueError( - f"Invalid mass value: {mass}. Mass must be positive to compute the " - f"square root." - ) - - weighted_force = forces_trans / np.sqrt(mass) - - logger.debug(f"Weighted Force: {weighted_force}") - - return weighted_force - - def get_weighted_torques( - self, - bead, - rot_axes, - center, - force_partitioning, - moment_of_inertia, - axes_manager, - ): - """ - Compute moment-of-inertia weighted torques for a bead. - - Atomic coordinates and forces are transformed into the provided rotational - reference frame. Torques are computed as the cross product of position - vectors (relative to the bead center of mass) and forces, with a - force-partitioning factor applied to reduce over-counting of correlated - atomic contributions. - - The total torque vector is then weighted by the square root of the bead's - principal moments of inertia. Weighting is performed component-wise using - the sorted eigenvalues of the moment of inertia tensor. - - To ensure numerical stability: - - Torque components that are effectively zero, zero or negative are skipped. - - Parameters - ---------- - data_container : object - Container holding atomic positions and forces. - bead : object - Molecular subunit whose atoms contribute to the torque. - rot_axes : np.ndarray - Transformation matrix defining the rotational reference frame. - force_partitioning : float - Scaling factor applied to forces to avoid over-counting correlated - contributions (typically 0.5). - moment_of_inertia : np.ndarray - Moment of inertia (3,) - customised_axes: bool - Whether to use customised axes for rotating forces. - - Returns - ------- - weighted_torque : np.ndarray - Moment-of-inertia weighted torque acting on the bead. - """ - - # translate and rotate positions and forces - # translated_coords = bead.positions - center - translated_coords = axes_manager.get_vector( - center, bead.positions, bead.dimensions[:3] - ) - rotated_coords = np.tensordot(translated_coords, rot_axes.T, axes=1) - rotated_forces = np.tensordot(bead.forces, rot_axes.T, axes=1) - # scale forces - rotated_forces *= force_partitioning - # get torques here - torques = np.cross(rotated_coords, rotated_forces) - torques = np.sum(torques, axis=0) - - weighted_torque = np.zeros((3,)) - for dimension in range(3): - if np.isclose(torques[dimension], 0): - weighted_torque[dimension] = 0 - continue - - if np.iscomplex(moment_of_inertia[dimension]): - weighted_torque[dimension] = 0 - continue - - if moment_of_inertia[dimension] == 0: - moment_of_inertia[dimension] = 0 - continue - - if moment_of_inertia[dimension] < 0: - moment_of_inertia[dimension] = 0 - continue - - # Compute weighted torque - weighted_torque[dimension] = torques[dimension] / np.sqrt( - moment_of_inertia[dimension] - ) - - logger.debug(f"Weighted Torque: {weighted_torque}") - - return weighted_torque - - def create_submatrix(self, data_i, data_j): - """ - Function for making covariance matrices. - - Args - ----- - data_i : values for bead i - data_j : values for bead j - - Returns - ------ - submatrix : 3x3 matrix for the covariance between i and j - """ - - # Start with 3 by 3 matrix of zeros - submatrix = np.zeros((3, 3)) - - # For each frame calculate the outer product (cross product) of the data from - # the two beads and add the result to the submatrix - outer_product_matrix = np.outer(data_i, data_j) - submatrix = np.add(submatrix, outer_product_matrix) - - logger.debug(f"Submatrix: {submatrix}") - - return submatrix - - def create_FTsubmatrix(self, data_i, data_j): - """ - Function for making covariance matrices. - - Args - ----- - data_i : values for bead i - data_j : values for bead j - - Returns - ------ - submatrix : 6x6 matrix for the covariance between i and j - """ - - # Start with 6 by 6 matrix of zeros - submatrix = np.zeros((6, 6)) - - # For each frame calculate the outer product (cross product) of the data from - # the two beads and add the result to the submatrix - outer_product_matrix = np.outer(data_i, data_j) - submatrix = np.add(submatrix, outer_product_matrix) - - return submatrix - - def build_covariance_matrices( - self, - entropy_manager, - reduced_atom, - levels, - groups, - start, - end, - step, - number_frames, - force_partitioning, - combined_forcetorque, - customised_axes, - ): - """ - Construct average force and torque covariance matrices for all molecules and - entropy levels. - - Parameters - ---------- - entropy_manager : EntropyManager - Instance of the EntropyManager. - reduced_atom : Universe - The reduced atom selection. - levels : dict - Dictionary mapping molecule IDs to lists of entropy levels. - groups : dict - Dictionary mapping group IDs to lists of molecule IDs. - start : int - Start frame index. - end : int - End frame index. - step : int - Step size for frame iteration. - number_frames : int - Total number of frames to process. - force_partitioning : float - Factor to adjust force contributions, default is 0.5. - combined_forcetorque : bool - Whether to use combined forcetorque covariance matrix. - - Returns - ------- - tuple - force_avg : dict - Averaged force covariance matrices by entropy level. - torque_avg : dict - Averaged torque covariance matrices by entropy level. - """ - number_groups = len(groups) - - force_avg = { - "ua": {}, - "res": [None] * number_groups, - "poly": [None] * number_groups, - } - torque_avg = { - "ua": {}, - "res": [None] * number_groups, - "poly": [None] * number_groups, - } - - forcetorque_avg = { - "ua": {}, - "res": [None] * number_groups, - "poly": [None] * number_groups, - } - - total_steps = len(reduced_atom.trajectory[start:end:step]) - total_items = ( - sum(len(levels[mol_id]) for mols in groups.values() for mol_id in mols) - * total_steps - ) - - frame_counts = { - "ua": {}, - "res": np.zeros(number_groups, dtype=int), - "poly": np.zeros(number_groups, dtype=int), - } - - with Progress( - SpinnerColumn(), - TextColumn("[bold blue]{task.fields[title]}", justify="right"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.1f}%"), - TimeElapsedColumn(), - ) as progress: - - task = progress.add_task( - "[green]Processing...", - total=total_items, - title="Starting...", - ) - - indices = list(range(number_frames)) - for time_index, _ in zip(indices, reduced_atom.trajectory[start:end:step]): - for group_id, molecules in groups.items(): - for mol_id in molecules: - mol = self._universe_operations.get_molecule_container( - reduced_atom, mol_id - ) - for level in levels[mol_id]: - resname = mol.atoms[0].resname - resid = mol.atoms[0].resid - segid = mol.atoms[0].segid - - mol_label = f"{resname}_{resid} (segid {segid})" - - progress.update( - task, - title=f"Building covariance matrices | " - f"Timestep {time_index} | " - f"Molecule: {mol_label} | " - f"Level: {level}", - ) - - self.update_force_torque_matrices( - entropy_manager, - mol, - group_id, - level, - levels[mol_id], - time_index, - number_frames, - force_avg, - torque_avg, - forcetorque_avg, - frame_counts, - force_partitioning, - combined_forcetorque, - customised_axes, - ) - - progress.advance(task) - - return force_avg, torque_avg, forcetorque_avg, frame_counts - - def update_force_torque_matrices( - self, - entropy_manager, - mol, - group_id, - level, - level_list, - time_index, - num_frames, - force_avg, - torque_avg, - forcetorque_avg, - frame_counts, - force_partitioning, - combined_forcetorque, - customised_axes, - ): - """ - Update the running averages of force and torque covariance matrices - for a given molecule and entropy level. - - This function computes the force and torque covariance matrices for the - current frame and updates the existing averages in-place using the incremental - mean formula: - - new_avg = old_avg + (value - old_avg) / n - - where n is the number of frames processed so far for that molecule/level - combination. This ensures that the averages are maintained without storing - all previous frame data. - - Parameters - ---------- - entropy_manager : EntropyManager - Instance of the EntropyManager. - mol : AtomGroup - The molecule to process. - group_id : int - Index of the group to which the molecule belongs. - level : str - Current entropy level ("united_atom", "residue", or "polymer"). - level_list : list - List of entropy levels for the molecule. - time_index : int - Index of the current frame relative to the start of the trajectory slice. - num_frames : int - Total number of frames to process. - force_avg : dict - Dictionary holding the running average force matrices, keyed by entropy - level. - torque_avg : dict - Dictionary holding the running average torque matrices, keyed by entropy - level. - frame_counts : dict - Dictionary holding the count of frames processed for each molecule/level - combination. - force_partitioning : float - Factor to adjust force contributions, default is 0.5. - combined_forcetorque : bool - Whether to use combined forcetorque covariance matrix. - customised_axes: bool - Whether to use bonded axes for UA rovib calculations - Returns - ------- - None - Updates are performed in-place on `force_avg`, `torque_avg`, and - `frame_counts`. - """ - highest = level == level_list[-1] - - # United atom level calculations are done separately for each residue - # This allows information per residue to be output and keeps the - # matrices from becoming too large - if level == "united_atom": - for res_id, residue in enumerate(mol.residues): - key = (group_id, res_id) - res = self._universe_operations.new_U_select_atom( - mol, f"index {residue.atoms.indices[0]}:{residue.atoms.indices[-1]}" - ) - - # This is to get MDAnalysis to get the information from the - # correct frame of the trajectory - res.trajectory[time_index] - - # Build the matrices, adding data from each timestep - # Being careful for the first timestep when data has not yet - # been added to the matrices - f_mat, t_mat = self.get_matrices( - res, - level, - highest, - None if key not in force_avg["ua"] else force_avg["ua"][key], - None if key not in torque_avg["ua"] else torque_avg["ua"][key], - force_partitioning, - customised_axes, - ) - - if key not in force_avg["ua"]: - force_avg["ua"][key] = f_mat.copy() - torque_avg["ua"][key] = t_mat.copy() - frame_counts["ua"][key] = 1 - else: - frame_counts["ua"][key] += 1 - n = frame_counts["ua"][key] - force_avg["ua"][key] += (f_mat - force_avg["ua"][key]) / n - torque_avg["ua"][key] += (t_mat - torque_avg["ua"][key]) / n - - elif level in ["residue", "polymer"]: - # This is to get MDAnalysis to get the information from the - # correct frame of the trajectory - mol.trajectory[time_index] - - key = "res" if level == "residue" else "poly" - - # Build the matrices, adding data from each timestep - # Being careful for the first timestep when data has not yet - # been added to the matrices - if highest and combined_forcetorque: - # use combined forcetorque covariance matrix for the highest level only - ft_mat = self.get_combined_forcetorque_matrices( - mol, - level, - highest, - ( - None - if forcetorque_avg[key][group_id] is None - else forcetorque_avg[key][group_id] - ), - force_partitioning, - customised_axes, - ) - - if forcetorque_avg[key][group_id] is None: - forcetorque_avg[key][group_id] = ft_mat.copy() - frame_counts[key][group_id] = 1 - else: - frame_counts[key][group_id] += 1 - n = frame_counts[key][group_id] - forcetorque_avg[key][group_id] += ( - ft_mat - forcetorque_avg[key][group_id] - ) / n - else: - f_mat, t_mat = self.get_matrices( - mol, - level, - highest, - ( - None - if force_avg[key][group_id] is None - else force_avg[key][group_id] - ), - ( - None - if torque_avg[key][group_id] is None - else torque_avg[key][group_id] - ), - force_partitioning, - customised_axes, - ) - - if force_avg[key][group_id] is None: - force_avg[key][group_id] = f_mat.copy() - torque_avg[key][group_id] = t_mat.copy() - frame_counts[key][group_id] = 1 - else: - frame_counts[key][group_id] += 1 - n = frame_counts[key][group_id] - force_avg[key][group_id] += (f_mat - force_avg[key][group_id]) / n - torque_avg[key][group_id] += (t_mat - torque_avg[key][group_id]) / n - - return frame_counts - - def filter_zero_rows_columns(self, arg_matrix): - """ - function for removing rows and columns that contain only zeros from a matrix - - Args: - arg_matrix : matrix - - Returns: - arg_matrix : the reduced size matrix - """ - - # record the initial size - init_shape = np.shape(arg_matrix) - - zero_indices = list( - filter( - lambda row: np.all(np.isclose(arg_matrix[row, :], 0.0)), - np.arange(np.shape(arg_matrix)[0]), - ) - ) - all_indices = np.ones((np.shape(arg_matrix)[0]), dtype=bool) - all_indices[zero_indices] = False - arg_matrix = arg_matrix[all_indices, :] - - all_indices = np.ones((np.shape(arg_matrix)[1]), dtype=bool) - zero_indices = list( - filter( - lambda col: np.all(np.isclose(arg_matrix[:, col], 0.0)), - np.arange(np.shape(arg_matrix)[1]), - ) - ) - all_indices[zero_indices] = False - arg_matrix = arg_matrix[:, all_indices] - - # get the final shape - final_shape = np.shape(arg_matrix) - - if init_shape != final_shape: - logger.debug( - "A shape change has occurred ({},{}) -> ({}, {})".format( - *init_shape, *final_shape - ) - ) - - logger.debug(f"arg_matrix: {arg_matrix}") - - return arg_matrix diff --git a/CodeEntropy/levels/__init__.py b/CodeEntropy/levels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CodeEntropy/levels/axes.py b/CodeEntropy/levels/axes.py new file mode 100644 index 00000000..07f1918f --- /dev/null +++ b/CodeEntropy/levels/axes.py @@ -0,0 +1,633 @@ +"""Axes utilities for entropy calculations. + +This module contains the :class:`AxesCalculator`, a geometry-focused helper used by +the entropy pipeline to compute translational and rotational axes, centres, and +moments of inertia at different hierarchy levels (residue / united-atom). +""" + +from __future__ import annotations + +import logging +from typing import Sequence, Tuple + +import numpy as np +from MDAnalysis.lib.mdamath import make_whole + +logger = logging.getLogger(__name__) + + +class AxesCalculator: + """Compute translation/rotation axes and inertia utilities used by entropy. + + Manages the structural and dynamic levels involved in entropy calculations. + This includes selecting relevant levels, computing axes for translation and + rotation, and handling bead-based representations of molecular systems. + + Provides utility methods to: + - extract averaged positions, + - convert coordinates to spherical systems (future/legacy scope), + - compute axes used to rotate forces around, + - compute custom moments of inertia, + - manipulate vectors under periodic boundary conditions (PBC), + - construct custom moment-of-inertia tensors and principal axes. + + Notes: + This class deliberately does **not**: + - compute weighted forces/torques (that belongs in ForceTorqueCalculator), + - build covariances, + - compute entropies. + """ + + def __init__(self) -> None: + """Initialize the AxesCalculator. + + The original implementation stored a few placeholders for level-related + data (axes, bead counts, etc.). In the current design, AxesCalculator is a + stateless helper, but we keep the attributes for compatibility and + debugging/extension. + + Attributes: + data_container: Optional container used by legacy workflows. + _levels: Optional levels list (legacy/placeholder). + _trans_axes: Optional cached translation axes (legacy/placeholder). + _rot_axes: Optional cached rotation axes (legacy/placeholder). + _number_of_beads: Optional bead count (legacy/placeholder). + """ + self.data_container = None + self._levels = None + self._trans_axes = None + self._rot_axes = None + self._number_of_beads = None + + def get_residue_axes(self, data_container, index: int, residue=None): + """Compute residue-level translational and rotational axes. + + The translational and rotational axes at the residue level. + + - Identify the residue (either provided or selected by `resindex index`). + - Determine whether the residue is bonded to neighbouring residues + (previous/next in sequence) using MDAnalysis bonded selections. + - If there are *no* bonds to other residues: + * Use a custom principal axes, from a moment-of-inertia (MOI) tensor + that uses positions of heavy atoms only, but including masses of + heavy atom + bonded hydrogens. + * Set translational axes equal to rotational axes (as per the original + code convention). + - If bonded to other residues: + * Use default axes and MOI (MDAnalysis principal axes / inertia). + + Args: + data_container (MDAnalysis.Universe or AtomGroup): + Molecule and trajectory data (the fragment/molecule container). + index (int): + Residue index (resindex) within `data_container`. + residue (MDAnalysis.AtomGroup, optional): + If provided, this residue selection will be used rather than + selecting again. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + - trans_axes: Translational axes array of shape (3, 3). + - rot_axes: Rotational axes array of shape (3, 3). + - center: Center of mass array of shape (3,). + - moment_of_inertia: Principal moments array of shape (3,). + + Raises: + ValueError: + If the residue selection is empty. + """ + # TODO refine selection so that it will work for branched polymers + index_prev = index - 1 + index_next = index + 1 + + if residue is None: + residue = data_container.select_atoms(f"resindex {index}") + if len(residue) == 0: + raise ValueError(f"Empty residue selection for resindex={index}") + + center = residue.atoms.center_of_mass(unwrap=True) + atom_set = data_container.select_atoms( + f"(resindex {index_prev} or resindex {index_next}) and bonded resid {index}" + ) + + if len(atom_set) == 0: + # No bonds to other residues. + # Use a custom principal axes, from a MOI tensor that uses positions of + # heavy atoms only, but including masses of heavy atom + bonded H. + uas = residue.select_atoms("mass 2 to 999") + ua_masses = self.get_UA_masses(residue) + moi_tensor = self.get_moment_of_inertia_tensor( + center_of_mass=center, + positions=uas.positions, + masses=ua_masses, + dimensions=data_container.dimensions[:3], + ) + rot_axes, moment_of_inertia = self.get_custom_principal_axes(moi_tensor) + trans_axes = rot_axes # per original convention + else: + # If bonded to other residues, use default axes and MOI. + make_whole(data_container.atoms) + trans_axes = data_container.atoms.principal_axes() + rot_axes, moment_of_inertia = self.get_vanilla_axes(residue) + center = residue.center_of_mass(unwrap=True) + + return trans_axes, rot_axes, center, moment_of_inertia + + def get_UA_axes(self, data_container, index: int): + """Compute united-atom-level translational and rotational axes. + + The translational and rotational axes at the united-atom level. + + This preserves the original behaviour and its rationale: + + - Translational axes: + Use the same custom principal-axes approach as residue level: + compute a custom MOI tensor using heavy-atom coordinates but UA masses + (heavy + bonded H masses), then compute the principal axes from it. + + - Rotational axes: + Identify heavy atoms in the residue/molecule of interest and choose + the `index`-th heavy atom (where index corresponds to the bead index). + Use bonded topology around that heavy atom to determine UA rotational + axes (see :meth:`get_bonded_axes`). + + Args: + data_container (MDAnalysis.Universe or AtomGroup): + Molecule and trajectory data. + index (int): + Bead index (ordinal among heavy atoms). + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + - trans_axes: Translational axes (3, 3). + - rot_axes: Rotational axes (3, 3). + - center: Rotation centre (3,) (heavy atom position). + - moment_of_inertia: (3,) moments for the UA around rot_axes. + + Raises: + IndexError: + If `index` does not correspond to an existing heavy atom. + ValueError: + If bonded-axis construction fails. + """ + + index = int(index) # bead index + + # use the same customPI trans axes as the residue level + heavy_atoms = data_container.select_atoms("prop mass > 1.1") + if len(heavy_atoms) > 1: + UA_masses = self.get_UA_masses(data_container.atoms) + center = data_container.atoms.center_of_mass(unwrap=True) + moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( + center, heavy_atoms.positions, UA_masses, data_container.dimensions[:3] + ) + trans_axes, _moment_of_inertia = self.get_custom_principal_axes( + moment_of_inertia_tensor + ) + else: + # use standard PA for UA not bonded to anything else + make_whole(data_container.atoms) + trans_axes = data_container.atoms.principal_axes() + + # look for heavy atoms in residue of interest + heavy_atom_indices = [] + for atom in heavy_atoms: + heavy_atom_indices.append(atom.index) + # we find the nth heavy atom + # where n is the bead index + heavy_atom_index = heavy_atom_indices[index] + heavy_atom = data_container.select_atoms(f"index {heavy_atom_index}") + + center = heavy_atom.positions[0] + rot_axes, moment_of_inertia = self.get_bonded_axes( + system=data_container, + atom=heavy_atom[0], + dimensions=data_container.dimensions[:3], + ) + if rot_axes is None or moment_of_inertia is None: + raise ValueError("Unable to compute bonded axes for UA bead.") + + logger.debug("Translational Axes: %s", trans_axes) + logger.debug("Rotational Axes: %s", rot_axes) + logger.debug("Center: %s", center) + logger.debug("Moment of Inertia: %s", moment_of_inertia) + + return trans_axes, rot_axes, center, moment_of_inertia + + def get_bonded_axes(self, system, atom, dimensions: np.ndarray): + r"""Compute UA rotational axes from bonded topology around a heavy atom. + + For a given heavy atom, use its bonded atoms to get the axes for rotating + forces around. Few cases for choosing united atom axes, which are dependent + on the bonds to the atom: + + :: + + X -- H = bonded to zero or more light atom/s (case1) + + X -- R = bonded to one heavy atom (case2) + + R -- X -- H = bonded to one heavy and at least one light atom (case3) + + R1 -- X -- R2 = bonded to two heavy atoms (case4) + + R1 -- X -- R2 = bonded to more than two heavy atoms (case5) + | + R3 + + Note that axis2 is calculated by taking the cross product between axis1 and + the vector chosen for each case, dependent on bonding: + + - case1: if all the bonded atoms are hydrogens, use the principal axes. + + - case2: use XR vector as axis1, arbitrary axis2. + + - case3: use XR vector as axis1, vector XH to calculate axis2 + + - case4: use vector XR1 as axis1, and XR2 to calculate axis2 + + - case5: get the sum of all XR normalised vectors as axis1, then use vector + R1R2 to calculate axis2 + + axis3 is always the cross product of axis1 and axis2. + + Args: + system: + MDAnalysis selection containing all atoms in current frame. + atom: + MDAnalysis Atom for the heavy atom. + dimensions: + Simulation box dimensions (3,). + + Returns: + Tuple[np.ndarray | None, np.ndarray | None]: + - custom_axes: Custom axes (3, 3), or None if atom is not heavy. + - custom_moment_of_inertia: (3,) moment of inertia around axes. + + Notes: + If custom_moment_of_inertia is not produced by the chosen method, it is + computed using :meth:`get_custom_moment_of_inertia` with the heavy atom + as COM (matching original behaviour). + """ + # check atom is a heavy atom + if not atom.mass > 1.1: + return None, None + + custom_moment_of_inertia = None + custom_axes = None + + heavy_bonded, light_bonded = self.find_bonded_atoms(atom.index, system) + ua = atom + light_bonded + ua_all = atom + heavy_bonded + light_bonded + + # case1 + if len(heavy_bonded) == 0: + custom_axes, custom_moment_of_inertia = self.get_vanilla_axes(ua_all) + + # case2 + if len(heavy_bonded) == 1 and len(light_bonded) == 0: + custom_axes = self.get_custom_axes( + a=atom.position, + b_list=[heavy_bonded[0].position], + c=np.zeros(3), + dimensions=dimensions, + ) + + # case3 + if len(heavy_bonded) == 1 and len(light_bonded) >= 1: + custom_axes = self.get_custom_axes( + a=atom.position, + b_list=[heavy_bonded[0].position], + c=light_bonded[0].position, + dimensions=dimensions, + ) + + # case4 (not used in original 2019 code; case5 used instead) + # case5 + if len(heavy_bonded) >= 2: + custom_axes = self.get_custom_axes( + a=atom.position, + b_list=heavy_bonded.positions, + c=heavy_bonded[1].position, + dimensions=dimensions, + ) + + if custom_axes is None: + return None, None + + if custom_moment_of_inertia is None: + custom_moment_of_inertia = self.get_custom_moment_of_inertia( + UA=ua, + custom_rotation_axes=custom_axes, + center_of_mass=atom.position, + dimensions=dimensions, + ) + + # flip axes to face correct way wrt COM + custom_axes = self.get_flipped_axes(ua, custom_axes, atom.position, dimensions) + + return custom_axes, custom_moment_of_inertia + + def find_bonded_atoms(self, atom_idx: int, system): + """Find bonded heavy and hydrogen atoms for a given atom. + + Args: + atom_idx: Atom index to find bonded atoms for. + system: MDAnalysis selection containing all atoms in current frame. + + Returns: + Tuple[AtomGroup, AtomGroup]: + - bonded_heavy_atoms: bonded heavy atoms (mass 2 to 999) + - bonded_H_atoms: bonded hydrogen atoms (mass 1 to 1.1) + """ + bonded_atoms = system.select_atoms(f"bonded index {atom_idx}") + bonded_heavy_atoms = bonded_atoms.select_atoms("mass 2 to 999") + bonded_H_atoms = bonded_atoms.select_atoms("mass 1 to 1.1") + return bonded_heavy_atoms, bonded_H_atoms + + def get_vanilla_axes(self, molecule): + """Get principal axes and sorted principal moments (vanilla method). + + Compute the principal axes and moments of inertia for a molecule using + MDAnalysis built-in functionality. + + The original description is preserved: + - The molecule is made whole to ensure correct handling of PBC. + - The moments are obtained by diagonalising the moment of inertia tensor. + - Eigenvalues are returned sorted from largest to smallest magnitude. + + Args: + molecule (MDAnalysis.core.groups.AtomGroup): + AtomGroup representing the molecule/bead. + + Returns: + Tuple[np.ndarray, np.ndarray]: + - principal_axes: (3, 3) axes. + - moment_of_inertia: (3,) moments sorted descending by |value|. + """ + moment_of_inertia_tensor = molecule.moment_of_inertia(unwrap=True) + make_whole(molecule.atoms) + principal_axes = molecule.principal_axes() + + eigenvalues, _ = np.linalg.eig(moment_of_inertia_tensor) + order = np.argsort(np.abs(eigenvalues))[::-1] + moment_of_inertia = eigenvalues[order] + + return principal_axes, moment_of_inertia + + def get_custom_axes( + self, + a: np.ndarray, + b_list: Sequence[np.ndarray], + c: np.ndarray, + dimensions: np.ndarray, + ) -> np.ndarray: + r"""Compute custom rotation axes from bonded vectors (PBC-aware). + + For atoms a, b_list and c, calculate the axis to rotate forces around: + + - axis1: use the normalised vector ab as axis1. If there is more than one + bonded heavy atom (HA), average over all the normalised vectors + calculated from b_list and use this as axis1). b_list contains all the + bonded heavy atom coordinates. + + - axis2: use the cross product of normalised vector ac and axis1 as axis2. + If there are more than two bonded heavy atoms, then use normalised vector + b[0]c to cross product with axis1, this gives the axis perpendicular + (represented by |_ symbol below) to axis1. + + - axis3: the cross product of axis1 and axis2, which is perpendicular to + axis1 and axis2. + + Args: + a: Central united-atom coordinates (3,). + b_list: Positions of heavy bonded atoms. + c: Coordinates of a second heavy atom or a hydrogen atom. + dimensions: Simulation box dimensions (3,). + + :: + + a 1 = norm_ab + / \ 2 = |_ norm_ab and norm_ac (use bc if more than 2 HAs) + / \ 3 = |_ 1 and 2 + b c + + Returns: + np.ndarray: (3, 3) array of the axes used to rotate forces. + + Raises: + ValueError: If axes cannot be normalized due to degeneracy. + """ + unscaled_axis1 = np.zeros(3, dtype=float) + for b in b_list: + ab_vector = self.get_vector(a, b, dimensions) + unscaled_axis1 += ab_vector + + if np.allclose(unscaled_axis1, 0.0): + raise ValueError("Degenerate axis1: summed bonded vectors are zero.") + + if len(b_list) >= 2: + ac_vector = self.get_vector(c, np.asarray(b_list)[0], dimensions) + else: + ac_vector = self.get_vector(c, a, dimensions) + + unscaled_axis2 = np.cross(ac_vector, unscaled_axis1) + unscaled_axis3 = np.cross(unscaled_axis2, unscaled_axis1) + + unscaled_custom_axes = np.array( + (unscaled_axis1, unscaled_axis2, unscaled_axis3), dtype=float + ) + mod = np.sqrt(np.sum(unscaled_custom_axes**2, axis=1)) + if np.any(np.isclose(mod, 0.0)): + raise ValueError("Degenerate custom axes: cannot normalize (zero norm).") + + scaled_custom_axes = unscaled_custom_axes / mod[:, np.newaxis] + return scaled_custom_axes + + def get_custom_moment_of_inertia( + self, + UA, + custom_rotation_axes: np.ndarray, + center_of_mass: np.ndarray, + dimensions: np.ndarray, + ) -> np.ndarray: + """Compute moment of inertia around custom axes for a UA. + + Get the moment of inertia (specifically used for the united atom level) + from a set of rotation axes and a given center of mass (COM is usually the + heavy atom position in a UA). + + Original behaviour preserved: + - Uses PBC-aware translated coordinates. + - Sums contributions from each atom: |axis x r|^2 * mass. + - Removes the lowest MOI degree of freedom if the UA only has a single + bonded H (i.e. UA has 2 atoms total). + + Args: + UA: MDAnalysis AtomGroup for the UA (heavy + bonded H atoms). + custom_rotation_axes: (3, 3) array of rotation axes. + center_of_mass: (3,) COM for the UA (typically HA position). + dimensions: (3,) simulation box dimensions. + + Returns: + np.ndarray: (3,) moment of inertia array. + """ + translated_coords = self.get_vector(center_of_mass, UA.positions, dimensions) + custom_moment_of_inertia = np.zeros(3, dtype=float) + + for coord, mass in zip(translated_coords, UA.masses, strict=True): + axis_component = np.sum( + np.cross(custom_rotation_axes, coord) ** 2 * mass, axis=1 + ) + custom_moment_of_inertia += axis_component + + if len(UA) == 2: + order = custom_moment_of_inertia.argsort()[::-1] # descending order + custom_moment_of_inertia[order[-1]] = 0.0 + + return custom_moment_of_inertia + + def get_flipped_axes( + self, + UA, + custom_axes: np.ndarray, + center_of_mass: np.ndarray, + dimensions: np.ndarray, + ): + """Flip custom axes to a consistent direction with respect to the UA. + + For a given set of custom axes, ensure the axes are pointing in the + correct direction with respect to the heavy atom position and the chosen + center of mass. + + Args: + UA: MDAnalysis AtomGroup for the UA. + custom_axes: (3, 3) array of rotation axes. + center_of_mass: (3,) COM reference (usually HA position). + dimensions: (3,) simulation box dimensions. + + Returns: + np.ndarray: (3, 3) array of flipped/normalized axes. + """ + rr_axis = self.get_vector(UA[0].position, center_of_mass, dimensions) + + axis_norm = np.sqrt(np.sum(custom_axes**2, axis=1)) + custom_axes_flipped = custom_axes / axis_norm[:, np.newaxis] + + for i in range(3): + dot_prod = float(np.dot(custom_axes_flipped[i], rr_axis)) + if dot_prod < 0.0: + custom_axes_flipped[i] *= -1.0 + + return custom_axes_flipped + + def get_vector(self, a: np.ndarray, b: np.ndarray, dimensions: np.ndarray): + """Compute PBC-wrapped displacement vector(s). + + For vector of two coordinates over periodic boundary conditions (PBCs). + + Args: + a: (3,) or (N, 3) array of coordinates. + b: (3,) or (N, 3) array of coordinates. + dimensions: (3,) simulation box dimensions. + + Returns: + np.ndarray: Wrapped displacement vector(s) with broadcasted shape. + """ + delta = b - a + delta -= dimensions * np.round(delta / dimensions) + return delta + + def get_moment_of_inertia_tensor( + self, + center_of_mass: np.ndarray, + positions: np.ndarray, + masses: Sequence[float], + dimensions: np.ndarray, + ) -> np.ndarray: + """Compute a custom moment of inertia tensor. + + Calculate a custom moment of inertia tensor. + E.g., for cases where the mass list will contain masses of UAs rather than + individual atoms and the positions will be those for the UAs only + (excluding the H atoms coordinates). + + Args: + center_of_mass: (3,) chosen centre for the tensor. + positions: (N, 3) point positions. + masses: (N,) point masses corresponding to positions. + dimensions: (3,) simulation box dimensions. + + Returns: + np.ndarray: (3, 3) moment of inertia tensor. + """ + r = self.get_vector(center_of_mass, positions, dimensions) + r2 = np.sum(r**2, axis=1) + + masses_arr = np.asarray(list(masses), dtype=float) + moment_of_inertia_tensor = np.eye(3) * np.sum(masses_arr * r2) + moment_of_inertia_tensor -= np.einsum("i,ij,ik->jk", masses_arr, r, r) + + return moment_of_inertia_tensor + + def get_custom_principal_axes( + self, moment_of_inertia_tensor: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """Compute principal axes and moments from a custom MOI tensor. + + Principal axes and centre of axes from the ordered eigenvalues and + eigenvectors of a moment of inertia tensor. This function allows for a + custom moment of inertia tensor to be used, which isn't possible with the + built-in MDAnalysis principal_axes() function. + + Original behaviour preserved: + - Eigenvalues are sorted by descending absolute magnitude. + - Eigenvectors are transposed so axes are returned as rows. + - Z axis is flipped to enforce the same handedness convention as the + original implementation. + + Args: + moment_of_inertia_tensor: (3, 3) custom inertia tensor. + + Returns: + Tuple[np.ndarray, np.ndarray]: + - principal_axes: (3, 3) principal axes (rows). + - moment_of_inertia: (3,) principal moments. + """ + eigenvalues, eigenvectors = np.linalg.eig(moment_of_inertia_tensor) + order = np.abs(eigenvalues).argsort()[::-1] # descending order + transposed = np.transpose(eigenvectors) # columns -> rows + moment_of_inertia = eigenvalues[order] + principal_axes = transposed[order] + + # point z axis in correct direction, as per original code + cross_xy = np.cross(principal_axes[0], principal_axes[1]) + dot_z = float(np.dot(cross_xy, principal_axes[2])) + if dot_z < 0: + principal_axes[2] *= -1 + + return principal_axes, moment_of_inertia + + def get_UA_masses(self, molecule) -> list[float]: + """Return united-atom (UA) masses for a molecule. + + For a given molecule, return a list of masses of UAs (combination of the + heavy atoms + bonded hydrogen atoms). This list is used to get the moment + of inertia tensor for molecules larger than one UA. + + Args: + molecule: MDAnalysis AtomGroup representing the molecule. + + Returns: + list[float]: UA masses for each heavy atom. + """ + ua_masses: list[float] = [] + for atom in molecule: + if atom.mass > 1.1: + ua_mass = float(atom.mass) + bonded_atoms = molecule.select_atoms(f"bonded index {atom.index}") + bonded_h_atoms = bonded_atoms.select_atoms("mass 1 to 1.1") + for h in bonded_h_atoms: + ua_mass += float(h.mass) + ua_masses.append(ua_mass) + return ua_masses diff --git a/CodeEntropy/levels/dihedrals.py b/CodeEntropy/levels/dihedrals.py new file mode 100644 index 00000000..ff5bf853 --- /dev/null +++ b/CodeEntropy/levels/dihedrals.py @@ -0,0 +1,492 @@ +"""Dihedral state assignment for conformational entropy. + +This module converts dihedral angle time series into discrete conformational +state labels. The resulting state labels are used downstream to compute +conformational entropy. +""" + +import logging +from typing import Dict, List, Tuple + +import numpy as np +from MDAnalysis.analysis.dihedrals import Dihedral + +logger = logging.getLogger(__name__) + +UAKey = Tuple[int, int] + + +class ConformationStateBuilder: + """Build conformational state labels from dihedral angles.""" + + def __init__(self, universe_operations=None): + """Initializes the analysis helper. + + Args: + universe_operations: Object providing helper methods: + - extract_fragment(data_container, molecule_id) + - select_atoms(atomgroup, selection_string) + """ + self._universe_operations = universe_operations + + def build_conformational_states( + self, + data_container, + levels, + groups, + start: int, + end: int, + step: int, + bin_width: float, + progress: object | None = None, + ): + """Build conformational state labels from trajectory dihedrals. + + This method constructs discrete conformational state descriptors used in + configurational entropy calculations. It supports united-atom (UA) and + residue-level state generation depending on which hierarchy levels are + enabled per molecule. + + Progress reporting is optional and UI-agnostic: if a progress sink is + provided, the method will create a single task and advance it once per + molecule group. + + Args: + data_container: MDAnalysis Universe (or compatible container) used to + extract fragments and compute dihedral time series. + levels: Mapping of molecule_id -> iterable of enabled level names + (e.g., ["united_atom", "residue"]). + groups: Mapping of group_id -> list of molecule_ids. + start: Inclusive start frame index. + end: Exclusive end frame index. + step: Frame stride. + bin_width: Histogram bin width in degrees used when identifying peak + dihedral populations. + progress: Optional progress sink (e.g., from ResultsReporter.progress()). + Must expose add_task(), update(), and advance(). + + Returns: + Tuple of: + states_ua: Dict mapping (group_id, local_residue_id) -> list of state + labels (strings) across the analyzed trajectory. + states_res: List-like structure indexed by group_id (or equivalent) + containing residue-level state labels (strings) across the + analyzed trajectory. + + Notes: + - This function advances progress once per group_id. + - Frame slicing arguments (start/end/step) are forwarded to downstream + helpers as implemented in this module. + """ + number_groups = len(groups) + states_ua: Dict[UAKey, List[str]] = {} + states_res: List[List[str]] = [None] * number_groups + + task = None + if progress is not None: + total = max(1, len(groups)) + task = progress.add_task( + "[green]Conformational states", + total=total, + title="Initializing", + ) + + if not groups: + if task is not None: + progress.update(task, title="No groups") + progress.advance(task) + return states_ua, states_res + + for group_id in groups.keys(): + molecules = groups[group_id] + if not molecules: + if task is not None: + progress.update(task, title=f"Group {group_id} (empty)") + progress.advance(task) + continue + + if task is not None: + progress.update(task, title=f"Group {group_id}") + + mol = self._universe_operations.extract_fragment( + data_container, molecules[0] + ) + + dihedrals_ua, dihedrals_res = self._collect_dihedrals_for_group( + mol=mol, + level_list=levels[molecules[0]], + ) + + peaks_ua, peaks_res = self._collect_peaks_for_group( + data_container=data_container, + molecules=molecules, + dihedrals_ua=dihedrals_ua, + dihedrals_res=dihedrals_res, + bin_width=bin_width, + start=start, + end=end, + step=step, + level_list=levels[molecules[0]], + ) + + self._assign_states_for_group( + data_container=data_container, + group_id=group_id, + molecules=molecules, + dihedrals_ua=dihedrals_ua, + peaks_ua=peaks_ua, + dihedrals_res=dihedrals_res, + peaks_res=peaks_res, + start=start, + end=end, + step=step, + level_list=levels[molecules[0]], + states_ua=states_ua, + states_res=states_res, + ) + + if task is not None: + progress.advance(task) + + return states_ua, states_res + + def _collect_dihedrals_for_group(self, mol, level_list): + """Collect UA and residue dihedral AtomGroups for a group. + + Args: + mol: Representative molecule AtomGroup. + level_list: List of enabled hierarchy levels. + + Returns: + Tuple: + dihedrals_ua: List of per-residue dihedral AtomGroups. + dihedrals_res: List of residue-level dihedral AtomGroups. + """ + num_residues = len(mol.residues) + dihedrals_ua: List[List] = [[] for _ in range(num_residues)] + dihedrals_res: List = [] + + for level in level_list: + if level == "united_atom": + for res_id in range(num_residues): + heavy_res = self._select_heavy_residue(mol, res_id) + dihedrals_ua[res_id] = self._get_dihedrals(heavy_res, level) + + elif level == "residue": + dihedrals_res = self._get_dihedrals(mol, level) + + return dihedrals_ua, dihedrals_res + + def _select_heavy_residue(self, mol, res_id: int): + """Select heavy atoms in a residue by residue index. + + Args: + mol: Representative molecule AtomGroup. + res_id: Residue index. + + Returns: + AtomGroup containing heavy atoms in the residue selection. + """ + selection1 = mol.residues[res_id].atoms.indices[0] + selection2 = mol.residues[res_id].atoms.indices[-1] + + res_container = self._universe_operations.select_atoms( + mol, f"index {selection1}:{selection2}" + ) + return self._universe_operations.select_atoms(res_container, "prop mass > 1.1") + + def _get_dihedrals(self, data_container, level: str): + """Return dihedral AtomGroups for a container at a given level. + + Args: + data_container: MDAnalysis container (AtomGroup/Universe). + level: Either "united_atom" or "residue". + + Returns: + List of AtomGroups (each representing a dihedral definition). + """ + atom_groups = [] + + if level == "united_atom": + dihedrals = data_container.dihedrals + for d in dihedrals: + atom_groups.append(d.atoms) + + if level == "residue": + num_residues = len(data_container.residues) + if num_residues >= 4: + for residue in range(4, num_residues + 1): + atom1 = data_container.select_atoms( + f"resindex {residue - 4} and bonded resindex {residue - 3}" + ) + atom2 = data_container.select_atoms( + f"resindex {residue - 3} and bonded resindex {residue - 4}" + ) + atom3 = data_container.select_atoms( + f"resindex {residue - 2} and bonded resindex {residue - 1}" + ) + atom4 = data_container.select_atoms( + f"resindex {residue - 1} and bonded resindex {residue - 2}" + ) + atom_groups.append(atom1 + atom2 + atom3 + atom4) + + logger.debug("Level: %s, Dihedrals: %s", level, atom_groups) + return atom_groups + + def _collect_peaks_for_group( + self, + data_container, + molecules, + dihedrals_ua, + dihedrals_res, + bin_width, + start, + end, + step, + level_list, + ): + """Compute histogram peaks for UA and residue dihedral sets. + + Returns: + Tuple: + peaks_ua: list of peaks per residue + (each item is list-of-peaks per dihedral) + peaks_res: list-of-peaks per dihedral for residue level (or []) + """ + peaks_ua = [{} for _ in range(len(dihedrals_ua))] + peaks_res = {} + + for level in level_list: + if level == "united_atom": + for res_id in range(len(dihedrals_ua)): + if len(dihedrals_ua[res_id]) == 0: + peaks_ua[res_id] = [] + else: + peaks_ua[res_id] = self._identify_peaks( + data_container=data_container, + molecules=molecules, + dihedrals=dihedrals_ua[res_id], + bin_width=bin_width, + start=start, + end=end, + step=step, + ) + + elif level == "residue": + if len(dihedrals_res) == 0: + peaks_res = [] + else: + peaks_res = self._identify_peaks( + data_container=data_container, + molecules=molecules, + dihedrals=dihedrals_res, + bin_width=bin_width, + start=start, + end=end, + step=step, + ) + + return peaks_ua, peaks_res + + def _identify_peaks( + self, + data_container, + molecules, + dihedrals, + bin_width, + start, + end, + step, + ): + """Identify histogram peaks ("convex turning points") for each dihedral. + + Important: + This function intentionally preserves the legacy behavior: + it samples over the full trajectory length for each molecule + and does not apply start/end/step to the Dihedral run. + + Args: + data_container: MDAnalysis universe. + molecules: Molecule ids in the group. + dihedrals: Dihedral AtomGroups. + bin_width: Histogram bin width (degrees). + start: Unused in legacy sampling. + end: Unused in legacy sampling. + step: Unused in legacy sampling. + + Returns: + List of peaks per dihedral (peak_values[dihedral_index] -> list of peaks). + """ + peak_values = [] * len(dihedrals) + + for dihedral_index in range(len(dihedrals)): + phi = [] + + for molecule in molecules: + mol = self._universe_operations.extract_fragment( + data_container, molecule + ) + number_frames = len(mol.trajectory) + + dihedral_results = Dihedral(dihedrals).run() + + for timestep in range(number_frames): + value = dihedral_results.results.angles[timestep][dihedral_index] + if value < 0: + value += 360 + phi.append(value) + + number_bins = int(360 / bin_width) + popul, bin_edges = np.histogram(a=phi, bins=number_bins, range=(0, 360)) + bin_value = [ + 0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(0, len(popul)) + ] + + peaks = self._find_histogram_peaks(popul=popul, bin_value=bin_value) + peak_values.append(peaks) + + logger.debug("Dihedral: %s, Peak Values: %s", dihedral_index, peak_values) + + return peak_values + + @staticmethod + def _find_histogram_peaks(popul, bin_value): + """Return convex turning-point peaks from a histogram.""" + number_bins = len(popul) + peaks = [] + + for bin_index in range(number_bins): + if popul[bin_index] == 0: + continue + + if bin_index == number_bins - 1: + if ( + popul[bin_index] >= popul[bin_index - 1] + and popul[bin_index] >= popul[0] + ): + peaks.append(bin_value[bin_index]) + else: + if ( + popul[bin_index] >= popul[bin_index - 1] + and popul[bin_index] >= popul[bin_index + 1] + ): + peaks.append(bin_value[bin_index]) + + return peaks + + def _assign_states_for_group( + self, + data_container, + group_id, + molecules, + dihedrals_ua, + peaks_ua, + dihedrals_res, + peaks_res, + start, + end, + step, + level_list, + states_ua, + states_res, + ): + """Assign UA and residue states for a group into output containers.""" + for level in level_list: + if level == "united_atom": + for res_id in range(len(dihedrals_ua)): + key = (group_id, res_id) + if len(dihedrals_ua[res_id]) == 0: + states_ua[key] = [] + else: + states_ua[key] = self._assign_states( + data_container=data_container, + molecules=molecules, + dihedrals=dihedrals_ua[res_id], + peaks=peaks_ua[res_id], + start=start, + end=end, + step=step, + ) + + elif level == "residue": + if len(dihedrals_res) == 0: + states_res[group_id] = [] + else: + states_res[group_id] = self._assign_states( + data_container=data_container, + molecules=molecules, + dihedrals=dihedrals_res, + peaks=peaks_res, + start=start, + end=end, + step=step, + ) + + def _assign_states( + self, + data_container, + molecules, + dihedrals, + peaks, + start, + end, + step, + ): + """Assign discrete state labels for the provided dihedrals. + + Important: + This function intentionally preserves the legacy behavior: + it samples over the full trajectory length for each molecule + and does not apply start/end/step to the Dihedral run. + + Args: + data_container: MDAnalysis universe. + molecules: Molecule ids in the group. + dihedrals: Dihedral AtomGroups. + peaks: Peaks per dihedral. + start: Unused in legacy sampling. + end: Unused in legacy sampling. + step: Unused in legacy sampling. + + Returns: + List of state labels (strings). + """ + states = None + + for molecule in molecules: + conformations = [] + mol = self._universe_operations.extract_fragment(data_container, molecule) + number_frames = len(mol.trajectory) + + dihedral_results = Dihedral(dihedrals).run() + + for dihedral_index in range(len(dihedrals)): + conformation = [] + for timestep in range(number_frames): + value = dihedral_results.results.angles[timestep][dihedral_index] + if value < 0: + value += 360 + + distances = [abs(value - peak) for peak in peaks[dihedral_index]] + conformation.append(np.argmin(distances)) + + conformations.append(conformation) + + mol_states = [ + state + for state in ( + "".join( + str(int(conformations[d][f])) for d in range(len(dihedrals)) + ) + for f in range(number_frames) + ) + if state + ] + + if states is None: + states = mol_states + else: + states.extend(mol_states) + + logger.debug("States: %s", states) + return states diff --git a/CodeEntropy/levels/forces.py b/CodeEntropy/levels/forces.py new file mode 100644 index 00000000..3b382b0c --- /dev/null +++ b/CodeEntropy/levels/forces.py @@ -0,0 +1,311 @@ +"""Force/torque weighting and per-frame second-moment construction. + +This module provides utilities for transforming atomic forces into bead-level +generalized forces (translation) and torques (rotation), and for assembling +per-frame second-moment matrices used downstream in entropy calculations. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Optional, Sequence, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + +Vector3 = np.ndarray +Matrix = np.ndarray + + +@dataclass(frozen=True) +class TorqueInputs: + """Container for torque computation inputs. + + Attributes: + rot_axes: Rotation matrix mapping lab-frame vectors into the bead frame, + shape (3, 3). + center: Reference center for torque arm vectors, shape (3,). + force_partitioning: Scaling factor applied to forces before torque + accumulation. + moment_of_inertia: Principal moments (aligned with rot_axes), shape (3,). + axes_manager: Optional object that provides: + get_vector(center, positions, box) -> displacement vectors (PBC-aware). + box: Optional periodic box passed to axes_manager.get_vector. + """ + + rot_axes: Matrix + center: Vector3 + force_partitioning: float + moment_of_inertia: Vector3 + axes_manager: Optional[Any] = None + box: Optional[np.ndarray] = None + + +class ForceTorqueCalculator: + """Computes weighted generalized forces/torques and per-frame second moments. + + This class provides: + - Mass-weighted generalized translational forces from per-atom forces. + - Moment-of-inertia-weighted generalized torques from per-atom positions + and forces, optionally using an axes_manager for PBC-aware displacements. + - Per-frame second-moment (outer product) matrices for concatenated bead + vectors, used downstream for covariance/entropy calculations. + """ + + def get_weighted_forces( + self, + bead: Any, + trans_axes: Matrix, + highest_level: bool, + force_partitioning: float, + ) -> Vector3: + """Compute a mass-weighted translational generalized force. + + Args: + bead: MDAnalysis AtomGroup-like bead with .atoms and .total_mass(). + Each atom must provide .force (shape (3,)). + trans_axes: Transform matrix for translational forces, shape (3, 3). + highest_level: If True, apply force_partitioning scaling. + force_partitioning: Scaling factor applied when highest_level is True. + + Returns: + Mass-weighted generalized force vector, shape (3,). + + Raises: + ValueError: If mass is non-positive or trans_axes shape is invalid. + """ + return self._compute_weighted_force( + bead=bead, + trans_axes=trans_axes, + apply_partitioning=highest_level, + force_partitioning=force_partitioning, + ) + + def get_weighted_torques( + self, + bead: Any, + rot_axes: Matrix, + center: Vector3, + force_partitioning: float, + moment_of_inertia: Vector3, + axes_manager: Optional[Any], + box: Optional[np.ndarray], + ) -> Vector3: + """Compute a moment-weighted generalized torque. + + Args: + bead: MDAnalysis AtomGroup-like bead with .positions and .forces (N,3). + rot_axes: Rotation matrix into bead frame, shape (3,3). + center: Reference center for displacement vectors, shape (3,). + force_partitioning: Scaling factor applied to forces before torque sum. + moment_of_inertia: Principal moments aligned with rot_axes, shape (3,). + axes_manager: Optional PBC displacement provider. + box: Periodic box passed to axes_manager when used. + + Returns: + Weighted torque vector, shape (3,). + + Raises: + ValueError: If shapes are invalid. + """ + inputs = TorqueInputs( + rot_axes=np.asarray(rot_axes, dtype=float), + center=np.asarray(center, dtype=float).reshape(3), + force_partitioning=float(force_partitioning), + moment_of_inertia=np.asarray(moment_of_inertia), + axes_manager=axes_manager, + box=box, + ) + return self._compute_weighted_torque(bead=bead, inputs=inputs) + + def compute_frame_covariance( + self, + force_vecs: Sequence[Vector3], + torque_vecs: Sequence[Vector3], + ) -> Tuple[Matrix, Matrix]: + """Compute per-frame second-moment matrices for force/torque vectors. + + Note: + This returns outer(x, x) where x is the concatenation of all bead + vectors in the frame. + + Args: + force_vecs: Sequence of per-bead force vectors (3,). + torque_vecs: Sequence of per-bead torque vectors (3,). + + Returns: + Tuple (F, T) where each is a (3N, 3N) second-moment matrix. + """ + return self._compute_frame_second_moments(force_vecs, torque_vecs) + + def _compute_weighted_force( + self, + bead: Any, + trans_axes: Matrix, + *, + apply_partitioning: bool, + force_partitioning: float, + ) -> Vector3: + """Compute a translational generalized force vector for a bead. + + The bead's atomic forces are transformed by trans_axes, summed, optionally + scaled by force_partitioning, and then mass-weighted by 1/sqrt(mass). + + Args: + bead: Bead-like object with .atoms and .total_mass(). Each atom must + provide .force with shape (3,). + trans_axes: Transform matrix for translational forces, shape (3, 3). + apply_partitioning: Whether to apply the force_partitioning scaling. + force_partitioning: Scaling factor applied when apply_partitioning is True. + + Returns: + Mass-weighted generalized force vector of shape (3,). + + Raises: + ValueError: If trans_axes is not (3,3) or bead mass is non-positive. + """ + trans_axes = np.asarray(trans_axes, dtype=float) + if trans_axes.shape != (3, 3): + raise ValueError(f"trans_axes must be (3,3), got {trans_axes.shape}") + + forces_trans = np.zeros((3,), dtype=float) + for atom in bead.atoms: + forces_trans += trans_axes @ np.asarray(atom.force, dtype=float) + + if apply_partitioning: + forces_trans *= float(force_partitioning) + + mass = float(bead.total_mass()) + if mass <= 0.0: + raise ValueError(f"Invalid bead mass: {mass}. Mass must be positive.") + + return forces_trans / np.sqrt(mass) + + def _compute_weighted_torque(self, bead: Any, inputs: TorqueInputs) -> Vector3: + """Compute a rotational generalized torque vector for a bead. + + Positions are displaced relative to inputs.center (optionally PBC-aware), + rotated into the bead frame, and crossed with rotated (and scaled) forces + to form a torque vector. Each component is then weighted by 1/sqrt(I_d) + where I_d is the corresponding principal moment of inertia. + + Args: + bead: Bead-like object with .positions (N,3) and .forces (N,3). + inputs: TorqueInputs containing axes, center, scaling, and inertia. + + Returns: + Moment-of-inertia-weighted torque vector of shape (3,). + + Raises: + ValueError: If rot_axes is not (3,3) or moment_of_inertia is not length 3. + """ + rot_axes = np.asarray(inputs.rot_axes, dtype=float) + if rot_axes.shape != (3, 3): + raise ValueError(f"rot_axes must be (3,3), got {rot_axes.shape}") + + moi = np.asarray(inputs.moment_of_inertia) + moi = np.real_if_close(moi, tol=1000) + moi = np.asarray(moi, dtype=float).reshape(-1) + if moi.size != 3: + raise ValueError(f"moment_of_inertia must be (3,), got {moi.shape}") + + translated = self._displacements_relative_to_center( + center=np.asarray(inputs.center, dtype=float).reshape(3), + positions=np.asarray(bead.positions, dtype=float), + axes_manager=inputs.axes_manager, + box=inputs.box, + ) + + rotated_coords = np.tensordot(translated, rot_axes.T, axes=1) + rotated_forces = np.tensordot( + np.asarray(bead.forces, dtype=float), rot_axes.T, axes=1 + ) + rotated_forces *= float(inputs.force_partitioning) + + torques = np.sum(np.cross(rotated_coords, rotated_forces), axis=0) + + weighted = np.zeros((3,), dtype=float) + for d in range(3): + if np.isclose(torques[d], 0.0): + continue + if moi[d] <= 0.0: + continue + weighted[d] = torques[d] / np.sqrt(moi[d]) + + return weighted + + def _compute_frame_second_moments( + self, + force_vectors: Sequence[Vector3], + torque_vectors: Sequence[Vector3], + ) -> Tuple[Matrix, Matrix]: + """Build outer-product second-moment matrices for a single frame. + + Args: + force_vectors: Sequence of per-bead force vectors of shape (3,). + torque_vectors: Sequence of per-bead torque vectors of shape (3,). + + Returns: + Tuple (F, T) where each is the outer-product second moment of the + concatenated vectors, with shape (3N, 3N). + """ + f = self._outer_second_moment(force_vectors) + t = self._outer_second_moment(torque_vectors) + return f, t + + @staticmethod + def _displacements_relative_to_center( + *, + center: Vector3, + positions: np.ndarray, + axes_manager: Optional[Any], + box: Optional[np.ndarray], + ) -> np.ndarray: + """Compute displacement vectors from center to positions. + + This method delegates displacement computation to axes_manager.get_vector, + which is expected to handle periodic boundary conditions if applicable. + + Args: + center: Reference center position of shape (3,). + positions: Array of positions of shape (N, 3). + axes_manager: Object providing get_vector(center, positions, box). + box: Periodic box passed through to axes_manager.get_vector. + + Returns: + Displacement vectors of shape (N, 3). + + Raises: + AttributeError: If axes_manager does not provide get_vector. + """ + return axes_manager.get_vector(center, positions, box) + + @staticmethod + def _outer_second_moment(vectors: Sequence[Vector3]) -> Matrix: + """Compute outer(flat, flat) for concatenated 3-vectors. + + Args: + vectors: Sequence of vectors of shape (3,). + + Returns: + Second-moment matrix with shape (3N, 3N). Returns (0,0) if empty. + + Raises: + ValueError: If any vector is not length 3. + """ + if not vectors: + return np.zeros((0, 0), dtype=float) + + parts = [] + for v in vectors: + arr = np.asarray(v, dtype=float).reshape(-1) + if arr.size != 3: + raise ValueError( + f"Expected vector of length 3, got shape {np.asarray(v).shape}" + ) + parts.append(arr) + + flat = np.concatenate(parts, axis=0) + return np.outer(flat, flat) diff --git a/CodeEntropy/levels/frame_dag.py b/CodeEntropy/levels/frame_dag.py new file mode 100644 index 00000000..dd27a3b2 --- /dev/null +++ b/CodeEntropy/levels/frame_dag.py @@ -0,0 +1,116 @@ +"""Frame-local DAG execution. + +This module defines the frame-scoped DAG used during the MAP stage of the +hierarchy workflow. Each frame is processed independently to produce +frame-local outputs (e.g., axes and covariance data), which are later reduced +outside this DAG. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import networkx as nx + +from CodeEntropy.levels.nodes.covariance import FrameCovarianceNode + +logger = logging.getLogger(__name__) + + +@dataclass +class FrameContext: + """Container for per-frame execution context. + + Attributes: + shared: Shared workflow data (mutated across the full workflow). + frame_index: Absolute trajectory frame index being processed. + frame_covariance: Frame-local covariance output produced by FrameCovarianceNode. + data: Additional frame-local scratch space for nodes, if needed. + """ + + shared: Dict[str, Any] + frame_index: int + frame_covariance: Any = None + data: Dict[str, Any] = None + + +class FrameGraph: + """Execute a frame-local directed acyclic graph. + + The graph is run once per trajectory frame. Nodes may read shared inputs from + `ctx["shared"]` and must write only frame-local outputs into the frame context. + + Expected node outputs: + - "frame_covariance" + """ + + def __init__(self, universe_operations: Optional[Any] = None) -> None: + """Initialise a FrameGraph. + + Args: + universe_operations: Optional adapter providing universe operations used + by frame-level nodes (e.g., selections / molecule containers). + """ + self._universe_operations = universe_operations + self._graph = nx.DiGraph() + self._nodes: Dict[str, Any] = {} + + def build(self) -> "FrameGraph": + """Build the default frame DAG topology. + + Returns: + Self, to allow fluent chaining. + """ + self._add("frame_covariance", FrameCovarianceNode()) + return self + + def execute_frame(self, shared_data: Dict[str, Any], frame_index: int) -> Any: + """Execute the frame DAG for a single trajectory frame. + + Args: + shared_data: Shared workflow data dict. + frame_index: Absolute trajectory frame index. + + Returns: + Frame-local covariance payload produced by FrameCovarianceNode. + """ + ctx = self._make_frame_ctx(shared_data=shared_data, frame_index=frame_index) + + for node_name in nx.topological_sort(self._graph): + logger.debug("[FrameGraph] running %s @ frame=%s", node_name, frame_index) + self._nodes[node_name].run(ctx) + + return ctx["frame_covariance"] + + def _add(self, name: str, node: Any, deps: Optional[list[str]] = None) -> None: + """Register a node and its dependencies in the DAG.""" + self._nodes[name] = node + self._graph.add_node(name) + for dep in deps or []: + self._graph.add_edge(dep, name) + + @staticmethod + def _make_frame_ctx( + shared_data: Dict[str, Any], frame_index: int + ) -> Dict[str, Any]: + """Create a frame context dictionary for node execution. + + Notes: + - The context includes a reference to `shared_data` via "shared". + - The context is intentionally frame-scoped and should not be used as + a replacement for shared workflow state. + + Args: + shared_data: Shared workflow data dict. + frame_index: Absolute trajectory frame index. + + Returns: + Frame context dict with required keys. + """ + return { + "shared": shared_data, + "frame_index": frame_index, + "frame_covariance": None, + } diff --git a/CodeEntropy/levels/hierarchy.py b/CodeEntropy/levels/hierarchy.py new file mode 100644 index 00000000..76c23311 --- /dev/null +++ b/CodeEntropy/levels/hierarchy.py @@ -0,0 +1,140 @@ +"""Hierarchy level selection and bead construction. + +This module defines `HierarchyBuilder`, which is responsible for: + 1) Determining which hierarchy levels apply to each molecule. + 2) Constructing "beads" (AtomGroups) for a given molecule at a given level. + +Notes: +- The "residue" bead construction must use residues attached to the provided + AtomGroup/container. Using `resindex` selection strings is unsafe because + `resindex` is global to the Universe and can produce empty/incorrect beads + when operating on per-molecule containers beyond the first molecule. +""" + +from __future__ import annotations + +import logging +from typing import List, Tuple + +logger = logging.getLogger(__name__) + + +class HierarchyBuilder: + """Determine applicable hierarchy levels and build beads for each level. + + A "level" represents a resolution scale used throughout the entropy workflow: + - united_atom: heavy-atom-centered beads (plus bonded hydrogens) + - residue: residue beads + - polymer: whole-molecule bead + + This class intentionally does not perform any entropy calculations. It only + provides structural information (levels and beads). + """ + + def select_levels(self, data_container) -> Tuple[int, List[List[str]]]: + """Select applicable hierarchy levels for each molecule in the container. + + A molecule is always assigned the `united_atom` level. + + Additional levels are included if: + - `residue`: the heavy-atom subset has more than one atom. + - `polymer`: the heavy-atom subset spans more than one residue. + + Args: + data_container: An MDAnalysis Universe (or compatible object) with + `atoms.fragments` available. + + Returns: + A tuple of: + - number_molecules: Number of molecular fragments. + - levels: List where `levels[mol_id]` is a list of level names + (strings) for that molecule in increasing coarseness. + """ + number_molecules = len(data_container.atoms.fragments) + logger.debug("The number of molecules is %d.", number_molecules) + + fragments = data_container.atoms.fragments + levels: List[List[str]] = [[] for _ in range(number_molecules)] + + for mol_id in range(number_molecules): + levels[mol_id].append("united_atom") + + heavy_atoms = fragments[mol_id].select_atoms("prop mass > 1.1") + if len(heavy_atoms) > 1: + levels[mol_id].append("residue") + + number_residues = len(heavy_atoms.residues) + if number_residues > 1: + levels[mol_id].append("polymer") + + logger.debug("Selected levels: %s", levels) + return number_molecules, levels + + def get_beads(self, data_container, level: str) -> List: + """Build beads for a given container at a given hierarchy level. + + Args: + data_container: An MDAnalysis AtomGroup representing a molecule or + other container that has `.select_atoms(...)` and `.residues`. + level: One of {"united_atom", "residue", "polymer"}. + + Returns: + A list of MDAnalysis AtomGroups representing beads at that level. + + Raises: + ValueError: If `level` is not a supported hierarchy level. + """ + if level == "polymer": + return [data_container.select_atoms("all")] + + if level == "residue": + return self._build_residue_beads(data_container) + + if level == "united_atom": + return self._build_united_atom_beads(data_container) + + raise ValueError(f"Unknown level: {level}") + + def _build_residue_beads(self, data_container) -> List: + """Build one bead per residue using the container's residues. + + Args: + data_container: MDAnalysis AtomGroup with `.residues`. + + Returns: + List of residue AtomGroups. + """ + beads = [res.atoms for res in data_container.residues] + logger.debug("Residue beads sizes: %s", [len(b) for b in beads]) + return beads + + def _build_united_atom_beads(self, data_container) -> List: + """Build united-atom beads from heavy atoms and their bonded hydrogens. + + For each heavy atom, a bead is defined as: + - that heavy atom, plus + - any bonded atoms with mass <= 1.1 (hydrogen-like). + + If no heavy atoms exist in the container, the entire container becomes + a single bead. + + Args: + data_container: MDAnalysis AtomGroup representing a molecule. + + Returns: + List of bead AtomGroups. + """ + heavy_atoms = data_container.select_atoms("prop mass > 1.1") + if len(heavy_atoms) == 0: + return [data_container.select_atoms("all")] + + beads = [] + for atom in heavy_atoms: + selection = ( + f"index {atom.index} " + f"or ((prop mass <= 1.1) and bonded index {atom.index})" + ) + beads.append(data_container.select_atoms(selection)) + + logger.debug("United-atom bead sizes: %s", [len(b) for b in beads]) + return beads diff --git a/CodeEntropy/levels/level_dag.py b/CodeEntropy/levels/level_dag.py new file mode 100644 index 00000000..39b4a1da --- /dev/null +++ b/CodeEntropy/levels/level_dag.py @@ -0,0 +1,340 @@ +"""Hierarchy-level DAG orchestration and reduction. + +This module defines the `LevelDAG`, which coordinates two stages of the hierarchy +workflow: + +1) Static stage (runs once): + - Detect molecules and available resolution levels. + - Build beads for each (molecule, level) definition. + - Initialise accumulators used during per-frame reduction. + - Compute conformational state descriptors required later by entropy nodes. + +2) Frame stage (runs for each trajectory frame): + - Execute the `FrameGraph` to produce frame-local covariance outputs. + - Reduce frame-local outputs into running (incremental) means. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +import networkx as nx + +from CodeEntropy.levels.axes import AxesCalculator +from CodeEntropy.levels.frame_dag import FrameGraph +from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode +from CodeEntropy.levels.nodes.beads import BuildBeadsNode +from CodeEntropy.levels.nodes.conformations import ComputeConformationalStatesNode +from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode +from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode + +logger = logging.getLogger(__name__) + + +class LevelDAG: + """Execute hierarchy detection, per-frame covariance calculation, and reduction. + + The LevelDAG is responsible for: + - Running a static DAG (once) to prepare shared inputs. + - Running a per-frame DAG (for each frame) to compute frame-local outputs. + - Reducing frame-local outputs into shared running means. + + The reduction performed here is an incremental mean across frames (and across + molecules within a group when frame nodes average within-frame first). + """ + + def __init__(self, universe_operations: Optional[Any] = None) -> None: + """Initialise a LevelDAG. + + Args: + universe_operations: Optional adapter providing universe operations. + Passed to the FrameGraph and the conformational-state node. + """ + self._universe_operations = universe_operations + + self._static_graph = nx.DiGraph() + self._static_nodes: Dict[str, Any] = {} + + self._frame_dag = FrameGraph(universe_operations=universe_operations) + + def build(self) -> "LevelDAG": + """Build the static and frame DAG topology. + + This registers all static nodes and their dependencies, and builds the + internal FrameGraph used for per-frame execution. + + Returns: + Self, to allow fluent chaining. + """ + self._add_static("detect_molecules", DetectMoleculesNode()) + self._add_static("detect_levels", DetectLevelsNode(), deps=["detect_molecules"]) + self._add_static("build_beads", BuildBeadsNode(), deps=["detect_levels"]) + + self._add_static( + "init_covariance_accumulators", + InitCovarianceAccumulatorsNode(), + deps=["detect_levels"], + ) + self._add_static( + "compute_conformational_states", + ComputeConformationalStatesNode(self._universe_operations), + deps=["detect_levels"], + ) + + self._frame_dag.build() + return self + + def execute( + self, shared_data: Dict[str, Any], *, progress: object | None = None + ) -> Dict[str, Any]: + """Execute the full hierarchy workflow and mutate shared_data. + + This method ensures required shared components exist, runs the static stage + once, then iterates through trajectory frames to run the per-frame stage and + reduce outputs into running means. + + Args: + shared_data: Shared workflow data dict. This mapping is mutated in-place + by both static and frame stages. + progress: Optional progress sink passed through to nodes and used for + per-frame progress reporting when supported. + + Returns: + The same shared_data mapping passed in, after mutation. + """ + shared_data.setdefault("axes_manager", AxesCalculator()) + self._run_static_stage(shared_data, progress=progress) + self._run_frame_stage(shared_data, progress=progress) + return shared_data + + def _run_static_stage( + self, shared_data: Dict[str, Any], *, progress: object | None = None + ) -> None: + """Run all static nodes in dependency order. + + Nodes are executed in topological order of the static DAG. If a progress + object is provided, it is passed to node.run when the node accepts it. + + Args: + shared_data: Shared workflow data dict to be mutated by static nodes. + progress: Optional progress sink to pass to nodes that support it. + """ + for node_name in nx.topological_sort(self._static_graph): + node = self._static_nodes[node_name] + if progress is not None: + try: + node.run(shared_data, progress=progress) + continue + except TypeError: + pass + node.run(shared_data) + + def _add_static( + self, name: str, node: Any, deps: Optional[list[str]] = None + ) -> None: + """Register a static node and its dependencies in the static DAG. + + Args: + name: Unique node name used in the static DAG. + node: Node object exposing a run(shared_data, **kwargs) method. + deps: Optional list of upstream node names that must run before this node. + + Returns: + None. Mutates the internal static graph and node registry. + """ + self._static_nodes[name] = node + self._static_graph.add_node(name) + for dep in deps or []: + self._static_graph.add_edge(dep, name) + + def _run_frame_stage( + self, shared_data: Dict[str, Any], *, progress: object | None = None + ) -> None: + """Execute the per-frame DAG stage and reduce frame outputs. + + This method iterates over the selected trajectory frames, executes the + frame-local DAG for each frame, and reduces the resulting outputs into the + shared accumulators stored in `shared_data`. + + Progress reporting is optional. If a progress sink is provided, a task is + always created. When the total number of frames cannot be determined, the + task is created with total=None (indeterminate). + + Args: + shared_data: Shared data dictionary. Must contain: + - "reduced_universe": MDAnalysis Universe providing the trajectory. + - "start", "end", "step": frame slicing parameters. + - any additional keys required by the frame DAG and reducer. + progress: Optional progress sink (e.g., from ResultsReporter.progress()). + Must expose add_task(), update(), and advance(). + + Returns: + None. Mutates `shared_data` in-place via reduction. + + Notes: + The task title shows the current frame index being processed. + """ + u = shared_data["reduced_universe"] + start, end, step = shared_data["start"], shared_data["end"], shared_data["step"] + + task = None + total_frames = None + + if progress is not None: + try: + n_frames = len(u.trajectory) + + s = 0 if start is None else int(start) + e = n_frames if end is None else int(end) + + if e < 0: + e = n_frames + e + + e = max(0, min(e, n_frames)) + s = max(0, min(s, e)) + + st = 1 if step is None else int(step) + if st > 0: + total_frames = max(0, (e - s + st - 1) // st) + except Exception: + total_frames = None + + task = progress.add_task( + "[green]Frame processing", + total=total_frames, + title="Initializing", + ) + + for ts in u.trajectory[start:end:step]: + if task is not None: + progress.update(task, title=f"Frame {ts.frame}") + + frame_out = self._frame_dag.execute_frame(shared_data, ts.frame) + self._reduce_one_frame(shared_data, frame_out) + + if task is not None: + progress.advance(task) + + @staticmethod + def _incremental_mean(old: Any, new: Any, n: int) -> Any: + """Compute an incremental mean. + + Args: + old: Previous running mean (or None for first sample). + new: New sample to incorporate. + n: 1-based sample count after adding `new`. + + Returns: + Updated running mean. + """ + if old is None: + return new.copy() if hasattr(new, "copy") else new + return old + (new - old) / float(n) + + def _reduce_one_frame( + self, shared_data: Dict[str, Any], frame_out: Dict[str, Any] + ) -> None: + """Reduce one frame's covariance outputs into shared running means. + + Args: + shared_data: Shared workflow data dict containing accumulators. + frame_out: Frame-local covariance outputs produced by FrameGraph. + """ + self._reduce_force_and_torque(shared_data, frame_out) + self._reduce_forcetorque(shared_data, frame_out) + + def _reduce_force_and_torque( + self, shared_data: Dict[str, Any], frame_out: Dict[str, Any] + ) -> None: + """Reduce force/torque covariance outputs into shared accumulators. + + Args: + shared_data: Shared workflow data dict containing: + - "force_covariances", "torque_covariances": accumulator structures. + - "frame_counts": running sample counts for each accumulator slot. + - "group_id_to_index": mapping from group id to accumulator index. + frame_out: Frame-local outputs containing "force" and "torque" sections. + + Returns: + None. Mutates accumulator values and counts in shared_data in-place. + """ + f_cov = shared_data["force_covariances"] + t_cov = shared_data["torque_covariances"] + counts = shared_data["frame_counts"] + gid2i = shared_data["group_id_to_index"] + + f_frame = frame_out["force"] + t_frame = frame_out["torque"] + + for key, F in f_frame["ua"].items(): + counts["ua"][key] = counts["ua"].get(key, 0) + 1 + n = counts["ua"][key] + f_cov["ua"][key] = self._incremental_mean(f_cov["ua"].get(key), F, n) + + for key, T in t_frame["ua"].items(): + if key not in counts["ua"]: + counts["ua"][key] = counts["ua"].get(key, 0) + 1 + n = counts["ua"][key] + t_cov["ua"][key] = self._incremental_mean(t_cov["ua"].get(key), T, n) + + for gid, F in f_frame["res"].items(): + gi = gid2i[gid] + counts["res"][gi] += 1 + n = counts["res"][gi] + f_cov["res"][gi] = self._incremental_mean(f_cov["res"][gi], F, n) + + for gid, T in t_frame["res"].items(): + gi = gid2i[gid] + if counts["res"][gi] == 0: + counts["res"][gi] += 1 + n = counts["res"][gi] + t_cov["res"][gi] = self._incremental_mean(t_cov["res"][gi], T, n) + + for gid, F in f_frame["poly"].items(): + gi = gid2i[gid] + counts["poly"][gi] += 1 + n = counts["poly"][gi] + f_cov["poly"][gi] = self._incremental_mean(f_cov["poly"][gi], F, n) + + for gid, T in t_frame["poly"].items(): + gi = gid2i[gid] + if counts["poly"][gi] == 0: + counts["poly"][gi] += 1 + n = counts["poly"][gi] + t_cov["poly"][gi] = self._incremental_mean(t_cov["poly"][gi], T, n) + + def _reduce_forcetorque( + self, shared_data: Dict[str, Any], frame_out: Dict[str, Any] + ) -> None: + """Reduce combined force-torque covariance outputs into shared accumulators. + + Args: + shared_data: Shared workflow data dict containing: + - "forcetorque_covariances": accumulator structures. + - "forcetorque_counts": running sample counts for each accumulator slot. + - "group_id_to_index": mapping from group id to accumulator index. + frame_out: Frame-local outputs that may include a "forcetorque" section. + + Returns: + None. Mutates accumulator values and counts in shared_data in-place. + """ + if "forcetorque" not in frame_out: + return + + ft_cov = shared_data["forcetorque_covariances"] + ft_counts = shared_data["forcetorque_counts"] + gid2i = shared_data["group_id_to_index"] + ft_frame = frame_out["forcetorque"] + + for gid, M in ft_frame.get("res", {}).items(): + gi = gid2i[gid] + ft_counts["res"][gi] += 1 + n = ft_counts["res"][gi] + ft_cov["res"][gi] = self._incremental_mean(ft_cov["res"][gi], M, n) + + for gid, M in ft_frame.get("poly", {}).items(): + gi = gid2i[gid] + ft_counts["poly"][gi] += 1 + n = ft_counts["poly"][gi] + ft_cov["poly"][gi] = self._incremental_mean(ft_cov["poly"][gi], M, n) diff --git a/CodeEntropy/levels/linalg.py b/CodeEntropy/levels/linalg.py new file mode 100644 index 00000000..b86defec --- /dev/null +++ b/CodeEntropy/levels/linalg.py @@ -0,0 +1,107 @@ +"""Matrix utilities used across covariance and entropy calculations. + +This module contains small, focused helpers for matrix construction and cleanup. +All functions are pure (no side effects beyond logging) and operate on NumPy +arrays. + +Key behaviors: +- `create_submatrix` computes a 3x3 outer-product block for two 3-vectors. +- `filter_zero_rows_columns` removes rows/columns that are all (near) zero. +""" + +from __future__ import annotations + +import logging + +import numpy as np + +logger = logging.getLogger(__name__) + + +class MatrixUtils: + """Utility operations for small matrix manipulations.""" + + def create_submatrix(self, data_i: np.ndarray, data_j: np.ndarray) -> np.ndarray: + """Create a 3x3 covariance-style submatrix from two 3-vectors. + + This computes the outer product of `data_i` and `data_j`: + + submatrix = outer(data_i, data_j) + + Args: + data_i: Vector of shape (3,) representing bead i values. + data_j: Vector of shape (3,) representing bead j values. + + Returns: + A (3, 3) NumPy array corresponding to the outer product. + + Raises: + ValueError: If either input cannot be reshaped to (3,). + """ + v_i = np.asarray(data_i, dtype=float).reshape(-1) + v_j = np.asarray(data_j, dtype=float).reshape(-1) + + if v_i.shape[0] != 3 or v_j.shape[0] != 3: + raise ValueError( + f"Expected 3-vectors for outer product, got {v_i.shape} " + f"and {v_j.shape}." + ) + + submatrix = np.outer(v_i, v_j) + logger.debug("Submatrix: %s", submatrix) + return submatrix + + def filter_zero_rows_columns( + self, matrix: np.ndarray, atol: float = 0.0 + ) -> np.ndarray: + """Remove rows and columns that are entirely (near) zero. + + A row (or column) is removed if all entries are close to zero according + to `np.isclose(..., atol=atol)`. + + Args: + matrix: Input 2D array. + atol: Absolute tolerance used to determine "zero". Defaults to 0.0. + + Returns: + A new matrix with all-zero rows and columns removed. If no such rows + or columns exist, returns a view/copy of the original with consistent + NumPy typing. + + Raises: + ValueError: If `matrix` is not 2D. + """ + mat = np.asarray(matrix, dtype=float) + if mat.ndim != 2: + raise ValueError(f"Expected a 2D matrix, got ndim={mat.ndim}.") + + init_shape = mat.shape + + row_mask = self._nonzero_row_mask(mat, atol=atol) + mat = mat[row_mask, :] + + col_mask = self._nonzero_col_mask(mat, atol=atol) + mat = mat[:, col_mask] + + final_shape = mat.shape + if init_shape != final_shape: + logger.debug( + "Matrix shape changed %s -> %s after removing zero rows/cols.", + init_shape, + final_shape, + ) + + logger.debug("Filtered matrix: %s", mat) + return mat + + @staticmethod + def _nonzero_row_mask(matrix: np.ndarray, atol: float) -> np.ndarray: + """Return a boolean mask selecting rows that are not all (near) zero.""" + is_zero_row = np.all(np.isclose(matrix, 0.0, atol=atol), axis=1) + return ~is_zero_row + + @staticmethod + def _nonzero_col_mask(matrix: np.ndarray, atol: float) -> np.ndarray: + """Return a boolean mask selecting columns that are not all (near) zero.""" + is_zero_col = np.all(np.isclose(matrix, 0.0, atol=atol), axis=0) + return ~is_zero_col diff --git a/CodeEntropy/levels/mda.py b/CodeEntropy/levels/mda.py new file mode 100644 index 00000000..c76b6f50 --- /dev/null +++ b/CodeEntropy/levels/mda.py @@ -0,0 +1,290 @@ +""" +MDAnalysis universe utilities. + +This module contains helpers for creating reduced MDAnalysis `Universe` objects by +sub-selecting frames and/or atoms, and for building a `Universe` that combines +coordinates from one trajectory with forces sourced from a second trajectory. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisFromFunction +from MDAnalysis.coordinates.memory import MemoryReader +from MDAnalysis.exceptions import NoDataError + +logger = logging.getLogger(__name__) + + +class UniverseOperations: + """Functions to create and manipulate MDAnalysis Universe objects. + + This helper provides methods to: + - Build reduced universes by selecting subsets of frames or atoms. + - Extract a single fragment (molecule) into a standalone universe. + - Merge coordinates from one trajectory with forces from another trajectory. + """ + + def __init__(self) -> None: + """Initialise the operations helper.""" + self._universe = None + + def select_frames( + self, + u: mda.Universe, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, + ) -> mda.Universe: + """Create a reduced universe by dropping frames according to user selection. + + Args: + u: A Universe object with topology, coordinates and (optionally) forces. + start: Frame index to start analysis. If None, defaults to 0. + end: Frame index to stop analysis (Python slicing semantics). If None, + defaults to the full trajectory length. + step: Step size between frames. + + Returns: + A reduced universe containing the selected frames, with coordinates, + forces (if present) and unit cell dimensions loaded into memory. + """ + if start is None: + start = 0 + if end is None: + end = len(u.trajectory) + + select_atom = u.select_atoms("all", updating=True) + + coordinates = self._extract_timeseries(select_atom, kind="positions")[ + start:end:step + ] + forces = self._extract_timeseries(select_atom, kind="forces")[start:end:step] + dimensions = self._extract_timeseries(select_atom, kind="dimensions")[ + start:end:step + ] + + u2 = mda.Merge(select_atom) + u2.load_new( + coordinates, + format=MemoryReader, + forces=forces, + dimensions=dimensions, + ) + + logger.debug("MDAnalysis.Universe - reduced universe (frame-selected): %s", u2) + return u2 + + def select_atoms(self, u: mda.Universe, select_string: str = "all") -> mda.Universe: + """Create a reduced universe by dropping atoms according to user selection. + + Args: + u: A Universe object with topology, coordinates and (optionally) forces. + select_string: MDAnalysis `select_atoms` selection string. + + Returns: + A reduced universe containing only the selected atoms. Coordinates, + forces (if present) and dimensions are loaded into memory. + """ + select_atom = u.select_atoms(select_string, updating=True) + + coordinates = self._extract_timeseries(select_atom, kind="positions") + forces = self._extract_timeseries(select_atom, kind="forces") + dimensions = self._extract_timeseries(select_atom, kind="dimensions") + + u2 = mda.Merge(select_atom) + u2.load_new( + coordinates, + format=MemoryReader, + forces=forces, + dimensions=dimensions, + ) + + logger.debug("MDAnalysis.Universe - reduced universe (atom-selected): %s", u2) + return u2 + + def extract_fragment( + self, universe: mda.Universe, molecule_id: int + ) -> mda.Universe: + """Extract a single molecule (fragment) as a standalone reduced universe. + + Args: + universe: The source universe. + molecule_id: Fragment index in `universe.atoms.fragments`. + + Returns: + A reduced universe containing only the atoms of the selected fragment. + """ + frag = universe.atoms.fragments[molecule_id] + selection_string = f"index {frag.indices[0]}:{frag.indices[-1]}" + return self.select_atoms(universe, selection_string) + + def merge_forces( + self, + tprfile: str, + trrfile, + forcefile: str, + fileformat: Optional[str] = None, + kcal: bool = False, + *, + force_format: Optional[str] = None, + fallback_to_positions_if_no_forces: bool = True, + ) -> mda.Universe: + """Create a universe by merging coordinates and forces from different files. + + This method loads: + - coordinates + dimensions from the coordinate trajectory (tprfile + trrfile) + - forces from the force trajectory (tprfile + forcefile) + + If the force trajectory does not expose forces in MDAnalysis (e.g. the file + does not contain forces, or the reader does not provide them), then: + - if `fallback_to_positions_if_no_forces` is True, positions from the force + trajectory are used as the "forces" array (backwards-compatible behaviour + with earlier implementations). + - otherwise, the underlying `NoDataError` is raised. + + Args: + tprfile: Topology input file. + trrfile: Coordinate trajectory file(s). This can be a single path or a + list, as accepted by MDAnalysis. + forcefile: Trajectory containing forces. + fileformat: Optional file format for the coordinate trajectory, as + recognised by MDAnalysis. + kcal: If True, scale the force array by 4.184 to convert from kcal to kJ. + force_format: Optional file format for the force trajectory. If not + provided, uses `fileformat`. + fallback_to_positions_if_no_forces: If True, and the force trajectory has + no accessible forces, use positions from the force trajectory as a + fallback (legacy behaviour). + + Returns: + A new Universe containing coordinates, forces and dimensions loaded into + memory. + """ + logger.debug("Loading coordinate Universe with %s", trrfile) + u = mda.Universe(tprfile, trrfile, format=fileformat) + + ff = force_format if force_format is not None else fileformat + logger.debug("Loading force Universe with %s", forcefile) + u_force = mda.Universe(tprfile, forcefile, format=ff) + + select_atom = u.select_atoms("all") + select_atom_force = u_force.select_atoms("all") + + coordinates = self._extract_timeseries(select_atom, kind="positions") + dimensions = self._extract_timeseries(select_atom, kind="dimensions") + + forces = self._extract_force_timeseries_with_fallback( + select_atom_force, + fallback_to_positions_if_no_forces=fallback_to_positions_if_no_forces, + ) + + if kcal: + forces *= 4.184 + + logger.debug("Merging forces with coordinates universe.") + new_universe = mda.Merge(select_atom) + new_universe.load_new( + coordinates, + forces=forces, + dimensions=dimensions, + ) + + return new_universe + + def _extract_timeseries(self, atomgroup, *, kind: str): + """Extract a time series array for the requested kind from an AtomGroup. + + Args: + atomgroup: MDAnalysis AtomGroup (may be updating). + kind: One of {"positions", "forces", "dimensions"}. + + Returns: + Time series with shape: + - positions: (n_frames, n_atoms, 3) + - forces: (n_frames, n_atoms, 3) if available, else raises NoDataError + - dimensions: (n_frames, 6) or (n_frames, 3) depending on reader + + Raises: + ValueError: If kind is not one of the supported values. + NoDataError: If kind is "forces" and the trajectory does not provide + forces via the configured reader. + """ + if kind == "positions": + func = self._positions_copy + elif kind == "forces": + func = self._forces_copy + elif kind == "dimensions": + func = self._dimensions_copy + else: + raise ValueError(f"Unknown timeseries kind: {kind}") + + return AnalysisFromFunction(func, atomgroup).run().results["timeseries"] + + def _positions_copy(self, ag): + """Return a copy of positions for AnalysisFromFunction. + + Args: + ag: MDAnalysis AtomGroup. + + Returns: + Copy of ag.positions. + """ + return ag.positions.copy() + + def _forces_copy(self, ag): + """Return a copy of forces for AnalysisFromFunction. + + Args: + ag: MDAnalysis AtomGroup. + + Returns: + Copy of ag.forces. + """ + return ag.forces.copy() + + def _dimensions_copy(self, ag): + """Return a copy of box dimensions for AnalysisFromFunction. + + Args: + ag: MDAnalysis AtomGroup. + + Returns: + Copy of ag.dimensions. + """ + return ag.dimensions.copy() + + def _extract_force_timeseries_with_fallback( + self, + atomgroup_force, + *, + fallback_to_positions_if_no_forces: bool, + ): + """Extract force timeseries, optionally falling back to positions. + + This isolates the behaviour that changed your runtime outcome: older code + used positions from the force trajectory, which never triggered `NoDataError`. + This method keeps that behaviour available for backwards compatibility. + + Args: + atomgroup_force: MDAnalysis AtomGroup sourced from the force trajectory. + fallback_to_positions_if_no_forces: If True, fall back to extracting + positions when forces are unavailable; otherwise re-raise NoDataError. + + Returns: + A time series array of shape (n_frames, n_atoms, 3). The returned array + contains forces when available, otherwise positions if fallback is enabled. + + Raises: + NoDataError: If forces are unavailable and + fallback_to_positions_if_no_forces is False. + """ + try: + return self._extract_timeseries(atomgroup_force, kind="forces") + except NoDataError: + if not fallback_to_positions_if_no_forces: + raise + return self._extract_timeseries(atomgroup_force, kind="positions") diff --git a/CodeEntropy/levels/nodes/__init__.py b/CodeEntropy/levels/nodes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CodeEntropy/levels/nodes/accumulators.py b/CodeEntropy/levels/nodes/accumulators.py new file mode 100644 index 00000000..8c2cb64c --- /dev/null +++ b/CodeEntropy/levels/nodes/accumulators.py @@ -0,0 +1,193 @@ +"""Initialize covariance accumulators. + +This module defines a LevelDAG static node that allocates all per-frame reduction +accumulators (means) and counters used by downstream frame processing. + +The node owns only initialization concerns (single responsibility): +- create group-id <-> index mappings +- allocate force/torque covariance mean containers +- allocate optional combined force-torque (FT) mean containers +- allocate per-level frame counters + +The structure created here is treated as the canonical storage layout for the +rest of the pipeline. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, MutableMapping + +import numpy as np + +logger = logging.getLogger(__name__) + +SharedData = MutableMapping[str, Any] + + +@dataclass(frozen=True) +class GroupIndex: + """Bidirectional mapping between group ids and contiguous indices.""" + + group_id_to_index: Dict[int, int] + index_to_group_id: List[int] + + +@dataclass(frozen=True) +class CovarianceAccumulators: + """Container for covariance mean accumulators and frame counters.""" + + force_covariances: Dict[str, Any] + torque_covariances: Dict[str, Any] + frame_counts: Dict[str, Any] + forcetorque_covariances: Dict[str, Any] + forcetorque_counts: Dict[str, Any] + + +class InitCovarianceAccumulatorsNode: + """Allocate accumulators and counters for per-frame reductions. + + Produces the following keys in `shared_data`: + + Canonical mean accumulators: + - force_covariances: {"ua": dict, "res": list, "poly": list} + - torque_covariances: {"ua": dict, "res": list, "poly": list} + - forcetorque_covariances: {"res": list, "poly": list} (6N x 6N means) + + Counters: + - frame_counts: {"ua": dict, "res": np.ndarray[int], "poly": np.ndarray[int]} + - forcetorque_counts: {"res": np.ndarray[int], "poly": np.ndarray[int]} + + Group index mapping: + - group_id_to_index: {group_id: index} + - index_to_group_id: [group_id_by_index] + + Backwards-compatible aliases (kept for older consumers): + - force_torque_stats -> forcetorque_covariances + - force_torque_counts -> forcetorque_counts + """ + + def run(self, shared_data: Dict[str, Any]) -> Dict[str, Any]: + """Initialize and attach all accumulator structures into shared_data. + + Args: + shared_data: Shared pipeline dictionary. Must contain "groups". + + Returns: + A dict of keys written into shared_data. + + Raises: + KeyError: If "groups" is missing from shared_data. + """ + groups = shared_data["groups"] + group_index = self._build_group_index(groups) + + accumulators = self._build_accumulators( + n_groups=len(group_index.index_to_group_id) + ) + + self._attach_to_shared_data(shared_data, group_index, accumulators) + self._attach_backwards_compatible_aliases(shared_data) + + return self._build_return_payload(shared_data) + + @staticmethod + def _build_group_index(groups: Dict[int, Any]) -> GroupIndex: + """Build group id <-> index mappings. + + Args: + groups: Mapping of group id to group members. + + Returns: + GroupIndex mapping object. + """ + group_ids = list(groups.keys()) + gid2i = {gid: i for i, gid in enumerate(group_ids)} + return GroupIndex(group_id_to_index=gid2i, index_to_group_id=list(group_ids)) + + @staticmethod + def _build_accumulators(n_groups: int) -> CovarianceAccumulators: + """Allocate empty covariance means and counters. + + Args: + n_groups: Number of molecule groups. + + Returns: + CovarianceAccumulators containing allocated containers. + """ + force_cov = {"ua": {}, "res": [None] * n_groups, "poly": [None] * n_groups} + torque_cov = {"ua": {}, "res": [None] * n_groups, "poly": [None] * n_groups} + + frame_counts = { + "ua": {}, + "res": np.zeros(n_groups, dtype=int), + "poly": np.zeros(n_groups, dtype=int), + } + + forcetorque_cov = {"res": [None] * n_groups, "poly": [None] * n_groups} + forcetorque_counts = { + "res": np.zeros(n_groups, dtype=int), + "poly": np.zeros(n_groups, dtype=int), + } + + return CovarianceAccumulators( + force_covariances=force_cov, + torque_covariances=torque_cov, + frame_counts=frame_counts, + forcetorque_covariances=forcetorque_cov, + forcetorque_counts=forcetorque_counts, + ) + + @staticmethod + def _attach_to_shared_data( + shared_data: SharedData, group_index: GroupIndex, acc: CovarianceAccumulators + ) -> None: + """Attach canonical structures to shared_data. + + Args: + shared_data: Shared pipeline dictionary. + group_index: GroupIndex object. + acc: CovarianceAccumulators object. + """ + shared_data["group_id_to_index"] = group_index.group_id_to_index + shared_data["index_to_group_id"] = group_index.index_to_group_id + + shared_data["force_covariances"] = acc.force_covariances + shared_data["torque_covariances"] = acc.torque_covariances + shared_data["frame_counts"] = acc.frame_counts + + shared_data["forcetorque_covariances"] = acc.forcetorque_covariances + shared_data["forcetorque_counts"] = acc.forcetorque_counts + + @staticmethod + def _attach_backwards_compatible_aliases(shared_data: SharedData) -> None: + """Attach backwards-compatible aliases. + + Args: + shared_data: Shared pipeline dictionary. + """ + shared_data["force_torque_stats"] = shared_data["forcetorque_covariances"] + shared_data["force_torque_counts"] = shared_data["forcetorque_counts"] + + @staticmethod + def _build_return_payload(shared_data: SharedData) -> Dict[str, Any]: + """Build the return payload containing initialized keys. + + Args: + shared_data: Shared pipeline dictionary. + + Returns: + Dict of keys to values that were set in shared_data. + """ + return { + "group_id_to_index": shared_data["group_id_to_index"], + "index_to_group_id": shared_data["index_to_group_id"], + "force_covariances": shared_data["force_covariances"], + "torque_covariances": shared_data["torque_covariances"], + "frame_counts": shared_data["frame_counts"], + "forcetorque_covariances": shared_data["forcetorque_covariances"], + "forcetorque_counts": shared_data["forcetorque_counts"], + "force_torque_stats": shared_data["force_torque_stats"], + "force_torque_counts": shared_data["force_torque_counts"], + } diff --git a/CodeEntropy/levels/nodes/beads.py b/CodeEntropy/levels/nodes/beads.py new file mode 100644 index 00000000..abd5b122 --- /dev/null +++ b/CodeEntropy/levels/nodes/beads.py @@ -0,0 +1,231 @@ +"""Build bead (AtomGroup index) definitions for each molecule and hierarchy level. + +This module defines the `BuildBeadsNode`, a static DAG node that constructs bead +definitions once, in reduced-universe index space. These bead definitions are +used by later frame-level nodes (e.g., covariance construction) without needing +to re-run selection logic every frame. + +Beads are stored as arrays of atom indices (in the reduced universe). +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, DefaultDict, Dict, List, MutableMapping, Tuple + +import numpy as np + +from CodeEntropy.levels.hierarchy import HierarchyBuilder + +logger = logging.getLogger(__name__) + +BeadKey = Tuple[int, str] | Tuple[int, str, int] +BeadsMap = Dict[BeadKey, List[np.ndarray]] + + +@dataclass(frozen=True) +class UnitedAtomBead: + """A united-atom bead associated with a residue bucket. + + Attributes: + residue_id: Local residue index within the molecule (0..n_residues-1). + atom_indices: Atom indices (reduced-universe index space) belonging to the bead. + """ + + residue_id: int + atom_indices: np.ndarray + + +class BuildBeadsNode: + """Build bead definitions once, in reduced-universe index space. + + Output contract: + Writes `shared_data["beads"]` with keys: + - (mol_id, "united_atom", res_id) -> list[np.ndarray] + - (mol_id, "residue") -> list[np.ndarray] + - (mol_id, "polymer") -> list[np.ndarray] + + Notes: + United-atom beads are generated at the molecule level (preserving the + underlying ordering provided by `HierarchyBuilder.get_beads`) and then + grouped into residue buckets based on the heavy atom that defines the bead. + """ + + def __init__(self, hierarchy: HierarchyBuilder | None = None) -> None: + """Initialize the node. + + Args: + hierarchy: Optional `HierarchyBuilder` dependency. If not provided, + a default instance is created. + """ + self._hier = hierarchy or HierarchyBuilder() + + def run(self, shared_data: Dict[str, Any]) -> Dict[str, Any]: + """Build bead definitions for all molecules and levels. + + Args: + shared_data: Shared data dictionary. Requires: + - "reduced_universe": MDAnalysis.Universe + - "levels": list[list[str]] + + Returns: + Dict containing the "beads" mapping (also written into shared_data). + + Raises: + KeyError: If required keys are missing from `shared_data`. + """ + u = shared_data["reduced_universe"] + levels: List[List[str]] = shared_data["levels"] + + beads: BeadsMap = {} + fragments = u.atoms.fragments + + for mol_id, level_list in enumerate(levels): + mol = fragments[mol_id] + + if "united_atom" in level_list: + self._add_united_atom_beads(beads=beads, mol_id=mol_id, mol=mol) + + if "residue" in level_list: + self._add_residue_beads(beads=beads, mol_id=mol_id, mol=mol) + + if "polymer" in level_list: + self._add_polymer_beads(beads=beads, mol_id=mol_id, mol=mol) + + shared_data["beads"] = beads + return {"beads": beads} + + def _add_united_atom_beads( + self, beads: MutableMapping[BeadKey, List[np.ndarray]], mol_id: int, mol + ) -> None: + """Compute and store united-atom beads grouped into residue buckets. + + Args: + beads: Output bead mapping mutated in-place. + mol_id: Molecule (fragment) index. + mol: MDAnalysis AtomGroup representing the molecule. + """ + ua_beads = self._hier.get_beads(mol, "united_atom") + + buckets: DefaultDict[int, List[np.ndarray]] = defaultdict(list) + for bead_i, bead in enumerate(ua_beads): + atom_indices = self._validate_bead_indices( + bead, mol_id=mol_id, level="united_atom", bead_i=bead_i + ) + if atom_indices is None: + continue + + residue_id = self._infer_local_residue_id(mol=mol, bead=bead) + buckets[residue_id].append(atom_indices) + + for local_res_id in range(len(mol.residues)): + beads[(mol_id, "united_atom", local_res_id)] = buckets.get(local_res_id, []) + + def _add_residue_beads( + self, beads: MutableMapping[BeadKey, List[np.ndarray]], mol_id: int, mol + ) -> None: + """Compute and store residue beads. + + Args: + beads: Output bead mapping mutated in-place. + mol_id: Molecule (fragment) index. + mol: MDAnalysis AtomGroup representing the molecule. + """ + res_beads = self._hier.get_beads(mol, "residue") + kept: List[np.ndarray] = [] + + for bead_i, bead in enumerate(res_beads): + atom_indices = self._validate_bead_indices( + bead, mol_id=mol_id, level="residue", bead_i=bead_i + ) + if atom_indices is None: + continue + kept.append(atom_indices) + + beads[(mol_id, "residue")] = kept + + if len(kept) == 0: + logger.error( + "[BuildBeadsNode] No residue beads kept for mol=%s. Residue-level " + "entropy may be 0.0.", + mol_id, + ) + + def _add_polymer_beads( + self, beads: MutableMapping[BeadKey, List[np.ndarray]], mol_id: int, mol + ) -> None: + """Compute and store polymer beads. + + Args: + beads: Output bead mapping mutated in-place. + mol_id: Molecule (fragment) index. + mol: MDAnalysis AtomGroup representing the molecule. + """ + poly_beads = self._hier.get_beads(mol, "polymer") + kept: List[np.ndarray] = [] + + for bead_i, bead in enumerate(poly_beads): + atom_indices = self._validate_bead_indices( + bead, mol_id=mol_id, level="polymer", bead_i=bead_i + ) + if atom_indices is None: + continue + kept.append(atom_indices) + + beads[(mol_id, "polymer")] = kept + + @staticmethod + def _validate_bead_indices( + bead, mol_id: int, level: str, bead_i: int + ) -> np.ndarray | None: + """Return a bead's atom indices, or None if the bead is empty. + + Args: + bead: MDAnalysis AtomGroup representing the bead. + mol_id: Molecule id used only for logging context. + level: Level name used only for logging context. + bead_i: Bead index used only for logging context. + + Returns: + A copy of the bead indices as a NumPy array, or None if the bead is empty. + """ + if len(bead) == 0: + logger.warning( + "[BuildBeadsNode] Empty bead skipped: mol=%s level=%s bead_i=%s", + mol_id, + level, + bead_i, + ) + return None + return bead.indices.copy() + + @staticmethod + def _infer_local_residue_id(mol, bead) -> int: + """Infer the local residue bucket for a united-atom bead. + + Strategy: + - Select heavy atoms in the bead (mass > 1.1). + - Use the first heavy atom's `resindex` (universe-level). + - Map that universe-level `resindex` back to the molecule's local residue + index by scanning `mol.residues`. + + Args: + mol: Molecule AtomGroup. + bead: United-atom bead AtomGroup. + + Returns: + Local residue index in [0, len(mol.residues) - 1]. Falls back to 0 if + mapping cannot be determined. + """ + heavy = bead.select_atoms("prop mass > 1.1") + if len(heavy) == 0: + return 0 + + target_resindex = int(heavy[0].resindex) + for local_i, res in enumerate(mol.residues): + if int(res.resindex) == target_resindex: + return local_i + + return 0 diff --git a/CodeEntropy/levels/nodes/conformations.py b/CodeEntropy/levels/nodes/conformations.py new file mode 100644 index 00000000..90a088fb --- /dev/null +++ b/CodeEntropy/levels/nodes/conformations.py @@ -0,0 +1,116 @@ +"""Compute conformational states for configurational entropy calculations. + +This module defines a static DAG node that scans the trajectory and builds +conformational state descriptors (united-atom and residue level). The resulting +states are stored in `shared_data` for later use by configurational entropy +calculations. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict + +from CodeEntropy.levels.dihedrals import ConformationStateBuilder + +SharedData = Dict[str, Any] +ConformationalStates = Dict[str, Any] + + +@dataclass(frozen=True) +class ConformationalStateConfig: + """Configuration for conformational state construction. + + Attributes: + start: Start frame index (inclusive). + end: End frame index (exclusive). + step: Frame stride. + bin_width: Histogram bin width in degrees. + """ + + start: int + end: int + step: int + bin_width: int + + +class ComputeConformationalStatesNode: + """Static node that computes conformational states from trajectory dihedrals. + + Produces: + shared_data["conformational_states"] = {"ua": states_ua, "res": states_res} + + Where: + - states_ua is a dict keyed by (group_id, local_residue_id) + - states_res is a list-like structure indexed by group_id (or equivalent) + """ + + def __init__(self, universe_operations: Any) -> None: + """Initialize the node. + + Args: + universe_operations: Object providing universe selection utilities used + by `ConformationStateBuilder`. + """ + self._dihedral_analysis = ConformationStateBuilder( + universe_operations=universe_operations + ) + + def run( + self, shared_data: SharedData, *, progress: object | None = None + ) -> Dict[str, ConformationalStates]: + """Compute conformational states and store them in shared_data. + + Args: + shared_data: Shared data dictionary. Requires: + - "reduced_universe" + - "levels" + - "groups" + - "start", "end", "step" + - "args" with attribute "bin_width" + progress: Optional progress sink provided by ResultsReporter.progress(). + + Returns: + Dict containing "conformational_states" (also written into shared_data). + """ + u = shared_data["reduced_universe"] + levels = shared_data["levels"] + groups = shared_data["groups"] + + cfg = self._build_config(shared_data) + + states_ua, states_res = self._dihedral_analysis.build_conformational_states( + data_container=u, + levels=levels, + groups=groups, + start=cfg.start, + end=cfg.end, + step=cfg.step, + bin_width=cfg.bin_width, + progress=progress, + ) + + conformational_states: ConformationalStates = { + "ua": states_ua, + "res": states_res, + } + shared_data["conformational_states"] = conformational_states + return {"conformational_states": conformational_states} + + @staticmethod + def _build_config(shared_data: SharedData) -> ConformationalStateConfig: + """Extract and validate configuration from shared_data. + + Args: + shared_data: Shared data dictionary. + + Returns: + ConformationalStateConfig with normalized integer fields. + """ + start = int(shared_data["start"]) + end = int(shared_data["end"]) + step = int(shared_data["step"]) + bin_width = int(shared_data["args"].bin_width) + return ConformationalStateConfig( + start=start, end=end, step=step, bin_width=bin_width + ) diff --git a/CodeEntropy/levels/nodes/covariance.py b/CodeEntropy/levels/nodes/covariance.py new file mode 100644 index 00000000..1085aad4 --- /dev/null +++ b/CodeEntropy/levels/nodes/covariance.py @@ -0,0 +1,699 @@ +"""Frame-level covariance (second-moment) construction. + +This module computes per-frame second-moment matrices for force and torque +vectors at each hierarchy level (united_atom, residue, polymer). Results are +incrementally averaged across molecules within a group for the current frame. + +Responsibilities: +- Build bead-level force/torque vectors using ForceTorqueCalculator. +- Construct per-frame force/torque second moments (outer products). +- Optionally construct combined force-torque block matrices. +- Average per-frame matrices across molecules in the same group. + +Not responsible for: +- Defining groups/levels/beads mapping (provided via shared context). +- Axis construction policy (delegated to axes_manager). +- Accumulating across frames (handled by the higher-level reducer). +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from MDAnalysis.lib.mdamath import make_whole + +from CodeEntropy.levels.forces import ForceTorqueCalculator + +logger = logging.getLogger(__name__) + +FrameCtx = Dict[str, Any] +Matrix = np.ndarray + + +class FrameCovarianceNode: + """Build per-frame covariance-like (second-moment) matrices for each group. + + This node computes per-frame second-moment matrices (outer products) for + force and torque generalized vectors at hierarchy levels: + - united_atom + - residue + - polymer + + Within a single frame, outputs are incrementally averaged across molecules + that belong to the same group. Frame-to-frame accumulation is handled + elsewhere (by a higher-level reducer). + """ + + def __init__(self) -> None: + """Initialise the frame covariance node.""" + self._ft = ForceTorqueCalculator() + + def run(self, ctx: FrameCtx) -> Dict[str, Any]: + """Compute and store per-frame force/torque (and optional FT) matrices. + + Args: + ctx: Frame context dict expected to include: + - "shared": dict containing reduced_universe, groups, levels, beads, + args + - shared["axes_manager"] (created in static stage) + + Returns: + The frame covariance payload also stored at ctx["frame_covariance"]. + + Raises: + KeyError: If ctx is missing required fields. + """ + shared = self._get_shared(ctx) + + u = shared["reduced_universe"] + groups = shared["groups"] + levels = shared["levels"] + beads = shared["beads"] + args = shared["args"] + axes_manager = shared.get("axes_manager") + + fp = float(args.force_partitioning) + combined = bool(getattr(args, "combined_forcetorque", False)) + customised_axes = bool(getattr(args, "customised_axes", False)) + + box = self._try_get_box(u) + fragments = u.atoms.fragments + + out_force: Dict[str, Dict[Any, Matrix]] = {"ua": {}, "res": {}, "poly": {}} + out_torque: Dict[str, Dict[Any, Matrix]] = {"ua": {}, "res": {}, "poly": {}} + out_ft: Optional[Dict[str, Dict[Any, Matrix]]] = ( + {"ua": {}, "res": {}, "poly": {}} if combined else None + ) + + ua_molcount: Dict[Tuple[int, int], int] = {} + res_molcount: Dict[int, int] = {} + poly_molcount: Dict[int, int] = {} + + for group_id, mol_ids in groups.items(): + for mol_id in mol_ids: + mol = fragments[mol_id] + level_list = levels[mol_id] + + if "united_atom" in level_list: + self._process_united_atom( + u=u, + mol=mol, + mol_id=mol_id, + group_id=group_id, + beads=beads, + axes_manager=axes_manager, + box=box, + force_partitioning=fp, + customised_axes=customised_axes, + is_highest=("united_atom" == level_list[-1]), + out_force=out_force, + out_torque=out_torque, + molcount=ua_molcount, + ) + + if "residue" in level_list: + self._process_residue( + u=u, + mol=mol, + mol_id=mol_id, + group_id=group_id, + beads=beads, + axes_manager=axes_manager, + box=box, + customised_axes=customised_axes, + force_partitioning=fp, + is_highest=("residue" == level_list[-1]), + out_force=out_force, + out_torque=out_torque, + out_ft=out_ft, + molcount=res_molcount, + combined=combined, + ) + + if "polymer" in level_list: + self._process_polymer( + u=u, + mol=mol, + mol_id=mol_id, + group_id=group_id, + beads=beads, + axes_manager=axes_manager, + box=box, + force_partitioning=fp, + is_highest=("polymer" == level_list[-1]), + out_force=out_force, + out_torque=out_torque, + out_ft=out_ft, + molcount=poly_molcount, + combined=combined, + ) + + frame_cov: Dict[str, Any] = {"force": out_force, "torque": out_torque} + if combined and out_ft is not None: + frame_cov["forcetorque"] = out_ft + + ctx["frame_covariance"] = frame_cov + return frame_cov + + def _process_united_atom( + self, + *, + u: Any, + mol: Any, + mol_id: int, + group_id: int, + beads: Dict[Any, List[Any]], + axes_manager: Any, + box: Optional[np.ndarray], + force_partitioning: float, + customised_axes: bool, + is_highest: bool, + out_force: Dict[str, Dict[Any, Matrix]], + out_torque: Dict[str, Dict[Any, Matrix]], + molcount: Dict[Tuple[int, int], int], + ) -> None: + """Compute UA-level force/torque second moments for one molecule. + + For each residue in the molecule, bead vectors are computed for all UA + beads in that residue. The resulting second-moment matrices are then + incrementally averaged across molecules in the same group for this frame. + + Args: + u: MDAnalysis Universe (or compatible) providing atom access. + mol: Molecule/fragment object providing residues/atoms. + mol_id: Molecule id used for bead keying. + group_id: Group identifier used for within-frame averaging. + beads: Mapping from bead keys to lists of atom indices. + axes_manager: Axes manager used to determine axes/centers/MOI. + box: Optional box vector used for PBC-aware displacements. + force_partitioning: Force scaling factor applied at highest level. + customised_axes: Whether to use customised axes methods when available. + is_highest: Whether the UA level is the highest level for the molecule. + out_force: Output accumulator for UA force second moments. + out_torque: Output accumulator for UA torque second moments. + molcount: Per-(group_id, local_res_i) molecule counters for averaging. + + Returns: + None. Mutates out_force/out_torque and molcount in-place. + """ + for local_res_i, res in enumerate(mol.residues): + bead_key = (mol_id, "united_atom", local_res_i) + bead_idx_list = beads.get(bead_key, []) + if not bead_idx_list: + continue + + bead_groups = [u.atoms[idx] for idx in bead_idx_list] + if any(len(bg) == 0 for bg in bead_groups): + continue + + force_vecs, torque_vecs = self._build_ua_vectors( + residue_atoms=res.atoms, + bead_groups=bead_groups, + axes_manager=axes_manager, + box=box, + force_partitioning=force_partitioning, + customised_axes=customised_axes, + is_highest=is_highest, + ) + + F, T = self._ft.compute_frame_covariance(force_vecs, torque_vecs) + + key = (group_id, local_res_i) + n = molcount.get(key, 0) + 1 + out_force["ua"][key] = self._inc_mean(out_force["ua"].get(key), F, n) + out_torque["ua"][key] = self._inc_mean(out_torque["ua"].get(key), T, n) + molcount[key] = n + + def _process_residue( + self, + *, + u: Any, + mol: Any, + mol_id: int, + group_id: int, + beads: Dict[Any, List[Any]], + axes_manager: Any, + box: Optional[np.ndarray], + customised_axes: bool, + force_partitioning: float, + is_highest: bool, + out_force: Dict[str, Dict[Any, Matrix]], + out_torque: Dict[str, Dict[Any, Matrix]], + out_ft: Optional[Dict[str, Dict[Any, Matrix]]], + molcount: Dict[int, int], + combined: bool, + ) -> None: + """Compute residue-level force/torque (and optional FT) moments for one + molecule. + + Residue bead vectors are constructed for the molecule and used to compute + per-frame force and torque second-moment matrices. Outputs are then + incrementally averaged across molecules in the same group for this frame. + If combined FT matrices are enabled and this is the highest level, a + force-torque block matrix is also constructed and averaged. + + Args: + u: MDAnalysis Universe (or compatible) providing atom access. + mol: Molecule/fragment object providing atoms/residues. + mol_id: Molecule id used for bead keying. + group_id: Group identifier used for within-frame averaging. + beads: Mapping from bead keys to lists of atom indices. + axes_manager: Axes manager used to determine axes/centers/MOI. + box: Optional box vector used for PBC-aware displacements. + customised_axes: Whether to use customised axes methods when available. + force_partitioning: Force scaling factor applied at highest level. + is_highest: Whether residue level is the highest level for the molecule. + out_force: Output accumulator for residue force second moments. + out_torque: Output accumulator for residue torque second moments. + out_ft: Optional output accumulator for residue combined FT matrices. + molcount: Per-group molecule counter for within-frame averaging. + combined: Whether combined force-torque matrices are enabled. + + Returns: + None. Mutates output dictionaries and molcount in-place. + """ + bead_key = (mol_id, "residue") + bead_idx_list = beads.get(bead_key, []) + if not bead_idx_list: + return + + bead_groups = [u.atoms[idx] for idx in bead_idx_list] + if any(len(bg) == 0 for bg in bead_groups): + return + + force_vecs, torque_vecs = self._build_residue_vectors( + mol=mol, + bead_groups=bead_groups, + axes_manager=axes_manager, + box=box, + customised_axes=customised_axes, + force_partitioning=force_partitioning, + is_highest=is_highest, + ) + + F, T = self._ft.compute_frame_covariance(force_vecs, torque_vecs) + + n = molcount.get(group_id, 0) + 1 + out_force["res"][group_id] = self._inc_mean( + out_force["res"].get(group_id), F, n + ) + out_torque["res"][group_id] = self._inc_mean( + out_torque["res"].get(group_id), T, n + ) + molcount[group_id] = n + + if combined and is_highest and out_ft is not None: + M = self._build_ft_block(force_vecs, torque_vecs) + out_ft["res"][group_id] = self._inc_mean(out_ft["res"].get(group_id), M, n) + + def _process_polymer( + self, + *, + u: Any, + mol: Any, + mol_id: int, + group_id: int, + beads: Dict[Any, List[Any]], + axes_manager: Any, + box: Optional[np.ndarray], + force_partitioning: float, + is_highest: bool, + out_force: Dict[str, Dict[Any, Matrix]], + out_torque: Dict[str, Dict[Any, Matrix]], + out_ft: Optional[Dict[str, Dict[Any, Matrix]]], + molcount: Dict[int, int], + combined: bool, + ) -> None: + """Compute polymer-level force/torque (and optional FT) moments for one + molecule. + + Polymer level uses a single bead. Translation/rotation axes, center, and + principal moments of inertia are computed, then used to build the + generalized force and torque vectors. Outputs are incrementally averaged + across molecules in the same group for this frame. If combined FT matrices + are enabled and this is the highest level, a force-torque block matrix is + also constructed and averaged. + + Args: + u: MDAnalysis Universe (or compatible) providing atom access. + mol: Molecule/fragment object providing atoms. + mol_id: Molecule id used for bead keying. + group_id: Group identifier used for within-frame averaging. + beads: Mapping from bead keys to lists of atom indices. + axes_manager: Axes manager used to determine axes/centers/MOI. + box: Optional box vector used for PBC-aware displacements. + force_partitioning: Force scaling factor applied at highest level. + is_highest: Whether polymer level is the highest level for the molecule. + out_force: Output accumulator for polymer force second moments. + out_torque: Output accumulator for polymer torque second moments. + out_ft: Optional output accumulator for polymer combined FT matrices. + molcount: Per-group molecule counter for within-frame averaging. + combined: Whether combined force-torque matrices are enabled. + + Returns: + None. Mutates output dictionaries and molcount in-place. + """ + bead_key = (mol_id, "polymer") + bead_idx_list = beads.get(bead_key, []) + if not bead_idx_list: + return + + bead_groups = [u.atoms[idx] for idx in bead_idx_list] + if any(len(bg) == 0 for bg in bead_groups): + return + + bead = bead_groups[0] + + trans_axes, rot_axes, center, moi = self._get_polymer_axes( + mol=mol, bead=bead, axes_manager=axes_manager + ) + + force_vecs = [ + self._ft.get_weighted_forces( + bead=bead, + trans_axes=np.asarray(trans_axes), + highest_level=is_highest, + force_partitioning=force_partitioning, + ) + ] + torque_vecs = [ + self._ft.get_weighted_torques( + bead=bead, + rot_axes=np.asarray(rot_axes), + center=np.asarray(center), + force_partitioning=force_partitioning, + moment_of_inertia=np.asarray(moi), + axes_manager=axes_manager, + box=box, + ) + ] + + F, T = self._ft.compute_frame_covariance(force_vecs, torque_vecs) + + n = molcount.get(group_id, 0) + 1 + out_force["poly"][group_id] = self._inc_mean( + out_force["poly"].get(group_id), F, n + ) + out_torque["poly"][group_id] = self._inc_mean( + out_torque["poly"].get(group_id), T, n + ) + molcount[group_id] = n + + if combined and is_highest and out_ft is not None: + M = self._build_ft_block(force_vecs, torque_vecs) + out_ft["poly"][group_id] = self._inc_mean( + out_ft["poly"].get(group_id), M, n + ) + + def _build_ua_vectors( + self, + *, + bead_groups: List[Any], + residue_atoms: Any, + axes_manager: Any, + box: Optional[np.ndarray], + force_partitioning: float, + customised_axes: bool, + is_highest: bool, + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Build force/torque vectors for UA-level beads of one residue. + + Args: + bead_groups: List of UA bead AtomGroups for the residue. + residue_atoms: AtomGroup for the residue atoms (used for axes when vanilla). + axes_manager: Axes manager used to determine axes/centers/MOI. + box: Optional box vector used for PBC-aware displacements. + force_partitioning: Force scaling factor applied at highest level. + customised_axes: Whether to use customised axes methods when available. + is_highest: Whether UA level is the highest level for the molecule. + + Returns: + A tuple (force_vecs, torque_vecs), each a list of (3,) vectors ordered + by UA bead index within the residue. + """ + force_vecs: List[np.ndarray] = [] + torque_vecs: List[np.ndarray] = [] + + for ua_i, bead in enumerate(bead_groups): + if customised_axes: + trans_axes, rot_axes, center, moi = axes_manager.get_UA_axes( + residue_atoms, ua_i + ) + else: + make_whole(residue_atoms) + make_whole(bead) + + trans_axes = residue_atoms.principal_axes() + rot_axes, moi = axes_manager.get_vanilla_axes(bead) + center = bead.center_of_mass(unwrap=True) + + force_vecs.append( + self._ft.get_weighted_forces( + bead=bead, + trans_axes=np.asarray(trans_axes), + highest_level=is_highest, + force_partitioning=force_partitioning, + ) + ) + torque_vecs.append( + self._ft.get_weighted_torques( + bead=bead, + rot_axes=np.asarray(rot_axes), + center=np.asarray(center), + force_partitioning=force_partitioning, + moment_of_inertia=np.asarray(moi), + axes_manager=axes_manager, + box=box, + ) + ) + + return force_vecs, torque_vecs + + def _build_residue_vectors( + self, + *, + mol: Any, + bead_groups: List[Any], + axes_manager: Any, + box: Optional[np.ndarray], + customised_axes: bool, + force_partitioning: float, + is_highest: bool, + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Build force/torque vectors for residue-level beads of one molecule. + + Args: + mol: Molecule/fragment object providing residues/atoms. + bead_groups: List of residue bead AtomGroups for the molecule. + axes_manager: Axes manager used to determine axes/centers/MOI. + box: Optional box vector used for PBC-aware displacements. + customised_axes: Whether to use customised axes methods when available. + force_partitioning: Force scaling factor applied at highest level. + is_highest: Whether residue level is the highest level for the molecule. + + Returns: + A tuple (force_vecs, torque_vecs), each a list of (3,) vectors ordered + by residue index within the molecule. + """ + force_vecs: List[np.ndarray] = [] + torque_vecs: List[np.ndarray] = [] + + for local_res_i, bead in enumerate(bead_groups): + trans_axes, rot_axes, center, moi = self._get_residue_axes( + mol=mol, + bead=bead, + local_res_i=local_res_i, + axes_manager=axes_manager, + customised_axes=customised_axes, + ) + + force_vecs.append( + self._ft.get_weighted_forces( + bead=bead, + trans_axes=np.asarray(trans_axes), + highest_level=is_highest, + force_partitioning=force_partitioning, + ) + ) + torque_vecs.append( + self._ft.get_weighted_torques( + bead=bead, + rot_axes=np.asarray(rot_axes), + center=np.asarray(center), + force_partitioning=force_partitioning, + moment_of_inertia=np.asarray(moi), + axes_manager=axes_manager, + box=box, + ) + ) + + return force_vecs, torque_vecs + + def _get_residue_axes( + self, + *, + mol: Any, + bead: Any, + local_res_i: int, + axes_manager: Any, + customised_axes: bool, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Get translation/rotation axes, center and MOI for a residue bead. + + Args: + mol: Molecule/fragment object providing residues/atoms. + bead: Residue bead AtomGroup. + local_res_i: Residue index within the molecule. + axes_manager: Axes manager used to determine axes/centers/MOI. + customised_axes: Whether to use customised axes methods when available. + + Returns: + Tuple (trans_axes, rot_axes, center, moi) where: + - trans_axes: (3, 3) translation axes + - rot_axes: (3, 3) rotation axes + - center: (3,) center of mass + - moi: (3,) principal moments of inertia + """ + if customised_axes: + res = mol.residues[local_res_i] + return axes_manager.get_residue_axes(mol, local_res_i, residue=res.atoms) + + make_whole(mol.atoms) + make_whole(bead) + + trans_axes = mol.atoms.principal_axes() + rot_axes, moi = axes_manager.get_vanilla_axes(bead) + center = bead.center_of_mass(unwrap=True) + return ( + np.asarray(trans_axes), + np.asarray(rot_axes), + np.asarray(center), + np.asarray(moi), + ) + + def _get_polymer_axes( + self, + *, + mol: Any, + bead: Any, + axes_manager: Any, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Get translation/rotation axes, center and MOI for a polymer bead. + + Args: + mol: Molecule/fragment object providing atoms. + bead: Polymer bead AtomGroup. + axes_manager: Axes manager used to determine axes/centers/MOI. + + Returns: + Tuple (trans_axes, rot_axes, center, moi) with shapes (3,3), (3,3), (3,), + and (3,) respectively. + """ + make_whole(mol.atoms) + make_whole(bead) + + trans_axes = mol.atoms.principal_axes() + rot_axes, moi = axes_manager.get_vanilla_axes(bead) + center = bead.center_of_mass(unwrap=True) + + return ( + np.asarray(trans_axes), + np.asarray(rot_axes), + np.asarray(center), + np.asarray(moi), + ) + + @staticmethod + def _get_shared(ctx: FrameCtx) -> Dict[str, Any]: + """Fetch shared context from a frame context dict. + + Args: + ctx: Frame context dictionary expected to contain a "shared" key. + + Returns: + The shared context dict stored at ctx["shared"]. + + Raises: + KeyError: If "shared" is not present in ctx. + """ + if "shared" not in ctx: + raise KeyError("FrameCovarianceNode expects ctx['shared'].") + return ctx["shared"] + + @staticmethod + def _try_get_box(u: Any) -> Optional[np.ndarray]: + """Extract a (3,) box vector from an MDAnalysis universe when available. + + Args: + u: MDAnalysis Universe (or compatible) that may expose dimensions. + + Returns: + A numpy array of shape (3,) containing box lengths, or None if not + available. + """ + try: + return np.asarray(u.dimensions[:3], dtype=float) + except Exception: + return None + + @staticmethod + def _inc_mean(old: Optional[np.ndarray], new: np.ndarray, n: int) -> np.ndarray: + """Compute an incremental mean (streaming average). + + Args: + old: Previous running mean value, or None for the first sample. + new: New sample to incorporate. + n: 1-based sample count after adding the new sample. + + Returns: + Updated running mean. + """ + if old is None: + return new.copy() + return old + (new - old) / float(n) + + @staticmethod + def _build_ft_block( + force_vecs: List[np.ndarray], torque_vecs: List[np.ndarray] + ) -> np.ndarray: + """Build a combined force-torque block matrix for a frame. + + For each bead i, create a 6-vector [Fi, Ti]. The block matrix is built + from outer products of these 6-vectors. + + Args: + force_vecs: List of per-bead force vectors, each of shape (3,). + torque_vecs: List of per-bead torque vectors, each of shape (3,). + + Returns: + A block matrix of shape (6N, 6N) where N is the number of beads. + + Raises: + ValueError: If force_vecs and torque_vecs have different lengths, if no + bead vectors are provided, or if any input vector is not length 3. + """ + if len(force_vecs) != len(torque_vecs): + raise ValueError("force_vecs and torque_vecs must have the same length.") + + n = len(force_vecs) + if n == 0: + raise ValueError("No bead vectors available to build an FT matrix.") + + bead_vecs: List[np.ndarray] = [] + for Fi, Ti in zip(force_vecs, torque_vecs, strict=True): + Fi = np.asarray(Fi, dtype=float).reshape(-1) + Ti = np.asarray(Ti, dtype=float).reshape(-1) + if Fi.size != 3 or Ti.size != 3: + raise ValueError("Each force/torque vector must be length 3.") + bead_vecs.append(np.concatenate([Fi, Ti], axis=0)) + + blocks: List[List[np.ndarray]] = [[None] * n for _ in range(n)] + for i in range(n): + for j in range(i, n): + sub = np.outer(bead_vecs[i], bead_vecs[j]) + blocks[i][j] = sub + blocks[j][i] = sub.T + + return np.block(blocks) diff --git a/CodeEntropy/levels/nodes/detect_levels.py b/CodeEntropy/levels/nodes/detect_levels.py new file mode 100644 index 00000000..6faecaaa --- /dev/null +++ b/CodeEntropy/levels/nodes/detect_levels.py @@ -0,0 +1,65 @@ +"""Detect hierarchy levels present for each molecule in the reduced universe. + +This module defines a static DAG node responsible for determining which +hierarchical levels (united_atom, residue, polymer) apply to each molecule. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Tuple + +from CodeEntropy.levels.hierarchy import HierarchyBuilder + +SharedData = Dict[str, Any] +Levels = List[List[str]] + + +class DetectLevelsNode: + """Static node that determines hierarchy levels per molecule. + + Produces: + shared_data["levels"] + shared_data["number_molecules"] + """ + + def __init__(self) -> None: + """Initialize the node with a HierarchyBuilder helper.""" + self._hierarchy = HierarchyBuilder() + + def run(self, shared_data: SharedData) -> Dict[str, Any]: + """Detect levels and store results in shared_data. + + Args: + shared_data: Shared data dictionary. Requires: + - "reduced_universe" + + Returns: + Dict containing: + - "levels": List of levels per molecule. + - "number_molecules": Total molecule count. + + Raises: + KeyError: If required keys are missing. + """ + universe = shared_data["reduced_universe"] + + number_molecules, levels = self._detect_levels(universe) + + shared_data["levels"] = levels + shared_data["number_molecules"] = number_molecules + + return { + "levels": levels, + "number_molecules": number_molecules, + } + + def _detect_levels(self, universe: Any) -> Tuple[int, Levels]: + """Delegate level detection to HierarchyBuilder. + + Args: + universe: Reduced MDAnalysis universe. + + Returns: + Tuple of molecule count and levels list. + """ + return self._hierarchy.select_levels(universe) diff --git a/CodeEntropy/levels/nodes/detect_molecules.py b/CodeEntropy/levels/nodes/detect_molecules.py new file mode 100644 index 00000000..bfc188d5 --- /dev/null +++ b/CodeEntropy/levels/nodes/detect_molecules.py @@ -0,0 +1,108 @@ +"""Detect molecules and build grouping definitions for the reduced universe. + +This module defines a static DAG node responsible for ensuring a reduced +universe is available and generating molecule groupings using the configured +grouping strategy. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict + +from CodeEntropy.molecules.grouping import MoleculeGrouper + +logger = logging.getLogger(__name__) + +SharedData = Dict[str, Any] + + +class DetectMoleculesNode: + """Static node that establishes molecule groups. + + Produces: + shared_data["reduced_universe"] + shared_data["groups"] + shared_data["number_molecules"] + """ + + def __init__(self) -> None: + """Initialize the node with a molecule grouping helper.""" + self._grouping = MoleculeGrouper() + + def run(self, shared_data: SharedData) -> Dict[str, Any]: + """Detect molecules and create grouping definitions. + + Args: + shared_data: Shared data dictionary. Requires: + - "universe" + - "args" + + Returns: + Dict containing: + - "groups": Molecule grouping dictionary. + - "number_molecules": Total molecule count. + + Raises: + KeyError: If required keys are missing. + """ + universe = self._ensure_reduced_universe(shared_data) + + grouping_strategy = self._get_grouping_strategy(shared_data) + + groups = self._grouping.grouping_molecules(universe, grouping_strategy) + number_molecules = self._count_molecules(universe) + + shared_data["groups"] = groups + shared_data["number_molecules"] = number_molecules + + return { + "groups": groups, + "number_molecules": number_molecules, + } + + def _ensure_reduced_universe(self, shared_data: SharedData) -> Any: + """Ensure reduced_universe exists in shared_data. + + Args: + shared_data: Shared data dictionary. + + Returns: + Reduced universe object. + + Raises: + KeyError: If no universe is available. + """ + universe = shared_data.get("reduced_universe") + + if universe is None: + universe = shared_data.get("universe") + if universe is None: + raise KeyError("shared_data must contain 'universe'") + shared_data["reduced_universe"] = universe + + return universe + + def _get_grouping_strategy(self, shared_data: SharedData) -> str: + """Extract grouping strategy from args. + + Args: + shared_data: Shared data dictionary. + + Returns: + Grouping strategy string. + """ + args = shared_data["args"] + return getattr(args, "grouping", "each") + + @staticmethod + def _count_molecules(universe: Any) -> int: + """Count molecules in the universe. + + Args: + universe: MDAnalysis universe. + + Returns: + Number of molecular fragments. + """ + return len(universe.atoms.fragments) diff --git a/CodeEntropy/main.py b/CodeEntropy/main.py deleted file mode 100644 index 3bf24ee6..00000000 --- a/CodeEntropy/main.py +++ /dev/null @@ -1,28 +0,0 @@ -import logging -import sys - -from CodeEntropy.run import RunManager - -logger = logging.getLogger(__name__) - - -def main(): - """ - Main function for calculating the entropy of a system using the multiscale cell - correlation method. - """ - - # Setup initial services - folder = RunManager.create_job_folder() - - try: - run_manager = RunManager(folder=folder) - run_manager.run_entropy_workflow() - except Exception as e: - logger.critical(f"Fatal error during entropy calculation: {e}", exc_info=True) - sys.exit(1) - - -if __name__ == "__main__": - - main() # pragma: no cover diff --git a/CodeEntropy/mda_universe_operations.py b/CodeEntropy/mda_universe_operations.py deleted file mode 100644 index 306dfd3a..00000000 --- a/CodeEntropy/mda_universe_operations.py +++ /dev/null @@ -1,179 +0,0 @@ -import logging - -import MDAnalysis as mda -from MDAnalysis.analysis.base import AnalysisFromFunction -from MDAnalysis.coordinates.memory import MemoryReader - -logger = logging.getLogger(__name__) - - -class UniverseOperations: - """ - Functions to create and manipulate MDAnalysis Universe objects. - """ - - def __init__(self): - """ - Initialise class - """ - self._universe = None - - def new_U_select_frame(self, u, start=None, end=None, step=1): - """Create a reduced universe by dropping frames according to - user selection. - - Parameters - ---------- - u : MDAnalyse.Universe - A Universe object will all topology, dihedrals,coordinates and force - information - start : int or None, Optional, default: None - Frame id to start analysis. Default None will start from frame 0 - end : int or None, Optional, default: None - Frame id to end analysis. Default None will end at last frame - step : int, Optional, default: 1 - Steps between frame. - - Returns - ------- - u2 : MDAnalysis.Universe - reduced universe - """ - if start is None: - start = 0 - if end is None: - end = len(u.trajectory) - select_atom = u.select_atoms("all", updating=True) - coordinates = ( - AnalysisFromFunction(lambda ag: ag.positions.copy(), select_atom) - .run() - .results["timeseries"][start:end:step] - ) - forces = ( - AnalysisFromFunction(lambda ag: ag.forces.copy(), select_atom) - .run() - .results["timeseries"][start:end:step] - ) - dimensions = ( - AnalysisFromFunction(lambda ag: ag.dimensions.copy(), select_atom) - .run() - .results["timeseries"][start:end:step] - ) - u2 = mda.Merge(select_atom) - u2.load_new( - coordinates, format=MemoryReader, forces=forces, dimensions=dimensions - ) - logger.debug(f"MDAnalysis.Universe - reduced universe: {u2}") - - return u2 - - def new_U_select_atom(self, u, select_string="all"): - """Create a reduced universe by dropping atoms according to - user selection. - - Parameters - ---------- - u : MDAnalyse.Universe - A Universe object will all topology, dihedrals,coordinates and force - information - select_string : str, Optional, default: 'all' - MDAnalysis.select_atoms selection string. - - Returns - ------- - u2 : MDAnalysis.Universe - reduced universe - - """ - select_atom = u.select_atoms(select_string, updating=True) - coordinates = ( - AnalysisFromFunction(lambda ag: ag.positions.copy(), select_atom) - .run() - .results["timeseries"] - ) - forces = ( - AnalysisFromFunction(lambda ag: ag.forces.copy(), select_atom) - .run() - .results["timeseries"] - ) - dimensions = ( - AnalysisFromFunction(lambda ag: ag.dimensions.copy(), select_atom) - .run() - .results["timeseries"] - ) - u2 = mda.Merge(select_atom) - u2.load_new( - coordinates, format=MemoryReader, forces=forces, dimensions=dimensions - ) - logger.debug(f"MDAnalysis.Universe - reduced universe: {u2}") - - return u2 - - def get_molecule_container(self, universe, molecule_id): - """ - Extracts the atom group corresponding to a single molecule from the universe. - - Args: - universe (MDAnalysis.Universe): The reduced universe. - molecule_id (int): Index of the molecule to extract. - - Returns: - MDAnalysis.Universe: Universe containing only the selected molecule. - """ - # Identify the atoms in the molecule - frag = universe.atoms.fragments[molecule_id] - selection_string = f"index {frag.indices[0]}:{frag.indices[-1]}" - - return self.new_U_select_atom(universe, selection_string) - - def merge_forces(self, tprfile, trrfile, forcefile, fileformat=None, kcal=False): - """ - Creates a universe by merging the coordinates and forces from - different input files. - - Args: - tprfile : Topology input file - trrfile : Coordinate trajectory file - forcefile : Force trajectory file - format : Optional string for MDAnalysis identifying the file format - kcal : Optional Boolean for when the forces are in kcal not kJ - - Returns: - MDAnalysis Universe object - """ - - logger.debug(f"Loading Universe with {trrfile}") - u = mda.Universe(tprfile, trrfile, format=fileformat) - - logger.debug(f"Loading Universe with {forcefile}") - u_force = mda.Universe(tprfile, forcefile, format=fileformat) - - select_atom = u.select_atoms("all") - select_atom_force = u_force.select_atoms("all") - - coordinates = ( - AnalysisFromFunction(lambda ag: ag.positions.copy(), select_atom) - .run() - .results["timeseries"] - ) - forces = ( - AnalysisFromFunction(lambda ag: ag.positions.copy(), select_atom_force) - .run() - .results["timeseries"] - ) - - dimensions = ( - AnalysisFromFunction(lambda ag: ag.dimensions.copy(), select_atom) - .run() - .results["timeseries"] - ) - - if kcal: - # Convert from kcal to kJ - forces *= 4.184 - - logger.debug("Merging forces with coordinates universe.") - new_universe = mda.Merge(select_atom) - new_universe.load_new(coordinates, forces=forces, dimensions=dimensions) - - return new_universe diff --git a/CodeEntropy/molecules/__init__.py b/CodeEntropy/molecules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CodeEntropy/molecules/grouping.py b/CodeEntropy/molecules/grouping.py new file mode 100644 index 00000000..79f0b145 --- /dev/null +++ b/CodeEntropy/molecules/grouping.py @@ -0,0 +1,204 @@ +"""Utilities for grouping molecules for entropy analysis. + +This module provides strategies for grouping molecular fragments from an +MDAnalysis Universe into deterministic groups used for statistical averaging +during entropy calculations. + +Grouping strategies are designed to be stable and reproducible so that group +IDs remain consistent across runs given the same input system. + +Available strategies: + - each: Every molecule is treated as its own group. + - molecules: Molecules are grouped by chemical signature + (atom count and atom names in order). +""" + +import logging +from dataclasses import dataclass +from typing import Callable, Dict, List, Mapping, Sequence, Tuple + +logger = logging.getLogger(__name__) + +GroupId = int +MoleculeId = int +MoleculeGroups = Dict[GroupId, List[MoleculeId]] +Signature = Tuple[int, Tuple[str, ...]] + + +@dataclass(frozen=True) +class GroupingConfig: + """Configuration for molecule grouping. + + Attributes: + strategy: Grouping strategy name. Supported values are: + - "each": each molecule gets its own group. + - "molecules": group molecules by chemical signature + (atom count + atom names in order). + """ + + strategy: str + + +class MoleculeGrouper: + """Build groups of molecules for averaging. + + This class provides strategies for grouping molecule fragments from an + MDAnalysis Universe. Groups are returned as a mapping: + + group_id -> [molecule_id, molecule_id, ...] + + Group IDs are deterministic and stable. + + Supported strategies: + - "each": Every molecule is its own group. + - "molecules": Molecules are grouped by chemical signature + (atom count, atom names in order). The group ID is the first molecule + index observed for that signature. + """ + + def grouping_molecules(self, universe, grouping: str) -> MoleculeGroups: + """Group molecules according to a selected strategy. + + Args: + universe: MDAnalysis Universe containing atoms and fragments. + grouping: Strategy name ("each" or "molecules"). + + Returns: + A dict mapping group IDs to molecule indices. + + Raises: + ValueError: If `grouping` is not a supported strategy. + """ + config = GroupingConfig(strategy=grouping) + grouper = self._get_strategy(config.strategy) + groups = grouper(universe) + + self._log_summary(groups) + return groups + + def _get_strategy(self, strategy: str) -> Callable[[object], MoleculeGroups]: + """Resolve a strategy name to a grouping implementation. + + Args: + strategy: Strategy name. + + Returns: + Callable that accepts a Universe and returns molecule groups. + + Raises: + ValueError: If the strategy is unknown. + """ + strategies: Mapping[str, Callable[[object], MoleculeGroups]] = { + "each": self._group_each, + "molecules": self._group_by_signature, + } + + try: + return strategies[strategy] + except KeyError as exc: + raise ValueError(f"Unknown grouping strategy: {strategy!r}") from exc + + def _group_each(self, universe) -> MoleculeGroups: + """Create one group per molecule. + + Args: + universe: MDAnalysis Universe. + + Returns: + Dict where each molecule id maps to a singleton list [molecule id]. + """ + n_molecules = self._num_molecules(universe) + return {mol_id: [mol_id] for mol_id in range(n_molecules)} + + def _group_by_signature(self, universe) -> MoleculeGroups: + """Group molecules by chemical signature with stable group IDs. + + Signature is defined as: + (atom_count, atom_names_in_order) + + Group ID selection is stable and matches the previous behavior: + the first molecule index encountered for a signature is the group ID. + + Args: + universe: MDAnalysis Universe. + + Returns: + Dict mapping representative molecule id -> list of all molecule ids + sharing the same signature. + """ + fragments = self._fragments(universe) + + signature_to_rep: Dict[Signature, MoleculeId] = {} + groups: MoleculeGroups = {} + + for mol_id, fragment in enumerate(fragments): + signature = self._signature(fragment) + rep_id = self._representative_id(signature_to_rep, signature, mol_id) + groups.setdefault(rep_id, []).append(mol_id) + + return groups + + def _num_molecules(self, universe) -> int: + """Return number of molecule fragments. + + Args: + universe: MDAnalysis Universe. + + Returns: + Number of fragments (molecules). + """ + return len(self._fragments(universe)) + + def _fragments(self, universe) -> Sequence[object]: + """Return universe fragments (molecules). + + Args: + universe: MDAnalysis Universe. + + Returns: + Sequence of fragments. + """ + return universe.atoms.fragments + + def _signature(self, fragment) -> Signature: + """Build a chemical signature for a fragment. + + Args: + fragment: MDAnalysis AtomGroup representing a fragment. + + Returns: + A tuple of (atom_count, atom_names_in_order). + """ + names = tuple(fragment.names) + return (len(names), names) + + def _representative_id( + self, + signature_to_rep: Dict[Signature, MoleculeId], + signature: Signature, + candidate_id: MoleculeId, + ) -> MoleculeId: + """Return stable representative id for a signature. + + Args: + signature_to_rep: Cache mapping signature -> representative id. + signature: Chemical signature of current molecule. + candidate_id: Current molecule id. + + Returns: + Representative id for this signature (first seen molecule id). + """ + rep_id = signature_to_rep.get(signature) + if rep_id is None: + signature_to_rep[signature] = candidate_id + return candidate_id + return rep_id + + def _log_summary(self, groups: MoleculeGroups) -> None: + """Log grouping summary. + + Args: + groups: Group mapping to summarize. + """ + logger.debug("Number of molecule groups: %d", len(groups)) + logger.debug("Molecule groups: %s", groups) diff --git a/CodeEntropy/results/__init__.py b/CodeEntropy/results/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CodeEntropy/results/reporter.py b/CodeEntropy/results/reporter.py new file mode 100644 index 00000000..f9579f38 --- /dev/null +++ b/CodeEntropy/results/reporter.py @@ -0,0 +1,549 @@ +""" +Utilities for logging entropy results and exporting data. + +This module provides the ResultsReporter class, which is responsible for: + +- Collecting molecule-level entropy results +- Collecting residue-level entropy results +- Storing group metadata labels +- Rendering rich tables to the console +- Exporting results to JSON +""" + +from __future__ import annotations + +import json +import logging +import os +import platform +import re +import subprocess +import sys +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, +) +from rich.table import Table + +from CodeEntropy.core.logging import LoggingConfig + +logger = logging.getLogger(__name__) +console = LoggingConfig.get_console() + + +class _RichProgressSink: + """Thin wrapper around rich.Progress. + + Keeps Rich usage inside the reporting layer so compute/orchestration code + can emit progress without importing Rich. + """ + + def __init__(self, progress: Progress): + """Initialise a progress sink that delegates to a rich.Progress instance. + + Args: + progress: Rich Progress instance used to create/update/advance tasks. + """ + self._progress = progress + + def add_task(self, description: str, total: int, **fields): + """Add a progress task to the underlying rich.Progress instance. + + Args: + description: Task description shown by Rich. + total: Total number of steps for the task. + **fields: Additional Rich task fields (e.g., title). + + Returns: + The task id returned by rich.Progress.add_task. + """ + fields.setdefault("title", "") + return self._progress.add_task(description, total=total, **fields) + + def advance(self, task_id, step: int = 1) -> None: + """Advance a progress task by a number of steps. + + Args: + task_id: Rich task identifier. + step: Number of steps to advance the task by. + """ + self._progress.advance(task_id, step) + + def update(self, task_id, **fields) -> None: + """Update fields for an existing progress task. + + Args: + task_id: Rich task identifier. + **fields: Task fields to update. If "title" is provided as None, it is + coerced to an empty string for compatibility with Rich rendering. + """ + if "title" in fields and fields["title"] is None: + fields["title"] = "" + self._progress.update(task_id, **fields) + + +class ResultsReporter: + """Collect, format, and output entropy calculation results. + + This reporter accumulates: + - Molecule-level results (group_id, level, entropy_type, value) + - Residue-level results (group_id, resname, level, entropy_type, frame_count, + value) + - Group metadata labels (label, residue_count, atom_count) + + It can render tables using Rich and export grouped results to JSON with basic + provenance metadata. + """ + + def __init__(self, console: Optional[Console] = None) -> None: + """Initialise a ResultsReporter. + + Args: + console: Optional Rich Console to use for rendering. If None, a default + Console instance is created. + """ + self.console: Console = console or Console() + self.molecule_data: List[Tuple[Any, Any, Any, Any]] = [] + self.residue_data: List[List[Any]] = [] + self.group_labels: Dict[Any, Dict[str, Any]] = {} + + @staticmethod + def clean_residue_name(resname: Any) -> str: + """Clean residue name by removing dash-like characters. + + Args: + resname: Residue name (any type, will be converted to str). + + Returns: + Residue name with dash-like characters removed. + """ + return re.sub(r"[-–—]", "", str(resname)) + + @staticmethod + def _gid_sort_key(x: Any) -> Tuple[int, Any]: + """Stable sort key for group IDs. + + Group IDs may be numeric strings, ints, or other objects. + + Returns a tuple (rank, value): + - numeric IDs -> (0, int_value) + - non-numeric -> (1, str_value) + + Args: + x: Group identifier. + + Returns: + Tuple used as a stable sorting key. + """ + sx = str(x) + try: + return (0, int(sx)) + except Exception: + return (1, sx) + + @staticmethod + def _safe_float(value: Any) -> Optional[float]: + """Convert value to float if possible; otherwise return None. + + Args: + value: Value to convert. + + Returns: + Float representation of value, or None if conversion is not possible or + value is a boolean. + """ + try: + if isinstance(value, bool): + return None + return float(value) + except Exception: + return None + + def add_results_data( + self, group_id: Any, level: str, entropy_type: str, value: Any + ) -> None: + """Add molecule-level entropy result. + + Args: + group_id: Group identifier. + level: Hierarchy level label. + entropy_type: Entropy component/type label. + value: Result value to store (kept as-is). + """ + self.molecule_data.append((group_id, level, entropy_type, value)) + + def add_residue_data( + self, + group_id: Any, + resname: str, + level: str, + entropy_type: str, + frame_count: Any, + value: Any, + ) -> None: + """Add residue-level entropy result. + + Args: + group_id: Group identifier. + resname: Residue name (will be cleaned to remove dash-like characters). + level: Hierarchy level label. + entropy_type: Entropy component/type label. + frame_count: Number of frames contributing to the value (may be ndarray). + value: Result value to store (kept as-is). + """ + resname = self.clean_residue_name(resname) + if isinstance(frame_count, np.ndarray): + frame_count = frame_count.tolist() + self.residue_data.append( + [group_id, resname, level, entropy_type, frame_count, value] + ) + + def add_group_label( + self, + group_id: Any, + label: str, + residue_count: Optional[int] = None, + atom_count: Optional[int] = None, + ) -> None: + """Store metadata label for a group. + + Args: + group_id: Group identifier. + label: Human-readable label for the group. + residue_count: Optional residue count for the group. + atom_count: Optional atom count for the group. + """ + self.group_labels[group_id] = { + "label": label, + "residue_count": residue_count, + "atom_count": atom_count, + } + + def log_tables(self) -> None: + """Render all collected data as Rich tables (grouped by group id).""" + self._log_grouped_results_tables() + self._log_residue_table_grouped() + self._log_group_label_table() + + def _log_grouped_results_tables(self) -> None: + """Print molecule-level results grouped by group_id with components + total. + + Results are grouped by group_id and rendered as separate tables per group. + """ + if not self.molecule_data: + return + + grouped: Dict[Any, List[Tuple[Any, Any, Any, Any]]] = {} + for row in self.molecule_data: + gid = row[0] + grouped.setdefault(gid, []).append(row) + + for gid in sorted(grouped.keys(), key=self._gid_sort_key): + label = self.group_labels.get(gid, {}).get("label", "") + title = f"Entropy Results — Group {gid}" + (f" ({label})" if label else "") + + table = Table(title=title, show_lines=True, expand=True) + table.add_column("Level", justify="center", style="magenta") + table.add_column("Type", justify="center", style="green") + table.add_column("Result (J/mol/K)", justify="center", style="yellow") + + rows = grouped[gid] + non_total: List[Tuple[str, str, Any]] = [] + totals: List[Tuple[str, str, Any]] = [] + + for _gid, level, typ, val in rows: + level_s = str(level) + typ_s = str(typ) + is_total = level_s.lower().startswith( + "group total" + ) or typ_s.lower().startswith("group total") + if is_total: + totals.append((level_s, typ_s, val)) + else: + non_total.append((level_s, typ_s, val)) + + for level_s, typ_s, val in sorted(non_total, key=lambda r: (r[0], r[1])): + table.add_row(level_s, typ_s, str(val)) + + for level_s, typ_s, val in totals: + table.add_row(level_s, typ_s, str(val)) + + console.print(table) + + def _log_residue_table_grouped(self) -> None: + """Render residue entropy table grouped by group id.""" + if not self.residue_data: + return + + grouped: Dict[Any, List[List[Any]]] = {} + for row in self.residue_data: + gid = row[0] + grouped.setdefault(gid, []).append(row) + + for gid in sorted(grouped.keys(), key=self._gid_sort_key): + label = self.group_labels.get(gid, {}).get("label", "") + title = f"Residue Entropy — Group {gid}" + (f" ({label})" if label else "") + + table = Table(title=title, show_lines=True, expand=True) + table.add_column("Residue Name", justify="center", style="cyan") + table.add_column("Level", justify="center", style="magenta") + table.add_column("Type", justify="center", style="green") + table.add_column("Count", justify="center", style="green") + table.add_column("Result (J/mol/K)", justify="center", style="yellow") + + for _gid, resname, level, typ, count, val in grouped[gid]: + table.add_row(str(resname), str(level), str(typ), str(count), str(val)) + + console.print(table) + + def _log_group_label_table(self) -> None: + """Render group label metadata table.""" + if not self.group_labels: + return + + table = Table(title="Group Metadata", show_lines=True, expand=True) + table.add_column("Group ID", justify="center", style="bold cyan") + table.add_column("Label", justify="center", style="green") + table.add_column("Residue Count", justify="center", style="magenta") + table.add_column("Atom Count", justify="center", style="yellow") + + for group_id in sorted(self.group_labels.keys(), key=self._gid_sort_key): + info = self.group_labels[group_id] + table.add_row( + str(group_id), + str(info.get("label", "")), + str(info.get("residue_count", "")), + str(info.get("atom_count", "")), + ) + + console.print(table) + + def save_dataframes_as_json( + self, + molecule_df, + residue_df, + output_file: str, + *, + args: Optional[Any] = None, + include_raw_tables: bool = False, + ) -> None: + """Save results to a grouped JSON structure. + + JSON contains: + - args: arguments used (serialized) + - provenance: version, python, platform, optional git sha + - groups: { "": { components: {...}, total: ... } } + + Args: + molecule_df: Pandas DataFrame containing molecule results. + residue_df: Pandas DataFrame containing residue results. + output_file: Path to JSON output file. + args: Optional argparse Namespace or dict of arguments used. + include_raw_tables: If True, also include old "molecule_data"/"residue_data" + arrays for debugging/backwards-compat. + """ + payload = self._build_grouped_payload( + molecule_df=molecule_df, + residue_df=residue_df, + args=args, + include_raw_tables=include_raw_tables, + ) + + with open(output_file, "w") as out: + json.dump(payload, out, indent=2) + + def _build_grouped_payload( + self, + *, + molecule_df, + residue_df, + args: Optional[Any], + include_raw_tables: bool, + ) -> Dict[str, Any]: + """Build a grouped JSON-serializable payload from result dataframes. + + Args: + molecule_df: Pandas DataFrame containing molecule results. + residue_df: Pandas DataFrame containing residue results. + args: Optional argparse Namespace or dict of arguments used. + include_raw_tables: If True, include raw dataframe record arrays in payload. + + Returns: + Dictionary payload suitable for JSON serialization. + """ + mol_rows = molecule_df.to_dict(orient="records") + res_rows = residue_df.to_dict(orient="records") + + groups: Dict[str, Dict[str, Any]] = {} + + for row in mol_rows: + gid = str(row.get("Group ID")) + level = str(row.get("Level")) + typ = str(row.get("Type")) + + raw_val = row.get("Result (J/mol/K)") + val = self._safe_float(raw_val) + if val is None: + continue + + groups.setdefault(gid, {"components": {}, "total": None}) + + is_total = level.lower().startswith( + "group total" + ) or typ.lower().startswith("group total") + if is_total: + groups[gid]["total"] = val + else: + key = f"{level}:{typ}" + groups[gid]["components"][key] = val + + for _gid, g in groups.items(): + if g["total"] is None: + comps = g["components"].values() + g["total"] = float(sum(comps)) if comps else 0.0 + + payload: Dict[str, Any] = { + "args": self._serialize_args(args), + "provenance": self._provenance(), + "groups": groups, + } + + if include_raw_tables: + payload["molecule_data"] = mol_rows + payload["residue_data"] = res_rows + + return payload + + @staticmethod + def _serialize_args(args: Optional[Any]) -> Dict[str, Any]: + """Turn argparse Namespace / dict / object into a JSON-serializable dict. + + Args: + args: argparse Namespace, dict, or other object with __dict__/iterable. + + Returns: + JSON-serializable dict of argument values. Unsupported/unreadable inputs + return an empty dict. + """ + if args is None: + return {} + + if isinstance(args, dict): + base = dict(args) + else: + base = getattr(args, "__dict__", None) + if not base: + try: + base = dict(args) + except Exception: + return {} + + out: Dict[str, Any] = {} + for k, v in base.items(): + if isinstance(v, np.ndarray): + out[k] = v.tolist() + elif isinstance(v, Path): + out[k] = str(v) + else: + out[k] = v + return out + + @staticmethod + def _provenance() -> Dict[str, Any]: + """Build a provenance dictionary for exported results. + + Returns: + Dictionary with python version, platform string, CodeEntropy package + version (if available), and git sha (if available). + """ + prov: Dict[str, Any] = { + "python": sys.version.split()[0], + "platform": platform.platform(), + } + + try: + from importlib.metadata import version + + prov["codeentropy_version"] = version("CodeEntropy") + except Exception: + prov["codeentropy_version"] = None + + prov["git_sha"] = ResultsReporter._try_get_git_sha() + return prov + + @staticmethod + def _try_get_git_sha() -> Optional[str]: + """Try to determine the current git commit SHA. + + The SHA is obtained from: + 1) Environment variable CODEENTROPY_GIT_SHA, if set. + 2) A git repository discovered by walking up from this file path and + running `git rev-parse HEAD`. + + Returns: + Git SHA string if found, otherwise None. + """ + env_sha = os.environ.get("CODEENTROPY_GIT_SHA") + if env_sha: + return env_sha + + try: + here = Path(__file__).resolve() + repo_guess = here.parents[2] + + if not (repo_guess / ".git").exists(): + for p in here.parents: + if (p / ".git").exists(): + repo_guess = p + break + else: + return None + + proc = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=str(repo_guess), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if proc.returncode != 0: + return None + sha = (proc.stdout or "").strip() + return sha or None + except Exception: + return None + + @contextmanager + def progress(self, *, transient: bool = True): + """Create a workflow progress context. + + Usage: + with reporter.progress() as p: + ... + + Args: + transient: Whether the progress display should be removed on exit. + + Yields: + A _RichProgressSink that exposes add_task(), update(), and advance(). + """ + progress = Progress( + SpinnerColumn(), + TextColumn("[bold blue]{task.fields[title]}", justify="right"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.1f}%"), + TimeElapsedColumn(), + console=self.console, + transient=transient, + ) + with progress: + yield _RichProgressSink(progress) diff --git a/CodeEntropy/run.py b/CodeEntropy/run.py deleted file mode 100644 index 20cbae6d..00000000 --- a/CodeEntropy/run.py +++ /dev/null @@ -1,343 +0,0 @@ -import logging -import os -import pickle - -import MDAnalysis as mda -import requests -import yaml -from art import text2art -from rich.align import Align -from rich.console import Group -from rich.padding import Padding -from rich.panel import Panel -from rich.rule import Rule -from rich.table import Table -from rich.text import Text - -from CodeEntropy.config.arg_config_manager import ConfigManager -from CodeEntropy.config.data_logger import DataLogger -from CodeEntropy.config.logging_config import LoggingConfig -from CodeEntropy.dihedral_tools import DihedralAnalysis -from CodeEntropy.entropy import EntropyManager -from CodeEntropy.group_molecules import GroupMolecules -from CodeEntropy.levels import LevelManager -from CodeEntropy.mda_universe_operations import UniverseOperations - -logger = logging.getLogger(__name__) -console = LoggingConfig.get_console() - - -class RunManager: - """ - Handles the setup and execution of entropy analysis runs, including configuration - loading, logging, and access to physical constants used in calculations. - """ - - def __init__(self, folder): - """ - Initializes the RunManager with the working folder and sets up configuration, - data logging, and logging systems. Also defines physical constants used in - entropy calculations. - """ - self.folder = folder - self._config_manager = ConfigManager() - self._data_logger = DataLogger() - self._logging_config = LoggingConfig(folder) - self._N_AVOGADRO = 6.0221415e23 - self._DEF_TEMPER = 298 - - @property - def N_AVOGADRO(self): - """Returns Avogadro's number used in entropy calculations.""" - return self._N_AVOGADRO - - @property - def DEF_TEMPER(self): - """Returns the default temperature (in Kelvin) used in the analysis.""" - return self._DEF_TEMPER - - @staticmethod - def create_job_folder(): - """ - Create a new job folder with an incremented job number based on existing - folders. - """ - # Get the current working directory - current_dir = os.getcwd() - - # Get a list of existing folders that start with "job" - existing_folders = [f for f in os.listdir(current_dir) if f.startswith("job")] - - # Extract numbers from existing folder names - job_numbers = [] - for folder in existing_folders: - try: - # Assuming folder names are in the format "jobXXX" - job_number = int(folder[3:]) # Get the number part after "job" - job_numbers.append(job_number) - except ValueError: - continue # Ignore any folder names that don't follow the pattern - - # If no folders exist, start with job001 - if not job_numbers: - next_job_number = 1 - else: - next_job_number = max(job_numbers) + 1 - - # Create the new job folder name - new_job_folder = f"job{next_job_number:03d}" - - # Create the full path to the new folder - new_folder_path = os.path.join(current_dir, new_job_folder) - - # Create the directory - os.makedirs(new_folder_path, exist_ok=True) - - # Return the path of the newly created folder - return new_folder_path - - def load_citation_data(self): - """ - Load CITATION.cff from GitHub into memory. - Return empty dict if offline. - """ - url = ( - "https://raw.githubusercontent.com/CCPBioSim/" - "CodeEntropy/refs/heads/main/CITATION.cff" - ) - try: - response = requests.get(url, timeout=10) - response.raise_for_status() - return yaml.safe_load(response.text) - except requests.exceptions.RequestException: - return None - - def show_splash(self): - """Render splash screen with optional citation metadata.""" - citation = self.load_citation_data() - - if citation: - # ASCII Title - ascii_title = text2art(citation.get("title", "CodeEntropy")) - ascii_render = Align.center(Text(ascii_title, style="bold white")) - - # Metadata - version = citation.get("version", "?") - release_date = citation.get("date-released", "?") - url = citation.get("url", citation.get("repository-code", "")) - - version_text = Align.center( - Text(f"Version {version} | Released {release_date}", style="green") - ) - url_text = Align.center(Text(url, style="blue underline")) - - # Description block - abstract = citation.get("abstract", "No description available.") - description_title = Align.center( - Text("Description", style="bold magenta underline") - ) - description_body = Align.center( - Padding(Text(abstract, style="white", justify="left"), (0, 4)) - ) - - # Contributors table - contributors_title = Align.center( - Text("Contributors", style="bold magenta underline") - ) - - author_table = Table( - show_header=True, header_style="bold yellow", box=None, pad_edge=False - ) - author_table.add_column("Name", style="bold", justify="center") - author_table.add_column("Affiliation", justify="center") - - for author in citation.get("authors", []): - name = ( - f"{author.get('given-names', '')} {author.get('family-names', '')}" - ).strip() - affiliation = author.get("affiliation", "") - author_table.add_row(name, affiliation) - - contributors_table = Align.center(Padding(author_table, (0, 4))) - - # Full layout - splash_content = Group( - ascii_render, - Rule(style="cyan"), - version_text, - url_text, - Text(), - description_title, - description_body, - Text(), - contributors_title, - contributors_table, - ) - else: - # ASCII Title - ascii_title = text2art("CodeEntropy") - ascii_render = Align.center(Text(ascii_title, style="bold white")) - - splash_content = Group( - ascii_render, - ) - - splash_panel = Panel( - splash_content, - title="[bold bright_cyan]Welcome to CodeEntropy", - title_align="center", - border_style="bright_cyan", - padding=(1, 4), - expand=True, - ) - - console.print(splash_panel) - - def print_args_table(self, args): - table = Table(title="Run Configuration", expand=True) - - table.add_column("Argument", style="cyan", no_wrap=True) - table.add_column("Value", style="magenta") - - for arg in vars(args): - table.add_row(arg, str(getattr(args, arg))) - - console.print(table) - - def run_entropy_workflow(self): - """ - Runs the entropy analysis workflow by setting up logging, loading configuration - files, parsing arguments, and executing the analysis for each configured run. - Initializes the MDAnalysis Universe and supporting managers, and logs all - relevant inputs and commands. - """ - try: - logger = self._logging_config.setup_logging() - self.show_splash() - - current_directory = os.getcwd() - - config = self._config_manager.load_config(current_directory) - parser = self._config_manager.setup_argparse() - args, _ = parser.parse_known_args() - args.output_file = os.path.join(self.folder, args.output_file) - - for run_name, run_config in config.items(): - if not isinstance(run_config, dict): - logger.warning( - f"Run configuration for {run_name} is not a dictionary." - ) - continue - - args = self._config_manager.merge_configs(args, run_config) - - log_level = logging.DEBUG if args.verbose else logging.INFO - self._logging_config.update_logging_level(log_level) - - command = " ".join(os.sys.argv) - logging.getLogger("commands").info(command) - - if not getattr(args, "top_traj_file", None): - raise ValueError("Missing 'top_traj_file' argument.") - if not getattr(args, "selection_string", None): - raise ValueError("Missing 'selection_string' argument.") - - self.print_args_table(args) - - # Load MDAnalysis Universe - tprfile = args.top_traj_file[0] - trrfile = args.top_traj_file[1:] - forcefile = args.force_file - fileformat = args.file_format - kcal_units = args.kcal_force_units - - # Create shared UniverseOperations instance - universe_operations = UniverseOperations() - - if forcefile is None: - logger.debug(f"Loading Universe with {tprfile} and {trrfile}") - u = mda.Universe(tprfile, trrfile, format=fileformat) - else: - u = universe_operations.merge_forces( - tprfile, trrfile, forcefile, fileformat, kcal_units - ) - - self._config_manager.input_parameters_validation(u, args) - - # Create LevelManager instance - level_manager = LevelManager(universe_operations) - - # Create GroupMolecules instance - group_molecules = GroupMolecules() - - # Create shared DihedralAnalysis with injected universe_operations - dihedral_analysis = DihedralAnalysis( - universe_operations=universe_operations - ) - - # Inject all dependencies into EntropyManager - entropy_manager = EntropyManager( - run_manager=self, - args=args, - universe=u, - data_logger=self._data_logger, - level_manager=level_manager, - group_molecules=group_molecules, - dihedral_analysis=dihedral_analysis, - universe_operations=universe_operations, - ) - - entropy_manager.execute() - - self._logging_config.save_console_log() - - except Exception as e: - logger.error(f"RunManager encountered an error: {e}", exc_info=True) - raise - - def write_universe(self, u, name="default"): - """Write a universe to working directories as pickle - - Parameters - ---------- - u : MDAnalyse.Universe - A Universe object will all topology, dihedrals,coordinates and force - information - name : str, Optional. default: 'default' - The name of file with sub file name .pkl - - Returns - ------- - name : str - filename of saved universe - """ - filename = f"{name}.pkl" - pickle.dump(u, open(filename, "wb")) - return name - - def read_universe(self, path): - """read a universe to working directories as pickle - - Parameters - ---------- - path : str - The path to file. - - Returns - ------- - u : MDAnalysis.Universe - A Universe object will all topology, dihedrals,coordinates and force - information. - """ - u = pickle.load(open(path, "rb")) - return u - - def change_lambda_units(self, arg_lambdas): - """Unit of lambdas : kJ2 mol-2 A-2 amu-1 - change units of lambda to J/s2""" - # return arg_lambdas * N_AVOGADRO * N_AVOGADRO * AMU2KG * 1e-26 - return arg_lambdas * 1e29 / self.N_AVOGADRO - - def get_KT2J(self, arg_temper): - """A temperature dependent KT to Joule conversion""" - return 4.11e-21 * arg_temper / self.DEF_TEMPER diff --git a/README.md b/README.md index 2eadcccf..00a9a818 100644 --- a/README.md +++ b/README.md @@ -3,23 +3,25 @@ CodeEntropy | Category | Badges | |----------------|--------| -| **Build** | [![CodeEntropy CI](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/project-ci.yaml/badge.svg)](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/project-ci.yaml) | -| **Documentation** | [![Docs - Status](https://app.readthedocs.org/projects/codeentropy/badge/?version=latest)](https://codeentropy.readthedocs.io/en/latest/?badge=latest) | +| **Build** | [![PR Checks](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/pr.yaml/badge.svg)](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/pr.yaml) [![Daily Tests](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/daily.yaml/badge.svg)](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/daily.yaml) | +| **Regression** | [![Weekly Regression](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/weekly-regression.yaml/badge.svg)](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/weekly-regression.yaml) | +| **Documentation** | [![Weekly Docs](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/weekly-docs.yaml/badge.svg)](https://github.com/CCPBioSim/CodeEntropy/actions/workflows/weekly-docs.yaml) [![Docs - Status](https://app.readthedocs.org/projects/codeentropy/badge/?version=latest)](https://codeentropy.readthedocs.io/en/latest/?badge=latest) | | **Citation** | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17570721.svg)](https://doi.org/10.5281/zenodo.17570721) | -| **PyPI** | ![PyPI - Status](https://img.shields.io/pypi/status/codeentropy?logo=pypi&logoColor=white) ![PyPI - Version](https://img.shields.io/pypi/v/codeentropy?logo=pypi&logoColor=white) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/CodeEntropy) ![PyPI - Total Downloads](https://img.shields.io/pepy/dt/codeentropy?logo=pypi&logoColor=white&color=blue) ![PyPI - Monthly Downloads](https://img.shields.io/pypi/dm/CodeEntropy?logo=pypi&logoColor=white&color=blue)| +| **PyPI** | ![PyPI - Status](https://img.shields.io/pypi/status/codeentropy?logo=pypi&logoColor=white) ![PyPI - Version](https://img.shields.io/pypi/v/codeentropy?logo=pypi&logoColor=white) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/CodeEntropy) ![PyPI - Total Downloads](https://img.shields.io/pepy/dt/codeentropy?logo=pypi&logoColor=white&color=blue) ![PyPI - Monthly Downloads](https://img.shields.io/pypi/dm/CodeEntropy?logo=pypi&logoColor=white&color=blue) | | **Quality** | [![Coverage Status](https://coveralls.io/repos/github/CCPBioSim/CodeEntropy/badge.svg?branch=main)](https://coveralls.io/github/CCPBioSim/CodeEntropy?branch=main) | CodeEntropy is a Python package for computing the configurational entropy of macromolecular systems using forces sampled from molecular dynamics (MD) simulations. It implements the multiscale cell correlation method to provide accurate and efficient entropy estimates, supporting a wide range of applications in molecular simulation and statistical mechanics.

-CodeEntropy logo + CodeEntropy logo + CodeEntropy logo

See [CodeEntropy’s documentation](https://codeentropy.readthedocs.io/en/latest/) for more information. ## Acknowledgements - -Project based on + +Project based on - [arghya90/CodeEntropy](https://github.com/arghya90/CodeEntropy) version 0.3 - [jkalayan/PoseidonBeta](https://github.com/jkalayan/PoseidonBeta) diff --git a/docs/Makefile b/docs/Makefile index 08b52d2c..1f4fb783 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -17,4 +17,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/README.md b/docs/README.md index cc030c72..79687e62 100644 --- a/docs/README.md +++ b/docs/README.md @@ -14,11 +14,10 @@ Once installed, you can use the `Makefile` in this directory to compile static H make html ``` -The compiled docs will be in the `_build` directory and can be viewed by opening `index.html` (which may itself +The compiled docs will be in the `_build` directory and can be viewed by opening `index.html` (which may itself be inside a directory called `html/` depending on what version of Sphinx is installed). A configuration file for [Read The Docs](https://readthedocs.org/) (readthedocs.yaml) is included in the top level of the repository. To use Read the Docs to host your documentation, go to https://readthedocs.org/ and connect this repository. You may need to change your default branch to `main` under Advanced Settings for the project. If you would like to use Read The Docs with `autodoc` (included automatically) and your package has dependencies, you will need to include those dependencies in your documentation yaml file (`docs/requirements.yaml`). - diff --git a/docs/_static/README.md b/docs/_static/README.md index 2f0cf843..122b610b 100644 --- a/docs/_static/README.md +++ b/docs/_static/README.md @@ -1,11 +1,11 @@ # Static Doc Directory Add any paths that contain custom static files (such as style sheets) here, -relative to the `conf.py` file's directory. +relative to the `conf.py` file's directory. They are copied after the builtin static files, so a file named "default.css" will overwrite the builtin "default.css". -The path to this folder is set in the Sphinx `conf.py` file in the line: +The path to this folder is set in the Sphinx `conf.py` file in the line: ```python templates_path = ['_static'] ``` diff --git a/docs/_static/custom.css b/docs/_static/custom.css index 07d420be..7d2e89e6 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -1,3 +1,3 @@ .tight-table td { white-space: normal !important; -} \ No newline at end of file +} diff --git a/docs/_templates/README.md b/docs/_templates/README.md index 3f4f8043..485f82ad 100644 --- a/docs/_templates/README.md +++ b/docs/_templates/README.md @@ -1,11 +1,11 @@ # Templates Doc Directory -Add any paths that contain templates here, relative to +Add any paths that contain templates here, relative to the `conf.py` file's directory. They are copied after the builtin template files, so a file named "page.html" will overwrite the builtin "page.html". -The path to this folder is set in the Sphinx `conf.py` file in the line: +The path to this folder is set in the Sphinx `conf.py` file in the line: ```python html_static_path = ['_templates'] ``` diff --git a/docs/api.rst b/docs/api.rst index de67a2a3..cc91d767 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -57,7 +57,7 @@ Vibrational Entropy CodeEntropy.entropy.VibrationalEntropy CodeEntropy.entropy.VibrationalEntropy.frequency_calculation CodeEntropy.entropy.VibrationalEntropy.vibrational_entropy_calculation - + Conformational Entropy ^^^^^^^^^^^^^^^^^^^^^^ @@ -67,4 +67,3 @@ Conformational Entropy CodeEntropy.entropy.ConformationalEntropy CodeEntropy.entropy.ConformationalEntropy.assign_conformation CodeEntropy.entropy.ConformationalEntropy.conformational_entropy_calculation - diff --git a/docs/conf.py b/docs/conf.py index 2ee1e608..cf51cf86 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -101,7 +101,10 @@ # a list of builtin themes. # html_theme = "furo" -html_logo = "images/biosim-codeentropy_logo_grey.svg" +html_theme_options = { + "light_logo": "images/biosim-codeentropy_logo_light.png", + "dark_logo": "images/biosim-codeentropy_logo_dark.png", +} # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the diff --git a/docs/config.yaml b/docs/config.yaml index 3f6fd8f8..3b1240ae 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -3,7 +3,7 @@ run1: top_traj_file: ["1AKI_prod.tpr", "1AKI_prod.trr"] force_file: - file_formate: + file_formate: selection_string: 'all' start: 0 end: 500 diff --git a/docs/developer_guide.rst b/docs/developer_guide.rst index 416c46e1..d747b563 100644 --- a/docs/developer_guide.rst +++ b/docs/developer_guide.rst @@ -1,7 +1,10 @@ Developer Guide =============== -CodeEntropy is open-source, and we welcome contributions from the wider community to help improve and extend its functionality. This guide walks you through setting up a development environment, running tests, submitting contributions, and maintaining coding standards. +CodeEntropy is open-source, and we welcome contributions from the wider community +to help improve and extend its functionality. This guide walks you through setting +up a development environment, running tests, contributing code, and understanding +the continuous integration workflows. Getting Started for Developers ------------------------------ @@ -24,44 +27,77 @@ Install development dependencies:: Running Tests ------------- -Run the full test suite:: +CodeEntropy uses **pytest** with separate unit and regression suites. - pytest -v +Run all tests:: -Run tests with code coverage:: + pytest + +Run only unit tests:: + + pytest tests/unit + +Run regression tests:: + + pytest tests/regression + +Run regression tests excluding slow systems:: + + pytest tests/regression -m "not slow" + +Run slow regression tests:: + + pytest tests/regression -m slow + +Run tests with coverage:: pytest --cov CodeEntropy --cov-report=term-missing -Run tests for a specific module:: +Update regression baselines:: - pytest CodeEntropy/tests/test_CodeEntropy/test_levels.py + pytest tests/regression --update-baselines Run a specific test:: - pytest CodeEntropy/tests/test_CodeEntropy/test_levels.py::test_select_levels + pytest tests/unit/.../test_file.py::test_function + +Regression Test Data +-------------------- + +Regression datasets are automatically downloaded from the CCPBioSim filestore +and cached locally in ``.testdata/`` when tests are run. + +No manual setup is required. + +The test configuration files reference datasets using the ``${TESTDATA}`` +placeholder, which is expanded automatically during test execution. Coding Standards ---------------- -We use **pre-commit hooks** to maintain code quality and consistent style. To enable these hooks:: +We use **pre-commit hooks** to maintain code quality and consistent style. + +Enable hooks:: pre-commit install -This ensures: +Our tooling stack: -- **Formatting** via ``black`` (`psf/black`) -- **Import sorting** via ``isort`` with the ``black`` profile -- **Linting** via ``flake8`` with ``flake8-pyproject`` -- **Basic checks** via ``pre-commit-hooks``, including: - - - Detection of large added files - - AST validity checks - - Case conflict detection - - Executable shebang verification - - Merge conflict detection - - TOML and YAML syntax validation +- **Linting and formatting** via ``ruff`` +- **Basic repository checks** via ``pre-commit-hooks`` -To skip pre-commit checks for a commit:: +Ruff performs: + +- Code formatting +- Import sorting +- Static analysis +- Style enforcement + +Run checks manually:: + + pre-commit run --all-files + +Skip checks for a commit (not recommended):: git commit -n @@ -72,33 +108,50 @@ To skip pre-commit checks for a commit:: Continuous Integration (CI) --------------------------- -CodeEntropy uses **GitHub Actions** to automatically: +CodeEntropy uses **GitHub Actions** with multiple workflows to ensure stability +across platforms and Python versions. + +Pull Request checks include: + +- Unit tests on Linux, macOS, and Windows +- Python versions 3.12-3.14 +- Quick regression tests +- Documentation build +- Pre-commit validation + +Daily workflow: -- Run all tests -- Check coding style -- Build documentation -- Validate versioning +- Runs automated test validation -Every pull request will trigger these checks to ensure quality and stability. +Weekly workflows: + +- Full regression suite including slow tests +- Documentation build across all Python versions + +CI also caches regression datasets to improve performance. Building Documentation ---------------------- -Build locally:: +Build documentation locally:: cd docs make html -The generated HTML files will be in ``docs/build/html/``. Open ``index.html`` in your browser to view the documentation. +The generated HTML files will be in ``docs/build/html/``. + +Open ``index.html`` in your browser to preview. -Edit docs in the following directories: +Documentation sources are located in: - ``docs/user_guide/`` - ``docs/developer_guide/`` Contributing Code ----------------- -If you would to contribution to **CodeEntropy** please refer to our `Contributing Guidelines `_ + +If you would like to contribute to **CodeEntropy**, please refer to the +`Contributing Guidelines `_. Creating an Issue ^^^^^^^^^^^^^^^^^ @@ -114,7 +167,7 @@ Branching - Never commit directly to ``main``. - Create a branch named after the issue:: - git checkout -b 123-fix-levels + git checkout -b 123-feature-description Pull Requests ^^^^^^^^^^^^^ @@ -132,6 +185,6 @@ Full developer setup:: git clone https://github.com/CCPBioSim/CodeEntropy.git cd CodeEntropy - pip install -e .[testing,docs,pre-commit] + pip install -e ".[testing,docs,pre-commit]" pre-commit install - pytest --cov CodeEntropy --cov-report=term-missing + pytest diff --git a/docs/faq.rst b/docs/faq.rst index a119a272..7ef18da6 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -4,15 +4,15 @@ Frequently asked questions Why do I get a ``WARNING`` about invalid eigenvalues? ----------------------------------------------------- -Insufficient sampling might introduce noise and cause matrix elements to deviate to values that would not reflect the uncorrelated nature of force-force covariance of distantly positioned residues. -Try increasing the sampling time. -This is especially true at the residue level. +Insufficient sampling might introduce noise and cause matrix elements to deviate to values that would not reflect the uncorrelated nature of force-force covariance of distantly positioned residues. +Try increasing the sampling time. +This is especially true at the residue level. -For example in a lysozyme system, the residue level contains the largest force and torque covariance matrices because at this level we have the largest number of beads (which is equal to the number of residues in a protein) compared to the molecule level (3 beads) and united-atom level (~10 beads per amino acid). +For example in a lysozyme system, the residue level contains the largest force and torque covariance matrices because at this level we have the largest number of beads (which is equal to the number of residues in a protein) compared to the molecule level (3 beads) and united-atom level (~10 beads per amino acid). What do I do if there is an error from MDAnalysis not recognising the file type? ------------------------------------------------------------------------------- -Use the file_format option. +Use the file_format option. The MDAnalysis documentation has a list of acceptable formats: https://userguide.mdanalysis.org/1.1.1/formats/index.html#id1. The first column of their table gives you the string you need (case sensitive). diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 06f6d9ef..2f300213 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -179,7 +179,7 @@ The ``top_traj_file`` argument is required; other arguments have default values. - ``molecules`` - ``str`` * - ``--kcal_force_units`` - - Set input units as kcal/mol + - Set input units as kcal/mol - ``bool`` - ``False`` * - ``--combined_forcetorque`` @@ -228,7 +228,7 @@ Create or edit ``config.yaml`` in your working directory: start: 0 end: -1 step: 1 - + Run CodeEntropy from that directory: .. code-block:: bash diff --git a/docs/images/biosim-codeentropy_logo_dark.png b/docs/images/biosim-codeentropy_logo_dark.png new file mode 100644 index 00000000..b1856eb2 Binary files /dev/null and b/docs/images/biosim-codeentropy_logo_dark.png differ diff --git a/docs/images/biosim-codeentropy_logo_grey.svg b/docs/images/biosim-codeentropy_logo_grey.svg deleted file mode 100644 index 0a2386da..00000000 --- a/docs/images/biosim-codeentropy_logo_grey.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/docs/images/biosim-codeentropy_logo_light.png b/docs/images/biosim-codeentropy_logo_light.png new file mode 100644 index 00000000..2e4c2841 Binary files /dev/null and b/docs/images/biosim-codeentropy_logo_light.png differ diff --git a/docs/science.rst b/docs/science.rst index 23e64889..6d89005e 100644 --- a/docs/science.rst +++ b/docs/science.rst @@ -3,9 +3,9 @@ Multiscale Cell Correlation Theory This section is to describe the scientific theory behind the method used in CodeEntropy. -The multiscale cell correlation (MCC) method [1-3] has been developed in the group of Richard Henchman to calculate entropy from molecular dynamics (MD) simulations. +The multiscale cell correlation (MCC) method [1-3] has been developed in the group of Richard Henchman to calculate entropy from molecular dynamics (MD) simulations. It has been applied to liquids [1,3,4], proteins [2,5,6], solutions [6-9], and complexes [6,7]. -The purpose of this project is to develop and release well written code that enables users from any group to calculate the entropy from their simulations using the MCC. +The purpose of this project is to develop and release well written code that enables users from any group to calculate the entropy from their simulations using the MCC. The latest code can be found at github.com/ccpbiosim/codeentropy. The method requires forces to be written to the MD trajectory files along with the coordinates. @@ -42,8 +42,8 @@ Additional application examples Hierarchy --------- - -Atoms are grouped into beads. + +Atoms are grouped into beads. The levels refer to the size of the beads and the different entropy terms are calculated at each level, taking care to avoid over counting. This is done at three different levels of the hierarchy - united atom, residues, and polymers. Not all molecules have all the levels of hierarchy, for example water has only the united atom level, benzene would have united atoms and residue, and a protein would have all three levels. @@ -68,7 +68,7 @@ The axes for this transformation are calculated for each bead in each time step. For the polymer level, the translational and rotational axes are defined as the principal axes of the molecule. -For the residue level, there are two situations. +For the residue level, there are two situations. When the residue is not bonded to any other residues, the translational and rotational axes are the principal axes of the molecule. When the residue is part of a larger polymer, the translational axes are the principal axes of the polymer, and the rotational axes are defined from the average position of the bonds to neighbouring residues. @@ -80,13 +80,13 @@ Conformational Entropy This term is based on the intramolecular conformational states. -The united atom level dihedrals are defined for every linear sequence of four bonded atoms, but only using the heavy atoms no hydrogens are involved. +The united atom level dihedrals are defined for every linear sequence of four bonded atoms, but only using the heavy atoms no hydrogens are involved. The MDAnalysis package is used to identify and calculate the united atom dihedral values. -For the residue level dihedrals, the bond between the first and second residues and the bond between the third and fourth residues are found. +For the residue level dihedrals, the bond between the first and second residues and the bond between the third and fourth residues are found. The four atoms at the ends of these two bonds are used as points for the dihedral angle calculation. -To discretise dihedrals, a histogram is constructed from each set of dihedral values and peaks are identified. +To discretise dihedrals, a histogram is constructed from each set of dihedral values and peaks are identified. Then at each timestep, every dihedral is assigned to its nearest peak and a state is created from all the assigned peaks in the residue (for united atom level) or molecule (for residue level). Once the states are defined, the probability of finding the residue or molecule in each state is calculated. Then the Boltzmann equation is used to calculate the entropy: @@ -96,7 +96,7 @@ Then the Boltzmann equation is used to calculate the entropy: Orientational Entropy --------------------- -Orientational entropy is the term that comes from the molecule's environment (or the intermolecular configuration). +Orientational entropy is the term that comes from the molecule's environment (or the intermolecular configuration). The different environments are the different states for the molecule, and the statistics can be used to calculate the entropy. The simplest part is counting the number of neighbours, but symmetry should be accounted for in determining the number of orientations. diff --git a/pyproject.toml b/pyproject.toml index 42b7998d..4d8ba438 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,8 @@ dependencies = [ "python-json-logger>=4.0,<5.0", "rich>=14.2,<15.0", "art>=6.5,<7.0", + "networkx>=3.6,<3.7", + "matplotlib>=3.10,<3.11", "waterEntropy>=2,<2.1", "requests>=2.32,<3.0", ] @@ -62,10 +64,7 @@ testing = [ ] pre-commit = [ "pre-commit>=4.5,<5.0", - "black>=26.1,<27.0", - "flake8>=7.3,<8.0", - "flake8-pyproject>=1.2,<2.0", - "isort>=7.0,<8.0", + "ruff>=0.15,<0.16", "pylint>=4.0,<5.0" ] docs = [ @@ -80,14 +79,15 @@ docs = [ ] [project.scripts] -CodeEntropy = "CodeEntropy.main:main" +CodeEntropy = "CodeEntropy.cli:main" -[tool.isort] -profile = "black" +[tool.ruff] +line-length = 88 +target-version = "py311" -[tool.flake8] -max-line-length = 88 -extend-select = "B950" -extend-ignore = [ - "E203", # whitespace before `:` -] +[tool.ruff.lint] +select = ["E", "F", "I", "B"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" diff --git a/readthedocs.yml b/readthedocs.yml index 0f8e627d..b5efb63a 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -16,4 +16,4 @@ python: sphinx: # Path to your Sphinx configuration file. - configuration: docs/conf.py \ No newline at end of file + configuration: docs/conf.py diff --git a/tests/data/md_A4_dna.tpr b/tests/data/md_A4_dna.tpr deleted file mode 100644 index 1557a12e..00000000 Binary files a/tests/data/md_A4_dna.tpr and /dev/null differ diff --git a/tests/data/md_A4_dna_xf.trr b/tests/data/md_A4_dna_xf.trr deleted file mode 100644 index e77158a9..00000000 Binary files a/tests/data/md_A4_dna_xf.trr and /dev/null differ diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..64a89db8 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,10 @@ +[pytest] +testpaths = + unit + regression + +markers = + regression: end-to-end regression tests against baselines + slow: long-running regression tests (20-30+ minutes) + +addopts = -ra diff --git a/tests/regression/baselines/benzaldehyde.json b/tests/regression/baselines/benzaldehyde.json new file mode 100644 index 00000000..54103991 --- /dev/null +++ b/tests/regression/baselines/benzaldehyde.json @@ -0,0 +1,43 @@ +{ + "args": { + "top_traj_file": [ + "/home/tdo96567/BioSim/CodeEntropy/tests/data/benzaldehyde/molecules.top", + "/home/tdo96567/BioSim/CodeEntropy/tests/data/benzaldehyde/trajectory.crd" + ], + "force_file": "/home/tdo96567/BioSim/CodeEntropy/tests/data/benzaldehyde/forces.frc", + "file_format": "MDCRD", + "kcal_force_units": false, + "selection_string": "all", + "start": 0, + "end": 1, + "step": 1, + "bin_width": 30, + "temperature": 298.0, + "verbose": false, + "output_file": "/tmp/pytest-of-tdo96567/pytest-60/test_regression_matches_baseli0/job001/output_file.json", + "force_partitioning": 0.5, + "water_entropy": true, + "grouping": "molecules", + "combined_forcetorque": true, + "customised_axes": true + }, + "provenance": { + "python": "3.14.0", + "platform": "Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39", + "codeentropy_version": "1.0.7", + "git_sha": "226b37f7b206adba1b60253c41c7a0d467e75a58" + }, + "groups": { + "0": { + "components": { + "united_atom:Transvibrational": 158.90339720185818, + "united_atom:Rovibrational": 143.87250586343512, + "residue:FTmat-Transvibrational": 106.71035236014967, + "residue:FTmat-Rovibrational": 95.07735227595549, + "united_atom:Conformational": 0.0, + "residue:Conformational": 0.0 + }, + "total": 504.56360770139844 + } + } +} diff --git a/tests/regression/baselines/benzene.json b/tests/regression/baselines/benzene.json new file mode 100644 index 00000000..3517db61 --- /dev/null +++ b/tests/regression/baselines/benzene.json @@ -0,0 +1,43 @@ +{ + "args": { + "top_traj_file": [ + "/home/tdo96567/BioSim/CodeEntropy/tests/data/benzene/molecules.top", + "/home/tdo96567/BioSim/CodeEntropy/tests/data/benzene/trajectory.crd" + ], + "force_file": "/home/tdo96567/BioSim/CodeEntropy/tests/data/benzene/forces.frc", + "file_format": "MDCRD", + "kcal_force_units": false, + "selection_string": "all", + "start": 0, + "end": 1, + "step": 1, + "bin_width": 30, + "temperature": 298.0, + "verbose": false, + "output_file": "/tmp/pytest-of-tdo96567/pytest-64/test_regression_matches_baseli0/job001/output_file.json", + "force_partitioning": 0.5, + "water_entropy": true, + "grouping": "molecules", + "combined_forcetorque": true, + "customised_axes": true + }, + "provenance": { + "python": "3.14.0", + "platform": "Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39", + "codeentropy_version": "1.0.7", + "git_sha": "226b37f7b206adba1b60253c41c7a0d467e75a58" + }, + "groups": { + "0": { + "components": { + "united_atom:Transvibrational": 93.55450341182438, + "united_atom:Rovibrational": 143.68264201362132, + "residue:FTmat-Transvibrational": 108.34125737284016, + "residue:FTmat-Rovibrational": 95.57598285903227, + "united_atom:Conformational": 0.0, + "residue:Conformational": 0.0 + }, + "total": 441.15438565731813 + } + } +} diff --git a/tests/regression/baselines/cyclohexane.json b/tests/regression/baselines/cyclohexane.json new file mode 100644 index 00000000..aea6f560 --- /dev/null +++ b/tests/regression/baselines/cyclohexane.json @@ -0,0 +1,43 @@ +{ + "args": { + "top_traj_file": [ + "/home/tdo96567/BioSim/CodeEntropy/tests/data/cyclohexane/molecules.top", + "/home/tdo96567/BioSim/CodeEntropy/tests/data/cyclohexane/trajectory.crd" + ], + "force_file": "/home/tdo96567/BioSim/CodeEntropy/tests/data/cyclohexane/forces.frc", + "file_format": "MDCRD", + "kcal_force_units": false, + "selection_string": "all", + "start": 0, + "end": 1, + "step": 1, + "bin_width": 30, + "temperature": 298.0, + "verbose": false, + "output_file": "/tmp/pytest-of-tdo96567/pytest-60/test_regression_matches_baseli2/job001/output_file.json", + "force_partitioning": 0.5, + "water_entropy": true, + "grouping": "molecules", + "combined_forcetorque": true, + "customised_axes": true + }, + "provenance": { + "python": "3.14.0", + "platform": "Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39", + "codeentropy_version": "1.0.7", + "git_sha": "226b37f7b206adba1b60253c41c7a0d467e75a58" + }, + "groups": { + "0": { + "components": { + "united_atom:Transvibrational": 109.02761125847158, + "united_atom:Rovibrational": 227.2888326629934, + "residue:FTmat-Transvibrational": 106.06698045971194, + "residue:FTmat-Rovibrational": 99.10449330958527, + "united_atom:Conformational": 0.0, + "residue:Conformational": 0.0 + }, + "total": 541.4879176907622 + } + } +} diff --git a/tests/regression/baselines/dna.json b/tests/regression/baselines/dna.json new file mode 100644 index 00000000..f1d58088 --- /dev/null +++ b/tests/regression/baselines/dna.json @@ -0,0 +1,58 @@ +{ + "args": { + "top_traj_file": [ + "/home/tdo96567/BioSim/CodeEntropy/tests/data/dna/md_A4_dna.tpr", + "/home/tdo96567/BioSim/CodeEntropy/tests/data/dna/md_A4_dna_xf.trr" + ], + "force_file": null, + "file_format": null, + "kcal_force_units": false, + "selection_string": "all", + "start": 0, + "end": 1, + "step": 1, + "bin_width": 30, + "temperature": 298.0, + "verbose": false, + "output_file": "/tmp/pytest-of-tdo96567/pytest-60/test_regression_matches_baseli3/job001/output_file.json", + "force_partitioning": 0.5, + "water_entropy": true, + "grouping": "molecules", + "combined_forcetorque": true, + "customised_axes": true + }, + "provenance": { + "python": "3.14.0", + "platform": "Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39", + "codeentropy_version": "1.0.7", + "git_sha": "226b37f7b206adba1b60253c41c7a0d467e75a58" + }, + "groups": { + "0": { + "components": { + "united_atom:Transvibrational": 0.0, + "united_atom:Rovibrational": 0.002160679012128457, + "residue:Transvibrational": 0.0, + "residue:Rovibrational": 3.376800684085249, + "polymer:FTmat-Transvibrational": 12.341104347192612, + "polymer:FTmat-Rovibrational": 0.0, + "united_atom:Conformational": 7.269386795471401, + "residue:Conformational": 0.0 + }, + "total": 22.989452505761392 + }, + "1": { + "components": { + "united_atom:Transvibrational": 0.0, + "united_atom:Rovibrational": 0.01846427765949586, + "residue:Transvibrational": 0.0, + "residue:Rovibrational": 2.3863201082544565, + "polymer:FTmat-Transvibrational": 11.11037253388596, + "polymer:FTmat-Rovibrational": 0.0, + "united_atom:Conformational": 6.410455987098191, + "residue:Conformational": 0.46183561256411515 + }, + "total": 20.387448519462218 + } + } +} diff --git a/tests/regression/baselines/ethyl-acetate.json b/tests/regression/baselines/ethyl-acetate.json new file mode 100644 index 00000000..e9614b14 --- /dev/null +++ b/tests/regression/baselines/ethyl-acetate.json @@ -0,0 +1,43 @@ +{ + "args": { + "top_traj_file": [ + "/home/tdo96567/BioSim/CodeEntropy/tests/data/ethyl-acetate/molecules.top", + "/home/tdo96567/BioSim/CodeEntropy/tests/data/ethyl-acetate/trajectory.crd" + ], + "force_file": "/home/tdo96567/BioSim/CodeEntropy/tests/data/ethyl-acetate/forces.frc", + "file_format": "MDCRD", + "kcal_force_units": false, + "selection_string": "all", + "start": 0, + "end": 1, + "step": 1, + "bin_width": 30, + "temperature": 298.0, + "verbose": false, + "output_file": "/tmp/pytest-of-tdo96567/pytest-60/test_regression_matches_baseli4/job001/output_file.json", + "force_partitioning": 0.5, + "water_entropy": true, + "grouping": "molecules", + "combined_forcetorque": true, + "customised_axes": true + }, + "provenance": { + "python": "3.14.0", + "platform": "Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39", + "codeentropy_version": "1.0.7", + "git_sha": "226b37f7b206adba1b60253c41c7a0d467e75a58" + }, + "groups": { + "0": { + "components": { + "united_atom:Transvibrational": 119.77870290196522, + "united_atom:Rovibrational": 144.2366436580796, + "residue:FTmat-Transvibrational": 103.5819666889598, + "residue:FTmat-Rovibrational": 95.68311953660015, + "united_atom:Conformational": 8.140778318198597, + "residue:Conformational": 0.0 + }, + "total": 471.4212111038034 + } + } +} diff --git a/tests/regression/baselines/methane.json b/tests/regression/baselines/methane.json new file mode 100644 index 00000000..482088b7 --- /dev/null +++ b/tests/regression/baselines/methane.json @@ -0,0 +1,40 @@ +{ + "args": { + "top_traj_file": [ + "/home/tdo96567/BioSim/CodeEntropy/tests/data/methane/molecules.top", + "/home/tdo96567/BioSim/CodeEntropy/tests/data/methane/trajectory.crd" + ], + "force_file": "/home/tdo96567/BioSim/CodeEntropy/tests/data/methane/forces.frc", + "file_format": "MDCRD", + "kcal_force_units": false, + "selection_string": "all", + "start": 0, + "end": 1, + "step": 1, + "bin_width": 30, + "temperature": 112.0, + "verbose": false, + "output_file": "/tmp/pytest-of-tdo96567/pytest-60/test_regression_matches_baseli5/job001/output_file.json", + "force_partitioning": 0.5, + "water_entropy": true, + "grouping": "molecules", + "combined_forcetorque": true, + "customised_axes": true + }, + "provenance": { + "python": "3.14.0", + "platform": "Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39", + "codeentropy_version": "1.0.7", + "git_sha": "226b37f7b206adba1b60253c41c7a0d467e75a58" + }, + "groups": { + "0": { + "components": { + "united_atom:Transvibrational": 75.73291215434239, + "united_atom:Rovibrational": 68.80103728327107, + "united_atom:Conformational": 0.0 + }, + "total": 144.53394943761344 + } + } +} diff --git a/tests/regression/baselines/methanol.json b/tests/regression/baselines/methanol.json new file mode 100644 index 00000000..c5f1c8f3 --- /dev/null +++ b/tests/regression/baselines/methanol.json @@ -0,0 +1,43 @@ +{ + "args": { + "top_traj_file": [ + "/home/tdo96567/BioSim/CodeEntropy/tests/data/methanol/molecules.top", + "/home/tdo96567/BioSim/CodeEntropy/tests/data/methanol/trajectory.crd" + ], + "force_file": "/home/tdo96567/BioSim/CodeEntropy/tests/data/methanol/forces.frc", + "file_format": "MDCRD", + "kcal_force_units": false, + "selection_string": "all", + "start": 0, + "end": 1, + "step": 1, + "bin_width": 30, + "temperature": 298.0, + "verbose": false, + "output_file": "/tmp/pytest-of-tdo96567/pytest-60/test_regression_matches_baseli6/job001/output_file.json", + "force_partitioning": 0.5, + "water_entropy": true, + "grouping": "molecules", + "combined_forcetorque": true, + "customised_axes": true + }, + "provenance": { + "python": "3.14.0", + "platform": "Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39", + "codeentropy_version": "1.0.7", + "git_sha": "226b37f7b206adba1b60253c41c7a0d467e75a58" + }, + "groups": { + "0": { + "components": { + "united_atom:Transvibrational": 0.0, + "united_atom:Rovibrational": 85.74870264018092, + "residue:FTmat-Transvibrational": 93.59616431728384, + "residue:FTmat-Rovibrational": 59.61417719536213, + "united_atom:Conformational": 0.0, + "residue:Conformational": 0.0 + }, + "total": 238.9590441528269 + } + } +} diff --git a/tests/regression/baselines/octonol.json b/tests/regression/baselines/octonol.json new file mode 100644 index 00000000..1856b344 --- /dev/null +++ b/tests/regression/baselines/octonol.json @@ -0,0 +1,43 @@ +{ + "args": { + "top_traj_file": [ + "/home/tdo96567/BioSim/CodeEntropy/tests/data/octonol/molecules.top", + "/home/tdo96567/BioSim/CodeEntropy/tests/data/octonol/trajectory.crd" + ], + "force_file": "/home/tdo96567/BioSim/CodeEntropy/tests/data/octonol/forces.frc", + "file_format": "MDCRD", + "kcal_force_units": false, + "selection_string": "all", + "start": 0, + "end": 1, + "step": 1, + "bin_width": 30, + "temperature": 298.0, + "verbose": false, + "output_file": "/tmp/pytest-of-tdo96567/pytest-65/test_regression_matches_baseli0/job001/output_file.json", + "force_partitioning": 0.5, + "water_entropy": true, + "grouping": "molecules", + "combined_forcetorque": true, + "customised_axes": true + }, + "provenance": { + "python": "3.14.0", + "platform": "Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39", + "codeentropy_version": "1.0.7", + "git_sha": "226b37f7b206adba1b60253c41c7a0d467e75a58" + }, + "groups": { + "0": { + "components": { + "united_atom:Transvibrational": 222.4800037654818, + "united_atom:Rovibrational": 345.86413400118744, + "residue:FTmat-Transvibrational": 101.79847675768119, + "residue:FTmat-Rovibrational": 92.71423842383722, + "united_atom:Conformational": 20.4159084259166, + "residue:Conformational": 0.0 + }, + "total": 783.2727613741043 + } + } +} diff --git a/tests/regression/configs/benzaldehyde/config.yaml b/tests/regression/configs/benzaldehyde/config.yaml new file mode 100644 index 00000000..40a9f932 --- /dev/null +++ b/tests/regression/configs/benzaldehyde/config.yaml @@ -0,0 +1,12 @@ +--- + +run1: + force_file: ".testdata/benzaldehyde/forces.frc" + top_traj_file: + - ".testdata/benzaldehyde/molecules.top" + - ".testdata/benzaldehyde/trajectory.crd" + selection_string: "all" + start: 0 + end: 1 + step: 1 + file_format: "MDCRD" diff --git a/tests/regression/configs/benzene/config.yaml b/tests/regression/configs/benzene/config.yaml new file mode 100644 index 00000000..204e19f2 --- /dev/null +++ b/tests/regression/configs/benzene/config.yaml @@ -0,0 +1,12 @@ +--- + +run1: + force_file: ".testdata/benzene/forces.frc" + top_traj_file: + - ".testdata/benzene/molecules.top" + - ".testdata/benzene/trajectory.crd" + selection_string: 'all' + start: 0 + end: 1 + step: 1 + file_format: 'MDCRD' diff --git a/tests/regression/configs/cyclohexane/config.yaml b/tests/regression/configs/cyclohexane/config.yaml new file mode 100644 index 00000000..cadb32b3 --- /dev/null +++ b/tests/regression/configs/cyclohexane/config.yaml @@ -0,0 +1,12 @@ +--- + +run1: + force_file: ".testdata/cyclohexane/forces.frc" + top_traj_file: + - ".testdata/cyclohexane/molecules.top" + - ".testdata/cyclohexane/trajectory.crd" + selection_string: 'all' + start: 0 + end: 1 + step: 1 + file_format: 'MDCRD' diff --git a/tests/regression/configs/dna/config.yaml b/tests/regression/configs/dna/config.yaml new file mode 100644 index 00000000..854dc148 --- /dev/null +++ b/tests/regression/configs/dna/config.yaml @@ -0,0 +1,10 @@ +--- + +run1: + top_traj_file: + - ".testdata/dna/md_A4_dna.tpr" + - ".testdata/dna/md_A4_dna_xf.trr" + selection_string: 'all' + start: 0 + end: 1 + step: 1 diff --git a/tests/regression/configs/ethyl-acetate/config.yaml b/tests/regression/configs/ethyl-acetate/config.yaml new file mode 100644 index 00000000..84d53a1a --- /dev/null +++ b/tests/regression/configs/ethyl-acetate/config.yaml @@ -0,0 +1,12 @@ +--- + +run1: + force_file: ".testdata/ethyl-acetate/forces.frc" + top_traj_file: + - ".testdata/ethyl-acetate/molecules.top" + - ".testdata/ethyl-acetate/trajectory.crd" + selection_string: 'all' + start: 0 + end: 1 + step: 1 + file_format: 'MDCRD' diff --git a/tests/regression/configs/methane/config.yaml b/tests/regression/configs/methane/config.yaml new file mode 100644 index 00000000..c7bd75d5 --- /dev/null +++ b/tests/regression/configs/methane/config.yaml @@ -0,0 +1,13 @@ +--- + +run1: + force_file: ".testdata/methane/forces.frc" + top_traj_file: + - ".testdata/methane/molecules.top" + - ".testdata/methane/trajectory.crd" + selection_string: all + start: 0 + end: 1 + step: 1 + file_format: MDCRD + temperature: 112.0 diff --git a/tests/regression/configs/methanol/config.yaml b/tests/regression/configs/methanol/config.yaml new file mode 100644 index 00000000..52ad6c77 --- /dev/null +++ b/tests/regression/configs/methanol/config.yaml @@ -0,0 +1,12 @@ +--- + +run1: + force_file: ".testdata/methanol/forces.frc" + top_traj_file: + - ".testdata/methanol/molecules.top" + - ".testdata/methanol/trajectory.crd" + selection_string: 'all' + start: 0 + end: 1 + step: 1 + file_format: 'MDCRD' diff --git a/tests/regression/configs/octonol/config.yaml b/tests/regression/configs/octonol/config.yaml new file mode 100644 index 00000000..29b31c06 --- /dev/null +++ b/tests/regression/configs/octonol/config.yaml @@ -0,0 +1,12 @@ +--- + +run1: + force_file: ".testdata/octonol/forces.frc" + top_traj_file: + - ".testdata/octonol/molecules.top" + - ".testdata/octonol/trajectory.crd" + selection_string: 'all' + start: 0 + end: 1 + step: 1 + file_format: 'MDCRD' diff --git a/tests/regression/conftest.py b/tests/regression/conftest.py new file mode 100644 index 00000000..6693f568 --- /dev/null +++ b/tests/regression/conftest.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import pytest + + +def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + "--run-slow", + action="store_true", + default=False, + help="Run slow regression tests (20-30+ minutes).", + ) + parser.addoption( + "--update-baselines", + action="store_true", + default=False, + help="Overwrite regression baselines with the newly produced outputs.", + ) + parser.addoption( + "--codeentropy-debug", + action="store_true", + default=False, + help="Print CodeEntropy stdout/stderr and paths for easier debugging.", + ) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "regression: end-to-end regression tests") + config.addinivalue_line("markers", "slow: long-running tests (20-30+ minutes)") + + +def pytest_collection_modifyitems( + config: pytest.Config, items: list[pytest.Item] +) -> None: + """ + Regression tests are collected and runnable by default. + + Only @pytest.mark.slow tests are skipped unless you pass --run-slow. + """ + if config.getoption("--run-slow"): + return + + skip_slow = pytest.mark.skip(reason="Skipped slow test (use --run-slow to run).") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/tests/regression/helpers.py b/tests/regression/helpers.py new file mode 100644 index 00000000..0404b967 --- /dev/null +++ b/tests/regression/helpers.py @@ -0,0 +1,407 @@ +from __future__ import annotations + +import json +import os +import subprocess +import tarfile +import urllib.request +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional + +import yaml + +DEFAULT_TESTDATA_BASE_URL = "https://www.ccpbiosim.ac.uk/file-store/codeentropy-testing" + + +@dataclass(frozen=True) +class RunResult: + """Holds outputs and metadata from a single CodeEntropy regression run. + + Attributes: + workdir: Working directory used to run CodeEntropy. + job_dir: The most recent job directory created by CodeEntropy. + output_json: Path to the JSON output produced by CodeEntropy. + payload: Parsed JSON payload. + stdout: Captured stdout from the CodeEntropy process. + stderr: Captured stderr from the CodeEntropy process. + """ + + workdir: Path + job_dir: Path + output_json: Path + payload: Dict[str, Any] + stdout: str + stderr: str + + +def _repo_root_from_this_file() -> Path: + """Return repository root inferred from this file location. + + Returns: + Repository root path. + """ + return Path(__file__).resolve().parents[2] + + +def _testdata_root() -> Path: + """Return the local on-disk cache root for regression input datasets. + + Returns: + Path to the local test data cache root. + """ + return _repo_root_from_this_file() / ".testdata" + + +def _is_within_directory(base: Path, target: Path) -> bool: + """Check whether a target path is within a base directory. + + Args: + base: Base directory. + target: Target path. + + Returns: + True if target is within base, otherwise False. + """ + base = base.resolve() + try: + target = target.resolve() + return str(target).startswith(str(base) + os.sep) or target == base + except FileNotFoundError: + return _is_within_directory(base, target.parent) + + +def _safe_extract_tar_gz(tar_gz: Path, dest_dir: Path) -> None: + """Extract a .tar.gz file safely into dest_dir. + + This prevents path traversal by validating extracted member paths. + + Args: + tar_gz: Path to the tar.gz archive. + dest_dir: Destination directory. + """ + dest_dir.mkdir(parents=True, exist_ok=True) + with tarfile.open(tar_gz, "r:gz") as tf: + for member in tf.getmembers(): + member_path = dest_dir / member.name + if not _is_within_directory(dest_dir, member_path): + raise RuntimeError(f"Unsafe path in tarball: {member.name}") + tf.extractall(dest_dir) + + +def _download(url: str, dst: Path) -> None: + """Download a URL to a local file path. + + Args: + url: Source URL. + dst: Destination file path. + """ + dst.parent.mkdir(parents=True, exist_ok=True) + try: + with urllib.request.urlopen(url) as r, dst.open("wb") as f: + f.write(r.read()) + except Exception as e: + raise RuntimeError(f"Failed to download from {url}: {e}") from e + + +def ensure_testdata_for_system(system: str, *, required_paths: list[Path]) -> Path: + """Ensure the filestore dataset for a system exists locally. + + This downloads and extracts .tar.gz from the CCPBioSim HTTPS filestore + into /.testdata. The archive is expected to contain a top-level + '/' directory. + + Args: + system: System name (e.g., 'methane'). + required_paths: Absolute paths that must exist after extraction. + + Returns: + Path to the system directory under the local cache. + + Raises: + RuntimeError: If download/extraction fails or required files remain missing. + """ + root = _testdata_root() + system_dir = root / system + tar_path = root / f"{system}.tar.gz" + url = f"{DEFAULT_TESTDATA_BASE_URL.rstrip('/')}/{system}.tar.gz" + + def all_required_exist() -> bool: + return all(p.exists() for p in required_paths) + + if required_paths and all_required_exist(): + return system_dir + + root.mkdir(parents=True, exist_ok=True) + _download(url, tar_path) + + if system_dir.exists(): + for p in sorted(system_dir.rglob("*"), reverse=True): + try: + if p.is_file() or p.is_symlink(): + p.unlink() + elif p.is_dir(): + p.rmdir() + except OSError: + pass + try: + system_dir.rmdir() + except OSError: + pass + + _safe_extract_tar_gz(tar_path, root) + + if not system_dir.exists(): + raise RuntimeError( + f"Extraction did not create expected folder {system_dir}. " + f"Tarball may not contain '{system}/'. url={url}" + ) + + if required_paths and not all_required_exist(): + found = [ + str(p.relative_to(system_dir)) for p in system_dir.rglob("*") if p.is_file() + ] + found.sort() + raise RuntimeError( + "Regression data extracted but required files are missing.\n" + f"system={system}\n" + f"expected:\n - " + "\n - ".join(str(p) for p in required_paths) + "\n" + f"found in {system_dir}:\n - " + + ("\n - ".join(found) if found else "") + + "\n" + f"url={url}\n" + ) + + return system_dir + + +def _find_latest_job_dir(workdir: Path) -> Path: + """Find the most recent CodeEntropy job directory in workdir. + + Args: + workdir: Working directory. + + Returns: + Path to the latest job directory. + + Raises: + FileNotFoundError: If no job directory exists. + """ + job_dirs = sorted( + [p for p in workdir.iterdir() if p.is_dir() and p.name.startswith("job")] + ) + if not job_dirs: + raise FileNotFoundError(f"No job*** folder created in {workdir}") + return job_dirs[-1] + + +def _pick_output_json(job_dir: Path) -> Path: + """Pick the primary JSON output file from a CodeEntropy job directory. + + Args: + job_dir: CodeEntropy job directory. + + Returns: + Path to the chosen JSON output. + + Raises: + FileNotFoundError: If no JSON output is found. + """ + for name in ("output.json", "output_file.json"): + p = job_dir / name + if p.exists(): + return p + jsons = sorted(job_dir.glob("*.json")) + if not jsons: + raise FileNotFoundError(f"No JSON output found in job dir: {job_dir}") + return jsons[0] + + +def _resolve_path(value: Any, *, base_dir: Path) -> Optional[str]: + """Resolve a path-like config value to an absolute path string. + + Paths beginning with '.testdata/' are resolved relative to the repository root. + Other relative paths are resolved relative to base_dir. + + Args: + value: Path-like config value. + base_dir: Directory to resolve relative paths against. + + Returns: + Absolute path string or None. + """ + if value is None: + return None + s = str(value) + if not s: + return None + + s_norm = s.replace("\\", "/") + if s_norm.startswith(".testdata/"): + repo_root = _repo_root_from_this_file() + return str((repo_root / s_norm).resolve()) + + p = Path(s) + if p.is_absolute(): + return str(p) + return str((base_dir / p).resolve()) + + +def _resolve_path_list(value: Any, *, base_dir: Path) -> list[str]: + """Resolve a config value representing a path or list of paths. + + Args: + value: Path-like value or list of path-like values. + base_dir: Directory to resolve relative paths against. + + Returns: + List of absolute path strings. + """ + if value is None: + return [] + if isinstance(value, (list, tuple)): + out: list[str] = [] + for v in value: + rp = _resolve_path(v, base_dir=base_dir) + if rp: + out.append(rp) + return out + rp = _resolve_path(value, base_dir=base_dir) + return [rp] if rp else [] + + +def _abspathify_config_paths( + config: Dict[str, Any], *, base_dir: Path +) -> Dict[str, Any]: + """Convert configured input paths into absolute paths. + + Args: + config: Parsed config mapping. + base_dir: Base directory for resolving relative paths. + + Returns: + A new config dict with resolved paths. + """ + path_keys = {"force_file"} + list_path_keys = {"top_traj_file"} + + out: Dict[str, Any] = {} + for run_name, run_cfg in config.items(): + if not isinstance(run_cfg, dict): + out[run_name] = run_cfg + continue + + run_cfg2 = dict(run_cfg) + for k in list(run_cfg2.keys()): + if k in path_keys: + run_cfg2[k] = _resolve_path(run_cfg2.get(k), base_dir=base_dir) + if k in list_path_keys: + run_cfg2[k] = _resolve_path_list(run_cfg2.get(k), base_dir=base_dir) + + out[run_name] = run_cfg2 + + return out + + +def _assert_inputs_exist(cooked: Dict[str, Any]) -> None: + """Assert that required input files referenced in cooked config exist.""" + run1 = cooked.get("run1") + if not isinstance(run1, dict): + return + + for p in run1.get("top_traj_file") or []: + if isinstance(p, str) and p: + assert Path(p).exists(), f"Missing input file: {p}" + + ff = run1.get("force_file") + if isinstance(ff, str) and ff.strip(): + assert Path(ff).exists(), f"Missing force file: {ff}" + + +def run_codeentropy_with_config(*, workdir: Path, config_src: Path) -> RunResult: + """Run CodeEntropy using a regression config file. + + This function loads the YAML config, resolves input paths, ensures required + dataset files exist by downloading from the filestore if needed, then runs + CodeEntropy and returns the parsed output JSON. + + Args: + workdir: Temporary working directory for running CodeEntropy. + config_src: Path to the YAML regression config. + + Returns: + RunResult containing outputs and metadata. + + Raises: + RuntimeError: If CodeEntropy fails or required data cannot be fetched. + ValueError: If config does not parse as a dict or output JSON lacks expected + keys. + FileNotFoundError: If job output files cannot be found. + """ + workdir.mkdir(parents=True, exist_ok=True) + + raw = yaml.safe_load(config_src.read_text()) + if not isinstance(raw, dict): + raise ValueError( + f"Config must parse to a dict. Got {type(raw)} from {config_src}" + ) + + system = config_src.parent.name + cooked = _abspathify_config_paths(raw, base_dir=config_src.parent) + + required: list[Path] = [] + run1 = cooked.get("run1") + if isinstance(run1, dict): + ff = run1.get("force_file") + if isinstance(ff, str) and ff: + required.append(Path(ff)) + for p in run1.get("top_traj_file") or []: + if isinstance(p, str) and p: + required.append(Path(p)) + + if required: + ensure_testdata_for_system(system, required_paths=required) + + _assert_inputs_exist(cooked) + + (workdir / "config.yaml").write_text(yaml.safe_dump(cooked, sort_keys=False)) + + proc = subprocess.run( + ["python", "-m", "CodeEntropy"], + cwd=str(workdir), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env={**os.environ}, + ) + + (workdir / "codeentropy_stdout.txt").write_text(proc.stdout or "") + (workdir / "codeentropy_stderr.txt").write_text(proc.stderr or "") + + if proc.returncode != 0: + raise RuntimeError( + "CodeEntropy regression run failed\n" + f"cwd={workdir}\n" + f"stdout saved to: {workdir / 'codeentropy_stdout.txt'}\n" + f"stderr saved to: {workdir / 'codeentropy_stderr.txt'}\n" + ) + + job_dir = _find_latest_job_dir(workdir) + out_json = _pick_output_json(job_dir) + payload = json.loads(out_json.read_text()) + + (workdir / "codeentropy_output.json").write_text(json.dumps(payload, indent=2)) + + if "groups" not in payload: + raise ValueError( + f"Regression output JSON did not contain 'groups'. output={out_json}" + ) + + return RunResult( + workdir=workdir, + job_dir=job_dir, + output_json=out_json, + payload=payload, + stdout=proc.stdout or "", + stderr=proc.stderr or "", + ) diff --git a/tests/regression/test_regression.py b/tests/regression/test_regression.py new file mode 100644 index 00000000..d5ca7f10 --- /dev/null +++ b/tests/regression/test_regression.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict + +import numpy as np +import pytest + +from tests.regression.helpers import run_codeentropy_with_config + + +def _group_index(payload: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: + """Return the groups mapping from a regression payload. + + Args: + payload: Parsed JSON payload. + + Returns: + Mapping of group id to group data. + + Raises: + TypeError: If payload["groups"] is not a dict. + """ + groups = payload.get("groups", {}) + if not isinstance(groups, dict): + raise TypeError("payload['groups'] must be a dict") + return groups + + +def _compare_grouped( + *, + got_payload: Dict[str, Any], + baseline_payload: Dict[str, Any], + rtol: float, + atol: float, +) -> None: + """Compare grouped regression outputs against baseline values. + + Args: + got_payload: Newly produced payload. + baseline_payload: Baseline payload. + rtol: Relative tolerance. + atol: Absolute tolerance. + + Raises: + AssertionError: If any required group/component differs from baseline. + """ + got_groups = _group_index(got_payload) + base_groups = _group_index(baseline_payload) + + missing_groups = sorted(set(base_groups.keys()) - set(got_groups.keys())) + assert not missing_groups, f"Missing groups in output: {missing_groups}" + + mismatches: list[str] = [] + + for gid, base_g in base_groups.items(): + got_g = got_groups[gid] + + base_components = base_g.get("components", {}) + got_components = got_g.get("components", {}) + + if not isinstance(base_components, dict) or not isinstance( + got_components, dict + ): + mismatches.append(f"group {gid}: components must be dicts") + continue + + missing_keys = sorted(set(base_components.keys()) - set(got_components.keys())) + if missing_keys: + mismatches.append(f"group {gid}: missing component keys: {missing_keys}") + continue + + for k, expected in base_components.items(): + actual = got_components[k] + try: + np.testing.assert_allclose( + float(actual), float(expected), rtol=rtol, atol=atol + ) + except AssertionError: + mismatches.append( + f"group {gid} component {k}: expected={expected} got={actual}" + ) + + if "total" in base_g: + try: + np.testing.assert_allclose( + float(got_g.get("total", 0.0)), + float(base_g["total"]), + rtol=rtol, + atol=atol, + ) + except AssertionError: + mismatches.append( + f"group {gid} total: expected={base_g['total']} " + f"got={got_g.get('total')}" + ) + + assert not mismatches, "Mismatches:\n" + "\n".join(" " + m for m in mismatches) + + +@pytest.mark.regression +@pytest.mark.parametrize( + "system", + [ + pytest.param("benzaldehyde", marks=pytest.mark.slow), + pytest.param("benzene", marks=pytest.mark.slow), + pytest.param("cyclohexane", marks=pytest.mark.slow), + "dna", + pytest.param("ethyl-acetate", marks=pytest.mark.slow), + "methane", + "methanol", + pytest.param("octonol", marks=pytest.mark.slow), + ], +) +def test_regression_matches_baseline( + tmp_path: Path, system: str, request: pytest.FixtureRequest +) -> None: + """Run a regression test for one system and compare to its baseline. + + Args: + tmp_path: Pytest-provided temporary directory. + system: System name parameter. + request: Pytest request fixture for reading CLI options. + """ + repo_root = Path(__file__).resolve().parents[2] + config_path = ( + repo_root / "tests" / "regression" / "configs" / system / "config.yaml" + ) + baseline_path = repo_root / "tests" / "regression" / "baselines" / f"{system}.json" + + assert config_path.exists(), f"Missing config: {config_path}" + assert baseline_path.exists(), f"Missing baseline: {baseline_path}" + + baseline_payload = json.loads(baseline_path.read_text()) + run = run_codeentropy_with_config(workdir=tmp_path, config_src=config_path) + + if request.config.getoption("--codeentropy-debug"): + print("\n[CodeEntropy regression debug]") + print("workdir:", run.workdir) + print("job_dir:", run.job_dir) + print("output_json:", run.output_json) + print("payload copy saved:", run.workdir / "codeentropy_output.json") + + if request.config.getoption("--update-baselines"): + baseline_path.write_text(json.dumps(run.payload, indent=2)) + pytest.skip(f"Baseline updated for {system}: {baseline_path}") + + _compare_grouped( + got_payload=run.payload, + baseline_payload=baseline_payload, + rtol=1e-9, + atol=1e-8, + ) diff --git a/tests/test_CodeEntropy/test_arg_config_manager.py b/tests/test_CodeEntropy/test_arg_config_manager.py deleted file mode 100644 index bf6d2202..00000000 --- a/tests/test_CodeEntropy/test_arg_config_manager.py +++ /dev/null @@ -1,601 +0,0 @@ -import argparse -import logging -import os -import unittest -from unittest.mock import MagicMock, mock_open, patch - -import tests.data as data -from CodeEntropy.config.arg_config_manager import ConfigManager -from CodeEntropy.main import main -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestArgConfigManager(BaseTestCase): - """ - Unit tests for the ConfigManager. - """ - - def setUp(self): - super().setUp() - - self.test_data_dir = os.path.dirname(data.__file__) - self.config_file = os.path.join(self.test_dir, "config.yaml") - - # Create a mock config file - with patch("builtins.open", new_callable=mock_open) as mock_file: - self.setup_file(mock_file) - with open(self.config_file, "w") as f: - f.write(mock_file.return_value.read()) - - def list_data_files(self): - """ - List all files in the test data directory. - """ - return os.listdir(self.test_data_dir) - - def setup_file(self, mock_file): - """ - Mock the contents of a configuration file. - """ - mock_file.return_value = mock_open( - read_data="--- \n \nrun1:\n " - "top_traj_file: ['/path/to/tpr', '/path/to/trr']\n " - "selection_string: 'all'\n " - "start: 0\n " - "end: -1\n " - "step: 1\n " - "bin_width: 30\n " - "tempra: 298.0\n " - "verbose: False\n " - "thread: 1\n " - "output_file: 'output_file.json'\n " - "force_partitioning: 0.5\n " - "water_entropy: False" - ).return_value - - @patch("builtins.open") - @patch("glob.glob", return_value=["config.yaml"]) - def test_load_config(self, mock_glob, mock_file): - """ - Test loading a valid configuration file. - """ - # Setup the mock file content - self.setup_file(mock_file) - - arg_config = ConfigManager() - config = arg_config.load_config("/some/path") - - self.assertIn("run1", config) - self.assertEqual( - config["run1"]["top_traj_file"], ["/path/to/tpr", "/path/to/trr"] - ) - self.assertEqual(config["run1"]["selection_string"], "all") - self.assertEqual(config["run1"]["start"], 0) - self.assertEqual(config["run1"]["end"], -1) - self.assertEqual(config["run1"]["step"], 1) - self.assertEqual(config["run1"]["bin_width"], 30) - self.assertEqual(config["run1"]["tempra"], 298.0) - self.assertFalse(config["run1"]["verbose"]) - self.assertEqual(config["run1"]["thread"], 1) - self.assertEqual(config["run1"]["output_file"], "output_file.json") - self.assertEqual(config["run1"]["force_partitioning"], 0.5) - self.assertFalse(config["run1"]["water_entropy"]) - - @patch("glob.glob", return_value=[]) - def test_load_config_no_yaml_files(self, mock_glob): - arg_config = ConfigManager() - config = arg_config.load_config("/some/path") - self.assertEqual(config, {"run1": {}}) - - @patch("builtins.open", side_effect=FileNotFoundError) - @patch("glob.glob", return_value=["config.yaml"]) - def test_load_config_file_not_found(self, mock_glob, mock_open): - """ - Test loading a configuration file that exists but cannot be opened. - Should return default config instead of raising an error. - """ - arg_config = ConfigManager() - config = arg_config.load_config("/some/path") - self.assertEqual(config, {"run1": {}}) - - @patch.object(ConfigManager, "load_config", return_value=None) - def test_no_cli_no_yaml(self, mock_load_config): - """Test behavior when no CLI arguments and no YAML file are provided.""" - with self.assertRaises(SystemExit) as context: - main() - self.assertEqual(context.exception.code, 1) - - def test_invalid_run_config_type(self): - """ - Test that passing an invalid type for run_config raises a TypeError. - """ - arg_config = ConfigManager() - args = MagicMock() - invalid_configs = ["string", 123, 3.14, ["list"], {("tuple_key",): "value"}] - - for invalid in invalid_configs: - with self.assertRaises(TypeError): - arg_config.merge_configs(args, invalid) - - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - top_traj_file=["/path/to/tpr", "/path/to/trr"], - selection_string="all", - start=0, - end=-1, - step=1, - bin_width=30, - tempra=298.0, - verbose=False, - thread=1, - output_file="output_file.json", - force_partitioning=0.5, - water_entropy=False, - ), - ) - def test_setup_argparse(self, mock_args): - """ - Test parsing command-line arguments. - """ - arg_config = ConfigManager() - parser = arg_config.setup_argparse() - args = parser.parse_args() - self.assertEqual(args.top_traj_file, ["/path/to/tpr", "/path/to/trr"]) - self.assertEqual(args.selection_string, "all") - - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - top_traj_file=["/path/to/tpr", "/path/to/trr"], - start=10, - water_entropy=False, - ), - ) - def test_setup_argparse_false_boolean(self, mock_args): - """ - Test that non-boolean arguments are parsed correctly. - """ - arg_config = ConfigManager() - parser = arg_config.setup_argparse() - args = parser.parse_args() - - self.assertEqual(args.top_traj_file, ["/path/to/tpr", "/path/to/trr"]) - self.assertEqual(args.start, 10) - self.assertFalse(args.water_entropy) - - def test_str2bool_true_variants(self): - """Test that various string representations of True are correctly parsed.""" - arg_config = ConfigManager() - - self.assertTrue(arg_config.str2bool("true")) - self.assertTrue(arg_config.str2bool("True")) - self.assertTrue(arg_config.str2bool("t")) - self.assertTrue(arg_config.str2bool("yes")) - self.assertTrue(arg_config.str2bool("1")) - - def test_str2bool_false_variants(self): - """Test that various string representations of False are correctly parsed.""" - arg_config = ConfigManager() - - self.assertFalse(arg_config.str2bool("false")) - self.assertFalse(arg_config.str2bool("False")) - self.assertFalse(arg_config.str2bool("f")) - self.assertFalse(arg_config.str2bool("no")) - self.assertFalse(arg_config.str2bool("0")) - - def test_str2bool_boolean_passthrough(self): - """Test that boolean values passed directly are returned unchanged.""" - arg_config = ConfigManager() - - self.assertTrue(arg_config.str2bool(True)) - self.assertFalse(arg_config.str2bool(False)) - - def test_str2bool_invalid_input(self): - """Test that invalid string inputs raise an ArgumentTypeError.""" - arg_config = ConfigManager() - - with self.assertRaises(Exception) as context: - arg_config.str2bool("maybe") - self.assertIn("Boolean value expected", str(context.exception)) - - def test_str2bool_empty_string(self): - """Test that an empty string raises an ArgumentTypeError.""" - arg_config = ConfigManager() - - with self.assertRaises(Exception) as context: - arg_config.str2bool("") - self.assertIn("Boolean value expected", str(context.exception)) - - def test_str2bool_unexpected_number(self): - """Test that unexpected numeric strings raise an ArgumentTypeError.""" - arg_config = ConfigManager() - - with self.assertRaises(Exception) as context: - arg_config.str2bool("2") - self.assertIn("Boolean value expected", str(context.exception)) - - def test_cli_overrides_defaults(self): - """ - Test if CLI parameters override default values. - """ - arg_config = ConfigManager() - parser = arg_config.setup_argparse() - args = parser.parse_args( - ["--top_traj_file", "/cli/path", "--selection_string", "cli_value"] - ) - self.assertEqual(args.top_traj_file, ["/cli/path"]) - self.assertEqual(args.selection_string, "cli_value") - - def test_cli_overrides_yaml(self): - """ - Test if CLI parameters override YAML parameters correctly. - """ - arg_config = ConfigManager() - parser = arg_config.setup_argparse() - args = parser.parse_args( - ["--top_traj_file", "/cli/path", "--selection_string", "cli_value"] - ) - run_config = {"top_traj_file": ["/yaml/path"], "selection_string": "yaml_value"} - merged_args = arg_config.merge_configs(args, run_config) - self.assertEqual(merged_args.top_traj_file, ["/cli/path"]) - self.assertEqual(merged_args.selection_string, "cli_value") - - def test_cli_overrides_yaml_with_multiple_values(self): - """ - Ensures that CLI arguments override YAML when multiple values are provided in - YAML. - """ - arg_config = ConfigManager() - yaml_config = {"top_traj_file": ["/yaml/path1", "/yaml/path2"]} - args = argparse.Namespace(top_traj_file=["/cli/path"]) - - merged_args = arg_config.merge_configs(args, yaml_config) - - self.assertEqual(merged_args.top_traj_file, ["/cli/path"]) - - def test_yaml_overrides_defaults(self): - """ - Test if YAML parameters override default values. - """ - run_config = {"top_traj_file": ["/yaml/path"], "selection_string": "yaml_value"} - args = argparse.Namespace() - arg_config = ConfigManager() - merged_args = arg_config.merge_configs(args, run_config) - self.assertEqual(merged_args.top_traj_file, ["/yaml/path"]) - self.assertEqual(merged_args.selection_string, "yaml_value") - - def test_yaml_does_not_override_cli_if_set(self): - """ - Ensure YAML does not override CLI arguments that are set. - """ - arg_config = ConfigManager() - - yaml_config = {"bin_width": 50} - args = argparse.Namespace(bin_width=100) - - merged_args = arg_config.merge_configs(args, yaml_config) - - self.assertEqual(merged_args.bin_width, 100) - - def test_yaml_overrides_defaults_when_no_cli(self): - """ - Test if YAML parameters override default values when no CLI input is given. - """ - arg_config = ConfigManager() - - yaml_config = { - "top_traj_file": ["/yaml/path"], - "bin_width": 50, - } - - args = argparse.Namespace() - - merged_args = arg_config.merge_configs(args, yaml_config) - - self.assertEqual(merged_args.top_traj_file, ["/yaml/path"]) - self.assertEqual(merged_args.bin_width, 50) - - def test_yaml_none_does_not_override_defaults(self): - """ - Ensures that YAML values set to `None` do not override existing CLI values. - """ - arg_config = ConfigManager() - yaml_config = {"bin_width": None} - args = argparse.Namespace(bin_width=100) - - merged_args = arg_config.merge_configs(args, yaml_config) - - self.assertEqual(merged_args.bin_width, 100) - - def test_hierarchy_cli_yaml_defaults(self): - """ - Test if CLI arguments override YAML, and YAML overrides defaults. - """ - arg_config = ConfigManager() - - yaml_config = { - "top_traj_file": ["/yaml/path", "/yaml/path"], - "bin_width": "50", - } - - args = argparse.Namespace( - top_traj_file=["/cli/path", "/cli/path"], bin_width=100 - ) - - merged_args = arg_config.merge_configs(args, yaml_config) - - self.assertEqual(merged_args.top_traj_file, ["/cli/path", "/cli/path"]) - self.assertEqual(merged_args.bin_width, 100) - - def test_merge_configs(self): - """ - Test merging default arguments with a run configuration. - """ - arg_config = ConfigManager() - args = MagicMock( - top_traj_file=None, - selection_string=None, - start=None, - end=None, - step=None, - bin_width=None, - tempra=None, - verbose=None, - thread=None, - output_file=None, - force_partitioning=None, - water_entropy=None, - ) - run_config = { - "top_traj_file": ["/path/to/tpr", "/path/to/trr"], - "selection_string": "all", - "start": 0, - "end": -1, - "step": 1, - "bin_width": 30, - "tempra": 298.0, - "verbose": False, - "thread": 1, - "output_file": "output_file.json", - "force_partitioning": 0.5, - "water_entropy": False, - } - merged_args = arg_config.merge_configs(args, run_config) - self.assertEqual(merged_args.top_traj_file, ["/path/to/tpr", "/path/to/trr"]) - self.assertEqual(merged_args.selection_string, "all") - - def test_merge_with_none_yaml(self): - """ - Ensure merging still works if no YAML config is provided. - """ - arg_config = ConfigManager() - - args = argparse.Namespace(top_traj_file=["/cli/path"]) - yaml_config = None - - merged_args = arg_config.merge_configs(args, yaml_config) - - self.assertEqual(merged_args.top_traj_file, ["/cli/path"]) - - @patch("CodeEntropy.config.arg_config_manager.logger") - def test_merge_configs_sets_debug_logging(self, mock_logger): - """ - Ensure logging is set to DEBUG when verbose=True. - """ - arg_config = ConfigManager() - args = argparse.Namespace(verbose=True) - for key in arg_config.arg_map: - if not hasattr(args, key): - setattr(args, key, None) - - # Mock logger handlers - mock_handler = MagicMock() - mock_logger.handlers = [mock_handler] - - arg_config.merge_configs(args, {}) - - mock_logger.setLevel.assert_called_with(logging.DEBUG) - mock_handler.setLevel.assert_called_with(logging.DEBUG) - mock_logger.debug.assert_called_with( - "Verbose mode enabled. Logger set to DEBUG level." - ) - - @patch("CodeEntropy.config.arg_config_manager.logger") - def test_merge_configs_sets_info_logging(self, mock_logger): - """ - Ensure logging is set to INFO when verbose=False. - """ - arg_config = ConfigManager() - args = argparse.Namespace(verbose=False) - for key in arg_config.arg_map: - if not hasattr(args, key): - setattr(args, key, None) - - # Mock logger handlers - mock_handler = MagicMock() - mock_logger.handlers = [mock_handler] - - arg_config.merge_configs(args, {}) - - mock_logger.setLevel.assert_called_with(logging.INFO) - mock_handler.setLevel.assert_called_with(logging.INFO) - - @patch("argparse.ArgumentParser.parse_args") - def test_default_values(self, mock_parse_args): - """ - Test if argument parser assigns default values correctly. - """ - arg_config = ConfigManager() - mock_parse_args.return_value = MagicMock( - top_traj_file=["example.top", "example.traj"] - ) - parser = arg_config.setup_argparse() - args = parser.parse_args() - self.assertEqual(args.top_traj_file, ["example.top", "example.traj"]) - - def test_fallback_to_defaults(self): - """ - Ensure arguments fall back to defaults if neither YAML nor CLI provides them. - """ - arg_config = ConfigManager() - - yaml_config = {} - args = argparse.Namespace() - - merged_args = arg_config.merge_configs(args, yaml_config) - - self.assertEqual(merged_args.step, 1) - self.assertEqual(merged_args.end, -1) - - @patch( - "argparse.ArgumentParser.parse_args", return_value=MagicMock(top_traj_file=None) - ) - def test_missing_required_arguments(self, mock_args): - """ - Test behavior when required arguments are missing. - """ - arg_config = ConfigManager() - parser = arg_config.setup_argparse() - args = parser.parse_args() - with self.assertRaises(ValueError): - if not args.top_traj_file: - raise ValueError( - "The 'top_traj_file' argument is required but not provided." - ) - - def test_invalid_argument_type(self): - """ - Test handling of invalid argument types. - """ - arg_config = ConfigManager() - parser = arg_config.setup_argparse() - with self.assertRaises(SystemExit): - parser.parse_args(["--start", "invalid"]) - - @patch( - "argparse.ArgumentParser.parse_args", return_value=MagicMock(start=-1, end=-10) - ) - def test_edge_case_argument_values(self, mock_args): - """ - Test parsing of edge case values. - """ - arg_config = ConfigManager() - parser = arg_config.setup_argparse() - args = parser.parse_args() - self.assertEqual(args.start, -1) - self.assertEqual(args.end, -10) - - @patch("builtins.open", new_callable=mock_open, read_data="--- \n") - @patch("glob.glob", return_value=["config.yaml"]) - def test_empty_yaml_config(self, mock_glob, mock_file): - """ - Test behavior when an empty YAML file is provided. - Should return default config {'run1': {}}. - """ - arg_config = ConfigManager() - config = arg_config.load_config("/some/path") - - self.assertIsInstance(config, dict) - self.assertEqual(config, {"run1": {}}) - - def test_input_parameters_validation_all_valid(self): - """Test that input_parameters_validation passes with all valid inputs.""" - manager = ConfigManager() - u = MagicMock() - u.trajectory = [0] * 100 - - args = MagicMock( - start=10, - end=90, - step=1, - bin_width=30, - temperature=298.0, - force_partitioning=0.5, - ) - - with patch.dict( - "CodeEntropy.config.arg_config_manager.arg_map", - {"force_partitioning": {"default": 0.5}}, - ): - manager.input_parameters_validation(u, args) - - def test_check_input_start_valid(self): - """Test that a valid 'start' value does not raise an error.""" - args = MagicMock(start=50) - u = MagicMock() - u.trajectory = [0] * 100 - ConfigManager()._check_input_start(u, args) - - def test_check_input_start_invalid(self): - """Test that an invalid 'start' value raises a ValueError.""" - args = MagicMock(start=150) - u = MagicMock() - u.trajectory = [0] * 100 - with self.assertRaises(ValueError): - ConfigManager()._check_input_start(u, args) - - def test_check_input_end_valid(self): - """Test that a valid 'end' value does not raise an error.""" - args = MagicMock(end=100) - u = MagicMock() - u.trajectory = [0] * 100 - ConfigManager()._check_input_end(u, args) - - def test_check_input_end_invalid(self): - """Test that an 'end' value exceeding trajectory length raises a ValueError.""" - args = MagicMock(end=101) - u = MagicMock() - u.trajectory = [0] * 100 - with self.assertRaises(ValueError): - ConfigManager()._check_input_end(u, args) - - @patch("CodeEntropy.config.arg_config_manager.logger") - def test_check_input_step_negative(self, mock_logger): - """Test that a negative 'step' value triggers a warning.""" - args = MagicMock(step=-1) - ConfigManager()._check_input_step(args) - mock_logger.warning.assert_called_once() - - def test_check_input_bin_width_valid(self): - """Test that a valid 'bin_width' value does not raise an error.""" - args = MagicMock(bin_width=180) - ConfigManager()._check_input_bin_width(args) - - def test_check_input_bin_width_invalid_low(self): - """Test that a negative 'bin_width' value raises a ValueError.""" - args = MagicMock(bin_width=-10) - with self.assertRaises(ValueError): - ConfigManager()._check_input_bin_width(args) - - def test_check_input_bin_width_invalid_high(self): - """Test that a 'bin_width' value above 360 raises a ValueError.""" - args = MagicMock(bin_width=400) - with self.assertRaises(ValueError): - ConfigManager()._check_input_bin_width(args) - - def test_check_input_temperature_valid(self): - """Test that a valid 'temperature' value does not raise an error.""" - args = MagicMock(temperature=298.0) - ConfigManager()._check_input_temperature(args) - - def test_check_input_temperature_invalid(self): - """Test that a negative 'temperature' value raises a ValueError.""" - args = MagicMock(temperature=-5) - with self.assertRaises(ValueError): - ConfigManager()._check_input_temperature(args) - - @patch("CodeEntropy.config.arg_config_manager.logger") - def test_check_input_force_partitioning_warning(self, mock_logger): - """Test that a non-default 'force_partitioning' value triggers a warning.""" - args = MagicMock(force_partitioning=0.7) - with patch.dict( - "CodeEntropy.config.arg_config_manager.arg_map", - {"force_partitioning": {"default": 0.5}}, - ): - ConfigManager()._check_input_force_partitioning(args) - mock_logger.warning.assert_called_once() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_CodeEntropy/test_axes.py b/tests/test_CodeEntropy/test_axes.py deleted file mode 100644 index b5f41aa6..00000000 --- a/tests/test_CodeEntropy/test_axes.py +++ /dev/null @@ -1,476 +0,0 @@ -from unittest.mock import MagicMock, patch - -import numpy as np - -from CodeEntropy.axes import AxesManager -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestAxesManager(BaseTestCase): - def setUp(self): - super().setUp() - - def test_get_residue_axes_no_bonds_custom_axes_branch(self): - """ - Tests that: atom_set empty (len == 0) -> custom axes branch - """ - axes_manager = AxesManager() - data_container = MagicMock() - - atom_set = MagicMock() - atom_set.__len__.return_value = 0 - - residue = MagicMock() - - data_container.select_atoms.side_effect = [atom_set, residue] - - center = np.array([1.0, 2.0, 3.0]) - residue.atoms.center_of_mass.return_value = center - - UAs = MagicMock() - UAs.positions = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) - residue.select_atoms.return_value = UAs - - UA_masses = [12.0, 14.0] - axes_manager.get_UA_masses = MagicMock(return_value=UA_masses) - - moi_tensor = np.eye(3) * 5.0 - axes_manager.get_moment_of_inertia_tensor = MagicMock(return_value=moi_tensor) - - rot_axes = np.eye(3) - moi = np.array([10.0, 9.0, 8.0]) - axes_manager.get_custom_principal_axes = MagicMock(return_value=(rot_axes, moi)) - - trans_axes_out, rot_axes_out, center_out, moi_out = ( - axes_manager.get_residue_axes( - data_container=data_container, - index=5, - ) - ) - - calls = data_container.select_atoms.call_args_list - assert len(calls) == 2 - assert calls[0].args[0] == "(resindex 4 or resindex 6) and bonded resid 5" - assert calls[1].args[0] == "resindex 5" - - residue.select_atoms.assert_called_once_with("mass 2 to 999") - - axes_manager.get_UA_masses.assert_called_once_with(residue) - - axes_manager.get_moment_of_inertia_tensor.assert_called_once() - tensor_args, tensor_kwargs = axes_manager.get_moment_of_inertia_tensor.call_args - np.testing.assert_array_equal(tensor_args[0], center) - np.testing.assert_array_equal(tensor_args[1], UAs.positions) - assert tensor_args[2] == UA_masses - assert tensor_kwargs == {} - - axes_manager.get_custom_principal_axes.assert_called_once_with(moi_tensor) - - np.testing.assert_array_equal(trans_axes_out, rot_axes) - np.testing.assert_array_equal(rot_axes_out, rot_axes) - np.testing.assert_array_equal(center_out, center) - np.testing.assert_array_equal(moi_out, moi) - - def test_get_residue_axes_bonded_default_axes_branch(self): - """ - Tests that: atom_set non-empty (len != 0) -> default/bonded branch - """ - axes_manager = AxesManager() - data_container = MagicMock() - data_container.atoms = MagicMock() - - atom_set = MagicMock() - atom_set.__len__.return_value = 2 - - residue = MagicMock() - data_container.select_atoms.side_effect = [atom_set, residue] - - trans_axes_expected = np.eye(3) * 2 - data_container.atoms.principal_axes.return_value = trans_axes_expected - - rot_axes_expected = np.eye(3) * 3 - residue.principal_axes.return_value = rot_axes_expected - - moi_tensor = np.eye(3) - residue.moment_of_inertia.return_value = moi_tensor - - center_expected = np.array([9.0, 8.0, 7.0]) - residue.atoms.center_of_mass.return_value = center_expected - residue.center_of_mass.return_value = center_expected - - with ( - patch("CodeEntropy.axes.make_whole", autospec=True), - patch.object( - AxesManager, - "get_vanilla_axes", - return_value=(rot_axes_expected, np.array([3.0, 2.0, 1.0])), - ), - ): - trans_axes_out, rot_axes_out, center_out, moi_out = ( - axes_manager.get_residue_axes( - data_container=data_container, - index=5, - ) - ) - - np.testing.assert_allclose(trans_axes_out, trans_axes_expected) - np.testing.assert_allclose(rot_axes_out, rot_axes_expected) - np.testing.assert_allclose(center_out, center_expected) - np.testing.assert_allclose(moi_out, np.array([3.0, 2.0, 1.0])) - - @patch("CodeEntropy.axes.make_whole", autospec=True) - def test_get_UA_axes_returns_expected_outputs(self, mock_make_whole): - """ - Tests that: `get_UA_axes` returns expected UA axes. - """ - axes = AxesManager() - - dc = MagicMock() - dc.atoms = MagicMock() - dc.dimensions = np.array([1.0, 2.0, 3.0, 90.0, 90.0, 90.0]) - dc.atoms.center_of_mass.return_value = np.array([0.0, 0.0, 0.0]) - - a0 = MagicMock() - a0.index = 7 - a1 = MagicMock() - a1.index = 9 - - heavy_atoms = MagicMock() - heavy_atoms.__len__.return_value = 2 - heavy_atoms.__iter__.return_value = iter([a0, a1]) - heavy_atoms.positions = np.array([[9.9, 8.8, 7.7], [1.1, 2.2, 3.3]]) - - dc.select_atoms.side_effect = [heavy_atoms, heavy_atoms] - - axes.get_UA_masses = MagicMock(return_value=[1.0, 1.0]) - axes.get_moment_of_inertia_tensor = MagicMock(return_value=np.eye(3)) - - trans_axes_expected = np.eye(3) - axes.get_custom_principal_axes = MagicMock( - return_value=(trans_axes_expected, np.array([1.0, 1.0, 1.0])) - ) - - rot_axes_expected = np.eye(3) * 2 - moi_expected = np.array([3.0, 2.0, 1.0]) - axes.get_bonded_axes = MagicMock(return_value=(rot_axes_expected, moi_expected)) - - trans_axes, rot_axes, center, moi = axes.get_UA_axes(dc, index=1) - - np.testing.assert_array_equal(trans_axes, trans_axes_expected) - np.testing.assert_array_equal(rot_axes, rot_axes_expected) - np.testing.assert_array_equal(center, heavy_atoms.positions[0]) - np.testing.assert_array_equal(moi, moi_expected) - - calls = [c.args[0] for c in dc.select_atoms.call_args_list] - assert calls[0] == "prop mass > 1.1" - assert calls[1] == "index 9" - - def test_get_bonded_axes_returns_none_for_light_atom(self): - """ - Tests that: bonded axes return none for light atoms - """ - axes = AxesManager() - - atom = MagicMock() - atom.mass = 1.0 - system = MagicMock() - - out = axes.get_bonded_axes( - system=system, atom=atom, dimensions=np.array([1.0, 2.0, 3.0]) - ) - assert out is None - - def test_get_bonded_axes_case2_one_heavy_zero_light(self): - """ - Tests that: bonded return one heavy and zero light atoms - """ - axes = AxesManager() - - system = MagicMock() - atom = MagicMock() - atom.mass = 12.0 - atom.index = 0 - atom.position = np.array([0.0, 0.0, 0.0]) - - heavy0 = MagicMock() - heavy0.position = np.array([1.0, 0.0, 0.0]) - - heavy_bonded = [heavy0] - light_bonded = [] - - axes.find_bonded_atoms = MagicMock(return_value=(heavy_bonded, light_bonded)) - - custom_axes = np.eye(3) - axes.get_custom_axes = MagicMock(return_value=custom_axes) - - moi = np.array([1.0, 2.0, 3.0]) - axes.get_custom_moment_of_inertia = MagicMock(return_value=moi) - - flipped_axes = np.eye(3) * 2 - axes.get_flipped_axes = MagicMock(return_value=flipped_axes) - - out_axes, out_moi = axes.get_bonded_axes( - system, atom, np.array([10.0, 10.0, 10.0]) - ) - - np.testing.assert_array_equal(out_axes, flipped_axes) - np.testing.assert_array_equal(out_moi, moi) - - axes.get_custom_axes.assert_called_once() - args, _ = axes.get_custom_axes.call_args - np.testing.assert_array_equal(args[0], atom.position) - assert len(args[1]) == 1 - np.testing.assert_array_equal(args[1][0], heavy0.position) - np.testing.assert_array_equal(args[2], np.zeros(3)) - np.testing.assert_array_equal(args[3], np.array([10.0, 10.0, 10.0])) - - def test_get_bonded_axes_case3_one_heavy_with_light(self): - """ - Tests that: bonded axes return one heavy with one light atom - """ - axes = AxesManager() - - system = MagicMock() - atom = MagicMock() - atom.mass = 12.0 - atom.index = 0 - atom.position = np.array([0.0, 0.0, 0.0]) - - heavy0 = MagicMock() - heavy0.position = np.array([1.0, 0.0, 0.0]) - - light0 = MagicMock() - light0.position = np.array([0.0, 1.0, 0.0]) - - heavy_bonded = [heavy0] - light_bonded = [light0] - - axes.find_bonded_atoms = MagicMock(return_value=(heavy_bonded, light_bonded)) - - custom_axes = np.eye(3) - axes.get_custom_axes = MagicMock(return_value=custom_axes) - axes.get_custom_moment_of_inertia = MagicMock( - return_value=np.array([1.0, 1.0, 1.0]) - ) - axes.get_flipped_axes = MagicMock(return_value=custom_axes) - - axes.get_bonded_axes(system, atom, np.array([10.0, 10.0, 10.0])) - - axes.get_custom_axes.assert_called_once() - args, _ = axes.get_custom_axes.call_args - - np.testing.assert_array_equal(args[2], light0.position) - - def test_get_bonded_axes_case5_two_or_more_heavy(self): - """ - Tests that: bonded axes return two or more heavy atoms - """ - axes = AxesManager() - - system = MagicMock() - atom = MagicMock() - atom.mass = 12.0 - atom.index = 0 - atom.position = np.array([0.0, 0.0, 0.0]) - - heavy0 = MagicMock() - heavy0.position = np.array([1.0, 0.0, 0.0]) - heavy1 = MagicMock() - heavy1.position = np.array([0.0, 1.0, 0.0]) - - heavy_bonded = MagicMock() - heavy_bonded.__len__.return_value = 2 - heavy_bonded.positions = np.array([heavy0.position, heavy1.position]) - - heavy_bonded.__getitem__.side_effect = lambda i: [heavy0, heavy1][i] - - light_bonded = [] - - axes.find_bonded_atoms = MagicMock(return_value=(heavy_bonded, light_bonded)) - - custom_axes = np.eye(3) - axes.get_custom_axes = MagicMock(return_value=custom_axes) - axes.get_custom_moment_of_inertia = MagicMock( - return_value=np.array([9.0, 9.0, 9.0]) - ) - axes.get_flipped_axes = MagicMock(return_value=custom_axes) - - axes.get_bonded_axes(system, atom, np.array([10.0, 10.0, 10.0])) - - axes.get_custom_axes.assert_called_once() - args, _ = axes.get_custom_axes.call_args - - np.testing.assert_array_equal(args[1], heavy_bonded.positions) - np.testing.assert_array_equal(args[2], heavy1.position) - - def test_find_bonded_atoms_splits_heavy_and_h(self): - """ - Tests that: Bonded atoms split into heavy and hydrogen. - """ - axes = AxesManager() - - system = MagicMock() - bonded = MagicMock() - heavy = MagicMock() - hydrogens = MagicMock() - - system.select_atoms.return_value = bonded - bonded.select_atoms.side_effect = [heavy, hydrogens] - - out_heavy, out_h = axes.find_bonded_atoms(5, system) - - system.select_atoms.assert_called_once_with("bonded index 5") - assert bonded.select_atoms.call_args_list[0].args[0] == "mass 2 to 999" - assert bonded.select_atoms.call_args_list[1].args[0] == "mass 1 to 1.1" - assert out_heavy is heavy - assert out_h is hydrogens - - def test_get_vector_wraps_pbc(self): - """ - Tests that: The vector wraps across periodic boundary. - """ - axes = AxesManager() - - a = np.array([9.0, 0.0, 0.0]) - b = np.array([1.0, 0.0, 0.0]) - dims = np.array([10.0, 10.0, 10.0]) - - out = axes.get_vector(a, b, dims) - np.testing.assert_array_equal(out, np.array([2.0, 0.0, 0.0])) - - def test_get_custom_axes_returns_unit_axes(self): - """ - Tests that: `get_axes` returns normalized 3x3 axes. - """ - axes = AxesManager() - - a = np.zeros(3) - b_list = [np.array([1.0, 0.0, 0.0])] - c = np.array([0.0, 1.0, 0.0]) - dims = np.array([100.0, 100.0, 100.0]) - - out = axes.get_custom_axes(a, b_list, c, dims) - - assert out.shape == (3, 3) - norms = np.linalg.norm(out, axis=1) - np.testing.assert_allclose(norms, np.ones(3)) - - def test_get_custom_axes_uses_bc_vector_when_multiple_heavy_atoms(self): - """ - Tests that: `get_custom_axes` uses c → b_list[0] vector when b_list has - ≥ 2 atoms. - """ - axes = AxesManager() - - a = np.zeros(3) - b0 = np.array([1.0, 0.0, 0.0]) - b1 = np.array([0.0, 1.0, 0.0]) - b_list = [b0, b1] - c = np.array([0.0, 0.0, 1.0]) - dimensions = np.array([10.0, 10.0, 10.0]) - - # Track calls to get_vector - axes.get_vector = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) - - axes.get_custom_axes(a, b_list, c, dimensions) - - # get_vector should be called - calls = axes.get_vector.call_args_list - - # Last call must be (c, b_list[0], dimensions) - last_args = calls[-1].args - np.testing.assert_array_equal(last_args[0], c) - np.testing.assert_array_equal(last_args[1], b0) - np.testing.assert_array_equal(last_args[2], dimensions) - - def test_get_custom_moment_of_inertia_len2_zeroed(self): - """ - Tests that: `get_custom_moment_of_inertia` zeroes one MOI component for - two-atom UA. - """ - axes = AxesManager() - - UA = MagicMock() - UA.positions = np.array([[1, 0, 0], [0, 1, 0]]) - UA.masses = np.array([12.0, 1.0]) - UA.__len__.return_value = 2 - - dimensions = np.array([100.0, 100.0, 100.0]) - - moi = axes.get_custom_moment_of_inertia(UA, np.eye(3), np.zeros(3), dimensions) - - assert moi.shape == (3,) - assert np.any(np.isclose(moi, 0.0)) - - def test_get_flipped_axes_flips_negative_dot(self): - """ - Tests that: `get_flipped_axes` flips axis when dot product is negative. - """ - axes = AxesManager() - - UA = MagicMock() - atom0 = MagicMock() - atom0.position = np.zeros(3) - UA.__getitem__.return_value = atom0 - - axes.get_vector = MagicMock(return_value=np.array([-1.0, 0.0, 0.0])) - - custom_axes = np.eye(3) - out = axes.get_flipped_axes( - UA, custom_axes, np.zeros(3), np.array([10, 10, 10]) - ) - - np.testing.assert_array_equal(out[0], np.array([-1.0, 0.0, 0.0])) - - def test_get_moment_of_inertia_tensor_simple(self): - """ - Tests that: `get_moment_of_inertia` Computes inertia tensor correctly. - """ - axes = AxesManager() - - center = np.zeros(3) - pos = np.array([[1, 0, 0], [0, 1, 0]]) - masses = np.array([1.0, 1.0]) - dimensions = np.array([100.0, 100.0, 100.0]) - - tensor = axes.get_moment_of_inertia_tensor(center, pos, masses, dimensions) - - expected = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 2]]) - np.testing.assert_array_equal(tensor, expected) - - def test_get_custom_principal_axes_flips_z(self): - """ - Tests that: `get_custom_principle_axes` ensures right-handed axes. - """ - axes = AxesManager() - - with patch("CodeEntropy.axes.np.linalg.eig") as eig: - eig.return_value = ( - np.array([3, 2, 1]), - np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]), - ) - - axes_out, moi = axes.get_custom_principal_axes(np.eye(3)) - - np.testing.assert_array_equal(axes_out[2], np.array([0, 0, 1])) - - def test_get_UA_masses_sums_hydrogens(self): - """ - Tests that: `get_UA_masses` sums heavy atom with bonded hydrogens. - """ - axes = AxesManager() - - heavy = MagicMock(mass=12.0, index=0) - light = MagicMock(mass=1.0, index=1) - - mol = MagicMock() - mol.__iter__.return_value = iter([heavy, light]) - - bonded = MagicMock() - H = MagicMock(mass=1.0) - mol.select_atoms.return_value = bonded - bonded.select_atoms.return_value = [H] - - out = axes.get_UA_masses(mol) - - assert out == [13.0] diff --git a/tests/test_CodeEntropy/test_base.py b/tests/test_CodeEntropy/test_base.py deleted file mode 100644 index 2c4da9a8..00000000 --- a/tests/test_CodeEntropy/test_base.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -import shutil -import tempfile -import unittest - - -class BaseTestCase(unittest.TestCase): - """ - Base test case class for cross-platform unit tests. - - Provides: - 1. A unique temporary directory for each test to avoid filesystem conflicts. - 2. Automatic restoration of the working directory after each test. - 3. Prepares a logs folder path for tests that need logging configuration. - """ - - def setUp(self): - """ - Prepare the test environment before each test method runs. - - Actions performed: - 1. Creates a unique temporary directory for the test. - 2. Creates a 'logs' subdirectory within the temp directory. - 3. Changes the current working directory to the temporary directory. - """ - # Create a unique temporary test directory - self.test_dir = tempfile.mkdtemp(prefix="CodeEntropy_") - self.logs_path = os.path.join(self.test_dir, "logs") - os.makedirs(self.logs_path, exist_ok=True) - - self._orig_dir = os.getcwd() - os.chdir(self.test_dir) - - def tearDown(self): - """ - Clean up the test environment after each test method runs. - - Actions performed: - 1. Restores the original working directory. - 2. Deletes the temporary test directory along with all its contents. - """ - os.chdir(self._orig_dir) - - if os.path.exists(self.test_dir): - shutil.rmtree(self.test_dir, ignore_errors=True) diff --git a/tests/test_CodeEntropy/test_data_logger.py b/tests/test_CodeEntropy/test_data_logger.py deleted file mode 100644 index 9d657f13..00000000 --- a/tests/test_CodeEntropy/test_data_logger.py +++ /dev/null @@ -1,136 +0,0 @@ -import json -import unittest - -import numpy as np -import pandas as pd - -from CodeEntropy.config.data_logger import DataLogger -from CodeEntropy.config.logging_config import LoggingConfig -from CodeEntropy.main import main -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestDataLogger(BaseTestCase): - """ - Unit tests for the DataLogger class. - """ - - def setUp(self): - super().setUp() - self.code_entropy = main - self.logger = DataLogger() - self.output_file = "test_output.json" - - def test_init(self): - """ - Test that the DataLogger initializes with empty molecule and residue data lists. - """ - self.assertEqual(self.logger.molecule_data, []) - self.assertEqual(self.logger.residue_data, []) - - def test_add_results_data(self): - """ - Test that add_results_data correctly appends a molecule-level entry. - """ - self.logger.add_results_data( - 0, "united_atom", "Transvibrational", 653.4041220313459 - ) - self.assertEqual( - self.logger.molecule_data, - [(0, "united_atom", "Transvibrational", 653.4041220313459)], - ) - - def test_add_residue_data(self): - """ - Test that add_residue_data correctly appends a residue-level entry. - """ - self.logger.add_residue_data( - 0, "DA", "united_atom", "Transvibrational", 10, 122.61216935211893 - ) - self.assertEqual( - self.logger.residue_data, - [[0, "DA", "united_atom", "Transvibrational", 10, 122.61216935211893]], - ) - - def test_add_residue_data_with_numpy_array(self): - """ - Test that add_residue_data correctly converts a NumPy array to a list. - """ - frame_array = np.array([10]) - self.logger.add_residue_data( - 1, "DT", "united_atom", "Transvibrational", frame_array, 98.123456789 - ) - self.assertEqual( - self.logger.residue_data, - [[1, "DT", "united_atom", "Transvibrational", [10], 98.123456789]], - ) - - def test_save_dataframes_as_json(self): - """ - Test that save_dataframes_as_json correctly writes molecule and residue data - to a JSON file with the expected structure and values. - """ - molecule_df = pd.DataFrame( - [ - { - "Molecule ID": 0, - "Level": "united_atom", - "Type": "Transvibrational (J/mol/K)", - "Result": 653.404, - }, - { - "Molecule ID": 1, - "Level": "united_atom", - "Type": "Rovibrational (J/mol/K)", - "Result": 236.081, - }, - ] - ) - residue_df = pd.DataFrame( - [ - { - "Molecule ID": 0, - "Residue": 0, - "Type": "Transvibrational (J/mol/K)", - "Result": 122.612, - }, - { - "Molecule ID": 1, - "Residue": 0, - "Type": "Conformational (J/mol/K)", - "Result": 6.845, - }, - ] - ) - - self.logger.save_dataframes_as_json(molecule_df, residue_df, self.output_file) - - with open(self.output_file, "r") as f: - data = json.load(f) - - self.assertIn("molecule_data", data) - self.assertIn("residue_data", data) - self.assertEqual(data["molecule_data"][0]["Type"], "Transvibrational (J/mol/K)") - self.assertEqual(data["residue_data"][0]["Residue"], 0) - - def test_log_tables_rich_output(self): - console = LoggingConfig.get_console() - - self.logger.add_results_data( - 0, "united_atom", "Transvibrational", 653.4041220313459 - ) - self.logger.add_residue_data( - 0, "DA", "united_atom", "Transvibrational", 10, 122.61216935211893 - ) - self.logger.add_group_label(0, "DA", 10, 100) - - self.logger.log_tables() - - output = console.export_text() - assert "Molecule Entropy Results" in output - assert "Residue Entropy Results" in output - assert "Group ID to Residue Label Mapping" in output - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_CodeEntropy/test_dihedral_tools.py b/tests/test_CodeEntropy/test_dihedral_tools.py deleted file mode 100644 index 99071f93..00000000 --- a/tests/test_CodeEntropy/test_dihedral_tools.py +++ /dev/null @@ -1,571 +0,0 @@ -from unittest.mock import MagicMock, patch - -from CodeEntropy.dihedral_tools import DihedralAnalysis -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestDihedralAnalysis(BaseTestCase): - """ - Unit tests for DihedralAnalysis. - """ - - def setUp(self): - super().setUp() - self.analysis = DihedralAnalysis() - - def test_get_dihedrals_united_atom(self): - """ - Test `_get_dihedrals` for 'united_atom' level. - - The function should: - - read dihedrals from `data_container.dihedrals` - - extract `.atoms` from each dihedral - - return a list of atom groups - - Expected behavior: - If dihedrals = [d1, d2, d3] and each dihedral has an `.atoms` - attribute, then the returned list must be: - [d1.atoms, d2.atoms, d3.atoms] - """ - data_container = MagicMock() - - # Mock dihedral objects with `.atoms` - d1 = MagicMock() - d1.atoms = "atoms1" - d2 = MagicMock() - d2.atoms = "atoms2" - d3 = MagicMock() - d3.atoms = "atoms3" - - data_container.dihedrals = [d1, d2, d3] - - result = self.analysis._get_dihedrals(data_container, level="united_atom") - - self.assertEqual(result, ["atoms1", "atoms2", "atoms3"]) - - def test_get_dihedrals_residue(self): - """ - Test `_get_dihedrals` for 'residue' level with 5 residues. - - The implementation: - - iterates over residues 4 → N - - for each, selects 4 bonded atom groups - - merges them using __add__ to form a single atom_group - - appends to result list - - For 5 residues (0–4), two dihedral groups should be created. - Expected: - - result of length 2 - - each item equal to the merged mock atom group - """ - data_container = MagicMock() - data_container.residues = [0, 1, 2, 3, 4] - - mock_atom_group = MagicMock() - mock_atom_group.__add__.return_value = mock_atom_group - - # Every MDAnalysis selection returns the same mock atom group - data_container.select_atoms.return_value = mock_atom_group - - result = self.analysis._get_dihedrals(data_container, level="residue") - - self.assertEqual(len(result), 2) - self.assertTrue(all(r is mock_atom_group for r in result)) - - def test_get_dihedrals_no_residue(self): - """ - Test `_get_dihedrals` for 'residue' level when fewer than - 4 residues exist (here: 3 residues). - - Expected: - - The function returns an empty list. - """ - data_container = MagicMock() - data_container.residues = [0, 1, 2] # Only 3 residues → too few - - result = self.analysis._get_dihedrals(data_container, level="residue") - - self.assertEqual(result, []) - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_identify_peaks_empty_dihedrals(self, Dihedral_patch): - """ - Test `_identify_peaks` returns an empty list when the - input dihedral list is empty. - - Expected: - - No angle extraction occurs. - - No histograms computed. - - Return value is an empty list. - """ - universe_operations = MagicMock() - analysis = DihedralAnalysis(universe_operations) - - peaks = analysis._identify_peaks( - data_container=MagicMock(), - molecules=[0], - dihedrals=[], - bin_width=10, - start=0, - end=360, - step=1, - ) - - assert peaks == [] - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_identify_peaks_negative_angles_become_positive(self, Dihedral_patch): - """ - Test that negative dihedral angles are converted into the - 0–360° range before histogramming. - - Scenario: - - A single dihedral produces a single angle: -15°. - - This should be converted to +345°. - - With 90° bins, it falls into the final bin → one peak. - - Expected: - - One peak detected. - - Peak center lies between 300° and 360°. - """ - universe_operations = MagicMock() - analysis = DihedralAnalysis(universe_operations) - - R = MagicMock() - R.results.angles = [[-15]] - Dihedral_patch.return_value.run.return_value = R - - mol = MagicMock() - mol.trajectory = [0] - universe_operations.get_molecule_container.return_value = mol - - peaks = analysis._identify_peaks( - MagicMock(), - [0], - dihedrals=[MagicMock()], - bin_width=90, - start=0, - end=360, - step=1, - ) - - assert len(peaks) == 1 - assert len(peaks[0]) == 1 - assert 300 <= peaks[0][0] <= 360 - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_identify_peaks_internal_peak_detection(self, Dihedral_patch): - """ - Test the detection of a peak located in a middle histogram bin. - - Scenario: - - Angles fall into bin #1 (45°, 50°, 55°). - - Bin 1 has higher population than its neighbors. - - Expected: - - Exactly one peak is detected. - """ - universe_operations = MagicMock() - analysis = DihedralAnalysis(universe_operations) - - R = MagicMock() - R.results.angles = [[45], [50], [55]] - Dihedral_patch.return_value.run.return_value = R - - mol = MagicMock() - mol.trajectory = [0, 1, 2] - universe_operations.get_molecule_container.return_value = mol - - peaks = analysis._identify_peaks( - MagicMock(), - [0], - dihedrals=[MagicMock()], - bin_width=90, - start=0, - end=360, - step=1, - ) - - assert len(peaks[0]) == 1 - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_identify_peaks_circular_boundary(self, Dihedral_patch): - """ - Test that `_identify_peaks` handles circular histogram boundaries - correctly when identifying peaks in the last bin. - - Setup: - - All angles are near 350°, falling into the final bin. - - Expected: - - The final bin is correctly identified as a peak. - """ - ops = MagicMock() - analysis = DihedralAnalysis(ops) - - R = MagicMock() - R.results.angles = [[350], [355], [349]] - Dihedral_patch.return_value.run.return_value = R - - mol = MagicMock() - mol.trajectory = [0, 1, 2] - ops.get_molecule_container.return_value = mol - - peaks = analysis._identify_peaks( - MagicMock(), - [0], - dihedrals=[MagicMock()], - bin_width=90, - start=0, - end=360, - step=1, - ) - - assert len(peaks[0]) == 1 - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_identify_peaks_circular_last_bin(self, Dihedral_patch): - """ - Test peak detection for circular histogram boundaries, where the - last bin compares against the first bin. - - Scenario: - - All angles near 350° fall into the final bin. - - Final bin should be considered a peak if it exceeds both - previous and first bins. - - Expected: - - One peak detected in the last bin. - """ - universe_operations = MagicMock() - analysis = DihedralAnalysis(universe_operations) - - R = MagicMock() - R.results.angles = [[350], [355], [349]] - Dihedral_patch.return_value.run.return_value = R - - mol = MagicMock() - mol.trajectory = [0, 1, 2] - universe_operations.get_molecule_container.return_value = mol - - peaks = analysis._identify_peaks( - MagicMock(), - [0], - dihedrals=[MagicMock()], - bin_width=90, - start=0, - end=360, - step=1, - ) - - assert len(peaks[0]) == 1 - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_assign_states_negative_angle_conversion(self, Dihedral_patch): - """ - Test `_assign_states` converts negative angles correctly and assigns - the dihedral to the nearest peak. - - Scenario: - - Angle returned = -10° → converted to 350°. - - Peak list contains [350]. - - Expected: - - Assigned state is "0". - """ - universe_operations = MagicMock() - analysis = DihedralAnalysis(universe_operations) - - R = MagicMock() - R.results.angles = [[-10]] - Dihedral_patch.return_value.run.return_value = R - - mol = MagicMock() - mol.trajectory = [0] - universe_operations.get_molecule_container.return_value = mol - - states = analysis._assign_states( - MagicMock(), - [0], - dihedrals=[MagicMock()], - peaks=[[350]], - start=0, - end=360, - step=1, - ) - - assert states == ["0"] - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_assign_states_closest_peak_selection(self, Dihedral_patch): - """ - Test that `_assign_states` selects the peak nearest to each dihedral - angle. - - Setup: - - Angle = 30°. - - Peaks = [20, 100]. - - Nearest peak = 20 (index 0). - - Expected: - - Returned state is ["0"]. - """ - ops = MagicMock() - analysis = DihedralAnalysis(ops) - - R = MagicMock() - R.results.angles = [[30]] - Dihedral_patch.return_value.run.return_value = R - - mol = MagicMock() - mol.trajectory = [0] - ops.get_molecule_container.return_value = mol - - states = analysis._assign_states( - MagicMock(), - [0], - dihedrals=[MagicMock()], - peaks=[[20, 100]], - start=0, - end=360, - step=1, - ) - - assert states == ["0"] - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_assign_states_closest_peak(self, Dihedral_patch): - """ - Test assignment to the correct peak based on minimum angular distance. - - Scenario: - - Angle = 30°. - - Peaks = [20, 100]. - - Closest peak is 20° → index 0. - - Expected: - - Returned state is "0". - """ - universe_operations = MagicMock() - analysis = DihedralAnalysis(universe_operations) - - R = MagicMock() - R.results.angles = [[30]] - Dihedral_patch.return_value.run.return_value = R - - mol = MagicMock() - mol.trajectory = [0] - universe_operations.get_molecule_container.return_value = mol - - states = analysis._assign_states( - MagicMock(), - [0], - dihedrals=[MagicMock()], - peaks=[[20, 100]], - start=0, - end=360, - step=1, - ) - - assert states == ["0"] - - @patch("CodeEntropy.dihedral_tools.Dihedral") - def test_assign_states_multiple_dihedrals(self, Dihedral_patch): - """ - Test concatenation of state labels across multiple dihedrals. - - Scenario: - - Two dihedrals, one frame: - dihedral 0 → 10° → closest peak 0 - dihedral 1 → 200° → closest peak 180 (index 0) - - Resulting frame state: "00". - - Expected: - - Returned list: ["00"]. - """ - universe_operations = MagicMock() - analysis = DihedralAnalysis(universe_operations) - - R = MagicMock() - R.results.angles = [[10, 200]] - Dihedral_patch.return_value.run.return_value = R - - mol = MagicMock() - mol.trajectory = [0] - universe_operations.get_molecule_container.return_value = mol - - peaks = [[0, 180], [180, 300]] - - states = analysis._assign_states( - MagicMock(), - [0], - dihedrals=[MagicMock(), MagicMock()], - peaks=peaks, - start=0, - end=360, - step=1, - ) - - assert states == ["00"] - - def test_assign_states_multiple_molecules(self): - """ - Test that `_assign_states` generates different conformational state - labels for different molecules when their dihedral angle trajectories - differ. - - Molecule 0 is mocked to produce an angle near peak 0. - Molecule 1 is mocked to produce an angle near peak 1. - - Expected: - The returned state list reflects these differences as - ["0", "1"]. - """ - - universe_operations = MagicMock() - analysis = DihedralAnalysis(universe_operations) - - mol1 = MagicMock() - mol1.trajectory = [0] - - mol2 = MagicMock() - mol2.trajectory = [0] - - universe_operations.get_molecule_container.side_effect = [mol1, mol2] - - # Two different R objects - R1 = MagicMock() - R1.results.angles = [[10]] # peak index 0 - - R2 = MagicMock() - R2.results.angles = [[200]] # peak index 1 - - peaks = [[0, 180]] - - # Patch where Dihedral is *used* - with patch("CodeEntropy.dihedral_tools.Dihedral") as Dihedral_patch: - instance = Dihedral_patch.return_value - instance.run.side_effect = [R1, R2] - - states = analysis._assign_states( - MagicMock(), - molecules=[0, 1], - dihedrals=[MagicMock()], - peaks=peaks, - start=0, - end=360, - step=1, - ) - - assert states == ["0", "1"] - - def test_build_states_united_atom_no_dihedrals(self): - """ - Test that UA-level state building produces empty state lists when no - dihedrals are found for any residue. - """ - ops = MagicMock() - analysis = DihedralAnalysis(ops) - - mol = MagicMock() - mol.residues = [MagicMock()] - ops.get_molecule_container.return_value = mol - ops.new_U_select_atom.return_value = MagicMock() - - analysis._get_dihedrals = MagicMock(return_value=[]) - analysis._identify_peaks = MagicMock(return_value=[]) - analysis._assign_states = MagicMock(return_value=[]) - - groups = {0: [0]} - levels = {0: ["united_atom"]} - - states_ua, states_res = analysis.build_conformational_states( - MagicMock(), levels, groups, start=0, end=360, step=1, bin_width=10 - ) - - assert states_ua[(0, 0)] == [] - - def test_build_states_united_atom_accumulate(self): - """ - Test that UA-level state building assigns states independently to each - residue and accumulates them correctly. - """ - ops = MagicMock() - analysis = DihedralAnalysis(ops) - - mol = MagicMock() - mol.residues = [MagicMock(), MagicMock()] - ops.get_molecule_container.return_value = mol - ops.new_U_select_atom.return_value = MagicMock() - - analysis._get_dihedrals = MagicMock(return_value=[1]) - analysis._identify_peaks = MagicMock(return_value=[[10]]) - analysis._assign_states = MagicMock(return_value=["A"]) - - groups = {0: [0]} - levels = {0: ["united_atom"]} - - states_ua, _ = analysis.build_conformational_states( - MagicMock(), levels, groups, start=0, end=360, step=1, bin_width=10 - ) - - assert states_ua[(0, 0)] == ["A"] - assert states_ua[(0, 1)] == ["A"] - - def test_build_states_residue_no_dihedrals(self): - """ - Test that residue-level state building returns an empty list when - `_get_dihedrals` reports no available dihedral groups. - """ - ops = MagicMock() - analysis = DihedralAnalysis(ops) - - mol = MagicMock() - mol.residues = [MagicMock()] - ops.get_molecule_container.return_value = mol - - analysis._get_dihedrals = MagicMock(return_value=[]) - analysis._identify_peaks = MagicMock(return_value=[]) - analysis._assign_states = MagicMock(return_value=[]) - - groups = {0: [0]} - levels = {0: ["residue"]} - - _, states_res = analysis.build_conformational_states( - MagicMock(), levels, groups, start=0, end=360, step=1, bin_width=10 - ) - - assert states_res[0] == [] - - def test_build_states_residue_accumulate(self): - """ - Test that residue-level state building delegates all molecules in a group - to a single `_assign_states` call, and stores its returned list directly. - - Expected: - _assign_states returns ["A", "B"], so states_res[0] == ["A", "B"]. - """ - ops = MagicMock() - analysis = DihedralAnalysis(ops) - - mol1 = MagicMock() - mol1.residues = [MagicMock()] - mol2 = MagicMock() - mol2.residues = [MagicMock()] - - ops.get_molecule_container.side_effect = [mol1, mol2] - - analysis._get_dihedrals = MagicMock(return_value=[1]) - analysis._identify_peaks = MagicMock(return_value=[[10]]) - - # One call for the whole group → one return value - analysis._assign_states = MagicMock(return_value=["A", "B"]) - - groups = {0: [0, 1]} - levels = {0: ["residue"], 1: ["residue"]} - - _, states_res = analysis.build_conformational_states( - MagicMock(), levels, groups, start=0, end=360, step=1, bin_width=10 - ) - - assert states_res[0] == ["A", "B"] diff --git a/tests/test_CodeEntropy/test_entropy.py b/tests/test_CodeEntropy/test_entropy.py deleted file mode 100644 index 72b54664..00000000 --- a/tests/test_CodeEntropy/test_entropy.py +++ /dev/null @@ -1,2135 +0,0 @@ -import logging -import math -import os -import shutil -import tempfile -import unittest -from unittest.mock import MagicMock, PropertyMock, call, patch - -import MDAnalysis as mda -import numpy as np -import numpy.linalg as la -import pytest - -import tests.data as data -from CodeEntropy.config.data_logger import DataLogger -from CodeEntropy.entropy import ( - ConformationalEntropy, - EntropyManager, - OrientationalEntropy, - VibrationalEntropy, -) -from CodeEntropy.levels import LevelManager -from CodeEntropy.main import main -from CodeEntropy.mda_universe_operations import UniverseOperations -from CodeEntropy.run import ConfigManager, RunManager -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestEntropyManager(BaseTestCase): - """ - Unit tests for EntropyManager. - """ - - def setUp(self): - super().setUp() - self.test_data_dir = os.path.dirname(data.__file__) - - # Disable MDAnalysis and commands file logging entirely - logging.getLogger("MDAnalysis").handlers = [logging.NullHandler()] - logging.getLogger("commands").handlers = [logging.NullHandler()] - - def test_execute_full_workflow(self): - # Setup universe and args - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - u = mda.Universe(tprfile, trrfile) - - args = MagicMock( - bin_width=0.1, temperature=300, selection_string="all", water_entropy=False - ) - run_manager = RunManager("mock_folder/job001") - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - dihedral_analysis = MagicMock() - entropy_manager = EntropyManager( - run_manager, - args, - u, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ) - - # Mocks for trajectory and molecules - entropy_manager._get_trajectory_bounds = MagicMock(return_value=(0, 10, 1)) - entropy_manager._get_number_frames = MagicMock(return_value=11) - entropy_manager._handle_water_entropy = MagicMock() - - mock_reduced_atom = MagicMock() - mock_reduced_atom.trajectory = [1] * 11 - - mock_groups = {0: [0], 1: [1], 2: [2]} - mock_levels = { - 0: ["united_atom", "polymer", "residue"], - 1: ["united_atom", "polymer", "residue"], - 2: ["united_atom", "polymer", "residue"], - } - - entropy_manager._initialize_molecules = MagicMock( - return_value=(mock_reduced_atom, 3, mock_levels, mock_groups) - ) - entropy_manager._level_manager.build_covariance_matrices = MagicMock( - return_value=( - "force_matrices", - "torque_matrices", - "forcetorque_avg", - "frame_counts", - ) - ) - entropy_manager._dihedral_analysis.build_conformational_states = MagicMock( - return_value=(["state_ua"], ["state_res"]) - ) - entropy_manager._compute_entropies = MagicMock() - entropy_manager._finalize_molecule_results = MagicMock() - entropy_manager._data_logger.log_tables = MagicMock() - - # Create mocks for VibrationalEntropy and ConformationalEntropy - ve = MagicMock() - ce = MagicMock() - - # Patch both VibrationalEntropy, ConformationalEntropy AND u.atoms.fragments - mock_molecule = MagicMock() - mock_molecule.residues = [] - - with ( - patch("CodeEntropy.entropy.VibrationalEntropy", return_value=ve), - patch("CodeEntropy.entropy.ConformationalEntropy", return_value=ce), - patch.object( - type(u.atoms), "fragments", new_callable=PropertyMock - ) as mock_fragments, - ): - mock_fragments.return_value = [mock_molecule] * 10 - entropy_manager.execute() - - # Assert the key calls happened with expected arguments - build_states = entropy_manager._dihedral_analysis.build_conformational_states - build_states.assert_called_once_with( - mock_reduced_atom, - mock_levels, - mock_groups, - 0, - 10, - 1, - args.bin_width, - ) - - entropy_manager._compute_entropies.assert_called_once_with( - mock_reduced_atom, - mock_levels, - mock_groups, - "force_matrices", - "torque_matrices", - "forcetorque_avg", - ["state_ua"], - ["state_res"], - "frame_counts", - 11, - ve, - ce, - ) - - entropy_manager._finalize_molecule_results.assert_called_once() - entropy_manager._data_logger.log_tables.assert_called_once() - - def test_execute_triggers_handle_water_entropy_minimal(self): - """ - Minimal test to ensure _handle_water_entropy line is executed. - """ - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - u = mda.Universe(tprfile, trrfile) - - args = MagicMock( - bin_width=0.1, temperature=300, selection_string="all", water_entropy=True - ) - run_manager = RunManager("mock_folder/job001") - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - dihedral_analysis = MagicMock() - entropy_manager = EntropyManager( - run_manager, - args, - u, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ) - - entropy_manager._get_trajectory_bounds = MagicMock(return_value=(0, 10, 1)) - entropy_manager._get_number_frames = MagicMock(return_value=11) - entropy_manager._initialize_molecules = MagicMock( - return_value=(MagicMock(), 3, {}, {0: [0]}) - ) - entropy_manager._level_manager.build_covariance_matrices = MagicMock( - return_value=( - "force_matrices", - "torque_matrices", - "forcetorque_avg", - "frame_counts", - ) - ) - entropy_manager._dihedral_analysis.build_conformational_states = MagicMock( - return_value=(["state_ua"], ["state_res"]) - ) - entropy_manager._compute_entropies = MagicMock() - entropy_manager._finalize_molecule_results = MagicMock() - entropy_manager._data_logger.log_tables = MagicMock() - - with ( - patch("CodeEntropy.entropy.VibrationalEntropy", return_value=MagicMock()), - patch( - "CodeEntropy.entropy.ConformationalEntropy", return_value=MagicMock() - ), - patch.object( - type(u.atoms), "fragments", new_callable=PropertyMock - ) as mock_fragments, - patch.object(u, "select_atoms") as mock_select_atoms, - patch.object( - entropy_manager, "_handle_water_entropy" - ) as mock_handle_water_entropy, - ): - mock_fragments.return_value = [MagicMock(residues=[MagicMock(resid=1)])] - mock_select_atoms.return_value = MagicMock(residues=[MagicMock(resid=1)]) - - entropy_manager.execute() - - mock_handle_water_entropy.assert_called_once() - - def test_water_entropy_sets_selection_string_when_all(self): - """ - If selection_string is 'all' and water entropy is enabled, - _handle_water_entropy should update it to 'not water'. - """ - mock_universe = MagicMock() - args = MagicMock(water_entropy=True, selection_string="all") - manager = EntropyManager( - MagicMock(), - args, - mock_universe, - DataLogger(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - manager._calculate_water_entropy = MagicMock() - manager._data_logger.add_group_label = MagicMock() - - water_groups = {0: [0, 1, 2]} - - manager._handle_water_entropy(0, 10, 1, water_groups) - - assert manager._args.selection_string == "not water" - manager._calculate_water_entropy.assert_called_once() - - def test_water_entropy_appends_to_custom_selection_string(self): - """ - If selection_string is custom and water entropy is enabled, - _handle_water_entropy appends ' and not water'. - """ - mock_universe = MagicMock() - args = MagicMock(water_entropy=True, selection_string="protein") - manager = EntropyManager( - MagicMock(), - args, - mock_universe, - DataLogger(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - manager._calculate_water_entropy = MagicMock() - manager._data_logger.add_group_label = MagicMock() - - water_groups = {0: [0, 1, 2]} - - manager._handle_water_entropy(0, 10, 1, water_groups) - - manager._calculate_water_entropy.assert_called_once() - assert args.selection_string == "protein and not water" - - def test_handle_water_entropy_returns_early(self): - """ - Verifies that _handle_water_entropy returns immediately if: - 1. water_groups is empty - 2. water_entropy is disabled - """ - mock_universe = MagicMock() - args = MagicMock(water_entropy=True, selection_string="protein") - manager = EntropyManager( - MagicMock(), - args, - mock_universe, - DataLogger(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - # Patch _calculate_water_entropy to track if called - manager._calculate_water_entropy = MagicMock() - - # Case 1: empty water_groups - result = manager._handle_water_entropy(0, 10, 1, {}) - assert result is None - manager._calculate_water_entropy.assert_not_called() - - # Case 2: water_entropy disabled - manager._args.water_entropy = False - result = manager._handle_water_entropy(0, 10, 1, {0: [0, 1, 2]}) - assert result is None - manager._calculate_water_entropy.assert_not_called() - - def test_initialize_molecules(self): - """ - Test _initialize_molecules returns expected tuple by mocking internal methods. - - - Ensures _get_reduced_universe is called and its return is used. - - Ensures _level_manager.select_levels is called with the reduced atom - selection. - - Ensures _group_molecules.grouping_molecules is called with the reduced atom - and grouping arg. - - Verifies the returned tuple matches the mocked values. - """ - - args = MagicMock( - bin_width=0.1, temperature=300, selection_string="all", water_entropy=False - ) - run_manager = RunManager("mock_folder/job001") - level_manager = LevelManager(MagicMock()) - data_logger = DataLogger() - group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - MagicMock(), - data_logger, - level_manager, - group_molecules, - MagicMock(), - MagicMock(), - ) - - # Mock dependencies - manager._get_reduced_universe = MagicMock(return_value="mock_reduced_atom") - manager._level_manager = MagicMock() - manager._level_manager.select_levels = MagicMock( - return_value=(5, ["level1", "level2"]) - ) - manager._group_molecules = MagicMock() - manager._group_molecules.grouping_molecules = MagicMock( - return_value=["groupA", "groupB"] - ) - manager._args = MagicMock() - manager._args.grouping = "custom_grouping" - - # Call the method under test - result = manager._initialize_molecules() - - # Assert calls - manager._get_reduced_universe.assert_called_once() - manager._level_manager.select_levels.assert_called_once_with( - "mock_reduced_atom" - ) - manager._group_molecules.grouping_molecules.assert_called_once_with( - "mock_reduced_atom", "custom_grouping" - ) - - # Assert return value - expected = ("mock_reduced_atom", 5, ["level1", "level2"], ["groupA", "groupB"]) - self.assertEqual(result, expected) - - def test_get_trajectory_bounds(self): - """ - Tests that `_get_trajectory_bounds` runs and returns expected types. - """ - - config_manager = ConfigManager() - - parser = config_manager.setup_argparse() - args, _ = parser.parse_known_args() - - entropy_manager = EntropyManager( - MagicMock(), - args, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - self.assertIsInstance(entropy_manager._args.start, int) - self.assertIsInstance(entropy_manager._args.end, int) - self.assertIsInstance(entropy_manager._args.step, int) - - self.assertEqual(entropy_manager._get_trajectory_bounds(), (0, 0, 1)) - - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - start=0, - end=-1, - step=1, - ), - ) - def test_get_number_frames(self, mock_args): - """ - Test `_get_number_frames` when the end index is -1. - - Ensures that the function correctly counts all frames from start to - the end of the trajectory. - """ - config_manager = ConfigManager() - parser = config_manager.setup_argparse() - args = parser.parse_args() - - # Mock universe with a trajectory of 10 frames - mock_universe = MagicMock() - mock_universe.trajectory = range(10) - - entropy_manager = EntropyManager( - MagicMock(), - args, - mock_universe, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - # Use _get_trajectory_bounds to convert end=-1 into the actual last frame - start, end, step = entropy_manager._get_trajectory_bounds() - number_frames = entropy_manager._get_number_frames(start, end, step) - - # Expect all frames to be counted - self.assertEqual(number_frames, 10) - - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - start=0, - end=20, - step=1, - ), - ) - def test_get_number_frames_sliced_trajectory(self, mock_args): - """ - Test `_get_number_frames` with a valid slicing range. - - Verifies that the function correctly calculates the number of frames - when slicing from 0 to 20 with a step of 1, expecting 21 frames. - """ - config_manager = ConfigManager() - parser = config_manager.setup_argparse() - args = parser.parse_args() - - # Mock universe with 30 frames - mock_universe = MagicMock() - mock_universe.trajectory = range(30) - - entropy_manager = EntropyManager( - MagicMock(), - args, - mock_universe, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - start, end, step = entropy_manager._get_trajectory_bounds() - number_frames = entropy_manager._get_number_frames(start, end, step) - - self.assertEqual(number_frames, 20) - - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - start=0, - end=-1, - step=5, - ), - ) - def test_get_number_frames_sliced_trajectory_step(self, mock_args): - """ - Test `_get_number_frames` with a step that skips frames. - - Ensures that the function correctly counts the number of frames - when a step size of 5 is applied. - """ - config_manager = ConfigManager() - parser = config_manager.setup_argparse() - args = parser.parse_args() - - # Mock universe with 20 frames - mock_universe = MagicMock() - mock_universe.trajectory = range(20) - - entropy_manager = EntropyManager( - MagicMock(), - args, - mock_universe, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - start, end, step = entropy_manager._get_trajectory_bounds() - number_frames = entropy_manager._get_number_frames(start, end, step) - - # Expect 20 frames divided by step of 5 = 4 frames - self.assertEqual(number_frames, 4) - - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - selection_string="all", - ), - ) - def test_get_reduced_universe_all(self, mock_args): - """ - Test `_get_reduced_universe` with 'all' selection. - - Verifies that the full universe is returned when the selection string - is set to 'all', and the number of atoms remains unchanged. - """ - # Load MDAnalysis Universe - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - u = mda.Universe(tprfile, trrfile) - - config_manager = ConfigManager() - - parser = config_manager.setup_argparse() - args = parser.parse_args() - - entropy_manager = EntropyManager( - MagicMock(), - args, - u, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - entropy_manager._get_reduced_universe() - - self.assertEqual(entropy_manager._universe.atoms.n_atoms, 254) - - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - selection_string="resname DA", - ), - ) - def test_get_reduced_universe_reduced(self, mock_args): - """ - Test `_get_reduced_universe` with a specific atom selection. - - Ensures that the reduced universe contains fewer atoms than the original - when a specific selection string is used. - """ - - # Load MDAnalysis Universe - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - u = mda.Universe(tprfile, trrfile) - - universe_operations = UniverseOperations() - - config_manager = ConfigManager() - run_manager = RunManager("mock_folder/job001") - - parser = config_manager.setup_argparse() - args = parser.parse_args() - - entropy_manager = EntropyManager( - run_manager, - args, - u, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - universe_operations, - ) - - reduced_u = entropy_manager._get_reduced_universe() - - # Assert that the reduced universe has fewer atoms - assert len(reduced_u.atoms) < len(u.atoms) - - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - selection_string="all", - ), - ) - def test_process_united_atom_entropy(self, selection_string_mock): - """ - Tests that `_process_united_atom_entropy` correctly logs global and - residue-level entropy results for a mocked molecular system. - """ - # Setup managers and arguments - args = MagicMock(bin_width=0.1, temperature=300, selection_string="all") - universe_operations = UniverseOperations() - run_manager = MagicMock(universe_operations) - level_manager = MagicMock() - data_logger = DataLogger() - group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - MagicMock(), - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - # Mock molecule container with residues and atoms - n_residues = 3 - mock_residues = [MagicMock(resname=f"RES{i}") for i in range(n_residues)] - mock_atoms_per_mol = 3 - mock_atoms = [MagicMock() for _ in range(mock_atoms_per_mol)] # per molecule - mol_container = MagicMock(residues=mock_residues, atoms=mock_atoms) - - # Create dummy matrices and states - force_matrix = {(0, i): np.eye(3) for i in range(n_residues)} - torque_matrix = {(0, i): np.eye(3) * 2 for i in range(n_residues)} - states = {(0, i): np.ones((10, 3)) for i in range(n_residues)} - - # Mock entropy calculators - ve = MagicMock() - ce = MagicMock() - ve.vibrational_entropy_calculation.side_effect = lambda m, t, temp, high: ( - 1.0 if t == "force" else 2.0 - ) - ce.conformational_entropy_calculation.return_value = 3.0 - - # Manually add the group label so group_id=0 exists - data_logger.add_group_label( - 0, - "_".join(f"RES{i}" for i in range(n_residues)), # label - n_residues, # residue_count - len(mock_atoms) * n_residues, # total atoms for the group - ) - - # Run the method - manager._process_united_atom_entropy( - group_id=0, - mol_container=mol_container, - ve=ve, - ce=ce, - level="united_atom", - force_matrix=force_matrix, - torque_matrix=torque_matrix, - states=states, - highest=True, - number_frames=10, - frame_counts={(0, i): 10 for i in range(n_residues)}, - ) - - # Check molecule-level results - df = data_logger.molecule_data - assert len(df) == 3 # Trans, Rot, Conf - - # Check residue-level results - residue_df = data_logger.residue_data - assert len(residue_df) == 3 * n_residues # 3 types per residue - - # Check that all expected types are present - expected_types = {"Transvibrational", "Rovibrational", "Conformational"} - actual_types = set(entry[2] for entry in df) - assert actual_types == expected_types - - residue_types = set(entry[3] for entry in residue_df) - assert residue_types == expected_types - - # Check group label logging - group_label = data_logger.group_labels[0] # Access by group_id key - assert group_label["label"] == "_".join(f"RES{i}" for i in range(n_residues)) - assert group_label["residue_count"] == n_residues - assert group_label["atom_count"] == len(mock_atoms) * n_residues - - def test_process_vibrational_only_levels(self): - """ - Tests that `_process_vibrational_entropy` correctly logs vibrational - entropy results for a known molecular system using MDAnalysis. - """ - # Load a known test universe - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - u = mda.Universe(tprfile, trrfile) - - # Setup managers and arguments - args = MagicMock(bin_width=0.1, temperature=300, selection_string="all") - run_manager = RunManager("mock_folder/job001") - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - u, - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - # Prepare mock molecule container - reduced_atom = manager._get_reduced_universe() - mol_container = universe_operations.get_molecule_container(reduced_atom, 0) - - # Simulate trajectory length - mol_container.trajectory = [None] * 10 # 10 frames - - # Create dummy matrices - force_matrix = np.eye(3) - torque_matrix = np.eye(3) * 2 - - # Mock entropy calculator - ve = MagicMock() - ve.vibrational_entropy_calculation.side_effect = [1.11, 2.22] - - forcetorque_matrix = np.eye(6) - - # Run the method - manager._process_vibrational_entropy( - group_id=0, - mol_container=mol_container, - number_frames=10, - ve=ve, - level="Vibrational", - force_matrix=force_matrix, - torque_matrix=torque_matrix, - forcetorque_matrix=forcetorque_matrix, - highest=True, - ) - - # Check that results were logged - df = data_logger.molecule_data - self.assertEqual(len(df), 2) # Transvibrational and Rovibrational - - expected_types = {"FTmat-Transvibrational", "FTmat-Rovibrational"} - actual_types = set(entry[2] for entry in df) - self.assertSetEqual(actual_types, expected_types) - - results = [entry[3] for entry in df] - self.assertIn(1.11, results) - self.assertIn(2.22, results) - - def test_process_vibrational_entropy_else_branch(self): - """ - Atomic unit test for EntropyManager._process_vibrational_entropy else-branch: - - forcetorque_matrix is None - - force/torque matrices are filtered - - ve.vibrational_entropy_calculation called for force & torque - - results logged as Transvibrational/Rovibrational - - group label added from mol_container residues/atoms - """ - manager = MagicMock() - manager._args = MagicMock(temperature=300) - - manager._level_manager = MagicMock() - manager._data_logger = MagicMock() - - force_matrix = np.eye(3) - torque_matrix = np.eye(3) * 2 - - filtered_force = np.eye(3) * 7 - filtered_torque = np.eye(3) * 9 - manager._level_manager.filter_zero_rows_columns.side_effect = [ - filtered_force, - filtered_torque, - ] - - ve = MagicMock() - ve.vibrational_entropy_calculation.side_effect = [1.11, 2.22] - - res1 = MagicMock(resname="ALA") - res2 = MagicMock(resname="GLY") - res3 = MagicMock(resname="ALA") - mol_container = MagicMock() - mol_container.residues = [res1, res2, res3] - mol_container.atoms = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] - - EntropyManager._process_vibrational_entropy( - manager, - group_id=0, - mol_container=mol_container, - number_frames=10, - ve=ve, - level="Vibrational", - force_matrix=force_matrix, - torque_matrix=torque_matrix, - forcetorque_matrix=None, - highest=True, - ) - - filter_calls = manager._level_manager.filter_zero_rows_columns.call_args_list - assert len(filter_calls) == 2 - - np.testing.assert_array_equal(filter_calls[0].args[0], force_matrix) - np.testing.assert_array_equal(filter_calls[1].args[0], torque_matrix) - - ve_calls = ve.vibrational_entropy_calculation.call_args_list - assert len(ve_calls) == 2 - - np.testing.assert_array_equal(ve_calls[0].args[0], filtered_force) - assert ve_calls[0].args[1:] == ("force", 300, True) - - np.testing.assert_array_equal(ve_calls[1].args[0], filtered_torque) - assert ve_calls[1].args[1:] == ("torque", 300, True) - - manager._data_logger.add_results_data.assert_any_call( - 0, "Vibrational", "Transvibrational", 1.11 - ) - manager._data_logger.add_results_data.assert_any_call( - 0, "Vibrational", "Rovibrational", 2.22 - ) - manager._data_logger.add_group_label.assert_called_once_with(0, "ALA_GLY", 3, 4) - - def test_compute_entropies_polymer_branch(self): - """ - Test _compute_entropies triggers _process_vibrational_entropy for 'polymer' - level. - """ - args = MagicMock(bin_width=0.1) - run_manager = MagicMock() - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - MagicMock(), - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - reduced_atom = MagicMock() - number_frames = 5 - groups = {0: [0]} # One molecule only - levels = [["polymer"]] # One level for that molecule - - force_matrices = {"poly": {0: np.eye(3)}} - torque_matrices = {"poly": {0: np.eye(3) * 2}} - states_ua = {} - states_res = [] - frame_counts = 10 - - mol_mock = MagicMock() - mol_mock.residues = [] - universe_operations.get_molecule_container = MagicMock(return_value=mol_mock) - manager._process_vibrational_entropy = MagicMock() - - ve = MagicMock() - ve.vibrational_entropy_calculation.side_effect = [1.11] - - ce = MagicMock() - ce.conformational_entropy_calculation.return_value = 3.33 - - manager._compute_entropies( - reduced_atom, - levels, - groups, - force_matrices, - torque_matrices, - force_matrices, - states_ua, - states_res, - frame_counts, - number_frames, - ve, - ce, - ) - - manager._process_vibrational_entropy.assert_called_once() - - def test_process_conformational_residue_level(self): - """ - Tests that `_process_conformational_entropy` correctly logs conformational - entropy results at the residue level for a known molecular system using - MDAnalysis. - """ - # Load a known test universe - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - u = mda.Universe(tprfile, trrfile) - - # Setup managers and arguments - args = MagicMock(bin_width=0.1, temperature=300, selection_string="all") - run_manager = RunManager("mock_folder/job001") - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - u, - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - # Create dummy states - states = {0: np.ones((10, 3))} - - # Mock entropy calculator - ce = MagicMock() - ce.conformational_entropy_calculation.return_value = 3.33 - - # Run the method - manager._process_conformational_entropy( - group_id=0, - mol_container=MagicMock(), - ce=ce, - level="residue", - states=states, - number_frames=10, - ) - - # Check that results were logged - df = data_logger.molecule_data - self.assertEqual(len(df), 1) - - expected_types = {"Conformational"} - actual_types = set(entry[2] for entry in df) - self.assertSetEqual(actual_types, expected_types) - - results = [entry[3] for entry in df] - self.assertIn(3.33, results) - - def test_process_conformational_entropy_no_states_entry(self): - """ - Tests that `_process_conformational_entropy` logs zero entropy when - the group_id is not present in the states dictionary. - """ - # Setup minimal mock universe - u = MagicMock() - - # Setup managers and arguments - args = MagicMock() - universe_operations = MagicMock() - run_manager = MagicMock() - level_manager = MagicMock() - data_logger = DataLogger() - group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - u, - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - # States dict does NOT contain group_id=1 - states = {0: np.ones((10, 3))} - - # Mock entropy calculator - ce = MagicMock() - - # Run method with group_id=1 (not in states) - manager._process_conformational_entropy( - group_id=1, - mol_container=MagicMock(), - ce=ce, - level="residue", - states=states, - number_frames=10, - ) - - # Assert entropy is zero - self.assertEqual(data_logger.molecule_data[0][3], 0) - - # Assert calculator was not called - ce.conformational_entropy_calculation.assert_not_called() - - def test_compute_entropies_united_atom(self): - """ - Test that _process_united_atom_entropy is called correctly for 'united_atom' - level with highest=False when it's the only level. - """ - args = MagicMock(bin_width=0.1) - universe_operations = UniverseOperations() - run_manager = MagicMock() - level_manager = MagicMock() - data_logger = DataLogger() - group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - MagicMock(), - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - reduced_atom = MagicMock() - number_frames = 10 - groups = {0: [0]} - levels = [["united_atom"]] # single level - - force_matrices = {"ua": {0: "force_ua"}} - torque_matrices = {"ua": {0: "torque_ua"}} - states_ua = {} - states_res = [] - frame_counts = {"ua": {(0, 0): 10}} - - mol_mock = MagicMock() - mol_mock.residues = [] - universe_operations.get_molecule_container = MagicMock(return_value=mol_mock) - manager._process_united_atom_entropy = MagicMock() - - force_torque_matrices = MagicMock() - - ve = MagicMock() - ce = MagicMock() - - manager._compute_entropies( - reduced_atom, - levels, - groups, - force_matrices, - torque_matrices, - force_torque_matrices, - states_ua, - states_res, - frame_counts, - number_frames, - ve, - ce, - ) - - manager._process_united_atom_entropy.assert_called_once_with( - 0, - mol_mock, - ve, - ce, - "united_atom", - force_matrices["ua"], - torque_matrices["ua"], - states_ua, - frame_counts["ua"], - True, # highest is True since only level - number_frames, - ) - - def test_compute_entropies_residue(self): - """ - Test that _process_vibrational_entropy and _process_conformational_entropy - are called correctly for 'residue' level with highest=True when it's the - only level. - """ - # Setup - args = MagicMock(bin_width=0.1) - universe_operations = UniverseOperations() - run_manager = MagicMock() - level_manager = MagicMock() - data_logger = DataLogger() - group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - MagicMock(), - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - reduced_atom = MagicMock() - number_frames = 10 - groups = {0: [0]} - levels = [["residue"]] # single level - - force_matrices = {"res": {0: "force_res"}} - torque_matrices = {"res": {0: "torque_res"}} - states_ua = {} - states_res = ["states_res"] - - # Frame counts for residue level - frame_counts = {"res": {(0, 0): 10}} - - # Mock molecule - mol_mock = MagicMock() - mol_mock.residues = [] - universe_operations.get_molecule_container = MagicMock(return_value=mol_mock) - manager._process_vibrational_entropy = MagicMock() - manager._process_conformational_entropy = MagicMock() - - force_torque_matrices = MagicMock() - - # Mock entropy calculators - ve = MagicMock() - ce = MagicMock() - - # Call the method under test - manager._compute_entropies( - reduced_atom, - levels, - groups, - force_matrices, - torque_matrices, - force_torque_matrices, - states_ua, - states_res, - frame_counts, - number_frames, - ve, - ce, - ) - - # Assert that the per-level processing methods were called - manager._process_vibrational_entropy.assert_called() - manager._process_conformational_entropy.assert_called() - - def test_compute_entropies_polymer(self): - args = MagicMock(bin_width=0.1) - universe_operations = UniverseOperations() - run_manager = MagicMock() - level_manager = MagicMock() - data_logger = DataLogger() - group_molecules = MagicMock() - dihedral_analysis = MagicMock() - - manager = EntropyManager( - run_manager, - args, - MagicMock(), - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ) - - reduced_atom = MagicMock() - number_frames = 10 - groups = {0: [0]} - levels = [["polymer"]] - - force_matrices = {"poly": {0: "force_poly"}} - torque_matrices = {"poly": {0: "torque_poly"}} - force_torque_matrices = {"poly": {0: "ft_poly"}} - - states_ua = {} - states_res = [] - frame_counts = {"poly": {(0, 0): 10}} - - mol_mock = MagicMock() - mol_mock.residues = [] - universe_operations.get_molecule_container = MagicMock(return_value=mol_mock) - manager._process_vibrational_entropy = MagicMock() - - ve = MagicMock() - ce = MagicMock() - - manager._compute_entropies( - reduced_atom, - levels, - groups, - force_matrices, - torque_matrices, - force_torque_matrices, - states_ua, - states_res, - frame_counts, - number_frames, - ve, - ce, - ) - - manager._process_vibrational_entropy.assert_called_once_with( - 0, - mol_mock, - number_frames, - ve, - "polymer", - force_matrices["poly"][0], - torque_matrices["poly"][0], - force_torque_matrices["poly"][0], - True, - ) - - def test_finalize_molecule_results_aggregates_and_logs_total_entropy(self): - """ - Tests that `_finalize_molecule_results` correctly aggregates entropy values per - molecule from `molecule_data`, appends a 'Group Total' entry, and calls - `save_dataframes_as_json` with the expected DataFrame structure. - """ - # Setup - args = MagicMock(output_file="mock_output.json") - data_logger = DataLogger() - data_logger.molecule_data = [ - ("mol1", "united_atom", "Transvibrational", 1.0), - ("mol1", "united_atom", "Rovibrational", 2.0), - ("mol1", "united_atom", "Conformational", 3.0), - ("mol2", "polymer", "Transvibrational", 4.0), - ] - data_logger.residue_data = [] - - manager = EntropyManager(None, args, None, data_logger, None, None, None, None) - - # Patch save method - data_logger.save_dataframes_as_json = MagicMock() - - # Execute - manager._finalize_molecule_results() - - # Check that totals were added - totals = [ - entry for entry in data_logger.molecule_data if entry[1] == "Group Total" - ] - self.assertEqual(len(totals), 2) - - # Check correct aggregation - mol1_total = next(entry for entry in totals if entry[0] == "mol1")[3] - mol2_total = next(entry for entry in totals if entry[0] == "mol2")[3] - self.assertEqual(mol1_total, 6.0) - self.assertEqual(mol2_total, 4.0) - - # Check save was called - data_logger.save_dataframes_as_json.assert_called_once() - - @patch("CodeEntropy.entropy.logger") - def test_finalize_molecule_results_skips_invalid_entries(self, mock_logger): - """ - Tests that `_finalize_molecule_results` skips entries with non-numeric entropy - values and logs a warning without raising an exception. - """ - args = MagicMock(output_file="mock_output.json") - data_logger = DataLogger() - data_logger.molecule_data = [ - ("mol1", "united_atom", "Transvibrational", 1.0), - ( - "mol1", - "united_atom", - "Rovibrational", - "not_a_number", - ), # Should trigger ValueError - ("mol1", "united_atom", "Conformational", 2.0), - ] - data_logger.residue_data = [] - - manager = EntropyManager(None, args, None, data_logger, None, None, None, None) - - # Patch save method - data_logger.save_dataframes_as_json = MagicMock() - - # Run the method - manager._finalize_molecule_results() - - # Check that only valid values were aggregated - totals = [ - entry for entry in data_logger.molecule_data if entry[1] == "Group Total" - ] - self.assertEqual(len(totals), 1) - self.assertEqual(totals[0][3], 3.0) # 1.0 + 2.0 - - # Check that a warning was logged - mock_logger.warning.assert_called_once_with( - "Skipping invalid entry: mol1, not_a_number" - ) - - -class TestVibrationalEntropy(unittest.TestCase): - """ - Unit tests for the functionality of Vibrational entropy calculations. - """ - - def setUp(self): - """ - Set up test environment. - """ - self.test_dir = tempfile.mkdtemp(prefix="CodeEntropy_") - self.test_data_dir = os.path.dirname(data.__file__) - self.code_entropy = main - - # Change to test directory - self._orig_dir = os.getcwd() - os.chdir(self.test_dir) - - self.entropy_manager = EntropyManager( - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - def tearDown(self): - """ - Clean up after each test. - """ - os.chdir(self._orig_dir) - if os.path.exists(self.test_dir): - shutil.rmtree(self.test_dir) - - def test_vibrational_entropy_init(self): - """ - Test initialization of the `VibrationalEntropy` class. - - Verifies that the object is correctly instantiated and that key arguments - such as temperature and bin width are properly assigned. - """ - # Mock dependencies - universe = MagicMock() - args = MagicMock() - args.bin_width = 0.1 - args.temperature = 300 - args.selection_string = "all" - - run_manager = RunManager("mock_folder/job001") - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - dihedral_analysis = MagicMock() - - # Instantiate VibrationalEntropy - ve = VibrationalEntropy( - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - dihedral_analysis, - universe_operations, - ) - - # Basic assertions to check initialization - self.assertIsInstance(ve, VibrationalEntropy) - self.assertEqual(ve._args.temperature, 300) - self.assertEqual(ve._args.bin_width, 0.1) - - # test when lambda is zero - def test_frequency_calculation_0(self): - """ - Test `frequency_calculation` with zero eigenvalue. - - Ensures that the method returns 0 when the input eigenvalue (lambda) is zero. - """ - lambdas = [0] - temp = 298 - - run_manager = RunManager("mock_folder/job001") - - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - frequencies = ve.frequency_calculation(lambdas, temp) - - assert np.allclose(frequencies, [0.0]) - - def test_frequency_calculation_positive(self): - """ - Test `frequency_calculation` with positive eigenvalues. - - Verifies that the method correctly computes frequencies from a set of - positive eigenvalues at a given temperature. - """ - lambdas = np.array([585495.0917897299, 658074.5130064893, 782425.305888707]) - temp = 298 - - # Create a mock RunManager and set return value for get_KT2J - run_manager = RunManager("mock_folder/job001") - - # Instantiate VibrationalEntropy with mocks - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - # Call the method under test - frequencies = ve.frequency_calculation(lambdas, temp) - - assert frequencies == pytest.approx( - [1899594266400.4016, 2013894687315.6213, 2195940987139.7097] - ) - - def test_frequency_calculation_filters_invalid(self): - """ - Test `frequency_calculation` filters out invalid eigenvalues. - - Ensures that negative, complex, and near-zero eigenvalues are excluded, - and frequencies are calculated only for valid ones. - """ - lambdas = np.array( - [585495.0917897299, -658074.5130064893, 0.0, 782425.305888707] - ) - temp = 298 - - # Create a mock RunManager and set return value for get_KT2J - run_manager = MagicMock() - run_manager.get_KT2J.return_value = 2.479e-21 # example value in Joules - - # Instantiate VibrationalEntropy with mocks - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - # Call the method - frequencies = ve.frequency_calculation(lambdas, temp) - - # Expected: only two valid eigenvalues used - expected_lambdas = np.array([585495.0917897299, 782425.305888707]) - expected_frequencies = ( - 1 - / (2 * np.pi) - * np.sqrt(expected_lambdas / run_manager.get_KT2J.return_value) - ) - - # Assert frequencies match expected - np.testing.assert_allclose(frequencies, expected_frequencies, rtol=1e-5) - - def test_frequency_calculation_filters_invalid_with_warning(self): - """ - Test `frequency_calculation` filters out invalid eigenvalues and logs a warning. - - Ensures that negative, complex, and near-zero eigenvalues are excluded, - and a warning is logged about the exclusions. - """ - lambdas = np.array( - [585495.0917897299, -658074.5130064893, 0.0, 782425.305888707] - ) - temp = 298 - - run_manager = MagicMock() - run_manager.get_KT2J.return_value = 2.479e-21 # example value - - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - with self.assertLogs("CodeEntropy.entropy", level="WARNING") as cm: - frequencies = ve.frequency_calculation(lambdas, temp) - - # Check that warning was logged - warning_messages = "\n".join(cm.output) - self.assertIn("invalid eigenvalues excluded", warning_messages) - - # Check that only valid frequencies are returned - expected_lambdas = np.array([585495.0917897299, 782425.305888707]) - expected_frequencies = ( - 1 - / (2 * np.pi) - * np.sqrt(expected_lambdas / run_manager.get_KT2J.return_value) - ) - np.testing.assert_allclose(frequencies, expected_frequencies, rtol=1e-5) - - def test_vibrational_entropy_calculation_force_not_highest(self): - """ - Test `vibrational_entropy_calculation` for a force matrix with - `highest_level=False`. - - Verifies that the entropy is correctly computed using mocked frequency values - and a dummy identity matrix, excluding the first six modes. - """ - # Mock RunManager - run_manager = MagicMock() - run_manager.change_lambda_units.return_value = np.array([1e-20] * 12) - run_manager.get_KT2J.return_value = 2.47e-21 - - # Instantiate VibrationalEntropy with mocks - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - # Patch frequency_calculation to return known frequencies - ve.frequency_calculation = MagicMock(return_value=np.array([1.0] * 12)) - - # Create a dummy 12x12 matrix - matrix = np.identity(12) - - # Run the method - result = ve.vibrational_entropy_calculation( - matrix=matrix, matrix_type="force", temp=298, highest_level=False - ) - - # Manually compute expected entropy components - exponent = ve._PLANCK_CONST * 1.0 / 2.47e-21 - power_positive = np.exp(exponent) - power_negative = np.exp(-exponent) - S_component = exponent / (power_positive - 1) - np.log(1 - power_negative) - S_component *= ve._GAS_CONST - expected = S_component * 6 # sum of components[6:] - - self.assertAlmostEqual(result, expected, places=5) - - def test_vibrational_entropy_polymer_force(self): - """ - Test `vibrational_entropy_calculation` with a real force matrix and - `highest_level='yes'`. - - Ensures that the entropy is computed correctly for a small polymer system - using a known force matrix and temperature. - """ - matrix = np.array( - [ - [4.67476, -0.04069, -0.19714], - [-0.04069, 3.86300, -0.17922], - [-0.19714, -0.17922, 3.66307], - ] - ) - matrix_type = "force" - temp = 298 - highest_level = "yes" - - run_manager = RunManager("mock_folder/job001") - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - S_vib = ve.vibrational_entropy_calculation( - matrix, matrix_type, temp, highest_level - ) - - assert S_vib == pytest.approx(52.88123410327823) - - def test_vibrational_entropy_polymer_torque(self): - """ - Test `vibrational_entropy_calculation` with a torque matrix and - `highest_level='yes'`. - - Verifies that the entropy is computed correctly for a torque matrix, - simulating rotational degrees of freedom. - """ - matrix = np.array( - [ - [6.69611, 0.39754, 0.57763], - [0.39754, 4.63265, 0.38648], - [0.57763, 0.38648, 6.34589], - ] - ) - matrix_type = "torque" - temp = 298 - highest_level = "yes" - - run_manager = RunManager("mock_folder/job001") - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - S_vib = ve.vibrational_entropy_calculation( - matrix, matrix_type, temp, highest_level - ) - - assert S_vib == pytest.approx(48.45003266069881) - - def test_vibrational_entropy_calculation_forcetorqueTRANS(self): - """ - Test for matrix_type='forcetorqueTRANS': - - verifies S_vib_total = sum(S_components[:3]) - """ - run_manager = MagicMock() - run_manager.change_lambda_units.side_effect = lambda x: x - kT = 2.47e-21 - run_manager.get_KT2J.return_value = kT - - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - orig_eigvals = la.eigvals - la.eigvals = lambda m: np.array( - [1.0] * 6 - ) # length 6 -> 6 frequencies/components - - try: - freqs = np.array([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]) - ve.frequency_calculation = MagicMock(return_value=freqs) - - matrix = np.identity(6) - - result = ve.vibrational_entropy_calculation( - matrix=matrix, - matrix_type="forcetorqueTRANS", - temp=298, - highest_level=True, - ) - - sorted_freqs = np.sort(freqs) - exponent = ve._PLANCK_CONST * sorted_freqs / kT - power_positive = np.exp(exponent) - power_negative = np.exp(-exponent) - S_components = exponent / (power_positive - 1) - np.log(1 - power_negative) - S_components *= ve._GAS_CONST - - expected = float(np.sum(S_components[:3])) - self.assertAlmostEqual(result, expected, places=6) - - finally: - la.eigvals = orig_eigvals - - def test_vibrational_entropy_calculation_forcetorqueROT(self): - """ - Test for matrix_type='forcetorqueROT': - - verifies S_vib_total = sum(S_components[3:]) - """ - run_manager = MagicMock() - run_manager.change_lambda_units.side_effect = lambda x: x - kT = 2.47e-21 - run_manager.get_KT2J.return_value = kT - - ve = VibrationalEntropy( - run_manager, - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - MagicMock(), - ) - - orig_eigvals = la.eigvals - la.eigvals = lambda m: np.array([1.0] * 6) - - try: - freqs = np.array([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]) - ve.frequency_calculation = MagicMock(return_value=freqs) - - matrix = np.identity(6) - - result = ve.vibrational_entropy_calculation( - matrix=matrix, - matrix_type="forcetorqueROT", - temp=298, - highest_level=True, - ) - - sorted_freqs = np.sort(freqs) - exponent = ve._PLANCK_CONST * sorted_freqs / kT - power_positive = np.exp(exponent) - power_negative = np.exp(-exponent) - S_components = exponent / (power_positive - 1) - np.log(1 - power_negative) - S_components *= ve._GAS_CONST - - expected = float(np.sum(S_components[3:])) - self.assertAlmostEqual(result, expected, places=3) - - finally: - la.eigvals = orig_eigvals - - def test_calculate_water_orientational_entropy(self): - """ - Test that orientational entropy values are correctly extracted from Sorient_dict - and logged per residue. - """ - Sorient_dict = {1: {"mol1": [1.0, 2]}, 2: {"mol1": [3.0, 4]}} - group_id = 0 - - self.entropy_manager._data_logger = MagicMock() - - self.entropy_manager._calculate_water_orientational_entropy( - Sorient_dict, group_id - ) - - expected_calls = [ - call(group_id, "mol1", "Water", "Orientational", 2, 1.0), - call(group_id, "mol1", "Water", "Orientational", 4, 3.0), - ] - - self.entropy_manager._data_logger.add_residue_data.assert_has_calls( - expected_calls, any_order=False - ) - assert self.entropy_manager._data_logger.add_residue_data.call_count == 2 - - def test_calculate_water_vibrational_translational_entropy(self): - mock_vibrations = MagicMock() - mock_vibrations.translational_S = { - ("res1", 10): [1.0, 2.0], - ("resB_invalid", 10): 4.0, - ("res2", 10): 3.0, - } - mock_covariances = MagicMock() - mock_covariances.counts = { - ("res1", "WAT"): 10, - # resB_invalid and res2 will use default count = 1 - } - - group_id = 0 - self.entropy_manager._data_logger = MagicMock() - - self.entropy_manager._calculate_water_vibrational_translational_entropy( - mock_vibrations, group_id, mock_covariances - ) - - expected_calls = [ - call(group_id, "res1", "Water", "Transvibrational", 10, 3.0), - call(group_id, "resB", "Water", "Transvibrational", 1, 4.0), - call(group_id, "res2", "Water", "Transvibrational", 1, 3.0), - ] - - self.entropy_manager._data_logger.add_residue_data.assert_has_calls( - expected_calls, any_order=False - ) - assert self.entropy_manager._data_logger.add_residue_data.call_count == 3 - - def test_calculate_water_vibrational_rotational_entropy(self): - mock_vibrations = MagicMock() - mock_vibrations.rotational_S = { - ("resA_101", 14): [2.0, 3.0], - ("resB_invalid", 14): 4.0, - ("resC", 14): 5.0, - } - mock_covariances = MagicMock() - mock_covariances.counts = {("resA_101", "WAT"): 14} - - group_id = 0 - self.entropy_manager._data_logger = MagicMock() - - self.entropy_manager._calculate_water_vibrational_rotational_entropy( - mock_vibrations, group_id, mock_covariances - ) - - expected_calls = [ - call(group_id, "resA", "Water", "Rovibrational", 14, 5.0), - call(group_id, "resB", "Water", "Rovibrational", 1, 4.0), - call(group_id, "resC", "Water", "Rovibrational", 1, 5.0), - ] - - self.entropy_manager._data_logger.add_residue_data.assert_has_calls( - expected_calls, any_order=False - ) - assert self.entropy_manager._data_logger.add_residue_data.call_count == 3 - - def test_empty_vibrational_entropy_dicts(self): - mock_vibrations = MagicMock() - mock_vibrations.translational_S = {} - mock_vibrations.rotational_S = {} - - group_id = 0 - mock_covariances = MagicMock() - mock_covariances.counts = {} - - self.entropy_manager._data_logger = MagicMock() - - self.entropy_manager._calculate_water_vibrational_translational_entropy( - mock_vibrations, group_id, mock_covariances - ) - self.entropy_manager._calculate_water_vibrational_rotational_entropy( - mock_vibrations, group_id, mock_covariances - ) - - self.entropy_manager._data_logger.add_residue_data.assert_not_called() - - @patch( - "waterEntropy.recipes.interfacial_solvent.get_interfacial_water_orient_entropy" - ) - def test_calculate_water_entropy(self, mock_get_entropy): - mock_vibrations = MagicMock() - mock_vibrations.translational_S = {("res1", "mol1"): 2.0} - mock_vibrations.rotational_S = {("res1", "mol1"): 3.0} - - mock_get_entropy.return_value = ( - {1: {"mol1": [1.0, 5]}}, # orientational - MagicMock(counts={("res1", "WAT"): 1}), - mock_vibrations, - None, - 1, - ) - - mock_universe = MagicMock() - self.entropy_manager._data_logger = MagicMock() - - self.entropy_manager._calculate_water_entropy(mock_universe, 0, 10, 5) - - expected_calls = [ - call(None, "mol1", "Water", "Orientational", 5, 1.0), - call(None, "res1", "Water", "Transvibrational", 1, 2.0), - call(None, "res1", "Water", "Rovibrational", 1, 3.0), - ] - - self.entropy_manager._data_logger.add_residue_data.assert_has_calls( - expected_calls, any_order=False - ) - assert self.entropy_manager._data_logger.add_residue_data.call_count == 3 - - @patch( - "waterEntropy.recipes.interfacial_solvent.get_interfacial_water_orient_entropy" - ) - def test_calculate_water_entropy_minimal(self, mock_get_entropy): - mock_vibrations = MagicMock() - mock_vibrations.translational_S = {("ACE_1", "WAT"): 10.0} - mock_vibrations.rotational_S = {("ACE_1", "WAT"): 2.0} - - mock_get_entropy.return_value = ( - {}, # no orientational entropy - MagicMock(counts={("ACE_1", "WAT"): 1}), - mock_vibrations, - None, - 1, - ) - - mock_logger = MagicMock() - self.entropy_manager._data_logger = mock_logger - - mock_residue = MagicMock(resnames=["WAT"]) - mock_selection = MagicMock(residues=mock_residue, atoms=[MagicMock()]) - mock_universe = MagicMock() - mock_universe.select_atoms.return_value = mock_selection - - self.entropy_manager._calculate_water_entropy( - mock_universe, 0, 10, 1, group_id=None - ) - - mock_logger.add_group_label.assert_called_once_with( - None, "WAT", len(mock_selection.residues), len(mock_selection.atoms) - ) - - @patch( - "waterEntropy.recipes.interfacial_solvent.get_interfacial_water_orient_entropy" - ) - def test_calculate_water_entropy_adds_resname(self, mock_get_entropy): - mock_vibrations = MagicMock() - mock_vibrations.translational_S = {("res1", "WAT"): 2.0} - mock_vibrations.rotational_S = {("res1", "WAT"): 3.0} - - mock_get_entropy.return_value = ( - {1: {"WAT": [1.0, 5]}}, # orientational - MagicMock(counts={("res1", "WAT"): 1}), - mock_vibrations, - None, - 1, - ) - - mock_water_selection = MagicMock() - mock_residues_group = MagicMock() - mock_residues_group.resnames = ["WAT"] - mock_water_selection.residues = mock_residues_group - mock_water_selection.atoms = [1, 2, 3] - mock_universe = MagicMock() - mock_universe.select_atoms.return_value = mock_water_selection - - group_id = 0 - self.entropy_manager._data_logger = MagicMock() - - self.entropy_manager._calculate_water_entropy( - mock_universe, start=0, end=1, step=1, group_id=group_id - ) - - self.entropy_manager._data_logger.add_group_label.assert_called_with( - group_id, - "WAT", - len(mock_water_selection.residues), - len(mock_water_selection.atoms), - ) - - # TODO test for error handling on invalid inputs - - -class TestConformationalEntropy(unittest.TestCase): - """ - Unit tests for the functionality of conformational entropy calculations. - """ - - def setUp(self): - """ - Set up test environment. - """ - self.test_dir = tempfile.mkdtemp(prefix="CodeEntropy_") - self.test_data_dir = os.path.dirname(data.__file__) - self.code_entropy = main - - # Change to test directory - self._orig_dir = os.getcwd() - os.chdir(self.test_dir) - - def tearDown(self): - """ - Clean up after each test. - """ - os.chdir(self._orig_dir) - if os.path.exists(self.test_dir): - shutil.rmtree(self.test_dir) - - def test_confirmational_entropy_init(self): - """ - Test initialization of the `ConformationalEntropy` class. - - Verifies that the object is correctly instantiated and that key arguments - such as temperature and bin width are properly assigned during initialization. - """ - # Mock dependencies - universe = MagicMock() - args = MagicMock() - args.bin_width = 0.1 - args.temperature = 300 - args.selection_string = "all" - - run_manager = RunManager("mock_folder/job001") - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - - # Instantiate ConformationalEntropy - ce = ConformationalEntropy( - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - # Basic assertions to check initialization - self.assertIsInstance(ce, ConformationalEntropy) - self.assertEqual(ce._args.temperature, 300) - self.assertEqual(ce._args.bin_width, 0.1) - - def test_conformational_entropy_calculation(self): - """ - Test `conformational_entropy_calculation` method to verify - correct entropy calculation from a simple discrete state array. - """ - - # Setup managers and arguments - args = MagicMock(bin_width=0.1, temperature=300, selection_string="all") - run_manager = RunManager("mock_folder/job001") - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - - ce = ConformationalEntropy( - run_manager, - args, - MagicMock(), - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - # Create a simple array of states with known counts - states = np.array([0, 0, 1, 1, 1, 2]) # 2x state 0, 3x state 1, 1x state 2 - - # Manually compute expected entropy - probs = np.array([2 / 6, 3 / 6, 1 / 6]) - expected_entropy = -np.sum(probs * np.log(probs)) * ce._GAS_CONST - - # Run the method under test - result = ce.conformational_entropy_calculation(states) - - # Assert the result is close to expected entropy - self.assertAlmostEqual(result, expected_entropy, places=6) - - -class TestOrientationalEntropy(unittest.TestCase): - """ - Unit tests for the functionality of orientational entropy calculations. - """ - - def setUp(self): - """ - Set up test environment. - """ - self.test_dir = tempfile.mkdtemp(prefix="CodeEntropy_") - self.code_entropy = main - - # Change to test directory - self._orig_dir = os.getcwd() - os.chdir(self.test_dir) - - def tearDown(self): - """ - Clean up after each test. - """ - os.chdir(self._orig_dir) - if os.path.exists(self.test_dir): - shutil.rmtree(self.test_dir) - - def test_orientational_entropy_init(self): - """ - Test initialization of the `OrientationalEntropy` class. - - Verifies that the object is correctly instantiated and that key arguments - such as temperature and bin width are properly assigned during initialization. - """ - # Mock dependencies - universe = MagicMock() - args = MagicMock() - args.bin_width = 0.1 - args.temperature = 300 - args.selection_string = "all" - - run_manager = RunManager("mock_folder/job001") - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - data_logger = DataLogger() - group_molecules = MagicMock() - - # Instantiate OrientationalEntropy - oe = OrientationalEntropy( - run_manager, - args, - universe, - data_logger, - level_manager, - group_molecules, - MagicMock(), - universe_operations, - ) - - # Basic assertions to check initialization - self.assertIsInstance(oe, OrientationalEntropy) - self.assertEqual(oe._args.temperature, 300) - self.assertEqual(oe._args.bin_width, 0.1) - - def test_orientational_entropy_calculation(self): - """ - Tests that `orientational_entropy_calculation` correctly computes the total - orientational entropy for a given dictionary of neighboring species using - the internal gas constant. - """ - # Setup a mock neighbours dictionary - neighbours_dict = { - "ligandA": 2, - "ligandB": 3, - } - - # Create an instance of OrientationalEntropy with dummy dependencies - oe = OrientationalEntropy(None, None, None, None, None, None, None, None) - - # Run the method - result = oe.orientational_entropy_calculation(neighbours_dict) - - # Manually compute expected result using the class's internal gas constant - expected = ( - math.log(math.sqrt((2**3) * math.pi)) - + math.log(math.sqrt((3**3) * math.pi)) - ) * oe._GAS_CONST - - # Assert the result is as expected - self.assertAlmostEqual(result, expected, places=6) - - def test_orientational_entropy_water_branch_is_covered(self): - """ - Tests that the placeholder branch for water molecules is executed to ensure - coverage of the `if neighbour in [...]` block. - """ - neighbours_dict = {"H2O": 1} # Matches the condition exactly - - oe = OrientationalEntropy(None, None, None, None, None, None, None, None) - result = oe.orientational_entropy_calculation(neighbours_dict) - - # Since the logic is skipped, total entropy should be 0.0 - self.assertEqual(result, 0.0) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_CodeEntropy/test_group_molecules.py b/tests/test_CodeEntropy/test_group_molecules.py deleted file mode 100644 index 6f158058..00000000 --- a/tests/test_CodeEntropy/test_group_molecules.py +++ /dev/null @@ -1,78 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -import numpy as np - -from CodeEntropy.group_molecules import GroupMolecules -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestGroupMolecules(BaseTestCase): - """ - Unit tests for GroupMolecules. - """ - - def setUp(self): - super().setUp() - self.group_molecules = GroupMolecules() - - def test_by_none_returns_individual_groups(self): - """ - Test _by_none returns each molecule in its own group when grouping is 'each'. - """ - mock_universe = MagicMock() - # Simulate universe.atoms.fragments has 3 molecules - mock_universe.atoms.fragments = [MagicMock(), MagicMock(), MagicMock()] - - groups = self.group_molecules._by_none(mock_universe) - expected = {0: [0], 1: [1], 2: [2]} - self.assertEqual(groups, expected) - - def test_by_molecules_groups_by_chemical_type(self): - """ - Test _by_molecules groups molecules with identical atom counts and names - together. - """ - mock_universe = MagicMock() - - fragment0 = MagicMock() - fragment0.names = np.array(["H", "O", "H"]) - fragment1 = MagicMock() - fragment1.names = np.array(["H", "O", "H"]) - fragment2 = MagicMock() - fragment2.names = np.array(["C", "C", "H", "H"]) - - mock_universe.atoms.fragments = [fragment0, fragment1, fragment2] - - groups = self.group_molecules._by_molecules(mock_universe) - - # Expect first two grouped, third separate - self.assertIn(0, groups) - self.assertIn(2, groups) - self.assertCountEqual(groups[0], [0, 1]) - self.assertEqual(groups[2], [2]) - - def test_grouping_molecules_dispatches_correctly(self): - """ - Test grouping_molecules method dispatches to correct grouping strategy. - """ - mock_universe = MagicMock() - mock_universe.atoms.fragments = [MagicMock()] # Just 1 molecule to keep simple - - # When grouping='each', calls _by_none - groups = self.group_molecules.grouping_molecules(mock_universe, "each") - self.assertEqual(groups, {0: [0]}) - - # When grouping='molecules', calls _by_molecules (mock to test call) - self.group_molecules._by_molecules = MagicMock(return_value={"mocked": [42]}) - groups = self.group_molecules.grouping_molecules(mock_universe, "molecules") - self.group_molecules._by_molecules.assert_called_once_with(mock_universe) - self.assertEqual(groups, {"mocked": [42]}) - - # If grouping unknown, should return empty dict - groups = self.group_molecules.grouping_molecules(mock_universe, "unknown") - self.assertEqual(groups, {}) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_CodeEntropy/test_levels.py b/tests/test_CodeEntropy/test_levels.py deleted file mode 100644 index eb853df0..00000000 --- a/tests/test_CodeEntropy/test_levels.py +++ /dev/null @@ -1,1726 +0,0 @@ -from unittest.mock import MagicMock, patch - -import numpy as np - -from CodeEntropy.axes import AxesManager -from CodeEntropy.levels import LevelManager -from CodeEntropy.mda_universe_operations import UniverseOperations -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestLevels(BaseTestCase): - """ - Unit tests for Levels. - """ - - def setUp(self): - super().setUp() - - def test_select_levels(self): - """ - Test `select_levels` with a mocked data container containing two molecules: - - The first molecule has 2 atoms and 1 residue (should return 'united_atom' and - 'residue'). - - The second molecule has 3 atoms and 2 residues (should return all three - levels). - - Asserts that the number of molecules and the levels list match expected values. - """ - # Create a mock data_container - data_container = MagicMock() - - # Mock fragments (2 molecules) - fragment1 = MagicMock() - fragment2 = MagicMock() - - # Mock select_atoms return values - atoms1 = MagicMock() - atoms1.__len__.return_value = 2 - atoms1.residues = [1] # 1 residue - - atoms2 = MagicMock() - atoms2.__len__.return_value = 3 - atoms2.residues = [1, 2] # 2 residues - - fragment1.select_atoms.return_value = atoms1 - fragment2.select_atoms.return_value = atoms2 - - data_container.atoms.fragments = [fragment1, fragment2] - - universe_operations = UniverseOperations() - - # Import the class and call the method - level_manager = LevelManager(universe_operations) - number_molecules, levels = level_manager.select_levels(data_container) - - # Assertions - self.assertEqual(number_molecules, 2) - self.assertEqual( - levels, [["united_atom", "residue"], ["united_atom", "residue", "polymer"]] - ) - - def test_get_matrices(self): - """ - Atomic unit test for LevelManager.get_matrices: - - AxesManager is mocked - - No inertia / MDAnalysis math - - Verifies block matrix construction and shape only - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - # Two beads - bead1 = MagicMock() - bead1.principal_axes.return_value = np.ones(3) - - bead2 = MagicMock() - bead2.principal_axes.return_value = np.ones(3) - - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([0.5, 1.5, 2.5]) - ) - - # Deterministic 3x3 submatrix for every (i,j) call - I3 = np.identity(3) - level_manager.create_submatrix = MagicMock(return_value=I3) - - data_container = MagicMock() - data_container.atoms = MagicMock() - data_container.atoms.principal_axes.return_value = np.ones(3) - - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.eye(3) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_residue_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - force_matrix, torque_matrix = level_manager.get_matrices( - data_container=data_container, - level="residue", - highest_level=True, - force_matrix=None, - torque_matrix=None, - force_partitioning=0.5, - customised_axes=True, - ) - - # Shape: 2 beads × 3 dof => 6×6 - assert force_matrix.shape == (6, 6) - assert torque_matrix.shape == (6, 6) - - # Expected block structure when every block is I3: - expected = np.block([[I3, I3], [I3, I3]]) - np.testing.assert_array_equal(force_matrix, expected) - np.testing.assert_array_equal(torque_matrix, expected) - - # Lightweight behavioral assertions - level_manager.get_beads.assert_called_once_with(data_container, "residue") - assert axes.get_residue_axes.call_count == 2 - - # For 2 beads: (0,0), (0,1), (1,1) => 3 pairs; - # each pair calls create_submatrix twice (force+torque) - assert level_manager.create_submatrix.call_count == 6 - - def test_get_matrices_force_shape_mismatch(self): - """ - Test that get_matrices raises a ValueError when the provided force_matrix - has a shape mismatch with the computed force block matrix. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - # Two beads -> force_block will be 6x6 - bead1 = MagicMock() - bead1.principal_axes.return_value = np.ones(3) - - bead2 = MagicMock() - bead2.principal_axes.return_value = np.ones(3) - - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([0.5, 1.5, 2.5]) - ) - - level_manager.create_submatrix = MagicMock(return_value=np.identity(3)) - - data_container = MagicMock() - data_container.atoms = MagicMock() - data_container.atoms.principal_axes.return_value = np.ones(3) - - bad_force_matrix = np.zeros((3, 3)) - correct_torque_matrix = np.zeros((6, 6)) - - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.eye(3) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_residue_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - with self.assertRaises(ValueError) as context: - level_manager.get_matrices( - data_container=data_container, - level="residue", - highest_level=True, - force_matrix=bad_force_matrix, - torque_matrix=correct_torque_matrix, - force_partitioning=0.5, - customised_axes=True, - ) - - assert "force matrix shape" in str(context.exception) - - def test_get_matrices_torque_shape_mismatch(self): - """ - Test that get_matrices raises a ValueError when the provided torque_matrix - has a shape mismatch with the computed torque block matrix. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead1 = MagicMock() - bead1.principal_axes.return_value = np.ones(3) - - bead2 = MagicMock() - bead2.principal_axes.return_value = np.ones(3) - - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([0.5, 1.5, 2.5]) - ) - level_manager.create_submatrix = MagicMock(return_value=np.identity(3)) - - data_container = MagicMock() - data_container.atoms = MagicMock() - data_container.atoms.principal_axes.return_value = np.ones(3) - - correct_force_matrix = np.zeros((6, 6)) - bad_torque_matrix = np.zeros((3, 3)) # Incorrect shape (should be 6x6) - - # Mock AxesManager return tuple to satisfy unpacking - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.eye(3) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_residue_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - with self.assertRaises(ValueError) as context: - level_manager.get_matrices( - data_container=data_container, - level="residue", - highest_level=True, - force_matrix=correct_force_matrix, - torque_matrix=bad_torque_matrix, - force_partitioning=0.5, - customised_axes=True, - ) - - assert "torque matrix shape" in str(context.exception) - - def test_get_matrices_torque_consistency(self): - """ - Test that get_matrices returns consistent force and torque matrices - when called multiple times with the same inputs. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead1 = MagicMock() - bead1.principal_axes.return_value = np.ones(3) - - bead2 = MagicMock() - bead2.principal_axes.return_value = np.ones(3) - - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([0.5, 1.5, 2.5]) - ) - level_manager.create_submatrix = MagicMock(return_value=np.identity(3)) - - data_container = MagicMock() - data_container.atoms = MagicMock() - data_container.atoms.principal_axes.return_value = np.ones(3) - - initial_force_matrix = np.zeros((6, 6)) - initial_torque_matrix = np.zeros((6, 6)) - - # Mock AxesManager return tuple (unpacked by get_matrices) - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.eye(3) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_residue_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - force_matrix_1, torque_matrix_1 = level_manager.get_matrices( - data_container=data_container, - level="residue", - highest_level=True, - force_matrix=initial_force_matrix.copy(), - torque_matrix=initial_torque_matrix.copy(), - force_partitioning=0.5, - customised_axes=True, - ) - - force_matrix_2, torque_matrix_2 = level_manager.get_matrices( - data_container=data_container, - level="residue", - highest_level=True, - force_matrix=initial_force_matrix.copy(), - torque_matrix=initial_torque_matrix.copy(), - force_partitioning=0.5, - customised_axes=True, - ) - - np.testing.assert_array_equal(force_matrix_1, force_matrix_2) - np.testing.assert_array_equal(torque_matrix_1, torque_matrix_2) - - assert force_matrix_1.shape == (6, 6) - assert torque_matrix_1.shape == (6, 6) - - def test_get_matrices_united_atom_customised_axes(self): - """ - Test that: level='united_atom' with customised_axes=True - Verifies: - - UA axes path is taken - - block matrix shape is correct for 1 bead (3x3) - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead = MagicMock() - level_manager.get_beads = MagicMock(return_value=[bead]) - - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([0.5, 1.5, 2.5]) - ) - - I3 = np.identity(3) - level_manager.create_submatrix = MagicMock(return_value=I3) - - data_container = MagicMock() - data_container.atoms = MagicMock() - data_container.atoms.principal_axes.return_value = np.ones(3) - - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.array([1.0, 1.0, 1.0]) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_UA_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - force_matrix, torque_matrix = level_manager.get_matrices( - data_container=data_container, - level="united_atom", - highest_level=True, - force_matrix=None, - torque_matrix=None, - force_partitioning=0.5, - customised_axes=True, - ) - - assert force_matrix.shape == (3, 3) - assert torque_matrix.shape == (3, 3) - np.testing.assert_array_equal(force_matrix, I3) - np.testing.assert_array_equal(torque_matrix, I3) - - axes.get_UA_axes.assert_called_once() - assert axes.get_residue_axes.call_count == 0 - - def test_get_matrices_non_customised_axes_path_atomic(self): - """ - Tests that `customised_axes=False` triggers the non-customised axes path. - - Verifies that: - - translational axes are taken from `data_container.atoms.principal_axes()` - - rotational axes are taken from `bead.principal_axes()` (real-valued) - - bead moment of inertia and center of mass are queried - - force and torque matrices are assembled with size (3N, 3N) for N beads - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead1, bead2 = MagicMock(), MagicMock() - bead1.principal_axes.return_value = np.eye(3) * (1 + 2j) - bead2.principal_axes.return_value = np.eye(3) * (1 + 2j) - bead1.center_of_mass.return_value = np.zeros(3) - bead2.center_of_mass.return_value = np.zeros(3) - bead1.moment_of_inertia.return_value = np.eye(3) - bead2.moment_of_inertia.return_value = np.eye(3) - - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([0.5, 1.5, 2.5]) - ) - level_manager.create_submatrix = MagicMock(return_value=np.eye(3)) - - data_container = MagicMock() - data_container.atoms = MagicMock() - data_container.atoms.principal_axes.return_value = np.eye(3) - - with ( - patch("CodeEntropy.levels.make_whole", autospec=True), - patch( - "CodeEntropy.levels.np.linalg.eig", - return_value=(np.array([1.0, 3.0, 2.0]), None), - ), - ): - force_matrix, torque_matrix = level_manager.get_matrices( - data_container=data_container, - level="polymer", - highest_level=True, - force_matrix=None, - torque_matrix=None, - force_partitioning=0.5, - customised_axes=False, - ) - - data_container.atoms.principal_axes.assert_called() - bead1.principal_axes.assert_called() - bead2.principal_axes.assert_called() - bead1.center_of_mass.assert_called() - bead2.center_of_mass.assert_called() - bead1.moment_of_inertia.assert_called() - bead2.moment_of_inertia.assert_called() - - assert force_matrix.shape == (6, 6) - assert torque_matrix.shape == (6, 6) - - def test_get_matrices_accepts_existing_same_shape(self): - """ - Test that: if force_matrix and torque_matrix are provided with correct shape, - no error is raised and returned matrices match the newly computed blocks. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead1 = MagicMock() - bead2 = MagicMock() - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([0.5, 1.5, 2.5]) - ) - - I3 = np.identity(3) - level_manager.create_submatrix = MagicMock(return_value=I3) - - data_container = MagicMock() - data_container.atoms = MagicMock() - data_container.atoms.principal_axes.return_value = np.ones(3) - - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.array([1.0, 1.0, 1.0]) - - existing_force = np.zeros((6, 6)) - existing_torque = np.zeros((6, 6)) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_residue_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - force_matrix, torque_matrix = level_manager.get_matrices( - data_container=data_container, - level="residue", - highest_level=True, - force_matrix=existing_force, - torque_matrix=existing_torque, - force_partitioning=0.5, - customised_axes=True, - ) - - expected = np.block([[I3, I3], [I3, I3]]) - np.testing.assert_array_equal(force_matrix, expected) - np.testing.assert_array_equal(torque_matrix, expected) - - def test_get_combined_forcetorque_matrices_residue_customised_init(self): - """ - Test: level='residue', customised_axes=True uses AxesManager.get_residue_axes - and returns a (6N x 6N) block matrix for N beads. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead1 = MagicMock() - bead2 = MagicMock() - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - - wf = np.array([1.0, 2.0, 3.0]) - wt = np.array([4.0, 5.0, 6.0]) - level_manager.get_weighted_forces = MagicMock(return_value=wf) - level_manager.get_weighted_torques = MagicMock(return_value=wt) - - I6 = np.identity(6) - level_manager.create_FTsubmatrix = MagicMock(return_value=I6) - - data_container = MagicMock() - data_container.atoms = MagicMock() - - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.array([1.0, 1.0, 1.0]) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_residue_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - ft_matrix = level_manager.get_combined_forcetorque_matrices( - data_container=data_container, - level="residue", - highest_level=True, - forcetorque_matrix=None, - force_partitioning=0.5, - customised_axes=True, - ) - - assert ft_matrix.shape == (12, 12) - - expected = np.block([[I6, I6], [I6, I6]]) - np.testing.assert_array_equal(ft_matrix, expected) - - assert axes.get_residue_axes.call_count == 2 - assert level_manager.create_FTsubmatrix.call_count == 3 - - def test_get_combined_forcetorque_matrices_noncustomised_axes_path(self): - """ - Test that: customised_axes=False forces else-path: - - make_whole(data_container.atoms) and make_whole(bead) called - - trans_axes = data_container.atoms.principal_axes() - - rot_axes, moment_of_inertia = AxesManager.get_vanilla_axes(bead) - - center = bead.center_of_mass(unwrap=True) - - FT block matrix assembled via create_FTsubmatrix and np.block - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead1 = MagicMock(name="bead1") - bead2 = MagicMock(name="bead2") - beads = [bead1, bead2] - - level_manager.get_beads = MagicMock(return_value=beads) - - data_container = MagicMock(name="data_container") - data_container.atoms = MagicMock(name="atoms") - data_container.atoms.principal_axes.return_value = np.eye(3) - - # Forces/torques are 3-vectors -> concatenated to length 6 - level_manager.get_weighted_forces = MagicMock( - side_effect=[ - np.array([1.0, 2.0, 3.0]), - np.array([1.1, 2.1, 3.1]), - ] - ) - level_manager.get_weighted_torques = MagicMock( - side_effect=[ - np.array([4.0, 5.0, 6.0]), - np.array([4.1, 5.1, 6.1]), - ] - ) - - level_manager.create_FTsubmatrix = MagicMock(return_value=np.identity(6)) - - rot_axes_expected = np.eye(3) - moi_expected = np.array([3.0, 2.0, 1.0]) - - with ( - patch("CodeEntropy.levels.make_whole", autospec=True) as mw_mock, - patch( - "CodeEntropy.axes.AxesManager.get_vanilla_axes", - autospec=True, - return_value=(rot_axes_expected, moi_expected), - ) as vanilla_mock, - ): - bead1.center_of_mass.return_value = np.zeros(3) - bead2.center_of_mass.return_value = np.zeros(3) - - ft_matrix = level_manager.get_combined_forcetorque_matrices( - data_container=data_container, - level="polymer", - highest_level=True, - forcetorque_matrix=None, - force_partitioning=0.5, - customised_axes=False, - ) - - data_container.atoms.principal_axes.assert_called() - bead1.center_of_mass.assert_called_with(unwrap=True) - bead2.center_of_mass.assert_called_with(unwrap=True) - - assert vanilla_mock.call_count == 2 # once per bead - - # make_whole is called twice per bead: on data_container.atoms and on bead - assert mw_mock.call_count == 4 - mw_mock.assert_any_call(data_container.atoms) - mw_mock.assert_any_call(bead1) - mw_mock.assert_any_call(bead2) - - # result shape: (6N, 6N) with N=2 - assert ft_matrix.shape == (12, 12) - - def test_get_combined_forcetorque_matrices_shape_mismatch_raises(self): - """ - Test that: raises ValueError when existing forcetorque_matrix has wrong shape. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead1 = MagicMock() - bead2 = MagicMock() - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([4.0, 5.0, 6.0]) - ) - level_manager.create_FTsubmatrix = MagicMock(return_value=np.identity(6)) - - data_container = MagicMock() - data_container.atoms = MagicMock() - - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.array([1.0, 1.0, 1.0]) - - bad_existing = np.zeros((6, 6)) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_residue_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - with self.assertRaises(ValueError) as ctx: - level_manager.get_combined_forcetorque_matrices( - data_container=data_container, - level="residue", - highest_level=True, - forcetorque_matrix=bad_existing, - force_partitioning=0.5, - customised_axes=True, - ) - - assert "forcetorque matrix shape" in str(ctx.exception) - - def test_get_combined_forcetorque_matrices_existing_same_shape(self): - """ - Test that: if existing forcetorque_matrix has correct shape, function returns - the newly computed block (no ValueError). - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - bead1 = MagicMock() - bead2 = MagicMock() - level_manager.get_beads = MagicMock(return_value=[bead1, bead2]) - - level_manager.get_weighted_forces = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - level_manager.get_weighted_torques = MagicMock( - return_value=np.array([4.0, 5.0, 6.0]) - ) - - I6 = np.identity(6) - level_manager.create_FTsubmatrix = MagicMock(return_value=I6) - - data_container = MagicMock() - data_container.atoms = MagicMock() - - dummy_trans_axes = np.eye(3) - dummy_rot_axes = np.eye(3) - dummy_center = np.zeros(3) - dummy_moi = np.array([1.0, 1.0, 1.0]) - - existing_ok = np.zeros((12, 12)) - - with patch("CodeEntropy.levels.AxesManager") as AxesManagerMock: - axes = AxesManagerMock.return_value - axes.get_residue_axes.return_value = ( - dummy_trans_axes, - dummy_rot_axes, - dummy_center, - dummy_moi, - ) - - ft_matrix = level_manager.get_combined_forcetorque_matrices( - data_container=data_container, - level="residue", - highest_level=True, - forcetorque_matrix=existing_ok, - force_partitioning=0.5, - customised_axes=True, - ) - - expected = np.block([[I6, I6], [I6, I6]]) - np.testing.assert_array_equal(ft_matrix, expected) - - def test_get_beads_polymer_level(self): - """ - Test `get_beads` for 'polymer' level. - Should return a single atom group representing the whole system. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - data_container = MagicMock() - mock_atom_group = MagicMock() - - data_container.select_atoms.return_value = mock_atom_group - - result = level_manager.get_beads(data_container, level="polymer") - - self.assertEqual(len(result), 1) - self.assertEqual(result[0], mock_atom_group) - data_container.select_atoms.assert_called_once_with("all") - - def test_get_beads_residue_level(self): - """ - Test `get_beads` for 'residue' level. - Should return one atom group per residue. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - data_container = MagicMock() - data_container.residues = [0, 1, 2] # 3 residues - mock_atom_group = MagicMock() - data_container.select_atoms.return_value = mock_atom_group - - result = level_manager.get_beads(data_container, level="residue") - - self.assertEqual(len(result), 3) - self.assertTrue(all(bead == mock_atom_group for bead in result)) - self.assertEqual(data_container.select_atoms.call_count, 3) - - def test_get_beads_united_atom_level(self): - """ - Test `get_beads` for 'united_atom' level. - Should return one bead per heavy atom, including bonded hydrogens. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - data_container = MagicMock() - heavy_atoms = [MagicMock(index=i) for i in range(3)] - data_container.select_atoms.side_effect = [ - heavy_atoms, - "bead0", - "bead1", - "bead2", - ] - - result = level_manager.get_beads(data_container, level="united_atom") - - self.assertEqual(len(result), 3) - self.assertEqual(result, ["bead0", "bead1", "bead2"]) - self.assertEqual( - data_container.select_atoms.call_count, 4 - ) # 1 for heavy_atoms + 3 beads - - def test_get_beads_hydrogen_molecule(self): - """ - Test `get_beads` for 'united_atom' level. - Should return one bead for molecule with no heavy atoms. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - data_container = MagicMock() - heavy_atoms = [] - data_container.select_atoms.side_effect = [ - heavy_atoms, - "hydrogen", - ] - - result = level_manager.get_beads(data_container, level="united_atom") - - self.assertEqual(len(result), 1) - self.assertEqual(result, ["hydrogen"]) - self.assertEqual( - data_container.select_atoms.call_count, 2 - ) # 1 for heavy_atoms + 1 beads - - def test_get_weighted_force_with_partitioning(self): - """ - Test correct weighted force calculation with partitioning enabled. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - atom = MagicMock() - atom.index = 0 - - bead = MagicMock() - bead.atoms = [atom] - bead.total_mass.return_value = 4.0 - - data_container = MagicMock() - data_container.atoms.__getitem__.return_value.force = np.array([2.0, 0.0, 0.0]) - - trans_axes = np.identity(3) - - result = level_manager.get_weighted_forces( - data_container, bead, trans_axes, highest_level=True, force_partitioning=0.5 - ) - - expected = (0.5 * np.array([2.0, 0.0, 0.0])) / np.sqrt(4.0) - np.testing.assert_array_almost_equal(result, expected) - - def test_get_weighted_force_without_partitioning(self): - """ - Test correct weighted force calculation with partitioning disabled. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - atom = MagicMock() - atom.index = 0 - - bead = MagicMock() - bead.atoms = [atom] - bead.total_mass.return_value = 1.0 - - data_container = MagicMock() - data_container.atoms.__getitem__.return_value.force = np.array([3.0, 0.0, 0.0]) - - trans_axes = np.identity(3) - - result = level_manager.get_weighted_forces( - data_container, - bead, - trans_axes, - highest_level=False, - force_partitioning=0.5, - ) - - expected = np.array([3.0, 0.0, 0.0]) / np.sqrt(1.0) - np.testing.assert_array_almost_equal(result, expected) - - def test_get_weighted_forces_zero_mass_raises_value_error(self): - """ - Test that a zero mass raises a ValueError. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - atom = MagicMock() - atom.index = 0 - - bead = MagicMock() - bead.atoms = [atom] - bead.total_mass.return_value = 0.0 - - data_container = MagicMock() - data_container.atoms.__getitem__.return_value.force = np.array([1.0, 0.0, 0.0]) - - trans_axes = np.identity(3) - - with self.assertRaises(ValueError): - level_manager.get_weighted_forces( - data_container, - bead, - trans_axes, - highest_level=True, - force_partitioning=0.5, - ) - - def test_get_weighted_forces_negative_mass_raises_value_error(self): - """ - Test that a negative mass raises a ValueError. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - atom = MagicMock() - atom.index = 0 - - bead = MagicMock() - bead.atoms = [atom] - bead.total_mass.return_value = -1.0 - - data_container = MagicMock() - data_container.atoms.__getitem__.return_value.force = np.array([1.0, 0.0, 0.0]) - - trans_axes = np.identity(3) - - with self.assertRaises(ValueError): - level_manager.get_weighted_forces( - data_container, - bead, - trans_axes, - highest_level=True, - force_partitioning=0.5, - ) - - def test_get_weighted_torques_weighted_torque_basic(self): - """ - Test basic weighted torque calculation for a single-atom bead. - - Setup: - r = [1, 0, 0], F = [0, 1, 0] => r x F = [0, 0, 1] - With force_partitioning=0.5, rot_axes=I, MOI=[1,1,1], - expected weighted torque is [0, 0, 0.5]. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - axes_manager = AxesManager() - - bead = MagicMock() - bead.positions = np.array([[1.0, 0.0, 0.0]]) - bead.forces = np.array([[0.0, 1.0, 0.0]]) - bead.dimensions = np.array([10.0, 10.0, 10.0]) - - rot_axes = np.eye(3) - center = np.zeros(3) - force_partitioning = 0.5 - moment_of_inertia = np.array([1.0, 1.0, 1.0]) - - with patch.object( - AxesManager, "get_vector", return_value=bead.positions - center - ) as gv_mock: - result = level_manager.get_weighted_torques( - bead=bead, - rot_axes=rot_axes, - center=center, - force_partitioning=force_partitioning, - moment_of_inertia=moment_of_inertia, - axes_manager=axes_manager, - ) - - gv_mock.assert_called() - - expected = np.array([0.0, 0.0, 0.5]) - np.testing.assert_allclose(result, expected) - - def test_get_weighted_torques_zero_torque_skips_division(self): - """ - Test that zero torque components skip division and remain zero. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - axes_manager = AxesManager() - - bead = MagicMock() - bead.positions = np.array([[0.0, 0.0, 0.0]]) - bead.forces = np.array([[0.0, 0.0, 0.0]]) - bead.dimensions = np.array([10.0, 10.0, 10.0]) - - rot_axes = np.identity(3) - center = np.array([0.0, 0.0, 0.0]) - force_partitioning = 0.5 - moment_of_inertia = np.array([1.0, 2.0, 3.0]) - - with patch.object( - AxesManager, "get_vector", return_value=bead.positions - center - ): - result = level_manager.get_weighted_torques( - bead=bead, - rot_axes=rot_axes, - center=center, - force_partitioning=force_partitioning, - moment_of_inertia=moment_of_inertia, - axes_manager=axes_manager, - ) - - np.testing.assert_array_equal(result, np.zeros(3)) - - def test_get_weighted_torques_zero_moi(self): - """ - Should set torque to 0 when moment of inertia is zero in a dimension - and torque in that dimension is non-zero. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - axes_manager = AxesManager() - - bead = MagicMock() - bead.positions = np.array([[1.0, 0.0, 0.0]]) - bead.forces = np.array([[0.0, 1.0, 0.0]]) - bead.dimensions = np.array([10.0, 10.0, 10.0]) - - rot_axes = np.identity(3) - center = np.array([0.0, 0.0, 0.0]) - force_partitioning = 0.5 - moment_of_inertia = np.array([1.0, 1.0, 0.0]) - - with patch.object( - AxesManager, "get_vector", return_value=bead.positions - center - ): - torque = level_manager.get_weighted_torques( - bead=bead, - rot_axes=rot_axes, - center=center, - force_partitioning=force_partitioning, - moment_of_inertia=moment_of_inertia, - axes_manager=axes_manager, - ) - - np.testing.assert_array_equal(torque, np.zeros(3)) - - def test_get_weighted_torques_negative_moi_sets_zero(self): - """ - Negative moment of inertia components should be skipped and set to 0 - even if the corresponding torque component is non-zero. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - axes_manager = AxesManager() - - bead = MagicMock() - bead.positions = np.array([[1.0, 0.0, 0.0]]) - bead.forces = np.array([[0.0, 1.0, 0.0]]) - bead.dimensions = np.array([10.0, 10.0, 10.0]) - - rot_axes = np.identity(3) - center = np.array([0.0, 0.0, 0.0]) - force_partitioning = 0.5 - moment_of_inertia = np.array([1.0, 1.0, -1.0]) - - with patch.object( - AxesManager, "get_vector", return_value=bead.positions - center - ): - result = level_manager.get_weighted_torques( - bead=bead, - rot_axes=rot_axes, - center=center, - force_partitioning=force_partitioning, - moment_of_inertia=moment_of_inertia, - axes_manager=axes_manager, - ) - - np.testing.assert_array_equal(result, np.zeros(3)) - - def test_create_submatrix_basic_outer_product(self): - """ - Test with known vectors to verify correct outer product. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - data_i = np.array([1, 0, 0]) - data_j = np.array([0, 1, 0]) - - expected = np.outer(data_i, data_j) - result = level_manager.create_submatrix(data_i, data_j) - - np.testing.assert_array_equal(result, expected) - - def test_create_submatrix_zero_vectors_returns_zero_matrix(self): - """ - Test that all-zero input vectors return a zero matrix. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - data_i = np.zeros(3) - data_j = np.zeros(3) - result = level_manager.create_submatrix(data_i, data_j) - - np.testing.assert_array_equal(result, np.zeros((3, 3))) - - def test_create_submatrix_single_frame(self): - """ - Test that one frame should return the outer product of the single pair of - vectors. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - vec_i = np.array([1, 2, 3]) - vec_j = np.array([4, 5, 6]) - expected = np.outer(vec_i, vec_j) - - result = level_manager.create_submatrix([vec_i], [vec_j]) - np.testing.assert_array_almost_equal(result, expected) - - def test_create_submatrix_symmetric_result_when_data_equal(self): - """ - Test that if data_i == data_j, the result is symmetric. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - data = np.array([1, 2, 3]) - result = level_manager.create_submatrix(data, data) - - self.assertTrue(np.allclose(result, result.T)) # Check symmetry - - def test_create_FTsubmatrix_basic_outer_product(self): - """ - Test that: - - create_FTsubmatrix returns the outer product of two 6D vectors - - shape is (6, 6) - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - data_i = np.array([1, 2, 3, 4, 5, 6], dtype=float) - data_j = np.array([6, 5, 4, 3, 2, 1], dtype=float) - - result = level_manager.create_FTsubmatrix(data_i, data_j) - - expected = np.outer(data_i, data_j) - - assert result.shape == (6, 6) - np.testing.assert_array_equal(result, expected) - - def test_create_FTsubmatrix_zero_input(self): - """ - Test that: - - if either input is zero, the result is a zero matrix - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - data_i = np.zeros(6) - data_j = np.array([1, 2, 3, 4, 5, 6], dtype=float) - - result = level_manager.create_FTsubmatrix(data_i, data_j) - - np.testing.assert_array_equal(result, np.zeros((6, 6))) - - def test_create_FTsubmatrix_transpose_property(self): - """ - Test that: - - outer(i, j).T == outer(j, i) - - required by block-matrix symmetry logic - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - data_i = np.arange(1, 7, dtype=float) - data_j = np.arange(7, 13, dtype=float) - - sub_ij = level_manager.create_FTsubmatrix(data_i, data_j) - sub_ji = level_manager.create_FTsubmatrix(data_j, data_i) - - np.testing.assert_array_equal(sub_ij.T, sub_ji) - - def test_create_FTsubmatrix_dtype(self): - """ - Test that: - - output dtype follows NumPy outer-product rules - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - data_i = np.ones(6, dtype=np.float64) - data_j = np.ones(6, dtype=np.float64) - - result = level_manager.create_FTsubmatrix(data_i, data_j) - - assert result.dtype == np.float64 - - def test_build_covariance_matrices_atomic(self): - """ - Test `build_covariance_matrices` to ensure it correctly orchestrates - calls and returns dictionaries with the expected structure. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - entropy_manager = MagicMock() - - # Fake atom with minimal attributes - atom = MagicMock() - atom.resname = "RES" - atom.resid = 1 - atom.segid = "A" - - fake_mol = MagicMock() - fake_mol.atoms = [atom] - - universe_operations.get_molecule_container = MagicMock(return_value=fake_mol) - - timestep1 = MagicMock() - timestep1.frame = 0 - timestep2 = MagicMock() - timestep2.frame = 1 - - reduced_atom = MagicMock() - reduced_atom.trajectory.__getitem__.return_value = [timestep1, timestep2] - - groups = {"ua": ["mol1", "mol2"]} - levels = {"mol1": ["level1", "level2"], "mol2": ["level1"]} - - level_manager.update_force_torque_matrices = MagicMock() - - force_matrices, torque_matrices, *_ = level_manager.build_covariance_matrices( - entropy_manager=entropy_manager, - reduced_atom=reduced_atom, - levels=levels, - groups=groups, - start=0, - end=2, - step=1, - number_frames=2, - force_partitioning=0.5, - combined_forcetorque=False, - customised_axes=True, - ) - - self.assertIsInstance(force_matrices, dict) - self.assertIsInstance(torque_matrices, dict) - self.assertSetEqual(set(force_matrices.keys()), {"ua", "res", "poly"}) - self.assertSetEqual(set(torque_matrices.keys()), {"ua", "res", "poly"}) - - self.assertIsInstance(force_matrices["res"], list) - self.assertIsInstance(force_matrices["poly"], list) - self.assertEqual(len(force_matrices["res"]), len(groups)) - self.assertEqual(len(force_matrices["poly"]), len(groups)) - - self.assertEqual(universe_operations.get_molecule_container.call_count, 4) - self.assertEqual(level_manager.update_force_torque_matrices.call_count, 6) - - def test_update_force_torque_matrices_united_atom(self): - """ - Test that update_force_torque_matrices() correctly initializes force and torque - matrices for the 'united_atom' level. - - Ensures: - - The matrices are initialized for each UA group key. - - Frame counts are incremented correctly. - """ - universe_operations = UniverseOperations() - universe_operations.new_U_select_atom = MagicMock() - - level_manager = LevelManager(universe_operations) - - entropy_manager = MagicMock() - run_manager = MagicMock() - entropy_manager._run_manager = run_manager - - mock_res = MagicMock() - mock_res.trajectory = MagicMock() - mock_res.trajectory.__getitem__.return_value = None - - universe_operations.new_U_select_atom.return_value = mock_res - - mock_residue1 = MagicMock() - mock_residue1.atoms.indices = [0, 2] - mock_residue2 = MagicMock() - mock_residue2.atoms.indices = [3, 5] - - mol = MagicMock() - mol.residues = [mock_residue1, mock_residue2] - - f_mat = np.array([[1]]) - t_mat = np.array([[2]]) - level_manager.get_matrices = MagicMock(return_value=(f_mat, t_mat)) - - force_avg = {"ua": {}, "res": [None], "poly": [None]} - torque_avg = {"ua": {}, "res": [None], "poly": [None]} - forcetorque_avg = {"ua": {}, "res": [None], "poly": [None]} - frame_counts = {"ua": {}, "res": [None], "poly": [None]} - - level_manager.update_force_torque_matrices( - entropy_manager=entropy_manager, - mol=mol, - group_id=0, - level="united_atom", - level_list=["residue", "united_atom"], - time_index=0, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=False, - customised_axes=True, - ) - - assert (0, 0) in force_avg["ua"] - assert (0, 1) in force_avg["ua"] - assert (0, 0) in torque_avg["ua"] - assert (0, 1) in torque_avg["ua"] - - np.testing.assert_array_equal(force_avg["ua"][(0, 0)], f_mat) - np.testing.assert_array_equal(force_avg["ua"][(0, 1)], f_mat) - np.testing.assert_array_equal(torque_avg["ua"][(0, 0)], t_mat) - np.testing.assert_array_equal(torque_avg["ua"][(0, 1)], t_mat) - - assert frame_counts["ua"][(0, 0)] == 1 - assert frame_counts["ua"][(0, 1)] == 1 - - assert forcetorque_avg["ua"] == {} - - def test_update_force_torque_matrices_united_atom_increment(self): - """ - Test that update_force_torque_matrices() correctly updates (increments) - existing force and torque matrices for the 'united_atom' level. - - Confirms correct incremental averaging behavior. - """ - universe_operations = UniverseOperations() - universe_operations.new_U_select_atom = MagicMock() - - level_manager = LevelManager(universe_operations) - - entropy_manager = MagicMock() - mol = MagicMock() - - residue = MagicMock() - residue.atoms.indices = [0, 1] - mol.residues = [residue] - mol.trajectory = MagicMock() - mol.trajectory.__getitem__.return_value = None - - selected_atoms = MagicMock() - selected_atoms.trajectory = MagicMock() - selected_atoms.trajectory.__getitem__.return_value = None - universe_operations.new_U_select_atom.return_value = selected_atoms - - f_mat_1 = np.array([[1.0]]) - t_mat_1 = np.array([[2.0]]) - - f_mat_2 = np.array([[3.0]]) - t_mat_2 = np.array([[4.0]]) - - level_manager.get_matrices = MagicMock( - side_effect=[(f_mat_1, t_mat_1), (f_mat_2, t_mat_2)] - ) - - force_avg = {"ua": {}, "res": [None], "poly": [None]} - torque_avg = {"ua": {}, "res": [None], "poly": [None]} - forcetorque_avg = {"ua": {}, "res": [None], "poly": [None]} - frame_counts = {"ua": {}, "res": [None], "poly": [None]} - - level_manager.update_force_torque_matrices( - entropy_manager=entropy_manager, - mol=mol, - group_id=0, - level="united_atom", - level_list=["residue", "united_atom"], - time_index=0, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=False, - customised_axes=True, - ) - - # Second update - level_manager.update_force_torque_matrices( - entropy_manager=entropy_manager, - mol=mol, - group_id=0, - level="united_atom", - level_list=["residue", "united_atom"], - time_index=1, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=False, - customised_axes=True, - ) - - expected_force = f_mat_1 + (f_mat_2 - f_mat_1) / 2 - expected_torque = t_mat_1 + (t_mat_2 - t_mat_1) / 2 - - np.testing.assert_array_almost_equal(force_avg["ua"][(0, 0)], expected_force) - np.testing.assert_array_almost_equal(torque_avg["ua"][(0, 0)], expected_torque) - assert frame_counts["ua"][(0, 0)] == 2 - - assert forcetorque_avg["ua"] == {} - - def test_update_force_torque_matrices_residue(self): - """ - Test that `update_force_torque_matrices` correctly updates force and torque - matrices for the 'residue' level, assigning whole-molecule matrices and - incrementing frame counts. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - entropy_manager = MagicMock() - mol = MagicMock() - mol.trajectory.__getitem__.return_value = None - - f_mat_mock = np.array([[1]]) - t_mat_mock = np.array([[2]]) - level_manager.get_matrices = MagicMock(return_value=(f_mat_mock, t_mat_mock)) - - force_avg = {"ua": {}, "res": [None], "poly": [None]} - torque_avg = {"ua": {}, "res": [None], "poly": [None]} - forcetorque_avg = {"ua": {}, "res": [None], "poly": [None]} - frame_counts = {"ua": {}, "res": [None], "poly": [None]} - - level_manager.update_force_torque_matrices( - entropy_manager=entropy_manager, - mol=mol, - group_id=0, - level="residue", - level_list=["residue", "united_atom"], - time_index=3, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=False, - customised_axes=True, - ) - - np.testing.assert_array_equal(force_avg["res"][0], f_mat_mock) - np.testing.assert_array_equal(torque_avg["res"][0], t_mat_mock) - assert frame_counts["res"][0] == 1 - - assert forcetorque_avg["res"][0] is None - - def test_update_force_torque_matrices_incremental_average(self): - """ - Test that `update_force_torque_matrices` correctly applies the incremental - mean formula when updating force and torque matrices over multiple frames. - - Ensures that float precision is maintained and no casting errors occur. - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - entropy_manager = MagicMock() - mol = MagicMock() - mol.trajectory.__getitem__.return_value = None - - # Ensure matrices are float64 to avoid casting errors - f_mat_1 = np.array([[1.0]], dtype=np.float64) - t_mat_1 = np.array([[2.0]], dtype=np.float64) - f_mat_2 = np.array([[3.0]], dtype=np.float64) - t_mat_2 = np.array([[4.0]], dtype=np.float64) - - level_manager.get_matrices = MagicMock( - side_effect=[(f_mat_1, t_mat_1), (f_mat_2, t_mat_2)] - ) - - force_avg = {"ua": {}, "res": [None], "poly": [None]} - torque_avg = {"ua": {}, "res": [None], "poly": [None]} - forcetorque_avg = {"ua": {}, "res": [None], "poly": [None]} - frame_counts = {"ua": {}, "res": [None], "poly": [None]} - - # First update - level_manager.update_force_torque_matrices( - entropy_manager=entropy_manager, - mol=mol, - group_id=0, - level="residue", - level_list=["residue", "united_atom"], - time_index=0, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=False, - customised_axes=True, - ) - - # Second update - level_manager.update_force_torque_matrices( - entropy_manager=entropy_manager, - mol=mol, - group_id=0, - level="residue", - level_list=["residue", "united_atom"], - time_index=1, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=False, - customised_axes=True, - ) - - expected_force = f_mat_1 + (f_mat_2 - f_mat_1) / 2 - expected_torque = t_mat_1 + (t_mat_2 - t_mat_1) / 2 - - np.testing.assert_array_almost_equal(force_avg["res"][0], expected_force) - np.testing.assert_array_almost_equal(torque_avg["res"][0], expected_torque) - - assert frame_counts["res"][0] == 2 - assert forcetorque_avg["res"][0] is None - - def test_update_force_torque_matrices_residue_combined_ft_init(self): - """ - Test that: When highest=True and combined_forcetorque=True at residue level, - update_force_torque_matrices should: - - call get_combined_forcetorque_matrices - - store ft_mat into forcetorque_avg['res'][group_id] - - set frame_counts['res'][group_id] = 1 - - NOT touch force_avg/torque_avg - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - mol = MagicMock() - mol.trajectory.__getitem__.return_value = None - - ft_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float) - - level_manager.get_combined_forcetorque_matrices = MagicMock(return_value=ft_mat) - level_manager.get_matrices = MagicMock() - - force_avg = {"ua": {}, "res": [None], "poly": [None]} - torque_avg = {"ua": {}, "res": [None], "poly": [None]} - forcetorque_avg = {"ua": {}, "res": [None], "poly": [None]} - frame_counts = {"ua": {}, "res": [None], "poly": [None]} - - level_manager.update_force_torque_matrices( - entropy_manager=MagicMock(), - mol=mol, - group_id=0, - level="residue", - level_list=["united_atom", "residue"], - time_index=0, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=True, - customised_axes=True, - ) - - level_manager.get_combined_forcetorque_matrices.assert_called_once() - args = level_manager.get_combined_forcetorque_matrices.call_args.args - assert args[0] is mol - assert args[1] == "residue" - assert args[2] is True - assert args[3] is None - assert args[4] == 0.5 - assert args[5] is True - - np.testing.assert_array_equal(forcetorque_avg["res"][0], ft_mat) - assert frame_counts["res"][0] == 1 - - level_manager.get_matrices.assert_not_called() - - assert force_avg["res"][0] is None - assert torque_avg["res"][0] is None - - def test_update_force_torque_matrices_residue_combined_ft_incremental_avg_no_helper( - self, - ): - """ - Test that: highest=True and combined_forcetorque=True - - initializes forcetorque_avg on first call - - updates it via incremental mean on second call - - avoids asserting the mutable 'existing' arg passed into the mock - """ - universe_operations = UniverseOperations() - level_manager = LevelManager(universe_operations) - - mol = MagicMock() - mol.trajectory.__getitem__.return_value = None - - ft1 = np.array([[1.0, 1.0], [1.0, 1.0]], dtype=float) - ft2 = np.array([[3.0, 3.0], [3.0, 3.0]], dtype=float) - - level_manager.get_combined_forcetorque_matrices = MagicMock( - side_effect=[ft1, ft2] - ) - level_manager.get_matrices = MagicMock() - - force_avg = {"ua": {}, "res": [None], "poly": [None]} - torque_avg = {"ua": {}, "res": [None], "poly": [None]} - forcetorque_avg = {"ua": {}, "res": [None], "poly": [None]} - frame_counts = {"ua": {}, "res": [None], "poly": [None]} - - level_manager.update_force_torque_matrices( - entropy_manager=MagicMock(), - mol=mol, - group_id=0, - level="residue", - level_list=["united_atom", "residue"], - time_index=0, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=True, - customised_axes=True, - ) - - np.testing.assert_array_equal(forcetorque_avg["res"][0], ft1) - assert frame_counts["res"][0] == 1 - - level_manager.update_force_torque_matrices( - entropy_manager=MagicMock(), - mol=mol, - group_id=0, - level="residue", - level_list=["united_atom", "residue"], - time_index=1, - num_frames=10, - force_avg=force_avg, - torque_avg=torque_avg, - forcetorque_avg=forcetorque_avg, - frame_counts=frame_counts, - force_partitioning=0.5, - combined_forcetorque=True, - customised_axes=True, - ) - - expected = ft1 + (ft2 - ft1) / 2.0 - np.testing.assert_array_almost_equal(forcetorque_avg["res"][0], expected) - assert frame_counts["res"][0] == 2 - - level_manager.get_matrices.assert_not_called() - assert level_manager.get_combined_forcetorque_matrices.call_count == 2 - - def test_filter_zero_rows_columns_no_zeros(self): - """ - Test that matrix with no zero-only rows or columns should return unchanged. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - matrix = np.array([[1, 2], [3, 4]]) - result = level_manager.filter_zero_rows_columns(matrix) - np.testing.assert_array_equal(result, matrix) - - def test_filter_zero_rows_columns_remove_rows_and_columns(self): - """ - Test that matrix with zero-only rows and columns should return reduced matrix. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - matrix = np.array([[0, 0, 0], [0, 5, 0], [0, 0, 0]]) - expected = np.array([[5]]) - result = level_manager.filter_zero_rows_columns(matrix) - np.testing.assert_array_equal(result, expected) - - def test_filter_zero_rows_columns_all_zeros(self): - """ - Test that matrix with all zeros should return an empty matrix. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - matrix = np.zeros((3, 3)) - result = level_manager.filter_zero_rows_columns(matrix) - self.assertEqual(result.size, 0) - self.assertEqual(result.shape, (0, 0)) - - def test_filter_zero_rows_columns_partial_zero_removal(self): - """ - Matrix with zeros in specific rows/columns should remove only those. - """ - universe_operations = UniverseOperations() - - level_manager = LevelManager(universe_operations) - - matrix = np.array([[0, 0, 0], [1, 2, 3], [0, 0, 0]]) - expected = np.array([[1, 2, 3]]) - result = level_manager.filter_zero_rows_columns(matrix) - np.testing.assert_array_equal(result, expected) diff --git a/tests/test_CodeEntropy/test_logging_config.py b/tests/test_CodeEntropy/test_logging_config.py deleted file mode 100644 index 7a07b2aa..00000000 --- a/tests/test_CodeEntropy/test_logging_config.py +++ /dev/null @@ -1,103 +0,0 @@ -import logging -import os -import unittest -from unittest.mock import MagicMock - -from CodeEntropy.config.logging_config import LoggingConfig -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestLoggingConfig(BaseTestCase): - """ - Unit tests for LoggingConfig. - """ - - def setUp(self): - super().setUp() - self.log_dir = self.logs_path - self.logging_config = LoggingConfig(folder=self.test_dir) - - self.mock_text = "Test console output" - self.logging_config.console.export_text = MagicMock(return_value=self.mock_text) - - def test_log_directory_created(self): - """Check if the log directory is created upon init""" - self.assertTrue(os.path.exists(self.log_dir)) - self.assertTrue(os.path.isdir(self.log_dir)) - - def test_setup_logging_returns_logger(self): - """Ensure setup_logging returns a logger instance""" - logger = self.logging_config.setup_logging() - self.assertIsInstance(logger, logging.Logger) - - def test_expected_log_files_created(self): - """Ensure log file paths are configured correctly in the logging config""" - self.logging_config.setup_logging() - - # Map expected filenames to the corresponding handler keys in LoggingConfig - expected_handlers = { - "program.log": "main", - "program.err": "error", - "program.com": "command", - "mdanalysis.log": "mdanalysis", - } - - for filename, handler_key in expected_handlers.items(): - expected_path = os.path.join(self.logging_config.log_dir, filename) - actual_path = self.logging_config.handlers[handler_key].baseFilename - self.assertEqual(actual_path, expected_path) - - def test_update_logging_level(self): - """Ensure logging levels are updated correctly""" - self.logging_config.setup_logging() - - # Update to DEBUG - self.logging_config.update_logging_level(logging.DEBUG) - root_logger = logging.getLogger() - self.assertEqual(root_logger.level, logging.DEBUG) - - # Check that at least one handler is DEBUG - handler_levels = [h.level for h in root_logger.handlers] - self.assertIn(logging.DEBUG, handler_levels) - - # Update to INFO - self.logging_config.update_logging_level(logging.INFO) - self.assertEqual(root_logger.level, logging.INFO) - - def test_mdanalysis_and_command_loggers_exist(self): - """Ensure specialized loggers are set up with correct configuration""" - log_level = logging.DEBUG - self.logging_config = LoggingConfig(folder=self.test_dir, level=log_level) - self.logging_config.setup_logging() - - mda_logger = logging.getLogger("MDAnalysis") - cmd_logger = logging.getLogger("commands") - - self.assertEqual(mda_logger.level, log_level) - self.assertEqual(cmd_logger.level, logging.INFO) - self.assertFalse(mda_logger.propagate) - self.assertFalse(cmd_logger.propagate) - - def test_save_console_log_writes_file(self): - """ - Test that save_console_log creates a log file in the expected location - and writes the console's recorded output correctly. - """ - filename = "test_log.txt" - self.logging_config.save_console_log(filename) - - output_path = os.path.join(self.test_dir, "logs", filename) - # Check file exists - self.assertTrue(os.path.exists(output_path)) - - # Read content and check it matches mocked export_text output - with open(output_path, "r", encoding="utf-8") as f: - content = f.read() - self.assertEqual(content, self.mock_text) - - # Ensure export_text was called once - self.logging_config.console.export_text.assert_called_once() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_CodeEntropy/test_main.py b/tests/test_CodeEntropy/test_main.py deleted file mode 100644 index 76feb528..00000000 --- a/tests/test_CodeEntropy/test_main.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import shutil -import subprocess -import sys -import unittest -from unittest.mock import MagicMock, patch - -from CodeEntropy.main import main -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestMain(BaseTestCase): - """ - Unit tests for the main functionality of CodeEntropy. - """ - - def setUp(self): - super().setUp() - self.code_entropy = main - - @patch("CodeEntropy.main.sys.exit") - @patch("CodeEntropy.main.RunManager") - def test_main_successful_run(self, mock_RunManager, mock_exit): - """ - Test that main runs successfully and does not call sys.exit. - """ - # Mock RunManager's methods to simulate successful execution - mock_run_manager_instance = MagicMock() - mock_RunManager.return_value = mock_run_manager_instance - - # Simulate that RunManager.create_job_folder returns a folder - mock_RunManager.create_job_folder.return_value = "mock_folder/job001" - - # Simulate the successful completion of the run_entropy_workflow method - mock_run_manager_instance.run_entropy_workflow.return_value = None - - # Run the main function - main() - - # Verify that sys.exit was not called - mock_exit.assert_not_called() - - # Verify that RunManager's methods were called correctly - mock_RunManager.create_job_folder.assert_called_once() - mock_run_manager_instance.run_entropy_workflow.assert_called_once() - - @patch("CodeEntropy.main.sys.exit") - @patch("CodeEntropy.main.RunManager") - @patch("CodeEntropy.main.logger") - def test_main_exception_triggers_exit( - self, mock_logger, mock_RunManager, mock_exit - ): - """ - Test that main logs a critical error and exits if RunManager - raises an exception. - """ - # Simulate an exception being raised in run_entropy_workflow - mock_run_manager_instance = MagicMock() - mock_RunManager.return_value = mock_run_manager_instance - - # Simulate that RunManager.create_job_folder returns a folder - mock_RunManager.create_job_folder.return_value = "mock_folder/job001" - - # Simulate an exception in the run_entropy_workflow method - mock_run_manager_instance.run_entropy_workflow.side_effect = Exception( - "Test exception" - ) - - # Run the main function and mock sys.exit to ensure it gets called - main() - - # Ensure sys.exit(1) was called due to the exception - mock_exit.assert_called_once_with(1) - - # Ensure that the logger logged the critical error with exception details - mock_logger.critical.assert_called_once_with( - "Fatal error during entropy calculation: Test exception", exc_info=True - ) - - def test_main_entry_point_runs(self): - """ - Test that the CLI entry point (main.py) runs successfully with minimal required - arguments. - """ - # Prepare input files - data_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..", "data") - ) - tpr_path = shutil.copy(os.path.join(data_dir, "md_A4_dna.tpr"), self.test_dir) - trr_path = shutil.copy( - os.path.join(data_dir, "md_A4_dna_xf.trr"), self.test_dir - ) - - config_path = os.path.join(self.test_dir, "config.yaml") - with open(config_path, "w") as f: - f.write("run1:\n" " end: 1\n" " selection_string: resname DA\n") - - citation_path = os.path.join(self.test_dir, "CITATION.cff") - with open(citation_path, "w") as f: - f.write("\n") - - result = subprocess.run( - [ - sys.executable, - "-X", - "utf8", - "-m", - "CodeEntropy.main", - "--top_traj_file", - tpr_path, - trr_path, - ], - cwd=self.test_dir, - capture_output=True, - encoding="utf-8", - ) - - self.assertEqual(result.returncode, 0) - - # Check for job folder and output file - job_dir = os.path.join(self.test_dir, "job001") - output_file = os.path.join(job_dir, "output_file.json") - - self.assertTrue(os.path.exists(job_dir)) - self.assertTrue(os.path.exists(output_file)) - - with open(output_file) as f: - content = f.read() - print(content) - self.assertIn("DA", content) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_CodeEntropy/test_mda_universe_operations.py b/tests/test_CodeEntropy/test_mda_universe_operations.py deleted file mode 100644 index 46e68a76..00000000 --- a/tests/test_CodeEntropy/test_mda_universe_operations.py +++ /dev/null @@ -1,304 +0,0 @@ -import logging -import os -from unittest.mock import MagicMock, patch - -import MDAnalysis as mda -import numpy as np - -import tests.data as data -from CodeEntropy.mda_universe_operations import UniverseOperations -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestUniverseOperations(BaseTestCase): - """ - Unit tests for UniverseOperations. - """ - - def setUp(self): - super().setUp() - self.test_data_dir = os.path.dirname(data.__file__) - - # Disable MDAnalysis and commands file logging entirely - logging.getLogger("MDAnalysis").handlers = [logging.NullHandler()] - logging.getLogger("commands").handlers = [logging.NullHandler()] - - @patch("CodeEntropy.mda_universe_operations.AnalysisFromFunction") - @patch("CodeEntropy.mda_universe_operations.mda.Merge") - def test_new_U_select_frame(self, MockMerge, MockAnalysisFromFunction): - """ - Unit test for UniverseOperations.new_U_select_frame(). - """ - # Mock Universe and its components - mock_universe = MagicMock() - mock_trajectory = MagicMock() - mock_trajectory.__len__.return_value = 10 - mock_universe.trajectory = mock_trajectory - - mock_select_atoms = MagicMock() - mock_universe.select_atoms.return_value = mock_select_atoms - - # Mock AnalysisFromFunction results for coordinates, forces, and dimensions - coords = np.random.rand(10, 100, 3) - forces = np.random.rand(10, 100, 3) - dims = np.random.rand(10, 6) - - mock_coords_analysis = MagicMock() - mock_coords_analysis.run.return_value.results = {"timeseries": coords} - - mock_forces_analysis = MagicMock() - mock_forces_analysis.run.return_value.results = {"timeseries": forces} - - mock_dims_analysis = MagicMock() - mock_dims_analysis.run.return_value.results = {"timeseries": dims} - - MockAnalysisFromFunction.side_effect = [ - mock_coords_analysis, - mock_forces_analysis, - mock_dims_analysis, - ] - - # Mock merge operation - mock_merged_universe = MagicMock() - MockMerge.return_value = mock_merged_universe - - ops = UniverseOperations() - result = ops.new_U_select_frame(mock_universe) - - # Basic behavior checks - mock_universe.select_atoms.assert_called_once_with("all", updating=True) - - # AnalysisFromFunction called 3 times (coords, forces, dimensions) - assert MockAnalysisFromFunction.call_count == 3 - mock_coords_analysis.run.assert_called_once() - mock_forces_analysis.run.assert_called_once() - mock_dims_analysis.run.assert_called_once() - - # Merge called with selected AtomGroup - MockMerge.assert_called_once_with(mock_select_atoms) - - assert result == mock_merged_universe - - @patch("CodeEntropy.mda_universe_operations.AnalysisFromFunction") - @patch("CodeEntropy.mda_universe_operations.mda.Merge") - def test_new_U_select_atom(self, MockMerge, MockAnalysisFromFunction): - """ - Unit test for UniverseOperations.new_U_select_atom(). - - Ensures that: - - The Universe is queried with the correct selection string - - Coordinates, forces, and dimensions are extracted via AnalysisFromFunction - - mda.Merge receives the AtomGroup from select_atoms - - The new universe is populated with the expected data via load_new() - - The returned universe is the object created by Merge - """ - # Mock Universe and its components - mock_universe = MagicMock() - mock_select_atoms = MagicMock() - mock_universe.select_atoms.return_value = mock_select_atoms - - # Mock AnalysisFromFunction results for coordinates, forces, and dimensions - coords = np.random.rand(10, 100, 3) - forces = np.random.rand(10, 100, 3) - dims = np.random.rand(10, 6) - - mock_coords_analysis = MagicMock() - mock_coords_analysis.run.return_value.results = {"timeseries": coords} - - mock_forces_analysis = MagicMock() - mock_forces_analysis.run.return_value.results = {"timeseries": forces} - - mock_dims_analysis = MagicMock() - mock_dims_analysis.run.return_value.results = {"timeseries": dims} - - MockAnalysisFromFunction.side_effect = [ - mock_coords_analysis, - mock_forces_analysis, - mock_dims_analysis, - ] - - # Mock the merge operation - mock_merged_universe = MagicMock() - MockMerge.return_value = mock_merged_universe - - ops = UniverseOperations() - - result = ops.new_U_select_atom(mock_universe, select_string="resid 1-10") - mock_universe.select_atoms.assert_called_once_with("resid 1-10", updating=True) - - # AnalysisFromFunction called for coords, forces, dimensions - assert MockAnalysisFromFunction.call_count == 3 - mock_coords_analysis.run.assert_called_once() - mock_forces_analysis.run.assert_called_once() - mock_dims_analysis.run.assert_called_once() - - # Merge called with the selected AtomGroup - MockMerge.assert_called_once_with(mock_select_atoms) - - # Returned universe should be the merged universe - assert result == mock_merged_universe - - def test_get_molecule_container(self): - """ - Integration test for UniverseOperations.get_molecule_container(). - - Uses a real MDAnalysis Universe loaded from test trajectory files. - Confirms that: - - The correct fragment for a given molecule index is selected - - The returned reduced Universe contains exactly the expected atom indices - - The number of atoms matches the original fragment - """ - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - - u = mda.Universe(tprfile, trrfile) - - ops = UniverseOperations() - - molecule_id = 0 - - fragment = u.atoms.fragments[molecule_id] - expected_indices = fragment.indices - - mol_u = ops.get_molecule_container(u, molecule_id) - - selected_indices = mol_u.atoms.indices - - self.assertSetEqual(set(selected_indices), set(expected_indices)) - self.assertEqual(len(selected_indices), len(expected_indices)) - - @patch("CodeEntropy.mda_universe_operations.AnalysisFromFunction") - @patch("CodeEntropy.mda_universe_operations.mda.Merge") - @patch("CodeEntropy.mda_universe_operations.mda.Universe") - def test_merge_forces(self, MockUniverse, MockMerge, MockAnalysisFromFunction): - """ - Unit test for UniverseOperations.merge_forces(). - """ - # Two Universes created: coords and forces - mock_u_coords = MagicMock() - mock_u_force = MagicMock() - MockUniverse.side_effect = [mock_u_coords, mock_u_force] - - # Each universe returns an AtomGroup from select_atoms("all") - mock_ag_coords = MagicMock() - mock_ag_force = MagicMock() - mock_u_coords.select_atoms.return_value = mock_ag_coords - mock_u_force.select_atoms.return_value = mock_ag_force - - coords = np.random.rand(5, 10, 3) - forces = np.random.rand(5, 10, 3) - dims = np.random.rand(5, 6) - - mock_coords_analysis = MagicMock() - mock_coords_analysis.run.return_value.results = {"timeseries": coords} - - mock_forces_analysis = MagicMock() - mock_forces_analysis.run.return_value.results = {"timeseries": forces} - - mock_dims_analysis = MagicMock() - mock_dims_analysis.run.return_value.results = {"timeseries": dims} - - MockAnalysisFromFunction.side_effect = [ - mock_coords_analysis, - mock_forces_analysis, - mock_dims_analysis, - ] - - mock_merged = MagicMock() - MockMerge.return_value = mock_merged - - ops = UniverseOperations() - result = ops.merge_forces( - tprfile="topol.tpr", - trrfile="traj.trr", - forcefile="forces.trr", - fileformat=None, - kcal=False, - ) - - # Universe construction - assert MockUniverse.call_count == 2 - - # Selection - mock_u_coords.select_atoms.assert_called_once_with("all") - mock_u_force.select_atoms.assert_called_once_with("all") - - # AnalysisFromFunction usage - assert MockAnalysisFromFunction.call_count == 3 - mock_coords_analysis.run.assert_called_once() - mock_forces_analysis.run.assert_called_once() - mock_dims_analysis.run.assert_called_once() - - # Merge called with coordinate AtomGroup - MockMerge.assert_called_once_with(mock_ag_coords) - - # Returned object is merged universe - assert result == mock_merged - - @patch("CodeEntropy.mda_universe_operations.AnalysisFromFunction") - @patch("CodeEntropy.mda_universe_operations.mda.Merge") - @patch("CodeEntropy.mda_universe_operations.mda.Universe") - def test_merge_forces_kcal_conversion( - self, MockUniverse, MockMerge, MockAnalysisFromFunction - ): - """ - Unit test for UniverseOperations.merge_forces() covering the kcal→kJ - conversion branch. - """ - mock_u_coords = MagicMock() - mock_u_force = MagicMock() - MockUniverse.side_effect = [mock_u_coords, mock_u_force] - - mock_ag_coords = MagicMock() - mock_ag_force = MagicMock() - mock_u_coords.select_atoms.return_value = mock_ag_coords - mock_u_force.select_atoms.return_value = mock_ag_force - - coords = np.ones((2, 3, 3)) - - original_forces = np.ones((2, 3, 3)) - mock_forces_array = original_forces.copy() - - dims = np.ones((2, 6)) - - # Mock AnalysisFromFunction return values - mock_coords_analysis = MagicMock() - mock_coords_analysis.run.return_value.results = {"timeseries": coords} - - mock_forces_analysis = MagicMock() - mock_forces_analysis.run.return_value.results = { - "timeseries": mock_forces_array - } - - mock_dims_analysis = MagicMock() - mock_dims_analysis.run.return_value.results = {"timeseries": dims} - - MockAnalysisFromFunction.side_effect = [ - mock_coords_analysis, - mock_forces_analysis, - mock_dims_analysis, - ] - - mock_merged = MagicMock() - MockMerge.return_value = mock_merged - - ops = UniverseOperations() - result = ops.merge_forces("t.tpr", "c.trr", "f.trr", kcal=True) - - # select_atoms("all") (your code uses no updating=True) - mock_u_coords.select_atoms.assert_called_once_with("all") - mock_u_force.select_atoms.assert_called_once_with("all") - - # AnalysisFromFunction called three times - assert MockAnalysisFromFunction.call_count == 3 - - # Forces are multiplied exactly once by 4.184 when kcal=True - np.testing.assert_allclose( - mock_forces_array, original_forces * 4.184, rtol=0, atol=0 - ) - - # Merge called with coordinate AtomGroup - MockMerge.assert_called_once_with(mock_ag_coords) - - # Returned universe is the merged universe - assert result == mock_merged diff --git a/tests/test_CodeEntropy/test_run.py b/tests/test_CodeEntropy/test_run.py deleted file mode 100644 index 0174df13..00000000 --- a/tests/test_CodeEntropy/test_run.py +++ /dev/null @@ -1,632 +0,0 @@ -import os -import unittest -from io import StringIO -from unittest.mock import MagicMock, mock_open, patch - -import requests -import yaml -from rich.console import Console - -from CodeEntropy.run import RunManager -from tests.test_CodeEntropy.test_base import BaseTestCase - - -class TestRunManager(BaseTestCase): - """ - Unit tests for the RunManager class. These tests verify the - correct behavior of run manager. - """ - - def setUp(self): - super().setUp() - self.config_file = os.path.join(self.test_dir, "CITATION.cff") - # Create mock config - with patch("builtins.open", new_callable=mock_open) as mock_file: - self.setup_citation_file(mock_file) - with open(self.config_file, "w") as f: - f.write(mock_file.return_value.read()) - self.run_manager = RunManager(folder=self.test_dir) - - def setup_citation_file(self, mock_file): - """ - Mock the contents of the CITATION.cff file. - """ - citation_content = """\ - authors: - - given-names: Alice - family-names: Smith - """ - - mock_file.return_value = mock_open(read_data=citation_content).return_value - - @patch("os.makedirs") - @patch("os.listdir") - def test_create_job_folder_empty_directory(self, mock_listdir, mock_makedirs): - """ - Test that 'job001' is created when the directory is initially empty. - """ - mock_listdir.return_value = [] - new_folder_path = RunManager.create_job_folder() - expected_path = os.path.join(self.test_dir, "job001") - self.assertEqual( - os.path.realpath(new_folder_path), os.path.realpath(expected_path) - ) - - @patch("os.makedirs") - @patch("os.listdir") - def test_create_job_folder_with_existing_folders(self, mock_listdir, mock_makedirs): - """ - Test that the next sequential job folder (e.g., 'job004') is created when - existing folders 'job001', 'job002', and 'job003' are present. - """ - mock_listdir.return_value = ["job001", "job002", "job003"] - new_folder_path = RunManager.create_job_folder() - expected_path = os.path.join(self.test_dir, "job004") - - # Normalize paths cross-platform - normalized_new = os.path.normcase( - os.path.realpath(os.path.normpath(new_folder_path)) - ) - normalized_expected = os.path.normcase( - os.path.realpath(os.path.normpath(expected_path)) - ) - - self.assertEqual(normalized_new, normalized_expected) - - called_args, called_kwargs = mock_makedirs.call_args - normalized_called = os.path.normcase( - os.path.realpath(os.path.normpath(called_args[0])) - ) - self.assertEqual(normalized_called, normalized_expected) - self.assertTrue(called_kwargs.get("exist_ok", False)) - - @patch("os.makedirs") - @patch("os.listdir") - def test_create_job_folder_with_non_matching_folders( - self, mock_listdir, mock_makedirs - ): - """ - Test that 'job001' is created when the directory contains only non-job-related - folders. - """ - mock_listdir.return_value = ["folderA", "another_one"] - - new_folder_path = RunManager.create_job_folder() - expected_path = os.path.join(self.test_dir, "job001") - - normalized_new = os.path.normcase( - os.path.realpath(os.path.normpath(new_folder_path)) - ) - normalized_expected = os.path.normcase( - os.path.realpath(os.path.normpath(expected_path)) - ) - self.assertEqual(normalized_new, normalized_expected) - - called_args, called_kwargs = mock_makedirs.call_args - normalized_called = os.path.normcase( - os.path.realpath(os.path.normpath(called_args[0])) - ) - self.assertEqual(normalized_called, normalized_expected) - self.assertTrue(called_kwargs.get("exist_ok", False)) - - @patch("os.makedirs") - @patch("os.listdir") - def test_create_job_folder_mixed_folder_names(self, mock_listdir, mock_makedirs): - """ - Test that the correct next job folder (e.g., 'job003') is created when both - job and non-job folders exist in the directory. - """ - mock_listdir.return_value = ["job001", "abc", "job002", "random"] - new_folder_path = RunManager.create_job_folder() - expected_path = os.path.join(self.test_dir, "job003") - - normalized_new = os.path.normcase( - os.path.realpath(os.path.normpath(new_folder_path)) - ) - normalized_expected = os.path.normcase( - os.path.realpath(os.path.normpath(expected_path)) - ) - self.assertEqual(normalized_new, normalized_expected) - - called_args, called_kwargs = mock_makedirs.call_args - normalized_called = os.path.normcase( - os.path.realpath(os.path.normpath(called_args[0])) - ) - self.assertEqual(normalized_called, normalized_expected) - self.assertTrue(called_kwargs.get("exist_ok", False)) - - @patch("os.makedirs") - @patch("os.listdir") - def test_create_job_folder_with_invalid_job_suffix( - self, mock_listdir, mock_makedirs - ): - """ - Test that invalid job folder names like 'jobABC' are ignored when determining - the next job number. - """ - # Simulate existing folders, one of which is invalid - mock_listdir.return_value = ["job001", "jobABC", "job002"] - - new_folder_path = RunManager.create_job_folder() - expected_path = os.path.join(self.test_dir, "job003") - - normalized_new = os.path.normcase( - os.path.realpath(os.path.normpath(new_folder_path)) - ) - normalized_expected = os.path.normcase( - os.path.realpath(os.path.normpath(expected_path)) - ) - self.assertEqual(normalized_new, normalized_expected) - - called_args, called_kwargs = mock_makedirs.call_args - normalized_called = os.path.normcase( - os.path.realpath(os.path.normpath(called_args[0])) - ) - self.assertEqual(normalized_called, normalized_expected) - self.assertTrue(called_kwargs.get("exist_ok", False)) - - @patch("requests.get") - def test_load_citation_data_success(self, mock_get): - """Should return parsed dict when CITATION.cff loads successfully.""" - mock_yaml = """ - authors: - - given-names: Alice - family-names: Smith - title: TestProject - version: 1.0 - date-released: 2025-01-01 - """ - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.text = mock_yaml - mock_get.return_value = mock_response - - instance = RunManager("dummy") - data = instance.load_citation_data() - - self.assertIsInstance(data, dict) - self.assertEqual(data["title"], "TestProject") - self.assertEqual(data["authors"][0]["given-names"], "Alice") - - @patch("requests.get") - def test_load_citation_data_network_error(self, mock_get): - """Should return None if network request fails.""" - mock_get.side_effect = requests.exceptions.ConnectionError("Network down") - - instance = RunManager("dummy") - data = instance.load_citation_data() - - self.assertIsNone(data) - - @patch("requests.get") - def test_load_citation_data_http_error(self, mock_get): - """Should return None if HTTP response is non-200.""" - mock_response = MagicMock() - mock_response.status_code = 404 - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError() - mock_get.return_value = mock_response - - instance = RunManager("dummy") - data = instance.load_citation_data() - - self.assertIsNone(data) - - @patch("requests.get") - def test_load_citation_data_invalid_yaml(self, mock_get): - """Should raise YAML error if file content is invalid YAML.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.text = "invalid: [oops" - mock_get.return_value = mock_response - - instance = RunManager("dummy") - - with self.assertRaises(yaml.YAMLError): - instance.load_citation_data() - - @patch.object(RunManager, "load_citation_data") - def test_show_splash_with_citation(self, mock_load): - """Should render full splash screen when citation data is present.""" - mock_load.return_value = { - "title": "TestProject", - "version": "1.0", - "date-released": "2025-01-01", - "url": "https://example.com", - "abstract": "This is a test abstract.", - "authors": [ - {"given-names": "Alice", "family-names": "Smith", "affiliation": "Uni"} - ], - } - - buf = StringIO() - test_console = Console(file=buf, force_terminal=False, width=80) - - instance = RunManager("dummy") - with patch("CodeEntropy.run.console", test_console): - instance.show_splash() - - output = buf.getvalue() - - self.assertIn("Version 1.0", output) - self.assertIn("2025-01-01", output) - self.assertIn("https://example.com", output) - self.assertIn("This is a test abstract.", output) - self.assertIn("Alice Smith", output) - - @patch.object(RunManager, "load_citation_data", return_value=None) - def test_show_splash_without_citation(self, mock_load): - """Should render minimal splash screen when no citation data.""" - buf = StringIO() - test_console = Console(file=buf, force_terminal=False, width=80) - - instance = RunManager("dummy") - with patch("CodeEntropy.run.console", test_console): - instance.show_splash() - - output = buf.getvalue() - - self.assertNotIn("Version", output) - self.assertNotIn("Contributors", output) - self.assertIn("Welcome to CodeEntropy", output) - - @patch.object(RunManager, "load_citation_data") - def test_show_splash_missing_fields(self, mock_load): - """Should gracefully handle missing optional fields in citation data.""" - mock_load.return_value = { - "title": "PartialProject", - # no version, no date, no authors, no abstract - } - - buf = StringIO() - test_console = Console(file=buf, force_terminal=False, width=80) - - instance = RunManager("dummy") - with patch("CodeEntropy.run.console", test_console): - instance.show_splash() - - output = buf.getvalue() - - self.assertIn("Version ?", output) - self.assertIn("No description available.", output) - - def test_run_entropy_workflow(self): - """ - Test the run_entropy_workflow method to ensure it initializes and executes - correctly with mocked dependencies. - """ - run_manager = RunManager("mock_folder/job001") - run_manager._logging_config = MagicMock() - run_manager._config_manager = MagicMock() - run_manager.load_citation_data = MagicMock() - run_manager._data_logger = MagicMock() - run_manager.folder = self.test_dir - - mock_logger = MagicMock() - run_manager._logging_config.setup_logging.return_value = mock_logger - - run_manager._config_manager.load_config.return_value = { - "test_run": { - "top_traj_file": ["/path/to/tpr", "/path/to/trr"], - "force_file": None, - "file_format": None, - "selection_string": "all", - "output_file": "output.json", - "verbose": True, - } - } - - run_manager.load_citation_data.return_value = { - "cff-version": "1.2.0", - "title": "CodeEntropy", - "message": ( - "If you use this software, please cite it using the " - "metadata from this file." - ), - "type": "software", - "authors": [ - { - "given-names": "Forename", - "family-names": "Sirname", - "email": "test@email.ac.uk", - } - ], - } - - mock_args = MagicMock() - mock_args.output_file = "output.json" - mock_args.verbose = True - mock_args.top_traj_file = ["/path/to/tpr", "/path/to/trr"] - mock_args.force_file = None - mock_args.file_format = None - mock_args.selection_string = "all" - parser = run_manager._config_manager.setup_argparse.return_value - parser.parse_known_args.return_value = (mock_args, []) - - run_manager._config_manager.merge_configs.return_value = mock_args - - mock_entropy_manager = MagicMock() - with ( - unittest.mock.patch( - "CodeEntropy.run.EntropyManager", return_value=mock_entropy_manager - ), - unittest.mock.patch("CodeEntropy.run.mda.Universe") as mock_universe, - ): - - run_manager.run_entropy_workflow() - - mock_universe.assert_called_once_with( - "/path/to/tpr", ["/path/to/trr"], format=None - ) - mock_entropy_manager.execute.assert_called_once() - - def test_run_entropy_workflow_with_forcefile(self): - """ - Test the else-branch in run_entropy_workflow where forcefile is not None. - """ - run_manager = RunManager("mock_folder/job001") - run_manager._logging_config = MagicMock() - run_manager._config_manager = MagicMock() - run_manager.load_citation_data = MagicMock() - run_manager.show_splash = MagicMock() - run_manager._data_logger = MagicMock() - run_manager.folder = self.test_dir - - # Logger mock - mock_logger = MagicMock() - run_manager._logging_config.setup_logging.return_value = mock_logger - - # Config contains force_file - run_manager._config_manager.load_config.return_value = { - "test_run": { - "top_traj_file": ["/path/to/tpr", "/path/to/trr"], - "force_file": "/path/to/forces", - "file_format": "gro", - "kcal_force_units": "kcal", - "selection_string": "all", - "output_file": "output.json", - "verbose": False, - } - } - - # Parse args mock - mock_args = MagicMock() - mock_args.output_file = "output.json" - mock_args.verbose = False - mock_args.top_traj_file = ["/path/to/tpr", "/path/to/trr"] - mock_args.force_file = "/path/to/forces" - mock_args.file_format = "gro" - mock_args.kcal_force_units = "kcal" - mock_args.selection_string = "all" - - parser = run_manager._config_manager.setup_argparse.return_value - parser.parse_known_args.return_value = (mock_args, []) - run_manager._config_manager.merge_configs.return_value = mock_args - - # Mock UniverseOperations.merge_forces - with ( - unittest.mock.patch( - "CodeEntropy.run.EntropyManager", return_value=MagicMock() - ) as Entropy_patch, - unittest.mock.patch("CodeEntropy.run.UniverseOperations") as UOps_patch, - unittest.mock.patch("CodeEntropy.run.mda.Universe") as mock_universe, - ): - mock_universe_ops = UOps_patch.return_value - mock_universe_ops.merge_forces.return_value = MagicMock() - - run_manager.run_entropy_workflow() - - # Ensure merge_forces is used - mock_universe_ops.merge_forces.assert_called_once_with( - "/path/to/tpr", - ["/path/to/trr"], - "/path/to/forces", - "gro", - "kcal", - ) - - mock_universe.assert_not_called() - - Entropy_patch.return_value.execute.assert_called_once() - - def test_run_configuration_warning(self): - """ - Test that a warning is logged when the config entry is not a dictionary. - """ - run_manager = RunManager("mock_folder/job001") - run_manager._logging_config = MagicMock() - run_manager._config_manager = MagicMock() - run_manager.load_citation_data = MagicMock() - run_manager._data_logger = MagicMock() - run_manager.folder = self.test_dir - - mock_logger = MagicMock() - run_manager._logging_config.setup_logging.return_value = mock_logger - - run_manager._config_manager.load_config.return_value = { - "invalid_run": "this_should_be_a_dict" - } - - run_manager.load_citation_data.return_value = { - "cff-version": "1.2.0", - "title": "CodeEntropy", - "message": ( - "If you use this software, please cite it using the " - "metadata from this file." - ), - "type": "software", - "authors": [ - { - "given-names": "Forename", - "family-names": "Sirname", - "email": "test@email.ac.uk", - } - ], - } - - mock_args = MagicMock() - mock_args.output_file = "output.json" - mock_args.verbose = False - - parser = run_manager._config_manager.setup_argparse.return_value - parser.parse_known_args.return_value = (mock_args, []) - run_manager._config_manager.merge_configs.return_value = mock_args - - run_manager.run_entropy_workflow() - - mock_logger.warning.assert_called_with( - "Run configuration for invalid_run is not a dictionary." - ) - - def test_run_entropy_workflow_missing_traj_file(self): - """ - Test that a ValueError is raised when 'top_traj_file' is missing. - """ - run_manager = RunManager("mock_folder/job001") - run_manager._logging_config = MagicMock() - run_manager._config_manager = MagicMock() - run_manager.load_citation_data = MagicMock() - run_manager._data_logger = MagicMock() - run_manager.folder = self.test_dir - - mock_logger = MagicMock() - run_manager._logging_config.setup_logging.return_value = mock_logger - - run_manager._config_manager.load_config.return_value = { - "test_run": { - "top_traj_file": None, - "output_file": "output.json", - "verbose": False, - } - } - - run_manager.load_citation_data.return_value = { - "cff-version": "1.2.0", - "title": "CodeEntropy", - "message": ( - "If you use this software, please cite it using the " - "metadata from this file." - ), - "type": "software", - "authors": [ - { - "given-names": "Forename", - "family-names": "Sirname", - "email": "test@email.ac.uk", - } - ], - } - - mock_args = MagicMock() - mock_args.output_file = "output.json" - mock_args.verbose = False - mock_args.top_traj_file = None - mock_args.selection_string = None - - parser = run_manager._config_manager.setup_argparse.return_value - parser.parse_known_args.return_value = (mock_args, []) - run_manager._config_manager.merge_configs.return_value = mock_args - - with self.assertRaisesRegex(ValueError, "Missing 'top_traj_file' argument."): - run_manager.run_entropy_workflow() - - def test_run_entropy_workflow_missing_selection_string(self): - """ - Test that a ValueError is raised when 'selection_string' is missing. - """ - run_manager = RunManager("mock_folder/job001") - run_manager._logging_config = MagicMock() - run_manager._config_manager = MagicMock() - run_manager.load_citation_data = MagicMock() - run_manager._data_logger = MagicMock() - run_manager.folder = self.test_dir - - mock_logger = MagicMock() - run_manager._logging_config.setup_logging.return_value = mock_logger - - run_manager._config_manager.load_config.return_value = { - "test_run": { - "top_traj_file": ["/path/to/tpr", "/path/to/trr"], - "output_file": "output.json", - "verbose": False, - } - } - - run_manager.load_citation_data.return_value = { - "cff-version": "1.2.0", - "title": "CodeEntropy", - "message": ( - "If you use this software, please cite it using the " - "metadata from this file." - ), - "type": "software", - "authors": [ - { - "given-names": "Forename", - "family-names": "Sirname", - "email": "test@email.ac.uk", - } - ], - } - - mock_args = MagicMock() - mock_args.output_file = "output.json" - mock_args.verbose = False - mock_args.top_traj_file = ["/path/to/tpr", "/path/to/trr"] - mock_args.selection_string = None - - parser = run_manager._config_manager.setup_argparse.return_value - parser.parse_known_args.return_value = (mock_args, []) - run_manager._config_manager.merge_configs.return_value = mock_args - - with self.assertRaisesRegex(ValueError, "Missing 'selection_string' argument."): - run_manager.run_entropy_workflow() - - @patch("CodeEntropy.run.pickle.dump") - @patch("CodeEntropy.run.open", create=True) - def test_write_universe(self, mock_open, mock_pickle_dump): - # Mock Universe - mock_universe = MagicMock() - - # Mock the file object returned by open - mock_file = MagicMock() - mock_open.return_value = mock_file - - run_manager = RunManager("mock_folder/job001") - result = run_manager.write_universe(mock_universe, name="test_universe") - - mock_open.assert_called_once_with("test_universe.pkl", "wb") - - # Ensure pickle.dump() was called - mock_pickle_dump.assert_called_once_with(mock_universe, mock_file) - - # Ensure the method returns the correct filename - self.assertEqual(result, "test_universe") - - @patch("CodeEntropy.run.pickle.load") - @patch("CodeEntropy.run.open", create=True) - def test_read_universe(self, mock_open, mock_pickle_load): - # Mock the file object returned by open - mock_file = MagicMock() - mock_open.return_value = mock_file - - # Mock Universe to return when pickle.load is called - mock_universe = MagicMock() - mock_pickle_load.return_value = mock_universe - - # Path to the mock file - path = "test_universe.pkl" - - run_manager = RunManager("mock_folder/job001") - result = run_manager.read_universe(path) - - mock_open.assert_called_once_with(path, "rb") - - # Ensure pickle.load() was called with the mock file object - mock_pickle_load.assert_called_once_with(mock_file) - - # Ensure the method returns the correct mock universe - self.assertEqual(result, mock_universe) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/CodeEntropy/cli/test___main__.py b/tests/unit/CodeEntropy/cli/test___main__.py new file mode 100644 index 00000000..ce8c3c29 --- /dev/null +++ b/tests/unit/CodeEntropy/cli/test___main__.py @@ -0,0 +1,13 @@ +import runpy +from unittest.mock import MagicMock + +import CodeEntropy.cli as cli + + +def test___main___invokes_cli_main(monkeypatch): + main_spy = MagicMock() + monkeypatch.setattr(cli, "main", main_spy) + + runpy.run_module("CodeEntropy", run_name="__main__") + + main_spy.assert_called_once_with() diff --git a/tests/unit/CodeEntropy/cli/test_cli.py b/tests/unit/CodeEntropy/cli/test_cli.py new file mode 100644 index 00000000..2df420cc --- /dev/null +++ b/tests/unit/CodeEntropy/cli/test_cli.py @@ -0,0 +1,38 @@ +from unittest.mock import MagicMock + +import pytest + +import CodeEntropy.cli as entry + + +def test_main_creates_job_folder_and_runs_workflow(monkeypatch): + fake_runner_cls = MagicMock() + fake_runner_cls.create_job_folder.return_value = "/tmp/job" + + fake_runner = MagicMock() + fake_runner_cls.return_value = fake_runner + + monkeypatch.setattr(entry, "CodeEntropyRunner", fake_runner_cls) + + entry.main() + + fake_runner_cls.create_job_folder.assert_called_once_with() + fake_runner_cls.assert_called_once_with(folder="/tmp/job") + fake_runner.run_entropy_workflow.assert_called_once_with() + + +def test_main_logs_and_exits_nonzero_on_exception(monkeypatch): + fake_runner_cls = MagicMock() + fake_runner_cls.create_job_folder.return_value = "/tmp/job" + + fake_runner = MagicMock() + fake_runner.run_entropy_workflow.side_effect = RuntimeError("boom") + fake_runner_cls.return_value = fake_runner + + monkeypatch.setattr(entry, "CodeEntropyRunner", fake_runner_cls) + + with pytest.raises(SystemExit) as exc: + entry.main() + + assert exc.value.code == 1 + fake_runner.run_entropy_workflow.assert_called_once_with() diff --git a/tests/unit/CodeEntropy/config/argparse/conftest.py b/tests/unit/CodeEntropy/config/argparse/conftest.py new file mode 100644 index 00000000..7ad0a4bc --- /dev/null +++ b/tests/unit/CodeEntropy/config/argparse/conftest.py @@ -0,0 +1,50 @@ +from types import SimpleNamespace + +import pytest + +from CodeEntropy.config.argparse import ConfigResolver + + +class DummyUniverse: + """Minimal MDAnalysis-like Universe stub for validate_inputs tests.""" + + def __init__(self, length: int): + self.trajectory = [None] * length + + +@pytest.fixture() +def resolver(): + return ConfigResolver() + + +@pytest.fixture() +def dummy_universe(): + # default length used in many tests + return DummyUniverse(length=100) + + +@pytest.fixture() +def make_args(): + """Factory to build an args-like object with defaults used by validation checks.""" + + def _make(**overrides): + base = dict( + start=0, + end=10, + step=1, + bin_width=30, + temperature=298.0, + force_partitioning=0.5, + ) + base.update(overrides) + # validation functions only require attribute access; SimpleNamespace is ideal + return SimpleNamespace(**base) + + return _make + + +@pytest.fixture() +def empty_cli_args(resolver): + """Argparse Namespace with all parser defaults.""" + parser = resolver.build_parser() + return parser.parse_args([]) diff --git a/tests/unit/CodeEntropy/config/argparse/test_argparse_build_parser.py b/tests/unit/CodeEntropy/config/argparse/test_argparse_build_parser.py new file mode 100644 index 00000000..4505ad9f --- /dev/null +++ b/tests/unit/CodeEntropy/config/argparse/test_argparse_build_parser.py @@ -0,0 +1,39 @@ +from CodeEntropy.config.argparse import ConfigResolver + + +def test_build_parser_parses_selection_string(): + resolver = ConfigResolver() + parser = resolver.build_parser() + + args = parser.parse_args(["--selection_string", "protein"]) + + assert args.selection_string == "protein" + + +def test_build_parser_parses_bool_with_str2bool(): + resolver = ConfigResolver() + parser = resolver.build_parser() + + args = parser.parse_args(["--kcal_force_units", "true"]) + + assert args.kcal_force_units is True + + +def test_build_parser_store_true_flag_verbose_defaults_false_and_sets_true(): + resolver = ConfigResolver() + parser = resolver.build_parser() + + args_default = parser.parse_args([]) + assert args_default.verbose is False + + args_verbose = parser.parse_args(["--verbose"]) + assert args_verbose.verbose is True + + +def test_build_parser_nargs_plus_parses_top_traj_file_list(): + resolver = ConfigResolver() + parser = resolver.build_parser() + + args = parser.parse_args(["--top_traj_file", "a.tpr", "b.trr"]) + + assert args.top_traj_file == ["a.tpr", "b.trr"] diff --git a/tests/unit/CodeEntropy/config/argparse/test_argparse_load_config.py b/tests/unit/CodeEntropy/config/argparse/test_argparse_load_config.py new file mode 100644 index 00000000..bfd5a6fe --- /dev/null +++ b/tests/unit/CodeEntropy/config/argparse/test_argparse_load_config.py @@ -0,0 +1,50 @@ +from unittest.mock import mock_open, patch + +from CodeEntropy.config.argparse import ConfigResolver + + +def test_load_config_valid_yaml_returns_dict(): + yaml_content = """ +run1: + selection_string: protein +""" + with ( + patch("glob.glob", return_value=["/fake/config.yaml"]), + patch("builtins.open", mock_open(read_data=yaml_content)), + ): + resolver = ConfigResolver() + config = resolver.load_config("/fake") + + assert "run1" in config + assert config["run1"]["selection_string"] == "protein" + + +def test_load_config_no_yaml_files_returns_default(): + with patch("glob.glob", return_value=[]): + resolver = ConfigResolver() + config = resolver.load_config("/fake") + + assert config == {"run1": {}} + + +def test_load_config_yaml_empty_returns_default_run1(): + yaml_content = "" # yaml.safe_load -> None + with ( + patch("glob.glob", return_value=["/fake/config.yaml"]), + patch("builtins.open", mock_open(read_data=yaml_content)), + ): + resolver = ConfigResolver() + config = resolver.load_config("/fake") + + assert config == {"run1": {}} + + +def test_load_config_open_error_returns_default(): + with ( + patch("glob.glob", return_value=["/fake/config.yaml"]), + patch("builtins.open", side_effect=OSError("boom")), + ): + resolver = ConfigResolver() + config = resolver.load_config("/fake") + + assert config == {"run1": {}} diff --git a/tests/unit/CodeEntropy/config/argparse/test_argparse_resolve.py b/tests/unit/CodeEntropy/config/argparse/test_argparse_resolve.py new file mode 100644 index 00000000..39893d93 --- /dev/null +++ b/tests/unit/CodeEntropy/config/argparse/test_argparse_resolve.py @@ -0,0 +1,77 @@ +import logging + +import pytest + +from CodeEntropy.config.argparse import ConfigResolver, logger + + +def test_resolve_run_config_wrong_type_raises_type_error(resolver, empty_cli_args): + with pytest.raises(TypeError): + resolver.resolve(empty_cli_args, run_config="not-a-dict") + + +def test_resolve_none_run_config_treated_as_empty(resolver, empty_cli_args): + resolved = resolver.resolve(empty_cli_args, None) + # should still have defaults applied + assert resolved.selection_string is not None + + +def test_resolve_yaml_applied_when_cli_not_provided(resolver, empty_cli_args): + run_config = {"selection_string": "yaml_value"} + + resolved = resolver.resolve(empty_cli_args, run_config) + + assert resolved.selection_string == "yaml_value" + + +def test_resolve_cli_value_overrides_yaml(resolver): + parser = resolver.build_parser() + args = parser.parse_args(["--selection_string", "cli_value"]) + run_config = {"selection_string": "yaml_value"} + + resolved = resolver.resolve(args, run_config) + + assert resolved.selection_string == "cli_value" + + +def test_resolve_does_not_apply_yaml_key_not_in_arg_specs(resolver, empty_cli_args): + run_config = {"not_a_real_arg": 123} + + resolved = resolver.resolve(empty_cli_args, run_config) + + assert not hasattr(resolved, "not_a_real_arg") + + +def test_resolve_ensure_defaults_sets_none_values(resolver): + # If a known arg is None, _ensure_defaults should fill it. + parser = resolver.build_parser() + args = parser.parse_args([]) + + # force a known arg to None to simulate partial/mutated namespace + args.selection_string = None + + resolved = resolver.resolve(args, {}) + + assert resolved.selection_string == "all" + + +def test_resolve_verbose_sets_logger_debug_level(resolver): + parser = resolver.build_parser() + args = parser.parse_args(["--verbose"]) + + resolver.resolve(args, {}) + + assert logger.level == logging.DEBUG + + +def test_apply_logging_level_updates_handler_level(): + handler = logging.StreamHandler() + handler.setLevel(logging.WARNING) + logger.addHandler(handler) + + try: + ConfigResolver._apply_logging_level(verbose=True) + assert logger.level == logging.DEBUG + assert handler.level == logging.DEBUG + finally: + logger.removeHandler(handler) diff --git a/tests/unit/CodeEntropy/config/argparse/test_argparse_str2bool.py b/tests/unit/CodeEntropy/config/argparse/test_argparse_str2bool.py new file mode 100644 index 00000000..3306dd9c --- /dev/null +++ b/tests/unit/CodeEntropy/config/argparse/test_argparse_str2bool.py @@ -0,0 +1,30 @@ +import argparse as _argparse + +import pytest + +from CodeEntropy.config.argparse import ConfigResolver + + +@pytest.mark.parametrize("value", ["true", "True", "t", "yes", "1"]) +def test_str2bool_true_variants(value): + assert ConfigResolver.str2bool(value) is True + + +@pytest.mark.parametrize("value", ["false", "False", "f", "no", "0"]) +def test_str2bool_false_variants(value): + assert ConfigResolver.str2bool(value) is False + + +def test_str2bool_bool_passthrough(): + assert ConfigResolver.str2bool(True) is True + assert ConfigResolver.str2bool(False) is False + + +def test_str2bool_non_string_non_bool_raises(): + with pytest.raises(_argparse.ArgumentTypeError): + ConfigResolver.str2bool(123) + + +def test_str2bool_invalid_string_raises(): + with pytest.raises(_argparse.ArgumentTypeError): + ConfigResolver.str2bool("maybe") diff --git a/tests/unit/CodeEntropy/config/argparse/test_argparse_validate_inputs.py b/tests/unit/CodeEntropy/config/argparse/test_argparse_validate_inputs.py new file mode 100644 index 00000000..1769a9c3 --- /dev/null +++ b/tests/unit/CodeEntropy/config/argparse/test_argparse_validate_inputs.py @@ -0,0 +1,59 @@ +import logging + +import pytest + + +def test_validate_inputs_valid_does_not_raise(resolver, dummy_universe, make_args): + args = make_args() + resolver.validate_inputs(dummy_universe, args) + + +def test_check_input_start_raises_when_start_exceeds_trajectory(resolver, make_args): + u = type("U", (), {"trajectory": [None] * 10})() + args = make_args(start=11) + + with pytest.raises(ValueError): + resolver._check_input_start(u, args) + + +def test_check_input_end_raises_when_end_exceeds_trajectory(resolver, make_args): + u = type("U", (), {"trajectory": [None] * 10})() + args = make_args(end=11) + + with pytest.raises(ValueError): + resolver._check_input_end(u, args) + + +def test_check_input_step_negative_logs_warning(resolver, make_args, caplog): + args = make_args(step=-1) + + with caplog.at_level(logging.WARNING): + resolver._check_input_step(args) + + assert "Negative 'step' value" in caplog.text + + +@pytest.mark.parametrize("bin_width", [-1, 361]) +def test_check_input_bin_width_out_of_range_raises(resolver, make_args, bin_width): + args = make_args(bin_width=bin_width) + + with pytest.raises(ValueError): + resolver._check_input_bin_width(args) + + +def test_check_input_temperature_negative_raises(resolver, make_args): + args = make_args(temperature=-0.1) + + with pytest.raises(ValueError): + resolver._check_input_temperature(args) + + +def test_check_input_force_partitioning_non_default_logs_warning( + resolver, make_args, caplog +): + args = make_args(force_partitioning=0.7) + + with caplog.at_level(logging.WARNING): + resolver._check_input_force_partitioning(args) + + assert "differs from the default" in caplog.text diff --git a/tests/unit/CodeEntropy/config/runtime/conftest.py b/tests/unit/CodeEntropy/config/runtime/conftest.py new file mode 100644 index 00000000..8649dc57 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/conftest.py @@ -0,0 +1,10 @@ +import pytest + +from CodeEntropy.config.runtime import CodeEntropyRunner + + +@pytest.fixture() +def runner(tmp_path, monkeypatch): + # keep filesystem effects isolated + monkeypatch.chdir(tmp_path) + return CodeEntropyRunner(folder=str(tmp_path)) diff --git a/tests/unit/CodeEntropy/config/runtime/test_build_universe_branches.py b/tests/unit/CodeEntropy/config/runtime/test_build_universe_branches.py new file mode 100644 index 00000000..2d8571b4 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_build_universe_branches.py @@ -0,0 +1,33 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + + +def test_build_universe_no_force(runner): + args = SimpleNamespace( + top_traj_file=["tpr", "trr"], + force_file=None, + file_format=None, + kcal_force_units=False, + ) + uops = MagicMock() + + with patch("CodeEntropy.config.runtime.mda.Universe", return_value="U"): + out = runner._build_universe(args, uops) + + assert out == "U" + uops.merge_forces.assert_not_called() + + +def test_build_universe_with_force(runner): + args = SimpleNamespace( + top_traj_file=["tpr", "trr"], + force_file="force", + file_format="gro", + kcal_force_units=True, + ) + uops = MagicMock() + uops.merge_forces.return_value = "U2" + + out = runner._build_universe(args, uops) + + assert out == "U2" diff --git a/tests/unit/CodeEntropy/config/runtime/test_citation_data.py b/tests/unit/CodeEntropy/config/runtime/test_citation_data.py new file mode 100644 index 00000000..d5876662 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_citation_data.py @@ -0,0 +1,23 @@ +from unittest.mock import MagicMock, patch + +import requests + + +def test_load_citation_data_success(runner): + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.text = """ +title: TestProject +authors: + - given-names: Alice +""" + + with patch("requests.get", return_value=mock_response): + data = runner.load_citation_data() + + assert data["title"] == "TestProject" + + +def test_load_citation_data_network_error_returns_none(runner): + with patch("requests.get", side_effect=requests.exceptions.ConnectionError()): + assert runner.load_citation_data() is None diff --git a/tests/unit/CodeEntropy/config/runtime/test_create_job_folder.py b/tests/unit/CodeEntropy/config/runtime/test_create_job_folder.py new file mode 100644 index 00000000..1c88ca33 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_create_job_folder.py @@ -0,0 +1,51 @@ +import os +from unittest.mock import patch + +from CodeEntropy.config.runtime import CodeEntropyRunner + + +def test_create_job_folder_empty_creates_job001(): + with ( + patch("os.getcwd", return_value="/cwd"), + patch("os.listdir", return_value=[]), + patch("os.makedirs") as mock_makedirs, + ): + path = CodeEntropyRunner.create_job_folder() + + assert path == os.path.join("/cwd", "job001") + mock_makedirs.assert_called_once() + + +def test_create_job_folder_existing_creates_next(): + with ( + patch("os.getcwd", return_value="/cwd"), + patch("os.listdir", return_value=["job001", "job002"]), + patch("os.makedirs") as mock_makedirs, + ): + path = CodeEntropyRunner.create_job_folder() + + assert path == os.path.join("/cwd", "job003") + mock_makedirs.assert_called_once() + + +def test_create_job_folder_ignores_invalid_names(): + with ( + patch("os.getcwd", return_value="/cwd"), + patch("os.listdir", return_value=["job001", "abc", "job002"]), + patch("os.makedirs") as _, + ): + path = CodeEntropyRunner.create_job_folder() + + assert path == os.path.join("/cwd", "job003") + + +def test_create_job_folder_skips_value_error_suffix(): + # jobABC triggers int("ABC") -> ValueError -> continue + with ( + patch("os.getcwd", return_value="/cwd"), + patch("os.listdir", return_value=["jobABC", "job001"]), + patch("os.makedirs"), + ): + path = CodeEntropyRunner.create_job_folder() + + assert path == os.path.join("/cwd", "job002") diff --git a/tests/unit/CodeEntropy/config/runtime/test_print_args_table.py b/tests/unit/CodeEntropy/config/runtime/test_print_args_table.py new file mode 100644 index 00000000..286b3164 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_print_args_table.py @@ -0,0 +1,22 @@ +from io import StringIO +from types import SimpleNamespace +from unittest.mock import patch + +from rich.console import Console + + +def test_print_args_table_prints_all_args(runner): + args = SimpleNamespace(alpha=1, beta="two") + + buf = StringIO() + test_console = Console(file=buf, force_terminal=False, width=120) + + with patch("CodeEntropy.config.runtime.console", test_console): + runner.print_args_table(args) + + out = buf.getvalue() + assert "Run Configuration" in out + assert "alpha" in out + assert "1" in out + assert "beta" in out + assert "two" in out diff --git a/tests/unit/CodeEntropy/config/runtime/test_required_args.py b/tests/unit/CodeEntropy/config/runtime/test_required_args.py new file mode 100644 index 00000000..c9035954 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_required_args.py @@ -0,0 +1,28 @@ +import pytest + +from CodeEntropy.config.runtime import CodeEntropyRunner + + +class Args: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +def test_validate_required_args_missing_traj_raises(): + args = Args(top_traj_file=None, selection_string="all") + + with pytest.raises(ValueError): + CodeEntropyRunner._validate_required_args(args) + + +def test_validate_required_args_missing_selection_raises(): + args = Args(top_traj_file=["a"], selection_string=None) + + with pytest.raises(ValueError): + CodeEntropyRunner._validate_required_args(args) + + +def test_validate_required_args_ok(): + args = Args(top_traj_file=["a"], selection_string="all") + CodeEntropyRunner._validate_required_args(args) diff --git a/tests/unit/CodeEntropy/config/runtime/test_run_entropy_workflow_branches.py b/tests/unit/CodeEntropy/config/runtime/test_run_entropy_workflow_branches.py new file mode 100644 index 00000000..9ef6f761 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_run_entropy_workflow_branches.py @@ -0,0 +1,158 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import CodeEntropy.config.runtime as runtime_mod + + +def test_run_entropy_workflow_warns_and_skips_non_dict_run_config(runner): + runner._logging_config = MagicMock() + runner._config_manager = MagicMock() + runner._reporter = MagicMock() + + run_logger = MagicMock() + runner._logging_config.configure.return_value = run_logger + + runner._config_manager.load_config.return_value = {"bad_run": "not_a_dict"} + + parser = MagicMock() + parser.parse_known_args.return_value = (SimpleNamespace(output_file="out.json"), []) + runner._config_manager.build_parser.return_value = parser + + runner.show_splash = MagicMock() + + runner.run_entropy_workflow() + + run_logger.warning.assert_called_once() + runner._config_manager.resolve.assert_not_called() + + +def test_run_entropy_workflow_raises_when_required_args_missing(runner): + runner._logging_config = MagicMock() + runner._config_manager = MagicMock() + runner._reporter = MagicMock() + + runner._logging_config.configure.return_value = MagicMock() + runner.show_splash = MagicMock() + + runner._config_manager.load_config.return_value = {"run1": {}} + + parser = MagicMock() + args = SimpleNamespace( + output_file="out.json", + verbose=False, + top_traj_file=None, + selection_string=None, + force_file=None, + file_format=None, + kcal_force_units=False, + ) + parser.parse_known_args.return_value = (args, []) + runner._config_manager.build_parser.return_value = parser + runner._config_manager.resolve.return_value = args + + with pytest.raises(RuntimeError) as exc: + runner.run_entropy_workflow() + + assert str(exc.value) == "CodeEntropyRunner encountered an error" + assert isinstance(exc.value.__cause__, ValueError) + assert "Missing 'top_traj_file' argument." in str(exc.value.__cause__) + + +def test_run_entropy_workflow_happy_path_calls_execute_once(runner): + runner._logging_config = MagicMock() + runner._config_manager = MagicMock() + runner._reporter = MagicMock() + runner.show_splash = MagicMock() + runner.print_args_table = MagicMock() + + run_logger = MagicMock() + runner._logging_config.configure.return_value = run_logger + + runner._config_manager.load_config.return_value = {"run1": {}} + + args = SimpleNamespace( + output_file="out.json", + verbose=False, + top_traj_file=["top.tpr", "traj.trr"], + selection_string="all", + force_file=None, + file_format=None, + kcal_force_units=False, + ) + parser = MagicMock() + parser.parse_known_args.return_value = (args, []) + runner._config_manager.build_parser.return_value = parser + + runner._config_manager.resolve.return_value = args + + runner._build_universe = MagicMock(return_value="U") + runner._config_manager.validate_inputs = MagicMock() + + with ( + patch("CodeEntropy.config.runtime.UniverseOperations") as _, + patch("CodeEntropy.config.runtime.MoleculeGrouper") as _, + patch("CodeEntropy.config.runtime.ConformationStateBuilder") as _, + patch("CodeEntropy.config.runtime.EntropyWorkflow") as EWCls, + ): + entropy_instance = MagicMock() + EWCls.return_value = entropy_instance + + runner.run_entropy_workflow() + + runner.print_args_table.assert_called_once_with(args) + runner._build_universe.assert_called_once() + runner._config_manager.validate_inputs.assert_called_once_with("U", args) + EWCls.assert_called_once() + entropy_instance.execute.assert_called_once() + runner._logging_config.export_console.assert_called_once() + + +def test_run_entropy_workflow_logs_when_args_cannot_be_serialized(runner, monkeypatch): + runner._logging_config = MagicMock() + runner._config_manager = MagicMock() + runner._reporter = MagicMock() + + runner._logging_config.configure.return_value = MagicMock() + runner.show_splash = MagicMock() + runner._config_manager.load_config.return_value = {"run1": {}} + + class BadArgs: + __slots__ = ( + "output_file", + "verbose", + "top_traj_file", + "selection_string", + "force_file", + "file_format", + "kcal_force_units", + ) + + args = BadArgs() + args.output_file = "out.json" + args.verbose = False + args.top_traj_file = None + args.selection_string = None + args.force_file = None + args.file_format = None + args.kcal_force_units = False + + parser = MagicMock() + parser.parse_known_args.return_value = (args, []) + runner._config_manager.build_parser.return_value = parser + runner._config_manager.resolve.return_value = args + + error_spy = MagicMock() + monkeypatch.setattr(runtime_mod.logger, "error", error_spy) + + with pytest.raises(RuntimeError) as exc: + runner.run_entropy_workflow() + + assert str(exc.value) == "CodeEntropyRunner encountered an error" + assert isinstance(exc.value.__cause__, ValueError) + + assert any( + "Run arguments at failure could not be serialized" in str(call.args[0]) + for call in error_spy.call_args_list + ) diff --git a/tests/unit/CodeEntropy/config/runtime/test_runtime_properties.py b/tests/unit/CodeEntropy/config/runtime/test_runtime_properties.py new file mode 100644 index 00000000..cd2cae36 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_runtime_properties.py @@ -0,0 +1,7 @@ +def test_n_avogadro_property_returns_internal_value(runner): + # uses the runner fixture (tmp folder) + assert runner.N_AVOGADRO == runner._N_AVOGADRO + + +def test_def_temper_property_returns_internal_value(runner): + assert runner.DEF_TEMPER == runner._DEF_TEMPER diff --git a/tests/unit/CodeEntropy/config/runtime/test_show_splash.py b/tests/unit/CodeEntropy/config/runtime/test_show_splash.py new file mode 100644 index 00000000..171927f1 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_show_splash.py @@ -0,0 +1,46 @@ +from io import StringIO +from unittest.mock import patch + +from rich.console import Console + + +def test_show_splash_with_citation(runner): + citation = { + "title": "TestProject", + "version": "1.0", + "date-released": "2025-01-01", + "url": "https://example.com", + "abstract": "This is a test abstract.", + "authors": [{"given-names": "Alice", "family-names": "Smith"}], + } + + buf = StringIO() + console = Console(file=buf, force_terminal=False, width=120) + + with ( + patch.object(runner, "load_citation_data", return_value=citation), + patch("CodeEntropy.config.runtime.console", console), + ): + runner.show_splash() + + out = buf.getvalue() + + assert "Welcome to CodeEntropy" in out + assert "Version 1.0" in out + assert "2025-01-01" in out + assert "https://example.com" in out + assert "This is a test abstract." in out + assert "Alice Smith" in out + + +def test_show_splash_without_citation(runner): + buf = StringIO() + console = Console(file=buf, force_terminal=False) + + with ( + patch.object(runner, "load_citation_data", return_value=None), + patch("CodeEntropy.config.runtime.console", console), + ): + runner.show_splash() + + assert "Welcome to CodeEntropy" in buf.getvalue() diff --git a/tests/unit/CodeEntropy/config/runtime/test_unit_conversions.py b/tests/unit/CodeEntropy/config/runtime/test_unit_conversions.py new file mode 100644 index 00000000..31b7543d --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_unit_conversions.py @@ -0,0 +1,6 @@ +def test_change_lambda_units(runner): + assert runner.change_lambda_units(2.0) == 2.0 * 1e29 / runner.N_AVOGADRO + + +def test_get_kt2j(runner): + assert runner.get_KT2J(298.0) == 4.11e-21 * 298.0 / runner.DEF_TEMPER diff --git a/tests/unit/CodeEntropy/config/runtime/test_universe_io.py b/tests/unit/CodeEntropy/config/runtime/test_universe_io.py new file mode 100644 index 00000000..023d4271 --- /dev/null +++ b/tests/unit/CodeEntropy/config/runtime/test_universe_io.py @@ -0,0 +1,28 @@ +from unittest.mock import MagicMock, patch + + +def test_write_universe(runner): + u = MagicMock() + + with patch("pickle.dump") as mock_dump: + name = runner.write_universe(u, name="test") + + assert name == "test" + mock_dump.assert_called_once() + + +def test_read_universe(runner): + mock_file = MagicMock() + + # Make open() context manager return mock_file + mock_file.__enter__.return_value = mock_file + + with ( + patch("builtins.open", return_value=mock_file) as mock_open, + patch("pickle.load", return_value="U") as mock_load, + ): + out = runner.read_universe("file.pkl") + + mock_open.assert_called_once_with("file.pkl", "rb") + mock_load.assert_called_once_with(mock_file) + assert out == "U" diff --git a/tests/unit/CodeEntropy/core/logging/conftest.py b/tests/unit/CodeEntropy/core/logging/conftest.py new file mode 100644 index 00000000..20867d87 --- /dev/null +++ b/tests/unit/CodeEntropy/core/logging/conftest.py @@ -0,0 +1,23 @@ +import logging + +import pytest + +from CodeEntropy.core.logging import LoggingConfig + + +@pytest.fixture(autouse=True) +def _isolate_global_logging(): + """ + LoggingConfig modifies global loggers. Keep tests atomic by clearing handlers + after each test so tests don't leak state into each other. + """ + yield + for name in ("", "commands", "MDAnalysis"): + lg = logging.getLogger(name) + lg.handlers.clear() + lg.propagate = True + + +@pytest.fixture() +def config(tmp_path): + return LoggingConfig(folder=str(tmp_path), level=logging.INFO) diff --git a/tests/unit/CodeEntropy/core/logging/test_configure_loggers.py b/tests/unit/CodeEntropy/core/logging/test_configure_loggers.py new file mode 100644 index 00000000..94a1fa39 --- /dev/null +++ b/tests/unit/CodeEntropy/core/logging/test_configure_loggers.py @@ -0,0 +1,26 @@ +import logging + + +def test_configure_attaches_handlers(config): + config.configure() + + root = logging.getLogger() + assert config.handlers["rich"] in root.handlers + assert config.handlers["main"] in root.handlers + assert config.handlers["error"] in root.handlers + + +def test_configure_commands_logger_non_propagating_with_handler(config): + config.configure() + + commands_logger = logging.getLogger("commands") + assert commands_logger.propagate is False + assert config.handlers["command"] in commands_logger.handlers + + +def test_configure_mdanalysis_logger_non_propagating_with_handler(config): + config.configure() + + mda_logger = logging.getLogger("MDAnalysis") + assert mda_logger.propagate is False + assert config.handlers["mdanalysis"] in mda_logger.handlers diff --git a/tests/unit/CodeEntropy/core/logging/test_error_filter.py b/tests/unit/CodeEntropy/core/logging/test_error_filter.py new file mode 100644 index 00000000..be6e6670 --- /dev/null +++ b/tests/unit/CodeEntropy/core/logging/test_error_filter.py @@ -0,0 +1,23 @@ +import logging + +from CodeEntropy.core.logging import ErrorFilter + + +def test_error_filter_allows_error_and_critical(): + f = ErrorFilter() + + record_error = logging.LogRecord("x", logging.ERROR, "f.py", 1, "msg", (), None) + record_crit = logging.LogRecord("x", logging.CRITICAL, "f.py", 1, "msg", (), None) + + assert f.filter(record_error) is True + assert f.filter(record_crit) is True + + +def test_error_filter_blocks_below_error(): + f = ErrorFilter() + + record_warn = logging.LogRecord("x", logging.WARNING, "f.py", 1, "msg", (), None) + record_info = logging.LogRecord("x", logging.INFO, "f.py", 1, "msg", (), None) + + assert f.filter(record_warn) is False + assert f.filter(record_info) is False diff --git a/tests/unit/CodeEntropy/core/logging/test_export_console.py b/tests/unit/CodeEntropy/core/logging/test_export_console.py new file mode 100644 index 00000000..65f5276f --- /dev/null +++ b/tests/unit/CodeEntropy/core/logging/test_export_console.py @@ -0,0 +1,17 @@ +import os +from unittest.mock import MagicMock + + +def test_export_console_writes_recorded_output(config): + # Make export_text deterministic + config.console.export_text = MagicMock(return_value="HELLO") + + config.export_console("out.txt") + + out_path = os.path.join(config.log_dir, "out.txt") + assert os.path.exists(out_path) + + with open(out_path, "r", encoding="utf-8") as f: + assert f.read() == "HELLO" + + config.console.export_text.assert_called_once() diff --git a/tests/unit/CodeEntropy/core/logging/test_get_console_singleton.py b/tests/unit/CodeEntropy/core/logging/test_get_console_singleton.py new file mode 100644 index 00000000..53900eb2 --- /dev/null +++ b/tests/unit/CodeEntropy/core/logging/test_get_console_singleton.py @@ -0,0 +1,19 @@ +from CodeEntropy.core.logging import LoggingConfig + + +def test_get_console_returns_singleton(): + # Reset singleton to make the test independent + LoggingConfig._console = None + + c1 = LoggingConfig.get_console() + c2 = LoggingConfig.get_console() + + assert c1 is c2 + + +def test_get_console_records_output_enabled(): + LoggingConfig._console = None + c = LoggingConfig.get_console() + + # Rich Console uses 'record' attribute when recording is enabled + assert getattr(c, "record", False) is True diff --git a/tests/unit/CodeEntropy/core/logging/test_handlers_setup.py b/tests/unit/CodeEntropy/core/logging/test_handlers_setup.py new file mode 100644 index 00000000..496f17be --- /dev/null +++ b/tests/unit/CodeEntropy/core/logging/test_handlers_setup.py @@ -0,0 +1,41 @@ +import logging +import os + +from rich.logging import RichHandler + +from CodeEntropy.core.logging import LoggingConfig + + +def test_init_creates_log_dir(tmp_path): + cfg = LoggingConfig(folder=str(tmp_path)) + assert os.path.isdir(cfg.log_dir) + + +def test_setup_handlers_creates_expected_handlers(config): + assert set(config.handlers.keys()) == { + "rich", + "main", + "error", + "command", + "mdanalysis", + } + + assert isinstance(config.handlers["rich"], RichHandler) + assert isinstance(config.handlers["main"], logging.FileHandler) + assert isinstance(config.handlers["error"], logging.FileHandler) + assert isinstance(config.handlers["command"], logging.FileHandler) + assert isinstance(config.handlers["mdanalysis"], logging.FileHandler) + + +def test_handler_paths_match_expected_filenames(config): + expected = { + "main": "program.log", + "error": "program.err", + "command": "program.com", + "mdanalysis": "mdanalysis.log", + } + + for handler_key, filename in expected.items(): + handler = config.handlers[handler_key] + assert os.path.basename(handler.baseFilename) == filename + assert os.path.dirname(handler.baseFilename) == config.log_dir diff --git a/tests/unit/CodeEntropy/core/logging/test_set_level.py b/tests/unit/CodeEntropy/core/logging/test_set_level.py new file mode 100644 index 00000000..b28a0bf5 --- /dev/null +++ b/tests/unit/CodeEntropy/core/logging/test_set_level.py @@ -0,0 +1,27 @@ +import logging + + +def test_set_level_updates_root_and_named_loggers(config): + config.configure() + + config.set_level(logging.DEBUG) + + root = logging.getLogger() + assert root.level == logging.DEBUG + + assert logging.getLogger("commands").level == logging.DEBUG + assert logging.getLogger("MDAnalysis").level == logging.DEBUG + + +def test_set_level_sets_filehandlers_to_log_level_and_other_handlers_to_info(config): + config.configure() + + config.set_level(logging.DEBUG) + + root = logging.getLogger() + + for h in root.handlers: + if isinstance(h, logging.FileHandler): + assert h.level == logging.DEBUG + else: + assert h.level == logging.INFO diff --git a/tests/unit/CodeEntropy/entropy/conftest.py b/tests/unit/CodeEntropy/entropy/conftest.py new file mode 100644 index 00000000..4a05ccbb --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/conftest.py @@ -0,0 +1,76 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + + +class FakeResidue: + def __init__(self, resname="RES"): + self.resname = resname + + +class FakeMol: + def __init__(self, residues=None): + self.residues = residues or [FakeResidue("ALA"), FakeResidue("GLY")] + + +class FakeAtoms: + def __init__(self, fragments): + self.fragments = fragments + + +class FakeUniverse: + def __init__(self, n_frames=10, fragments=None): + self.trajectory = list(range(n_frames)) + self.atoms = FakeAtoms(fragments or [FakeMol()]) + + +@pytest.fixture() +def args(): + # Only fields used by entropy modules. + return SimpleNamespace( + temperature=298.0, + bin_width=30, + grouping="molecules", + water_entropy=False, + combined_forcetorque=True, + selection_string="all", + start=0, + end=-1, + step=1, + ) + + +@pytest.fixture() +def reporter(): + return MagicMock() + + +@pytest.fixture() +def run_manager(): + rm = MagicMock() + rm.change_lambda_units.side_effect = lambda x: x + rm.get_KT2J.return_value = 2.479e-21 + return rm + + +@pytest.fixture() +def reduced_universe(): + return FakeUniverse(n_frames=12, fragments=[FakeMol()]) + + +@pytest.fixture() +def shared_data(args, reporter, run_manager, reduced_universe): + return { + "args": args, + "reporter": reporter, + "run_manager": run_manager, + "reduced_universe": reduced_universe, + "universe": reduced_universe, + "groups": {0: [0]}, + "levels": {0: ["united_atom", "residue"]}, + "start": 0, + "end": 12, + "step": 1, + "n_frames": 12, + } diff --git a/tests/unit/CodeEntropy/entropy/nodes/test_aggregate_node.py b/tests/unit/CodeEntropy/entropy/nodes/test_aggregate_node.py new file mode 100644 index 00000000..87230f40 --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/nodes/test_aggregate_node.py @@ -0,0 +1,22 @@ +from CodeEntropy.entropy.nodes.aggregate import AggregateEntropyNode + + +def test_aggregate_node_collects_values_and_writes_shared_data(): + node = AggregateEntropyNode() + shared = {"vibrational_entropy": {"v": 1}, "configurational_entropy": {"c": 2}} + + out = node.run(shared) + + assert out["entropy_results"]["vibrational_entropy"] == {"v": 1} + assert out["entropy_results"]["configurational_entropy"] == {"c": 2} + assert shared["entropy_results"] == out["entropy_results"] + + +def test_aggregate_node_missing_upstreams_yields_none_values(): + node = AggregateEntropyNode() + shared = {} + + out = node.run(shared) + + assert out["entropy_results"]["vibrational_entropy"] is None + assert out["entropy_results"]["configurational_entropy"] is None diff --git a/tests/unit/CodeEntropy/entropy/nodes/test_configurational_node.py b/tests/unit/CodeEntropy/entropy/nodes/test_configurational_node.py new file mode 100644 index 00000000..b31cff3c --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/nodes/test_configurational_node.py @@ -0,0 +1,72 @@ +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from CodeEntropy.entropy.nodes.configurational import ConfigurationalEntropyNode + + +def test_config_node_raises_if_frame_count_missing(): + node = ConfigurationalEntropyNode() + with pytest.raises(KeyError): + node._get_n_frames({}) + + +def test_config_node_run_writes_results(shared_data): + node = ConfigurationalEntropyNode() + + shared_data["conformational_states"] = { + "ua": {(0, 0): [0, 0, 1, 1]}, + "res": {0: [0, 1, 1, 1]}, + } + + shared_data["levels"] = {0: ["united_atom", "residue"]} + shared_data["groups"] = {0: [0]} + + out = node.run(shared_data) + + assert "configurational_entropy" in out + assert "configurational_entropy" in shared_data + assert 0 in shared_data["configurational_entropy"] + + +def test_run_skips_empty_mol_ids_group(): + node = ConfigurationalEntropyNode() + + shared_data = { + "n_frames": 5, + "groups": {0: []}, + "levels": {0: ["united_atom"]}, + "reduced_universe": MagicMock(), + "conformational_states": {"ua": {}, "res": {}}, + "reporter": None, + } + + out = node.run(shared_data) + + assert "configurational_entropy" in out + assert out["configurational_entropy"][0]["ua"] == 0.0 + + +def test_get_group_states_sequence_in_range_returns_value(): + states_res = [None, [1, 2, 3]] + out = ConfigurationalEntropyNode._get_group_states(states_res, group_id=1) + assert out == [1, 2, 3] + + +def test_get_group_states_sequence_out_of_range_returns_none(): + states_res = [None] + out = ConfigurationalEntropyNode._get_group_states(states_res, group_id=2) + assert out is None + + +def test_has_state_data_numpy_array_uses_np_any_branch(): + assert ConfigurationalEntropyNode._has_state_data(np.array([0, 0, 1])) is True + assert ConfigurationalEntropyNode._has_state_data(np.array([0, 0, 0])) is False + + +def test_has_state_data_noniterable_hits_typeerror_fallback(): + # any(0) raises TypeError -> returns bool(0) == False + assert ConfigurationalEntropyNode._has_state_data(0) is False + # any(1) raises TypeError -> returns bool(1) == True + assert ConfigurationalEntropyNode._has_state_data(1) is True diff --git a/tests/unit/CodeEntropy/entropy/nodes/test_vibrational_node.py b/tests/unit/CodeEntropy/entropy/nodes/test_vibrational_node.py new file mode 100644 index 00000000..45f2d05e --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/nodes/test_vibrational_node.py @@ -0,0 +1,382 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from CodeEntropy.entropy.nodes.vibrational import EntropyPair, VibrationalEntropyNode + + +@pytest.fixture() +def shared_data_base(): + frag = MagicMock() + frag.residues = [MagicMock(resname="RES")] + reduced_universe = MagicMock() + reduced_universe.atoms.fragments = [frag] + + return { + "run_manager": MagicMock(), + "args": SimpleNamespace(temperature=298.0, combined_forcetorque=False), + "groups": {0: [0]}, + "levels": {0: ["united_atom"]}, + "reduced_universe": reduced_universe, + "force_covariances": {"ua": {}, "res": [], "poly": []}, + "torque_covariances": {"ua": {}, "res": [], "poly": []}, + "n_frames": 5, + "reporter": MagicMock(), + } + + +@pytest.fixture() +def shared_groups(): + frag = MagicMock() + frag.residues = [MagicMock(resname="RES")] + ru = MagicMock() + ru.atoms.fragments = [frag] + + return { + "run_manager": MagicMock(), + "args": SimpleNamespace(temperature=298.0, combined_forcetorque=False), + "groups": {5: [0]}, + "levels": {0: ["united_atom"]}, + "reduced_universe": ru, + "force_covariances": {"ua": {}, "res": [], "poly": []}, + "torque_covariances": {"ua": {}, "res": [], "poly": []}, + "n_frames": 5, + "reporter": MagicMock(), + } + + +def test_united_atom_branch_logs_and_stores(shared_data_base, monkeypatch): + node = VibrationalEntropyNode() + + monkeypatch.setattr( + node, + "_compute_united_atom_entropy", + MagicMock(return_value=EntropyPair(trans=1.0, rot=2.0)), + ) + monkeypatch.setattr(node, "_log_molecule_level_results", MagicMock()) + + out = node.run(shared_data_base) + + assert out["vibrational_entropy"][0]["united_atom"]["trans"] == 1.0 + assert out["vibrational_entropy"][0]["united_atom"]["rot"] == 2.0 + + +def test_residue_branch_noncombined(shared_data_base, monkeypatch): + node = VibrationalEntropyNode() + + shared_data_base["levels"] = {0: ["residue"]} + shared_data_base["group_id_to_index"] = {0: 0} + shared_data_base["force_covariances"]["res"] = [np.eye(3)] + shared_data_base["torque_covariances"]["res"] = [np.eye(3)] + + monkeypatch.setattr( + node, + "_compute_force_torque_entropy", + MagicMock(return_value=EntropyPair(trans=3.0, rot=4.0)), + ) + monkeypatch.setattr(node, "_log_molecule_level_results", MagicMock()) + + out = node.run(shared_data_base) + + assert out["vibrational_entropy"][0]["residue"]["trans"] == 3.0 + assert out["vibrational_entropy"][0]["residue"]["rot"] == 4.0 + + +def test_polymer_branch_combined_ft_at_highest(shared_data_base, monkeypatch): + node = VibrationalEntropyNode() + + shared_data_base["args"].combined_forcetorque = True + shared_data_base["levels"] = {0: ["polymer"]} + shared_data_base["group_id_to_index"] = {0: 0} + shared_data_base["forcetorque_covariances"] = {"poly": [np.eye(6)]} + + monkeypatch.setattr( + node, + "_compute_ft_entropy", + MagicMock(return_value=EntropyPair(trans=5.0, rot=6.0)), + ) + monkeypatch.setattr(node, "_log_molecule_level_results", MagicMock()) + + out = node.run(shared_data_base) + + assert out["vibrational_entropy"][0]["polymer"]["trans"] == 5.0 + assert out["vibrational_entropy"][0]["polymer"]["rot"] == 6.0 + + +def test_get_indexed_matrix_typeerror_returns_none(): + node = VibrationalEntropyNode() + assert node._get_indexed_matrix(mats=123, index=0) is None + + +def test_get_group_id_to_index_uses_cached_mapping(shared_data): + node = VibrationalEntropyNode() + shared_data["group_id_to_index"] = {7: 0} + assert node._get_group_id_to_index(shared_data) == {7: 0} + + +def test_get_group_id_to_index_falls_back_to_enumeration(shared_data): + node = VibrationalEntropyNode() + shared_data.pop("group_id_to_index", None) + shared_data["groups"] = {5: [0], 9: [1]} + assert node._get_group_id_to_index(shared_data) == {5: 0, 9: 1} + + +def test_run_raises_on_unknown_level(shared_data, monkeypatch): + node = VibrationalEntropyNode() + + shared_data["levels"] = {0: ["banana"]} + shared_data["groups"] = {0: [0]} + + shared_data["force_covariances"] = {"ua": {}, "res": [], "poly": []} + shared_data["torque_covariances"] = {"ua": {}, "res": [], "poly": []} + + with pytest.raises(ValueError): + node.run(shared_data) + + +def test_run_united_atom_branch_stores_results(shared_data, monkeypatch): + node = VibrationalEntropyNode() + + shared_data["levels"] = {0: ["united_atom"]} + shared_data["groups"] = {0: [0]} + shared_data["force_covariances"] = {"ua": {}, "res": [], "poly": []} + shared_data["torque_covariances"] = {"ua": {}, "res": [], "poly": []} + + fake_pair = MagicMock(trans=1.0, rot=2.0) + monkeypatch.setattr( + node, "_compute_united_atom_entropy", MagicMock(return_value=fake_pair) + ) + monkeypatch.setattr(node, "_log_molecule_level_results", MagicMock()) + + out = node.run(shared_data) + + assert "vibrational_entropy" in out + assert shared_data["vibrational_entropy"][0]["united_atom"]["trans"] == 1.0 + assert shared_data["vibrational_entropy"][0]["united_atom"]["rot"] == 2.0 + + +def test_unknown_level_raises(shared_data): + node = VibrationalEntropyNode() + + shared_data["levels"] = {0: ["invalid"]} + shared_data["groups"] = {0: [0]} + shared_data["force_covariances"] = {"ua": {}, "res": [], "poly": []} + shared_data["torque_covariances"] = {"ua": {}, "res": [], "poly": []} + + with pytest.raises(ValueError): + node.run(shared_data) + + +def test_polymer_branch_executes(shared_data, monkeypatch): + node = VibrationalEntropyNode() + + shared_data["levels"] = {0: ["polymer"]} + shared_data["groups"] = {0: [0]} + + shared_data["force_covariances"] = {"ua": {}, "res": [], "poly": [MagicMock()]} + shared_data["torque_covariances"] = {"ua": {}, "res": [], "poly": [MagicMock()]} + + shared_data["reduced_universe"].atoms.fragments = [MagicMock(residues=[])] + + monkeypatch.setattr( + node, + "_compute_force_torque_entropy", + MagicMock(return_value=EntropyPair(trans=1.0, rot=1.0)), + ) + monkeypatch.setattr(node, "_log_molecule_level_results", MagicMock()) + + out = node.run(shared_data) + + assert "vibrational_entropy" in out + assert out["vibrational_entropy"][0]["polymer"]["trans"] == 1.0 + assert out["vibrational_entropy"][0]["polymer"]["rot"] == 1.0 + + +def test_run_skips_empty_mol_ids_group(): + node = VibrationalEntropyNode() + + shared_groups = { + "run_manager": MagicMock(), + "args": SimpleNamespace(temperature=298.0, combined_forcetorque=False), + "groups": {0: []}, + "levels": {0: ["united_atom"]}, + "reduced_universe": MagicMock(atoms=MagicMock(fragments=[])), + "force_covariances": {"ua": {}, "res": [], "poly": []}, + "torque_covariances": {"ua": {}, "res": [], "poly": []}, + "n_frames": 5, + "reporter": None, + } + + out = node.run(shared_groups) + assert "vibrational_entropy" in out + assert out["vibrational_entropy"][0] == {} + + +def test_get_ua_frame_counts_falls_back_to_empty_when_shape_wrong(): + node = VibrationalEntropyNode() + assert node._get_ua_frame_counts({"frame_counts": "not-a-dict"}) == {} + + +def test_compute_united_atom_entropy_logs_residue_data_when_reporter_present(): + node = VibrationalEntropyNode() + ve = MagicMock() + + node._compute_force_torque_entropy = MagicMock(return_value=EntropyPair(1.0, 2.0)) + + reporter = MagicMock() + residues = [SimpleNamespace(resname="A"), SimpleNamespace(resname="B")] + + out = node._compute_united_atom_entropy( + ve=ve, + temp=298.0, + group_id=7, + residues=residues, + force_ua={}, + torque_ua={}, + ua_frame_counts={(7, 0): 3, (7, 1): 4}, + reporter=reporter, + n_frames_default=10, + highest=True, + ) + + assert out == EntropyPair(trans=2.0, rot=4.0) + assert reporter.add_residue_data.call_count == 4 + + +def test_compute_force_torque_entropy_success_calls_vibrational_engine(): + node = VibrationalEntropyNode() + ve = MagicMock() + ve.vibrational_entropy_calculation.side_effect = [10.0, 20.0] + + out = node._compute_force_torque_entropy( + ve=ve, + temp=298.0, + fmat=np.eye(3), + tmat=np.eye(3), + highest=False, + ) + + assert out == EntropyPair(trans=10.0, rot=20.0) + assert ve.vibrational_entropy_calculation.call_count == 2 + + +def test_compute_ft_entropy_success_calls_vibrational_engine_for_trans_and_rot(): + node = VibrationalEntropyNode() + ve = MagicMock() + ve.vibrational_entropy_calculation.side_effect = [1.5, 2.5] + + out = node._compute_ft_entropy(ve=ve, temp=298.0, ftmat=np.eye(6)) + + assert out == EntropyPair(trans=1.5, rot=2.5) + assert ve.vibrational_entropy_calculation.call_count == 2 + + +def test_log_molecule_level_results_returns_when_reporter_none(): + VibrationalEntropyNode._log_molecule_level_results( + reporter=None, + group_id=1, + level="residue", + pair=EntropyPair(1.0, 2.0), + use_ft_labels=False, + ) + + +def test_log_molecule_level_results_writes_trans_and_rot_labels(): + reporter = MagicMock() + VibrationalEntropyNode._log_molecule_level_results( + reporter=reporter, + group_id=1, + level="residue", + pair=EntropyPair(3.0, 4.0), + use_ft_labels=False, + ) + + reporter.add_results_data.assert_any_call(1, "residue", "Transvibrational", 3.0) + reporter.add_results_data.assert_any_call(1, "residue", "Rovibrational", 4.0) + + +def test_get_group_id_to_index_builds_from_groups(shared_groups): + node = VibrationalEntropyNode() + gid2i = node._get_group_id_to_index(shared_groups) + assert gid2i == {5: 0} + + +def test_get_ua_frame_counts_returns_empty_when_missing(shared_groups): + node = VibrationalEntropyNode() + assert node._get_ua_frame_counts(shared_groups) == {} + + +def test_compute_force_torque_entropy_returns_zero_when_missing_matrix(shared_groups): + node = VibrationalEntropyNode() + ve = MagicMock() + pair = node._compute_force_torque_entropy( + ve=ve, temp=298.0, fmat=None, tmat=np.eye(3), highest=True + ) + assert pair == EntropyPair(trans=0.0, rot=0.0) + + +def test_compute_force_torque_entropy_returns_zero_when_filter_removes_all(monkeypatch): + node = VibrationalEntropyNode() + ve = MagicMock() + + monkeypatch.setattr( + node._mat_ops, "filter_zero_rows_columns", lambda a, atol: np.array([]) + ) + + pair = node._compute_force_torque_entropy( + ve=ve, temp=298.0, fmat=np.eye(3), tmat=np.eye(3), highest=True + ) + assert pair == EntropyPair(trans=0.0, rot=0.0) + + +def test_compute_ft_entropy_returns_zero_when_none(): + node = VibrationalEntropyNode() + ve = MagicMock() + assert node._compute_ft_entropy(ve=ve, temp=298.0, ftmat=None) == EntropyPair( + trans=0.0, rot=0.0 + ) + + +def test_log_molecule_level_results_ft_labels_branch(): + node = VibrationalEntropyNode() + reporter = MagicMock() + + node._log_molecule_level_results( + reporter, 1, "residue", EntropyPair(1.0, 2.0), use_ft_labels=True + ) + + reporter.add_results_data.assert_any_call( + 1, "residue", "FTmat-Transvibrational", 1.0 + ) + reporter.add_results_data.assert_any_call(1, "residue", "FTmat-Rovibrational", 2.0) + + +def test_get_indexed_matrix_out_of_range_returns_none(): + node = VibrationalEntropyNode() + assert node._get_indexed_matrix([np.eye(3)], 5) is None + + +def test_run_unknown_level_raises(shared_groups): + node = VibrationalEntropyNode() + shared_groups["levels"] = {0: ["nope"]} + + with pytest.raises(ValueError): + node.run(shared_groups) + + +def test_compute_ft_entropy_returns_zeros_when_filtered_ft_matrix_is_empty(monkeypatch): + node = VibrationalEntropyNode() + ve = MagicMock() + + monkeypatch.setattr( + node._mat_ops, + "filter_zero_rows_columns", + lambda _arr, atol: np.empty((0, 0), dtype=float), + ) + + out = node._compute_ft_entropy(ve=ve, temp=298.0, ftmat=np.eye(6)) + + assert out == EntropyPair(trans=0.0, rot=0.0) + ve.vibrational_entropy_calculation.assert_not_called() diff --git a/tests/unit/CodeEntropy/entropy/test_configurational_edges.py b/tests/unit/CodeEntropy/entropy/test_configurational_edges.py new file mode 100644 index 00000000..1b8f9792 --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/test_configurational_edges.py @@ -0,0 +1,90 @@ +import logging +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from CodeEntropy.entropy.configurational import ConformationalEntropy + + +def test_validate_assignment_config_step_must_be_positive(): + ce = ConformationalEntropy() + with pytest.raises(ValueError): + ce.assign_conformation( + data_container=SimpleNamespace(trajectory=list(range(5))), + dihedral=MagicMock(value=lambda: 10.0), + number_frames=5, + bin_width=30, + start=0, + end=5, + step=0, + ) + + +def test_validate_assignment_config_bin_width_out_of_range(): + ce = ConformationalEntropy() + with pytest.raises(ValueError): + ce.assign_conformation( + data_container=SimpleNamespace(trajectory=list(range(5))), + dihedral=MagicMock(value=lambda: 10.0), + number_frames=5, + bin_width=0, + start=0, + end=5, + step=1, + ) + + +def test_validate_assignment_config_warns_when_bin_width_not_dividing_360(caplog): + ce = ConformationalEntropy() + caplog.set_level(logging.WARNING) + + data_container = SimpleNamespace(trajectory=list(range(5))) + dihedral = MagicMock() + dihedral.value.return_value = 10.0 + + ce.assign_conformation( + data_container=data_container, + dihedral=dihedral, + number_frames=5, + bin_width=7, + start=0, + end=5, + step=1, + ) + + assert any("does not evenly divide 360" in r.message for r in caplog.records) + + +def test_collect_dihedral_angles_normalizes_negative_values(): + ce = ConformationalEntropy() + + traj_slice = list(range(3)) + dihedral = MagicMock() + dihedral.value.side_effect = [-10.0, 0.0, 10.0] + + phi = ce._collect_dihedral_angles(traj_slice, dihedral) + + assert phi[0] == pytest.approx(350.0) + + +def test_to_1d_array_returns_none_for_non_iterable_state_input(): + ce = ConformationalEntropy() + # int is not iterable -> list(states) raises TypeError -> returns None + assert ce._to_1d_array(123) is None + + +def test_find_histogram_peaks_skips_zero_population_bins(): + ce = ConformationalEntropy() + + phi = np.zeros(50, dtype=float) + + peaks = ce._find_histogram_peaks(phi, bin_width=30) + + assert peaks.size >= 1 + + +def test_to_1d_array_returns_none_when_states_is_none(): + ce = ConformationalEntropy() + assert ce._to_1d_array(None) is None diff --git a/tests/unit/CodeEntropy/entropy/test_configurational_entropy_math.py b/tests/unit/CodeEntropy/entropy/test_configurational_entropy_math.py new file mode 100644 index 00000000..7624fbf6 --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/test_configurational_entropy_math.py @@ -0,0 +1,120 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from CodeEntropy.entropy.configurational import ConformationalEntropy + + +def test_find_histogram_peaks_empty_histogram_returns_empty(): + ce = ConformationalEntropy() + phi = np.zeros(100, dtype=float) + + peaks = ce._find_histogram_peaks(phi, bin_width=30) + + assert isinstance(peaks, np.ndarray) + assert peaks.dtype == float + + +def test_find_histogram_peaks_returns_empty_for_empty_phi(): + ce = ConformationalEntropy() + phi = np.array([], dtype=float) + + peaks = ce._find_histogram_peaks(phi, bin_width=30) + + assert isinstance(peaks, np.ndarray) + assert peaks.size == 0 + + +def test_assign_nearest_peaks_with_single_peak_assigns_all_zero(): + ce = ConformationalEntropy() + phi = np.array([0.0, 10.0, 20.0], dtype=float) + peak_values = np.array([15.0], dtype=float) + + states = ce._assign_nearest_peaks(phi, peak_values) + + assert np.all(states == 0) + + +def test_assign_conformation_no_peaks_returns_all_zero(): + ce = ConformationalEntropy() + + data_container = SimpleNamespace(trajectory=[]) + dihedral = MagicMock() + + states = ce.assign_conformation( + data_container=data_container, + dihedral=dihedral, + number_frames=0, + bin_width=30, + start=0, + end=0, + step=1, + ) + + assert states.size == 0 + + +def test_assign_conformation_fallback_when_peak_finder_returns_empty(monkeypatch): + ce = ConformationalEntropy() + data_container = SimpleNamespace(trajectory=list(range(5))) + dihedral = MagicMock() + dihedral.value.return_value = 10.0 + + monkeypatch.setattr( + ce, "_find_histogram_peaks", lambda phi, bw: np.array([], dtype=float) + ) + + states = ce.assign_conformation( + data_container=data_container, + dihedral=dihedral, + number_frames=5, + bin_width=30, + start=0, + end=5, + step=1, + ) + assert np.all(states == 0) + + +def test_assign_conformation_detects_multiple_states(): + ce = ConformationalEntropy() + + values = [0.0] * 50 + [180.0] * 50 + data_container = SimpleNamespace(trajectory=list(range(len(values)))) + dihedral = MagicMock() + dihedral.value.side_effect = values + + states = ce.assign_conformation( + data_container=data_container, + dihedral=dihedral, + number_frames=len(values), + bin_width=30, + start=0, + end=len(values), + step=1, + ) + + assert len(np.unique(states)) >= 2 + + +def test_conformational_entropy_empty_returns_zero(): + ce = ConformationalEntropy() + assert ce.conformational_entropy_calculation([], number_frames=10) == 0.0 + + +def test_conformational_entropy_single_state_returns_zero(): + ce = ConformationalEntropy() + assert ce.conformational_entropy_calculation([0, 0, 0], number_frames=3) == 0.0 + + +def test_conformational_entropy_known_distribution_matches_expected(): + ce = ConformationalEntropy() + states = np.array([0, 0, 1, 1, 1, 2]) + + probs = np.array([2 / 6, 3 / 6, 1 / 6], dtype=float) + expected = -ce._GAS_CONST * float(np.sum(probs * np.log(probs))) + + got = ce.conformational_entropy_calculation(states, number_frames=6) + assert got == pytest.approx(expected) diff --git a/tests/unit/CodeEntropy/entropy/test_graph.py b/tests/unit/CodeEntropy/entropy/test_graph.py new file mode 100644 index 00000000..92f4e2f3 --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/test_graph.py @@ -0,0 +1,79 @@ +from unittest.mock import MagicMock + +import pytest + +from CodeEntropy.entropy.graph import EntropyGraph, NodeSpec + + +def test_build_creates_expected_nodes_and_edges(): + g = EntropyGraph().build() + + assert set(g._nodes.keys()) == { + "vibrational_entropy", + "configurational_entropy", + "aggregate_entropy", + } + + assert g._graph.has_edge("vibrational_entropy", "aggregate_entropy") + assert g._graph.has_edge("configurational_entropy", "aggregate_entropy") + + +def test_execute_runs_nodes_in_topological_order_and_merges_dict_outputs(shared_data): + g = EntropyGraph() + + node_a = MagicMock() + node_b = MagicMock() + node_c = MagicMock() + + node_a.run.return_value = {"a": 1} + node_b.run.return_value = {"b": 2} + node_c.run.return_value = "not-a-dict" + + g._add_node(NodeSpec("a", node_a)) + g._add_node(NodeSpec("b", node_b)) + g._add_node(NodeSpec("c", node_c, deps=("a", "b"))) + + out = g.execute(shared_data) + + assert node_a.run.called + assert node_b.run.called + assert node_c.run.called + assert out == {"a": 1, "b": 2} + + +def test_add_node_duplicate_name_raises(): + g = EntropyGraph() + g._add_node(NodeSpec("x", object())) + with pytest.raises(ValueError): + g._add_node(NodeSpec("x", object())) + + +def test_execute_forwards_progress_to_nodes_that_accept_it(shared_data): + g = EntropyGraph() + + node_a = MagicMock() + node_a.run.return_value = {"a": 1} + + g._add_node(NodeSpec("a", node_a)) + + progress = MagicMock() + out = g.execute(shared_data, progress=progress) + + node_a.run.assert_called_once_with(shared_data, progress=progress) + assert out == {"a": 1} + + +def test_execute_falls_back_when_node_run_does_not_accept_progress(shared_data): + g = EntropyGraph() + + class NoProgressNode: + def run(self, shared_data): + return {"x": 1} + + node = NoProgressNode() + g._add_node(NodeSpec("x", node)) + + progress = MagicMock() + out = g.execute(shared_data, progress=progress) + + assert out == {"x": 1} diff --git a/tests/unit/CodeEntropy/entropy/test_orientational_entropy.py b/tests/unit/CodeEntropy/entropy/test_orientational_entropy.py new file mode 100644 index 00000000..1411d08d --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/test_orientational_entropy.py @@ -0,0 +1,54 @@ +import pytest + +from CodeEntropy.entropy.orientational import OrientationalEntropy + + +def test_orientational_skips_water_species(): + oe = OrientationalEntropy(None, None, None, None, None) + res = oe.calculate({"WAT": 10, "LIG": 2}) + assert res.total > 0 + + +def test_orientational_negative_count_raises(): + oe = OrientationalEntropy(None, None, None, None, None) + with pytest.raises(ValueError): + oe.calculate({"LIG": -1}) + + +def test_orientational_zero_count_contributes_zero(): + oe = OrientationalEntropy(None, None, None, None, None) + res = oe.calculate({"LIG": 0}) + assert res.total == 0.0 + + +def test_orientational_skips_water_resname_all_uppercase(): + oe = OrientationalEntropy(None, None, None, None, None) + res = oe.calculate({"WAT": 10}) + assert res.total == 0.0 + + +def test_orientational_entropy_skips_water_species(): + oe = OrientationalEntropy(None, None, None, None, None) + res = oe.calculate({"WAT": 10, "Na+": 2}) + assert res.total > 0.0 + + +def test_orientational_calculate_only_water_returns_zero(): + oe = OrientationalEntropy(None, None, None, None, None) + res = oe.calculate({"WAT": 5}) + assert res.total == 0.0 + + +def test_calculate_skips_water_species_branch(): + oe = OrientationalEntropy(None, None, None, None, None) + out = oe.calculate({"WAT": 10, "Na+": 2}) + + assert out.total > 0.0 + + +def test_entropy_contribution_returns_zero_when_omega_nonpositive(monkeypatch): + oe = OrientationalEntropy(None, None, None, None, None) + + monkeypatch.setattr(OrientationalEntropy, "_omega", staticmethod(lambda n: 0.0)) + + assert oe._entropy_contribution(5) == 0.0 diff --git a/tests/unit/CodeEntropy/entropy/test_vibrational_entropy_math.py b/tests/unit/CodeEntropy/entropy/test_vibrational_entropy_math.py new file mode 100644 index 00000000..43fb8e04 --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/test_vibrational_entropy_math.py @@ -0,0 +1,161 @@ +import numpy as np +import pytest + +from CodeEntropy.entropy.vibrational import VibrationalEntropy + + +@pytest.fixture() +def run_manager(): + class RM: + def change_lambda_units(self, x): + return np.asarray(x) + + def get_KT2J(self, temp): + return 1e-34 + + return RM() + + +def test_matrix_eigenvalues_returns_complex_dtype_possible(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + m = np.array([[0.0, -1.0], [1.0, 0.0]]) + eigs = ve._matrix_eigenvalues(m) + assert eigs.shape == (2,) + + +def test_frequencies_from_lambdas_filters_nonpositive_and_near_zero(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + + lambdas = np.array([-1.0, 0.0, 1e-12, 1.0, 4.0]) + freqs = ve._frequencies_from_lambdas(lambdas, temp=298.0) + + assert freqs.size == 2 + assert np.all(freqs > 0) + + +def test_frequencies_from_lambdas_filters_complex(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + lambdas = np.array([1.0 + 2.0j, 9.0 + 0.0j, 16.0]) + + freqs = ve._frequencies_from_lambdas(lambdas, temp=298.0) + + assert freqs.size == 2 + assert np.all(freqs > 0) + + +def test_entropy_components_returns_empty_when_all_invalid(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + + ve._matrix_eigenvalues = lambda m: np.array([-1.0, 0.0, 0.0]) + comps = ve._entropy_components(np.eye(3), temp=298.0) + + assert comps.size == 0 + + +def test_entropy_components_from_frequencies_returns_correct_shape(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + + freqs = np.array([1.0, 2.0, 3.0], dtype=float) + comps = ve._entropy_components_from_frequencies(freqs, temp=298.0) + + assert comps.shape == (3,) + assert isinstance(comps, np.ndarray) + + assert np.all(comps >= 0) or np.isinf(comps).any() + + +def test_split_halves_even_length(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + + arr = np.arange(10, dtype=float) + a, b = ve._split_halves(arr) + + assert a.shape == (5,) + assert b.shape == (5,) + assert np.all(a == np.array([0, 1, 2, 3, 4], dtype=float)) + assert np.all(b == np.array([5, 6, 7, 8, 9], dtype=float)) + + +def test_split_halves_odd_length_returns_empty_second(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + + arr = np.arange(5, dtype=float) + a, b = ve._split_halves(arr) + + assert a.shape == (5,) + assert b.size == 0 + + +def test_sum_components_empty_returns_zero(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + assert ( + ve._sum_components(np.array([], dtype=float), "force", highest_level=True) + == 0.0 + ) + + +def test_sum_components_force_highest_sums_all(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + comps = np.arange(12, dtype=float) + assert ve._sum_components(comps, "force", highest_level=True) == pytest.approx( + float(np.sum(comps)) + ) + + +def test_sum_components_force_not_highest_drops_first_six(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + comps = np.arange(12, dtype=float) + assert ve._sum_components(comps, "force", highest_level=False) == pytest.approx( + float(np.sum(comps[6:])) + ) + + +def test_sum_components_torque_sums_all(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + comps = np.arange(12, dtype=float) + assert ve._sum_components(comps, "torque", highest_level=False) == pytest.approx( + float(np.sum(comps)) + ) + + +def test_sum_components_forcetorque_trans_uses_first_three(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + comps = np.arange(10, dtype=float) + assert ve._sum_components( + comps, "forcetorqueTRANS", highest_level=False + ) == pytest.approx(float(np.sum(comps[:3]))) + + +def test_sum_components_forcetorque_rot_uses_after_three(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + comps = np.arange(10, dtype=float) + assert ve._sum_components( + comps, "forcetorqueROT", highest_level=False + ) == pytest.approx(float(np.sum(comps[3:]))) + + +def test_sum_components_unknown_matrix_type_raises(run_manager): + ve = VibrationalEntropy(run_manager=run_manager) + comps = np.arange(6, dtype=float) + + with pytest.raises(ValueError): + ve._sum_components(comps, "nope", highest_level=True) + + +def test_vibrational_entropy_calculation_end_to_end_returns_float( + run_manager, monkeypatch +): + ve = VibrationalEntropy(run_manager=run_manager) + + monkeypatch.setattr(ve, "_matrix_eigenvalues", lambda m: np.array([1.0, 2.0, 3.0])) + monkeypatch.setattr(ve, "_convert_lambda_units", lambda x: np.asarray(x)) + monkeypatch.setattr( + ve, "_frequencies_from_lambdas", lambda lambdas, temp: np.array([1.0, 2.0, 3.0]) + ) + + out = ve.vibrational_entropy_calculation( + np.eye(3), matrix_type="torque", temp=298.0, highest_level=False + ) + + assert isinstance(out, float) + assert out >= 0.0 diff --git a/tests/unit/CodeEntropy/entropy/test_water_entropy.py b/tests/unit/CodeEntropy/entropy/test_water_entropy.py new file mode 100644 index 00000000..67de69d0 --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/test_water_entropy.py @@ -0,0 +1,188 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from CodeEntropy.entropy.water import WaterEntropy + + +def _make_fake_universe_with_water(): + universe = MagicMock() + + water_selection = MagicMock() + + water_residues = MagicMock() + water_residues.resnames = ["WAT"] + water_residues.__len__.return_value = 1 + + water_selection.residues = water_residues + water_selection.atoms = [1, 2, 3] + + universe.select_atoms.return_value = water_selection + return universe + + +def test_water_entropy_calls_solver_and_logs_components(): + args = SimpleNamespace(temperature=298.0) + reporter = MagicMock() + + Sorient_dict = {1: {"WAT": [3.0, 5]}} + + covariances = SimpleNamespace(counts={(7, "WAT"): 2}) + + vibrations = SimpleNamespace( + translational_S={(7, "WAT"): [1.0, 2.0]}, + rotational_S={(7, "WAT"): 4.0}, + ) + + solver = MagicMock(return_value=(Sorient_dict, covariances, vibrations, None, 123)) + + we = WaterEntropy(args=args, reporter=reporter, solver=solver) + + we._solute_id_to_resname = MagicMock(return_value="SOL") + + universe = _make_fake_universe_with_water() + + we.calculate_and_log(universe=universe, start=0, end=10, step=1, group_id=9) + + solver.assert_called_once() + + reporter.add_residue_data.assert_any_call( + 9, "WAT", "Water", "Orientational", 5, 3.0 + ) + + reporter.add_residue_data.assert_any_call( + 9, "SOL", "Water", "Transvibrational", 2, 3.0 + ) + + reporter.add_residue_data.assert_any_call( + 9, "SOL", "Water", "Rovibrational", 2, 4.0 + ) + + reporter.add_group_label.assert_called_once() + + +def test_water_entropy_handles_empty_solver_results_gracefully(): + args = SimpleNamespace(temperature=298.0) + reporter = MagicMock() + + solver = MagicMock( + return_value=( + {}, + SimpleNamespace(counts={}), + SimpleNamespace(translational_S={}, rotational_S={}), + None, + 0, + ) + ) + + we = WaterEntropy(args=args, reporter=reporter, solver=solver) + we._solute_id_to_resname = MagicMock(return_value="SOL") + + universe = _make_fake_universe_with_water() + + we.calculate_and_log(universe=universe, start=0, end=10, step=1, group_id=1) + + reporter.add_residue_data.assert_not_called() + reporter.add_group_label.assert_called_once() + + +def test_water_group_label_handles_multiple_water_resnames(): + args = SimpleNamespace(temperature=298.0) + reporter = MagicMock() + + solver = MagicMock( + return_value=( + {}, + SimpleNamespace(counts={}), + SimpleNamespace(translational_S={}, rotational_S={}), + None, + 0, + ) + ) + we = WaterEntropy(args=args, reporter=reporter, solver=solver) + we._solute_id_to_resname = MagicMock(return_value="SOL") + + universe = MagicMock() + water_selection = MagicMock() + + water_residues = MagicMock() + water_residues.resnames = ["WAT", "TIP3"] + water_residues.__len__.return_value = 2 + + water_selection.residues = water_residues + water_selection.atoms = [1, 2, 3, 4] + + universe.select_atoms.return_value = water_selection + + we.calculate_and_log(universe=universe, start=0, end=10, step=1, group_id=1) + + reporter.add_group_label.assert_called_once() + + +def test_log_group_label_defaults_to_WAT_when_no_residue_names_match(): + args = SimpleNamespace(temperature=298.0) + reporter = MagicMock() + + solver = MagicMock( + return_value=( + {}, + SimpleNamespace(counts={}), + SimpleNamespace(translational_S={}, rotational_S={}), + None, + 0, + ) + ) + we = WaterEntropy(args=args, reporter=reporter, solver=solver) + + universe = MagicMock() + + water_selection = MagicMock() + water_residues = MagicMock() + water_residues.resnames = ["WAT"] + water_residues.__len__.return_value = 1 + + water_selection.residues = water_residues + water_selection.atoms = [1, 2, 3] + universe.select_atoms.return_value = water_selection + + Sorient_dict = {1: {"TIP3": [1.0, 1]}} + + we._log_group_label(universe, Sorient_dict, group_id=7) + + reporter.add_group_label.assert_called_once() + _, residue_group, *_ = reporter.add_group_label.call_args.args + assert residue_group == "WAT" + + +def test_log_group_label_defaults_to_WAT_when_no_names_match(): + args = SimpleNamespace(temperature=298.0) + reporter = MagicMock() + we = WaterEntropy(args=args, reporter=reporter, solver=MagicMock()) + + universe = MagicMock() + water_selection = MagicMock() + + residues = MagicMock() + residues.resnames = ["WAT"] + residues.__len__.return_value = 2 + + water_selection.residues = residues + water_selection.atoms = [1, 2, 3, 4] + universe.select_atoms.return_value = water_selection + + Sorient_dict = {1: {"TIP3": [1.0, 2]}} + + we._log_group_label(universe, Sorient_dict, group_id=7) + + reporter.add_group_label.assert_called_once() + + assert reporter.add_group_label.call_args.args[1] == "WAT" + + +def test_solute_id_to_resname_strips_suffix_after_last_underscore(): + assert WaterEntropy._solute_id_to_resname("ALA_0") == "ALA" + assert WaterEntropy._solute_id_to_resname("ALA_BLA_12") == "ALA_BLA" + + +def test_solute_id_to_resname_returns_string_when_no_underscore(): + assert WaterEntropy._solute_id_to_resname("WAT") == "WAT" + assert WaterEntropy._solute_id_to_resname(123) == "123" diff --git a/tests/unit/CodeEntropy/entropy/test_workflow.py b/tests/unit/CodeEntropy/entropy/test_workflow.py new file mode 100644 index 00000000..29b93cc6 --- /dev/null +++ b/tests/unit/CodeEntropy/entropy/test_workflow.py @@ -0,0 +1,426 @@ +import logging +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from CodeEntropy.entropy.workflow import EntropyWorkflow + + +def _make_wf(args): + return EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=MagicMock(), + reporter=MagicMock(molecule_data=[], residue_data=[]), + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + +def test_execute_calls_level_dag_and_entropy_graph_and_logs_tables(): + args = SimpleNamespace( + start=0, + end=-1, + step=1, + grouping="molecules", + water_entropy=False, + selection_string="all", + ) + + universe = MagicMock() + universe.trajectory = list(range(5)) + + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=universe, + reporter=MagicMock(), + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + wf._build_reduced_universe = MagicMock(return_value=MagicMock()) + wf._detect_levels = MagicMock(return_value={0: ["united_atom"]}) + wf._split_water_groups = MagicMock(return_value=({0: [0]}, {})) + wf._finalize_molecule_results = MagicMock() + + wf._group_molecules.grouping_molecules.return_value = {0: [0]} + + with ( + patch("CodeEntropy.entropy.workflow.LevelDAG") as LevelDAGCls, + patch("CodeEntropy.entropy.workflow.EntropyGraph") as GraphCls, + ): + LevelDAGCls.return_value.build.return_value.execute.return_value = None + GraphCls.return_value.build.return_value.execute.return_value = {"x": 1} + + wf.execute() + + wf._reporter.log_tables.assert_called_once() + + +def test_execute_water_entropy_branch_calls_water_entropy_solver(): + args = SimpleNamespace( + start=0, + end=-1, + step=1, + grouping="molecules", + water_entropy=True, + selection_string="all", + output_file="out.json", + ) + + universe = MagicMock() + universe.trajectory = list(range(5)) + + reporter = MagicMock() + reporter.molecule_data = [] + reporter.residue_data = [] + reporter.save_dataframes_as_json = MagicMock() + + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=universe, + reporter=reporter, + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + wf._build_reduced_universe = MagicMock(return_value=MagicMock()) + wf._detect_levels = MagicMock(return_value={0: ["united_atom"]}) + + wf._split_water_groups = MagicMock(return_value=({0: [0]}, {9: [1, 2]})) + wf._finalize_molecule_results = MagicMock() + + wf._group_molecules.grouping_molecules.return_value = {0: [0], 9: [1, 2]} + + with ( + patch("CodeEntropy.entropy.workflow.WaterEntropy") as WaterCls, + patch("CodeEntropy.entropy.workflow.LevelDAG") as LevelDAGCls, + patch("CodeEntropy.entropy.workflow.EntropyGraph") as GraphCls, + ): + water_instance = WaterCls.return_value + water_instance._calculate_water_entropy = MagicMock() + + LevelDAGCls.return_value.build.return_value.execute.return_value = None + GraphCls.return_value.build.return_value.execute.return_value = {} + + wf.execute() + + water_instance._calculate_water_entropy.assert_called_once() + _, kwargs = water_instance._calculate_water_entropy.call_args + assert kwargs["universe"] is universe + assert kwargs["start"] == 0 + assert kwargs["end"] == 5 + assert kwargs["step"] == 1 + assert kwargs["group_id"] == 9 + + +def test_get_trajectory_bounds_end_minus_one_uses_trajectory_length(): + args = SimpleNamespace( + start=0, + end=-1, + step=2, + grouping="molecules", + water_entropy=False, + selection_string="all", + ) + universe = SimpleNamespace(trajectory=list(range(10))) + + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=universe, + reporter=MagicMock(), + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + start, end, step = wf._get_trajectory_bounds() + assert (start, end, step) == (0, 10, 2) + + +def test_get_number_frames_matches_python_slice_math(): + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=MagicMock(), + universe=MagicMock(), + reporter=MagicMock(), + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + assert wf._get_number_frames(0, 10, 1) == 10 + assert wf._get_number_frames(0, 10, 2) == 5 + + +def test_finalize_results_called_even_if_empty(): + args = SimpleNamespace(output_file="out.json") + reporter = MagicMock() + reporter.molecule_data = [] + reporter.residue_data = [] + reporter.save_dataframes_as_json = MagicMock() + + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=MagicMock(), + reporter=reporter, + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + wf._finalize_molecule_results() + + reporter.save_dataframes_as_json.assert_called_once() + + +def test_split_water_groups_returns_empty_when_none(): + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=MagicMock(water_entropy=False), + universe=MagicMock(), + reporter=MagicMock(), + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + groups, water = wf._split_water_groups({0: [1, 2]}) + + assert water == {} + + +def test_build_reduced_universe_non_all_selects_and_writes_universe(): + args = SimpleNamespace( + selection_string="protein", + grouping="molecules", + start=0, + end=-1, + step=1, + water_entropy=False, + output_file="out.json", + ) + universe = MagicMock() + universe.trajectory = list(range(3)) + + reduced = MagicMock() + reduced.trajectory = list(range(2)) + + uops = MagicMock() + uops.select_atoms.return_value = reduced + + run_manager = MagicMock() + reporter = MagicMock() + + wf = EntropyWorkflow( + run_manager=run_manager, + args=args, + universe=universe, + reporter=reporter, + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=uops, + ) + + out = wf._build_reduced_universe() + + assert out is reduced + uops.select_atoms.assert_called_once_with(universe, "protein") + run_manager.write_universe.assert_called_once() + + +def test_compute_water_entropy_updates_selection_string_and_calls_internal_method(): + args = SimpleNamespace( + selection_string="all", water_entropy=True, temperature=298.0 + ) + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=MagicMock(), + reporter=MagicMock(molecule_data=[], residue_data=[]), + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + traj = SimpleNamespace(start=0, end=5, step=1) + water_groups = {9: [1, 2]} + + with patch("CodeEntropy.entropy.workflow.WaterEntropy") as WaterCls: + inst = WaterCls.return_value + inst._calculate_water_entropy = MagicMock() + + wf._compute_water_entropy(traj, water_groups) + + inst._calculate_water_entropy.assert_called_once() + assert wf._args.selection_string == "not water" + + +def test_finalize_molecule_results_skips_invalid_entries_with_warning(caplog): + args = SimpleNamespace(output_file="out.json") + reporter = MagicMock() + + reporter.molecule_data = [(1, "united_atom", "Trans", "not-a-number")] + reporter.residue_data = [] + reporter.save_dataframes_as_json = MagicMock() + + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=MagicMock(), + reporter=reporter, + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + caplog.set_level(logging.WARNING) + wf._finalize_molecule_results() + + assert any("Skipping invalid entry" in r.message for r in caplog.records) + reporter.save_dataframes_as_json.assert_called_once() + + +def test_build_reduced_universe_all_returns_original_universe(): + args = SimpleNamespace( + selection_string="all", + start=0, + end=-1, + step=1, + grouping="molecules", + water_entropy=False, + output_file="out.json", + ) + universe = MagicMock() + uops = MagicMock() + run_manager = MagicMock() + wf = EntropyWorkflow( + run_manager, args, universe, MagicMock(), MagicMock(), MagicMock(), uops + ) + + out = wf._build_reduced_universe() + + assert out is universe + uops.select_atoms.assert_not_called() + run_manager.write_universe.assert_not_called() + + +def test_split_water_groups_partitions_correctly(): + args = SimpleNamespace( + start=0, + end=-1, + step=1, + grouping="molecules", + water_entropy=False, + selection_string="all", + output_file="out.json", + ) + universe = MagicMock() + + water_res = MagicMock() + water_res.resid = 10 + water_atoms = MagicMock() + water_atoms.residues = [water_res] + universe.select_atoms.return_value = water_atoms + + frag0 = MagicMock() + r0 = MagicMock() + r0.resid = 10 + frag0.residues = [r0] + + frag1 = MagicMock() + r1 = MagicMock() + r1.resid = 99 + frag1.residues = [r1] + + universe.atoms.fragments = [frag0, frag1] + + wf = EntropyWorkflow( + MagicMock(), args, universe, MagicMock(), MagicMock(), MagicMock(), MagicMock() + ) + + groups = {0: [0], 1: [1]} + nonwater, water = wf._split_water_groups(groups) + + assert 0 in water + assert 1 in nonwater + + +def test_compute_water_entropy_instantiates_waterentropy_and_updates_selection_string(): + args = SimpleNamespace( + selection_string="all", water_entropy=True, temperature=298.0 + ) + universe = MagicMock() + reporter = MagicMock() + wf = EntropyWorkflow( + MagicMock(), args, universe, reporter, MagicMock(), MagicMock(), MagicMock() + ) + + traj = SimpleNamespace(start=0, end=5, step=1, n_frames=5) + water_groups = {9: [0]} + + with patch("CodeEntropy.entropy.workflow.WaterEntropy") as WaterCls: + inst = WaterCls.return_value + inst._calculate_water_entropy = MagicMock() + + wf._compute_water_entropy(traj, water_groups) + + WaterCls.assert_called_once_with(args) + inst._calculate_water_entropy.assert_called_once() + assert wf._args.selection_string == "not water" + + +def test_detect_levels_calls_hierarchy_builder(): + args = SimpleNamespace( + selection_string="all", water_entropy=False, output_file="out.json" + ) + wf = _make_wf(args) + + with patch("CodeEntropy.entropy.workflow.HierarchyBuilder") as HB: + HB.return_value.select_levels.return_value = (123, {"levels": "ok"}) + + out = wf._detect_levels(reduced_universe=MagicMock()) + + assert out == {"levels": "ok"} + HB.return_value.select_levels.assert_called_once() + + +def test_compute_water_entropy_returns_early_when_disabled_or_empty_groups(): + args = SimpleNamespace( + selection_string="all", + water_entropy=False, + temperature=298.0, + output_file="out.json", + ) + wf = _make_wf(args) + + traj = SimpleNamespace(start=0, end=5, step=1, n_frames=5) + + # empty water groups OR water_entropy disabled -> early return + wf._compute_water_entropy(traj, water_groups={}) + # no exception and no side effects expected + + +def test_finalize_molecule_results_skips_group_total_rows(): + args = SimpleNamespace( + output_file="out.json", selection_string="all", water_entropy=False + ) + wf = _make_wf(args) + + wf._reporter.molecule_data = [ + (1, "Group Total", "Group Total Entropy", 999.0), # should be skipped + (1, "united_atom", "Transvibrational", 1.5), # should count + ] + wf._reporter.residue_data = [] + + wf._finalize_molecule_results() + + # should append a new "Group Total" row based only on the non-total entries + assert any( + row[1] == "Group Total" and row[3] == 1.5 for row in wf._reporter.molecule_data + ) diff --git a/tests/unit/CodeEntropy/levels/conftest.py b/tests/unit/CodeEntropy/levels/conftest.py new file mode 100644 index 00000000..11a2c6d5 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/conftest.py @@ -0,0 +1,58 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pytest + + +@pytest.fixture +def args(): + # minimal args object used by nodes + return SimpleNamespace(grouping="each") + + +@pytest.fixture +def reduced_universe(): + """ + Minimal Universe-like object: + - .atoms.fragments exists and is list-like + """ + u = MagicMock() + u.atoms = MagicMock() + u.atoms.fragments = [] + return u + + +@pytest.fixture +def universe_with_fragments(): + """ + Universe with 3 fragments. + Each fragment can be customized by tests. + """ + u = MagicMock() + u.atoms = MagicMock() + u.atoms.fragments = [MagicMock(), MagicMock(), MagicMock()] + return u + + +@pytest.fixture +def simple_ts_list(): + # list supports slicing directly: lst[start:end:step] + return [SimpleNamespace(frame=i) for i in range(10)] + + +@pytest.fixture +def axes_manager_identity(): + """ + AxesCalculator-like adapter used by ForceTorqueCalculator for displacements. + Returns positions-center (no PBC). + """ + mgr = MagicMock() + + def _get_vector(center, positions, box): + center = np.asarray(center, dtype=float).reshape(3) + positions = np.asarray(positions, dtype=float) + return positions - center + + mgr.get_vector.side_effect = _get_vector + return mgr diff --git a/tests/unit/CodeEntropy/levels/nodes/test_build_beads_node.py b/tests/unit/CodeEntropy/levels/nodes/test_build_beads_node.py new file mode 100644 index 00000000..a48e83e1 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_build_beads_node.py @@ -0,0 +1,190 @@ +from unittest.mock import MagicMock + +import numpy as np + +from CodeEntropy.levels.nodes.beads import BuildBeadsNode + + +def _bead(indices, heavy_resindex=None, empty=False): + b = MagicMock() + b.__len__.return_value = 0 if empty else len(indices) + b.indices = np.asarray(indices, dtype=int) + + heavy = MagicMock() + if heavy_resindex is None: + heavy.__len__.return_value = 0 + heavy.__iter__.return_value = iter([]) + else: + a0 = MagicMock() + a0.resindex = int(heavy_resindex) + heavy.__len__.return_value = 1 + heavy.__getitem__.side_effect = lambda i: a0 + heavy.__iter__.return_value = iter([a0]) + + b.select_atoms.return_value = heavy + return b + + +def test_build_beads_node_groups_united_atom_beads_into_local_residue_buckets(): + r0 = MagicMock() + r0.resindex = 10 + r1 = MagicMock() + r1.resindex = 11 + + mol = MagicMock() + mol.residues = [r0, r1] + + ua0 = _bead([1, 2], heavy_resindex=10) + ua1 = _bead([3], heavy_resindex=11) + ua_empty = _bead([], heavy_resindex=10, empty=True) + + hier = MagicMock() + hier.get_beads.side_effect = lambda m, lvl: ( + [ua0, ua1, ua_empty] if lvl == "united_atom" else [] + ) + + node = BuildBeadsNode(hierarchy=hier) + + u = MagicMock() + u.atoms = MagicMock() + u.atoms.fragments = [mol] + + shared = {"reduced_universe": u, "levels": [["united_atom"]]} + + out = node.run(shared) + beads = out["beads"] + + assert (0, "united_atom", 0) in beads + assert (0, "united_atom", 1) in beads + assert len(beads[(0, "united_atom", 0)]) == 1 + assert len(beads[(0, "united_atom", 1)]) == 1 + + np.testing.assert_array_equal(beads[(0, "united_atom", 0)][0], np.array([1, 2])) + np.testing.assert_array_equal(beads[(0, "united_atom", 1)][0], np.array([3])) + + +def test_add_residue_beads_logs_error_if_none_kept(caplog): + hier = MagicMock() + # returns one empty bead -> skipped -> kept stays 0 + empty_bead = MagicMock() + empty_bead.__len__.return_value = 0 + hier.get_beads.return_value = [empty_bead] + + node = BuildBeadsNode(hierarchy=hier) + + beads = {} + mol = MagicMock() + mol.residues = [MagicMock()] + + node._add_residue_beads(beads=beads, mol_id=0, mol=mol) + + assert (0, "residue") in beads + assert beads[(0, "residue")] == [] + assert any("No residue beads kept" in rec.message for rec in caplog.records) + + +def test_infer_local_residue_id_returns_zero_if_no_heavy_atoms(): + mol = MagicMock() + mol.residues = [MagicMock(resindex=10), MagicMock(resindex=11)] + + bead = MagicMock() + heavy = MagicMock() + heavy.__len__.return_value = 0 + bead.select_atoms.return_value = heavy + + out = BuildBeadsNode._infer_local_residue_id(mol=mol, bead=bead) + assert out == 0 + + +def test_infer_local_residue_id_returns_zero_if_resindex_not_found(): + mol = MagicMock() + mol.residues = [MagicMock(resindex=10), MagicMock(resindex=11)] + + bead = MagicMock() + heavy = MagicMock() + heavy.__len__.return_value = 1 + heavy0 = MagicMock(resindex=999) + heavy.__getitem__.return_value = heavy0 + bead.select_atoms.return_value = heavy + + out = BuildBeadsNode._infer_local_residue_id(mol=mol, bead=bead) + assert out == 0 + + +def test_build_beads_node_skips_when_no_levels(): + """ + Covers: early return when levels missing (92/95 style guard branches) + """ + node = BuildBeadsNode(hierarchy=MagicMock()) + out = node.run({"reduced_universe": MagicMock(), "levels": []}) + assert out["beads"] == {} + + +def test_build_beads_node_residue_level_adds_residue_beads(): + """ + Covers: residue path + _add_residue_beads bookkeeping (log around 145 and 166-177) + """ + r0 = MagicMock(resindex=10) + r1 = MagicMock(resindex=11) + + mol = MagicMock() + mol.residues = [r0, r1] + + res0 = _bead([100, 101], heavy_resindex=10) + res1 = _bead([200], heavy_resindex=11) + + ua0 = _bead([1, 2], heavy_resindex=10) + ua1 = _bead([3], heavy_resindex=11) + + hier = MagicMock() + + def _get_beads(m, lvl): + if lvl == "residue": + return [res0, res1] + if lvl == "united_atom": + return [ua0, ua1] + return [] + + hier.get_beads.side_effect = _get_beads + + node = BuildBeadsNode(hierarchy=hier) + + u = MagicMock() + u.atoms = MagicMock() + u.atoms.fragments = [mol] + + shared = {"reduced_universe": u, "levels": [["residue", "united_atom"]]} + out = node.run(shared) + + beads = out["beads"] + + assert (0, "residue") in beads + assert len(beads[(0, "residue")]) == 2 + assert np.array_equal(beads[(0, "residue")][0], np.array([100, 101])) + assert np.array_equal(beads[(0, "residue")][1], np.array([200])) + + +def test_build_beads_node_polymer_level_adds_polymer_beads_and_skips_empty(): + mol0 = MagicMock() + mol0.residues = [MagicMock(resindex=10)] + + u = MagicMock() + u.atoms.fragments = [mol0] + + polymer_beads = [_bead([]), _bead([7, 8, 9])] + + hier = MagicMock() + hier.get_beads.side_effect = lambda m, lvl: ( + polymer_beads if lvl == "polymer" else [] + ) + + node = BuildBeadsNode(hierarchy=hier) + + shared = {"reduced_universe": u, "levels": [["polymer"]]} + out = node.run(shared) + + beads = out["beads"] + assert (0, "polymer") in beads + + assert len(beads[(0, "polymer")]) == 1 + np.testing.assert_array_equal(beads[(0, "polymer")][0], np.array([7, 8, 9])) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_conformations_node.py b/tests/unit/CodeEntropy/levels/nodes/test_conformations_node.py new file mode 100644 index 00000000..a03fb555 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_conformations_node.py @@ -0,0 +1,30 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from CodeEntropy.levels.nodes.conformations import ComputeConformationalStatesNode + + +def test_compute_conformational_states_node_runs_and_writes_shared_data(): + uops = MagicMock() + node = ComputeConformationalStatesNode(universe_operations=uops) + + node._dihedral_analysis.build_conformational_states = MagicMock( + return_value=({"ua_key": ["0", "1"]}, [["00", "01"]]) + ) + + shared = { + "reduced_universe": MagicMock(), + "levels": {0: ["united_atom"]}, + "groups": {0: [0]}, + "start": 0, + "end": 10, + "step": 1, + "args": SimpleNamespace(bin_width=10), + } + + out = node.run(shared) + + assert "conformational_states" in out + assert shared["conformational_states"]["ua"] == {"ua_key": ["0", "1"]} + assert shared["conformational_states"]["res"] == [["00", "01"]] + node._dihedral_analysis.build_conformational_states.assert_called_once() diff --git a/tests/unit/CodeEntropy/levels/nodes/test_detect_levels_node.py b/tests/unit/CodeEntropy/levels/nodes/test_detect_levels_node.py new file mode 100644 index 00000000..e9500810 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_detect_levels_node.py @@ -0,0 +1,19 @@ +from unittest.mock import patch + +from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode + + +def test_detect_levels_node_stores_results(reduced_universe): + node = DetectLevelsNode() + shared = {"reduced_universe": reduced_universe} + + with patch.object( + node._hierarchy, + "select_levels", + return_value=(2, [["united_atom"], ["united_atom", "residue"]]), + ): + out = node.run(shared) + + assert shared["number_molecules"] == 2 + assert shared["levels"] == [["united_atom"], ["united_atom", "residue"]] + assert out["levels"] == shared["levels"] diff --git a/tests/unit/CodeEntropy/levels/nodes/test_detect_molecules_node.py b/tests/unit/CodeEntropy/levels/nodes/test_detect_molecules_node.py new file mode 100644 index 00000000..249db1e4 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_detect_molecules_node.py @@ -0,0 +1,45 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode + + +def test_run_sets_reduced_universe_when_missing(args, universe_with_fragments): + node = DetectMoleculesNode() + + shared = { + "universe": universe_with_fragments, + "args": args, + } + + with patch.object(node._grouping, "grouping_molecules", return_value={0: [1]}): + out = node.run(shared) + + assert shared["reduced_universe"] is universe_with_fragments + assert shared["groups"] == {0: [1]} + assert shared["number_molecules"] == len(universe_with_fragments.atoms.fragments) + assert out["number_molecules"] == shared["number_molecules"] + + +def test_run_uses_args_grouping_strategy(universe_with_fragments): + node = DetectMoleculesNode() + shared = { + "universe": universe_with_fragments, + "args": SimpleNamespace(grouping="molecules"), + } + + with patch.object( + node._grouping, "grouping_molecules", return_value={"g": [1]} + ) as gm: + node.run(shared) + + gm.assert_called_once() + assert gm.call_args[0][1] == "molecules" + + +def test_ensure_reduced_universe_raises_if_missing_universe(): + node = DetectMoleculesNode() + with pytest.raises(KeyError): + node._ensure_reduced_universe({}) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py b/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py new file mode 100644 index 00000000..1b51006b --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py @@ -0,0 +1,634 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from CodeEntropy.levels.nodes import covariance as covmod +from CodeEntropy.levels.nodes.covariance import FrameCovarianceNode + + +class _BeadGroup: + def __init__(self, n=1): + self._n = n + + def __len__(self): + return self._n + + def center_of_mass(self, unwrap=False): + return np.array([0.0, 0.0, 0.0], dtype=float) + + +class _EmptyGroup: + def __len__(self): + return 0 + + +def _mk_atomgroup(n=1): + g = MagicMock() + g.__len__.return_value = n + return g + + +def test_get_shared_missing_raises_keyerror(): + node = FrameCovarianceNode() + with pytest.raises(KeyError): + node._get_shared({}) + + +def test_try_get_box_returns_none_on_failure(): + node = FrameCovarianceNode() + u = MagicMock() + type(u).dimensions = property(lambda self: (_ for _ in ()).throw(RuntimeError("x"))) + assert node._try_get_box(u) is None + + +def test_inc_mean_first_sample_copies(): + node = FrameCovarianceNode() + new = np.eye(2) + out = node._inc_mean(None, new, n=1) + np.testing.assert_allclose(out, new) + new[0, 0] = 999.0 + assert out[0, 0] != 999.0 + + +def test_inc_mean_updates_streaming_average(): + node = FrameCovarianceNode() + old = np.array([[2.0, 2.0], [2.0, 2.0]]) + new = np.array([[4.0, 0.0], [0.0, 4.0]]) + out = node._inc_mean(old, new, n=2) + np.testing.assert_allclose(out, np.array([[3.0, 1.0], [1.0, 3.0]])) + + +def test_build_ft_block_rejects_mismatched_lengths(): + node = FrameCovarianceNode() + with pytest.raises(ValueError): + node._build_ft_block([np.zeros(3)], [np.zeros(3), np.zeros(3)]) + + +def test_build_ft_block_rejects_empty(): + node = FrameCovarianceNode() + with pytest.raises(ValueError): + node._build_ft_block([], []) + + +def test_build_ft_block_rejects_non_length3_vectors(): + node = FrameCovarianceNode() + with pytest.raises(ValueError): + node._build_ft_block([np.zeros(2)], [np.zeros(3)]) + + +def test_build_ft_block_returns_symmetric_block_matrix(): + node = FrameCovarianceNode() + + force_vecs = [np.array([1.0, 0.0, 0.0]), np.array([0.0, 2.0, 0.0])] + torque_vecs = [np.array([0.0, 0.0, 3.0]), np.array([4.0, 0.0, 0.0])] + + M = node._build_ft_block(force_vecs, torque_vecs) + assert M.shape == (12, 12) + + np.testing.assert_allclose(M, M.T) + + +def test_process_residue_skips_when_no_beads_key_present(): + node = FrameCovarianceNode() + + shared = { + "reduced_universe": MagicMock(), + "groups": {0: [0]}, + "levels": [["residue"]], + "beads": {}, + "args": MagicMock( + force_partitioning=1.0, combined_forcetorque=False, customised_axes=False + ), + "axes_manager": MagicMock(), + } + ctx = {"shared": shared} + + out = node.run(ctx) + assert out["force"]["res"] == {} + assert out["torque"]["res"] == {} + assert "forcetorque" not in out + + +def test_process_residue_combined_only_when_highest_level(): + node = FrameCovarianceNode() + + u = MagicMock() + u.atoms = MagicMock() + frag = MagicMock() + frag.residues = [MagicMock()] + u.atoms.fragments = [frag] + u.atoms.__getitem__.side_effect = lambda idx: _mk_atomgroup(1) + u.dimensions = np.array([10.0, 10.0, 10.0, 90.0, 90.0, 90.0]) + + args = MagicMock() + args.force_partitioning = 1.0 + args.combined_forcetorque = True + args.customised_axes = True + + axes_manager = MagicMock() + axes_manager.get_residue_axes.return_value = ( + np.eye(3), + np.eye(3), + np.zeros(3), + np.array([1.0, 1.0, 1.0]), + ) + + shared = { + "reduced_universe": u, + "groups": {7: [0]}, + "levels": [["residue"]], + "beads": {(0, "residue"): [np.array([1, 2, 3])]}, + "args": args, + "axes_manager": axes_manager, + } + + with ( + patch.object( + node._ft, "get_weighted_forces", return_value=np.array([1.0, 0.0, 0.0]) + ), + patch.object( + node._ft, "get_weighted_torques", return_value=np.array([0.0, 1.0, 0.0]) + ), + patch.object( + node._ft, + "compute_frame_covariance", + return_value=(np.eye(3), 2.0 * np.eye(3)), + ), + ): + ctx = {"shared": shared} + out = node.run(ctx) + + assert "forcetorque" in out + assert 7 in out["force"]["res"] + assert 7 in out["torque"]["res"] + assert 7 in out["forcetorque"]["res"] + + +def test_process_residue_combined_not_added_if_not_highest_level(): + node = FrameCovarianceNode() + + u = MagicMock() + u.atoms = MagicMock() + frag = MagicMock() + frag.residues = [MagicMock()] + u.atoms.fragments = [frag] + u.atoms.__getitem__.side_effect = lambda idx: _mk_atomgroup(1) + u.dimensions = np.array([10.0, 10.0, 10.0, 90.0, 90.0, 90.0]) + + args = MagicMock( + force_partitioning=1.0, combined_forcetorque=True, customised_axes=True + ) + + axes_manager = MagicMock() + axes_manager.get_residue_axes.return_value = ( + np.eye(3), + np.eye(3), + np.zeros(3), + np.ones(3), + ) + + shared = { + "reduced_universe": u, + "groups": {7: [0]}, + "levels": [["united_atom", "residue", "polymer"]], + "beads": {(0, "residue"): [np.array([1, 2, 3])]}, + "args": args, + "axes_manager": axes_manager, + } + + with ( + patch.object( + node._ft, "get_weighted_forces", return_value=np.array([1.0, 0.0, 0.0]) + ), + patch.object( + node._ft, "get_weighted_torques", return_value=np.array([0.0, 1.0, 0.0]) + ), + patch.object( + node._ft, + "compute_frame_covariance", + return_value=(np.eye(3), 2.0 * np.eye(3)), + ), + ): + out = node.run({"shared": shared}) + + assert "forcetorque" in out + assert out["forcetorque"]["res"] == {} + + +def test_process_united_atom_returns_when_no_beads_for_level(): + node = FrameCovarianceNode() + + res = MagicMock() + res.atoms = MagicMock() + mol = MagicMock() + mol.residues = [res] + + axes_manager = MagicMock() + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + molcount = {} + + node._process_united_atom( + u=MagicMock(), + mol=mol, + mol_id=0, + group_id=0, + beads={}, + axes_manager=axes_manager, + box=np.array([10.0, 10.0, 10.0], dtype=float), + force_partitioning=1.0, + customised_axes=False, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + molcount=molcount, + ) + + assert out_force["ua"] == {} + assert out_torque["ua"] == {} + assert molcount == {} + axes_manager.get_UA_axes.assert_not_called() + axes_manager.get_vanilla_axes.assert_not_called() + + +def test_get_residue_axes_vanilla_branch_returns_arrays(monkeypatch): + node = FrameCovarianceNode() + + monkeypatch.setattr( + "CodeEntropy.levels.nodes.covariance.make_whole", lambda _ag: None + ) + + mol = MagicMock() + mol.atoms.principal_axes.return_value = np.eye(3) * 2 + + bead = MagicMock() + bead.center_of_mass.return_value = np.array([1.0, 2.0, 3.0]) + + axes_manager = MagicMock() + axes_manager.get_vanilla_axes.return_value = (np.eye(3), np.array([9.0, 8.0, 7.0])) + + trans, rot, center, moi = node._get_residue_axes( + mol=mol, + bead=bead, + local_res_i=0, + axes_manager=axes_manager, + customised_axes=False, + ) + + assert trans.shape == (3, 3) + assert rot.shape == (3, 3) + assert center.shape == (3,) + assert moi.shape == (3,) + assert np.allclose(trans, np.eye(3) * 2) + assert np.allclose(rot, np.eye(3)) + assert np.allclose(center, np.array([1.0, 2.0, 3.0])) + assert np.allclose(moi, np.array([9.0, 8.0, 7.0])) + + +def test_get_polymer_axes_returns_arrays(monkeypatch): + node = FrameCovarianceNode() + + monkeypatch.setattr( + "CodeEntropy.levels.nodes.covariance.make_whole", lambda _ag: None + ) + + mol = MagicMock() + mol.atoms.principal_axes.return_value = np.eye(3) * 3 + + bead = MagicMock() + bead.center_of_mass.return_value = np.array([0.0, 0.0, 0.0]) + + axes_manager = MagicMock() + axes_manager.get_vanilla_axes.return_value = (np.eye(3), np.array([1.0, 1.0, 1.0])) + + trans, rot, center, moi = node._get_polymer_axes( + mol=mol, + bead=bead, + axes_manager=axes_manager, + ) + + assert trans.shape == (3, 3) + assert rot.shape == (3, 3) + assert center.shape == (3,) + assert moi.shape == (3,) + assert np.allclose(trans, np.eye(3) * 3) + assert np.allclose(rot, np.eye(3)) + assert np.allclose(center, np.array([0.0, 0.0, 0.0])) + assert np.allclose(moi, np.array([1.0, 1.0, 1.0])) + + +def test_process_united_atom_updates_outputs_and_molcount(): + node = FrameCovarianceNode() + + node._build_ua_vectors = MagicMock( + return_value=( + [np.array([1.0, 0.0, 0.0])], + [np.array([0.0, 1.0, 0.0])], + ) + ) + + F = np.eye(3) + T = np.eye(3) * 2 + node._ft.compute_frame_covariance = MagicMock(return_value=(F, T)) + + u = MagicMock() + u.atoms = MagicMock() + u.atoms.__getitem__.side_effect = lambda idx: _BeadGroup(1) + + res = MagicMock() + res.atoms = MagicMock() + mol = MagicMock() + mol.residues = [res] + + beads = {(0, "united_atom", 0): [123]} + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + molcount = {} + + node._process_united_atom( + u=u, + mol=mol, + mol_id=0, + group_id=7, + beads=beads, + axes_manager=MagicMock(), + box=np.array([10.0, 10.0, 10.0]), + force_partitioning=1.0, + customised_axes=False, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + molcount=molcount, + ) + + key = (7, 0) + assert np.allclose(out_force["ua"][key], F) + assert np.allclose(out_torque["ua"][key], T) + assert molcount[key] == 1 + + +def test_process_residue_returns_early_when_no_beads(): + node = FrameCovarianceNode() + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + + node._process_residue( + u=MagicMock(), + mol=MagicMock(), + mol_id=0, + group_id=0, + beads={}, + axes_manager=MagicMock(), + box=np.array([10.0, 10.0, 10.0]), + customised_axes=False, + force_partitioning=1.0, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=None, + molcount={}, + combined=False, + ) + + assert out_force["res"] == {} + assert out_torque["res"] == {} + + +def test_build_ua_vectors_customised_axes_true_calls_get_UA_axes(): + node = FrameCovarianceNode() + + bead = _BeadGroup(1) + residue_atoms = MagicMock() + + axes_manager = MagicMock() + axes_manager.get_UA_axes.return_value = ( + np.eye(3), + np.eye(3), + np.array([0.0, 0.0, 0.0]), + np.array([1.0, 1.0, 1.0]), + ) + + node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) + node._ft.get_weighted_torques = MagicMock(return_value=np.array([4.0, 5.0, 6.0])) + + force_vecs, torque_vecs = node._build_ua_vectors( + bead_groups=[bead], + residue_atoms=residue_atoms, + axes_manager=axes_manager, + box=np.array([10.0, 10.0, 10.0]), + force_partitioning=1.0, + customised_axes=True, + is_highest=True, + ) + + axes_manager.get_UA_axes.assert_called_once() + assert len(force_vecs) == 1 and len(torque_vecs) == 1 + + +def test_build_ua_vectors_vanilla_path_uses_principal_axes_and_vanilla_axes( + monkeypatch, +): + node = FrameCovarianceNode() + + residue_atoms = MagicMock() + residue_atoms.principal_axes.return_value = np.eye(3) + + bead = _BeadGroup(1) + + axes_manager = MagicMock() + axes_manager.get_vanilla_axes.return_value = ( + np.eye(3) * 2, + np.array([9.0, 8.0, 7.0]), + ) + + monkeypatch.setattr(covmod, "make_whole", lambda *_: None) + + node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) + node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) + + force_vecs, torque_vecs = node._build_ua_vectors( + bead_groups=[bead], + residue_atoms=residue_atoms, + axes_manager=axes_manager, + box=np.array([10.0, 10.0, 10.0]), + force_partitioning=1.0, + customised_axes=False, + is_highest=True, + ) + + axes_manager.get_vanilla_axes.assert_called_once() + assert len(force_vecs) == 1 and len(torque_vecs) == 1 + + +def test_process_united_atom_skips_when_any_bead_group_is_empty(): + node = FrameCovarianceNode() + + res = MagicMock() + res.atoms = MagicMock() + mol = MagicMock() + mol.residues = [res] + + u = MagicMock() + u.atoms = MagicMock() + u.atoms.__getitem__.side_effect = lambda idx: _EmptyGroup() + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + + node._process_united_atom( + u=u, + mol=mol, + mol_id=0, + group_id=0, + beads={(0, "united_atom", 0): [123]}, + axes_manager=MagicMock(), + box=np.array([10.0, 10.0, 10.0]), + force_partitioning=1.0, + customised_axes=False, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + molcount={}, + ) + + assert out_force["ua"] == {} + assert out_torque["ua"] == {} + + +def test_process_residue_returns_early_when_any_bead_group_is_empty(): + node = FrameCovarianceNode() + + u = MagicMock() + u.atoms = MagicMock() + u.atoms.__getitem__.side_effect = lambda idx: _EmptyGroup() + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + + node._process_residue( + u=u, + mol=MagicMock(), + mol_id=0, + group_id=0, + beads={(0, "residue"): [np.array([1, 2, 3])]}, + axes_manager=MagicMock(), + box=np.array([10.0, 10.0, 10.0]), + customised_axes=False, + force_partitioning=1.0, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=None, + molcount={}, + combined=False, + ) + + assert out_force["res"] == {} + assert out_torque["res"] == {} + + +def test_process_polymer_skips_when_any_bead_group_is_empty(): + node = FrameCovarianceNode() + + u = MagicMock() + u.atoms = MagicMock() + u.atoms.__getitem__.side_effect = lambda idx: _EmptyGroup() + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + out_ft = {"ua": {}, "res": {}, "poly": {}} + + node._process_polymer( + u=u, + mol=MagicMock(), + mol_id=0, + group_id=7, + beads={(0, "polymer"): [np.array([1, 2, 3])]}, + axes_manager=MagicMock(), + box=np.array([10.0, 10.0, 10.0]), + force_partitioning=1.0, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=out_ft, + molcount={}, + combined=True, + ) + + assert out_force["poly"] == {} + assert out_torque["poly"] == {} + assert out_ft["poly"] == {} + + +def test_process_polymer_happy_path_updates_force_torque_and_optional_ft(): + node = FrameCovarianceNode() + + u = MagicMock() + u.atoms = MagicMock() + + bead_obj = _BeadGroup(1) + u.atoms.__getitem__.side_effect = lambda idx: bead_obj + + mol = MagicMock() + mol.atoms = MagicMock() + + axes_manager = MagicMock() + + f_vec = np.array([1.0, 0.0, 0.0], dtype=float) + t_vec = np.array([0.0, 1.0, 0.0], dtype=float) + + F = np.eye(3) + T = 2.0 * np.eye(3) + FT = np.eye(6) + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + out_ft = {"ua": {}, "res": {}, "poly": {}} + molcount = {} + + with ( + patch.object( + node, + "_get_polymer_axes", + return_value=(np.eye(3), np.eye(3), np.zeros(3), np.ones(3)), + ) as axes_spy, + patch.object(node._ft, "get_weighted_forces", return_value=f_vec) as f_spy, + patch.object(node._ft, "get_weighted_torques", return_value=t_vec) as t_spy, + patch.object( + node._ft, "compute_frame_covariance", return_value=(F, T) + ) as cov_spy, + patch.object(node, "_build_ft_block", return_value=FT) as ft_spy, + ): + node._process_polymer( + u=u, + mol=mol, + mol_id=0, + group_id=7, + beads={(0, "polymer"): [np.array([1, 2, 3])]}, + axes_manager=axes_manager, + box=np.array([10.0, 10.0, 10.0]), + force_partitioning=0.5, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=out_ft, + molcount=molcount, + combined=True, + ) + + assert u.atoms.__getitem__.call_count == 1 + axes_spy.assert_called_once_with(mol=mol, bead=bead_obj, axes_manager=axes_manager) + + f_spy.assert_called_once() + t_spy.assert_called_once() + cov_spy.assert_called_once() + + np.testing.assert_allclose(out_force["poly"][7], F) + np.testing.assert_allclose(out_torque["poly"][7], T) + assert molcount[7] == 1 + + ft_spy.assert_called_once() + np.testing.assert_allclose(out_ft["poly"][7], FT) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_init_covariance_accumulators_node.py b/tests/unit/CodeEntropy/levels/nodes/test_init_covariance_accumulators_node.py new file mode 100644 index 00000000..fd5d31e8 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_init_covariance_accumulators_node.py @@ -0,0 +1,23 @@ +import numpy as np + +from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode + + +def test_init_covariance_accumulators_allocates_and_sets_aliases(): + node = InitCovarianceAccumulatorsNode() + + shared = {"groups": {9: [1, 2], 2: [3]}} + + out = node.run(shared) + + assert out["group_id_to_index"] == {9: 0, 2: 1} + assert out["index_to_group_id"] == [9, 2] + + assert shared["force_covariances"]["res"] == [None, None] + assert shared["torque_covariances"]["poly"] == [None, None] + + assert np.all(shared["frame_counts"]["res"] == np.array([0, 0])) + assert np.all(shared["forcetorque_counts"]["poly"] == np.array([0, 0])) + + assert shared["force_torque_stats"] is shared["forcetorque_covariances"] + assert shared["force_torque_counts"] is shared["forcetorque_counts"] diff --git a/tests/unit/CodeEntropy/levels/test_axes.py b/tests/unit/CodeEntropy/levels/test_axes.py new file mode 100644 index 00000000..d6d0093f --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_axes.py @@ -0,0 +1,701 @@ +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from CodeEntropy.levels.axes import AxesCalculator + + +class _FakeAtom: + def __init__(self, index: int, mass: float, position): + self.index = int(index) + self.mass = float(mass) + self.position = np.asarray(position, dtype=float) + + def __add__(self, other): + # atom + atomgroup => atomgroup + if isinstance(other, _FakeAtomGroup): + return _FakeAtomGroup([self] + list(other._atoms)) + if isinstance(other, _FakeAtom): + return _FakeAtomGroup([self, other]) + raise TypeError(f"Unsupported add: _FakeAtom + {type(other)}") + + +class _FakeAtomGroup: + def __init__(self, atoms, positions=None, select_map=None): + self._atoms = list(atoms) + self._select_map = dict(select_map or {}) + + if positions is None: + if self._atoms: + self.positions = np.vstack([a.position for a in self._atoms]).astype( + float + ) + else: + self.positions = np.zeros((0, 3), dtype=float) + else: + self.positions = np.asarray(positions, dtype=float) + + def __len__(self): + return len(self._atoms) + + def __iter__(self): + return iter(self._atoms) + + def __getitem__(self, idx): + return self._atoms[idx] + + @property + def masses(self): + return np.asarray([a.mass for a in self._atoms], dtype=float) + + def select_atoms(self, query: str): + return self._select_map.get(query, _FakeAtomGroup([])) + + def __add__(self, other): + if isinstance(other, _FakeAtomGroup): + return _FakeAtomGroup(self._atoms + other._atoms) + if isinstance(other, _FakeAtom): + return _FakeAtomGroup(self._atoms + [other]) + raise TypeError(f"Unsupported add: _FakeAtomGroup + {type(other)}") + + +def _atom(index=0, mass=12.0, pos=(0.0, 0.0, 0.0), resindex=0): + a = MagicMock() + a.index = index + a.mass = mass + a.position = np.array(pos, dtype=float) + a.resindex = resindex + return a + + +def test_get_residue_axes_empty_residue_raises(): + ax = AxesCalculator() + u = MagicMock() + u.select_atoms.return_value = [] + + with pytest.raises(ValueError): + ax.get_residue_axes(u, index=5) + + +def test_get_residue_axes_no_bonds_uses_custom_principal_axes(monkeypatch): + ax = AxesCalculator() + + # residue selection: non-empty, has heavy atoms and positions + residue = MagicMock() + residue.__len__.return_value = 1 + residue.atoms.center_of_mass.return_value = np.array([0.0, 0.0, 0.0]) + residue.select_atoms.return_value = MagicMock(positions=np.zeros((2, 3))) + residue.center_of_mass.return_value = np.array([0.0, 0.0, 0.0]) + + u = MagicMock() + u.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90]) + + # atom_set empty => "no bonds to other residues" branch + def _select_atoms(q): + if q.startswith("(resindex"): + return [] + if q.startswith("resindex "): + return residue + return [] + + u.select_atoms.side_effect = _select_atoms + + monkeypatch.setattr(ax, "get_UA_masses", lambda mol: [10.0, 12.0]) + monkeypatch.setattr(ax, "get_moment_of_inertia_tensor", lambda **kwargs: np.eye(3)) + monkeypatch.setattr( + ax, + "get_custom_principal_axes", + lambda moi: (np.eye(3), np.array([3.0, 2.0, 1.0])), + ) + + trans, rot, center, moi = ax.get_residue_axes(u, index=7) + + assert np.allclose(trans, np.eye(3)) + assert np.allclose(rot, np.eye(3)) + assert np.allclose(center, np.array([0.0, 0.0, 0.0])) + assert np.allclose(moi, np.array([3.0, 2.0, 1.0])) + + +def test_get_residue_axes_with_bonds_uses_vanilla_axes(monkeypatch): + ax = AxesCalculator() + + residue = MagicMock() + residue.__len__.return_value = 1 + residue.atoms.center_of_mass.return_value = np.array([1.0, 2.0, 3.0]) + residue.center_of_mass.return_value = np.array([1.0, 2.0, 3.0]) + + u = MagicMock() + u.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90]) + u.atoms.principal_axes.return_value = np.eye(3) + + # atom_set non-empty => bonded branch + def _select_atoms(q): + if q.startswith("(resindex"): + return [1] # non-empty + if q.startswith("resindex "): + return residue + return [] + + u.select_atoms.side_effect = _select_atoms + + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", lambda _ag: None) + monkeypatch.setattr( + ax, "get_vanilla_axes", lambda mol: (np.eye(3) * 2, np.array([9.0, 8.0, 7.0])) + ) + + trans, rot, center, moi = ax.get_residue_axes(u, index=10) + + assert np.allclose(trans, np.eye(3)) + assert np.allclose(rot, np.eye(3) * 2) + assert np.allclose(moi, np.array([9.0, 8.0, 7.0])) + + +def test_get_UA_axes_uses_principal_axes_when_single_heavy(monkeypatch): + ax = AxesCalculator() + + u = MagicMock() + u.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90]) + u.atoms.principal_axes.return_value = np.eye(3) + + # heavy_atoms length <= 1 => principal_axes path + heavy_atom = MagicMock(index=5) + heavy_atoms = [heavy_atom] + + def _sel(q): + if q == "prop mass > 1.1": + return heavy_atoms + if q.startswith("index "): + # return atom group with positions + ag = MagicMock() + ag.positions = np.array([[4.0, 0.0, 0.0]]) + ag.__getitem__.return_value = MagicMock( + mass=12.0, position=np.array([4.0, 0.0, 0.0]), index=5 + ) + return ag + return [] + + u.select_atoms.side_effect = _sel + + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", lambda _ag: None) + monkeypatch.setattr( + ax, + "get_bonded_axes", + lambda system, atom, dimensions: (np.eye(3), np.array([1.0, 2.0, 3.0])), + ) + + trans, rot, center, moi = ax.get_UA_axes(u, index=0) + + assert np.allclose(trans, np.eye(3)) + assert np.allclose(rot, np.eye(3)) + assert np.allclose(center, np.array([4.0, 0.0, 0.0])) + assert np.allclose(moi, np.array([1.0, 2.0, 3.0])) + + +def test_get_UA_axes_raises_when_bonded_axes_fail(monkeypatch): + ax = AxesCalculator() + u = MagicMock() + u.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90]) + + heavy_atom = MagicMock(index=5) + heavy_atoms = [heavy_atom] + + def _sel(q): + if q == "prop mass > 1.1": + return heavy_atoms + if q.startswith("index "): + ag = MagicMock() + ag.positions = np.array([[1.0, 1.0, 1.0]]) + ag.__getitem__.return_value = MagicMock( + mass=12.0, position=np.array([1.0, 1.0, 1.0]), index=5 + ) + return ag + return [] + + u.select_atoms.side_effect = _sel + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", lambda _ag: None) + monkeypatch.setattr(ax, "get_bonded_axes", lambda **kwargs: (None, None)) + + with pytest.raises(ValueError): + ax.get_UA_axes(u, index=0) + + +def test_get_custom_axes_degenerate_axis1_raises(): + ax = AxesCalculator() + a = np.zeros(3) + b_list = [np.zeros(3)] + with pytest.raises(ValueError): + ax.get_custom_axes( + a=a, b_list=b_list, c=np.zeros(3), dimensions=np.array([10.0, 10.0, 10.0]) + ) + + +def test_get_custom_axes_normalizes_and_uses_bc_when_multiple_b(monkeypatch): + ax = AxesCalculator() + a = np.array([0.0, 0.0, 0.0]) + b_list = [np.array([1.0, 0.0, 0.0]), np.array([1.0, 0.0, 0.0])] + c = np.array([0.0, 1.0, 0.0]) + + axes = ax.get_custom_axes( + a=a, b_list=b_list, c=c, dimensions=np.array([10.0, 10.0, 10.0]) + ) + assert axes.shape == (3, 3) + # axes rows must be unit length + assert np.allclose(np.linalg.norm(axes, axis=1), 1.0) + + +def test_get_custom_moment_of_inertia_two_atom_sets_smallest_to_zero(): + ax = AxesCalculator() + + a0 = _FakeAtom(0, 12.0, [0.0, 0.0, 0.0]) + a1 = _FakeAtom(1, 1.0, [1.0, 0.0, 0.0]) + ua = _FakeAtomGroup([a0, a1]) + + axes = np.eye(3) + moi = ax.get_custom_moment_of_inertia( + UA=ua, + custom_rotation_axes=axes, + center_of_mass=np.array([0.0, 0.0, 0.0]), + dimensions=np.array([10.0, 10.0, 10.0]), + ) + assert moi.shape == (3,) + assert np.isclose(np.min(moi), 0.0) + + +def test_get_flipped_axes_flips_negative_dot(): + ax = AxesCalculator() + + a0 = _FakeAtom(0, 12.0, [0.0, 0.0, 0.0]) + ua = _FakeAtomGroup([a0]) + + # axis0 points opposite to rr_axis -> should flip + custom_axes = np.array( + [ + [-1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=float, + ) + flipped = ax.get_flipped_axes( + UA=ua, + custom_axes=custom_axes, + center_of_mass=np.array([1.0, 0.0, 0.0]), + dimensions=np.array([10.0, 10.0, 10.0]), + ) + assert np.allclose(flipped[0], np.array([1.0, 0.0, 0.0])) + + +def test_get_custom_principal_axes_flips_z_when_left_handed(): + ax = AxesCalculator() + moi = np.eye(3) + axes, vals = ax.get_custom_principal_axes(moi) + assert axes.shape == (3, 3) + assert vals.shape == (3,) + + +def test_get_UA_masses_sums_bonded_hydrogens(): + ax = AxesCalculator() + + heavy = _FakeAtom(index=0, mass=12.0, position=[0, 0, 0]) + h1 = _FakeAtom(index=1, mass=1.0, position=[1, 0, 0]) + h2 = _FakeAtom(index=2, mass=1.0, position=[0, 1, 0]) + + bonded_atoms = _FakeAtomGroup( + [h1, h2], select_map={"mass 1 to 1.1": _FakeAtomGroup([h1, h2])} + ) + + mol = _FakeAtomGroup( + [heavy, h1, h2], + select_map={ + "bonded index 0": bonded_atoms, + }, + ) + + masses = ax.get_UA_masses(mol) + assert masses == [14.0] + + +def test_get_vanilla_axes_sorts_eigenvalues_desc_by_abs(monkeypatch): + ax = AxesCalculator() + mol = MagicMock() + moi_tensor = np.diag([1.0, -10.0, 3.0]) + mol.moment_of_inertia.return_value = moi_tensor + mol.principal_axes.return_value = np.eye(3) + mol.atoms = MagicMock() + + # avoid real MDAnalysis unwrap + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", lambda _ag: None) + + axes, moments = ax.get_vanilla_axes(mol) + + assert axes.shape == (3, 3) + # sorted by abs descending => -10, 3, 1 + assert np.allclose(moments, np.array([-10.0, 3.0, 1.0])) + + +def test_find_bonded_atoms_selects_heavy_and_hydrogen_groups(): + ax = AxesCalculator() + + bonded = MagicMock() + heavy = MagicMock() + hyd = MagicMock() + bonded.select_atoms.side_effect = [heavy, hyd] + + system = MagicMock() + system.select_atoms.return_value = bonded + + out_heavy, out_h = ax.find_bonded_atoms(atom_idx=7, system=system) + + system.select_atoms.assert_called_once_with("bonded index 7") + bonded.select_atoms.assert_any_call("mass 2 to 999") + bonded.select_atoms.assert_any_call("mass 1 to 1.1") + assert out_heavy is heavy + assert out_h is hyd + + +def test_get_bonded_axes_non_heavy_returns_none(): + ax = AxesCalculator() + system = MagicMock() + atom = _atom(index=1, mass=1.0) + + out_axes, out_moi = ax.get_bonded_axes( + system, atom, dimensions=np.array([10.0, 10.0, 10.0]) + ) + assert out_axes is None + assert out_moi is None + + +def test_get_bonded_axes_case1_uses_vanilla_axes_and_returns_flipped(monkeypatch): + ax = AxesCalculator() + system = MagicMock() + atom = _atom(index=1, mass=12.0, pos=(1, 0, 0)) + + heavy = _FakeAtomGroup([]) # len == 0 -> case1 + hyd = _FakeAtomGroup([_atom(index=2, mass=1.0)]) + monkeypatch.setattr(ax, "find_bonded_atoms", lambda _idx, _sys: (heavy, hyd)) + + monkeypatch.setattr( + ax, "get_vanilla_axes", lambda _ag: (np.eye(3) * 7, np.array([1.0, 2.0, 3.0])) + ) + monkeypatch.setattr(ax, "get_flipped_axes", lambda ua, axes, com, dims: axes * -1) + + out_axes, out_moi = ax.get_bonded_axes(system, atom, np.array([10.0, 10.0, 10.0])) + + assert np.allclose(out_axes, -np.eye(3) * 7) + assert np.allclose(out_moi, np.array([1.0, 2.0, 3.0])) + + +def test_get_bonded_axes_case2_one_heavy_no_h_calls_get_custom_axes_and_custom_moi( + monkeypatch, +): + ax = AxesCalculator() + system = MagicMock() + atom = _atom(index=1, mass=12.0, pos=(0, 0, 0)) + + heavy = _FakeAtomGroup( + [_atom(index=3, mass=12.0, pos=(1, 0, 0))], positions=np.array([[1, 0, 0]]) + ) + hyd = _FakeAtomGroup([]) + + monkeypatch.setattr(ax, "find_bonded_atoms", lambda _idx, _sys: (heavy, hyd)) + monkeypatch.setattr(ax, "get_custom_axes", lambda **kwargs: np.eye(3)) + monkeypatch.setattr( + ax, "get_custom_moment_of_inertia", lambda **kwargs: np.array([9.0, 8.0, 7.0]) + ) + monkeypatch.setattr(ax, "get_flipped_axes", lambda ua, axes, com, dims: axes) + + out_axes, out_moi = ax.get_bonded_axes(system, atom, np.array([10.0, 10.0, 10.0])) + + assert out_axes.shape == (3, 3) + assert np.allclose(out_moi, np.array([9.0, 8.0, 7.0])) + + +def test_get_bonded_axes_case3_one_heavy_with_h_calls_get_custom_axes(monkeypatch): + ax = AxesCalculator() + system = MagicMock() + atom = _atom(index=1, mass=12.0, pos=(0, 0, 0)) + + heavy = _FakeAtomGroup( + [_atom(index=3, mass=12.0, pos=(1, 0, 0))], positions=np.array([[1, 0, 0]]) + ) + hyd = _FakeAtomGroup([_atom(index=4, mass=1.0, pos=(0, 1, 0))]) + + monkeypatch.setattr(ax, "find_bonded_atoms", lambda _idx, _sys: (heavy, hyd)) + called = {"n": 0} + + def _custom_axes(**kwargs): + called["n"] += 1 + return np.eye(3) * 2 + + monkeypatch.setattr(ax, "get_custom_axes", _custom_axes) + monkeypatch.setattr( + ax, "get_custom_moment_of_inertia", lambda **kwargs: np.array([1.0, 1.0, 1.0]) + ) + monkeypatch.setattr(ax, "get_flipped_axes", lambda ua, axes, com, dims: axes) + + out_axes, out_moi = ax.get_bonded_axes(system, atom, np.array([10.0, 10.0, 10.0])) + + assert called["n"] == 1 + assert np.allclose(out_axes, np.eye(3) * 2) + assert np.allclose(out_moi, np.array([1.0, 1.0, 1.0])) + + +def test_get_bonded_axes_case5_two_heavy_calls_get_custom_axes(monkeypatch): + ax = AxesCalculator() + system = MagicMock() + atom = _atom(index=1, mass=12.0, pos=(0, 0, 0)) + + heavy_atoms = [ + _atom(index=3, mass=12.0, pos=(1, 0, 0)), + _atom(index=5, mass=12.0, pos=(0, 1, 0)), + ] + heavy = _FakeAtomGroup(heavy_atoms, positions=np.array([[1, 0, 0], [0, 1, 0]])) + heavy.positions = np.array([[1, 0, 0], [0, 1, 0]]) + hyd = _FakeAtomGroup([]) + + monkeypatch.setattr(ax, "find_bonded_atoms", lambda _idx, _sys: (heavy, hyd)) + monkeypatch.setattr(ax, "get_custom_axes", lambda **kwargs: np.eye(3) * 3) + monkeypatch.setattr( + ax, "get_custom_moment_of_inertia", lambda **kwargs: np.array([2.0, 2.0, 2.0]) + ) + monkeypatch.setattr(ax, "get_flipped_axes", lambda ua, axes, com, dims: axes) + + out_axes, out_moi = ax.get_bonded_axes(system, atom, np.array([10.0, 10.0, 10.0])) + + assert np.allclose(out_axes, np.eye(3) * 3) + assert np.allclose(out_moi, np.array([2.0, 2.0, 2.0])) + + +def test_get_residue_axes_no_bonds_custom_path(monkeypatch): + ax = AxesCalculator() + + residue = MagicMock() + residue.__len__.return_value = 1 + residue.atoms.center_of_mass.return_value = np.array([0.0, 0.0, 0.0]) + residue.select_atoms.return_value = MagicMock(positions=np.zeros((2, 3))) + residue.center_of_mass.return_value = np.array([0.0, 0.0, 0.0]) + + u = MagicMock() + u.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90]) + + def _select_atoms(q): + if q.startswith("(resindex"): + return [] # no bonds + if q.startswith("resindex "): + return residue + return [] + + u.select_atoms.side_effect = _select_atoms + + monkeypatch.setattr(ax, "get_UA_masses", lambda mol: [10.0, 12.0]) + monkeypatch.setattr(ax, "get_moment_of_inertia_tensor", lambda **kwargs: np.eye(3)) + monkeypatch.setattr( + ax, + "get_custom_principal_axes", + lambda moi: (np.eye(3), np.array([3.0, 2.0, 1.0])), + ) + + trans, rot, center, moi = ax.get_residue_axes(u, index=7) + + assert trans.shape == (3, 3) + assert rot.shape == (3, 3) + assert np.allclose(center, np.array([0.0, 0.0, 0.0])) + assert np.allclose(moi, np.array([3.0, 2.0, 1.0])) + + +def test_get_residue_axes_with_bonds_vanilla_path(monkeypatch): + ax = AxesCalculator() + + residue = MagicMock() + residue.__len__.return_value = 1 + residue.atoms.principal_axes.return_value = np.eye(3) * 2 + residue.atoms.center_of_mass.return_value = np.array([1.0, 2.0, 3.0]) + residue.center_of_mass.return_value = np.array([1.0, 2.0, 3.0]) + + u = MagicMock() + u.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90]) + u.atoms.principal_axes.return_value = np.eye(3) * 2 + + def _select_atoms(q): + if q.startswith("(resindex"): + return [1] + if q.startswith("resindex "): + return residue + return [] + + u.select_atoms.side_effect = _select_atoms + + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", lambda _ag: None) + monkeypatch.setattr( + ax, "get_vanilla_axes", lambda mol: (np.eye(3) * 2, np.array([9.0, 8.0, 7.0])) + ) + + trans, rot, center, moi = ax.get_residue_axes(u, index=10) + + assert np.allclose(trans, np.eye(3) * 2) + assert np.allclose(rot, np.eye(3) * 2) + assert np.allclose(center, np.array([1.0, 2.0, 3.0])) + assert np.allclose(moi, np.array([9.0, 8.0, 7.0])) + + +def test_get_vector_wraps_periodic_boundaries(): + ac = AxesCalculator() + dims = np.array([10.0, 10.0, 10.0]) + + a = np.array([9.0, 0.0, 0.0]) + b = np.array([1.0, 0.0, 0.0]) + out = ac.get_vector(a, b, dims) + np.testing.assert_allclose(out, np.array([2.0, 0.0, 0.0])) + + +def test_get_custom_axes_raises_when_axis1_degenerate(): + ac = AxesCalculator() + a = np.zeros(3) + b_list = [np.zeros(3), np.zeros(3)] + c = np.ones(3) + dims = np.array([10.0, 10.0, 10.0]) + with pytest.raises(ValueError): + ac.get_custom_axes(a=a, b_list=b_list, c=c, dimensions=dims) + + +def test_get_custom_axes_raises_when_normalization_degenerate(): + ac = AxesCalculator() + dims = np.array([10.0, 10.0, 10.0]) + + a = np.zeros(3) + b_list = [np.array([1.0, 0.0, 0.0])] + c = np.array([2.0, 0.0, 0.0]) + + with pytest.raises(ValueError): + ac.get_custom_axes(a=a, b_list=b_list, c=c, dimensions=dims) + + +def test_get_custom_principal_axes_flips_z_for_handedness(): + ac = AxesCalculator() + + moi = np.diag([3.0, 2.0, 1.0]) + axes, vals = ac.get_custom_principal_axes(moi) + + assert axes.shape == (3, 3) + assert vals.shape == (3,) + + cross_xy = np.cross(axes[0], axes[1]) + assert float(np.dot(cross_xy, axes[2])) > 0.0 + + +def test_get_moment_of_inertia_tensor_shape_and_symmetry(): + ac = AxesCalculator() + dims = np.array([10.0, 10.0, 10.0]) + com = np.zeros(3) + positions = np.array([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]) + masses = [1.0, 3.0] + + moi = ac.get_moment_of_inertia_tensor(com, positions, masses, dims) + + assert moi.shape == (3, 3) + np.testing.assert_allclose(moi, moi.T) + + +def test_get_custom_moment_of_inertia_len2_zeros_smallest_component(): + ac = AxesCalculator() + dims = np.array([10.0, 10.0, 10.0]) + + UA = MagicMock() + UA.positions = np.array([[1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]) + UA.masses = [12.0, 1.0] + UA.__len__.return_value = 2 + + axes = np.eye(3) + com = np.zeros(3) + + moi = ac.get_custom_moment_of_inertia(UA, axes, com, dims) + + assert moi.shape == (3,) + assert np.isclose(np.min(moi), 0.0) + + +def test_get_UA_axes_multiple_heavy_atoms_uses_custom_principal_axes(monkeypatch): + ax = AxesCalculator() + + heavy_atoms = _FakeAtomGroup( + [ + _FakeAtom(0, 12.0, [0, 0, 0]), + _FakeAtom(1, 12.0, [1, 0, 0]), + ], + positions=np.array([[0, 0, 0], [1, 0, 0]], dtype=float), + ) + + system_atom = _FakeAtom(index=0, mass=12.0, position=[0, 0, 0]) + heavy_atom_selection = _FakeAtomGroup( + [system_atom], positions=np.array([[0, 0, 0]], dtype=float) + ) + + class _Atoms: + def center_of_mass(self, *args, **kwargs): + return np.array([0.0, 0.0, 0.0], dtype=float) + + def __getitem__(self, idx): + return system_atom + + data_container = MagicMock() + data_container.atoms = _Atoms() + data_container.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90], dtype=float) + + def _select_atoms(q): + if q == "prop mass > 1.1": + return heavy_atoms + if q.startswith("index "): + return heavy_atom_selection + return _FakeAtomGroup([]) + + data_container.select_atoms.side_effect = _select_atoms + + monkeypatch.setattr( + ax, + "get_bonded_axes", + lambda system, atom, dimensions: (np.eye(3), np.array([1.0, 1.0, 1.0])), + ) + monkeypatch.setattr(ax, "get_UA_masses", lambda _ag: [12.0, 12.0]) + + got_tensor = MagicMock(return_value=np.eye(3)) + monkeypatch.setattr(ax, "get_moment_of_inertia_tensor", got_tensor) + + got_custom_axes = MagicMock(return_value=(np.eye(3), np.array([3.0, 2.0, 1.0]))) + monkeypatch.setattr(ax, "get_custom_principal_axes", got_custom_axes) + + trans_axes, rot_axes, center, moi = ax.get_UA_axes(data_container, index=0) + + assert trans_axes.shape == (3, 3) + assert rot_axes.shape == (3, 3) + assert np.allclose(center, np.array([0.0, 0.0, 0.0])) + assert moi.shape == (3,) + got_tensor.assert_called_once() + got_custom_axes.assert_called_once() + + +def test_get_bonded_axes_returns_none_none_if_custom_axes_none(monkeypatch): + ax = AxesCalculator() + + atom = _FakeAtom(index=7, mass=12.0, position=[0, 0, 0]) + system = MagicMock() + dimensions = np.array([10.0, 10.0, 10.0], dtype=float) + + heavy_bonded = _FakeAtomGroup( + [_FakeAtom(8, 12.0, [1, 0, 0])], + positions=np.array([[1.0, 0.0, 0.0]], dtype=float), + ) + light_bonded = _FakeAtomGroup([], positions=np.zeros((0, 3), dtype=float)) + + monkeypatch.setattr( + ax, "find_bonded_atoms", lambda _idx, _sys: (heavy_bonded, light_bonded) + ) + + monkeypatch.setattr(ax, "get_custom_axes", lambda **kwargs: None) + + custom_axes, moi = ax.get_bonded_axes( + system=system, atom=atom, dimensions=dimensions + ) + + assert custom_axes is None + assert moi is None diff --git a/tests/unit/CodeEntropy/levels/test_dihedrals.py b/tests/unit/CodeEntropy/levels/test_dihedrals.py new file mode 100644 index 00000000..8ea0f1fe --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_dihedrals.py @@ -0,0 +1,619 @@ +import contextlib +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np + +from CodeEntropy.levels.dihedrals import ConformationStateBuilder + + +class _AddableAG: + def __init__(self, name: str): + self.name = name + + def __add__(self, other: "_AddableAG") -> "_AddableAG": + return _AddableAG(f"({self.name}+{other.name})") + + +class _FakeProgress: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def add_task(self, *args, **kwargs): + return 1 + + def advance(self, *args, **kwargs): + return None + + +@contextlib.contextmanager +def _fake_progress_bar(*_args, **_kwargs): + yield _FakeProgress() + + +def test_select_heavy_residue_builds_two_selections(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol = MagicMock() + mol.residues = [MagicMock()] + mol.residues[0].atoms.indices = np.array([10, 11, 12], dtype=int) + + uops.select_atoms.side_effect = ["res_container", "heavy_only"] + + out = dt._select_heavy_residue(mol, res_id=0) + + assert out == "heavy_only" + assert uops.select_atoms.call_count == 2 + uops.select_atoms.assert_any_call(mol, "index 10:12") + uops.select_atoms.assert_any_call("res_container", "prop mass > 1.1") + + +def test_get_dihedrals_united_atom_collects_atoms_from_dihedral_objects(): + dt = ConformationStateBuilder(universe_operations=MagicMock()) + + d0 = MagicMock() + d0.atoms = "A0" + d1 = MagicMock() + d1.atoms = "A1" + + container = MagicMock() + container.dihedrals = [d0, d1] + + assert dt._get_dihedrals(container, level="united_atom") == ["A0", "A1"] + + +def test_get_dihedrals_residue_returns_empty_when_less_than_4_residues(): + dt = ConformationStateBuilder(universe_operations=MagicMock()) + + mol = MagicMock() + mol.residues = [MagicMock(), MagicMock(), MagicMock()] + mol.select_atoms = MagicMock() + + assert dt._get_dihedrals(mol, level="residue") == [] + mol.select_atoms.assert_not_called() + + +def test_get_dihedrals_residue_builds_one_dihedral_when_4_residues(): + dt = ConformationStateBuilder(universe_operations=MagicMock()) + + mol = MagicMock() + mol.residues = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + mol.select_atoms = MagicMock( + side_effect=[ + _AddableAG("a1"), + _AddableAG("a2"), + _AddableAG("a3"), + _AddableAG("a4"), + ] + ) + + out = dt._get_dihedrals(mol, level="residue") + + assert len(out) == 1 + assert isinstance(out[0], _AddableAG) + assert mol.select_atoms.call_count == 4 + + +def test_collect_dihedrals_for_group_handles_both_levels(): + dt = ConformationStateBuilder(universe_operations=MagicMock()) + + mol = MagicMock() + mol.residues = [MagicMock(), MagicMock()] + + with ( + patch.object( + dt, "_select_heavy_residue", side_effect=["heavy0", "heavy1"] + ) as sel_spy, + patch.object( + dt, "_get_dihedrals", side_effect=[["ua0"], ["ua1"], ["res_d0"]] + ) as get_spy, + ): + ua, res = dt._collect_dihedrals_for_group( + mol=mol, level_list=["united_atom", "residue"] + ) + + assert ua == [["ua0"], ["ua1"]] + assert res == ["res_d0"] + assert sel_spy.call_count == 2 + assert get_spy.call_count == 3 + + +def test_collect_peaks_for_group_sets_empty_outputs_when_no_dihedrals(): + dt = ConformationStateBuilder(universe_operations=MagicMock()) + + dihedrals_ua = [[], []] + dihedrals_res = [] + + with patch.object(dt, "_identify_peaks") as identify_spy: + peaks_ua, peaks_res = dt._collect_peaks_for_group( + data_container=MagicMock(), + molecules=[0], + dihedrals_ua=dihedrals_ua, + dihedrals_res=dihedrals_res, + bin_width=30.0, + start=0, + end=10, + step=1, + level_list=["united_atom", "residue"], + ) + + assert peaks_ua == [[], []] + assert peaks_res == [] + identify_spy.assert_not_called() + + +def test_identify_peaks_wraps_negative_angles_and_calls_find_histogram_peaks(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol = MagicMock() + mol.trajectory = [0, 1] + uops.extract_fragment.return_value = mol + + angles = np.array([[-10.0], [10.0]], dtype=float) + + class _FakeDihedral: + def __init__(self, _dihedrals): + pass + + def run(self): + return SimpleNamespace(results=SimpleNamespace(angles=angles)) + + with ( + patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral), + patch.object(dt, "_find_histogram_peaks", return_value=[15.0]) as peaks_spy, + ): + out = dt._identify_peaks( + data_container=MagicMock(), + molecules=[0], + dihedrals=[MagicMock()], + bin_width=180.0, + start=0, + end=2, + step=1, + ) + + assert out == [[15.0]] + peaks_spy.assert_called_once() + + +def test_find_histogram_peaks_hits_interior_and_wraparound_last_bin(): + popul = [0, 2, 0, 3] + bin_value = [10.0, 20.0, 30.0, 40.0] + assert ConformationStateBuilder._find_histogram_peaks(popul, bin_value) == [ + 20.0, + 40.0, + ] + + +def test_assign_states_initialises_then_extends_for_multiple_molecules(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol = MagicMock() + mol.trajectory = [0, 1] + uops.extract_fragment.return_value = mol + + angles = np.array([[5.0], [15.0]], dtype=float) + peaks = [[5.0, 15.0]] + + class _FakeDihedral: + def __init__(self, _dihedrals): + pass + + def run(self): + return SimpleNamespace(results=SimpleNamespace(angles=angles)) + + with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): + states = dt._assign_states( + data_container=MagicMock(), + molecules=[0, 1], + dihedrals=["D0"], + peaks=peaks, + start=0, + end=2, + step=1, + ) + + assert states == ["0", "1", "0", "1"] + + +def test_assign_states_for_group_sets_empty_lists_and_delegates_for_nonempty(): + dt = ConformationStateBuilder(universe_operations=MagicMock()) + + states_ua = {} + states_res = [None, None] + + with patch.object(dt, "_assign_states", return_value=["x"]) as assign_spy: + dt._assign_states_for_group( + data_container=MagicMock(), + group_id=1, + molecules=[99], + dihedrals_ua=[[], ["UA"]], + peaks_ua=[[], [["p"]]], + dihedrals_res=[], + peaks_res=[], + start=0, + end=2, + step=1, + level_list=["united_atom", "residue"], + states_ua=states_ua, + states_res=states_res, + ) + + assert states_ua[(1, 0)] == [] + assert states_ua[(1, 1)] == ["x"] + assert states_res[1] == [] + assert assign_spy.call_count == 1 + + +def test_build_conformational_states_runs_group_and_skips_empty_group(monkeypatch): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + groups = {0: [], 1: [7]} + levels = {7: ["residue"]} + + uops.extract_fragment.return_value = MagicMock(trajectory=[0]) + + monkeypatch.setattr(dt, "_collect_dihedrals_for_group", lambda **kw: ([], [])) + monkeypatch.setattr(dt, "_collect_peaks_for_group", lambda **kw: ([], [])) + monkeypatch.setattr(dt, "_assign_states_for_group", lambda **kw: None) + + states_ua, states_res = dt.build_conformational_states( + data_container=MagicMock(), + levels=levels, + groups=groups, + start=0, + end=1, + step=1, + bin_width=30.0, + ) + + assert states_ua == {} + assert len(states_res) == 2 + + +def test_identify_peaks_handles_multiple_dihedrals_and_calls_histogram_each_time(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol = MagicMock() + mol.trajectory = [0, 1] + uops.extract_fragment.return_value = mol + + angles = np.array( + [ + [-10.0, 10.0], + [20.0, -20.0], + ], + dtype=float, + ) + + class _FakeDihedral: + def __init__(self, _dihedrals): + pass + + def run(self): + return SimpleNamespace(results=SimpleNamespace(angles=angles)) + + with ( + patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral), + patch( + "CodeEntropy.levels.dihedrals.np.histogram", wraps=np.histogram + ) as hist_spy, + ): + out = dt._identify_peaks( + data_container=MagicMock(), + molecules=[0], + dihedrals=["D0", "D1"], + bin_width=180.0, + start=0, + end=2, + step=1, + ) + + assert len(out) == 2 + assert hist_spy.call_count == 2 + + +def test_assign_states_filters_out_empty_state_strings_when_no_dihedrals(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol = MagicMock() + mol.trajectory = [0, 1, 2] + uops.extract_fragment.return_value = mol + + class _FakeDihedral: + def __init__(self, _dihedrals): + pass + + def run(self): + return SimpleNamespace(results=SimpleNamespace(angles=[])) + + with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): + out = dt._assign_states( + data_container=MagicMock(), + molecules=[0], + dihedrals=[], + peaks=[], + start=0, + end=3, + step=1, + ) + + assert out == [] + + +def test_identify_peaks_multiple_molecules_real_histogram(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol0 = MagicMock() + mol0.trajectory = [0, 1] + mol1 = MagicMock() + mol1.trajectory = [0, 1] + + uops.extract_fragment.side_effect = [mol0, mol1] + + angles = np.array([[10.0], [20.0]], dtype=float) + + class _FakeDihedral: + def __init__(self, _): + pass + + def run(self): + return SimpleNamespace(results=SimpleNamespace(angles=angles)) + + with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): + peaks = dt._identify_peaks( + data_container=MagicMock(), + molecules=[0, 1], + dihedrals=["D0"], + bin_width=90.0, + start=0, + end=2, + step=1, + ) + + assert len(peaks) == 1 + + +def test_identify_peaks_real_histogram_without_spy(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol = MagicMock() + mol.trajectory = [0, 1] + uops.extract_fragment.return_value = mol + + angles = np.array([[10.0], [20.0]], dtype=float) + + class _FakeDihedral: + def __init__(self, _): + pass + + def run(self): + return SimpleNamespace(results=SimpleNamespace(angles=angles)) + + with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): + peaks = dt._identify_peaks( + data_container=MagicMock(), + molecules=[0], + dihedrals=["D0"], + bin_width=90.0, + start=0, + end=2, + step=1, + ) + + assert isinstance(peaks, list) + + +def test_assign_states_for_group_residue_nonempty_calls_assign_states(): + dt = ConformationStateBuilder(universe_operations=MagicMock()) + + states_ua = {} + states_res = [None, None] + + with patch.object(dt, "_assign_states", return_value=["A"]) as spy: + dt._assign_states_for_group( + data_container=MagicMock(), + group_id=1, + molecules=[0], + dihedrals_ua=[[]], + peaks_ua=[[]], + dihedrals_res=["D"], + peaks_res=[["p"]], + start=0, + end=1, + step=1, + level_list=["residue"], + states_ua=states_ua, + states_res=states_res, + ) + + assert states_res[1] == ["A"] + spy.assert_called_once() + + +def test_assign_states_first_empty_then_extend(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol0 = MagicMock() + mol0.trajectory = [] + mol1 = MagicMock() + mol1.trajectory = [0] + + uops.extract_fragment.side_effect = [mol0, mol1] + + angles = np.array([[10.0]], dtype=float) + + class _FakeDihedral: + def __init__(self, _): + pass + + def run(self): + return SimpleNamespace(results=SimpleNamespace(angles=angles)) + + with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): + states = dt._assign_states( + data_container=MagicMock(), + molecules=[0, 1], + dihedrals=["D0"], + peaks=[[10.0]], + start=0, + end=1, + step=1, + ) + + assert states == ["0"] + + +def test_collect_peaks_for_group_calls_identify_peaks_for_ua_and_residue(): + dt = ConformationStateBuilder(universe_operations=MagicMock()) + + dihedrals_ua = [["UA_D0"]] + dihedrals_res = ["RES_D0"] + + with patch.object( + dt, + "_identify_peaks", + side_effect=[[["ua_peak"]], [["res_peak"]]], + ) as identify_spy: + peaks_ua, peaks_res = dt._collect_peaks_for_group( + data_container=MagicMock(), + molecules=[0], + dihedrals_ua=dihedrals_ua, + dihedrals_res=dihedrals_res, + bin_width=30.0, + start=0, + end=10, + step=1, + level_list=["united_atom", "residue"], + ) + + assert peaks_ua == [[["ua_peak"]]] + assert peaks_res == [["res_peak"]] + assert identify_spy.call_count == 2 + + +def test_assign_states_wraps_negative_angles(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + mol = MagicMock() + mol.trajectory = [0, 1] + uops.extract_fragment.return_value = mol + + angles = np.array([[-10.0], [10.0]], dtype=float) + peaks = [[10.0, 350.0]] + + class _FakeDihedral: + def __init__(self, _dihedrals): + pass + + def run(self): + return SimpleNamespace(results=SimpleNamespace(angles=angles)) + + with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): + states = dt._assign_states( + data_container=MagicMock(), + molecules=[0], + dihedrals=["D0"], + peaks=peaks, + start=0, + end=2, + step=1, + ) + + assert states == ["1", "0"] + + +def test_build_conformational_states_with_progress_handles_no_groups(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + progress = MagicMock() + progress.add_task.return_value = 123 + + states_ua, states_res = dt.build_conformational_states( + data_container=MagicMock(), + levels={}, + groups={}, # empty + start=0, + end=1, + step=1, + bin_width=30.0, + progress=progress, + ) + + assert states_ua == {} + assert states_res == [] + progress.add_task.assert_called_once() + progress.update.assert_called_once_with(123, title="No groups") + progress.advance.assert_called_once_with(123) + + +def test_build_conformational_states_with_progress_skips_empty_molecule_group(): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + progress = MagicMock() + progress.add_task.return_value = 5 + + groups = {0: []} + levels = {} + + states_ua, states_res = dt.build_conformational_states( + data_container=MagicMock(), + levels=levels, + groups=groups, + start=0, + end=1, + step=1, + bin_width=30.0, + progress=progress, + ) + + assert states_ua == {} + assert len(states_res) == 1 + progress.update.assert_called_with(5, title="Group 0 (empty)") + progress.advance.assert_called_with(5) + + +def test_build_conformational_states_with_progress_updates_title_per_group(monkeypatch): + uops = MagicMock() + dt = ConformationStateBuilder(universe_operations=uops) + + progress = MagicMock() + progress.add_task.return_value = 9 + + groups = {1: [7]} + levels = {7: ["residue"]} + + uops.extract_fragment.return_value = MagicMock(trajectory=[0]) + + monkeypatch.setattr(dt, "_collect_dihedrals_for_group", lambda **kw: ([], [])) + monkeypatch.setattr(dt, "_collect_peaks_for_group", lambda **kw: ([], [])) + monkeypatch.setattr(dt, "_assign_states_for_group", lambda **kw: None) + + dt.build_conformational_states( + data_container=MagicMock(), + levels=levels, + groups=groups, + start=0, + end=1, + step=1, + bin_width=30.0, + progress=progress, + ) + + progress.update.assert_any_call(9, title="Group 1") + progress.advance.assert_called_with(9) diff --git a/tests/unit/CodeEntropy/levels/test_forces_force_torque_calculator.py b/tests/unit/CodeEntropy/levels/test_forces_force_torque_calculator.py new file mode 100644 index 00000000..d9111f80 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_forces_force_torque_calculator.py @@ -0,0 +1,238 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from CodeEntropy.levels.forces import ForceTorqueCalculator, TorqueInputs + + +def test_get_weighted_forces_applies_partitioning_when_highest_level(): + calc = ForceTorqueCalculator() + + bead = MagicMock() + bead.atoms = [ + SimpleNamespace(force=np.array([1.0, 0.0, 0.0])), + SimpleNamespace(force=np.array([0.0, 2.0, 0.0])), + ] + bead.total_mass.return_value = 4.0 + + trans_axes = np.eye(3) + + out = calc.get_weighted_forces( + bead=bead, + trans_axes=trans_axes, + highest_level=True, + force_partitioning=2.0, + ) + + np.testing.assert_allclose(out, np.array([1.0, 2.0, 0.0])) + + +def test_get_weighted_forces_no_partitioning_when_not_highest_level(): + calc = ForceTorqueCalculator() + + bead = MagicMock() + bead.atoms = [SimpleNamespace(force=np.array([2.0, 0.0, 0.0]))] + bead.total_mass.return_value = 4.0 + + out = calc.get_weighted_forces( + bead=bead, + trans_axes=np.eye(3), + highest_level=False, + force_partitioning=999.0, + ) + + np.testing.assert_allclose(out, np.array([1.0, 0.0, 0.0])) + + +def test_get_weighted_forces_raises_on_non_positive_mass(): + calc = ForceTorqueCalculator() + + bead = MagicMock() + bead.atoms = [SimpleNamespace(force=np.array([1.0, 0.0, 0.0]))] + bead.total_mass.return_value = 0.0 + + with pytest.raises(ValueError): + calc.get_weighted_forces( + bead=bead, + trans_axes=np.eye(3), + highest_level=False, + force_partitioning=1.0, + ) + + +def test_get_weighted_torques_uses_axes_manager_displacements(axes_manager_identity): + calc = ForceTorqueCalculator() + + bead = MagicMock() + bead.positions = np.array([[1.0, 0.0, 0.0]]) + bead.forces = np.array([[0.0, 1.0, 0.0]]) + + out = calc.get_weighted_torques( + bead=bead, + rot_axes=np.eye(3), + center=np.array([0.0, 0.0, 0.0]), + force_partitioning=1.0, + moment_of_inertia=np.array([4.0, 9.0, 16.0]), + axes_manager=axes_manager_identity, + box=None, + ) + + np.testing.assert_allclose(out, np.array([0.0, 0.0, 0.25])) + + +def test_get_weighted_torques_skips_zero_or_invalid_moi_components( + axes_manager_identity, +): + calc = ForceTorqueCalculator() + + bead = MagicMock() + bead.positions = np.array([[1.0, 0.0, 0.0]]) + bead.forces = np.array([[0.0, 1.0, 0.0]]) + + out = calc.get_weighted_torques( + bead=bead, + rot_axes=np.eye(3), + center=np.array([0.0, 0.0, 0.0]), + force_partitioning=1.0, + moment_of_inertia=np.array([0.0, -1.0, 16.0]), + axes_manager=axes_manager_identity, + box=None, + ) + + np.testing.assert_allclose(out, np.array([0.0, 0.0, 0.25])) + + +def test_compute_frame_covariance_outer_products(): + calc = ForceTorqueCalculator() + + f_vecs = [np.array([1.0, 0.0, 0.0]), np.array([0.0, 2.0, 0.0])] + t_vecs = [np.array([0.0, 0.0, 3.0])] + + F, T = calc.compute_frame_covariance(f_vecs, t_vecs) + + assert F.shape == (6, 6) + assert T.shape == (3, 3) + + flat_f = np.array([1.0, 0.0, 0.0, 0.0, 2.0, 0.0]) + np.testing.assert_allclose(F, np.outer(flat_f, flat_f)) + + flat_t = np.array([0.0, 0.0, 3.0]) + np.testing.assert_allclose(T, np.outer(flat_t, flat_t)) + + +def test_outer_second_moment_empty_returns_0x0(): + calc = ForceTorqueCalculator() + F, T = calc.compute_frame_covariance([], []) + assert F.shape == (0, 0) + assert T.shape == (0, 0) + + +def test_outer_second_moment_raises_if_vector_not_length_3(): + calc = ForceTorqueCalculator() + with pytest.raises(ValueError): + calc.compute_frame_covariance([np.array([1.0, 2.0])], []) + + +def test_compute_weighted_force_rejects_wrong_axes_shape(): + ft = ForceTorqueCalculator() + bead = MagicMock() + bead.atoms = [] + bead.total_mass.return_value = 10.0 + + with pytest.raises(ValueError): + ft._compute_weighted_force( + bead, + trans_axes=np.zeros((2, 2)), + apply_partitioning=False, + force_partitioning=1.0, + ) + + +def test_compute_weighted_torque_rejects_wrong_rot_axes_shape(): + ft = ForceTorqueCalculator() + bead = MagicMock() + bead.positions = np.zeros((1, 3)) + bead.forces = np.zeros((1, 3)) + + inputs = TorqueInputs( + rot_axes=np.zeros((2, 2)), + center=np.zeros(3), + moment_of_inertia=np.ones(3), + axes_manager=MagicMock(), + box=None, + force_partitioning=1.0, + ) + + with pytest.raises(ValueError): + ft._compute_weighted_torque(bead, inputs) + + +def test_compute_weighted_torque_rejects_wrong_moi_shape(): + ft = ForceTorqueCalculator() + bead = MagicMock() + bead.positions = np.zeros((1, 3)) + bead.forces = np.zeros((1, 3)) + + inputs = TorqueInputs( + rot_axes=np.eye(3), + center=np.zeros(3), + moment_of_inertia=np.ones(2), + axes_manager=MagicMock(), + box=None, + force_partitioning=1.0, + ) + + with pytest.raises(ValueError): + ft._compute_weighted_torque(bead, inputs) + + +def test_compute_weighted_torque_skips_zero_torque_and_nonpositive_moi(monkeypatch): + ft = ForceTorqueCalculator() + + bead = MagicMock() + bead.positions = np.array([[1.0, 0.0, 0.0]]) + bead.forces = np.array([[0.0, 0.0, 0.0]]) + + monkeypatch.setattr( + ft, + "_displacements_relative_to_center", + lambda **kwargs: np.array([[1.0, 0.0, 0.0]]), + ) + + inputs = TorqueInputs( + rot_axes=np.eye(3), + center=np.zeros(3), + moment_of_inertia=np.array([0.0, -1.0, 2.0]), + axes_manager=MagicMock(), + box=None, + force_partitioning=1.0, + ) + + out = ft._compute_weighted_torque(bead, inputs) + assert np.allclose(out, np.zeros(3)) + + +def test_compute_weighted_torque_skips_nonpositive_moi_components(): + calc = ForceTorqueCalculator() + + bead = SimpleNamespace( + positions=np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=float), + forces=np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=float), + ) + + inputs = SimpleNamespace( + center=np.array([0.0, 0.0, 0.0]), + rot_axes=np.eye(3), + moment_of_inertia=np.array([1.0, 0.0, -1.0], dtype=float), # triggers skips + force_partitioning=1.0, + axes_manager=None, + box=np.array([10.0, 10.0, 10.0], dtype=float), + ) + + calc._displacements_relative_to_center = lambda **kwargs: bead.positions + + weighted = calc._compute_weighted_torque(bead=bead, inputs=inputs) + + assert np.allclose(weighted, np.array([1.0, 0.0, 0.0])) diff --git a/tests/unit/CodeEntropy/levels/test_frame_graph.py b/tests/unit/CodeEntropy/levels/test_frame_graph.py new file mode 100644 index 00000000..e4490c01 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_frame_graph.py @@ -0,0 +1,50 @@ +from unittest.mock import MagicMock + +from CodeEntropy.levels.frame_dag import FrameGraph + + +def test_make_frame_ctx_has_required_keys(): + ctx = FrameGraph._make_frame_ctx(shared_data={"x": 1}, frame_index=7) + assert ctx["shared"] == {"x": 1} + assert ctx["frame_index"] == 7 + assert ctx["frame_covariance"] is None + + +def test_add_registers_node_and_deps_edges(): + fg = FrameGraph() + n1 = MagicMock() + n2 = MagicMock() + + fg._add("a", n1) + fg._add("b", n2, deps=["a"]) + + assert "a" in fg._nodes and "b" in fg._nodes + assert ("a", "b") in fg._graph.edges + + +def test_execute_frame_runs_nodes_in_topological_order_and_returns_frame_covariance(): + fg = FrameGraph() + + a = MagicMock() + b = MagicMock() + + fg._add("a", a) + fg._add("b", b, deps=["a"]) + + def _b_run(ctx): + ctx["frame_covariance"] = {"ok": True} + + b.run.side_effect = _b_run + + out = fg.execute_frame(shared_data={"S": 1}, frame_index=3) + + assert out == {"ok": True} + assert a.run.call_count == 1 + assert b.run.call_count == 1 + + +def test_build_adds_frame_covariance_node(): + fg = FrameGraph() + fg.build() + assert "frame_covariance" in fg._nodes + assert "frame_covariance" in fg._graph.nodes diff --git a/tests/unit/CodeEntropy/levels/test_hierarchy_builder.py b/tests/unit/CodeEntropy/levels/test_hierarchy_builder.py new file mode 100644 index 00000000..99628f78 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_hierarchy_builder.py @@ -0,0 +1,108 @@ +from unittest.mock import MagicMock + +import pytest + +from CodeEntropy.levels.hierarchy import HierarchyBuilder + + +def _heavy_atoms_group(n_atoms: int, n_residues: int): + heavy = MagicMock() + heavy.__len__.return_value = n_atoms + heavy.residues = [MagicMock() for _ in range(n_residues)] + return heavy + + +def test_select_levels_assigns_expected_levels(): + hb = HierarchyBuilder() + + u = MagicMock() + u.atoms = MagicMock() + frag0 = MagicMock() + frag1 = MagicMock() + frag2 = MagicMock() + + frag0.select_atoms.return_value = _heavy_atoms_group(n_atoms=1, n_residues=1) + frag1.select_atoms.return_value = _heavy_atoms_group(n_atoms=2, n_residues=1) + frag2.select_atoms.return_value = _heavy_atoms_group(n_atoms=3, n_residues=2) + + u.atoms.fragments = [frag0, frag1, frag2] + + n_mols, levels = hb.select_levels(u) + + assert n_mols == 3 + assert levels[0] == ["united_atom"] + assert levels[1] == ["united_atom", "residue"] + assert levels[2] == ["united_atom", "residue", "polymer"] + + +def test_get_beads_unknown_level_raises(): + hb = HierarchyBuilder() + with pytest.raises(ValueError): + hb.get_beads(MagicMock(), "nonsense") + + +def test_get_beads_polymer_returns_single_all_selection(): + hb = HierarchyBuilder() + mol = MagicMock() + mol.select_atoms.return_value = "ALL" + out = hb.get_beads(mol, "polymer") + assert out == ["ALL"] + mol.select_atoms.assert_called_once_with("all") + + +def test_get_beads_residue_returns_residue_atomgroups(): + hb = HierarchyBuilder() + + mol = MagicMock() + r0 = MagicMock() + r1 = MagicMock() + r0.atoms = "R0_ATOMS" + r1.atoms = "R1_ATOMS" + mol.residues = [r0, r1] + + out = hb.get_beads(mol, "residue") + assert out == ["R0_ATOMS", "R1_ATOMS"] + + +def test_get_beads_united_atom_no_heavy_atoms_falls_back_to_all(): + hb = HierarchyBuilder() + + mol = MagicMock() + heavy = MagicMock() + heavy.__len__.return_value = 0 + mol.select_atoms.side_effect = lambda sel: ( + heavy if sel == "prop mass > 1.1" else "ALL" + ) + out = hb.get_beads(mol, "united_atom") + + assert out == ["ALL"] + + +def test_get_beads_united_atom_builds_selection_per_heavy_atom(): + hb = HierarchyBuilder() + + mol = MagicMock() + h0 = MagicMock() + h1 = MagicMock() + h0.index = 7 + h1.index = 9 + + heavy = [h0, h1] + heavy_group = MagicMock() + heavy_group.__len__.return_value = 2 + heavy_group.__iter__.return_value = iter(heavy) + + def _select(sel): + if sel == "prop mass > 1.1": + return heavy_group + bead = MagicMock() + bead.__len__.return_value = 1 + return bead + + mol.select_atoms.side_effect = _select + + out = hb.get_beads(mol, "united_atom") + assert len(out) == 2 + calls = [c.args[0] for c in mol.select_atoms.call_args_list] + assert any("index 7" in s for s in calls) + assert any("index 9" in s for s in calls) diff --git a/tests/unit/CodeEntropy/levels/test_level_dag_orchestration.py b/tests/unit/CodeEntropy/levels/test_level_dag_orchestration.py new file mode 100644 index 00000000..17a5a93a --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_level_dag_orchestration.py @@ -0,0 +1,381 @@ +from unittest.mock import MagicMock, patch + +import numpy as np + +from CodeEntropy.levels.level_dag import LevelDAG + + +def _shared(): + return { + "levels": [["united_atom"]], + "frame_counts": {}, + "force_covariances": {}, + "torque_covariances": {}, + "force_counts": {}, + "torque_counts": {}, + "reduced_force_covariances": {}, + "reduced_torque_covariances": {}, + "reduced_force_counts": {}, + "reduced_torque_counts": {}, + "group_id_to_index": {0: 0}, + } + + +def test_execute_sets_default_axes_manager_once(): + dag = LevelDAG() + + shared = { + "reduced_universe": MagicMock(), + "start": 0, + "end": 0, + "step": 1, + } + + dag._run_static_stage = MagicMock() + dag._run_frame_stage = MagicMock() + + dag.execute(shared) + + assert "axes_manager" in shared + dag._run_static_stage.assert_called_once() + dag._run_frame_stage.assert_called_once() + + +def test_run_static_stage_calls_nodes_in_topological_sort_order(): + dag = LevelDAG() + dag._static_graph.add_node("a") + dag._static_graph.add_node("b") + + dag._static_nodes["a"] = MagicMock() + dag._static_nodes["b"] = MagicMock() + + with patch("networkx.topological_sort", return_value=["a", "b"]): + dag._run_static_stage({"X": 1}) + + dag._static_nodes["a"].run.assert_called_once() + dag._static_nodes["b"].run.assert_called_once() + + +def test_run_frame_stage_iterates_selected_frames_and_reduces_each(): + dag = LevelDAG() + + ts0 = MagicMock(frame=10) + ts1 = MagicMock(frame=11) + u = MagicMock() + u.trajectory = [ts0, ts1] + + shared = {"reduced_universe": u, "start": 0, "end": 2, "step": 1} + + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.side_effect = [ + { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + ] * 2 + dag._reduce_one_frame = MagicMock() + + dag._run_frame_stage(shared) + + assert dag._frame_dag.execute_frame.call_count == 2 + assert dag._reduce_one_frame.call_count == 2 + dag._frame_dag.execute_frame.assert_any_call(shared, 10) + dag._frame_dag.execute_frame.assert_any_call(shared, 11) + + +def test_incremental_mean_handles_non_copyable_values(): + out = LevelDAG._incremental_mean(old=None, new=3.0, n=1) + assert out == 3.0 + + +def test_reduce_forcetorque_no_key_is_noop(): + dag = LevelDAG() + shared = { + "forcetorque_covariances": {"res": [None], "poly": [None]}, + "forcetorque_counts": { + "res": np.zeros(1, dtype=int), + "poly": np.zeros(1, dtype=int), + }, + "group_id_to_index": {9: 0}, + } + dag._reduce_forcetorque(shared, frame_out={}) + assert shared["forcetorque_counts"]["res"][0] == 0 + assert shared["forcetorque_covariances"]["res"][0] is None + + +def test_build_registers_static_nodes_and_builds_frame_dag(): + with ( + patch("CodeEntropy.levels.level_dag.DetectMoleculesNode") as _, + patch("CodeEntropy.levels.level_dag.DetectLevelsNode") as _, + patch("CodeEntropy.levels.level_dag.BuildBeadsNode") as _, + patch("CodeEntropy.levels.level_dag.InitCovarianceAccumulatorsNode") as _, + patch("CodeEntropy.levels.level_dag.ComputeConformationalStatesNode") as _, + ): + dag = LevelDAG(universe_operations=MagicMock()) + dag._frame_dag.build = MagicMock() + + dag.build() + + assert "detect_molecules" in dag._static_nodes + assert "detect_levels" in dag._static_nodes + assert "build_beads" in dag._static_nodes + assert "init_covariance_accumulators" in dag._static_nodes + assert "compute_conformational_states" in dag._static_nodes + dag._frame_dag.build.assert_called_once() + + +def test_add_static_adds_dependency_edges(): + dag = LevelDAG() + dag._add_static("A", MagicMock()) + dag._add_static("B", MagicMock(), deps=["A"]) + + assert ("A", "B") in dag._static_graph.edges + + +def test_reduce_force_and_torque_hits_zero_count_branches(): + dag = LevelDAG() + + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + frame_out = { + "force": { + "ua": {(7, 0): np.eye(1)}, + "res": {7: np.eye(2)}, + "poly": {7: np.eye(3)}, + }, + "torque": { + "ua": {(7, 0): np.eye(1)}, + "res": {7: np.eye(2)}, + "poly": {7: np.eye(3)}, + }, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][(7, 0)] == 1 + assert (7, 0) in shared["force_covariances"]["ua"] + assert (7, 0) in shared["torque_covariances"]["ua"] + + assert shared["frame_counts"]["res"][0] == 1 + assert shared["frame_counts"]["poly"][0] == 1 + + +def test_reduce_force_and_torque_handles_empty_frame_gracefully(): + dag = LevelDAG() + + shared = { + "group_id_to_index": {0: 0}, + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + } + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared_data=shared, frame_out=frame_out) + + assert shared["force_covariances"]["ua"] == {} + assert shared["torque_covariances"]["ua"] == {} + assert shared["frame_counts"]["res"][0] == 0 + assert shared["frame_counts"]["poly"][0] == 0 + + +def test_reduce_force_and_torque_increments_res_and_poly_counts_from_zero(): + dag = LevelDAG() + + shared = { + "group_id_to_index": {7: 0}, + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + } + + F = np.eye(3) + T = np.eye(3) * 2 + + frame_out = { + "force": {"ua": {}, "res": {7: F}, "poly": {7: F}}, + "torque": {"ua": {}, "res": {7: T}, "poly": {7: T}}, + } + + dag._reduce_force_and_torque(shared_data=shared, frame_out=frame_out) + + assert shared["frame_counts"]["res"][0] == 1 + assert shared["frame_counts"]["poly"][0] == 1 + assert np.allclose(shared["torque_covariances"]["res"][0], T) + assert np.allclose(shared["torque_covariances"]["poly"][0], T) + + +def test_reduce_one_frame_skips_missing_force_and_torque_keys(): + dag = LevelDAG() + shared = _shared() + + bead_key = (0, "united_atom", 0) + frame_out = { + "beads": {bead_key: [1, 2, 3]}, + "counts": {bead_key: 1}, + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + dag._reduce_one_frame(shared_data=shared, frame_out=frame_out) + + assert shared["force_covariances"] == {} + assert shared["torque_covariances"] == {} + + +def test_reduce_force_and_torque_skips_when_counts_are_zero(): + dag = LevelDAG() + shared = _shared() + + k = (0, "united_atom", 0) + shared["force_covariances"][k] = np.eye(3) + shared["torque_covariances"][k] = np.eye(3) + shared["force_counts"][k] = 0 + shared["torque_counts"][k] = 0 + shared["frame_counts"][k] = 0 + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + "beads": {}, + } + + dag._reduce_force_and_torque(shared_data=shared, frame_out=frame_out) + + assert shared["reduced_force_covariances"] == {} + assert shared["reduced_torque_covariances"] == {} + assert shared["reduced_force_counts"] == {} + assert shared["reduced_torque_counts"] == {} + + +def test_run_static_stage_forwards_progress_when_node_accepts_it(): + dag = LevelDAG() + dag._static_graph.add_node("a") + + node = MagicMock() + dag._static_nodes["a"] = node + + progress = MagicMock() + + with patch("networkx.topological_sort", return_value=["a"]): + dag._run_static_stage({"X": 1}, progress=progress) + + node.run.assert_called_once_with({"X": 1}, progress=progress) + + +def test_run_static_stage_falls_back_when_node_does_not_accept_progress(): + dag = LevelDAG() + dag._static_graph.add_node("a") + + class NoProgressNode: + def run(self, shared_data): + return None + + dag._static_nodes["a"] = NoProgressNode() + progress = MagicMock() + + with patch("networkx.topological_sort", return_value=["a"]): + dag._run_static_stage({"X": 1}, progress=progress) # should not raise + + +def test_run_frame_stage_with_progress_creates_task_and_updates_titles(): + dag = LevelDAG() + + ts0 = MagicMock(frame=10) + ts1 = MagicMock(frame=11) + u = MagicMock() + u.trajectory = [ts0, ts1] + + shared = {"reduced_universe": u, "start": 0, "end": 2, "step": 1} + + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.return_value = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + dag._reduce_one_frame = MagicMock() + + progress = MagicMock() + progress.add_task.return_value = 77 + + dag._run_frame_stage(shared, progress=progress) + + progress.add_task.assert_called_once() + progress.update.assert_any_call(77, title="Frame 10") + progress.update.assert_any_call(77, title="Frame 11") + assert progress.advance.call_count == 2 + + +def test_run_frame_stage_with_negative_end_computes_total_frames(): + dag = LevelDAG() + + ts_list = [MagicMock(frame=i) for i in range(10)] + u = MagicMock() + u.trajectory = ts_list + + shared = { + "reduced_universe": u, + "start": 0, + "end": -1, + "step": 1, + } + + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.return_value = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + dag._reduce_one_frame = MagicMock() + + progress = MagicMock() + progress.add_task.return_value = 123 + + dag._run_frame_stage(shared, progress=progress) + + progress.add_task.assert_called_once() + _, kwargs = progress.add_task.call_args + assert kwargs["total"] == 9 + + assert progress.advance.call_count == 9 + + +def test_run_frame_stage_progress_total_frames_falls_back_to_none_on_error(): + + dag = LevelDAG() + + class BadTrajectory: + def __len__(self): + raise RuntimeError("boom") + + def __getitem__(self, item): + return [] + + u = type("U", (), {})() + u.trajectory = BadTrajectory() + + shared = { + "reduced_universe": u, + "start": 0, + "end": 10, + "step": 1, + } + + dag._frame_dag = MagicMock() + dag._reduce_one_frame = MagicMock() + + progress = MagicMock() + progress.add_task.return_value = 99 + + dag._run_frame_stage(shared, progress=progress) + + _, kwargs = progress.add_task.call_args + assert kwargs["total"] is None diff --git a/tests/unit/CodeEntropy/levels/test_level_dag_reduce.py b/tests/unit/CodeEntropy/levels/test_level_dag_reduce.py new file mode 100644 index 00000000..cb14bc1a --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_level_dag_reduce.py @@ -0,0 +1,208 @@ +import numpy as np + +from CodeEntropy.levels.level_dag import LevelDAG + + +def test_incremental_mean_first_sample_copies(): + x = np.array([1.0, 2.0]) + out = LevelDAG._incremental_mean(None, x, n=1) + assert np.allclose(out, x) + x[0] = 999.0 + assert out[0] != 999.0 + + +def test_reduce_force_and_torque_exercises_count_branches(): + dag = LevelDAG() + + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + frame_out = { + "force": { + "ua": {(9, 0): np.array([1.0])}, + "res": {7: np.array([2.0])}, + "poly": {7: np.array([3.0])}, + }, + "torque": { + "ua": {(9, 0): np.array([4.0])}, + "res": {7: np.array([5.0])}, + "poly": {7: np.array([6.0])}, + }, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert (9, 0) in shared["torque_covariances"]["ua"] + assert shared["frame_counts"]["res"][0] == 1 + assert shared["frame_counts"]["poly"][0] == 1 + + +def test_reduce_forcetorque_returns_when_missing_key(): + dag = LevelDAG() + shared = { + "forcetorque_covariances": {"res": [None], "poly": [None]}, + "forcetorque_counts": {"res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + dag._reduce_forcetorque(shared, frame_out={}) + assert shared["forcetorque_counts"]["res"][0] == 0 + + +def test_reduce_forcetorque_updates_res_and_poly(): + dag = LevelDAG() + + shared = { + "forcetorque_covariances": {"res": [None], "poly": [None]}, + "forcetorque_counts": {"res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + frame_out = { + "forcetorque": { + "res": {7: np.array([1.0, 1.0])}, + "poly": {7: np.array([2.0, 2.0])}, + } + } + + dag._reduce_forcetorque(shared, frame_out) + + assert shared["forcetorque_counts"]["res"][0] == 1 + assert shared["forcetorque_counts"]["poly"][0] == 1 + assert shared["forcetorque_covariances"]["res"][0] is not None + assert shared["forcetorque_covariances"]["poly"][0] is not None + + +def test_reduce_force_and_torque_res_torque_increments_when_res_count_is_zero(): + dag = LevelDAG() + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {7: np.eye(3)}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["res"][0] == 1 + assert shared["torque_covariances"]["res"][0] is not None + + +def test_reduce_force_and_torque_poly_torque_increments_when_poly_count_is_zero(): + dag = LevelDAG() + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {7: np.eye(3)}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["poly"][0] == 1 + assert shared["torque_covariances"]["poly"][0] is not None + + +def test_reduce_force_and_torque_increments_ua_frame_counts_for_force(): + dag = LevelDAG() + + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + k = (9, 0) + frame_out = { + "force": {"ua": {k: np.eye(3)}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][k] == 1 + assert k in shared["force_covariances"]["ua"] + + +def test_reduce_force_and_torque_increments_ua_counts_from_zero(): + dag = LevelDAG() + + key = (9, 0) + F = np.eye(3) + + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + frame_out = { + "force": {"ua": {key: F}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][key] == 1 + + np.testing.assert_array_equal(shared["force_covariances"]["ua"][key], F) + + +def test_reduce_force_and_torque_hits_ua_force_count_increment_line(): + dag = LevelDAG() + key = (9, 0) + + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + frame_out = { + "force": {"ua": {key: np.eye(3)}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][key] == 1 + + +def test_reduce_force_and_torque_ua_torque_increments_count_when_force_missing_key(): + dag = LevelDAG() + + key = (9, 0) + T = np.eye(3) + + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, + "group_id_to_index": {7: 0}, + } + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {key: T}, "res": {}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][key] == 1 + np.testing.assert_array_equal(shared["torque_covariances"]["ua"][key], T) diff --git a/tests/unit/CodeEntropy/levels/test_level_dag_reduction.py b/tests/unit/CodeEntropy/levels/test_level_dag_reduction.py new file mode 100644 index 00000000..5c8c4716 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_level_dag_reduction.py @@ -0,0 +1,91 @@ +from unittest.mock import MagicMock + +import numpy as np + +from CodeEntropy.levels.level_dag import LevelDAG + + +def test_incremental_mean_none_returns_copy_for_numpy(): + arr = np.array([1.0, 2.0]) + out = LevelDAG._incremental_mean(None, arr, n=1) + assert np.all(out == arr) + arr[0] = 999.0 + assert out[0] != 999.0 + + +def test_incremental_mean_updates_mean_correctly(): + old = np.array([2.0, 2.0]) + new = np.array([4.0, 0.0]) + out = LevelDAG._incremental_mean(old, new, n=2) + np.testing.assert_allclose(out, np.array([3.0, 1.0])) + + +def test_reduce_force_and_torque_updates_counts_and_means(): + dag = LevelDAG() + + shared = { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": { + "ua": {}, + "res": np.zeros(1, dtype=int), + "poly": np.zeros(1, dtype=int), + }, + "group_id_to_index": {9: 0}, + } + + F1 = np.eye(3) + T1 = 2.0 * np.eye(3) + + frame_out = { + "force": {"ua": {(0, 0): F1}, "res": {9: F1}, "poly": {}}, + "torque": {"ua": {(0, 0): T1}, "res": {9: T1}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][(0, 0)] == 1 + np.testing.assert_allclose(shared["force_covariances"]["ua"][(0, 0)], F1) + np.testing.assert_allclose(shared["torque_covariances"]["ua"][(0, 0)], T1) + + assert shared["frame_counts"]["res"][0] == 1 + np.testing.assert_allclose(shared["force_covariances"]["res"][0], F1) + np.testing.assert_allclose(shared["torque_covariances"]["res"][0], T1) + + +def test_reduce_forcetorque_no_key_is_noop(): + dag = LevelDAG() + shared = { + "forcetorque_covariances": {"res": [None], "poly": [None]}, + "forcetorque_counts": { + "res": np.zeros(1, dtype=int), + "poly": np.zeros(1, dtype=int), + }, + "group_id_to_index": {9: 0}, + } + + dag._reduce_forcetorque(shared, frame_out={}) + assert shared["forcetorque_counts"]["res"][0] == 0 + assert shared["forcetorque_covariances"]["res"][0] is None + + +def test_run_frame_stage_calls_execute_frame_for_each_ts(simple_ts_list): + dag = LevelDAG() + + u = MagicMock() + u.trajectory = simple_ts_list + + shared = {"reduced_universe": u, "start": 0, "end": 3, "step": 1} + + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.side_effect = lambda shared_data, frame_index: { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + dag._reduce_one_frame = MagicMock() + + dag._run_frame_stage(shared) + + assert dag._frame_dag.execute_frame.call_count == 3 + assert dag._reduce_one_frame.call_count == 3 diff --git a/tests/unit/CodeEntropy/levels/test_linalg_matrix_utils.py b/tests/unit/CodeEntropy/levels/test_linalg_matrix_utils.py new file mode 100644 index 00000000..d49b6b11 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_linalg_matrix_utils.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest + +from CodeEntropy.levels.linalg import MatrixUtils + + +def test_create_submatrix_outer_product_correct(): + mu = MatrixUtils() + a = np.array([1.0, 2.0, 3.0]) + b = np.array([4.0, 5.0, 6.0]) + + out = mu.create_submatrix(a, b) + + assert out.shape == (3, 3) + np.testing.assert_allclose(out, np.outer(a, b)) + + +def test_create_submatrix_rejects_non_3_vectors(): + mu = MatrixUtils() + with pytest.raises(ValueError): + mu.create_submatrix(np.array([1.0, 2.0]), np.array([1.0, 2.0, 3.0])) + + +def test_filter_zero_rows_columns_removes_all_zero_rows_and_cols(): + mu = MatrixUtils() + mat = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + out = mu.filter_zero_rows_columns(mat) + + assert out.shape == (1, 1) + assert out[0, 0] == 2.0 + + +def test_filter_zero_rows_columns_uses_atol(): + mu = MatrixUtils() + mat = np.array( + [ + [1e-9, 0.0], + [0.0, 1.0], + ] + ) + + out = mu.filter_zero_rows_columns(mat, atol=1e-8) + assert out.shape == (1, 1) + assert out[0, 0] == 1.0 + + +def test_filter_zero_rows_columns_rejects_non_2d(): + mu = MatrixUtils() + with pytest.raises(ValueError): + mu.filter_zero_rows_columns(np.array([1.0, 2.0, 3.0])) diff --git a/tests/unit/CodeEntropy/levels/test_mda_universe_operations.py b/tests/unit/CodeEntropy/levels/test_mda_universe_operations.py new file mode 100644 index 00000000..98e3a7d7 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_mda_universe_operations.py @@ -0,0 +1,217 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from MDAnalysis.exceptions import NoDataError + +from CodeEntropy.levels.mda import UniverseOperations + + +class _FakeAF: + """Fake AnalysisFromFunction that avoids MDAnalysis trajectory requirements.""" + + def __init__(self, func, atomgroup): + self._func = func + self._ag = atomgroup + self.results = {} + + def run(self): + self.results["timeseries"] = self._func(self._ag) + return self + + +def test_extract_timeseries_unknown_kind_raises(): + ops = UniverseOperations() + with pytest.raises(ValueError): + ops._extract_timeseries(MagicMock(), kind="nope") + + +def test_extract_force_timeseries_fallback_to_positions_when_no_forces(): + ops = UniverseOperations() + ag_force = MagicMock() + + def _extract(atomgroup, *, kind): + if kind == "forces": + raise NoDataError("no forces") + return np.ones((2, 3, 3)) + + ops._extract_timeseries = MagicMock(side_effect=_extract) + + out = ops._extract_force_timeseries_with_fallback( + ag_force, fallback_to_positions_if_no_forces=True + ) + assert out.shape == (2, 3, 3) + + +def test_extract_force_timeseries_raises_when_no_fallback(): + ops = UniverseOperations() + ops._extract_timeseries = MagicMock(side_effect=NoDataError("no forces")) + + with pytest.raises(NoDataError): + ops._extract_force_timeseries_with_fallback( + MagicMock(), fallback_to_positions_if_no_forces=False + ) + + +def test_select_frames_defaults_start_end_and_slices(monkeypatch): + ops = UniverseOperations() + + u = MagicMock() + u.trajectory = list(range(10)) + u.select_atoms.return_value = MagicMock() + + # timeseries arrays + ops._extract_timeseries = MagicMock( + side_effect=[ + np.zeros((10, 2, 3)), # positions + np.ones((10, 2, 3)), # forces + np.zeros((10, 6)), # dimensions + ] + ) + + merged = MagicMock() + merged.load_new = MagicMock() + monkeypatch.setattr("CodeEntropy.levels.mda.mda.Merge", lambda ag: merged) + + out = ops.select_frames(u, start=None, end=None, step=2) + + assert out is merged + merged.load_new.assert_called_once() + + +def test_merge_forces_scales_kcal(monkeypatch): + ops = UniverseOperations() + + u = MagicMock() + u.select_atoms.return_value = MagicMock() + u_force = MagicMock() + u_force.select_atoms.return_value = MagicMock() + + monkeypatch.setattr( + "CodeEntropy.levels.mda.mda.Universe", MagicMock(side_effect=[u, u_force]) + ) + + ops._extract_timeseries = MagicMock( + side_effect=[ + np.zeros((2, 2, 3)), # coordinates + np.zeros((2, 6)), # dimensions + ] + ) + + forces = np.ones((2, 2, 3), dtype=float) + ops._extract_force_timeseries_with_fallback = MagicMock(return_value=forces) + + merged = MagicMock() + merged.load_new = MagicMock() + monkeypatch.setattr("CodeEntropy.levels.mda.mda.Merge", lambda ag: merged) + + out = ops.merge_forces( + tprfile="tpr", + trrfile="trr", + forcefile="force.trr", + fileformat=None, + kcal=True, + ) + + assert out is merged + assert np.allclose(forces, np.ones((2, 2, 3)) * 4.184) + + +def test_select_atoms_builds_merged_universe_and_loads_timeseries(monkeypatch): + ops = UniverseOperations() + + u = MagicMock() + sel = MagicMock() + u.select_atoms.return_value = sel + + monkeypatch.setattr( + ops, + "_extract_timeseries", + lambda _sel, kind: np.zeros((2, 3)) if kind == "positions" else np.zeros((2,)), + ) + + merged = MagicMock() + with ( + patch("CodeEntropy.levels.mda.mda.Merge", return_value=merged) as MergeCls, + patch("CodeEntropy.levels.mda.MemoryReader"), + ): + out = ops.select_atoms(u, "name CA") + + u.select_atoms.assert_called_once_with("name CA", updating=True) + MergeCls.assert_called_once_with(sel) + merged.load_new.assert_called_once() + assert out is merged + + +def test_extract_fragment_selects_by_resindices(monkeypatch): + u = MagicMock() + frag0 = MagicMock() + frag0.indices = np.array([10, 11, 12], dtype=int) + + u.atoms.fragments = [frag0] + + uops = UniverseOperations() + + select_spy = MagicMock(return_value="FRAG") + monkeypatch.setattr(uops, "select_atoms", select_spy) + + out = uops.extract_fragment(u, molecule_id=0) + + assert out == "FRAG" + select_spy.assert_called_once_with(u, "index 10:12") + + +def test_extract_timeseries_kind_positions_returns_xyz_array(): + uops = UniverseOperations() + + ag = MagicMock() + ag.positions = np.array([[1.0, 2.0, 3.0]], dtype=float) + + class _FakeAnalysisFromFunction: + def __init__(self, func, atomgroup): + self.func = func + self.atomgroup = atomgroup + + def run(self): + return SimpleNamespace(results={"timeseries": self.func(self.atomgroup)}) + + with patch( + "CodeEntropy.levels.mda.AnalysisFromFunction", _FakeAnalysisFromFunction + ): + out = uops._extract_timeseries(atomgroup=ag, kind="positions") + + assert out.shape == (1, 3) + assert np.allclose(out, np.array([[1.0, 2.0, 3.0]])) + + +def test_extract_timeseries_invalid_kind_raises_value_error(): + uops = UniverseOperations() + ag = MagicMock() + + with pytest.raises(ValueError): + uops._extract_timeseries(atomgroup=ag, kind="not-a-kind") + + +def test_extract_timeseries_forces_branch_uses_forces_copy(): + uops = UniverseOperations() + + ag = MagicMock() + ag.forces = np.array([[1.0, 2.0, 3.0]], dtype=float) + + with patch("CodeEntropy.levels.mda.AnalysisFromFunction", _FakeAF): + out = uops._extract_timeseries(ag, kind="forces") + + assert np.allclose(out, ag.forces) + + +def test_extract_timeseries_dimensions_branch_uses_dimensions_copy(): + uops = UniverseOperations() + + ag = MagicMock() + ag.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90], dtype=float) + + with patch("CodeEntropy.levels.mda.AnalysisFromFunction", _FakeAF): + out = uops._extract_timeseries(ag, kind="dimensions") + + assert np.allclose(out, ag.dimensions) diff --git a/tests/unit/CodeEntropy/molecules/test_grouping.py b/tests/unit/CodeEntropy/molecules/test_grouping.py new file mode 100644 index 00000000..53312585 --- /dev/null +++ b/tests/unit/CodeEntropy/molecules/test_grouping.py @@ -0,0 +1,127 @@ +import logging +from unittest.mock import MagicMock + +import pytest + +from CodeEntropy.molecules.grouping import MoleculeGrouper + + +def _universe_with_fragments(fragments): + u = MagicMock() + u.atoms.fragments = fragments + return u + + +def _fragment(names): + f = MagicMock() + f.names = names + return f + + +def test_get_strategy_returns_each_callable(): + g = MoleculeGrouper() + fn = g._get_strategy("each") + assert callable(fn) + assert fn == g._group_each + + +def test_get_strategy_returns_molecules_callable(): + g = MoleculeGrouper() + fn = g._get_strategy("molecules") + assert callable(fn) + assert fn == g._group_by_signature + + +def test_get_strategy_raises_value_error_for_unknown(): + g = MoleculeGrouper() + with pytest.raises(ValueError, match="Unknown grouping strategy"): + g._get_strategy("nope") + + +def test_fragments_returns_universe_fragments(): + g = MoleculeGrouper() + frags = [MagicMock(), MagicMock()] + u = _universe_with_fragments(frags) + assert g._fragments(u) is frags + + +def test_num_molecules_counts_fragments_length(): + g = MoleculeGrouper() + u = _universe_with_fragments([MagicMock(), MagicMock(), MagicMock()]) + assert g._num_molecules(u) == 3 + + +def test_group_each_returns_one_group_per_molecule(): + g = MoleculeGrouper() + u = _universe_with_fragments([MagicMock(), MagicMock(), MagicMock()]) + assert g._group_each(u) == {0: [0], 1: [1], 2: [2]} + + +def test_signature_uses_atom_count_and_ordered_names(): + g = MoleculeGrouper() + frag = _fragment(["H", "O", "H"]) + assert g._signature(frag) == (3, ("H", "O", "H")) + + +def test_representative_id_first_seen_sets_rep_and_returns_candidate(): + g = MoleculeGrouper() + cache = {} + sig = (3, ("H", "O", "H")) + rep = g._representative_id(cache, sig, candidate_id=5) + assert rep == 5 + assert cache[sig] == 5 + + +def test_representative_id_returns_existing_rep_when_seen_before(): + g = MoleculeGrouper() + sig = (3, ("H", "O", "H")) + cache = {sig: 2} + rep = g._representative_id(cache, sig, candidate_id=9) + assert rep == 2 + assert cache[sig] == 2 + + +def test_group_by_signature_groups_identical_signatures_and_uses_first_id_as_group_id(): + g = MoleculeGrouper() + f0 = _fragment(["H", "O", "H"]) + f1 = _fragment(["H", "O", "H"]) # same signature as f0 + f2 = _fragment(["C", "C", "H", "H"]) # different signature + u = _universe_with_fragments([f0, f1, f2]) + + out = g._group_by_signature(u) + + assert out == {0: [0, 1], 2: [2]} + + +def test_group_by_signature_is_deterministic_for_first_seen_representative(): + g = MoleculeGrouper() + f0 = _fragment(["B"]) + f1 = _fragment(["A"]) + f2 = _fragment(["B"]) + u = _universe_with_fragments([f0, f1, f2]) + + out = g._group_by_signature(u) + + assert out[0] == [0, 2] + assert out[1] == [1] + + +def test_grouping_molecules_dispatches_each_and_logs_summary(caplog): + g = MoleculeGrouper() + u = _universe_with_fragments([MagicMock(), MagicMock()]) + + caplog.set_level(logging.INFO) + out = g.grouping_molecules(u, "each") + + assert out == {0: [0], 1: [1]} + + +def test_grouping_molecules_dispatches_molecules_strategy(): + g = MoleculeGrouper() + f0 = _fragment(["H", "O", "H"]) + f1 = _fragment(["H", "O", "H"]) + u = _universe_with_fragments([f0, f1]) + + out = g.grouping_molecules(u, "molecules") + + assert out == {0: [0, 1]} diff --git a/tests/unit/CodeEntropy/results/test_reporter.py b/tests/unit/CodeEntropy/results/test_reporter.py new file mode 100644 index 00000000..f21b0148 --- /dev/null +++ b/tests/unit/CodeEntropy/results/test_reporter.py @@ -0,0 +1,549 @@ +import json +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pandas as pd +from rich.console import Console + +import CodeEntropy.results.reporter as reporter_mod +from CodeEntropy.results.reporter import ResultsReporter, _RichProgressSink + + +class FakeTable: + def __init__(self, title=None, show_lines=None, expand=None): + self.title = title + self.columns = [] + self.rows = [] + + def add_column(self, *args, **kwargs): + self.columns.append((args, kwargs)) + + def add_row(self, *args, **kwargs): + self.rows.append((args, kwargs)) + + +def test_init_uses_provided_console(): + c = Console() + rr = ResultsReporter(console=c) + assert rr.console is c + + +def test_clean_residue_name_removes_dash_like(): + assert ResultsReporter.clean_residue_name("ALA-GLY") == "ALAGLY" + assert ResultsReporter.clean_residue_name("ALA–GLY") == "ALAGLY" + assert ResultsReporter.clean_residue_name("ALA—GLY") == "ALAGLY" + assert ResultsReporter.clean_residue_name(123) == "123" + + +def test_add_results_data_appends(): + rr = ResultsReporter() + rr.add_results_data(group_id=1, level="L", entropy_type="T", value=1.23) + assert rr.molecule_data == [(1, "L", "T", 1.23)] + + +def test_add_residue_data_converts_ndarray_frame_count_to_list(): + rr = ResultsReporter() + rr.add_residue_data( + group_id=1, + resname="ALA-1", + level="L", + entropy_type="T", + frame_count=np.array([1, 2, 3]), + value=9.0, + ) + assert rr.residue_data == [[1, "ALA1", "L", "T", [1, 2, 3], 9.0]] + + +def test_add_residue_data_keeps_scalar_frame_count(): + rr = ResultsReporter() + rr.add_residue_data( + group_id=1, + resname="ALA-1", + level="L", + entropy_type="T", + frame_count=7, + value=9.0, + ) + assert rr.residue_data == [[1, "ALA1", "L", "T", 7, 9.0]] + + +def test_add_group_label_stores_metadata(): + rr = ResultsReporter() + rr.add_group_label(1, "protein", residue_count=10, atom_count=100) + assert rr.group_labels[1]["label"] == "protein" + assert rr.group_labels[1]["residue_count"] == 10 + assert rr.group_labels[1]["atom_count"] == 100 + + +def test_gid_sort_key_numeric_before_string_and_numeric_order(): + rr = ResultsReporter() + gids = ["10", "2", "A", "1", "B"] + out = sorted(gids, key=rr._gid_sort_key) + assert out == ["1", "2", "10", "A", "B"] + + +def test_safe_float_valid_invalid(): + assert ResultsReporter._safe_float("1.25") == 1.25 + assert ResultsReporter._safe_float(3) == 3.0 + assert ResultsReporter._safe_float("bad") is None + assert ResultsReporter._safe_float(None) is None + assert ResultsReporter._safe_float(True) is None + + +def test_build_grouped_payload_components_and_total_from_sum(monkeypatch): + rr = ResultsReporter() + mol = pd.DataFrame( + [ + {"Group ID": 1, "Level": "Trans", "Type": "A", "Result (J/mol/K)": 1.0}, + {"Group ID": 1, "Level": "Rovib", "Type": "B", "Result (J/mol/K)": 2.0}, + ] + ) + res = pd.DataFrame([]) + monkeypatch.setattr( + ResultsReporter, "_provenance", staticmethod(lambda: {"git_sha": None}) + ) + payload = rr._build_grouped_payload( + molecule_df=mol, residue_df=res, args=None, include_raw_tables=False + ) + comps = payload["groups"]["1"]["components"] + assert comps == {"Trans:A": 1.0, "Rovib:B": 2.0} or comps == { + "Rovib:B": 2.0, + "Trans:A": 1.0, + } + assert payload["groups"]["1"]["total"] == 3.0 + + +def test_build_grouped_payload_prefers_explicit_total(monkeypatch): + rr = ResultsReporter() + mol = pd.DataFrame( + [ + {"Group ID": 1, "Level": "Trans", "Type": "A", "Result (J/mol/K)": 1.0}, + { + "Group ID": 1, + "Level": "Group Total", + "Type": "Total", + "Result (J/mol/K)": 99.0, + }, + {"Group ID": 1, "Level": "Rovib", "Type": "B", "Result (J/mol/K)": 2.0}, + ] + ) + res = pd.DataFrame([]) + monkeypatch.setattr( + ResultsReporter, "_provenance", staticmethod(lambda: {"git_sha": None}) + ) + payload = rr._build_grouped_payload( + molecule_df=mol, residue_df=res, args=None, include_raw_tables=False + ) + assert payload["groups"]["1"]["total"] == 99.0 + assert payload["groups"]["1"]["components"]["Trans:A"] == 1.0 + assert payload["groups"]["1"]["components"]["Rovib:B"] == 2.0 + + +def test_build_grouped_payload_skips_non_numeric_results(monkeypatch): + rr = ResultsReporter() + mol = pd.DataFrame( + [ + {"Group ID": 1, "Level": "Trans", "Type": "A", "Result (J/mol/K)": 1.0}, + {"Group ID": 1, "Level": "Rovib", "Type": "B", "Result (J/mol/K)": "bad"}, + { + "Group ID": 1, + "Level": "Group Total", + "Type": "Total", + "Result (J/mol/K)": None, + }, + ] + ) + res = pd.DataFrame([]) + monkeypatch.setattr( + ResultsReporter, "_provenance", staticmethod(lambda: {"git_sha": None}) + ) + payload = rr._build_grouped_payload( + molecule_df=mol, residue_df=res, args=None, include_raw_tables=False + ) + assert payload["groups"]["1"]["components"] == {"Trans:A": 1.0} + assert payload["groups"]["1"]["total"] == 1.0 + + +def test_build_grouped_payload_invalid_total_row_skipped(monkeypatch): + rr = ResultsReporter() + mol = pd.DataFrame( + [ + { + "Group ID": 1, + "Level": "Group Total", + "Type": "Total", + "Result (J/mol/K)": "bad", + }, + ] + ) + res = pd.DataFrame([]) + monkeypatch.setattr( + ResultsReporter, "_provenance", staticmethod(lambda: {"git_sha": None}) + ) + payload = rr._build_grouped_payload( + molecule_df=mol, residue_df=res, args=None, include_raw_tables=False + ) + assert payload["groups"] == {} + + +def test_build_grouped_payload_include_raw_tables(monkeypatch): + rr = ResultsReporter() + mol = pd.DataFrame( + [{"Group ID": 1, "Level": "Trans", "Type": "A", "Result (J/mol/K)": 1.0}] + ) + res = pd.DataFrame([{"Group ID": 1, "Residue": "ALA", "Result": 0.5}]) + monkeypatch.setattr( + ResultsReporter, "_provenance", staticmethod(lambda: {"git_sha": None}) + ) + payload = rr._build_grouped_payload( + molecule_df=mol, residue_df=res, args=None, include_raw_tables=True + ) + assert "molecule_data" in payload + assert "residue_data" in payload + assert payload["molecule_data"][0]["Group ID"] == 1 + assert payload["residue_data"][0]["Group ID"] == 1 + + +def test_serialize_args_dict_converts_ndarray_and_path(): + p = Path("x/y") + args = {"arr": np.array([1, 2]), "p": p, "n": 3} + out = ResultsReporter._serialize_args(args) + assert out == {"arr": [1, 2], "p": str(p), "n": 3} + + +def test_serialize_args_namespace_converts_types(): + ns = SimpleNamespace(a=np.array([1]), b=Path("z")) + assert ResultsReporter._serialize_args(ns) == {"a": [1], "b": "z"} + + +def test_serialize_args_falls_back_to_dict_protocol(): + class PairIterable: + def __iter__(self): + return iter([("k", 1)]) + + assert ResultsReporter._serialize_args(PairIterable()) == {"k": 1} + + +def test_serialize_args_unserializable_returns_empty(): + class Unserializable: + __slots__ = () + + assert ResultsReporter._serialize_args(Unserializable()) == {} + + +def test_provenance_sets_version_none_on_failure(monkeypatch): + import importlib.metadata + + monkeypatch.setattr( + importlib.metadata, + "version", + lambda _: (_ for _ in ()).throw(Exception("nope")), + ) + monkeypatch.setattr(ResultsReporter, "_try_get_git_sha", staticmethod(lambda: None)) + prov = ResultsReporter._provenance() + assert "python" in prov + assert "platform" in prov + assert prov["codeentropy_version"] is None + assert prov["git_sha"] is None + + +def test_provenance_sets_version_on_success(monkeypatch): + import importlib.metadata + + monkeypatch.setattr(importlib.metadata, "version", lambda _: "9.9.9") + monkeypatch.setattr( + ResultsReporter, "_try_get_git_sha", staticmethod(lambda: "sha") + ) + prov = ResultsReporter._provenance() + assert prov["codeentropy_version"] == "9.9.9" + assert prov["git_sha"] == "sha" + + +def test_try_get_git_sha_env_override(monkeypatch): + monkeypatch.setenv("CODEENTROPY_GIT_SHA", "abc123") + assert ResultsReporter._try_get_git_sha() == "abc123" + + +def test_try_get_git_sha_subprocess_success(monkeypatch, tmp_path): + monkeypatch.delenv("CODEENTROPY_GIT_SHA", raising=False) + fake_file = tmp_path / "a" / "b" / "c.py" + fake_file.parent.mkdir(parents=True, exist_ok=True) + (tmp_path / "a" / ".git").mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(reporter_mod, "__file__", str(fake_file)) + mock_run = MagicMock() + mock_run.return_value = SimpleNamespace( + returncode=0, stdout="deadbeef\n", stderr="" + ) + monkeypatch.setattr(reporter_mod.subprocess, "run", mock_run) + assert ResultsReporter._try_get_git_sha() == "deadbeef" + + +def test_try_get_git_sha_subprocess_failure(monkeypatch, tmp_path): + monkeypatch.delenv("CODEENTROPY_GIT_SHA", raising=False) + fake_file = tmp_path / "a" / "b" / "c.py" + fake_file.parent.mkdir(parents=True, exist_ok=True) + (tmp_path / "a" / ".git").mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(reporter_mod, "__file__", str(fake_file)) + mock_run = MagicMock() + mock_run.return_value = SimpleNamespace(returncode=1, stdout="", stderr="err") + monkeypatch.setattr(reporter_mod.subprocess, "run", mock_run) + assert ResultsReporter._try_get_git_sha() is None + + +def test_try_get_git_sha_returns_none_when_no_git_anywhere(monkeypatch, tmp_path): + monkeypatch.delenv("CODEENTROPY_GIT_SHA", raising=False) + fake_file = tmp_path / "a" / "b" / "c.py" + fake_file.parent.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(reporter_mod, "__file__", str(fake_file)) + monkeypatch.setattr( + reporter_mod.subprocess, + "run", + lambda *a, **k: (_ for _ in ()).throw( + AssertionError("subprocess.run should not be called") + ), + ) + assert ResultsReporter._try_get_git_sha() is None + + +def test_try_get_git_sha_executes_subprocess_kwargs_block(monkeypatch, tmp_path): + monkeypatch.delenv("CODEENTROPY_GIT_SHA", raising=False) + fake_file = tmp_path / "a" / "b" / "c.py" + fake_file.parent.mkdir(parents=True, exist_ok=True) + (tmp_path / "a" / ".git").mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(reporter_mod, "__file__", str(fake_file)) + + original_resolve = reporter_mod.Path.resolve + + def fake_resolve(self): + if str(self) == str(fake_file): + return fake_file + return original_resolve(self) + + monkeypatch.setattr(reporter_mod.Path, "resolve", fake_resolve) + + mock_run = MagicMock() + mock_run.return_value = SimpleNamespace(returncode=0, stdout="sha\n", stderr="") + monkeypatch.setattr(reporter_mod.subprocess, "run", mock_run) + + assert ResultsReporter._try_get_git_sha() == "sha" + _args, kwargs = mock_run.call_args + assert "stdout" in kwargs + assert "stderr" in kwargs + assert kwargs.get("text") is True + + +def test_log_grouped_results_tables_hits_non_total_add_row(monkeypatch): + rr = ResultsReporter() + rr.add_results_data("1", "AAA", "BBB", 123.0) + printed = [] + monkeypatch.setattr(reporter_mod, "Table", FakeTable) + monkeypatch.setattr(reporter_mod.console, "print", lambda t: printed.append(t)) + rr._log_grouped_results_tables() + assert len(printed) == 1 + assert printed[0].rows == [(("AAA", "BBB", "123.0"), {})] + + +def test_log_grouped_results_tables_prints_in_sorted_gid_order(monkeypatch): + rr = ResultsReporter() + rr.add_results_data("10", "L", "T", 1.0) + rr.add_results_data("2", "L", "T", 2.0) + rr.add_results_data("A", "L", "T", 3.0) + printed_titles = [] + monkeypatch.setattr( + reporter_mod.console, + "print", + lambda obj: printed_titles.append(getattr(obj, "title", None)), + ) + rr._log_grouped_results_tables() + assert printed_titles[0].startswith("Entropy Results — Group 2") + assert printed_titles[1].startswith("Entropy Results — Group 10") + assert printed_titles[2].startswith("Entropy Results — Group A") + + +def test_log_residue_table_grouped_prints_table(monkeypatch): + rr = ResultsReporter() + rr.add_group_label("2", "ResidLabel") + rr.residue_data.append(["2", "ALA", "LevelX", "TypeY", 10, 0.5]) + printed = [] + monkeypatch.setattr(reporter_mod.console, "print", lambda obj: printed.append(obj)) + rr._log_residue_table_grouped() + assert len(printed) == 1 + assert getattr(printed[0], "title", "").startswith("Residue Entropy — Group 2") + + +def test_log_group_label_table_hits_label_add_column(monkeypatch): + rr = ResultsReporter() + rr.add_group_label("1", "LabelHere", residue_count=2, atom_count=3) + printed = [] + monkeypatch.setattr(reporter_mod, "Table", FakeTable) + monkeypatch.setattr(reporter_mod.console, "print", lambda t: printed.append(t)) + rr._log_group_label_table() + assert len(printed) == 1 + assert printed[0].columns[1][0][0] == "Label" + + +def test_log_tables_calls_all_subtables(monkeypatch): + rr = ResultsReporter() + m1 = MagicMock() + m2 = MagicMock() + m3 = MagicMock() + monkeypatch.setattr(rr, "_log_grouped_results_tables", m1) + monkeypatch.setattr(rr, "_log_residue_table_grouped", m2) + monkeypatch.setattr(rr, "_log_group_label_table", m3) + rr.log_tables() + m1.assert_called_once() + m2.assert_called_once() + m3.assert_called_once() + + +def test_save_dataframes_as_json_writes_file(tmp_path, monkeypatch): + rr = ResultsReporter() + mol = pd.DataFrame( + [{"Group ID": 1, "Level": "Trans", "Type": "A", "Result (J/mol/K)": 1.0}] + ) + res = pd.DataFrame([]) + monkeypatch.setattr( + ResultsReporter, "_provenance", staticmethod(lambda: {"git_sha": None}) + ) + out = tmp_path / "out.json" + rr.save_dataframes_as_json( + mol, res, str(out), args={"x": 1}, include_raw_tables=False + ) + data = json.loads(out.read_text()) + assert data["args"] == {"x": 1} + assert data["provenance"] == {"git_sha": None} + assert data["groups"]["1"]["components"] == {"Trans:A": 1.0} + assert data["groups"]["1"]["total"] == 1.0 + + +def test_save_dataframes_as_json_uses_default_include_raw_tables(tmp_path, monkeypatch): + rr = ResultsReporter() + mol = pd.DataFrame( + [{"Group ID": 1, "Level": "L", "Type": "T", "Result (J/mol/K)": 1.0}] + ) + res = pd.DataFrame([]) + monkeypatch.setattr( + ResultsReporter, "_provenance", staticmethod(lambda: {"git_sha": None}) + ) + out = tmp_path / "out.json" + rr.save_dataframes_as_json(mol, res, str(out), args={"x": 1}) + assert out.exists() + + +def test_log_grouped_results_tables_returns_when_empty(monkeypatch): + rr = ResultsReporter() + monkeypatch.setattr( + reporter_mod.console, + "print", + lambda *_: (_ for _ in ()).throw(AssertionError("should not print")), + ) + rr._log_grouped_results_tables() + + +def test_log_grouped_results_tables_handles_total_row(monkeypatch): + rr = ResultsReporter() + rr.add_results_data("1", "A", "B", 1.0) + rr.add_results_data("1", "Group Total", "Group Total", 3.0) + + printed = [] + monkeypatch.setattr(reporter_mod, "Table", FakeTable) + monkeypatch.setattr(reporter_mod.console, "print", lambda t: printed.append(t)) + + rr._log_grouped_results_tables() + + assert len(printed) == 1 + assert ("Group Total", "Group Total", "3.0") in [r[0] for r in printed[0].rows] + + +def test_log_residue_table_grouped_returns_when_empty(monkeypatch): + rr = ResultsReporter() + monkeypatch.setattr( + reporter_mod.console, + "print", + lambda *_: (_ for _ in ()).throw(AssertionError("should not print")), + ) + rr._log_residue_table_grouped() + + +def test_log_group_label_table_returns_when_empty(monkeypatch): + rr = ResultsReporter() + monkeypatch.setattr( + reporter_mod.console, + "print", + lambda *_: (_ for _ in ()).throw(AssertionError("should not print")), + ) + rr._log_group_label_table() + + +def test_try_get_git_sha_returns_none_on_exception(monkeypatch): + monkeypatch.delenv("CODEENTROPY_GIT_SHA", raising=False) + + def boom(self): + raise RuntimeError("boom") + + monkeypatch.setattr(reporter_mod.Path, "resolve", boom) + assert ResultsReporter._try_get_git_sha() is None + + +def test_progress_context_yields_progress_sink(): + rr = ResultsReporter() + with rr.progress(transient=True) as p: + assert hasattr(p, "add_task") + assert hasattr(p, "update") + assert hasattr(p, "advance") + + +def test_progress_sink_update_normalizes_none_title(monkeypatch): + rr = ResultsReporter() + + with rr.progress(transient=True) as sink: + inner = sink._progress + spy = MagicMock() + monkeypatch.setattr(inner, "update", spy) + + sink.update(1, title=None) + + spy.assert_called_once() + _args, kwargs = spy.call_args + assert kwargs["title"] == "" + + +def test_rich_progress_sink_add_task_sets_default_title(): + inner = MagicMock() + inner.add_task.return_value = 7 + + sink = _RichProgressSink(inner) + task_id = sink.add_task("Stage", total=3) + + assert task_id == 7 + + inner.add_task.assert_called_once() + args, kwargs = inner.add_task.call_args + + assert args[0] == "Stage" + assert kwargs["total"] == 3 + assert kwargs["title"] == "" + + +def test_rich_progress_sink_update_normalizes_title_none(): + inner = MagicMock() + sink = _RichProgressSink(inner) + + sink.update(99, title=None) + + inner.update.assert_called_once_with(99, title="") + + +def test_gid_sort_key_handles_non_numeric_group_id(): + assert ResultsReporter._gid_sort_key("abc") == (1, "abc") + + +def test_rich_progress_sink_advance_forwards_to_inner_progress(): + inner = MagicMock() + sink = _RichProgressSink(inner) + + sink.advance(123, step=5) + + inner.advance.assert_called_once_with(123, 5)