Skip to content

Commit 8b6a5fd

Browse files
committed
[red-knot] Decompose bool to Literal[True, False] in unions and intersections
1 parent fcd0f34 commit 8b6a5fd

File tree

8 files changed

+255
-170
lines changed

8 files changed

+255
-170
lines changed

crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,22 @@ reveal_type(c >= d) # revealed: Literal[True]
5858
#### Results with Ambiguity
5959

6060
```py
61-
def _(x: bool, y: int):
61+
class P:
62+
def __lt__(self, other: "P") -> bool:
63+
return True
64+
65+
def __le__(self, other: "P") -> bool:
66+
return True
67+
68+
def __gt__(self, other: "P") -> bool:
69+
return True
70+
71+
def __ge__(self, other: "P") -> bool:
72+
return True
73+
74+
class Q(P): ...
75+
76+
def _(x: P, y: Q):
6277
a = (x,)
6378
b = (y,)
6479

crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,9 @@ else:
455455
reveal_type(x) # revealed: slice
456456
finally:
457457
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
458-
reveal_type(x) # revealed: bool | float | slice
458+
reveal_type(x) # revealed: bool | slice | float
459459

460-
reveal_type(x) # revealed: bool | float | slice
460+
reveal_type(x) # revealed: bool | slice | float
461461
```
462462

463463
## Nested `try`/`except` blocks
@@ -534,7 +534,7 @@ try:
534534
reveal_type(x) # revealed: slice
535535
finally:
536536
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
537-
reveal_type(x) # revealed: bool | float | slice
537+
reveal_type(x) # revealed: bool | slice | float
538538
x = 2
539539
reveal_type(x) # revealed: Literal[2]
540540
reveal_type(x) # revealed: Literal[2]

crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@ else:
2121
if x and not x:
2222
reveal_type(x) # revealed: Never
2323
else:
24-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
24+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
2525

2626
if not (x and not x):
27-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
27+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
2828
else:
2929
reveal_type(x) # revealed: Never
3030

3131
if x or not x:
32-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
32+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
3333
else:
3434
reveal_type(x) # revealed: Never
3535

3636
if not (x or not x):
3737
reveal_type(x) # revealed: Never
3838
else:
39-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
39+
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
4040

4141
if (isinstance(x, int) or isinstance(x, str)) and x:
4242
reveal_type(x) # revealed: Literal[-1, True, "foo"]

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,56 @@ static_assert(
8484
)
8585
```
8686

87+
## Unions containing tuples containing `bool`
88+
89+
```py
90+
from knot_extensions import is_equivalent_to, static_assert
91+
from typing_extensions import Never, Literal
92+
93+
class P: ...
94+
class Q: ...
95+
96+
static_assert(is_equivalent_to(tuple[Literal[True, False]] | P, tuple[bool] | P))
97+
static_assert(is_equivalent_to(P | tuple[bool], P | tuple[Literal[True, False]]))
98+
99+
static_assert(
100+
is_equivalent_to(
101+
tuple[tuple[tuple[P | Q]]] | P,
102+
tuple[tuple[tuple[Q | P]]] | P,
103+
)
104+
)
105+
```
106+
107+
## Unions and intersections involving `AlwaysTruthy`, `bool` and `AlwaysFalsy`
108+
109+
```py
110+
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not
111+
from typing_extensions import Literal
112+
113+
static_assert(is_equivalent_to(AlwaysTruthy | bool, Literal[False] | AlwaysTruthy))
114+
static_assert(is_equivalent_to(AlwaysFalsy | bool, Literal[True] | AlwaysFalsy))
115+
static_assert(is_equivalent_to(Not[AlwaysTruthy] | bool, Not[AlwaysTruthy] | Literal[True]))
116+
static_assert(is_equivalent_to(Not[AlwaysFalsy] | bool, Literal[False] | Not[AlwaysFalsy]))
117+
```
118+
119+
## Unions and intersections involving `AlwaysTruthy`, `LiteralString` and `AlwaysFalsy`
120+
121+
```py
122+
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not, Intersection
123+
from typing_extensions import Literal, LiteralString
124+
125+
# TODO: these should all pass!
126+
127+
# error: [static-assert-error]
128+
static_assert(is_equivalent_to(AlwaysTruthy | LiteralString, Literal[""] | AlwaysTruthy))
129+
# error: [static-assert-error]
130+
static_assert(is_equivalent_to(AlwaysFalsy | LiteralString, Intersection[LiteralString, Not[Literal[""]]] | AlwaysFalsy))
131+
# error: [static-assert-error]
132+
static_assert(is_equivalent_to(Not[AlwaysFalsy] | LiteralString, Literal[""] | Not[AlwaysFalsy]))
133+
# error: [static-assert-error]
134+
static_assert(
135+
is_equivalent_to(Not[AlwaysTruthy] | LiteralString, Not[AlwaysTruthy] | Intersection[LiteralString, Not[Literal[""]]])
136+
)
137+
```
138+
87139
[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent

crates/red_knot_python_semantic/src/types.rs

Lines changed: 112 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,47 @@ impl<'db> Type<'db> {
811811
}
812812
}
813813

814+
#[must_use]
815+
pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self {
816+
const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)];
817+
818+
match self {
819+
Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => {
820+
Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS)))
821+
}
822+
// TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`?
823+
// We'd need to rename this method... --Alex
824+
_ => self,
825+
}
826+
}
827+
828+
#[must_use]
829+
pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self {
830+
match self {
831+
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
832+
Type::Intersection(intersection) => {
833+
Type::Intersection(intersection.to_sorted_intersection(db))
834+
}
835+
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions(db)),
836+
Type::LiteralString
837+
| Type::Instance(_)
838+
| Type::AlwaysFalsy
839+
| Type::AlwaysTruthy
840+
| Type::BooleanLiteral(_)
841+
| Type::SliceLiteral(_)
842+
| Type::BytesLiteral(_)
843+
| Type::StringLiteral(_)
844+
| Type::Dynamic(_)
845+
| Type::Never
846+
| Type::FunctionLiteral(_)
847+
| Type::ModuleLiteral(_)
848+
| Type::ClassLiteral(_)
849+
| Type::KnownInstance(_)
850+
| Type::IntLiteral(_)
851+
| Type::SubclassOf(_) => self,
852+
}
853+
}
854+
814855
/// Return true if this type is a [subtype of] type `target`.
815856
///
816857
/// This method returns `false` if either `self` or `other` is not fully static.
@@ -830,6 +871,12 @@ impl<'db> Type<'db> {
830871
return true;
831872
}
832873

