Skip to content

Commit b263c35

Browse files
committed
MUSA: Conditionally remove torch and numpy from dependencies
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
1 parent 649489d commit b263c35

File tree

4 files changed

+41
-32
lines changed

4 files changed

+41
-32
lines changed

Makefile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
flake_find:
2-
cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' -
2+
cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' -
33
format:
44
@cd ktransformers && black .
55
@black setup.py
@@ -14,7 +14,11 @@ dev_install:
1414

1515
# install ktransformers
1616
echo "Installing python dependencies from requirements.txt"
17-
pip install -r requirements-local_chat.txt
17+
@if command -v mcc > /dev/null 2>&1; then \
18+
bash -c 'pip install -r <(grep -v -E "torch|numpy" requirements-local_chat.txt)'; \
19+
else \
20+
pip install -r requirements-local_chat.txt; \
21+
fi
1822

1923
echo "Installing ktransformers"
2024
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation

install.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
set -e
2+
set -e
33

44
# clear build dirs
55
rm -rf build
@@ -10,7 +10,11 @@ rm -rf ktransformers/ktransformers_ext/cuda/dist
1010
rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info
1111

1212
echo "Installing python dependencies from requirements.txt"
13-
pip install -r requirements-local_chat.txt
13+
if command -v mcc > /dev/null 2>&1; then
14+
bash -c 'pip install -r <(grep -v -E "torch|numpy" requirements-local_chat.txt)'
15+
else
16+
pip install -r requirements-local_chat.txt
17+
fi
1418

1519
echo "Installing ktransformers"
1620
KTRANSFORMERS_FORCE_BUILD=TRUE pip install . --no-build-isolation

pyproject.toml

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[build-system]
22
requires = [
33
"setuptools",
4-
"torch >= 2.3.0",
4+
"torch >= 2.3.0",
55
"ninja",
66
"packaging",
77
"cpufeature"
@@ -12,25 +12,7 @@ build-backend = "setuptools.build_meta"
1212

1313
name = "ktransformers"
1414

15-
dynamic = ["version"]
16-
17-
dependencies = [
18-
"torch >= 2.3.0",
19-
"transformers == 4.43.2",
20-
"fastapi >= 0.111.0",
21-
"uvicorn >= 0.30.1",
22-
"langchain >= 0.2.0",
23-
"blessed >= 1.20.0",
24-
"accelerate >= 0.31.0",
25-
"sentencepiece >= 0.1.97",
26-
"setuptools",
27-
"ninja",
28-
"wheel",
29-
"colorlog",
30-
"build",
31-
"fire",
32-
"protobuf"
33-
]
15+
dynamic = ["version", "dependencies"]
3416

3517
requires-python = ">=3.10"
3618

setup.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,17 @@ def get_musa_bare_metal_version(self, musa_dir):
6767
def get_rocm_bare_metal_version(self, rocm_dir):
6868
"""
6969
Get the ROCm version from the ROCm installation directory.
70-
70+
7171
Args:
7272
rocm_dir: Path to the ROCm installation directory
73-
73+
7474
Returns:
7575
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
7676
"""
7777
try:
7878
# Try using rocm_agent_enumerator to get version info
7979
raw_output = subprocess.check_output(
80-
[rocm_dir + "/bin/rocminfo", "--version"],
80+
[rocm_dir + "/bin/rocminfo", "--version"],
8181
universal_newlines=True,
8282
stderr=subprocess.STDOUT)
8383
# Extract version number from output
@@ -90,7 +90,7 @@ def get_rocm_bare_metal_version(self, rocm_dir):
9090
except (subprocess.CalledProcessError, FileNotFoundError):
9191
# If rocminfo --version fails, try alternative methods
9292
pass
93-
93+
9494
try:
9595
# Try reading version from release file
9696
with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f:
@@ -100,7 +100,7 @@ def get_rocm_bare_metal_version(self, rocm_dir):
100100
return rocm_version
101101
except (FileNotFoundError, IOError):
102102
pass
103-
103+
104104
# If all else fails, try to extract from directory name
105105
dir_name = os.path.basename(os.path.normpath(rocm_dir))
106106
match = re.search(r'rocm-(\d+\.\d+)', dir_name)
@@ -109,7 +109,7 @@ def get_rocm_bare_metal_version(self, rocm_dir):
109109
version = parse(version_str)
110110
rocm_version = f"{version.major}{version.minor}"
111111
return rocm_version
112-
112+
113113
# Fallback to extracting from hipcc version
114114
try:
115115
raw_output = subprocess.check_output(
@@ -124,7 +124,7 @@ def get_rocm_bare_metal_version(self, rocm_dir):
124124
return rocm_version
125125
except (subprocess.CalledProcessError, FileNotFoundError):
126126
pass
127-
127+
128128
# If we still can't determine the version, raise an error
129129
raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}")
130130

@@ -319,7 +319,7 @@ def build_extension(self, ext) -> None:
319319
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
320320
# log cmake_args
321321
print("CMake args:", cmake_args)
322-
322+
323323
build_args = []
324324
if "CMAKE_ARGS" in os.environ:
325325
cmake_args += [
@@ -398,6 +398,23 @@ def build_extension(self, ext) -> None:
398398
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
399399
)
400400

401+
dependencies = [
402+
"torch >= 2.3.0",
403+
"transformers == 4.43.2",
404+
"fastapi >= 0.111.0",
405+
"uvicorn >= 0.30.1",
406+
"langchain >= 0.2.0",
407+
"blessed >= 1.20.0",
408+
"accelerate >= 0.31.0",
409+
"sentencepiece >= 0.1.97",
410+
"setuptools",
411+
"ninja",
412+
"wheel",
413+
"colorlog",
414+
"build",
415+
"fire",
416+
"protobuf"
417+
]
401418
if CUDA_HOME is not None or ROCM_HOME is not None:
402419
ops_module = CUDAExtension('KTransformersOps', [
403420
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
@@ -415,6 +432,7 @@ def build_extension(self, ext) -> None:
415432
}
416433
)
417434
elif MUSA_HOME is not None:
435+
dependencies.remove("torch >= 2.3.0")
418436
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
419437
# Common rules
420438
"at::cuda": "at::musa",
@@ -443,6 +461,7 @@ def build_extension(self, ext) -> None:
443461
setup(
444462
name=VersionInfo.PACKAGE_NAME,
445463
version=VersionInfo().get_package_version(),
464+
install_requires=dependencies,
446465
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
447466
ext_modules=[
448467
CMakeExtension("cpuinfer_ext"),

0 commit comments

Comments
 (0)