Skip to content

Commit 6991676

Browse files
committed
[red-knot] Decompose bool to Literal[True, False] in unions and intersections
1 parent c824140 commit 6991676

File tree

8 files changed

+192
-145
lines changed

8 files changed

+192
-145
lines changed

crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,22 @@ reveal_type(c >= d) # revealed: Literal[True]
5858
#### Results with Ambiguity
5959

6060
```py
61-
def _(x: bool, y: int):
61+
class P:
62+
def __lt__(self, other: "P") -> bool:
63+
return True
64+
65+
def __le__(self, other: "P") -> bool:
66+
return True
67+
68+
def __gt__(self, other: "P") -> bool:
69+
return True
70+
71+
def __ge__(self, other: "P") -> bool:
72+
return True
73+
74+
class Q(P): ...
75+
76+
def _(x: P, y: Q):
6277
a = (x,)
6378
b = (y,)
6479

crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,9 @@ else:
455455
reveal_type(x) # revealed: slice
456456
finally:
457457
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
458-
reveal_type(x) # revealed: bool | float | slice
458+
reveal_type(x) # revealed: bool | slice | float
459459

460-
reveal_type(x) # revealed: bool | float | slice
460+
reveal_type(x) # revealed: bool | slice | float
461461
```
462462

463463
## Nested `try`/`except` blocks
@@ -534,7 +534,7 @@ try:
534534
reveal_type(x) # revealed: slice
535535
finally:
536536
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
537-
reveal_type(x) # revealed: bool | float | slice
537+
reveal_type(x) # revealed: bool | slice | float
538538
x = 2
539539
reveal_type(x) # revealed: Literal[2]
540540
reveal_type(x) # revealed: Literal[2]

crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@ else:
2121
if x and not x:
2222
reveal_type(x) # revealed: Never
2323
else:
24-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
24+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
2525

2626
if not (x and not x):
27-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
27+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
2828
else:
2929
reveal_type(x) # revealed: Never
3030

3131
if x or not x:
32-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
32+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
3333
else:
3434
reveal_type(x) # revealed: Never
3535

3636
if not (x or not x):
3737
reveal_type(x) # revealed: Never
3838
else:
39-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
39+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
4040

4141
if (isinstance(x, int) or isinstance(x, str)) and x:
4242
reveal_type(x) # revealed: Literal[-1, True, "foo"]

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,48 @@ class R: ...
118118
static_assert(is_equivalent_to(Intersection[tuple[P | Q], R], Intersection[tuple[Q | P], R]))
119119
```
120120

121+
## Unions containing tuples containing `bool`
122+
123+
```py
124+
from knot_extensions import is_equivalent_to, static_assert
125+
from typing_extensions import Literal
126+
127+
class P: ...
128+
129+
static_assert(is_equivalent_to(tuple[Literal[True, False]] | P, tuple[bool] | P))
130+
static_assert(is_equivalent_to(P | tuple[bool], P | tuple[Literal[True, False]]))
131+
```
132+
133+
## Unions and intersections involving `AlwaysTruthy`, `bool` and `AlwaysFalsy`
134+
135+
```py
136+
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not
137+
from typing_extensions import Literal
138+
139+
static_assert(is_equivalent_to(AlwaysTruthy | bool, Literal[False] | AlwaysTruthy))
140+
static_assert(is_equivalent_to(AlwaysFalsy | bool, Literal[True] | AlwaysFalsy))
141+
static_assert(is_equivalent_to(Not[AlwaysTruthy] | bool, Not[AlwaysTruthy] | Literal[True]))
142+
static_assert(is_equivalent_to(Not[AlwaysFalsy] | bool, Literal[False] | Not[AlwaysFalsy]))
143+
```
144+
145+
## Unions and intersections involving `AlwaysTruthy`, `LiteralString` and `AlwaysFalsy`
146+
147+
```py
148+
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not, Intersection
149+
from typing_extensions import Literal, LiteralString
150+
151+
# TODO: these should all pass!
152+
153+
# error: [static-assert-error]
154+
static_assert(is_equivalent_to(AlwaysTruthy | LiteralString, Literal[""] | AlwaysTruthy))
155+
# error: [static-assert-error]
156+
static_assert(is_equivalent_to(AlwaysFalsy | LiteralString, Intersection[LiteralString, Not[Literal[""]]] | AlwaysFalsy))
157+
# error: [static-assert-error]
158+
static_assert(is_equivalent_to(Not[AlwaysFalsy] | LiteralString, Literal[""] | Not[AlwaysFalsy]))
159+
# error: [static-assert-error]
160+
static_assert(
161+
is_equivalent_to(Not[AlwaysTruthy] | LiteralString, Not[AlwaysTruthy] | Intersection[LiteralString, Not[Literal[""]]])
162+
)
163+
```
164+
121165
[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent

crates/red_knot_python_semantic/src/types.rs

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,31 @@ impl<'db> Type<'db> {
811811
}
812812
}
813813

814+
/// Normalize the type `bool` -> `Literal[True, False]`.
815+
///
816+
/// Using this method in various type-relational methods
817+
/// ensures that the following invariants hold true:
818+
///
819+
/// - bool ≡ Literal[True, False]
820+
/// - bool | T ≡ Literal[True, False] | T
821+
/// - bool <: Literal[True, False]
822+
/// - bool | T <: Literal[True, False] | T
823+
/// - Literal[True, False] <: bool
824+
/// - Literal[True, False] | T <: bool | T
825+
#[must_use]
826+
pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self {
827+
const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)];
828+
829+
match self {
830+
Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => {
831+
Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS)))
832+
}
833+
// TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`?
834+
// We'd need to rename this method... --Alex
835+
_ => self,
836+
}
837+
}
838+
814839
/// Return a normalized version of `self` in which all unions and intersections are sorted
815840
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
816841
#[must_use]
@@ -859,6 +884,12 @@ impl<'db> Type<'db> {
859884
return true;
860885
}
861886

