@@ -328,6 +328,23 @@ def build_extension(self, ext) -> None:
328328 ["cmake" , "--build" , "." , "--verbose" , * build_args ], cwd = build_temp , check = True
329329 )
330330
331+ dependencies = [
332+ "torch >= 2.3.0" ,
333+ "transformers == 4.43.2" ,
334+ "fastapi >= 0.111.0" ,
335+ "uvicorn >= 0.30.1" ,
336+ "langchain >= 0.2.0" ,
337+ "blessed >= 1.20.0" ,
338+ "accelerate >= 0.31.0" ,
339+ "sentencepiece >= 0.1.97" ,
340+ "setuptools" ,
341+ "ninja" ,
342+ "wheel" ,
343+ "colorlog" ,
344+ "build" ,
345+ "fire" ,
346+ "protobuf"
347+ ]
331348if CUDA_HOME is not None :
332349 ops_module = CUDAExtension ('KTransformersOps' , [
333350 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu' ,
@@ -345,6 +362,7 @@ def build_extension(self, ext) -> None:
345362 }
346363 )
347364elif MUSA_HOME is not None :
365+ dependencies .remove ("torch >= 2.3.0" )
348366 SimplePorting (cuda_dir_path = "ktransformers/ktransformers_ext/cuda" , mapping_rule = {
349367 # Common rules
350368 "at::cuda" : "at::musa" ,
@@ -372,6 +390,7 @@ def build_extension(self, ext) -> None:
372390
373391setup (
374392 version = VersionInfo ().get_package_version (),
393+ install_requires = dependencies ,
375394 cmdclass = {"bdist_wheel" :BuildWheelsCommand ,"build_ext" : CMakeBuild },
376395 ext_modules = [
377396 CMakeExtension ("cpuinfer_ext" ),
0 commit comments