Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix strict behavior for unions #1638

Merged
merged 4 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2525,7 +2525,6 @@ def union_schema(
custom_error_message: str | None = None,
custom_error_context: dict[str, str | int] | None = None,
mode: Literal['smart', 'left_to_right'] | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: dict[str, Any] | None = None,
serialization: SerSchema | None = None,
Expand All @@ -2551,7 +2550,6 @@ def union_schema(
mode: How to select which choice to return
* `smart` (default) will try to return the choice which is the closest match to the input value
* `left_to_right` will return the first choice in `choices` which succeeds validation
strict: Whether the underlying schemas should be validated with strict mode
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
Expand All @@ -2564,7 +2562,6 @@ def union_schema(
custom_error_message=custom_error_message,
custom_error_context=custom_error_context,
mode=mode,
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
Expand Down
22 changes: 2 additions & 20 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use pyo3::{intern, PyTraverseError, PyVisit};
use smallvec::SmallVec;

use crate::build_tools::py_schema_err;
use crate::build_tools::{is_strict, schema_or_config};
use crate::build_tools::schema_or_config;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::errors::{ErrorType, ToErrorValue, ValError, ValLineError, ValResult};
use crate::input::{BorrowInput, Input, ValidatedDict};
Expand Down Expand Up @@ -43,7 +43,6 @@ pub struct UnionValidator {
mode: UnionMode,
choices: Vec<(CombinedValidator, Option<String>)>,
custom_error: Option<CustomError>,
strict: bool,
name: String,
}

Expand Down Expand Up @@ -91,7 +90,6 @@ impl BuildValidator for UnionValidator {
mode,
choices,
custom_error: CustomError::build(schema, config, definitions)?,
strict: is_strict(schema, config)?,
name: format!("{}[{descr}]", Self::EXPECTED_TYPE),
}
.into())
Expand All @@ -110,17 +108,11 @@ impl UnionValidator {
let old_exactness = state.exactness;
let old_fields_set_count = state.fields_set_count;

let strict = state.strict_or(self.strict);
let mut errors = MaybeErrors::new(self.custom_error.as_ref());

let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>)> = None;

for (choice, label) in &self.choices {
let state = &mut state.rebind_extra(|extra| {
if strict {
extra.strict = Some(strict);
}
});
state.exactness = Some(Exactness::Exact);
state.fields_set_count = None;
let result = choice.validate(py, input, state);
Expand Down Expand Up @@ -197,14 +189,6 @@ impl UnionValidator {
) -> ValResult<PyObject> {
let mut errors = MaybeErrors::new(self.custom_error.as_ref());

let mut rebound_state;
let state = if state.strict_or(self.strict) {
rebound_state = state.rebind_extra(|extra| extra.strict = Some(true));
&mut rebound_state
} else {
state
};

for (validator, label) in &self.choices {
match validator.validate(py, input, state) {
Err(ValError::LineErrors(lines)) => errors.push(validator, label.as_deref(), lines),
Expand Down Expand Up @@ -300,7 +284,6 @@ pub struct TaggedUnionValidator {
discriminator: Discriminator,
lookup: LiteralLookup<CombinedValidator>,
from_attributes: bool,
strict: bool,
custom_error: Option<CustomError>,
tags_repr: String,
discriminator_repr: String,
Expand Down Expand Up @@ -349,7 +332,6 @@ impl BuildValidator for TaggedUnionValidator {
discriminator,
lookup,
from_attributes,
strict: is_strict(schema, config)?,
custom_error: CustomError::build(schema, config, definitions)?,
tags_repr,
discriminator_repr,
Expand All @@ -371,7 +353,7 @@ impl Validator for TaggedUnionValidator {
match &self.discriminator {
Discriminator::LookupKey(lookup_key) => {
let from_attributes = state.extra().from_attributes.unwrap_or(self.from_attributes);
let dict = input.validate_model_fields(self.strict, from_attributes)?;
let dict = input.validate_model_fields(state.strict_or(false), from_attributes)?;
// note this methods returns PyResult<Option<(data, data)>>, the outer Err is just for
// errors when getting attributes which should be "raised"
let tag = match dict.get_item(lookup_key)? {
Expand Down
8 changes: 5 additions & 3 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,16 +686,18 @@ def test_smart_union_coerce_core(self, benchmark):
def test_strict_union_core(self, benchmark):
v = SchemaValidator(
schema=core_schema.union_schema(
strict=True, choices=[core_schema.bool_schema(), core_schema.int_schema(), core_schema.str_schema()]
)
choices=[core_schema.bool_schema(), core_schema.int_schema(), core_schema.str_schema()]
),
config=CoreConfig(strict=True),
)

benchmark(v.validate_python, 1)

@pytest.mark.benchmark(group='strict-union-error')
def test_strict_union_error_core(self, benchmark):
v = SchemaValidator(
schema=core_schema.union_schema(strict=True, choices=[core_schema.bool_schema(), core_schema.str_schema()])
schema=core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.str_schema()]),
config=CoreConfig(strict=True),
)

def validate_with_expected_error():
Expand Down
2 changes: 1 addition & 1 deletion tests/validators/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_constrained_bytes(py_and_json: PyAndJson, opts: dict[str, Any], input,


def test_union():
v = SchemaValidator(cs.union_schema(choices=[cs.str_schema(), cs.bytes_schema()], strict=True))
v = SchemaValidator(cs.union_schema(choices=[cs.str_schema(strict=True), cs.bytes_schema(strict=True)]))
assert v.validate_python('oh, a string') == 'oh, a string'
assert v.validate_python(b'oh, bytes') == b'oh, bytes'

Expand Down
8 changes: 4 additions & 4 deletions tests/validators/test_definitions_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,11 +611,11 @@ def test_union_cycle(strict: bool):
'foobar': core_schema.typed_dict_field(
core_schema.list_schema(core_schema.definition_reference_schema('root-schema'))
)
}
},
strict=strict,
)
],
auto_collapse=False,
strict=strict,
ref='root-schema',
)
],
Expand Down Expand Up @@ -700,11 +700,11 @@ def f(input_value, info):
)
],
auto_collapse=False,
strict=strict,
ref='root-schema',
)
],
)
),
config=CoreConfig(strict=strict),
)

