Skip to content

Commit 012e4d5

Browse files
committed
patch tensorflow v2.12.0
1 parent de5f6e2 commit 012e4d5

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

patches/apply_patch.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def patch_fix_RoundToNearest(tf_version: str, tf_src_root: str):
13-
if tf_version not in ['2.11.0', '2.11.1']:
13+
if tf_version not in ['2.11.0', '2.11.1', '2.12.0']:
1414
print(f"warning: skip applying `patch_fix_RoundToNearest` to tf version `{tf_version}`")
1515
return
1616

@@ -38,8 +38,31 @@ def patch_fix_RoundToNearest(tf_version: str, tf_src_root: str):
3838
dst.truncate(0)
3939
dst.write(fixed.getvalue())
4040

41-
patches = [patch_fix_RoundToNearest]
41+
def patch_compiling_telemetry_cc(tf_version: str, tf_src_root: str):
42+
if tf_version not in ['2.12.0']:
43+
print(f"warning: skip applying `patch_compiling_telemetry_cc` to tf version `{tf_version}`")
44+
return
45+
46+
cmakelists_txt = Path(tf_src_root) / 'CMakeLists.txt'
47+
fixed = StringIO()
48+
patched_1 = False
49+
with open(cmakelists_txt, 'r') as source:
50+
for line in source:
51+
line_strip = line.strip()
52+
if not patched_1 and line_strip == '${TFLITE_SOURCE_DIR}/profiling/telemetry/profiler.cc':
53+
fixed.write(f" {line_strip} # fixed\n")
54+
fixed.write(" ${TFLITE_SOURCE_DIR}/profiling/telemetry/telemetry.cc\n")
55+
patched_1 = True
56+
else:
57+
fixed.write(line)
58+
59+
if patched_1:
60+
with open(cmakelists_txt, 'w') as dst:
61+
dst.truncate(0)
62+
dst.write(fixed.getvalue())
63+
4264

65+
patches = [patch_fix_RoundToNearest, patch_compiling_telemetry_cc]
4366

4467
if __name__ == '__main__':
4568
tf_version = None

0 commit comments

Comments
 (0)