diff --git a/CHANGELOG.md b/CHANGELOG.md index 129c6c70217a..5d686dda9383 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ - Ensure errors raised in state handlers are trapped appropriately in Cloud Runners - [#554](https://github.com/PrefectHQ/prefect/pull/554) - Ensure unexpected errors raised in FlowRunners are robustly handled - [#568](https://github.com/PrefectHQ/prefect/pull/568) - Fixed non-deterministic errors in mapping caused by clients resolving futures of other clients - [#569](https://github.com/PrefectHQ/prefect/pull/569) +- Older versions of Prefect will now ignore fields added by newer versions when deserializing objects - [#583](https://github.com/PrefectHQ/prefect/pull/583) ### Breaking Changes @@ -37,6 +38,7 @@ - Convert `timeout` kwarg from `timedelta` to `integer` - [#540](https://github.com/PrefectHQ/prefect/issues/540) - Remove `timeout` kwarg from `executor.wait` - [#569](https://github.com/PrefectHQ/prefect/pull/569) - Serialization of States will _ignore_ any result data that hasn't been processed - [#581](https://github.com/PrefectHQ/prefect/pull/581) +- Removes `VersionedSchema` in favor of implicit versioning: serializers will ignore unknown fields and the `create_object` method is responsible for recreating missing ones - [#583](https://github.com/PrefectHQ/prefect/pull/583) ## 0.4.0 diff --git a/docs/outline.toml b/docs/outline.toml index d459f816c569..e9becf6ad9e7 100644 --- a/docs/outline.toml +++ b/docs/outline.toml @@ -194,8 +194,8 @@ functions = ["callback_factory", "slack_notifier", "gmail_notifier"] [pages.utilities.serialization] title = "Serialization" module = "prefect.utilities.serialization" -classes = ["VersionedSchema", "JSONCompatible", "Nested", "Bytes", "UUID", "FunctionReference"] -functions = ["to_qualified_name", "from_qualified_name", "version"] +classes = ["JSONCompatible", "Nested", "Bytes", "UUID", "FunctionReference"] +functions = ["to_qualified_name", "from_qualified_name"] [pages.utilities.tasks] title = "Tasks" diff --git a/src/prefect/serialization/edge.py b/src/prefect/serialization/edge.py index ffdae7bc8e13..4c0af047e9e4 100644 --- a/src/prefect/serialization/edge.py +++ b/src/prefect/serialization/edge.py @@ -2,11 +2,10 @@ import prefect from prefect.serialization.task import TaskSchema -from prefect.utilities.serialization import VersionedSchema, version +from prefect.utilities.serialization import ObjectSchema -@version("0.3.3") -class EdgeSchema(VersionedSchema): +class EdgeSchema(ObjectSchema): class Meta: object_class = lambda: prefect.core.Edge diff --git a/src/prefect/serialization/environment.py b/src/prefect/serialization/environment.py index 9202ab44cf7d..f03b91915028 100644 --- a/src/prefect/serialization/environment.py +++ b/src/prefect/serialization/environment.py @@ -9,14 +9,12 @@ Bytes, JSONCompatible, OneOfSchema, - VersionedSchema, + ObjectSchema, to_qualified_name, - version, ) -@version("0.3.3") -class LocalEnvironmentSchema(VersionedSchema): +class LocalEnvironmentSchema(ObjectSchema): class Meta: object_class = prefect.environments.LocalEnvironment @@ -24,8 +22,7 @@ class Meta: serialized_flow = Bytes(allow_none=True) -@version("0.3.3") -class ContainerEnvironmentSchema(VersionedSchema): +class ContainerEnvironmentSchema(ObjectSchema): class Meta: object_class = prefect.environments.ContainerEnvironment diff --git a/src/prefect/serialization/flow.py b/src/prefect/serialization/flow.py index 194386448ec5..eff319feec81 100644 --- a/src/prefect/serialization/flow.py +++ b/src/prefect/serialization/flow.py @@ -8,9 +8,8 @@ from prefect.utilities.serialization import ( JSONCompatible, Nested, - VersionedSchema, + ObjectSchema, to_qualified_name, - version, ) @@ -28,11 +27,10 @@ def get_reference_tasks(obj, context): return utils.get_value(obj, "reference_tasks") -@version("0.3.3") -class FlowSchema(VersionedSchema): +class FlowSchema(ObjectSchema): class Meta: object_class = lambda: prefect.core.Flow - object_class_exclude = ["id", "type", "parameters"] + exclude_fields = ["id", "type", "parameters"] # ordered to make sure Task objects are loaded before Edge objects, due to Task caching ordered = True diff --git a/src/prefect/serialization/result_handlers.py b/src/prefect/serialization/result_handlers.py index 65f54f4d96f5..970f9678442f 100644 --- a/src/prefect/serialization/result_handlers.py +++ b/src/prefect/serialization/result_handlers.py @@ -8,19 +8,16 @@ from prefect.utilities.serialization import ( JSONCompatible, OneOfSchema, - VersionedSchema, + ObjectSchema, to_qualified_name, - version, ) -@version("0.4.0") -class BaseResultHandlerSchema(VersionedSchema): +class BaseResultHandlerSchema(ObjectSchema): class Meta: object_class = ResultHandler -@version("0.4.0") class CloudResultHandlerSchema(BaseResultHandlerSchema): class Meta: object_class = CloudResultHandler @@ -28,7 +25,6 @@ class Meta: result_handler_service = fields.String(allow_none=True) -@version("0.4.0") class LocalResultHandlerSchema(BaseResultHandlerSchema): class Meta: object_class = LocalResultHandler diff --git a/src/prefect/serialization/schedule.py b/src/prefect/serialization/schedule.py index fccffc11634e..52d71836c8cf 100644 --- a/src/prefect/serialization/schedule.py +++ b/src/prefect/serialization/schedule.py @@ -4,16 +4,10 @@ from marshmallow import fields import prefect -from prefect.utilities.serialization import ( - OneOfSchema, - VersionedSchema, - to_qualified_name, - version, -) +from prefect.utilities.serialization import OneOfSchema, ObjectSchema, to_qualified_name -@version("0.3.3") -class IntervalScheduleSchema(VersionedSchema): +class IntervalScheduleSchema(ObjectSchema): class Meta: object_class = prefect.schedules.IntervalSchedule @@ -22,8 +16,7 @@ class Meta: interval = fields.TimeDelta(precision="microseconds", required=True) -@version("0.3.3") -class CronScheduleSchema(VersionedSchema): +class CronScheduleSchema(ObjectSchema): class Meta: object_class = prefect.schedules.CronSchedule diff --git a/src/prefect/serialization/state.py b/src/prefect/serialization/state.py index 1a7e7ff0534d..16c0a6c915ec 100644 --- a/src/prefect/serialization/state.py +++ b/src/prefect/serialization/state.py @@ -7,9 +7,8 @@ from prefect.utilities.serialization import ( JSONCompatible, OneOfSchema, - VersionedSchema, + ObjectSchema, to_qualified_name, - version, ) @@ -29,13 +28,8 @@ def _serialize(self, value, attr, obj, **kwargs): ) return super()._serialize(value, attr, obj, **kwargs) - def _deserialize(self, value, attr, data, **kwargs): - value = super()._deserialize(value, attr, data, **kwargs) - return value - -@version("0.3.3") -class BaseStateSchema(VersionedSchema): +class BaseStateSchema(ObjectSchema): class Meta: object_class = state.State @@ -51,7 +45,6 @@ def create_object(self, data): return base_obj -@version("0.3.3") class PendingSchema(BaseStateSchema): class Meta: object_class = state.Pending @@ -59,7 +52,6 @@ class Meta: cached_inputs = ResultHandlerField(allow_none=True) -@version("0.3.3") class SubmittedSchema(BaseStateSchema): class Meta: object_class = state.Submitted @@ -67,7 +59,6 @@ class Meta: state = fields.Nested("StateSchema", allow_none=True) -@version("0.3.3") class CachedStateSchema(PendingSchema): class Meta: object_class = state.CachedState @@ -77,7 +68,6 @@ class Meta: cached_result_expiration = fields.DateTime(allow_none=True) -@version("0.3.3") class ScheduledSchema(PendingSchema): class Meta: object_class = state.Scheduled @@ -85,13 +75,11 @@ class Meta: start_time = fields.DateTime(allow_none=True) -@version("0.3.3") class ResumeSchema(ScheduledSchema): class Meta: object_class = state.Resume -@version("0.3.3") class RetryingSchema(ScheduledSchema): class Meta: object_class = state.Retrying @@ -99,19 +87,16 @@ class Meta: run_count = fields.Int(allow_none=True) -@version("0.3.3") class RunningSchema(BaseStateSchema): class Meta: object_class = state.Running -@version("0.3.3") class FinishedSchema(BaseStateSchema): class Meta: object_class = state.Finished -@version("0.3.3") class SuccessSchema(FinishedSchema): class Meta: object_class = state.Success @@ -119,7 +104,6 @@ class Meta: cached = fields.Nested(CachedStateSchema, allow_none=True) -@version("0.3.3") class MappedSchema(SuccessSchema): class Meta: exclude = ["result", "map_states"] @@ -138,13 +122,11 @@ def create_object(self, data): return super().create_object(data) -@version("0.3.3") class FailedSchema(FinishedSchema): class Meta: object_class = state.Failed -@version("0.3.3") class TimedOutSchema(FinishedSchema): class Meta: object_class = state.TimedOut @@ -152,20 +134,17 @@ class Meta: cached_inputs = ResultHandlerField(allow_none=True) -@version("0.3.3") class TriggerFailedSchema(FailedSchema): class Meta: object_class = state.TriggerFailed -@version("0.3.3") class SkippedSchema(SuccessSchema): class Meta: object_class = state.Skipped - object_class_exclude = ["cached"] + exclude_fields = ["cached"] -@version("0.3.3") class PausedSchema(PendingSchema): class Meta: object_class = state.Paused diff --git a/src/prefect/serialization/task.py b/src/prefect/serialization/task.py index 29ac0d158205..0895d3601200 100644 --- a/src/prefect/serialization/task.py +++ b/src/prefect/serialization/task.py @@ -16,10 +16,9 @@ UUID, FunctionReference, JSONCompatible, - VersionedSchema, + ObjectSchema, from_qualified_name, to_qualified_name, - version, ) @@ -55,11 +54,10 @@ def create_object(self, data): return self.context["task_id_cache"][task_id] -@version("0.3.3") -class TaskSchema(TaskMethodsMixin, VersionedSchema): +class TaskSchema(TaskMethodsMixin, ObjectSchema): class Meta: object_class = lambda: prefect.core.Task - object_class_exclude = ["id", "type"] + exclude_fields = ["id", "type"] id = UUID() type = fields.Function(lambda task: to_qualified_name(type(task)), lambda x: x) @@ -101,11 +99,10 @@ class Meta: ) -@version("0.3.3") -class ParameterSchema(TaskMethodsMixin, VersionedSchema): +class ParameterSchema(TaskMethodsMixin, ObjectSchema): class Meta: object_class = lambda: prefect.core.task.Parameter - object_class_exclude = ["id", "type"] + exclude_fields = ["id", "type"] id = UUID() type = fields.Function(lambda task: to_qualified_name(type(task)), lambda x: x) diff --git a/src/prefect/utilities/serialization.py b/src/prefect/utilities/serialization.py index 8cee485308cc..6aac7007d101 100644 --- a/src/prefect/utilities/serialization.py +++ b/src/prefect/utilities/serialization.py @@ -1,4 +1,5 @@ # Licensed under LICENSE.md; also available at https://www.prefect.io/licenses/alpha-eula + import base64 import json import sys @@ -9,6 +10,7 @@ import pendulum from marshmallow import ( Schema, + EXCLUDE, SchemaOpts, ValidationError, fields, @@ -21,9 +23,6 @@ import prefect from prefect.utilities.collections import DotDict, as_nested_dict -MAX_VERSION = "__MAX_VERSION__" -VERSIONS = {} # type: Dict[str, Dict[str, VersionedSchema]] - def to_qualified_name(obj: Any) -> str: """ @@ -71,135 +70,27 @@ def from_qualified_name(obj_str: str) -> Any: ) -def version(version: str) -> Callable: - """ - Decorator that registers a schema with a specific version of Prefect. - - Args: - - version (str): the version to associated with the schema - - Returns: - - Callable: the decorated VersionedSchema class - """ - if not isinstance(version, str): - raise TypeError("Version must be a string.") - - def wrapper(cls): - if not issubclass(cls, VersionedSchema): - raise TypeError("Expected VersionedSchema") - VERSIONS.setdefault(to_qualified_name(cls), {})[version] = cls - return cls - - return wrapper - - -def get_versioned_schema(schema: "VersionedSchema", version: str) -> "VersionedSchema": - """ - Attempts to retrieve the registered VersionedSchema corresponding to name with the - highest version less or than equal to `version`. - - Args: - - name (str): the fully-qualified name of a registered VersionedSchema - - version (str): the version number - - Returns: - - VersionedSchema: the matching schema, or the original schema if no better match - was found. - - """ - name = to_qualified_name(type(schema)) - versions = VERSIONS.get(name, None) - - if name not in VERSIONS: - raise ValueError("Unregistered VersionedSchema") - - elif version is MAX_VERSION: - return VERSIONS[name][max(VERSIONS[name])] - - else: - matching_versions = [v for v in VERSIONS[name] if v <= version] - if not matching_versions: - raise ValueError( - "No VersionSchema was registered for version {}".format(version) - ) - return VERSIONS[name][max(matching_versions)] - - -class VersionedSchemaOptions(SchemaOpts): +class ObjectSchemaOptions(SchemaOpts): def __init__(self, meta, **kwargs) -> None: super().__init__(meta, **kwargs) self.object_class = getattr(meta, "object_class", None) - self.object_class_exclude = getattr(meta, "object_class_exclude", None) or [] + self.exclude_fields = getattr(meta, "exclude_fields", None) or [] + self.unknown = getattr(meta, "unknown", EXCLUDE) -class VersionedSchema(Schema): +class ObjectSchema(Schema): """ - This Marshmallow Schema automatically adds a `__version__` field when it serializes an - object, corresponding to the version of Prefect that did the serialization. - - Subclasses of VersionedSchema can be registered for specific versions of Prefect - by using the `version` decorator. - - When a VersionedSchema is deserialized, it reads the `__version__` field (if available) - and uses the VersionedSchema with the highest registered version less than or equal to - the `__version__` value. - - This allows object schemas to be migrated while maintaining compatibility with - - Args: - - *args (Any): the arguments accepted by `marshmallow.Schema` - - **kwargs (Any): the keyword arguments accepted by `marshmallow.Schema` + This Marshmallow Schema automatically instantiates an object whose type is indicated by the + `object_class` attribute of the class `Meta`. All deserialized fields are passed to the + constructor's `__init__()` unless the name of the field appears in `Meta.exclude_fields`. """ - OPTIONS_CLASS = VersionedSchemaOptions + OPTIONS_CLASS = ObjectSchemaOptions class Meta: object_class = None # type: type - object_class_exclude = [] # type: List[str] - - def __init__(self, *args, **kwargs) -> None: - self._args = args - self._kwargs = kwargs - super().__init__(*args, **kwargs) - - def load( - self, - data: dict, - create_object: bool = True, - check_version: bool = True, - **kwargs - ) -> Any: - """ - Loads an object by first retrieving the appropate schema version (based on the data's - __version__ key). - - Args: - - data (dict): the serialized data - - create_object (bool): if True, an instantiated object will be returned. Otherwise, - the deserialized data dict will be returned. - - check_version (bool): if True, the version will be checked and the appropriate schema - loaded. - - **kwargs (Any): additional keyword arguments for the load() method - - Returns: - - Any: the deserialized object or data - """ - self.context.setdefault("create_object", create_object) - - if check_version and isinstance(data, dict): - schema = get_versioned_schema(self, data.get("__version__", MAX_VERSION)) - else: - return super().load(data, **kwargs) - - # if we got a different (or no) schema, instantiate it - if schema is not self: - schema_instance = schema(*self._args, **self._kwargs) - schema_instance.context = self.context - return schema_instance.load( - data, create_object=create_object, check_version=False, **kwargs - ) - else: - return super().load(data, **kwargs) + exclude_fields = [] # type: List[str] + unknown = EXCLUDE @pre_load def _remove_version(self, data: dict) -> dict: @@ -229,10 +120,27 @@ def _add_version(self, data: dict) -> dict: data.setdefault("__version__", prefect.__version__) return data + def load(self, data: dict, create_object: bool = True, **kwargs) -> Any: + """ + Loads an object by first retrieving the appropate schema version (based on the data's + __version__ key). + + Args: + - data (dict): the serialized data + - create_object (bool): if True, an instantiated object will be returned. Otherwise, + the deserialized data dict will be returned. + - **kwargs (Any): additional keyword arguments for the load() method + + Returns: + - Any: the deserialized object or data + """ + self.context.setdefault("create_object", create_object) + return super().load(data, **kwargs) + @post_load def create_object(self, data: dict) -> Any: """ - By default, returns an instantiated object using the VersionedSchema's `object_class`. + By default, returns an instantiated object using the ObjectSchema's `object_class`. Otherwise, returns a data dict. Args: @@ -247,9 +155,7 @@ def create_object(self, data: dict) -> Any: if isinstance(object_class, types.FunctionType): object_class = object_class() init_data = { - k: v - for k, v in data.items() - if k not in self.opts.object_class_exclude + k: v for k, v in data.items() if k not in self.opts.exclude_fields } return object_class(**init_data) return data @@ -317,6 +223,9 @@ class OneOfSchema(marshmallow_oneofschema.OneOfSchema): A subclass of marshmallow_oneofschema.OneOfSchema that can load DotDicts """ + class Meta: + unknown = EXCLUDE + def _load(self, data, partial=None, unknown=None): if isinstance(data, DotDict): data = as_nested_dict(data, dict) diff --git a/tests/serialization/test_deserialization/README.md b/tests/serialization/test_deserialization/README.md new file mode 100644 index 000000000000..10344230fbf2 --- /dev/null +++ b/tests/serialization/test_deserialization/README.md @@ -0,0 +1,5 @@ +# Testing Deserialization + +It is important for Prefect to maintain backwards-compatibility with older versions of Prefect by ensuring that objects serialized under older versions can be deserialized under newer ones. + +These tests use snapshots of objects from older versions to ensure compatibility as the Schemas evolve. diff --git a/tests/serialization/test_deserialization/__init__.py b/tests/serialization/test_deserialization/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/serialization/test_deserialization/test_deserialization.py b/tests/serialization/test_deserialization/test_deserialization.py new file mode 100644 index 000000000000..f2f302cac46f --- /dev/null +++ b/tests/serialization/test_deserialization/test_deserialization.py @@ -0,0 +1,41 @@ +import json +import os + +import pytest + +import prefect +from prefect import serialization as s + +file_dir = os.path.dirname(__file__) + + +@pytest.fixture +def version_0_3_0(): + with open(os.path.join(file_dir, "version_0_3_0.json")) as f: + return json.load(f) + + +@pytest.fixture +def version_0_4_0(): + with open(os.path.join(file_dir, "version_0_4_0.json")) as f: + return json.load(f) + + +class Test_Version_0_3_0: + def test_deserialize_success(self, version_0_3_0): + state = s.state.StateSchema().load(version_0_3_0["states"]["success"]) + assert state.is_successful() + + def test_deserialize_retrying(self, version_0_3_0): + state = s.state.StateSchema().load(version_0_3_0["states"]["retrying"]) + assert isinstance(state, prefect.engine.state.Retrying) + + +class Test_Version_0_4_0: + def test_deserialize_success(self, version_0_4_0): + state = s.state.StateSchema().load(version_0_4_0["states"]["success"]) + assert state.is_successful() + + def test_deserialize_retrying(self, version_0_4_0): + state = s.state.StateSchema().load(version_0_4_0["states"]["retrying"]) + assert isinstance(state, prefect.engine.state.Retrying) diff --git a/tests/serialization/test_deserialization/version_0_3_0.json b/tests/serialization/test_deserialization/version_0_3_0.json new file mode 100644 index 000000000000..59f390b81731 --- /dev/null +++ b/tests/serialization/test_deserialization/version_0_3_0.json @@ -0,0 +1,16 @@ +{ + "states": { + "success": { + "type": "Success", + "message": "Success", + "result": 1, + "__version__": "0.3.0" + }, + "retrying": { + "type": "Retrying", + "message": "Retrying", + "result": 1, + "__version__": "0.3.0" + } + } +} diff --git a/tests/serialization/test_deserialization/version_0_4_0.json b/tests/serialization/test_deserialization/version_0_4_0.json new file mode 100644 index 000000000000..fac14b8f272e --- /dev/null +++ b/tests/serialization/test_deserialization/version_0_4_0.json @@ -0,0 +1,22 @@ +{ + "states": { + "success": { + "type": "Success", + "message": "Success", + "metadata": {}, + "result": 1, + "__version__": "0.3.0" + }, + "retrying": { + "type": "Retrying", + "message": "Retrying", + "metadata": { + "result": { + "is_raw": true + } + }, + "result": 1, + "__version__": "0.3.0" + } + } +} diff --git a/tests/serialization/test_states.py b/tests/serialization/test_states.py index 825d5f53f596..2acafbb16011 100644 --- a/tests/serialization/test_states.py +++ b/tests/serialization/test_states.py @@ -309,3 +309,15 @@ def test_deserialize_json_without_version(): assert deserialized.is_running() assert deserialized.message == "test" assert deserialized.result == 1 + + +def test_deserialize_handles_unknown_fields(): + """ensure that deserialization can happen even if a newer version of prefect created unknown fields""" + deserialized = StateSchema().load( + { + "type": "Success", + "success_message_that_definitely_wont_exist_on_a_real_state!": 1, + } + ) + + assert deserialized.is_successful() diff --git a/tests/serialization/test_versioned_schemas.py b/tests/serialization/test_versioned_schemas.py deleted file mode 100644 index 492b5594927e..000000000000 --- a/tests/serialization/test_versioned_schemas.py +++ /dev/null @@ -1,316 +0,0 @@ -import datetime - -import marshmallow -import pendulum -import pytest - -import prefect -from prefect.utilities.collections import DotDict, as_nested_dict -from prefect.utilities.serialization import ( - VERSIONS, - OneOfSchema, - VersionedSchema, - get_versioned_schema, - to_qualified_name, - version, -) - - -@pytest.fixture(autouse=True) -def clear_versions(): - VERSIONS.clear() - - -def test_register_addsto_qualified_name_to_VERSIONS(): - @version("1") - class Schema(VersionedSchema): - pass - - assert to_qualified_name(Schema) in VERSIONS - assert list(VERSIONS[to_qualified_name(Schema)]) == ["1"] - - -def test_register_multiple_versions(): - @version("1") - class Schema(VersionedSchema): - pass - - v1 = Schema - - @version("2") - class Schema(VersionedSchema): - pass - - v2 = Schema - - assert to_qualified_name(v1) == to_qualified_name(v2) - assert to_qualified_name(Schema) in VERSIONS - assert set(VERSIONS[to_qualified_name(Schema)]) == {"1", "2"} - assert VERSIONS[to_qualified_name(Schema)]["1"] is v1 - assert VERSIONS[to_qualified_name(Schema)]["2"] is v2 - - -def test_get_versioned_schemas(): - @version("1") - class Schema(VersionedSchema): - pass - - v1 = Schema - - @version("2") - class Schema(VersionedSchema): - pass - - v2 = Schema - - assert get_versioned_schema(Schema(), version="1") is v1 - assert get_versioned_schema(Schema(), version="1.5") is v1 - assert get_versioned_schema(Schema(), version="2") is v2 - assert get_versioned_schema(Schema(), version="2.5") is v2 - - -def testget_versioned_schemas_when_unregistered(): - class Schema(VersionedSchema): - pass - - with pytest.raises(ValueError) as exc: - get_versioned_schema(Schema(), version="1") - assert "unregistered" in str(exc).lower() - - -def testget_versioned_schemas_when_version_doesnt_match(): - @version("1") - class Schema(VersionedSchema): - pass - - with pytest.raises(ValueError) as exc: - get_versioned_schema(Schema(), version="0") - assert "no versionschema was registered" in str(exc).lower() - - -def test_version_must_be_a_string(): - - with pytest.raises(TypeError) as exc: - version(1) - assert "version must be a string" in str(exc).lower() - - -def test_cant_register_non_versioned_schemas(): - - with pytest.raises(TypeError) as exc: - - @version("1") - class Schema(marshmallow.Schema): - pass - - assert "expected versionedschema" in str(exc).lower() - - -def test_schema_writes_version_when_serialized_and_removes_when_deserialized(): - @version("0") - class Schema(VersionedSchema): - x = marshmallow.fields.String() - - serialized = Schema().dump({"x": "1"}) - assert serialized == {"x": "1", "__version__": prefect.__version__} - deserialized = Schema().load(serialized) - assert deserialized == {"x": "1"} - - -def test_version_determines_which_schema_to_load(): - @version("2") - class Schema(VersionedSchema): - y = marshmallow.fields.Int() - - @version("1") - class Schema(VersionedSchema): - x = marshmallow.fields.String() - - serialized = {"y": 1} - serialized_v1 = {"x": "hi", "__version__": "1"} - serialized_v2 = {"y": 2, "__version__": "2"} - serialized_v2_wrong_version = {"y": 2, "__version__": "1"} - assert Schema().load(serialized_v1) == {"x": "hi"} - assert Schema().load(serialized_v2) == {"y": 2} - with pytest.raises(marshmallow.exceptions.ValidationError): - Schema().load(serialized_v2_wrong_version) == {} - - -def test_no_version_uses_most_recent_version(): - @version("2") - class Schema(VersionedSchema): - y = marshmallow.fields.Int() - - @version("1") - class Schema(VersionedSchema): - x = marshmallow.fields.String() - - serialized = {"y": 1} - assert Schema().load(serialized) == {"y": 1} - - -def test_version_determines_which_nested_schema_to_load(): - @version("2") - class Schema(VersionedSchema): - y = marshmallow.fields.Int() - nested = marshmallow.fields.Nested("self") - - @version("1") - class Schema(VersionedSchema): - x = marshmallow.fields.String() - - serialized_v2 = {"y": 1, "nested": {"y": 1}} - serialized_v2_v1 = {"y": 1, "nested": {"x": "hi", "__version__": "1"}} - serialized_v2_v2 = {"y": 1, "nested": {"y": 1, "__version__": "2"}} - assert Schema().load(serialized_v2) == {"y": 1, "nested": {"y": 1}} - assert Schema().load(serialized_v2_v2) == {"y": 1, "nested": {"y": 1}} - assert Schema().load(serialized_v2_v1) == {"y": 1, "nested": {"x": "hi"}} - - -def test_versions_inherit_init_args(): - @version("2") - class Schema(VersionedSchema): - y = marshmallow.fields.Int() - - @version("1") - class Schema(VersionedSchema): - x = marshmallow.fields.String() - - # the version 1 schema will pass the unknown=True kwarg to the version 2 schema when - # it loads it - assert Schema(unknown=True).load({"y": 1, "z": 3, "__version__": "2"}) - - -def test_schema_creates_object(): - class TestObject: - def __init__(self, x): - self.x = x - - @version("0") - class Schema(VersionedSchema): - class Meta: - object_class = TestObject - - x = marshmallow.fields.Int() - - deserialized = Schema().load({"x": "1"}) - assert isinstance(deserialized, TestObject) - assert deserialized.x == 1 - - -def test_schema_does_not_create_object(): - class TestObject: - def __init__(self, x): - self.x = x - - @version("0") - class Schema(VersionedSchema): - class Meta: - object_class = TestObject - - x = marshmallow.fields.Int() - - deserialized = Schema().load({"x": "1"}, create_object=False) - assert deserialized == {"x": 1} - - -def test_nested_schema_creates_object(): - class TestObject: - def __init__(self, x, nested=None): - self.x = x - self.nested = nested - - @version("0") - class Schema(VersionedSchema): - class Meta: - object_class = TestObject - - x = marshmallow.fields.Int() - nested = marshmallow.fields.Nested("self", allow_none=True) - - deserialized = Schema().load({"x": "1", "nested": {"x": "2"}}) - assert isinstance(deserialized, TestObject) - assert isinstance(deserialized.nested, TestObject) - assert deserialized.nested.x == 2 - - -def test_nested_schema_does_not_create_object(): - class TestObject: - def __init__(self, y, nested=None): - self.x = x - self.nested = nested - - @version("0") - class Schema(VersionedSchema): - class Meta: - object_class = TestObject - - x = marshmallow.fields.Int() - nested = marshmallow.fields.Nested("self", allow_none=True) - - deserialized = Schema().load({"x": "1", "nested": {"x": "2"}}, create_object=False) - assert deserialized == {"x": 1, "nested": {"x": 2}} - - -def test_schema_creates_object_with_lambda(): - @version("0") - class Schema(VersionedSchema): - class Meta: - object_class = lambda: TestObject - - y = marshmallow.fields.Int() - - class TestObject: - def __init__(self, y): - self.y = y - - deserialized = Schema().load({"y": "1"}) - assert isinstance(deserialized, TestObject) - assert deserialized.y == 1 - - -def test_schema_doesnt_create_object_if_arg_is_false(): - class TestObject: - def __init__(self, y): - self.y = y - - @version("0") - class Schema(VersionedSchema): - class Meta: - object_class = TestObject - - y = marshmallow.fields.Int() - - assert Schema().load({"y": 1}, create_object=False) == {"y": 1} - - -def test_nested_schemas_pass_context_on_load(): - @version("0") - class Child(VersionedSchema): - x = marshmallow.fields.Function(None, lambda x, context: context["x"]) - - @version("0") - class Parent(VersionedSchema): - child = marshmallow.fields.Nested(Child) - - @marshmallow.pre_load - def set_context(self, obj): - self.context.update(x=5) - return obj - - assert Parent().load({"child": {"x": 1}})["child"]["x"] == 5 - - -def test_oneofschema_load_dotdict(): - """ - Tests that modified OneOfSchema can load data from a DotDict (standard can not) - """ - - class ChildSchema(marshmallow.Schema): - x = marshmallow.fields.Integer() - - class ParentSchema(OneOfSchema): - type_schemas = {"Child": ChildSchema} - - child = ParentSchema().load(DotDict(type="Child", x="5")) - assert child["x"] == 5 diff --git a/tests/utilities/test_serialization.py b/tests/utilities/test_serialization.py index 7f1f8e976c77..08202246aa31 100644 --- a/tests/utilities/test_serialization.py +++ b/tests/utilities/test_serialization.py @@ -11,7 +11,10 @@ FunctionReference, JSONCompatible, Nested, + ObjectSchema, + OneOfSchema, ) +import prefect json_test_values = [ 1, @@ -170,3 +173,138 @@ def test_deserialize_none(self): with pytest.raises(marshmallow.ValidationError): self.Schema().load({"f": None}) assert self.Schema().load({"f_none": None})["f_none"] is None + + +class TestObjectSchema: + def test_schema_writes_version_to_serialized_object(self): + class TestObject: + def __init__(self, x): + self.x = x + + class Schema(ObjectSchema): + class Meta: + object_class = TestObject + + x = marshmallow.fields.Int() + + serialized = Schema().dump(TestObject(x=5)) + assert serialized == {"__version__": prefect.__version__, "x": 5} + + def test_schema_creates_object(self): + class TestObject: + def __init__(self, x): + self.x = x + + class Schema(ObjectSchema): + class Meta: + object_class = TestObject + + x = marshmallow.fields.Int() + + deserialized = Schema().load({"x": "1"}) + assert isinstance(deserialized, TestObject) + assert deserialized.x == 1 + + def test_schema_does_not_create_object_if_arg_is_false(self): + class TestObject: + def __init__(self, x): + self.x = x + + class Schema(ObjectSchema): + class Meta: + object_class = TestObject + + x = marshmallow.fields.Int() + + deserialized = Schema().load({"x": "1"}, create_object=False) + assert deserialized == {"x": 1} + + def test_schema_has_error_if_fields_cant_be_supplied_to_init(self): + class TestObject: + def __init__(self, x): + self.x = x + + class Schema(ObjectSchema): + class Meta: + object_class = TestObject + + x = marshmallow.fields.Int() + y = marshmallow.fields.Int() + + with pytest.raises(TypeError): + Schema().load({"x": "1", "y": "2"}) + + def test_schema_with_excluded_fields(self): + class TestObject: + def __init__(self, x): + self.x = x + + class Schema(ObjectSchema): + class Meta: + object_class = TestObject + exclude_fields = ["y"] + + x = marshmallow.fields.Int() + y = marshmallow.fields.Int() + + deserialized = Schema().load({"x": "1", "y": "2"}) + assert isinstance(deserialized, TestObject) + assert deserialized.x == 1 + assert not hasattr(deserialized, "y") + + def test_schema_creates_object_with_lambda(self): + class Schema(ObjectSchema): + class Meta: + object_class = lambda: TestObject + + x = marshmallow.fields.Int() + + class TestObject: + def __init__(self, x): + self.x = x + + deserialized = Schema().load({"x": "1"}) + assert isinstance(deserialized, TestObject) + assert deserialized.x == 1 + + def test_schema_handles_unknown_fields(self): + class TestObject: + def __init__(self, x): + self.x = x + + class Schema(ObjectSchema): + class Meta: + object_class = TestObject + + x = marshmallow.fields.Int() + + deserialized = Schema().load({"x": "1", "y": "2"}) + assert isinstance(deserialized, TestObject) + assert not hasattr(deserialized, "y") + + +class TestOneOfSchema: + def test_oneofschema_load_dotdict(self): + """ + Tests that modified OneOfSchema can load data from a DotDict (standard can not) + """ + + class ChildSchema(marshmallow.Schema): + x = marshmallow.fields.Integer() + + class ParentSchema(OneOfSchema): + type_schemas = {"Child": ChildSchema} + + child = ParentSchema().load(DotDict(type="Child", x="5")) + assert child["x"] == 5 + + def test_oneofschema_handles_unknown_fields(self): + class ChildSchema(marshmallow.Schema): + x = marshmallow.fields.Integer() + + class ParentSchema(OneOfSchema): + type_schemas = {"Child": ChildSchema} + + child = ParentSchema().load(DotDict(type="Child", x="5", y="6")) + assert child["x"] == 5 + assert not hasattr(child, "y")