Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ufl
import finat.ufl
from finat.quadrature import QuadratureRule

from ufl.cell import CellSequence
from ufl.duals import is_dual, is_primal
Expand Down Expand Up @@ -416,6 +417,31 @@ def broken_space(self):
self.mesh(), finat.ufl.BrokenElement(self.ufl_element()),
name=f"{self.name}_broken" if self.name else None)

def quadrature_space(self):
"""Return a :class:`.WithGeometryBase` with a ``Quadrature`` element
defined on the point set required for interpolating external data into this space.

Returns
-------
WithGeometryBase :
The new function space with a ``Quadrature`` FiniteElement.
"""
ufl_element = self.ufl_element()
if not self.finat_element.has_pointwise_dual_basis:
# Grab the point set for interpolation
_, ps = self.finat_element.dual_basis
# Invalidate the weights. This quadrature scheme is not for integration.
weights = numpy.full(len(ps.points), numpy.nan)
quad_scheme = QuadratureRule(ps, weights, self.finat_element.cell)

ufl_element = finat.ufl.FiniteElement("Quadrature",
cell=ufl_element.cell,
degree=self.finat_element.degree,
quad_scheme=quad_scheme)
if self.value_shape:
ufl_element = finat.ufl.TensorElement(ufl_element, shape=self.value_shape)
return self.collapse().reconstruct(element=ufl_element)

