@@ -4,28 +4,60 @@ use assert_fs::NamedTempFile;
44use polars:: prelude:: * ;
55use std:: { fs:: File , path:: PathBuf , process:: Command , result:: Result } ;
66
7+ enum OverwriteOption {
8+ Overwrite ( NamedTempFile ) ,
9+ DoNotOverwrite ,
10+ }
11+
712fn cli_data_to_parquet (
813 base_file_name : & str ,
14+ overwrite : OverwriteOption ,
915 rows_to_stream : Option < u32 > ,
1016) -> Result < ( Command , NamedTempFile ) , Box < dyn std:: error:: Error > > {
11- let tempfile = NamedTempFile :: new ( format ! ( "{}.parquet" , base_file_name) ) ?;
12-
1317 let mut cmd = Command :: cargo_bin ( "readstat" ) ?;
1418
15- if let Some ( rows) = rows_to_stream {
16- cmd. arg ( "data" )
17- . arg ( format ! ( "tests/data/{}.sas7bdat" , base_file_name) )
18- . args ( [ "--format" , "parquet" ] )
19- . args ( [ "--output" , tempfile. as_os_str ( ) . to_str ( ) . unwrap ( ) ] )
20- . args ( [ "--stream-rows" , rows. to_string ( ) . as_str ( ) ] )
21- . arg ( "--overwrite" ) ;
22- } else {
23- cmd. arg ( "data" )
24- . arg ( format ! ( "tests/data/{}.sas7bdat" , base_file_name) )
25- . args ( [ "--format" , "parquet" ] )
26- . args ( [ "--output" , tempfile. as_os_str ( ) . to_str ( ) . unwrap ( ) ] )
27- . arg ( "--overwrite" ) ;
28- }
19+ let tempfile = match ( overwrite, rows_to_stream) {
20+ ( OverwriteOption :: Overwrite ( tempfile) , Some ( rows) ) => {
21+ cmd. arg ( "data" )
22+ . arg ( format ! ( "tests/data/{}.sas7bdat" , base_file_name) )
23+ . args ( [ "--format" , "parquet" ] )
24+ . args ( [ "--output" , tempfile. as_os_str ( ) . to_str ( ) . unwrap ( ) ] )
25+ . args ( [ "--stream-rows" , rows. to_string ( ) . as_str ( ) ] )
26+ . arg ( "--overwrite" ) ;
27+
28+ tempfile
29+ }
30+ ( OverwriteOption :: DoNotOverwrite , Some ( rows) ) => {
31+ let tempfile = NamedTempFile :: new ( format ! ( "{}.parquet" , base_file_name) ) ?;
32+
33+ cmd. arg ( "data" )
34+ . arg ( format ! ( "tests/data/{}.sas7bdat" , base_file_name) )
35+ . args ( [ "--format" , "parquet" ] )
36+ . args ( [ "--output" , tempfile. as_os_str ( ) . to_str ( ) . unwrap ( ) ] )
37+ . args ( [ "--stream-rows" , rows. to_string ( ) . as_str ( ) ] ) ;
38+
39+ tempfile
40+ }
41+ ( OverwriteOption :: Overwrite ( tempfile) , None ) => {
42+ cmd. arg ( "data" )
43+ . arg ( format ! ( "tests/data/{}.sas7bdat" , base_file_name) )
44+ . args ( [ "--format" , "parquet" ] )
45+ . args ( [ "--output" , tempfile. as_os_str ( ) . to_str ( ) . unwrap ( ) ] )
46+ . arg ( "--overwrite" ) ;
47+
48+ tempfile
49+ }
50+ ( OverwriteOption :: DoNotOverwrite , None ) => {
51+ let tempfile = NamedTempFile :: new ( format ! ( "{}.parquet" , base_file_name) ) ?;
52+
53+ cmd. arg ( "data" )
54+ . arg ( format ! ( "tests/data/{}.sas7bdat" , base_file_name) )
55+ . args ( [ "--format" , "parquet" ] )
56+ . args ( [ "--output" , tempfile. as_os_str ( ) . to_str ( ) . unwrap ( ) ] ) ;
57+
58+ tempfile
59+ }
60+ } ;
2961
3062 Ok ( ( cmd, tempfile) )
3163}
@@ -40,7 +72,8 @@ fn parquet_to_df(path: PathBuf) -> Result<DataFrame, Box<dyn std::error::Error>>
4072
4173#[ test]
4274fn cars_to_parquet ( ) {
43- let ( mut cmd, tempfile) = cli_data_to_parquet ( "cars" , None ) . unwrap ( ) ;
75+ let ( mut cmd, tempfile) =
76+ cli_data_to_parquet ( "cars" , OverwriteOption :: DoNotOverwrite , None ) . unwrap ( ) ;
4477
4578 cmd. assert ( ) . success ( ) . stdout ( predicate:: str:: contains (
4679 "In total, wrote 1,081 rows from file cars.sas7bdat into cars.parquet" ,
@@ -58,7 +91,36 @@ fn cars_to_parquet() {
5891
5992#[ test]
6093fn cars_to_parquet_with_streaming ( ) {
61- let ( mut cmd, tempfile) = cli_data_to_parquet ( "cars" , Some ( 500 ) ) . unwrap ( ) ;
94+ let ( mut cmd, tempfile) =
95+ cli_data_to_parquet ( "cars" , OverwriteOption :: DoNotOverwrite , Some ( 500 ) ) . unwrap ( ) ;
96+
97+ cmd. assert ( ) . success ( ) . stdout ( predicate:: str:: contains (
98+ "In total, wrote 1,081 rows from file cars.sas7bdat into cars.parquet" ,
99+ ) ) ;
100+
101+ let df = parquet_to_df ( tempfile. to_path_buf ( ) ) . unwrap ( ) ;
102+
103+ let ( height, width) = df. shape ( ) ;
104+
105+ assert_eq ! ( height, 1081 ) ;
106+ assert_eq ! ( width, 13 ) ;
107+
108+ tempfile. close ( ) . unwrap ( ) ;
109+ }
110+
111+ #[ test]
112+ fn cars_to_parquet_overwrite ( ) {
113+ // first stream
114+ let ( mut cmd, tempfile) =
115+ cli_data_to_parquet ( "cars" , OverwriteOption :: DoNotOverwrite , Some ( 500 ) ) . unwrap ( ) ;
116+
117+ cmd. assert ( ) . success ( ) . stdout ( predicate:: str:: contains (
118+ "In total, wrote 1,081 rows from file cars.sas7bdat into cars.parquet" ,
119+ ) ) ;
120+
121+ // next do not stream
122+ let ( mut cmd, tempfile) =
123+ cli_data_to_parquet ( "cars" , OverwriteOption :: Overwrite ( tempfile) , None ) . unwrap ( ) ;
62124
63125 cmd. assert ( ) . success ( ) . stdout ( predicate:: str:: contains (
64126 "In total, wrote 1,081 rows from file cars.sas7bdat into cars.parquet" ,
0 commit comments