@@ -729,37 +729,77 @@ where
729729}
730730
731731/// Remove axes with length one, except never removing the last axis.
732+ ///
733+ /// This function is a no-op for const dim.
732734pub ( crate ) fn squeeze < D > ( dim : & mut D , strides : & mut D )
733735where
734736 D : Dimension ,
735737{
736738 if let Some ( _) = D :: NDIM {
737739 return ;
738740 }
741+
742+ // infallible for dyn dim
743+ let ( d, s) = squeeze_into ( dim, strides) . unwrap ( ) ;
744+ * dim = d;
745+ * strides = s;
746+ }
747+
748+ /// Remove axes with length one, except never removing the last axis.
749+ ///
750+ /// Return an error if there are more non-unitary dimensions than can be stored
751+ /// in `E`. Infallible for dyn dim.
752+ ///
753+ /// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is
754+ /// dynamic 0D, the output can be too.
755+ ///
756+ /// For const dim, this may instead pad the dimensionality with ones if it needs
757+ /// to grow to fill the target dimensionality; the dimension is padded in the
758+ /// start.
759+ pub ( crate ) fn squeeze_into < D , E > ( dim : & D , strides : & D ) -> Result < ( E , E ) , ShapeError >
760+ where
761+ D : Dimension ,
762+ E : Dimension ,
763+ {
739764 debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
740765
741766 // Count axes with dim == 1; we keep axes with d == 0 or d > 1
742767 let mut ndim_new = 0 ;
743768 for & d in dim. slice ( ) {
744769 if d != 1 { ndim_new += 1 ; }
745770 }
746- ndim_new = Ord :: max ( 1 , ndim_new) ;
747- let mut new_dim = D :: zeros ( ndim_new) ;
748- let mut new_strides = D :: zeros ( ndim_new) ;
771+ let mut fill_ones = 0 ;
772+ if let Some ( e_ndim) = E :: NDIM {
773+ if e_ndim < ndim_new {
774+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
775+ }
776+ fill_ones = e_ndim - ndim_new;
777+ ndim_new = e_ndim;
778+ } else {
779+ // dynamic-dimensional
780+ // use minimum one dimension unless input has less than one dim
781+ if dim. ndim ( ) > 0 && ndim_new == 0 {
782+ ndim_new = 1 ;
783+ fill_ones = 1 ;
784+ }
785+ }
786+
787+ let mut new_dim = E :: zeros ( ndim_new) ;
788+ let mut new_strides = E :: zeros ( ndim_new) ;
749789 let mut i = 0 ;
790+ while i < fill_ones {
791+ new_dim[ i] = 1 ;
792+ new_strides[ i] = 1 ;
793+ i += 1 ;
794+ }
750795 for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
751796 if d != 1 {
752797 new_dim[ i] = d;
753798 new_strides[ i] = s;
754799 i += 1 ;
755800 }
756801 }
757- if i == 0 {
758- new_dim[ i] = 1 ;
759- new_strides[ i] = 1 ;
760- }
761- * dim = new_dim;
762- * strides = new_strides;
802+ Ok ( ( new_dim, new_strides) )
763803}
764804
765805
@@ -1148,6 +1188,91 @@ mod test {
11481188 assert_eq ! ( s, sans) ;
11491189 }
11501190
1191+ #[ test]
1192+ #[ cfg( feature = "std" ) ]
1193+ fn test_squeeze_into ( ) {
1194+ use super :: squeeze_into;
1195+
1196+ let dyndim = Dim :: < & [ usize ] > ;
1197+
1198+ // squeeze to ixdyn
1199+ let d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1200+ let s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1201+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1202+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1203+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1204+ assert_eq ! ( d2, dans) ;
1205+ assert_eq ! ( s2, sans) ;
1206+
1207+ // squeeze to ixdyn does not go below 1D
1208+ let d = dyndim ( & [ 1 , 1 ] ) ;
1209+ let s = dyndim ( & [ 3 , 4 ] ) ;
1210+ let dans = dyndim ( & [ 1 ] ) ;
1211+ let sans = dyndim ( & [ 1 ] ) ;
1212+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1213+ assert_eq ! ( d2, dans) ;
1214+ assert_eq ! ( s2, sans) ;
1215+
1216+ let d = Dim ( [ 1 , 1 ] ) ;
1217+ let s = Dim ( [ 3 , 4 ] ) ;
1218+ let dans = Dim ( [ 1 ] ) ;
1219+ let sans = Dim ( [ 1 ] ) ;
1220+ let ( d2, s2) = squeeze_into :: < _ , Ix1 > ( & d, & s) . unwrap ( ) ;
1221+ assert_eq ! ( d2, dans) ;
1222+ assert_eq ! ( s2, sans) ;
1223+
1224+ // squeeze to zero-dim
1225+ let ( d2, s2) = squeeze_into :: < _ , Ix0 > ( & d, & s) . unwrap ( ) ;
1226+ assert_eq ! ( d2, Ix0 ( ) ) ;
1227+ assert_eq ! ( s2, Ix0 ( ) ) ;
1228+
1229+ let d = Dim ( [ 0 , 1 , 3 , 4 ] ) ;
1230+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1231+ let dans = Dim ( [ 0 , 3 , 4 ] ) ;
1232+ let sans = Dim ( [ 2 , 4 , 5 ] ) ;
1233+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1234+ assert_eq ! ( d2, dans) ;
1235+ assert_eq ! ( s2, sans) ;
1236+
1237+ // Pad with ones
1238+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1239+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1240+ let dans = Dim ( [ 1 , 0 , 3 ] ) ;
1241+ let sans = Dim ( [ 1 , 2 , 4 ] ) ;
1242+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1243+ assert_eq ! ( d2, dans) ;
1244+ assert_eq ! ( s2, sans) ;
1245+
1246+ // Try something that doesn't fit
1247+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1248+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1249+ let res = squeeze_into :: < _ , Ix1 > ( & d, & s) ;
1250+ assert ! ( res. is_err( ) ) ;
1251+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1252+ assert ! ( res. is_err( ) ) ;
1253+
1254+ // Squeeze 0d to 0d
1255+ let d = Dim ( [ ] ) ;
1256+ let s = Dim ( [ ] ) ;
1257+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1258+ assert ! ( res. is_ok( ) ) ;
1259+ // grow 0d to 2d
1260+ let dans = Dim ( [ 1 , 1 ] ) ;
1261+ let sans = Dim ( [ 1 , 1 ] ) ;
1262+ let ( d2, s2) = squeeze_into :: < _ , Ix2 > ( & d, & s) . unwrap ( ) ;
1263+ assert_eq ! ( d2, dans) ;
1264+ assert_eq ! ( s2, sans) ;
1265+
1266+ // Squeeze 0d to 0d dynamic
1267+ let d = dyndim ( & [ ] ) ;
1268+ let s = dyndim ( & [ ] ) ;
1269+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1270+ let dans = d;
1271+ let sans = s;
1272+ assert_eq ! ( d2, dans) ;
1273+ assert_eq ! ( s2, sans) ;
1274+ }
1275+
11511276 #[ test]
11521277 fn test_merge_axes_from_the_back ( ) {
11531278 let dyndim = Dim :: < & [ usize ] > ;
0 commit comments