Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@v7

- uses: pyvista/setup-headless-display-action@v3

- name: Set up Conda
uses: conda-incubator/setup-miniconda@v4
with:
Expand All @@ -26,7 +28,7 @@ jobs:
- name: Create Conda environment
shell: bash -l {0}
run: |
conda install -c conda-forge python pip fenics-dolfinx=${{ matrix.dolfinx }} scifem adios4dolfinx
conda install -c conda-forge python pip fenics-dolfinx=${{ matrix.dolfinx }} scifem adios4dolfinx pyvista

- name: Install local package and dependencies
shell: bash -l {0}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:

- name: Install local package and dependencies
run: |
python -m pip install scipy # needed in scifem, can be removed when scifem adds it to their dependencies
python -m pip install scipy pyvista # scipy needed in scifem, can be removed when scifem adds it to their dependencies
python -m pip install .[test]

- name: Overload adios4dolfinx
Expand Down
1 change: 1 addition & 0 deletions src/festim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .mesh.mesh import Mesh
from .mesh.mesh_1d import Mesh1D
from .mesh.mesh_from_xdmf import MeshFromXDMF
from .plotting import plot
from .problem import ProblemBase
from .reaction import Reaction
from .settings import Settings
Expand Down
108 changes: 108 additions & 0 deletions src/festim/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pathlib import Path

from festim.species import Species

DEFAULT_TITLE_FONT_SIZE = 12


def _normalize_fields(field: Species | list[Species]) -> list[Species]:
if isinstance(field, Species):
return [field]
if isinstance(field, list) and all(isinstance(f, Species) for f in field):
return field
raise TypeError("field must be of type festim.Species or a list of festim.Species")


def _get_solution(field: Species, subdomain=None):
if subdomain is None:
if field.post_processing_solution is None:
raise ValueError(
f"Species {field.name} has no post_processing_solution to plot."
)
return field.post_processing_solution
else:
if field.post_processing_solution is not None:
raise ValueError(
"Problem seems to be HydrogenTransportProblem but a subdomain"
" was provided."
)
if not field.subdomain_to_post_processing_solution:
raise ValueError(
f"Species {field.name} has no subdomain post-processing solutions."
)
if subdomain not in field.subdomain_to_post_processing_solution:
raise ValueError(
f"Species {field.name} has no post-processing solution on subdomain "
f"{subdomain}."
)
return field.subdomain_to_post_processing_solution[subdomain]


def _make_ugrid(solution, pyvista_module, name="c"):
from dolfinx import plot as dolfinx_plot

topology, cell_types, geometry = dolfinx_plot.vtk_mesh(solution.function_space)
u_grid = pyvista_module.UnstructuredGrid(topology, cell_types, geometry)
u_grid.point_data[name] = solution.x.array.real
u_grid.set_active_scalars(name)
return u_grid


def plot(
field: Species | list[Species],
subdomain=None,
filename: str | Path | None = None,
show_edges: bool = False,
split_colourbars: bool = False,
**kwargs,
):
"""
Plot one or several species fields with pyvista.

Args:
field: one species or a list of species.
subdomain: optional volume subdomain used in mixed-domain problems.
filename: optional output image path. If provided, a screenshot is saved.
show_edges: whether to show mesh edges.
split_colourbars: whether to use a different colourbar for each species.
**kwargs: additional arguments forwarded to ``pyvista.Plotter.add_mesh``.
"""
try:
import pyvista
except ImportError as import_error:
raise ImportError(
"pyvista is required for plotting. Install it with `pip install pyvista`."
) from import_error

fields = _normalize_fields(field)
shape = (1, len(fields)) if len(fields) > 1 else (1, 1)
plotter = pyvista.Plotter(shape=shape)

for i, spe in enumerate(fields):
if len(fields) > 1:
plotter.subplot(0, i)
# if subdomain is None but the species has .subdomain_to_post_processing_solution,
# we need to plot on all subdomains
if subdomain is None:
if spe.subdomain_to_post_processing_solution:
for solution in spe.subdomain_to_post_processing_solution.values():
u_grid = _make_ugrid(solution, pyvista)
plotter.add_mesh(u_grid, show_edges=show_edges, **kwargs)
continue