with pytest.raises(ValidationError) as exc_info:
Expand Down
48 changes: 42 additions & 6 deletions tests/validators/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from dirty_equals import IsFloat, IsInt

from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema
from pydantic_core import CoreConfig, SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema

from ..conftest import plain_repr

Expand Down Expand Up @@ -262,16 +262,47 @@ def test_one_choice():
assert v.validate_python('hello') == 'hello'


def test_strict_union():
def test_strict_union_flag() -> None:
v = SchemaValidator(core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.int_schema()]))
assert v.validate_python(1, strict=True) == 1
assert v.validate_python(123, strict=True) == 123

with pytest.raises(ValidationError) as exc_info:
v.validate_python('123', strict=True)

assert exc_info.value.errors(include_url=False) == [
{'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'},
{'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'},
]


def test_strict_union_config_level() -> None:
v = SchemaValidator(
core_schema.union_schema(strict=True, choices=[core_schema.bool_schema(), core_schema.int_schema()])
core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.int_schema()]),
config=CoreConfig(strict=True),
)

assert v.validate_python(1) == 1
assert v.validate_python(123) == 123

with pytest.raises(ValidationError) as exc_info:
v.validate_python('123')
assert exc_info.value.errors(include_url=False) == [
{'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'},
{'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'},
]


def test_strict_union_member_level() -> None:
v = SchemaValidator(
core_schema.union_schema(choices=[core_schema.bool_schema(strict=True), core_schema.int_schema(strict=True)])
)

assert v.validate_python(1) == 1
assert v.validate_python(123) == 123

with pytest.raises(ValidationError) as exc_info:
v.validate_python('123')
assert exc_info.value.errors(include_url=False) == [
{'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'},
{'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'},
Expand Down Expand Up @@ -469,10 +500,10 @@ def test_left_to_right_union():


def test_left_to_right_union_strict():
choices = [core_schema.int_schema(), core_schema.float_schema()]
choices = [core_schema.int_schema(strict=True), core_schema.float_schema(strict=True)]

# left_to_right union will select not cast if int first (strict int will not accept float)
v = SchemaValidator(core_schema.union_schema(choices, mode='left_to_right', strict=True))
v = SchemaValidator(core_schema.union_schema(choices, mode='left_to_right'))
out = v.validate_python(1)
assert out == 1
assert isinstance(out, int)
Expand All @@ -482,7 +513,12 @@ def test_left_to_right_union_strict():
assert isinstance(out, float)

# reversing union will select float always (as strict float will accept int)
v = SchemaValidator(core_schema.union_schema(list(reversed(choices)), mode='left_to_right', strict=True))
v = SchemaValidator(
core_schema.union_schema(
list(reversed(choices)),
mode='left_to_right',
)
)
out = v.validate_python(1.0)
assert out == 1.0
assert isinstance(out, float)
Expand Down
Loading