874+
let left = self.with_normalized_bools(db);
875+
let right = target.with_normalized_bools(db);
876+
if left != self || right != target {
877+
return left.is_subtype_of(db, right);
878+
}
879+
833880
// Non-fully-static types do not participate in subtyping.
834881
//
835882
// Type `A` can only be a subtype of type `B` if the set of possible runtime objects
@@ -932,7 +979,7 @@ impl<'db> Type<'db> {
932979
KnownClass::Str.to_instance(db).is_subtype_of(db, target)
933980
}
934981
(Type::BooleanLiteral(_), _) => {
935-
KnownClass::Bool.to_instance(db).is_subtype_of(db, target)
982+
KnownClass::Int.to_instance(db).is_subtype_of(db, target)
936983
}
937984
(Type::IntLiteral(_), _) => KnownClass::Int.to_instance(db).is_subtype_of(db, target),
938985
(Type::BytesLiteral(_), _) => {
@@ -1048,6 +1095,11 @@ impl<'db> Type<'db> {
10481095
if self.is_gradual_equivalent_to(db, target) {
10491096
return true;
10501097
}
1098+
let normalized_self = self.with_normalized_bools(db);
1099+
let normalized_target = target.with_normalized_bools(db);
1100+
if normalized_self != self || normalized_target != target {
1101+
return normalized_self.is_assignable_to(db, normalized_target);
1102+
}
10511103
match (self, target) {
10521104
// Never can be assigned to any type.
10531105
(Type::Never, _) => true,
@@ -1148,13 +1200,20 @@ impl<'db> Type<'db> {
11481200
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
11491201
// TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc.
11501202

1203+
let normalized_self = self.with_normalized_bools(db);
1204+
let normalized_other = other.with_normalized_bools(db);
1205+
1206+
if normalized_self != self || normalized_other != other {
1207+
return normalized_self.is_equivalent_to(db, normalized_other);
1208+
}
1209+
11511210
match (self, other) {
11521211
(Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right),
11531212
(Type::Intersection(left), Type::Intersection(right)) => {
11541213
left.is_equivalent_to(db, right)
11551214
}
11561215
(Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right),
1157-
_ => self.is_fully_static(db) && other.is_fully_static(db) && self == other,
1216+
_ => self == other && self.is_fully_static(db) && other.is_fully_static(db),
11581217
}
11591218
}
11601219

@@ -1189,6 +1248,13 @@ impl<'db> Type<'db> {
11891248
///
11901249
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
11911250
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
1251+
let normalized_self = self.with_normalized_bools(db);
1252+
let normalized_other = other.with_normalized_bools(db);
1253+
1254+
if normalized_self != self || normalized_other != other {
1255+
return normalized_self.is_gradual_equivalent_to(db, normalized_other);
1256+
}
1257+
11921258
if self == other {
11931259
return true;
11941260
}
@@ -1221,6 +1287,12 @@ impl<'db> Type<'db> {
12211287
/// Note: This function aims to have no false positives, but might return
12221288
/// wrong `false` answers in some cases.
12231289
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
1290+
let normalized_self = self.with_normalized_bools(db);
1291+
let normalized_other = other.with_normalized_bools(db);
1292+
if normalized_self != self || normalized_other != other {
1293+
return normalized_self.is_disjoint_from(db, normalized_other);
1294+
}
1295+
12241296
match (self, other) {
12251297
(Type::Never, _) | (_, Type::Never) => true,
12261298

@@ -4352,12 +4424,11 @@ impl<'db> UnionType<'db> {
43524424
/// Create a new union type with the elements sorted according to a canonical ordering.
43534425
#[must_use]
43544426
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
4355-
let mut new_elements = self.elements(db).to_vec();
4356-
for element in &mut new_elements {
4357-
if let Type::Intersection(intersection) = element {
4358-
intersection.sort(db);
4359-
}
4360-
}
4427+
let mut new_elements: Vec<Type<'db>> = self
4428+
.elements(db)
4429+
.iter()
4430+
.map(|element| element.with_sorted_unions(db))
4431+
.collect();
43614432
new_elements.sort_unstable_by(union_elements_ordering);
43624433
UnionType::new(db, new_elements.into_boxed_slice())
43634434
}
@@ -4453,19 +4524,24 @@ impl<'db> IntersectionType<'db> {
44534524
/// according to a canonical ordering.
44544525
#[must_use]
44554526
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
4456-
let mut positive = self.positive(db).clone();
4457-
positive.sort_unstable_by(union_elements_ordering);
4458-
4459-
let mut negative = self.negative(db).clone();
4460-
negative.sort_unstable_by(union_elements_ordering);
4527+
fn normalized_set<'db>(
4528+
db: &'db dyn Db,
4529+
elements: &FxOrderSet<Type<'db>>,
4530+
) -> FxOrderSet<Type<'db>> {
4531+
let mut elements: FxOrderSet<Type<'db>> = elements
4532+
.iter()
4533+
.map(|ty| ty.with_sorted_unions(db))
4534+
.collect();
44614535

4462-
IntersectionType::new(db, positive, negative)
4463-
}
4536+
elements.sort_unstable_by(union_elements_ordering);
4537+
elements
4538+
}
44644539

4465-
/// Perform an in-place sort of this [`IntersectionType`] instance
4466-
/// according to a canonical ordering.
4467-
fn sort(&mut self, db: &'db dyn Db) {
4468-
*self = self.to_sorted_intersection(db);
4540+
IntersectionType::new(
4541+
db,
4542+
normalized_set(db, self.positive(db)),
4543+
normalized_set(db, self.negative(db)),
4544+
)
44694545
}
44704546

44714547
pub fn is_fully_static(self, db: &'db dyn Db) -> bool {
@@ -4591,23 +4667,34 @@ pub struct TupleType<'db> {
45914667
}
45924668

45934669
impl<'db> TupleType<'db> {
4594-
pub fn from_elements<T: Into<Type<'db>>>(
4595-
db: &'db dyn Db,
4596-
types: impl IntoIterator<Item = T>,
4597-
) -> Type<'db> {
4670+
pub fn from_elements<I, T>(db: &'db dyn Db, types: I) -> Type<'db>
4671+
where
4672+
I: IntoIterator<Item = T>,
4673+
T: Into<Type<'db>>,
4674+
{
45984675
let mut elements = vec![];
45994676

46004677
for ty in types {
4601-
let ty = ty.into();
4678+
let ty: Type<'db> = ty.into();
46024679
if ty.is_never() {
46034680
return Type::Never;
46044681
}
4605-
elements.push(ty);
4682+
elements.push(ty.with_normalized_bools(db));
46064683
}
46074684

46084685
Type::Tuple(Self::new(db, elements.into_boxed_slice()))
46094686
}
46104687

4688+
#[must_use]
4689+
pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self {
4690+
let elements: Box<[Type<'db>]> = self
4691+
.elements(db)
4692+
.iter()
4693+
.map(|ty| ty.with_sorted_unions(db))
4694+
.collect();
4695+
TupleType::new(db, elements)
4696+
}
4697+
46114698
pub fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
46124699
let self_elements = self.elements(db);
46134700
let other_elements = other.elements(db);

0 commit comments

Comments
 (0)