From 3fca3b080088245df0bb85378d42658c21e41460 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Sat, 23 Mar 2024 16:20:17 +0000 Subject: [PATCH 1/2] fix: favour SA mapped type over impl type --- polyfactory/factories/sqlalchemy_factory.py | 9 ++-- .../test_sqlalchemy_factory_common.py | 45 ++++++++++++++++++- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 62a7c0a5..5a2155b2 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -129,11 +129,10 @@ def get_type_from_column(cls, column: Column) -> type: elif issubclass(column_type, types.ARRAY): annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] else: - annotation = ( - column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues] - if hasattr(column.type, "impl") - else column.type.python_type - ) + try: + annotation = column.type.python_type + except NotImplementedError: + annotation = column.type.impl.python_type # type: ignore[attr-defined] if column.nullable: annotation = Union[annotation, None] # type: ignore[assignment] diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 6a7657da..01f7ab52 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Type +from typing import Any, Callable, Type, Union +from uuid import UUID import pytest from sqlalchemy import Column, ForeignKey, Integer, String, create_engine, inspect, orm, types @@ -347,3 +348,45 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]): result = ModelFactory.build() assert isinstance(result.name, str) + + +@pytest.mark.parametrize("python_type_", (UUID, None)) +@pytest.mark.parametrize( + "impl_", + ( + types.Uuid(), + types.Uuid(native_uuid=False), + types.CHAR(32), + ), +) +def test_sqlalchemy_custom_type_from_type_decorator(impl_: types.TypeEngine, python_type_: Union[type, None]) -> None: + class CustomType(types.TypeDecorator): + impl = impl_ + cache_ok = True + + if python_type_ is not None: + + @property + def python_type(self) -> type: + return python_type_ + + class Base(orm.DeclarativeBase): + type_annotation_map = { + UUID: CustomType, + } + + class Model(Base): + __tablename__ = "model_with_custom_types" + + id: orm.Mapped[int] = orm.mapped_column(primary_key=True) + custom_type: orm.Mapped[UUID] = orm.mapped_column(type_=CustomType(), nullable=False) + custom_type_from_annotation_map: orm.Mapped[UUID] + + class ModelFactory(SQLAlchemyFactory[Model]): + __model__ = Model + + instance = ModelFactory.build() + + expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type + assert isinstance(instance.custom_type, expected_type) + assert isinstance(instance.custom_type_from_annotation_map, expected_type) From db59dc501c580143fe86c555a6bb8a2a8d2b83dc Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Sat, 23 Mar 2024 16:34:16 +0000 Subject: [PATCH 2/2] fix: SA1.4 compatability --- .../test_sqlalchemy_factory_common.py | 30 +++++++------------ 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 01f7ab52..ca77fd1e 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -351,17 +351,9 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]): @pytest.mark.parametrize("python_type_", (UUID, None)) -@pytest.mark.parametrize( - "impl_", - ( - types.Uuid(), - types.Uuid(native_uuid=False), - types.CHAR(32), - ), -) -def test_sqlalchemy_custom_type_from_type_decorator(impl_: types.TypeEngine, python_type_: Union[type, None]) -> None: +def test_sqlalchemy_custom_type_from_type_decorator(python_type_: Union[type, None]) -> None: class CustomType(types.TypeDecorator): - impl = impl_ + impl = types.CHAR(32) cache_ok = True if python_type_ is not None: @@ -370,17 +362,18 @@ class CustomType(types.TypeDecorator): def python_type(self) -> type: return python_type_ - class Base(orm.DeclarativeBase): - type_annotation_map = { - UUID: CustomType, - } + class Base(metaclass=DeclarativeMeta): + __abstract__ = True + __allow_unmapped__ = True + + registry = _registry + metadata = _registry.metadata class Model(Base): - __tablename__ = "model_with_custom_types" + __tablename__ = f"model_with_custom_types_{python_type_}" - id: orm.Mapped[int] = orm.mapped_column(primary_key=True) - custom_type: orm.Mapped[UUID] = orm.mapped_column(type_=CustomType(), nullable=False) - custom_type_from_annotation_map: orm.Mapped[UUID] + id: Any = Column(Integer(), primary_key=True) + custom_type: Any = Column(type_=CustomType(), nullable=False) class ModelFactory(SQLAlchemyFactory[Model]): __model__ = Model @@ -389,4 +382,3 @@ class ModelFactory(SQLAlchemyFactory[Model]): expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type assert isinstance(instance.custom_type, expected_type) - assert isinstance(instance.custom_type_from_annotation_map, expected_type)