887+
let normalized_self = self.with_normalized_bools(db);
888+
let normalized_target = target.with_normalized_bools(db);
889+
if normalized_self != self || normalized_target != target {
890+
return normalized_self.is_subtype_of(db, normalized_target);
891+
}
892+
862893
// Non-fully-static types do not participate in subtyping.
863894
//
864895
// Type `A` can only be a subtype of type `B` if the set of possible runtime objects
@@ -961,7 +992,7 @@ impl<'db> Type<'db> {
961992
KnownClass::Str.to_instance(db).is_subtype_of(db, target)
962993
}
963994
(Type::BooleanLiteral(_), _) => {
964-
KnownClass::Bool.to_instance(db).is_subtype_of(db, target)
995+
KnownClass::Int.to_instance(db).is_subtype_of(db, target)
965996
}
966997
(Type::IntLiteral(_), _) => KnownClass::Int.to_instance(db).is_subtype_of(db, target),
967998
(Type::BytesLiteral(_), _) => {
@@ -1077,6 +1108,11 @@ impl<'db> Type<'db> {
10771108
if self.is_gradual_equivalent_to(db, target) {
10781109
return true;
10791110
}
1111+
let normalized_self = self.with_normalized_bools(db);
1112+
let normalized_target = target.with_normalized_bools(db);
1113+
if normalized_self != self || normalized_target != target {
1114+
return normalized_self.is_assignable_to(db, normalized_target);
1115+
}
10801116
match (self, target) {
10811117
// Never can be assigned to any type.
10821118
(Type::Never, _) => true,
@@ -1177,6 +1213,13 @@ impl<'db> Type<'db> {
11771213
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
11781214
// TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc.
11791215

1216+
let normalized_self = self.with_normalized_bools(db);
1217+
let normalized_other = other.with_normalized_bools(db);
1218+
1219+
if normalized_self != self || normalized_other != other {
1220+
return normalized_self.is_equivalent_to(db, normalized_other);
1221+
}
1222+
11801223
match (self, other) {
11811224
(Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right),
11821225
(Type::Intersection(left), Type::Intersection(right)) => {
@@ -1218,6 +1261,13 @@ impl<'db> Type<'db> {
12181261
///
12191262
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
12201263
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
1264+
let normalized_self = self.with_normalized_bools(db);
1265+
let normalized_other = other.with_normalized_bools(db);
1266+
1267+
if normalized_self != self || normalized_other != other {
1268+
return normalized_self.is_gradual_equivalent_to(db, normalized_other);
1269+
}
1270+
12211271
if self == other {
12221272
return true;
12231273
}
@@ -1250,6 +1300,12 @@ impl<'db> Type<'db> {
12501300
/// Note: This function aims to have no false positives, but might return
12511301
/// wrong `false` answers in some cases.
12521302
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
1303+
let normalized_self = self.with_normalized_bools(db);
1304+
let normalized_other = other.with_normalized_bools(db);
1305+
if normalized_self != self || normalized_other != other {
1306+
return normalized_self.is_disjoint_from(db, normalized_other);
1307+
}
1308+
12531309
match (self, other) {
12541310
(Type::Never, _) | (_, Type::Never) => true,
12551311

@@ -4642,18 +4698,19 @@ pub struct TupleType<'db> {
46424698
}
46434699

46444700
impl<'db> TupleType<'db> {
4645-
pub fn from_elements<T: Into<Type<'db>>>(
4646-
db: &'db dyn Db,
4647-
types: impl IntoIterator<Item = T>,
4648-
) -> Type<'db> {
4701+
pub fn from_elements<I, T>(db: &'db dyn Db, types: I) -> Type<'db>
4702+
where
4703+
I: IntoIterator<Item = T>,
4704+
T: Into<Type<'db>>,
4705+
{
46494706
let mut elements = vec![];
46504707

46514708
for ty in types {
4652-
let ty = ty.into();
4709+
let ty: Type<'db> = ty.into();
46534710
if ty.is_never() {
46544711
return Type::Never;
46554712
}
4656-
elements.push(ty);
4713+
elements.push(ty.with_normalized_bools(db));
46574714
}
46584715

46594716
Type::Tuple(Self::new(db, elements.into_boxed_slice()))

0 commit comments

Comments
 (0)