diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 62a7c0a5..5a2155b2 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -129,11 +129,10 @@ def get_type_from_column(cls, column: Column) -> type: elif issubclass(column_type, types.ARRAY): annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] else: - annotation = ( - column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues] - if hasattr(column.type, "impl") - else column.type.python_type - ) + try: + annotation = column.type.python_type + except NotImplementedError: + annotation = column.type.impl.python_type # type: ignore[attr-defined] if column.nullable: annotation = Union[annotation, None] # type: ignore[assignment] diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 6a7657da..ca77fd1e 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Type +from typing import Any, Callable, Type, Union +from uuid import UUID import pytest from sqlalchemy import Column, ForeignKey, Integer, String, create_engine, inspect, orm, types @@ -347,3 +348,37 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]): result = ModelFactory.build() assert isinstance(result.name, str) + + +@pytest.mark.parametrize("python_type_", (UUID, None)) +def test_sqlalchemy_custom_type_from_type_decorator(python_type_: Union[type, None]) -> None: + class CustomType(types.TypeDecorator): + impl = types.CHAR(32) + cache_ok = True + + if python_type_ is not None: + + @property + def python_type(self) -> type: + return python_type_ + + class Base(metaclass=DeclarativeMeta): + __abstract__ = True + __allow_unmapped__ = True + + registry = _registry + metadata = _registry.metadata + + class Model(Base): + __tablename__ = f"model_with_custom_types_{python_type_}" + + id: Any = Column(Integer(), primary_key=True) + custom_type: Any = Column(type_=CustomType(), nullable=False) + + class ModelFactory(SQLAlchemyFactory[Model]): + __model__ = Model + + instance = ModelFactory.build() + + expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type + assert isinstance(instance.custom_type, expected_type)