Skip to content

Commit

Permalink
fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Jan 25, 2025
1 parent a93f407 commit 226a4cb
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static_assert(
## Unions containing tuples containing tuples containing unions (etc.)

```py
from knot_extensions import is_equivalent_to, static_assert
from knot_extensions import is_equivalent_to, static_assert, Intersection

class P: ...
class Q: ...
Expand All @@ -98,6 +98,12 @@ static_assert(
tuple[tuple[tuple[Q | P]]] | P,
)
)
static_assert(
is_equivalent_to(
tuple[tuple[tuple[tuple[tuple[Intersection[P, Q]]]]]],
tuple[tuple[tuple[tuple[tuple[Intersection[Q, P]]]]]],
)
)
```

[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent
81 changes: 63 additions & 18 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,35 @@ impl<'db> Type<'db> {
}
}

/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
#[must_use]
pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self {
match self {
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
Type::Intersection(intersection) => {
Type::Intersection(intersection.to_sorted_intersection(db))
}
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions(db)),
Type::LiteralString
| Type::Instance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::BooleanLiteral(_)
| Type::SliceLiteral(_)
| Type::BytesLiteral(_)
| Type::StringLiteral(_)
| Type::Dynamic(_)
| Type::Never
| Type::FunctionLiteral(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::KnownInstance(_)
| Type::IntLiteral(_)
| Type::SubclassOf(_) => self,
}
}

/// Return true if this type is a [subtype of] type `target`.
///
/// This method returns `false` if either `self` or `other` is not fully static.
Expand Down Expand Up @@ -1154,7 +1183,7 @@ impl<'db> Type<'db> {
left.is_equivalent_to(db, right)
}
(Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right),
_ => self.is_fully_static(db) && other.is_fully_static(db) && self == other,
_ => self == other && self.is_fully_static(db) && other.is_fully_static(db),
}
}

Expand Down Expand Up @@ -4352,12 +4381,11 @@ impl<'db> UnionType<'db> {
/// Create a new union type with the elements sorted according to a canonical ordering.
#[must_use]
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
let mut new_elements = self.elements(db).to_vec();
for element in &mut new_elements {
if let Type::Intersection(intersection) = element {
intersection.sort(db);
}
}
let mut new_elements: Vec<Type<'db>> = self
.elements(db)
.iter()
.map(|element| element.with_sorted_unions(db))
.collect();
new_elements.sort_unstable_by(union_elements_ordering);
UnionType::new(db, new_elements.into_boxed_slice())
}
Expand Down Expand Up @@ -4453,19 +4481,24 @@ impl<'db> IntersectionType<'db> {
/// according to a canonical ordering.
#[must_use]
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
let mut positive = self.positive(db).clone();
positive.sort_unstable_by(union_elements_ordering);

let mut negative = self.negative(db).clone();
negative.sort_unstable_by(union_elements_ordering);
fn normalized_set<'db>(
db: &'db dyn Db,
elements: &FxOrderSet<Type<'db>>,
) -> FxOrderSet<Type<'db>> {
let mut elements: FxOrderSet<Type<'db>> = elements
.iter()
.map(|ty| ty.with_sorted_unions(db))
.collect();

IntersectionType::new(db, positive, negative)
}
elements.sort_unstable_by(union_elements_ordering);
elements
}

/// Perform an in-place sort of this [`IntersectionType`] instance
/// according to a canonical ordering.
fn sort(&mut self, db: &'db dyn Db) {
*self = self.to_sorted_intersection(db);
IntersectionType::new(
db,
normalized_set(db, self.positive(db)),
normalized_set(db, self.negative(db)),
)
}

pub fn is_fully_static(self, db: &'db dyn Db) -> bool {
Expand Down Expand Up @@ -4608,6 +4641,18 @@ impl<'db> TupleType<'db> {
Type::Tuple(Self::new(db, elements.into_boxed_slice()))
}

/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
#[must_use]
pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self {
let elements: Box<[Type<'db>]> = self
.elements(db)
.iter()
.map(|ty| ty.with_sorted_unions(db))
.collect();
TupleType::new(db, elements)
}

pub fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
let self_elements = self.elements(db);
let other_elements = other.elements(db);
Expand Down

0 comments on commit 226a4cb

Please sign in to comment.