def reconstruct(
self,
mesh: MeshGeometry | None = None,
Expand Down
38 changes: 13 additions & 25 deletions firedrake/mg/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,11 @@
from enum import IntEnum
from firedrake.petsc import PETSc
from firedrake.embedding import get_embedding_dg_element

from finat.element_factory import create_element

__all__ = ("TransferManager", )


native_families = frozenset(["Lagrange", "Discontinuous Lagrange", "Real", "Q", "DQ", "BrokenElement", "Crouzeix-Raviart", "Kong-Mulder-Veldhuizen"])
alfeld_families = frozenset(["Hsieh-Clough-Tocher", "Reduced-Hsieh-Clough-Tocher", "Johnson-Mercier",
"Alfeld-Sorokina", "Arnold-Qin", "Reduced-Arnold-Qin", "Christiansen-Hu",
"Guzman-Neilan", "Guzman-Neilan Bubble"])
non_native_variants = frozenset(["integral", "fdm", "alfeld"])


def get_embedding_element(element, value_shape):
broken_cg = element.sobolev_space in {ufl.H1, ufl.H2}
dg_element = get_embedding_dg_element(element, value_shape, broken_cg=broken_cg)
variant = element.variant() or "default"
family = element.family()
# Elements on Alfeld splits are embedded onto DG Powell-Sabin.
# This yields supermesh projection
if (family in alfeld_families) or ("alfeld" in variant.lower() and family != "Discontinuous Lagrange"):
dg_element = dg_element.reconstruct(variant="powell-sabin")
return dg_element


class Op(IntEnum):
PROLONG = 0
RESTRICT = 1
Expand Down Expand Up @@ -68,14 +49,21 @@ def __init__(self, *, native_transfers=None, use_averaging=True):
self.caches = {}

def is_native(self, element, op):
if element in self.native_transfers.keys():
if element in self.native_transfers:
return self.native_transfers[element][op] is not None
if isinstance(element.cell, ufl.TensorProductCell):
if isinstance(element, finat.ufl.TensorProductElement):
return all(self.is_native(e, op) for e in element.factor_elements)
elif isinstance(element, finat.ufl.MixedElement):
return all(self.is_native(e, op) for e in element.sub_elements)
return (element.family() in native_families) and not (element.variant() in non_native_variants)

# Can we interpolate into this element?
finat_element = create_element(element)
try:
finat_element.dual_basis
return True
except NotImplementedError:
return False

def _native_transfer(self, element, op):
try:
Expand Down Expand Up @@ -253,8 +241,8 @@ def op(self, source, target, transfer_op):
if not self.requires_transfer(Vs, transfer_op, source, target):
return

if all(self.is_native(e, transfer_op) for e in (source_element, target_element)):
self._native_transfer(source_element, transfer_op)(source, target)
if self.is_native(target_element, transfer_op):
self._native_transfer(target_element, transfer_op)(source, target)
elif type(source_element) is finat.ufl.MixedElement:
assert type(target_element) is finat.ufl.MixedElement
for source_, target_ in zip(source.subfunctions, target.subfunctions):
Expand Down Expand Up @@ -318,7 +306,7 @@ def restrict(self, source, target):
if not self.requires_transfer(Vs_star, Op.RESTRICT, source, target):
return

if all(self.is_native(e, Op.RESTRICT) for e in (source_element, target_element)):
if self.is_native(source_element, Op.RESTRICT):
self._native_transfer(source_element, Op.RESTRICT)(source, target)
elif type(source_element) is finat.ufl.MixedElement:
assert type(target_element) is finat.ufl.MixedElement
Expand Down
126 changes: 76 additions & 50 deletions firedrake/mg/interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pyop2 import op2

import firedrake
from firedrake import ufl_expr
from firedrake import ufl_expr, dmhooks
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.petsc import PETSc
from ufl.duals import is_dual
from . import utils
Expand All @@ -13,10 +14,10 @@

def check_arguments(coarse, fine, needs_dual=False):
if is_dual(coarse) != needs_dual:
expected_type = firedrake.Cofunction if needs_dual else firedrake.Function
expected_type = Cofunction if needs_dual else Function
raise TypeError("Coarse argument is a %s, not a %s" % (type(coarse).__name__, expected_type.__name__))
if is_dual(fine) != needs_dual:
expected_type = firedrake.Cofunction if needs_dual else firedrake.Function
expected_type = Cofunction if needs_dual else Function
raise TypeError("Fine argument is a %s, not a %s" % (type(fine).__name__, expected_type.__name__))
cfs = coarse.function_space()
ffs = fine.function_space()
Expand All @@ -41,7 +42,7 @@ def prolong(coarse, fine):
if len(Vc) != len(Vf):
raise ValueError("Mixed spaces have different lengths")
for in_, out in zip(coarse.subfunctions, fine.subfunctions):
manager = firedrake.dmhooks.get_transfer_manager(in_.function_space().dm)
manager = dmhooks.get_transfer_manager(in_.function_space().dm)
manager.prolong(in_, out)
return fine

Expand All @@ -58,20 +59,26 @@ def prolong(coarse, fine):
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = coarse_level * refinements_per_level

if needs_quadrature := not Vf.finat_element.has_pointwise_dual_basis:
# Introduce an intermediate quadrature target space
Vf = Vf.quadrature_space()

finest = fine
Vfinest = finest.function_space()
meshes = hierarchy._meshes
for j in range(repeat):
next_level += 1
if j == repeat - 1:
next = fine
Vf = fine.function_space()
if j == repeat - 1 and not needs_quadrature:
fine = finest
else:
Vf = Vc.reconstruct(mesh=meshes[next_level])
next = firedrake.Function(Vf)
fine = Function(Vf.reconstruct(mesh=meshes[next_level]))
Vf = fine.function_space()
Vc = coarse.function_space()

coarse_coords = get_coordinates(Vc)
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
kernel = kernels.prolong_kernel(coarse)
kernel = kernels.prolong_kernel(coarse, Vf)

# XXX: Should be able to figure out locations by pushing forward
# reference cell node locations to physical space.
Expand All @@ -82,13 +89,17 @@ def prolong(coarse, fine):
for d in [coarse, coarse_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
op2.par_loop(kernel, next.node_set,
next.dat(op2.WRITE),
op2.par_loop(kernel, fine.node_set,
fine.dat(op2.WRITE),
coarse.dat(op2.READ, fine_to_coarse),
node_locations.dat(op2.READ),
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
coarse = next
Vc = Vf

if needs_quadrature:
# Transfer to the actual target space
new_fine = finest if j == repeat-1 else Function(Vfinest.reconstruct(mesh=meshes[next_level]))
fine = new_fine.interpolate(fine)
coarse = fine
return fine


Expand All @@ -101,7 +112,7 @@ def restrict(fine_dual, coarse_dual):
if len(Vc) != len(Vf):
raise ValueError("Mixed spaces have different lengths")
for in_, out in zip(fine_dual.subfunctions, coarse_dual.subfunctions):
manager = firedrake.dmhooks.get_transfer_manager(in_.function_space().dm)
manager = dmhooks.get_transfer_manager(in_.function_space().dm)
manager.restrict(in_, out)
return coarse_dual

Expand All @@ -118,17 +129,25 @@ def restrict(fine_dual, coarse_dual):
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = fine_level * refinements_per_level

meshes = hierarchy._meshes
if needs_quadrature := not Vf.finat_element.has_pointwise_dual_basis:
# Introduce an intermediate quadrature source space
Vq = Vf.quadrature_space()

coarsest = coarse_dual.zero()
meshes = hierarchy._meshes
for j in range(repeat):
if needs_quadrature:
# Transfer to the quadrature source space
fine_dual = Function(Vq.reconstruct(mesh=meshes[next_level])).interpolate(fine_dual)

next_level -= 1
if j == repeat - 1:
coarse_dual.dat.zero()
next = coarse_dual
coarse_dual = coarsest
else:
Vc = Vf.reconstruct(mesh=meshes[next_level])
next = firedrake.Cofunction(Vc)
Vc = next.function_space()
coarse_dual = Function(Vc.reconstruct(mesh=meshes[next_level]))
Vf = fine_dual.function_space()
Vc = coarse_dual.function_space()

# XXX: Should be able to figure out locations by pushing forward
# reference cell node locations to physical space.
# x = \sum_i c_i \phi_i(x_hat)
Expand All @@ -144,12 +163,11 @@ def restrict(fine_dual, coarse_dual):
d.dat.global_to_local_end(op2.READ)
kernel = kernels.restrict_kernel(Vf, Vc)
op2.par_loop(kernel, fine_dual.node_set,
next.dat(op2.INC, fine_to_coarse),
coarse_dual.dat(op2.INC, fine_to_coarse),
fine_dual.dat(op2.READ),
node_locations.dat(op2.READ),
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
fine_dual = next
Vf = Vc
fine_dual = coarse_dual
return coarse_dual


Expand All @@ -162,7 +180,7 @@ def inject(fine, coarse):
if len(Vc) != len(Vf):
raise ValueError("Mixed spaces have different lengths")
for in_, out in zip(fine.subfunctions, coarse.subfunctions):
manager = firedrake.dmhooks.get_transfer_manager(in_.function_space().dm)
manager = dmhooks.get_transfer_manager(in_.function_space().dm)
manager.inject(in_, out)
return

Expand All @@ -184,46 +202,50 @@ def inject(fine, coarse):
# For DG, for each coarse cell, instead:
# solve inner(u_c, v_c)*dx_c == inner(f, v_c)*dx_c

kernel, dg = kernels.inject_kernel(Vf, Vc)
hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse))
if dg and not hierarchy.nested:
raise NotImplementedError("Sorry, we can't do supermesh projections yet!")
_, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine))
refinements_per_level = hierarchy.refinements_per_level
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = fine_level * refinements_per_level

meshes = hierarchy._meshes
if needs_quadrature := not Vc.finat_element.has_pointwise_dual_basis:
# Introduce an intermediate quadrature target space
Vc = Vc.quadrature_space()

kernel, dg = kernels.inject_kernel(Vf, Vc)
if dg and not hierarchy.nested:
raise NotImplementedError("Sorry, we can't do supermesh projections yet!")

coarsest = coarse.zero()
Vcoarsest = coarsest.function_space()
meshes = hierarchy._meshes
for j in range(repeat):
next_level -= 1
if j == repeat - 1:
coarse.dat.zero()
next = coarse
Vc = next.function_space()
if j == repeat - 1 and not needs_quadrature:
coarse = coarsest
else:
Vc = Vf.reconstruct(mesh=meshes[next_level])
next = firedrake.Function(Vc)
coarse = Function(Vc.reconstruct(mesh=meshes[next_level]))
Vc = coarse.function_space()
Vf = fine.function_space()
if not dg:
node_locations = utils.physical_node_locations(Vc)

fine_coords = get_coordinates(Vf)
coarse_node_to_fine_nodes = utils.coarse_node_to_fine_node_map(Vc, Vf)
coarse_node_to_fine_coords = utils.coarse_node_to_fine_node_map(Vc, fine_coords.function_space())
coarse_to_fine = utils.coarse_node_to_fine_node_map(Vc, Vf)
coarse_to_fine_coords = utils.coarse_node_to_fine_node_map(Vc, fine_coords.function_space())

node_locations = utils.physical_node_locations(Vc)
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [fine, fine_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
op2.par_loop(kernel, next.node_set,
next.dat(op2.INC),
op2.par_loop(kernel, coarse.node_set,
coarse.dat(op2.WRITE),
fine.dat(op2.READ, coarse_to_fine),
node_locations.dat(op2.READ),
fine.dat(op2.READ, coarse_node_to_fine_nodes),
fine_coords.dat(op2.READ, coarse_node_to_fine_coords))
fine_coords.dat(op2.READ, coarse_to_fine_coords))
else:
coarse_coords = Vc.mesh().coordinates
fine_coords = Vf.mesh().coordinates
coarse_coords = get_coordinates(Vc)
fine_coords = get_coordinates(Vf)
coarse_cell_to_fine_nodes = utils.coarse_cell_to_fine_node_map(Vc, Vf)
coarse_cell_to_fine_coords = utils.coarse_cell_to_fine_node_map(Vc, fine_coords.function_space())
# Have to do this, because the node set core size is not right for
Expand All @@ -232,18 +254,22 @@ def inject(fine, coarse):
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
op2.par_loop(kernel, Vc.mesh().cell_set,
next.dat(op2.INC, next.cell_node_map()),
coarse.dat(op2.INC, coarse.cell_node_map()),
fine.dat(op2.READ, coarse_cell_to_fine_nodes),
fine_coords.dat(op2.READ, coarse_cell_to_fine_coords),
coarse_coords.dat(op2.READ, coarse_coords.cell_node_map()))
fine = next
Vf = Vc

if needs_quadrature:
# Transfer to the actual target space
new_coarse = coarsest if j == repeat - 1 else Function(Vcoarsest.reconstruct(mesh=meshes[next_level]))
coarse = new_coarse.interpolate(coarse)
fine = coarse
return coarse


def get_coordinates(V):
coords = V.mesh().coordinates
if V.boundary_set:
W = V.reconstruct(element=coords.function_space().ufl_element())
coords = firedrake.Function(W).interpolate(coords)
coords = Function(W).interpolate(coords)
return coords
Loading