@@ -785,37 +785,77 @@ where
785785}
786786
787787/// Remove axes with length one, except never removing the last axis.
788+ ///
789+ /// This function is a no-op for const dim.
788790pub ( crate ) fn squeeze < D > ( dim : & mut D , strides : & mut D )
789791where
790792 D : Dimension ,
791793{
792794 if let Some ( _) = D :: NDIM {
793795 return ;
794796 }
797+
798+ // infallible for dyn dim
799+ let ( d, s) = squeeze_into ( dim, strides) . unwrap ( ) ;
800+ * dim = d;
801+ * strides = s;
802+ }
803+
804+ /// Remove axes with length one, except never removing the last axis.
805+ ///
806+ /// Return an error if there are more non-unitary dimensions than can be stored
807+ /// in `E`. Infallible for dyn dim.
808+ ///
809+ /// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is
810+ /// dynamic 0D, the output can be too.
811+ ///
812+ /// For const dim, this may instead pad the dimensionality with ones if it needs
813+ /// to grow to fill the target dimensionality; the dimension is padded in the
814+ /// start.
815+ pub ( crate ) fn squeeze_into < D , E > ( dim : & D , strides : & D ) -> Result < ( E , E ) , ShapeError >
816+ where
817+ D : Dimension ,
818+ E : Dimension ,
819+ {
795820 debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
796821
797822 // Count axes with dim == 1; we keep axes with d == 0 or d > 1
798823 let mut ndim_new = 0 ;
799824 for & d in dim. slice ( ) {
800825 if d != 1 { ndim_new += 1 ; }
801826 }
802- ndim_new = Ord :: max ( 1 , ndim_new) ;
803- let mut new_dim = D :: zeros ( ndim_new) ;
804- let mut new_strides = D :: zeros ( ndim_new) ;
827+ let mut fill_ones = 0 ;
828+ if let Some ( e_ndim) = E :: NDIM {
829+ if e_ndim < ndim_new {
830+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
831+ }
832+ fill_ones = e_ndim - ndim_new;
833+ ndim_new = e_ndim;
834+ } else {
835+ // dynamic-dimensional
836+ // use minimum one dimension unless input has less than one dim
837+ if dim. ndim ( ) > 0 && ndim_new == 0 {
838+ ndim_new = 1 ;
839+ fill_ones = 1 ;
840+ }
841+ }
842+
843+ let mut new_dim = E :: zeros ( ndim_new) ;
844+ let mut new_strides = E :: zeros ( ndim_new) ;
805845 let mut i = 0 ;
846+ while i < fill_ones {
847+ new_dim[ i] = 1 ;
848+ new_strides[ i] = 1 ;
849+ i += 1 ;
850+ }
806851 for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
807852 if d != 1 {
808853 new_dim[ i] = d;
809854 new_strides[ i] = s;
810855 i += 1 ;
811856 }
812857 }
813- if i == 0 {
814- new_dim[ i] = 1 ;
815- new_strides[ i] = 1 ;
816- }
817- * dim = new_dim;
818- * strides = new_strides;
858+ Ok ( ( new_dim, new_strides) )
819859}
820860
821861
@@ -1220,6 +1260,91 @@ mod test {
12201260 assert_eq ! ( s, sans) ;
12211261 }
12221262
1263+ #[ test]
1264+ #[ cfg( feature = "std" ) ]
1265+ fn test_squeeze_into ( ) {
1266+ use super :: squeeze_into;
1267+
1268+ let dyndim = Dim :: < & [ usize ] > ;
1269+
1270+ // squeeze to ixdyn
1271+ let d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1272+ let s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1273+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1274+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1275+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1276+ assert_eq ! ( d2, dans) ;
1277+ assert_eq ! ( s2, sans) ;
1278+
1279+ // squeeze to ixdyn does not go below 1D
1280+ let d = dyndim ( & [ 1 , 1 ] ) ;
1281+ let s = dyndim ( & [ 3 , 4 ] ) ;
1282+ let dans = dyndim ( & [ 1 ] ) ;
1283+ let sans = dyndim ( & [ 1 ] ) ;
1284+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1285+ assert_eq ! ( d2, dans) ;
1286+ assert_eq ! ( s2, sans) ;
1287+
1288+ let d = Dim ( [ 1 , 1 ] ) ;
1289+ let s = Dim ( [ 3 , 4 ] ) ;
1290+ let dans = Dim ( [ 1 ] ) ;
1291+ let sans = Dim ( [ 1 ] ) ;
1292+ let ( d2, s2) = squeeze_into :: < _ , Ix1 > ( & d, & s) . unwrap ( ) ;
1293+ assert_eq ! ( d2, dans) ;
1294+ assert_eq ! ( s2, sans) ;
1295+
1296+ // squeeze to zero-dim
1297+ let ( d2, s2) = squeeze_into :: < _ , Ix0 > ( & d, & s) . unwrap ( ) ;
1298+ assert_eq ! ( d2, Ix0 ( ) ) ;
1299+ assert_eq ! ( s2, Ix0 ( ) ) ;
1300+
1301+ let d = Dim ( [ 0 , 1 , 3 , 4 ] ) ;
1302+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1303+ let dans = Dim ( [ 0 , 3 , 4 ] ) ;
1304+ let sans = Dim ( [ 2 , 4 , 5 ] ) ;
1305+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1306+ assert_eq ! ( d2, dans) ;
1307+ assert_eq ! ( s2, sans) ;
1308+
1309+ // Pad with ones
1310+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1311+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1312+ let dans = Dim ( [ 1 , 0 , 3 ] ) ;
1313+ let sans = Dim ( [ 1 , 2 , 4 ] ) ;
1314+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1315+ assert_eq ! ( d2, dans) ;
1316+ assert_eq ! ( s2, sans) ;
1317+
1318+ // Try something that doesn't fit
1319+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1320+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1321+ let res = squeeze_into :: < _ , Ix1 > ( & d, & s) ;
1322+ assert ! ( res. is_err( ) ) ;
1323+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1324+ assert ! ( res. is_err( ) ) ;
1325+
1326+ // Squeeze 0d to 0d
1327+ let d = Dim ( [ ] ) ;
1328+ let s = Dim ( [ ] ) ;
1329+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1330+ assert ! ( res. is_ok( ) ) ;
1331+ // grow 0d to 2d
1332+ let dans = Dim ( [ 1 , 1 ] ) ;
1333+ let sans = Dim ( [ 1 , 1 ] ) ;
1334+ let ( d2, s2) = squeeze_into :: < _ , Ix2 > ( & d, & s) . unwrap ( ) ;
1335+ assert_eq ! ( d2, dans) ;
1336+ assert_eq ! ( s2, sans) ;
1337+
1338+ // Squeeze 0d to 0d dynamic
1339+ let d = dyndim ( & [ ] ) ;
1340+ let s = dyndim ( & [ ] ) ;
1341+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1342+ let dans = d;
1343+ let sans = s;
1344+ assert_eq ! ( d2, dans) ;
1345+ assert_eq ! ( s2, sans) ;
1346+ }
1347+
12231348 #[ test]
12241349 fn test_merge_axes_from_the_back ( ) {
12251350 let dyndim = Dim :: < & [ usize ] > ;
0 commit comments