diff --git a/.github/workflows/ci_conda.yml b/.github/workflows/ci_conda.yml index a1b07ad16..613466ff7 100644 --- a/.github/workflows/ci_conda.yml +++ b/.github/workflows/ci_conda.yml @@ -15,6 +15,8 @@ jobs: - name: Checkout code uses: actions/checkout@v7 + - uses: pyvista/setup-headless-display-action@v3 + - name: Set up Conda uses: conda-incubator/setup-miniconda@v4 with: @@ -26,7 +28,7 @@ jobs: - name: Create Conda environment shell: bash -l {0} run: | - conda install -c conda-forge python pip fenics-dolfinx=${{ matrix.dolfinx }} scifem adios4dolfinx + conda install -c conda-forge python pip fenics-dolfinx=${{ matrix.dolfinx }} scifem adios4dolfinx pyvista - name: Install local package and dependencies shell: bash -l {0} diff --git a/.github/workflows/ci_docker.yml b/.github/workflows/ci_docker.yml index 8996d0893..66fc498db 100644 --- a/.github/workflows/ci_docker.yml +++ b/.github/workflows/ci_docker.yml @@ -20,7 +20,7 @@ jobs: - name: Install local package and dependencies run: | - python -m pip install scipy # needed in scifem, can be removed when scifem adds it to their dependencies + python -m pip install scipy pyvista # scipy needed in scifem, can be removed when scifem adds it to their dependencies python -m pip install .[test] - name: Overload adios4dolfinx diff --git a/src/festim/__init__.py b/src/festim/__init__.py index b2095f861..722474873 100644 --- a/src/festim/__init__.py +++ b/src/festim/__init__.py @@ -72,6 +72,7 @@ from .mesh.mesh import Mesh from .mesh.mesh_1d import Mesh1D from .mesh.mesh_from_xdmf import MeshFromXDMF +from .plotting import plot from .problem import ProblemBase from .reaction import Reaction from .settings import Settings diff --git a/src/festim/plotting.py b/src/festim/plotting.py new file mode 100644 index 000000000..ea99d0907 --- /dev/null +++ b/src/festim/plotting.py @@ -0,0 +1,108 @@ +from pathlib import Path + +from festim.species import Species + +DEFAULT_TITLE_FONT_SIZE = 12 + + +def _normalize_fields(field: Species | list[Species]) -> list[Species]: + if isinstance(field, Species): + return [field] + if isinstance(field, list) and all(isinstance(f, Species) for f in field): + return field + raise TypeError("field must be of type festim.Species or a list of festim.Species") + + +def _get_solution(field: Species, subdomain=None): + if subdomain is None: + if field.post_processing_solution is None: + raise ValueError( + f"Species {field.name} has no post_processing_solution to plot." + ) + return field.post_processing_solution + else: + if field.post_processing_solution is not None: + raise ValueError( + "Problem seems to be HydrogenTransportProblem but a subdomain" + " was provided." + ) + if not field.subdomain_to_post_processing_solution: + raise ValueError( + f"Species {field.name} has no subdomain post-processing solutions." + ) + if subdomain not in field.subdomain_to_post_processing_solution: + raise ValueError( + f"Species {field.name} has no post-processing solution on subdomain " + f"{subdomain}." + ) + return field.subdomain_to_post_processing_solution[subdomain] + + +def _make_ugrid(solution, pyvista_module, name="c"): + from dolfinx import plot as dolfinx_plot + + topology, cell_types, geometry = dolfinx_plot.vtk_mesh(solution.function_space) + u_grid = pyvista_module.UnstructuredGrid(topology, cell_types, geometry) + u_grid.point_data[name] = solution.x.array.real + u_grid.set_active_scalars(name) + return u_grid + + +def plot( + field: Species | list[Species], + subdomain=None, + filename: str | Path | None = None, + show_edges: bool = False, + split_colourbars: bool = False, + **kwargs, +): + """ + Plot one or several species fields with pyvista. + + Args: + field: one species or a list of species. + subdomain: optional volume subdomain used in mixed-domain problems. + filename: optional output image path. If provided, a screenshot is saved. + show_edges: whether to show mesh edges. + split_colourbars: whether to use a different colourbar for each species. + **kwargs: additional arguments forwarded to ``pyvista.Plotter.add_mesh``. + """ + try: + import pyvista + except ImportError as import_error: + raise ImportError( + "pyvista is required for plotting. Install it with `pip install pyvista`." + ) from import_error + + fields = _normalize_fields(field) + shape = (1, len(fields)) if len(fields) > 1 else (1, 1) + plotter = pyvista.Plotter(shape=shape) + + for i, spe in enumerate(fields): + if len(fields) > 1: + plotter.subplot(0, i) + # if subdomain is None but the species has .subdomain_to_post_processing_solution, + # we need to plot on all subdomains + if subdomain is None: + if spe.subdomain_to_post_processing_solution: + for solution in spe.subdomain_to_post_processing_solution.values(): + u_grid = _make_ugrid(solution, pyvista) + plotter.add_mesh(u_grid, show_edges=show_edges, **kwargs) + continue + + solution = _get_solution(spe, subdomain=subdomain) + u_grid = _make_ugrid( + solution, pyvista, name=spe.name if split_colourbars else "c" + ) + plotter.add_mesh(u_grid, show_edges=show_edges, **kwargs) + plotter.view_xy() + if spe.name: + plotter.add_text(spe.name, font_size=DEFAULT_TITLE_FONT_SIZE) + + if filename is not None: + plotter.show() + plotter.screenshot(str(filename)) + elif not pyvista.OFF_SCREEN: + plotter.show() + + return plotter diff --git a/test/test_plot.py b/test/test_plot.py new file mode 100644 index 000000000..f0130e315 --- /dev/null +++ b/test/test_plot.py @@ -0,0 +1,116 @@ +from mpi4py import MPI + +import numpy as np +import pytest +from dolfinx import fem, mesh + +import festim as F + +pyvista = pytest.importorskip("pyvista") +pyvista.OFF_SCREEN = True + + +def create_mock_solution(): + test_mesh = mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = fem.functionspace(test_mesh, ("Lagrange", 1)) + u = fem.Function(V) + u.x.array[:] = np.ones_like(u.x.array) + return u + + +def test_plot_single_species(): + species = F.Species("H") + species.post_processing_solution = create_mock_solution() + + plotter = F.plot(species, show_edges=True, opacity=0.5) + + assert isinstance(plotter, pyvista.Plotter) + assert plotter.shape == (1, 1) + plotter.close() + + +def test_plot_multiple_species_creates_subplots(): + h = F.Species("H") + d = F.Species("D") + h.post_processing_solution = create_mock_solution() + d.post_processing_solution = create_mock_solution() + + plotter = F.plot([h, d]) + + assert plotter.shape == (1, 2) + plotter.close() + + +def test_plot_subdomain_uses_subdomain_solution(): + material = F.Material(D_0=1, E_D=0) + vol_1 = F.VolumeSubdomain(id=1, material=material) + vol_2 = F.VolumeSubdomain(id=2, material=material) + sol_1 = create_mock_solution() + sol_2 = create_mock_solution() + + h = F.Species("H") + h.subdomain_to_post_processing_solution = {vol_1: sol_1, vol_2: sol_2} + + plotter = F.plot(h, subdomain=vol_2) + assert isinstance(plotter, pyvista.Plotter) + plotter.close() + + +def test_plot_with_filename_saves_screenshot(tmp_path): + species = F.Species("H") + species.post_processing_solution = create_mock_solution() + filename = tmp_path / "out.png" + + plotter = F.plot(species, filename=filename) + plotter.close() + assert filename.exists() + + +def test_plot_with_string_filename_saves_screenshot(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + + species = F.Species("H") + species.post_processing_solution = create_mock_solution() + + plotter = F.plot(species, filename="out.png") + plotter.close() + + assert (tmp_path / "out.png").exists() + + +def test_plot_raises_for_invalid_field_type(): + with pytest.raises( + TypeError, + match=r"field must be of type festim\.Species or a list of festim\.Species", + ): + F.plot("H") + + +def test_plot_raises_if_no_solution(): + with pytest.raises(ValueError, match="has no post_processing_solution to plot"): + F.plot(F.Species("H")) + + +def test_plot_default_show_edges_and_empty_name(): + species = F.Species() + species.post_processing_solution = create_mock_solution() + + plotter = F.plot(species) + + assert isinstance(plotter, pyvista.Plotter) + plotter.close() + + +def test_plot_default_with_several_subdomains(): + material = F.Material(D_0=1, E_D=0) + vol_1 = F.VolumeSubdomain(id=1, material=material) + vol_2 = F.VolumeSubdomain(id=2, material=material) + sol_1 = create_mock_solution() + sol_2 = create_mock_solution() + + h = F.Species("H") + h.subdomain_to_post_processing_solution = {vol_1: sol_1, vol_2: sol_2} + + plotter = F.plot(h) + assert isinstance(plotter, pyvista.Plotter) + plotter.close()