From a8c9a08e2dd71c711aca8c0c280e6457075688fe Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 29 Jan 2026 12:57:08 -0800 Subject: [PATCH 1/2] first pass Signed-off-by: Yee Hing Tong --- dc_with_mashu.py | 30 ++ dev-requirements.in | 3 + flytekit/configuration/__init__.py | 12 +- flytekit/core/type_engine.py | 259 ++++++++++++++---- flytekit/extras/pytorch/checkpoint.py | 4 +- flytekit/extras/tensorflow/record.py | 4 +- flytekit/interaction/click_types.py | 10 +- flytekit/types/directory/types.py | 7 +- flytekit/types/file/file.py | 4 +- flytekit/types/schema/types.py | 4 +- .../types/structured/structured_dataset.py | 6 +- pyproject.toml | 1 - .../unit/interaction/test_click_types.py | 4 +- 13 files changed, 256 insertions(+), 92 deletions(-) create mode 100644 dc_with_mashu.py diff --git a/dc_with_mashu.py b/dc_with_mashu.py new file mode 100644 index 0000000000..998a6fdb9a --- /dev/null +++ b/dc_with_mashu.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from dataclasses_json import DataClassJsonMixin # dataclasses-json +from mashumaro import DataClassDictMixin # mashumaro +from mashumaro.codecs.json import JSONEncoder, JSONDecoder + + +@dataclass +class User(DataClassJsonMixin, DataClassDictMixin): + id: int + + +def main(): + user = User(id=42) + + print("=== Mashumaro serialize (codec) ===") + enc = JSONEncoder(User) + s = enc.encode(user) + print("Serialized:", s) + + print("\n=== Mashumaro deserialize (codec) ===") + dec = JSONDecoder(User) + u2 = dec.decode(s) + print("Deserialized:", u2, "type:", type(u2)) + + print("\n=== dataclasses-json still exists (separate API) ===") + print("dataclasses-json to_json():", user.to_json()) + + +if __name__ == "__main__": + main() diff --git a/dev-requirements.in b/dev-requirements.in index cef6ce1929..7172acb323 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -62,3 +62,6 @@ ipykernel orjson kubernetes>=12.0.1 httpx + +# dataclasses-json for backward compatibility testing +dataclasses-json>=0.6.7 diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 61818a0e36..f44c140dae 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -124,7 +124,7 @@ from typing import Dict, List, Optional import yaml -from dataclasses_json import DataClassJsonMixin +from mashumaro.mixins.json import DataClassJSONMixin from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages @@ -146,7 +146,7 @@ @dataclass(init=True, repr=True, eq=True, frozen=True) -class Image(DataClassJsonMixin): +class Image(DataClassJSONMixin): """ Image is a structured wrapper for task container images used in object serialization. @@ -236,7 +236,7 @@ def _parse_image_identifier(image_identifier: str) -> typing.Tuple[str, Optional @dataclass(init=True, repr=True, eq=True, frozen=True) -class ImageConfig(DataClassJsonMixin): +class ImageConfig(DataClassJSONMixin): """ We recommend you to use ImageConfig.auto(img_name=None) to create an ImageConfig. For example, ImageConfig.auto(img_name=""ghcr.io/flyteorg/flytecookbook:v1.0.0"") will create an ImageConfig. @@ -792,7 +792,7 @@ def for_endpoint( @dataclass -class EntrypointSettings(DataClassJsonMixin): +class EntrypointSettings(DataClassJSONMixin): """ This object carries information about the path of the entrypoint command that will be invoked at runtime. This is where `pyflyte-execute` code can be found. This is useful for cases like pyspark execution. @@ -802,7 +802,7 @@ class EntrypointSettings(DataClassJsonMixin): @dataclass -class FastSerializationSettings(DataClassJsonMixin): +class FastSerializationSettings(DataClassJSONMixin): """ This object hold information about settings necessary to serialize an object so that it can be fast-registered. """ @@ -817,7 +817,7 @@ class FastSerializationSettings(DataClassJsonMixin): # TODO: ImageConfig, python_interpreter, venv_root, fast_serialization_settings.destination_dir should be combined. @dataclass -class SerializationSettings(DataClassJsonMixin): +class SerializationSettings(DataClassJSONMixin): """ These settings are provided while serializing a workflow and task, before registration. This is required to get runtime information at serialization time, as well as some defaults. diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 24a78f184b..30a5c84043 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,7 +22,6 @@ import msgpack from cachetools import LRUCache -from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212 from google.protobuf import json_format as _json_format @@ -68,8 +67,170 @@ # In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. # This is relevant for cases like Dict[int, str]. # If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.` + + +def _filter_dataclass_json_config(obj: Any) -> Any: + """Recursively remove dataclass_json_config from dicts. + + This provides compatibility with dataclasses-json 0.6.x which adds + dataclass_json_config as a class annotation that mashumaro picks up + during serialization. + """ + if isinstance(obj, dict): + return {k: _filter_dataclass_json_config(v) for k, v in obj.items() if k != "dataclass_json_config"} + elif isinstance(obj, list): + return [_filter_dataclass_json_config(item) for item in obj] + return obj + + def _default_msgpack_decoder(data: bytes) -> Any: - return msgpack.unpackb(data, strict_map_key=False) + result = msgpack.unpackb(data, strict_map_key=False) + return _filter_dataclass_json_config(result) + + +# Lock for thread-safe patching of DataClassJsonMixin.__annotations__ +_dataclass_json_mixin_lock = threading.Lock() + +# Cache for JSONDecoders and encoders used with dataclass_json_config patching +_patched_decoders: Dict[type, JSONDecoder] = {} +_patched_json_encoders: Dict[type, JSONEncoder] = {} +_patched_msgpack_encoders: Dict[type, MessagePackEncoder] = {} + + +def _deserialize_with_dataclass_json_config_patch( + expected_python_type: type, json_str: str, decoder_cache: Optional[Dict[type, Any]] = None +) -> Any: + """Deserialize JSON to a dataclass, handling dataclasses-json 0.6.x compatibility. + + dataclasses-json 0.6.x adds dataclass_json_config to DataClassJsonMixin's __annotations__. + This causes mashumaro's JSONDecoder to fail because it expects this field in the data. + + For classes inheriting from dataclasses-json's mixin, this function temporarily removes + the annotation from the parent class, creates a decoder, deserializes, and restores it. + + Args: + expected_python_type: The dataclass type to deserialize to + json_str: The JSON string to deserialize + decoder_cache: Optional cache dict to store/retrieve decoders (for non-patched cases) + """ + try: + from dataclasses_json import DataClassJsonMixin as DCJsonMixin + except ImportError: + DCJsonMixin = None # type: ignore + + # Check if the expected type inherits from dataclasses-json's mixin + needs_patch = DCJsonMixin is not None and issubclass(expected_python_type, DCJsonMixin) + + if not needs_patch: + # Not using dataclasses-json, use JSONDecoder directly with caching + if decoder_cache is not None: + try: + decoder = decoder_cache[expected_python_type] + except KeyError: + decoder = JSONDecoder(expected_python_type) + decoder_cache[expected_python_type] = decoder + else: + decoder = JSONDecoder(expected_python_type) + return decoder.decode(json_str) + + # For dataclasses-json mixin classes, we need to patch temporarily + # Check if we already have a cached decoder (created with patch) + if expected_python_type in _patched_decoders: + return _patched_decoders[expected_python_type].decode(json_str) + + # Temporarily patch DataClassJsonMixin.__annotations__ to remove dataclass_json_config + with _dataclass_json_mixin_lock: + original_annotations = DCJsonMixin.__annotations__.copy() + DCJsonMixin.__annotations__ = {} + try: + decoder = JSONDecoder(expected_python_type) + _patched_decoders[expected_python_type] = decoder + return decoder.decode(json_str) + finally: + DCJsonMixin.__annotations__ = original_annotations + + +def _get_msgpack_encoder_with_patch(python_type: type, encoder_cache: Optional[Dict[type, Any]] = None) -> MessagePackEncoder: + """Get a MessagePackEncoder for a dataclass, handling dataclasses-json 0.6.x compatibility. + + Similar to _deserialize_with_dataclass_json_config_patch but for encoding. + """ + try: + from dataclasses_json import DataClassJsonMixin as DCJsonMixin + except ImportError: + DCJsonMixin = None # type: ignore + + # Only check for mixin inheritance if python_type is actually a class (not a generic like Dict[str, int]) + needs_patch = ( + DCJsonMixin is not None + and isinstance(python_type, type) + and issubclass(python_type, DCJsonMixin) + ) + + if not needs_patch: + if encoder_cache is not None: + try: + return encoder_cache[python_type] + except KeyError: + encoder = MessagePackEncoder(python_type) + encoder_cache[python_type] = encoder + return encoder + return MessagePackEncoder(python_type) + + # Check cache first + if python_type in _patched_msgpack_encoders: + return _patched_msgpack_encoders[python_type] + + # Create encoder with patch + with _dataclass_json_mixin_lock: + original_annotations = DCJsonMixin.__annotations__.copy() + DCJsonMixin.__annotations__ = {} + try: + encoder = MessagePackEncoder(python_type) + _patched_msgpack_encoders[python_type] = encoder + return encoder + finally: + DCJsonMixin.__annotations__ = original_annotations + + +def _get_json_encoder_with_patch(python_type: type, encoder_cache: Optional[Dict[type, Any]] = None) -> JSONEncoder: + """Get a JSONEncoder for a dataclass, handling dataclasses-json 0.6.x compatibility.""" + try: + from dataclasses_json import DataClassJsonMixin as DCJsonMixin + except ImportError: + DCJsonMixin = None # type: ignore + + # Only check for mixin inheritance if python_type is actually a class (not a generic like Dict[str, int]) + needs_patch = ( + DCJsonMixin is not None + and isinstance(python_type, type) + and issubclass(python_type, DCJsonMixin) + ) + + if not needs_patch: + if encoder_cache is not None: + try: + return encoder_cache[python_type] + except KeyError: + encoder = JSONEncoder(python_type) + encoder_cache[python_type] = encoder + return encoder + return JSONEncoder(python_type) + + # Check cache first + if python_type in _patched_json_encoders: + return _patched_json_encoders[python_type] + + # Create encoder with patch + with _dataclass_json_mixin_lock: + original_annotations = DCJsonMixin.__annotations__.copy() + DCJsonMixin.__annotations__ = {} + try: + encoder = JSONEncoder(python_type) + _patched_json_encoders[python_type] = encoder + return encoder + finally: + DCJsonMixin.__annotations__ = original_annotations class BatchSize: @@ -621,35 +782,13 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: - logger.error( - f"Failed to extract schema for object {t}, error: {e}\n" - f"Please remove `DataClassJsonMixin` and `dataclass_json` decorator from the dataclass definition" + # https://github.com/lovasoa/marshmallow_dataclass/issues/13 + logger.warning( + f"Failed to extract schema for object {t}, (will run schemaless) error: {e}" + f"If you have postponed annotations turned on (PEP 563) turn it off please. Postponed" + f"evaluation doesn't work with json dataclasses" ) - if schema is None: - try: - # This produce JSON SCHEMA draft 2020-12 - from marshmallow_enum import EnumField, LoadDumpOptions - - if issubclass(t, DataClassJsonMixin): - s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() - for _, v in s.fields.items(): - # marshmallow-jsonschema only supports enums loaded by name. - # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 - if isinstance(v, EnumField): - v.load_by = LoadDumpOptions.name - # check if DataClass mixin - from marshmallow_jsonschema import JSONSchema - - schema = JSONSchema().dump(s) - except Exception as e: - # https://github.com/lovasoa/marshmallow_dataclass/issues/13 - logger.warning( - f"Failed to extract schema for object {t}, (will run schemaless) error: {e}" - f"If you have postponed annotations turned on (PEP 563) turn it off please. Postponed" - f"evaluation doesn't work with json dataclasses" - ) - # Recursively construct the dataclass_type which contains the literal type of each field literal_type = {} @@ -699,11 +838,8 @@ def to_generic_literal( if isinstance(python_val, DataClassJSONMixin): json_str = python_val.to_json() else: - try: - encoder = self._json_encoder[python_type] - except KeyError: - encoder = JSONEncoder(python_type) - self._json_encoder[python_type] = encoder + # Use the patching helper to handle dataclasses-json 0.6.x compatibility. + encoder = _get_json_encoder_with_patch(python_type, encoder_cache=self._json_encoder) try: json_str = encoder.encode(python_val) @@ -741,11 +877,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp else: # The function looks up or creates a MessagePackEncoder specifically designed for the object's type. # This encoder is then used to convert a data class into MessagePack Bytes. - try: - encoder = self._msgpack_encoder[python_type] - except KeyError: - encoder = MessagePackEncoder(python_type) - self._msgpack_encoder[python_type] = encoder + # Use the patching helper to handle dataclasses-json 0.6.x compatibility. + encoder = _get_msgpack_encoder_with_patch(python_type, encoder_cache=self._msgpack_encoder) try: msgpack_bytes = encoder.encode(python_val) @@ -906,17 +1039,21 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T: if binary_idl_object.tag == MESSAGEPACK: + # Decode msgpack to dict and filter dataclass_json_config for compatibility + # with dataclasses-json 0.6.x which adds this field as a class annotation. + dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) + dict_obj = _filter_dataclass_json_config(dict_obj) + json_str = json.dumps(dict_obj) + if issubclass(expected_python_type, DataClassJSONMixin): - dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) - json_str = json.dumps(dict_obj) + # For mashumaro's DataClassJSONMixin, use from_json directly dc = expected_python_type.from_json(json_str) # type: ignore else: - try: - decoder = self._msgpack_decoder[expected_python_type] - except KeyError: - decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder) - self._msgpack_decoder[expected_python_type] = decoder - dc = decoder.decode(binary_idl_object.value) + # For other dataclasses, use JSONDecoder with dataclass_json_config patch + # Pass _msgpack_decoder as cache since this is the msgpack path + dc = _deserialize_with_dataclass_json_config_patch( + expected_python_type, json_str, decoder_cache=self._msgpack_decoder + ) return dc else: @@ -934,21 +1071,21 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: json_str = _json_format.MessageToJson(lv.scalar.generic) + # Filter out dataclass_json_config for compatibility with dataclasses-json 0.6.x + dict_obj = json.loads(json_str) + dict_obj = _filter_dataclass_json_config(dict_obj) + json_str = json.dumps(dict_obj) + # The `from_json` function is provided from mashumaro's `DataClassJSONMixin`. # It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder # We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to customize the deserialization behavior for Flyte types. if issubclass(expected_python_type, DataClassJSONMixin): dc = expected_python_type.from_json(json_str) # type: ignore else: - # The function looks up or creates a JSONDecoder specifically designed for the object's type. - # This decoder is then used to convert a JSON string into a data class. - try: - decoder = self._json_decoder[expected_python_type] - except KeyError: - decoder = JSONDecoder(expected_python_type) - self._json_decoder[expected_python_type] = decoder - - dc = decoder.decode(json_str) + # For other dataclasses, use JSONDecoder with dataclass_json_config patch + dc = _deserialize_with_dataclass_json_config_patch( + expected_python_type, json_str, decoder_cache=self._json_decoder + ) return self._fix_dataclass_int(expected_python_type, dc) @@ -2182,7 +2319,8 @@ async def dict_to_generic_literal( try: try: # JSONEncoder is mashumaro's codec and this can triggered Flyte Types customized serialization and deserialization. - encoder = JSONEncoder(python_type) + # Use the patching helper to handle dataclasses-json 0.6.x compatibility. + encoder = _get_json_encoder_with_patch(python_type) json_str = encoder.encode(v) except NotImplementedError: raise NotImplementedError( @@ -2217,7 +2355,8 @@ async def dict_to_binary_literal( try: # Handle dictionaries with non-string keys (e.g., Dict[int, Type]) - encoder = MessagePackEncoder(python_type) + # Use the patching helper to handle dataclasses-json 0.6.x compatibility. + encoder = _get_msgpack_encoder_with_patch(python_type) msgpack_bytes = encoder.encode(v) return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK))) except TypeError as e: @@ -2488,7 +2627,7 @@ def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: t """ attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) - return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + return dataclasses.make_dataclass(schema_name, attribute_list) def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> type: @@ -2499,7 +2638,7 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ """ attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) - return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + return dataclasses.make_dataclass(schema_name, attribute_list) def _get_element_type(element_property: typing.Dict[str, str]) -> Type: diff --git a/flytekit/extras/pytorch/checkpoint.py b/flytekit/extras/pytorch/checkpoint.py index dfb21f5932..f110be1b98 100644 --- a/flytekit/extras/pytorch/checkpoint.py +++ b/flytekit/extras/pytorch/checkpoint.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union import torch -from dataclasses_json import DataClassJsonMixin +from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Protocol from flytekit.core.context_manager import FlyteContext @@ -21,7 +21,7 @@ class IsDataclass(Protocol): @dataclass -class PyTorchCheckpoint(DataClassJsonMixin): +class PyTorchCheckpoint(DataClassJSONMixin): """ This class is helpful to save a checkpoint. """ diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py index 3e86b6b2ee..ff3cd430ca 100644 --- a/flytekit/extras/tensorflow/record.py +++ b/flytekit/extras/tensorflow/record.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Type, Union import tensorflow as tf -from dataclasses_json import DataClassJsonMixin +from mashumaro.mixins.json import DataClassJSONMixin from tensorflow.python.data.ops.readers import TFRecordDatasetV2 from typing_extensions import Annotated, get_args, get_origin @@ -17,7 +17,7 @@ @dataclass -class TFRecordDatasetConfig(DataClassJsonMixin): +class TFRecordDatasetConfig(DataClassJSONMixin): """ TFRecordDatasetConfig can be used while creating tf.data.TFRecordDataset comprising record of one or more TFRecord files. diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 056fa2db61..df5ed2d324 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -16,7 +16,7 @@ import yaml from click import Parameter from click import __version__ as click_version -from dataclasses_json import DataClassJsonMixin, dataclass_json +from flytekit.core.type_engine import _deserialize_with_dataclass_json_config_patch from packaging.version import Version from pytimeparse import parse @@ -438,11 +438,9 @@ def has_nested_dataclass(t: typing.Type) -> bool: # The behavior of the Pydantic v1 plugin. return self._python_type.parse_raw(json.dumps(parsed_value)) - # Ensure that the python type has `from_json` function - if not hasattr(self._python_type, "from_json"): - self._python_type = dataclass_json(self._python_type) - - return cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(parsed_value)) + # Use mashumaro's JSONDecoder to deserialize the dataclass + # Use the patching helper to handle dataclasses-json 0.6.x compatibility + return _deserialize_with_dataclass_json_config_patch(self._python_type, json.dumps(parsed_value)) def modify_literal_uris(lit: Literal): diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index ec081aca94..cde38815dc 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -14,11 +14,10 @@ import fsspec import msgpack -from dataclasses_json import DataClassJsonMixin, config +from mashumaro.mixins.json import DataClassJSONMixin from fsspec.utils import get_protocol from google.protobuf import json_format as _json_format from google.protobuf.struct_pb2 import Struct -from marshmallow import fields from mashumaro.types import SerializableType from flytekit.core.constants import MESSAGEPACK @@ -72,8 +71,8 @@ def __call__(self): @dataclass -class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.Generic[T]): - path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore +class FlyteDirectory(SerializableType, DataClassJSONMixin, os.PathLike, typing.Generic[T]): + path: PathType = field(default=None) # type: ignore """ > [!WARNING] > This class should not be used on very large datasets, as merely listing the dataset will cause diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 47915add8e..14d462e405 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -11,10 +11,8 @@ from urllib.parse import unquote import msgpack -from dataclasses_json import config from google.protobuf import json_format as _json_format from google.protobuf.struct_pb2 import Struct -from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -60,7 +58,7 @@ def __call__(self): @dataclass class FlyteFile(SerializableType, os.PathLike, typing.Generic[T], DataClassJSONMixin): - path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore + path: typing.Union[str, os.PathLike] = field(default=None) # type: ignore metadata: typing.Optional[dict[str, str]] = None """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 34dcc18058..3b4f910dc9 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -11,10 +11,8 @@ from typing import Dict, Optional, Type import msgpack -from dataclasses_json import config from google.protobuf import json_format as _json_format from google.protobuf.struct_pb2 import Struct -from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -184,7 +182,7 @@ def get_handler(cls, t: Type) -> SchemaHandler: @dataclass class FlyteSchema(SerializableType, DataClassJSONMixin): - remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) + remote_path: typing.Optional[str] = field(default=None) """ This is the main schema class that users should use. """ diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 7dd9532382..07deb7ad2d 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -10,11 +10,9 @@ from typing import Dict, Generator, Generic, List, Optional, Type, Union import msgpack -from dataclasses_json import config from fsspec.utils import get_protocol from google.protobuf import json_format as _json_format from google.protobuf.struct_pb2 import Struct -from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType from typing_extensions import Annotated, TypeAlias, get_args, get_origin @@ -59,8 +57,8 @@ class StructuredDataset(SerializableType, DataClassJSONMixin): class (that is just a model, a Python class representation of the protobuf). """ - uri: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) - file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String())) + uri: typing.Optional[str] = field(default=None) + file_format: typing.Optional[str] = field(default=GENERIC_FORMAT) def _serialize(self) -> Dict[str, Optional[str]]: # dataclass case diff --git a/pyproject.toml b/pyproject.toml index d2c1e98fb0..967f38332a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ dependencies = [ "click>=6.6", "cloudpickle>=2.0.0", "croniter>=0.3.20", - "dataclasses-json>=0.5.2,<0.5.12", # TODO: remove upper-bound after fixing change in contract "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index 4992e5a91d..b47be639db 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -422,7 +422,9 @@ class Datum: t = JsonParamType(Datum) value = { "x": parquet_file, "y": DIR_NAME, "z": os.path.join(DIR_NAME, "testdata")} - with pytest.raises(AttributeError): + from mashumaro.exceptions import InvalidFieldValue + with pytest.raises((AttributeError, InvalidFieldValue)): + # mashumaro raises InvalidFieldValue for invalid field values t.convert(value=value, param=None, ctx=None) def test_dataclass_with_optional_fields(): From aee4b58ea46ac2ef1f8cd2b4ffa35eb036ed6e01 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 29 Jan 2026 13:01:24 -0800 Subject: [PATCH 2/2] update test to use mashumaro encoder directly Signed-off-by: Yee Hing Tong --- tests/flytekit/unit/core/test_dataclass_guessing.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/flytekit/unit/core/test_dataclass_guessing.py b/tests/flytekit/unit/core/test_dataclass_guessing.py index e3face7342..80eafd1ba4 100644 --- a/tests/flytekit/unit/core/test_dataclass_guessing.py +++ b/tests/flytekit/unit/core/test_dataclass_guessing.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from flytekit.core.type_engine import TypeEngine, strict_type_hint_matching -from mashumaro.codecs.json import JSONDecoder +from mashumaro.codecs.json import JSONDecoder, JSONEncoder class JobConfig(BaseModel): @@ -51,7 +51,8 @@ def test_guessing_of_nested_pydantic(): input_config_dc_version = decoder.decode(input_config_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfig.model_validate_json(json_dc_version) assert reconstituted_pydantic == input_config @@ -80,7 +81,8 @@ def test_nested_pydantic_reconstruction_from_raw_json(): input_config_dc_version = decoder.decode(existing_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfig.model_validate_json(json_dc_version) assert reconstituted_pydantic == SchedulerConfig( input_storage_bucket="s3://input-storage-bucket", @@ -120,7 +122,8 @@ def test_guessing_of_nested_pydantic_mapped(): input_config_dc_version = decoder.decode(input_config_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfigMapped.model_validate_json(json_dc_version) assert reconstituted_pydantic == input_config