Skip to content

Commit a347b0d

Browse files
committed
[red-knot] Decompose bool to Literal[True, False] in unions and intersections
1 parent fcd0f34 commit a347b0d

File tree

8 files changed

+220
-160
lines changed

8 files changed

+220
-160
lines changed

crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,22 @@ reveal_type(c >= d) # revealed: Literal[True]
5858
#### Results with Ambiguity
5959

6060
```py
61-
def _(x: bool, y: int):
61+
class P:
62+
def __lt__(self, other: "P") -> bool:
63+
return True
64+
65+
def __le__(self, other: "P") -> bool:
66+
return True
67+
68+
def __gt__(self, other: "P") -> bool:
69+
return True
70+
71+
def __ge__(self, other: "P") -> bool:
72+
return True
73+
74+
class Q(P): ...
75+
76+
def _(x: P, y: Q):
6277
a = (x,)
6378
b = (y,)
6479

crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,9 @@ else:
455455
reveal_type(x) # revealed: slice
456456
finally:
457457
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
458-
reveal_type(x) # revealed: bool | float | slice
458+
reveal_type(x) # revealed: bool | slice | float
459459

460-
reveal_type(x) # revealed: bool | float | slice
460+
reveal_type(x) # revealed: bool | slice | float
461461
```
462462

463463
## Nested `try`/`except` blocks
@@ -534,7 +534,7 @@ try:
534534
reveal_type(x) # revealed: slice
535535
finally:
536536
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
537-
reveal_type(x) # revealed: bool | float | slice
537+
reveal_type(x) # revealed: bool | slice | float
538538
x = 2
539539
reveal_type(x) # revealed: Literal[2]
540540
reveal_type(x) # revealed: Literal[2]

crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@ else:
2121
if x and not x:
2222
reveal_type(x) # revealed: Never
2323
else:
24-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
24+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
2525

2626
if not (x and not x):
27-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
27+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
2828
else:
2929
reveal_type(x) # revealed: Never
3030

3131
if x or not x:
32-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
32+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
3333
else:
3434
reveal_type(x) # revealed: Never
3535

3636
if not (x or not x):
3737
reveal_type(x) # revealed: Never
3838
else:
39-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
39+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
4040

4141
if (isinstance(x, int) or isinstance(x, str)) and x:
4242
reveal_type(x) # revealed: Literal[-1, True, "foo"]

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,46 @@ static_assert(
8484
)
8585
```
8686

87+
## Unions containing tuples containing `bool`
88+
89+
```py
90+
from knot_extensions import is_equivalent_to, static_assert
91+
from typing_extensions import Never, Literal
92+
93+
class P: ...
94+
95+
static_assert(is_equivalent_to(tuple[Literal[True, False]] | P, tuple[bool] | P))
96+
static_assert(is_equivalent_to(P | tuple[bool], P | tuple[Literal[True, False]]))
97+
```
98+
99+
## Unions and intersections involving `AlwaysTruthy`, `bool` and `AlwaysFalsy`
100+
101+
```py
102+
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not
103+
from typing_extensions import Literal
104+
105+
static_assert(is_equivalent_to(AlwaysTruthy | bool, Literal[False] | AlwaysTruthy))
106+
static_assert(is_equivalent_to(AlwaysFalsy | bool, Literal[True] | AlwaysFalsy))
107+
static_assert(is_equivalent_to(Not[AlwaysTruthy] | bool, Not[AlwaysTruthy] | Literal[True]))
108+
static_assert(is_equivalent_to(Not[AlwaysFalsy] | bool, Literal[False] | Not[AlwaysFalsy]))
109+
```
110+
111+
## Unions and intersections involving `AlwaysTruthy`, `LiteralString` and `AlwaysFalsy`
112+
113+
```py
114+
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not, Intersection
115+
from typing_extensions import Literal, LiteralString
116+
117+
# TODO: these should all pass!
118+
119+
# error: [static-assert-error]
120+
static_assert(is_equivalent_to(AlwaysTruthy | LiteralString, Literal[""] | AlwaysTruthy))
121+
# error: [static-assert-error]
122+
static_assert(is_equivalent_to(AlwaysFalsy | LiteralString, Intersection[LiteralString, Not[Literal[""]]] | AlwaysFalsy))
123+
# error: [static-assert-error]
124+
static_assert(is_equivalent_to(Not[AlwaysTruthy] | LiteralString, Not[AlwaysTruthy] | Intersection[LiteralString, Not[Literal[""]]]))
125+
# error: [static-assert-error]
126+
static_assert(is_equivalent_to(Not[AlwaysFalsy] | LiteralString, Literal[""] | Not[AlwaysFalsy]))
127+
```
128+
87129
[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent

crates/red_knot_python_semantic/src/types.rs

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,19 @@ impl<'db> Type<'db> {
811811
}
812812
}
813813

814+
#[must_use]
815+
pub fn normalized(self, db: &'db dyn Db) -> Self {
816+
const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)];
817+
818+
match self {
819+
Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => {
820+
Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS)))
821+
}
822+
// TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`? --Alex
823+
_ => self,
824+
}
825+
}
826+
814827
/// Return true if this type is a [subtype of] type `target`.
815828
///
816829
/// This method returns `false` if either `self` or `other` is not fully static.
@@ -840,7 +853,7 @@ impl<'db> Type<'db> {
840853
return false;
841854
}
842855

