@@ -14,8 +14,9 @@ def patch_fix_RoundToNearest(tf_version: str, tf_src_root: str):
1414 print (f"warning: skip applying `patch_fix_RoundToNearest` to tf version `{ tf_version } `" )
1515 return
1616
17- # kernels/internal/optimized/neon_tensor_utils.cc
18- neon_tensor_utils_cc = Path (tf_src_root ) / 'kernels' / 'internal' / 'optimized' / 'neon_tensor_utils.cc'
17+ # tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
18+ tf_lite_src_root = Path (tf_src_root ) / 'tensorflow' / 'lite'
19+ neon_tensor_utils_cc = Path (tf_lite_src_root ) / 'kernels' / 'internal' / 'optimized' / 'neon_tensor_utils.cc'
1920 fixed = StringIO ()
2021 patched_1 = False
2122 skip_next_line = False
@@ -43,7 +44,8 @@ def patch_compiling_telemetry_cc(tf_version: str, tf_src_root: str):
4344 print (f"warning: skip applying `patch_compiling_telemetry_cc` to tf version `{ tf_version } `" )
4445 return
4546
46- cmakelists_txt = Path (tf_src_root ) / 'CMakeLists.txt'
47+ tf_lite_src_root = Path (tf_src_root ) / 'tensorflow' / 'lite'
48+ cmakelists_txt = Path (tf_lite_src_root ) / 'CMakeLists.txt'
4749 fixed = StringIO ()
4850 patched_1 = False
4951 with open (cmakelists_txt , 'r' ) as source :
@@ -62,7 +64,64 @@ def patch_compiling_telemetry_cc(tf_version: str, tf_src_root: str):
6264 dst .write (fixed .getvalue ())
6365
6466
65- patches = [patch_fix_RoundToNearest , patch_compiling_telemetry_cc ]
67+ def patch_cpuinfo_riscv64_sys_hwprobe (tf_version : str , tf_src_root : str ):
68+ if tf_version not in ['2.16.0' , '2.16.1' ]:
69+ print (f"warning: skip applying `patch_cpuinfo_riscv64_sys_hwprobe` to tf version `{ tf_version } `" )
70+ return
71+
72+ # tensorflow/lite/tools/cmake/modules/cpuinfo.cmake
73+ # tensorflow/workspace2.bzl
74+ tf_lite_src_root = Path (tf_src_root ) / 'tensorflow' / 'lite'
75+ cpuinfo_cmake = Path (tf_lite_src_root ) / 'tools' / 'cmake' / 'modules' / 'cpuinfo.cmake'
76+ fixed = StringIO ()
77+ patched_1 = False
78+ lines_to_skip = 0
79+ with open (cpuinfo_cmake , 'r' ) as source :
80+ for line in source :
81+ line_strip = line .strip ()
82+ if not patched_1 and line_strip == '# Sync with tensorflow/third_party/cpuinfo/workspace.bzl' :
83+ fixed .write (f" { line_strip } # fixed\n " )
84+ fixed .write (" GIT_TAG 6543fec09b2f04ac4a666882998b534afc9c1349\n " )
85+ patched_1 = True
86+ lines_to_skip = 1
87+ else :
88+ if lines_to_skip > 0 :
89+ lines_to_skip -= 1
90+ continue
91+ fixed .write (line )
92+
93+ if patched_1 :
94+ with open (cpuinfo_cmake , 'w' ) as dst :
95+ dst .truncate (0 )
96+ dst .write (fixed .getvalue ())
97+
98+ workspace2_bzl = Path (tf_src_root ) / 'tensorflow' / 'workspace2.bzl'
99+ fixed = StringIO ()
100+ patched_2 = False
101+ with open (workspace2_bzl , 'r' ) as source :
102+ for line in source :
103+ line_strip = line .strip ()
104+ if not patched_2 and line_strip == 'name = "cpuinfo",' :
105+ fixed .write (' name = "cpuinfo", # fixed\n ' )
106+ fixed .write (' strip_prefix = "cpuinfo-6543fec09b2f04ac4a666882998b534afc9c1349",\n ' )
107+ fixed .write (' sha256 = "17180581df58b811ef93cfafd074598966a185f48e5a574e8947ca51419f7ca6",\n ' )
108+ fixed .write (' urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/6543fec09b2f04ac4a666882998b534afc9c1349.zip"),\n ' )
109+ patched_2 = True
110+ lines_to_skip = 3
111+ else :
112+ if lines_to_skip > 0 :
113+ lines_to_skip -= 1
114+ continue
115+ fixed .write (line )
116+
117+ if patched_2 :
118+ with open (workspace2_bzl , 'w' ) as dst :
119+ dst .truncate (0 )
120+ dst .write (fixed .getvalue ())
121+
122+
123+
124+ patches = [patch_fix_RoundToNearest , patch_compiling_telemetry_cc , patch_cpuinfo_riscv64_sys_hwprobe ]
66125
67126if __name__ == '__main__' :
68127 tf_version = None
0 commit comments