@@ -811,6 +811,31 @@ impl<'db> Type<'db> {
811811 }
812812 }
813813
814+ /// Normalize the type `bool` -> `Literal[True, False]`.
815+ ///
816+ /// Using this method in various type-relational methods
817+ /// ensures that the following invariants hold true:
818+ ///
819+ /// - bool ≡ Literal[True, False]
820+ /// - bool | T ≡ Literal[True, False] | T
821+ /// - bool <: Literal[True, False]
822+ /// - bool | T <: Literal[True, False] | T
823+ /// - Literal[True, False] <: bool
824+ /// - Literal[True, False] | T <: bool | T
825+ #[ must_use]
826+ pub fn with_normalized_bools ( self , db : & ' db dyn Db ) -> Self {
827+ const LITERAL_BOOLS : [ Type ; 2 ] = [ Type :: BooleanLiteral ( false ) , Type :: BooleanLiteral ( true ) ] ;
828+
829+ match self {
830+ Type :: Instance ( InstanceType { class } ) if class. is_known ( db, KnownClass :: Bool ) => {
831+ Type :: Union ( UnionType :: new ( db, Box :: from ( LITERAL_BOOLS ) ) )
832+ }
833+ // TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`?
834+ // We'd need to rename this method... --Alex
835+ _ => self ,
836+ }
837+ }
838+
814839 /// Return a normalized version of `self` in which all unions and intersections are sorted
815840 /// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
816841 #[ must_use]
@@ -859,6 +884,12 @@ impl<'db> Type<'db> {
859884 return true ;
860885 }
861886
887+ let normalized_self = self . with_normalized_bools ( db) ;
888+ let normalized_target = target. with_normalized_bools ( db) ;
889+ if normalized_self != self || normalized_target != target {
890+ return normalized_self. is_subtype_of ( db, normalized_target) ;
891+ }
892+
862893 // Non-fully-static types do not participate in subtyping.
863894 //
864895 // Type `A` can only be a subtype of type `B` if the set of possible runtime objects
@@ -961,7 +992,7 @@ impl<'db> Type<'db> {
961992 KnownClass :: Str . to_instance ( db) . is_subtype_of ( db, target)
962993 }
963994 ( Type :: BooleanLiteral ( _) , _) => {
964- KnownClass :: Bool . to_instance ( db) . is_subtype_of ( db, target)
995+ KnownClass :: Int . to_instance ( db) . is_subtype_of ( db, target)
965996 }
966997 ( Type :: IntLiteral ( _) , _) => KnownClass :: Int . to_instance ( db) . is_subtype_of ( db, target) ,
967998 ( Type :: BytesLiteral ( _) , _) => {
@@ -1077,6 +1108,11 @@ impl<'db> Type<'db> {
10771108 if self . is_gradual_equivalent_to ( db, target) {
10781109 return true ;
10791110 }
1111+ let normalized_self = self . with_normalized_bools ( db) ;
1112+ let normalized_target = target. with_normalized_bools ( db) ;
1113+ if normalized_self != self || normalized_target != target {
1114+ return normalized_self. is_assignable_to ( db, normalized_target) ;
1115+ }
10801116 match ( self , target) {
10811117 // Never can be assigned to any type.
10821118 ( Type :: Never , _) => true ,
@@ -1177,6 +1213,13 @@ impl<'db> Type<'db> {
11771213 pub ( crate ) fn is_equivalent_to ( self , db : & ' db dyn Db , other : Type < ' db > ) -> bool {
11781214 // TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc.
11791215
1216+ let normalized_self = self . with_normalized_bools ( db) ;
1217+ let normalized_other = other. with_normalized_bools ( db) ;
1218+
1219+ if normalized_self != self || normalized_other != other {
1220+ return normalized_self. is_equivalent_to ( db, normalized_other) ;
1221+ }
1222+
11801223 match ( self , other) {
11811224 ( Type :: Union ( left) , Type :: Union ( right) ) => left. is_equivalent_to ( db, right) ,
11821225 ( Type :: Intersection ( left) , Type :: Intersection ( right) ) => {
@@ -1218,6 +1261,13 @@ impl<'db> Type<'db> {
12181261 ///
12191262 /// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
12201263 pub ( crate ) fn is_gradual_equivalent_to ( self , db : & ' db dyn Db , other : Type < ' db > ) -> bool {
1264+ let normalized_self = self . with_normalized_bools ( db) ;
1265+ let normalized_other = other. with_normalized_bools ( db) ;
1266+
1267+ if normalized_self != self || normalized_other != other {
1268+ return normalized_self. is_gradual_equivalent_to ( db, normalized_other) ;
1269+ }
1270+
12211271 if self == other {
12221272 return true ;
12231273 }
@@ -1250,6 +1300,12 @@ impl<'db> Type<'db> {
12501300 /// Note: This function aims to have no false positives, but might return
12511301 /// wrong `false` answers in some cases.
12521302 pub ( crate ) fn is_disjoint_from ( self , db : & ' db dyn Db , other : Type < ' db > ) -> bool {
1303+ let normalized_self = self . with_normalized_bools ( db) ;
1304+ let normalized_other = other. with_normalized_bools ( db) ;
1305+ if normalized_self != self || normalized_other != other {
1306+ return normalized_self. is_disjoint_from ( db, normalized_other) ;
1307+ }
1308+
12531309 match ( self , other) {
12541310 ( Type :: Never , _) | ( _, Type :: Never ) => true ,
12551311
@@ -4642,18 +4698,19 @@ pub struct TupleType<'db> {
46424698}
46434699
46444700impl < ' db > TupleType < ' db > {
4645- pub fn from_elements < T : Into < Type < ' db > > > (
4646- db : & ' db dyn Db ,
4647- types : impl IntoIterator < Item = T > ,
4648- ) -> Type < ' db > {
4701+ pub fn from_elements < I , T > ( db : & ' db dyn Db , types : I ) -> Type < ' db >
4702+ where
4703+ I : IntoIterator < Item = T > ,
4704+ T : Into < Type < ' db > > ,
4705+ {
46494706 let mut elements = vec ! [ ] ;
46504707
46514708 for ty in types {
4652- let ty = ty. into ( ) ;
4709+ let ty: Type < ' db > = ty. into ( ) ;
46534710 if ty. is_never ( ) {
46544711 return Type :: Never ;
46554712 }
4656- elements. push ( ty) ;
4713+ elements. push ( ty. with_normalized_bools ( db ) ) ;
46574714 }
46584715
46594716 Type :: Tuple ( Self :: new ( db, elements. into_boxed_slice ( ) ) )
0 commit comments