843-
match (self, target) {
856+
match (self.normalized(db), target.normalized(db)) {
844857
// We should have handled these immediately above.
845858
(Type::Dynamic(_), _) | (_, Type::Dynamic(_)) => {
846859
unreachable!("Non-fully-static types do not participate in subtyping!")
@@ -932,7 +945,7 @@ impl<'db> Type<'db> {
932945
KnownClass::Str.to_instance(db).is_subtype_of(db, target)
933946
}
934947
(Type::BooleanLiteral(_), _) => {
935-
KnownClass::Bool.to_instance(db).is_subtype_of(db, target)
948+
KnownClass::Int.to_instance(db).is_subtype_of(db, target)
936949
}
937950
(Type::IntLiteral(_), _) => KnownClass::Int.to_instance(db).is_subtype_of(db, target),
938951
(Type::BytesLiteral(_), _) => {
@@ -1048,6 +1061,14 @@ impl<'db> Type<'db> {
10481061
if self.is_gradual_equivalent_to(db, target) {
10491062
return true;
10501063
}
1064+
let normalized_self = self.normalized(db);
1065+
if normalized_self != self {
1066+
return normalized_self.is_assignable_to(db, target);
1067+
}
1068+
let normalized_target = target.normalized(db);
1069+
if normalized_target != target {
1070+
return self.is_assignable_to(db, normalized_target);
1071+
}
10511072
match (self, target) {
10521073
// Never can be assigned to any type.
10531074
(Type::Never, _) => true,
@@ -1148,13 +1169,13 @@ impl<'db> Type<'db> {
11481169
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
11491170
// TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc.
11501171

1151-
match (self, other) {
1172+
match (self.normalized(db), other.normalized(db)) {
11521173
(Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right),
11531174
(Type::Intersection(left), Type::Intersection(right)) => {
11541175
left.is_equivalent_to(db, right)
11551176
}
11561177
(Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right),
1157-
_ => self.is_fully_static(db) && other.is_fully_static(db) && self == other,
1178+
(left, right) => left == right && left.is_fully_static(db) && right.is_fully_static(db),
11581179
}
11591180
}
11601181

@@ -1189,11 +1210,14 @@ impl<'db> Type<'db> {
11891210
///
11901211
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
11911212
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
1192-
if self == other {
1213+
let left = self.normalized(db);
1214+
let right = other.normalized(db);
1215+
1216+
if left == right {
11931217
return true;
11941218
}
11951219

1196-
match (self, other) {
1220+
match (left, right) {
11971221
(Type::Dynamic(_), Type::Dynamic(_)) => true,
11981222

11991223
(Type::SubclassOf(first), Type::SubclassOf(second)) => {
@@ -1221,6 +1245,15 @@ impl<'db> Type<'db> {
12211245
/// Note: This function aims to have no false positives, but might return
12221246
/// wrong `false` answers in some cases.
12231247
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
1248+
let normalized_self = self.normalized(db);
1249+
if normalized_self != self {
1250+
return normalized_self.is_disjoint_from(db, other);
1251+
}
1252+
let normalized_other = other.normalized(db);
1253+
if normalized_other != other {
1254+
return self.is_disjoint_from(db, normalized_other);
1255+
}
1256+
12241257
match (self, other) {
12251258
(Type::Never, _) | (_, Type::Never) => true,
12261259

@@ -4354,8 +4387,10 @@ impl<'db> UnionType<'db> {
43544387
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
43554388
let mut new_elements = self.elements(db).to_vec();
43564389
for element in &mut new_elements {
4357-
if let Type::Intersection(intersection) = element {
4358-
intersection.sort(db);
4390+
match element {
4391+
Type::Intersection(intersection) => intersection.sort(db),
4392+
Type::Tuple(tuple) => tuple.sort_inner_unions(db),
4393+
_ => {}
43594394
}
43604395
}
43614396
new_elements.sort_unstable_by(union_elements_ordering);
@@ -4453,10 +4488,26 @@ impl<'db> IntersectionType<'db> {
44534488
/// according to a canonical ordering.
44544489
#[must_use]
44554490
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
4456-
let mut positive = self.positive(db).clone();
4491+
let mut positive: FxOrderSet<Type<'db>> = self
4492+
.positive(db)
4493+
.iter()
4494+
.map(|ty| match ty {
4495+
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_inner_unions(db)),
4496+
_ => *ty,
4497+
})
4498+
.collect();
4499+
44574500
positive.sort_unstable_by(union_elements_ordering);
44584501

4459-
let mut negative = self.negative(db).clone();
4502+
let mut negative: FxOrderSet<Type<'db>> = self
4503+
.negative(db)
4504+
.iter()
4505+
.map(|ty| match ty {
4506+
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_inner_unions(db)),
4507+
_ => *ty,
4508+
})
4509+
.collect();
4510+
44604511
negative.sort_unstable_by(union_elements_ordering);
44614512

44624513
IntersectionType::new(db, positive, negative)
@@ -4591,23 +4642,44 @@ pub struct TupleType<'db> {
45914642
}
45924643

45934644
impl<'db> TupleType<'db> {
4594-
pub fn from_elements<T: Into<Type<'db>>>(
4595-
db: &'db dyn Db,
4596-
types: impl IntoIterator<Item = T>,
4597-
) -> Type<'db> {
4645+
pub fn from_elements<I, T>(db: &'db dyn Db, types: I) -> Type<'db>
4646+
where
4647+
I: IntoIterator<Item = T>,
4648+
T: Into<Type<'db>>,
4649+
{
45984650
let mut elements = vec![];
45994651

46004652
for ty in types {
46014653
let ty = ty.into();
46024654
if ty.is_never() {
46034655
return Type::Never;
46044656
}
4605-
elements.push(ty);
4657+
elements.push(ty.normalized(db));
46064658
}
46074659

46084660
Type::Tuple(Self::new(db, elements.into_boxed_slice()))
46094661
}
46104662

4663+
#[must_use]
4664+
pub fn with_sorted_inner_unions(self, db: &'db dyn Db) -> Self {
4665+
let elements: Box<[Type<'db>]> = self
4666+
.elements(db)
4667+
.iter()
4668+
.map(|ty| match ty {
4669+
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
4670+
Type::Intersection(intersection) => {
4671+
Type::Intersection(intersection.to_sorted_intersection(db))
4672+
}
4673+
_ => *ty,
4674+
})
4675+
.collect();
4676+
TupleType::new(db, elements)
4677+
}
4678+
4679+
pub fn sort_inner_unions(&mut self, db: &'db dyn Db) {
4680+
*self = self.with_sorted_inner_unions(db);
4681+
}
4682+
46114683
pub fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
46124684
let self_elements = self.elements(db);
46134685
let other_elements = other.elements(db);

0 commit comments

Comments
 (0)