Skip to content

Commit 4ce8fb1

Browse files
Take care when overwriting - need to truncate file if it already exists. Otherwise writing will start at the beginning and overwrite and potentially leave extra data at the end of the file.
1 parent e16f481 commit 4ce8fb1

File tree

2 files changed

+120
-26
lines changed

2 files changed

+120
-26
lines changed

readstat/src/rs_write.rs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,19 @@ impl ReadStatWriter {
230230
fn write_data_to_csv(&mut self, d: &ReadStatData, rsp: &ReadStatPath) -> Result<(), Box<dyn Error + Send + Sync>> {
231231
if let Some(p) = &rsp.out_path {
232232
// if already started writing, then need to append to file; otherwise create file
233-
let f = OpenOptions::new()
233+
let f = if self.wrote_start {
234+
OpenOptions::new()
234235
.write(true)
235236
.create(true)
236237
.append(true)
237-
.open(p)?;
238+
.open(p)?
239+
} else {
240+
OpenOptions::new()
241+
.write(true)
242+
.create(true)
243+
.truncate(true)
244+
.open(p)?
245+
};
238246

239247
// set message for what is being read/written
240248
self.write_message_for_rows(d, rsp)?;
@@ -274,11 +282,19 @@ impl ReadStatWriter {
274282
fn write_data_to_feather(&mut self, d: &ReadStatData, rsp: &ReadStatPath) -> Result<(), Box<dyn Error + Send + Sync>> {
275283
if let Some(p) = &rsp.out_path {
276284
// if already started writing, then need to append to file; otherwise create file
277-
let f = OpenOptions::new()
285+
let f = if self.wrote_start {
286+
OpenOptions::new()
278287
.write(true)
279288
.create(true)
280289
.append(true)
281-
.open(p)?;
290+
.open(p)?
291+
} else {
292+
OpenOptions::new()
293+
.write(true)
294+
.create(true)
295+
.truncate(true)
296+
.open(p)?
297+
};
282298

283299
// set message for what is being read/written
284300
self.write_message_for_rows(d, rsp)?;
@@ -332,11 +348,19 @@ impl ReadStatWriter {
332348
fn write_data_to_ndjson(&mut self, d: &ReadStatData, rsp: &ReadStatPath) -> Result<(), Box<dyn Error + Send + Sync>> {
333349
if let Some(p) = &rsp.out_path {
334350
// if already started writing, then need to append to file; otherwise create file
335-
let f = OpenOptions::new()
351+
let f = if self.wrote_start {
352+
OpenOptions::new()
336353
.write(true)
337354
.create(true)
338355
.append(true)
339-
.open(p)?;
356+
.open(p)?
357+
} else {
358+
OpenOptions::new()
359+
.write(true)
360+
.create(true)
361+
.truncate(true)
362+
.open(p)?
363+
};
340364

341365
// set message for what is being read/written
342366
self.write_message_for_rows(d, rsp)?;
@@ -380,11 +404,19 @@ impl ReadStatWriter {
380404
fn write_data_to_parquet(&mut self, d: &ReadStatData, rsp: &ReadStatPath) -> Result<(), Box<dyn Error + Send + Sync>> {
381405
if let Some(p) = &rsp.out_path {
382406
// if already started writing, then need to append to file; otherwise create file
383-
let f = OpenOptions::new()
407+
let f = if self.wrote_start {
408+
OpenOptions::new()
384409
.write(true)
385410
.create(true)
386411
.append(true)
387-
.open(p)?;
412+
.open(p)?
413+
} else {
414+
OpenOptions::new()
415+
.write(true)
416+
.create(true)
417+
.truncate(true)
418+
.open(p)?
419+
};
388420

389421
// set message for what is being read/written
390422
self.write_message_for_rows(d, rsp)?;

readstat/tests/cli_data_parquet.rs

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,60 @@ use assert_fs::NamedTempFile;
44
use polars::prelude::*;
55
use std::{fs::File, path::PathBuf, process::Command, result::Result};
66

7+
enum OverwriteOption {
8+
Overwrite(NamedTempFile),
9+
DoNotOverwrite,
10+
}
11+
712
fn cli_data_to_parquet(
813
base_file_name: &str,
14+
overwrite: OverwriteOption,
915
rows_to_stream: Option<u32>,
1016
) -> Result<(Command, NamedTempFile), Box<dyn std::error::Error>> {
11-
let tempfile = NamedTempFile::new(format!("{}.parquet", base_file_name))?;
12-
1317
let mut cmd = Command::cargo_bin("readstat")?;
1418

15-
if let Some(rows) = rows_to_stream {
16-
cmd.arg("data")
17-
.arg(format!("tests/data/{}.sas7bdat", base_file_name))
18-
.args(["--format", "parquet"])
19-
.args(["--output", tempfile.as_os_str().to_str().unwrap()])
20-
.args(["--stream-rows", rows.to_string().as_str()])
21-
.arg("--overwrite");
22-
} else {
23-
cmd.arg("data")
24-
.arg(format!("tests/data/{}.sas7bdat", base_file_name))
25-
.args(["--format", "parquet"])
26-
.args(["--output", tempfile.as_os_str().to_str().unwrap()])
27-
.arg("--overwrite");
28-
}
19+
let tempfile = match (overwrite, rows_to_stream) {
20+
(OverwriteOption::Overwrite(tempfile), Some(rows)) => {
21+
cmd.arg("data")
22+
.arg(format!("tests/data/{}.sas7bdat", base_file_name))
23+
.args(["--format", "parquet"])
24+
.args(["--output", tempfile.as_os_str().to_str().unwrap()])
25+
.args(["--stream-rows", rows.to_string().as_str()])
26+
.arg("--overwrite");
27+
28+
tempfile
29+
}
30+
(OverwriteOption::DoNotOverwrite, Some(rows)) => {
31+
let tempfile = NamedTempFile::new(format!("{}.parquet", base_file_name))?;
32+
33+
cmd.arg("data")
34+
.arg(format!("tests/data/{}.sas7bdat", base_file_name))
35+
.args(["--format", "parquet"])
36+
.args(["--output", tempfile.as_os_str().to_str().unwrap()])
37+
.args(["--stream-rows", rows.to_string().as_str()]);
38+
39+
tempfile
40+
}
41+
(OverwriteOption::Overwrite(tempfile), None) => {
42+
cmd.arg("data")
43+
.arg(format!("tests/data/{}.sas7bdat", base_file_name))
44+
.args(["--format", "parquet"])
45+
.args(["--output", tempfile.as_os_str().to_str().unwrap()])
46+
.arg("--overwrite");
47+
48+
tempfile
49+
}
50+
(OverwriteOption::DoNotOverwrite, None) => {
51+
let tempfile = NamedTempFile::new(format!("{}.parquet", base_file_name))?;
52+
53+
cmd.arg("data")
54+
.arg(format!("tests/data/{}.sas7bdat", base_file_name))
55+
.args(["--format", "parquet"])
56+
.args(["--output", tempfile.as_os_str().to_str().unwrap()]);
57+
58+
tempfile
59+
}
60+
};
2961

3062
Ok((cmd, tempfile))
3163
}
@@ -40,7 +72,8 @@ fn parquet_to_df(path: PathBuf) -> Result<DataFrame, Box<dyn std::error::Error>>
4072

4173
#[test]
4274
fn cars_to_parquet() {
43-
let (mut cmd, tempfile) = cli_data_to_parquet("cars", None).unwrap();
75+
let (mut cmd, tempfile) =
76+
cli_data_to_parquet("cars", OverwriteOption::DoNotOverwrite, None).unwrap();
4477

4578
cmd.assert().success().stdout(predicate::str::contains(
4679
"In total, wrote 1,081 rows from file cars.sas7bdat into cars.parquet",
@@ -58,7 +91,36 @@ fn cars_to_parquet() {
5891

5992
#[test]
6093
fn cars_to_parquet_with_streaming() {
61-
let (mut cmd, tempfile) = cli_data_to_parquet("cars", Some(500)).unwrap();
94+
let (mut cmd, tempfile) =
95+
cli_data_to_parquet("cars", OverwriteOption::DoNotOverwrite, Some(500)).unwrap();
96+
97+
cmd.assert().success().stdout(predicate::str::contains(
98+
"In total, wrote 1,081 rows from file cars.sas7bdat into cars.parquet",
99+
));
100+
101+
let df = parquet_to_df(tempfile.to_path_buf()).unwrap();
102+
103+
let (height, width) = df.shape();
104+
105+
assert_eq!(height, 1081);
106+
assert_eq!(width, 13);
107+
108+
tempfile.close().unwrap();
109+
}
110+
111+
#[test]
112+
fn cars_to_parquet_overwrite() {
113+
// first stream
114+
let (mut cmd, tempfile) =
115+
cli_data_to_parquet("cars", OverwriteOption::DoNotOverwrite, Some(500)).unwrap();
116+
117+
cmd.assert().success().stdout(predicate::str::contains(
118+
"In total, wrote 1,081 rows from file cars.sas7bdat into cars.parquet",
119+
));
120+
121+
// next do not stream
122+
let (mut cmd, tempfile) =
123+
cli_data_to_parquet("cars", OverwriteOption::Overwrite(tempfile), None).unwrap();
62124

63125
cmd.assert().success().stdout(predicate::str::contains(
64126
"In total, wrote 1,081 rows from file cars.sas7bdat into cars.parquet",

0 commit comments

Comments
 (0)