solution = _get_solution(spe, subdomain=subdomain)
u_grid = _make_ugrid(
solution, pyvista, name=spe.name if split_colourbars else "c"
)
plotter.add_mesh(u_grid, show_edges=show_edges, **kwargs)
plotter.view_xy()
if spe.name:
plotter.add_text(spe.name, font_size=DEFAULT_TITLE_FONT_SIZE)

if filename is not None:
plotter.show()
plotter.screenshot(str(filename))
elif not pyvista.OFF_SCREEN:
plotter.show()

return plotter
116 changes: 116 additions & 0 deletions test/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from mpi4py import MPI

import numpy as np
import pytest
from dolfinx import fem, mesh

import festim as F

pyvista = pytest.importorskip("pyvista")
pyvista.OFF_SCREEN = True


def create_mock_solution():
test_mesh = mesh.create_unit_interval(MPI.COMM_WORLD, 10)
V = fem.functionspace(test_mesh, ("Lagrange", 1))
u = fem.Function(V)
u.x.array[:] = np.ones_like(u.x.array)
return u


def test_plot_single_species():
species = F.Species("H")
species.post_processing_solution = create_mock_solution()

plotter = F.plot(species, show_edges=True, opacity=0.5)

assert isinstance(plotter, pyvista.Plotter)
assert plotter.shape == (1, 1)
plotter.close()


def test_plot_multiple_species_creates_subplots():
h = F.Species("H")
d = F.Species("D")
h.post_processing_solution = create_mock_solution()
d.post_processing_solution = create_mock_solution()

plotter = F.plot([h, d])

assert plotter.shape == (1, 2)
plotter.close()


def test_plot_subdomain_uses_subdomain_solution():
material = F.Material(D_0=1, E_D=0)
vol_1 = F.VolumeSubdomain(id=1, material=material)
vol_2 = F.VolumeSubdomain(id=2, material=material)
sol_1 = create_mock_solution()
sol_2 = create_mock_solution()

h = F.Species("H")
h.subdomain_to_post_processing_solution = {vol_1: sol_1, vol_2: sol_2}

plotter = F.plot(h, subdomain=vol_2)
assert isinstance(plotter, pyvista.Plotter)
plotter.close()


def test_plot_with_filename_saves_screenshot(tmp_path):
species = F.Species("H")
species.post_processing_solution = create_mock_solution()
filename = tmp_path / "out.png"

plotter = F.plot(species, filename=filename)
plotter.close()
assert filename.exists()


def test_plot_with_string_filename_saves_screenshot(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)

species = F.Species("H")
species.post_processing_solution = create_mock_solution()

plotter = F.plot(species, filename="out.png")
plotter.close()

assert (tmp_path / "out.png").exists()


def test_plot_raises_for_invalid_field_type():
with pytest.raises(
TypeError,
match=r"field must be of type festim\.Species or a list of festim\.Species",
):
F.plot("H")


def test_plot_raises_if_no_solution():
with pytest.raises(ValueError, match="has no post_processing_solution to plot"):
F.plot(F.Species("H"))


def test_plot_default_show_edges_and_empty_name():
species = F.Species()
species.post_processing_solution = create_mock_solution()

plotter = F.plot(species)

assert isinstance(plotter, pyvista.Plotter)
plotter.close()


def test_plot_default_with_several_subdomains():
material = F.Material(D_0=1, E_D=0)
vol_1 = F.VolumeSubdomain(id=1, material=material)
vol_2 = F.VolumeSubdomain(id=2, material=material)
sol_1 = create_mock_solution()
sol_2 = create_mock_solution()

h = F.Species("H")
h.subdomain_to_post_processing_solution = {vol_1: sol_1, vol_2: sol_2}

plotter = F.plot(h)
assert isinstance(plotter, pyvista.Plotter)
plotter.close()
Loading