@@ -67,17 +67,17 @@ def get_musa_bare_metal_version(self, musa_dir):
6767 def get_rocm_bare_metal_version (self , rocm_dir ):
6868 """
6969 Get the ROCm version from the ROCm installation directory.
70-
70+
7171 Args:
7272 rocm_dir: Path to the ROCm installation directory
73-
73+
7474 Returns:
7575 A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
7676 """
7777 try :
7878 # Try using rocm_agent_enumerator to get version info
7979 raw_output = subprocess .check_output (
80- [rocm_dir + "/bin/rocminfo" , "--version" ],
80+ [rocm_dir + "/bin/rocminfo" , "--version" ],
8181 universal_newlines = True ,
8282 stderr = subprocess .STDOUT )
8383 # Extract version number from output
@@ -90,7 +90,7 @@ def get_rocm_bare_metal_version(self, rocm_dir):
9090 except (subprocess .CalledProcessError , FileNotFoundError ):
9191 # If rocminfo --version fails, try alternative methods
9292 pass
93-
93+
9494 try :
9595 # Try reading version from release file
9696 with open (os .path .join (rocm_dir , "share/doc/hip/version.txt" ), "r" ) as f :
@@ -100,7 +100,7 @@ def get_rocm_bare_metal_version(self, rocm_dir):
100100 return rocm_version
101101 except (FileNotFoundError , IOError ):
102102 pass
103-
103+
104104 # If all else fails, try to extract from directory name
105105 dir_name = os .path .basename (os .path .normpath (rocm_dir ))
106106 match = re .search (r'rocm-(\d+\.\d+)' , dir_name )
@@ -109,7 +109,7 @@ def get_rocm_bare_metal_version(self, rocm_dir):
109109 version = parse (version_str )
110110 rocm_version = f"{ version .major } { version .minor } "
111111 return rocm_version
112-
112+
113113 # Fallback to extracting from hipcc version
114114 try :
115115 raw_output = subprocess .check_output (
@@ -124,7 +124,7 @@ def get_rocm_bare_metal_version(self, rocm_dir):
124124 return rocm_version
125125 except (subprocess .CalledProcessError , FileNotFoundError ):
126126 pass
127-
127+
128128 # If we still can't determine the version, raise an error
129129 raise ValueError (f"Could not determine ROCm version from directory: { rocm_dir } " )
130130
@@ -319,7 +319,7 @@ def build_extension(self, ext) -> None:
319319 raise ValueError ("Unsupported backend: CUDA_HOME and MUSA_HOME are not set." )
320320 # log cmake_args
321321 print ("CMake args:" , cmake_args )
322-
322+
323323 build_args = []
324324 if "CMAKE_ARGS" in os .environ :
325325 cmake_args += [
@@ -398,6 +398,23 @@ def build_extension(self, ext) -> None:
398398 ["cmake" , "--build" , "." , "--verbose" , * build_args ], cwd = build_temp , check = True
399399 )
400400
401+ dependencies = [
402+ "torch >= 2.3.0" ,
403+ "transformers == 4.43.2" ,
404+ "fastapi >= 0.111.0" ,
405+ "uvicorn >= 0.30.1" ,
406+ "langchain >= 0.2.0" ,
407+ "blessed >= 1.20.0" ,
408+ "accelerate >= 0.31.0" ,
409+ "sentencepiece >= 0.1.97" ,
410+ "setuptools" ,
411+ "ninja" ,
412+ "wheel" ,
413+ "colorlog" ,
414+ "build" ,
415+ "fire" ,
416+ "protobuf"
417+ ]
401418if CUDA_HOME is not None or ROCM_HOME is not None :
402419 ops_module = CUDAExtension ('KTransformersOps' , [
403420 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu' ,
@@ -415,6 +432,7 @@ def build_extension(self, ext) -> None:
415432 }
416433 )
417434elif MUSA_HOME is not None :
435+ dependencies .remove ("torch >= 2.3.0" )
418436 SimplePorting (cuda_dir_path = "ktransformers/ktransformers_ext/cuda" , mapping_rule = {
419437 # Common rules
420438 "at::cuda" : "at::musa" ,
@@ -443,6 +461,7 @@ def build_extension(self, ext) -> None:
443461setup (
444462 name = VersionInfo .PACKAGE_NAME ,
445463 version = VersionInfo ().get_package_version (),
464+ install_requires = dependencies ,
446465 cmdclass = {"bdist_wheel" :BuildWheelsCommand ,"build_ext" : CMakeBuild },
447466 ext_modules = [
448467 CMakeExtension ("cpuinfer_ext" ),
0 commit comments