Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 81 additions & 12 deletions meshmode/discretization/connection/opposite_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import numpy as np
import numpy.linalg as la
from meshmode.discretization.connection.direct import InterpolationBatch

import logging
logger = logging.getLogger(__name__)
Expand All @@ -43,15 +44,13 @@ def _make_cross_face_batches(actx,
i_tgt_grp, i_src_grp,
tgt_bdry_element_indices, src_bdry_element_indices):

from meshmode.discretization.connection.direct import InterpolationBatch
if tgt_bdry_discr.dim == 0:
yield InterpolationBatch(
return [InterpolationBatch(
from_group_index=i_src_grp,
from_element_indices=freeze_from_numpy(actx, src_bdry_element_indices),
to_element_indices=freeze_from_numpy(actx, tgt_bdry_element_indices),
result_unit_nodes=src_bdry_discr.groups[i_src_grp].unit_nodes,
to_element_face=None)
return
to_element_face=None)]

tgt_bdry_nodes = np.array([
thaw_to_numpy(actx, ary[i_tgt_grp])[tgt_bdry_element_indices]
Expand All @@ -68,11 +67,72 @@ def _make_cross_face_batches(actx,
src_mesh_grp = src_bdry_discr.mesh.groups[i_src_grp]
src_grp = src_bdry_discr.groups[i_src_grp]

src_unit_nodes = _find_src_unit_nodes_by_matching(
tgt_bdry_nodes=tgt_bdry_nodes,
src_bdry_nodes=src_bdry_nodes,
src_grp=src_grp, tol=tol)
if src_unit_nodes is None:
src_unit_nodes = _find_src_unit_nodes_via_gauss_newton(
tgt_bdry_nodes=tgt_bdry_nodes,
src_bdry_nodes=src_bdry_nodes,
src_grp=src_grp, src_mesh_grp=src_mesh_grp,
tgt_bdry_discr=tgt_bdry_discr, src_bdry_discr=src_bdry_discr,
tol=tol)

return list(_find_src_unit_nodes_batches(
actx=actx, src_unit_nodes=src_unit_nodes,
i_src_grp=i_src_grp,
tgt_bdry_element_indices=tgt_bdry_element_indices,
src_bdry_element_indices=src_bdry_element_indices,
tol=tol))

# }}}


# {{{ _find_src_unit_nodes_by_matching

def _find_src_unit_nodes_by_matching(
tgt_bdry_nodes,
src_bdry_nodes,
src_grp, tol):
ambient_dim, nelements, ntgt_unit_nodes = tgt_bdry_nodes.shape

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check whether ntgt_unit_nodes == nsrc_unit_nodes first before going on to the pairwise comparison stuff?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the thinking, but the criterion your propose is actually too restrictive. The comparison thing is also supposed to work in the (ideally, common, see #225) case of face restrictions where the face nodes are a subset of the volume nodes.

dist_vecs = (tgt_bdry_nodes.reshape(ambient_dim, nelements, -1, 1)
- src_bdry_nodes.reshape(ambient_dim, nelements, 1, -1))
Comment on lines +100 to +101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to think of a faster way to do this, but I wasn't able to come up with anything that didn't involve either SpatialBinaryTreeBucket (which might not be faster most of the time) or rand(). 🙂

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... it is clearly an O(nunit_nodes^2) algorithm, my hope is just that numpy crushes it. Anything with a Python loop will be way slower. I am a bit worried about the size of that temporary. If it's too big,we can always do a Python loop for one of those axes.


# shape: (nelements, num_tgt_nodes, num_source_nodes)
is_close = la.norm(dist_vecs, axis=0, ord=2) < tol

num_close_vertices = np.sum(is_close.astype(np.int32), axis=-1)
if not (num_close_vertices == 1).all():
return None

# Success: it's just a permutation
source_indices = np.where(is_close)[-1].reshape(nelements, ntgt_unit_nodes)

# check
matched_src_bdry_nodes = src_bdry_nodes[
:, np.arange(nelements).reshape(-1, 1), source_indices]
dist_vecs = tgt_bdry_nodes - matched_src_bdry_nodes
is_close = la.norm(dist_vecs, axis=0, ord=2) < tol
assert is_close.all()

return src_grp.unit_nodes[:, source_indices]

# }}}


# {{{ _find_src_unit_nodes_via_gauss_newton

def _find_src_unit_nodes_via_gauss_newton(
tgt_bdry_nodes,
src_bdry_nodes,
src_grp, src_mesh_grp,
tgt_bdry_discr, src_bdry_discr,
tol):
dim = src_grp.dim
_, nelements, ntgt_unit_nodes = tgt_bdry_nodes.shape

# {{{ invert face map (using Gauss-Newton)

initial_guess = np.mean(src_mesh_grp.vertex_unit_coordinates(), axis=0)
src_unit_nodes = np.empty((dim, nelements, ntgt_unit_nodes))
src_unit_nodes[:] = initial_guess.reshape(-1, 1, 1)
Expand Down Expand Up @@ -162,7 +222,7 @@ def get_map_jacobian(unit_nodes):

# }}}

logger.debug("_make_cross_face_batches: begin gauss-newton")
logger.debug("_find_src_unit_nodes_via_gauss_newton: begin")

niter = 0
while True:
Expand Down Expand Up @@ -223,18 +283,27 @@ def get_map_jacobian(unit_nodes):
max_resid = np.max(np.abs(resid))

if max_resid < tol:
logger.debug("_make_cross_face_batches: gauss-newton: done, "
logger.debug("_find_src_unit_nodes_via_gauss_newton: done, "
"final residual: %g", max_resid)
break
return src_unit_nodes

niter += 1
if niter > 10:
raise RuntimeError("Gauss-Newton (for finding opposite-face reference "
"coordinates) did not converge (residual: %g)" % max_resid)

# }}}
raise AssertionError()

# {{{ find batches of src_unit_nodes
# }}}


# {{{ _find_src_unit_nodes_batches

def _find_src_unit_nodes_batches(
actx, src_unit_nodes, i_src_grp,
tgt_bdry_element_indices, src_bdry_element_indices,
tol):
dim, nelements, _ = src_unit_nodes.shape

done_elements = np.zeros(nelements, dtype=bool)
while True:
Expand All @@ -261,7 +330,7 @@ def get_map_jacobian(unit_nodes):
result_unit_nodes=template_unit_nodes,
to_element_face=None)

# }}}
# }}}


def _find_ibatch_for_face(vbc_tgt_grp_batches, iface):
Expand Down
16 changes: 2 additions & 14 deletions test/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,8 @@ def test_bdry_restriction_is_permutation(actx_factory, group_factory, dim, order

assert connection_is_permutation(actx, bdry_connection)

is_lgl = group_factory is LegendreGaussLobattoTensorProductGroupFactory

# FIXME: This should pass unconditionally
should_pass = (
(dim == 3 and order < 2)
or (dim == 2 and not is_lgl)
or (dim == 2 and is_lgl and order < 4)
)

if should_pass:
opp_face = make_opposite_face_connection(actx, bdry_connection)
assert connection_is_permutation(actx, opp_face)
else:
pytest.xfail("https://github.com/inducer/meshmode/pull/105")
opp_face = make_opposite_face_connection(actx, bdry_connection)
assert connection_is_permutation(actx, opp_face)


if __name__ == "__main__":
Expand Down