Skip to content

Commit 8b2ddb2

Browse files
authored
Qualcomm AI Engine Direct - Support MaskedSoftmax in static llama (#12745)
Summary: - Add a unit test for masked softmax - Add amin op support - Add a flag `--enable_masked_softmax` to enable masked softmax feature. It is designed to optimize the LLMs accuracy and performance executed on HTP backend. MaskedSoftmax is used to replace the Softmax(Add(In, Mask)) structure in attention block in LLMs during backend optimization. For more details, please refer to QNN documents. Note that it is only supported starting from QNN 2.35. cc: @haowhsu-quic , @winskuo-quic
1 parent bed0f9e commit 8b2ddb2

File tree

12 files changed

+259
-6
lines changed

12 files changed

+259
-6
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class LayoutTransform(ExportPass):
6363
exir_ops.edge.aten.abs.default,
6464
exir_ops.edge.aten.add.Tensor,
6565
exir_ops.edge.aten.amax.default,
66+
exir_ops.edge.aten.amin.default,
6667
exir_ops.edge.aten.atan.default,
6768
exir_ops.edge.aten.bitwise_or.Tensor,
6869
exir_ops.edge.aten.bmm.default,

backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88
#include <executorch/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h>
99
#include <pybind11/pybind11.h>
10+
#include "QnnSdkBuildId.h"
1011

1112
namespace py = pybind11;
1213
namespace executorch {
@@ -15,10 +16,27 @@ namespace qnn {
1516

1617
using executorch::runtime::Error;
1718

19+
std::string GetQnnSdkBuildId(std::string library_path) {
20+
QnnImplementation qnn_loaded_backend = QnnImplementation(library_path);
21+
ET_CHECK_MSG(
22+
qnn_loaded_backend.Load(nullptr) == Error::Ok,
23+
"Fail to load Qnn library");
24+
const char* id = nullptr;
25+
// Safe to call any time, backend does not have to be created.
26+
Qnn_ErrorHandle_t err =
27+
qnn_loaded_backend.GetQnnInterface().qnn_backend_get_build_id(&id);
28+
if (err != QNN_SUCCESS || id == nullptr) {
29+
throw std::runtime_error("Failed to get QNN backend build ID");
30+
}
31+
qnn_loaded_backend.TerminateAllBackends();
32+
return std::string(id);
33+
}
34+
1835
PYBIND11_MODULE(PyQnnManagerAdaptor, m) {
1936
// TODO: Add related documents for configurations listed below
2037
using namespace qnn_delegate;
2138

39+
m.def("GetQnnSdkBuildId", &GetQnnSdkBuildId);
2240
py::class_<QnnExecuTorchContextBinary>(m, "QnnExecuTorchContextBinary")
2341
.def(py::init<>());
2442

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
op_adaptive_avg_pool2d,
1111
op_add,
1212
op_amax,
13+
op_amin,
1314
op_and,
1415
op_arange,
1516
op_argmax,
@@ -106,6 +107,7 @@
106107
op_adaptive_avg_pool2d,
107108
op_add,
108109
op_amax,
110+
op_amin,
109111
op_and,
110112
op_arange,
111113
op_argmax,

backends/qualcomm/builders/op_amin.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict, List
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
17+
from .qnn_constants import OpReduceMin, QNN_OP_PACKAGE_NAME_QTI_AISW
18+
19+
20+
@register_node_visitor
21+
class AMin(NodeVisitor):
22+
target = ["aten.amin.default"]
23+
24+
def __init__(self, *args) -> None:
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
31+
) -> PyQnnWrapper.PyQnnOpWrapper:
32+
input_node = self.get_node(node.args[0])
33+
input_tensor = self.get_tensor(input_node, node)
34+
input_tensor_wrapper = self.define_tensor(
35+
input_node,
36+
node,
37+
input_tensor,
38+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
39+
nodes_to_wrappers,
40+
)
41+
42+
# mean dims and keep dims
43+
mean_dims = cast(List[int], node.args[1])
44+
mean_dims = [
45+
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
46+
]
47+
if QCOM_AXIS_ORDER in node.meta:
48+
mean_dims = [
49+
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
50+
]
51+
mean_dims_shape = [len(mean_dims)]
52+
53+
output_tensor = self.get_tensor(node, node)
54+
output_tensor_wrapper = self.define_tensor(
55+
node,
56+
node,
57+
output_tensor,
58+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
59+
nodes_to_wrappers,
60+
)
61+
62+
reduce_min_op = PyQnnWrapper.PyQnnOpWrapper(
63+
node.name,
64+
QNN_OP_PACKAGE_NAME_QTI_AISW,
65+
OpReduceMin.op_name,
66+
)
67+
reduce_min_op.AddInputTensors([input_tensor_wrapper])
68+
reduce_min_op.AddOutputTensors([output_tensor_wrapper])
69+
reduce_min_op.AddTensorParam(
70+
OpReduceMin.param_axes,
71+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
72+
len(mean_dims_shape),
73+
mean_dims_shape,
74+
np.array(mean_dims, dtype=np.uint32),
75+
True,
76+
)
77+
if len(node.args) > 2:
78+
keep_dims = cast(bool, node.args[2])
79+
reduce_min_op.AddScalarParam(
80+
OpReduceMin.param_keep_dims,
81+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
82+
{QCOM_DATA: keep_dims},
83+
)
84+
85+
return reduce_min_op

backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ def annotate_argmax(node: Node, quantization_config: QuantizationConfig) -> None
217217
annotate_single_in(node, quantization_config)
218218

219219

220+
@register_annotator([torch.ops.aten.amin.default])
221+
def annotate_amin(node: Node, quantization_config: QuantizationConfig) -> None:
222+
annotate_binary(node, quantization_config)
223+
224+
220225
@register_annotator([torch.ops.aten.argmin.default])
221226
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
222227
annotate_single_in(node, quantization_config)

backends/qualcomm/runtime/backends/QnnFunctionInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class QnnInterface {
3232

3333
// --------- QnnBackend ---------
3434
DEFINE_SHIM_FUNCTION_INTERFACE(backend_create, backendCreate);
35+
DEFINE_SHIM_FUNCTION_INTERFACE(backend_get_build_id, backendGetBuildId);
3536
DEFINE_SHIM_FUNCTION_INTERFACE(backend_free, backendFree);
3637
DEFINE_SHIM_FUNCTION_INTERFACE(
3738
backend_register_op_package,

backends/qualcomm/tests/models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ def forward(self, x):
102102
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
103103

104104

105+
class AMin(torch.nn.Module):
106+
def __init__(self, dim=None, keepdim=False):
107+
super().__init__()
108+
self.dim = dim
109+
self.keepdim = keepdim
110+
111+
def forward(self, x):
112+
return torch.amin(x, dim=self.dim, keepdim=self.keepdim)
113+
114+
105115
class Arange(torch.nn.Module):
106116
def __init__(self, start, end, step, dtype):
107117
super().__init__()
@@ -1155,6 +1165,17 @@ def forward(self, attn_mask):
11551165
)
11561166

11571167

1168+
class MaskedSoftmax(torch.nn.Module):
1169+
def __init__(self):
1170+
super().__init__()
1171+
1172+
def forward(self, attention_mask, input):
1173+
attn_weights = torch.where(
1174+
attention_mask == 0, input, torch.amin(input, dim=3, keepdim=True) + (-20)
1175+
)
1176+
return torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
1177+
1178+
11581179
class MaxDim(torch.nn.Module):
11591180
def __init__(self):
11601181
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,14 @@
4646
from_context_binary,
4747
generate_htp_compiler_spec,
4848
generate_qnn_executorch_compiler_spec,
49+
is_qnn_sdk_version_less_than,
4950
PyQnnManagerAdaptor,
5051
rewrite_prepared_observer,
5152
skip_annotation,
5253
to_edge_transform_and_lower_to_qnn,
5354
update_spill_fill_size,
5455
)
5556

56-
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
57-
58-
from executorch.examples.models.llama.model_args import ModelArgs
59-
6057
from executorch.examples.qualcomm.utils import (
6158
make_quantizer,
6259
setup_common_args_and_variables,
@@ -136,6 +133,13 @@ def test_qnn_backend_amax(self):
136133
with self.subTest(i=i):
137134
self.lower_module_and_test_output(module, sample_input)
138135

136+
def test_qnn_backend_amin(self):
137+
modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405
138+
sample_input = (torch.randn(4, 4),)
139+
for i, module in enumerate(modules):
140+
with self.subTest(i=i):
141+
self.lower_module_and_test_output(module, sample_input)
142+
139143
def test_qnn_backend_any(self):
140144
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
141145
sample_input = (torch.randn(3, 3, 3) > 0,)
@@ -1227,6 +1231,9 @@ def test_qnn_backend_lift_add_tensor(self):
12271231

12281232
@unittest.skip("Fail because of bad accuracy")
12291233
def test_qnn_backend_moe_feed_forward(self):
1234+
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
1235+
from executorch.examples.models.llama.model_args import ModelArgs
1236+
12301237
args = ModelArgs()
12311238
args.dim = 32
12321239
args.n_heads = 8
@@ -1421,6 +1428,14 @@ def test_qnn_backend_amax(self):
14211428
module = self.get_qdq_module(module, sample_input)
14221429
self.lower_module_and_test_output(module, sample_input)
14231430

1431+
def test_qnn_backend_amin(self):
1432+
modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405
1433+
sample_input = (torch.randn(4, 4),)
1434+
for i, module in enumerate(modules):
1435+
with self.subTest(i=i):
1436+
module = self.get_qdq_module(module, sample_input)
1437+
self.lower_module_and_test_output(module, sample_input)
1438+
14241439
def test_qnn_backend_any(self):
14251440
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
14261441
sample_input = (torch.randn(3, 3, 3) > 0,)
@@ -2643,8 +2658,57 @@ def test_qnn_backend_einsum_outer_product_relu(self):
26432658
module = self.get_qdq_module(module, sample_input)
26442659
self.lower_module_and_test_output(module, sample_input)
26452660

2661+
@unittest.skipIf(is_qnn_sdk_version_less_than("2.35"), "UT pass after QNN 2.35")
2662+
def test_qnn_backend_masked_softmax(self):
2663+
if self.enable_x86_64:
2664+
self.skipTest(
2665+
"At the moment, testing is only being conducted on the device."
2666+
)
2667+
module = MaskedSoftmax() # noqa: F405
2668+
kv_arange = torch.arange(128)
2669+
reshaped_cache_position = torch.tensor([[0]])
2670+
2671+
# Simplest and most efficient way to obtain a causal mask
2672+
causal_mask = kv_arange <= reshaped_cache_position
2673+
atten_mask = torch.full((1, 128), torch.tensor(-65535.0))
2674+
atten_mask = atten_mask.masked_fill(causal_mask, 0)
2675+
atten_mask = atten_mask[None, None, :, :].expand(1, -1, -1, -1)
2676+
sample_input = (atten_mask, torch.randn([1, 1, 1, 128]))
2677+
# Masked softmax is only support in quantized model
2678+
module = self.get_qdq_module(
2679+
module, sample_input, quant_dtype=QuantDtype.use_16a8w
2680+
)
2681+
backend_options = generate_htp_compiler_spec(use_fp16=False)
2682+
compiler_spec = generate_qnn_executorch_compiler_spec(
2683+
soc_model=self.chipset_table[TestQNN.model],
2684+
backend_options=backend_options,
2685+
optrace=True,
2686+
)
2687+
with tempfile.TemporaryDirectory() as tmp_dir:
2688+
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
2689+
module, sample_input, compiler_spec
2690+
).to_executorch()
2691+
pte_path = f"{tmp_dir}/model.pte"
2692+
with open(pte_path, "wb") as f:
2693+
edge_prog_mgr.write_to_file(f)
2694+
adb = self.get_adb_tool(pte_path)
2695+
binaries_trace = generate_optrace(
2696+
tmp_dir, self.chipset_table[self.model], adb, pte_path, sample_input
2697+
)
2698+
has_masked_softmax = False
2699+
for _, (_, qhas) in binaries_trace.items():
2700+
with open(qhas, "r") as qhas_file:
2701+
qhas_data = json.load(qhas_file)
2702+
for row in qhas_data["data"]["htp_op_types"]["data"]:
2703+
if "MaskedSoftmax" in row["op"]:
2704+
has_masked_softmax = True
2705+
self.assertTrue(has_masked_softmax)
2706+
26462707
@unittest.skip("UT pass before QNN 2.26, segfault during partitioner")
26472708
def test_qnn_backend_moe_feed_forward(self):
2709+
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
2710+
from executorch.examples.models.llama.model_args import ModelArgs
2711+
26482712
args = ModelArgs()
26492713
args.dim = 32
26502714
args.n_heads = 8

backends/qualcomm/utils/utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import operator
7+
import os
8+
import re
79
import warnings
810
from collections import defaultdict, OrderedDict
911
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1012

1113
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
1214

1315
import executorch.exir as exir
14-
1516
import torch
1617

1718
from executorch.backends.qualcomm._passes import AnnotateStack, AnnotateUnbind
@@ -1167,3 +1168,28 @@ def rewrite_prepared_observer(
11671168
continue
11681169
for target_name in module_name_list[old_module]:
11691170
setattr(graph_module, target_name, new_observer)
1171+
1172+
1173+
def get_sdk_build_id():
1174+
htp_library_path = (
1175+
os.environ.get("QNN_SDK_ROOT", None) + "/lib/x86_64-linux-clang/libQnnHtp.so"
1176+
)
1177+
# The GetQnnSdkBuildId API can be used without needing to create a backend first, so it works regardless of which backend is used.
1178+
sdk_build_id = PyQnnManagerAdaptor.GetQnnSdkBuildId(htp_library_path)
1179+
return sdk_build_id
1180+
1181+
1182+
def is_qnn_sdk_version_less_than(target_version):
1183+
current_version = get_sdk_build_id()
1184+
1185+
match = re.search(r"v(\d+)\.(\d+)", current_version)
1186+
if match:
1187+
current_major, current_minor = map(int, match.groups()[:2])
1188+
else:
1189+
raise ValueError(
1190+
f"Failed to get current major and minor version from QNN sdk Build id {current_version}"
1191+
)
1192+
1193+
target_major, target_minor = map(int, target_version.split(".")[:2])
1194+
1195+
return current_major == target_major and current_minor < target_minor

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,16 @@ On the other hand, if you already have a pre-compiled .pte model, you can perfor
124124
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
125125
```
126126

127+
#### KV Cache Updater
128+
127129
You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask".
128130
`KV_UPDATER` = "shift_pointer"
129131
```bash
130132
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
131133
```
132134

135+
#### Lookahead Decoding Mode
136+
133137
You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters:
134138
- `--ngram` (N-gram size): Represents the size of the n-grams used in the lookahead process.
135139
- `--window` (window size): Determines how many future tokens the algorithm attempts to predict in each step.
@@ -140,3 +144,8 @@ For more details, please refer to the paper ["Break the Sequential Dependency of
140144
```bash
141145
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
142146
```
147+
148+
#### Masked Softmax
149+
150+
You can enable MaskedSoftmax feature by providing the flag `--enable_masked_softmax`. It is designed to optimize the LLMs accuracy and performance executed on HTP backend. MaskedSoftmax is used to replace the Softmax(Add(In, Mask)) structure in attention block in LLMs during backend optimization. For more details, please refer to QNN documents.
151+
Note that it is only supported starting from QNN 2.35.

0 commit comments

Comments
 (0)