@@ -762,6 +762,53 @@ where
762762 * strides = new_strides;
763763}
764764
765+ /// Remove axes with length one, except never removing the last axis.
766+ pub ( crate ) fn squeeze_into < D , E > ( dim : & D , strides : & D ) -> Result < ( E , E ) , ShapeError >
767+ where
768+ D : Dimension ,
769+ E : Dimension ,
770+ {
771+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
772+
773+ // Count axes with dim == 1; we keep axes with d == 0 or d > 1
774+ let mut ndim_new = 0 ;
775+ for & d in dim. slice ( ) {
776+ if d != 1 { ndim_new += 1 ; }
777+ }
778+ let mut fill_ones = 0 ;
779+ if let Some ( e_ndim) = E :: NDIM {
780+ if e_ndim < ndim_new {
781+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
782+ }
783+ fill_ones = e_ndim - ndim_new;
784+ ndim_new = e_ndim;
785+ } else {
786+ // dynamic-dimensional
787+ // use minimum one dimension unless input has less than one dim
788+ if dim. ndim ( ) > 0 && ndim_new == 0 {
789+ ndim_new = 1 ;
790+ fill_ones = 1 ;
791+ }
792+ }
793+
794+ let mut new_dim = E :: zeros ( ndim_new) ;
795+ let mut new_strides = E :: zeros ( ndim_new) ;
796+ let mut i = 0 ;
797+ while i < fill_ones {
798+ new_dim[ i] = 1 ;
799+ new_strides[ i] = 1 ;
800+ i += 1 ;
801+ }
802+ for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
803+ if d != 1 {
804+ new_dim[ i] = d;
805+ new_strides[ i] = s;
806+ i += 1 ;
807+ }
808+ }
809+ Ok ( ( new_dim, new_strides) )
810+ }
811+
765812
766813/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
767814/// stride
@@ -1148,6 +1195,91 @@ mod test {
11481195 assert_eq ! ( s, sans) ;
11491196 }
11501197
1198+ #[ test]
1199+ #[ cfg( feature = "std" ) ]
1200+ fn test_squeeze_into ( ) {
1201+ use super :: squeeze_into;
1202+
1203+ let dyndim = Dim :: < & [ usize ] > ;
1204+
1205+ // squeeze to ixdyn
1206+ let d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1207+ let s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1208+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1209+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1210+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1211+ assert_eq ! ( d2, dans) ;
1212+ assert_eq ! ( s2, sans) ;
1213+
1214+ // squeeze to ixdyn does not go below 1D
1215+ let d = dyndim ( & [ 1 , 1 ] ) ;
1216+ let s = dyndim ( & [ 3 , 4 ] ) ;
1217+ let dans = dyndim ( & [ 1 ] ) ;
1218+ let sans = dyndim ( & [ 1 ] ) ;
1219+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1220+ assert_eq ! ( d2, dans) ;
1221+ assert_eq ! ( s2, sans) ;
1222+
1223+ let d = Dim ( [ 1 , 1 ] ) ;
1224+ let s = Dim ( [ 3 , 4 ] ) ;
1225+ let dans = Dim ( [ 1 ] ) ;
1226+ let sans = Dim ( [ 1 ] ) ;
1227+ let ( d2, s2) = squeeze_into :: < _ , Ix1 > ( & d, & s) . unwrap ( ) ;
1228+ assert_eq ! ( d2, dans) ;
1229+ assert_eq ! ( s2, sans) ;
1230+
1231+ // squeeze to zero-dim
1232+ let ( d2, s2) = squeeze_into :: < _ , Ix0 > ( & d, & s) . unwrap ( ) ;
1233+ assert_eq ! ( d2, Ix0 ( ) ) ;
1234+ assert_eq ! ( s2, Ix0 ( ) ) ;
1235+
1236+ let d = Dim ( [ 0 , 1 , 3 , 4 ] ) ;
1237+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1238+ let dans = Dim ( [ 0 , 3 , 4 ] ) ;
1239+ let sans = Dim ( [ 2 , 4 , 5 ] ) ;
1240+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1241+ assert_eq ! ( d2, dans) ;
1242+ assert_eq ! ( s2, sans) ;
1243+
1244+ // Pad with ones
1245+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1246+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1247+ let dans = Dim ( [ 1 , 0 , 3 ] ) ;
1248+ let sans = Dim ( [ 1 , 2 , 4 ] ) ;
1249+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1250+ assert_eq ! ( d2, dans) ;
1251+ assert_eq ! ( s2, sans) ;
1252+
1253+ // Try something that doesn't fit
1254+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1255+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1256+ let res = squeeze_into :: < _ , Ix1 > ( & d, & s) ;
1257+ assert ! ( res. is_err( ) ) ;
1258+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1259+ assert ! ( res. is_err( ) ) ;
1260+
1261+ // Squeeze 0d to 0d
1262+ let d = Dim ( [ ] ) ;
1263+ let s = Dim ( [ ] ) ;
1264+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1265+ assert ! ( res. is_ok( ) ) ;
1266+ // grow 0d to 2d
1267+ let dans = Dim ( [ 1 , 1 ] ) ;
1268+ let sans = Dim ( [ 1 , 1 ] ) ;
1269+ let ( d2, s2) = squeeze_into :: < _ , Ix2 > ( & d, & s) . unwrap ( ) ;
1270+ assert_eq ! ( d2, dans) ;
1271+ assert_eq ! ( s2, sans) ;
1272+
1273+ // Squeeze 0d to 0d dynamic
1274+ let d = dyndim ( & [ ] ) ;
1275+ let s = dyndim ( & [ ] ) ;
1276+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1277+ let dans = d;
1278+ let sans = s;
1279+ assert_eq ! ( d2, dans) ;
1280+ assert_eq ! ( s2, sans) ;
1281+ }
1282+
11511283 #[ test]
11521284 fn test_merge_axes_from_the_back ( ) {
11531285 let dyndim = Dim :: < & [ usize ] > ;
0 commit comments