Skip to content

Commit 6c35ca7

Browse files
authored
Merge pull request #809 from KMSorSMS/develop-0.2.3
⚡ release v0.2.3
2 parents 034a116 + 848fe8a commit 6c35ca7

File tree

9 files changed

+233
-8
lines changed

9 files changed

+233
-8
lines changed

doc/en/FAQ.md

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
1+
<!-- omit in toc -->
12
# FAQ
3+
- [Install](#install)
4+
- [Q: ImportError: /lib/x86\_64-linux-gnu/libstdc++.so.6: version GLIBCXX\_3.4.32' not found](#q-importerror-libx86_64-linux-gnulibstdcso6-version-glibcxx_3432-not-found)
5+
- [Q: DeepSeek-R1 not outputting initial token](#q-deepseek-r1-not-outputting-initial--token)
6+
- [Usage](#usage)
7+
- [Q: If I got more VRAM than the model's requirement, how can I fully utilize it?](#q-if-i-got-more-vram-than-the-models-requirement-how-can-i-fully-utilize-it)
8+
- [Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them?](#q-if-i-dont-have-enough-vram-but-i-have-multiple-gpus-how-can-i-utilize-them)
9+
- [Q: How to get the best performance?](#q-how-to-get-the-best-performance)
10+
- [Q: My DeepSeek-R1 model is not thinking.](#q-my-deepseek-r1-model-is-not-thinking)
11+
- [Q: Loading gguf error](#q-loading-gguf-error)
12+
- [Q: Version \`GLIBCXX\_3.4.30' not found](#q-version-glibcxx_3430-not-found)
13+
- [Q: When running the bfloat16 moe model, the data shows NaN](#q-when-running-the-bfloat16-moe-model-the-data-shows-nan)
14+
- [Q: Using fp8 prefill very slow.](#q-using-fp8-prefill-very-slow)
15+
- [Q: Possible ways to run graphics cards using volta and turing architectures](#q-possible-ways-to-run-graphics-cards-using-volta-and-turing-architectures)
216
## Install
317
### Q: ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version GLIBCXX_3.4.32' not found
418
```
@@ -96,4 +110,58 @@ RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
96110

97111
### Q: Using fp8 prefill very slow.
98112

99-
The FP8 kernel is build by JIT, so the first run will be slow. The subsequent runs will be faster.
113+
The FP8 kernel is build by JIT, so the first run will be slow. The subsequent runs will be faster.
114+
115+
### Q: Possible ways to run graphics cards using volta and turing architectures
116+
117+
From: https://github.com/kvcache-ai/ktransformers/issues/374
118+
119+
1. First, download the latest source code using git.
120+
2. Then, modify the DeepSeek-V3-Chat-multi-gpu-4.yaml in the source code and all related yaml files, replacing all instances of KLinearMarlin with KLinearTorch.
121+
3. Next, you need to compile from the ktransformer source code until it successfully compiles on your local machine.
122+
4. Then, install flash-attn. It won't be used, but not installing it will cause an error.
123+
5. Then, modify local_chat.py, replacing all instances of flash_attention_2 with eager.
124+
6. Then, run local_chat.py. Be sure to follow the official tutorial's commands and adjust according to your local machine's parameters.
125+
7. During the running process, check the memory usage. Observe its invocation through the top command. The memory capacity on a single CPU must be greater than the complete size of the model. (For multiple CPUs, it's just a copy.)
126+
Finally, confirm that the model is fully loaded into memory and specific weight layers are fully loaded into the GPU memory. Then, try to input content in the chat interface and observe if there are any errors.
127+
128+
Attention, for better perfomance, you can check this [method](https://github.com/kvcache-ai/ktransformers/issues/374#issuecomment-2667520838) in the issue
129+
>
130+
>https://github.com/kvcache-ai/ktransformers/blob/89f8218a2ab7ff82fa54dbfe30df741c574317fc/ktransformers/operators/attention.py#L274-L279
131+
>
132+
>```diff
133+
>+ original_dtype = query_states.dtype
134+
>+ target_dtype = torch.half
135+
>+ query_states = query_states.to(target_dtype)
136+
>+ compressed_kv_with_k_pe = compressed_kv_with_k_pe.to(target_dtype)
137+
>+ compressed_kv = compressed_kv.to(target_dtype)
138+
>+ attn_output = attn_output.to(target_dtype)
139+
>
140+
>decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
141+
> page_table,
142+
> position_ids.squeeze(0).to(torch.int32)+1, attn_logits,
143+
> 4, #num_kv_splits # follow vLLM, fix it TODO
144+
> self.softmax_scale,
145+
> past_key_value.page_size)
146+
>
147+
>+ attn_output = attn_output.to(original_dtype)
148+
>```
149+
>
150+
>https://github.com/kvcache-ai/ktransformers/blob/89f8218a2ab7ff82fa54dbfe30df741c574317fc/ktransformers/operators/attention.py#L320-L326
151+
>
152+
>```diff
153+
>- attn_output = flash_attn_func(
154+
>- query_states,
155+
>- key_states,
156+
>- value_states_padded,
157+
>- softmax_scale=self.softmax_scale,
158+
>- causal=True,
159+
>- )
160+
>+ attn_output = F.scaled_dot_product_attention(
161+
>+ query_states.transpose(1, 2),
162+
>+ key_states.transpose(1, 2),
163+
>+ value_states_padded.transpose(1, 2),
164+
>+ scale=self.softmax_scale,
165+
>+ is_causal=True
166+
>+ ).transpose(1, 2)
167+
>```

doc/en/benchmark.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Given that we have only tested 1,000 cases, which provides only a preliminary ju
2626

2727

2828
## The Result Table
29-
29+
Uses DeepSeek-V3 model (Some specific cases are R1)
3030
| | | | | | | | |
3131
| ------------------------ | ----------------- | ---------- | ----------------- | ------- | ---------- | ------------------------------------------------------ | ------------ |
3232
| DataSet | CPU Weight Format | CPU Kernel | GPU Weight Format | GEMM Kernel | MLA Kernel | [Siliconflow](https://cloud.siliconflow.cn/models)<br> | Ktrans Point |
@@ -37,9 +37,11 @@ Given that we have only tested 1,000 cases, which provides only a preliminary ju
3737
| 4 | q4km | cpuinfer | q4km->marlin 8 | marlin | triton | 81.6 | 81.1 |
3838
| 5 | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 81.6 | 81 |
3939
| 6 | q4km | cpuinfer | fp8 | fp8gemm | triton | 81.6 | 81.5 |
40-
| MMLU-pro | | | | | | | |
40+
| 7 (DeepSeek-R1) | iq1 | cpuinfer | fp8 | fp8gemm | triton | 78.6 | 83.6 |
41+
| MMLU-pro<br>(shuffle 1k) | | | | | | | |
4142
| 1 | q4km | cpuinfer | fp8 | fp8gemm | triton | 57.7 | 57.6 |
4243
| 2 | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 57.7 | 57.5 |
44+
| 3 (DeepSeek-R1) | iq1 | cpuinfer | fp8 | fp8gem | triton | 71.9 | tbd |
4345
| HumanEval | tbd | tbd | tbd | tbd | tbd | tbd | tbd |
4446
| GSM8K | tbd | tbd | tbd | tbd | tbd | tbd | tbd |
4547

@@ -54,6 +56,8 @@ By default, The MLA kernel uses triton in linux and torch in windows. But we nee
5456
4. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You don't need to change the source code as they both use q4km. But note the yaml file [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L29) and [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L18), below these lines you need to add `num_bits: 8` (in other words: add this kwargs to all that use `KLinearMarlin`). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
5557
5. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
5658
6. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
59+
7. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
5760
- MMLU-pro test
5861
1. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
59-
2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
62+
2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
63+
3. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.

doc/en/install.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
1+
<!-- omit in toc -->
22
# How to Run DeepSeek-R1
3+
- [Preparation](#preparation)
4+
- [Installation](#installation)
5+
- [Attention](#attention)
6+
- [Supported models include:](#supported-models-include)
7+
- [Support quantize format:](#support-quantize-format)
8+
39
In this document, we will show you how to install and run KTransformers on your local machine. There are two versions:
410
* V0.2 is the current main branch.
511
* V0.3 is a preview version only provides binary distribution for now.
@@ -56,6 +62,8 @@ Some preparation:
5662
- At the same time, you should download and install the corresponding version of flash-attention from https://github.com/Dao-AILab/flash-attention/releases.
5763

5864
## Installation
65+
### Attention
66+
If you want to use numa support, not only do you need to set USE_NUMA=1, but you also need to make sure you have installed the libnuma-dev (`sudo apt-get install libnuma-dev` may help you).
5967

6068
<!-- 1. ~~Use a Docker image, see [documentation for Docker](./doc/en/Docker.md)~~
6169

ktransformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
LastEditors : chenxl
99
LastEditTime : 2025-02-15 03:53:02
1010
'''
11-
__version__ = "0.2.2rc1"
11+
__version__ = "0.2.3"

ktransformers/tests/.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
humaneval/results
1+
results/
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# adapt from https://github.com/abacaj/code-eval?tab=readme-ov-file
2+
import argparse
3+
import json
4+
import os
5+
import time
6+
import requests
7+
import tqdm
8+
9+
from evaluation import filter_answer
10+
from prompts import instruct_prompt
11+
import pandas as pd
12+
from datasets import load_dataset
13+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
14+
15+
16+
def generate_text(api_url,question , model_name, stream=False, auth_token=None):
17+
headers = {
18+
'accept': 'application/json',
19+
'Content-Type': 'application/json',
20+
# 添加 API Key
21+
'Authorization' : 'Bearer ' + auth_token if auth_token else ''
22+
}
23+
question = instruct_prompt(question)
24+
data = {
25+
"messages": [{"content": question, "role": "user"}],
26+
"model": model_name,
27+
"stream": stream,
28+
"temperature": 0.6,
29+
"max_tokens": 10240,
30+
}
31+
print(f"content: {question}")
32+
response = requests.post(api_url, headers=headers, json=data,verify=False)
33+
if response.status_code == 200:
34+
result = response.json()
35+
results = result.get('choices', [{}])[0].get('message', {}).get('content', '')
36+
return filter_answer(results)
37+
else:
38+
print(f"API Request failed with status code {response.status_code}")
39+
return None
40+
def load_data(file_path):
41+
"""
42+
Load data from a Parquet file into a list.
43+
Each record in the Parquet file should represent an individual record.
44+
"""
45+
# 读取 Parquet 文件
46+
# dataset = load_dataset('parquet', data_files=file_path)
47+
data = []
48+
ds = load_dataset(file_path)
49+
df = pd.DataFrame(ds['train'])
50+
for _, row in df.iterrows():
51+
data.append(row.to_dict())
52+
return data
53+
54+
def get_score(pred, answer):
55+
"""
56+
Calculate scores between the prediction and the answer.
57+
Uses ROUGE scores as the evaluation metric.
58+
:param pred: The predicted string.
59+
:param answer: The reference answer string.
60+
:return: A dictionary containing ROUGE scores.
61+
"""
62+
if pred == answer:
63+
return 1
64+
# if we need to compare str with number, convert teh str to number
65+
try:
66+
pred = float(pred)
67+
answer = float(answer)
68+
except:
69+
pass
70+
if pred == answer:
71+
return 1
72+
return 0
73+
74+
def run_eval_api(
75+
api_url: str,
76+
model_name: str,
77+
out_path: str,
78+
format_tabs: bool = False,
79+
auth_token: str = None,
80+
problem_file: str = None,
81+
append: bool = False
82+
):
83+
84+
data = load_data(problem_file)
85+
pbar = tqdm.tqdm(total=len(data) * 1)
86+
87+
for i in range(len(data)):
88+
data_item = data[i]
89+
question = data_item['Problem']
90+
# Start the timer for this evaluation
91+
start_time = time.time()
92+
try:
93+
completion = generate_text(api_url, question, model_name, auth_token=auth_token)
94+
if completion is None:
95+
raise Exception(f"Failed to get prediction for {question}")
96+
answer = data_item['Answer']
97+
score = get_score(completion, answer)
98+
elapsed_time = time.time() - start_time
99+
result = {
100+
"question_id": data_item["ID"],
101+
"answer": answer,
102+
"prediction": completion,
103+
"score": score,
104+
"time": elapsed_time
105+
}
106+
with open(out_path, "a" if append else "w") as f:
107+
f.write(json.dumps(result) + "\n")
108+
109+
except Exception as e:
110+
print(f"Failed to get prediction for {question}")
111+
print(e)
112+
continue
113+
114+
pbar.update(1)
115+
116+
117+
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append):
118+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
119+
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append)
120+
121+
122+
if __name__ == "__main__":
123+
parser = argparse.ArgumentParser(description="API Generate Tester")
124+
parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
125+
parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-R1", help="Model Name")
126+
parser.add_argument("--out_path", type=str, default="results/api/eval_aime.jsonl", help="Output Path")
127+
parser.add_argument("--auth_token", type=str, default=None, help="Auth Token")
128+
parser.add_argument("--format_tabs", action="store_true", help="Format Tabs")
129+
parser.add_argument("--problem_file", type=str, default="Maxwell-Jia/AIME_2024", help="Evalset File")
130+
parser.add_argument("--no_append", action="store_false", help="Append to existing file")
131+
args = parser.parse_args()
132+
# api_url = "https://api.siliconflow.cn/v1/chat/completions"
133+
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35
2+
def filter_answer(completion: str) -> str:
3+
# the answer is the last part of the completion, it's a int64 number
4+
# get the last line
5+
completion = completion.strip().split("\n")[-1]
6+
# handle the $\\boxed{...}$ format
7+
if "$\\boxed{" in completion:
8+
return completion.split("}")[0].split("{")[-1]
9+
return completion.split()[-1]
10+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def instruct_prompt(prompt: str) -> str:
2+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nSolve the following math problem without any tests or explanation only one answer surrounede by '$\\boxed{{}}$'\n{prompt}\n\n### Response:"""

ktransformers/tests/humaneval/eval_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def generate_text(api_url,question , model_name, stream=False, auth_token=None):
1313
'accept': 'application/json',
1414
'Content-Type': 'application/json',
1515
# 添加 API Key
16-
'Authorization' : 'Bearer ' + auth_token
16+
'Authorization' : 'Bearer ' + auth_token if auth_token else ''
1717
}
1818
question = instruct_prompt(question)
1919
data = {

0 commit comments

Comments
 (0)