diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index 8ae96d733aaba5..44fa07756c2bcd 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -84,4 +84,26 @@ static_assert( ) ``` +## Unions containing tuples containing tuples containing unions (etc.) + +```py +from knot_extensions import is_equivalent_to, static_assert, Intersection + +class P: ... +class Q: ... + +static_assert( + is_equivalent_to( + tuple[tuple[tuple[P | Q]]] | P, + tuple[tuple[tuple[Q | P]]] | P, + ) +) +static_assert( + is_equivalent_to( + tuple[tuple[tuple[tuple[tuple[Intersection[P, Q]]]]]], + tuple[tuple[tuple[tuple[tuple[Intersection[Q, P]]]]]], + ) +) +``` + [the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 69940e4a9d78f9..14dfb013946c0e 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -811,6 +811,35 @@ impl<'db> Type<'db> { } } + /// Return a normalized version of `self` in which all unions and intersections are sorted + /// according to a canonical order, no matter how "deeply" a union/intersection may be nested. + #[must_use] + pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self { + match self { + Type::Union(union) => Type::Union(union.to_sorted_union(db)), + Type::Intersection(intersection) => { + Type::Intersection(intersection.to_sorted_intersection(db)) + } + Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions(db)), + Type::LiteralString + | Type::Instance(_) + | Type::AlwaysFalsy + | Type::AlwaysTruthy + | Type::BooleanLiteral(_) + | Type::SliceLiteral(_) + | Type::BytesLiteral(_) + | Type::StringLiteral(_) + | Type::Dynamic(_) + | Type::Never + | Type::FunctionLiteral(_) + | Type::ModuleLiteral(_) + | Type::ClassLiteral(_) + | Type::KnownInstance(_) + | Type::IntLiteral(_) + | Type::SubclassOf(_) => self, + } + } + /// Return true if this type is a [subtype of] type `target`. /// /// This method returns `false` if either `self` or `other` is not fully static. @@ -1154,7 +1183,7 @@ impl<'db> Type<'db> { left.is_equivalent_to(db, right) } (Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right), - _ => self.is_fully_static(db) && other.is_fully_static(db) && self == other, + _ => self == other && self.is_fully_static(db) && other.is_fully_static(db), } } @@ -4352,12 +4381,11 @@ impl<'db> UnionType<'db> { /// Create a new union type with the elements sorted according to a canonical ordering. #[must_use] pub fn to_sorted_union(self, db: &'db dyn Db) -> Self { - let mut new_elements = self.elements(db).to_vec(); - for element in &mut new_elements { - if let Type::Intersection(intersection) = element { - intersection.sort(db); - } - } + let mut new_elements: Vec> = self + .elements(db) + .iter() + .map(|element| element.with_sorted_unions(db)) + .collect(); new_elements.sort_unstable_by(union_elements_ordering); UnionType::new(db, new_elements.into_boxed_slice()) } @@ -4453,19 +4481,24 @@ impl<'db> IntersectionType<'db> { /// according to a canonical ordering. #[must_use] pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self { - let mut positive = self.positive(db).clone(); - positive.sort_unstable_by(union_elements_ordering); - - let mut negative = self.negative(db).clone(); - negative.sort_unstable_by(union_elements_ordering); + fn normalized_set<'db>( + db: &'db dyn Db, + elements: &FxOrderSet>, + ) -> FxOrderSet> { + let mut elements: FxOrderSet> = elements + .iter() + .map(|ty| ty.with_sorted_unions(db)) + .collect(); - IntersectionType::new(db, positive, negative) - } + elements.sort_unstable_by(union_elements_ordering); + elements + } - /// Perform an in-place sort of this [`IntersectionType`] instance - /// according to a canonical ordering. - fn sort(&mut self, db: &'db dyn Db) { - *self = self.to_sorted_intersection(db); + IntersectionType::new( + db, + normalized_set(db, self.positive(db)), + normalized_set(db, self.negative(db)), + ) } pub fn is_fully_static(self, db: &'db dyn Db) -> bool { @@ -4608,6 +4641,18 @@ impl<'db> TupleType<'db> { Type::Tuple(Self::new(db, elements.into_boxed_slice())) } + /// Return a normalized version of `self` in which all unions and intersections are sorted + /// according to a canonical order, no matter how "deeply" a union/intersection may be nested. + #[must_use] + pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self { + let elements: Box<[Type<'db>]> = self + .elements(db) + .iter() + .map(|ty| ty.with_sorted_unions(db)) + .collect(); + TupleType::new(db, elements) + } + pub fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { let self_elements = self.elements(db); let other_elements = other.elements(db);