Skip to content

Commit 47fcb39

Browse files
authored
Use filter="data" option of TarFile.extractall. (#21760)
For Python versions between 3.12 (inclusive) and 3.14 (exclusive). The "data" filter performs a number of additional checks on links and paths. The `filter` option was added in Python 3.12. The `filter="data"` option became the default in Python 3.14. Also: - added similar path filtering when extracting zip archives - shared the extraction code between `file_utils` and `saving_lib`
1 parent 869b31a commit 47fcb39

File tree

3 files changed

+54
-16
lines changed

3 files changed

+54
-16
lines changed

keras/src/saving/saving_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ def __init__(self, root_path, archive=None, mode=None):
943943
if self.archive:
944944
self.tmp_dir = get_temp_dir()
945945
if self.mode == "r":
946-
self.archive.extractall(path=self.tmp_dir)
946+
file_utils.extract_open_archive(self.archive, self.tmp_dir)
947947
self.working_dir = file_utils.join(
948948
self.tmp_dir, self.root_path
949949
).replace("\\", "/")

keras/src/utils/file_utils.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import re
44
import shutil
5+
import sys
56
import tarfile
67
import tempfile
78
import urllib
@@ -52,17 +53,32 @@ def is_link_in_dir(info, base):
5253
return is_path_in_dir(info.linkname, base_dir=tip)
5354

5455

55-
def filter_safe_paths(members):
56+
def filter_safe_zipinfos(members):
5657
base_dir = resolve_path(".")
5758
for finfo in members:
5859
valid_path = False
59-
if is_path_in_dir(finfo.name, base_dir):
60+
if is_path_in_dir(finfo.filename, base_dir):
6061
valid_path = True
6162
yield finfo
62-
elif finfo.issym() or finfo.islnk():
63+
if not valid_path:
64+
warnings.warn(
65+
"Skipping invalid path during archive extraction: "
66+
f"'{finfo.name}'.",
67+
stacklevel=2,
68+
)
69+
70+
71+
def filter_safe_tarinfos(members):
72+
base_dir = resolve_path(".")
73+
for finfo in members:
74+
valid_path = False
75+
if finfo.issym() or finfo.islnk():
6376
if is_link_in_dir(finfo, base_dir):
6477
valid_path = True
6578
yield finfo
79+
elif is_path_in_dir(finfo.name, base_dir):
80+
valid_path = True
81+
yield finfo
6682
if not valid_path:
6783
warnings.warn(
6884
"Skipping invalid path during archive extraction: "
@@ -71,6 +87,35 @@ def filter_safe_paths(members):
7187
)
7288

7389

90+
def extract_open_archive(archive, path="."):
91+
"""Extracts an open tar or zip archive to the provided directory.
92+
93+
This function filters unsafe paths during extraction.
94+
95+
Args:
96+
archive: The archive object, either a `TarFile` or a `ZipFile`.
97+
path: Where to extract the archive file.
98+
"""
99+
if isinstance(archive, zipfile.ZipFile):
100+
# Zip archive.
101+
archive.extractall(
102+
path, members=filter_safe_zipinfos(archive.infolist())
103+
)
104+
else:
105+
# Tar archive.
106+
extractall_kwargs = {}
107+
# The `filter="data"` option was added in Python 3.12. It became the
108+
# default starting from Python 3.14. So we only specify it between
109+
# those two versions.
110+
if sys.version_info >= (3, 12) and sys.version_info < (3, 14):
111+
extractall_kwargs = {"filter": "data"}
112+
archive.extractall(
113+
path,
114+
members=filter_safe_tarinfos(archive),
115+
**extractall_kwargs,
116+
)
117+
118+
74119
def extract_archive(file_path, path=".", archive_format="auto"):
75120
"""Extracts an archive if it matches a support format.
76121
@@ -112,14 +157,7 @@ def extract_archive(file_path, path=".", archive_format="auto"):
112157
if is_match_fn(file_path):
113158
with open_fn(file_path) as archive:
114159
try:
115-
if zipfile.is_zipfile(file_path):
116-
# Zip archive.
117-
archive.extractall(path)
118-
else:
119-
# Tar archive, perhaps unsafe. Filter paths.
120-
archive.extractall(
121-
path, members=filter_safe_paths(archive)
122-
)
160+
extract_open_archive(archive, path)
123161
except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
124162
if os.path.exists(path):
125163
if os.path.isfile(path):

keras/src/utils/file_utils_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_member_within_base_dir(self):
142142
with tarfile.open(self.tar_path, "w") as tar:
143143
tar.add(__file__, arcname="safe_path.txt")
144144
with tarfile.open(self.tar_path, "r") as tar:
145-
members = list(file_utils.filter_safe_paths(tar.getmembers()))
145+
members = list(file_utils.filter_safe_tarinfos(tar.getmembers()))
146146
self.assertEqual(len(members), 1)
147147
self.assertEqual(members[0].name, "safe_path.txt")
148148

@@ -156,7 +156,7 @@ def test_symlink_within_base_dir(self):
156156
with tarfile.open(self.tar_path, "w") as tar:
157157
tar.add(symlink_path, arcname="symlink.txt")
158158
with tarfile.open(self.tar_path, "r") as tar:
159-
members = list(file_utils.filter_safe_paths(tar.getmembers()))
159+
members = list(file_utils.filter_safe_tarinfos(tar.getmembers()))
160160
self.assertEqual(len(members), 1)
161161
self.assertEqual(members[0].name, "symlink.txt")
162162
os.remove(symlink_path)
@@ -173,7 +173,7 @@ def test_invalid_path_warning(self):
173173
) # Path intended to be outside of base dir
174174
with tarfile.open(self.tar_path, "r") as tar:
175175
with patch("warnings.warn") as mock_warn:
176-
_ = list(file_utils.filter_safe_paths(tar.getmembers()))
176+
_ = list(file_utils.filter_safe_tarinfos(tar.getmembers()))
177177
warning_msg = (
178178
"Skipping invalid path during archive extraction: "
179179
"'../../invalid.txt'."
@@ -196,7 +196,7 @@ def test_symbolic_link_in_base_dir(self):
196196
tar.add(symlink_path, arcname="symlink.txt")
197197

198198
with tarfile.open(self.tar_path, "r") as tar:
199-
members = list(file_utils.filter_safe_paths(tar.getmembers()))
199+
members = list(file_utils.filter_safe_tarinfos(tar.getmembers()))
200200
self.assertEqual(len(members), 1)
201201
self.assertEqual(members[0].name, "symlink.txt")
202202
self.assertTrue(

0 commit comments

Comments
 (0)