Skip to content

Commit a782bd3

Browse files
committed
Using pathlib.Path directly, added encoding on load
1 parent 82c3502 commit a782bd3

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

src/sc2_datasets/torch/datasets/sc2_replaypack_dataset.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,13 @@ def __init__(
5454
self.unpack_n_workers = unpack_n_workers
5555
self.download_dir = Path(download_dir).resolve()
5656

57-
if not self.download_dir.is_dir():
58-
raise Exception("Download directory must be a directory!")
59-
6057
# Replaypack download directory must exist, we create it if it does not exist:
6158
if not self.download_dir.exists():
6259
self.download_dir.mkdir(parents=True, exist_ok=True)
6360

6461
self.unpack_dir = Path(unpack_dir).resolve()
6562
# Replaypack unpack directory must exist, we create it if it does not exist:
6663
# This is because otherwise we will not be able to load any data:
67-
if not self.unpack_dir.is_dir():
68-
raise Exception("Replaypack unpack directory must be a directory!")
69-
7064
if not self.unpack_dir.exists():
7165
self.unpack_dir.mkdir(parents=True, exist_ok=True)
7266

@@ -120,7 +114,7 @@ def __init__(
120114
self._replaypack_summary,
121115
) = load_replaypack_information(
122116
replaypack_name=self.replaypack_name,
123-
replaypack_path=self.replaypack_unpack_path.as_posix(),
117+
replaypack_path=self.replaypack_unpack_path,
124118
unpack_n_workers=self.unpack_n_workers,
125119
)
126120

src/sc2_datasets/utils/dataset_utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
from pathlib import Path
34

45
from sc2_datasets.utils.zip_utils import unpack_zipfile
56

@@ -8,7 +9,7 @@
89

910
def load_replaypack_information(
1011
replaypack_name: str,
11-
replaypack_path: str,
12+
replaypack_path: Path,
1213
unpack_n_workers: int,
1314
) -> Tuple[str, Dict[str, str], Dict[str, str]]:
1415
"""
@@ -56,9 +57,9 @@ def load_replaypack_information(
5657
>>> assert unpack_n_workers >= 1
5758
"""
5859

59-
replaypack_files = os.listdir(replaypack_path)
60+
replaypack_files = list(replaypack_path.iterdir())
6061
# Initializing variables that should be returned:
61-
replaypack_data_path = os.path.join(replaypack_path, replaypack_name + "_data")
62+
replaypack_data_path = Path(replaypack_path, replaypack_name + "_data").resolve()
6263
replaypack_main_log_obj_list = []
6364
replaypack_processed_failed = {}
6465
replaypack_summary = {}
@@ -67,29 +68,34 @@ def load_replaypack_information(
6768
# Extracting the nested .zip files,
6869
# and loading replaypack information files:
6970
for file in replaypack_files:
70-
if file.endswith("_data.zip"):
71+
filename = file.name
72+
if filename.endswith("_data.zip"):
7173
# Unpack the .zip archive only if it is not unpacked already:
72-
if not os.path.isdir(replaypack_data_path):
74+
if not replaypack_data_path.is_dir():
7375
replaypack_data_path = unpack_zipfile(
7476
destination_dir=replaypack_path,
7577
subdir=replaypack_name + "_data",
7678
zip_path=os.path.join(replaypack_path, file),
7779
n_workers=unpack_n_workers,
7880
)
79-
if file.endswith("_main_log.log"):
80-
with open(os.path.join(replaypack_path, file)) as main_log_file:
81+
if filename.endswith("_main_log.log"):
82+
main_log_filepath = Path(replaypack_path, file).resolve()
83+
with main_log_filepath.open(encoding="utf-8") as main_log_file:
8184
# Reading the lines of the log file and parsing them:
8285
for line in main_log_file.readlines():
8386
log_object = json.loads(line)
8487
replaypack_main_log_obj_list.append(log_object)
85-
if file.endswith("_processed_failed.log"):
86-
with open(os.path.join(replaypack_path, file)) as processed_files:
88+
if filename.endswith("_processed_failed.log"):
89+
processed_files_filepath = Path(replaypack_path, file).resolve()
90+
with processed_files_filepath.open(encoding="utf-8") as processed_files:
8791
replaypack_processed_failed = json.load(processed_files)
88-
if file.endswith("_processed_mapping.json"):
89-
with open(os.path.join(replaypack_path, file)) as mapping_file:
92+
if filename.endswith("_processed_mapping.json"):
93+
mapping_file_filepath = Path(replaypack_path, file).resolve()
94+
with mapping_file_filepath.open(encoding="utf-8") as mapping_file:
9095
replaypack_dir_mapping = json.load(mapping_file)
91-
if file.endswith("_summary.json"):
92-
with open(os.path.join(replaypack_path, file)) as summary_file:
96+
if filename.endswith("_summary.json"):
97+
summary_file_filepath = Path(replaypack_path, file).resolve()
98+
with summary_file_filepath.open(encoding="utf-8") as summary_file:
9399
replaypack_summary = json.load(summary_file)
94100

95101
return (

0 commit comments

Comments
 (0)