diff --git a/src/parcels/_core/spatialhash.py b/src/parcels/_core/spatialhash.py index f7c551c62..c4f4d0a64 100644 --- a/src/parcels/_core/spatialhash.py +++ b/src/parcels/_core/spatialhash.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from parcels._core.index_search import ( @@ -6,6 +8,7 @@ curvilinear_point_in_cell, uxgrid_point_in_cell, ) +from parcels._core.warnings import FieldSetWarning from parcels._python import isinstance_noimport @@ -88,6 +91,23 @@ def __init__( self._zlow = np.min(_zbound, axis=-1) self._zhigh = np.max(_zbound, axis=-1) + degenerate_mask = _find_degenerate_xgrid_faces(x, y, z) + degeneracy_count = np.sum(degenerate_mask) + if degeneracy_count > 0: + degen_locs = np.argwhere(degenerate_mask) # shape (N, 2), columns are (j, i) + max_shown = np.min([degeneracy_count, 5]) + shown = degen_locs[:max_shown] + loc_str = ", ".join(f"(j={loc[0]}, i={loc[1]})" for loc in shown) + warnings.warn( + f"Grid contains {degeneracy_count} degenerate faces that span a large portion of the " + "hash grid. This is most likely due to a mesh that isn't fully defined (e.g., points corresponding to land with lat/lon masked to 0). " + "You may experience runtime crashes due to high memory usage in the hash table or cell lookup failures for particles" + "in the vicinity of these degenerate cells." + f"First degenerate face location(s): {loc_str}.", + FieldSetWarning, + stacklevel=2, + ) + else: # Boundaries of the hash grid are the bounding box of the source grid self._xmin = self._source_grid.lon.min() @@ -483,6 +503,53 @@ def _dilate_bits(n): return n +def _find_degenerate_xgrid_faces(x, y, z, threshold_factor=10): + """Identify faces in structured grids that potentially span large portions of + the underlying hash grid (e.g., due to the mesh being incomplete, with 0.0 stored in missing lon/lat points). Such degenerate faces can result in high memory requirements + for the hash table. + + Detection is based on the maximum great-circle edge length of each cell. A cell + is flagged as degenerate when its longest edge exceeds ``threshold_factor`` multiplied by + the 99th percentile of all edge lengths. + + Parameters + ---------- + x, y, z : ndarray, shape (ny, nx) + Unit-sphere Cartesian coordinates of the grid nodes. + threshold_factor : float, optional + Multiplier applied to the 99th-percentile edge length to set the threshold. + Default is 10. + + Returns + ------- + degenerate : ndarray of bool, shape (ny-1, nx-1) + True for each cell whose maximum edge length exceeds the threshold. + """ + + # Chord length between two sets of points on the unit sphere, shape (ny-1, nx-1) + def _chord(p1, p2): + return np.sqrt(((p1 - p2) ** 2).sum(axis=-1)) + + pts = np.stack([x, y, z], axis=-1) + c00, c01 = pts[:-1, :-1], pts[:-1, 1:] + c10, c11 = pts[1:, :-1], pts[1:, 1:] + + # Maximum chord across all four edges and both diagonals + max_chord = np.maximum.reduce( + [ + _chord(c00, c01), + _chord(c10, c11), + _chord(c00, c10), + _chord(c01, c11), + _chord(c00, c11), + _chord(c01, c10), + ] + ) + + threshold = threshold_factor * np.percentile(max_chord, 99) + return max_chord > threshold + + def quantize_coordinates(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1023): """ Normalize (x, y, z) to [0, 1] over their bounding box, then quantize to 10 bits each (0..1023).