diff --git a/doc/changes/DM-54976.api.md b/doc/changes/DM-54976.api.md new file mode 100644 index 00000000..e1e1dfbd --- /dev/null +++ b/doc/changes/DM-54976.api.md @@ -0,0 +1,3 @@ +Removed the `obs_info` component from `Image`, `Mask`, and `MaskedImage`, in favor of defining it directly on `VisitImage`. + +Fully unified the butler formatters into `lsst.images.formatters.GenericFormatter` and deleted the old ones. diff --git a/python/lsst/images/_backgrounds.py b/python/lsst/images/_backgrounds.py index 123e2b97..6a876aea 100644 --- a/python/lsst/images/_backgrounds.py +++ b/python/lsst/images/_backgrounds.py @@ -20,7 +20,7 @@ import pydantic from .fields import Field, FieldSerializationModel -from .serialization import ArchiveTree, InputArchive, OutputArchive +from .serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive @dataclasses.dataclass(frozen=True) @@ -152,8 +152,9 @@ class BackgroundMapSerializationModel(ArchiveTree): description="Name of the background that was subtracted, or `None` if no background was subtracted.", ) - def deserialize(self, archive: InputArchive[Any]) -> BackgroundMap: - """Read a background map from an archive.""" + def deserialize(self, archive: InputArchive[Any], **kwargs: Any) -> BackgroundMap: + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for BackgroundMap: {set(kwargs.keys())}.") return BackgroundMap( [ Background( diff --git a/python/lsst/images/_color_image.py b/python/lsst/images/_color_image.py index a677efe3..2170bf28 100644 --- a/python/lsst/images/_color_image.py +++ b/python/lsst/images/_color_image.py @@ -25,7 +25,7 @@ from ._geom import Box from ._image import Image, ImageSerializationModel from ._transforms import Projection, ProjectionSerializationModel -from .serialization import ArchiveTree, InputArchive, MetadataValue, OutputArchive +from .serialization import ArchiveTree, InputArchive, InvalidParameterError, MetadataValue, OutputArchive from .utils import is_none @@ -201,7 +201,9 @@ def bbox(self) -> Box: """The bounding box of the image.""" return self.red.bbox - def deserialize(self, archive: InputArchive[Any], *, bbox: Box | None = None) -> ColorImage: + def deserialize( + self, archive: InputArchive[Any], *, bbox: Box | None = None, **kwargs: Any + ) -> ColorImage: """Deserialize a image from an input archive. Parameters @@ -214,6 +216,8 @@ def deserialize(self, archive: InputArchive[Any], *, bbox: Box | None = None) -> bbox Bounding box of a subimage to read instead. """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for ColoImage: {set(kwargs.keys())}.") r = self.red.deserialize(archive, bbox=bbox) g = self.green.deserialize(archive, bbox=bbox) b = self.blue.deserialize(archive, bbox=bbox) diff --git a/python/lsst/images/_image.py b/python/lsst/images/_image.py index 7c3ae9a2..ac2dca55 100644 --- a/python/lsst/images/_image.py +++ b/python/lsst/images/_image.py @@ -24,7 +24,6 @@ import numpy as np import numpy.typing as npt import pydantic -from astro_metadata_translator import ObservationInfo from lsst.resources import ResourcePath, ResourcePathExpression @@ -39,6 +38,7 @@ InlineArrayModel, InlineArrayQuantityModel, InputArchive, + InvalidParameterError, MetadataValue, OutputArchive, no_header_updates, @@ -72,9 +72,6 @@ class Image(GeneralizedImage): Units for the image's pixel values. projection Projection that maps the pixel grid to the sky. - obs_info - General information about the associated observation in standardized - form. metadata Arbitrary flexible metadata to associate with the image. @@ -111,7 +108,6 @@ def __init__( dtype: npt.DTypeLike | None = None, unit: astropy.units.UnitBase | None = None, projection: Projection[Any] | None = None, - obs_info: ObservationInfo | None = None, metadata: dict[str, MetadataValue] | None = None, ): super().__init__(metadata) @@ -140,7 +136,6 @@ def __init__( self._bbox: Box = bbox self._unit = unit self._projection = projection - self._obs_info = obs_info @property def array(self) -> np.ndarray: @@ -192,26 +187,13 @@ def projection(self) -> Projection[Any] | None: """ return self._projection - @property - def obs_info(self) -> ObservationInfo | None: - """General information about the associated observation in standard - form. (`~astro_metadata_translator.ObservationInfo` | `None`). - """ - return self._obs_info - def __getitem__(self, bbox: Box | EllipsisType) -> Image: if bbox is ...: return self super().__getitem__(bbox) indices = bbox.slice_within(self._bbox) return self._transfer_metadata( - Image( - self._array[indices], - bbox=bbox, - unit=self._unit, - projection=self._projection, - obs_info=self._obs_info, - ), + Image(self._array[indices], bbox=bbox, unit=self._unit, projection=self._projection), bbox=bbox, ) @@ -235,13 +217,7 @@ def __eq__(self, other: object) -> bool: def copy(self) -> Image: return self._transfer_metadata( - Image( - self._array.copy(), - bbox=self._bbox, - unit=self._unit, - projection=self._projection, - obs_info=self._obs_info, - ), + Image(self._array.copy(), bbox=self._bbox, unit=self._unit, projection=self._projection), copy=True, ) @@ -251,7 +227,6 @@ def view( unit: astropy.units.UnitBase | None | EllipsisType = ..., projection: Projection | None | EllipsisType = ..., start: Sequence[int] | EllipsisType = ..., - obs_info: ObservationInfo | None | EllipsisType = ..., ) -> Image: """Make a view of the image, with optional updates.""" if unit is ...: @@ -260,11 +235,7 @@ def view( projection = self._projection if start is ...: start = self._bbox.start - if obs_info is ...: - obs_info = self._obs_info - return self._transfer_metadata( - Image(self._array, start=start, unit=unit, projection=projection, obs_info=obs_info) - ) + return self._transfer_metadata(Image(self._array, start=start, unit=unit, projection=projection)) def serialize[P: pydantic.BaseModel]( self, @@ -272,7 +243,6 @@ def serialize[P: pydantic.BaseModel]( *, update_header: Callable[[astropy.io.fits.Header], None] = no_header_updates, save_projection: bool = True, - save_obs_info: bool = True, add_offset_wcs: str | None = "A", ) -> ImageSerializationModel[P]: """Serialize the image to an output archive. @@ -291,10 +261,6 @@ def serialize[P: pydantic.BaseModel]( is one. This does not affect whether a FITS WCS corresponding to the projection is written (it always is, if available, and if ``add_offset_wcs`` is not ``" "``). - save_obs_info - If `True`, save the - `~astro_metadata_translator.ObservationInfo` attached to the - image, if there is one. add_offset_wcs A FITS WCS single-character suffix to use when adding a linear WCS that maps the FITS array to the logical pixel coordinates @@ -326,7 +292,6 @@ def _update_header(header: astropy.io.fits.Header) -> None: data=data, start=list(self.bbox.start), projection=serialized_projection, - obs_info=self._obs_info if save_obs_info else None, metadata=self.metadata, ) @@ -532,11 +497,6 @@ class ImageSerializationModel[P: pydantic.BaseModel](ArchiveTree): exclude_if=is_none, description="Projection that maps the logical pixel grid onto the sky.", ) - obs_info: ObservationInfo | None = pydantic.Field( - default=None, - exclude_if=is_none, - description="Standardized description of image metadata", - ) @property def bbox(self) -> Box: @@ -554,6 +514,7 @@ def deserialize( *, bbox: Box | None = None, strip_header: Callable[[astropy.io.fits.Header], None] = no_header_updates, + **kwargs: Any, ) -> Image: """Deserialize an image from an input archive. @@ -567,7 +528,12 @@ def deserialize( A callable that strips out any FITS header cards added by the ``update_header`` argument in the corresponding call to `Image.serialize`. + **kwargs + Unsupported keyword arguments are accepted only to provide better + error messages (raising `serialization.InvalidParameterError`). """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for Image: {set(kwargs.keys())}.") array_model: ArrayReferenceModel | InlineArrayModel unit: astropy.units.UnitBase | None = None if isinstance(self.data, ArrayReferenceQuantityModel | InlineArrayQuantityModel): @@ -590,5 +556,9 @@ def _strip_header(header: astropy.io.fits.Header) -> None: start=self.start if bbox is None else bbox.start, unit=unit, projection=projection, - obs_info=self.obs_info, )._finish_deserialize(self) + + def deserialize_component(self, component: str, archive: InputArchive[Any], **kwargs: Any) -> Any: + if kwargs: + raise InvalidParameterError(f"Unsupported parameters for Image components: {set(kwargs.keys())}.") + return super().deserialize_component(component, archive) diff --git a/python/lsst/images/_mask.py b/python/lsst/images/_mask.py index b59c7fe5..03427c60 100644 --- a/python/lsst/images/_mask.py +++ b/python/lsst/images/_mask.py @@ -32,7 +32,6 @@ import numpy as np import numpy.typing as npt import pydantic -from astro_metadata_translator import ObservationInfo from lsst.resources import ResourcePath, ResourcePathExpression @@ -47,6 +46,7 @@ InlineArrayModel, InputArchive, IntegerType, + InvalidParameterError, MetadataValue, NumberType, OutputArchive, @@ -322,9 +322,6 @@ class Mask(GeneralizedImage): include the last dimension of the array. projection Projection that maps the pixel grid to the sky. - obs_info - General information about the associated observation in standardized - form. metadata Arbitrary flexible metadata to associate with the mask. @@ -351,7 +348,6 @@ def __init__( start: Sequence[int] | None = None, shape: Sequence[int] | None = None, projection: Projection | None = None, - obs_info: ObservationInfo | None = None, metadata: dict[str, MetadataValue] | None = None, ): super().__init__(metadata) @@ -386,7 +382,6 @@ def __init__( self._bbox: Box = bbox self._schema: MaskSchema = schema self._projection = projection - self._obs_info = obs_info @property def array(self) -> np.ndarray: @@ -428,13 +423,6 @@ def projection(self) -> Projection[Any] | None: """ return self._projection - @property - def obs_info(self) -> ObservationInfo | None: - """General information about the associated observation in standard - form. (`~astro_metadata_translator.ObservationInfo` | `None`). - """ - return self._obs_info - def __getitem__(self, bbox: Box | EllipsisType) -> Mask: if bbox is ...: return self @@ -444,6 +432,7 @@ def __getitem__(self, bbox: Box | EllipsisType) -> Mask: self.array[bbox.y.slice_within(self._bbox.y), bbox.x.slice_within(self._bbox.x), :], bbox=bbox, schema=self.schema, + projection=self._projection, ), bbox=bbox, ) @@ -471,13 +460,7 @@ def __eq__(self, other: object) -> bool: def copy(self) -> Mask: """Deep-copy the mask and metadata.""" return self._transfer_metadata( - Mask( - self._array.copy(), - bbox=self._bbox, - schema=self._schema, - projection=self._projection, - obs_info=self._obs_info, - ), + Mask(self._array.copy(), bbox=self._bbox, schema=self._schema, projection=self._projection), copy=True, ) @@ -487,7 +470,6 @@ def view( schema: MaskSchema | EllipsisType = ..., projection: Projection | None | EllipsisType = ..., start: Sequence[int] | EllipsisType = ..., - obs_info: ObservationInfo | None | EllipsisType = ..., ) -> Mask: """Make a view of the mask, with optional updates. @@ -505,11 +487,7 @@ def view( projection = self._projection if start is ...: start = self._bbox.start - if obs_info is ...: - obs_info = self._obs_info - return self._transfer_metadata( - Mask(self._array, start=start, schema=schema, projection=projection, obs_info=obs_info) - ) + return self._transfer_metadata(Mask(self._array, start=start, schema=schema, projection=projection)) def update(self, other: Mask) -> None: """Update ``self`` to include all common mask values set in ``other``. @@ -595,7 +573,6 @@ def serialize[P: pydantic.BaseModel]( *, update_header: Callable[[astropy.io.fits.Header], None] = no_header_updates, save_projection: bool = True, - save_obs_info: bool = True, add_offset_wcs: str | None = "A", ) -> MaskSerializationModel[P]: """Serialize the mask to an output archive. @@ -615,10 +592,6 @@ def serialize[P: pydantic.BaseModel]( is one. This does not affect whether a FITS WCS corresponding to the projection is written (it always is, if available, and if ``add_offset_wcs`` is not ``" "``). - save_obs_info - If `True`, save the - `~astro_metadata_translator.ObservationInfo` attached to the - image, if there is one. add_offset_wcs A FITS WCS single-character suffix to use when adding a linear WCS that maps the FITS array to the logical pixel coordinates @@ -639,9 +612,7 @@ def serialize[P: pydantic.BaseModel]( else: data = [] for schema_2d in self.schema.split(np.int32): - mask_2d = Mask( - 0, bbox=self.bbox, schema=schema_2d, projection=self._projection, obs_info=self._obs_info - ) + mask_2d = Mask(0, bbox=self.bbox, schema=schema_2d, projection=self._projection) mask_2d.update(self) data.append( mask_2d._serialize_2d(archive, update_header=update_header, add_offset_wcs=add_offset_wcs) @@ -657,7 +628,6 @@ def serialize[P: pydantic.BaseModel]( planes=list(self.schema), dtype=serialized_dtype, projection=serialized_projection, - obs_info=self._obs_info if save_obs_info else None, metadata=self.metadata, ) @@ -892,11 +862,6 @@ class MaskSerializationModel[P: pydantic.BaseModel](ArchiveTree): exclude_if=is_none, description="Projection that maps the logical pixel grid onto the sky.", ) - obs_info: ObservationInfo | None = pydantic.Field( - default=None, - exclude_if=is_none, - description="Standardized description of image metadata", - ) @property def bbox(self) -> Box: @@ -912,6 +877,7 @@ def deserialize( *, bbox: Box | None = None, strip_header: Callable[[astropy.io.fits.Header], None] = no_header_updates, + **kwargs: Any, ) -> Mask: """Deserialize a mask from an input archive. @@ -925,7 +891,12 @@ def deserialize( A callable that strips out any FITS header cards added by the ``update_header`` argument in the corresponding call to `Mask.serialize`. + **kwargs + Unsupported keyword arguments are accepted only to provide better + error messages (raising `serialization.InvalidParameterError`). """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for Mask: {set(kwargs.keys())}.") slices: tuple[slice, ...] | EllipsisType = ... if bbox is not None: slices = bbox.slice_within(self.bbox) @@ -939,20 +910,8 @@ def deserialize( storage_slices = slices if slices is ... else (slice(None),) + slices array = archive.get_array(self.data[0], strip_header=strip_header, slices=storage_slices) array = np.moveaxis(array, 0, -1) - return Mask( - array, - schema=schema, - bbox=bbox, - projection=projection, - obs_info=self.obs_info, - )._finish_deserialize(self) - result = Mask( - 0, - schema=schema, - bbox=bbox, - projection=projection, - obs_info=self.obs_info, - ) + return Mask(array, schema=schema, bbox=bbox, projection=projection)._finish_deserialize(self) + result = Mask(0, schema=schema, bbox=bbox, projection=projection) schemas_2d = schema.split(np.int32) if len(schemas_2d) != len(self.data): raise ArchiveReadError( @@ -983,6 +942,11 @@ def _strip_header(header: astropy.io.fits.Header) -> None: array_2d = archive.get_array(ref, strip_header=_strip_header, slices=slices) return Mask(array_2d[:, :, np.newaxis], schema=schema_2d, start=start) + def deserialize_component(self, component: str, archive: InputArchive[Any], **kwargs: Any) -> Any: + if kwargs: + raise InvalidParameterError(f"Unsupported parameters for Mask components: {set(kwargs.keys())}.") + return super().deserialize_component(component, archive) + def _archive_prefers_native_mask_arrays(archive: OutputArchive[Any]) -> bool: """Return whether an archive wants masks in their native 3-D layout.""" diff --git a/python/lsst/images/_masked_image.py b/python/lsst/images/_masked_image.py index af753854..e5abc69f 100644 --- a/python/lsst/images/_masked_image.py +++ b/python/lsst/images/_masked_image.py @@ -24,7 +24,6 @@ import astropy.wcs import numpy as np import pydantic -from astro_metadata_translator import ObservationInfo from lsst.resources import ResourcePath, ResourcePathExpression @@ -34,7 +33,7 @@ from ._image import Image, ImageSerializationModel from ._mask import Mask, MaskPlane, MaskSchema, MaskSerializationModel from ._transforms import Frame, Projection, ProjectionSerializationModel -from .serialization import ArchiveTree, InputArchive, MetadataValue, OutputArchive +from .serialization import ArchiveTree, InputArchive, InvalidParameterError, MetadataValue, OutputArchive from .utils import is_none @@ -61,9 +60,6 @@ class MaskedImage(GeneralizedImage): not provided. projection Projection that maps the pixel grid to the sky. - obs_info - General information about the associated observation in standardized - form. metadata Arbitrary flexible metadata to associate with the image. """ @@ -76,7 +72,6 @@ def __init__( variance: Image | None = None, mask_schema: MaskSchema | None = None, projection: Projection | None = None, - obs_info: ObservationInfo | None = None, metadata: dict[str, MetadataValue] | None = None, ): super().__init__(metadata) @@ -84,20 +79,16 @@ def __init__( projection = image.projection else: image = image.view(projection=projection) - if obs_info is None: - obs_info = image.obs_info - else: - image = image.view(obs_info=obs_info) if mask is None: if mask_schema is None: raise TypeError("'mask_schema' must be provided if 'mask' is not.") - mask = Mask(schema=mask_schema, bbox=image.bbox, projection=projection, obs_info=obs_info) + mask = Mask(schema=mask_schema, bbox=image.bbox, projection=projection) elif mask_schema is not None: raise TypeError("'mask_schema' may not be provided if 'mask' is.") else: if image.bbox != mask.bbox: raise ValueError(f"Image ({image.bbox}) and mask ({mask.bbox}) bboxes do not agree.") - mask = mask.view(projection=projection, obs_info=obs_info) + mask = mask.view(projection=projection) if variance is None: variance = Image( 1.0, @@ -105,12 +96,11 @@ def __init__( bbox=image.bbox, unit=None if image.unit is None else image.unit**2, projection=projection, - obs_info=obs_info, ) else: if image.bbox != variance.bbox: raise ValueError(f"Image ({image.bbox}) and variance ({variance.bbox}) bboxes do not agree.") - variance = variance.view(projection=projection, obs_info=obs_info) + variance = variance.view(projection=projection) if image.unit is None: if variance.unit is not None: raise ValueError(f"Image has no units but variance does ({variance.unit}).") @@ -158,13 +148,6 @@ def projection(self) -> Projection[Any] | None: """ return self._image.projection - @property - def obs_info(self) -> ObservationInfo | None: - """General information about the associated observation in standard - form. (`~astro_metadata_translator.ObservationInfo` | `None`). - """ - return self._image.obs_info - def __getitem__(self, bbox: Box | EllipsisType) -> MaskedImage: if bbox is ...: return self @@ -206,13 +189,13 @@ def serialize(self, archive: OutputArchive[Any]) -> MaskedImageSerializationMode Archive to write to. """ serialized_image = archive.serialize_direct( - "image", functools.partial(self.image.serialize, save_projection=False, save_obs_info=False) + "image", functools.partial(self.image.serialize, save_projection=False) ) serialized_mask = archive.serialize_direct( - "mask", functools.partial(self.mask.serialize, save_projection=False, save_obs_info=False) + "mask", functools.partial(self.mask.serialize, save_projection=False) ) serialized_variance = archive.serialize_direct( - "variance", functools.partial(self.variance.serialize, save_projection=False, save_obs_info=False) + "variance", functools.partial(self.variance.serialize, save_projection=False) ) serialized_projection = ( archive.serialize_direct("projection", self.projection.serialize) @@ -224,7 +207,6 @@ def serialize(self, archive: OutputArchive[Any]) -> MaskedImageSerializationMode mask=serialized_mask, variance=serialized_variance, projection=serialized_projection, - obs_info=self.obs_info, metadata=self.metadata, ) @@ -495,18 +477,15 @@ class MaskedImageSerializationModel[P: pydantic.BaseModel](ArchiveTree): exclude_if=is_none, description="Projection that maps the pixel grid to the sky.", ) - obs_info: ObservationInfo | None = pydantic.Field( - default=None, - exclude_if=is_none, - description="Standardized description of image metadata", - ) @property def bbox(self) -> Box: """The bounding box of the image.""" return self.image.bbox - def deserialize(self, archive: InputArchive[Any], *, bbox: Box | None = None) -> MaskedImage: + def deserialize( + self, archive: InputArchive[Any], *, bbox: Box | None = None, **kwargs: Any + ) -> MaskedImage: """Deserialize an image from an input archive. Parameters @@ -515,11 +494,23 @@ def deserialize(self, archive: InputArchive[Any], *, bbox: Box | None = None) -> Archive to read from. bbox Bounding box of a subimage to read instead. + **kwargs + Unsupported keyword arguments are accepted only to provide better + error messages (raising `serialization.InvalidParameterError`). """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for MaskedImage: {set(kwargs.keys())}.") image = self.image.deserialize(archive, bbox=bbox) mask = self.mask.deserialize(archive, bbox=bbox) variance = self.variance.deserialize(archive, bbox=bbox) projection = self.projection.deserialize(archive) if self.projection is not None else None - return MaskedImage( - image, mask=mask, variance=variance, projection=projection, obs_info=self.obs_info - )._finish_deserialize(self) + return MaskedImage(image, mask=mask, variance=variance, projection=projection)._finish_deserialize( + self + ) + + def deserialize_component(self, component: str, archive: InputArchive[Any], **kwargs: Any) -> Any: + if component == "bbox" and kwargs: + raise InvalidParameterError( + f"Unrecognized parameters for MaskedImage.bbox: {set(kwargs.keys())}." + ) + return super().deserialize_component(component, archive, **kwargs) diff --git a/python/lsst/images/_transforms/_camera_frame_set.py b/python/lsst/images/_transforms/_camera_frame_set.py index 664ec496..38bc61a7 100644 --- a/python/lsst/images/_transforms/_camera_frame_set.py +++ b/python/lsst/images/_transforms/_camera_frame_set.py @@ -19,7 +19,7 @@ import pydantic from .._geom import Bounds, Box -from ..serialization import ArchiveTree, InputArchive, OutputArchive +from ..serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive from . import _ast as astshim from . import _frames # use this import style to facilitate pattern matching from ._frame_set import FrameLookupError, FrameSet @@ -215,12 +215,17 @@ class CameraFrameSetSerializationModel(ArchiveTree): description="A serialized Starlink AST FrameSet, using the AST native encoding." ) - def deserialize(self, archive: InputArchive[Any]) -> CameraFrameSet: + def deserialize(self, archive: InputArchive[Any], **kwargs: Any) -> CameraFrameSet: """Deserialize a frame set from an archive. Parameters ---------- archive Archive to read from. + **kwargs + Unsupported keyword arguments are accepted only to provide better + error messages (raising `serialization.InvalidParameterError`). """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for CameraFrameSet: {set(kwargs.keys())}.") return CameraFrameSet(self.instrument, astshim.FrameSet.fromString(self.ast)) diff --git a/python/lsst/images/_transforms/_projection.py b/python/lsst/images/_transforms/_projection.py index 385fe490..ca7d25e6 100644 --- a/python/lsst/images/_transforms/_projection.py +++ b/python/lsst/images/_transforms/_projection.py @@ -24,7 +24,7 @@ from astropy.wcs.wcsapi import BaseLowLevelWCS, HighLevelWCSMixin from .._geom import XY, YX, Bounds, Box -from ..serialization import ArchiveTree, InputArchive, OutputArchive +from ..serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive from ..utils import is_none from . import _ast as astshim from ._frames import Frame, SkyFrame @@ -476,14 +476,19 @@ class ProjectionSerializationModel[P: pydantic.BaseModel](ArchiveTree): exclude_if=is_none, ) - def deserialize(self, archive: InputArchive[P]) -> Projection[Any]: + def deserialize(self, archive: InputArchive[P], **kwargs: Any) -> Projection[Any]: """Deserialize a projection from an archive. Parameters ---------- archive Archive to read from. + **kwargs + Unsupported keyword arguments are accepted only to provide better + error messages (raising `serialization.InvalidParameterError`). """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for Projection: {set(kwargs.keys())}.") pixel_to_sky = self.pixel_to_sky.deserialize(archive) fits_approximation = ( self.fits_approximation.deserialize(archive) if self.fits_approximation is not None else None diff --git a/python/lsst/images/_transforms/_transform.py b/python/lsst/images/_transforms/_transform.py index 19c42da0..973e6ae8 100644 --- a/python/lsst/images/_transforms/_transform.py +++ b/python/lsst/images/_transforms/_transform.py @@ -28,7 +28,7 @@ from .._concrete_bounds import SerializableBounds from .._geom import XY, Bounds, Box -from ..serialization import ArchiveReadError, ArchiveTree, InputArchive, OutputArchive +from ..serialization import ArchiveReadError, ArchiveTree, InputArchive, InvalidParameterError, OutputArchive from . import _ast as astshim from ._frames import Frame, SerializableFrame, SkyFrame @@ -561,14 +561,19 @@ class TransformSerializationModel[P: pydantic.BaseModel](ArchiveTree): ), ) - def deserialize(self, archive: InputArchive[P]) -> Transform[Any, Any]: + def deserialize(self, archive: InputArchive[P], **kwargs: Any) -> Transform[Any, Any]: """Deserialize a transform from an archive. Parameters ---------- archive Archive to read from. + **kwargs + Unsupported keyword arguments are accepted only to provide better + error messages (raising `serialization.InvalidParameterError`). """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for Transform: {set(kwargs.keys())}.") if len(self.frames) != len(self.bounds): raise ArchiveReadError( f"Inconsistent lengths for 'frames' ({len(self.frames)}) and 'bounds' ({len(self.bounds)})." diff --git a/python/lsst/images/_visit_image.py b/python/lsst/images/_visit_image.py index 99f4e914..fd4cdf37 100644 --- a/python/lsst/images/_visit_image.py +++ b/python/lsst/images/_visit_image.py @@ -52,7 +52,7 @@ PSFExSerializationModel, PSFExWrapper, ) -from .serialization import ArchiveReadError, InputArchive, MetadataValue, OutputArchive +from .serialization import ArchiveReadError, InputArchive, InvalidParameterError, MetadataValue, OutputArchive from .utils import is_none @@ -199,15 +199,15 @@ def __init__( variance=variance, mask_schema=mask_schema, projection=projection, - obs_info=obs_info, metadata=metadata, ) if self.image.unit is None: raise TypeError("The image component of a VisitImage must have units.") if self.image.projection is None: raise TypeError("The projection component of a VisitImage cannot be None.") - if self.image.obs_info is None: + if obs_info is None: raise TypeError("The observation info component of a VisitImage cannot be None.") + self._obs_info = obs_info if not isinstance(self.image.projection.pixel_frame, DetectorFrame): raise TypeError("The projection's pixel frame must be a DetectorFrame for VisitImage.") if summary_stats is None: @@ -246,9 +246,7 @@ def obs_info(self) -> ObservationInfo: """General information about this observation in standard form. (`~astro_metadata_translator.ObservationInfo`). """ - obs_info = self.image.obs_info - assert obs_info is not None - return obs_info + return self._obs_info @property def astropy_wcs(self) -> ProjectionAstropyView: @@ -496,7 +494,6 @@ def serialize(self, archive: OutputArchive[Any]) -> VisitImageSerializationModel f"Cannot serialize VisitImage with unrecognized PSF type {type(self._psf).__name__}." ) assert masked_image_model.projection is not None, "VisitImage always has a projection." - assert masked_image_model.obs_info is not None, "VisitImage always has observation info." serialized_detector = archive.serialize_direct("detector", self._detector.serialize) serialized_photometric_scaling = ( archive.serialize_direct("photometric_scaling", self._photometric_scaling.serialize) @@ -513,7 +510,7 @@ def serialize(self, archive: OutputArchive[Any]) -> VisitImageSerializationModel mask=masked_image_model.mask, variance=masked_image_model.variance, projection=masked_image_model.projection, - obs_info=masked_image_model.obs_info, + obs_info=self.obs_info, photometric_scaling=serialized_photometric_scaling, psf=serialized_psf, summary_stats=self.summary_stats, @@ -914,7 +911,7 @@ def read_legacy( # type: ignore[override] if component is not None: # This is the image, mask, or variance; attach the projection and # obs_info and return - return from_masked_image.view(projection=projection, obs_info=obs_info) + return from_masked_image.view(projection=projection) legacy_polygon = reader.readValidPolygon() result = VisitImage( from_masked_image.image, @@ -982,9 +979,17 @@ class VisitImageSerializationModel[P: pydantic.BaseModel](MaskedImageSerializati description="Background models associated with this image.", ) - def deserialize(self, archive: InputArchive[Any], *, bbox: Box | None = None) -> VisitImage: + def deserialize( + self, archive: InputArchive[Any], *, bbox: Box | None = None, **kwargs: Any + ) -> VisitImage: + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for VisitImage: {set(kwargs.keys())}.") masked_image = super().deserialize(archive, bbox=bbox) - psf = self.deserialize_psf(archive) + try: + psf = self.psf.deserialize(archive) + except ArchiveReadError as err: + # Defer this until/unless somebody actually asks for the PSF. + psf = err detector = self.detector.deserialize(archive) aperture_corrections = self.aperture_corrections.deserialize(archive) photometric_scaling = ( @@ -996,7 +1001,7 @@ def deserialize(self, archive: InputArchive[Any], *, bbox: Box | None = None) -> variance=masked_image.variance, psf=psf, projection=masked_image.projection, - obs_info=masked_image.obs_info, + obs_info=self.obs_info, summary_stats=self.summary_stats, detector=detector, aperture_corrections=aperture_corrections, @@ -1005,14 +1010,12 @@ def deserialize(self, archive: InputArchive[Any], *, bbox: Box | None = None) -> backgrounds=self.backgrounds.deserialize(archive), )._finish_deserialize(self) - def deserialize_psf(self, archive: InputArchive[Any]) -> PointSpreadFunction | ArchiveReadError: - """Finish deserializing the PSF model, or *return* any exception - raised in the attempt. - """ - try: - return self.psf.deserialize(archive) - except ArchiveReadError as err: - return err + def deserialize_component(self, component: str, archive: InputArchive[Any], **kwargs: Any) -> Any: + if kwargs and component not in ("image", "mask", "variance"): + raise InvalidParameterError( + f"Unsupported parameters for VisitImage component {component}: {set(kwargs.keys())}." + ) + return super().deserialize_component(component, archive) def _extract_or_check_value[T]( diff --git a/python/lsst/images/aperture_corrections.py b/python/lsst/images/aperture_corrections.py index d4ed9937..d2c9dc7b 100644 --- a/python/lsst/images/aperture_corrections.py +++ b/python/lsst/images/aperture_corrections.py @@ -22,7 +22,7 @@ import pydantic from .fields import Field, FieldSerializationModel, field_from_legacy -from .serialization import ArchiveTree, InputArchive, OutputArchive +from .serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive if TYPE_CHECKING: try: @@ -83,6 +83,8 @@ def serialize( ) return result - def deserialize(self, archive: InputArchive[Any]) -> ApertureCorrectionMap: + def deserialize(self, archive: InputArchive[Any], **kwargs: Any) -> ApertureCorrectionMap: """Read an aperture correction map from an archive.""" + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for Image: {set(kwargs.keys())}.") return {name: field.deserialize(archive) for name, field in self.fields.items()} diff --git a/python/lsst/images/cameras.py b/python/lsst/images/cameras.py index edef7df4..863b8091 100644 --- a/python/lsst/images/cameras.py +++ b/python/lsst/images/cameras.py @@ -45,6 +45,7 @@ ArchiveTree, InlineArray, InputArchive, + InvalidParameterError, OutputArchive, Quantity, ) @@ -699,7 +700,9 @@ class DetectorSerializationModel(ArchiveTree): visit: int | None = pydantic.Field(description="ID of the visit this detector is associated with.") - def deserialize(self, archive: InputArchive[Any], frames: CameraFrameSet | None = None) -> Detector: + def deserialize( + self, archive: InputArchive[Any], frames: CameraFrameSet | None = None, **kwargs: Any + ) -> Detector: """Deserialize this detector from an archive. Parameters @@ -710,6 +713,8 @@ def deserialize(self, archive: InputArchive[Any], frames: CameraFrameSet | None Coordinate systems and transforms to use instead of what is saved in ``model``. Must be provided if ``model.frames`` is `None`. """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for Detector: {set(kwargs.keys())}.") if frames is None: if self.frames is None: raise ArchiveReadError( diff --git a/python/lsst/images/cells/_coadd.py b/python/lsst/images/cells/_coadd.py index 9c300aaa..f573e408 100644 --- a/python/lsst/images/cells/_coadd.py +++ b/python/lsst/images/cells/_coadd.py @@ -30,7 +30,7 @@ from .._mask import Mask, MaskPlane, MaskSchema, MaskSerializationModel from .._masked_image import MaskedImage, MaskedImageSerializationModel from .._transforms import Projection, ProjectionSerializationModel, TractFrame -from ..serialization import ArchiveReadError, InputArchive, OutputArchive +from ..serialization import InputArchive, InvalidParameterError, OutputArchive from ._provenance import CoaddProvenance, CoaddProvenanceSerializationModel from ._psf import CellPointSpreadFunction, CellPointSpreadFunctionSerializationModel @@ -449,6 +449,7 @@ def deserialize( # type: ignore[override] *, bbox: Box | None = None, provenance: bool = True, + **kwargs: Any, ) -> CellCoadd: """Deserialize an image from an input archive. @@ -460,7 +461,12 @@ def deserialize( # type: ignore[override] Bounding box of a subimage to read instead. provenance Whether to read and attach provenance information. + **kwargs + Unsupported keyword arguments are accepted only to provide better + error messages (raising `.serialization.InvalidParameterError`). """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for CellCoadd: {set(kwargs.keys())}.") masked_image = super().deserialize(archive, bbox=bbox) mask_fractions = { k.removeprefix("mask_fractions/"): v.deserialize(archive) for k, v in self.mask_fractions.items() @@ -488,12 +494,13 @@ def deserialize( # type: ignore[override] backgrounds=backgrounds, )._finish_deserialize(self) - def deserialize_psf(self, archive: InputArchive[Any], bbox: Box | None = None) -> CellPointSpreadFunction: - """Finish deserializing the PSF model.""" - return self.psf.deserialize(archive, bbox=bbox) - - def deserialize_provenance(self, archive: InputArchive[Any]) -> CoaddProvenance: - """Finish deserializing the provenance information.""" - if self.provenance is not None: - return self.provenance.deserialize(archive) - raise ArchiveReadError("No coadd provenance stored in this file.") + def deserialize_component(self, component: str, archive: InputArchive[Any], **kwargs: Any) -> Any: + match component: + case "mask_fractions": + return { + name: image_model.deserialize(archive, **kwargs) + for name, image_model in self.mask_fractions.items() + } + case "noise_realizations": + return [image_model.deserialize(archive, **kwargs) for image_model in self.noise_realizations] + return super().deserialize_component(component, archive, **kwargs) diff --git a/python/lsst/images/cells/_provenance.py b/python/lsst/images/cells/_provenance.py index 631a7d80..d233caf7 100644 --- a/python/lsst/images/cells/_provenance.py +++ b/python/lsst/images/cells/_provenance.py @@ -23,7 +23,7 @@ from .._cell_grid import CellIJ from .._polygon import Polygon -from ..serialization import ArchiveTree, InputArchive, OutputArchive, TableModel +from ..serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive, TableModel if TYPE_CHECKING: try: @@ -244,7 +244,7 @@ class CoaddProvenanceSerializationModel(ArchiveTree): inputs: TableModel = pydantic.Field(description="Table of all inputs to the coadd.") contributions: TableModel = pydantic.Field(description="Table of per-cell contributions to the coadd.") - def deserialize(self, archive: InputArchive[Any]) -> CoaddProvenance: + def deserialize(self, archive: InputArchive[Any], **kwargs: Any) -> CoaddProvenance: """Deserialize a provenance from an input archive. Parameters @@ -259,6 +259,8 @@ def deserialize(self, archive: InputArchive[Any]) -> CoaddProvenance: had from doing this during deserialization (the table data is not ordered by cell, and hence there's read-slicing we can do). """ + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for CoaddProvenance: {set(kwargs.keys())}.") inputs = archive.get_table(self.inputs) contributions = archive.get_table(self.contributions) CoaddProvenanceSerializationModel._fix_str_for_deserialization( diff --git a/python/lsst/images/cells/_psf.py b/python/lsst/images/cells/_psf.py index 464cab72..602e5873 100644 --- a/python/lsst/images/cells/_psf.py +++ b/python/lsst/images/cells/_psf.py @@ -23,7 +23,13 @@ from .._geom import YX, Bounds, BoundsError, Box from .._image import Image from ..psfs import PointSpreadFunction -from ..serialization import ArchiveTree, ArrayReferenceModel, InputArchive, OutputArchive +from ..serialization import ( + ArchiveTree, + ArrayReferenceModel, + InputArchive, + InvalidParameterError, + OutputArchive, +) from ..utils import round_half_up if TYPE_CHECKING: @@ -215,7 +221,13 @@ class CellPointSpreadFunctionSerializationModel(ArchiveTree): ) ) - def deserialize(self, archive: InputArchive[Any], *, bbox: Box | None = None) -> CellPointSpreadFunction: + def deserialize( + self, archive: InputArchive[Any], *, bbox: Box | None = None, **kwargs: Any + ) -> CellPointSpreadFunction: + if kwargs: + raise InvalidParameterError( + f"Unrecognized parameters for CellPointSpreadFunction: {set(kwargs.keys())}." + ) bounds = self.bounds if bbox is not None: bounds, slices = CellPointSpreadFunction._subset_impl(bounds, bbox) diff --git a/python/lsst/images/fields/_chebyshev.py b/python/lsst/images/fields/_chebyshev.py index f7e9a35d..1c8e60a8 100644 --- a/python/lsst/images/fields/_chebyshev.py +++ b/python/lsst/images/fields/_chebyshev.py @@ -23,7 +23,7 @@ from .._concrete_bounds import SerializableBounds from .._geom import YX, Bounds, Box from .._image import Image -from ..serialization import ArchiveTree, InlineArray, InputArchive, OutputArchive, Unit +from ..serialization import ArchiveTree, InlineArray, InputArchive, InvalidParameterError, OutputArchive, Unit from ._base import BaseField if TYPE_CHECKING: @@ -397,6 +397,8 @@ class ChebyshevFieldSerializationModel(ArchiveTree): field_type: Literal["CHEBYSHEV"] = "CHEBYSHEV" - def deserialize(self, archive: InputArchive) -> ChebyshevField: + def deserialize(self, archive: InputArchive, **kwargs: Any) -> ChebyshevField: """Deserialize the Chebyshev field from an input archive.""" + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for ChebyshevField: {set(kwargs.keys())}.") return ChebyshevField(self.bounds.deserialize(), self.coefficients, unit=self.unit) diff --git a/python/lsst/images/fields/_product.py b/python/lsst/images/fields/_product.py index 9dbc0356..7b053c5f 100644 --- a/python/lsst/images/fields/_product.py +++ b/python/lsst/images/fields/_product.py @@ -22,7 +22,7 @@ from .._geom import Bounds, Box from .._image import Image -from ..serialization import ArchiveTree, InputArchive, OutputArchive +from ..serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive from ._base import BaseField if TYPE_CHECKING: @@ -161,6 +161,8 @@ class ProductFieldSerializationModel(ArchiveTree): field_type: Literal["PRODUCT"] = "PRODUCT" - def deserialize(self, archive: InputArchive) -> ProductField: + def deserialize(self, archive: InputArchive, **kwargs: Any) -> ProductField: """Deserialize the field from an input archive.""" + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for ProductField: {set(kwargs.keys())}.") return ProductField([operand.deserialize(archive) for operand in self.operands]) diff --git a/python/lsst/images/fields/_spline.py b/python/lsst/images/fields/_spline.py index 6325ad64..6e7e84e9 100644 --- a/python/lsst/images/fields/_spline.py +++ b/python/lsst/images/fields/_spline.py @@ -23,7 +23,15 @@ from .._concrete_bounds import SerializableBounds from .._geom import Bounds, Box from .._image import Image -from ..serialization import ArchiveTree, ArrayReferenceModel, InlineArray, InputArchive, OutputArchive, Unit +from ..serialization import ( + ArchiveTree, + ArrayReferenceModel, + InlineArray, + InputArchive, + InvalidParameterError, + OutputArchive, + Unit, +) from ._base import BaseField if TYPE_CHECKING: @@ -275,8 +283,10 @@ class SplineFieldSerializationModel(ArchiveTree): field_type: Literal["SPLINE"] = "SPLINE" - def deserialize(self, archive: InputArchive) -> SplineField: + def deserialize(self, archive: InputArchive, **kwargs: Any) -> SplineField: """Deserialize the spline field from an input archive.""" + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for SplineField: {set(kwargs.keys())}.") return SplineField( self.bounds.deserialize(), archive.get_array(self.data), diff --git a/python/lsst/images/fields/_sum.py b/python/lsst/images/fields/_sum.py index 589d0555..52865cbb 100644 --- a/python/lsst/images/fields/_sum.py +++ b/python/lsst/images/fields/_sum.py @@ -22,7 +22,7 @@ from .._geom import Bounds, Box from .._image import Image -from ..serialization import ArchiveTree, InputArchive, OutputArchive +from ..serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive from ._base import BaseField if TYPE_CHECKING: @@ -156,6 +156,8 @@ class SumFieldSerializationModel(ArchiveTree): field_type: Literal["SUM"] = "SUM" - def deserialize(self, archive: InputArchive) -> SumField: + def deserialize(self, archive: InputArchive, **kwargs: Any) -> SumField: """Deserialize the field from an input archive.""" + if kwargs: + raise InvalidParameterError(f"Unrecognized parameters for SumField: {set(kwargs.keys())}.") return SumField([operand.deserialize(archive) for operand in self.operands]) diff --git a/python/lsst/images/fits/formatters.py b/python/lsst/images/fits/formatters.py deleted file mode 100644 index c5f44a81..00000000 --- a/python/lsst/images/fits/formatters.py +++ /dev/null @@ -1,84 +0,0 @@ -# This file is part of lsst-images. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# Use of this source code is governed by a 3-clause BSD-style -# license that can be found in the LICENSE file. - -"""Deprecated re-exports of the unified ``lsst.images.formatters`` module. - -These names are kept so that deployed butler configs in -``daf_butler/configs/datastores/formatters.yaml`` continue to work. -Each class is a one-line subclass of the corresponding unified -formatter that emits a `DeprecationWarning` on first instantiation. -""" - -from __future__ import annotations - -__all__ = ( - "CellCoaddFormatter", - "GenericFormatter", - "ImageFormatter", - "MaskedImageFormatter", - "VisitImageFormatter", -) - -import warnings -from typing import Any - -from .. import formatters as _unified - - -def _warn(name: str) -> None: - warnings.warn( - f"lsst.images.fits.formatters.{name} is deprecated; " - f"use lsst.images.formatters.{name} instead. The fits-only " - f"formatter forwards to the unified one and will be removed " - f"in a future release.", - DeprecationWarning, - stacklevel=3, - ) - - -class GenericFormatter(_unified.GenericFormatter): - """Deprecated alias for `lsst.images.formatters.GenericFormatter`.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - _warn("GenericFormatter") - super().__init__(*args, **kwargs) - - -class ImageFormatter(_unified.ImageFormatter): - """Deprecated alias for `lsst.images.formatters.ImageFormatter`.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - _warn("ImageFormatter") - super().__init__(*args, **kwargs) - - -class MaskedImageFormatter(_unified.MaskedImageFormatter): - """Deprecated alias for `lsst.images.formatters.MaskedImageFormatter`.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - _warn("MaskedImageFormatter") - super().__init__(*args, **kwargs) - - -class VisitImageFormatter(_unified.VisitImageFormatter): - """Deprecated alias for `lsst.images.formatters.VisitImageFormatter`.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - _warn("VisitImageFormatter") - super().__init__(*args, **kwargs) - - -class CellCoaddFormatter(_unified.CellCoaddFormatter): - """Deprecated alias for `lsst.images.formatters.CellCoaddFormatter`.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - _warn("CellCoaddFormatter") - super().__init__(*args, **kwargs) diff --git a/python/lsst/images/formatters.py b/python/lsst/images/formatters.py index 52f7e229..90962ed9 100644 --- a/python/lsst/images/formatters.py +++ b/python/lsst/images/formatters.py @@ -20,83 +20,22 @@ from __future__ import annotations -__all__ = ( - "CellCoaddFormatter", - "ComponentSentinel", - "GenericFormatter", - "ImageFormatter", - "MaskedImageFormatter", - "VisitImageFormatter", -) +__all__ = ("GenericFormatter",) -import enum import hashlib import json as _stdlib_json # disambiguates from .json subpackage -from collections.abc import Callable -from dataclasses import dataclass +from collections.abc import Callable, Iterator +from contextlib import contextmanager from typing import Any, ClassVar import astropy.io.fits -from astro_metadata_translator import ObservationInfo from lsst.daf.butler import DatasetProvenance, FormatterV2 from lsst.resources import ResourcePath from . import fits as _fits from . import json as _json -from ._geom import Box -from ._masked_image import MaskedImageSerializationModel -from ._transforms import ProjectionSerializationModel -from ._visit_image import VisitImageSerializationModel -from .fits._common import FitsCompressionOptions -from .fits._common import PointerModel as _FitsPointerModel -from .fits._input_archive import FitsInputArchive as _FitsInputArchive -from .serialization import ButlerInfo - -try: - from . import ndf as _ndf - from .ndf._common import NdfPointerModel as _NdfPointerModel - from .ndf._input_archive import NdfInputArchive as _NdfInputArchive - - _HAVE_NDF = True -except ImportError: # h5py is optional; see ndf/__init__.py - _ndf = None # type: ignore[assignment] - _NdfPointerModel = None # type: ignore[assignment,misc] - _NdfInputArchive = None # type: ignore[assignment,misc] - _HAVE_NDF = False - - -@dataclass(frozen=True) -class _Backend: - """One row of the extension-to-backend lookup table.""" - - read: Callable[..., Any] - write: Callable[..., Any] - input_archive: type | None - pointer_model: type | None - - -_BACKENDS: dict[str, _Backend] = { - ".fits": _Backend( - read=_fits.read, - write=_fits.write, - input_archive=_FitsInputArchive, - pointer_model=_FitsPointerModel, - ), - ".json": _Backend( - read=_json.read, - write=_json.write, - input_archive=None, - pointer_model=None, - ), -} -if _HAVE_NDF: - _BACKENDS[".sdf"] = _Backend( - read=_ndf.read, - write=_ndf.write, - input_archive=_NdfInputArchive, - pointer_model=_NdfPointerModel, - ) +from .serialization import ArchiveTree, ButlerInfo, InputArchive, JsonRef class GenericFormatter(FormatterV2): @@ -143,17 +82,25 @@ def _validate_write_parameters(self) -> None: def write_local_file(self, in_memory_dataset: Any, uri: ResourcePath) -> None: self._validate_write_parameters() ext = self.get_write_extension() - backend = _BACKENDS[ext] butler_info = ButlerInfo( dataset=self.dataset_ref.to_simple(), provenance=self.butler_provenance if self.butler_provenance is not None else DatasetProvenance(), ) kwargs: dict[str, Any] = {"butler_info": butler_info} - if ext == ".fits": - kwargs["update_header"] = self._update_header - kwargs["compression_options"] = self._get_compression_options() - kwargs["compression_seed"] = self._get_compression_seed() - backend.write(in_memory_dataset, uri.ospath, **kwargs) + write_func: Callable[..., ArchiveTree] + match ext: + case ".fits": + kwargs["update_header"] = self._update_header + kwargs["compression_options"] = self._get_compression_options() + kwargs["compression_seed"] = self._get_compression_seed() + write_func = _fits.write + case ".json": + write_func = _json.write + case ".sdf": + from . import ndf as _ndf + + write_func = _ndf.write + write_func(in_memory_dataset, uri.ospath, **kwargs) def add_provenance( self, @@ -184,7 +131,7 @@ def _get_compression_seed(self) -> int: # 10000] range allowed by FITS. return 1 + int.from_bytes(hash_bytes) % 9999 - def _get_compression_options(self) -> dict[str, FitsCompressionOptions]: + def _get_compression_options(self) -> dict[str, _fits.FitsCompressionOptions]: recipe = self.write_parameters.get("recipe", "default") try: config = self.write_recipes[recipe] @@ -193,7 +140,7 @@ def _get_compression_options(self) -> dict[str, FitsCompressionOptions]: # If there's no default recipe just use the software defaults. return {} raise RuntimeError(f"Invalid recipe for GenericFormatter: {recipe!r}.") from None - return {k: FitsCompressionOptions.model_validate(v) for k, v in config.items()} + return {k: _fits.FitsCompressionOptions.model_validate(v) for k, v in config.items()} def _update_header(self, header: astropy.io.fits.Header) -> None: # Logic here largely lifted from lsst.obs.base.utils, which we @@ -222,38 +169,31 @@ def _extension_from_uri(self, uri: ResourcePath) -> str: raise RuntimeError(f"Cannot read {uri}: unsupported extension {ext!r}.") return ext - def read_from_uri( - self, - uri: ResourcePath, - component: str | None = None, - expected_size: int = -1, - ) -> Any: - pytype = self.dataset_ref.datasetType.storageClass.pytype + @contextmanager + def _open_archive_and_tree( + self, uri: ResourcePath, partial: bool + ) -> Iterator[tuple[InputArchive[Any], ArchiveTree]]: + pytype: type[Any] = self.dataset_ref.datasetType.storageClass.pytype ext = self._extension_from_uri(uri) - backend = _BACKENDS[ext] - kwargs = self.file_descriptor.parameters or {} - return backend.read(pytype, uri, **kwargs).deserialized - - -class ComponentSentinel(enum.Enum): - """Special return values from `ImageFormatter.read_component`.""" - - UNRECOGNIZED_COMPONENT = enum.auto() - """Subclasses might still recognise this component.""" - - INVALID_COMPONENT_MODEL = enum.auto() - """Component name is known but the model attribute is missing or - has the wrong type. - """ - - -class ImageFormatter(GenericFormatter): - """Adds component-level read support for image-like types. - - Subclasses override `read_component` to handle additional components - (image/mask/variance for MaskedImage; psf/summary_stats/etc. for - VisitImage). - """ + archive: InputArchive[Any] + match ext: + case ".fits": + tree_type = pytype._get_archive_tree_type(_fits.PointerModel) + with _fits.FitsInputArchive.open(uri, partial=partial) as archive: + tree = archive.get_tree(tree_type) + yield archive, tree + case ".json": + tree_type = pytype._get_archive_tree_type(JsonRef) + tree = tree_type.model_validate_json(ResourcePath(uri).read()) + archive = _json.JsonInputArchive(tree.indirect) + yield archive, tree + case ".sdf": + from . import ndf as _ndf + + tree_type = pytype._get_archive_tree_type(_ndf.NdfPointerModel) + with _ndf.NdfInputArchive.open(uri) as archive: + tree = archive.get_tree(tree_type) + yield archive, tree def read_from_uri( self, @@ -261,154 +201,11 @@ def read_from_uri( component: str | None = None, expected_size: int = -1, ) -> Any: - pytype: Any = self.file_descriptor.storageClass.pytype - ext = self._extension_from_uri(uri) - backend = _BACKENDS[ext] - if component is None: - result = backend.read(pytype, uri, bbox=self.pop_bbox_from_parameters()).deserialized - else: - result = self._read_component_from_uri(component, uri) - self.check_unhandled_parameters() - return result - - def _read_component_from_uri(self, component: str, uri: ResourcePath) -> Any: - ext = self._extension_from_uri(uri) - backend = _BACKENDS[ext] - pytype: Any = self.file_descriptor.storageClass.pytype - if ext == ".json": - obj = backend.read(pytype, uri).deserialized - try: - return getattr(obj, component) - except AttributeError as exc: - raise NotImplementedError(f"Unrecognized component {component!r} for JSON read.") from exc - # FITS/NDF archive path. backend.input_archive and pointer_model are - # typed as `type | None` to allow the JSON row to opt out; here we - # know they are populated. - archive_cls: Any = backend.input_archive - pointer_model: Any = backend.pointer_model - assert archive_cls is not None - assert pointer_model is not None - # FitsInputArchive uses partial=True for component reads; NDF - # has no such kwarg. - open_kwargs = {"partial": True} if ext == ".fits" else {} - with archive_cls.open(uri, **open_kwargs) as archive: - tree_type = pytype._get_archive_tree_type(pointer_model) - tree = archive.get_tree(tree_type) - result = self.read_component(component, tree, archive) - if result is ComponentSentinel.UNRECOGNIZED_COMPONENT: - raise NotImplementedError(f"Unrecognized component {component!r} for {type(self).__name__}.") - if result is ComponentSentinel.INVALID_COMPONENT_MODEL: - raise NotImplementedError( - f"Invalid serialization model for component {component!r} for {type(self).__name__}." - ) - return result - - def pop_bbox_from_parameters(self) -> Box | None: - parameters = self.file_descriptor.parameters or {} - return parameters.pop("bbox", None) - - def check_unhandled_parameters(self) -> None: - parameters = self.file_descriptor.parameters - if parameters: - raise RuntimeError(f"Parameters {list(parameters.keys())} not recognized.") - - def read_component(self, component: str, tree: Any, archive: Any) -> Any: - match component: - case "projection": - if isinstance( - p := getattr(tree, "projection", None), - ProjectionSerializationModel, - ): - return p.deserialize(archive) - return ComponentSentinel.INVALID_COMPONENT_MODEL - case "bbox": - if isinstance(bbox := getattr(tree, "bbox", None), Box): - return bbox - return ComponentSentinel.INVALID_COMPONENT_MODEL - case "obs_info": - if isinstance(oi := getattr(tree, "obs_info", None), ObservationInfo): - return oi - return ComponentSentinel.INVALID_COMPONENT_MODEL - return ComponentSentinel.UNRECOGNIZED_COMPONENT - - -class MaskedImageFormatter(ImageFormatter): - """Adds image/mask/variance component support.""" - - def read_component(self, component: str, tree: Any, archive: Any) -> Any: - match super().read_component(component, tree, archive): - case ComponentSentinel(): - pass - case handled: - return handled - if not isinstance(tree, MaskedImageSerializationModel): - return ComponentSentinel.INVALID_COMPONENT_MODEL - match component: - case "image": - return tree.image.deserialize(archive, bbox=self.pop_bbox_from_parameters()) - case "mask": - return tree.mask.deserialize(archive, bbox=self.pop_bbox_from_parameters()) - case "variance": - return tree.variance.deserialize(archive, bbox=self.pop_bbox_from_parameters()) - return ComponentSentinel.UNRECOGNIZED_COMPONENT - - -class VisitImageFormatter(MaskedImageFormatter): - """Adds psf/summary_stats/detector/aperture_corrections.""" - - def read_component(self, component: str, tree: Any, archive: Any) -> Any: - match super().read_component(component, tree, archive): - case ComponentSentinel(): - pass - case handled: - return handled - if not isinstance(tree, VisitImageSerializationModel): - return ComponentSentinel.INVALID_COMPONENT_MODEL - match component: - case "psf": - # The FITS path uses tree.psf.deserialize; the NDF tree - # exposes deserialize_psf for the same effect. - if hasattr(tree, "deserialize_psf"): - return tree.deserialize_psf(archive) - return tree.psf.deserialize(archive) - case "summary_stats": - return tree.summary_stats - case "detector": - if getattr(tree, "detector", None) is not None: - return tree.detector.deserialize(archive) - return ComponentSentinel.INVALID_COMPONENT_MODEL - case "aperture_corrections": - return tree.aperture_corrections.deserialize(archive) - case "photometric_scaling": - return ( - tree.photometric_scaling.deserialize(archive) - if tree.photometric_scaling is not None - else None - ) - case "backgrounds": - return tree.backgrounds.deserialize(archive) - return ComponentSentinel.UNRECOGNIZED_COMPONENT - - -class CellCoaddFormatter(MaskedImageFormatter): - """Adds CellCoadd-specific psf and provenance components.""" - - def read_component(self, component: str, tree: Any, archive: Any) -> Any: - from .cells import CellCoaddSerializationModel # avoid cycles - - match super().read_component(component, tree, archive): - case ComponentSentinel(): - pass - case handled: - return handled - if not isinstance(tree, CellCoaddSerializationModel): - return ComponentSentinel.INVALID_COMPONENT_MODEL - match component: - case "psf": - bbox = self.pop_bbox_from_parameters() - return tree.deserialize_psf(archive, bbox=bbox) - case "provenance": - return tree.deserialize_provenance(archive) - case "backgrounds": - return tree.backgrounds.deserialize(archive) - return ComponentSentinel.UNRECOGNIZED_COMPONENT + kwargs = self.file_descriptor.parameters or {} + with self._open_archive_and_tree(uri, partial=bool(kwargs or component)) as (archive, tree): + if component is None: + result = tree.deserialize(archive, **kwargs) + result._opaque_metadata = archive.get_opaque_metadata() + return result + else: + return tree.deserialize_component(component, archive, **kwargs) diff --git a/python/lsst/images/json/formatters.py b/python/lsst/images/json/formatters.py deleted file mode 100644 index babe7ac9..00000000 --- a/python/lsst/images/json/formatters.py +++ /dev/null @@ -1,54 +0,0 @@ -# This file is part of lsst-images. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# Use of this source code is governed by a 3-clause BSD-style -# license that can be found in the LICENSE file. - -"""Deprecated re-export of the unified ``lsst.images.formatters`` module. - -`lsst.images.json.formatters.GenericFormatter` exists so that deployed -butler configs that point Transform and Projection storage classes at -this path keep working. The shim overrides ``default_extension`` to -``.json`` so writes default to JSON output when no ``format`` write -parameter is supplied. -""" - -from __future__ import annotations - -__all__ = ("GenericFormatter",) - -import warnings -from typing import Any, ClassVar - -from .. import formatters as _unified - - -def _warn(name: str) -> None: - warnings.warn( - f"lsst.images.json.formatters.{name} is deprecated; " - f"use lsst.images.formatters.{name} with format='json' " - f"instead. The json-only formatter forwards to the unified " - f"one and will be removed in a future release.", - DeprecationWarning, - stacklevel=3, - ) - - -class GenericFormatter(_unified.GenericFormatter): - """Deprecated alias for `lsst.images.formatters.GenericFormatter`. - - Defaults to ``.json`` output so existing butler configs that point - Transform/Projection storage classes here keep producing JSON - without specifying a ``format`` write parameter. - """ - - default_extension: ClassVar[str] = ".json" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - _warn("GenericFormatter") - super().__init__(*args, **kwargs) diff --git a/python/lsst/images/psfs/_gaussian.py b/python/lsst/images/psfs/_gaussian.py index 5e60f94a..856a0516 100644 --- a/python/lsst/images/psfs/_gaussian.py +++ b/python/lsst/images/psfs/_gaussian.py @@ -127,7 +127,13 @@ class GaussianPSFSerializationModel(serialization.ArchiveTree): description="The bounds object that represents the PSF's validity region." ) - def deserialize(self, archive: serialization.InputArchive[Any]) -> GaussianPointSpreadFunction: + def deserialize( + self, archive: serialization.InputArchive[Any], **kwargs: Any + ) -> GaussianPointSpreadFunction: + if kwargs: + raise serialization.InvalidParameterError( + f"Unrecognized parameters for GaussianPointSpreadFunction: {set(kwargs.keys())}." + ) return GaussianPointSpreadFunction( sigma=self.sigma, bounds=self.bounds.deserialize(), stamp_size=self.stamp_size ) diff --git a/python/lsst/images/psfs/_legacy.py b/python/lsst/images/psfs/_legacy.py index 93c785c7..83f5d8d8 100644 --- a/python/lsst/images/psfs/_legacy.py +++ b/python/lsst/images/psfs/_legacy.py @@ -174,12 +174,16 @@ class PSFExSerializationModel(serialization.ArchiveTree): model_config = pydantic.ConfigDict(ser_json_inf_nan="constants") - def deserialize(self, archive: serialization.InputArchive[Any]) -> PSFExWrapper: + def deserialize(self, archive: serialization.InputArchive[Any], **kwargs: Any) -> PSFExWrapper: """Deserialize the PSF from an archive. This method is intended to be usable as the callback function passed to `.serialization.InputArchive.deserialize_pointer`. """ + if kwargs: + raise serialization.InvalidParameterError( + f"Unrecognized parameters for PsfExWrapper: {set(kwargs.keys())}." + ) try: from lsst.meas.extensions.psfex import PsfexPsf, PsfexPsfSerializationData except ImportError: diff --git a/python/lsst/images/psfs/_piff.py b/python/lsst/images/psfs/_piff.py index 62ce8f2d..343edad7 100644 --- a/python/lsst/images/psfs/_piff.py +++ b/python/lsst/images/psfs/_piff.py @@ -224,12 +224,16 @@ class PiffSerializationModel(serialization.ArchiveTree): description="The bounds object that represents the PSF's validity region." ) - def deserialize(self, archive: serialization.InputArchive[Any]) -> PiffWrapper: + def deserialize(self, archive: serialization.InputArchive[Any], **kwargs: Any) -> PiffWrapper: """Deserialize the PSF from an archive. This method is intended to be usable as the callback function passed to `.serialization.InputArchive.deserialize_pointer`. """ + if kwargs: + raise serialization.InvalidParameterError( + f"Unrecognized parameters for PiffWrapper: {set(kwargs.keys())}." + ) try: from piff import PSF from piff.config import PiffLogger diff --git a/python/lsst/images/serialization/_common.py b/python/lsst/images/serialization/_common.py index 99c6a844..3eff2c22 100644 --- a/python/lsst/images/serialization/_common.py +++ b/python/lsst/images/serialization/_common.py @@ -15,6 +15,8 @@ "ArchiveReadError", "ArchiveTree", "ButlerInfo", + "InvalidComponentError", + "InvalidParameterError", "JsonRef", "MetadataValue", "OpaqueArchiveMetadata", @@ -93,10 +95,82 @@ class ArchiveTree( ) @abstractmethod - def deserialize(self, archive: InputArchive[Any]) -> Any: - """Return the in-memory object that was serialized to this tree.""" + def deserialize(self, archive: InputArchive[Any], **kwargs: Any) -> Any: + """Return the in-memory object that was serialized to this tree. + + Parameters + ---------- + archive + The input archive to read from. + **kwargs + Additional keyword arguments specific to this type. + + Raises + ------ + ~lsst.images.serialization.InvalidParameterError + Raised for unsupported ``**kwargs``. + + Notes + ----- + Subclass implementations may take additional keyword-only arguments. + Callers that invoke this method without knowing what those might be + should catch `TypeError` and re-raise as + `~lsst.images.serialization.InvalidParameterError` if they pass + additional keyword arguments. + """ raise NotImplementedError() + def deserialize_component(self, component: str, archive: InputArchive[Any], **kwargs: Any) -> Any: + """Return a component in-memory object that was serialized to this + tree. + + Parameters + ---------- + component + Name of the component to read. + archive + The input archive to read from. + **kwargs + Additional keyword arguments specific to this type. + + Raises + ------ + ~lsst.images.serialization.InvalidComponentError + Raise if ``component`` is not recognized. + ~lsst.images.serialization.InvalidParameterError + Raised for unsupported ``**kwargs``. + + Notes + ----- + The default implementation for this method tries to get an attribute + with the component's name from ``self``, and then: + + - returns `None` if it is `None`; + - calls `deserialize` on that object if it is also an + `~lsst.images.serialization.ArchiveTree`; + - returns it directly otherwise. + + If there is no such attribute, it raises + `~lsst.images.serialization.InvalidComponentError`. + + ``**kwargs`` are forwarded to component `deserialize` methods, but + are otherwise not checked. Subclasses are generally expected to + implement this method to do that checking and handle any components + for which the other will not work, and then delegate to `super` at + the end. + """ + try: + component_model = getattr(self, component) + except AttributeError: + raise InvalidComponentError( + f"Component {component!r} is not recognized by {type(self).__name__}." + ) from None + if component_model is None: + return None + if isinstance(component_model, ArchiveTree): + return component_model.deserialize(archive, **kwargs) + return component_model + class ReadResult[T: Any](NamedTuple): """A struct that can be used to return both a deserialized object and @@ -118,6 +192,19 @@ class ArchiveReadError(RuntimeError): """Exception raised when the contents of an archive cannot be read.""" +class InvalidParameterError(ArchiveReadError): + """Exception raised by `ArchiveTree.deserialize` or + `ArchiveTree.deserialize_component` when passed an invalid keyword + argument. + """ + + +class InvalidComponentError(ArchiveReadError): + """Exception `ArchiveTree.deserialize_component` when passed an invalid + component name. + """ + + class OpaqueArchiveMetadata(Protocol): """Interface for opaque archive metadata. diff --git a/tests/test_cell_coadd.py b/tests/test_cell_coadd.py index 7e20b65b..8e47af0d 100644 --- a/tests/test_cell_coadd.py +++ b/tests/test_cell_coadd.py @@ -13,15 +13,13 @@ import os import pickle -import tempfile import unittest from typing import Any import numpy as np -from lsst.images import YX, Box, Interval, fits, get_legacy_deep_coadd_mask_planes +from lsst.images import YX, Box, Interval, get_legacy_deep_coadd_mask_planes from lsst.images.cells import CellCoadd, CellIJ -from lsst.images.formatters import CellCoaddFormatter from lsst.images.tests import ( DP2_COADD_DATA_ID, DP2_COADD_MISSING_CELL, @@ -29,9 +27,7 @@ assert_masked_images_equal, assert_psfs_equal, compare_cell_coadd_to_legacy, - make_test_formatter, ) -from lsst.resources import ResourcePath DATA_DIR = os.environ.get("TESTDATA_IMAGES_DIR", None) @@ -152,27 +148,5 @@ def test_roundtrip(self) -> None: ) -@unittest.skipUnless(DATA_DIR is not None, "TESTDATA_IMAGES_DIR is not in the environment.") -class CellCoaddFormatterComponentReadTestCase(unittest.TestCase): - """CellCoaddFormatter reads psf/provenance components from FITS. - - Reuses `CellCoaddTestCase`'s class-level fixture rather than - inheriting from it, so the parent's tests don't run twice. - """ - - @classmethod - def setUpClass(cls) -> None: - CellCoaddTestCase.setUpClass() - cls.cell_coadd = CellCoaddTestCase.cell_coadd - - def test_fits_psf_component(self): - with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False) as tmp: - tmp.close() - fits.write(self.cell_coadd, tmp.name) - formatter = make_test_formatter(CellCoaddFormatter, CellCoadd) - psf = formatter._read_component_from_uri("psf", ResourcePath(tmp.name)) - self.assertIsNotNone(psf) - - if __name__ == "__main__": unittest.main() diff --git a/tests/test_formatters.py b/tests/test_formatters.py deleted file mode 100644 index 738e09d6..00000000 --- a/tests/test_formatters.py +++ /dev/null @@ -1,379 +0,0 @@ -# This file is part of lsst-images. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# Use of this source code is governed by a 3-clause BSD-style -# license that can be found in the LICENSE file. - -from __future__ import annotations - -import tempfile -import unittest -import warnings - -import numpy as np - -from lsst.images import ( - Box, - Image, - MaskedImage, - MaskPlane, - MaskSchema, - VisitImage, - fits, -) -from lsst.images import json as images_json -from lsst.images.fits import formatters as fits_shim -from lsst.images.fits._common import PointerModel -from lsst.images.fits._input_archive import FitsInputArchive -from lsst.images.formatters import ( - _BACKENDS, - GenericFormatter, - ImageFormatter, - MaskedImageFormatter, - VisitImageFormatter, -) -from lsst.images.json import formatters as json_shim -from lsst.images.tests import make_test_formatter -from lsst.resources import ResourcePath - -try: - from lsst.images import ndf - from lsst.images.ndf._common import NdfPointerModel - from lsst.images.ndf._input_archive import NdfInputArchive - - HAVE_H5PY = True -except ImportError: - HAVE_H5PY = False - - -class BackendsTableTestCase(unittest.TestCase): - """The private _BACKENDS table wires extension -> read/write/archive.""" - - def test_table_keys(self): - expected = {".fits", ".json"} - if HAVE_H5PY: - expected.add(".sdf") - self.assertEqual(set(_BACKENDS), expected) - - def test_fits_backend_wires_fits_read_write(self): - backend = _BACKENDS[".fits"] - self.assertIs(backend.read, fits.read) - self.assertIs(backend.write, fits.write) - self.assertIs(backend.input_archive, FitsInputArchive) - self.assertIs(backend.pointer_model, PointerModel) - - @unittest.skipUnless(HAVE_H5PY, "h5py is not installed") - def test_sdf_backend_wires_ndf_read_write(self): - backend = _BACKENDS[".sdf"] - self.assertIs(backend.read, ndf.read) - self.assertIs(backend.write, ndf.write) - self.assertIs(backend.input_archive, NdfInputArchive) - self.assertIs(backend.pointer_model, NdfPointerModel) - - def test_json_backend_wires_json_read_write_no_archive(self): - backend = _BACKENDS[".json"] - self.assertIs(backend.read, images_json.read) - self.assertIs(backend.write, images_json.write) - self.assertIsNone(backend.input_archive) - self.assertIsNone(backend.pointer_model) - - -class GetWriteExtensionTestCase(unittest.TestCase): - """`get_write_extension` reads the `format` write parameter.""" - - def _make_formatter(self, write_parameters: dict[str, str] | None = None): - return make_test_formatter(GenericFormatter, Image, write_parameters=write_parameters) - - def test_default_returns_fits(self): - formatter = self._make_formatter() - self.assertEqual(formatter.get_write_extension(), ".fits") - - def test_explicit_fits(self): - formatter = self._make_formatter({"format": "fits"}) - self.assertEqual(formatter.get_write_extension(), ".fits") - - def test_explicit_json(self): - formatter = self._make_formatter({"format": "json"}) - self.assertEqual(formatter.get_write_extension(), ".json") - - @unittest.skipUnless(HAVE_H5PY, "h5py is not installed") - def test_explicit_sdf(self): - formatter = self._make_formatter({"format": "sdf"}) - self.assertEqual(formatter.get_write_extension(), ".sdf") - - def test_unknown_format_raises(self): - formatter = self._make_formatter({"format": "pickle"}) - with self.assertRaisesRegex(RuntimeError, "is not supported"): - formatter.get_write_extension() - - def test_recipe_with_non_fits_format_raises(self): - # `recipe` is FITS-only; using it with format=json must error. - formatter = self._make_formatter({"format": "json", "recipe": "default"}) - with self.assertRaisesRegex(RuntimeError, "only valid for FITS"): - formatter._validate_write_parameters() - - -class ExtensionFromUriTestCase(unittest.TestCase): - """`read_from_uri` routes based on `uri.getExtension()`.""" - - def _make_formatter(self): - return make_test_formatter(GenericFormatter, Image) - - def test_fits(self): - formatter = self._make_formatter() - uri = ResourcePath("/tmp/x.fits") - self.assertEqual(formatter._extension_from_uri(uri), ".fits") - - @unittest.skipUnless(HAVE_H5PY, "h5py is not installed") - def test_sdf(self): - formatter = self._make_formatter() - uri = ResourcePath("/tmp/x.sdf") - self.assertEqual(formatter._extension_from_uri(uri), ".sdf") - - def test_json(self): - formatter = self._make_formatter() - uri = ResourcePath("/tmp/x.json") - self.assertEqual(formatter._extension_from_uri(uri), ".json") - - def test_unknown(self): - formatter = self._make_formatter() - uri = ResourcePath("/tmp/x.pickle") - with self.assertRaisesRegex(RuntimeError, "unsupported extension"): - formatter._extension_from_uri(uri) - - def test_compressed_fits_unsupported(self): - # We don't claim to handle .fits.gz; getExtension returns - # '.fits.gz' and the lookup misses. - formatter = self._make_formatter() - uri = ResourcePath("/tmp/x.fits.gz") - with self.assertRaisesRegex(RuntimeError, "unsupported extension"): - formatter._extension_from_uri(uri) - - -class ImageFormatterFullReadTestCase(unittest.TestCase): - """`read_from_uri(component=None)` round-trips each backend.""" - - def _make_image(self): - return Image( - np.arange(20, dtype=np.float32).reshape(4, 5), - bbox=Box.factory[10:14, 20:25], - ) - - def test_fits_full_read(self): - image = self._make_image() - with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False) as tmp: - tmp.close() - fits.write(image, tmp.name) - formatter = make_test_formatter(ImageFormatter, Image) - result = formatter.read_from_uri(ResourcePath(tmp.name)) - np.testing.assert_array_equal(result.array, image.array) - - @unittest.skipUnless(HAVE_H5PY, "h5py is not installed") - def test_sdf_full_read(self): - image = self._make_image() - with tempfile.NamedTemporaryFile(suffix=".sdf", delete_on_close=False) as tmp: - tmp.close() - ndf.write(image, tmp.name) - formatter = make_test_formatter(ImageFormatter, Image) - result = formatter.read_from_uri(ResourcePath(tmp.name)) - np.testing.assert_array_equal(result.array, image.array) - - def test_json_full_read(self): - image = self._make_image() - with tempfile.NamedTemporaryFile(suffix=".json", delete_on_close=False) as tmp: - tmp.close() - images_json.write(image, tmp.name) - formatter = make_test_formatter(ImageFormatter, Image) - result = formatter.read_from_uri(ResourcePath(tmp.name)) - np.testing.assert_array_equal(result.array, image.array) - - -class ImageFormatterComponentReadTestCase(unittest.TestCase): - """ImageFormatter routes component reads per extension.""" - - def _make_image(self): - return Image( - np.arange(20, dtype=np.float32).reshape(4, 5), - bbox=Box.factory[10:14, 20:25], - ) - - def test_fits_bbox_component(self): - image = self._make_image() - with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False) as tmp: - tmp.close() - fits.write(image, tmp.name) - formatter = make_test_formatter(ImageFormatter, Image) - bbox = formatter._read_component_from_uri("bbox", ResourcePath(tmp.name)) - self.assertEqual(bbox, image.bbox) - - @unittest.skipUnless(HAVE_H5PY, "h5py is not installed") - def test_sdf_bbox_component(self): - image = self._make_image() - with tempfile.NamedTemporaryFile(suffix=".sdf", delete_on_close=False) as tmp: - tmp.close() - ndf.write(image, tmp.name) - formatter = make_test_formatter(ImageFormatter, Image) - bbox = formatter._read_component_from_uri("bbox", ResourcePath(tmp.name)) - self.assertEqual(bbox, image.bbox) - - def test_json_bbox_component_via_whole_object(self): - image = self._make_image() - with tempfile.NamedTemporaryFile(suffix=".json", delete_on_close=False) as tmp: - tmp.close() - images_json.write(image, tmp.name) - formatter = make_test_formatter(ImageFormatter, Image) - bbox = formatter._read_component_from_uri("bbox", ResourcePath(tmp.name)) - self.assertEqual(bbox, image.bbox) - - def test_json_unknown_component_raises(self): - image = self._make_image() - with tempfile.NamedTemporaryFile(suffix=".json", delete_on_close=False) as tmp: - tmp.close() - images_json.write(image, tmp.name) - formatter = make_test_formatter(ImageFormatter, Image) - with self.assertRaises(NotImplementedError): - formatter._read_component_from_uri("nonexistent", ResourcePath(tmp.name)) - - -class MaskedImageFormatterComponentReadTestCase(unittest.TestCase): - """MaskedImageFormatter routes image/mask/variance per extension.""" - - def _make_masked_image(self): - rng = np.random.default_rng(11) - return MaskedImage( - Image(rng.normal(100.0, 8.0, size=(10, 12)), start=(0, 0)), - mask_schema=MaskSchema([MaskPlane("BAD", "bad pixel")]), - ) - - def test_fits_image_component(self): - mi = self._make_masked_image() - with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False) as tmp: - tmp.close() - fits.write(mi, tmp.name) - formatter = make_test_formatter(MaskedImageFormatter, MaskedImage) - image = formatter._read_component_from_uri("image", ResourcePath(tmp.name)) - self.assertEqual(image.bbox, mi.image.bbox) - - @unittest.skipUnless(HAVE_H5PY, "h5py is not installed") - def test_sdf_mask_component(self): - mi = self._make_masked_image() - with tempfile.NamedTemporaryFile(suffix=".sdf", delete_on_close=False) as tmp: - tmp.close() - ndf.write(mi, tmp.name) - formatter = make_test_formatter(MaskedImageFormatter, MaskedImage) - mask = formatter._read_component_from_uri("mask", ResourcePath(tmp.name)) - self.assertEqual(mask.bbox, mi.mask.bbox) - - def test_json_variance_component_via_whole_object(self): - mi = self._make_masked_image() - with tempfile.NamedTemporaryFile(suffix=".json", delete_on_close=False) as tmp: - tmp.close() - images_json.write(mi, tmp.name) - formatter = make_test_formatter(MaskedImageFormatter, MaskedImage) - variance = formatter._read_component_from_uri("variance", ResourcePath(tmp.name)) - self.assertEqual(variance.bbox, mi.variance.bbox) - - -class VisitImageFormatterComponentReadTestCase(unittest.TestCase): - """VisitImageFormatter reads VisitImage-specific components.""" - - def _make_visit_image(self): - # Reuse the existing test helper from tests/test_visit_image.py. - # Pytest places the tests directory on sys.path, so import the - # sibling module by its bare name. - from test_visit_image import VisitImageTestCase # local import - - VisitImageTestCase.setUpClass() - return VisitImageTestCase.visit_image - - def test_fits_summary_stats_component(self): - vi = self._make_visit_image() - with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False) as tmp: - tmp.close() - fits.write(vi, tmp.name) - formatter = make_test_formatter(VisitImageFormatter, VisitImage) - summary = formatter._read_component_from_uri("summary_stats", ResourcePath(tmp.name)) - self.assertEqual(summary, vi.summary_stats) - - @unittest.skipUnless(HAVE_H5PY, "h5py is not installed") - def test_sdf_psf_component(self): - vi = self._make_visit_image() - with tempfile.NamedTemporaryFile(suffix=".sdf", delete_on_close=False) as tmp: - tmp.close() - ndf.write(vi, tmp.name) - formatter = make_test_formatter(VisitImageFormatter, VisitImage) - psf = formatter._read_component_from_uri("psf", ResourcePath(tmp.name)) - self.assertEqual(type(psf), type(vi.psf)) - - def test_json_aperture_corrections_via_whole_object(self): - vi = self._make_visit_image() - with tempfile.NamedTemporaryFile(suffix=".json", delete_on_close=False) as tmp: - tmp.close() - images_json.write(vi, tmp.name) - formatter = make_test_formatter(VisitImageFormatter, VisitImage) - ap = formatter._read_component_from_uri("aperture_corrections", ResourcePath(tmp.name)) - # ChebyshevField has no __eq__; compare keys and types. - self.assertEqual(ap.keys(), vi.aperture_corrections.keys()) - for k, v in vi.aperture_corrections.items(): - self.assertEqual(type(ap[k]), type(v)) - - -class FitsDeprecationShimTestCase(unittest.TestCase): - """lsst.images.fits.formatters is a deprecation shim.""" - - def test_image_formatter_warns(self): - with warnings.catch_warnings(record=True) as recorded: - warnings.simplefilter("always") - make_test_formatter(fits_shim.ImageFormatter, Image) - self.assertTrue( - any( - issubclass(w.category, DeprecationWarning) - and "fits.formatters.ImageFormatter is deprecated" in str(w.message) - for w in recorded - ), - f"No deprecation warning observed; got: {[str(w.message) for w in recorded]}", - ) - - def test_subclass_is_unified_class(self): - from lsst.images import formatters as unified - - self.assertTrue(issubclass(fits_shim.GenericFormatter, unified.GenericFormatter)) - self.assertTrue(issubclass(fits_shim.ImageFormatter, unified.ImageFormatter)) - self.assertTrue(issubclass(fits_shim.MaskedImageFormatter, unified.MaskedImageFormatter)) - self.assertTrue(issubclass(fits_shim.VisitImageFormatter, unified.VisitImageFormatter)) - self.assertTrue(issubclass(fits_shim.CellCoaddFormatter, unified.CellCoaddFormatter)) - - -class JsonDeprecationShimTestCase(unittest.TestCase): - """lsst.images.json.formatters is a deprecation shim. - - The shim defaults to ``.json`` output. - """ - - def test_generic_formatter_warns(self): - with warnings.catch_warnings(record=True) as recorded: - warnings.simplefilter("always") - make_test_formatter(json_shim.GenericFormatter, Image) - self.assertTrue( - any( - issubclass(w.category, DeprecationWarning) - and "json.formatters.GenericFormatter is deprecated" in str(w.message) - for w in recorded - ) - ) - - def test_default_extension_is_json(self): - self.assertEqual(json_shim.GenericFormatter.default_extension, ".json") - - def test_default_write_extension_is_json(self): - formatter = make_test_formatter(json_shim.GenericFormatter, Image) - self.assertEqual(formatter.get_write_extension(), ".json") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_image.py b/tests/test_image.py index eea7d7fc..9b20762e 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -17,7 +17,6 @@ import astropy.io.fits import astropy.units as u import numpy as np -from astro_metadata_translator import ObservationInfo import lsst.utils.tests from lsst.images import Box, DetectorFrame, Image @@ -168,7 +167,6 @@ def test_read_write(self): """ data = np.array([[1.0, 2.0, np.nan, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]) md = {"int": 1, "float": 42.0, "bool": False, "long string header": "This is a string"} - obsinfo = ObservationInfo(telescope="Simonyi", instrument="LSSTCam", relative_humidity=23.5) det_frame = DetectorFrame(instrument="Inst", visit=1234, detector=1, bbox=Box.factory[1:4096, 1:4096]) rng = np.random.default_rng(500) projection = make_random_projection(rng, det_frame, Box.factory[1:4096, 1:4096]) @@ -177,7 +175,6 @@ def test_read_write(self): data, unit=u.dn, metadata=md, - obs_info=obsinfo, bbox=Box.factory[-2:1, 3:7], projection=projection, ) @@ -189,7 +186,6 @@ def test_read_write(self): self.assertEqual(new, image) # __eq__ does not test all components. - self.assertEqual(new.obs_info, image.obs_info) self.assertEqual(new.metadata, image.metadata) self.maxDiff = None assert_projections_equal(self, new.projection, image.projection, expect_identity=False) diff --git a/tests/test_mask.py b/tests/test_mask.py index df92cfdd..340f4e2d 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -15,7 +15,6 @@ import unittest import numpy as np -from astro_metadata_translator import ObservationInfo import lsst.utils.tests from lsst.images import Box, Mask, MaskPlane, MaskSchema, get_legacy_visit_image_mask_planes @@ -88,13 +87,11 @@ def test_basics(self) -> None: schema=schema, bbox=bbox, metadata={"four_and_a_half": 4.5}, - obs_info=ObservationInfo(instrument="LSSTCam"), ) self.assertIs(mask[...], mask) self.assertEqual(mask.__eq__(42), NotImplemented) self.assertEqual(mask, mask) - self.assertEqual(mask.obs_info.instrument, "LSSTCam") self.maxDiff = None self.assertEqual( str(mask), @@ -136,7 +133,6 @@ def test_read_write(self) -> None: schema=schema, bbox=bbox, metadata={"four_and_a_half": 4.5}, - obs_info=ObservationInfo(instrument="LSSTCam"), ) with lsst.utils.tests.getTempFilePath(".fits") as tmpFile: mask.write_fits(tmpFile) @@ -145,8 +141,6 @@ def test_read_write(self) -> None: self.assertEqual(new, mask) # __eq__ ignores metadata. self.assertEqual(new.metadata["four_and_a_half"], 4.5) - self.assertEqual(new.obs_info.instrument, "LSSTCam") - self.assertEqual(new.obs_info, mask.obs_info) self.assertEqual(new.metadata, mask.metadata) def test_serialize_multi(self) -> None: diff --git a/tests/test_masked_image.py b/tests/test_masked_image.py index 9b44271a..09b73882 100644 --- a/tests/test_masked_image.py +++ b/tests/test_masked_image.py @@ -18,7 +18,6 @@ import astropy.io.fits import astropy.units as u import numpy as np -from astro_metadata_translator import ObservationInfo from lsst.images import Box, Image, MaskedImage, MaskPlane, MaskSchema, get_legacy_visit_image_mask_planes from lsst.images.fits import FitsCompressionOptions @@ -45,7 +44,6 @@ class MaskedImageTestCase(unittest.TestCase): def setUp(self) -> None: self.maxDiff = None self.rng = np.random.default_rng(500) - self.obs_info = ObservationInfo(instrument="LSSTCam", detector_num=4) self.masked_image = MaskedImage( Image(self.rng.normal(100.0, 8.0, size=(200, 251)), dtype=np.float64, unit=u.nJy, start=(5, 8)), mask_schema=MaskSchema( @@ -55,7 +53,6 @@ def setUp(self) -> None: ] ), metadata={"fifty": "5 * 10"}, - obs_info=self.obs_info, ) self.masked_image.mask.array |= np.multiply.outer( self.masked_image.image.array < 102.0, @@ -78,7 +75,6 @@ def test_construction(self) -> None: self.assertEqual(self.masked_image.unit, u.nJy) self.assertEqual(self.masked_image.variance.unit, u.nJy**2) self.assertEqual(self.masked_image.metadata, {"fifty": "5 * 10"}) - self.assertEqual(self.masked_image.obs_info.instrument, "LSSTCam") # The checks below are subject to the vagaries of the RNG, but we want # the seed to be such that they all pass, or other tests will be # weaker. @@ -220,7 +216,6 @@ def test_round_trip_ndf_incompatible_mask(self): start=(0, 0), ), mask_schema=MaskSchema(planes), - obs_info=self.obs_info, ) wide.variance.array = rng.normal(64.0, 0.5, size=wide.bbox.shape) with RoundtripNdf(self, wide) as roundtrip: @@ -239,7 +234,6 @@ def test_round_trip_ndf_many_plane_mask(self): start=(0, 0), ), mask_schema=MaskSchema(planes), - obs_info=self.obs_info, ) wide.mask.set("P0", wide.image.array > 100.0) wide.mask.set("P17", wide.image.array < 95.0) diff --git a/tests/test_visit_image.py b/tests/test_visit_image.py index 50cc11d4..ca88ffcc 100644 --- a/tests/test_visit_image.py +++ b/tests/test_visit_image.py @@ -403,16 +403,13 @@ def test_component_reads(self) -> None: assert_projections_equal(self, proj, visit.projection, expect_identity=False) image = VisitImage.read_legacy(self.filename, component="image") self.assertEqual(image, visit.image) - self.check_legacy_obs_info(image.obs_info) assert_projections_equal(self, proj, image.projection, expect_identity=False) variance = VisitImage.read_legacy(self.filename, component="variance") self.assertEqual(variance, visit.variance) assert_projections_equal(self, proj, variance.projection, expect_identity=False) - self.check_legacy_obs_info(variance.obs_info) mask = VisitImage.read_legacy(self.filename, component="mask") self.assertEqual(mask, visit.mask) assert_projections_equal(self, proj, mask.projection, expect_identity=False) - self.check_legacy_obs_info(mask.obs_info) psf = VisitImage.read_legacy(self.filename, component="psf") self.assertIsInstance(psf, PointSpreadFunction) obs_info = VisitImage.read_legacy(self.filename, component="obs_info") @@ -677,39 +674,39 @@ def test_convert_unit(self) -> None: os.path.join(EXTERNAL_DATA_DIR, "dp2", "legacy", "visit_summary.fits") ) legacy_photo_calib = visit_summary.find(DP2_VISIT_DETECTOR_DATA_ID["detector"]).getPhotoCalib() - self.visit_image.photometric_scaling = field_from_legacy_photo_calib( + visit_image_nJy.photometric_scaling = field_from_legacy_photo_calib( legacy_photo_calib, bounds=self.visit_image.detector.bbox, instrumental_unit=u.electron ) compare_photo_calib_to_legacy( self, - self.visit_image.photometric_scaling, + visit_image_nJy.photometric_scaling, self.legacy_exposure.getPhotoCalib(), applied_legacy_photo_calib=legacy_photo_calib, - subimage_bbox=self.visit_image.bbox, + subimage_bbox=visit_image_nJy.bbox, ) # We still can't convert to completely unrelated units. with self.assertRaises(u.UnitConversionError): - self.visit_image.convert_unit(u.mm) + visit_image_nJy.convert_unit(u.mm) # Uncalibrating via the photometric_scaling matches what legacy code # does, and by default it copies everything. with self.assertRaises(u.UnitConversionError): - self.visit_image.convert_unit(u.electron, copy=False) + visit_image_nJy.convert_unit(u.electron, copy=False) legacy_masked_image_e = legacy_photo_calib.uncalibrateImage(self.legacy_exposure.maskedImage) - visit_image_e = self.visit_image.convert_unit(u.electron) + visit_image_e = visit_image_nJy.convert_unit(u.electron) assert_close(self, visit_image_e.image.array, legacy_masked_image_e.image.array) assert_close(self, visit_image_e.variance.array, legacy_masked_image_e.variance.array) - self.assertFalse(np.may_share_memory(visit_image_e.mask.array, self.visit_image.mask.array)) + self.assertFalse(np.may_share_memory(visit_image_e.mask.array, visit_image_nJy.mask.array)) # We can also uncalibrate if we start with an image that has units # that are compatible with the photometric_scaling but not identical # to it. - visit_image_mJy.photometric_scaling = self.visit_image.photometric_scaling + visit_image_mJy.photometric_scaling = visit_image_nJy.photometric_scaling visit_image_e = visit_image_mJy.convert_unit(u.electron) assert_close(self, visit_image_e.image.array, legacy_masked_image_e.image.array) assert_close(self, visit_image_e.variance.array, legacy_masked_image_e.variance.array) # We can re-apply the scaling go go back to calibrated units. - visit_image_nJy = visit_image_e.convert_unit(u.nJy) - assert_close(self, visit_image_nJy.image.array, self.visit_image.image.array) - assert_close(self, visit_image_nJy.variance.array, self.visit_image.variance.array) + visit_image_nJy_2 = visit_image_e.convert_unit(u.nJy) + assert_close(self, visit_image_nJy_2.image.array, visit_image_nJy.image.array) + assert_close(self, visit_image_nJy_2.variance.array, self.visit_image.variance.array) @unittest.skipUnless(EXTERNAL_DATA_DIR is not None, "TESTDATA_IMAGES_DIR is not in the environment.")