Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions flopy4/mf6/codec/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any

import numpy as np
import xarray as xr
from modflow_devtools.dfns.schema.field import Field
from modflow_devtools.dfns.schema.v2 import FieldType
Expand All @@ -12,13 +13,13 @@ def field_type(value: Any) -> FieldType:

if isinstance(value, Field):
return value.type
if isinstance(value, bool):
if isinstance(value, (bool, np.bool)):
return "keyword"
if isinstance(value, int):
if isinstance(value, (int, np.integer)):
return "integer"
if isinstance(value, float):
if isinstance(value, (float, np.floating)):
return "double"
if isinstance(value, str):
if isinstance(value, (str, np.str_)):
return "string"
if isinstance(value, (dict, tuple)):
return "record"
Expand Down
15 changes: 6 additions & 9 deletions flopy4/mf6/gwf/sto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from numpy.typing import NDArray
from xattree import xattree

from flopy4.mf6.constants import LENBOUNDNAME
from flopy4.mf6.converter import structure_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field, path
from flopy4.mf6.utils.grid import update_maxbound
from flopy4.utils import to_path


Expand Down Expand Up @@ -72,17 +74,12 @@ class Sto(Package):
converter=Converter(structure_array, takes_self=True, takes_field=True),
longname="specific yield",
)
steady_state: Optional[NDArray[np.bool_]] = array(
storage: Optional[NDArray[np.str_]] = array(
dtype=f"<U{LENBOUNDNAME}",
block="period",
dims=("nper",),
default=None,
converter=Converter(structure_array, takes_self=True, takes_field=True),
longname="steady state indicator",
)
transient: Optional[NDArray[np.bool_]] = array(
block="period",
dims=("nper",),
default=None,
converter=Converter(structure_array, takes_self=True, takes_field=True),
longname="transient indicator",
on_setattr=update_maxbound,
longname="storage type",
)
71 changes: 28 additions & 43 deletions flopy4/mf6/utils/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any

import numpy as np
import sparse
import xarray as xr
import xugrid as xu
from attrs import fields
Expand All @@ -11,7 +10,7 @@
from xarray.core.indexes import PandasIndex
from xattree import Scalar

from flopy4.mf6.constants import FILL_DNODATA, FILL_INT64
from flopy4.mf6.constants import FILL_INT64


class StructuredGrid(LegacyStructuredGrid):
Expand Down Expand Up @@ -1319,55 +1318,41 @@ def get_coords(grid: LegacyStructuredGrid) -> dict[str, Any]:
return coords


def update_maxbound(instance, attribute, new_value):
"""
Generalized function to update maxbound when period block arrays change.

This function automatically finds all period block arrays in the instance
and calculates maxbound based on the maximum number of non-default values
across all arrays.
def get_default_fill_value(dtype):
if np.issubdtype(dtype, np.integer):
return 0
elif np.issubdtype(dtype, np.floating):
return np.nan
elif dtype.kind in ["U", "S"]:
return ""
else:
return None

Args:
instance: The package instance
attribute: The attribute being set (from attrs on_setattr)
new_value: The new value being set

Returns:
The new_value (unchanged)
def update_maxbound(instance, attribute, new_value):
"""
Recompute maxbound after period block data has changed.
Called by attrs on_setattr hook.
"""

period_arrays = []
instance_fields = fields(instance.__class__)
for f in instance_fields:
if (
bounds = []
for f in fields(instance.__class__):
if not (
f.metadata
and f.metadata.get("block") == "period"
and f.metadata.get("xattree", {}).get("dims")
):
period_arrays.append(f.name)

maxbound_values = []
for array_name in period_arrays:
if attribute and attribute.name == array_name:
array_val = new_value
continue # select period block arrays
if attribute and attribute.name == f.name:
d = new_value
else:
array_val = getattr(instance, array_name, None)

if array_val is not None:
if isinstance(array_val.data, sparse.SparseArray):
# densify if the array is sparse
array_data = array_val.data.todense()
else:
# handle memoryview and other array-likes
array_data = np.asarray(array_val.data)

if array_data.dtype.kind in ["U", "S"]: # String arrays
non_default_count = len(np.where(array_data != "")[0])
else: # Numeric arrays
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])

maxbound_values.append(non_default_count)
if maxbound_values:
instance.maxbound = max(maxbound_values)
d = getattr(instance, f.name, None)
if d is None:
continue
d = np.asarray(d.data) # handle memoryview etc
fill = f.metadata.get("fill", get_default_fill_value(d.dtype))
bounds.append((d != fill).sum())
if bounds:
instance.maxbound = max(bounds)

return new_value
Loading
Loading