Skip to content

Commit a3bb30e

Browse files
zhihaoshan-googleZhihao Shan
andauthored
An experimental JAX inference framework for prototyping new ideas. (#161)
It has the following features (some of them are limited version): Performance: 1. Paged Attetnion 2. Chunked Prefill and Piggybacking Decode 3. Collective Matmul Framework: 1. Pythonic model builder 2. JAX manual sharding 3. Interface for different hardware support 4. On-the-flying HF model conversion and deployment Please refer to README file for more information. Will add more reports in the later commits. Co-authored-by: Zhihao Shan <zhihaoshan@google.com>
1 parent 9a7f10b commit a3bb30e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+6976
-0
lines changed

experimental/jax/README.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# An experimental JAX inference framework for prototyping new ideas.
2+
3+
## About
4+
5+
It has the following features (some of them are limited version):
6+
7+
```
8+
Performance:
9+
1. Paged Attention
10+
2. Chunked Prefill and Piggybacking Decode
11+
3. Collective Matmul
12+
13+
Framework:
14+
1. Pythonic model builder
15+
2. JAX manual sharding
16+
3. Interface for different hardware supports
17+
4. On-the-flying HF model conversion and deployment
18+
```
19+
20+
## Quick Start
21+
22+
So far, the experimental code only works for llama2 7b and TPU v5e-8. The whole process only takes less than 10 mins if you have a Cloud TPU v5e-8 ready.
23+
24+
### 1. Create Cloud TPU v5e-8 on Google Cloud:
25+
26+
```
27+
gcloud alpha compute tpus queued-resources create ${QR_NAME} \
28+
--node-id ${NODE_NAME} \
29+
--project ${PROJECT_ID} \
30+
--zone ${ZONE} \
31+
--accelerator-type v5litepod-8 \
32+
--runtime-version v2-alpha-tpuv5-lite
33+
```
34+
35+
For more [information](https://cloud.google.com/tpu/docs/queued-resources)
36+
37+
38+
### 2. Set up the LLM Server and serve request:
39+
SSH into your Cloud TPU VM first and run the following command:
40+
41+
Set up a new Python env.
42+
```
43+
virtualenv jax-inference
44+
source jax-inference/bin/activate
45+
```
46+
47+
Clone the repo and install the dependencies.
48+
```
49+
git clone https://github.com/AI-Hypercomputer/JetStream.git
50+
51+
cd JetStream/experimental/jax
52+
53+
pip install -r requirements.txt
54+
```
55+
56+
Log in to the Hugging Face (make sure your account has the permission to access `meta-llama/Llama-2-7b-chat-hf`)
57+
58+
```
59+
huggingface-cli login
60+
```
61+
62+
63+
### 3. Offline Benchmarking:
64+
65+
Note: the current setup is using 8-ways TP which is just for experiment and compare with current JetStream + MaxText number.
66+
67+
```
68+
export PYTHONPATH=$(pwd)
69+
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"
70+
python inference/entrypoint/mini_offline_benchmarking.py
71+
```
72+
73+
Offline Benchmarking result:
74+
75+
This number is around `45%` better than the current MaxText and JetStream (as of 2024/08/16) number in the same situation.
76+
77+
78+
```
79+
Benchmarking result:
80+
Total requests: 1000
81+
Total input tokens: 218743
82+
Total output tokens: 291740
83+
Input token throughput: 2980.654636529649 tokens/sec
84+
Output token throughput: 3975.332621666338 tokens/sec
85+
```
86+
87+
Note: The online number should be even more better than the current MaxText and JetStream as the experimental framework runs the prefill and decode together in one model forward pass.
88+
89+
### 4. Online Serving Example:
90+
91+
Start server:
92+
93+
```
94+
python inference/entrypoint/run_simple_server.py &
95+
```
96+
97+
Send request:
98+
99+
```
100+
curl --no-buffer -H 'Content-Type: application/json' \
101+
-d '{ "prompt": "Today is a good day" }' \
102+
-X POST \
103+
localhost:8000/generate
104+
```
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import time
19+
import pandas
20+
from inference.runtime.request_type import *
21+
from inference.runtime import offline_inference
22+
23+
24+
def load_openorca_dataset_pkl():
25+
# Read pickle file
26+
current_dir = os.path.dirname(__file__)
27+
samples = pandas.read_pickle(
28+
f"{current_dir}/open_orca_gpt4_tokenized_llama.calibration_1000.pkl"
29+
)
30+
31+
prompts = []
32+
outputs = []
33+
for _, row in samples.iterrows():
34+
prompts.append(row["input"])
35+
outputs.append(row["output"])
36+
37+
return [(prompt, output) for prompt, output in zip(prompts, outputs)]
38+
39+
40+
def benchmarking():
41+
dataset = load_openorca_dataset_pkl()
42+
43+
ds = dataset[:1000]
44+
ds = [d[0] for d in ds]
45+
46+
inference_instance = offline_inference.OfflineInference()
47+
48+
start_time = time.perf_counter()
49+
res_list: list[Response] = inference_instance(ds)
50+
end_time = time.perf_counter()
51+
duration = end_time - start_time
52+
53+
input_tokens = []
54+
for res in res_list:
55+
input_tokens = input_tokens + res.input_tokens
56+
57+
output_tokens = []
58+
for res in res_list:
59+
output_tokens = output_tokens + res.generated_tokens
60+
61+
num_input_tokens = len(input_tokens)
62+
num_output_tokens = len(output_tokens)
63+
64+
print("Benchmarking result: ")
65+
# Hardcode the number of requests as 1000 based on the test
66+
# dataset.
67+
print(" Total requests: 1000")
68+
print(" Total input tokens:", num_input_tokens)
69+
print(" Total output tokens:", num_output_tokens)
70+
print(f" Input token throughput: {num_input_tokens/duration} tokens/sec")
71+
print(f" Output token throughput: {num_output_tokens/duration} tokens/sec")
72+
73+
74+
if __name__ == "__main__":
75+
benchmarking()
Binary file not shown.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import uvicorn
18+
import os
19+
20+
21+
if __name__ == "__main__":
22+
print("start")
23+
current_dir = os.path.dirname(__file__)
24+
parent_dir = os.path.dirname(current_dir)
25+
26+
uvicorn.run(
27+
app_dir=f"{parent_dir}/server",
28+
app="simple_server:app",
29+
host="0.0.0.0",
30+
port=8000,
31+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .attention_ops import *
18+
from .attention.tpu.quantization_utils import *
19+
from .collective_matmul_ops import *
20+
from .linear.tpu.collective_matmul import (
21+
prepare_rhs_for_all_gather_collective_matmul,
22+
prepare_rhs_for_collective_matmul_reduce_scatter,
23+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .chunked_prefill_attention import *
18+
from .paged_attention import *

0 commit comments

Comments
 (0)